aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--issue_template.md4
-rw-r--r--mitmproxy/console/common.py8
-rw-r--r--mitmproxy/flow.py19
-rw-r--r--mitmproxy/protocol/base.py11
-rw-r--r--mitmproxy/protocol/http.py14
-rw-r--r--mitmproxy/protocol/tls.py178
-rw-r--r--mitmproxy/proxy/root_context.py2
-rw-r--r--mitmproxy/utils.py23
-rw-r--r--netlib/http/__init__.py4
-rw-r--r--netlib/http/message.py69
-rw-r--r--netlib/http/request.py59
-rw-r--r--netlib/http/response.py20
-rw-r--r--netlib/multidict.py49
-rw-r--r--netlib/odict.py81
-rw-r--r--pathod/utils.py23
-rw-r--r--test/mitmproxy/test_examples.py3
-rw-r--r--test/mitmproxy/tutils.py3
-rw-r--r--test/netlib/http/test_request.py2
-rw-r--r--test/netlib/test_multidict.py26
-rw-r--r--test/netlib/test_odict.py10
-rw-r--r--test/pathod/test_pathoc.py2
-rw-r--r--test/pathod/test_pathod.py7
-rw-r--r--test/pathod/tutils.py2
23 files changed, 298 insertions, 321 deletions
diff --git a/issue_template.md b/issue_template.md
index 3f9be788..08d390e4 100644
--- a/issue_template.md
+++ b/issue_template.md
@@ -10,10 +10,10 @@
##### What went wrong?
-##### Any other comments?
+##### Any other comments? What have you tried so far?
---
Mitmproxy Version:
-Operating System: \ No newline at end of file
+Operating System:
diff --git a/mitmproxy/console/common.py b/mitmproxy/console/common.py
index 4e472fb6..25658dfa 100644
--- a/mitmproxy/console/common.py
+++ b/mitmproxy/console/common.py
@@ -154,7 +154,7 @@ def raw_format_flow(f, focus, extended):
if f["intercepted"] and not f["acked"]:
uc = "intercept"
- elif f["resp_code"] or f["err_msg"]:
+ elif "resp_code" in f or "err_msg" in f:
uc = "text"
else:
uc = "title"
@@ -173,7 +173,7 @@ def raw_format_flow(f, focus, extended):
("fixed", preamble, urwid.Text(""))
)
- if f["resp_code"]:
+ if "resp_code" in f:
codes = {
2: "code_200",
3: "code_300",
@@ -185,6 +185,8 @@ def raw_format_flow(f, focus, extended):
if f["resp_is_replay"]:
resp.append(fcol(SYMBOL_REPLAY, "replay"))
resp.append(fcol(f["resp_code"], ccol))
+ if extended:
+ resp.append(fcol(f["resp_reason"], ccol))
if f["intercepted"] and f["resp_code"] and not f["acked"]:
rc = "intercept"
else:
@@ -412,7 +414,6 @@ def format_flow(f, focus, extended=False, hostheader=False, marked=False):
req_http_version = f.request.http_version,
err_msg = f.error.msg if f.error else None,
- resp_code = f.response.status_code if f.response else None,
marked = marked,
)
@@ -430,6 +431,7 @@ def format_flow(f, focus, extended=False, hostheader=False, marked=False):
d.update(dict(
resp_code = f.response.status_code,
+ resp_reason = f.response.reason,
resp_is_replay = f.response.is_replay,
resp_clen = contentdesc,
roundtrip = roundtrip,
diff --git a/mitmproxy/flow.py b/mitmproxy/flow.py
index 647ebf68..a9018e16 100644
--- a/mitmproxy/flow.py
+++ b/mitmproxy/flow.py
@@ -13,6 +13,8 @@ from six.moves import http_cookies, http_cookiejar, urllib
import os
import re
+from typing import List, Optional, Set
+
from netlib import wsgi, odict
from netlib.exceptions import HttpException
from netlib.http import Headers, http1, cookies
@@ -376,8 +378,11 @@ class StickyAuthState:
f.request.headers["authorization"] = self.hosts[host]
+@six.add_metaclass(ABCMeta)
class FlowList(object):
- __metaclass__ = ABCMeta
+
+ def __init__(self):
+ self._list = [] # type: List[Flow]
def __iter__(self):
return iter(self._list)
@@ -416,7 +421,7 @@ class FlowList(object):
class FlowView(FlowList):
def __init__(self, store, filt=None):
- self._list = []
+ super(FlowView, self).__init__()
if not filt:
filt = lambda flow: True
self._build(store, filt)
@@ -458,7 +463,7 @@ class FlowStore(FlowList):
"""
def __init__(self):
- self._list = []
+ super(FlowStore, self).__init__()
self._set = set() # Used for O(1) lookups
self.views = []
self._recalculate_views()
@@ -649,18 +654,18 @@ class FlowMaster(controller.ServerMaster):
self.server_playback = None
self.client_playback = None
self.kill_nonreplay = False
- self.scripts = []
+ self.scripts = [] # type: List[script.Script]
self.pause_scripts = False
- self.stickycookie_state = False
+ self.stickycookie_state = None # type: Optional[StickyCookieState]
self.stickycookie_txt = None
- self.stickyauth_state = False
+ self.stickyauth_state = False # type: Optional[StickyAuthState]
self.stickyauth_txt = None
self.anticache = False
self.anticomp = False
- self.stream_large_bodies = False
+ self.stream_large_bodies = None # type: Optional[StreamLargeBodies]
self.refresh_server_playback = False
self.replacehooks = ReplaceHooks()
self.setheaders = SetHeaders()
diff --git a/mitmproxy/protocol/base.py b/mitmproxy/protocol/base.py
index 536f2753..c8e58d1b 100644
--- a/mitmproxy/protocol/base.py
+++ b/mitmproxy/protocol/base.py
@@ -133,24 +133,15 @@ class ServerConnectionMixin(object):
"The proxy shall not connect to itself.".format(repr(address))
)
- def set_server(self, address, server_tls=None, sni=None):
+ def set_server(self, address):
"""
Sets a new server address. If there is an existing connection, it will be closed.
-
- Raises:
- ~mitmproxy.exceptions.ProtocolException:
- if ``server_tls`` is ``True``, but there was no TLS layer on the
- protocol stack which could have processed this.
"""
if self.server_conn:
self.disconnect()
self.log("Set new server address: " + repr(address), "debug")
self.server_conn.address = address
self.__check_self_connect()
- if server_tls:
- raise ProtocolException(
- "Cannot upgrade to TLS, no TLS layer on the protocol stack."
- )
def disconnect(self):
"""
diff --git a/mitmproxy/protocol/http.py b/mitmproxy/protocol/http.py
index 9cb35176..d9111303 100644
--- a/mitmproxy/protocol/http.py
+++ b/mitmproxy/protocol/http.py
@@ -120,7 +120,7 @@ class UpstreamConnectLayer(Layer):
if address != self.server_conn.via.address:
self.ctx.set_server(address)
- def set_server(self, address, server_tls=None, sni=None):
+ def set_server(self, address):
if self.ctx.server_conn:
self.ctx.disconnect()
address = tcp.Address.wrap(address)
@@ -128,11 +128,6 @@ class UpstreamConnectLayer(Layer):
self.connect_request.port = address.port
self.server_conn.address = address
- if server_tls:
- raise ProtocolException(
- "Cannot upgrade to TLS, no TLS layer on the protocol stack."
- )
-
class HttpLayer(Layer):
@@ -149,7 +144,7 @@ class HttpLayer(Layer):
def __call__(self):
if self.mode == "transparent":
- self.__initial_server_tls = self._server_tls
+ self.__initial_server_tls = self.server_tls
self.__initial_server_conn = self.server_conn
while True:
try:
@@ -360,8 +355,9 @@ class HttpLayer(Layer):
if self.mode == "regular" or self.mode == "transparent":
# If there's an existing connection that doesn't match our expectations, kill it.
- if address != self.server_conn.address or tls != self.server_conn.tls_established:
- self.set_server(address, tls, address.host)
+ if address != self.server_conn.address or tls != self.server_tls:
+ self.set_server(address)
+ self.set_server_tls(tls, address.host)
# Establish connection is neccessary.
if not self.server_conn:
self.connect()
diff --git a/mitmproxy/protocol/tls.py b/mitmproxy/protocol/tls.py
index 26c3f9d2..74c55ab4 100644
--- a/mitmproxy/protocol/tls.py
+++ b/mitmproxy/protocol/tls.py
@@ -266,18 +266,22 @@ class TlsClientHello(object):
return self._client_hello
@property
- def client_cipher_suites(self):
+ def cipher_suites(self):
return self._client_hello.cipher_suites.cipher_suites
@property
- def client_sni(self):
+ def sni(self):
for extension in self._client_hello.extensions:
- if (extension.type == 0x00 and len(extension.server_names) == 1
- and extension.server_names[0].type == 0):
+ is_valid_sni_extension = (
+ extension.type == 0x00
+ and len(extension.server_names) == 1
+ and extension.server_names[0].type == 0
+ )
+ if is_valid_sni_extension:
return extension.server_names[0].name
@property
- def client_alpn_protocols(self):
+ def alpn_protocols(self):
for extension in self._client_hello.extensions:
if extension.type == 0x10:
return list(extension.alpn_protocols)
@@ -304,55 +308,78 @@ class TlsClientHello(object):
def __repr__(self):
return "TlsClientHello( sni: %s alpn_protocols: %s, cipher_suites: %s)" % \
- (self.client_sni, self.client_alpn_protocols, self.client_cipher_suites)
+ (self.sni, self.alpn_protocols, self.cipher_suites)
class TlsLayer(Layer):
+ """
+ The TLS layer implements transparent TLS connections.
- def __init__(self, ctx, client_tls, server_tls):
- self.client_sni = None
- self.client_alpn_protocols = None
- self.client_ciphers = []
+ It exposes the following API to child layers:
+
+ - :py:meth:`set_server_tls` to modify TLS settings for the server connection.
+ - :py:attr:`server_tls`, :py:attr:`server_sni` as read-only attributes describing the current TLS settings for
+ the server connection.
+ """
+ def __init__(self, ctx, client_tls, server_tls):
super(TlsLayer, self).__init__(ctx)
self._client_tls = client_tls
self._server_tls = server_tls
- self._sni_from_server_change = None
+ self._custom_server_sni = None
+ self._client_hello = None # type: TlsClientHello
def __call__(self):
"""
- The strategy for establishing SSL is as follows:
+ The strategy for establishing TLS is as follows:
First, we determine whether we need the server cert to establish ssl with the client.
If so, we first connect to the server and then to the client.
- If not, we only connect to the client and do the server_ssl lazily on a Connect message.
-
- An additional complexity is that establish ssl with the server may require a SNI value from
- the client. In an ideal world, we'd do the following:
- 1. Start the SSL handshake with the client
- 2. Check if the client sends a SNI.
- 3. Pause the client handshake, establish SSL with the server.
- 4. Finish the client handshake with the certificate from the server.
- There's just one issue: We cannot get a callback from OpenSSL if the client doesn't send a SNI. :(
- Thus, we manually peek into the connection and parse the ClientHello message to obtain both SNI and ALPN values.
-
- Further notes:
- - OpenSSL 1.0.2 introduces a callback that would help here:
- https://www.openssl.org/docs/ssl/SSL_CTX_set_cert_cb.html
- - The original mitmproxy issue is https://github.com/mitmproxy/mitmproxy/issues/427
- """
-
- client_tls_requires_server_cert = (
- self._client_tls and self._server_tls and not self.config.no_upstream_cert
- )
+ If not, we only connect to the client and do the server handshake lazily.
+ An additional complexity is that we need to mirror SNI and ALPN from the client when connecting to the server.
+ We manually peek into the connection and parse the ClientHello message to obtain these values.
+ """
if self._client_tls:
- self._parse_client_hello()
+ # Peek into the connection, read the initial client hello and parse it to obtain SNI and ALPN values.
+ try:
+ self._client_hello = TlsClientHello.from_client_conn(self.client_conn)
+ except TlsProtocolException as e:
+ self.log("Cannot parse Client Hello: %s" % repr(e), "error")
+
+ # Do we need to do a server handshake now?
+ # There are two reasons why we would want to establish TLS with the server now:
+ # 1. If we already have an existing server connection and server_tls is True,
+ # we need to establish TLS now because .connect() will not be called anymore.
+ # 2. We may need information from the server connection for the client handshake.
+ #
+ # A couple of factors influence (2):
+ # 2.1 There actually is (or will be) a TLS-enabled upstream connection
+ # 2.2 An upstream connection is not wanted by the user if --no-upstream-cert is passed.
+ # 2.3 An upstream connection is implied by add_upstream_certs_to_client_chain
+ # 2.4 The client wants to negotiate an alternative protocol in its handshake, we need to find out
+ # what is supported by the server
+ # 2.5 The client did not sent a SNI value, we don't know the certificate subject.
+ client_tls_requires_server_connection = (
+ self._server_tls
+ and not self.config.no_upstream_cert
+ and (
+ self.config.add_upstream_certs_to_client_chain
+ or self._client_hello.alpn_protocols
+ or not self._client_hello.sni
+ )
+ )
+ establish_server_tls_now = (
+ (self.server_conn and self._server_tls)
+ or client_tls_requires_server_connection
+ )
- if client_tls_requires_server_cert:
+ if self._client_tls and establish_server_tls_now:
self._establish_tls_with_client_and_server()
elif self._client_tls:
self._establish_tls_with_client()
+ elif establish_server_tls_now:
+ self._establish_tls_with_server()
layer = self.ctx.next_layer(self)
layer()
@@ -367,47 +394,48 @@ class TlsLayer(Layer):
else:
return "TlsLayer(inactive)"
- def _parse_client_hello(self):
- """
- Peek into the connection, read the initial client hello and parse it to obtain ALPN values.
- """
- try:
- parsed = TlsClientHello.from_client_conn(self.client_conn)
- self.client_sni = parsed.client_sni
- self.client_alpn_protocols = parsed.client_alpn_protocols
- self.client_ciphers = parsed.client_cipher_suites
- except TlsProtocolException as e:
- self.log("Cannot parse Client Hello: %s" % repr(e), "error")
-
def connect(self):
if not self.server_conn:
self.ctx.connect()
if self._server_tls and not self.server_conn.tls_established:
self._establish_tls_with_server()
- def set_server(self, address, server_tls=None, sni=None):
- if server_tls is not None:
- self._sni_from_server_change = sni
- self._server_tls = server_tls
- self.ctx.set_server(address, None, None)
+ def set_server_tls(self, server_tls, sni=None):
+ """
+ Set the TLS settings for the next server connection that will be established.
+ This function will not alter an existing connection.
+
+ Args:
+ server_tls: Shall we establish TLS with the server?
+ sni: ``bytes`` for a custom SNI value,
+ ``None`` for the client SNI value,
+ ``False`` if no SNI value should be sent.
+ """
+ self._server_tls = server_tls
+ self._custom_server_sni = sni
+
+ @property
+ def server_tls(self):
+ """
+ ``True``, if the next server connection that will be established should be upgraded to TLS.
+ """
+ return self._server_tls
@property
- def sni_for_server_connection(self):
- if self._sni_from_server_change is False:
+ def server_sni(self):
+ """
+ The Server Name Indication we want to send with the next server TLS handshake.
+ """
+ if self._custom_server_sni is False:
return None
else:
- return self._sni_from_server_change or self.client_sni
+ return self._custom_server_sni or self._client_hello.sni
@property
def alpn_for_client_connection(self):
return self.server_conn.get_alpn_proto_negotiated()
def __alpn_select_callback(self, conn_, options):
- """
- Once the client signals the alternate protocols it supports,
- we reconnect upstream with the same list and pass the server's choice down to the client.
- """
-
# This gets triggered if we haven't established an upstream connection yet.
default_alpn = b'http/1.1'
# alpn_preference = b'h2'
@@ -422,12 +450,12 @@ class TlsLayer(Layer):
return choice
def _establish_tls_with_client_and_server(self):
- # If establishing TLS with the server fails, we try to establish TLS with the client nonetheless
- # to send an error message over TLS.
try:
self.ctx.connect()
self._establish_tls_with_server()
except Exception:
+ # If establishing TLS with the server fails, we try to establish TLS with the client nonetheless
+ # to send an error message over TLS.
try:
self._establish_tls_with_client()
except:
@@ -466,9 +494,9 @@ class TlsLayer(Layer):
ClientHandshakeException,
ClientHandshakeException(
"Cannot establish TLS with client (sni: {sni}): {e}".format(
- sni=self.client_sni, e=repr(e)
+ sni=self._client_hello.sni, e=repr(e)
),
- self.client_sni or repr(self.server_conn.address)
+ self._client_hello.sni or repr(self.server_conn.address)
),
sys.exc_info()[2]
)
@@ -480,8 +508,8 @@ class TlsLayer(Layer):
# If the server only supports spdy (next to http/1.1), it may select that
# and mitmproxy would enter TCP passthrough mode, which we want to avoid.
deprecated_http2_variant = lambda x: x.startswith(b"h2-") or x.startswith(b"spdy")
- if self.client_alpn_protocols:
- alpn = [x for x in self.client_alpn_protocols if not deprecated_http2_variant(x)]
+ if self._client_hello.alpn_protocols:
+ alpn = [x for x in self._client_hello.alpn_protocols if not deprecated_http2_variant(x)]
else:
alpn = None
if alpn and b"h2" in alpn and not self.config.http2:
@@ -490,14 +518,14 @@ class TlsLayer(Layer):
ciphers_server = self.config.ciphers_server
if not ciphers_server:
ciphers_server = []
- for id in self.client_ciphers:
+ for id in self._client_hello.cipher_suites:
if id in CIPHER_ID_NAME_MAP.keys():
ciphers_server.append(CIPHER_ID_NAME_MAP[id])
ciphers_server = ':'.join(ciphers_server)
self.server_conn.establish_ssl(
self.config.clientcerts,
- self.sni_for_server_connection,
+ self.server_sni,
method=self.config.openssl_method_server,
options=self.config.openssl_options_server,
verify_options=self.config.openssl_verification_mode_server,
@@ -524,7 +552,7 @@ class TlsLayer(Layer):
TlsProtocolException,
TlsProtocolException("Cannot establish TLS with {address} (sni: {sni}): {e}".format(
address=repr(self.server_conn.address),
- sni=self.sni_for_server_connection,
+ sni=self.server_sni,
e=repr(e),
)),
sys.exc_info()[2]
@@ -534,7 +562,7 @@ class TlsLayer(Layer):
TlsProtocolException,
TlsProtocolException("Cannot establish TLS with {address} (sni: {sni}): {e}".format(
address=repr(self.server_conn.address),
- sni=self.sni_for_server_connection,
+ sni=self.server_sni,
e=repr(e),
)),
sys.exc_info()[2]
@@ -569,13 +597,13 @@ class TlsLayer(Layer):
sans.add(host)
host = upstream_cert.cn.decode("utf8").encode("idna")
# Also add SNI values.
- if self.client_sni:
- sans.add(self.client_sni)
- if self._sni_from_server_change:
- sans.add(self._sni_from_server_change)
+ if self._client_hello.sni:
+ sans.add(self._client_hello.sni)
+ if self._custom_server_sni:
+ sans.add(self._custom_server_sni)
- # Some applications don't consider the CN and expect the hostname to be in the SANs.
- # For example, Thunderbird 38 will display a warning if the remote host is only the CN.
+ # RFC 2818: If a subjectAltName extension of type dNSName is present, that MUST be used as the identity.
+ # In other words, the Common Name is irrelevant then.
if host:
sans.add(host)
return self.config.certstore.get_cert(host, list(sans))
diff --git a/mitmproxy/proxy/root_context.py b/mitmproxy/proxy/root_context.py
index 9caae02a..c55105ec 100644
--- a/mitmproxy/proxy/root_context.py
+++ b/mitmproxy/proxy/root_context.py
@@ -63,7 +63,7 @@ class RootContext(object):
except TlsProtocolException as e:
self.log("Cannot parse Client Hello: %s" % repr(e), "error")
else:
- ignore = self.config.check_ignore((client_hello.client_sni, 443))
+ ignore = self.config.check_ignore((client_hello.sni, 443))
if ignore:
return RawTCPLayer(top_layer, logging=False)
diff --git a/mitmproxy/utils.py b/mitmproxy/utils.py
index 5fd062ea..cda5bba6 100644
--- a/mitmproxy/utils.py
+++ b/mitmproxy/utils.py
@@ -7,6 +7,9 @@ import json
import importlib
import inspect
+import netlib.utils
+
+
def timestamp():
"""
Returns a serializable UTC timestamp.
@@ -73,25 +76,7 @@ def pretty_duration(secs):
return "{:.0f}ms".format(secs * 1000)
-class Data:
-
- def __init__(self, name):
- m = importlib.import_module(name)
- dirname = os.path.dirname(inspect.getsourcefile(m))
- self.dirname = os.path.abspath(dirname)
-
- def path(self, path):
- """
- Returns a path to the package data housed at 'path' under this
- module.Path can be a path to a file, or to a directory.
-
- This function will raise ValueError if the path does not exist.
- """
- fullpath = os.path.join(self.dirname, path)
- if not os.path.exists(fullpath):
- raise ValueError("dataPath: %s does not exist." % fullpath)
- return fullpath
-pkg_data = Data(__name__)
+pkg_data = netlib.utils.Data(__name__)
class LRUCache:
diff --git a/netlib/http/__init__.py b/netlib/http/__init__.py
index 9fafa28f..c4eb1d58 100644
--- a/netlib/http/__init__.py
+++ b/netlib/http/__init__.py
@@ -2,13 +2,13 @@ from __future__ import absolute_import, print_function, division
from .request import Request
from .response import Response
from .headers import Headers
-from .message import MultiDictView, decoded
+from .message import decoded
from . import http1, http2, status_codes
__all__ = [
"Request",
"Response",
"Headers",
- "MultiDictView", "decoded",
+ "decoded",
"http1", "http2", "status_codes",
]
diff --git a/netlib/http/message.py b/netlib/http/message.py
index 76affeec..028f43a1 100644
--- a/netlib/http/message.py
+++ b/netlib/http/message.py
@@ -242,72 +242,3 @@ class decoded(object):
def __exit__(self, type, value, tb):
if self.ce:
self.message.encode(self.ce)
-
-
-class MultiDictView(MultiDict):
- """
- Some parts in HTTP (Cookies, URL query strings, ...) require a specific data structure: A MultiDict.
- It behaves mostly like an ordered dict but it can have several values for the same key.
-
- The MultiDictView provides a MultiDict *view* on an :py:class:`Request` or :py:class:`Response`.
- That is, it represents a part of the request as a MultiDict, but doesn't contain state/data themselves.
-
- For example, ``request.cookies`` provides a view on the ``Cookie: ...`` header.
- Any change to ``request.cookies`` will also modify the ``Cookie`` header.
- Any change to the ``Cookie`` header will also modify ``request.cookies``.
-
- Example:
-
- .. code-block:: python
-
- # Cookies are represented as a MultiDict.
- >>> request.cookies
- MultiDictView[("name", "value"), ("a", "false"), ("a", "42")]
-
- # MultiDicts mostly behave like a normal dict.
- >>> request.cookies["name"]
- "value"
-
- # If there is more than one value, only the first value is returned.
- >>> request.cookies["a"]
- "false"
-
- # `.get_all(key)` returns a list of all values.
- >>> request.cookies.get_all("a")
- ["false", "42"]
-
- # Changes to the headers are immediately reflected in the cookies.
- >>> request.cookies
- MultiDictView[("name", "value"), ...]
- >>> del request.headers["Cookie"]
- >>> request.cookies
- MultiDictView[] # empty now
- """
-
- def __init__(self, attr, message):
- if False: # pragma: no cover
- # We do not want to call the parent constructor here as that
- # would cause an unnecessary parse/unparse pass.
- # This is here to silence linters. Message
- super(MultiDictView, self).__init__(None)
- self._attr = attr
- self._message = message # type: Message
-
- @staticmethod
- def _kconv(key):
- # All request-attributes are case-sensitive.
- return key
-
- @staticmethod
- def _reduce_values(values):
- # We just return the first element if
- # multiple elements exist with the same key.
- return values[0]
-
- @property
- def fields(self):
- return getattr(self._message, "_" + self._attr)
-
- @fields.setter
- def fields(self, value):
- setattr(self._message, self._attr, value)
diff --git a/netlib/http/request.py b/netlib/http/request.py
index ae28084b..056a2d93 100644
--- a/netlib/http/request.py
+++ b/netlib/http/request.py
@@ -10,8 +10,9 @@ from netlib import utils
from netlib.http import cookies
from netlib.odict import ODict
from .. import encoding
+from ..multidict import MultiDictView
from .headers import Headers
-from .message import Message, _native, _always_bytes, MessageData, MultiDictView
+from .message import Message, _native, _always_bytes, MessageData
# This regex extracts & splits the host header into host and port.
# Handles the edge case of IPv6 addresses containing colons.
@@ -228,20 +229,25 @@ class Request(Message):
"""
The request query string as an :py:class:`MultiDictView` object.
"""
- return MultiDictView("query", self)
+ return MultiDictView(
+ self._get_query,
+ self._set_query
+ )
- @property
- def _query(self):
+ def _get_query(self):
_, _, _, _, query, _ = urllib.parse.urlparse(self.url)
return tuple(utils.urldecode(query))
- @query.setter
- def query(self, value):
+ def _set_query(self, value):
query = utils.urlencode(value)
scheme, netloc, path, params, _, fragment = urllib.parse.urlparse(self.url)
_, _, _, self.path = utils.parse_url(
urllib.parse.urlunparse([scheme, netloc, path, params, query, fragment]))
+ @query.setter
+ def query(self, value):
+ self._set_query(value)
+
@property
def cookies(self):
# type: () -> MultiDictView
@@ -250,16 +256,21 @@ class Request(Message):
An empty :py:class:`MultiDictView` object if the cookie monster ate them all.
"""
- return MultiDictView("cookies", self)
+ return MultiDictView(
+ self._get_cookies,
+ self._set_cookies
+ )
- @property
- def _cookies(self):
+ def _get_cookies(self):
h = self.headers.get_all("Cookie")
return tuple(cookies.parse_cookie_headers(h))
+ def _set_cookies(self, value):
+ self.headers["cookie"] = cookies.format_cookie_header(value)
+
@cookies.setter
def cookies(self, value):
- self.headers["cookie"] = cookies.format_cookie_header(value)
+ self._set_cookies(value)
@property
def path_components(self):
@@ -322,17 +333,18 @@ class Request(Message):
An empty MultiDictView if the content-type indicates non-form data
or the content could not be parsed.
"""
- return MultiDictView("urlencoded_form", self)
+ return MultiDictView(
+ self._get_urlencoded_form,
+ self._set_urlencoded_form
+ )
- @property
- def _urlencoded_form(self):
+ def _get_urlencoded_form(self):
is_valid_content_type = "application/x-www-form-urlencoded" in self.headers.get("content-type", "").lower()
if is_valid_content_type:
return tuple(utils.urldecode(self.content))
return ()
- @urlencoded_form.setter
- def urlencoded_form(self, value):
+ def _set_urlencoded_form(self, value):
"""
Sets the body to the URL-encoded form data, and adds the appropriate content-type header.
This will overwrite the existing content if there is one.
@@ -340,21 +352,30 @@ class Request(Message):
self.headers["content-type"] = "application/x-www-form-urlencoded"
self.content = utils.urlencode(value)
+ @urlencoded_form.setter
+ def urlencoded_form(self, value):
+ self._set_urlencoded_form(value)
+
@property
def multipart_form(self):
"""
The multipart form data as an :py:class:`MultipartFormDict` object.
None if the content-type indicates non-form data.
"""
- return MultiDictView("multipart_form", self)
+ return MultiDictView(
+ self._get_multipart_form,
+ self._set_multipart_form
+ )
- @property
- def _multipart_form(self):
+ def _get_multipart_form(self):
is_valid_content_type = "multipart/form-data" in self.headers.get("content-type", "").lower()
if is_valid_content_type:
return utils.multipartdecode(self.headers, self.content)
return ()
+ def _set_multipart_form(self, value):
+ raise NotImplementedError()
+
@multipart_form.setter
def multipart_form(self, value):
- raise NotImplementedError()
+ self._set_multipart_form(value)
diff --git a/netlib/http/response.py b/netlib/http/response.py
index 6d56fc1f..7d272e10 100644
--- a/netlib/http/response.py
+++ b/netlib/http/response.py
@@ -5,7 +5,8 @@ import time
from . import cookies
from .headers import Headers
-from .message import Message, _native, _always_bytes, MessageData, MultiDictView
+from .message import Message, _native, _always_bytes, MessageData
+from ..multidict import MultiDictView
from .. import utils
@@ -80,21 +81,26 @@ class Response(Message):
Caveats:
Updating the attr
"""
- return MultiDictView("cookies", self)
+ return MultiDictView(
+ self._get_cookies,
+ self._set_cookies
+ )
- @property
- def _cookies(self):
+ def _get_cookies(self):
h = self.headers.get_all("set-cookie")
return tuple(cookies.parse_set_cookie_headers(h))
- @cookies.setter
- def cookies(self, all_cookies):
+ def _set_cookies(self, value):
cookie_headers = []
- for k, v in all_cookies:
+ for k, v in value:
header = cookies.format_set_cookie_header(k, v[0], v[1])
cookie_headers.append(header)
self.headers.set_all("set-cookie", cookie_headers)
+ @cookies.setter
+ def cookies(self, value):
+ self._set_cookies(value)
+
def refresh(self, now=None):
"""
This fairly complex and heuristic function refreshes a server
diff --git a/netlib/multidict.py b/netlib/multidict.py
index ec1b24d8..248acdec 100644
--- a/netlib/multidict.py
+++ b/netlib/multidict.py
@@ -15,13 +15,7 @@ from .utils import Serializable
@six.add_metaclass(ABCMeta)
-class MultiDict(MutableMapping, Serializable):
- def __init__(self, fields=None):
-
- # it is important for us that .fields is immutable, so that we can easily
- # detect changes to it.
- self.fields = tuple(fields) if fields else tuple() # type: Tuple[Tuple[bytes, bytes], ...]
-
+class _MultiDict(MutableMapping, Serializable):
def __repr__(self):
fields = tuple(
repr(field)
@@ -100,7 +94,7 @@ class MultiDict(MutableMapping, Serializable):
value
for k, value in self.fields
if self._kconv(k) == key
- ]
+ ]
def set_all(self, key, values):
"""
@@ -176,7 +170,7 @@ class MultiDict(MutableMapping, Serializable):
if multi:
return self.fields
else:
- return super(MultiDict, self).items()
+ return super(_MultiDict, self).items()
def to_dict(self):
"""
@@ -216,6 +210,12 @@ class MultiDict(MutableMapping, Serializable):
return cls(tuple(x) for x in state)
+class MultiDict(_MultiDict):
+ def __init__(self, fields=None):
+ super(MultiDict, self).__init__()
+ self.fields = tuple(fields) if fields else tuple() # type: Tuple[Tuple[bytes, bytes], ...]
+
+
@six.add_metaclass(ABCMeta)
class ImmutableMultiDict(MultiDict):
def _immutable(self, *_):
@@ -249,3 +249,34 @@ class ImmutableMultiDict(MultiDict):
ret = self.copy()
super(ImmutableMultiDict, ret).insert(index, key, value)
return ret
+
+
+class MultiDictView(_MultiDict):
+ """
+ The MultiDictView provides the MultiDict interface over calculated data.
+ The view itself contains no state - data is retrieved from the parent on
+ request, and stored back to the parent on change.
+ """
+ def __init__(self, getter, setter):
+ self._getter = getter
+ self._setter = setter
+ super(MultiDictView, self).__init__()
+
+ @staticmethod
+ def _kconv(key):
+ # All request-attributes are case-sensitive.
+ return key
+
+ @staticmethod
+ def _reduce_values(values):
+ # We just return the first element if
+ # multiple elements exist with the same key.
+ return values[0]
+
+ @property
+ def fields(self):
+ return self._getter()
+
+ @fields.setter
+ def fields(self, value):
+ return self._setter(value)
diff --git a/netlib/odict.py b/netlib/odict.py
index 461192f7..8a638dab 100644
--- a/netlib/odict.py
+++ b/netlib/odict.py
@@ -1,5 +1,6 @@
from __future__ import (absolute_import, print_function, division)
import copy
+
import six
from .utils import Serializable, safe_subn
@@ -27,27 +28,24 @@ class ODict(Serializable):
def __iter__(self):
return self.lst.__iter__()
- def __getitem__(self, k):
+ def __getitem__(self, key):
"""
Returns a list of values matching key.
"""
- ret = []
- k = self._kconv(k)
- for i in self.lst:
- if self._kconv(i[0]) == k:
- ret.append(i[1])
- return ret
- def keys(self):
- return list(set([self._kconv(i[0]) for i in self.lst]))
+ key = self._kconv(key)
+ return [
+ v
+ for k, v in self.lst
+ if self._kconv(k) == key
+ ]
- def _filter_lst(self, k, lst):
- k = self._kconv(k)
- new = []
- for i in lst:
- if self._kconv(i[0]) != k:
- new.append(i)
- return new
+ def keys(self):
+ return list(
+ set(
+ self._kconv(k) for k, _ in self.lst
+ )
+ )
def __len__(self):
"""
@@ -81,14 +79,19 @@ class ODict(Serializable):
"""
Delete all items matching k.
"""
- self.lst = self._filter_lst(k, self.lst)
-
- def __contains__(self, k):
k = self._kconv(k)
- for i in self.lst:
- if self._kconv(i[0]) == k:
- return True
- return False
+ self.lst = [
+ i
+ for i in self.lst
+ if self._kconv(i[0]) != k
+ ]
+
+ def __contains__(self, key):
+ key = self._kconv(key)
+ return any(
+ self._kconv(k) == key
+ for k, _ in self.lst
+ )
def add(self, key, value, prepend=False):
if prepend:
@@ -127,40 +130,24 @@ class ODict(Serializable):
def __repr__(self):
return repr(self.lst)
- def in_any(self, key, value, caseless=False):
- """
- Do any of the values matching key contain value?
-
- If caseless is true, value comparison is case-insensitive.
- """
- if caseless:
- value = value.lower()
- for i in self[key]:
- if caseless:
- i = i.lower()
- if value in i:
- return True
- return False
-
def replace(self, pattern, repl, *args, **kwargs):
"""
Replaces a regular expression pattern with repl in both keys and
- values. Encoded content will be decoded before replacement, and
- re-encoded afterwards.
+ values.
Returns the number of replacements made.
"""
- nlst, count = [], 0
- for i in self.lst:
- k, c = safe_subn(pattern, repl, i[0], *args, **kwargs)
+ new, count = [], 0
+ for k, v in self.lst:
+ k, c = safe_subn(pattern, repl, k, *args, **kwargs)
count += c
- v, c = safe_subn(pattern, repl, i[1], *args, **kwargs)
+ v, c = safe_subn(pattern, repl, v, *args, **kwargs)
count += c
- nlst.append([k, v])
- self.lst = nlst
+ new.append([k, v])
+ self.lst = new
return count
- # Implement the StateObject protocol from mitmproxy
+ # Implement Serializable
def get_state(self):
return [tuple(i) for i in self.lst]
diff --git a/pathod/utils.py b/pathod/utils.py
index 1e5bd9a4..d1e2dd00 100644
--- a/pathod/utils.py
+++ b/pathod/utils.py
@@ -1,5 +1,6 @@
import os
import sys
+import netlib.utils
SIZE_UNITS = dict(
@@ -75,27 +76,7 @@ def escape_unprintables(s):
return s
-class Data(object):
-
- def __init__(self, name):
- m = __import__(name)
- dirname, _ = os.path.split(m.__file__)
- self.dirname = os.path.abspath(dirname)
-
- def path(self, path):
- """
- Returns a path to the package data housed at 'path' under this
- module.Path can be a path to a file, or to a directory.
-
- This function will raise ValueError if the path does not exist.
- """
- fullpath = os.path.join(self.dirname, path)
- if not os.path.exists(fullpath):
- raise ValueError("dataPath: %s does not exist." % fullpath)
- return fullpath
-
-
-data = Data(__name__)
+data = netlib.utils.Data(__name__)
def daemonize(stdin='/dev/null', stdout='/dev/null', stderr='/dev/null'): # pragma: no cover
diff --git a/test/mitmproxy/test_examples.py b/test/mitmproxy/test_examples.py
index ac79b093..c4b06f4b 100644
--- a/test/mitmproxy/test_examples.py
+++ b/test/mitmproxy/test_examples.py
@@ -5,11 +5,12 @@ from contextlib import contextmanager
from mitmproxy import utils, script
from mitmproxy.proxy import config
+import netlib.utils
from netlib import tutils as netutils
from netlib.http import Headers
from . import tservers, tutils
-example_dir = utils.Data(__name__).path("../../examples")
+example_dir = netlib.utils.Data(__name__).path("../../examples")
class DummyContext(object):
diff --git a/test/mitmproxy/tutils.py b/test/mitmproxy/tutils.py
index d51ac185..2dfd710e 100644
--- a/test/mitmproxy/tutils.py
+++ b/test/mitmproxy/tutils.py
@@ -8,6 +8,7 @@ from contextlib import contextmanager
from unittest.case import SkipTest
+import netlib.utils
import netlib.tutils
from mitmproxy import utils, controller
from mitmproxy.models import (
@@ -163,4 +164,4 @@ def capture_stderr(command, *args, **kwargs):
sys.stderr = out
-test_data = utils.Data(__name__)
+test_data = netlib.utils.Data(__name__)
diff --git a/test/netlib/http/test_request.py b/test/netlib/http/test_request.py
index eefdc091..fae7aefe 100644
--- a/test/netlib/http/test_request.py
+++ b/test/netlib/http/test_request.py
@@ -3,9 +3,7 @@ from __future__ import absolute_import, print_function, division
import six
-from netlib import utils
from netlib.http import Headers
-from netlib.odict import ODict
from netlib.tutils import treq, raises
from .test_message import _test_decoded_attr, _test_passthrough_attr
diff --git a/test/netlib/test_multidict.py b/test/netlib/test_multidict.py
index ceea3806..5bb65e3f 100644
--- a/test/netlib/test_multidict.py
+++ b/test/netlib/test_multidict.py
@@ -1,5 +1,5 @@
from netlib import tutils
-from netlib.multidict import MultiDict, ImmutableMultiDict
+from netlib.multidict import MultiDict, ImmutableMultiDict, MultiDictView
class _TMulti(object):
@@ -214,4 +214,26 @@ class TestImmutableMultiDict(object):
def test_with_insert(self):
md = TImmutableMultiDict()
assert md.with_insert(0, "foo", "bar").fields == (("foo", "bar"),)
- assert md.fields == () \ No newline at end of file
+
+
+class TParent(object):
+ def __init__(self):
+ self.vals = tuple()
+
+ def setter(self, vals):
+ self.vals = vals
+
+ def getter(self):
+ return self.vals
+
+
+class TestMultiDictView(object):
+ def test_modify(self):
+ p = TParent()
+ tv = MultiDictView(p.getter, p.setter)
+ assert len(tv) == 0
+ tv["a"] = "b"
+ assert p.vals == (("a", "b"),)
+ tv["c"] = "b"
+ assert p.vals == (("a", "b"), ("c", "b"))
+ assert tv["a"] == "b"
diff --git a/test/netlib/test_odict.py b/test/netlib/test_odict.py
index f0985ef6..b6fd6401 100644
--- a/test/netlib/test_odict.py
+++ b/test/netlib/test_odict.py
@@ -27,16 +27,6 @@ class TestODict(object):
b.set_state(state)
assert b == od
- def test_in_any(self):
- od = odict.ODict()
- od["one"] = ["atwoa", "athreea"]
- assert od.in_any("one", "two")
- assert od.in_any("one", "three")
- assert not od.in_any("one", "four")
- assert not od.in_any("nonexistent", "foo")
- assert not od.in_any("one", "TWO")
- assert od.in_any("one", "TWO", True)
-
def test_iter(self):
od = odict.ODict()
assert not [i for i in od]
diff --git a/test/pathod/test_pathoc.py b/test/pathod/test_pathoc.py
index 8d0f92ac..4e8c89c5 100644
--- a/test/pathod/test_pathoc.py
+++ b/test/pathod/test_pathoc.py
@@ -211,7 +211,7 @@ class TestDaemon(_TestDaemon):
c.stop()
@skip_windows
- @pytest.mark.xfail
+ @pytest.mark.skip(reason="race condition")
def test_wait_finish(self):
c = pathoc.Pathoc(
("127.0.0.1", self.d.port),
diff --git a/test/pathod/test_pathod.py b/test/pathod/test_pathod.py
index 1718cc0b..05a3962e 100644
--- a/test/pathod/test_pathod.py
+++ b/test/pathod/test_pathod.py
@@ -129,7 +129,7 @@ class CommonTests(tutils.DaemonTests):
l = self.d.last_log()
# FIXME: Other binary data elements
- @pytest.mark.skip
+ @pytest.mark.skip(reason="race condition")
def test_sizelimit(self):
r = self.get("200:b@1g")
assert r.status_code == 800
@@ -143,7 +143,7 @@ class CommonTests(tutils.DaemonTests):
def test_info(self):
assert tuple(self.d.info()["version"]) == version.IVERSION
- @pytest.mark.skip
+ @pytest.mark.skip(reason="race condition")
def test_logs(self):
assert self.d.clear_log()
assert not self.d.last_log()
@@ -223,7 +223,7 @@ class CommonTests(tutils.DaemonTests):
)
assert r[1].payload == "test"
- @pytest.mark.skip
+ @pytest.mark.skip(reason="race condition")
def test_websocket_frame_reflect_error(self):
r, _ = self.pathoc(
["ws:/p/", "wf:-mask:knone:f'wf:b@10':i13,'a'"],
@@ -233,6 +233,7 @@ class CommonTests(tutils.DaemonTests):
# FIXME: Race Condition?
assert "Parse error" in self.d.text_log()
+ @pytest.mark.skip(reason="race condition")
def test_websocket_frame_disconnect_error(self):
self.pathoc(["ws:/p/", "wf:b@10:d3"], ws_read_limit=0)
assert self.d.last_log()
diff --git a/test/pathod/tutils.py b/test/pathod/tutils.py
index 9739afde..f6ed3efb 100644
--- a/test/pathod/tutils.py
+++ b/test/pathod/tutils.py
@@ -116,7 +116,7 @@ tmpdir = netlib.tutils.tmpdir
raises = netlib.tutils.raises
-test_data = utils.Data(__name__)
+test_data = netlib.utils.Data(__name__)
def render(r, settings=language.Settings()):