diff options
-rw-r--r-- | issue_template.md | 4 | ||||
-rw-r--r-- | mitmproxy/console/common.py | 8 | ||||
-rw-r--r-- | mitmproxy/flow.py | 19 | ||||
-rw-r--r-- | mitmproxy/protocol/base.py | 11 | ||||
-rw-r--r-- | mitmproxy/protocol/http.py | 14 | ||||
-rw-r--r-- | mitmproxy/protocol/tls.py | 178 | ||||
-rw-r--r-- | mitmproxy/proxy/root_context.py | 2 | ||||
-rw-r--r-- | mitmproxy/utils.py | 23 | ||||
-rw-r--r-- | netlib/http/__init__.py | 4 | ||||
-rw-r--r-- | netlib/http/message.py | 69 | ||||
-rw-r--r-- | netlib/http/request.py | 59 | ||||
-rw-r--r-- | netlib/http/response.py | 20 | ||||
-rw-r--r-- | netlib/multidict.py | 49 | ||||
-rw-r--r-- | netlib/odict.py | 81 | ||||
-rw-r--r-- | pathod/utils.py | 23 | ||||
-rw-r--r-- | test/mitmproxy/test_examples.py | 3 | ||||
-rw-r--r-- | test/mitmproxy/tutils.py | 3 | ||||
-rw-r--r-- | test/netlib/http/test_request.py | 2 | ||||
-rw-r--r-- | test/netlib/test_multidict.py | 26 | ||||
-rw-r--r-- | test/netlib/test_odict.py | 10 | ||||
-rw-r--r-- | test/pathod/test_pathoc.py | 2 | ||||
-rw-r--r-- | test/pathod/test_pathod.py | 7 | ||||
-rw-r--r-- | test/pathod/tutils.py | 2 |
23 files changed, 298 insertions, 321 deletions
diff --git a/issue_template.md b/issue_template.md index 3f9be788..08d390e4 100644 --- a/issue_template.md +++ b/issue_template.md @@ -10,10 +10,10 @@ ##### What went wrong? -##### Any other comments? +##### Any other comments? What have you tried so far? --- Mitmproxy Version: -Operating System:
\ No newline at end of file +Operating System: diff --git a/mitmproxy/console/common.py b/mitmproxy/console/common.py index 4e472fb6..25658dfa 100644 --- a/mitmproxy/console/common.py +++ b/mitmproxy/console/common.py @@ -154,7 +154,7 @@ def raw_format_flow(f, focus, extended): if f["intercepted"] and not f["acked"]: uc = "intercept" - elif f["resp_code"] or f["err_msg"]: + elif "resp_code" in f or "err_msg" in f: uc = "text" else: uc = "title" @@ -173,7 +173,7 @@ def raw_format_flow(f, focus, extended): ("fixed", preamble, urwid.Text("")) ) - if f["resp_code"]: + if "resp_code" in f: codes = { 2: "code_200", 3: "code_300", @@ -185,6 +185,8 @@ def raw_format_flow(f, focus, extended): if f["resp_is_replay"]: resp.append(fcol(SYMBOL_REPLAY, "replay")) resp.append(fcol(f["resp_code"], ccol)) + if extended: + resp.append(fcol(f["resp_reason"], ccol)) if f["intercepted"] and f["resp_code"] and not f["acked"]: rc = "intercept" else: @@ -412,7 +414,6 @@ def format_flow(f, focus, extended=False, hostheader=False, marked=False): req_http_version = f.request.http_version, err_msg = f.error.msg if f.error else None, - resp_code = f.response.status_code if f.response else None, marked = marked, ) @@ -430,6 +431,7 @@ def format_flow(f, focus, extended=False, hostheader=False, marked=False): d.update(dict( resp_code = f.response.status_code, + resp_reason = f.response.reason, resp_is_replay = f.response.is_replay, resp_clen = contentdesc, roundtrip = roundtrip, diff --git a/mitmproxy/flow.py b/mitmproxy/flow.py index 647ebf68..a9018e16 100644 --- a/mitmproxy/flow.py +++ b/mitmproxy/flow.py @@ -13,6 +13,8 @@ from six.moves import http_cookies, http_cookiejar, urllib import os import re +from typing import List, Optional, Set + from netlib import wsgi, odict from netlib.exceptions import HttpException from netlib.http import Headers, http1, cookies @@ -376,8 +378,11 @@ class StickyAuthState: f.request.headers["authorization"] = self.hosts[host] +@six.add_metaclass(ABCMeta) class FlowList(object): - __metaclass__ = ABCMeta + + def __init__(self): + self._list = [] # type: List[Flow] def __iter__(self): return iter(self._list) @@ -416,7 +421,7 @@ class FlowList(object): class FlowView(FlowList): def __init__(self, store, filt=None): - self._list = [] + super(FlowView, self).__init__() if not filt: filt = lambda flow: True self._build(store, filt) @@ -458,7 +463,7 @@ class FlowStore(FlowList): """ def __init__(self): - self._list = [] + super(FlowStore, self).__init__() self._set = set() # Used for O(1) lookups self.views = [] self._recalculate_views() @@ -649,18 +654,18 @@ class FlowMaster(controller.ServerMaster): self.server_playback = None self.client_playback = None self.kill_nonreplay = False - self.scripts = [] + self.scripts = [] # type: List[script.Script] self.pause_scripts = False - self.stickycookie_state = False + self.stickycookie_state = None # type: Optional[StickyCookieState] self.stickycookie_txt = None - self.stickyauth_state = False + self.stickyauth_state = False # type: Optional[StickyAuthState] self.stickyauth_txt = None self.anticache = False self.anticomp = False - self.stream_large_bodies = False + self.stream_large_bodies = None # type: Optional[StreamLargeBodies] self.refresh_server_playback = False self.replacehooks = ReplaceHooks() self.setheaders = SetHeaders() diff --git a/mitmproxy/protocol/base.py b/mitmproxy/protocol/base.py index 536f2753..c8e58d1b 100644 --- a/mitmproxy/protocol/base.py +++ b/mitmproxy/protocol/base.py @@ -133,24 +133,15 @@ class ServerConnectionMixin(object): "The proxy shall not connect to itself.".format(repr(address)) ) - def set_server(self, address, server_tls=None, sni=None): + def set_server(self, address): """ Sets a new server address. If there is an existing connection, it will be closed. - - Raises: - ~mitmproxy.exceptions.ProtocolException: - if ``server_tls`` is ``True``, but there was no TLS layer on the - protocol stack which could have processed this. """ if self.server_conn: self.disconnect() self.log("Set new server address: " + repr(address), "debug") self.server_conn.address = address self.__check_self_connect() - if server_tls: - raise ProtocolException( - "Cannot upgrade to TLS, no TLS layer on the protocol stack." - ) def disconnect(self): """ diff --git a/mitmproxy/protocol/http.py b/mitmproxy/protocol/http.py index 9cb35176..d9111303 100644 --- a/mitmproxy/protocol/http.py +++ b/mitmproxy/protocol/http.py @@ -120,7 +120,7 @@ class UpstreamConnectLayer(Layer): if address != self.server_conn.via.address: self.ctx.set_server(address) - def set_server(self, address, server_tls=None, sni=None): + def set_server(self, address): if self.ctx.server_conn: self.ctx.disconnect() address = tcp.Address.wrap(address) @@ -128,11 +128,6 @@ class UpstreamConnectLayer(Layer): self.connect_request.port = address.port self.server_conn.address = address - if server_tls: - raise ProtocolException( - "Cannot upgrade to TLS, no TLS layer on the protocol stack." - ) - class HttpLayer(Layer): @@ -149,7 +144,7 @@ class HttpLayer(Layer): def __call__(self): if self.mode == "transparent": - self.__initial_server_tls = self._server_tls + self.__initial_server_tls = self.server_tls self.__initial_server_conn = self.server_conn while True: try: @@ -360,8 +355,9 @@ class HttpLayer(Layer): if self.mode == "regular" or self.mode == "transparent": # If there's an existing connection that doesn't match our expectations, kill it. - if address != self.server_conn.address or tls != self.server_conn.tls_established: - self.set_server(address, tls, address.host) + if address != self.server_conn.address or tls != self.server_tls: + self.set_server(address) + self.set_server_tls(tls, address.host) # Establish connection is neccessary. if not self.server_conn: self.connect() diff --git a/mitmproxy/protocol/tls.py b/mitmproxy/protocol/tls.py index 26c3f9d2..74c55ab4 100644 --- a/mitmproxy/protocol/tls.py +++ b/mitmproxy/protocol/tls.py @@ -266,18 +266,22 @@ class TlsClientHello(object): return self._client_hello @property - def client_cipher_suites(self): + def cipher_suites(self): return self._client_hello.cipher_suites.cipher_suites @property - def client_sni(self): + def sni(self): for extension in self._client_hello.extensions: - if (extension.type == 0x00 and len(extension.server_names) == 1 - and extension.server_names[0].type == 0): + is_valid_sni_extension = ( + extension.type == 0x00 + and len(extension.server_names) == 1 + and extension.server_names[0].type == 0 + ) + if is_valid_sni_extension: return extension.server_names[0].name @property - def client_alpn_protocols(self): + def alpn_protocols(self): for extension in self._client_hello.extensions: if extension.type == 0x10: return list(extension.alpn_protocols) @@ -304,55 +308,78 @@ class TlsClientHello(object): def __repr__(self): return "TlsClientHello( sni: %s alpn_protocols: %s, cipher_suites: %s)" % \ - (self.client_sni, self.client_alpn_protocols, self.client_cipher_suites) + (self.sni, self.alpn_protocols, self.cipher_suites) class TlsLayer(Layer): + """ + The TLS layer implements transparent TLS connections. - def __init__(self, ctx, client_tls, server_tls): - self.client_sni = None - self.client_alpn_protocols = None - self.client_ciphers = [] + It exposes the following API to child layers: + + - :py:meth:`set_server_tls` to modify TLS settings for the server connection. + - :py:attr:`server_tls`, :py:attr:`server_sni` as read-only attributes describing the current TLS settings for + the server connection. + """ + def __init__(self, ctx, client_tls, server_tls): super(TlsLayer, self).__init__(ctx) self._client_tls = client_tls self._server_tls = server_tls - self._sni_from_server_change = None + self._custom_server_sni = None + self._client_hello = None # type: TlsClientHello def __call__(self): """ - The strategy for establishing SSL is as follows: + The strategy for establishing TLS is as follows: First, we determine whether we need the server cert to establish ssl with the client. If so, we first connect to the server and then to the client. - If not, we only connect to the client and do the server_ssl lazily on a Connect message. - - An additional complexity is that establish ssl with the server may require a SNI value from - the client. In an ideal world, we'd do the following: - 1. Start the SSL handshake with the client - 2. Check if the client sends a SNI. - 3. Pause the client handshake, establish SSL with the server. - 4. Finish the client handshake with the certificate from the server. - There's just one issue: We cannot get a callback from OpenSSL if the client doesn't send a SNI. :( - Thus, we manually peek into the connection and parse the ClientHello message to obtain both SNI and ALPN values. - - Further notes: - - OpenSSL 1.0.2 introduces a callback that would help here: - https://www.openssl.org/docs/ssl/SSL_CTX_set_cert_cb.html - - The original mitmproxy issue is https://github.com/mitmproxy/mitmproxy/issues/427 - """ - - client_tls_requires_server_cert = ( - self._client_tls and self._server_tls and not self.config.no_upstream_cert - ) + If not, we only connect to the client and do the server handshake lazily. + An additional complexity is that we need to mirror SNI and ALPN from the client when connecting to the server. + We manually peek into the connection and parse the ClientHello message to obtain these values. + """ if self._client_tls: - self._parse_client_hello() + # Peek into the connection, read the initial client hello and parse it to obtain SNI and ALPN values. + try: + self._client_hello = TlsClientHello.from_client_conn(self.client_conn) + except TlsProtocolException as e: + self.log("Cannot parse Client Hello: %s" % repr(e), "error") + + # Do we need to do a server handshake now? + # There are two reasons why we would want to establish TLS with the server now: + # 1. If we already have an existing server connection and server_tls is True, + # we need to establish TLS now because .connect() will not be called anymore. + # 2. We may need information from the server connection for the client handshake. + # + # A couple of factors influence (2): + # 2.1 There actually is (or will be) a TLS-enabled upstream connection + # 2.2 An upstream connection is not wanted by the user if --no-upstream-cert is passed. + # 2.3 An upstream connection is implied by add_upstream_certs_to_client_chain + # 2.4 The client wants to negotiate an alternative protocol in its handshake, we need to find out + # what is supported by the server + # 2.5 The client did not sent a SNI value, we don't know the certificate subject. + client_tls_requires_server_connection = ( + self._server_tls + and not self.config.no_upstream_cert + and ( + self.config.add_upstream_certs_to_client_chain + or self._client_hello.alpn_protocols + or not self._client_hello.sni + ) + ) + establish_server_tls_now = ( + (self.server_conn and self._server_tls) + or client_tls_requires_server_connection + ) - if client_tls_requires_server_cert: + if self._client_tls and establish_server_tls_now: self._establish_tls_with_client_and_server() elif self._client_tls: self._establish_tls_with_client() + elif establish_server_tls_now: + self._establish_tls_with_server() layer = self.ctx.next_layer(self) layer() @@ -367,47 +394,48 @@ class TlsLayer(Layer): else: return "TlsLayer(inactive)" - def _parse_client_hello(self): - """ - Peek into the connection, read the initial client hello and parse it to obtain ALPN values. - """ - try: - parsed = TlsClientHello.from_client_conn(self.client_conn) - self.client_sni = parsed.client_sni - self.client_alpn_protocols = parsed.client_alpn_protocols - self.client_ciphers = parsed.client_cipher_suites - except TlsProtocolException as e: - self.log("Cannot parse Client Hello: %s" % repr(e), "error") - def connect(self): if not self.server_conn: self.ctx.connect() if self._server_tls and not self.server_conn.tls_established: self._establish_tls_with_server() - def set_server(self, address, server_tls=None, sni=None): - if server_tls is not None: - self._sni_from_server_change = sni - self._server_tls = server_tls - self.ctx.set_server(address, None, None) + def set_server_tls(self, server_tls, sni=None): + """ + Set the TLS settings for the next server connection that will be established. + This function will not alter an existing connection. + + Args: + server_tls: Shall we establish TLS with the server? + sni: ``bytes`` for a custom SNI value, + ``None`` for the client SNI value, + ``False`` if no SNI value should be sent. + """ + self._server_tls = server_tls + self._custom_server_sni = sni + + @property + def server_tls(self): + """ + ``True``, if the next server connection that will be established should be upgraded to TLS. + """ + return self._server_tls @property - def sni_for_server_connection(self): - if self._sni_from_server_change is False: + def server_sni(self): + """ + The Server Name Indication we want to send with the next server TLS handshake. + """ + if self._custom_server_sni is False: return None else: - return self._sni_from_server_change or self.client_sni + return self._custom_server_sni or self._client_hello.sni @property def alpn_for_client_connection(self): return self.server_conn.get_alpn_proto_negotiated() def __alpn_select_callback(self, conn_, options): - """ - Once the client signals the alternate protocols it supports, - we reconnect upstream with the same list and pass the server's choice down to the client. - """ - # This gets triggered if we haven't established an upstream connection yet. default_alpn = b'http/1.1' # alpn_preference = b'h2' @@ -422,12 +450,12 @@ class TlsLayer(Layer): return choice def _establish_tls_with_client_and_server(self): - # If establishing TLS with the server fails, we try to establish TLS with the client nonetheless - # to send an error message over TLS. try: self.ctx.connect() self._establish_tls_with_server() except Exception: + # If establishing TLS with the server fails, we try to establish TLS with the client nonetheless + # to send an error message over TLS. try: self._establish_tls_with_client() except: @@ -466,9 +494,9 @@ class TlsLayer(Layer): ClientHandshakeException, ClientHandshakeException( "Cannot establish TLS with client (sni: {sni}): {e}".format( - sni=self.client_sni, e=repr(e) + sni=self._client_hello.sni, e=repr(e) ), - self.client_sni or repr(self.server_conn.address) + self._client_hello.sni or repr(self.server_conn.address) ), sys.exc_info()[2] ) @@ -480,8 +508,8 @@ class TlsLayer(Layer): # If the server only supports spdy (next to http/1.1), it may select that # and mitmproxy would enter TCP passthrough mode, which we want to avoid. deprecated_http2_variant = lambda x: x.startswith(b"h2-") or x.startswith(b"spdy") - if self.client_alpn_protocols: - alpn = [x for x in self.client_alpn_protocols if not deprecated_http2_variant(x)] + if self._client_hello.alpn_protocols: + alpn = [x for x in self._client_hello.alpn_protocols if not deprecated_http2_variant(x)] else: alpn = None if alpn and b"h2" in alpn and not self.config.http2: @@ -490,14 +518,14 @@ class TlsLayer(Layer): ciphers_server = self.config.ciphers_server if not ciphers_server: ciphers_server = [] - for id in self.client_ciphers: + for id in self._client_hello.cipher_suites: if id in CIPHER_ID_NAME_MAP.keys(): ciphers_server.append(CIPHER_ID_NAME_MAP[id]) ciphers_server = ':'.join(ciphers_server) self.server_conn.establish_ssl( self.config.clientcerts, - self.sni_for_server_connection, + self.server_sni, method=self.config.openssl_method_server, options=self.config.openssl_options_server, verify_options=self.config.openssl_verification_mode_server, @@ -524,7 +552,7 @@ class TlsLayer(Layer): TlsProtocolException, TlsProtocolException("Cannot establish TLS with {address} (sni: {sni}): {e}".format( address=repr(self.server_conn.address), - sni=self.sni_for_server_connection, + sni=self.server_sni, e=repr(e), )), sys.exc_info()[2] @@ -534,7 +562,7 @@ class TlsLayer(Layer): TlsProtocolException, TlsProtocolException("Cannot establish TLS with {address} (sni: {sni}): {e}".format( address=repr(self.server_conn.address), - sni=self.sni_for_server_connection, + sni=self.server_sni, e=repr(e), )), sys.exc_info()[2] @@ -569,13 +597,13 @@ class TlsLayer(Layer): sans.add(host) host = upstream_cert.cn.decode("utf8").encode("idna") # Also add SNI values. - if self.client_sni: - sans.add(self.client_sni) - if self._sni_from_server_change: - sans.add(self._sni_from_server_change) + if self._client_hello.sni: + sans.add(self._client_hello.sni) + if self._custom_server_sni: + sans.add(self._custom_server_sni) - # Some applications don't consider the CN and expect the hostname to be in the SANs. - # For example, Thunderbird 38 will display a warning if the remote host is only the CN. + # RFC 2818: If a subjectAltName extension of type dNSName is present, that MUST be used as the identity. + # In other words, the Common Name is irrelevant then. if host: sans.add(host) return self.config.certstore.get_cert(host, list(sans)) diff --git a/mitmproxy/proxy/root_context.py b/mitmproxy/proxy/root_context.py index 9caae02a..c55105ec 100644 --- a/mitmproxy/proxy/root_context.py +++ b/mitmproxy/proxy/root_context.py @@ -63,7 +63,7 @@ class RootContext(object): except TlsProtocolException as e: self.log("Cannot parse Client Hello: %s" % repr(e), "error") else: - ignore = self.config.check_ignore((client_hello.client_sni, 443)) + ignore = self.config.check_ignore((client_hello.sni, 443)) if ignore: return RawTCPLayer(top_layer, logging=False) diff --git a/mitmproxy/utils.py b/mitmproxy/utils.py index 5fd062ea..cda5bba6 100644 --- a/mitmproxy/utils.py +++ b/mitmproxy/utils.py @@ -7,6 +7,9 @@ import json import importlib import inspect +import netlib.utils + + def timestamp(): """ Returns a serializable UTC timestamp. @@ -73,25 +76,7 @@ def pretty_duration(secs): return "{:.0f}ms".format(secs * 1000) -class Data: - - def __init__(self, name): - m = importlib.import_module(name) - dirname = os.path.dirname(inspect.getsourcefile(m)) - self.dirname = os.path.abspath(dirname) - - def path(self, path): - """ - Returns a path to the package data housed at 'path' under this - module.Path can be a path to a file, or to a directory. - - This function will raise ValueError if the path does not exist. - """ - fullpath = os.path.join(self.dirname, path) - if not os.path.exists(fullpath): - raise ValueError("dataPath: %s does not exist." % fullpath) - return fullpath -pkg_data = Data(__name__) +pkg_data = netlib.utils.Data(__name__) class LRUCache: diff --git a/netlib/http/__init__.py b/netlib/http/__init__.py index 9fafa28f..c4eb1d58 100644 --- a/netlib/http/__init__.py +++ b/netlib/http/__init__.py @@ -2,13 +2,13 @@ from __future__ import absolute_import, print_function, division from .request import Request from .response import Response from .headers import Headers -from .message import MultiDictView, decoded +from .message import decoded from . import http1, http2, status_codes __all__ = [ "Request", "Response", "Headers", - "MultiDictView", "decoded", + "decoded", "http1", "http2", "status_codes", ] diff --git a/netlib/http/message.py b/netlib/http/message.py index 76affeec..028f43a1 100644 --- a/netlib/http/message.py +++ b/netlib/http/message.py @@ -242,72 +242,3 @@ class decoded(object): def __exit__(self, type, value, tb): if self.ce: self.message.encode(self.ce) - - -class MultiDictView(MultiDict): - """ - Some parts in HTTP (Cookies, URL query strings, ...) require a specific data structure: A MultiDict. - It behaves mostly like an ordered dict but it can have several values for the same key. - - The MultiDictView provides a MultiDict *view* on an :py:class:`Request` or :py:class:`Response`. - That is, it represents a part of the request as a MultiDict, but doesn't contain state/data themselves. - - For example, ``request.cookies`` provides a view on the ``Cookie: ...`` header. - Any change to ``request.cookies`` will also modify the ``Cookie`` header. - Any change to the ``Cookie`` header will also modify ``request.cookies``. - - Example: - - .. code-block:: python - - # Cookies are represented as a MultiDict. - >>> request.cookies - MultiDictView[("name", "value"), ("a", "false"), ("a", "42")] - - # MultiDicts mostly behave like a normal dict. - >>> request.cookies["name"] - "value" - - # If there is more than one value, only the first value is returned. - >>> request.cookies["a"] - "false" - - # `.get_all(key)` returns a list of all values. - >>> request.cookies.get_all("a") - ["false", "42"] - - # Changes to the headers are immediately reflected in the cookies. - >>> request.cookies - MultiDictView[("name", "value"), ...] - >>> del request.headers["Cookie"] - >>> request.cookies - MultiDictView[] # empty now - """ - - def __init__(self, attr, message): - if False: # pragma: no cover - # We do not want to call the parent constructor here as that - # would cause an unnecessary parse/unparse pass. - # This is here to silence linters. Message - super(MultiDictView, self).__init__(None) - self._attr = attr - self._message = message # type: Message - - @staticmethod - def _kconv(key): - # All request-attributes are case-sensitive. - return key - - @staticmethod - def _reduce_values(values): - # We just return the first element if - # multiple elements exist with the same key. - return values[0] - - @property - def fields(self): - return getattr(self._message, "_" + self._attr) - - @fields.setter - def fields(self, value): - setattr(self._message, self._attr, value) diff --git a/netlib/http/request.py b/netlib/http/request.py index ae28084b..056a2d93 100644 --- a/netlib/http/request.py +++ b/netlib/http/request.py @@ -10,8 +10,9 @@ from netlib import utils from netlib.http import cookies from netlib.odict import ODict from .. import encoding +from ..multidict import MultiDictView from .headers import Headers -from .message import Message, _native, _always_bytes, MessageData, MultiDictView +from .message import Message, _native, _always_bytes, MessageData # This regex extracts & splits the host header into host and port. # Handles the edge case of IPv6 addresses containing colons. @@ -228,20 +229,25 @@ class Request(Message): """ The request query string as an :py:class:`MultiDictView` object. """ - return MultiDictView("query", self) + return MultiDictView( + self._get_query, + self._set_query + ) - @property - def _query(self): + def _get_query(self): _, _, _, _, query, _ = urllib.parse.urlparse(self.url) return tuple(utils.urldecode(query)) - @query.setter - def query(self, value): + def _set_query(self, value): query = utils.urlencode(value) scheme, netloc, path, params, _, fragment = urllib.parse.urlparse(self.url) _, _, _, self.path = utils.parse_url( urllib.parse.urlunparse([scheme, netloc, path, params, query, fragment])) + @query.setter + def query(self, value): + self._set_query(value) + @property def cookies(self): # type: () -> MultiDictView @@ -250,16 +256,21 @@ class Request(Message): An empty :py:class:`MultiDictView` object if the cookie monster ate them all. """ - return MultiDictView("cookies", self) + return MultiDictView( + self._get_cookies, + self._set_cookies + ) - @property - def _cookies(self): + def _get_cookies(self): h = self.headers.get_all("Cookie") return tuple(cookies.parse_cookie_headers(h)) + def _set_cookies(self, value): + self.headers["cookie"] = cookies.format_cookie_header(value) + @cookies.setter def cookies(self, value): - self.headers["cookie"] = cookies.format_cookie_header(value) + self._set_cookies(value) @property def path_components(self): @@ -322,17 +333,18 @@ class Request(Message): An empty MultiDictView if the content-type indicates non-form data or the content could not be parsed. """ - return MultiDictView("urlencoded_form", self) + return MultiDictView( + self._get_urlencoded_form, + self._set_urlencoded_form + ) - @property - def _urlencoded_form(self): + def _get_urlencoded_form(self): is_valid_content_type = "application/x-www-form-urlencoded" in self.headers.get("content-type", "").lower() if is_valid_content_type: return tuple(utils.urldecode(self.content)) return () - @urlencoded_form.setter - def urlencoded_form(self, value): + def _set_urlencoded_form(self, value): """ Sets the body to the URL-encoded form data, and adds the appropriate content-type header. This will overwrite the existing content if there is one. @@ -340,21 +352,30 @@ class Request(Message): self.headers["content-type"] = "application/x-www-form-urlencoded" self.content = utils.urlencode(value) + @urlencoded_form.setter + def urlencoded_form(self, value): + self._set_urlencoded_form(value) + @property def multipart_form(self): """ The multipart form data as an :py:class:`MultipartFormDict` object. None if the content-type indicates non-form data. """ - return MultiDictView("multipart_form", self) + return MultiDictView( + self._get_multipart_form, + self._set_multipart_form + ) - @property - def _multipart_form(self): + def _get_multipart_form(self): is_valid_content_type = "multipart/form-data" in self.headers.get("content-type", "").lower() if is_valid_content_type: return utils.multipartdecode(self.headers, self.content) return () + def _set_multipart_form(self, value): + raise NotImplementedError() + @multipart_form.setter def multipart_form(self, value): - raise NotImplementedError() + self._set_multipart_form(value) diff --git a/netlib/http/response.py b/netlib/http/response.py index 6d56fc1f..7d272e10 100644 --- a/netlib/http/response.py +++ b/netlib/http/response.py @@ -5,7 +5,8 @@ import time from . import cookies from .headers import Headers -from .message import Message, _native, _always_bytes, MessageData, MultiDictView +from .message import Message, _native, _always_bytes, MessageData +from ..multidict import MultiDictView from .. import utils @@ -80,21 +81,26 @@ class Response(Message): Caveats: Updating the attr """ - return MultiDictView("cookies", self) + return MultiDictView( + self._get_cookies, + self._set_cookies + ) - @property - def _cookies(self): + def _get_cookies(self): h = self.headers.get_all("set-cookie") return tuple(cookies.parse_set_cookie_headers(h)) - @cookies.setter - def cookies(self, all_cookies): + def _set_cookies(self, value): cookie_headers = [] - for k, v in all_cookies: + for k, v in value: header = cookies.format_set_cookie_header(k, v[0], v[1]) cookie_headers.append(header) self.headers.set_all("set-cookie", cookie_headers) + @cookies.setter + def cookies(self, value): + self._set_cookies(value) + def refresh(self, now=None): """ This fairly complex and heuristic function refreshes a server diff --git a/netlib/multidict.py b/netlib/multidict.py index ec1b24d8..248acdec 100644 --- a/netlib/multidict.py +++ b/netlib/multidict.py @@ -15,13 +15,7 @@ from .utils import Serializable @six.add_metaclass(ABCMeta) -class MultiDict(MutableMapping, Serializable): - def __init__(self, fields=None): - - # it is important for us that .fields is immutable, so that we can easily - # detect changes to it. - self.fields = tuple(fields) if fields else tuple() # type: Tuple[Tuple[bytes, bytes], ...] - +class _MultiDict(MutableMapping, Serializable): def __repr__(self): fields = tuple( repr(field) @@ -100,7 +94,7 @@ class MultiDict(MutableMapping, Serializable): value for k, value in self.fields if self._kconv(k) == key - ] + ] def set_all(self, key, values): """ @@ -176,7 +170,7 @@ class MultiDict(MutableMapping, Serializable): if multi: return self.fields else: - return super(MultiDict, self).items() + return super(_MultiDict, self).items() def to_dict(self): """ @@ -216,6 +210,12 @@ class MultiDict(MutableMapping, Serializable): return cls(tuple(x) for x in state) +class MultiDict(_MultiDict): + def __init__(self, fields=None): + super(MultiDict, self).__init__() + self.fields = tuple(fields) if fields else tuple() # type: Tuple[Tuple[bytes, bytes], ...] + + @six.add_metaclass(ABCMeta) class ImmutableMultiDict(MultiDict): def _immutable(self, *_): @@ -249,3 +249,34 @@ class ImmutableMultiDict(MultiDict): ret = self.copy() super(ImmutableMultiDict, ret).insert(index, key, value) return ret + + +class MultiDictView(_MultiDict): + """ + The MultiDictView provides the MultiDict interface over calculated data. + The view itself contains no state - data is retrieved from the parent on + request, and stored back to the parent on change. + """ + def __init__(self, getter, setter): + self._getter = getter + self._setter = setter + super(MultiDictView, self).__init__() + + @staticmethod + def _kconv(key): + # All request-attributes are case-sensitive. + return key + + @staticmethod + def _reduce_values(values): + # We just return the first element if + # multiple elements exist with the same key. + return values[0] + + @property + def fields(self): + return self._getter() + + @fields.setter + def fields(self, value): + return self._setter(value) diff --git a/netlib/odict.py b/netlib/odict.py index 461192f7..8a638dab 100644 --- a/netlib/odict.py +++ b/netlib/odict.py @@ -1,5 +1,6 @@ from __future__ import (absolute_import, print_function, division) import copy + import six from .utils import Serializable, safe_subn @@ -27,27 +28,24 @@ class ODict(Serializable): def __iter__(self): return self.lst.__iter__() - def __getitem__(self, k): + def __getitem__(self, key): """ Returns a list of values matching key. """ - ret = [] - k = self._kconv(k) - for i in self.lst: - if self._kconv(i[0]) == k: - ret.append(i[1]) - return ret - def keys(self): - return list(set([self._kconv(i[0]) for i in self.lst])) + key = self._kconv(key) + return [ + v + for k, v in self.lst + if self._kconv(k) == key + ] - def _filter_lst(self, k, lst): - k = self._kconv(k) - new = [] - for i in lst: - if self._kconv(i[0]) != k: - new.append(i) - return new + def keys(self): + return list( + set( + self._kconv(k) for k, _ in self.lst + ) + ) def __len__(self): """ @@ -81,14 +79,19 @@ class ODict(Serializable): """ Delete all items matching k. """ - self.lst = self._filter_lst(k, self.lst) - - def __contains__(self, k): k = self._kconv(k) - for i in self.lst: - if self._kconv(i[0]) == k: - return True - return False + self.lst = [ + i + for i in self.lst + if self._kconv(i[0]) != k + ] + + def __contains__(self, key): + key = self._kconv(key) + return any( + self._kconv(k) == key + for k, _ in self.lst + ) def add(self, key, value, prepend=False): if prepend: @@ -127,40 +130,24 @@ class ODict(Serializable): def __repr__(self): return repr(self.lst) - def in_any(self, key, value, caseless=False): - """ - Do any of the values matching key contain value? - - If caseless is true, value comparison is case-insensitive. - """ - if caseless: - value = value.lower() - for i in self[key]: - if caseless: - i = i.lower() - if value in i: - return True - return False - def replace(self, pattern, repl, *args, **kwargs): """ Replaces a regular expression pattern with repl in both keys and - values. Encoded content will be decoded before replacement, and - re-encoded afterwards. + values. Returns the number of replacements made. """ - nlst, count = [], 0 - for i in self.lst: - k, c = safe_subn(pattern, repl, i[0], *args, **kwargs) + new, count = [], 0 + for k, v in self.lst: + k, c = safe_subn(pattern, repl, k, *args, **kwargs) count += c - v, c = safe_subn(pattern, repl, i[1], *args, **kwargs) + v, c = safe_subn(pattern, repl, v, *args, **kwargs) count += c - nlst.append([k, v]) - self.lst = nlst + new.append([k, v]) + self.lst = new return count - # Implement the StateObject protocol from mitmproxy + # Implement Serializable def get_state(self): return [tuple(i) for i in self.lst] diff --git a/pathod/utils.py b/pathod/utils.py index 1e5bd9a4..d1e2dd00 100644 --- a/pathod/utils.py +++ b/pathod/utils.py @@ -1,5 +1,6 @@ import os import sys +import netlib.utils SIZE_UNITS = dict( @@ -75,27 +76,7 @@ def escape_unprintables(s): return s -class Data(object): - - def __init__(self, name): - m = __import__(name) - dirname, _ = os.path.split(m.__file__) - self.dirname = os.path.abspath(dirname) - - def path(self, path): - """ - Returns a path to the package data housed at 'path' under this - module.Path can be a path to a file, or to a directory. - - This function will raise ValueError if the path does not exist. - """ - fullpath = os.path.join(self.dirname, path) - if not os.path.exists(fullpath): - raise ValueError("dataPath: %s does not exist." % fullpath) - return fullpath - - -data = Data(__name__) +data = netlib.utils.Data(__name__) def daemonize(stdin='/dev/null', stdout='/dev/null', stderr='/dev/null'): # pragma: no cover diff --git a/test/mitmproxy/test_examples.py b/test/mitmproxy/test_examples.py index ac79b093..c4b06f4b 100644 --- a/test/mitmproxy/test_examples.py +++ b/test/mitmproxy/test_examples.py @@ -5,11 +5,12 @@ from contextlib import contextmanager from mitmproxy import utils, script from mitmproxy.proxy import config +import netlib.utils from netlib import tutils as netutils from netlib.http import Headers from . import tservers, tutils -example_dir = utils.Data(__name__).path("../../examples") +example_dir = netlib.utils.Data(__name__).path("../../examples") class DummyContext(object): diff --git a/test/mitmproxy/tutils.py b/test/mitmproxy/tutils.py index d51ac185..2dfd710e 100644 --- a/test/mitmproxy/tutils.py +++ b/test/mitmproxy/tutils.py @@ -8,6 +8,7 @@ from contextlib import contextmanager from unittest.case import SkipTest +import netlib.utils import netlib.tutils from mitmproxy import utils, controller from mitmproxy.models import ( @@ -163,4 +164,4 @@ def capture_stderr(command, *args, **kwargs): sys.stderr = out -test_data = utils.Data(__name__) +test_data = netlib.utils.Data(__name__) diff --git a/test/netlib/http/test_request.py b/test/netlib/http/test_request.py index eefdc091..fae7aefe 100644 --- a/test/netlib/http/test_request.py +++ b/test/netlib/http/test_request.py @@ -3,9 +3,7 @@ from __future__ import absolute_import, print_function, division import six -from netlib import utils from netlib.http import Headers -from netlib.odict import ODict from netlib.tutils import treq, raises from .test_message import _test_decoded_attr, _test_passthrough_attr diff --git a/test/netlib/test_multidict.py b/test/netlib/test_multidict.py index ceea3806..5bb65e3f 100644 --- a/test/netlib/test_multidict.py +++ b/test/netlib/test_multidict.py @@ -1,5 +1,5 @@ from netlib import tutils -from netlib.multidict import MultiDict, ImmutableMultiDict +from netlib.multidict import MultiDict, ImmutableMultiDict, MultiDictView class _TMulti(object): @@ -214,4 +214,26 @@ class TestImmutableMultiDict(object): def test_with_insert(self): md = TImmutableMultiDict() assert md.with_insert(0, "foo", "bar").fields == (("foo", "bar"),) - assert md.fields == ()
\ No newline at end of file + + +class TParent(object): + def __init__(self): + self.vals = tuple() + + def setter(self, vals): + self.vals = vals + + def getter(self): + return self.vals + + +class TestMultiDictView(object): + def test_modify(self): + p = TParent() + tv = MultiDictView(p.getter, p.setter) + assert len(tv) == 0 + tv["a"] = "b" + assert p.vals == (("a", "b"),) + tv["c"] = "b" + assert p.vals == (("a", "b"), ("c", "b")) + assert tv["a"] == "b" diff --git a/test/netlib/test_odict.py b/test/netlib/test_odict.py index f0985ef6..b6fd6401 100644 --- a/test/netlib/test_odict.py +++ b/test/netlib/test_odict.py @@ -27,16 +27,6 @@ class TestODict(object): b.set_state(state) assert b == od - def test_in_any(self): - od = odict.ODict() - od["one"] = ["atwoa", "athreea"] - assert od.in_any("one", "two") - assert od.in_any("one", "three") - assert not od.in_any("one", "four") - assert not od.in_any("nonexistent", "foo") - assert not od.in_any("one", "TWO") - assert od.in_any("one", "TWO", True) - def test_iter(self): od = odict.ODict() assert not [i for i in od] diff --git a/test/pathod/test_pathoc.py b/test/pathod/test_pathoc.py index 8d0f92ac..4e8c89c5 100644 --- a/test/pathod/test_pathoc.py +++ b/test/pathod/test_pathoc.py @@ -211,7 +211,7 @@ class TestDaemon(_TestDaemon): c.stop() @skip_windows - @pytest.mark.xfail + @pytest.mark.skip(reason="race condition") def test_wait_finish(self): c = pathoc.Pathoc( ("127.0.0.1", self.d.port), diff --git a/test/pathod/test_pathod.py b/test/pathod/test_pathod.py index 1718cc0b..05a3962e 100644 --- a/test/pathod/test_pathod.py +++ b/test/pathod/test_pathod.py @@ -129,7 +129,7 @@ class CommonTests(tutils.DaemonTests): l = self.d.last_log() # FIXME: Other binary data elements - @pytest.mark.skip + @pytest.mark.skip(reason="race condition") def test_sizelimit(self): r = self.get("200:b@1g") assert r.status_code == 800 @@ -143,7 +143,7 @@ class CommonTests(tutils.DaemonTests): def test_info(self): assert tuple(self.d.info()["version"]) == version.IVERSION - @pytest.mark.skip + @pytest.mark.skip(reason="race condition") def test_logs(self): assert self.d.clear_log() assert not self.d.last_log() @@ -223,7 +223,7 @@ class CommonTests(tutils.DaemonTests): ) assert r[1].payload == "test" - @pytest.mark.skip + @pytest.mark.skip(reason="race condition") def test_websocket_frame_reflect_error(self): r, _ = self.pathoc( ["ws:/p/", "wf:-mask:knone:f'wf:b@10':i13,'a'"], @@ -233,6 +233,7 @@ class CommonTests(tutils.DaemonTests): # FIXME: Race Condition? assert "Parse error" in self.d.text_log() + @pytest.mark.skip(reason="race condition") def test_websocket_frame_disconnect_error(self): self.pathoc(["ws:/p/", "wf:b@10:d3"], ws_read_limit=0) assert self.d.last_log() diff --git a/test/pathod/tutils.py b/test/pathod/tutils.py index 9739afde..f6ed3efb 100644 --- a/test/pathod/tutils.py +++ b/test/pathod/tutils.py @@ -116,7 +116,7 @@ tmpdir = netlib.tutils.tmpdir raises = netlib.tutils.raises -test_data = utils.Data(__name__) +test_data = netlib.utils.Data(__name__) def render(r, settings=language.Settings()): |