diff options
199 files changed, 4141 insertions, 1893 deletions
diff --git a/.appveyor.yml b/.appveyor.yml index 160cdf73..6891f1b3 100644 --- a/.appveyor.yml +++ b/.appveyor.yml @@ -34,9 +34,9 @@ install: test_script: - ps: "tox -- --verbose --cov-report=term" - ps: | - $Env:VERSION = $(python mitmproxy/version.py) + $Env:VERSION = $(python -m mitmproxy.version) $Env:SKIP_MITMPROXY = "python -c `"print('skip mitmproxy')`"" - tox -e wheel + tox -e rtool -- wheel tox -e rtool -- bdist - ps: | @@ -46,7 +46,7 @@ test_script: ) { echo "Decrypt license..." tox -e rtool -- decrypt release\installbuilder\license.xml.enc release\installbuilder\license.xml - $ibVersion = "17.9.0" + $ibVersion = "17.12.0" $ibSetup = "C:\projects\mitmproxy\release\installbuilder-installer.exe" $ibCli = "C:\Program Files (x86)\BitRock InstallBuilder Enterprise $ibVersion\bin\builder-cli.exe" if (!(Test-Path $ibSetup)) { @@ -79,6 +79,7 @@ deploy_script: ($Env:TOXENV -match "py35") -and (($Env:APPVEYOR_REPO_BRANCH -In ("master", "pyinstaller")) -or ($Env:APPVEYOR_REPO_TAG -match "true")) ) { + tox -e rtool -- decrypt release\known_hosts.enc release\known_hosts tox -e rtool -- upload-snapshot --bdist --wheel --installer } @@ -11,6 +11,7 @@ MANIFEST .cache/ .tox*/ build/ +dist/ mitmproxy/contrib/kaitaistruct/*.ksy # UI diff --git a/.travis.yml b/.travis.yml index a29d0c75..b7504097 100644 --- a/.travis.yml +++ b/.travis.yml @@ -73,6 +73,7 @@ after_success: - | if [[ $BDIST == "1" && $TRAVIS_PULL_REQUEST == "false" && ($TRAVIS_BRANCH == "pyinstaller" || $TRAVIS_BRANCH == "master" || -n $TRAVIS_TAG) ]] then + tox -e rtool -- decrypt release/known_hosts.enc release/known_hosts tox -e rtool -- upload-snapshot --bdist fi @@ -186,4 +186,4 @@ with the following command: .. _PEP8: https://www.python.org/dev/peps/pep-0008 .. _`Google Style Guide`: https://google.github.io/styleguide/pyguide.html .. _forums: https://discourse.mitmproxy.org/ -.. _`good first contributions`: https://github.com/mitmproxy/mitmproxy/issues?q=is%3Aissue+is%3Aopen+label%3Agood-first-contribution +.. _`good first contributions`: https://github.com/mitmproxy/mitmproxy/issues?q=is%3Aissue+is%3Aopen+label%3A%22help+wanted%22 diff --git a/docs/certinstall.rst b/docs/certinstall.rst index 2ec5f022..6662e34d 100644 --- a/docs/certinstall.rst +++ b/docs/certinstall.rst @@ -24,10 +24,6 @@ something like this: Click on the relevant icon, follow the setup instructions for the platform you're on and you are good to go. -For iOS version 10.3 or up, you need to make sure ``mitmproxy`` is enabled in -``Certificate Trust Settings``, you can check it by going to -``Settings > General > About > Certificate Trust Settings``. - Installing the mitmproxy CA certificate manually ------------------------------------------------ @@ -39,7 +35,6 @@ documentation for some common platforms. The mitmproxy CA cert is located in ``~/.mitmproxy`` after it has been generated at the first start of mitmproxy. - iOS ^^^ @@ -47,6 +42,11 @@ See http://jasdev.me/intercepting-ios-traffic and https://web.archive.org/web/20150920082614/http://kb.mit.edu/confluence/pages/viewpage.action?pageId=152600377 +On iOS 10.3 and onwards, you also need to enable full trust for the mitmproxy root certificate: + +1. Go to Settings > General > About > Certificate Trust Settings. +2. Under "Enable full trust for root certificates", turn on trust for the mitmproxy certificate. + iOS Simulator ^^^^^^^^^^^^^ diff --git a/docs/features/passthrough.rst b/docs/features/passthrough.rst index 00462e9d..91fcb9b6 100644 --- a/docs/features/passthrough.rst +++ b/docs/features/passthrough.rst @@ -13,7 +13,7 @@ mechanism: away. Note that mitmproxy's "Limit" option is often the better alternative here, as it is not affected by the limitations listed below. -If you want to peek into (SSL-protected) non-HTTP connections, check out the :ref:`tcpproxy` +If you want to peek into (SSL-protected) non-HTTP connections, check out the :ref:`tcp_proxy` feature. If you want to ignore traffic from mitmproxy's processing because of large response bodies, take a look at the :ref:`streaming` feature. @@ -38,7 +38,7 @@ There are two important quirks to consider: - **In transparent mode, the ignore pattern is matched against the IP and ClientHello SNI host.** While we usually infer the hostname from the Host header if the ``--host`` argument is passed to mitmproxy, we do not have access to this information before the SSL handshake. If the client uses SNI however, then we treat the SNI host as an ignore target. -- **In regular mode, explicit HTTP requests are never ignored.** [#explicithttp]_ The ignore pattern is +- **In regular and upstream proxy mode, explicit HTTP requests are never ignored.** [#explicithttp]_ The ignore pattern is applied on CONNECT requests, which initiate HTTPS or clear-text WebSocket connections. Tutorial @@ -88,7 +88,7 @@ Here are some other examples for ignore patterns: .. seealso:: - - :ref:`tcpproxy` + - :ref:`tcp_proxy` - :ref:`streaming` - mitmproxy's "Limit" feature diff --git a/docs/index.rst b/docs/index.rst index 7cf593ff..8dba4d04 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -22,6 +22,15 @@ .. toctree:: :hidden: + :caption: Protocols + + protocols/http1 + protocols/http2 + protocols/websocket + protocols/tcpproxy + +.. toctree:: + :hidden: :caption: Features features/anticache @@ -36,7 +45,6 @@ features/streaming features/socksproxy features/sticky - features/tcpproxy features/upstreamproxy features/upstreamcerts diff --git a/docs/protocols/http1.rst b/docs/protocols/http1.rst new file mode 100644 index 00000000..21e68785 --- /dev/null +++ b/docs/protocols/http1.rst @@ -0,0 +1,15 @@ +.. _http1_protocol: + +HTTP/1.0 and HTTP/1.1 +=========================== + +.. seealso:: + + - `RFC7230: HTTP/1.1: Message Syntax and Routing <http://tools.ietf.org/html/rfc7230>`_ + - `RFC7231: HTTP/1.1: Semantics and Content <http://tools.ietf.org/html/rfc7231>`_ + +HTTP/1.0 and HTTP/1.1 support in mitmproxy is based on our custom HTTP stack, +which takes care of all semantics and on-the-wire parsing/serialization tasks. + +mitmproxy currently does not support HTTP trailers - but if you want to send +us a PR, we promise to take look! diff --git a/docs/protocols/http2.rst b/docs/protocols/http2.rst new file mode 100644 index 00000000..b3268ae5 --- /dev/null +++ b/docs/protocols/http2.rst @@ -0,0 +1,16 @@ +.. _http2_protocol: + +HTTP/2 +====== + +.. seealso:: + + - `RFC7540: Hypertext Transfer Protocol Version 2 (HTTP/2) <http://tools.ietf.org/html/rfc7540>`_ + +HTTP/2 support in mitmproxy is based on the amazing work by the python-hyper +community with the `hyper-h2 <https://github.com/python-hyper/hyper-h2>`_ +project. It fully encapsulates the internal state of HTTP/2 connections and +provides an easy-to-use event-based API. + +mitmproxy currently does not support HTTP/2 trailers - but if you want to send +us a PR, we promise to take look! diff --git a/docs/features/tcpproxy.rst b/docs/protocols/tcpproxy.rst index cba374e3..77248573 100644 --- a/docs/features/tcpproxy.rst +++ b/docs/protocols/tcpproxy.rst @@ -1,7 +1,7 @@ -.. _tcpproxy: +.. _tcp_proxy: -TCP Proxy -========= +TCP Proxy / Fallback +==================== In case mitmproxy does not handle a specific protocol, you can exempt hostnames from processing, so that mitmproxy acts as a generic TCP forwarder. diff --git a/docs/protocols/websocket.rst b/docs/protocols/websocket.rst new file mode 100644 index 00000000..8a7e807f --- /dev/null +++ b/docs/protocols/websocket.rst @@ -0,0 +1,22 @@ +.. _websocket_protocol: + +WebSocket +========= + +.. seealso:: + + - `RFC6455: The WebSocket Protocol <http://tools.ietf.org/html/rfc6455>`_ + - `RFC7692: Compression Extensions for WebSocket <http://tools.ietf.org/html/rfc7692>`_ + +WebSocket support in mitmproxy is based on the amazing work by the python-hyper +community with the `wsproto <https://github.com/python-hyper/wsproto>`_ +project. It fully encapsulates WebSocket frames/messages/connections and +provides an easy-to-use event-based API. + +mitmproxy fully supports the compression extension for WebSocket messages, +provided by wsproto. + +If an endpoint sends a PING to mitmproxy, a PONG will be sent back immediately +(with the same payload if present). To keep the other connection alive, a new +PING (without a payload) is sent to the other endpoint. Unsolicited PONG's are +not forwarded. All PING's and PONG's are logged (with payload if present). diff --git a/docs/scripting/api.rst b/docs/scripting/api.rst index e82afef4..368b9ba8 100644 --- a/docs/scripting/api.rst +++ b/docs/scripting/api.rst @@ -10,6 +10,9 @@ API - `mitmproxy.http.HTTPRequest <#mitmproxy.http.HTTPRequest>`_ - `mitmproxy.http.HTTPResponse <#mitmproxy.http.HTTPResponse>`_ - `mitmproxy.http.HTTPFlow <#mitmproxy.http.HTTPFlow>`_ +- WebSocket + - `mitmproxy.websocket.WebSocketFlow <#mitmproxy.websocket.WebSocketFlow>`_ + - `mitmproxy.websocket.WebSocketMessage <#mitmproxy.websocket.WebSocketMessage>`_ - Logging - `mitmproxy.log.Log <#mitmproxy.controller.Log>`_ - `mitmproxy.log.LogEntry <#mitmproxy.controller.LogEntry>`_ @@ -33,6 +36,15 @@ HTTP .. autoclass:: mitmproxy.http.HTTPFlow :inherited-members: +WebSocket +--------- + +.. autoclass:: mitmproxy.websocket.WebSocketFlow + :inherited-members: + +.. autoclass:: mitmproxy.websocket.WebSocketMessage + :inherited-members: + Logging -------- diff --git a/docs/scripting/events.rst b/docs/scripting/events.rst index 8f9463ff..d8b1fbb8 100644 --- a/docs/scripting/events.rst +++ b/docs/scripting/events.rst @@ -100,10 +100,10 @@ HTTP Events * - .. py:function:: http_connect(flow) - Called when we receive an HTTP CONNECT request. Setting a non 2xx - response on the flow will return the response to the client abort the - connection. CONNECT requests and responses do not generate the usual - HTTP handler events. CONNECT requests are only valid in regular and - upstream proxy modes. + response on the flow will return the response to the client and abort + the connection. CONNECT requests and responses do not generate the + usual HTTP handler events. CONNECT requests are only valid in regular + and upstream proxy modes. *flow* A ``models.HTTPFlow`` object. The flow is guaranteed to have @@ -187,8 +187,8 @@ are issued, only new WebSocket messages are called. - Called when a WebSocket message is received from the client or server. The sender and receiver are identifiable. The most recent message will be - ``flow.messages[-1]``. The message is user-modifiable. Currently there are - two types of messages, corresponding to the BINARY and TEXT frame types. + ``flow.messages[-1]``. The message is user-modifiable and is killable. + A message is either of TEXT or BINARY type. *flow* A ``models.WebSocketFlow`` object. @@ -211,7 +211,7 @@ TCP Events ---------- These events are called only if the connection is in :ref:`TCP mode -<tcpproxy>`. So, for instance, TCP events are not called for ordinary HTTP/S +<tcp_proxy>`. So, for instance, TCP events are not called for ordinary HTTP/S connections. .. list-table:: diff --git a/docs/transparent/linux.rst b/docs/transparent/linux.rst index 1878008c..14f6a165 100644 --- a/docs/transparent/linux.rst +++ b/docs/transparent/linux.rst @@ -11,16 +11,16 @@ achieve transparent mode. 2. Enable IP forwarding: >>> sysctl -w net.ipv4.ip_forward=1 + >>> sysctl -w net.ipv6.conf.all.forwarding=1 - You may also want to consider enabling this permanently in ``/etc/sysctl.conf``. + You may also want to consider enabling this permanently in ``/etc/sysctl.conf`` or newly created ``/etc/sysctl.d/mitmproxy.conf``, see `here <https://superuser.com/a/625852>`__. 3. If your target machine is on the same physical network and you configured it to use a custom gateway, disable ICMP redirects: - >>> echo 0 | sudo tee /proc/sys/net/ipv4/conf/*/send_redirects + >>> sysctl -w net.ipv4.conf.all.send_redirects=0 - You may also want to consider enabling this permanently in ``/etc/sysctl.conf`` - as demonstrated `here <https://unix.stackexchange.com/a/58081>`_. + You may also want to consider enabling this permanently in ``/etc/sysctl.conf`` or a newly created ``/etc/sysctl.d/mitmproxy.conf``, see `here <https://superuser.com/a/625852>`__. 4. Create an iptables ruleset that redirects the desired traffic to the mitmproxy port. Details will differ according to your setup, but the @@ -30,6 +30,10 @@ achieve transparent mode. iptables -t nat -A PREROUTING -i eth0 -p tcp --dport 80 -j REDIRECT --to-port 8080 iptables -t nat -A PREROUTING -i eth0 -p tcp --dport 443 -j REDIRECT --to-port 8080 + ip6tables -t nat -A PREROUTING -i eth0 -p tcp --dport 80 -j REDIRECT --to-port 8080 + ip6tables -t nat -A PREROUTING -i eth0 -p tcp --dport 443 -j REDIRECT --to-port 8080 + + You may also want to consider enabling this permanently with the ``iptables-persistent`` package, see `here <http://www.microhowto.info/howto/make_the_configuration_of_iptables_persistent_on_debian.html>`__. 5. Fire up mitmproxy. You probably want a command like this: diff --git a/docs/transparent/osx.rst b/docs/transparent/osx.rst index 40e91fac..5d4ec612 100644 --- a/docs/transparent/osx.rst +++ b/docs/transparent/osx.rst @@ -17,8 +17,7 @@ Note that this means we don't support transparent mode for earlier versions of O .. code-block:: none - rdr on en2 inet proto tcp to any port 80 -> 127.0.0.1 port 8080 - rdr on en2 inet proto tcp to any port 443 -> 127.0.0.1 port 8080 + rdr on en0 inet proto tcp to any port {80, 443} -> 127.0.0.1 port 8080 These rules tell pf to redirect all traffic destined for port 80 or 443 to the local mitmproxy instance running on port 8080. You should diff --git a/examples/complex/change_upstream_proxy.py b/examples/complex/change_upstream_proxy.py index 49d5379f..089a9df5 100644 --- a/examples/complex/change_upstream_proxy.py +++ b/examples/complex/change_upstream_proxy.py @@ -1,3 +1,6 @@ +from mitmproxy import http +import typing + # This scripts demonstrates how mitmproxy can switch to a second/different upstream proxy # in upstream proxy mode. # @@ -6,7 +9,7 @@ # If you want to change the target server, you should modify flow.request.host and flow.request.port -def proxy_address(flow): +def proxy_address(flow: http.HTTPFlow) -> typing.Tuple[str, int]: # Poor man's loadbalancing: route every second domain through the alternative proxy. if hash(flow.request.host) % 2 == 1: return ("localhost", 8082) @@ -14,7 +17,7 @@ def proxy_address(flow): return ("localhost", 8081) -def request(flow): +def request(flow: http.HTTPFlow) -> None: if flow.request.method == "CONNECT": # If the decision is done by domain, one could also modify the server address here. # We do it after CONNECT here to have the request data available as well. diff --git a/examples/complex/dns_spoofing.py b/examples/complex/dns_spoofing.py index 632783a7..e28934ab 100644 --- a/examples/complex/dns_spoofing.py +++ b/examples/complex/dns_spoofing.py @@ -33,7 +33,7 @@ parse_host_header = re.compile(r"^(?P<host>[^:]+|\[.+\])(?::(?P<port>\d+))?$") class Rerouter: def request(self, flow): - if flow.client_conn.ssl_established: + if flow.client_conn.tls_established: flow.request.scheme = "https" sni = flow.client_conn.connection.get_servername() port = 443 diff --git a/examples/complex/har_dump.py b/examples/complex/har_dump.py index 21bcc341..9e287a19 100644 --- a/examples/complex/har_dump.py +++ b/examples/complex/har_dump.py @@ -7,22 +7,24 @@ import json import base64 import zlib import os +import typing # noqa from datetime import datetime from datetime import timezone import mitmproxy +from mitmproxy import connections # noqa from mitmproxy import version from mitmproxy import ctx from mitmproxy.utils import strutils from mitmproxy.net.http import cookies -HAR = {} +HAR = {} # type: typing.Dict # A list of server seen till now is maintained so we can avoid # using 'connect' time for entries that use an existing connection. -SERVERS_SEEN = set() +SERVERS_SEEN = set() # type: typing.Set[connections.ServerConnection] def load(l): @@ -58,8 +60,8 @@ def response(flow): connect_time = (flow.server_conn.timestamp_tcp_setup - flow.server_conn.timestamp_start) - if flow.server_conn.timestamp_ssl_setup is not None: - ssl_time = (flow.server_conn.timestamp_ssl_setup - + if flow.server_conn.timestamp_tls_setup is not None: + ssl_time = (flow.server_conn.timestamp_tls_setup - flow.server_conn.timestamp_tcp_setup) SERVERS_SEEN.add(flow.server_conn) diff --git a/examples/complex/sslstrip.py b/examples/complex/sslstrip.py index 2f60c8b9..c3f8c4f7 100644 --- a/examples/complex/sslstrip.py +++ b/examples/complex/sslstrip.py @@ -3,13 +3,16 @@ This script implements an sslstrip-like attack based on mitmproxy. https://moxie.org/software/sslstrip/ """ import re -import urllib +import urllib.parse +import typing # noqa + +from mitmproxy import http # set of SSL/TLS capable hosts -secure_hosts = set() +secure_hosts = set() # type: typing.Set[str] -def request(flow): +def request(flow: http.HTTPFlow) -> None: flow.request.headers.pop('If-Modified-Since', None) flow.request.headers.pop('Cache-Control', None) @@ -27,7 +30,7 @@ def request(flow): flow.request.host = flow.request.pretty_host -def response(flow): +def response(flow: http.HTTPFlow) -> None: flow.response.headers.pop('Strict-Transport-Security', None) flow.response.headers.pop('Public-Key-Pins', None) diff --git a/examples/complex/xss_scanner.py b/examples/complex/xss_scanner.py index 4b35c6c1..0ee38cd4 100755 --- a/examples/complex/xss_scanner.py +++ b/examples/complex/xss_scanner.py @@ -35,14 +35,17 @@ Line: 1029zxcs'd"ao<ac>so[sb]po(pc)se;sl/bsl\eq=3847asd """ -from mitmproxy import ctx +from html.parser import HTMLParser +from typing import Dict, Union, Tuple, Optional, List, NamedTuple from socket import gaierror, gethostbyname from urllib.parse import urlparse -import requests import re -from html.parser import HTMLParser + +import requests + from mitmproxy import http -from typing import Dict, Union, Tuple, Optional, List, NamedTuple +from mitmproxy import ctx + # The actual payload is put between a frontWall and a backWall to make it easy # to locate the payload with regular expressions @@ -83,15 +86,16 @@ def get_cookies(flow: http.HTTPFlow) -> Cookies: return {name: value for name, value in flow.request.cookies.fields} -def find_unclaimed_URLs(body: Union[str, bytes], requestUrl: bytes) -> None: +def find_unclaimed_URLs(body: str, requestUrl: bytes) -> None: """ Look for unclaimed URLs in script tags and log them if found""" - def getValue(attrs: List[Tuple[str, str]], attrName: str) -> str: + def getValue(attrs: List[Tuple[str, str]], attrName: str) -> Optional[str]: for name, value in attrs: if attrName == name: return value + return None class ScriptURLExtractor(HTMLParser): - script_URLs = [] + script_URLs = [] # type: List[str] def handle_starttag(self, tag, attrs): if (tag == "script" or tag == "iframe") and "src" in [name for name, value in attrs]: @@ -100,13 +104,10 @@ def find_unclaimed_URLs(body: Union[str, bytes], requestUrl: bytes) -> None: self.script_URLs.append(getValue(attrs, "href")) parser = ScriptURLExtractor() - try: - parser.feed(body) - except TypeError: - parser.feed(body.decode('utf-8')) + parser.feed(body) for url in parser.script_URLs: - parser = urlparse(url) - domain = parser.netloc + url_parser = urlparse(url) + domain = url_parser.netloc try: gethostbyname(domain) except gaierror: @@ -178,10 +179,11 @@ def log_SQLi_data(sqli_info: Optional[SQLiData]) -> None: if not sqli_info: return ctx.log.error("===== SQLi Found =====") - ctx.log.error("SQLi URL: %s" % sqli_info.url.decode('utf-8')) - ctx.log.error("Injection Point: %s" % sqli_info.injection_point.decode('utf-8')) - ctx.log.error("Regex used: %s" % sqli_info.regex.decode('utf-8')) - ctx.log.error("Suspected DBMS: %s" % sqli_info.dbms.decode('utf-8')) + ctx.log.error("SQLi URL: %s" % sqli_info.url) + ctx.log.error("Injection Point: %s" % sqli_info.injection_point) + ctx.log.error("Regex used: %s" % sqli_info.regex) + ctx.log.error("Suspected DBMS: %s" % sqli_info.dbms) + return def get_SQLi_data(new_body: str, original_body: str, request_URL: str, injection_point: str) -> Optional[SQLiData]: @@ -202,20 +204,21 @@ def get_SQLi_data(new_body: str, original_body: str, request_URL: str, injection "Sybase": (r"(?i)Warning.*sybase.*", r"Sybase message", r"Sybase.*Server message.*"), } for dbms, regexes in DBMS_ERRORS.items(): - for regex in regexes: + for regex in regexes: # type: ignore if re.search(regex, new_body, re.IGNORECASE) and not re.search(regex, original_body, re.IGNORECASE): return SQLiData(request_URL, injection_point, regex, dbms) + return None # A qc is either ' or " -def inside_quote(qc: str, substring: bytes, text_index: int, body: bytes) -> bool: +def inside_quote(qc: str, substring_bytes: bytes, text_index: int, body_bytes: bytes) -> bool: """ Whether the Numberth occurence of the first string in the second string is inside quotes as defined by the supplied QuoteChar """ - substring = substring.decode('utf-8') - body = body.decode('utf-8') + substring = substring_bytes.decode('utf-8') + body = body_bytes.decode('utf-8') num_substrings_found = 0 in_quote = False for index, char in enumerate(body): @@ -238,20 +241,20 @@ def inside_quote(qc: str, substring: bytes, text_index: int, body: bytes) -> boo return False -def paths_to_text(html: str, str: str) -> List[str]: +def paths_to_text(html: str, string: str) -> List[str]: """ Return list of Paths to a given str in the given HTML tree - Note that it does a BFS """ - def remove_last_occurence_of_sub_string(str: str, substr: str): + def remove_last_occurence_of_sub_string(string: str, substr: str) -> str: """ Delete the last occurence of substr from str String String -> String """ - index = str.rfind(substr) - return str[:index] + str[index + len(substr):] + index = string.rfind(substr) + return string[:index] + string[index + len(substr):] class PathHTMLParser(HTMLParser): currentPath = "" - paths = [] + paths = [] # type: List[str] def handle_starttag(self, tag, attrs): self.currentPath += ("/" + tag) @@ -260,7 +263,7 @@ def paths_to_text(html: str, str: str) -> List[str]: self.currentPath = remove_last_occurence_of_sub_string(self.currentPath, "/" + tag) def handle_data(self, data): - if str in data: + if string in data: self.paths.append(self.currentPath) parser = PathHTMLParser() @@ -268,7 +271,7 @@ def paths_to_text(html: str, str: str) -> List[str]: return parser.paths -def get_XSS_data(body: str, request_URL: str, injection_point: str) -> Optional[XSSData]: +def get_XSS_data(body: Union[str, bytes], request_URL: str, injection_point: str) -> Optional[XSSData]: """ Return a XSSDict if there is a XSS otherwise return None """ def in_script(text, index, body) -> bool: """ Whether the Numberth occurence of the first string in the second @@ -314,9 +317,9 @@ def get_XSS_data(body: str, request_URL: str, injection_point: str) -> Optional[ matches = regex.findall(body) for index, match in enumerate(matches): # Where the string is injected into the HTML - in_script = in_script(match, index, body) - in_HTML = in_HTML(match, index, body) - in_tag = not in_script and not in_HTML + in_script_val = in_script(match, index, body) + in_HTML_val = in_HTML(match, index, body) + in_tag = not in_script_val and not in_HTML_val in_single_quotes = inside_quote("'", match, index, body) in_double_quotes = inside_quote('"', match, index, body) # Whether you can inject: @@ -327,17 +330,17 @@ def get_XSS_data(body: str, request_URL: str, injection_point: str) -> Optional[ inject_slash = b"sl/bsl" in match # forward slashes inject_semi = b"se;sl" in match # semicolons inject_equals = b"eq=" in match # equals sign - if in_script and inject_slash and inject_open_angle and inject_close_angle: # e.g. <script>PAYLOAD</script> + if in_script_val and inject_slash and inject_open_angle and inject_close_angle: # e.g. <script>PAYLOAD</script> return XSSData(request_URL, injection_point, '</script><script>alert(0)</script><script>', match.decode('utf-8')) - elif in_script and in_single_quotes and inject_single_quotes and inject_semi: # e.g. <script>t='PAYLOAD';</script> + elif in_script_val and in_single_quotes and inject_single_quotes and inject_semi: # e.g. <script>t='PAYLOAD';</script> return XSSData(request_URL, injection_point, "';alert(0);g='", match.decode('utf-8')) - elif in_script and in_double_quotes and inject_double_quotes and inject_semi: # e.g. <script>t="PAYLOAD";</script> + elif in_script_val and in_double_quotes and inject_double_quotes and inject_semi: # e.g. <script>t="PAYLOAD";</script> return XSSData(request_URL, injection_point, '";alert(0);g="', @@ -380,33 +383,35 @@ def get_XSS_data(body: str, request_URL: str, injection_point: str) -> Optional[ injection_point, " onmouseover=alert(0) t=", match.decode('utf-8')) - elif in_HTML and not in_script and inject_open_angle and inject_close_angle and inject_slash: # e.g. <html>PAYLOAD</html> + elif in_HTML_val and not in_script_val and inject_open_angle and inject_close_angle and inject_slash: # e.g. <html>PAYLOAD</html> return XSSData(request_URL, injection_point, '<script>alert(0)</script>', match.decode('utf-8')) else: return None + return None # response is mitmproxy's entry point def response(flow: http.HTTPFlow) -> None: - cookiesDict = get_cookies(flow) + cookies_dict = get_cookies(flow) + resp = flow.response.get_text(strict=False) # Example: http://xss.guru/unclaimedScriptTag.html - find_unclaimed_URLs(flow.response.content, flow.request.url) - results = test_end_of_URL_injection(flow.response.content.decode('utf-8'), flow.request.url, cookiesDict) + find_unclaimed_URLs(resp, flow.request.url) + results = test_end_of_URL_injection(resp, flow.request.url, cookies_dict) log_XSS_data(results[0]) log_SQLi_data(results[1]) # Example: https://daviddworken.com/vulnerableReferer.php - results = test_referer_injection(flow.response.content.decode('utf-8'), flow.request.url, cookiesDict) + results = test_referer_injection(resp, flow.request.url, cookies_dict) log_XSS_data(results[0]) log_SQLi_data(results[1]) # Example: https://daviddworken.com/vulnerableUA.php - results = test_user_agent_injection(flow.response.content.decode('utf-8'), flow.request.url, cookiesDict) + results = test_user_agent_injection(resp, flow.request.url, cookies_dict) log_XSS_data(results[0]) log_SQLi_data(results[1]) if "?" in flow.request.url: # Example: https://daviddworken.com/vulnerable.php?name= - results = test_query_injection(flow.response.content.decode('utf-8'), flow.request.url, cookiesDict) + results = test_query_injection(resp, flow.request.url, cookies_dict) log_XSS_data(results[0]) log_SQLi_data(results[1]) diff --git a/examples/simple/README.md b/examples/simple/README.md index 5a7782db..d140a84c 100644 --- a/examples/simple/README.md +++ b/examples/simple/README.md @@ -1,18 +1,18 @@ ## Simple Examples -| Filename | Description | -|:-----------------------------|:---------------------------------------------------------------------------| -| add_header.py | Simple script that just adds a header to every request. | -| custom_contentview.py | Add a custom content view to the mitmproxy UI. | -| filter_flows.py | This script demonstrates how to use mitmproxy's filter pattern in scripts. | -| io_read_dumpfile.py | Read a dumpfile generated by mitmproxy. | -| io_write_dumpfile.py | Only write selected flows into a mitmproxy dumpfile. | -| log_events.py | Use mitmproxy's logging API. | -| modify_body_inject_iframe.py | Inject configurable iframe into pages. | -| modify_form.py | Modify HTTP form submissions. | -| modify_querystring.py | Modify HTTP query strings. | -| redirect_requests.py | Redirect a request to a different server. | -| script_arguments.py | Add arguments to a script. | -| send_reply_from_proxy.py | Send a HTTP response directly from the proxy. | -| upsidedownternet.py | Turn all images upside down. | -| wsgi_flask_app.py | Embed a WSGI app into mitmproxy. | +| Filename | Description | +| :----------------------------- | :--------------------------------------------------------------------------- | +| add_header.py | Simple script that just adds a header to every request. | +| custom_contentview.py | Add a custom content view to the mitmproxy UI. | +| custom_option.py | Add arguments to a script. | +| filter_flows.py | This script demonstrates how to use mitmproxy's filter pattern in scripts. | +| io_read_dumpfile.py | Read a dumpfile generated by mitmproxy. | +| io_write_dumpfile.py | Only write selected flows into a mitmproxy dumpfile. | +| log_events.py | Use mitmproxy's logging API. | +| modify_body_inject_iframe.py | Inject configurable iframe into pages. | +| modify_form.py | Modify HTTP form submissions. | +| modify_querystring.py | Modify HTTP query strings. | +| redirect_requests.py | Redirect a request to a different server. | +| send_reply_from_proxy.py | Send a HTTP response directly from the proxy. | +| upsidedownternet.py | Turn all images upside down. | +| wsgi_flask_app.py | Embed a WSGI app into mitmproxy. | diff --git a/mitmproxy/__init__.py b/mitmproxy/__init__.py index 9697de87..e69de29b 100644 --- a/mitmproxy/__init__.py +++ b/mitmproxy/__init__.py @@ -1,3 +0,0 @@ -# https://github.com/mitmproxy/mitmproxy/issues/1809 -# import script here so that pyinstaller registers it. -from . import script # noqa diff --git a/mitmproxy/addonmanager.py b/mitmproxy/addonmanager.py index 70cfda30..37c501ee 100644 --- a/mitmproxy/addonmanager.py +++ b/mitmproxy/addonmanager.py @@ -230,7 +230,7 @@ class AddonManager: self.trigger(name, message) - if message.reply.state != "taken": + if message.reply.state == "start": message.reply.take() if not message.reply.has_message: message.reply.ack() diff --git a/mitmproxy/addons/browser.py b/mitmproxy/addons/browser.py index 6e8b2585..247c356b 100644 --- a/mitmproxy/addons/browser.py +++ b/mitmproxy/addons/browser.py @@ -1,15 +1,27 @@ +import shutil import subprocess -import sys import tempfile +import typing from mitmproxy import command from mitmproxy import ctx -platformPaths = { - "linux": "google-chrome", - "win32": "chrome.exe", - "darwin": "/Applications/Google Chrome.app/Contents/MacOS/Google Chrome", -} + +def get_chrome_executable() -> typing.Optional[str]: + for browser in ( + "/Applications/Google Chrome.app/Contents/MacOS/Google Chrome", + # https://stackoverflow.com/questions/40674914/google-chrome-path-in-windows-10 + r"C:\Program Files (x86)\Google\Chrome\Application\chrome.exe", + r"C:\Program Files (x86)\Google\Application\chrome.exe", + # Linux binary names from Python's webbrowser module. + "google-chrome", + "chrome", + "chromium", + "chromium-browser", + ): + if shutil.which(browser): + return browser + return None class Browser: @@ -29,8 +41,8 @@ class Browser: else: self.done() - cmd = platformPaths.get(sys.platform) - if not cmd: # pragma: no cover + cmd = get_chrome_executable() + if not cmd: ctx.log.alert("Your platform is not supported yet - please submit a patch.") return @@ -59,4 +71,4 @@ class Browser: self.browser.kill() self.tdir.cleanup() self.browser = None - self.tdir = None
\ No newline at end of file + self.tdir = None diff --git a/mitmproxy/addons/clientplayback.py b/mitmproxy/addons/clientplayback.py index 9e012b67..2dd488b9 100644 --- a/mitmproxy/addons/clientplayback.py +++ b/mitmproxy/addons/clientplayback.py @@ -3,6 +3,7 @@ from mitmproxy import ctx from mitmproxy import io from mitmproxy import flow from mitmproxy import command +import mitmproxy.types import typing @@ -26,6 +27,7 @@ class ClientPlayback: Stop client replay. """ self.flows = [] + ctx.log.alert("Client replay stopped.") ctx.master.addons.trigger("update", []) @command.command("replay.client") @@ -33,16 +35,22 @@ class ClientPlayback: """ Replay requests from flows. """ + for f in flows: + if f.live: + raise exceptions.CommandError("Can't replay live flow.") self.flows = list(flows) + ctx.log.alert("Replaying %s flows." % len(self.flows)) ctx.master.addons.trigger("update", []) @command.command("replay.client.file") - def load_file(self, path: str) -> None: + def load_file(self, path: mitmproxy.types.Path) -> None: try: flows = io.read_flows_from_paths([path]) except exceptions.FlowReadException as e: raise exceptions.CommandError(str(e)) + ctx.log.alert("Replaying %s flows." % len(self.flows)) self.flows = flows + ctx.master.addons.trigger("update", []) def configure(self, updated): if not self.configured and ctx.options.client_replay: diff --git a/mitmproxy/addons/core.py b/mitmproxy/addons/core.py index 33d67279..ca21e1dc 100644 --- a/mitmproxy/addons/core.py +++ b/mitmproxy/addons/core.py @@ -6,6 +6,7 @@ from mitmproxy import command from mitmproxy import flow from mitmproxy import optmanager from mitmproxy.net.http import status_codes +import mitmproxy.types class Core: @@ -96,19 +97,16 @@ class Core: ] @command.command("flow.set") + @command.argument("spec", type=mitmproxy.types.Choice("flow.set.options")) def flow_set( self, - flows: typing.Sequence[flow.Flow], spec: str, sval: str + flows: typing.Sequence[flow.Flow], + spec: str, + sval: str ) -> None: """ Quickly set a number of common values on flows. """ - opts = self.flow_set_options() - if spec not in opts: - raise exceptions.CommandError( - "Set spec must be one of: %s." % ", ".join(opts) - ) - val = sval # type: typing.Union[int, str] if spec == "status_code": try: @@ -166,6 +164,7 @@ class Core: for f in flows: p = getattr(f, part, None) if p: + f.backup() p.decode() updated.append(f) ctx.master.addons.trigger("update", updated) @@ -180,6 +179,7 @@ class Core: for f in flows: p = getattr(f, part, None) if p: + f.backup() current_enc = p.headers.get("content-encoding", "identity") if current_enc == "identity": p.encode("deflate") @@ -190,19 +190,23 @@ class Core: ctx.log.alert("Toggled encoding on %s flows." % len(updated)) @command.command("flow.encode") - def encode(self, flows: typing.Sequence[flow.Flow], part: str, enc: str) -> None: + @command.argument("enc", type=mitmproxy.types.Choice("flow.encode.options")) + def encode( + self, + flows: typing.Sequence[flow.Flow], + part: str, + enc: str, + ) -> None: """ Encode flows with a specified encoding. """ - if enc not in self.encode_options(): - raise exceptions.CommandError("Invalid encoding format: %s" % enc) - updated = [] for f in flows: p = getattr(f, part, None) if p: current_enc = p.headers.get("content-encoding", "identity") if current_enc == "identity": + f.backup() p.encode(enc) updated.append(f) ctx.master.addons.trigger("update", updated) @@ -212,12 +216,11 @@ class Core: def encode_options(self) -> typing.Sequence[str]: """ The possible values for an encoding specification. - """ return ["gzip", "deflate", "br"] @command.command("options.load") - def options_load(self, path: str) -> None: + def options_load(self, path: mitmproxy.types.Path) -> None: """ Load options from a file. """ @@ -229,7 +232,7 @@ class Core: ) from e @command.command("options.save") - def options_save(self, path: str) -> None: + def options_save(self, path: mitmproxy.types.Path) -> None: """ Save options to a file. """ diff --git a/mitmproxy/addons/cut.py b/mitmproxy/addons/cut.py index a4a2107b..1c8fbc05 100644 --- a/mitmproxy/addons/cut.py +++ b/mitmproxy/addons/cut.py @@ -7,6 +7,7 @@ from mitmproxy import flow from mitmproxy import ctx from mitmproxy import certs from mitmproxy.utils import strutils +import mitmproxy.types import pyperclip @@ -17,14 +18,6 @@ def headername(spec: str): return spec[len("header["):-1].strip() -flow_shortcuts = { - "q": "request", - "s": "response", - "cc": "client_conn", - "sc": "server_conn", -} - - def is_addr(v): return isinstance(v, tuple) and len(v) > 1 @@ -35,8 +28,6 @@ def extract(cut: str, f: flow.Flow) -> typing.Union[str, bytes]: for i, spec in enumerate(path): if spec.startswith("_"): raise exceptions.CommandError("Can't access internal attribute %s" % spec) - if isinstance(current, flow.Flow): - spec = flow_shortcuts.get(spec, spec) part = getattr(current, spec, None) if i == len(path) - 1: @@ -45,60 +36,49 @@ def extract(cut: str, f: flow.Flow) -> typing.Union[str, bytes]: if spec == "host" and is_addr(current): return str(current[0]) elif spec.startswith("header["): + if not current: + return "" return current.headers.get(headername(spec), "") elif isinstance(part, bytes): return part elif isinstance(part, bool): return "true" if part else "false" - elif isinstance(part, certs.SSLCert): + elif isinstance(part, certs.Cert): return part.to_pem().decode("ascii") current = part return str(current or "") -def parse_cutspec(s: str) -> typing.Tuple[str, typing.Sequence[str]]: - """ - Returns (flowspec, [cuts]). - - Raises exceptions.CommandError if input is invalid. - """ - parts = s.split("|", maxsplit=1) - flowspec = "@all" - if len(parts) == 2: - flowspec = parts[1].strip() - cuts = parts[0] - cutparts = [i.strip() for i in cuts.split(",") if i.strip()] - if len(cutparts) == 0: - raise exceptions.CommandError("Invalid cut specification.") - return flowspec, cutparts - - class Cut: @command.command("cut") - def cut(self, cutspec: str) -> command.Cuts: + def cut( + self, + flows: typing.Sequence[flow.Flow], + cuts: mitmproxy.types.CutSpec, + ) -> mitmproxy.types.Data: """ - Resolve a cut specification of the form "cuts|flowspec". The cuts - are a comma-separated list of cut snippets. Cut snippets are - attribute paths from the base of the flow object, with a few - conveniences - "q", "s", "cc" and "sc" are shortcuts for request, - response, client_conn and server_conn, "port" and "host" retrieve - parts of an address tuple, ".header[key]" retrieves a header value. - Return values converted sensibly: SSL certicates are converted to PEM - format, bools are "true" or "false", "bytes" are preserved, and all - other values are converted to strings. The flowspec is optional, and - if it is not specified, it is assumed to be @all. + Cut data from a set of flows. Cut specifications are attribute paths + from the base of the flow object, with a few conveniences - "port" + and "host" retrieve parts of an address tuple, ".header[key]" + retrieves a header value. Return values converted to strings or + bytes: SSL certicates are converted to PEM format, bools are "true" + or "false", "bytes" are preserved, and all other values are + converted to strings. """ - flowspec, cuts = parse_cutspec(cutspec) - flows = ctx.master.commands.call_args("view.resolve", [flowspec]) - ret = [] + ret = [] # type:typing.List[typing.List[typing.Union[str, bytes]]] for f in flows: ret.append([extract(c, f) for c in cuts]) - return ret + return ret # type: ignore @command.command("cut.save") - def save(self, cuts: command.Cuts, path: str) -> None: + def save( + self, + flows: typing.Sequence[flow.Flow], + cuts: mitmproxy.types.CutSpec, + path: mitmproxy.types.Path + ) -> None: """ - Save cuts to file. If there are multiple rows or columns, the format + Save cuts to file. If there are multiple flows or cuts, the format is UTF-8 encoded CSV. If there is exactly one row and one column, the data is written to file as-is, with raw bytes preserved. If the path is prefixed with a "+", values are appended if there is an @@ -107,35 +87,45 @@ class Cut: append = False if path.startswith("+"): append = True - path = path[1:] - if len(cuts) == 1 and len(cuts[0]) == 1: - with open(path, "ab" if append else "wb") as fp: - if fp.tell() > 0: - # We're appending to a file that already exists and has content - fp.write(b"\n") - v = cuts[0][0] - if isinstance(v, bytes): - fp.write(v) - else: - fp.write(v.encode("utf8")) - ctx.log.alert("Saved single cut.") - else: - with open(path, "a" if append else "w", newline='', encoding="utf8") as fp: - writer = csv.writer(fp) - for r in cuts: - writer.writerow( - [strutils.always_str(c) or "" for c in r] # type: ignore - ) - ctx.log.alert("Saved %s cuts as CSV." % len(cuts)) + path = mitmproxy.types.Path(path[1:]) + try: + if len(cuts) == 1 and len(flows) == 1: + with open(path, "ab" if append else "wb") as fp: + if fp.tell() > 0: + # We're appending to a file that already exists and has content + fp.write(b"\n") + v = extract(cuts[0], flows[0]) + if isinstance(v, bytes): + fp.write(v) + else: + fp.write(v.encode("utf8")) + ctx.log.alert("Saved single cut.") + else: + with open(path, "a" if append else "w", newline='', encoding="utf8") as fp: + writer = csv.writer(fp) + for f in flows: + vals = [extract(c, f) for c in cuts] + writer.writerow( + [strutils.always_str(x) or "" for x in vals] # type: ignore + ) + ctx.log.alert("Saved %s cuts over %d flows as CSV." % (len(cuts), len(flows))) + except IOError as e: + ctx.log.error(str(e)) @command.command("cut.clip") - def clip(self, cuts: command.Cuts) -> None: + def clip( + self, + flows: typing.Sequence[flow.Flow], + cuts: mitmproxy.types.CutSpec, + ) -> None: """ - Send cuts to the system clipboard. + Send cuts to the clipboard. If there are multiple flows or cuts, the + format is UTF-8 encoded CSV. If there is exactly one row and one + column, the data is written to file as-is, with raw bytes preserved. """ fp = io.StringIO(newline="") - if len(cuts) == 1 and len(cuts[0]) == 1: - v = cuts[0][0] + if len(cuts) == 1 and len(flows) == 1: + v = extract(cuts[0], flows[0]) if isinstance(v, bytes): fp.write(strutils.always_str(v)) else: @@ -143,9 +133,13 @@ class Cut: ctx.log.alert("Clipped single cut.") else: writer = csv.writer(fp) - for r in cuts: + for f in flows: + vals = [extract(c, f) for c in cuts] writer.writerow( - [strutils.always_str(c) or "" for c in r] # type: ignore + [strutils.always_str(v) or "" for v in vals] # type: ignore ) ctx.log.alert("Clipped %s cuts as CSV." % len(cuts)) - pyperclip.copy(fp.getvalue()) + try: + pyperclip.copy(fp.getvalue()) + except pyperclip.PyperclipException as e: + ctx.log.error(str(e)) diff --git a/mitmproxy/addons/dumper.py b/mitmproxy/addons/dumper.py index 54526d5b..48bc8118 100644 --- a/mitmproxy/addons/dumper.py +++ b/mitmproxy/addons/dumper.py @@ -234,6 +234,8 @@ class Dumper: message = f.messages[-1] self.echo(f.message_info(message)) if ctx.options.flow_detail >= 3: + message = message.from_state(message.get_state()) + message.content = message.content.encode() if isinstance(message.content, str) else message.content self._echo_message(message) def websocket_end(self, f): diff --git a/mitmproxy/addons/eventstore.py b/mitmproxy/addons/eventstore.py index 4e410c98..73ffc70c 100644 --- a/mitmproxy/addons/eventstore.py +++ b/mitmproxy/addons/eventstore.py @@ -1,19 +1,30 @@ -from typing import List # noqa +import collections +import typing # noqa import blinker + +from mitmproxy import command from mitmproxy.log import LogEntry class EventStore: - def __init__(self): - self.data = [] # type: List[LogEntry] + def __init__(self, size=10000): + self.data = collections.deque(maxlen=size) # type: typing.Deque[LogEntry] self.sig_add = blinker.Signal() self.sig_refresh = blinker.Signal() - def log(self, entry: LogEntry): + @property + def size(self) -> int: + return self.data.maxlen + + def log(self, entry: LogEntry) -> None: self.data.append(entry) self.sig_add.send(self, entry=entry) - def clear(self): + @command.command("eventstore.clear") + def clear(self) -> None: + """ + Clear the event log. + """ self.data.clear() self.sig_refresh.send(self) diff --git a/mitmproxy/addons/export.py b/mitmproxy/addons/export.py index fd0c830e..4bb44548 100644 --- a/mitmproxy/addons/export.py +++ b/mitmproxy/addons/export.py @@ -1,10 +1,12 @@ import typing +from mitmproxy import ctx from mitmproxy import command from mitmproxy import flow from mitmproxy import exceptions from mitmproxy.utils import strutils from mitmproxy.net.http.http1 import assemble +import mitmproxy.types import pyperclip @@ -49,7 +51,7 @@ class Export(): return list(sorted(formats.keys())) @command.command("export.file") - def file(self, fmt: str, f: flow.Flow, path: str) -> None: + def file(self, fmt: str, f: flow.Flow, path: mitmproxy.types.Path) -> None: """ Export a flow to path. """ @@ -57,11 +59,14 @@ class Export(): raise exceptions.CommandError("No such export format: %s" % fmt) func = formats[fmt] # type: typing.Any v = func(f) - with open(path, "wb") as fp: - if isinstance(v, bytes): - fp.write(v) - else: - fp.write(v.encode("utf-8")) + try: + with open(path, "wb") as fp: + if isinstance(v, bytes): + fp.write(v) + else: + fp.write(v.encode("utf-8")) + except IOError as e: + ctx.log.error(str(e)) @command.command("export.clip") def clip(self, fmt: str, f: flow.Flow) -> None: @@ -72,4 +77,7 @@ class Export(): raise exceptions.CommandError("No such export format: %s" % fmt) func = formats[fmt] # type: typing.Any v = strutils.always_str(func(f)) - pyperclip.copy(v) + try: + pyperclip.copy(v) + except pyperclip.PyperclipException as e: + ctx.log.error(str(e)) diff --git a/mitmproxy/addons/onboardingapp/static/images/favicon.ico b/mitmproxy/addons/onboardingapp/static/images/favicon.ico Binary files differnew file mode 100644 index 00000000..3c3b891c --- /dev/null +++ b/mitmproxy/addons/onboardingapp/static/images/favicon.ico diff --git a/mitmproxy/addons/onboardingapp/static/images/mitmproxy-long.png b/mitmproxy/addons/onboardingapp/static/images/mitmproxy-long.png Binary files differnew file mode 100644 index 00000000..f9397d1e --- /dev/null +++ b/mitmproxy/addons/onboardingapp/static/images/mitmproxy-long.png diff --git a/mitmproxy/addons/onboardingapp/static/mitmproxy.css b/mitmproxy/addons/onboardingapp/static/mitmproxy.css index b390976a..969bd62b 100644 --- a/mitmproxy/addons/onboardingapp/static/mitmproxy.css +++ b/mitmproxy/addons/onboardingapp/static/mitmproxy.css @@ -1,8 +1,6 @@ - #certbank div { text-align: center; - - + padding-top: 20px; } .fronttable { @@ -40,7 +38,6 @@ section { .masthead { padding: 50px 0 60px; text-align: center; - } .header { diff --git a/mitmproxy/addons/onboardingapp/templates/index.html b/mitmproxy/addons/onboardingapp/templates/index.html index fc6213ea..38aa27ed 100644 --- a/mitmproxy/addons/onboardingapp/templates/index.html +++ b/mitmproxy/addons/onboardingapp/templates/index.html @@ -4,59 +4,135 @@ <script> function changeTo(device) { if (device == "apple") { - var text = `<h3>Apple: How to install on macOS / OSX</h3> - <ul> - <li>Double-click the PEM file</li> - <li>The "Keychain Access" applications opens</li> - <li>Find the new certificate "mitmproxy" in the list</li> - <li>Double-click the "mitmproxy" entry</li> - <li>A dialog window openes up</li> - <li>Change "Secure Socket Layer (SSL)" to "Always Trust"</li> - <li>Close the dialog window (and enter your password if prompted)</li> - <li>For iOS version 10.3 or up, you need to make sure mitmproxy is enabled in<br> - Certificate Trust Settings, you can check it by going to<br> - Settings > General > About > Certificate Trust Settings</li> - <li>Done!</li> - </ul>`; + var text = `<div class = "container"> + <div> + <div class="col-md-4"> + <h3 class="text-center">How to install on macOS</h3> + <ul class="left"> + <li>Double-click the PEM file</li> + <li>The "Keychain Access" applications opens</li> + <li>Find the new certificate "mitmproxy" in the list</li> + <li>Double-click the "mitmproxy" entry</li> + <li>A dialog window openes up</li> + <li>Change "Secure Socket Layer (SSL)" to "Always Trust"</li> + <li>Close the dialog window (and enter your password if prompted)</li> + <li>Done!</li> + </ul> + </div> + <div class="col-md-4"> + <h3 class="text-center">How to install on browsers</h3> + <ul> + <li>Safari on macOS uses the macOS keychain. So installing our CA in the system is enough.</li> + <li>Chrome on macOS uses the macOS keychain. So installing our CA in the system is enough.</li> + <li>Firefox on macOS has its own CA store and needs to be installed with Firefox-specific instructions that can be found <a href="https://wiki.mozilla.org/MozillaRootCertificate#Mozilla_Firefox">HERE</a> .</li> + </ul> + </div> + <div class="col-md-4"> + <h3 class="text-center">How to install on iOS v10.3</h3> + <ul> + <li>After certificate installation, open Settings</li> + <li>Navigate to General and then About</li> + <li>Select Certificate Trust Settings</li> + <li>Each root that has been installed via a profile will be listed below the heading Enable Full Trust For Root Certificates. Toggle mitmproxy on</li> + <li>Done!</li> + </div> + </div> + </div>`; } else if (device == "windows") { - var text = `<h3>Windows: How to install on Windows</h3> - <ul> - <li>Double-click the P12 file</li> - <li>Select Store Location for Current User and click Next</li> - <li>Click Next</li> - <li>Leave the Password column blank and click Next</li> - <li>Select Place all certificates in the following store</li> - <li>Click Browse and select Trusted Root Certification Authorities</li> - <li>Click Next and then click Finish</li> - <li>Click Yes if prompted for confirmation</li> - <li>Done!</li> - </ul>`; + var text = `<div class = "container"> + <div class="row"> + <div class="col-md-4"> + <h3 class="text-center">How to install on Windows</h3> + <ul> + <li>Double-click the P12 file</li> + <li>Select Store Location for Current User and click Next</li> + <li>Click Next</li> + <li>Leave the Password column blank and click Next</li> + <li>Select Place all certificates in the following store</li> + <li>Click Browse and select Trusted Root Certification Authorities</li> + <li>Click Next and then click Finish</li> + <li>Click Yes if prompted for confirmation</li> + <li>Done!</li> + </ul> + </div> + <div class="col-md-4"> + <h3 class="text-center">How to install on browsers</h3> + <ul> + <li>Edge and IE use the Windows CA store. So installing our CA in the system is enough.</li> + <li>Chrome on Windows uses the Windows CA store. So installing our CA in the system is enough.</li> + <li>Firefox on Windows has its own CA store and needs to be installed with Firefox-specific instructions that can be found <a href="https://wiki.mozilla.org/MozillaRootCertificate#Mozilla_Firefox">HERE</a> .</li> + </ul> + </div> + <div class="col-md-4"> + <h3 class="text-center">How to install on Windows (Automated)</h3> + <ul> + <li> >>> certutil.exe -importpfx Root mitmproxy-ca-cert.p12 </li> + <li> To know more click <a href="https://technet.microsoft.com/en-us/library/cc732443.aspx">HERE</a> </li> + </ul> + </div> + </div> + </div>`; } else if (device == "android") { - var text = `<h3>Android: How to install on Android</h3> - <ul> - <li>Open your device's Settings app</li> - <li>Under "Credential storage," tap Install from storage</li> - <li>Under "Open from," tap where you saved the certificate</li> - <li>Tap the file</li> - <li>If prompted, enter the key store password and tap OK</li> - <li>Type a name for the certificate</li> - <li>Pick VPN and apps</li> - <li>Tap OK</li> - <li>Done!</li> - </ul>`; + var text = `<div class = "container"> + <div class="row"> + <div class="col-md-4"> + <h3 class="text-center">How to install on Android</h3> + <ul> + <li>Open your device's Settings app</li> + <li>Under "Credential storage," tap Install from storage</li> + <li>Under "Open from," tap where you saved the certificate</li> + <li>Tap the file</li> + <li>If prompted, enter the key store password and tap OK</li> + <li>Type a name for the certificate</li> + <li>Pick VPN and apps</li> + <li>Tap OK</li> + <li>Done!</li> + </ul> + </div> + </div> + </div>`; } else if (device == "asterisk") { - var text = ""; + var text = `<div class = "container"> + <div class="row"> + <div class="col-md-4"> + <h3 class="text-center">How to install on Chrome on Debian/Ubuntu</h3> + <ul> + <li>Using Chrome, hit a page on your server via HTTPS and continue past the red warning page (assuming you haven't done this already)</li> + <li>Open up Chrome Settings > Show advanced settings > HTTPS/SSL > Manage Certificates</li> + <li>Click the Authorities tab and scroll down to find your certificate under the Organization Name that you gave to the certificate</li> + <li>Select it, click Edit (NOTE: in recent versions of Chrome, the button is now "Advanced" instead of "Edit"), check all the boxes and click OK. You may have to restart Chrome</li> + </ul> + </div> + <div class="col-md-4"> + <h3 class="text-center">How to install on Chrome on Linux</h3> + <ul> + <li>Open Developer Tools > Security, and select View certificate</li> + <li>Click the Details tab > Export. Choose PKCS #7, single certificate as the file format</li> + <li>Then follow my original instructions to get to the Manage Certificates page. Click the Authorities tab > Import and choose the file to which you exported the certificate, and make sure to choose PKCS #7, single certificate as the file type</li> + <li>If prompted certification store, choose Trusted Root Certificate Authorities</li> + <li>Check all boxes and click OK. Restart Chrome</li> + </ul> + </div> + <div class="col-md-4"> + <h3 class="text-center">How to install on Ubuntu (Manually)</h3> + <ul> + <li>Create a directory for extra CA certificates in /usr/share/ca-certificates: <div class="text-muted">$ sudo mkdir /usr/share/ca-certificates/extra<div></li> + <li>Copy the CA mitmproxy.crt file to this directory: <div class="text-muted">$ sudo cp mitmproxy.crt /usr/share/ca-certificates/extra/mitmproxy.crt<div></li> + <li>Let Ubuntu add the mitmproxy.crt file's path relative to /usr/share/ca-certificates to /etc/ca-certificates.conf: <div class="text-muted">$ sudo dpkg-reconfigure ca-certificates</div></li> + <li>In case of a .pem file on Ubuntu, it must first be converted to a .crt file: <div class="text-muted">$ openssl x509 -in foo.pem -inform PEM -out foo.crt</div></li> + </ul> + </div> + </div> + </div>`; } document.getElementById("dynamic").innerHTML = text; } </script> -<center> -<h2> Click to install your mitmproxy certificate: </h2> -</center> +<h2 class="text-center"> Click to install your mitmproxy certificate </h2> <div id="certbank" class="row"> <div class="col-md-3"> <a onclick="changeTo('apple')" href="/cert/pem"><i class="fa fa-apple fa-5x"></i></a> diff --git a/mitmproxy/addons/onboardingapp/templates/layout.html b/mitmproxy/addons/onboardingapp/templates/layout.html index 8726a788..f6e1b286 100644 --- a/mitmproxy/addons/onboardingapp/templates/layout.html +++ b/mitmproxy/addons/onboardingapp/templates/layout.html @@ -12,20 +12,23 @@ <link href="/static/bootstrap.min.css" rel="stylesheet"> <link href="/static/mitmproxy.css" rel="stylesheet"> <link href="/static/fontawesome/css/font-awesome.min.css" rel="stylesheet"> + <link rel="icon" href="/static/images/favicon.ico" type="image/x-icon"/> </head> <body> <div class="navbar navbar-default" role="navigation"> <div class="container"> <div class="navbar-header"> - <a class="navbar-brand" href="#">mitmproxy</a> + <a class="navbar-brand" href="#"> + <img height="20px" src="static/images/mitmproxy-long.png"/> + </a> </div> </div> </div> <div class="container"> - {% block content %} - {% end %} + {% block content %} + {% end %} </div> </body> diff --git a/mitmproxy/addons/proxyauth.py b/mitmproxy/addons/proxyauth.py index 64233e88..dc99d5cc 100644 --- a/mitmproxy/addons/proxyauth.py +++ b/mitmproxy/addons/proxyauth.py @@ -146,14 +146,14 @@ class ProxyAuth: ) elif ctx.options.proxyauth.startswith("ldap"): parts = ctx.options.proxyauth.split(':') - security = parts[0] - ldap_server = parts[1] - dn_baseauth = parts[2] - password_baseauth = parts[3] if len(parts) != 5: raise exceptions.OptionsError( "Invalid ldap specification" ) + security = parts[0] + ldap_server = parts[1] + dn_baseauth = parts[2] + password_baseauth = parts[3] if security == "ldaps": server = ldap3.Server(ldap_server, use_ssl=True) elif security == "ldap": diff --git a/mitmproxy/addons/save.py b/mitmproxy/addons/save.py index 5e739039..44afef68 100644 --- a/mitmproxy/addons/save.py +++ b/mitmproxy/addons/save.py @@ -1,11 +1,13 @@ import os.path import typing +from mitmproxy import command from mitmproxy import exceptions from mitmproxy import flowfilter from mitmproxy import io from mitmproxy import ctx from mitmproxy import flow +import mitmproxy.types class Save: @@ -48,7 +50,8 @@ class Save: if ctx.options.save_stream_file: self.start_stream_to_path(ctx.options.save_stream_file, self.filt) - def save(self, flows: typing.Sequence[flow.Flow], path: str) -> None: + @command.command("save.file") + def save(self, flows: typing.Sequence[flow.Flow], path: mitmproxy.types.Path) -> None: """ Save flows to a file. If the path starts with a +, flows are appended to the file, otherwise it is over-written. @@ -63,9 +66,6 @@ class Save: f.close() ctx.log.alert("Saved %s flows." % len(flows)) - def load(self, l): - l.add_command("save.file", self.save) - def tcp_start(self, flow): if self.stream: self.active_flows.add(flow) @@ -75,6 +75,15 @@ class Save: self.stream.add(flow) self.active_flows.discard(flow) + def websocket_start(self, flow): + if self.stream: + self.active_flows.add(flow) + + def websocket_end(self, flow): + if self.stream: + self.stream.add(flow) + self.active_flows.discard(flow) + def response(self, flow): if self.stream: self.stream.add(flow) diff --git a/mitmproxy/addons/script.py b/mitmproxy/addons/script.py index 2d030321..0a524359 100644 --- a/mitmproxy/addons/script.py +++ b/mitmproxy/addons/script.py @@ -44,13 +44,15 @@ class Script: def __init__(self, path): self.name = "scriptmanager:" + path self.path = path - self.fullpath = os.path.expanduser(path) + self.fullpath = os.path.expanduser( + path.strip("'\" ") + ) self.ns = None self.last_load = 0 self.last_mtime = 0 if not os.path.isfile(self.fullpath): - raise exceptions.OptionsError("No such script: %s" % path) + raise exceptions.OptionsError('No such script: "%s"' % self.fullpath) @property def addons(self): diff --git a/mitmproxy/addons/serverplayback.py b/mitmproxy/addons/serverplayback.py index 4a382754..d8b2299a 100644 --- a/mitmproxy/addons/serverplayback.py +++ b/mitmproxy/addons/serverplayback.py @@ -9,6 +9,7 @@ from mitmproxy import flow from mitmproxy import exceptions from mitmproxy import io from mitmproxy import command +import mitmproxy.types class ServerPlayback: @@ -31,7 +32,7 @@ class ServerPlayback: ctx.master.addons.trigger("update", []) @command.command("replay.server.file") - def load_file(self, path: str) -> None: + def load_file(self, path: mitmproxy.types.Path) -> None: try: flows = io.read_flows_from_paths([path]) except exceptions.FlowReadException as e: diff --git a/mitmproxy/addons/termlog.py b/mitmproxy/addons/termlog.py index 3a9f2c19..2a7e2d09 100644 --- a/mitmproxy/addons/termlog.py +++ b/mitmproxy/addons/termlog.py @@ -24,7 +24,8 @@ class TermLog: click.secho( e.msg, file=outfile, - fg=dict(error="red", warn="yellow").get(e.level), + fg=dict(error="red", warn="yellow", + alert="magenta").get(e.level), dim=(e.level == "debug"), err=(e.level == "error") ) diff --git a/mitmproxy/addons/view.py b/mitmproxy/addons/view.py index 8ae1f341..e87daf35 100644 --- a/mitmproxy/addons/view.py +++ b/mitmproxy/addons/view.py @@ -238,7 +238,7 @@ class View(collections.Sequence): @command.command("view.order.options") def order_options(self) -> typing.Sequence[str]: """ - A list of all the orders we support. + Choices supported by the view_order option. """ return list(sorted(self.orders.keys())) @@ -351,13 +351,13 @@ class View(collections.Sequence): ctx.master.addons.trigger("update", updated) @command.command("view.load") - def load_file(self, path: str) -> None: + def load_file(self, path: mitmproxy.types.Path) -> None: """ Load flows into the view, without processing them with addons. """ - path = os.path.expanduser(path) + spath = os.path.expanduser(path) try: - with open(path, "rb") as f: + with open(spath, "rb") as f: for i in io.FlowReader(f).stream(): # Do this to get a new ID, so we can load the same file N times and # get new flows each time. It would be more efficient to just have a @@ -365,7 +365,8 @@ class View(collections.Sequence): self.add([i.copy()]) except IOError as e: ctx.log.error(e.strerror) - return + except exceptions.FlowReadException as e: + ctx.log.error(str(e)) @command.command("view.go") def go(self, dst: int) -> None: @@ -406,8 +407,11 @@ class View(collections.Sequence): if f.killable: f.kill() if f in self._view: + # We manually pass the index here because multiple flows may have the same + # sorting key, and we cannot reconstruct the index from that. + idx = self._view.index(f) self._view.remove(f) - self.sig_view_remove.send(self, flow=f) + self.sig_view_remove.send(self, flow=f, index=idx) del self._store[f.id] self.sig_store_remove.send(self, flow=f) if len(flows) > 1: @@ -438,7 +442,10 @@ class View(collections.Sequence): @command.command("view.create") def create(self, method: str, url: str) -> None: - req = http.HTTPRequest.make(method.upper(), url) + try: + req = http.HTTPRequest.make(method.upper(), url) + except ValueError as e: + raise exceptions.CommandError("Invalid URL: %s" % e) c = connections.ClientConnection.make_dummy(("", 0)) s = connections.ServerConnection.make_dummy((req.host, req.port)) f = http.HTTPFlow(c, s) @@ -507,11 +514,12 @@ class View(collections.Sequence): self.sig_view_update.send(self, flow=f) else: try: - self._view.remove(f) - self.sig_view_remove.send(self, flow=f) + idx = self._view.index(f) except ValueError: - # The value was not in the view - pass + pass # The value was not in the view + else: + self._view.remove(f) + self.sig_view_remove.send(self, flow=f, index=idx) class Focus: @@ -554,11 +562,11 @@ class Focus: def _nearest(self, f, v): return min(v._bisect(f), len(v) - 1) - def _sig_view_remove(self, view, flow): + def _sig_view_remove(self, view, flow, index): if len(view) == 0: self.flow = None elif flow is self.flow: - self.flow = view[self._nearest(self.flow, view)] + self.index = min(index, len(self.view) - 1) def _sig_view_refresh(self, view): if len(view) == 0: diff --git a/mitmproxy/certs.py b/mitmproxy/certs.py index 572a12d0..4e10529a 100644 --- a/mitmproxy/certs.py +++ b/mitmproxy/certs.py @@ -11,7 +11,7 @@ from pyasn1.codec.der.decoder import decode from pyasn1.error import PyAsn1Error import OpenSSL -from mitmproxy.types import serializable +from mitmproxy.coretypes import serializable # Default expiry must not be too long: https://github.com/mitmproxy/mitmproxy/issues/815 DEFAULT_EXP = 94608000 # = 24 * 60 * 60 * 365 * 3 @@ -112,7 +112,7 @@ def dummy_cert(privkey, cacert, commonname, sans): [OpenSSL.crypto.X509Extension(b"subjectAltName", False, ss)]) cert.set_pubkey(cacert.get_pubkey()) cert.sign(privkey, "sha256") - return SSLCert(cert) + return Cert(cert) class CertStoreEntry: @@ -249,7 +249,7 @@ class CertStore: def add_cert_file(self, spec: str, path: str) -> None: with open(path, "rb") as f: raw = f.read() - cert = SSLCert( + cert = Cert( OpenSSL.crypto.load_certificate( OpenSSL.crypto.FILETYPE_PEM, raw)) @@ -345,7 +345,7 @@ class _GeneralNames(univ.SequenceOf): constraint.ValueSizeConstraint(1, 1024) -class SSLCert(serializable.Serializable): +class Cert(serializable.Serializable): def __init__(self, cert): """ @@ -436,7 +436,7 @@ class SSLCert(serializable.Serializable): Returns: All DNS altnames. """ - # tcp.TCPClient.convert_to_ssl assumes that this property only contains DNS altnames for hostname verification. + # tcp.TCPClient.convert_to_tls assumes that this property only contains DNS altnames for hostname verification. altnames = [] for i in range(self.x509.get_extension_count()): ext = self.x509.get_extension(i) diff --git a/mitmproxy/command.py b/mitmproxy/command.py index eae3d80c..48968c90 100644 --- a/mitmproxy/command.py +++ b/mitmproxy/command.py @@ -2,39 +2,43 @@ This module manges and invokes typed commands. """ import inspect +import types +import io import typing import shlex import textwrap import functools import sys -from mitmproxy.utils import typecheck from mitmproxy import exceptions -from mitmproxy import flow +import mitmproxy.types -Cuts = typing.Sequence[ - typing.Sequence[typing.Union[str, bytes]] -] +def verify_arg_signature(f: typing.Callable, args: list, kwargs: dict) -> None: + sig = inspect.signature(f) + try: + sig.bind(*args, **kwargs) + except TypeError as v: + raise exceptions.CommandError("command argument mismatch: %s" % v.args[0]) + +def lexer(s): + # mypy mis-identifies shlex.shlex as abstract + lex = shlex.shlex(s, posix=True) # type: ignore + lex.wordchars += "." + lex.whitespace_split = True + lex.commenters = '' + return lex -def typename(t: type, ret: bool) -> str: + +def typename(t: type) -> str: """ - Translates a type to an explanatory string. If ret is True, we're - looking at a return type, else we're looking at a parameter type. + Translates a type to an explanatory string. """ - if issubclass(t, (str, int, bool)): - return t.__name__ - elif t == typing.Sequence[flow.Flow]: - return "[flow]" if ret else "flowspec" - elif t == typing.Sequence[str]: - return "[str]" - elif t == Cuts: - return "[cuts]" if ret else "cutspec" - elif t == flow.Flow: - return "flow" - else: # pragma: no cover + to = mitmproxy.types.CommandTypes.get(t, None) + if not to: raise NotImplementedError(t) + return to.display class Command: @@ -57,13 +61,13 @@ class Command: self.returntype = sig.return_annotation def paramnames(self) -> typing.Sequence[str]: - v = [typename(i, False) for i in self.paramtypes] + v = [typename(i) for i in self.paramtypes] if self.has_positional: v[-1] = "*" + v[-1] return v def retname(self) -> str: - return typename(self.returntype, True) if self.returntype else "" + return typename(self.returntype) if self.returntype else "" def signature_help(self) -> str: params = " ".join(self.paramnames()) @@ -72,13 +76,8 @@ class Command: ret = " -> " + ret return "%s %s%s" % (self.path, params, ret) - def call(self, args: typing.Sequence[str]): - """ - Call the command with a list of arguments. At this point, all - arguments are strings. - """ - if not self.has_positional and (len(self.paramtypes) != len(args)): - raise exceptions.CommandError("Usage: %s" % self.signature_help()) + def prepare_args(self, args: typing.Sequence[str]) -> typing.List[typing.Any]: + verify_arg_signature(self.func, list(args), {}) remainder = [] # type: typing.Sequence[str] if self.has_positional: @@ -86,35 +85,47 @@ class Command: args = args[:len(self.paramtypes) - 1] pargs = [] - for i in range(len(args)): - if typecheck.check_command_type(args[i], self.paramtypes[i]): - pargs.append(args[i]) - else: - pargs.append(parsearg(self.manager, args[i], self.paramtypes[i])) + for arg, paramtype in zip(args, self.paramtypes): + pargs.append(parsearg(self.manager, arg, paramtype)) + pargs.extend(remainder) + return pargs - if remainder: - chk = typecheck.check_command_type( - remainder, - typing.Sequence[self.paramtypes[-1]] # type: ignore - ) - if chk: - pargs.extend(remainder) - else: - raise exceptions.CommandError("Invalid value type.") + def call(self, args: typing.Sequence[str]) -> typing.Any: + """ + Call the command with a list of arguments. At this point, all + arguments are strings. + """ + pargs = self.prepare_args(args) with self.manager.master.handlecontext(): ret = self.func(*pargs) - if not typecheck.check_command_type(ret, self.returntype): - raise exceptions.CommandError("Command returned unexpected data") - + if ret is None and self.returntype is None: + return + typ = mitmproxy.types.CommandTypes.get(self.returntype) + if not typ.is_valid(self.manager, typ, ret): + raise exceptions.CommandError( + "%s returned unexpected data - expected %s" % ( + self.path, typ.display + ) + ) return ret -class CommandManager: +ParseResult = typing.NamedTuple( + "ParseResult", + [ + ("value", str), + ("type", typing.Type), + ("valid", bool), + ], +) + + +class CommandManager(mitmproxy.types._CommandBase): def __init__(self, master): self.master = master - self.commands = {} + self.commands = {} # type: typing.Dict[str, Command] def collect_commands(self, addon): for i in dir(addon): @@ -126,7 +137,73 @@ class CommandManager: def add(self, path: str, func: typing.Callable): self.commands[path] = Command(self, path, func) - def call_args(self, path, args): + def parse_partial( + self, + cmdstr: str + ) -> typing.Tuple[typing.Sequence[ParseResult], typing.Sequence[str]]: + """ + Parse a possibly partial command. Return a sequence of ParseResults and a sequence of remainder type help items. + """ + buf = io.StringIO(cmdstr) + parts = [] # type: typing.List[str] + lex = lexer(buf) + while 1: + remainder = cmdstr[buf.tell():] + try: + t = lex.get_token() + except ValueError: + parts.append(remainder) + break + if not t: + break + parts.append(t) + if not parts: + parts = [""] + elif cmdstr.endswith(" "): + parts.append("") + + parse = [] # type: typing.List[ParseResult] + params = [] # type: typing.List[type] + typ = None # type: typing.Type + for i in range(len(parts)): + if i == 0: + typ = mitmproxy.types.Cmd + if parts[i] in self.commands: + params.extend(self.commands[parts[i]].paramtypes) + elif params: + typ = params.pop(0) + if typ == mitmproxy.types.Cmd and params and params[0] == mitmproxy.types.Arg: + if parts[i] in self.commands: + params[:] = self.commands[parts[i]].paramtypes + else: + typ = mitmproxy.types.Unknown + + to = mitmproxy.types.CommandTypes.get(typ, None) + valid = False + if to: + try: + to.parse(self, typ, parts[i]) + except exceptions.TypeError: + valid = False + else: + valid = True + + parse.append( + ParseResult( + value=parts[i], + type=typ, + valid=valid, + ) + ) + + remhelp = [] # type: typing.List[str] + for x in params: + remt = mitmproxy.types.CommandTypes.get(x, None) + remhelp.append(remt.display) + + return parse, remhelp + + def call_args(self, path: str, args: typing.Sequence[str]) -> typing.Any: """ Call a command using a list of string arguments. May raise CommandError. """ @@ -138,7 +215,7 @@ class CommandManager: """ Call a command using a string. May raise CommandError. """ - parts = shlex.split(cmdstr) + parts = list(lexer(cmdstr)) if not len(parts) >= 1: raise exceptions.CommandError("Invalid command: %s" % cmdstr) return self.call_args(parts[0], parts[1:]) @@ -157,45 +234,13 @@ def parsearg(manager: CommandManager, spec: str, argtype: type) -> typing.Any: """ Convert a string to a argument to the appropriate type. """ - if issubclass(argtype, str): - return spec - elif argtype == bool: - if spec == "true": - return True - elif spec == "false": - return False - else: - raise exceptions.CommandError( - "Booleans are 'true' or 'false', got %s" % spec - ) - elif issubclass(argtype, int): - try: - return int(spec) - except ValueError as e: - raise exceptions.CommandError("Expected an integer, got %s." % spec) - elif argtype == typing.Sequence[flow.Flow]: - return manager.call_args("view.resolve", [spec]) - elif argtype == Cuts: - return manager.call_args("cut", [spec]) - elif argtype == flow.Flow: - flows = manager.call_args("view.resolve", [spec]) - if len(flows) != 1: - raise exceptions.CommandError( - "Command requires one flow, specification matched %s." % len(flows) - ) - return flows[0] - elif argtype == typing.Sequence[str]: - return [i.strip() for i in spec.split(",")] - else: + t = mitmproxy.types.CommandTypes.get(argtype, None) + if not t: raise exceptions.CommandError("Unsupported argument type: %s" % argtype) - - -def verify_arg_signature(f: typing.Callable, args: list, kwargs: dict) -> None: - sig = inspect.signature(f) try: - sig.bind(*args, **kwargs) - except TypeError as v: - raise exceptions.CommandError("Argument mismatch: %s" % v.args[0]) + return t.parse(manager, argtype, spec) # type: ignore + except exceptions.TypeError as e: + raise exceptions.CommandError from e def command(path): @@ -207,3 +252,16 @@ def command(path): wrapper.__dict__["command_path"] = path return wrapper return decorator + + +def argument(name, type): + """ + Set the type of a command argument at runtime. This is useful for more + specific types such as mitmproxy.types.Choice, which we cannot annotate + directly as mypy does not like that. + """ + def decorator(f: types.FunctionType) -> types.FunctionType: + assert name in f.__annotations__ + f.__annotations__[name] = type + return f + return decorator diff --git a/mitmproxy/connections.py b/mitmproxy/connections.py index 01721a71..9c47985c 100644 --- a/mitmproxy/connections.py +++ b/mitmproxy/connections.py @@ -1,11 +1,13 @@ import time import os +import typing import uuid -from mitmproxy import stateobject +from mitmproxy import stateobject, exceptions from mitmproxy import certs from mitmproxy.net import tcp +from mitmproxy.net import tls from mitmproxy.utils import strutils @@ -16,16 +18,17 @@ class ClientConnection(tcp.BaseHandler, stateobject.StateObject): Attributes: address: Remote address - ssl_established: True if TLS is established, False otherwise + tls_established: True if TLS is established, False otherwise clientcert: The TLS client certificate mitmcert: The MITM'ed TLS server certificate presented to the client timestamp_start: Connection start timestamp - timestamp_ssl_setup: TLS established timestamp + timestamp_tls_setup: TLS established timestamp timestamp_end: Connection end timestamp sni: Server Name Indication sent by client during the TLS handshake cipher_name: The current used cipher alpn_proto_negotiated: The negotiated application protocol tls_version: TLS version + tls_extensions: TLS ClientHello extensions """ def __init__(self, client_connection, address, server): @@ -40,23 +43,24 @@ class ClientConnection(tcp.BaseHandler, stateobject.StateObject): self.rfile = None self.address = None self.clientcert = None - self.ssl_established = None + self.tls_established = None self.id = str(uuid.uuid4()) self.mitmcert = None self.timestamp_start = time.time() self.timestamp_end = None - self.timestamp_ssl_setup = None + self.timestamp_tls_setup = None self.sni = None self.cipher_name = None self.alpn_proto_negotiated = None self.tls_version = None + self.tls_extensions = None def connected(self): return bool(self.connection) and not self.finished def __repr__(self): - if self.ssl_established: + if self.tls_established: tls = "[{}] ".format(self.tls_version) else: tls = "" @@ -83,27 +87,20 @@ class ClientConnection(tcp.BaseHandler, stateobject.StateObject): def __hash__(self): return hash(self.id) - @property - def tls_established(self): - return self.ssl_established - - @tls_established.setter - def tls_established(self, value): - self.ssl_established = value - _stateobject_attributes = dict( id=str, address=tuple, - ssl_established=bool, - clientcert=certs.SSLCert, - mitmcert=certs.SSLCert, + tls_established=bool, + clientcert=certs.Cert, + mitmcert=certs.Cert, timestamp_start=float, - timestamp_ssl_setup=float, + timestamp_tls_setup=float, timestamp_end=float, sni=str, cipher_name=str, alpn_proto_negotiated=bytes, tls_version=str, + tls_extensions=typing.List[typing.Tuple[int, bytes]], ) def send(self, message): @@ -125,19 +122,29 @@ class ClientConnection(tcp.BaseHandler, stateobject.StateObject): address=address, clientcert=None, mitmcert=None, - ssl_established=False, + tls_established=False, timestamp_start=None, timestamp_end=None, - timestamp_ssl_setup=None, + timestamp_tls_setup=None, sni=None, cipher_name=None, alpn_proto_negotiated=None, tls_version=None, + tls_extensions=None, )) - def convert_to_ssl(self, cert, *args, **kwargs): - super().convert_to_ssl(cert, *args, **kwargs) - self.timestamp_ssl_setup = time.time() + def convert_to_tls(self, cert, *args, **kwargs): + # Unfortunately OpenSSL provides no way to expose all TLS extensions, so we do this dance + # here and use our Kaitai parser. + try: + client_hello = tls.ClientHello.from_file(self.rfile) + except exceptions.TlsProtocolException: # pragma: no cover + pass # if this fails, we don't want everything to go down. + else: + self.tls_extensions = client_hello.extensions + + super().convert_to_tls(cert, *args, **kwargs) + self.timestamp_tls_setup = time.time() self.mitmcert = cert sni = self.connection.get_servername() if sni: @@ -162,7 +169,7 @@ class ServerConnection(tcp.TCPClient, stateobject.StateObject): address: Remote address. Can be both a domain or an IP address. ip_address: Resolved remote IP address. source_address: Local IP address or client's source IP address. - ssl_established: True if TLS is established, False otherwise + tls_established: True if TLS is established, False otherwise cert: The certificate presented by the remote during the TLS handshake sni: Server Name Indication sent by the proxy during the TLS handshake alpn_proto_negotiated: The negotiated application protocol @@ -170,7 +177,7 @@ class ServerConnection(tcp.TCPClient, stateobject.StateObject): via: The underlying server connection (e.g. the connection to the upstream proxy in upstream proxy mode) timestamp_start: Connection start timestamp timestamp_tcp_setup: TCP ACK received timestamp - timestamp_ssl_setup: TLS established timestamp + timestamp_tls_setup: TLS established timestamp timestamp_end: Connection end timestamp """ @@ -184,15 +191,15 @@ class ServerConnection(tcp.TCPClient, stateobject.StateObject): self.timestamp_start = None self.timestamp_end = None self.timestamp_tcp_setup = None - self.timestamp_ssl_setup = None + self.timestamp_tls_setup = None def connected(self): return bool(self.connection) and not self.finished def __repr__(self): - if self.ssl_established and self.sni: + if self.tls_established and self.sni: tls = "[{}: {}] ".format(self.tls_version or "TLS", self.sni) - elif self.ssl_established: + elif self.tls_established: tls = "[{}] ".format(self.tls_version or "TLS") else: tls = "" @@ -217,27 +224,19 @@ class ServerConnection(tcp.TCPClient, stateobject.StateObject): def __hash__(self): return hash(self.id) - @property - def tls_established(self): - return self.ssl_established - - @tls_established.setter - def tls_established(self, value): - self.ssl_established = value - _stateobject_attributes = dict( id=str, address=tuple, ip_address=tuple, source_address=tuple, - ssl_established=bool, - cert=certs.SSLCert, + tls_established=bool, + cert=certs.Cert, sni=str, alpn_proto_negotiated=bytes, tls_version=str, timestamp_start=float, timestamp_tcp_setup=float, - timestamp_ssl_setup=float, + timestamp_tls_setup=float, timestamp_end=float, ) @@ -254,14 +253,14 @@ class ServerConnection(tcp.TCPClient, stateobject.StateObject): address=address, ip_address=address, cert=None, - sni=None, + sni=address[0], alpn_proto_negotiated=None, tls_version=None, source_address=('', 0), - ssl_established=False, + tls_established=False, timestamp_start=None, timestamp_tcp_setup=None, - timestamp_ssl_setup=None, + timestamp_tls_setup=None, timestamp_end=None, via=None )) @@ -277,25 +276,25 @@ class ServerConnection(tcp.TCPClient, stateobject.StateObject): self.wfile.write(message) self.wfile.flush() - def establish_ssl(self, clientcerts, sni, **kwargs): + def establish_tls(self, *, sni=None, client_certs=None, **kwargs): if sni and not isinstance(sni, str): raise ValueError("sni must be str, not " + type(sni).__name__) - clientcert = None - if clientcerts: - if os.path.isfile(clientcerts): - clientcert = clientcerts + client_cert = None + if client_certs: + if os.path.isfile(client_certs): + client_cert = client_certs else: path = os.path.join( - clientcerts, + client_certs, self.address[0].encode("idna").decode()) + ".pem" if os.path.exists(path): - clientcert = path + client_cert = path - self.convert_to_ssl(cert=clientcert, sni=sni, **kwargs) + self.convert_to_tls(cert=client_cert, sni=sni, **kwargs) self.sni = sni self.alpn_proto_negotiated = self.get_alpn_proto_negotiated() self.tls_version = self.connection.get_protocol_version_name() - self.timestamp_ssl_setup = time.time() + self.timestamp_tls_setup = time.time() def finish(self): tcp.TCPClient.finish(self) diff --git a/mitmproxy/contentviews/base.py b/mitmproxy/contentviews/base.py index 97740eea..dbaa6ccc 100644 --- a/mitmproxy/contentviews/base.py +++ b/mitmproxy/contentviews/base.py @@ -43,12 +43,15 @@ def format_dict( ) -> typing.Iterator[TViewLine]: """ Helper function that transforms the given dictionary into a list of + [ ("key", key ) ("value", value) - tuples, where key is padded to a uniform width. + ] + entries, where key is padded to a uniform width. """ - max_key_len = max(len(k) for k in d.keys()) - max_key_len = min(max_key_len, KEY_MAX) + + max_key_len = max((len(k) for k in d.keys()), default=0) + max_key_len = min((max_key_len, KEY_MAX), default=0) for key, value in d.items(): if isinstance(key, bytes): key += b":" diff --git a/mitmproxy/contentviews/image/view.py b/mitmproxy/contentviews/image/view.py index 6f75473b..fde9b39d 100644 --- a/mitmproxy/contentviews/image/view.py +++ b/mitmproxy/contentviews/image/view.py @@ -1,7 +1,7 @@ import imghdr from mitmproxy.contentviews import base -from mitmproxy.types import multidict +from mitmproxy.coretypes import multidict from . import image_parser diff --git a/mitmproxy/contentviews/multipart.py b/mitmproxy/contentviews/multipart.py index 0b0e51e2..be3dc135 100644 --- a/mitmproxy/contentviews/multipart.py +++ b/mitmproxy/contentviews/multipart.py @@ -1,5 +1,5 @@ from mitmproxy.net import http -from mitmproxy.types import multidict +from mitmproxy.coretypes import multidict from . import base diff --git a/mitmproxy/contentviews/urlencoded.py b/mitmproxy/contentviews/urlencoded.py index 79fe9c1c..a24f342a 100644 --- a/mitmproxy/contentviews/urlencoded.py +++ b/mitmproxy/contentviews/urlencoded.py @@ -1,5 +1,5 @@ from mitmproxy.net.http import url -from mitmproxy.types import multidict +from mitmproxy.coretypes import multidict from . import base diff --git a/mitmproxy/contrib/kaitaistruct/gif.py b/mitmproxy/contrib/kaitaistruct/gif.py index 820df568..76d7fc16 100644 --- a/mitmproxy/contrib/kaitaistruct/gif.py +++ b/mitmproxy/contrib/kaitaistruct/gif.py @@ -35,9 +35,11 @@ class Gif(KaitaiStruct): self.global_color_table = self._root.ColorTable(io, self, self._root) self.blocks = [] - while not self._io.is_eof(): - self.blocks.append(self._root.Block(self._io, self, self._root)) - + while True: + _ = self._root.Block(self._io, self, self._root) + self.blocks.append(_) + if ((self._io.is_eof()) or (_.block_type == self._root.BlockType.end_of_file)) : + break class ImageData(KaitaiStruct): def __init__(self, _io, _parent=None, _root=None): diff --git a/mitmproxy/controller.py b/mitmproxy/controller.py index 63117ef0..f39c1b24 100644 --- a/mitmproxy/controller.py +++ b/mitmproxy/controller.py @@ -105,16 +105,16 @@ class Reply: self.q.put(self.value) def ack(self, force=False): - if self.state not in {"start", "taken"}: - raise exceptions.ControlException( - "Reply is {}, but expected it to be start or taken.".format(self.state) - ) self.send(self.obj, force) def kill(self, force=False): self.send(exceptions.Kill, force) def send(self, msg, force=False): + if self.state not in {"start", "taken"}: + raise exceptions.ControlException( + "Reply is {}, but expected it to be start or taken.".format(self.state) + ) if self.has_message and not force: raise exceptions.ControlException("There is already a reply message.") self.value = msg diff --git a/mitmproxy/types/__init__.py b/mitmproxy/coretypes/__init__.py index e69de29b..e69de29b 100644 --- a/mitmproxy/types/__init__.py +++ b/mitmproxy/coretypes/__init__.py diff --git a/mitmproxy/types/basethread.py b/mitmproxy/coretypes/basethread.py index a3c81d19..a3c81d19 100644 --- a/mitmproxy/types/basethread.py +++ b/mitmproxy/coretypes/basethread.py diff --git a/mitmproxy/types/bidi.py b/mitmproxy/coretypes/bidi.py index 0982a34a..0982a34a 100644 --- a/mitmproxy/types/bidi.py +++ b/mitmproxy/coretypes/bidi.py diff --git a/mitmproxy/types/multidict.py b/mitmproxy/coretypes/multidict.py index bd9766a3..90f3013e 100644 --- a/mitmproxy/types/multidict.py +++ b/mitmproxy/coretypes/multidict.py @@ -1,7 +1,7 @@ from abc import ABCMeta, abstractmethod from collections.abc import MutableMapping -from mitmproxy.types import serializable +from mitmproxy.coretypes import serializable class _MultiDict(MutableMapping, metaclass=ABCMeta): diff --git a/mitmproxy/types/serializable.py b/mitmproxy/coretypes/serializable.py index cd8539b0..cd8539b0 100644 --- a/mitmproxy/types/serializable.py +++ b/mitmproxy/coretypes/serializable.py diff --git a/mitmproxy/exceptions.py b/mitmproxy/exceptions.py index 71517480..d568898b 100644 --- a/mitmproxy/exceptions.py +++ b/mitmproxy/exceptions.py @@ -112,6 +112,10 @@ class AddonHalt(MitmproxyException): pass +class TypeError(MitmproxyException): + pass + + """ Net-layer exceptions """ diff --git a/mitmproxy/flow.py b/mitmproxy/flow.py index dc778404..6a27a4a8 100644 --- a/mitmproxy/flow.py +++ b/mitmproxy/flow.py @@ -1,13 +1,12 @@ import time +import typing # noqa import uuid -from mitmproxy import controller # noqa -from mitmproxy import stateobject from mitmproxy import connections +from mitmproxy import controller, exceptions # noqa +from mitmproxy import stateobject from mitmproxy import version -import typing # noqa - class Error(stateobject.StateObject): @@ -88,7 +87,7 @@ class Flow(stateobject.StateObject): type=str, intercepted=bool, marked=bool, - metadata=dict, + metadata=typing.Dict[str, typing.Any], ) def get_state(self): @@ -145,7 +144,11 @@ class Flow(stateobject.StateObject): @property def killable(self): - return self.reply and self.reply.state == "taken" + return ( + self.reply and + self.reply.state in {"start", "taken"} and + self.reply.value != exceptions.Kill + ) def kill(self): """ @@ -153,13 +156,7 @@ class Flow(stateobject.StateObject): """ self.error = Error("Connection killed") self.intercepted = False - - # reply.state should be "taken" here, or .take() will raise an - # exception. - if self.reply.state != "taken": - self.reply.take() self.reply.kill(force=True) - self.reply.commit() self.live = False def intercept(self): @@ -179,5 +176,7 @@ class Flow(stateobject.StateObject): if not self.intercepted: return self.intercepted = False - self.reply.ack() - self.reply.commit() + # If a flow is intercepted and then duplicated, the duplicated one is not taken. + if self.reply.state == "taken": + self.reply.ack() + self.reply.commit() diff --git a/mitmproxy/flowfilter.py b/mitmproxy/flowfilter.py index 23e47e2b..d1fd8299 100644 --- a/mitmproxy/flowfilter.py +++ b/mitmproxy/flowfilter.py @@ -322,8 +322,10 @@ class FDomain(_Rex): flags = re.IGNORECASE is_binary = False - @only(http.HTTPFlow) + @only(http.HTTPFlow, websocket.WebSocketFlow) def __call__(self, f): + if isinstance(f, websocket.WebSocketFlow): + f = f.handshake_flow return bool( self.re.search(f.request.host) or self.re.search(f.request.pretty_host) @@ -342,9 +344,11 @@ class FUrl(_Rex): toks = toks[1:] return klass(*toks) - @only(http.HTTPFlow) + @only(http.HTTPFlow, websocket.WebSocketFlow) def __call__(self, f): - if not f.request: + if isinstance(f, websocket.WebSocketFlow): + f = f.handshake_flow + if not f or not f.request: return False return self.re.search(f.request.pretty_url) diff --git a/mitmproxy/io/compat.py b/mitmproxy/io/compat.py index da9d2a44..51bd116b 100644 --- a/mitmproxy/io/compat.py +++ b/mitmproxy/io/compat.py @@ -1,5 +1,9 @@ """ This module handles the import of mitmproxy flows generated by old versions. + +The flow file version is decoupled from the mitmproxy release cycle (since +v3.0.0dev) and versioning. Every change or migration gets a new flow file +version number, this prevents issues with developer builds and snapshots. """ import uuid from typing import Any, Dict, Mapping, Union # noqa @@ -119,6 +123,7 @@ def convert_200_300(data): def convert_300_4(data): data["version"] = 4 + # Ths is an empty migration to transition to the new versioning scheme. return data @@ -149,6 +154,24 @@ def convert_4_5(data): return data +def convert_5_6(data): + data["version"] = 6 + data["client_conn"]["tls_established"] = data["client_conn"].pop("ssl_established") + data["client_conn"]["timestamp_tls_setup"] = data["client_conn"].pop("timestamp_ssl_setup") + data["server_conn"]["tls_established"] = data["server_conn"].pop("ssl_established") + data["server_conn"]["timestamp_tls_setup"] = data["server_conn"].pop("timestamp_ssl_setup") + if data["server_conn"]["via"]: + data["server_conn"]["via"]["tls_established"] = data["server_conn"]["via"].pop("ssl_established") + data["server_conn"]["via"]["timestamp_tls_setup"] = data["server_conn"]["via"].pop("timestamp_ssl_setup") + return data + + +def convert_6_7(data): + data["version"] = 7 + data["client_conn"]["tls_extensions"] = None + return data + + def _convert_dict_keys(o: Any) -> Any: if isinstance(o, dict): return {strutils.always_str(k): _convert_dict_keys(v) for k, v in o.items()} @@ -201,6 +224,8 @@ converters = { (2, 0): convert_200_300, (3, 0): convert_300_4, 4: convert_4_5, + 5: convert_5_6, + 6: convert_6_7, } diff --git a/mitmproxy/master.py b/mitmproxy/master.py index b41e2a8d..a5e948f6 100644 --- a/mitmproxy/master.py +++ b/mitmproxy/master.py @@ -9,10 +9,11 @@ from mitmproxy import eventsequence from mitmproxy import exceptions from mitmproxy import command from mitmproxy import http +from mitmproxy import websocket from mitmproxy import log from mitmproxy.net import server_spec from mitmproxy.proxy.protocol import http_replay -from mitmproxy.types import basethread +from mitmproxy.coretypes import basethread from . import ctx as mitmproxy_ctx @@ -41,6 +42,7 @@ class Master: self.should_exit = threading.Event() self._server = None self.first_tick = True + self.waiting_flows = [] @property def server(self): @@ -75,7 +77,7 @@ class Master: def add_log(self, e, level): """ - level: debug, info, warn, error + level: debug, alert, info, warn, error """ self.addons.trigger("log", log.LogEntry(e, level)) @@ -117,15 +119,33 @@ class Master: self.should_exit.set() self.addons.trigger("done") + def _change_reverse_host(self, f): + """ + When we load flows in reverse proxy mode, we adjust the target host to + the reverse proxy destination for all flows we load. This makes it very + easy to replay saved flows against a different host. + """ + if self.options.mode.startswith("reverse:"): + _, upstream_spec = server_spec.parse_with_mode(self.options.mode) + f.request.host, f.request.port = upstream_spec.address + f.request.scheme = upstream_spec.scheme + def load_flow(self, f): """ - Loads a flow + Loads a flow and links websocket & handshake flows """ + if isinstance(f, http.HTTPFlow): - if self.options.mode.startswith("reverse:"): - _, upstream_spec = server_spec.parse_with_mode(self.options.mode) - f.request.host, f.request.port = upstream_spec.address - f.request.scheme = upstream_spec.scheme + self._change_reverse_host(f) + if 'websocket' in f.metadata: + self.waiting_flows.append(f) + + if isinstance(f, websocket.WebSocketFlow): + hf = [hf for hf in self.waiting_flows if hf.id == f.metadata['websocket_handshake']][0] + f.handshake_flow = hf + self.waiting_flows.remove(hf) + self._change_reverse_host(f.handshake_flow) + f.reply = controller.DummyReply() for e, o in eventsequence.iterate(f): self.addons.handle_lifecycle(e, o) diff --git a/mitmproxy/net/http/cookies.py b/mitmproxy/net/http/cookies.py index 5b410acc..7bef8757 100644 --- a/mitmproxy/net/http/cookies.py +++ b/mitmproxy/net/http/cookies.py @@ -3,7 +3,7 @@ import re import time from typing import Tuple, List, Iterable -from mitmproxy.types import multidict +from mitmproxy.coretypes import multidict """ A flexible module for cookie parsing and manipulation. @@ -114,11 +114,10 @@ def _read_cookie_pairs(s, off=0): lhs, off = _read_key(s, off) lhs = lhs.lstrip() - if lhs: - rhs = None - if off < len(s) and s[off] == "=": - rhs, off = _read_value(s, off + 1, ";") - + rhs = "" + if off < len(s) and s[off] == "=": + rhs, off = _read_value(s, off + 1, ";") + if rhs or lhs: pairs.append([lhs, rhs]) off += 1 @@ -143,25 +142,24 @@ def _read_set_cookie_pairs(s: str, off=0) -> Tuple[List[TPairs], int]: lhs, off = _read_key(s, off, ";=,") lhs = lhs.lstrip() - if lhs: - rhs = None - if off < len(s) and s[off] == "=": - rhs, off = _read_value(s, off + 1, ";,") - - # Special handliing of attributes - if lhs.lower() == "expires": - # 'expires' values can contain commas in them so they need to - # be handled separately. + rhs = "" + if off < len(s) and s[off] == "=": + rhs, off = _read_value(s, off + 1, ";,") - # We actually bank on the fact that the expires value WILL - # contain a comma. Things will fail, if they don't. + # Special handling of attributes + if lhs.lower() == "expires": + # 'expires' values can contain commas in them so they need to + # be handled separately. - # '3' is just a heuristic we use to determine whether we've - # only read a part of the expires value and we should read more. - if len(rhs) <= 3: - trail, off = _read_value(s, off + 1, ";,") - rhs = rhs + "," + trail + # We actually bank on the fact that the expires value WILL + # contain a comma. Things will fail, if they don't. + # '3' is just a heuristic we use to determine whether we've + # only read a part of the expires value and we should read more. + if len(rhs) <= 3: + trail, off = _read_value(s, off + 1, ";,") + rhs = rhs + "," + trail + if rhs or lhs: pairs.append([lhs, rhs]) # comma marks the beginning of a new cookie @@ -196,13 +194,10 @@ def _format_pairs(pairs, specials=(), sep="; "): """ vals = [] for k, v in pairs: - if v is None: - vals.append(k) - else: - if k.lower() not in specials and _has_special(v): - v = ESCAPE.sub(r"\\\1", v) - v = '"%s"' % v - vals.append("%s=%s" % (k, v)) + if k.lower() not in specials and _has_special(v): + v = ESCAPE.sub(r"\\\1", v) + v = '"%s"' % v + vals.append("%s=%s" % (k, v)) return sep.join(vals) diff --git a/mitmproxy/net/http/headers.py b/mitmproxy/net/http/headers.py index 8fc0cd43..8a58cbbc 100644 --- a/mitmproxy/net/http/headers.py +++ b/mitmproxy/net/http/headers.py @@ -1,7 +1,7 @@ import re import collections -from mitmproxy.types import multidict +from mitmproxy.coretypes import multidict from mitmproxy.utils import strutils # See also: http://lucumr.pocoo.org/2013/7/2/the-updated-guide-to-unicode/ diff --git a/mitmproxy/net/http/message.py b/mitmproxy/net/http/message.py index cb32aee4..65820f67 100644 --- a/mitmproxy/net/http/message.py +++ b/mitmproxy/net/http/message.py @@ -3,7 +3,7 @@ from typing import Optional, Union # noqa from mitmproxy.utils import strutils from mitmproxy.net.http import encoding -from mitmproxy.types import serializable +from mitmproxy.coretypes import serializable from mitmproxy.net.http import headers diff --git a/mitmproxy/net/http/request.py b/mitmproxy/net/http/request.py index 6f366a4f..6b4041f6 100644 --- a/mitmproxy/net/http/request.py +++ b/mitmproxy/net/http/request.py @@ -2,7 +2,7 @@ import re import urllib from typing import Optional, AnyStr, Dict, Iterable, Tuple, Union -from mitmproxy.types import multidict +from mitmproxy.coretypes import multidict from mitmproxy.utils import strutils from mitmproxy.net.http import multipart from mitmproxy.net.http import cookies diff --git a/mitmproxy/net/http/response.py b/mitmproxy/net/http/response.py index 18950fc7..48527d63 100644 --- a/mitmproxy/net/http/response.py +++ b/mitmproxy/net/http/response.py @@ -1,7 +1,7 @@ import time from email.utils import parsedate_tz, formatdate, mktime_tz from mitmproxy.utils import human -from mitmproxy.types import multidict +from mitmproxy.coretypes import multidict from mitmproxy.net.http import cookies from mitmproxy.net.http import headers as nheaders from mitmproxy.net.http import message diff --git a/mitmproxy/net/http/url.py b/mitmproxy/net/http/url.py index 86f65cfd..f938cb12 100644 --- a/mitmproxy/net/http/url.py +++ b/mitmproxy/net/http/url.py @@ -76,7 +76,7 @@ def encode(s: Sequence[Tuple[str, str]], similar_to: str=None) -> str: encoded = urllib.parse.urlencode(s, False, errors="surrogateescape") - if remove_trailing_equal: + if encoded and remove_trailing_equal: encoded = encoded.replace("=&", "&") if encoded[-1] == '=': encoded = encoded[:-1] diff --git a/mitmproxy/net/socks.py b/mitmproxy/net/socks.py index fdfcfb80..0b2790df 100644 --- a/mitmproxy/net/socks.py +++ b/mitmproxy/net/socks.py @@ -3,7 +3,7 @@ import array import ipaddress from mitmproxy.net import check -from mitmproxy.types import bidi +from mitmproxy.coretypes import bidi class SocksError(Exception): diff --git a/mitmproxy/net/tcp.py b/mitmproxy/net/tcp.py index 47c80e80..85217794 100644 --- a/mitmproxy/net/tcp.py +++ b/mitmproxy/net/tcp.py @@ -14,7 +14,7 @@ from OpenSSL import SSL from mitmproxy import certs from mitmproxy import exceptions -from mitmproxy.types import basethread +from mitmproxy.coretypes import basethread socket_fileobject = socket.SocketIO @@ -301,11 +301,11 @@ class _Connection: self.rfile = None self.wfile = None - self.ssl_established = False + self.tls_established = False self.finished = False def get_current_cipher(self): - if not self.ssl_established: + if not self.tls_established: return None name = self.connection.get_cipher_name() @@ -381,7 +381,7 @@ class TCPClient(_Connection): else: close_socket(self.connection) - def convert_to_ssl(self, sni=None, alpn_protos=None, **sslctx_kwargs): + def convert_to_tls(self, sni=None, alpn_protos=None, **sslctx_kwargs): context = tls.create_client_context( alpn_protos=alpn_protos, sni=sni, @@ -400,13 +400,13 @@ class TCPClient(_Connection): else: raise exceptions.TlsException("SSL handshake error: %s" % repr(v)) - self.cert = certs.SSLCert(self.connection.get_peer_certificate()) + self.cert = certs.Cert(self.connection.get_peer_certificate()) # Keep all server certificates in a list for i in self.connection.get_peer_cert_chain(): - self.server_certs.append(certs.SSLCert(i)) + self.server_certs.append(certs.Cert(i)) - self.ssl_established = True + self.tls_established = True self.rfile.set_descriptor(self.connection) self.wfile.set_descriptor(self.connection) @@ -473,7 +473,7 @@ class TCPClient(_Connection): return self.connection.gettimeout() def get_alpn_proto_negotiated(self): - if self.ssl_established: + if self.tls_established: return self.connection.get_alpn_proto_negotiated() else: return b"" @@ -491,7 +491,7 @@ class BaseHandler(_Connection): self.server = server self.clientcert = None - def convert_to_ssl(self, cert, key, **sslctx_kwargs): + def convert_to_tls(self, cert, key, **sslctx_kwargs): """ Convert connection to SSL. For a list of parameters, see tls.create_server_context(...) @@ -507,10 +507,10 @@ class BaseHandler(_Connection): self.connection.do_handshake() except SSL.Error as v: raise exceptions.TlsException("SSL handshake error: %s" % repr(v)) - self.ssl_established = True + self.tls_established = True cert = self.connection.get_peer_certificate() if cert: - self.clientcert = certs.SSLCert(cert) + self.clientcert = certs.Cert(cert) self.rfile.set_descriptor(self.connection) self.wfile.set_descriptor(self.connection) @@ -521,7 +521,7 @@ class BaseHandler(_Connection): self.connection.settimeout(n) def get_alpn_proto_negotiated(self): - if self.ssl_established: + if self.tls_established: return self.connection.get_alpn_proto_negotiated() else: return b"" diff --git a/mitmproxy/net/tls.py b/mitmproxy/net/tls.py index 74911f1e..f8eeb44b 100644 --- a/mitmproxy/net/tls.py +++ b/mitmproxy/net/tls.py @@ -2,15 +2,21 @@ # then add options to disable certain methods # https://bugs.launchpad.net/pyopenssl/+bug/1020632/comments/3 import binascii +import io import os +import struct import threading import typing from ssl import match_hostname, CertificateError import certifi from OpenSSL import SSL +from kaitaistruct import KaitaiStream +import mitmproxy.options # noqa from mitmproxy import exceptions, certs +from mitmproxy.contrib.kaitaistruct import tls_client_hello +from mitmproxy.net import check BASIC_OPTIONS = ( SSL.OP_CIPHER_SERVER_PREFERENCE @@ -52,6 +58,26 @@ METHOD_NAMES = { } +def client_arguments_from_options(options: "mitmproxy.options.Options") -> dict: + + if options.ssl_insecure: + verify = SSL.VERIFY_NONE + else: + verify = SSL.VERIFY_PEER + + method, tls_options = VERSION_CHOICES[options.ssl_version_server] + + return { + "verify": verify, + "method": method, + "options": tls_options, + "ca_path": options.ssl_verify_upstream_trusted_cadir, + "ca_pemfile": options.ssl_verify_upstream_trusted_ca, + "client_certs": options.client_certs, + "cipher_list": options.ciphers_server, + } + + class MasterSecretLogger: def __init__(self, filename): self.filename = filename @@ -189,7 +215,7 @@ def _create_ssl_context( def create_client_context( cert: str = None, sni: str = None, - address: str=None, + address: str = None, verify: int = SSL.VERIFY_NONE, **sslctx_kwargs ) -> SSL.Context: @@ -213,7 +239,7 @@ def create_client_context( ) -> bool: if is_cert_verified and depth == 0: # Verify hostname of leaf certificate. - cert = certs.SSLCert(x509) + cert = certs.Cert(x509) try: crt = dict( subjectAltName=[("DNS", x.decode("ascii", "strict")) for x in cert.altnames] @@ -270,17 +296,17 @@ def create_client_context( def create_server_context( - cert: typing.Union[certs.SSLCert, str], + cert: typing.Union[certs.Cert, str], key: SSL.PKey, handle_sni: typing.Optional[typing.Callable[[SSL.Connection], None]] = None, request_client_cert: bool = False, chain_file=None, dhparams=None, - extra_chain_certs: typing.Iterable[certs.SSLCert] = None, + extra_chain_certs: typing.Iterable[certs.Cert] = None, **sslctx_kwargs ) -> SSL.Context: """ - cert: A certs.SSLCert object or the path to a certificate + cert: A certs.Cert object or the path to a certificate chain file. handle_sni: SNI handler, should take a connection object. Server @@ -321,7 +347,7 @@ def create_server_context( ) context.use_privatekey(key) - if isinstance(cert, certs.SSLCert): + if isinstance(cert, certs.Cert): context.use_certificate(cert.x509) else: context.use_certificate_chain_file(cert) @@ -338,3 +364,119 @@ def create_server_context( SSL._lib.SSL_CTX_set_tmp_dh(context._context, dhparams) return context + + +def is_tls_record_magic(d): + """ + Returns: + True, if the passed bytes start with the TLS record magic bytes. + False, otherwise. + """ + d = d[:3] + + # TLS ClientHello magic, works for SSLv3, TLSv1.0, TLSv1.1, TLSv1.2 + # http://www.moserware.com/2009/06/first-few-milliseconds-of-https.html#client-hello + return ( + len(d) == 3 and + d[0] == 0x16 and + d[1] == 0x03 and + 0x0 <= d[2] <= 0x03 + ) + + +def get_client_hello(rfile): + """ + Peek into the socket and read all records that contain the initial client hello message. + + client_conn: + The :py:class:`client connection <mitmproxy.connections.ClientConnection>`. + + Returns: + The raw handshake packet bytes, without TLS record header(s). + """ + client_hello = b"" + client_hello_size = 1 + offset = 0 + while len(client_hello) < client_hello_size: + record_header = rfile.peek(offset + 5)[offset:] + if not is_tls_record_magic(record_header) or len(record_header) < 5: + raise exceptions.TlsProtocolException( + 'Expected TLS record, got "%s" instead.' % record_header) + record_size = struct.unpack_from("!H", record_header, 3)[0] + 5 + record_body = rfile.peek(offset + record_size)[offset + 5:] + if len(record_body) != record_size - 5: + raise exceptions.TlsProtocolException( + "Unexpected EOF in TLS handshake: %s" % record_body) + client_hello += record_body + offset += record_size + client_hello_size = struct.unpack("!I", b'\x00' + client_hello[1:4])[0] + 4 + return client_hello + + +class ClientHello: + + def __init__(self, raw_client_hello): + self._client_hello = tls_client_hello.TlsClientHello( + KaitaiStream(io.BytesIO(raw_client_hello)) + ) + + @property + def cipher_suites(self): + return self._client_hello.cipher_suites.cipher_suites + + @property + def sni(self): + if self._client_hello.extensions: + for extension in self._client_hello.extensions.extensions: + is_valid_sni_extension = ( + extension.type == 0x00 and + len(extension.body.server_names) == 1 and + extension.body.server_names[0].name_type == 0 and + check.is_valid_host(extension.body.server_names[0].host_name) + ) + if is_valid_sni_extension: + return extension.body.server_names[0].host_name.decode("idna") + return None + + @property + def alpn_protocols(self): + if self._client_hello.extensions: + for extension in self._client_hello.extensions.extensions: + if extension.type == 0x10: + return list(x.name for x in extension.body.alpn_protocols) + return [] + + @property + def extensions(self) -> typing.List[typing.Tuple[int, bytes]]: + ret = [] + if self._client_hello.extensions: + for extension in self._client_hello.extensions.extensions: + body = getattr(extension, "_raw_body", extension.body) + ret.append((extension.type, body)) + return ret + + @classmethod + def from_file(cls, client_conn) -> "ClientHello": + """ + Peek into the connection, read the initial client hello and parse it to obtain ALPN values. + client_conn: + The :py:class:`client connection <mitmproxy.connections.ClientConnection>`. + Returns: + :py:class:`client hello <mitmproxy.net.tls.ClientHello>`. + """ + try: + raw_client_hello = get_client_hello(client_conn)[4:] # exclude handshake header. + except exceptions.ProtocolException as e: + raise exceptions.TlsProtocolException('Cannot read raw Client Hello: %s' % repr(e)) + + try: + return cls(raw_client_hello) + except EOFError as e: + raise exceptions.TlsProtocolException( + 'Cannot parse Client Hello: %s, Raw Client Hello: %s' % + (repr(e), binascii.hexlify(raw_client_hello)) + ) + + def __repr__(self): + return "ClientHello(sni: %s, alpn_protocols: %s, cipher_suites: %s)" % \ + (self.sni, self.alpn_protocols, self.cipher_suites) diff --git a/mitmproxy/net/websockets/frame.py b/mitmproxy/net/websockets/frame.py index 28881f64..ac6a0812 100644 --- a/mitmproxy/net/websockets/frame.py +++ b/mitmproxy/net/websockets/frame.py @@ -6,7 +6,7 @@ from mitmproxy.net import tcp from mitmproxy.utils import strutils from mitmproxy.utils import bits from mitmproxy.utils import human -from mitmproxy.types import bidi +from mitmproxy.coretypes import bidi from .masker import Masker diff --git a/mitmproxy/options.py b/mitmproxy/options.py index 8ffa1cad..76060548 100644 --- a/mitmproxy/options.py +++ b/mitmproxy/options.py @@ -44,8 +44,6 @@ class Options(optmanager.OptManager): console_layout = None # type: str console_layout_headers = None # type: bool console_mouse = None # type: bool - console_order = None # type: str - console_order_reversed = None # type: bool console_palette = None # type: str console_palette_transparent = None # type: bool default_contentview = None # type: str @@ -98,6 +96,8 @@ class Options(optmanager.OptManager): upstream_cert = None # type: bool verbosity = None # type: str view_filter = None # type: Optional[str] + view_order = None # type: str + view_order_reversed = None # type: bool web_debug = None # type: bool web_iface = None # type: str web_open_browser = None # type: bool diff --git a/mitmproxy/platform/linux.py b/mitmproxy/platform/linux.py index 4fa3191a..f446bb72 100644 --- a/mitmproxy/platform/linux.py +++ b/mitmproxy/platform/linux.py @@ -1,12 +1,34 @@ import socket import struct +import typing -# Python socket module does not have this constant +# Python's socket module does not have these constants SO_ORIGINAL_DST = 80 +SOL_IPV6 = 41 -def original_addr(csock: socket.socket): - odestdata = csock.getsockopt(socket.SOL_IP, SO_ORIGINAL_DST, 16) - _, port, a1, a2, a3, a4 = struct.unpack("!HHBBBBxxxxxxxx", odestdata) - address = "%d.%d.%d.%d" % (a1, a2, a3, a4) - return address, port +def original_addr(csock: socket.socket) -> typing.Tuple[str, int]: + # Get the original destination on Linux. + # In theory, this can be done using the following syscalls: + # sock.getsockopt(socket.SOL_IP, SO_ORIGINAL_DST, 16) + # sock.getsockopt(SOL_IPV6, SO_ORIGINAL_DST, 28) + # + # In practice, it is a bit more complex: + # 1. We cannot rely on sock.family to decide which syscall to use because of IPv4-mapped + # IPv6 addresses. If sock.family is AF_INET6 while sock.getsockname() is ::ffff:127.0.0.1, + # we need to call the IPv4 version to get a result. + # 2. We can't just try the IPv4 syscall and then do IPv6 if that doesn't work, + # because doing the wrong syscall can apparently crash the whole Python runtime. + # As such, we use a heuristic to check which syscall to do. + is_ipv4 = "." in csock.getsockname()[0] # either 127.0.0.1 or ::ffff:127.0.0.1 + if is_ipv4: + # the struct returned here should only have 8 bytes, but invoking sock.getsockopt + # with buflen=8 doesn't work. + dst = csock.getsockopt(socket.SOL_IP, SO_ORIGINAL_DST, 16) + port, raw_ip = struct.unpack_from("!2xH4s", dst) + ip = socket.inet_ntop(socket.AF_INET, raw_ip) + else: + dst = csock.getsockopt(SOL_IPV6, SO_ORIGINAL_DST, 28) + port, raw_ip = struct.unpack_from("!2xH4x16s", dst) + ip = socket.inet_ntop(socket.AF_INET6, raw_ip) + return ip, port diff --git a/mitmproxy/platform/pf.py b/mitmproxy/platform/pf.py index c0397d78..bb5eb515 100644 --- a/mitmproxy/platform/pf.py +++ b/mitmproxy/platform/pf.py @@ -1,3 +1,4 @@ +import re import sys @@ -8,6 +9,9 @@ def lookup(address, port, s): Returns an (address, port) tuple, or None. """ + # We may get an ipv4-mapped ipv6 address here, e.g. ::ffff:127.0.0.1. + # Those still appear as "127.0.0.1" in the table, so we need to strip the prefix. + address = re.sub("^::ffff:(?=\d+.\d+.\d+.\d+$)", "", address) s = s.decode() spec = "%s:%s" % (address, port) for i in s.split("\n"): diff --git a/mitmproxy/proxy/protocol/__init__.py b/mitmproxy/proxy/protocol/__init__.py index 6dbdd13c..5860542a 100644 --- a/mitmproxy/proxy/protocol/__init__.py +++ b/mitmproxy/proxy/protocol/__init__.py @@ -36,13 +36,11 @@ from .http1 import Http1Layer from .http2 import Http2Layer from .websocket import WebSocketLayer from .rawtcp import RawTCPLayer -from .tls import TlsClientHello from .tls import TlsLayer -from .tls import is_tls_record_magic __all__ = [ "Layer", "ServerConnectionMixin", - "TlsLayer", "is_tls_record_magic", "TlsClientHello", + "TlsLayer", "UpstreamConnectLayer", "HttpLayer", "Http1Layer", diff --git a/mitmproxy/proxy/protocol/http.py b/mitmproxy/proxy/protocol/http.py index 57ac0f16..076ffa62 100644 --- a/mitmproxy/proxy/protocol/http.py +++ b/mitmproxy/proxy/protocol/http.py @@ -321,6 +321,7 @@ class HttpLayer(base.Layer): try: if websockets.check_handshake(request.headers) and websockets.check_client_version(request.headers): + f.metadata['websocket'] = True # We only support RFC6455 with WebSocket version 13 # allow inline scripts to manipulate the client handshake self.channel.ask("websocket_handshake", f) diff --git a/mitmproxy/proxy/protocol/http2.py b/mitmproxy/proxy/protocol/http2.py index cf021291..cc99a715 100644 --- a/mitmproxy/proxy/protocol/http2.py +++ b/mitmproxy/proxy/protocol/http2.py @@ -15,7 +15,7 @@ from mitmproxy.proxy.protocol import base from mitmproxy.proxy.protocol import http as httpbase import mitmproxy.net.http from mitmproxy.net import tcp -from mitmproxy.types import basethread +from mitmproxy.coretypes import basethread from mitmproxy.net.http import http2, headers from mitmproxy.utils import human diff --git a/mitmproxy/proxy/protocol/http_replay.py b/mitmproxy/proxy/protocol/http_replay.py index 00bb31c9..0f3be1ea 100644 --- a/mitmproxy/proxy/protocol/http_replay.py +++ b/mitmproxy/proxy/protocol/http_replay.py @@ -9,9 +9,9 @@ from mitmproxy import http from mitmproxy import flow from mitmproxy import options from mitmproxy import connections -from mitmproxy.net import server_spec +from mitmproxy.net import server_spec, tls from mitmproxy.net.http import http1 -from mitmproxy.types import basethread +from mitmproxy.coretypes import basethread from mitmproxy.utils import human @@ -75,9 +75,9 @@ class RequestReplayThread(basethread.BaseThread): ) if resp.status_code != 200: raise exceptions.ReplayException("Upstream server refuses CONNECT request") - server.establish_ssl( - self.options.client_certs, - sni=self.f.server_conn.sni + server.establish_tls( + sni=self.f.server_conn.sni, + **tls.client_arguments_from_options(self.options) ) r.first_line_format = "relative" else: @@ -90,9 +90,9 @@ class RequestReplayThread(basethread.BaseThread): ) server.connect() if r.scheme == "https": - server.establish_ssl( - self.options.client_certs, - sni=self.f.server_conn.sni + server.establish_tls( + sni=self.f.server_conn.sni, + **tls.client_arguments_from_options(self.options) ) r.first_line_format = "relative" diff --git a/mitmproxy/proxy/protocol/tls.py b/mitmproxy/proxy/protocol/tls.py index 21bf1417..876c1162 100644 --- a/mitmproxy/proxy/protocol/tls.py +++ b/mitmproxy/proxy/protocol/tls.py @@ -1,14 +1,9 @@ -import struct from typing import Optional # noqa from typing import Union -import io -from kaitaistruct import KaitaiStream from mitmproxy import exceptions -from mitmproxy.contrib.kaitaistruct import tls_client_hello +from mitmproxy.net import tls as net_tls from mitmproxy.proxy.protocol import base -from mitmproxy.net import check - # taken from https://testssl.sh/openssl-rfc.mappping.html CIPHER_ID_NAME_MAP = { @@ -200,7 +195,6 @@ CIPHER_ID_NAME_MAP = { 0x080080: 'RC4-64-MD5', } - # We manually need to specify this, otherwise OpenSSL may select a non-HTTP2 cipher by default. # https://mozilla.github.io/server-side-tls/ssl-config-generator/?server=apache-2.2.15&openssl=1.0.2&hsts=yes&profile=old DEFAULT_CLIENT_CIPHERS = ( @@ -216,114 +210,7 @@ DEFAULT_CLIENT_CIPHERS = ( ) -def is_tls_record_magic(d): - """ - Returns: - True, if the passed bytes start with the TLS record magic bytes. - False, otherwise. - """ - d = d[:3] - - # TLS ClientHello magic, works for SSLv3, TLSv1.0, TLSv1.1, TLSv1.2 - # http://www.moserware.com/2009/06/first-few-milliseconds-of-https.html#client-hello - return ( - len(d) == 3 and - d[0] == 0x16 and - d[1] == 0x03 and - 0x0 <= d[2] <= 0x03 - ) - - -def get_client_hello(client_conn): - """ - Peek into the socket and read all records that contain the initial client hello message. - - client_conn: - The :py:class:`client connection <mitmproxy.connections.ClientConnection>`. - - Returns: - The raw handshake packet bytes, without TLS record header(s). - """ - client_hello = b"" - client_hello_size = 1 - offset = 0 - while len(client_hello) < client_hello_size: - record_header = client_conn.rfile.peek(offset + 5)[offset:] - if not is_tls_record_magic(record_header) or len(record_header) != 5: - raise exceptions.TlsProtocolException('Expected TLS record, got "%s" instead.' % record_header) - record_size = struct.unpack("!H", record_header[3:])[0] + 5 - record_body = client_conn.rfile.peek(offset + record_size)[offset + 5:] - if len(record_body) != record_size - 5: - raise exceptions.TlsProtocolException("Unexpected EOF in TLS handshake: %s" % record_body) - client_hello += record_body - offset += record_size - client_hello_size = struct.unpack("!I", b'\x00' + client_hello[1:4])[0] + 4 - return client_hello - - -class TlsClientHello: - - def __init__(self, raw_client_hello): - self._client_hello = tls_client_hello.TlsClientHello(KaitaiStream(io.BytesIO(raw_client_hello))) - - def raw(self): - return self._client_hello - - @property - def cipher_suites(self): - return self._client_hello.cipher_suites.cipher_suites - - @property - def sni(self): - if self._client_hello.extensions: - for extension in self._client_hello.extensions.extensions: - is_valid_sni_extension = ( - extension.type == 0x00 and - len(extension.body.server_names) == 1 and - extension.body.server_names[0].name_type == 0 and - check.is_valid_host(extension.body.server_names[0].host_name) - ) - if is_valid_sni_extension: - return extension.body.server_names[0].host_name.decode("idna") - return None - - @property - def alpn_protocols(self): - if self._client_hello.extensions: - for extension in self._client_hello.extensions.extensions: - if extension.type == 0x10: - return list(x.name for x in extension.body.alpn_protocols) - return [] - - @classmethod - def from_client_conn(cls, client_conn): - """ - Peek into the connection, read the initial client hello and parse it to obtain ALPN values. - client_conn: - The :py:class:`client connection <mitmproxy.connections.ClientConnection>`. - Returns: - :py:class:`client hello <mitmproxy.proxy.protocol.tls.TlsClientHello>`. - """ - try: - raw_client_hello = get_client_hello(client_conn)[4:] # exclude handshake header. - except exceptions.ProtocolException as e: - raise exceptions.TlsProtocolException('Cannot read raw Client Hello: %s' % repr(e)) - - try: - return cls(raw_client_hello) - except EOFError as e: - raise exceptions.TlsProtocolException( - 'Cannot parse Client Hello: %s, Raw Client Hello: %s' % - (repr(e), raw_client_hello.encode("hex")) - ) - - def __repr__(self): - return "TlsClientHello( sni: %s alpn_protocols: %s, cipher_suites: %s)" % \ - (self.sni, self.alpn_protocols, self.cipher_suites) - - class TlsLayer(base.Layer): - """ The TLS layer implements transparent TLS connections. @@ -334,13 +221,13 @@ class TlsLayer(base.Layer): the server connection. """ - def __init__(self, ctx, client_tls, server_tls, custom_server_sni = None): + def __init__(self, ctx, client_tls, server_tls, custom_server_sni=None): super().__init__(ctx) self._client_tls = client_tls self._server_tls = server_tls self._custom_server_sni = custom_server_sni - self._client_hello = None # type: Optional[TlsClientHello] + self._client_hello = None # type: Optional[net_tls.ClientHello] def __call__(self): """ @@ -355,7 +242,7 @@ class TlsLayer(base.Layer): if self._client_tls: # Peek into the connection, read the initial client hello and parse it to obtain SNI and ALPN values. try: - self._client_hello = TlsClientHello.from_client_conn(self.client_conn) + self._client_hello = net_tls.ClientHello.from_file(self.client_conn.rfile) except exceptions.TlsProtocolException as e: self.log("Cannot parse Client Hello: %s" % repr(e), "error") @@ -414,7 +301,7 @@ class TlsLayer(base.Layer): if self._server_tls and not self.server_conn.tls_established: self._establish_tls_with_server() - def set_server_tls(self, server_tls: bool, sni: Union[str, None, bool]=None) -> None: + def set_server_tls(self, server_tls: bool, sni: Union[str, None, bool] = None) -> None: """ Set the TLS settings for the next server connection that will be established. This function will not alter an existing connection. @@ -487,7 +374,7 @@ class TlsLayer(base.Layer): extra_certs = None try: - self.client_conn.convert_to_ssl( + self.client_conn.convert_to_tls( cert, key, method=self.config.openssl_method_client, options=self.config.openssl_options_client, @@ -519,12 +406,14 @@ class TlsLayer(base.Layer): # We only support http/1.1 and h2. # If the server only supports spdy (next to http/1.1), it may select that # and mitmproxy would enter TCP passthrough mode, which we want to avoid. - alpn = [x for x in self._client_hello.alpn_protocols if - not (x.startswith(b"h2-") or x.startswith(b"spdy"))] + alpn = [ + x for x in self._client_hello.alpn_protocols if + not (x.startswith(b"h2-") or x.startswith(b"spdy")) + ] if alpn and b"h2" in alpn and not self.config.options.http2: alpn.remove(b"h2") - if self.client_conn.ssl_established and self.client_conn.get_alpn_proto_negotiated(): + if self.client_conn.tls_established and self.client_conn.get_alpn_proto_negotiated(): # If the client has already negotiated an ALP, then force the # server to use the same. This can only happen if the host gets # changed after the initial connection was established. E.g.: @@ -535,6 +424,9 @@ class TlsLayer(base.Layer): # * which results in garbage because the layers don' match. alpn = [self.client_conn.get_alpn_proto_negotiated()] + # We pass through the list of ciphers send by the client, because some HTTP/2 servers + # will select a non-HTTP/2 compatible cipher from our default list and then hang up + # because it's incompatible with h2. :-) ciphers_server = self.config.options.ciphers_server if not ciphers_server and self._client_tls: ciphers_server = [] @@ -543,16 +435,12 @@ class TlsLayer(base.Layer): ciphers_server.append(CIPHER_ID_NAME_MAP[id]) ciphers_server = ':'.join(ciphers_server) - self.server_conn.establish_ssl( - self.config.client_certs, - self.server_sni, - method=self.config.openssl_method_server, - options=self.config.openssl_options_server, - verify=self.config.openssl_verification_mode_server, - ca_path=self.config.options.ssl_verify_upstream_trusted_cadir, - ca_pemfile=self.config.options.ssl_verify_upstream_trusted_ca, - cipher_list=ciphers_server, + args = net_tls.client_arguments_from_options(self.config.options) + args["cipher_list"] = ciphers_server + self.server_conn.establish_tls( + sni=self.server_sni, alpn_protos=alpn, + **args ) tls_cert_err = self.server_conn.ssl_verification_error if tls_cert_err is not None: diff --git a/mitmproxy/proxy/protocol/websocket.py b/mitmproxy/proxy/protocol/websocket.py index 19546eb2..2d8458a5 100644 --- a/mitmproxy/proxy/protocol/websocket.py +++ b/mitmproxy/proxy/protocol/websocket.py @@ -1,14 +1,19 @@ -import os import socket -import struct from OpenSSL import SSL + +import wsproto +from wsproto import events +from wsproto.connection import ConnectionType, WSConnection +from wsproto.extensions import PerMessageDeflate + from mitmproxy import exceptions from mitmproxy import flow from mitmproxy.proxy.protocol import base from mitmproxy.net import tcp from mitmproxy.net import websockets from mitmproxy.websocket import WebSocketFlow, WebSocketMessage +from mitmproxy.utils import strutils class WebSocketLayer(base.Layer): @@ -44,26 +49,59 @@ class WebSocketLayer(base.Layer): self.client_frame_buffer = [] self.server_frame_buffer = [] - def _handle_frame(self, frame, source_conn, other_conn, is_server): - if frame.header.opcode & 0x8 == 0: - return self._handle_data_frame(frame, source_conn, other_conn, is_server) - elif frame.header.opcode in (websockets.OPCODE.PING, websockets.OPCODE.PONG): - return self._handle_ping_pong(frame, source_conn, other_conn, is_server) - elif frame.header.opcode == websockets.OPCODE.CLOSE: - return self._handle_close(frame, source_conn, other_conn, is_server) - else: - return self._handle_unknown_frame(frame, source_conn, other_conn, is_server) + self.connections = {} # type: Dict[object, WSConnection] + + extensions = [] + if 'Sec-WebSocket-Extensions' in handshake_flow.response.headers: + if PerMessageDeflate.name in handshake_flow.response.headers['Sec-WebSocket-Extensions']: + extensions = [PerMessageDeflate()] + self.connections[self.client_conn] = WSConnection(ConnectionType.SERVER, + extensions=extensions) + self.connections[self.server_conn] = WSConnection(ConnectionType.CLIENT, + host=handshake_flow.request.host, + resource=handshake_flow.request.path, + extensions=extensions) + if extensions: + for conn in self.connections.values(): + conn.extensions[0].finalize(conn, handshake_flow.response.headers['Sec-WebSocket-Extensions']) + + data = self.connections[self.server_conn].bytes_to_send() + self.connections[self.client_conn].receive_bytes(data) + + event = next(self.connections[self.client_conn].events()) + assert isinstance(event, events.ConnectionRequested) + + self.connections[self.client_conn].accept(event) + self.connections[self.server_conn].receive_bytes(self.connections[self.client_conn].bytes_to_send()) + assert isinstance(next(self.connections[self.server_conn].events()), events.ConnectionEstablished) + + def _handle_event(self, event, source_conn, other_conn, is_server): + if isinstance(event, events.DataReceived): + return self._handle_data_received(event, source_conn, other_conn, is_server) + elif isinstance(event, events.PingReceived): + return self._handle_ping_received(event, source_conn, other_conn, is_server) + elif isinstance(event, events.PongReceived): + return self._handle_pong_received(event, source_conn, other_conn, is_server) + elif isinstance(event, events.ConnectionClosed): + return self._handle_connection_closed(event, source_conn, other_conn, is_server) + + # fail-safe for unhandled events + return True # pragma: no cover + + def _handle_data_received(self, event, source_conn, other_conn, is_server): + fb = self.server_frame_buffer if is_server else self.client_frame_buffer + fb.append(event.data) - def _handle_data_frame(self, frame, source_conn, other_conn, is_server): + if event.message_finished: + original_chunk_sizes = [len(f) for f in fb] - fb = self.server_frame_buffer if is_server else self.client_frame_buffer - fb.append(frame) + if isinstance(event, events.TextReceived): + message_type = wsproto.frame_protocol.Opcode.TEXT + payload = ''.join(fb) + else: + message_type = wsproto.frame_protocol.Opcode.BINARY + payload = b''.join(fb) - if frame.header.fin: - payload = b''.join(f.payload for f in fb) - original_chunk_sizes = [len(f.payload) for f in fb] - message_type = fb[0].header.opcode - compressed_message = fb[0].header.rsv1 fb.clear() websocket_message = WebSocketMessage(message_type, not is_server, payload) @@ -71,13 +109,13 @@ class WebSocketLayer(base.Layer): self.flow.messages.append(websocket_message) self.channel.ask("websocket_message", self.flow) - if not self.flow.stream: + if not self.flow.stream and not websocket_message.killed: def get_chunk(payload): if len(payload) == length: # message has the same length, we can reuse the same sizes pos = 0 for s in original_chunk_sizes: - yield payload[pos:pos + s] + yield (payload[pos:pos + s], True if pos + s == length else False) pos += s else: # just re-chunk everything into 4kB frames @@ -85,95 +123,76 @@ class WebSocketLayer(base.Layer): chunk_size = 4092 if is_server else 4088 chunks = range(0, len(payload), chunk_size) for i in chunks: - yield payload[i:i + chunk_size] - - frms = [ - websockets.Frame( - payload=chunk, - opcode=frame.header.opcode, - mask=(False if is_server else 1), - masking_key=(b'' if is_server else os.urandom(4))) - for chunk in get_chunk(websocket_message.content) - ] - - if len(frms) > 0: - frms[-1].header.fin = True - else: - frms.append(websockets.Frame( - fin=True, - opcode=websockets.OPCODE.CONTINUE, - mask=(False if is_server else 1), - masking_key=(b'' if is_server else os.urandom(4)))) - - frms[0].header.opcode = message_type - frms[0].header.rsv1 = compressed_message - - for frm in frms: - other_conn.send(bytes(frm)) + yield (payload[i:i + chunk_size], True if i + chunk_size >= len(payload) else False) - else: - other_conn.send(bytes(frame)) + for chunk, final in get_chunk(websocket_message.content): + self.connections[other_conn].send_data(chunk, final) + other_conn.send(self.connections[other_conn].bytes_to_send()) - elif self.flow.stream: - other_conn.send(bytes(frame)) + if self.flow.stream: + self.connections[other_conn].send_data(event.data, event.message_finished) + other_conn.send(self.connections[other_conn].bytes_to_send()) + return True + def _handle_ping_received(self, event, source_conn, other_conn, is_server): + # PING is automatically answered with a PONG by wsproto + self.connections[other_conn].ping() + other_conn.send(self.connections[other_conn].bytes_to_send()) + source_conn.send(self.connections[source_conn].bytes_to_send()) + self.log( + "Ping Received from {}".format("server" if is_server else "client"), + "info", + [strutils.bytes_to_escaped_str(bytes(event.payload))] + ) return True - def _handle_ping_pong(self, frame, source_conn, other_conn, is_server): - # just forward the ping/pong to the other side - other_conn.send(bytes(frame)) + def _handle_pong_received(self, event, source_conn, other_conn, is_server): + self.log( + "Pong Received from {}".format("server" if is_server else "client"), + "info", + [strutils.bytes_to_escaped_str(bytes(event.payload))] + ) return True - def _handle_close(self, frame, source_conn, other_conn, is_server): + def _handle_connection_closed(self, event, source_conn, other_conn, is_server): self.flow.close_sender = "server" if is_server else "client" - if len(frame.payload) >= 2: - code, = struct.unpack('!H', frame.payload[:2]) - self.flow.close_code = code - self.flow.close_message = websockets.CLOSE_REASON.get_name(code, default='unknown status code') - if len(frame.payload) > 2: - self.flow.close_reason = frame.payload[2:] + self.flow.close_code = event.code + self.flow.close_reason = event.reason - other_conn.send(bytes(frame)) + self.connections[other_conn].close(event.code, event.reason) + other_conn.send(self.connections[other_conn].bytes_to_send()) + source_conn.send(self.connections[source_conn].bytes_to_send()) - # initiate close handshake return False - def _handle_unknown_frame(self, frame, source_conn, other_conn, is_server): - # unknown frame - just forward it - other_conn.send(bytes(frame)) - - sender = "server" if is_server else "client" - self.log("Unknown WebSocket frame received from {}".format(sender), "info", [repr(frame)]) - - return True - def __call__(self): self.flow = WebSocketFlow(self.client_conn, self.server_conn, self.handshake_flow, self) self.flow.metadata['websocket_handshake'] = self.handshake_flow.id self.handshake_flow.metadata['websocket_flow'] = self.flow.id self.channel.ask("websocket_start", self.flow) - client = self.client_conn.connection - server = self.server_conn.connection - conns = [client, server] + conns = [c.connection for c in self.connections.keys()] close_received = False try: while not self.channel.should_exit.is_set(): r = tcp.ssl_read_select(conns, 0.1) for conn in r: - source_conn = self.client_conn if conn == client else self.server_conn - other_conn = self.server_conn if conn == client else self.client_conn - is_server = (conn == self.server_conn.connection) + source_conn = self.client_conn if conn == self.client_conn.connection else self.server_conn + other_conn = self.server_conn if conn == self.client_conn.connection else self.client_conn + is_server = (source_conn == self.server_conn) frame = websockets.Frame.from_file(source_conn.rfile) + self.connections[source_conn].receive_bytes(bytes(frame)) + source_conn.send(self.connections[source_conn].bytes_to_send()) + + if close_received: + return - cont = self._handle_frame(frame, source_conn, other_conn, is_server) - if not cont: - if close_received: - return - else: - close_received = True + for event in self.connections[source_conn].events(): + if not self._handle_event(event, source_conn, other_conn, is_server): + if not close_received: + close_received = True except (socket.error, exceptions.TcpException, SSL.Error) as e: s = 'server' if is_server else 'client' self.flow.error = flow.Error("WebSocket connection closed unexpectedly by {}: {}".format(s, repr(e))) diff --git a/mitmproxy/proxy/root_context.py b/mitmproxy/proxy/root_context.py index c0ec64c9..eb0008cf 100644 --- a/mitmproxy/proxy/root_context.py +++ b/mitmproxy/proxy/root_context.py @@ -1,5 +1,6 @@ from mitmproxy import log from mitmproxy import exceptions +from mitmproxy.net import tls from mitmproxy.proxy import protocol from mitmproxy.proxy import modes from mitmproxy.proxy.protocol import http @@ -45,14 +46,14 @@ class RootContext: d = top_layer.client_conn.rfile.peek(3) except exceptions.TcpException as e: raise exceptions.ProtocolException(str(e)) - client_tls = protocol.is_tls_record_magic(d) + client_tls = tls.is_tls_record_magic(d) # 1. check for --ignore if self.config.check_ignore: ignore = self.config.check_ignore(top_layer.server_conn.address) if not ignore and client_tls: try: - client_hello = protocol.TlsClientHello.from_client_conn(self.client_conn) + client_hello = tls.ClientHello.from_file(self.client_conn.rfile) except exceptions.TlsProtocolException as e: self.log("Cannot parse Client Hello: %s" % repr(e), "error") else: @@ -76,10 +77,10 @@ class RootContext: # if the user manually sets a scheme for connect requests, we use this to decide if we # want TLS or not. if top_layer.connect_request.scheme: - tls = top_layer.connect_request.scheme == "https" + server_tls = top_layer.connect_request.scheme == "https" else: - tls = client_tls - return protocol.TlsLayer(top_layer, client_tls, tls) + server_tls = client_tls + return protocol.TlsLayer(top_layer, client_tls, server_tls) # 3. In Http Proxy mode and Upstream Proxy mode, the next layer is fixed. if isinstance(top_layer, protocol.TlsLayer): diff --git a/mitmproxy/proxy/server.py b/mitmproxy/proxy/server.py index 5171fbee..5df5383a 100644 --- a/mitmproxy/proxy/server.py +++ b/mitmproxy/proxy/server.py @@ -114,9 +114,9 @@ class ConnectionHandler: def handle(self): self.log("clientconnect", "info") - root_layer = self._create_root_layer() - + root_layer = None try: + root_layer = self._create_root_layer() root_layer = self.channel.ask("clientconnect", root_layer) root_layer() except exceptions.Kill: @@ -151,7 +151,8 @@ class ConnectionHandler: print("Please lodge a bug report at: https://github.com/mitmproxy/mitmproxy", file=sys.stderr) self.log("clientdisconnect", "info") - self.channel.tell("clientdisconnect", root_layer) + if root_layer is not None: + self.channel.tell("clientdisconnect", root_layer) self.client_conn.finish() def log(self, msg, level): diff --git a/mitmproxy/script/concurrent.py b/mitmproxy/script/concurrent.py index cbb3beb0..217fab9d 100644 --- a/mitmproxy/script/concurrent.py +++ b/mitmproxy/script/concurrent.py @@ -4,7 +4,7 @@ offload computations from mitmproxy's main master thread. """ from mitmproxy import eventsequence -from mitmproxy.types import basethread +from mitmproxy.coretypes import basethread class ScriptThread(basethread.BaseThread): @@ -17,9 +17,14 @@ def concurrent(fn): "Concurrent decorator not supported for '%s' method." % fn.__name__ ) - def _concurrent(obj): + def _concurrent(*args): + # When annotating classmethods, "self" is passed as the first argument. + # To support both class and static methods, we accept a variable number of arguments + # and take the last one as our actual hook object. + obj = args[-1] + def run(): - fn(obj) + fn(*args) if obj.reply.state == "taken": if not obj.reply.has_message: obj.reply.ack() @@ -29,8 +34,5 @@ def concurrent(fn): "script.concurrent (%s)" % fn.__name__, target=run ).start() - # Support @concurrent for class-based addons - if "." in fn.__qualname__: - return staticmethod(_concurrent) - else: - return _concurrent + + return _concurrent diff --git a/mitmproxy/stateobject.py b/mitmproxy/stateobject.py index a0deaec9..ffaf285f 100644 --- a/mitmproxy/stateobject.py +++ b/mitmproxy/stateobject.py @@ -1,18 +1,12 @@ -from typing import Any -from typing import List +import typing +from typing import Any # noqa from typing import MutableMapping # noqa -from mitmproxy.types import serializable - - -def _is_list(cls): - # The typing module is broken on Python 3.5.0, fixed on 3.5.1. - is_list_bugfix = getattr(cls, "__origin__", False) == getattr(List[Any], "__origin__", True) - return issubclass(cls, List) or is_list_bugfix +from mitmproxy.coretypes import serializable +from mitmproxy.utils import typecheck class StateObject(serializable.Serializable): - """ An object with serializable state. @@ -34,22 +28,7 @@ class StateObject(serializable.Serializable): state = {} for attr, cls in self._stateobject_attributes.items(): val = getattr(self, attr) - if val is None: - state[attr] = None - elif hasattr(val, "get_state"): - state[attr] = val.get_state() - elif _is_list(cls): - state[attr] = [x.get_state() for x in val] - elif isinstance(val, dict): - s = {} - for k, v in val.items(): - if hasattr(v, "get_state"): - s[k] = v.get_state() - else: - s[k] = v - state[attr] = s - else: - state[attr] = val + state[attr] = get_state(cls, val) return state def set_state(self, state): @@ -65,13 +44,51 @@ class StateObject(serializable.Serializable): curr = getattr(self, attr) if hasattr(curr, "set_state"): curr.set_state(val) - elif hasattr(cls, "from_state"): - obj = cls.from_state(val) - setattr(self, attr, obj) - elif _is_list(cls): - cls = cls.__parameters__[0] if cls.__parameters__ else cls.__args__[0] - setattr(self, attr, [cls.from_state(x) for x in val]) - else: # primitive types such as int, str, ... - setattr(self, attr, cls(val)) + else: + setattr(self, attr, make_object(cls, val)) if state: raise RuntimeWarning("Unexpected State in __setstate__: {}".format(state)) + + +def _process(typeinfo: typecheck.Type, val: typing.Any, make: bool) -> typing.Any: + if val is None: + return None + elif make and hasattr(typeinfo, "from_state"): + return typeinfo.from_state(val) + elif not make and hasattr(val, "get_state"): + return val.get_state() + + typename = str(typeinfo) + + if typename.startswith("typing.List"): + T = typecheck.sequence_type(typeinfo) + return [_process(T, x, make) for x in val] + elif typename.startswith("typing.Tuple"): + Ts = typecheck.tuple_types(typeinfo) + if len(Ts) != len(val): + raise ValueError("Invalid data. Expected {}, got {}.".format(Ts, val)) + return tuple( + _process(T, x, make) for T, x in zip(Ts, val) + ) + elif typename.startswith("typing.Dict"): + k_cls, v_cls = typecheck.mapping_types(typeinfo) + return { + _process(k_cls, k, make): _process(v_cls, v, make) + for k, v in val.items() + } + elif typename.startswith("typing.Any"): + # FIXME: Remove this when we remove flow.metadata + assert isinstance(val, (int, str, bool, bytes)) + return val + else: + return typeinfo(val) + + +def make_object(typeinfo: typecheck.Type, val: typing.Any) -> typing.Any: + """Create an object based on the state given in val.""" + return _process(typeinfo, val, True) + + +def get_state(typeinfo: typecheck.Type, val: typing.Any) -> typing.Any: + """Get the state of the object given as val.""" + return _process(typeinfo, val, False) diff --git a/mitmproxy/tcp.py b/mitmproxy/tcp.py index fe9f217b..11de80e9 100644 --- a/mitmproxy/tcp.py +++ b/mitmproxy/tcp.py @@ -3,7 +3,7 @@ import time from typing import List from mitmproxy import flow -from mitmproxy.types import serializable +from mitmproxy.coretypes import serializable class TCPMessage(serializable.Serializable): diff --git a/mitmproxy/test/tflow.py b/mitmproxy/test/tflow.py index e754cb54..204c7526 100644 --- a/mitmproxy/test/tflow.py +++ b/mitmproxy/test/tflow.py @@ -44,7 +44,7 @@ def twebsocketflow(client_conn=True, server_conn=True, messages=True, err=None, "GET", "http", "example.com", - "80", + 80, "/ws", "HTTP/1.1", headers=net_http.Headers( @@ -53,6 +53,8 @@ def twebsocketflow(client_conn=True, server_conn=True, messages=True, err=None, sec_websocket_version="13", sec_websocket_key="1234", ), + timestamp_start=946681200, + timestamp_end=946681201, content=b'' ) resp = http.HTTPResponse( @@ -64,6 +66,8 @@ def twebsocketflow(client_conn=True, server_conn=True, messages=True, err=None, upgrade='websocket', sec_websocket_accept=b'', ), + timestamp_start=946681202, + timestamp_end=946681203, content=b'', ) handshake_flow = http.HTTPFlow(client_conn, server_conn) @@ -71,7 +75,9 @@ def twebsocketflow(client_conn=True, server_conn=True, messages=True, err=None, handshake_flow.response = resp f = websocket.WebSocketFlow(client_conn, server_conn, handshake_flow) - handshake_flow.metadata['websocket_flow'] = f + f.metadata['websocket_handshake'] = handshake_flow.id + handshake_flow.metadata['websocket_flow'] = f.id + handshake_flow.metadata['websocket'] = True if messages is True: messages = [ @@ -151,14 +157,15 @@ def tclient_conn(): address=("127.0.0.1", 22), clientcert=None, mitmcert=None, - ssl_established=False, - timestamp_start=1, - timestamp_ssl_setup=2, - timestamp_end=3, + tls_established=False, + timestamp_start=946681200, + timestamp_tls_setup=946681201, + timestamp_end=946681206, sni="address", cipher_name="cipher", alpn_proto_negotiated=b"http/1.1", tls_version="TLSv1.2", + tls_extensions=[(0x00, bytes.fromhex("000e00000b6578616d"))], )) c.reply = controller.DummyReply() c.rfile = io.BytesIO() @@ -176,11 +183,11 @@ def tserver_conn(): source_address=("address", 22), ip_address=("192.168.0.1", 22), cert=None, - timestamp_start=1, - timestamp_tcp_setup=2, - timestamp_ssl_setup=3, - timestamp_end=4, - ssl_established=False, + timestamp_start=946681202, + timestamp_tcp_setup=946681203, + timestamp_tls_setup=946681204, + timestamp_end=946681205, + tls_established=False, sni="address", alpn_proto_negotiated=None, tls_version="TLSv1.2", diff --git a/mitmproxy/test/tutils.py b/mitmproxy/test/tutils.py index 80e5b6fd..d5b52bbe 100644 --- a/mitmproxy/test/tutils.py +++ b/mitmproxy/test/tutils.py @@ -1,4 +1,3 @@ -import time from io import BytesIO from mitmproxy.utils import data @@ -31,7 +30,9 @@ def treq(**kwargs): path=b"/path", http_version=b"HTTP/1.1", headers=http.Headers(((b"header", b"qvalue"), (b"content-length", b"7"))), - content=b"content" + content=b"content", + timestamp_start=946681200, + timestamp_end=946681201, ) default.update(kwargs) return http.Request(**default) @@ -48,8 +49,8 @@ def tresp(**kwargs): reason=b"OK", headers=http.Headers(((b"header-response", b"svalue"), (b"content-length", b"7"))), content=b"message", - timestamp_start=time.time(), - timestamp_end=time.time(), + timestamp_start=946681202, + timestamp_end=946681203, ) default.update(kwargs) return http.Response(**default) diff --git a/test/mitmproxy/types/__init__.py b/mitmproxy/tools/console/commander/__init__.py index e69de29b..e69de29b 100644 --- a/test/mitmproxy/types/__init__.py +++ b/mitmproxy/tools/console/commander/__init__.py diff --git a/mitmproxy/tools/console/commander/commander.py b/mitmproxy/tools/console/commander/commander.py new file mode 100644 index 00000000..566c42e6 --- /dev/null +++ b/mitmproxy/tools/console/commander/commander.py @@ -0,0 +1,182 @@ +import abc +import typing + +import urwid +from urwid.text_layout import calc_coords + +import mitmproxy.flow +import mitmproxy.master +import mitmproxy.command +import mitmproxy.types + + +class Completer: # pragma: no cover + @abc.abstractmethod + def cycle(self) -> str: + pass + + +class ListCompleter(Completer): + def __init__( + self, + start: str, + options: typing.Sequence[str], + ) -> None: + self.start = start + self.options = [] # type: typing.Sequence[str] + for o in options: + if o.startswith(start): + self.options.append(o) + self.options.sort() + self.offset = 0 + + def cycle(self) -> str: + if not self.options: + return self.start + ret = self.options[self.offset] + self.offset = (self.offset + 1) % len(self.options) + return ret + + +CompletionState = typing.NamedTuple( + "CompletionState", + [ + ("completer", Completer), + ("parse", typing.Sequence[mitmproxy.command.ParseResult]) + ] +) + + +class CommandBuffer: + def __init__(self, master: mitmproxy.master.Master, start: str = "") -> None: + self.master = master + self.text = self.flatten(start) + # Cursor is always within the range [0:len(buffer)]. + self._cursor = len(self.text) + self.completion = None # type: CompletionState + + @property + def cursor(self) -> int: + return self._cursor + + @cursor.setter + def cursor(self, x) -> None: + if x < 0: + self._cursor = 0 + elif x > len(self.text): + self._cursor = len(self.text) + else: + self._cursor = x + + def render(self): + """ + This function is somewhat tricky - in order to make the cursor + position valid, we have to make sure there is a + character-for-character offset match in the rendered output, up + to the cursor. Beyond that, we can add stuff. + """ + parts, remhelp = self.master.commands.parse_partial(self.text) + ret = [] + for p in parts: + if p.valid: + if p.type == mitmproxy.types.Cmd: + ret.append(("commander_command", p.value)) + else: + ret.append(("text", p.value)) + elif p.value: + ret.append(("commander_invalid", p.value)) + else: + ret.append(("text", "")) + ret.append(("text", " ")) + if remhelp: + ret.append(("text", " ")) + for v in remhelp: + ret.append(("commander_hint", "%s " % v)) + return ret + + def flatten(self, txt): + parts, _ = self.master.commands.parse_partial(txt) + return " ".join([x.value for x in parts]) + + def left(self) -> None: + self.cursor = self.cursor - 1 + + def right(self) -> None: + self.cursor = self.cursor + 1 + + def cycle_completion(self) -> None: + if not self.completion: + parts, remainhelp = self.master.commands.parse_partial(self.text[:self.cursor]) + last = parts[-1] + ct = mitmproxy.types.CommandTypes.get(last.type, None) + if ct: + self.completion = CompletionState( + completer = ListCompleter( + parts[-1].value, + ct.completion(self.master.commands, last.type, parts[-1].value) + ), + parse = parts, + ) + if self.completion: + nxt = self.completion.completer.cycle() + buf = " ".join([i.value for i in self.completion.parse[:-1]]) + " " + nxt + buf = buf.strip() + self.text = self.flatten(buf) + self.cursor = len(self.text) + + def backspace(self) -> None: + if self.cursor == 0: + return + self.text = self.flatten(self.text[:self.cursor - 1] + self.text[self.cursor:]) + self.cursor = self.cursor - 1 + self.completion = None + + def insert(self, k: str) -> None: + """ + Inserts text at the cursor. + """ + self.text = self.flatten(self.text[:self.cursor] + k + self.text[self.cursor:]) + self.cursor += 1 + self.completion = None + + +class CommandEdit(urwid.WidgetWrap): + leader = ": " + + def __init__(self, master: mitmproxy.master.Master, text: str) -> None: + super().__init__(urwid.Text(self.leader)) + self.master = master + self.cbuf = CommandBuffer(master, text) + self.update() + + def keypress(self, size, key): + if key == "backspace": + self.cbuf.backspace() + elif key == "left": + self.cbuf.left() + elif key == "right": + self.cbuf.right() + elif key == "tab": + self.cbuf.cycle_completion() + elif len(key) == 1: + self.cbuf.insert(key) + self.update() + + def update(self): + self._w.set_text([self.leader, self.cbuf.render()]) + + def render(self, size, focus=False): + (maxcol,) = size + canv = self._w.render((maxcol,)) + canv = urwid.CompositeCanvas(canv) + canv.cursor = self.get_cursor_coords((maxcol,)) + return canv + + def get_cursor_coords(self, size): + p = self.cbuf.cursor + len(self.leader) + trans = self._w.get_line_translation(size[0]) + x, y = calc_coords(self._w.get_text()[0], trans, p) + return x, y + + def get_edit_text(self): + return self.cbuf.text diff --git a/mitmproxy/tools/console/commandeditor.py b/mitmproxy/tools/console/commandexecutor.py index 17d1506b..26f92238 100644 --- a/mitmproxy/tools/console/commandeditor.py +++ b/mitmproxy/tools/console/commandexecutor.py @@ -1,17 +1,10 @@ import typing -import urwid from mitmproxy import exceptions from mitmproxy import flow -from mitmproxy.tools.console import signals - - -class CommandEdit(urwid.Edit): - def __init__(self, partial): - urwid.Edit.__init__(self, ":", partial) - def keypress(self, size, key): - return urwid.Edit.keypress(self, size, key) +from mitmproxy.tools.console import overlay +from mitmproxy.tools.console import signals class CommandExecutor: @@ -30,9 +23,15 @@ class CommandExecutor: signals.status_message.send( message="Command returned %s flows" % len(ret) ) - elif len(str(ret)) < 50: - signals.status_message.send(message=str(ret)) - else: + elif type(ret) == flow.Flow: signals.status_message.send( - message="Command returned too much data to display." + message="Command returned 1 flow" ) + else: + self.master.overlay( + overlay.DataViewerOverlay( + self.master, + ret, + ), + valign="top" + )
\ No newline at end of file diff --git a/mitmproxy/tools/console/commands.py b/mitmproxy/tools/console/commands.py index 20efcee3..1183ee9d 100644 --- a/mitmproxy/tools/console/commands.py +++ b/mitmproxy/tools/console/commands.py @@ -124,7 +124,7 @@ class CommandHelp(urwid.Frame): class Commands(urwid.Pile, layoutwidget.LayoutWidget): - title = "Commands" + title = "Command Reference" keyctx = "commands" def __init__(self, master): diff --git a/mitmproxy/tools/console/common.py b/mitmproxy/tools/console/common.py index 47a30272..8a842799 100644 --- a/mitmproxy/tools/console/common.py +++ b/mitmproxy/tools/console/common.py @@ -1,9 +1,10 @@ import platform +import typing +from functools import lru_cache import urwid import urwid.util -from functools import lru_cache from mitmproxy.utils import human # Detect Windows Subsystem for Linux @@ -43,41 +44,48 @@ def highlight_key(str, key, textattr="text", keyattr="key"): KEY_MAX = 30 -def format_keyvals(lst, key="key", val="text", indent=0): +def format_keyvals( + entries: typing.List[typing.Tuple[str, typing.Union[None, str, urwid.Widget]]], + key_format: str = "key", + value_format: str = "text", + indent: int = 0 +) -> typing.List[urwid.Columns]: """ - Format a list of (key, value) tuples. - - If key is None, it's treated specially: - - We assume a sub-value, and add an extra indent. - - The value is treated as a pre-formatted list of directives. + Format a list of (key, value) tuples. + + Args: + entries: The list to format. keys must be strings, values can also be None or urwid widgets. + The latter makes it possible to use the result of format_keyvals() as a value. + key_format: The display attribute for the key. + value_format: The display attribute for the value. + indent: Additional indent to apply. """ + max_key_len = max((len(k) for k, v in entries if k is not None), default=0) + max_key_len = min(max_key_len, KEY_MAX) + + if indent > 2: + indent -= 2 # We use dividechars=2 below, which already adds two empty spaces + ret = [] - if lst: - maxk = min(max(len(i[0]) for i in lst if i and i[0]), KEY_MAX) - for i, kv in enumerate(lst): - if kv is None: - ret.append(urwid.Text("")) - else: - if isinstance(kv[1], urwid.Widget): - v = kv[1] - elif kv[1] is None: - v = urwid.Text("") - else: - v = urwid.Text([(val, kv[1])]) - ret.append( - urwid.Columns( - [ - ("fixed", indent, urwid.Text("")), - ( - "fixed", - maxk, - urwid.Text([(key, kv[0] or "")]) - ), - v - ], - dividechars = 2 - ) - ) + for k, v in entries: + if v is None: + v = urwid.Text("") + elif not isinstance(v, urwid.Widget): + v = urwid.Text([(value_format, v)]) + ret.append( + urwid.Columns( + [ + ("fixed", indent, urwid.Text("")), + ( + "fixed", + max_key_len, + urwid.Text([(key_format, k)]) + ), + v + ], + dividechars=2 + ) + ) return ret @@ -205,19 +213,15 @@ def format_flow(f, focus, extended=False, hostheader=False, max_url_len=False): focus=focus, extended=extended, max_url_len=max_url_len, - - intercepted = f.intercepted, - acked = acked, - - req_timestamp = f.request.timestamp_start, - req_is_replay = f.request.is_replay, - req_method = f.request.method, - req_url = f.request.pretty_url if hostheader else f.request.url, - req_http_version = f.request.http_version, - - err_msg = f.error.msg if f.error else None, - - marked = f.marked, + intercepted=f.intercepted, + acked=acked, + req_timestamp=f.request.timestamp_start, + req_is_replay=f.request.is_replay, + req_method=f.request.method, + req_url=f.request.pretty_url if hostheader else f.request.url, + req_http_version=f.request.http_version, + err_msg=f.error.msg if f.error else None, + marked=f.marked, ) if f.response: if f.response.raw_content: @@ -232,11 +236,11 @@ def format_flow(f, focus, extended=False, hostheader=False, max_url_len=False): roundtrip = human.pretty_duration(duration) d.update(dict( - resp_code = f.response.status_code, - resp_reason = f.response.reason, - resp_is_replay = f.response.is_replay, - resp_clen = contentdesc, - roundtrip = roundtrip, + resp_code=f.response.status_code, + resp_reason=f.response.reason, + resp_is_replay=f.response.is_replay, + resp_clen=contentdesc, + roundtrip=roundtrip, )) t = f.response.headers.get("content-type") diff --git a/mitmproxy/tools/console/consoleaddons.py b/mitmproxy/tools/console/consoleaddons.py index 49934e4d..03f2e240 100644 --- a/mitmproxy/tools/console/consoleaddons.py +++ b/mitmproxy/tools/console/consoleaddons.py @@ -1,11 +1,15 @@ +import csv import typing from mitmproxy import ctx from mitmproxy import command from mitmproxy import exceptions from mitmproxy import flow +from mitmproxy import http from mitmproxy import contentviews from mitmproxy.utils import strutils +import mitmproxy.types + from mitmproxy.tools.console import overlay from mitmproxy.tools.console import signals @@ -32,43 +36,33 @@ console_layouts = [ ] -class Logger: - def log(self, evt): - signals.add_log(evt.msg, evt.level) - if evt.level == "alert": - signals.status_message.send( - message=str(evt.msg), - expire=2 - ) - - class UnsupportedLog: """ A small addon to dump info on flow types we don't support yet. """ def websocket_message(self, f): message = f.messages[-1] - signals.add_log(f.message_info(message), "info") - signals.add_log(strutils.bytes_to_escaped_str(message.content), "debug") + ctx.log.info(f.message_info(message)) + ctx.log.debug(message.content if isinstance(message.content, str) else strutils.bytes_to_escaped_str(message.content)) def websocket_end(self, f): - signals.add_log("WebSocket connection closed by {}: {} {}, {}".format( + ctx.log.info("WebSocket connection closed by {}: {} {}, {}".format( f.close_sender, f.close_code, f.close_message, - f.close_reason), "info") + f.close_reason)) def tcp_message(self, f): message = f.messages[-1] direction = "->" if message.from_client else "<-" - signals.add_log("{client_host}:{client_port} {direction} tcp {direction} {server_host}:{server_port}".format( + ctx.log.info("{client_host}:{client_port} {direction} tcp {direction} {server_host}:{server_port}".format( client_host=f.client_conn.address[0], client_port=f.client_conn.address[1], server_host=f.server_conn.address[0], server_port=f.server_conn.address[1], direction=direction, - ), "info") - signals.add_log(strutils.bytes_to_escaped_str(message.content), "debug") + )) + ctx.log.debug(strutils.bytes_to_escaped_str(message.content)) class ConsoleAddon: @@ -111,8 +105,7 @@ class ConsoleAddon: @command.command("console.layout.options") def layout_options(self) -> typing.Sequence[str]: """ - Returns the valid options for console layout. Use these by setting - the console_layout option. + Returns the available options for the console_layout option. """ return ["single", "vertical", "horizontal"] @@ -225,7 +218,11 @@ class ConsoleAddon: @command.command("console.choose") def console_choose( - self, prompt: str, choices: typing.Sequence[str], *cmd: str + self, + prompt: str, + choices: typing.Sequence[str], + cmd: mitmproxy.types.Cmd, + *args: mitmproxy.types.Arg ) -> None: """ Prompt the user to choose from a specified list of strings, then @@ -234,7 +231,7 @@ class ConsoleAddon: """ def callback(opt): # We're now outside of the call context... - repl = " ".join(cmd) + repl = cmd + " " + " ".join(args) repl = repl.replace("{choice}", opt) try: self.master.commands.call(repl) @@ -247,7 +244,11 @@ class ConsoleAddon: @command.command("console.choose.cmd") def console_choose_cmd( - self, prompt: str, choicecmd: str, *cmd: str + self, + prompt: str, + choicecmd: mitmproxy.types.Cmd, + subcmd: mitmproxy.types.Cmd, + *args: mitmproxy.types.Arg ) -> None: """ Prompt the user to choose from a list of strings returned by a @@ -258,10 +259,10 @@ class ConsoleAddon: def callback(opt): # We're now outside of the call context... - repl = " ".join(cmd) + repl = " ".join(args) repl = repl.replace("{choice}", opt) try: - self.master.commands.call(repl) + self.master.commands.call(subcmd + " " + repl) except exceptions.CommandError as e: signals.status_message.send(message=str(e)) @@ -272,10 +273,21 @@ class ConsoleAddon: @command.command("console.command") def console_command(self, *partial: str) -> None: """ - Prompt the user to edit a command with a (possilby empty) starting value. + Prompt the user to edit a command with a (possibly empty) starting value. """ signals.status_prompt_command.send(partial=" ".join(partial)) # type: ignore + @command.command("console.command.set") + def console_command_set(self, option: str) -> None: + """ + Prompt the user to set an option of the form "key[=value]". + """ + option_value = getattr(self.master.options, option, None) + current_value = option_value if option_value else "" + self.master.commands.call( + "console.command set %s=%s" % (option, current_value) + ) + @command.command("console.view.keybindings") def view_keybindings(self) -> None: """View the commands list.""" @@ -322,6 +334,7 @@ class ConsoleAddon: signals.pop_view_state.send(self) @command.command("console.bodyview") + @command.argument("part", type=mitmproxy.types.Choice("console.bodyview.options")) def bodyview(self, f: flow.Flow, part: str) -> None: """ Spawn an external viewer for a flow request or response body based @@ -329,17 +342,27 @@ class ConsoleAddon: correct viewier, and fall back to the programs in $PAGER or $EDITOR if necessary. """ - fpart = getattr(f, part) + fpart = getattr(f, part, None) if not fpart: - raise exceptions.CommandError("Could not view part %s." % part) + raise exceptions.CommandError("Part must be either request or response, not %s." % part) t = fpart.headers.get("content-type") content = fpart.get_content(strict=False) if not content: raise exceptions.CommandError("No content to view.") self.master.spawn_external_viewer(content, t) + @command.command("console.bodyview.options") + def bodyview_options(self) -> typing.Sequence[str]: + """ + Possible parts for console.bodyview. + """ + return ["request", "response"] + @command.command("console.edit.focus.options") def edit_focus_options(self) -> typing.Sequence[str]: + """ + Possible components for console.edit.focus. + """ return [ "cookies", "form", @@ -349,16 +372,32 @@ class ConsoleAddon: "reason", "request-headers", "response-headers", + "request-body", + "response-body", "status_code", "set-cookies", "url", ] @command.command("console.edit.focus") + @command.argument("part", type=mitmproxy.types.Choice("console.edit.focus.options")) def edit_focus(self, part: str) -> None: """ - Edit the query of the current focus. + Edit a component of the currently focused flow. """ + flow = self.master.view.focus.flow + # This shouldn't be necessary once this command is "console.edit @focus", + # but for now it is. + if not flow: + raise exceptions.CommandError("No flow selected.") + flow.backup() + + require_dummy_response = ( + part in ("response-headers", "response-body", "set-cookies") and + flow.response is None + ) + if require_dummy_response: + flow.response = http.HTTPResponse.make() if part == "cookies": self.master.switch_view("edit_focus_cookies") elif part == "form": @@ -371,6 +410,19 @@ class ConsoleAddon: self.master.switch_view("edit_focus_request_headers") elif part == "response-headers": self.master.switch_view("edit_focus_response_headers") + elif part in ("request-body", "response-body"): + if part == "request-body": + message = flow.request + else: + message = flow.response + c = self.master.spawn_editor(message.get_content(strict=False) or b"") + # Fix an issue caused by some editors when editing a + # request/response body. Many editors make it hard to save a + # file without a terminating newline on the last line. When + # editing message bodies, this can cause problems. For now, I + # just strip the newlines off the end of the body when we return + # from an editor. + message.content = c.rstrip(b"\n") elif part == "set-cookies": self.master.switch_view("edit_focus_setcookies") elif part in ["url", "method", "status_code", "reason"]: @@ -405,21 +457,38 @@ class ConsoleAddon: """ self._grideditor().cmd_delete() - @command.command("console.grideditor.readfile") - def grideditor_readfile(self, path: str) -> None: + @command.command("console.grideditor.load") + def grideditor_load(self, path: mitmproxy.types.Path) -> None: """ Read a file into the currrent cell. """ self._grideditor().cmd_read_file(path) - @command.command("console.grideditor.readfile_escaped") - def grideditor_readfile_escaped(self, path: str) -> None: + @command.command("console.grideditor.load_escaped") + def grideditor_load_escaped(self, path: mitmproxy.types.Path) -> None: """ - Read a file containing a Python-style escaped stringinto the + Read a file containing a Python-style escaped string into the currrent cell. """ self._grideditor().cmd_read_file_escaped(path) + @command.command("console.grideditor.save") + def grideditor_save(self, path: mitmproxy.types.Path) -> None: + """ + Save data to file as a CSV. + """ + rows = self._grideditor().value + try: + with open(path, "w", newline='', encoding="utf8") as fp: + writer = csv.writer(fp) + for row in rows: + writer.writerow( + [strutils.always_str(x) or "" for x in row] # type: ignore + ) + ctx.log.alert("Saved %s rows as CSV." % (len(rows))) + except IOError as e: + ctx.log.error(str(e)) + @command.command("console.grideditor.editor") def grideditor_editor(self) -> None: """ @@ -428,26 +497,33 @@ class ConsoleAddon: self._grideditor().cmd_spawn_editor() @command.command("console.flowview.mode.set") - def flowview_mode_set(self) -> None: + @command.argument("mode", type=mitmproxy.types.Choice("console.flowview.mode.options")) + def flowview_mode_set(self, mode: str) -> None: """ Set the display mode for the current flow view. """ - fv = self.master.window.current("flowview") + fv = self.master.window.current_window("flowview") if not fv: raise exceptions.CommandError("Not viewing a flow.") idx = fv.body.tab_offset - def callback(opt): - try: - self.master.commands.call_args( - "view.setval", - ["@focus", "flowview_mode_%s" % idx, opt] - ) - except exceptions.CommandError as e: - signals.status_message.send(message=str(e)) + if mode not in [i.name.lower() for i in contentviews.views]: + raise exceptions.CommandError("Invalid flowview mode.") - opts = [i.name.lower() for i in contentviews.views] - self.master.overlay(overlay.Chooser(self.master, "Mode", opts, "", callback)) + try: + self.master.commands.call_args( + "view.setval", + ["@focus", "flowview_mode_%s" % idx, mode] + ) + except exceptions.CommandError as e: + signals.status_message.send(message=str(e)) + + @command.command("console.flowview.mode.options") + def flowview_mode_options(self) -> typing.Sequence[str]: + """ + Returns the valid options for the flowview mode. + """ + return [i.name.lower() for i in contentviews.views] @command.command("console.flowview.mode") def flowview_mode(self) -> str: @@ -467,13 +543,6 @@ class ConsoleAddon: ] ) - @command.command("console.eventlog.clear") - def eventlog_clear(self) -> None: - """ - Clear the event log. - """ - signals.sig_clear_log.send(self) - @command.command("console.key.contexts") def key_contexts(self) -> typing.Sequence[str]: """ @@ -482,14 +551,20 @@ class ConsoleAddon: return list(sorted(keymap.Contexts)) @command.command("console.key.bind") - def key_bind(self, contexts: typing.Sequence[str], key: str, *command: str) -> None: + def key_bind( + self, + contexts: typing.Sequence[str], + key: str, + cmd: mitmproxy.types.Cmd, + *args: mitmproxy.types.Arg + ) -> None: """ Bind a shortcut key. """ try: self.master.keymap.add( key, - " ".join(command), + cmd + " " + " ".join(args), contexts, "" ) @@ -510,7 +585,7 @@ class ConsoleAddon: kwidget = self.master.window.current("keybindings") if not kwidget: raise exceptions.CommandError("Not viewing key bindings.") - f = kwidget.focus() + f = kwidget.get_focused_binding() if not f: raise exceptions.CommandError("No key binding focused") return f diff --git a/mitmproxy/tools/console/defaultkeys.py b/mitmproxy/tools/console/defaultkeys.py index 8c28524a..c7876288 100644 --- a/mitmproxy/tools/console/defaultkeys.py +++ b/mitmproxy/tools/console/defaultkeys.py @@ -2,8 +2,8 @@ def map(km): km.add(":", "console.command ", ["global"], "Command prompt") km.add("?", "console.view.help", ["global"], "View help") - km.add("B", "browser.start", ["global"], "View commands") - km.add("C", "console.view.commands", ["global"], "Start an attached browser") + km.add("B", "browser.start", ["global"], "Start an attached browser") + km.add("C", "console.view.commands", ["global"], "View commands") km.add("K", "console.view.keybindings", ["global"], "View key bindings") km.add("O", "console.view.options", ["global"], "View options") km.add("E", "console.view.eventlog", ["global"], "View event log") @@ -11,6 +11,7 @@ def map(km): km.add("q", "console.view.pop", ["global"], "Exit the current view") km.add("-", "console.layout.cycle", ["global"], "Cycle to next layout") km.add("shift tab", "console.panes.next", ["global"], "Focus next layout pane") + km.add("ctrl right", "console.panes.next", ["global"], "Focus next layout pane") km.add("P", "console.view.flow @focus", ["global"], "View flow details") km.add("g", "console.nav.start", ["global"], "Go to start") @@ -26,12 +27,12 @@ def map(km): km.add("ctrl b", "console.nav.pageup", ["global"], "Page up") km.add("I", "console.intercept.toggle", ["global"], "Toggle intercept") - km.add("i", "console.command set intercept=", ["global"], "Set intercept") - km.add("W", "console.command set save_stream_file=", ["global"], "Stream to file") + km.add("i", "console.command.set intercept", ["global"], "Set intercept") + km.add("W", "console.command.set save_stream_file", ["global"], "Stream to file") km.add("A", "flow.resume @all", ["flowlist", "flowview"], "Resume all intercepted flows") km.add("a", "flow.resume @focus", ["flowlist", "flowview"], "Resume this intercepted flow") km.add( - "b", "console.command cut.save s.content|@focus ''", + "b", "console.command cut.save @focus response.content ", ["flowlist", "flowview"], "Save response body to file" ) @@ -41,12 +42,12 @@ def map(km): "e", """ console.choose.cmd Format export.formats - console.command export.file {choice} @focus '' + console.command export.file {choice} @focus """, ["flowlist", "flowview"], "Export this flow to file" ) - km.add("f", "console.command set view_filter=", ["flowlist"], "Set view filter") + km.add("f", "console.command.set view_filter", ["flowlist"], "Set view filter") km.add("F", "set console_focus_follow=toggle", ["flowlist"], "Set focus follow") km.add( "ctrl l", @@ -59,7 +60,7 @@ def map(km): km.add("M", "view.marked.toggle", ["flowlist"], "Toggle viewing marked flows") km.add( "n", - "console.command view.create get https://google.com", + "console.command view.create get https://example.com/", ["flowlist"], "Create a new flow" ) @@ -67,14 +68,14 @@ def map(km): "o", """ console.choose.cmd Order view.order.options - set console_order={choice} + set view_order={choice} """, ["flowlist"], "Set flow list order" ) km.add("r", "replay.client @focus", ["flowlist", "flowview"], "Replay this flow") km.add("S", "console.command replay.server ", ["flowlist"], "Start server replay") - km.add("v", "set console_order_reversed=toggle", ["flowlist"], "Reverse flow list order") + km.add("v", "set view_order_reversed=toggle", ["flowlist"], "Reverse flow list order") km.add("U", "flow.mark @all false", ["flowlist"], "Un-set all marks") km.add("w", "console.command save.file @shown ", ["flowlist"], "Save listed flows to file") km.add("V", "flow.revert @focus", ["flowlist", "flowview"], "Revert changes to this flow") @@ -116,7 +117,15 @@ def map(km): "View flow body in an external viewer" ) km.add("p", "view.focus.prev", ["flowview"], "Go to previous flow") - km.add("m", "console.flowview.mode.set", ["flowview"], "Set flow view mode") + km.add( + "m", + """ + console.choose.cmd Mode console.flowview.mode.options + console.flowview.mode.set {choice} + """, + ["flowview"], + "Set flow view mode" + ) km.add( "z", """ @@ -137,19 +146,25 @@ def map(km): km.add("d", "console.grideditor.delete", ["grideditor"], "Delete this row") km.add( "r", - "console.command console.grideditor.readfile", + "console.command console.grideditor.load", ["grideditor"], - "Read unescaped data from file" + "Read unescaped data into the current cell from file" ) km.add( "R", - "console.command console.grideditor.readfile_escaped", + "console.command console.grideditor.load_escaped", ["grideditor"], - "Read a Python-style escaped string from file" + "Load a Python-style escaped string into the current cell from file" ) km.add("e", "console.grideditor.editor", ["grideditor"], "Edit in external editor") + km.add( + "w", + "console.command console.grideditor.save ", + ["grideditor"], + "Save data to file as CSV" + ) - km.add("z", "console.eventlog.clear", ["eventlog"], "Clear") + km.add("z", "eventstore.clear", ["eventlog"], "Clear") km.add( "a", diff --git a/mitmproxy/tools/console/eventlog.py b/mitmproxy/tools/console/eventlog.py index c3e5dd39..8083180d 100644 --- a/mitmproxy/tools/console/eventlog.py +++ b/mitmproxy/tools/console/eventlog.py @@ -1,11 +1,9 @@ +import collections + import urwid -from mitmproxy.tools.console import signals from mitmproxy.tools.console import layoutwidget -from mitmproxy import ctx from mitmproxy import log -EVENTLOG_SIZE = 10000 - class LogBufferWalker(urwid.SimpleListWalker): pass @@ -16,11 +14,17 @@ class EventLog(urwid.ListBox, layoutwidget.LayoutWidget): title = "Events" def __init__(self, master): - self.walker = LogBufferWalker([]) self.master = master - urwid.ListBox.__init__(self, self.walker) - signals.sig_add_log.connect(self.sig_add_log) - signals.sig_clear_log.connect(self.sig_clear_log) + self.walker = LogBufferWalker( + collections.deque(maxlen=self.master.events.size) + ) + + master.events.sig_add.connect(self.add_event) + master.events.sig_refresh.connect(self.refresh_events) + self.master.options.subscribe(self.refresh_events, ["verbosity"]) + self.refresh_events() + + super().__init__(self.walker) def load(self, loader): loader.add_option( @@ -37,21 +41,21 @@ class EventLog(urwid.ListBox, layoutwidget.LayoutWidget): self.set_focus(len(self.walker) - 1) elif key == "m_start": self.set_focus(0) - return urwid.ListBox.keypress(self, size, key) + return super().keypress(size, key) - def sig_add_log(self, sender, e, level): - if log.log_tier(ctx.options.verbosity) < log.log_tier(level): + def add_event(self, event_store, entry: log.LogEntry): + if log.log_tier(self.master.options.verbosity) < log.log_tier(entry.level): return - txt = "%s: %s" % (level, str(e)) - if level in ("error", "warn"): - e = urwid.Text((level, txt)) + txt = "%s: %s" % (entry.level, str(entry.msg)) + if entry.level in ("error", "warn", "alert"): + e = urwid.Text((entry.level, txt)) else: e = urwid.Text(txt) self.walker.append(e) - if len(self.walker) > EVENTLOG_SIZE: - self.walker.pop(0) if self.master.options.console_focus_follow: self.walker.set_focus(len(self.walker) - 1) - def sig_clear_log(self, sender): - self.walker[:] = [] + def refresh_events(self, *_): + self.walker.clear() + for event in self.master.events.data: + self.add_event(None, event) diff --git a/mitmproxy/tools/console/flowdetailview.py b/mitmproxy/tools/console/flowdetailview.py index 28fe1fbc..443ca526 100644 --- a/mitmproxy/tools/console/flowdetailview.py +++ b/mitmproxy/tools/console/flowdetailview.py @@ -23,157 +23,157 @@ def flowdetails(state, flow: http.HTTPFlow): metadata = flow.metadata if metadata is not None and len(metadata) > 0: - parts = [[str(k), repr(v)] for k, v in metadata.items()] + parts = [(str(k), repr(v)) for k, v in metadata.items()] text.append(urwid.Text([("head", "Metadata:")])) - text.extend(common.format_keyvals(parts, key="key", val="text", indent=4)) + text.extend(common.format_keyvals(parts, indent=4)) if sc is not None and sc.ip_address: text.append(urwid.Text([("head", "Server Connection:")])) parts = [ - ["Address", human.format_address(sc.address)], + ("Address", human.format_address(sc.address)), ] if sc.ip_address: - parts.append(["Resolved Address", human.format_address(sc.ip_address)]) + parts.append(("Resolved Address", human.format_address(sc.ip_address))) if resp: - parts.append(["HTTP Version", resp.http_version]) + parts.append(("HTTP Version", resp.http_version)) if sc.alpn_proto_negotiated: - parts.append(["ALPN", sc.alpn_proto_negotiated]) + parts.append(("ALPN", sc.alpn_proto_negotiated)) text.extend( - common.format_keyvals(parts, key="key", val="text", indent=4) + common.format_keyvals(parts, indent=4) ) c = sc.cert if c: text.append(urwid.Text([("head", "Server Certificate:")])) parts = [ - ["Type", "%s, %s bits" % c.keyinfo], - ["SHA1 digest", c.digest("sha1")], - ["Valid to", str(c.notafter)], - ["Valid from", str(c.notbefore)], - ["Serial", str(c.serial)], - [ + ("Type", "%s, %s bits" % c.keyinfo), + ("SHA1 digest", c.digest("sha1")), + ("Valid to", str(c.notafter)), + ("Valid from", str(c.notbefore)), + ("Serial", str(c.serial)), + ( "Subject", urwid.BoxAdapter( urwid.ListBox( common.format_keyvals( c.subject, - key="highlight", - val="text" + key_format="highlight" ) ), len(c.subject) ) - ], - [ + ), + ( "Issuer", urwid.BoxAdapter( urwid.ListBox( common.format_keyvals( - c.issuer, key="highlight", val="text" + c.issuer, + key_format="highlight" ) ), len(c.issuer) ) - ] + ) ] if c.altnames: parts.append( - [ + ( "Alt names", ", ".join(strutils.bytes_to_escaped_str(x) for x in c.altnames) - ] + ) ) text.extend( - common.format_keyvals(parts, key="key", val="text", indent=4) + common.format_keyvals(parts, indent=4) ) if cc is not None: text.append(urwid.Text([("head", "Client Connection:")])) parts = [ - ["Address", "{}:{}".format(cc.address[0], cc.address[1])], + ("Address", "{}:{}".format(cc.address[0], cc.address[1])), ] if req: - parts.append(["HTTP Version", req.http_version]) + parts.append(("HTTP Version", req.http_version)) if cc.tls_version: - parts.append(["TLS Version", cc.tls_version]) + parts.append(("TLS Version", cc.tls_version)) if cc.sni: - parts.append(["Server Name Indication", cc.sni]) + parts.append(("Server Name Indication", cc.sni)) if cc.cipher_name: - parts.append(["Cipher Name", cc.cipher_name]) + parts.append(("Cipher Name", cc.cipher_name)) if cc.alpn_proto_negotiated: - parts.append(["ALPN", cc.alpn_proto_negotiated]) + parts.append(("ALPN", cc.alpn_proto_negotiated)) text.extend( - common.format_keyvals(parts, key="key", val="text", indent=4) + common.format_keyvals(parts, indent=4) ) parts = [] if cc is not None and cc.timestamp_start: parts.append( - [ + ( "Client conn. established", maybe_timestamp(cc, "timestamp_start") - ] + ) ) - if cc.ssl_established: + if cc.tls_established: parts.append( - [ + ( "Client conn. TLS handshake", - maybe_timestamp(cc, "timestamp_ssl_setup") - ] + maybe_timestamp(cc, "timestamp_tls_setup") + ) ) if sc is not None and sc.timestamp_start: parts.append( - [ + ( "Server conn. initiated", maybe_timestamp(sc, "timestamp_start") - ] + ) ) parts.append( - [ + ( "Server conn. TCP handshake", maybe_timestamp(sc, "timestamp_tcp_setup") - ] + ) ) - if sc.ssl_established: + if sc.tls_established: parts.append( - [ + ( "Server conn. TLS handshake", - maybe_timestamp(sc, "timestamp_ssl_setup") - ] + maybe_timestamp(sc, "timestamp_tls_setup") + ) ) if req is not None and req.timestamp_start: parts.append( - [ + ( "First request byte", maybe_timestamp(req, "timestamp_start") - ] + ) ) parts.append( - [ + ( "Request complete", maybe_timestamp(req, "timestamp_end") - ] + ) ) if resp is not None and resp.timestamp_start: parts.append( - [ + ( "First response byte", maybe_timestamp(resp, "timestamp_start") - ] + ) ) parts.append( - [ + ( "Response complete", maybe_timestamp(resp, "timestamp_end") - ] + ) ) if parts: @@ -181,6 +181,6 @@ def flowdetails(state, flow: http.HTTPFlow): parts = sorted(parts, key=lambda p: p[1]) text.append(urwid.Text([("head", "Timing:")])) - text.extend(common.format_keyvals(parts, key="key", val="text", indent=4)) + text.extend(common.format_keyvals(parts, indent=4)) return searchable.Searchable(text) diff --git a/mitmproxy/tools/console/flowview.py b/mitmproxy/tools/console/flowview.py index 651c4330..9420c105 100644 --- a/mitmproxy/tools/console/flowview.py +++ b/mitmproxy/tools/console/flowview.py @@ -11,9 +11,9 @@ from mitmproxy.tools.console import common from mitmproxy.tools.console import layoutwidget from mitmproxy.tools.console import flowdetailview from mitmproxy.tools.console import searchable -from mitmproxy.tools.console import signals from mitmproxy.tools.console import tabs -import mitmproxy.tools.console.master # noqa +import mitmproxy.tools.console.master # noqa +from mitmproxy.utils import strutils class SearchError(Exception): @@ -55,7 +55,9 @@ class FlowDetails(tabs.Tabs): (self.tab_response, self.view_response), (self.tab_details, self.view_details), ] - self.show() + self.show() + else: + self.master.window.pop() @property def view(self): @@ -117,7 +119,7 @@ class FlowDetails(tabs.Tabs): viewmode, message ) if error: - signals.add_log(error, "error") + self.master.add_log(error, "debug") # Give hint that you have to tab for the response. if description == "No content" and isinstance(message, http.HTTPRequest): description = "No request content (press tab to view response)" @@ -153,10 +155,31 @@ class FlowDetails(tabs.Tabs): def conn_text(self, conn): if conn: + hdrs = [] + for k, v in conn.headers.fields: + # This will always force an ascii representation of headers. For example, if the server sends a + # + # X-Authors: Made with ❤ in Hamburg + # + # header, mitmproxy will display the following: + # + # X-Authors: Made with \xe2\x9d\xa4 in Hamburg. + # + # The alternative would be to just use the header's UTF-8 representation and maybe + # do `str.replace("\t", "\\t")` to exempt tabs from urwid's special characters escaping [1]. + # That would in some terminals allow rendering UTF-8 characters, but the mapping + # wouldn't be bijective, i.e. a user couldn't distinguish "\\t" and "\t". + # Also, from a security perspective, a mitmproxy user couldn't be fooled by homoglyphs. + # + # 1) https://github.com/mitmproxy/mitmproxy/issues/1833 + # https://github.com/urwid/urwid/blob/6608ee2c9932d264abd1171468d833b7a4082e13/urwid/display_common.py#L35-L36, + + k = strutils.bytes_to_escaped_str(k) + ":" + v = strutils.bytes_to_escaped_str(v) + hdrs.append((k, v)) txt = common.format_keyvals( - [(h + ":", v) for (h, v) in conn.headers.items(multi=True)], - key = "header", - val = "text" + hdrs, + key_format="header" ) viewmode = self.master.commands.call("console.flowview.mode") msg, body = self.content_view(viewmode, conn) diff --git a/mitmproxy/tools/console/grideditor/base.py b/mitmproxy/tools/console/grideditor/base.py index cdda3def..204820a8 100644 --- a/mitmproxy/tools/console/grideditor/base.py +++ b/mitmproxy/tools/console/grideditor/base.py @@ -433,7 +433,6 @@ class FocusEditor(urwid.WidgetWrap, layoutwidget.LayoutWidget): def __init__(self, master): self.master = master - self.focus_changed() def call(self, v, name, *args, **kwargs): f = getattr(v, name, None) @@ -462,7 +461,7 @@ class FocusEditor(urwid.WidgetWrap, layoutwidget.LayoutWidget): def layout_popping(self): self.call(self._w, "layout_popping") - def focus_changed(self): + def layout_pushed(self, prev): if self.master.view.focus.flow: self._w = BaseGridEditor( self.master, diff --git a/mitmproxy/tools/console/grideditor/col_bytes.py b/mitmproxy/tools/console/grideditor/col_bytes.py index da10cbaf..990253ea 100644 --- a/mitmproxy/tools/console/grideditor/col_bytes.py +++ b/mitmproxy/tools/console/grideditor/col_bytes.py @@ -46,7 +46,7 @@ class Edit(base.Cell): except ValueError: signals.status_message.send( self, - message="Invalid Python-style string encoding.", + message="Invalid data.", expire=1000 ) raise diff --git a/mitmproxy/tools/console/grideditor/col_subgrid.py b/mitmproxy/tools/console/grideditor/col_subgrid.py index 95995cd2..c9cbf66d 100644 --- a/mitmproxy/tools/console/grideditor/col_subgrid.py +++ b/mitmproxy/tools/console/grideditor/col_subgrid.py @@ -27,15 +27,8 @@ class Column(base.Column): ) return elif key == "m_select": - editor.master.view_grideditor( - self.subeditor( - editor.master, - editor.walker.get_current_value(), - editor.set_subeditor_value, - editor.walker.focus, - editor.walker.focus_col - ) - ) + self.subeditor.grideditor = editor + editor.master.switch_view("edit_focus_setcookie_attrs") else: return key diff --git a/mitmproxy/tools/console/grideditor/col_text.py b/mitmproxy/tools/console/grideditor/col_text.py index f0ac06f8..32518670 100644 --- a/mitmproxy/tools/console/grideditor/col_text.py +++ b/mitmproxy/tools/console/grideditor/col_text.py @@ -21,7 +21,7 @@ class Column(col_bytes.Column): return TEdit(data, self.encoding_args) def blank(self): - return u"" + return "" # This is the same for both edit and display. diff --git a/mitmproxy/tools/console/grideditor/col_viewany.py b/mitmproxy/tools/console/grideditor/col_viewany.py new file mode 100644 index 00000000..f5d35eee --- /dev/null +++ b/mitmproxy/tools/console/grideditor/col_viewany.py @@ -0,0 +1,33 @@ +""" +A display-only column that displays any data type. +""" + +import typing + +import urwid +from mitmproxy.tools.console.grideditor import base +from mitmproxy.utils import strutils + + +class Column(base.Column): + def Display(self, data): + return Display(data) + + Edit = Display + + def blank(self): + return "" + + +class Display(base.Cell): + def __init__(self, data: typing.Any) -> None: + self.data = data + if isinstance(data, bytes): + data = strutils.bytes_to_escaped_str(data) + if not isinstance(data, str): + data = repr(data) + w = urwid.Text(data, wrap="any") + super().__init__(w) + + def get_data(self) -> typing.Any: + return self.data diff --git a/mitmproxy/tools/console/grideditor/editors.py b/mitmproxy/tools/console/grideditor/editors.py index 074cdb77..fffd782c 100644 --- a/mitmproxy/tools/console/grideditor/editors.py +++ b/mitmproxy/tools/console/grideditor/editors.py @@ -1,12 +1,15 @@ +import urwid +import typing from mitmproxy import exceptions +from mitmproxy.net.http import Headers from mitmproxy.tools.console import layoutwidget +from mitmproxy.tools.console import signals from mitmproxy.tools.console.grideditor import base -from mitmproxy.tools.console.grideditor import col_text from mitmproxy.tools.console.grideditor import col_bytes from mitmproxy.tools.console.grideditor import col_subgrid -from mitmproxy.tools.console import signals -from mitmproxy.net.http import Headers +from mitmproxy.tools.console.grideditor import col_text +from mitmproxy.tools.console.grideditor import col_viewany class QueryEditor(base.FocusEditor): @@ -66,7 +69,6 @@ class RequestFormEditor(base.FocusEditor): class PathEditor(base.FocusEditor): # TODO: Next row on enter? - title = "Edit Path Components" columns = [ col_text.Column("Component"), @@ -99,12 +101,13 @@ class CookieEditor(base.FocusEditor): flow.request.cookies = vals -class CookieAttributeEditor(base.GridEditor): +class CookieAttributeEditor(base.FocusEditor): title = "Editing Set-Cookie attributes" columns = [ col_text.Column("Name"), col_text.Column("Value"), ] + grideditor = None # type: base.BaseGridEditor def data_in(self, data): return [(k, v or "") for k, v in data] @@ -118,6 +121,20 @@ class CookieAttributeEditor(base.GridEditor): ret.append(i) return ret + def layout_pushed(self, prev): + if self.grideditor.master.view.focus.flow: + self._w = base.BaseGridEditor( + self.grideditor.master, + self.title, + self.columns, + self.grideditor.walker.get_current_value(), + self.grideditor.set_subeditor_value, + self.grideditor.walker.focus, + self.grideditor.walker.focus_col + ) + else: + self._w = urwid.Pile([]) + class SetCookieEditor(base.FocusEditor): title = "Edit SetCookie Header" @@ -169,3 +186,31 @@ class OptionsEditor(base.GridEditor, layoutwidget.LayoutWidget): def is_error(self, col, val): pass + + +class DataViewer(base.GridEditor, layoutwidget.LayoutWidget): + title = None # type: str + + def __init__( + self, + master, + vals: typing.Union[ + typing.List[typing.List[typing.Any]], + typing.List[typing.Any], + str, + ]) -> None: + if vals: + # Whatever vals is, make it a list of rows containing lists of column values. + if isinstance(vals, str): + vals = [vals] + if not isinstance(vals[0], list): + vals = [[i] for i in vals] + + self.columns = [col_viewany.Column("")] * len(vals[0]) + super().__init__(master, vals, self.callback) + + def callback(self, vals): + pass + + def is_error(self, col, val): + pass diff --git a/mitmproxy/tools/console/help.py b/mitmproxy/tools/console/help.py index 439289f6..1b4b9ac6 100644 --- a/mitmproxy/tools/console/help.py +++ b/mitmproxy/tools/console/help.py @@ -76,7 +76,7 @@ class HelpView(tabs.Tabs, layoutwidget.LayoutWidget): def filtexp(self): text = [] - text.extend(common.format_keyvals(flowfilter.help, key="key", val="text", indent=4)) + text.extend(common.format_keyvals(flowfilter.help, indent=4)) text.append( urwid.Text( [ @@ -96,7 +96,7 @@ class HelpView(tabs.Tabs, layoutwidget.LayoutWidget): ("!(~q & ~t \"text/html\")", "Anything but requests with a text/html content type."), ] text.extend( - common.format_keyvals(examples, key="key", val="text", indent=4) + common.format_keyvals(examples, indent=4) ) return CListBox(text) diff --git a/mitmproxy/tools/console/keybindings.py b/mitmproxy/tools/console/keybindings.py index 45f5c33c..312c19f9 100644 --- a/mitmproxy/tools/console/keybindings.py +++ b/mitmproxy/tools/console/keybindings.py @@ -135,7 +135,7 @@ class KeyBindings(urwid.Pile, layoutwidget.LayoutWidget): ) self.master = master - def focus(self): + def get_focused_binding(self): if self.focus_position != 0: return None f = self.widget_list[0] diff --git a/mitmproxy/tools/console/keymap.py b/mitmproxy/tools/console/keymap.py index e406905d..fbb569a4 100644 --- a/mitmproxy/tools/console/keymap.py +++ b/mitmproxy/tools/console/keymap.py @@ -1,5 +1,5 @@ import typing -from mitmproxy.tools.console import commandeditor +from mitmproxy.tools.console import commandexecutor from mitmproxy.tools.console import signals @@ -17,6 +17,13 @@ Contexts = { } +navkeys = [ + "m_start", "m_end", "m_next", "m_select", + "up", "down", "page_up", "page_down", + "left", "right" +] + + class Binding: def __init__(self, key, command, contexts, help): self.key, self.command, self.contexts = key, command, sorted(contexts) @@ -35,7 +42,7 @@ class Binding: class Keymap: def __init__(self, master): - self.executor = commandeditor.CommandExecutor(master) + self.executor = commandexecutor.CommandExecutor(master) self.keys = {} for c in Contexts: self.keys[c] = {} @@ -122,3 +129,13 @@ class Keymap: if b: return self.executor(b.command) return key + + def handle_only(self, context: str, key: str) -> typing.Optional[str]: + """ + Like handle, but ignores global bindings. Returns the key if it has + not been handled, or None. + """ + b = self.get(context, key) + if b: + return self.executor(b.command) + return key diff --git a/mitmproxy/tools/console/master.py b/mitmproxy/tools/console/master.py index 4c7f9cc1..da35047e 100644 --- a/mitmproxy/tools/console/master.py +++ b/mitmproxy/tools/console/master.py @@ -9,6 +9,7 @@ import subprocess import sys import tempfile import traceback +import typing # noqa import urwid @@ -16,6 +17,7 @@ from mitmproxy import addons from mitmproxy import master from mitmproxy import log from mitmproxy.addons import intercept +from mitmproxy.addons import eventstore from mitmproxy.addons import readfile from mitmproxy.addons import view from mitmproxy.tools.console import consoleaddons @@ -31,7 +33,12 @@ class ConsoleMaster(master.Master): def __init__(self, opts): super().__init__(opts) + self.start_err = None # type: typing.Optional[log.LogEntry] + self.view = view.View() # type: view.View + self.events = eventstore.EventStore() + self.events.sig_add.connect(self.sig_add_log) + self.stream_path = None self.keymap = keymap.Keymap(self) defaultkeys.map(self.keymap) @@ -40,12 +47,11 @@ class ConsoleMaster(master.Master): self.view_stack = [] signals.call_in.connect(self.sig_call_in) - signals.sig_add_log.connect(self.sig_add_log) - self.addons.add(consoleaddons.Logger()) self.addons.add(*addons.default_addons()) self.addons.add( intercept.Intercept(), self.view, + self.events, consoleaddons.UnsupportedLog(), readfile.ReadFile(), consoleaddons.ConsoleAddon(self), @@ -79,13 +85,17 @@ class ConsoleMaster(master.Master): callback = self.quit, ) - def sig_add_log(self, sender, e, level): - if log.log_tier(self.options.verbosity) < log.log_tier(level): + def sig_add_log(self, event_store, entry: log.LogEntry): + if log.log_tier(self.options.verbosity) < log.log_tier(entry.level): return - if level in ("error", "warn"): - signals.status_message.send( - message = "{}: {}".format(level.title(), e) - ) + if entry.level in ("error", "warn", "alert"): + if self.first_tick: + self.start_err = entry + else: + signals.status_message.send( + message=(entry.level, "{}: {}".format(entry.level.title(), entry.msg)), + expire=5 + ) def sig_call_in(self, sender, seconds, callback, args=()): def cb(*_): @@ -195,6 +205,12 @@ class ConsoleMaster(master.Master): self.loop.set_alarm_in(0.01, self.ticker) + if self.start_err: + def display_err(*_): + self.sig_add_log(None, self.start_err) + self.start_err = None + self.loop.set_alarm_in(0.01, display_err) + self.start() try: self.loop.run() diff --git a/mitmproxy/tools/console/options.py b/mitmproxy/tools/console/options.py index 4d55aeec..54772cf0 100644 --- a/mitmproxy/tools/console/options.py +++ b/mitmproxy/tools/console/options.py @@ -117,6 +117,7 @@ class OptionListWalker(urwid.ListWalker): def stop_editing(self): self.editing = False self.focus_obj = self._get(self.index, False) + self.set_focus(self.index) self._modified() def get_edit_text(self): diff --git a/mitmproxy/tools/console/overlay.py b/mitmproxy/tools/console/overlay.py index 7072d00e..d255bc8c 100644 --- a/mitmproxy/tools/console/overlay.py +++ b/mitmproxy/tools/console/overlay.py @@ -5,6 +5,7 @@ import urwid from mitmproxy.tools.console import signals from mitmproxy.tools.console import grideditor from mitmproxy.tools.console import layoutwidget +from mitmproxy.tools.console import keymap class SimpleOverlay(urwid.Overlay, layoutwidget.LayoutWidget): @@ -39,12 +40,17 @@ class SimpleOverlay(urwid.Overlay, layoutwidget.LayoutWidget): class Choice(urwid.WidgetWrap): - def __init__(self, txt, focus, current): + def __init__(self, txt, focus, current, shortcut): + if shortcut: + selection_type = "option_selected_key" if focus else "key" + txt = [(selection_type, shortcut), ") ", txt] + else: + txt = " " + txt if current: s = "option_active_selected" if focus else "option_active" else: s = "option_selected" if focus else "text" - return super().__init__( + super().__init__( urwid.AttrWrap( urwid.Padding(urwid.Text(txt)), s, @@ -59,6 +65,8 @@ class Choice(urwid.WidgetWrap): class ChooserListWalker(urwid.ListWalker): + shortcuts = "123456789abcdefghijklmnoprstuvwxyz" + def __init__(self, choices, current): self.index = 0 self.choices = choices @@ -66,7 +74,7 @@ class ChooserListWalker(urwid.ListWalker): def _get(self, idx, focus): c = self.choices[idx] - return Choice(c, focus, c == self.current) + return Choice(c, focus, c == self.current, self.shortcuts[idx:idx + 1]) def set_focus(self, index): self.index = index @@ -86,6 +94,12 @@ class ChooserListWalker(urwid.ListWalker): return None, None return self._get(pos, False), pos + def choice_by_shortcut(self, shortcut): + for i, choice in enumerate(self.choices): + if shortcut == self.shortcuts[i:i + 1]: + return choice + return None + class Chooser(urwid.WidgetWrap, layoutwidget.LayoutWidget): keyctx = "chooser" @@ -95,7 +109,8 @@ class Chooser(urwid.WidgetWrap, layoutwidget.LayoutWidget): self.choices = choices self.callback = callback choicewidth = max([len(i) for i in choices]) - self.width = max(choicewidth, len(title)) + 5 + self.width = max(choicewidth, len(title)) + 7 + self.walker = ChooserListWalker(choices, current) super().__init__( urwid.AttrWrap( @@ -104,7 +119,7 @@ class Chooser(urwid.WidgetWrap, layoutwidget.LayoutWidget): urwid.ListBox(self.walker), len(choices) ), - title= title + title=title ), "background" ) @@ -114,13 +129,26 @@ class Chooser(urwid.WidgetWrap, layoutwidget.LayoutWidget): return True def keypress(self, size, key): - key = self.master.keymap.handle("chooser", key) + key = self.master.keymap.handle_only("chooser", key) + choice = self.walker.choice_by_shortcut(key) + if choice: + self.callback(choice) + signals.pop_view_state.send(self) + return if key == "m_select": self.callback(self.choices[self.walker.index]) signals.pop_view_state.send(self) - elif key == "esc": + return + elif key in ["q", "esc"]: signals.pop_view_state.send(self) - return super().keypress(size, key) + return + + binding = self.master.keymap.get("global", key) + # This is extremely awkward. We need a better way to match nav keys only. + if binding and binding.command.startswith("console.nav"): + self.master.keymap.handle("global", key) + elif key in keymap.navkeys: + return super().keypress(size, key) class OptionsOverlay(urwid.WidgetWrap, layoutwidget.LayoutWidget): @@ -148,3 +176,30 @@ class OptionsOverlay(urwid.WidgetWrap, layoutwidget.LayoutWidget): def layout_popping(self): return self.ge.layout_popping() + + +class DataViewerOverlay(urwid.WidgetWrap, layoutwidget.LayoutWidget): + keyctx = "grideditor" + + def __init__(self, master, vals): + """ + vspace: how much vertical space to keep clear + """ + cols, rows = master.ui.get_cols_rows() + self.ge = grideditor.DataViewer(master, vals) + super().__init__( + urwid.AttrWrap( + urwid.LineBox( + urwid.BoxAdapter(self.ge, rows - 5), + title="Data viewer" + ), + "background" + ) + ) + self.width = math.ceil(cols * 0.8) + + def key_responder(self): + return self.ge.key_responder() + + def layout_popping(self): + return self.ge.layout_popping() diff --git a/mitmproxy/tools/console/palettes.py b/mitmproxy/tools/console/palettes.py index 7fbdcfd8..df69ff2f 100644 --- a/mitmproxy/tools/console/palettes.py +++ b/mitmproxy/tools/console/palettes.py @@ -24,7 +24,7 @@ class Palette: # List and Connections 'method', 'focus', 'code_200', 'code_300', 'code_400', 'code_500', 'code_other', - 'error', "warn", + 'error', "warn", "alert", 'header', 'highlight', 'intercept', 'replay', 'mark', # Hex view @@ -32,6 +32,9 @@ class Palette: # Grid Editor 'focusfield', 'focusfield_error', 'field_error', 'editfield', + + # Commander + 'commander_command', 'commander_invalid', 'commander_hint' ] high = None # type: typing.Mapping[str, typing.Sequence[str]] @@ -100,6 +103,7 @@ class LowDark(Palette): code_500 = ('light red', 'default'), code_other = ('dark red', 'default'), + alert = ('light magenta', 'default'), warn = ('brown', 'default'), error = ('light red', 'default'), @@ -117,6 +121,11 @@ class LowDark(Palette): focusfield_error = ('dark red', 'light gray'), field_error = ('dark red', 'default'), editfield = ('white', 'default'), + + + commander_command = ('white,bold', 'default'), + commander_invalid = ('light red', 'default'), + commander_hint = ('dark gray', 'default'), ) @@ -168,6 +177,7 @@ class LowLight(Palette): error = ('light red', 'default'), warn = ('brown', 'default'), + alert = ('light magenta', 'default'), header = ('dark blue', 'default'), highlight = ('black,bold', 'default'), @@ -183,6 +193,10 @@ class LowLight(Palette): focusfield_error = ('dark red', 'light gray'), field_error = ('dark red', 'black'), editfield = ('black', 'default'), + + commander_command = ('dark magenta', 'default'), + commander_invalid = ('light red', 'default'), + commander_hint = ('light gray', 'default'), ) @@ -253,6 +267,7 @@ class SolarizedLight(LowLight): error = (sol_red, 'default'), warn = (sol_orange, 'default'), + alert = (sol_magenta, 'default'), header = (sol_blue, 'default'), highlight = (sol_base01, 'default'), @@ -267,6 +282,10 @@ class SolarizedLight(LowLight): focusfield_error = (sol_red, sol_base2), field_error = (sol_red, 'default'), editfield = (sol_base01, 'default'), + + commander_command = (sol_cyan, 'default'), + commander_invalid = (sol_orange, 'default'), + commander_hint = (sol_base1, 'default'), ) @@ -303,6 +322,7 @@ class SolarizedDark(LowDark): error = (sol_red, 'default'), warn = (sol_orange, 'default'), + alert = (sol_magenta, 'default'), header = (sol_blue, 'default'), highlight = (sol_base01, 'default'), @@ -317,6 +337,10 @@ class SolarizedDark(LowDark): focusfield_error = (sol_red, sol_base02), field_error = (sol_red, 'default'), editfield = (sol_base1, 'default'), + + commander_command = (sol_blue, 'default'), + commander_invalid = (sol_orange, 'default'), + commander_hint = (sol_base00, 'default'), ) diff --git a/mitmproxy/tools/console/pathedit.py b/mitmproxy/tools/console/pathedit.py deleted file mode 100644 index 10ee1416..00000000 --- a/mitmproxy/tools/console/pathedit.py +++ /dev/null @@ -1,71 +0,0 @@ -import glob -import os.path - -import urwid - - -class _PathCompleter: - - def __init__(self, _testing=False): - """ - _testing: disables reloading of the lookup table to make testing - possible. - """ - self.lookup, self.offset = None, None - self.final = None - self._testing = _testing - - def reset(self): - self.lookup = None - self.offset = -1 - - def complete(self, txt): - """ - Returns the next completion for txt, or None if there is no - completion. - """ - path = os.path.expanduser(txt) - if not self.lookup: - if not self._testing: - # Lookup is a set of (display value, actual value) tuples. - self.lookup = [] - if os.path.isdir(path): - files = glob.glob(os.path.join(path, "*")) - prefix = txt - else: - files = glob.glob(path + "*") - prefix = os.path.dirname(txt) - prefix = prefix or "./" - for f in files: - display = os.path.join(prefix, os.path.basename(f)) - if os.path.isdir(f): - display += "/" - self.lookup.append((display, f)) - if not self.lookup: - self.final = path - return path - self.lookup.sort() - self.offset = -1 - self.lookup.append((txt, txt)) - self.offset += 1 - if self.offset >= len(self.lookup): - self.offset = 0 - ret = self.lookup[self.offset] - self.final = ret[1] - return ret[0] - - -class PathEdit(urwid.Edit, _PathCompleter): - - def __init__(self, prompt, last_path): - urwid.Edit.__init__(self, prompt, last_path) - _PathCompleter.__init__(self) - - def keypress(self, size, key): - if key == "tab": - comp = self.complete(self.get_edit_text()) - self.set_edit_text(comp) - self.set_edit_pos(len(comp)) - else: - self.reset() - return urwid.Edit.keypress(self, size, key) diff --git a/mitmproxy/tools/console/signals.py b/mitmproxy/tools/console/signals.py index 5d39d96a..9c44b361 100644 --- a/mitmproxy/tools/console/signals.py +++ b/mitmproxy/tools/console/signals.py @@ -1,19 +1,5 @@ import blinker -# Clear the eventlog -sig_clear_log = blinker.Signal() - -# Add an entry to the eventlog -sig_add_log = blinker.Signal() - - -def add_log(e, level): - sig_add_log.send( - None, - e=e, - level=level - ) - # Show a status message in the action bar status_message = blinker.Signal() diff --git a/mitmproxy/tools/console/statusbar.py b/mitmproxy/tools/console/statusbar.py index e85b1bc9..bdb39013 100644 --- a/mitmproxy/tools/console/statusbar.py +++ b/mitmproxy/tools/console/statusbar.py @@ -4,8 +4,9 @@ import urwid from mitmproxy.tools.console import common from mitmproxy.tools.console import signals -from mitmproxy.tools.console import commandeditor +from mitmproxy.tools.console import commandexecutor import mitmproxy.tools.console.master # noqa +from mitmproxy.tools.console.commander import commander class PromptPath: @@ -66,8 +67,8 @@ class ActionBar(urwid.WidgetWrap): def sig_prompt_command(self, sender, partial=""): signals.focus.send(self, section="footer") - self._w = commandeditor.CommandEdit(partial) - self.prompting = commandeditor.CommandExecutor(self.master) + self._w = commander.CommandEdit(self.master, partial) + self.prompting = commandexecutor.CommandExecutor(self.master) def sig_prompt_onekey(self, sender, prompt, keys, callback, args=()): """ @@ -139,9 +140,10 @@ class StatusBar(urwid.WidgetWrap): signals.flowlist_change.connect(self.sig_update) master.options.changed.connect(self.sig_update) master.view.focus.sig_change.connect(self.sig_update) + master.view.sig_view_add.connect(self.sig_update) self.redraw() - def sig_update(self, sender, updated=None): + def sig_update(self, sender, flow=None, updated=None): self.redraw() def keypress(self, *args, **kwargs): @@ -228,9 +230,7 @@ class StatusBar(urwid.WidgetWrap): if self.master.options.mode != "regular": r.append("[%s]" % self.master.options.mode) if self.master.options.scripts: - r.append("[") - r.append(("heading_key", "s")) - r.append("cripts:%s]" % len(self.master.options.scripts)) + r.append("[scripts:%s]" % len(self.master.options.scripts)) if self.master.options.save_stream_file: r.append("[W:%s]" % self.master.options.save_stream_file) diff --git a/mitmproxy/tools/console/window.py b/mitmproxy/tools/console/window.py index c6ff78f8..f2b6d3f4 100644 --- a/mitmproxy/tools/console/window.py +++ b/mitmproxy/tools/console/window.py @@ -15,16 +15,38 @@ from mitmproxy.tools.console import grideditor from mitmproxy.tools.console import eventlog -class Header(urwid.Frame): - def __init__(self, widget, title, focus): - super().__init__( - widget, +class StackWidget(urwid.Frame): + def __init__(self, window, widget, title, focus): + self.is_focused = focus + self.window = window + + if title: header = urwid.AttrWrap( urwid.Text(title), "heading" if focus else "heading_inactive" ) + else: + header = None + super().__init__( + widget, + header=header ) + def mouse_event(self, size, event, button, col, row, focus): + if event == "mouse press" and button == 1 and not self.is_focused: + self.window.switch() + return super().mouse_event(size, event, button, col, row, focus) + + def keypress(self, size, key): + # Make sure that we don't propagate cursor events outside of the widget. + # Otherwise, in a horizontal layout, urwid's Pile would change the focused widget + # if we cannot scroll any further. + ret = super().keypress(size, key) + command = self._command_map[ret] # awkward as they don't implement a full dict api + if command and command.startswith("cursor"): + return None + return ret + class WindowStack: def __init__(self, master, base): @@ -41,6 +63,7 @@ class WindowStack: edit_focus_query = grideditor.QueryEditor(master), edit_focus_cookies = grideditor.CookieEditor(master), edit_focus_setcookies = grideditor.SetCookieEditor(master), + edit_focus_setcookie_attrs = grideditor.CookieAttributeEditor(master), edit_focus_form = grideditor.RequestFormEditor(master), edit_focus_path = grideditor.PathEditor(master), edit_focus_request_headers = grideditor.RequestHeaderEditor(master), @@ -142,12 +165,17 @@ class Window(urwid.Frame): self.pane = 0 def wrapped(idx): - window = self.stacks[idx].top_window() widget = self.stacks[idx].top_widget() - if self.master.options.console_layout_headers and window.title: - return Header(widget, window.title, self.pane == idx) + if self.master.options.console_layout_headers: + title = self.stacks[idx].top_window().title else: - return widget + title = None + return StackWidget( + self, + widget, + title, + self.pane == idx + ) w = None if c == "single": @@ -156,12 +184,14 @@ class Window(urwid.Frame): w = urwid.Pile( [ wrapped(i) for i, s in enumerate(self.stacks) - ] + ], + focus_item=self.pane ) else: w = urwid.Columns( [wrapped(i) for i, s in enumerate(self.stacks)], - dividechars=1 + dividechars=1, + focus_column=self.pane ) self.body = urwid.AttrWrap(w, "background") @@ -214,28 +244,34 @@ class Window(urwid.Frame): self.view_changed() self.focus_changed() - def current(self, keyctx): + def stacks_sorted_by_focus(self): """ - Returns the active widget, but only the current focus or overlay has - a matching key context. + Returns: + self.stacks, with the focused stack first. """ - t = self.focus_stack().top_widget() - if t.keyctx == keyctx: - return t + stacks = self.stacks.copy() + stacks.insert(0, stacks.pop(self.pane)) + return stacks - def current_window(self, keyctx): + def current(self, keyctx): """ - Returns the active window, ignoring overlays. + Returns the active widget with a matching key context, including overlays. + If multiple stacks have an active widget with a matching key context, + the currently focused stack is preferred. """ - t = self.focus_stack().top_window() - if t.keyctx == keyctx: - return t + for s in self.stacks_sorted_by_focus(): + t = s.top_widget() + if t.keyctx == keyctx: + return t - def any(self, keyctx): + def current_window(self, keyctx): """ - Returns the top window of either stack if they match the context. + Returns the active window with a matching key context, ignoring overlays. + If multiple stacks have an active widget with a matching key context, + the currently focused stack is preferred. """ - for t in [x.top_window() for x in self.stacks]: + for s in self.stacks_sorted_by_focus(): + t = s.top_window() if t.keyctx == keyctx: return t @@ -270,13 +306,12 @@ class Window(urwid.Frame): return True def keypress(self, size, k): - if self.focus_part == "footer": - return super().keypress(size, k) - else: - fs = self.focus_stack().top_widget() - k = fs.keypress(size, k) - if k: - return self.master.keymap.handle(fs.keyctx, k) + k = super().keypress(size, k) + if k: + return self.master.keymap.handle( + self.focus_stack().top_widget().keyctx, + k + ) class Screen(urwid.raw_display.Screen): diff --git a/mitmproxy/tools/web/app.py b/mitmproxy/tools/web/app.py index 77695515..36c9d917 100644 --- a/mitmproxy/tools/web/app.py +++ b/mitmproxy/tools/web/app.py @@ -43,6 +43,8 @@ def flow_to_json(flow: mitmproxy.flow.Flow) -> dict: continue f[conn]["alpn_proto_negotiated"] = \ f[conn]["alpn_proto_negotiated"].decode(errors="backslashreplace") + # There are some bytes in here as well, let's skip it until we have them in the UI. + f["client_conn"].pop("tls_extensions", None) if flow.error: f["error"] = flow.error.get_state() diff --git a/mitmproxy/tools/web/master.py b/mitmproxy/tools/web/master.py index 694ee2f7..4c597f0e 100644 --- a/mitmproxy/tools/web/master.py +++ b/mitmproxy/tools/web/master.py @@ -60,7 +60,7 @@ class WebMaster(master.Master): data=app.flow_to_json(flow) ) - def _sig_view_remove(self, view, flow): + def _sig_view_remove(self, view, flow, index): app.ClientConnection.broadcast( resource="flows", cmd="remove", diff --git a/mitmproxy/types.py b/mitmproxy/types.py new file mode 100644 index 00000000..3875128d --- /dev/null +++ b/mitmproxy/types.py @@ -0,0 +1,445 @@ +import os +import glob +import typing + +from mitmproxy import exceptions +from mitmproxy import flow + + +class Path(str): + pass + + +class Cmd(str): + pass + + +class Arg(str): + pass + + +class Unknown(str): + pass + + +class CutSpec(typing.Sequence[str]): + pass + + +class Data(typing.Sequence[typing.Sequence[typing.Union[str, bytes]]]): + pass + + +class Choice: + def __init__(self, options_command): + self.options_command = options_command + + def __instancecheck__(self, instance): # pragma: no cover + # return false here so that arguments are piped through parsearg, + # which does extended validation. + return False + + +# One of the many charming things about mypy is that introducing type +# annotations can cause circular dependencies where there were none before. +# Rather than putting types and the CommandManger in the same file, we introduce +# a stub type with the signature we use. +class _CommandBase: + commands = {} # type: typing.MutableMapping[str, typing.Any] + + def call_args(self, path: str, args: typing.Sequence[str]) -> typing.Any: + raise NotImplementedError + + def call(self, cmd: str) -> typing.Any: + raise NotImplementedError + + +class _BaseType: + typ = object # type: typing.Type + display = "" # type: str + + def completion( + self, manager: _CommandBase, t: typing.Any, s: str + ) -> typing.Sequence[str]: + """ + Returns a list of completion strings for a given prefix. The strings + returned don't necessarily need to be suffixes of the prefix, since + completers will do prefix filtering themselves.. + """ + raise NotImplementedError + + def parse( + self, manager: _CommandBase, typ: typing.Any, s: str + ) -> typing.Any: + """ + Parse a string, given the specific type instance (to allow rich type annotations like Choice) and a string. + + Raises exceptions.TypeError if the value is invalid. + """ + raise NotImplementedError + + def is_valid(self, manager: _CommandBase, typ: typing.Any, val: typing.Any) -> bool: + """ + Check if data is valid for this type. + """ + raise NotImplementedError + + +class _BoolType(_BaseType): + typ = bool + display = "bool" + + def completion(self, manager: _CommandBase, t: type, s: str) -> typing.Sequence[str]: + return ["false", "true"] + + def parse(self, manager: _CommandBase, t: type, s: str) -> bool: + if s == "true": + return True + elif s == "false": + return False + else: + raise exceptions.TypeError( + "Booleans are 'true' or 'false', got %s" % s + ) + + def is_valid(self, manager: _CommandBase, typ: typing.Any, val: typing.Any) -> bool: + return val in [True, False] + + +class _StrType(_BaseType): + typ = str + display = "str" + + def completion(self, manager: _CommandBase, t: type, s: str) -> typing.Sequence[str]: + return [] + + def parse(self, manager: _CommandBase, t: type, s: str) -> str: + return s + + def is_valid(self, manager: _CommandBase, typ: typing.Any, val: typing.Any) -> bool: + return isinstance(val, str) + + +class _UnknownType(_BaseType): + typ = Unknown + display = "unknown" + + def completion(self, manager: _CommandBase, t: type, s: str) -> typing.Sequence[str]: + return [] + + def parse(self, manager: _CommandBase, t: type, s: str) -> str: + return s + + def is_valid(self, manager: _CommandBase, typ: typing.Any, val: typing.Any) -> bool: + return False + + +class _IntType(_BaseType): + typ = int + display = "int" + + def completion(self, manager: _CommandBase, t: type, s: str) -> typing.Sequence[str]: + return [] + + def parse(self, manager: _CommandBase, t: type, s: str) -> int: + try: + return int(s) + except ValueError as e: + raise exceptions.TypeError from e + + def is_valid(self, manager: _CommandBase, typ: typing.Any, val: typing.Any) -> bool: + return isinstance(val, int) + + +class _PathType(_BaseType): + typ = Path + display = "path" + + def completion(self, manager: _CommandBase, t: type, start: str) -> typing.Sequence[str]: + if not start: + start = "./" + path = os.path.expanduser(start) + ret = [] + if os.path.isdir(path): + files = glob.glob(os.path.join(path, "*")) + prefix = start + else: + files = glob.glob(path + "*") + prefix = os.path.dirname(start) + prefix = prefix or "./" + for f in files: + display = os.path.join(prefix, os.path.normpath(os.path.basename(f))) + if os.path.isdir(f): + display += "/" + ret.append(display) + if not ret: + ret = [start] + ret.sort() + return ret + + def parse(self, manager: _CommandBase, t: type, s: str) -> str: + return s + + def is_valid(self, manager: _CommandBase, typ: typing.Any, val: typing.Any) -> bool: + return isinstance(val, str) + + +class _CmdType(_BaseType): + typ = Cmd + display = "cmd" + + def completion(self, manager: _CommandBase, t: type, s: str) -> typing.Sequence[str]: + return list(manager.commands.keys()) + + def parse(self, manager: _CommandBase, t: type, s: str) -> str: + if s not in manager.commands: + raise exceptions.TypeError("Unknown command: %s" % s) + return s + + def is_valid(self, manager: _CommandBase, typ: typing.Any, val: typing.Any) -> bool: + return val in manager.commands + + +class _ArgType(_BaseType): + typ = Arg + display = "arg" + + def completion(self, manager: _CommandBase, t: type, s: str) -> typing.Sequence[str]: + return [] + + def parse(self, manager: _CommandBase, t: type, s: str) -> str: + return s + + def is_valid(self, manager: _CommandBase, typ: typing.Any, val: typing.Any) -> bool: + return isinstance(val, str) + + +class _StrSeqType(_BaseType): + typ = typing.Sequence[str] + display = "[str]" + + def completion(self, manager: _CommandBase, t: type, s: str) -> typing.Sequence[str]: + return [] + + def parse(self, manager: _CommandBase, t: type, s: str) -> typing.Sequence[str]: + return [x.strip() for x in s.split(",")] + + def is_valid(self, manager: _CommandBase, typ: typing.Any, val: typing.Any) -> bool: + if isinstance(val, str) or isinstance(val, bytes): + return False + try: + for v in val: + if not isinstance(v, str): + return False + except TypeError: + return False + return True + + +class _CutSpecType(_BaseType): + typ = CutSpec + display = "[cut]" + valid_prefixes = [ + "request.method", + "request.scheme", + "request.host", + "request.http_version", + "request.port", + "request.path", + "request.url", + "request.text", + "request.content", + "request.raw_content", + "request.timestamp_start", + "request.timestamp_end", + "request.header[", + + "response.status_code", + "response.reason", + "response.text", + "response.content", + "response.timestamp_start", + "response.timestamp_end", + "response.raw_content", + "response.header[", + + "client_conn.address.port", + "client_conn.address.host", + "client_conn.tls_version", + "client_conn.sni", + "client_conn.tls_established", + + "server_conn.address.port", + "server_conn.address.host", + "server_conn.ip_address.host", + "server_conn.tls_version", + "server_conn.sni", + "server_conn.tls_established", + ] + + def completion(self, manager: _CommandBase, t: type, s: str) -> typing.Sequence[str]: + spec = s.split(",") + opts = [] + for pref in self.valid_prefixes: + spec[-1] = pref + opts.append(",".join(spec)) + return opts + + def parse(self, manager: _CommandBase, t: type, s: str) -> CutSpec: + parts = s.split(",") # type: typing.Any + return parts + + def is_valid(self, manager: _CommandBase, typ: typing.Any, val: typing.Any) -> bool: + if not isinstance(val, str): + return False + parts = [x.strip() for x in val.split(",")] + for p in parts: + for pref in self.valid_prefixes: + if p.startswith(pref): + break + else: + return False + return True + + +class _BaseFlowType(_BaseType): + viewmarkers = [ + "@all", + "@focus", + "@shown", + "@hidden", + "@marked", + "@unmarked", + ] + valid_prefixes = viewmarkers + [ + "~q", + "~s", + "~a", + "~hq", + "~hs", + "~b", + "~bq", + "~bs", + "~t", + "~d", + "~m", + "~u", + "~c", + ] + + def completion(self, manager: _CommandBase, t: type, s: str) -> typing.Sequence[str]: + return self.valid_prefixes + + +class _FlowType(_BaseFlowType): + typ = flow.Flow + display = "flow" + + def parse(self, manager: _CommandBase, t: type, s: str) -> flow.Flow: + try: + flows = manager.call_args("view.resolve", [s]) + except exceptions.CommandError as e: + raise exceptions.TypeError from e + if len(flows) != 1: + raise exceptions.TypeError( + "Command requires one flow, specification matched %s." % len(flows) + ) + return flows[0] + + def is_valid(self, manager: _CommandBase, typ: typing.Any, val: typing.Any) -> bool: + return isinstance(val, flow.Flow) + + +class _FlowsType(_BaseFlowType): + typ = typing.Sequence[flow.Flow] + display = "[flow]" + + def parse(self, manager: _CommandBase, t: type, s: str) -> typing.Sequence[flow.Flow]: + try: + return manager.call_args("view.resolve", [s]) + except exceptions.CommandError as e: + raise exceptions.TypeError from e + + def is_valid(self, manager: _CommandBase, typ: typing.Any, val: typing.Any) -> bool: + try: + for v in val: + if not isinstance(v, flow.Flow): + return False + except TypeError: + return False + return True + + +class _DataType(_BaseType): + typ = Data + display = "[data]" + + def completion( + self, manager: _CommandBase, t: type, s: str + ) -> typing.Sequence[str]: # pragma: no cover + raise exceptions.TypeError("data cannot be passed as argument") + + def parse( + self, manager: _CommandBase, t: type, s: str + ) -> typing.Any: # pragma: no cover + raise exceptions.TypeError("data cannot be passed as argument") + + def is_valid(self, manager: _CommandBase, typ: typing.Any, val: typing.Any) -> bool: + # FIXME: validate that all rows have equal length, and all columns have equal types + try: + for row in val: + for cell in row: + if not (isinstance(cell, str) or isinstance(cell, bytes)): + return False + except TypeError: + return False + return True + + +class _ChoiceType(_BaseType): + typ = Choice + display = "choice" + + def completion(self, manager: _CommandBase, t: Choice, s: str) -> typing.Sequence[str]: + return manager.call(t.options_command) + + def parse(self, manager: _CommandBase, t: Choice, s: str) -> str: + opts = manager.call(t.options_command) + if s not in opts: + raise exceptions.TypeError("Invalid choice.") + return s + + def is_valid(self, manager: _CommandBase, typ: typing.Any, val: typing.Any) -> bool: + try: + opts = manager.call(typ.options_command) + except exceptions.CommandError: + return False + return val in opts + + +class TypeManager: + def __init__(self, *types): + self.typemap = {} + for t in types: + self.typemap[t.typ] = t() + + def get(self, t: type, default=None) -> _BaseType: + if type(t) in self.typemap: + return self.typemap[type(t)] + return self.typemap.get(t, default) + + +CommandTypes = TypeManager( + _ArgType, + _BoolType, + _ChoiceType, + _CmdType, + _CutSpecType, + _DataType, + _FlowType, + _FlowsType, + _IntType, + _PathType, + _StrType, + _StrSeqType, +) diff --git a/mitmproxy/utils/arg_check.py b/mitmproxy/utils/arg_check.py index 73f7047c..873bef06 100644 --- a/mitmproxy/utils/arg_check.py +++ b/mitmproxy/utils/arg_check.py @@ -66,9 +66,9 @@ REPLACEMENTS = { "--palette": "console_palette", "--palette-transparent": "console_palette_transparent:", "--follow": "console_focus_follow", - "--order": "console_order", + "--order": "view_order", "--no-mouse": "console_mouse", - "--reverse": "console_order_reversed", + "--reverse": "view_order_reversed", "--no-http2-priority": "http2_priority", "--no-websocket": "websocket", "--no-upstream-cert": "upstream_cert", diff --git a/mitmproxy/utils/debug.py b/mitmproxy/utils/debug.py index de01b12c..e8eca906 100644 --- a/mitmproxy/utils/debug.py +++ b/mitmproxy/utils/debug.py @@ -1,43 +1,24 @@ import gc import os +import platform +import re +import signal import sys import threading -import signal -import platform import traceback -import subprocess - -from mitmproxy import version from OpenSSL import SSL +from mitmproxy import version -def dump_system_info(): - mitmproxy_version = version.VERSION - here = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) - try: - git_describe = subprocess.check_output( - ['git', 'describe', '--tags', '--long'], - stderr=subprocess.STDOUT, - cwd=here, - ) - except: - pass - else: - last_tag, tag_dist, commit = git_describe.decode().strip().rsplit("-", 2) - - commit = commit.lstrip("g") # remove the 'g' prefix added by recent git versions - tag_dist = int(tag_dist) - - if tag_dist > 0: - tag_dist = "dev{:04}".format(tag_dist) - else: - tag_dist = "" - mitmproxy_version += "{tag_dist} ({commit})".format( - tag_dist=tag_dist, - commit=commit, - ) +def dump_system_info(): + mitmproxy_version = version.get_version(True, True) + mitmproxy_version = re.sub( + r"-0x([0-9a-f]+)", + r" (commit \1)", + mitmproxy_version + ) # PyInstaller builds indicator, if using precompiled binary if getattr(sys, 'frozen', False): diff --git a/mitmproxy/utils/human.py b/mitmproxy/utils/human.py index e2e3142a..b21ac0b8 100644 --- a/mitmproxy/utils/human.py +++ b/mitmproxy/utils/human.py @@ -80,6 +80,8 @@ def format_address(address: tuple) -> str: """ try: host = ipaddress.ip_address(address[0]) + if host.is_unspecified: + return "*:{}".format(address[1]) if isinstance(host, ipaddress.IPv4Address): return "{}:{}".format(str(host), address[1]) # If IPv6 is mapped to IPv4 diff --git a/mitmproxy/utils/typecheck.py b/mitmproxy/utils/typecheck.py index 87a0e804..22db68f5 100644 --- a/mitmproxy/utils/typecheck.py +++ b/mitmproxy/utils/typecheck.py @@ -1,42 +1,40 @@ import typing +Type = typing.Union[ + typing.Any # anything more elaborate really fails with mypy at the moment. +] + + +def sequence_type(typeinfo: typing.Type[typing.List]) -> Type: + """Return the type of a sequence, e.g. typing.List""" + try: + return typeinfo.__args__[0] # type: ignore + except AttributeError: # Python 3.5.0 + return typeinfo.__parameters__[0] # type: ignore -def check_command_type(value: typing.Any, typeinfo: typing.Any) -> bool: - """ - Check if the provided value is an instance of typeinfo. Returns True if the - types match, False otherwise. This function supports only those types - required for command return values. - """ - typename = str(typeinfo) - if typename.startswith("typing.Sequence"): - try: - T = typeinfo.__args__[0] # type: ignore - except AttributeError: - # Python 3.5.0 - T = typeinfo.__parameters__[0] # type: ignore - if not isinstance(value, (tuple, list)): - return False - for v in value: - if not check_command_type(v, T): - return False - elif typename.startswith("typing.Union"): - try: - types = typeinfo.__args__ # type: ignore - except AttributeError: - # Python 3.5.x - types = typeinfo.__union_params__ # type: ignore - for T in types: - checks = [check_command_type(value, T) for T in types] - if not any(checks): - return False - elif value is None and typeinfo is None: - return True - elif not isinstance(value, typeinfo): - return False - return True +def tuple_types(typeinfo: typing.Type[typing.Tuple]) -> typing.Sequence[Type]: + """Return the types of a typing.Tuple""" + try: + return typeinfo.__args__ # type: ignore + except AttributeError: # Python 3.5.x + return typeinfo.__tuple_params__ # type: ignore -def check_option_type(name: str, value: typing.Any, typeinfo: typing.Any) -> None: + +def union_types(typeinfo: typing.Type[typing.Tuple]) -> typing.Sequence[Type]: + """return the types of a typing.Union""" + try: + return typeinfo.__args__ # type: ignore + except AttributeError: # Python 3.5.x + return typeinfo.__union_params__ # type: ignore + + +def mapping_types(typeinfo: typing.Type[typing.Mapping]) -> typing.Tuple[Type, Type]: + """return the types of a mapping, e.g. typing.Dict""" + return typeinfo.__args__ # type: ignore + + +def check_option_type(name: str, value: typing.Any, typeinfo: Type) -> None: """ Check if the provided value is an instance of typeinfo and raises a TypeError otherwise. This function supports only those types required for @@ -51,13 +49,7 @@ def check_option_type(name: str, value: typing.Any, typeinfo: typing.Any) -> Non typename = str(typeinfo) if typename.startswith("typing.Union"): - try: - types = typeinfo.__args__ # type: ignore - except AttributeError: - # Python 3.5.x - types = typeinfo.__union_params__ # type: ignore - - for T in types: + for T in union_types(typeinfo): try: check_option_type(name, value, T) except TypeError: @@ -66,12 +58,7 @@ def check_option_type(name: str, value: typing.Any, typeinfo: typing.Any) -> Non return raise e elif typename.startswith("typing.Tuple"): - try: - types = typeinfo.__args__ # type: ignore - except AttributeError: - # Python 3.5.x - types = typeinfo.__tuple_params__ # type: ignore - + types = tuple_types(typeinfo) if not isinstance(value, (tuple, list)): raise e if len(types) != len(value): @@ -80,11 +67,7 @@ def check_option_type(name: str, value: typing.Any, typeinfo: typing.Any) -> Non check_option_type("{}[{}]".format(name, i), x, T) return elif typename.startswith("typing.Sequence"): - try: - T = typeinfo.__args__[0] # type: ignore - except AttributeError: - # Python 3.5.0 - T = typeinfo.__parameters__[0] # type: ignore + T = sequence_type(typeinfo) if not isinstance(value, (tuple, list)): raise e for v in value: diff --git a/mitmproxy/version.py b/mitmproxy/version.py index 3cae2a04..c2cb3822 100644 --- a/mitmproxy/version.py +++ b/mitmproxy/version.py @@ -1,11 +1,64 @@ -IVERSION = (3, 0, 0) -VERSION = ".".join(str(i) for i in IVERSION) +import os +import subprocess + +# The actual version string. For precompiled binaries, this will be changed to include the build +# tag, e.g. "3.0.0.dev0042-0xcafeabc" +VERSION = "3.0.0" PATHOD = "pathod " + VERSION MITMPROXY = "mitmproxy " + VERSION # Serialization format version. This is displayed nowhere, it just needs to be incremented by one # for each change in the file format. -FLOW_FORMAT_VERSION = 5 +FLOW_FORMAT_VERSION = 7 + + +def get_version(dev: bool = False, build: bool = False, refresh: bool = False) -> str: + """ + Return a detailed version string, sourced either from a hardcoded VERSION constant + or obtained dynamically using git. + + Args: + dev: If True, non-tagged releases will include a ".devXXXX" suffix, where XXXX is the number + of commits since the last tagged release. + build: If True, non-tagged releases will include a "-0xXXXXXXX" suffix, where XXXXXXX are + the first seven digits of the commit hash. + refresh: If True, always try to use git instead of a potentially hardcoded constant. + """ + + mitmproxy_version = VERSION + + if "dev" in VERSION and not refresh: + pass # There is a hardcoded build tag, so we just use what's there. + elif dev or build: + here = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) + try: + git_describe = subprocess.check_output( + ['git', 'describe', '--long'], + stderr=subprocess.STDOUT, + cwd=here, + ) + last_tag, tag_dist, commit = git_describe.decode().strip().rsplit("-", 2) + commit = commit.lstrip("g")[:7] + tag_dist = int(tag_dist) + except Exception: + pass + else: + # Remove current suffix + mitmproxy_version = mitmproxy_version.split(".dev")[0] + + # Add suffix for non-tagged releases + if tag_dist > 0: + mitmproxy_version += ".dev{tag_dist}".format(tag_dist=tag_dist) + # The wheel build tag (we use the commit) must start with a digit, so we include "0x" + mitmproxy_version += "-0x{commit}".format(commit=commit) + + if not dev: + mitmproxy_version = mitmproxy_version.split(".dev")[0] + elif not build: + mitmproxy_version = mitmproxy_version.split("-0x")[0] + + return mitmproxy_version + -if __name__ == "__main__": +if __name__ == "__main__": # pragma: no cover print(VERSION) diff --git a/mitmproxy/websocket.py b/mitmproxy/websocket.py index ded09f65..66257852 100644 --- a/mitmproxy/websocket.py +++ b/mitmproxy/websocket.py @@ -1,51 +1,82 @@ import time from typing import List, Optional +from wsproto.frame_protocol import CloseReason +from wsproto.frame_protocol import Opcode + from mitmproxy import flow from mitmproxy.net import websockets -from mitmproxy.types import serializable +from mitmproxy.coretypes import serializable from mitmproxy.utils import strutils, human class WebSocketMessage(serializable.Serializable): + """ + A WebSocket message sent from one endpoint to the other. + """ + def __init__( - self, type: int, from_client: bool, content: bytes, timestamp: Optional[int]=None + self, type: int, from_client: bool, content: bytes, timestamp: Optional[int]=None, killed: bool=False ) -> None: - self.type = type + self.type = Opcode(type) # type: ignore + """indicates either TEXT or BINARY (from wsproto.frame_protocol.Opcode).""" self.from_client = from_client + """True if this messages was sent by the client.""" self.content = content + """A byte-string representing the content of this message.""" self.timestamp = timestamp or int(time.time()) # type: int + """Timestamp of when this message was received or created.""" + self.killed = killed + """True if this messages was killed and should not be sent to the other endpoint.""" @classmethod def from_state(cls, state): return cls(*state) def get_state(self): - return self.type, self.from_client, self.content, self.timestamp + return int(self.type), self.from_client, self.content, self.timestamp, self.killed def set_state(self, state): - self.type, self.from_client, self.content, self.timestamp = state + self.type, self.from_client, self.content, self.timestamp, self.killed = state + self.type = Opcode(self.type) # replace enum with bare int def __repr__(self): - if self.type == websockets.OPCODE.TEXT: + if self.type == Opcode.TEXT: return "text message: {}".format(repr(self.content)) else: return "binary message: {}".format(strutils.bytes_to_escaped_str(self.content)) + def kill(self): + """ + Kill this message. + + It will not be sent to the other endpoint. This has no effect in streaming mode. + """ + self.killed = True + class WebSocketFlow(flow.Flow): """ - A WebsocketFlow is a simplified representation of a Websocket session. + A WebsocketFlow is a simplified representation of a Websocket connection. """ def __init__(self, client_conn, server_conn, handshake_flow, live=None): super().__init__("websocket", client_conn, server_conn, live) + self.messages = [] # type: List[WebSocketMessage] + """A list containing all WebSocketMessage's.""" self.close_sender = 'client' - self.close_code = '(status code missing)' + """'client' if the client initiated connection closing.""" + self.close_code = CloseReason.NORMAL_CLOSURE + """WebSocket close code.""" self.close_message = '(message missing)' + """WebSocket close message.""" self.close_reason = 'unknown status code' + """WebSocket close reason.""" self.stream = False + """True of this connection is streaming directly to the other endpoint.""" + self.handshake_flow = handshake_flow + """The HTTP flow containing the initial WebSocket handshake.""" if handshake_flow: self.client_key = websockets.get_client_key(handshake_flow.request.headers) @@ -62,14 +93,12 @@ class WebSocketFlow(flow.Flow): self.server_protocol = '' self.server_extensions = '' - self.handshake_flow = handshake_flow - _stateobject_attributes = flow.Flow._stateobject_attributes.copy() # mypy doesn't support update with kwargs _stateobject_attributes.update(dict( messages=List[WebSocketMessage], close_sender=str, - close_code=str, + close_code=int, close_message=str, close_reason=str, client_key=str, @@ -83,6 +112,11 @@ class WebSocketFlow(flow.Flow): # dumping the handshake_flow will include the WebSocketFlow too. )) + def get_state(self): + d = super().get_state() + d['close_code'] = int(d['close_code']) # replace enum with bare int + return d + @classmethod def from_state(cls, state): f = cls(None, None, None) diff --git a/pathod/pathoc.py b/pathod/pathoc.py index e1052750..b177d556 100644 --- a/pathod/pathoc.py +++ b/pathod/pathoc.py @@ -17,7 +17,7 @@ from mitmproxy.net import tcp, tls from mitmproxy.net import websockets from mitmproxy.net import socks from mitmproxy.net import http as net_http -from mitmproxy.types import basethread +from mitmproxy.coretypes import basethread from mitmproxy.utils import strutils from pathod import log @@ -79,7 +79,7 @@ class SSLInfo: } t = types.get(pk.type(), "Uknown") parts.append("\tPubkey: %s bit %s" % (pk.bits(), t)) - s = certs.SSLCert(i) + s = certs.Cert(i) if s.altnames: parts.append("\tSANs: %s" % " ".join(strutils.always_str(n, "utf8") for n in s.altnames)) return "\n".join(parts) @@ -244,6 +244,7 @@ class Pathoc(tcp.TCPClient): port=connect_to[1], path=None, http_version='HTTP/1.1', + headers=[(b"Host", connect_to[0].encode("idna"))], content=b'', ) self.wfile.write(net_http.http1.assemble_request(req)) @@ -312,7 +313,7 @@ class Pathoc(tcp.TCPClient): if self.use_http2: alpn_protos.append(b'h2') - self.convert_to_ssl( + self.convert_to_tls( sni=self.sni, cert=self.clientcert, method=self.ssl_version, diff --git a/pathod/pathod.py b/pathod/pathod.py index f8e64f9e..17db57ee 100644 --- a/pathod/pathod.py +++ b/pathod/pathod.py @@ -170,7 +170,7 @@ class PathodHandler(tcp.BaseHandler): ), cipher=None, ) - if self.ssl_established: + if self.tls_established: retlog["cipher"] = self.get_current_cipher() m = utils.MemBool() @@ -244,7 +244,7 @@ class PathodHandler(tcp.BaseHandler): if self.server.ssl: try: cert, key, _ = self.server.ssloptions.get_cert(None) - self.convert_to_ssl( + self.convert_to_tls( cert, key, handle_sni=self.handle_sni, diff --git a/pathod/protocols/http.py b/pathod/protocols/http.py index 4387b4fb..5fcb6618 100644 --- a/pathod/protocols/http.py +++ b/pathod/protocols/http.py @@ -27,7 +27,7 @@ class HTTPProtocol: cert, key, chain_file_ = self.pathod_handler.server.ssloptions.get_cert( connect[0].encode() ) - self.pathod_handler.convert_to_ssl( + self.pathod_handler.convert_to_tls( cert, key, handle_sni=self.pathod_handler.handle_sni, diff --git a/pathod/protocols/http2.py b/pathod/protocols/http2.py index cfc71650..c56d304d 100644 --- a/pathod/protocols/http2.py +++ b/pathod/protocols/http2.py @@ -8,7 +8,7 @@ from mitmproxy.net.http import http2 import mitmproxy.net.http.headers import mitmproxy.net.http.response import mitmproxy.net.http.request -from mitmproxy.types import bidi +from mitmproxy.coretypes import bidi from .. import language diff --git a/pathod/protocols/websockets.py b/pathod/protocols/websockets.py index 2d1f1bf6..63e6ee0b 100644 --- a/pathod/protocols/websockets.py +++ b/pathod/protocols/websockets.py @@ -30,7 +30,7 @@ class WebsocketsProtocol: ), cipher=None, ) - if self.pathod_handler.ssl_established: + if self.pathod_handler.tls_established: retlog["cipher"] = self.pathod_handler.get_current_cipher() self.pathod_handler.addlog(retlog) ld = language.websockets.NESTED_LEADER diff --git a/pathod/test.py b/pathod/test.py index 52f3ba02..819c7a94 100644 --- a/pathod/test.py +++ b/pathod/test.py @@ -2,7 +2,7 @@ import io import time import queue from . import pathod -from mitmproxy.types import basethread +from mitmproxy.coretypes import basethread import typing # noqa diff --git a/release/.gitignore b/release/.gitignore index 2247d5f9..905eec6e 100644 --- a/release/.gitignore +++ b/release/.gitignore @@ -1,2 +1,3 @@ /build /dist +known_hosts diff --git a/release/README.md b/release/README.md index a60b7f98..7bb89638 100644 --- a/release/README.md +++ b/release/README.md @@ -5,6 +5,10 @@ Make sure run all these steps on the correct branch you want to create a new rel - Update CHANGELOG - Verify that all CI tests pass - Tag the release and push to Github + - For alphas, betas, and release candidates, use lightweight tags. + This is necessary so that the .devXXXX counter does not reset. + - For final releases, use annotated tags. + This makes the .devXXXX counter reset. - Wait for tag CI to complete ## GitHub Release diff --git a/release/hooks/hook-mitmproxy.py b/release/hooks/hook-mitmproxy.py new file mode 100644 index 00000000..21932507 --- /dev/null +++ b/release/hooks/hook-mitmproxy.py @@ -0,0 +1 @@ +hiddenimports = ["mitmproxy.script"] diff --git a/release/known_hosts.enc b/release/known_hosts.enc new file mode 100644 index 00000000..585ee678 --- /dev/null +++ b/release/known_hosts.enc @@ -0,0 +1 @@ +gAAAAABaTif138dCP2-G3sAJxqh5icnwM0Zy7qh4HFCxeKQBMiVDr4nJyf9T82U677M_QKWRJmp_PsbnrshHXPylq0FuHwak7Yx7kdiLue6d85VQ7_kkMs-MlPM7_Xn54_zyuj1c0b3TVAuix2xHfFLdSd_mCxygFukLzf47OyYbno7lMY_-q0HZfVPz3PBZdk95wDcbYprmgEkVJZd64Tu_LG1JDDiz56LlqADMA4znMcSAoRmbVtHu-II09HMcX3TkmcqJsNv-IVHMs4fxW_DFsq9w5ARggL6ANMfhnFQPyMtgVHjGLkSjOMRshLkQUBVYx8yWEGaQOkP0doVtDS3fZ-MKc6OJC_NSs6gkm1rswjVsQsmgZGPIqjcVf9oCbFYcw0m-JrfB1irdsLoGzpfJaSGxveC7XqOd9ArBpCHFPVO-6ilu-E1qZelvL0HiplrFvJCMEev1U2YvznC1BWKpy81vJfH--64QKZ35yQBHMV_VoH-wi80EfWtz4ISvCMQWdjRAvhLHKHSYYhUSIgBZvCCQcPySdFpbDtwsQnzIqC8MQKG787w1FiYAwzdIHTWZuanENaPMALo0t0GgMSqPV4UUyw7dto8XSMqoUXOCuZNYjunVh7AzAKS7oMUYjDs38o92sWh5sZUpPfv2WYIiecTiQw4uPae7PdSwMhkI3WIOsSb8LURnG484vvgFc2jMpQThw-BHJx7tGYC0yFLouRH2O7m9x6xgiCiVA_u_BdOj_2PFufvOCaB9wno5Vo7C1hUERGWqoBZH0htBqxYci27hh8GFwkvj6OjFUyV_kk920cBYBDG4jS4bTrTzn_znJ9TNw2XkP98nA8cwlRYhDQG9FypJG0WwYkft3TVLSQ3Hq7t0nhvhSZvXts-3LR4S0_Hm0QgFUpUc-VHViinwK8_vQH3ZjvVlEWiXnzPdpAujjX_tQXsi13UE1Zp90wGeLrmdxGXq2K76Shytu8IwTcLNZ7m0jh8KmmfNwn6oZv-czqNmC4hh0OqRDFBrv3nnjDg2Vw74uKSZmXgtZlF_Zj9hPqxVWzj7lJUcyRqABBFbBH6lTSWPHLrzQ4eTex5dnOkXC8c3hRYDUt06xUkmDqaLK0rGFcfNXawZj1YqpUJW0qaNgbtBZRsSs92kblkETxCzcwxOfupmAhWdSkmCoxt019crodz3heREcyN2xcD9qHvdY49_FD3l3U6UhrWvmkDkzyLMd7VmRPWqlW0lkzrwav8e92leIq-xKFcvbnWgSdSCWWbXvIVJKcQ6hML3jX4oY7SoBs33U1Q0HfC7SuS5lqTASuRIOVCfIGeFfRwlIfEszbWg_WDoUjR6StaVq9tbtIC3mimWND82Z9r1NfUNxr8kFYIpH_6hbxhcW26HNBKr4wLxWFFE9l1QZORPM3s6z-lT4LzUPCkFExd_eYFx3X6yUJ3cHZhkQQzCLQqG7jQqvcMwDIfM-MXkJnttLfpBq0yiq0-mc-SEas5uy27iSJgbXnsV7G3YiKEelKW_uWP2bw-rQGG_AXMGNGF2A_aREsvGrEqPnyeHAxfS1bBcnqslpIzEwr9vyyJ5v_bxfHFQC4bwYMUvPGkjHVFc0Wrk7ss9P5Kd1bzh46H7OfroUbocmYBmHMMWEg-LvsG0RZil3KWh_CSyIIPETkDjuC3W7teT-wZK0zbTEaKCuz99Dg-tjzT6fP25ipoI70cX5R3KPwrLP3XNODRTsg_Jh7IpaXo9O3o8yLV9R6_rST_1KKJwzR2MMIXIvKaJQD9w2DZIaYx3tcVsXGCDnU4Tw2hhdB5wMCl3vHx83UHfjLxnc1tJ6ObpQUjwHM1SgHK8wLW409SVHphBbSjSilX5mIaR1S1SOTK53iFj5z6asZHY9JgDj11rng1uLKeirbrNZDnUme3NNYU-HX8Ret6oOesn3374uIHux1giqgR8VsPdkcMhvunx2oTP9R2fRBTSQ8sKNqDznRC8_qlQaRC94RnWO6VRNXVBT24cXq7HTepNp4f02UvUqQRyaIUmyn2S02mjLFECDm1iMxRhuacCKbI-WSKwJcm-7p39_Uh7m_nTl2VTseeQ-3NS6i-BiGmCHt3iDxR1Fkm31b50kWW3jCe6fcwMDeu3I_8mkQs_7mCFUjSDbvFUr2Y45a5guRlw63_KUW_mNN9td9hk8POWfxWEGhcZ9eRXh_eEdEaYZmviZdHi0I8pV52CqiEO-ZrnMw-w4rSpUQeRn9oKwp3GgB9j51RNlLqK9LTp-jfSGGi5GM-ab9sPgFCJLQ-HvHdGu0tQsF2wTD3qbJwNqapx28yNVfY6e8F2jOWjmP-zzFez8VNXcfoS--Ji_zI-VqsDx-cfz3DccWEjL6vjQOvaQTRwzhI7
\ No newline at end of file diff --git a/release/rtool.py b/release/rtool.py index 271392ba..9050107e 100755 --- a/release/rtool.py +++ b/release/rtool.py @@ -4,6 +4,7 @@ import contextlib import fnmatch import os import platform +import re import runpy import shlex import shutil @@ -79,26 +80,21 @@ def git(args: str) -> str: return subprocess.check_output(["git"] + shlex.split(args)).decode() -def get_version() -> str: - return runpy.run_path(VERSION_FILE)["VERSION"] +def get_version(dev: bool = False, build: bool = False) -> str: + x = runpy.run_path(VERSION_FILE) + return x["get_version"](dev, build, True) -def get_snapshot_version() -> str: - last_tag, tag_dist, commit = git("describe --tags --long").strip().rsplit("-", 2) - tag_dist = int(tag_dist) - if tag_dist == 0: - return get_version() - else: - # remove the 'g' prefix added by recent git versions - if commit.startswith('g'): - commit = commit[1:] - - # The wheel build tag (we use the commit) must start with a digit, so we include "0x" - return "{version}dev{tag_dist:04}-0x{commit}".format( - version=get_version(), # this should already be the next version - tag_dist=tag_dist, - commit=commit - ) +def set_version(dev: bool) -> None: + """ + Update version information in mitmproxy's version.py to either include hardcoded information or not. + """ + version = get_version(dev, dev) + with open(VERSION_FILE, "r") as f: + content = f.read() + content = re.sub(r'^VERSION = ".+?"', 'VERSION = "{}"'.format(version), content, flags=re.M) + with open(VERSION_FILE, "w") as f: + f.write(content) def archive_name(bdist: str) -> str: @@ -116,7 +112,7 @@ def archive_name(bdist: str) -> str: def wheel_name() -> str: return "mitmproxy-{version}-py3-none-any.whl".format( - version=get_version(), + version=get_version(True), ) @@ -179,6 +175,23 @@ def contributors(): f.write(contributors_data.encode()) +@cli.command("wheel") +def make_wheel(): + """ + Build a Python wheel + """ + set_version(True) + try: + subprocess.check_call([ + "tox", "-e", "wheel", + ], env={ + **os.environ, + "VERSION": get_version(True), + }) + finally: + set_version(False) + + @cli.command("bdist") def make_bdist(): """ @@ -206,24 +219,30 @@ def make_bdist(): excludes.append("mitmproxy.tools.web") if tool != "mitmproxy_main": excludes.append("mitmproxy.tools.console") - subprocess.check_call( - [ - "pyinstaller", - "--clean", - "--workpath", PYINSTALLER_TEMP, - "--distpath", PYINSTALLER_DIST, - "--additional-hooks-dir", PYINSTALLER_HOOKS, - "--onefile", - "--console", - "--icon", "icon.ico", - # This is PyInstaller, so setting a - # different log level obviously breaks it :-) - # "--log-level", "WARN", - ] - + [x for e in excludes for x in ["--exclude-module", e]] - + PYINSTALLER_ARGS - + [tool] - ) + + # Overwrite mitmproxy/version.py to include commit info + set_version(True) + try: + subprocess.check_call( + [ + "pyinstaller", + "--clean", + "--workpath", PYINSTALLER_TEMP, + "--distpath", PYINSTALLER_DIST, + "--additional-hooks-dir", PYINSTALLER_HOOKS, + "--onefile", + "--console", + "--icon", "icon.ico", + # This is PyInstaller, so setting a + # different log level obviously breaks it :-) + # "--log-level", "WARN", + ] + + [x for e in excludes for x in ["--exclude-module", e]] + + PYINSTALLER_ARGS + + [tool] + ) + finally: + set_version(False) # Delete the spec file - we're good without. os.remove("{}.spec".format(tool)) @@ -280,11 +299,15 @@ def upload_snapshot(host, port, user, private_key, private_key_password, wheel, """ Upload snapshot to snapshot server """ + cnopts = pysftp.CnOpts( + knownhosts=join(RELEASE_DIR, 'known_hosts') + ) with pysftp.Connection(host=host, port=port, username=user, private_key=private_key, - private_key_pass=private_key_password) as sftp: + private_key_pass=private_key_password, + cnopts=cnopts) as sftp: dir_name = "snapshots/v{}".format(get_version()) sftp.makedirs(dir_name) with sftp.cd(dir_name): @@ -299,7 +322,11 @@ def upload_snapshot(host, port, user, private_key, private_key_password, wheel, for f in files: local_path = join(DIST_DIR, f) - remote_filename = f.replace(get_version(), get_snapshot_version()) + remote_filename = re.sub( + r"{version}(\.dev\d+(-0x[0-9a-f]+)?)?".format(version=get_version()), + get_version(True, True), + f + ) symlink_path = "../{}".format(f.replace(get_version(), "latest")) # Upload new version diff --git a/release/setup.py b/release/setup.py deleted file mode 100644 index 0c4e6605..00000000 --- a/release/setup.py +++ /dev/null @@ -1,18 +0,0 @@ -from setuptools import setup - -setup( - name='mitmproxy-rtool', - version="1.0", - py_modules=["rtool"], - install_requires=[ - "click>=6.2, <7.0", - "twine>=1.6.5, <1.10", - "pysftp==0.2.8", - "cryptography>=2.0.0, <2.1", - ], - entry_points={ - "console_scripts": [ - "rtool=rtool:cli", - ], - }, -) @@ -19,9 +19,18 @@ exclude_lines = pragma: no cover raise NotImplementedError() +[mypy-mitmproxy.contrib.*] +ignore_errors = True + [tool:full_coverage] exclude = - mitmproxy/proxy/protocol/ + mitmproxy/proxy/protocol/base.py + mitmproxy/proxy/protocol/http.py + mitmproxy/proxy/protocol/http1.py + mitmproxy/proxy/protocol/http2.py + mitmproxy/proxy/protocol/http_replay.py + mitmproxy/proxy/protocol/rawtcp.py + mitmproxy/proxy/protocol/tls.py mitmproxy/proxy/root_context.py mitmproxy/proxy/server.py mitmproxy/tools/ @@ -64,10 +73,8 @@ exclude = mitmproxy/proxy/protocol/http_replay.py mitmproxy/proxy/protocol/rawtcp.py mitmproxy/proxy/protocol/tls.py - mitmproxy/proxy/protocol/websocket.py mitmproxy/proxy/root_context.py mitmproxy/proxy/server.py - mitmproxy/stateobject.py mitmproxy/utils/bits.py pathod/language/actions.py pathod/language/base.py @@ -1,7 +1,7 @@ import os -import runpy from codecs import open +import re from setuptools import setup, find_packages # Based on https://github.com/pypa/sampleproject/blob/master/setup.py @@ -12,7 +12,8 @@ here = os.path.abspath(os.path.dirname(__file__)) with open(os.path.join(here, 'README.rst'), encoding='utf-8') as f: long_description = f.read() -VERSION = runpy.run_path(os.path.join(here, "mitmproxy", "version.py"))["VERSION"] +with open(os.path.join(here, "mitmproxy", "version.py")) as f: + VERSION = re.search(r'VERSION = "(.+?)(?:-0x|")', f.read()).group(1) setup( name="mitmproxy", @@ -61,24 +62,26 @@ setup( # It is not considered best practice to use install_requires to pin dependencies to specific versions. install_requires=[ "blinker>=1.4, <1.5", - "brotlipy>=0.5.1, <0.8", + "brotlipy>=0.7.0,<0.8", "certifi>=2015.11.20.1", # no semver here - this should always be on the last release! "click>=6.2, <7", - "cryptography>=2.0,<2.2", - "h2>=3.0, <4", - "hyperframe>=5.0, <6", - "kaitaistruct>=0.7, <0.8", + "cryptography>=2.1.4,<2.2", + 'h11>=0.7.0,<0.8', + "h2>=3.0.1,<4", + "hyperframe>=5.1.0,<6", + "kaitaistruct>=0.7,<0.9", "ldap3>=2.4,<2.5", "passlib>=1.6.5, <1.8", "pyasn1>=0.3.1,<0.5", - "pyOpenSSL>=17.2,<17.6", + "pyOpenSSL>=17.5,<17.6", "pyparsing>=2.1.3, <2.3", "pyperclip>=1.5.22, <1.7", "requests>=2.9.1, <3", "ruamel.yaml>=0.13.2, <0.16", "sortedcontainers>=1.5.4, <1.6", "tornado>=4.3, <4.6", - "urwid>=1.3.1, <1.4", + "urwid>=2.0.1,<2.1", + "wsproto>=0.11.0,<0.12.0", ], extras_require={ ':sys_platform == "win32"': [ @@ -87,22 +90,22 @@ setup( 'dev': [ "flake8>=3.5, <3.6", "Flask>=0.10.1, <0.13", - "mypy>=0.550,<0.551", - "pytest-cov>=2.2.1, <3", - "pytest-faulthandler>=1.3.0, <2", - "pytest-timeout>=1.0.0, <2", - "pytest-xdist>=1.14, <2", - "pytest>=3.1, <4", + "mypy>=0.560,<0.561", + "pytest-cov>=2.5.1,<3", + "pytest-faulthandler>=1.3.1,<2", + "pytest-timeout>=1.2.1,<2", + "pytest-xdist>=1.22,<2", + "pytest>=3.3,<4", "rstcheck>=2.2, <4.0", "sphinx_rtd_theme>=0.1.9, <0.3", "sphinx-autobuild>=0.5.2, <0.8", - "sphinx>=1.3.5, <1.7", + "sphinx>=1.7,<1.8", "sphinxcontrib-documentedlist>=0.5.0, <0.7", "tox>=2.3, <3", ], 'examples': [ "beautifulsoup4>=4.4.1, <4.7", - "Pillow>=4.3,<4.4", + "Pillow>=4.3,<5.1", ] } ) diff --git a/test/examples/test_xss_scanner.py b/test/examples/test_xss_scanner.py index e15d7e10..8cf06a2a 100644 --- a/test/examples/test_xss_scanner.py +++ b/test/examples/test_xss_scanner.py @@ -343,10 +343,10 @@ class TestXSSScanner(): monkeypatch.setattr("mitmproxy.ctx.log", logger) xss.log_SQLi_data(None) assert logger.args == [] - xss.log_SQLi_data(xss.SQLiData(b'https://example.com', - b'Location', - b'Oracle.*Driver', - b'Oracle')) + xss.log_SQLi_data(xss.SQLiData('https://example.com', + 'Location', + 'Oracle.*Driver', + 'Oracle')) assert logger.args[0] == '===== SQLi Found =====' assert logger.args[1] == 'SQLi URL: https://example.com' assert logger.args[2] == 'Injection Point: Location' diff --git a/test/mitmproxy/addons/test_browser.py b/test/mitmproxy/addons/test_browser.py index d1b32186..407a3fe6 100644 --- a/test/mitmproxy/addons/test_browser.py +++ b/test/mitmproxy/addons/test_browser.py @@ -5,7 +5,8 @@ from mitmproxy.test import taddons def test_browser(): - with mock.patch("subprocess.Popen") as po: + with mock.patch("subprocess.Popen") as po, mock.patch("shutil.which") as which: + which.return_value = "chrome" b = browser.Browser() with taddons.context() as tctx: b.start() @@ -18,3 +19,13 @@ def test_browser(): assert tctx.master.has_log("already running") b.done() assert not b.browser + + +def test_no_browser(): + with mock.patch("shutil.which") as which: + which.return_value = False + + b = browser.Browser() + with taddons.context() as tctx: + b.start() + assert tctx.master.has_log("platform is not supported") diff --git a/test/mitmproxy/addons/test_clientplayback.py b/test/mitmproxy/addons/test_clientplayback.py index 2dc7eb92..3f990668 100644 --- a/test/mitmproxy/addons/test_clientplayback.py +++ b/test/mitmproxy/addons/test_clientplayback.py @@ -52,6 +52,10 @@ class TestClientPlayback: cp.stop_replay() assert not cp.flows + df = tflow.DummyFlow(tflow.tclient_conn(), tflow.tserver_conn(), True) + with pytest.raises(exceptions.CommandError, match="Can't replay live flow."): + cp.start_replay([df]) + def test_load_file(self, tmpdir): cp = clientplayback.ClientPlayback() with taddons.context(): diff --git a/test/mitmproxy/addons/test_core.py b/test/mitmproxy/addons/test_core.py index c132d80a..5aa4ef37 100644 --- a/test/mitmproxy/addons/test_core.py +++ b/test/mitmproxy/addons/test_core.py @@ -69,9 +69,6 @@ def test_flow_set(): f = tflow.tflow(resp=True) assert sa.flow_set_options() - with pytest.raises(exceptions.CommandError): - sa.flow_set([f], "flibble", "post") - assert f.request.method != "post" sa.flow_set([f], "method", "post") assert f.request.method == "POST" @@ -126,9 +123,6 @@ def test_encoding(): sa.encode_toggle([f], "request") assert "content-encoding" not in f.request.headers - with pytest.raises(exceptions.CommandError): - sa.encode([f], "request", "invalid") - def test_options(tmpdir): p = str(tmpdir.join("path")) diff --git a/test/mitmproxy/addons/test_cut.py b/test/mitmproxy/addons/test_cut.py index 242c6c2f..56568f21 100644 --- a/test/mitmproxy/addons/test_cut.py +++ b/test/mitmproxy/addons/test_cut.py @@ -7,89 +7,58 @@ from mitmproxy.test import taddons from mitmproxy.test import tflow from mitmproxy.test import tutils import pytest +import pyperclip from unittest import mock def test_extract(): tf = tflow.tflow(resp=True) tests = [ - ["q.method", "GET"], - ["q.scheme", "http"], - ["q.host", "address"], - ["q.port", "22"], - ["q.path", "/path"], - ["q.url", "http://address:22/path"], - ["q.text", "content"], - ["q.content", b"content"], - ["q.raw_content", b"content"], - ["q.header[header]", "qvalue"], - - ["s.status_code", "200"], - ["s.reason", "OK"], - ["s.text", "message"], - ["s.content", b"message"], - ["s.raw_content", b"message"], - ["s.header[header-response]", "svalue"], - - ["cc.address.port", "22"], - ["cc.address.host", "127.0.0.1"], - ["cc.tls_version", "TLSv1.2"], - ["cc.sni", "address"], - ["cc.ssl_established", "false"], - - ["sc.address.port", "22"], - ["sc.address.host", "address"], - ["sc.ip_address.host", "192.168.0.1"], - ["sc.tls_version", "TLSv1.2"], - ["sc.sni", "address"], - ["sc.ssl_established", "false"], + ["request.method", "GET"], + ["request.scheme", "http"], + ["request.host", "address"], + ["request.http_version", "HTTP/1.1"], + ["request.port", "22"], + ["request.path", "/path"], + ["request.url", "http://address:22/path"], + ["request.text", "content"], + ["request.content", b"content"], + ["request.raw_content", b"content"], + ["request.timestamp_start", "946681200"], + ["request.timestamp_end", "946681201"], + ["request.header[header]", "qvalue"], + + ["response.status_code", "200"], + ["response.reason", "OK"], + ["response.text", "message"], + ["response.content", b"message"], + ["response.raw_content", b"message"], + ["response.header[header-response]", "svalue"], + ["response.timestamp_start", "946681202"], + ["response.timestamp_end", "946681203"], + + ["client_conn.address.port", "22"], + ["client_conn.address.host", "127.0.0.1"], + ["client_conn.tls_version", "TLSv1.2"], + ["client_conn.sni", "address"], + ["client_conn.tls_established", "false"], + + ["server_conn.address.port", "22"], + ["server_conn.address.host", "address"], + ["server_conn.ip_address.host", "192.168.0.1"], + ["server_conn.tls_version", "TLSv1.2"], + ["server_conn.sni", "address"], + ["server_conn.tls_established", "false"], ] - for t in tests: - ret = cut.extract(t[0], tf) - if ret != t[1]: - raise AssertionError("%s: Expected %s, got %s" % (t[0], t[1], ret)) + for spec, expected in tests: + ret = cut.extract(spec, tf) + assert spec and ret == expected with open(tutils.test_data.path("mitmproxy/net/data/text_cert"), "rb") as f: d = f.read() - c1 = certs.SSLCert.from_pem(d) + c1 = certs.Cert.from_pem(d) tf.server_conn.cert = c1 - assert "CERTIFICATE" in cut.extract("sc.cert", tf) - - -def test_parse_cutspec(): - tests = [ - ("", None, True), - ("req.method", ("@all", ["req.method"]), False), - ( - "req.method,req.host", - ("@all", ["req.method", "req.host"]), - False - ), - ( - "req.method,req.host|~b foo", - ("~b foo", ["req.method", "req.host"]), - False - ), - ( - "req.method,req.host|~b foo | ~b bar", - ("~b foo | ~b bar", ["req.method", "req.host"]), - False - ), - ( - "req.method, req.host | ~b foo | ~b bar", - ("~b foo | ~b bar", ["req.method", "req.host"]), - False - ), - ] - for cutspec, output, err in tests: - try: - assert cut.parse_cutspec(cutspec) == output - except exceptions.CommandError: - if not err: - raise - else: - if err: - raise AssertionError("Expected error.") + assert "CERTIFICATE" in cut.extract("server_conn.cert", tf) def test_headername(): @@ -110,69 +79,95 @@ def test_cut_clip(): v.add([tflow.tflow(resp=True)]) with mock.patch('pyperclip.copy') as pc: - tctx.command(c.clip, "q.method|@all") + tctx.command(c.clip, "@all", "request.method") assert pc.called with mock.patch('pyperclip.copy') as pc: - tctx.command(c.clip, "q.content|@all") + tctx.command(c.clip, "@all", "request.content") assert pc.called with mock.patch('pyperclip.copy') as pc: - tctx.command(c.clip, "q.method,q.content|@all") + tctx.command(c.clip, "@all", "request.method,request.content") assert pc.called + with mock.patch('pyperclip.copy') as pc: + log_message = "Pyperclip could not find a " \ + "copy/paste mechanism for your system." + pc.side_effect = pyperclip.PyperclipException(log_message) + tctx.command(c.clip, "@all", "request.method") + assert tctx.master.has_log(log_message, level="error") + -def test_cut_file(tmpdir): +def test_cut_save(tmpdir): f = str(tmpdir.join("path")) v = view.View() c = cut.Cut() with taddons.context() as tctx: tctx.master.addons.add(v, c) - v.add([tflow.tflow(resp=True)]) - tctx.command(c.save, "q.method|@all", f) + tctx.command(c.save, "@all", "request.method", f) assert qr(f) == b"GET" - tctx.command(c.save, "q.content|@all", f) + tctx.command(c.save, "@all", "request.content", f) assert qr(f) == b"content" - tctx.command(c.save, "q.content|@all", "+" + f) + tctx.command(c.save, "@all", "request.content", "+" + f) assert qr(f) == b"content\ncontent" v.add([tflow.tflow(resp=True)]) - tctx.command(c.save, "q.method|@all", f) + tctx.command(c.save, "@all", "request.method", f) assert qr(f).splitlines() == [b"GET", b"GET"] - tctx.command(c.save, "q.method,q.content|@all", f) + tctx.command(c.save, "@all", "request.method,request.content", f) assert qr(f).splitlines() == [b"GET,content", b"GET,content"] -def test_cut(): +@pytest.mark.parametrize("exception, log_message", [ + (PermissionError, "Permission denied"), + (IsADirectoryError, "Is a directory"), + (FileNotFoundError, "No such file or directory") +]) +def test_cut_save_open(exception, log_message, tmpdir): + f = str(tmpdir.join("path")) v = view.View() c = cut.Cut() with taddons.context() as tctx: - v.add([tflow.tflow(resp=True)]) tctx.master.addons.add(v, c) - assert c.cut("q.method|@all") == [["GET"]] - assert c.cut("q.scheme|@all") == [["http"]] - assert c.cut("q.host|@all") == [["address"]] - assert c.cut("q.port|@all") == [["22"]] - assert c.cut("q.path|@all") == [["/path"]] - assert c.cut("q.url|@all") == [["http://address:22/path"]] - assert c.cut("q.content|@all") == [[b"content"]] - assert c.cut("q.header[header]|@all") == [["qvalue"]] - assert c.cut("q.header[unknown]|@all") == [[""]] - - assert c.cut("s.status_code|@all") == [["200"]] - assert c.cut("s.reason|@all") == [["OK"]] - assert c.cut("s.content|@all") == [[b"message"]] - assert c.cut("s.header[header-response]|@all") == [["svalue"]] - assert c.cut("moo") == [[""]] + v.add([tflow.tflow(resp=True)]) + + with mock.patch("mitmproxy.addons.cut.open") as m: + m.side_effect = exception(log_message) + tctx.command(c.save, "@all", "request.method", f) + assert tctx.master.has_log(log_message, level="error") + + +def test_cut(): + c = cut.Cut() + with taddons.context(): + tflows = [tflow.tflow(resp=True)] + assert c.cut(tflows, ["request.method"]) == [["GET"]] + assert c.cut(tflows, ["request.scheme"]) == [["http"]] + assert c.cut(tflows, ["request.host"]) == [["address"]] + assert c.cut(tflows, ["request.port"]) == [["22"]] + assert c.cut(tflows, ["request.path"]) == [["/path"]] + assert c.cut(tflows, ["request.url"]) == [["http://address:22/path"]] + assert c.cut(tflows, ["request.content"]) == [[b"content"]] + assert c.cut(tflows, ["request.header[header]"]) == [["qvalue"]] + assert c.cut(tflows, ["request.header[unknown]"]) == [[""]] + + assert c.cut(tflows, ["response.status_code"]) == [["200"]] + assert c.cut(tflows, ["response.reason"]) == [["OK"]] + assert c.cut(tflows, ["response.content"]) == [[b"message"]] + assert c.cut(tflows, ["response.header[header-response]"]) == [["svalue"]] + assert c.cut(tflows, ["moo"]) == [[""]] with pytest.raises(exceptions.CommandError): - assert c.cut("__dict__") == [[""]] + assert c.cut(tflows, ["__dict__"]) == [[""]] + + with taddons.context(): + tflows = [tflow.tflow(resp=False)] + assert c.cut(tflows, ["response.reason"]) == [[""]] + assert c.cut(tflows, ["response.header[key]"]) == [[""]] - v = view.View() c = cut.Cut() - with taddons.context() as tctx: - tctx.master.addons.add(v, c) - v.add([tflow.ttcpflow()]) - assert c.cut("q.method|@all") == [[""]] - assert c.cut("s.status|@all") == [[""]] + with taddons.context(): + tflows = [tflow.ttcpflow()] + assert c.cut(tflows, ["request.method"]) == [[""]] + assert c.cut(tflows, ["response.status"]) == [[""]] diff --git a/test/mitmproxy/addons/test_eventstore.py b/test/mitmproxy/addons/test_eventstore.py index f54b9980..8ac26b05 100644 --- a/test/mitmproxy/addons/test_eventstore.py +++ b/test/mitmproxy/addons/test_eventstore.py @@ -30,3 +30,18 @@ def test_simple(): assert not sig_add.called assert sig_refresh.called + + +def test_max_size(): + store = eventstore.EventStore(3) + assert store.size == 3 + store.log(log.LogEntry("foo", "info")) + store.log(log.LogEntry("bar", "info")) + store.log(log.LogEntry("baz", "info")) + assert len(store.data) == 3 + assert ["foo", "bar", "baz"] == [x.msg for x in store.data] + + # overflow + store.log(log.LogEntry("boo", "info")) + assert len(store.data) == 3 + assert ["bar", "baz", "boo"] == [x.msg for x in store.data] diff --git a/test/mitmproxy/addons/test_export.py b/test/mitmproxy/addons/test_export.py index 233c62d5..07227a7a 100644 --- a/test/mitmproxy/addons/test_export.py +++ b/test/mitmproxy/addons/test_export.py @@ -1,6 +1,8 @@ -import pytest import os +import pytest +import pyperclip + from mitmproxy import exceptions from mitmproxy.addons import export # heh from mitmproxy.test import tflow @@ -94,9 +96,24 @@ def test_export(tmpdir): os.unlink(f) +@pytest.mark.parametrize("exception, log_message", [ + (PermissionError, "Permission denied"), + (IsADirectoryError, "Is a directory"), + (FileNotFoundError, "No such file or directory") +]) +def test_export_open(exception, log_message, tmpdir): + f = str(tmpdir.join("path")) + e = export.Export() + with taddons.context() as tctx: + with mock.patch("mitmproxy.addons.export.open") as m: + m.side_effect = exception(log_message) + e.file("raw", tflow.tflow(resp=True), f) + assert tctx.master.has_log(log_message, level="error") + + def test_clip(tmpdir): e = export.Export() - with taddons.context(): + with taddons.context() as tctx: with pytest.raises(exceptions.CommandError): e.clip("nonexistent", tflow.tflow(resp=True)) @@ -107,3 +124,10 @@ def test_clip(tmpdir): with mock.patch('pyperclip.copy') as pc: e.clip("curl", tflow.tflow(resp=True)) assert pc.called + + with mock.patch('pyperclip.copy') as pc: + log_message = "Pyperclip could not find a " \ + "copy/paste mechanism for your system." + pc.side_effect = pyperclip.PyperclipException(log_message) + e.clip("raw", tflow.tflow(resp=True)) + assert tctx.master.has_log(log_message, level="error") diff --git a/test/mitmproxy/addons/test_proxyauth.py b/test/mitmproxy/addons/test_proxyauth.py index 1d05e137..97259d1c 100644 --- a/test/mitmproxy/addons/test_proxyauth.py +++ b/test/mitmproxy/addons/test_proxyauth.py @@ -190,7 +190,7 @@ class TestProxyAuth: with pytest.raises(exceptions.OptionsError): ctx.configure(up, proxyauth="ldap:test:test:test") - with pytest.raises(IndexError): + with pytest.raises(exceptions.OptionsError): ctx.configure(up, proxyauth="ldap:fake_serveruid=?dc=example,dc=com:person") with pytest.raises(exceptions.OptionsError): diff --git a/test/mitmproxy/addons/test_save.py b/test/mitmproxy/addons/test_save.py index a4e425cd..2dee708f 100644 --- a/test/mitmproxy/addons/test_save.py +++ b/test/mitmproxy/addons/test_save.py @@ -44,6 +44,19 @@ def test_tcp(tmpdir): assert rd(p) +def test_websocket(tmpdir): + sa = save.Save() + with taddons.context() as tctx: + p = str(tmpdir.join("foo")) + tctx.configure(sa, save_stream_file=p) + + f = tflow.twebsocketflow() + sa.websocket_start(f) + sa.websocket_end(f) + tctx.configure(sa, save_stream_file=None) + assert rd(p) + + def test_save_command(tmpdir): sa = save.Save() with taddons.context() as tctx: diff --git a/test/mitmproxy/addons/test_script.py b/test/mitmproxy/addons/test_script.py index c4fe6b43..78a5be6c 100644 --- a/test/mitmproxy/addons/test_script.py +++ b/test/mitmproxy/addons/test_script.py @@ -68,6 +68,18 @@ class TestScript: with pytest.raises(exceptions.OptionsError): script.Script("nonexistent") + def test_quotes_around_filename(self): + """ + Test that a script specified as '"foo.py"' works to support the calling convention of + mitmproxy 2.0, as e.g. used by Cuckoo Sandbox. + """ + path = tutils.test_data.path("mitmproxy/data/addonscripts/recorder/recorder.py") + + s = script.Script( + '"{}"'.format(path) + ) + assert '"' not in s.fullpath + def test_simple(self): with taddons.context() as tctx: sc = script.Script( diff --git a/test/mitmproxy/addons/test_view.py b/test/mitmproxy/addons/test_view.py index 1e0c3b55..6f2a9ca5 100644 --- a/test/mitmproxy/addons/test_view.py +++ b/test/mitmproxy/addons/test_view.py @@ -30,7 +30,7 @@ def test_order_refresh(): with taddons.context() as tctx: tctx.configure(v, view_order="time") v.add([tf]) - tf.request.timestamp_start = 1 + tf.request.timestamp_start = 10 assert not sargs v.update([tf]) assert sargs @@ -41,7 +41,7 @@ def test_order_generators(): tf = tflow.tflow(resp=True) rs = view.OrderRequestStart(v) - assert rs.generate(tf) == 0 + assert rs.generate(tf) == 946681200 rm = view.OrderRequestMethod(v) assert rm.generate(tf) == tf.request.method @@ -147,6 +147,10 @@ def test_create(): assert v[0].request.url == "http://foo.com/" v.create("get", "http://foo.com") assert len(v) == 2 + with pytest.raises(exceptions.CommandError, match="Invalid URL"): + v.create("get", "http://foo.com\\") + with pytest.raises(exceptions.CommandError, match="Invalid URL"): + v.create("get", "http://") def test_orders(): @@ -175,6 +179,10 @@ def test_load(tmpdir): v.load_file("nonexistent_file_path") except IOError: assert False + with open(path, "wb") as f: + f.write(b"invalidflows") + v.load_file(path) + assert tctx.master.has_log("Invalid data format.") def test_resolve(): diff --git a/test/mitmproxy/contentviews/test_auto.py b/test/mitmproxy/contentviews/test_auto.py index 2ff43139..cd888a2d 100644 --- a/test/mitmproxy/contentviews/test_auto.py +++ b/test/mitmproxy/contentviews/test_auto.py @@ -1,6 +1,6 @@ from mitmproxy.contentviews import auto from mitmproxy.net import http -from mitmproxy.types import multidict +from mitmproxy.coretypes import multidict from . import full_eval diff --git a/test/mitmproxy/contentviews/test_base.py b/test/mitmproxy/contentviews/test_base.py index 777ab4dd..c94d8be2 100644 --- a/test/mitmproxy/contentviews/test_base.py +++ b/test/mitmproxy/contentviews/test_base.py @@ -1 +1,17 @@ -# TODO: write tests +import pytest +from mitmproxy.contentviews import base + + +def test_format_dict(): + d = {"one": "two", "three": "four"} + f_d = base.format_dict(d) + assert next(f_d) + + d = {"adsfa": ""} + f_d = base.format_dict(d) + assert next(f_d) + + d = {} + f_d = base.format_dict(d) + with pytest.raises(StopIteration): + next(f_d) diff --git a/test/mitmproxy/contentviews/test_query.py b/test/mitmproxy/contentviews/test_query.py index d2bddd05..741b23f1 100644 --- a/test/mitmproxy/contentviews/test_query.py +++ b/test/mitmproxy/contentviews/test_query.py @@ -1,5 +1,5 @@ from mitmproxy.contentviews import query -from mitmproxy.types import multidict +from mitmproxy.coretypes import multidict from . import full_eval diff --git a/test/mitmproxy/coretypes/__init__.py b/test/mitmproxy/coretypes/__init__.py new file mode 100644 index 00000000..e69de29b --- /dev/null +++ b/test/mitmproxy/coretypes/__init__.py diff --git a/test/mitmproxy/types/test_basethread.py b/test/mitmproxy/coretypes/test_basethread.py index a91588eb..4a383fea 100644 --- a/test/mitmproxy/types/test_basethread.py +++ b/test/mitmproxy/coretypes/test_basethread.py @@ -1,5 +1,5 @@ import re -from mitmproxy.types import basethread +from mitmproxy.coretypes import basethread def test_basethread(): diff --git a/test/mitmproxy/types/test_bidi.py b/test/mitmproxy/coretypes/test_bidi.py index e3a259fd..3bdad3c2 100644 --- a/test/mitmproxy/types/test_bidi.py +++ b/test/mitmproxy/coretypes/test_bidi.py @@ -1,5 +1,5 @@ import pytest -from mitmproxy.types import bidi +from mitmproxy.coretypes import bidi def test_bidi(): diff --git a/test/mitmproxy/types/test_multidict.py b/test/mitmproxy/coretypes/test_multidict.py index c76cd753..273d8ca2 100644 --- a/test/mitmproxy/types/test_multidict.py +++ b/test/mitmproxy/coretypes/test_multidict.py @@ -1,6 +1,6 @@ import pytest -from mitmproxy.types import multidict +from mitmproxy.coretypes import multidict class _TMulti: diff --git a/test/mitmproxy/types/test_serializable.py b/test/mitmproxy/coretypes/test_serializable.py index 390d17e1..a316f876 100644 --- a/test/mitmproxy/types/test_serializable.py +++ b/test/mitmproxy/coretypes/test_serializable.py @@ -1,6 +1,6 @@ import copy -from mitmproxy.types import serializable +from mitmproxy.coretypes import serializable class SerializableDummy(serializable.Serializable): diff --git a/test/mitmproxy/data/addonscripts/concurrent_decorator_class.py b/test/mitmproxy/data/addonscripts/concurrent_decorator_class.py index 2a7d300c..b52f55c5 100644 --- a/test/mitmproxy/data/addonscripts/concurrent_decorator_class.py +++ b/test/mitmproxy/data/addonscripts/concurrent_decorator_class.py @@ -5,7 +5,7 @@ from mitmproxy.script import concurrent class ConcurrentClass: @concurrent - def request(flow): + def request(self, flow): time.sleep(0.1) diff --git a/test/mitmproxy/net/http/test_cookies.py b/test/mitmproxy/net/http/test_cookies.py index 77549d9e..e12b0f00 100644 --- a/test/mitmproxy/net/http/test_cookies.py +++ b/test/mitmproxy/net/http/test_cookies.py @@ -7,6 +7,10 @@ from mitmproxy.net.http import cookies cookie_pairs = [ [ + "=uno", + [["", "uno"]] + ], + [ "", [] ], @@ -16,7 +20,7 @@ cookie_pairs = [ ], [ "one", - [["one", None]] + [["one", ""]] ], [ "one=uno; two=due", @@ -36,7 +40,7 @@ cookie_pairs = [ ], [ "one=uno; two; three=tre", - [["one", "uno"], ["two", None], ["three", "tre"]] + [["one", "uno"], ["two", ""], ["three", "tre"]] ], [ "_lvs2=zHai1+Hq+Tc2vmc2r4GAbdOI5Jopg3EwsdUT9g=; " @@ -79,8 +83,12 @@ def test_read_quoted_string(): def test_read_cookie_pairs(): vals = [ [ + "=uno", + [["", "uno"]] + ], + [ "one", - [["one", None]] + [["one", ""]] ], [ "one=two", @@ -100,7 +108,7 @@ def test_read_cookie_pairs(): ], [ 'one="two"; three=four; five', - [["one", "two"], ["three", "four"], ["five", None]] + [["one", "two"], ["three", "four"], ["five", ""]] ], [ 'one="\\"two"; three=four', @@ -135,6 +143,12 @@ def test_cookie_roundtrips(): def test_parse_set_cookie_pairs(): pairs = [ [ + "=uno", + [[ + ["", "uno"] + ]] + ], + [ "one=uno", [[ ["one", "uno"] @@ -150,7 +164,7 @@ def test_parse_set_cookie_pairs(): "one=uno; foo", [[ ["one", "uno"], - ["foo", None] + ["foo", ""] ]] ], [ @@ -200,6 +214,12 @@ def test_parse_set_cookie_header(): ";", [] ], [ + "=uno", + [ + ("", "uno", ()) + ] + ], + [ "one=uno", [ ("one", "uno", ()) diff --git a/test/mitmproxy/net/http/test_response.py b/test/mitmproxy/net/http/test_response.py index fa1770fe..f3470384 100644 --- a/test/mitmproxy/net/http/test_response.py +++ b/test/mitmproxy/net/http/test_response.py @@ -113,7 +113,7 @@ class TestResponseUtils: assert attrs["domain"] == "example.com" assert attrs["expires"] == "Wed Oct 21 16:29:41 2015" assert attrs["path"] == "/" - assert attrs["httponly"] is None + assert attrs["httponly"] == "" def test_get_cookies_no_value(self): resp = tresp() @@ -150,10 +150,10 @@ class TestResponseUtils: n = time.time() r.headers["date"] = email.utils.formatdate(n) pre = r.headers["date"] - r.refresh(n) + r.refresh(946681202) assert pre == r.headers["date"] - r.refresh(n + 60) + r.refresh(946681262) d = email.utils.parsedate_tz(r.headers["date"]) d = email.utils.mktime_tz(d) # Weird that this is not exact... diff --git a/test/mitmproxy/net/http/test_url.py b/test/mitmproxy/net/http/test_url.py index 2064aab8..c9f61faf 100644 --- a/test/mitmproxy/net/http/test_url.py +++ b/test/mitmproxy/net/http/test_url.py @@ -108,6 +108,7 @@ def test_empty_key_trailing_equal_sign(): def test_encode(): assert url.encode([('foo', 'bar')]) assert url.encode([('foo', surrogates)]) + assert not url.encode([], similar_to="justatext") def test_decode(): diff --git a/test/mitmproxy/net/test_tcp.py b/test/mitmproxy/net/test_tcp.py index 3e27929d..8c012e42 100644 --- a/test/mitmproxy/net/test_tcp.py +++ b/test/mitmproxy/net/test_tcp.py @@ -1,4 +1,5 @@ from io import BytesIO +import re import queue import time import socket @@ -95,7 +96,13 @@ class TestServerBind(tservers.ServerTestBase): class handler(tcp.BaseHandler): def handle(self): - self.wfile.write(str(self.connection.getpeername()).encode()) + # We may get an ipv4-mapped ipv6 address here, e.g. ::ffff:127.0.0.1. + # Those still appear as "127.0.0.1" in the table, so we need to strip the prefix. + peername = self.connection.getpeername() + address = re.sub("^::ffff:(?=\d+.\d+.\d+.\d+$)", "", peername[0]) + port = peername[1] + + self.wfile.write(str((address, port)).encode()) self.wfile.flush() def test_bind(self): @@ -171,7 +178,7 @@ class TestServerSSL(tservers.ServerTestBase): def test_echo(self): c = tcp.TCPClient(("127.0.0.1", self.port)) with c.connect(): - c.convert_to_ssl(sni="foo.com", options=SSL.OP_ALL) + c.convert_to_tls(sni="foo.com", options=SSL.OP_ALL) testval = b"echo!\n" c.wfile.write(testval) c.wfile.flush() @@ -181,7 +188,7 @@ class TestServerSSL(tservers.ServerTestBase): c = tcp.TCPClient(("127.0.0.1", self.port)) with c.connect(): assert not c.get_current_cipher() - c.convert_to_ssl(sni="foo.com") + c.convert_to_tls(sni="foo.com") ret = c.get_current_cipher() assert ret assert "AES" in ret[0] @@ -198,7 +205,7 @@ class TestSSLv3Only(tservers.ServerTestBase): c = tcp.TCPClient(("127.0.0.1", self.port)) with c.connect(): with pytest.raises(exceptions.TlsException): - c.convert_to_ssl(sni="foo.com") + c.convert_to_tls(sni="foo.com") class TestInvalidTrustFile(tservers.ServerTestBase): @@ -206,7 +213,7 @@ class TestInvalidTrustFile(tservers.ServerTestBase): c = tcp.TCPClient(("127.0.0.1", self.port)) with c.connect(): with pytest.raises(exceptions.TlsException): - c.convert_to_ssl( + c.convert_to_tls( sni="example.mitmproxy.org", verify=SSL.VERIFY_PEER, ca_pemfile=tutils.test_data.path("mitmproxy/net/data/verificationcerts/generate.py") @@ -224,7 +231,7 @@ class TestSSLUpstreamCertVerificationWBadServerCert(tservers.ServerTestBase): def test_mode_default_should_pass(self): c = tcp.TCPClient(("127.0.0.1", self.port)) with c.connect(): - c.convert_to_ssl() + c.convert_to_tls() # Verification errors should be saved even if connection isn't aborted # aborted @@ -238,7 +245,7 @@ class TestSSLUpstreamCertVerificationWBadServerCert(tservers.ServerTestBase): def test_mode_none_should_pass(self): c = tcp.TCPClient(("127.0.0.1", self.port)) with c.connect(): - c.convert_to_ssl(verify=SSL.VERIFY_NONE) + c.convert_to_tls(verify=SSL.VERIFY_NONE) # Verification errors should be saved even if connection isn't aborted assert c.ssl_verification_error @@ -252,7 +259,7 @@ class TestSSLUpstreamCertVerificationWBadServerCert(tservers.ServerTestBase): c = tcp.TCPClient(("127.0.0.1", self.port)) with c.connect(): with pytest.raises(exceptions.InvalidCertificateException): - c.convert_to_ssl( + c.convert_to_tls( sni="example.mitmproxy.org", verify=SSL.VERIFY_PEER, ca_pemfile=tutils.test_data.path("mitmproxy/net/data/verificationcerts/trusted-root.crt") @@ -277,7 +284,7 @@ class TestSSLUpstreamCertVerificationWBadHostname(tservers.ServerTestBase): c = tcp.TCPClient(("127.0.0.1", self.port)) with c.connect(): with pytest.raises(exceptions.TlsException): - c.convert_to_ssl( + c.convert_to_tls( verify=SSL.VERIFY_PEER, ca_pemfile=tutils.test_data.path("mitmproxy/net/data/verificationcerts/trusted-root.crt") ) @@ -285,7 +292,7 @@ class TestSSLUpstreamCertVerificationWBadHostname(tservers.ServerTestBase): def test_mode_none_should_pass_without_sni(self): c = tcp.TCPClient(("127.0.0.1", self.port)) with c.connect(): - c.convert_to_ssl( + c.convert_to_tls( verify=SSL.VERIFY_NONE, ca_path=tutils.test_data.path("mitmproxy/net/data/verificationcerts/") ) @@ -296,7 +303,7 @@ class TestSSLUpstreamCertVerificationWBadHostname(tservers.ServerTestBase): c = tcp.TCPClient(("127.0.0.1", self.port)) with c.connect(): with pytest.raises(exceptions.InvalidCertificateException): - c.convert_to_ssl( + c.convert_to_tls( sni="mitmproxy.org", verify=SSL.VERIFY_PEER, ca_pemfile=tutils.test_data.path("mitmproxy/net/data/verificationcerts/trusted-root.crt") @@ -315,7 +322,7 @@ class TestSSLUpstreamCertVerificationWValidCertChain(tservers.ServerTestBase): def test_mode_strict_w_pemfile_should_pass(self): c = tcp.TCPClient(("127.0.0.1", self.port)) with c.connect(): - c.convert_to_ssl( + c.convert_to_tls( sni="example.mitmproxy.org", verify=SSL.VERIFY_PEER, ca_pemfile=tutils.test_data.path("mitmproxy/net/data/verificationcerts/trusted-root.crt") @@ -331,7 +338,7 @@ class TestSSLUpstreamCertVerificationWValidCertChain(tservers.ServerTestBase): def test_mode_strict_w_cadir_should_pass(self): c = tcp.TCPClient(("127.0.0.1", self.port)) with c.connect(): - c.convert_to_ssl( + c.convert_to_tls( sni="example.mitmproxy.org", verify=SSL.VERIFY_PEER, ca_path=tutils.test_data.path("mitmproxy/net/data/verificationcerts/") @@ -365,7 +372,7 @@ class TestSSLClientCert(tservers.ServerTestBase): def test_clientcert(self): c = tcp.TCPClient(("127.0.0.1", self.port)) with c.connect(): - c.convert_to_ssl( + c.convert_to_tls( cert=tutils.test_data.path("mitmproxy/net/data/clientcert/client.pem")) assert c.rfile.readline().strip() == b"1" @@ -373,7 +380,7 @@ class TestSSLClientCert(tservers.ServerTestBase): c = tcp.TCPClient(("127.0.0.1", self.port)) with c.connect(): with pytest.raises(exceptions.TlsException): - c.convert_to_ssl(cert=tutils.test_data.path("mitmproxy/net/data/clientcert/make")) + c.convert_to_tls(cert=tutils.test_data.path("mitmproxy/net/data/clientcert/make")) class TestSNI(tservers.ServerTestBase): @@ -393,15 +400,15 @@ class TestSNI(tservers.ServerTestBase): def test_echo(self): c = tcp.TCPClient(("127.0.0.1", self.port)) with c.connect(): - c.convert_to_ssl(sni="foo.com") + c.convert_to_tls(sni="foo.com") assert c.sni == "foo.com" assert c.rfile.readline() == b"foo.com" def test_idn(self): c = tcp.TCPClient(("127.0.0.1", self.port)) with c.connect(): - c.convert_to_ssl(sni="mitmproxyäöüß.example.com") - assert c.ssl_established + c.convert_to_tls(sni="mitmproxyäöüß.example.com") + assert c.tls_established assert "doesn't match" not in str(c.ssl_verification_error) @@ -414,7 +421,7 @@ class TestServerCipherList(tservers.ServerTestBase): def test_echo(self): c = tcp.TCPClient(("127.0.0.1", self.port)) with c.connect(): - c.convert_to_ssl(sni="foo.com") + c.convert_to_tls(sni="foo.com") expected = b"['AES256-GCM-SHA384']" assert c.rfile.read(len(expected) + 2) == expected @@ -435,7 +442,7 @@ class TestServerCurrentCipher(tservers.ServerTestBase): def test_echo(self): c = tcp.TCPClient(("127.0.0.1", self.port)) with c.connect(): - c.convert_to_ssl(sni="foo.com") + c.convert_to_tls(sni="foo.com") assert b'AES256-GCM-SHA384' in c.rfile.readline() @@ -449,7 +456,7 @@ class TestServerCipherListError(tservers.ServerTestBase): c = tcp.TCPClient(("127.0.0.1", self.port)) with c.connect(): with pytest.raises(Exception, match="handshake error"): - c.convert_to_ssl(sni="foo.com") + c.convert_to_tls(sni="foo.com") class TestClientCipherListError(tservers.ServerTestBase): @@ -462,7 +469,7 @@ class TestClientCipherListError(tservers.ServerTestBase): c = tcp.TCPClient(("127.0.0.1", self.port)) with c.connect(): with pytest.raises(Exception, match="cipher specification"): - c.convert_to_ssl(sni="foo.com", cipher_list="bogus") + c.convert_to_tls(sni="foo.com", cipher_list="bogus") class TestSSLDisconnect(tservers.ServerTestBase): @@ -477,7 +484,7 @@ class TestSSLDisconnect(tservers.ServerTestBase): def test_echo(self): c = tcp.TCPClient(("127.0.0.1", self.port)) with c.connect(): - c.convert_to_ssl() + c.convert_to_tls() # Excercise SSL.ZeroReturnError c.rfile.read(10) c.close() @@ -494,7 +501,7 @@ class TestSSLHardDisconnect(tservers.ServerTestBase): def test_echo(self): c = tcp.TCPClient(("127.0.0.1", self.port)) with c.connect(): - c.convert_to_ssl() + c.convert_to_tls() # Exercise SSL.SysCallError c.rfile.read(10) c.close() @@ -558,7 +565,7 @@ class TestALPNClient(tservers.ServerTestBase): def test_alpn(self, monkeypatch, alpn_protos, expected_negotiated, expected_response): c = tcp.TCPClient(("127.0.0.1", self.port)) with c.connect(): - c.convert_to_ssl(alpn_protos=alpn_protos) + c.convert_to_tls(alpn_protos=alpn_protos) assert c.get_alpn_proto_negotiated() == expected_negotiated assert c.rfile.readline().strip() == expected_response @@ -580,7 +587,7 @@ class TestSSLTimeOut(tservers.ServerTestBase): def test_timeout_client(self): c = tcp.TCPClient(("127.0.0.1", self.port)) with c.connect(): - c.convert_to_ssl() + c.convert_to_tls() c.settimeout(0.1) with pytest.raises(exceptions.TcpTimeout): c.rfile.read(10) @@ -598,7 +605,7 @@ class TestDHParams(tservers.ServerTestBase): def test_dhparams(self): c = tcp.TCPClient(("127.0.0.1", self.port)) with c.connect(): - c.convert_to_ssl() + c.convert_to_tls() ret = c.get_current_cipher() assert ret[0] == "DHE-RSA-AES256-SHA" @@ -794,5 +801,5 @@ class TestPeekSSL(TestPeek): def _connect(self, c): with c.connect() as conn: - c.convert_to_ssl() + c.convert_to_tls() return conn.pop() diff --git a/test/mitmproxy/net/test_tls.py b/test/mitmproxy/net/test_tls.py index d0583d34..489bf89f 100644 --- a/test/mitmproxy/net/test_tls.py +++ b/test/mitmproxy/net/test_tls.py @@ -1,3 +1,5 @@ +import io + import pytest from mitmproxy import exceptions @@ -6,6 +8,17 @@ from mitmproxy.net.tcp import TCPClient from test.mitmproxy.net.test_tcp import EchoHandler from . import tservers +CLIENT_HELLO_NO_EXTENSIONS = bytes.fromhex( + "03015658a756ab2c2bff55f636814deac086b7ca56b65058c7893ffc6074f5245f70205658a75475103a152637" + "78e1bb6d22e8bbd5b6b0a3a59760ad354e91ba20d353001a0035002f000a000500040009000300060008006000" + "61006200640100" +) +FULL_CLIENT_HELLO_NO_EXTENSIONS = ( + b"\x16\x03\x03\x00\x65" # record layer + b"\x01\x00\x00\x61" + # handshake header + CLIENT_HELLO_NO_EXTENSIONS +) + class TestMasterSecretLogger(tservers.ServerTestBase): handler = EchoHandler @@ -22,7 +35,7 @@ class TestMasterSecretLogger(tservers.ServerTestBase): c = TCPClient(("127.0.0.1", self.port)) with c.connect(): - c.convert_to_ssl() + c.convert_to_tls() c.wfile.write(testval) c.wfile.flush() assert c.rfile.readline() == testval @@ -53,3 +66,92 @@ class TestTLSInvalid: with pytest.raises(exceptions.TlsException, match="ALPN error"): tls.create_client_context(alpn_select="foo", alpn_select_callback="bar") + + +def test_is_record_magic(): + assert not tls.is_tls_record_magic(b"POST /") + assert not tls.is_tls_record_magic(b"\x16\x03") + assert not tls.is_tls_record_magic(b"\x16\x03\x04") + assert tls.is_tls_record_magic(b"\x16\x03\x00") + assert tls.is_tls_record_magic(b"\x16\x03\x01") + assert tls.is_tls_record_magic(b"\x16\x03\x02") + assert tls.is_tls_record_magic(b"\x16\x03\x03") + + +def test_get_client_hello(): + rfile = io.BufferedReader(io.BytesIO( + FULL_CLIENT_HELLO_NO_EXTENSIONS + )) + assert tls.get_client_hello(rfile) + + rfile = io.BufferedReader(io.BytesIO( + FULL_CLIENT_HELLO_NO_EXTENSIONS[:30] + )) + with pytest.raises(exceptions.TlsProtocolException, message="Unexpected EOF"): + tls.get_client_hello(rfile) + + rfile = io.BufferedReader(io.BytesIO( + b"GET /" + )) + with pytest.raises(exceptions.TlsProtocolException, message="Expected TLS record"): + tls.get_client_hello(rfile) + + +class TestClientHello: + def test_no_extensions(self): + c = tls.ClientHello(CLIENT_HELLO_NO_EXTENSIONS) + assert repr(c) + assert c.sni is None + assert c.cipher_suites == [53, 47, 10, 5, 4, 9, 3, 6, 8, 96, 97, 98, 100] + assert c.alpn_protocols == [] + assert c.extensions == [] + + def test_extensions(self): + data = bytes.fromhex( + "03033b70638d2523e1cba15f8364868295305e9c52aceabda4b5147210abc783e6e1000022c02bc02fc02cc030" + "cca9cca8cc14cc13c009c013c00ac014009c009d002f0035000a0100006cff0100010000000010000e00000b65" + "78616d706c652e636f6d0017000000230000000d00120010060106030501050304010403020102030005000501" + "00000000001200000010000e000c02683208687474702f312e3175500000000b00020100000a00080006001d00" + "170018" + ) + c = tls.ClientHello(data) + assert repr(c) + assert c.sni == 'example.com' + assert c.cipher_suites == [ + 49195, 49199, 49196, 49200, 52393, 52392, 52244, 52243, 49161, + 49171, 49162, 49172, 156, 157, 47, 53, 10 + ] + assert c.alpn_protocols == [b'h2', b'http/1.1'] + assert c.extensions == [ + (65281, b'\x00'), + (0, b'\x00\x0e\x00\x00\x0bexample.com'), + (23, b''), + (35, b''), + (13, b'\x00\x10\x06\x01\x06\x03\x05\x01\x05\x03\x04\x01\x04\x03\x02\x01\x02\x03'), + (5, b'\x01\x00\x00\x00\x00'), + (18, b''), + (16, b'\x00\x0c\x02h2\x08http/1.1'), + (30032, b''), + (11, b'\x01\x00'), + (10, b'\x00\x06\x00\x1d\x00\x17\x00\x18') + ] + + def test_from_file(self): + rfile = io.BufferedReader(io.BytesIO( + FULL_CLIENT_HELLO_NO_EXTENSIONS + )) + assert tls.ClientHello.from_file(rfile) + + rfile = io.BufferedReader(io.BytesIO( + b"" + )) + with pytest.raises(exceptions.TlsProtocolException): + tls.ClientHello.from_file(rfile) + + rfile = io.BufferedReader(io.BytesIO( + b"\x16\x03\x03\x00\x07" # record layer + b"\x01\x00\x00\x03" + # handshake header + b"foo" + )) + with pytest.raises(exceptions.TlsProtocolException, message='Cannot parse Client Hello'): + tls.ClientHello.from_file(rfile) diff --git a/test/mitmproxy/net/tools/getcertnames b/test/mitmproxy/net/tools/getcertnames index d64e5ff5..9349415f 100644 --- a/test/mitmproxy/net/tools/getcertnames +++ b/test/mitmproxy/net/tools/getcertnames @@ -7,7 +7,7 @@ from mitmproxy.net import tcp def get_remote_cert(host, port, sni): c = tcp.TCPClient((host, port)) c.connect() - c.convert_to_ssl(sni=sni) + c.convert_to_tls(sni=sni) return c.cert if len(sys.argv) > 2: diff --git a/test/mitmproxy/net/tservers.py b/test/mitmproxy/net/tservers.py index 44701aa5..22e195e3 100644 --- a/test/mitmproxy/net/tservers.py +++ b/test/mitmproxy/net/tservers.py @@ -60,7 +60,7 @@ class _TServer(tcp.TCPServer): else: method = OpenSSL.SSL.SSLv23_METHOD options = None - h.convert_to_ssl( + h.convert_to_tls( cert, key, method=method, diff --git a/test/mitmproxy/platform/test_pf.py b/test/mitmproxy/platform/test_pf.py index 3292d345..b048a697 100644 --- a/test/mitmproxy/platform/test_pf.py +++ b/test/mitmproxy/platform/test_pf.py @@ -15,6 +15,7 @@ class TestLookup: d = f.read() assert pf.lookup("192.168.1.111", 40000, d) == ("5.5.5.5", 80) + assert pf.lookup("::ffff:192.168.1.111", 40000, d) == ("5.5.5.5", 80) with pytest.raises(Exception, match="Could not resolve original destination"): pf.lookup("192.168.1.112", 40000, d) with pytest.raises(Exception, match="Could not resolve original destination"): diff --git a/test/mitmproxy/proxy/protocol/test_http2.py b/test/mitmproxy/proxy/protocol/test_http2.py index 4f161ef5..194a57c9 100644 --- a/test/mitmproxy/proxy/protocol/test_http2.py +++ b/test/mitmproxy/proxy/protocol/test_http2.py @@ -141,7 +141,7 @@ class _Http2TestBase: while self.client.rfile.readline() != b"\r\n": pass - self.client.convert_to_ssl(alpn_protos=[b'h2']) + self.client.convert_to_tls(alpn_protos=[b'h2']) config = h2.config.H2Configuration( client_side=True, diff --git a/test/mitmproxy/proxy/protocol/test_tls.py b/test/mitmproxy/proxy/protocol/test_tls.py index e17ee46f..e69de29b 100644 --- a/test/mitmproxy/proxy/protocol/test_tls.py +++ b/test/mitmproxy/proxy/protocol/test_tls.py @@ -1,26 +0,0 @@ -from mitmproxy.proxy.protocol.tls import TlsClientHello - - -class TestClientHello: - - def test_no_extensions(self): - data = bytes.fromhex( - "03015658a756ab2c2bff55f636814deac086b7ca56b65058c7893ffc6074f5245f70205658a75475103a152637" - "78e1bb6d22e8bbd5b6b0a3a59760ad354e91ba20d353001a0035002f000a000500040009000300060008006000" - "61006200640100" - ) - c = TlsClientHello(data) - assert c.sni is None - assert c.alpn_protocols == [] - - def test_extensions(self): - data = bytes.fromhex( - "03033b70638d2523e1cba15f8364868295305e9c52aceabda4b5147210abc783e6e1000022c02bc02fc02cc030" - "cca9cca8cc14cc13c009c013c00ac014009c009d002f0035000a0100006cff0100010000000010000e00000b65" - "78616d706c652e636f6d0017000000230000000d00120010060106030501050304010403020102030005000501" - "00000000001200000010000e000c02683208687474702f312e3175500000000b00020100000a00080006001d00" - "170018" - ) - c = TlsClientHello(data) - assert c.sni == 'example.com' - assert c.alpn_protocols == [b'h2', b'http/1.1'] diff --git a/test/mitmproxy/proxy/protocol/test_websocket.py b/test/mitmproxy/proxy/protocol/test_websocket.py index 460d85f8..5cd9601c 100644 --- a/test/mitmproxy/proxy/protocol/test_websocket.py +++ b/test/mitmproxy/proxy/protocol/test_websocket.py @@ -1,5 +1,6 @@ import pytest import os +import struct import tempfile import traceback @@ -33,6 +34,7 @@ class _WebSocketServerBase(net_tservers.ServerTestBase): connection='upgrade', upgrade='websocket', sec_websocket_accept=b'', + sec_websocket_extensions='permessage-deflate' if "permessage-deflate" in request.headers.values() else '' ), content=b'', ) @@ -80,7 +82,7 @@ class _WebSocketTestBase: if self.client: self.client.close() - def setup_connection(self): + def setup_connection(self, extension=False): self.client = tcp.TCPClient(("127.0.0.1", self.proxy.port)) self.client.connect() @@ -99,8 +101,8 @@ class _WebSocketTestBase: response = http.http1.read_response(self.client.rfile, request) if self.ssl: - self.client.convert_to_ssl() - assert self.client.ssl_established + self.client.convert_to_tls() + assert self.client.tls_established request = http.Request( "relative", @@ -115,6 +117,7 @@ class _WebSocketTestBase: upgrade="websocket", sec_websocket_version="13", sec_websocket_key="1234", + sec_websocket_extensions="permessage-deflate" if extension else "" ), content=b'') self.client.wfile.write(http.http1.assemble_request(request)) @@ -145,11 +148,11 @@ class TestSimple(_WebSocketTest): wfile.flush() frame = websockets.Frame.from_file(rfile) - wfile.write(bytes(frame)) + wfile.write(bytes(websockets.Frame(fin=1, opcode=frame.header.opcode, payload=frame.payload))) wfile.flush() frame = websockets.Frame.from_file(rfile) - wfile.write(bytes(frame)) + wfile.write(bytes(websockets.Frame(fin=1, opcode=frame.header.opcode, payload=frame.payload))) wfile.flush() @pytest.mark.parametrize('streaming', [True, False]) @@ -164,36 +167,78 @@ class TestSimple(_WebSocketTest): frame = websockets.Frame.from_file(self.client.rfile) assert frame.payload == b'server-foobar' - self.client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.TEXT, payload=b'self.client-foobar'))) + self.client.wfile.write(bytes(websockets.Frame(fin=1, mask=1, opcode=websockets.OPCODE.TEXT, payload=b'self.client-foobar'))) self.client.wfile.flush() frame = websockets.Frame.from_file(self.client.rfile) assert frame.payload == b'self.client-foobar' - self.client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.BINARY, payload=b'\xde\xad\xbe\xef'))) + self.client.wfile.write(bytes(websockets.Frame(fin=1, mask=1, opcode=websockets.OPCODE.BINARY, payload=b'\xde\xad\xbe\xef'))) self.client.wfile.flush() frame = websockets.Frame.from_file(self.client.rfile) assert frame.payload == b'\xde\xad\xbe\xef' - self.client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.CLOSE))) + self.client.wfile.write(bytes(websockets.Frame(fin=1, mask=1, opcode=websockets.OPCODE.CLOSE))) self.client.wfile.flush() assert len(self.master.state.flows) == 2 assert isinstance(self.master.state.flows[0], HTTPFlow) assert isinstance(self.master.state.flows[1], WebSocketFlow) assert len(self.master.state.flows[1].messages) == 5 - assert self.master.state.flows[1].messages[0].content == b'server-foobar' + assert self.master.state.flows[1].messages[0].content == 'server-foobar' assert self.master.state.flows[1].messages[0].type == websockets.OPCODE.TEXT - assert self.master.state.flows[1].messages[1].content == b'self.client-foobar' + assert self.master.state.flows[1].messages[1].content == 'self.client-foobar' assert self.master.state.flows[1].messages[1].type == websockets.OPCODE.TEXT - assert self.master.state.flows[1].messages[2].content == b'self.client-foobar' + assert self.master.state.flows[1].messages[2].content == 'self.client-foobar' assert self.master.state.flows[1].messages[2].type == websockets.OPCODE.TEXT assert self.master.state.flows[1].messages[3].content == b'\xde\xad\xbe\xef' assert self.master.state.flows[1].messages[3].type == websockets.OPCODE.BINARY assert self.master.state.flows[1].messages[4].content == b'\xde\xad\xbe\xef' assert self.master.state.flows[1].messages[4].type == websockets.OPCODE.BINARY + def test_change_payload(self): + class Addon: + def websocket_message(self, f): + f.messages[-1].content = "foo" + + self.master.addons.add(Addon()) + self.setup_connection() + + frame = websockets.Frame.from_file(self.client.rfile) + assert frame.payload == b'foo' + + self.client.wfile.write(bytes(websockets.Frame(fin=1, mask=1, opcode=websockets.OPCODE.TEXT, payload=b'self.client-foobar'))) + self.client.wfile.flush() + + frame = websockets.Frame.from_file(self.client.rfile) + assert frame.payload == b'foo' + + self.client.wfile.write(bytes(websockets.Frame(fin=1, mask=1, opcode=websockets.OPCODE.BINARY, payload=b'\xde\xad\xbe\xef'))) + self.client.wfile.flush() + + frame = websockets.Frame.from_file(self.client.rfile) + assert frame.payload == b'foo' + + +class TestKillFlow(_WebSocketTest): + + @classmethod + def handle_websockets(cls, rfile, wfile): + wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.TEXT, payload=b'server-foobar'))) + wfile.flush() + + def test_kill(self): + class KillFlow: + def websocket_message(self, f): + f.kill() + + self.master.addons.add(KillFlow()) + self.setup_connection() + + with pytest.raises(exceptions.TcpDisconnect): + websockets.Frame.from_file(self.client.rfile) + class TestSimpleTLS(_WebSocketTest): ssl = True @@ -204,7 +249,7 @@ class TestSimpleTLS(_WebSocketTest): wfile.flush() frame = websockets.Frame.from_file(rfile) - wfile.write(bytes(frame)) + wfile.write(bytes(websockets.Frame(fin=1, opcode=frame.header.opcode, payload=frame.payload))) wfile.flush() def test_simple_tls(self): @@ -213,13 +258,13 @@ class TestSimpleTLS(_WebSocketTest): frame = websockets.Frame.from_file(self.client.rfile) assert frame.payload == b'server-foobar' - self.client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.TEXT, payload=b'self.client-foobar'))) + self.client.wfile.write(bytes(websockets.Frame(fin=1, mask=1, opcode=websockets.OPCODE.TEXT, payload=b'self.client-foobar'))) self.client.wfile.flush() frame = websockets.Frame.from_file(self.client.rfile) assert frame.payload == b'self.client-foobar' - self.client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.CLOSE))) + self.client.wfile.write(bytes(websockets.Frame(fin=1, mask=1, opcode=websockets.OPCODE.CLOSE))) self.client.wfile.flush() @@ -234,22 +279,24 @@ class TestPing(_WebSocketTest): assert frame.header.opcode == websockets.OPCODE.PONG assert frame.payload == b'foobar' - wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.TEXT, payload=b'pong-received'))) + wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.PONG, payload=b'done'))) + wfile.flush() + + wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.CLOSE))) wfile.flush() + websockets.Frame.from_file(rfile) def test_ping(self): self.setup_connection() frame = websockets.Frame.from_file(self.client.rfile) + websockets.Frame.from_file(self.client.rfile) + self.client.wfile.write(bytes(websockets.Frame(fin=1, mask=1, opcode=websockets.OPCODE.CLOSE))) + self.client.wfile.flush() assert frame.header.opcode == websockets.OPCODE.PING - assert frame.payload == b'foobar' + assert frame.payload == b'' # We don't send payload to other end - self.client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.PONG, payload=frame.payload))) - self.client.wfile.flush() - - frame = websockets.Frame.from_file(self.client.rfile) - assert frame.header.opcode == websockets.OPCODE.TEXT - assert frame.payload == b'pong-received' + assert self.master.has_log("Pong Received from server", "info") class TestPong(_WebSocketTest): @@ -258,20 +305,29 @@ class TestPong(_WebSocketTest): def handle_websockets(cls, rfile, wfile): frame = websockets.Frame.from_file(rfile) assert frame.header.opcode == websockets.OPCODE.PING - assert frame.payload == b'foobar' + assert frame.payload == b'' wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.PONG, payload=frame.payload))) wfile.flush() + wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.CLOSE))) + wfile.flush() + websockets.Frame.from_file(rfile) + def test_pong(self): self.setup_connection() - self.client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.PING, payload=b'foobar'))) + self.client.wfile.write(bytes(websockets.Frame(fin=1, mask=1, opcode=websockets.OPCODE.PING, payload=b'foobar'))) self.client.wfile.flush() frame = websockets.Frame.from_file(self.client.rfile) + websockets.Frame.from_file(self.client.rfile) + self.client.wfile.write(bytes(websockets.Frame(fin=1, mask=1, opcode=websockets.OPCODE.CLOSE))) + self.client.wfile.flush() + assert frame.header.opcode == websockets.OPCODE.PONG assert frame.payload == b'foobar' + assert self.master.has_log("Pong Received from server", "info") class TestClose(_WebSocketTest): @@ -279,7 +335,7 @@ class TestClose(_WebSocketTest): @classmethod def handle_websockets(cls, rfile, wfile): frame = websockets.Frame.from_file(rfile) - wfile.write(bytes(frame)) + wfile.write(bytes(websockets.Frame(fin=1, opcode=frame.header.opcode, payload=frame.payload))) wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.CLOSE))) wfile.flush() @@ -289,7 +345,7 @@ class TestClose(_WebSocketTest): def test_close(self): self.setup_connection() - self.client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.CLOSE))) + self.client.wfile.write(bytes(websockets.Frame(fin=1, mask=1, opcode=websockets.OPCODE.CLOSE))) self.client.wfile.flush() websockets.Frame.from_file(self.client.rfile) @@ -299,7 +355,7 @@ class TestClose(_WebSocketTest): def test_close_payload_1(self): self.setup_connection() - self.client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.CLOSE, payload=b'\00\42'))) + self.client.wfile.write(bytes(websockets.Frame(fin=1, mask=1, opcode=websockets.OPCODE.CLOSE, payload=b'\00\42'))) self.client.wfile.flush() websockets.Frame.from_file(self.client.rfile) @@ -309,7 +365,7 @@ class TestClose(_WebSocketTest): def test_close_payload_2(self): self.setup_connection() - self.client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.CLOSE, payload=b'\00\42foobar'))) + self.client.wfile.write(bytes(websockets.Frame(fin=1, mask=1, opcode=websockets.OPCODE.CLOSE, payload=b'\00\42foobar'))) self.client.wfile.flush() websockets.Frame.from_file(self.client.rfile) @@ -329,8 +385,9 @@ class TestInvalidFrame(_WebSocketTest): # with pytest.raises(exceptions.TcpDisconnect): frame = websockets.Frame.from_file(self.client.rfile) - assert frame.header.opcode == 15 - assert frame.payload == b'foobar' + code, = struct.unpack('!H', frame.payload[:2]) + assert code == 1002 + assert frame.payload[2:].startswith(b'Invalid opcode') class TestStreaming(_WebSocketTest): @@ -360,3 +417,51 @@ class TestStreaming(_WebSocketTest): assert frame assert self.master.state.flows[1].messages == [] # Message not appended as the final frame isn't received + + +class TestExtension(_WebSocketTest): + + @classmethod + def handle_websockets(cls, rfile, wfile): + wfile.write(b'\xc1\x0f*N-*K-\xd2M\xcb\xcfOJ,\x02\x00') + wfile.flush() + + frame = websockets.Frame.from_file(rfile) + assert frame.header.rsv1 + wfile.write(b'\xc1\nJ\xce\xc9L\xcd+\x81r\x00\x00') + wfile.flush() + + frame = websockets.Frame.from_file(rfile) + assert frame.header.rsv1 + wfile.write(b'\xc2\x07\xba\xb7v\xdf{\x00\x00') + wfile.flush() + + def test_extension(self): + self.setup_connection(True) + + frame = websockets.Frame.from_file(self.client.rfile) + assert frame.header.rsv1 + + self.client.wfile.write(b'\xc1\x8fQ\xb7vX\x1by\xbf\x14\x9c\x9c\xa7\x15\x9ax9\x12}\xb5v') + self.client.wfile.flush() + + frame = websockets.Frame.from_file(self.client.rfile) + assert frame.header.rsv1 + + self.client.wfile.write(b'\xc2\x87\xeb\xbb\x0csQ\x0cz\xac\x90\xbb\x0c') + self.client.wfile.flush() + + frame = websockets.Frame.from_file(self.client.rfile) + assert frame.header.rsv1 + + assert len(self.master.state.flows[1].messages) == 5 + assert self.master.state.flows[1].messages[0].content == 'server-foobar' + assert self.master.state.flows[1].messages[0].type == websockets.OPCODE.TEXT + assert self.master.state.flows[1].messages[1].content == 'client-foobar' + assert self.master.state.flows[1].messages[1].type == websockets.OPCODE.TEXT + assert self.master.state.flows[1].messages[2].content == 'client-foobar' + assert self.master.state.flows[1].messages[2].type == websockets.OPCODE.TEXT + assert self.master.state.flows[1].messages[3].content == b'\xde\xad\xbe\xef' + assert self.master.state.flows[1].messages[3].type == websockets.OPCODE.BINARY + assert self.master.state.flows[1].messages[4].content == b'\xde\xad\xbe\xef' + assert self.master.state.flows[1].messages[4].type == websockets.OPCODE.BINARY diff --git a/test/mitmproxy/proxy/test_server.py b/test/mitmproxy/proxy/test_server.py index 8dce9bcd..87ec443a 100644 --- a/test/mitmproxy/proxy/test_server.py +++ b/test/mitmproxy/proxy/test_server.py @@ -143,9 +143,9 @@ class TcpMixin: # Test that we get the original SSL cert if self.ssl: - i_cert = certs.SSLCert(i.sslinfo.certchain[0]) - i2_cert = certs.SSLCert(i2.sslinfo.certchain[0]) - n_cert = certs.SSLCert(n.sslinfo.certchain[0]) + i_cert = certs.Cert(i.sslinfo.certchain[0]) + i2_cert = certs.Cert(i2.sslinfo.certchain[0]) + n_cert = certs.Cert(n.sslinfo.certchain[0]) assert i_cert == i2_cert assert i_cert != n_cert @@ -188,9 +188,9 @@ class TcpMixin: # Test that we get the original SSL cert if self.ssl: - i_cert = certs.SSLCert(i.sslinfo.certchain[0]) - i2_cert = certs.SSLCert(i2.sslinfo.certchain[0]) - n_cert = certs.SSLCert(n.sslinfo.certchain[0]) + i_cert = certs.Cert(i.sslinfo.certchain[0]) + i2_cert = certs.Cert(i2.sslinfo.certchain[0]) + n_cert = certs.Cert(n.sslinfo.certchain[0]) assert i_cert == i2_cert assert i_cert != n_cert @@ -511,6 +511,14 @@ class TestReverse(tservers.ReverseProxyTest, CommonMixin, TcpMixin): req = self.master.state.flows[0].request assert req.host_header == "127.0.0.1" + def test_selfconnection(self): + self.options.mode = "reverse:http://127.0.0.1:0" + + p = self.pathoc() + with p.connect(): + p.request("get:/") + assert self.master.has_log("The proxy shall not connect to itself.") + class TestReverseSSL(tservers.ReverseProxyTest, CommonMixin, TcpMixin): reverse = True @@ -579,7 +587,7 @@ class TestSocks5SSL(tservers.SocksModeTest): p = self.pathoc_raw() with p.connect(): p.socks_connect(("localhost", self.server.port)) - p.convert_to_ssl() + p.convert_to_tls() f = p.request("get:/p/200") assert f.status_code == 200 @@ -709,7 +717,7 @@ class TestProxy(tservers.HTTPProxyTest): first_flow = self.master.state.flows[0] second_flow = self.master.state.flows[1] assert first_flow.server_conn.timestamp_tcp_setup - assert first_flow.server_conn.timestamp_ssl_setup is None + assert first_flow.server_conn.timestamp_tls_setup is None assert second_flow.server_conn.timestamp_tcp_setup assert first_flow.server_conn.timestamp_tcp_setup == second_flow.server_conn.timestamp_tcp_setup @@ -723,12 +731,13 @@ class TestProxy(tservers.HTTPProxyTest): class TestProxySSL(tservers.HTTPProxyTest): ssl = True - def test_request_ssl_setup_timestamp_presence(self): + def test_request_tls_attribute_presence(self): # tests that the ssl timestamp is present when ssl is used f = self.pathod("304:b@10k") assert f.status_code == 304 first_flow = self.master.state.flows[0] - assert first_flow.server_conn.timestamp_ssl_setup + assert first_flow.server_conn.timestamp_tls_setup + assert first_flow.client_conn.tls_extensions def test_via(self): # tests that the ssl timestamp is present when ssl is used @@ -1149,7 +1158,7 @@ class AddUpstreamCertsToClientChainMixin: def test_add_upstream_certs_to_client_chain(self): with open(self.servercert, "rb") as f: d = f.read() - upstreamCert = certs.SSLCert.from_pem(d) + upstreamCert = certs.Cert.from_pem(d) p = self.pathoc() with p.connect(): upstream_cert_found_in_client_chain = False diff --git a/test/mitmproxy/test_certs.py b/test/mitmproxy/test_certs.py index 693bebc6..dcc185c0 100644 --- a/test/mitmproxy/test_certs.py +++ b/test/mitmproxy/test_certs.py @@ -136,18 +136,18 @@ class TestDummyCert: assert r.altnames == [] -class TestSSLCert: +class TestCert: def test_simple(self): with open(tutils.test_data.path("mitmproxy/net/data/text_cert"), "rb") as f: d = f.read() - c1 = certs.SSLCert.from_pem(d) + c1 = certs.Cert.from_pem(d) assert c1.cn == b"google.com" assert len(c1.altnames) == 436 with open(tutils.test_data.path("mitmproxy/net/data/text_cert_2"), "rb") as f: d = f.read() - c2 = certs.SSLCert.from_pem(d) + c2 = certs.Cert.from_pem(d) assert c2.cn == b"www.inode.co.nz" assert len(c2.altnames) == 2 assert c2.digest("sha1") @@ -165,20 +165,20 @@ class TestSSLCert: def test_err_broken_sans(self): with open(tutils.test_data.path("mitmproxy/net/data/text_cert_weird1"), "rb") as f: d = f.read() - c = certs.SSLCert.from_pem(d) + c = certs.Cert.from_pem(d) # This breaks unless we ignore a decoding error. assert c.altnames is not None def test_der(self): with open(tutils.test_data.path("mitmproxy/net/data/dercert"), "rb") as f: d = f.read() - s = certs.SSLCert.from_der(d) + s = certs.Cert.from_der(d) assert s.cn def test_state(self): with open(tutils.test_data.path("mitmproxy/net/data/text_cert"), "rb") as f: d = f.read() - c = certs.SSLCert.from_pem(d) + c = certs.Cert.from_pem(d) c.get_state() c2 = c.copy() @@ -188,6 +188,6 @@ class TestSSLCert: assert c == c2 assert c is not c2 - x = certs.SSLCert('') + x = certs.Cert('') x.set_state(a) assert x == c diff --git a/test/mitmproxy/test_command.py b/test/mitmproxy/test_command.py index 43b97742..c777192d 100644 --- a/test/mitmproxy/test_command.py +++ b/test/mitmproxy/test_command.py @@ -4,27 +4,56 @@ from mitmproxy import flow from mitmproxy import exceptions from mitmproxy.test import tflow from mitmproxy.test import taddons +import mitmproxy.types import io import pytest class TAddon: + @command.command("cmd1") def cmd1(self, foo: str) -> str: """cmd1 help""" return "ret " + foo + @command.command("cmd2") def cmd2(self, foo: str) -> str: return 99 + @command.command("cmd3") def cmd3(self, foo: int) -> int: return foo + @command.command("cmd4") + def cmd4(self, a: int, b: str, c: mitmproxy.types.Path) -> str: + return "ok" + + @command.command("subcommand") + def subcommand(self, cmd: mitmproxy.types.Cmd, *args: mitmproxy.types.Arg) -> str: + return "ok" + + @command.command("empty") def empty(self) -> None: pass + @command.command("varargs") def varargs(self, one: str, *var: str) -> typing.Sequence[str]: return list(var) + def choices(self) -> typing.Sequence[str]: + return ["one", "two", "three"] + + @command.argument("arg", type=mitmproxy.types.Choice("choices")) + def choose(self, arg: str) -> typing.Sequence[str]: + return ["one", "two", "three"] + + @command.command("path") + def path(self, arg: mitmproxy.types.Path) -> None: + pass + + @command.command("flow") + def flow(self, f: flow.Flow, s: str) -> None: + pass + class TestCommand: def test_varargs(self): @@ -52,6 +81,144 @@ class TestCommand: c = command.Command(cm, "cmd.three", a.cmd3) assert c.call(["1"]) == 1 + def test_parse_partial(self): + tests = [ + [ + "foo bar", + [ + command.ParseResult( + value = "foo", type = mitmproxy.types.Cmd, valid = False + ), + command.ParseResult( + value = "bar", type = mitmproxy.types.Unknown, valid = False + ) + ], + [], + ], + [ + "cmd1 'bar", + [ + command.ParseResult(value = "cmd1", type = mitmproxy.types.Cmd, valid = True), + command.ParseResult(value = "'bar", type = str, valid = True) + ], + [], + ], + [ + "a", + [command.ParseResult(value = "a", type = mitmproxy.types.Cmd, valid = False)], + [], + ], + [ + "", + [command.ParseResult(value = "", type = mitmproxy.types.Cmd, valid = False)], + [] + ], + [ + "cmd3 1", + [ + command.ParseResult(value = "cmd3", type = mitmproxy.types.Cmd, valid = True), + command.ParseResult(value = "1", type = int, valid = True), + ], + [] + ], + [ + "cmd3 ", + [ + command.ParseResult(value = "cmd3", type = mitmproxy.types.Cmd, valid = True), + command.ParseResult(value = "", type = int, valid = False), + ], + [] + ], + [ + "subcommand ", + [ + command.ParseResult( + value = "subcommand", type = mitmproxy.types.Cmd, valid = True, + ), + command.ParseResult(value = "", type = mitmproxy.types.Cmd, valid = False), + ], + ["arg"], + ], + [ + "subcommand cmd3 ", + [ + command.ParseResult(value = "subcommand", type = mitmproxy.types.Cmd, valid = True), + command.ParseResult(value = "cmd3", type = mitmproxy.types.Cmd, valid = True), + command.ParseResult(value = "", type = int, valid = False), + ], + [] + ], + [ + "cmd4", + [ + command.ParseResult(value = "cmd4", type = mitmproxy.types.Cmd, valid = True), + ], + ["int", "str", "path"] + ], + [ + "cmd4 ", + [ + command.ParseResult(value = "cmd4", type = mitmproxy.types.Cmd, valid = True), + command.ParseResult(value = "", type = int, valid = False), + ], + ["str", "path"] + ], + [ + "cmd4 1", + [ + command.ParseResult(value = "cmd4", type = mitmproxy.types.Cmd, valid = True), + command.ParseResult(value = "1", type = int, valid = True), + ], + ["str", "path"] + ], + [ + "cmd4 1", + [ + command.ParseResult(value = "cmd4", type = mitmproxy.types.Cmd, valid = True), + command.ParseResult(value = "1", type = int, valid = True), + ], + ["str", "path"] + ], + [ + "flow", + [ + command.ParseResult(value = "flow", type = mitmproxy.types.Cmd, valid = True), + ], + ["flow", "str"] + ], + [ + "flow ", + [ + command.ParseResult(value = "flow", type = mitmproxy.types.Cmd, valid = True), + command.ParseResult(value = "", type = flow.Flow, valid = False), + ], + ["str"] + ], + [ + "flow x", + [ + command.ParseResult(value = "flow", type = mitmproxy.types.Cmd, valid = True), + command.ParseResult(value = "x", type = flow.Flow, valid = False), + ], + ["str"] + ], + [ + "flow x ", + [ + command.ParseResult(value = "flow", type = mitmproxy.types.Cmd, valid = True), + command.ParseResult(value = "x", type = flow.Flow, valid = False), + command.ParseResult(value = "", type = str, valid = True), + ], + [] + ], + ] + with taddons.context() as tctx: + tctx.master.addons.add(TAddon()) + for s, expected, expectedremain in tests: + current, remain = tctx.master.commands.parse_partial(s) + assert current == expected + assert expectedremain == remain + def test_simple(): with taddons.context() as tctx: @@ -64,7 +231,7 @@ def test_simple(): c.call("nonexistent") with pytest.raises(exceptions.CommandError, match="Invalid"): c.call("") - with pytest.raises(exceptions.CommandError, match="Usage"): + with pytest.raises(exceptions.CommandError, match="argument mismatch"): c.call("one.two too many args") c.add("empty", a.empty) @@ -76,15 +243,18 @@ def test_simple(): def test_typename(): - assert command.typename(str, True) == "str" - assert command.typename(typing.Sequence[flow.Flow], True) == "[flow]" - assert command.typename(typing.Sequence[flow.Flow], False) == "flowspec" + assert command.typename(str) == "str" + assert command.typename(typing.Sequence[flow.Flow]) == "[flow]" + + assert command.typename(mitmproxy.types.Data) == "[data]" + assert command.typename(mitmproxy.types.CutSpec) == "[cut]" - assert command.typename(command.Cuts, False) == "cutspec" - assert command.typename(command.Cuts, True) == "[cuts]" + assert command.typename(flow.Flow) == "flow" + assert command.typename(typing.Sequence[str]) == "[str]" - assert command.typename(flow.Flow, False) == "flow" - assert command.typename(typing.Sequence[str], False) == "[str]" + assert command.typename(mitmproxy.types.Choice("foo")) == "choice" + assert command.typename(mitmproxy.types.Path) == "path" + assert command.typename(mitmproxy.types.Cmd) == "cmd" class DummyConsole: @@ -94,7 +264,7 @@ class DummyConsole: return [tflow.tflow(resp=True)] * n @command.command("cut") - def cut(self, spec: str) -> command.Cuts: + def cut(self, spec: str) -> mitmproxy.types.Data: return [["test"]] @@ -102,38 +272,11 @@ def test_parsearg(): with taddons.context() as tctx: tctx.master.addons.add(DummyConsole()) assert command.parsearg(tctx.master.commands, "foo", str) == "foo" - - assert command.parsearg(tctx.master.commands, "1", int) == 1 + with pytest.raises(exceptions.CommandError, match="Unsupported"): + command.parsearg(tctx.master.commands, "foo", type) with pytest.raises(exceptions.CommandError): command.parsearg(tctx.master.commands, "foo", int) - assert command.parsearg(tctx.master.commands, "true", bool) is True - assert command.parsearg(tctx.master.commands, "false", bool) is False - with pytest.raises(exceptions.CommandError): - command.parsearg(tctx.master.commands, "flobble", bool) - - assert len(command.parsearg( - tctx.master.commands, "2", typing.Sequence[flow.Flow] - )) == 2 - assert command.parsearg(tctx.master.commands, "1", flow.Flow) - with pytest.raises(exceptions.CommandError): - command.parsearg(tctx.master.commands, "2", flow.Flow) - with pytest.raises(exceptions.CommandError): - command.parsearg(tctx.master.commands, "0", flow.Flow) - with pytest.raises(exceptions.CommandError): - command.parsearg(tctx.master.commands, "foo", Exception) - - assert command.parsearg( - tctx.master.commands, "foo", command.Cuts - ) == [["test"]] - - assert command.parsearg( - tctx.master.commands, "foo", typing.Sequence[str] - ) == ["foo"] - assert command.parsearg( - tctx.master.commands, "foo, bar", typing.Sequence[str] - ) == ["foo", "bar"] - class TDec: @command.command("cmd1") @@ -169,4 +312,4 @@ def test_verify_arg_signature(): with pytest.raises(exceptions.CommandError): command.verify_arg_signature(lambda: None, [1, 2], {}) print('hello there') - command.verify_arg_signature(lambda a, b: None, [1, 2], {})
\ No newline at end of file + command.verify_arg_signature(lambda a, b: None, [1, 2], {}) diff --git a/test/mitmproxy/test_connections.py b/test/mitmproxy/test_connections.py index 83f0bd34..00cdbc87 100644 --- a/test/mitmproxy/test_connections.py +++ b/test/mitmproxy/test_connections.py @@ -41,10 +41,10 @@ class TestClientConnection: def test_tls_established_property(self): c = tflow.tclient_conn() c.tls_established = True - assert c.ssl_established + assert c.tls_established assert c.tls_established c.tls_established = False - assert not c.ssl_established + assert not c.tls_established assert not c.tls_established def test_make_dummy(self): @@ -113,10 +113,10 @@ class TestServerConnection: def test_tls_established_property(self): c = tflow.tserver_conn() c.tls_established = True - assert c.ssl_established + assert c.tls_established assert c.tls_established c.tls_established = False - assert not c.ssl_established + assert not c.tls_established assert not c.tls_established def test_make_dummy(self): @@ -155,7 +155,7 @@ class TestServerConnection: def test_sni(self): c = connections.ServerConnection(('', 1234)) with pytest.raises(ValueError, matches='sni must be str, not '): - c.establish_ssl(None, b'foobar') + c.establish_tls(sni=b'foobar') def test_state(self): c = tflow.tserver_conn() @@ -206,7 +206,7 @@ class TestClientConnectionTLS: key = OpenSSL.crypto.load_privatekey( OpenSSL.crypto.FILETYPE_PEM, raw_key) - c.convert_to_ssl(cert, key) + c.convert_to_tls(cert, key) assert c.connected() assert c.sni == sni assert c.tls_established @@ -222,17 +222,16 @@ class TestServerConnectionTLS(tservers.ServerTestBase): def handle(self): self.finish() - @pytest.mark.parametrize("clientcert", [ + @pytest.mark.parametrize("client_certs", [ None, tutils.test_data.path("mitmproxy/data/clientcert"), tutils.test_data.path("mitmproxy/data/clientcert/client.pem"), ]) - def test_tls(self, clientcert): + def test_tls(self, client_certs): c = connections.ServerConnection(("127.0.0.1", self.port)) c.connect() - c.establish_ssl(clientcert, "foo.com") + c.establish_tls(client_certs=client_certs) assert c.connected() - assert c.sni == "foo.com" assert c.tls_established c.close() c.finish() diff --git a/test/mitmproxy/test_flow.py b/test/mitmproxy/test_flow.py index fcc766b5..8cc11a16 100644 --- a/test/mitmproxy/test_flow.py +++ b/test/mitmproxy/test_flow.py @@ -97,7 +97,7 @@ class TestSerialize: class TestFlowMaster: - def test_load_flow_reverse(self): + def test_load_http_flow_reverse(self): s = tservers.TestState() opts = options.Options( mode="reverse:https://use-this-domain" @@ -108,6 +108,20 @@ class TestFlowMaster: fm.load_flow(f) assert s.flows[0].request.host == "use-this-domain" + def test_load_websocket_flow(self): + s = tservers.TestState() + opts = options.Options( + mode="reverse:https://use-this-domain" + ) + fm = master.Master(opts) + fm.addons.add(s) + f = tflow.twebsocketflow() + fm.load_flow(f.handshake_flow) + fm.load_flow(f) + assert s.flows[0].request.host == "use-this-domain" + assert s.flows[1].handshake_flow == f.handshake_flow + assert len(s.flows[1].messages) == len(f.messages) + def test_replay(self): opts = options.Options() fm = master.Master(opts) diff --git a/test/mitmproxy/test_flowfilter.py b/test/mitmproxy/test_flowfilter.py index c411258a..4eb37d81 100644 --- a/test/mitmproxy/test_flowfilter.py +++ b/test/mitmproxy/test_flowfilter.py @@ -420,6 +420,20 @@ class TestMatchingWebSocketFlow: e = self.err() assert self.q("~e", e) + def test_domain(self): + q = self.flow() + assert self.q("~d example.com", q) + assert not self.q("~d none", q) + + def test_url(self): + q = self.flow() + assert self.q("~u example.com", q) + assert self.q("~u example.com/ws", q) + assert not self.q("~u moo/path", q) + + q.handshake_flow = None + assert not self.q("~u example.com", q) + def test_body(self): f = self.flow() diff --git a/test/mitmproxy/test_http.py b/test/mitmproxy/test_http.py index 4463961a..49e61e25 100644 --- a/test/mitmproxy/test_http.py +++ b/test/mitmproxy/test_http.py @@ -203,6 +203,15 @@ class TestHTTPFlow: f.resume() assert f.reply.state == "committed" + def test_resume_duplicated(self): + f = tflow.tflow() + f.intercept() + f2 = f.copy() + assert f.intercepted is f2.intercepted is True + f.resume() + f2.resume() + assert f.intercepted is f2.intercepted is False + def test_replace_unicode(self): f = tflow.tflow(resp=True) f.response.content = b"\xc2foo" diff --git a/test/mitmproxy/test_stateobject.py b/test/mitmproxy/test_stateobject.py index d8c7a8e9..bd5d1792 100644 --- a/test/mitmproxy/test_stateobject.py +++ b/test/mitmproxy/test_stateobject.py @@ -1,101 +1,146 @@ -from typing import List +import typing + import pytest from mitmproxy.stateobject import StateObject -class Child(StateObject): +class TObject(StateObject): def __init__(self, x): self.x = x - _stateobject_attributes = dict( - x=int - ) - @classmethod def from_state(cls, state): obj = cls(None) obj.set_state(state) return obj + +class Child(TObject): + _stateobject_attributes = dict( + x=int + ) + def __eq__(self, other): return isinstance(other, Child) and self.x == other.x -class Container(StateObject): - def __init__(self): - self.child = None - self.children = None - self.dictionary = None +class TTuple(TObject): + _stateobject_attributes = dict( + x=typing.Tuple[int, Child] + ) + + +class TList(TObject): + _stateobject_attributes = dict( + x=typing.List[Child] + ) + +class TDict(TObject): _stateobject_attributes = dict( - child=Child, - children=List[Child], - dictionary=dict, + x=typing.Dict[str, Child] ) - @classmethod - def from_state(cls, state): - obj = cls() - obj.set_state(state) - return obj + +class TAny(TObject): + _stateobject_attributes = dict( + x=typing.Any + ) + + +class TSerializableChild(TObject): + _stateobject_attributes = dict( + x=Child + ) def test_simple(): a = Child(42) + assert a.get_state() == {"x": 42} b = a.copy() - assert b.get_state() == {"x": 42} a.set_state({"x": 44}) assert a.x == 44 assert b.x == 42 -def test_container(): - a = Container() - a.child = Child(42) +def test_serializable_child(): + child = Child(42) + a = TSerializableChild(child) + assert a.get_state() == { + "x": {"x": 42} + } + a.set_state({ + "x": {"x": 43} + }) + assert a.x.x == 43 + assert a.x is child b = a.copy() - assert a.child.x == b.child.x - b.child.x = 44 - assert a.child.x != b.child.x + assert a.x == b.x + assert a.x is not b.x -def test_container_list(): - a = Container() - a.children = [Child(42), Child(44)] +def test_tuple(): + a = TTuple((42, Child(43))) assert a.get_state() == { - "child": None, - "children": [{"x": 42}, {"x": 44}], - "dictionary": None, + "x": (42, {"x": 43}) } - copy = a.copy() - assert len(copy.children) == 2 - assert copy.children is not a.children - assert copy.children[0] is not a.children[0] - assert Container.from_state(a.get_state()) + b = a.copy() + a.set_state({"x": (44, {"x": 45})}) + assert a.x == (44, Child(45)) + assert b.x == (42, Child(43)) + +def test_tuple_err(): + a = TTuple(None) + with pytest.raises(ValueError, msg="Invalid data"): + a.set_state({"x": (42,)}) -def test_container_dict(): - a = Container() - a.dictionary = dict() - a.dictionary['foo'] = 'bar' - a.dictionary['bar'] = Child(44) + +def test_list(): + a = TList([Child(1), Child(2)]) assert a.get_state() == { - "child": None, - "children": None, - "dictionary": {'bar': {'x': 44}, 'foo': 'bar'}, + "x": [{"x": 1}, {"x": 2}], } copy = a.copy() - assert len(copy.dictionary) == 2 - assert copy.dictionary is not a.dictionary - assert copy.dictionary['bar'] is not a.dictionary['bar'] + assert len(copy.x) == 2 + assert copy.x is not a.x + assert copy.x[0] is not a.x[0] + + +def test_dict(): + a = TDict({"foo": Child(42)}) + assert a.get_state() == { + "x": {"foo": {"x": 42}} + } + b = a.copy() + assert list(a.x.items()) == list(b.x.items()) + assert a.x is not b.x + assert a.x["foo"] is not b.x["foo"] + + +def test_any(): + a = TAny(42) + b = a.copy() + assert a.x == b.x + + a = TAny(object()) + with pytest.raises(AssertionError): + a.get_state() def test_too_much_state(): - a = Container() - a.child = Child(42) + a = Child(42) s = a.get_state() s['foo'] = 'bar' - b = Container() with pytest.raises(RuntimeWarning): - b.set_state(s) + a.set_state(s) + + +def test_none(): + a = Child(None) + assert a.get_state() == {"x": None} + a = Child(42) + a.set_state({"x": None}) + assert a.x is None diff --git a/test/mitmproxy/test_typemanager.py b/test/mitmproxy/test_typemanager.py new file mode 100644 index 00000000..e69de29b --- /dev/null +++ b/test/mitmproxy/test_typemanager.py diff --git a/test/mitmproxy/test_types.py b/test/mitmproxy/test_types.py new file mode 100644 index 00000000..72492fa9 --- /dev/null +++ b/test/mitmproxy/test_types.py @@ -0,0 +1,237 @@ +import pytest +import os +import typing +import contextlib + +from mitmproxy.test import tutils +import mitmproxy.exceptions +import mitmproxy.types +from mitmproxy.test import taddons +from mitmproxy.test import tflow +from mitmproxy import command +from mitmproxy import flow + +from . import test_command + + +@contextlib.contextmanager +def chdir(path: str): + old_dir = os.getcwd() + os.chdir(path) + yield + os.chdir(old_dir) + + +def test_bool(): + with taddons.context() as tctx: + b = mitmproxy.types._BoolType() + assert b.completion(tctx.master.commands, bool, "b") == ["false", "true"] + assert b.parse(tctx.master.commands, bool, "true") is True + assert b.parse(tctx.master.commands, bool, "false") is False + assert b.is_valid(tctx.master.commands, bool, True) is True + assert b.is_valid(tctx.master.commands, bool, "foo") is False + with pytest.raises(mitmproxy.exceptions.TypeError): + b.parse(tctx.master.commands, bool, "foo") + + +def test_str(): + with taddons.context() as tctx: + b = mitmproxy.types._StrType() + assert b.is_valid(tctx.master.commands, str, "foo") is True + assert b.is_valid(tctx.master.commands, str, 1) is False + assert b.completion(tctx.master.commands, str, "") == [] + assert b.parse(tctx.master.commands, str, "foo") == "foo" + + +def test_unknown(): + with taddons.context() as tctx: + b = mitmproxy.types._UnknownType() + assert b.is_valid(tctx.master.commands, mitmproxy.types.Unknown, "foo") is False + assert b.is_valid(tctx.master.commands, mitmproxy.types.Unknown, 1) is False + assert b.completion(tctx.master.commands, mitmproxy.types.Unknown, "") == [] + assert b.parse(tctx.master.commands, mitmproxy.types.Unknown, "foo") == "foo" + + +def test_int(): + with taddons.context() as tctx: + b = mitmproxy.types._IntType() + assert b.is_valid(tctx.master.commands, int, "foo") is False + assert b.is_valid(tctx.master.commands, int, 1) is True + assert b.completion(tctx.master.commands, int, "b") == [] + assert b.parse(tctx.master.commands, int, "1") == 1 + assert b.parse(tctx.master.commands, int, "999") == 999 + with pytest.raises(mitmproxy.exceptions.TypeError): + b.parse(tctx.master.commands, int, "foo") + + +def test_path(): + with taddons.context() as tctx: + b = mitmproxy.types._PathType() + assert b.parse(tctx.master.commands, mitmproxy.types.Path, "/foo") == "/foo" + assert b.parse(tctx.master.commands, mitmproxy.types.Path, "/bar") == "/bar" + assert b.is_valid(tctx.master.commands, mitmproxy.types.Path, "foo") is True + assert b.is_valid(tctx.master.commands, mitmproxy.types.Path, 3) is False + + def normPathOpts(prefix, match): + ret = [] + for s in b.completion(tctx.master.commands, mitmproxy.types.Path, match): + s = s[len(prefix):] + s = s.replace(os.sep, "/") + ret.append(s) + return ret + + cd = os.path.normpath(tutils.test_data.path("mitmproxy/completion")) + assert normPathOpts(cd, cd) == ['/aaa', '/aab', '/aac', '/bbb/'] + assert normPathOpts(cd, os.path.join(cd, "a")) == ['/aaa', '/aab', '/aac'] + with chdir(cd): + assert normPathOpts("", "./") == ['./aaa', './aab', './aac', './bbb/'] + assert normPathOpts("", "") == ['./aaa', './aab', './aac', './bbb/'] + assert b.completion( + tctx.master.commands, mitmproxy.types.Path, "nonexistent" + ) == ["nonexistent"] + + +def test_cmd(): + with taddons.context() as tctx: + tctx.master.addons.add(test_command.TAddon()) + b = mitmproxy.types._CmdType() + assert b.is_valid(tctx.master.commands, mitmproxy.types.Cmd, "foo") is False + assert b.is_valid(tctx.master.commands, mitmproxy.types.Cmd, "cmd1") is True + assert b.parse(tctx.master.commands, mitmproxy.types.Cmd, "cmd1") == "cmd1" + with pytest.raises(mitmproxy.exceptions.TypeError): + assert b.parse(tctx.master.commands, mitmproxy.types.Cmd, "foo") + assert len( + b.completion(tctx.master.commands, mitmproxy.types.Cmd, "") + ) == len(tctx.master.commands.commands.keys()) + + +def test_cutspec(): + with taddons.context() as tctx: + b = mitmproxy.types._CutSpecType() + b.parse(tctx.master.commands, mitmproxy.types.CutSpec, "foo,bar") == ["foo", "bar"] + assert b.is_valid(tctx.master.commands, mitmproxy.types.CutSpec, 1) is False + assert b.is_valid(tctx.master.commands, mitmproxy.types.CutSpec, "foo") is False + assert b.is_valid(tctx.master.commands, mitmproxy.types.CutSpec, "request.path") is True + + assert b.completion( + tctx.master.commands, mitmproxy.types.CutSpec, "request.p" + ) == b.valid_prefixes + ret = b.completion(tctx.master.commands, mitmproxy.types.CutSpec, "request.port,f") + assert ret[0].startswith("request.port,") + assert len(ret) == len(b.valid_prefixes) + + +def test_arg(): + with taddons.context() as tctx: + b = mitmproxy.types._ArgType() + assert b.completion(tctx.master.commands, mitmproxy.types.Arg, "") == [] + assert b.parse(tctx.master.commands, mitmproxy.types.Arg, "foo") == "foo" + assert b.is_valid(tctx.master.commands, mitmproxy.types.Arg, "foo") is True + assert b.is_valid(tctx.master.commands, mitmproxy.types.Arg, 1) is False + + +def test_strseq(): + with taddons.context() as tctx: + b = mitmproxy.types._StrSeqType() + assert b.completion(tctx.master.commands, typing.Sequence[str], "") == [] + assert b.parse(tctx.master.commands, typing.Sequence[str], "foo") == ["foo"] + assert b.parse(tctx.master.commands, typing.Sequence[str], "foo,bar") == ["foo", "bar"] + assert b.is_valid(tctx.master.commands, typing.Sequence[str], ["foo"]) is True + assert b.is_valid(tctx.master.commands, typing.Sequence[str], ["a", "b", 3]) is False + assert b.is_valid(tctx.master.commands, typing.Sequence[str], 1) is False + assert b.is_valid(tctx.master.commands, typing.Sequence[str], "foo") is False + + +class DummyConsole: + @command.command("view.resolve") + def resolve(self, spec: str) -> typing.Sequence[flow.Flow]: + if spec == "err": + raise mitmproxy.exceptions.CommandError() + n = int(spec) + return [tflow.tflow(resp=True)] * n + + @command.command("cut") + def cut(self, spec: str) -> mitmproxy.types.Data: + return [["test"]] + + @command.command("options") + def options(self) -> typing.Sequence[str]: + return ["one", "two", "three"] + + +def test_flow(): + with taddons.context() as tctx: + tctx.master.addons.add(DummyConsole()) + b = mitmproxy.types._FlowType() + assert len(b.completion(tctx.master.commands, flow.Flow, "")) == len(b.valid_prefixes) + assert b.parse(tctx.master.commands, flow.Flow, "1") + assert b.is_valid(tctx.master.commands, flow.Flow, tflow.tflow()) is True + assert b.is_valid(tctx.master.commands, flow.Flow, "xx") is False + with pytest.raises(mitmproxy.exceptions.TypeError): + b.parse(tctx.master.commands, flow.Flow, "0") + with pytest.raises(mitmproxy.exceptions.TypeError): + b.parse(tctx.master.commands, flow.Flow, "2") + with pytest.raises(mitmproxy.exceptions.TypeError): + b.parse(tctx.master.commands, flow.Flow, "err") + + +def test_flows(): + with taddons.context() as tctx: + tctx.master.addons.add(DummyConsole()) + b = mitmproxy.types._FlowsType() + assert len( + b.completion(tctx.master.commands, typing.Sequence[flow.Flow], "") + ) == len(b.valid_prefixes) + assert b.is_valid(tctx.master.commands, typing.Sequence[flow.Flow], [tflow.tflow()]) is True + assert b.is_valid(tctx.master.commands, typing.Sequence[flow.Flow], "xx") is False + assert b.is_valid(tctx.master.commands, typing.Sequence[flow.Flow], 0) is False + assert len(b.parse(tctx.master.commands, typing.Sequence[flow.Flow], "0")) == 0 + assert len(b.parse(tctx.master.commands, typing.Sequence[flow.Flow], "1")) == 1 + assert len(b.parse(tctx.master.commands, typing.Sequence[flow.Flow], "2")) == 2 + with pytest.raises(mitmproxy.exceptions.TypeError): + b.parse(tctx.master.commands, typing.Sequence[flow.Flow], "err") + + +def test_data(): + with taddons.context() as tctx: + b = mitmproxy.types._DataType() + assert b.is_valid(tctx.master.commands, mitmproxy.types.Data, 0) is False + assert b.is_valid(tctx.master.commands, mitmproxy.types.Data, []) is True + assert b.is_valid(tctx.master.commands, mitmproxy.types.Data, [["x"]]) is True + assert b.is_valid(tctx.master.commands, mitmproxy.types.Data, [[b"x"]]) is True + assert b.is_valid(tctx.master.commands, mitmproxy.types.Data, [[1]]) is False + with pytest.raises(mitmproxy.exceptions.TypeError): + b.parse(tctx.master.commands, mitmproxy.types.Data, "foo") + with pytest.raises(mitmproxy.exceptions.TypeError): + b.parse(tctx.master.commands, mitmproxy.types.Data, "foo") + + +def test_choice(): + with taddons.context() as tctx: + tctx.master.addons.add(DummyConsole()) + b = mitmproxy.types._ChoiceType() + assert b.is_valid( + tctx.master.commands, + mitmproxy.types.Choice("options"), + "one", + ) is True + assert b.is_valid( + tctx.master.commands, + mitmproxy.types.Choice("options"), + "invalid", + ) is False + assert b.is_valid( + tctx.master.commands, + mitmproxy.types.Choice("nonexistent"), + "invalid", + ) is False + comp = b.completion(tctx.master.commands, mitmproxy.types.Choice("options"), "") + assert comp == ["one", "two", "three"] + assert b.parse(tctx.master.commands, mitmproxy.types.Choice("options"), "one") == "one" + with pytest.raises(mitmproxy.exceptions.TypeError): + b.parse(tctx.master.commands, mitmproxy.types.Choice("options"), "invalid") + + +def test_typemanager(): + assert mitmproxy.types.CommandTypes.get(bool, None) + assert mitmproxy.types.CommandTypes.get(mitmproxy.types.Choice("choide"), None) diff --git a/test/mitmproxy/test_version.py b/test/mitmproxy/test_version.py index f87b0851..8c176542 100644 --- a/test/mitmproxy/test_version.py +++ b/test/mitmproxy/test_version.py @@ -1,10 +1,36 @@ +import pathlib import runpy +import subprocess +from unittest import mock from mitmproxy import version def test_version(capsys): - runpy.run_module('mitmproxy.version', run_name='__main__') + here = pathlib.Path(__file__).absolute().parent + version_file = here / ".." / ".." / "mitmproxy" / "version.py" + runpy.run_path(str(version_file), run_name='__main__') stdout, stderr = capsys.readouterr() assert len(stdout) > 0 assert stdout.strip() == version.VERSION + + +def test_get_version_hardcoded(): + version.VERSION = "3.0.0.dev123-0xcafebabe" + assert version.get_version() == "3.0.0" + assert version.get_version(True) == "3.0.0.dev123" + assert version.get_version(True, True) == "3.0.0.dev123-0xcafebabe" + + +def test_get_version(): + version.VERSION = "3.0.0" + + with mock.patch('subprocess.check_output') as m: + m.return_value = b"tag-0-cafecafe" + assert version.get_version(True, True) == "3.0.0" + + m.return_value = b"tag-2-cafecafe" + assert version.get_version(True, True) == "3.0.0.dev2-0xcafecaf" + + m.side_effect = subprocess.CalledProcessError(-1, 'git describe --long') + assert version.get_version(True, True) == "3.0.0" diff --git a/test/mitmproxy/test_websocket.py b/test/mitmproxy/test_websocket.py index 7c53a4b0..fcacec36 100644 --- a/test/mitmproxy/test_websocket.py +++ b/test/mitmproxy/test_websocket.py @@ -3,6 +3,7 @@ import pytest from mitmproxy.io import tnetstring from mitmproxy import flowfilter +from mitmproxy.exceptions import Kill, ControlException from mitmproxy.test import tflow @@ -42,6 +43,20 @@ class TestWebSocketFlow: assert f.error.get_state() == f2.error.get_state() assert f.error is not f2.error + def test_kill(self): + f = tflow.twebsocketflow() + with pytest.raises(ControlException): + f.intercept() + f.resume() + f.kill() + + f = tflow.twebsocketflow() + f.intercept() + assert f.killable + f.kill() + assert not f.killable + assert f.reply.value == Kill + def test_match(self): f = tflow.twebsocketflow() assert not flowfilter.match("~b nonexistent", f) @@ -71,3 +86,9 @@ class TestWebSocketFlow: d = tflow.twebsocketflow().handshake_flow.get_state() tnetstring.dump(d, b) assert b.getvalue() + + def test_message_kill(self): + f = tflow.twebsocketflow() + assert not f.messages[-1].killed + f.messages[-1].kill() + assert f.messages[-1].killed diff --git a/test/mitmproxy/tools/console/test_commander.py b/test/mitmproxy/tools/console/test_commander.py new file mode 100644 index 00000000..2a96995d --- /dev/null +++ b/test/mitmproxy/tools/console/test_commander.py @@ -0,0 +1,98 @@ + +from mitmproxy.tools.console.commander import commander +from mitmproxy.test import taddons + + +class TestListCompleter: + def test_cycle(self): + tests = [ + [ + "", + ["a", "b", "c"], + ["a", "b", "c", "a"] + ], + [ + "xxx", + ["a", "b", "c"], + ["xxx", "xxx", "xxx"] + ], + [ + "b", + ["a", "b", "ba", "bb", "c"], + ["b", "ba", "bb", "b"] + ], + ] + for start, options, cycle in tests: + c = commander.ListCompleter(start, options) + for expected in cycle: + assert c.cycle() == expected + + +class TestCommandBuffer: + + def test_backspace(self): + tests = [ + [("", 0), ("", 0)], + [("1", 0), ("1", 0)], + [("1", 1), ("", 0)], + [("123", 3), ("12", 2)], + [("123", 2), ("13", 1)], + [("123", 0), ("123", 0)], + ] + with taddons.context() as tctx: + for start, output in tests: + cb = commander.CommandBuffer(tctx.master) + cb.text, cb.cursor = start[0], start[1] + cb.backspace() + assert cb.text == output[0] + assert cb.cursor == output[1] + + def test_left(self): + cursors = [3, 2, 1, 0, 0] + with taddons.context() as tctx: + cb = commander.CommandBuffer(tctx.master) + cb.text, cb.cursor = "abcd", 4 + for c in cursors: + cb.left() + assert cb.cursor == c + + def test_right(self): + cursors = [1, 2, 3, 4, 4] + with taddons.context() as tctx: + cb = commander.CommandBuffer(tctx.master) + cb.text, cb.cursor = "abcd", 0 + for c in cursors: + cb.right() + assert cb.cursor == c + + def test_insert(self): + tests = [ + [("", 0), ("x", 1)], + [("a", 0), ("xa", 1)], + [("xa", 2), ("xax", 3)], + ] + with taddons.context() as tctx: + for start, output in tests: + cb = commander.CommandBuffer(tctx.master) + cb.text, cb.cursor = start[0], start[1] + cb.insert("x") + assert cb.text == output[0] + assert cb.cursor == output[1] + + def test_cycle_completion(self): + with taddons.context() as tctx: + cb = commander.CommandBuffer(tctx.master) + cb.text = "foo bar" + cb.cursor = len(cb.text) + cb.cycle_completion() + + def test_render(self): + with taddons.context() as tctx: + cb = commander.CommandBuffer(tctx.master) + cb.text = "foo" + assert cb.render() + + def test_flatten(self): + with taddons.context() as tctx: + cb = commander.CommandBuffer(tctx.master) + assert cb.flatten("foo bar") == "foo bar" diff --git a/test/mitmproxy/tools/console/test_common.py b/test/mitmproxy/tools/console/test_common.py index 3ab4fd67..72438c49 100644 --- a/test/mitmproxy/tools/console/test_common.py +++ b/test/mitmproxy/tools/console/test_common.py @@ -1,12 +1,34 @@ +import urwid + from mitmproxy.test import tflow from mitmproxy.tools.console import common -from ....conftest import skip_appveyor - -@skip_appveyor def test_format_flow(): f = tflow.tflow(resp=True) assert common.format_flow(f, True) assert common.format_flow(f, True, hostheader=True) assert common.format_flow(f, True, extended=True) + + +def test_format_keyvals(): + assert common.format_keyvals( + [ + ("aa", "bb"), + ("cc", "dd"), + ("ee", None), + ] + ) + wrapped = urwid.BoxAdapter( + urwid.ListBox( + urwid.SimpleFocusListWalker( + common.format_keyvals([("foo", "bar")]) + ) + ), 1 + ) + assert wrapped.render((30, )) + assert common.format_keyvals( + [ + ("aa", wrapped) + ] + ) diff --git a/test/mitmproxy/tools/console/test_defaultkeys.py b/test/mitmproxy/tools/console/test_defaultkeys.py new file mode 100644 index 00000000..1f17c888 --- /dev/null +++ b/test/mitmproxy/tools/console/test_defaultkeys.py @@ -0,0 +1,23 @@ +from mitmproxy.test.tflow import tflow +from mitmproxy.tools.console import defaultkeys +from mitmproxy.tools.console import keymap +from mitmproxy.tools.console import master +from mitmproxy import command + + +def test_commands_exist(): + km = keymap.Keymap(None) + defaultkeys.map(km) + assert km.bindings + m = master.ConsoleMaster(None) + m.load_flow(tflow()) + + for binding in km.bindings: + cmd, *args = command.lexer(binding.command) + assert cmd in m.commands.commands + + cmd_obj = m.commands.commands[cmd] + try: + cmd_obj.prepare_args(args) + except Exception as e: + raise ValueError("Invalid command: {}".format(binding.command)) from e diff --git a/test/mitmproxy/tools/console/test_keymap.py b/test/mitmproxy/tools/console/test_keymap.py index 00e64991..7b475ff8 100644 --- a/test/mitmproxy/tools/console/test_keymap.py +++ b/test/mitmproxy/tools/console/test_keymap.py @@ -42,7 +42,7 @@ def test_join(): km = keymap.Keymap(tctx.master) km.add("key", "str", ["options"], "help1") km.add("key", "str", ["commands"]) - return + assert len(km.bindings) == 1 assert len(km.bindings[0].contexts) == 2 assert km.bindings[0].help == "help1" diff --git a/test/mitmproxy/tools/console/test_master.py b/test/mitmproxy/tools/console/test_master.py index fd9b301e..6f46ce9e 100644 --- a/test/mitmproxy/tools/console/test_master.py +++ b/test/mitmproxy/tools/console/test_master.py @@ -4,22 +4,9 @@ from mitmproxy import options from mitmproxy.test import tflow from mitmproxy.test import tutils from mitmproxy.tools import console -from mitmproxy.tools.console import common from ... import tservers -def test_format_keyvals(): - assert common.format_keyvals( - [ - ("aa", "bb"), - None, - ("cc", "dd"), - (None, "dd"), - (None, "dd"), - ] - ) - - def test_options(): assert options.Options(server_replay_kill_extra=True) diff --git a/test/mitmproxy/tools/console/test_pathedit.py b/test/mitmproxy/tools/console/test_pathedit.py deleted file mode 100644 index b9f51f5a..00000000 --- a/test/mitmproxy/tools/console/test_pathedit.py +++ /dev/null @@ -1,72 +0,0 @@ -import os -from os.path import normpath -from unittest import mock - -from mitmproxy.tools.console import pathedit -from mitmproxy.test import tutils - - -class TestPathCompleter: - - def test_lookup_construction(self): - c = pathedit._PathCompleter() - - cd = os.path.normpath(tutils.test_data.path("mitmproxy/completion")) - ca = os.path.join(cd, "a") - assert c.complete(ca).endswith(normpath("/completion/aaa")) - assert c.complete(ca).endswith(normpath("/completion/aab")) - c.reset() - ca = os.path.join(cd, "aaa") - assert c.complete(ca).endswith(normpath("/completion/aaa")) - assert c.complete(ca).endswith(normpath("/completion/aaa")) - c.reset() - assert c.complete(cd).endswith(normpath("/completion/aaa")) - - def test_completion(self): - c = pathedit._PathCompleter(True) - c.reset() - c.lookup = [ - ("a", "x/a"), - ("aa", "x/aa"), - ] - assert c.complete("a") == "a" - assert c.final == "x/a" - assert c.complete("a") == "aa" - assert c.complete("a") == "a" - - c = pathedit._PathCompleter(True) - r = c.complete("l") - assert c.final.endswith(r) - - c.reset() - assert c.complete("/nonexistent") == "/nonexistent" - assert c.final == "/nonexistent" - c.reset() - assert c.complete("~") != "~" - - c.reset() - s = "thisisatotallynonexistantpathforsure" - assert c.complete(s) == s - assert c.final == s - - -class TestPathEdit: - - def test_keypress(self): - - pe = pathedit.PathEdit("", "") - - with mock.patch('urwid.widget.Edit.get_edit_text') as get_text, \ - mock.patch('urwid.widget.Edit.set_edit_text') as set_text: - - cd = os.path.normpath(tutils.test_data.path("mitmproxy/completion")) - get_text.return_value = os.path.join(cd, "a") - - # Pressing tab should set completed path - pe.keypress((1,), "tab") - set_text_called_with = set_text.call_args[0][0] - assert set_text_called_with.endswith(normpath("/completion/aaa")) - - # Pressing any other key should reset - pe.keypress((1,), "a") - assert pe.lookup is None diff --git a/test/mitmproxy/tools/web/test_app.py b/test/mitmproxy/tools/web/test_app.py index 248581b9..5afc0bca 100644 --- a/test/mitmproxy/tools/web/test_app.py +++ b/test/mitmproxy/tools/web/test_app.py @@ -322,7 +322,7 @@ class TestApp(tornado.testing.AsyncHTTPTestCase): ws_client2 = yield websocket.websocket_connect(ws_url) ws_client2.close() - def test_generate_tflow_js(self): + def _test_generate_tflow_js(self): _tflow = app.flow_to_json(tflow.tflow(resp=True, err=True)) # Set some value as constant, so that _tflow.js would not change every time. _tflow['client_conn']['id'] = "4a18d1a0-50a1-48dd-9aa6-d45d74282939" diff --git a/test/mitmproxy/utils/test_debug.py b/test/mitmproxy/utils/test_debug.py index a8e1054d..0ca6ead0 100644 --- a/test/mitmproxy/utils/test_debug.py +++ b/test/mitmproxy/utils/test_debug.py @@ -1,5 +1,4 @@ import io -import subprocess import sys from unittest import mock import pytest @@ -14,18 +13,6 @@ def test_dump_system_info_precompiled(precompiled): assert ("binary" in debug.dump_system_info()) == precompiled -def test_dump_system_info_version(): - with mock.patch('subprocess.check_output') as m: - m.return_value = b"v2.0.0-0-cafecafe" - x = debug.dump_system_info() - assert 'dev' not in x - assert 'cafecafe' in x - - with mock.patch('subprocess.check_output') as m: - m.side_effect = subprocess.CalledProcessError(-1, 'git describe --tags --long') - assert 'dev' not in debug.dump_system_info() - - def test_dump_info(): cs = io.StringIO() debug.dump_info(None, None, file=cs, testing=True) diff --git a/test/mitmproxy/utils/test_human.py b/test/mitmproxy/utils/test_human.py index e8ffaad4..947cfa4a 100644 --- a/test/mitmproxy/utils/test_human.py +++ b/test/mitmproxy/utils/test_human.py @@ -54,3 +54,5 @@ def test_format_address(): assert human.format_address(("::ffff:127.0.0.1", "54010", "0", "0")) == "127.0.0.1:54010" assert human.format_address(("127.0.0.1", "54010")) == "127.0.0.1:54010" assert human.format_address(("example.com", "54010")) == "example.com:54010" + assert human.format_address(("::", "8080")) == "*:8080" + assert human.format_address(("0.0.0.0", "8080")) == "*:8080" diff --git a/test/mitmproxy/utils/test_typecheck.py b/test/mitmproxy/utils/test_typecheck.py index 66b1884e..9cb4334e 100644 --- a/test/mitmproxy/utils/test_typecheck.py +++ b/test/mitmproxy/utils/test_typecheck.py @@ -4,7 +4,6 @@ from unittest import mock import pytest from mitmproxy.utils import typecheck -from mitmproxy import command class TBase: @@ -88,34 +87,14 @@ def test_check_any(): typecheck.check_option_type("foo", None, typing.Any) -def test_check_command_type(): - assert(typecheck.check_command_type("foo", str)) - assert(typecheck.check_command_type(["foo"], typing.Sequence[str])) - assert(not typecheck.check_command_type(["foo", 1], typing.Sequence[str])) - assert(typecheck.check_command_type(None, None)) - assert(not typecheck.check_command_type(["foo"], typing.Sequence[int])) - assert(not typecheck.check_command_type("foo", typing.Sequence[int])) - assert(typecheck.check_command_type([["foo", b"bar"]], command.Cuts)) - assert(not typecheck.check_command_type(["foo", b"bar"], command.Cuts)) - assert(not typecheck.check_command_type([["foo", 22]], command.Cuts)) - - # Python 3.5 only defines __parameters__ - m = mock.Mock() - m.__str__ = lambda self: "typing.Sequence" - m.__parameters__ = (int,) - - typecheck.check_command_type([10], m) - - # Python 3.5 only defines __union_params__ - m = mock.Mock() - m.__str__ = lambda self: "typing.Union" - m.__union_params__ = (int,) - assert not typecheck.check_command_type([22], m) - - def test_typesec_to_str(): assert(typecheck.typespec_to_str(str)) == "str" assert(typecheck.typespec_to_str(typing.Sequence[str])) == "sequence of str" assert(typecheck.typespec_to_str(typing.Optional[str])) == "optional str" with pytest.raises(NotImplementedError): typecheck.typespec_to_str(dict) + + +def test_mapping_types(): + # this is not covered by check_option_type, but still belongs in this module + assert (str, int) == typecheck.mapping_types(typing.Mapping[str, int]) diff --git a/test/pathod/protocols/test_http2.py b/test/pathod/protocols/test_http2.py index b1eebc73..95965cee 100644 --- a/test/pathod/protocols/test_http2.py +++ b/test/pathod/protocols/test_http2.py @@ -75,7 +75,7 @@ class TestCheckALPNMatch(net_tservers.ServerTestBase): def test_check_alpn(self): c = tcp.TCPClient(("127.0.0.1", self.port)) with c.connect(): - c.convert_to_ssl(alpn_protos=[b'h2']) + c.convert_to_tls(alpn_protos=[b'h2']) protocol = HTTP2StateProtocol(c) assert protocol.check_alpn() @@ -89,7 +89,7 @@ class TestCheckALPNMismatch(net_tservers.ServerTestBase): def test_check_alpn(self): c = tcp.TCPClient(("127.0.0.1", self.port)) with c.connect(): - c.convert_to_ssl(alpn_protos=[b'h2']) + c.convert_to_tls(alpn_protos=[b'h2']) protocol = HTTP2StateProtocol(c) with pytest.raises(NotImplementedError): protocol.check_alpn() @@ -207,7 +207,7 @@ class TestApplySettings(net_tservers.ServerTestBase): def test_apply_settings(self): c = tcp.TCPClient(("127.0.0.1", self.port)) with c.connect(): - c.convert_to_ssl() + c.convert_to_tls() protocol = HTTP2StateProtocol(c) protocol._apply_settings({ @@ -302,7 +302,7 @@ class TestReadRequest(net_tservers.ServerTestBase): def test_read_request(self): c = tcp.TCPClient(("127.0.0.1", self.port)) with c.connect(): - c.convert_to_ssl() + c.convert_to_tls() protocol = HTTP2StateProtocol(c, is_server=True) protocol.connection_preface_performed = True @@ -328,7 +328,7 @@ class TestReadRequestRelative(net_tservers.ServerTestBase): def test_asterisk_form(self): c = tcp.TCPClient(("127.0.0.1", self.port)) with c.connect(): - c.convert_to_ssl() + c.convert_to_tls() protocol = HTTP2StateProtocol(c, is_server=True) protocol.connection_preface_performed = True @@ -351,7 +351,7 @@ class TestReadRequestAbsolute(net_tservers.ServerTestBase): def test_absolute_form(self): c = tcp.TCPClient(("127.0.0.1", self.port)) with c.connect(): - c.convert_to_ssl() + c.convert_to_tls() protocol = HTTP2StateProtocol(c, is_server=True) protocol.connection_preface_performed = True @@ -378,7 +378,7 @@ class TestReadResponse(net_tservers.ServerTestBase): def test_read_response(self): c = tcp.TCPClient(("127.0.0.1", self.port)) with c.connect(): - c.convert_to_ssl() + c.convert_to_tls() protocol = HTTP2StateProtocol(c) protocol.connection_preface_performed = True @@ -404,7 +404,7 @@ class TestReadEmptyResponse(net_tservers.ServerTestBase): def test_read_empty_response(self): c = tcp.TCPClient(("127.0.0.1", self.port)) with c.connect(): - c.convert_to_ssl() + c.convert_to_tls() protocol = HTTP2StateProtocol(c) protocol.connection_preface_performed = True diff --git a/test/pathod/test_pathoc.py b/test/pathod/test_pathoc.py index 4b50e2a7..297b54d4 100644 --- a/test/pathod/test_pathoc.py +++ b/test/pathod/test_pathoc.py @@ -238,11 +238,11 @@ class TestDaemonHTTP2(PathocTestDaemon): http2_skip_connection_preface=True, ) - tmp_convert_to_ssl = c.convert_to_ssl - c.convert_to_ssl = Mock() - c.convert_to_ssl.side_effect = tmp_convert_to_ssl + tmp_convert_to_tls = c.convert_to_tls + c.convert_to_tls = Mock() + c.convert_to_tls.side_effect = tmp_convert_to_tls with c.connect(): - _, kwargs = c.convert_to_ssl.call_args + _, kwargs = c.convert_to_tls.call_args assert set(kwargs['alpn_protos']) == set([b'http/1.1', b'h2']) def test_request(self): diff --git a/test/pathod/test_pathod.py b/test/pathod/test_pathod.py index c0011952..d6522cb6 100644 --- a/test/pathod/test_pathod.py +++ b/test/pathod/test_pathod.py @@ -153,7 +153,7 @@ class CommonTests(tservers.DaemonTests): c = tcp.TCPClient(("localhost", self.d.port)) with c.connect(): if self.ssl: - c.convert_to_ssl() + c.convert_to_tls() c.wfile.write(b"foo\n\n\n") c.wfile.flush() l = self.d.last_log() @@ -241,7 +241,7 @@ class TestDaemonSSL(CommonTests): with c.connect(): c.wfile.write(b"\0\0\0\0") with pytest.raises(exceptions.TlsException): - c.convert_to_ssl() + c.convert_to_tls() l = self.d.last_log() assert l["type"] == "error" assert "SSL" in l["msg"] @@ -25,17 +25,16 @@ commands = sphinx-build -W -b html -d {envtmpdir}/doctrees . {envtmpdir}/html commands = mitmdump --version flake8 --jobs 8 mitmproxy pathod examples test release - python3 test/filename_matching.py + python test/filename_matching.py rstcheck README.rst - mypy --ignore-missing-imports ./mitmproxy - mypy --ignore-missing-imports ./pathod - mypy --ignore-missing-imports --follow-imports=skip ./examples/simple/ + mypy --ignore-missing-imports ./mitmproxy ./pathod + mypy --ignore-missing-imports --follow-imports=skip ./examples/simple/ ./examples/pathod/ ./examples/complex/ [testenv:individual_coverage] deps = -rrequirements.txt commands = - python3 test/individual_coverage.py + python test/individual_coverage.py [testenv:wheel] recreate = True @@ -51,14 +50,13 @@ commands = pathoc --version [testenv:rtool] +passenv = SKIP_MITMPROXY SNAPSHOT_HOST SNAPSHOT_PORT SNAPSHOT_USER SNAPSHOT_PASS RTOOL_KEY deps = -rrequirements.txt - -e./release - # The 3.2 release is broken - # the next commit after this updates the bootloaders, which then segfault! - # https://github.com/pyinstaller/pyinstaller/issues/2232 - git+https://github.com/pyinstaller/pyinstaller.git@483c819d6a256b58db6740696a901bd41c313f0c; sys_platform == 'win32' - git+https://github.com/mhils/pyinstaller.git@d094401e4196b1a6a03818b80164a5f555861cef; sys_platform != 'win32' + pyinstaller==3.3.1 + twine==1.9.1 + pysftp==0.2.9 commands = - rtool {posargs} + mitmdump --version + python ./release/rtool.py {posargs} diff --git a/web/package.json b/web/package.json index 31c2d6d6..77b13e8b 100644 --- a/web/package.json +++ b/web/package.json @@ -37,7 +37,8 @@ "redux-logger": "^3.0.6", "redux-mock-store": "^1.3.0", "redux-thunk": "^2.2.0", - "shallowequal": "^1.0.2" + "shallowequal": "^1.0.2", + "stable": "^0.1.6" }, "devDependencies": { "babel-core": "^6.26.0", diff --git a/web/src/js/components/FlowTable/FlowColumns.jsx b/web/src/js/components/FlowTable/FlowColumns.jsx index 02a4fba1..e60ed487 100644 --- a/web/src/js/components/FlowTable/FlowColumns.jsx +++ b/web/src/js/components/FlowTable/FlowColumns.jsx @@ -119,7 +119,7 @@ export function TimeColumn({ flow }) { return ( <td className="col-time"> {flow.response ? ( - formatTimeDelta(1000 * (flow.response.timestamp_end - flow.server_conn.timestamp_start)) + formatTimeDelta(1000 * (flow.response.timestamp_end - flow.request.timestamp_start)) ) : ( '...' )} diff --git a/web/src/js/ducks/utils/store.js b/web/src/js/ducks/utils/store.js index ac272650..ad2242ee 100644 --- a/web/src/js/ducks/utils/store.js +++ b/web/src/js/ducks/utils/store.js @@ -1,3 +1,5 @@ +import stable from 'stable' + export const SET_FILTER = 'LIST_SET_FILTER' export const SET_SORT = 'LIST_SET_SORT' export const ADD = 'LIST_ADD' @@ -35,7 +37,7 @@ export default function reduce(state = defaultState, action) { switch (action.type) { case SET_FILTER: - view = list.filter(action.filter).sort(action.sort) + view = stable(list.filter(action.filter), action.sort) viewIndex = {} view.forEach((item, index) => { viewIndex[item.id] = index @@ -43,7 +45,7 @@ export default function reduce(state = defaultState, action) { break case SET_SORT: - view = [...view].sort(action.sort) + view = stable([...view], action.sort) viewIndex = {} view.forEach((item, index) => { viewIndex[item.id] = index diff --git a/web/src/js/filt/filt.js b/web/src/js/filt/filt.js index 26058649..19a41af2 100644 --- a/web/src/js/filt/filt.js +++ b/web/src/js/filt/filt.js @@ -1929,7 +1929,7 @@ module.exports = (function() { function body(regex){ regex = new RegExp(regex, "i"); function bodyFilter(flow){ - return True; + return true; } bodyFilter.desc = "body filters are not implemented yet, see https://github.com/mitmproxy/mitmweb/issues/10"; return bodyFilter; @@ -1937,7 +1937,7 @@ module.exports = (function() { function requestBody(regex){ regex = new RegExp(regex, "i"); function requestBodyFilter(flow){ - return True; + return true; } requestBodyFilter.desc = "body filters are not implemented yet, see https://github.com/mitmproxy/mitmweb/issues/10"; return requestBodyFilter; @@ -1945,7 +1945,7 @@ module.exports = (function() { function responseBody(regex){ regex = new RegExp(regex, "i"); function responseBodyFilter(flow){ - return True; + return true; } responseBodyFilter.desc = "body filters are not implemented yet, see https://github.com/mitmproxy/mitmweb/issues/10"; return responseBodyFilter; @@ -2104,4 +2104,4 @@ module.exports = (function() { SyntaxError: peg$SyntaxError, parse: peg$parse }; -})();
\ No newline at end of file +})(); diff --git a/web/src/js/filt/filt.peg b/web/src/js/filt/filt.peg index 12959474..e4b151ad 100644 --- a/web/src/js/filt/filt.peg +++ b/web/src/js/filt/filt.peg @@ -1,4 +1,4 @@ -// PEG.js filter rules - see http://pegjs.majda.cz/online +// PEG.js filter rules - see https://pegjs.org/ { var flowutils = require("../flow/utils.js"); @@ -72,7 +72,7 @@ function responseCode(code){ function body(regex){ regex = new RegExp(regex, "i"); function bodyFilter(flow){ - return True; + return true; } bodyFilter.desc = "body filters are not implemented yet, see https://github.com/mitmproxy/mitmweb/issues/10"; return bodyFilter; @@ -80,7 +80,7 @@ function body(regex){ function requestBody(regex){ regex = new RegExp(regex, "i"); function requestBodyFilter(flow){ - return True; + return true; } requestBodyFilter.desc = "body filters are not implemented yet, see https://github.com/mitmproxy/mitmweb/issues/10"; return requestBodyFilter; @@ -88,7 +88,7 @@ function requestBody(regex){ function responseBody(regex){ regex = new RegExp(regex, "i"); function responseBodyFilter(flow){ - return True; + return true; } responseBodyFilter.desc = "body filters are not implemented yet, see https://github.com/mitmproxy/mitmweb/issues/10"; return responseBodyFilter; diff --git a/web/yarn.lock b/web/yarn.lock index aa5ae85f..1930fded 100644 --- a/web/yarn.lock +++ b/web/yarn.lock @@ -5449,6 +5449,10 @@ sshpk@^1.7.0: jsbn "~0.1.0"
tweetnacl "~0.14.0"
+stable@^0.1.6:
+ version "0.1.6"
+ resolved "https://registry.yarnpkg.com/stable/-/stable-0.1.6.tgz#910f5d2aed7b520c6e777499c1f32e139fdecb10"
+
statuses@1:
version "1.3.1"
resolved "https://registry.yarnpkg.com/statuses/-/statuses-1.3.1.tgz#faf51b9eb74aaef3b3acf4ad5f61abf24cb7b93e"
|