aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--docs/dev/models.rst11
-rw-r--r--examples/modify_form.py11
-rw-r--r--examples/modify_querystring.py5
-rw-r--r--mitmproxy/console/flowview.py52
-rw-r--r--mitmproxy/console/grideditor.py12
-rw-r--r--mitmproxy/flow.py34
-rw-r--r--mitmproxy/flow_export.py4
-rw-r--r--mitmproxy/protocol/base.py11
-rw-r--r--mitmproxy/protocol/http.py14
-rw-r--r--mitmproxy/protocol/tls.py178
-rw-r--r--mitmproxy/proxy/root_context.py2
-rw-r--r--mitmproxy/utils.py23
-rw-r--r--netlib/encoding.py1
-rw-r--r--netlib/http/__init__.py4
-rw-r--r--netlib/http/cookies.py60
-rw-r--r--netlib/http/headers.py140
-rw-r--r--netlib/http/http1/read.py4
-rw-r--r--netlib/http/http2/connections.py12
-rw-r--r--netlib/http/message.py7
-rw-r--r--netlib/http/request.py131
-rw-r--r--netlib/http/response.py45
-rw-r--r--netlib/multidict.py282
-rw-r--r--netlib/utils.py11
-rw-r--r--pathod/utils.py23
-rw-r--r--test/mitmproxy/test_examples.py15
-rw-r--r--test/mitmproxy/test_flow.py54
-rw-r--r--test/mitmproxy/test_flow_export.py6
-rw-r--r--test/mitmproxy/test_flow_export/locust_get.py6
-rw-r--r--test/mitmproxy/test_flow_export/locust_task_get.py6
-rw-r--r--test/mitmproxy/test_flow_export/python_get.py6
-rw-r--r--test/mitmproxy/tutils.py3
-rw-r--r--test/netlib/http/http1/test_read.py8
-rw-r--r--test/netlib/http/http2/test_connections.py6
-rw-r--r--test/netlib/http/test_cookies.py18
-rw-r--r--test/netlib/http/test_headers.py109
-rw-r--r--test/netlib/http/test_request.py67
-rw-r--r--test/netlib/http/test_response.py34
-rw-r--r--test/netlib/test_multidict.py239
-rw-r--r--test/pathod/test_pathod.py1
-rw-r--r--test/pathod/tutils.py2
40 files changed, 1015 insertions, 642 deletions
diff --git a/docs/dev/models.rst b/docs/dev/models.rst
index 8c4e6825..f2ddf242 100644
--- a/docs/dev/models.rst
+++ b/docs/dev/models.rst
@@ -56,6 +56,17 @@ Datastructures
:special-members:
:no-undoc-members:
+ .. autoclass:: MultiDictView
+
+ .. automethod:: get_all
+ .. automethod:: set_all
+ .. automethod:: add
+ .. automethod:: insert
+ .. automethod:: keys
+ .. automethod:: values
+ .. automethod:: items
+ .. automethod:: to_dict
+
.. autoclass:: decoded
.. automodule:: mitmproxy.models
diff --git a/examples/modify_form.py b/examples/modify_form.py
index 86188781..3fe0cf96 100644
--- a/examples/modify_form.py
+++ b/examples/modify_form.py
@@ -1,5 +1,8 @@
def request(context, flow):
- form = flow.request.urlencoded_form
- if form is not None:
- form["mitmproxy"] = ["rocks"]
- flow.request.urlencoded_form = form
+ if flow.request.urlencoded_form:
+ flow.request.urlencoded_form["mitmproxy"] = "rocks"
+ else:
+ # This sets the proper content type and overrides the body.
+ flow.request.urlencoded_form = [
+ ("foo", "bar")
+ ]
diff --git a/examples/modify_querystring.py b/examples/modify_querystring.py
index d682df69..b89e5c8d 100644
--- a/examples/modify_querystring.py
+++ b/examples/modify_querystring.py
@@ -1,5 +1,2 @@
def request(context, flow):
- q = flow.request.query
- if q:
- q["mitmproxy"] = ["rocks"]
- flow.request.query = q
+ flow.request.query["mitmproxy"] = "rocks"
diff --git a/mitmproxy/console/flowview.py b/mitmproxy/console/flowview.py
index b2ebe49e..2010cecd 100644
--- a/mitmproxy/console/flowview.py
+++ b/mitmproxy/console/flowview.py
@@ -6,8 +6,7 @@ import sys
import math
import urwid
-from netlib import odict
-from netlib.http import Headers
+from netlib.http import Headers, status_codes
from . import common, grideditor, signals, searchable, tabs
from . import flowdetailview
from .. import utils, controller, contentviews
@@ -187,7 +186,7 @@ class FlowView(tabs.Tabs):
viewmode,
message,
limit,
- (bytes(message.headers), message.content) # Cache invalidation
+ message # Cache invalidation
)
def _get_content_view(self, viewmode, message, max_lines, _):
@@ -316,21 +315,18 @@ class FlowView(tabs.Tabs):
return "Invalid URL."
signals.flow_change.send(self, flow = self.flow)
- def set_resp_code(self, code):
- response = self.flow.response
+ def set_resp_status_code(self, status_code):
try:
- response.status_code = int(code)
+ status_code = int(status_code)
except ValueError:
return None
- import BaseHTTPServer
- if int(code) in BaseHTTPServer.BaseHTTPRequestHandler.responses:
- response.msg = BaseHTTPServer.BaseHTTPRequestHandler.responses[
- int(code)][0]
+ self.flow.response.status_code = status_code
+ if status_code in status_codes.RESPONSES:
+ self.flow.response.reason = status_codes.RESPONSES[status_code]
signals.flow_change.send(self, flow = self.flow)
- def set_resp_msg(self, msg):
- response = self.flow.response
- response.msg = msg
+ def set_resp_reason(self, reason):
+ self.flow.response.reason = reason
signals.flow_change.send(self, flow = self.flow)
def set_headers(self, fields, conn):
@@ -338,22 +334,22 @@ class FlowView(tabs.Tabs):
signals.flow_change.send(self, flow = self.flow)
def set_query(self, lst, conn):
- conn.set_query(odict.ODict(lst))
+ conn.query = lst
signals.flow_change.send(self, flow = self.flow)
def set_path_components(self, lst, conn):
- conn.set_path_components(lst)
+ conn.path_components = lst
signals.flow_change.send(self, flow = self.flow)
def set_form(self, lst, conn):
- conn.set_form_urlencoded(odict.ODict(lst))
+ conn.urlencoded_form = lst
signals.flow_change.send(self, flow = self.flow)
def edit_form(self, conn):
self.master.view_grideditor(
grideditor.URLEncodedFormEditor(
self.master,
- conn.get_form_urlencoded().lst,
+ conn.urlencoded_form.items(multi=True),
self.set_form,
conn
)
@@ -364,7 +360,7 @@ class FlowView(tabs.Tabs):
self.edit_form(conn)
def set_cookies(self, lst, conn):
- conn.cookies = odict.ODict(lst)
+ conn.cookies = lst
signals.flow_change.send(self, flow = self.flow)
def set_setcookies(self, data, conn):
@@ -388,7 +384,7 @@ class FlowView(tabs.Tabs):
self.master.view_grideditor(
grideditor.CookieEditor(
self.master,
- message.cookies.lst,
+ message.cookies.items(multi=True),
self.set_cookies,
message
)
@@ -397,7 +393,7 @@ class FlowView(tabs.Tabs):
self.master.view_grideditor(
grideditor.SetCookieEditor(
self.master,
- message.cookies,
+ message.cookies.items(multi=True),
self.set_setcookies,
message
)
@@ -413,7 +409,7 @@ class FlowView(tabs.Tabs):
c = self.master.spawn_editor(message.content or "")
message.content = c.rstrip("\n")
elif part == "f":
- if not message.get_form_urlencoded() and message.content:
+ if not message.urlencoded_form and message.content:
signals.status_prompt_onekey.send(
prompt = "Existing body is not a URL-encoded form. Clear and edit?",
keys = [
@@ -435,7 +431,7 @@ class FlowView(tabs.Tabs):
)
)
elif part == "p":
- p = message.get_path_components()
+ p = message.path_components
self.master.view_grideditor(
grideditor.PathEditor(
self.master,
@@ -448,7 +444,7 @@ class FlowView(tabs.Tabs):
self.master.view_grideditor(
grideditor.QueryEditor(
self.master,
- message.get_query().lst,
+ message.query.items(multi=True),
self.set_query, message
)
)
@@ -458,7 +454,7 @@ class FlowView(tabs.Tabs):
text = message.url,
callback = self.set_url
)
- elif part == "m":
+ elif part == "m" and message == self.flow.request:
signals.status_prompt_onekey.send(
prompt = "Method",
keys = common.METHOD_OPTIONS,
@@ -468,13 +464,13 @@ class FlowView(tabs.Tabs):
signals.status_prompt.send(
prompt = "Code",
text = str(message.status_code),
- callback = self.set_resp_code
+ callback = self.set_resp_status_code
)
- elif part == "m":
+ elif part == "m" and message == self.flow.response:
signals.status_prompt.send(
prompt = "Message",
- text = message.msg,
- callback = self.set_resp_msg
+ text = message.reason,
+ callback = self.set_resp_reason
)
signals.flow_change.send(self, flow = self.flow)
diff --git a/mitmproxy/console/grideditor.py b/mitmproxy/console/grideditor.py
index 46ff348e..11ce7d02 100644
--- a/mitmproxy/console/grideditor.py
+++ b/mitmproxy/console/grideditor.py
@@ -700,17 +700,17 @@ class SetCookieEditor(GridEditor):
def data_in(self, data):
flattened = []
- for k, v in data.items():
- flattened.append([k, v[0], v[1].lst])
+ for key, (value, attrs) in data:
+ flattened.append([key, value, attrs.items(multi=True)])
return flattened
def data_out(self, data):
vals = []
- for i in data:
+ for key, value, attrs in data:
vals.append(
[
- i[0],
- [i[1], odict.ODictCaseless(i[2])]
+ key,
+ (value, attrs)
]
)
- return odict.ODict(vals)
+ return vals
diff --git a/mitmproxy/flow.py b/mitmproxy/flow.py
index ccedd1d4..a9018e16 100644
--- a/mitmproxy/flow.py
+++ b/mitmproxy/flow.py
@@ -158,9 +158,9 @@ class SetHeaders:
for _, header, value, cpatt in self.lst:
if cpatt(f):
if f.response:
- f.response.headers.fields.append((header, value))
+ f.response.headers.add(header, value)
else:
- f.request.headers.fields.append((header, value))
+ f.request.headers.add(header, value)
class StreamLargeBodies(object):
@@ -265,7 +265,7 @@ class ServerPlaybackState:
form_contents = r.urlencoded_form or r.multipart_form
if self.ignore_payload_params and form_contents:
key.extend(
- p for p in form_contents
+ p for p in form_contents.items(multi=True)
if p[0] not in self.ignore_payload_params
)
else:
@@ -321,10 +321,10 @@ class StickyCookieState:
"""
domain = f.request.host
path = "/"
- if attrs["domain"]:
- domain = attrs["domain"][-1]
- if attrs["path"]:
- path = attrs["path"][-1]
+ if "domain" in attrs:
+ domain = attrs["domain"]
+ if "path" in attrs:
+ path = attrs["path"]
return (domain, f.request.port, path)
def domain_match(self, a, b):
@@ -335,28 +335,26 @@ class StickyCookieState:
return False
def handle_response(self, f):
- for i in f.response.headers.get_all("set-cookie"):
+ for name, (value, attrs) in f.response.cookies.items(multi=True):
# FIXME: We now know that Cookie.py screws up some cookies with
# valid RFC 822/1123 datetime specifications for expiry. Sigh.
- name, value, attrs = cookies.parse_set_cookie_header(str(i))
a = self.ckey(attrs, f)
if self.domain_match(f.request.host, a[0]):
- b = attrs.lst
- b.insert(0, [name, value])
- self.jar[a][name] = odict.ODictCaseless(b)
+ b = attrs.with_insert(0, name, value)
+ self.jar[a][name] = b
def handle_request(self, f):
l = []
if f.match(self.flt):
- for i in self.jar.keys():
+ for domain, port, path in self.jar.keys():
match = [
- self.domain_match(f.request.host, i[0]),
- f.request.port == i[1],
- f.request.path.startswith(i[2])
+ self.domain_match(f.request.host, domain),
+ f.request.port == port,
+ f.request.path.startswith(path)
]
if all(match):
- c = self.jar[i]
- l.extend([cookies.format_cookie_header(c[name]) for name in c.keys()])
+ c = self.jar[(domain, port, path)]
+ l.extend([cookies.format_cookie_header(c[name].items(multi=True)) for name in c.keys()])
if l:
f.request.stickycookie = True
f.request.headers["cookie"] = "; ".join(l)
diff --git a/mitmproxy/flow_export.py b/mitmproxy/flow_export.py
index d8e65704..ae282fce 100644
--- a/mitmproxy/flow_export.py
+++ b/mitmproxy/flow_export.py
@@ -51,7 +51,7 @@ def python_code(flow):
params = ""
if flow.request.query:
- lines = [" '%s': '%s',\n" % (k, v) for k, v in flow.request.query]
+ lines = [" %s: %s,\n" % (repr(k), repr(v)) for k, v in flow.request.query.to_dict().items()]
params = "\nparams = {\n%s}\n" % "".join(lines)
args += "\n params=params,"
@@ -140,7 +140,7 @@ def locust_code(flow):
params = ""
if flow.request.query:
- lines = [" '%s': '%s',\n" % (k, v) for k, v in flow.request.query]
+ lines = [" %s: %s,\n" % (repr(k), repr(v)) for k, v in flow.request.query.to_dict().items()]
params = "\n params = {\n%s }\n" % "".join(lines)
args += "\n params=params,"
diff --git a/mitmproxy/protocol/base.py b/mitmproxy/protocol/base.py
index 536f2753..c8e58d1b 100644
--- a/mitmproxy/protocol/base.py
+++ b/mitmproxy/protocol/base.py
@@ -133,24 +133,15 @@ class ServerConnectionMixin(object):
"The proxy shall not connect to itself.".format(repr(address))
)
- def set_server(self, address, server_tls=None, sni=None):
+ def set_server(self, address):
"""
Sets a new server address. If there is an existing connection, it will be closed.
-
- Raises:
- ~mitmproxy.exceptions.ProtocolException:
- if ``server_tls`` is ``True``, but there was no TLS layer on the
- protocol stack which could have processed this.
"""
if self.server_conn:
self.disconnect()
self.log("Set new server address: " + repr(address), "debug")
self.server_conn.address = address
self.__check_self_connect()
- if server_tls:
- raise ProtocolException(
- "Cannot upgrade to TLS, no TLS layer on the protocol stack."
- )
def disconnect(self):
"""
diff --git a/mitmproxy/protocol/http.py b/mitmproxy/protocol/http.py
index 9cb35176..d9111303 100644
--- a/mitmproxy/protocol/http.py
+++ b/mitmproxy/protocol/http.py
@@ -120,7 +120,7 @@ class UpstreamConnectLayer(Layer):
if address != self.server_conn.via.address:
self.ctx.set_server(address)
- def set_server(self, address, server_tls=None, sni=None):
+ def set_server(self, address):
if self.ctx.server_conn:
self.ctx.disconnect()
address = tcp.Address.wrap(address)
@@ -128,11 +128,6 @@ class UpstreamConnectLayer(Layer):
self.connect_request.port = address.port
self.server_conn.address = address
- if server_tls:
- raise ProtocolException(
- "Cannot upgrade to TLS, no TLS layer on the protocol stack."
- )
-
class HttpLayer(Layer):
@@ -149,7 +144,7 @@ class HttpLayer(Layer):
def __call__(self):
if self.mode == "transparent":
- self.__initial_server_tls = self._server_tls
+ self.__initial_server_tls = self.server_tls
self.__initial_server_conn = self.server_conn
while True:
try:
@@ -360,8 +355,9 @@ class HttpLayer(Layer):
if self.mode == "regular" or self.mode == "transparent":
# If there's an existing connection that doesn't match our expectations, kill it.
- if address != self.server_conn.address or tls != self.server_conn.tls_established:
- self.set_server(address, tls, address.host)
+ if address != self.server_conn.address or tls != self.server_tls:
+ self.set_server(address)
+ self.set_server_tls(tls, address.host)
# Establish connection is neccessary.
if not self.server_conn:
self.connect()
diff --git a/mitmproxy/protocol/tls.py b/mitmproxy/protocol/tls.py
index 26c3f9d2..74c55ab4 100644
--- a/mitmproxy/protocol/tls.py
+++ b/mitmproxy/protocol/tls.py
@@ -266,18 +266,22 @@ class TlsClientHello(object):
return self._client_hello
@property
- def client_cipher_suites(self):
+ def cipher_suites(self):
return self._client_hello.cipher_suites.cipher_suites
@property
- def client_sni(self):
+ def sni(self):
for extension in self._client_hello.extensions:
- if (extension.type == 0x00 and len(extension.server_names) == 1
- and extension.server_names[0].type == 0):
+ is_valid_sni_extension = (
+ extension.type == 0x00
+ and len(extension.server_names) == 1
+ and extension.server_names[0].type == 0
+ )
+ if is_valid_sni_extension:
return extension.server_names[0].name
@property
- def client_alpn_protocols(self):
+ def alpn_protocols(self):
for extension in self._client_hello.extensions:
if extension.type == 0x10:
return list(extension.alpn_protocols)
@@ -304,55 +308,78 @@ class TlsClientHello(object):
def __repr__(self):
return "TlsClientHello( sni: %s alpn_protocols: %s, cipher_suites: %s)" % \
- (self.client_sni, self.client_alpn_protocols, self.client_cipher_suites)
+ (self.sni, self.alpn_protocols, self.cipher_suites)
class TlsLayer(Layer):
+ """
+ The TLS layer implements transparent TLS connections.
- def __init__(self, ctx, client_tls, server_tls):
- self.client_sni = None
- self.client_alpn_protocols = None
- self.client_ciphers = []
+ It exposes the following API to child layers:
+
+ - :py:meth:`set_server_tls` to modify TLS settings for the server connection.
+ - :py:attr:`server_tls`, :py:attr:`server_sni` as read-only attributes describing the current TLS settings for
+ the server connection.
+ """
+ def __init__(self, ctx, client_tls, server_tls):
super(TlsLayer, self).__init__(ctx)
self._client_tls = client_tls
self._server_tls = server_tls
- self._sni_from_server_change = None
+ self._custom_server_sni = None
+ self._client_hello = None # type: TlsClientHello
def __call__(self):
"""
- The strategy for establishing SSL is as follows:
+ The strategy for establishing TLS is as follows:
First, we determine whether we need the server cert to establish ssl with the client.
If so, we first connect to the server and then to the client.
- If not, we only connect to the client and do the server_ssl lazily on a Connect message.
-
- An additional complexity is that establish ssl with the server may require a SNI value from
- the client. In an ideal world, we'd do the following:
- 1. Start the SSL handshake with the client
- 2. Check if the client sends a SNI.
- 3. Pause the client handshake, establish SSL with the server.
- 4. Finish the client handshake with the certificate from the server.
- There's just one issue: We cannot get a callback from OpenSSL if the client doesn't send a SNI. :(
- Thus, we manually peek into the connection and parse the ClientHello message to obtain both SNI and ALPN values.
-
- Further notes:
- - OpenSSL 1.0.2 introduces a callback that would help here:
- https://www.openssl.org/docs/ssl/SSL_CTX_set_cert_cb.html
- - The original mitmproxy issue is https://github.com/mitmproxy/mitmproxy/issues/427
- """
-
- client_tls_requires_server_cert = (
- self._client_tls and self._server_tls and not self.config.no_upstream_cert
- )
+ If not, we only connect to the client and do the server handshake lazily.
+ An additional complexity is that we need to mirror SNI and ALPN from the client when connecting to the server.
+ We manually peek into the connection and parse the ClientHello message to obtain these values.
+ """
if self._client_tls:
- self._parse_client_hello()
+ # Peek into the connection, read the initial client hello and parse it to obtain SNI and ALPN values.
+ try:
+ self._client_hello = TlsClientHello.from_client_conn(self.client_conn)
+ except TlsProtocolException as e:
+ self.log("Cannot parse Client Hello: %s" % repr(e), "error")
+
+ # Do we need to do a server handshake now?
+ # There are two reasons why we would want to establish TLS with the server now:
+ # 1. If we already have an existing server connection and server_tls is True,
+ # we need to establish TLS now because .connect() will not be called anymore.
+ # 2. We may need information from the server connection for the client handshake.
+ #
+ # A couple of factors influence (2):
+ # 2.1 There actually is (or will be) a TLS-enabled upstream connection
+ # 2.2 An upstream connection is not wanted by the user if --no-upstream-cert is passed.
+ # 2.3 An upstream connection is implied by add_upstream_certs_to_client_chain
+ # 2.4 The client wants to negotiate an alternative protocol in its handshake, we need to find out
+ # what is supported by the server
+ # 2.5 The client did not sent a SNI value, we don't know the certificate subject.
+ client_tls_requires_server_connection = (
+ self._server_tls
+ and not self.config.no_upstream_cert
+ and (
+ self.config.add_upstream_certs_to_client_chain
+ or self._client_hello.alpn_protocols
+ or not self._client_hello.sni
+ )
+ )
+ establish_server_tls_now = (
+ (self.server_conn and self._server_tls)
+ or client_tls_requires_server_connection
+ )
- if client_tls_requires_server_cert:
+ if self._client_tls and establish_server_tls_now:
self._establish_tls_with_client_and_server()
elif self._client_tls:
self._establish_tls_with_client()
+ elif establish_server_tls_now:
+ self._establish_tls_with_server()
layer = self.ctx.next_layer(self)
layer()
@@ -367,47 +394,48 @@ class TlsLayer(Layer):
else:
return "TlsLayer(inactive)"
- def _parse_client_hello(self):
- """
- Peek into the connection, read the initial client hello and parse it to obtain ALPN values.
- """
- try:
- parsed = TlsClientHello.from_client_conn(self.client_conn)
- self.client_sni = parsed.client_sni
- self.client_alpn_protocols = parsed.client_alpn_protocols
- self.client_ciphers = parsed.client_cipher_suites
- except TlsProtocolException as e:
- self.log("Cannot parse Client Hello: %s" % repr(e), "error")
-
def connect(self):
if not self.server_conn:
self.ctx.connect()
if self._server_tls and not self.server_conn.tls_established:
self._establish_tls_with_server()
- def set_server(self, address, server_tls=None, sni=None):
- if server_tls is not None:
- self._sni_from_server_change = sni
- self._server_tls = server_tls
- self.ctx.set_server(address, None, None)
+ def set_server_tls(self, server_tls, sni=None):
+ """
+ Set the TLS settings for the next server connection that will be established.
+ This function will not alter an existing connection.
+
+ Args:
+ server_tls: Shall we establish TLS with the server?
+ sni: ``bytes`` for a custom SNI value,
+ ``None`` for the client SNI value,
+ ``False`` if no SNI value should be sent.
+ """
+ self._server_tls = server_tls
+ self._custom_server_sni = sni
+
+ @property
+ def server_tls(self):
+ """
+ ``True``, if the next server connection that will be established should be upgraded to TLS.
+ """
+ return self._server_tls
@property
- def sni_for_server_connection(self):
- if self._sni_from_server_change is False:
+ def server_sni(self):
+ """
+ The Server Name Indication we want to send with the next server TLS handshake.
+ """
+ if self._custom_server_sni is False:
return None
else:
- return self._sni_from_server_change or self.client_sni
+ return self._custom_server_sni or self._client_hello.sni
@property
def alpn_for_client_connection(self):
return self.server_conn.get_alpn_proto_negotiated()
def __alpn_select_callback(self, conn_, options):
- """
- Once the client signals the alternate protocols it supports,
- we reconnect upstream with the same list and pass the server's choice down to the client.
- """
-
# This gets triggered if we haven't established an upstream connection yet.
default_alpn = b'http/1.1'
# alpn_preference = b'h2'
@@ -422,12 +450,12 @@ class TlsLayer(Layer):
return choice
def _establish_tls_with_client_and_server(self):
- # If establishing TLS with the server fails, we try to establish TLS with the client nonetheless
- # to send an error message over TLS.
try:
self.ctx.connect()
self._establish_tls_with_server()
except Exception:
+ # If establishing TLS with the server fails, we try to establish TLS with the client nonetheless
+ # to send an error message over TLS.
try:
self._establish_tls_with_client()
except:
@@ -466,9 +494,9 @@ class TlsLayer(Layer):
ClientHandshakeException,
ClientHandshakeException(
"Cannot establish TLS with client (sni: {sni}): {e}".format(
- sni=self.client_sni, e=repr(e)
+ sni=self._client_hello.sni, e=repr(e)
),
- self.client_sni or repr(self.server_conn.address)
+ self._client_hello.sni or repr(self.server_conn.address)
),
sys.exc_info()[2]
)
@@ -480,8 +508,8 @@ class TlsLayer(Layer):
# If the server only supports spdy (next to http/1.1), it may select that
# and mitmproxy would enter TCP passthrough mode, which we want to avoid.
deprecated_http2_variant = lambda x: x.startswith(b"h2-") or x.startswith(b"spdy")
- if self.client_alpn_protocols:
- alpn = [x for x in self.client_alpn_protocols if not deprecated_http2_variant(x)]
+ if self._client_hello.alpn_protocols:
+ alpn = [x for x in self._client_hello.alpn_protocols if not deprecated_http2_variant(x)]
else:
alpn = None
if alpn and b"h2" in alpn and not self.config.http2:
@@ -490,14 +518,14 @@ class TlsLayer(Layer):
ciphers_server = self.config.ciphers_server
if not ciphers_server:
ciphers_server = []
- for id in self.client_ciphers:
+ for id in self._client_hello.cipher_suites:
if id in CIPHER_ID_NAME_MAP.keys():
ciphers_server.append(CIPHER_ID_NAME_MAP[id])
ciphers_server = ':'.join(ciphers_server)
self.server_conn.establish_ssl(
self.config.clientcerts,
- self.sni_for_server_connection,
+ self.server_sni,
method=self.config.openssl_method_server,
options=self.config.openssl_options_server,
verify_options=self.config.openssl_verification_mode_server,
@@ -524,7 +552,7 @@ class TlsLayer(Layer):
TlsProtocolException,
TlsProtocolException("Cannot establish TLS with {address} (sni: {sni}): {e}".format(
address=repr(self.server_conn.address),
- sni=self.sni_for_server_connection,
+ sni=self.server_sni,
e=repr(e),
)),
sys.exc_info()[2]
@@ -534,7 +562,7 @@ class TlsLayer(Layer):
TlsProtocolException,
TlsProtocolException("Cannot establish TLS with {address} (sni: {sni}): {e}".format(
address=repr(self.server_conn.address),
- sni=self.sni_for_server_connection,
+ sni=self.server_sni,
e=repr(e),
)),
sys.exc_info()[2]
@@ -569,13 +597,13 @@ class TlsLayer(Layer):
sans.add(host)
host = upstream_cert.cn.decode("utf8").encode("idna")
# Also add SNI values.
- if self.client_sni:
- sans.add(self.client_sni)
- if self._sni_from_server_change:
- sans.add(self._sni_from_server_change)
+ if self._client_hello.sni:
+ sans.add(self._client_hello.sni)
+ if self._custom_server_sni:
+ sans.add(self._custom_server_sni)
- # Some applications don't consider the CN and expect the hostname to be in the SANs.
- # For example, Thunderbird 38 will display a warning if the remote host is only the CN.
+ # RFC 2818: If a subjectAltName extension of type dNSName is present, that MUST be used as the identity.
+ # In other words, the Common Name is irrelevant then.
if host:
sans.add(host)
return self.config.certstore.get_cert(host, list(sans))
diff --git a/mitmproxy/proxy/root_context.py b/mitmproxy/proxy/root_context.py
index 9caae02a..c55105ec 100644
--- a/mitmproxy/proxy/root_context.py
+++ b/mitmproxy/proxy/root_context.py
@@ -63,7 +63,7 @@ class RootContext(object):
except TlsProtocolException as e:
self.log("Cannot parse Client Hello: %s" % repr(e), "error")
else:
- ignore = self.config.check_ignore((client_hello.client_sni, 443))
+ ignore = self.config.check_ignore((client_hello.sni, 443))
if ignore:
return RawTCPLayer(top_layer, logging=False)
diff --git a/mitmproxy/utils.py b/mitmproxy/utils.py
index 5fd062ea..cda5bba6 100644
--- a/mitmproxy/utils.py
+++ b/mitmproxy/utils.py
@@ -7,6 +7,9 @@ import json
import importlib
import inspect
+import netlib.utils
+
+
def timestamp():
"""
Returns a serializable UTC timestamp.
@@ -73,25 +76,7 @@ def pretty_duration(secs):
return "{:.0f}ms".format(secs * 1000)
-class Data:
-
- def __init__(self, name):
- m = importlib.import_module(name)
- dirname = os.path.dirname(inspect.getsourcefile(m))
- self.dirname = os.path.abspath(dirname)
-
- def path(self, path):
- """
- Returns a path to the package data housed at 'path' under this
- module.Path can be a path to a file, or to a directory.
-
- This function will raise ValueError if the path does not exist.
- """
- fullpath = os.path.join(self.dirname, path)
- if not os.path.exists(fullpath):
- raise ValueError("dataPath: %s does not exist." % fullpath)
- return fullpath
-pkg_data = Data(__name__)
+pkg_data = netlib.utils.Data(__name__)
class LRUCache:
diff --git a/netlib/encoding.py b/netlib/encoding.py
index 14479e00..98502451 100644
--- a/netlib/encoding.py
+++ b/netlib/encoding.py
@@ -5,7 +5,6 @@ from __future__ import absolute_import
from io import BytesIO
import gzip
import zlib
-from .utils import always_byte_args
ENCODINGS = {"identity", "gzip", "deflate"}
diff --git a/netlib/http/__init__.py b/netlib/http/__init__.py
index 917080f7..c4eb1d58 100644
--- a/netlib/http/__init__.py
+++ b/netlib/http/__init__.py
@@ -3,12 +3,12 @@ from .request import Request
from .response import Response
from .headers import Headers
from .message import decoded
-from . import http1, http2
+from . import http1, http2, status_codes
__all__ = [
"Request",
"Response",
"Headers",
"decoded",
- "http1", "http2",
+ "http1", "http2", "status_codes",
]
diff --git a/netlib/http/cookies.py b/netlib/http/cookies.py
index 4451f1da..88c76870 100644
--- a/netlib/http/cookies.py
+++ b/netlib/http/cookies.py
@@ -1,8 +1,8 @@
-from six.moves import http_cookies as Cookie
+import collections
import re
-import string
from email.utils import parsedate_tz, formatdate, mktime_tz
+from netlib.multidict import ImmutableMultiDict
from .. import odict
"""
@@ -157,42 +157,76 @@ def _parse_set_cookie_pairs(s):
return pairs
+def parse_set_cookie_headers(headers):
+ ret = []
+ for header in headers:
+ v = parse_set_cookie_header(header)
+ if v:
+ name, value, attrs = v
+ ret.append((name, SetCookie(value, attrs)))
+ return ret
+
+
+class CookieAttrs(ImmutableMultiDict):
+ @staticmethod
+ def _kconv(key):
+ return key.lower()
+
+ @staticmethod
+ def _reduce_values(values):
+ # See the StickyCookieTest for a weird cookie that only makes sense
+ # if we take the last part.
+ return values[-1]
+
+
+SetCookie = collections.namedtuple("SetCookie", ["value", "attrs"])
+
+
def parse_set_cookie_header(line):
"""
Parse a Set-Cookie header value
Returns a (name, value, attrs) tuple, or None, where attrs is an
- ODictCaseless set of attributes. No attempt is made to parse attribute
+ CookieAttrs dict of attributes. No attempt is made to parse attribute
values - they are treated purely as strings.
"""
pairs = _parse_set_cookie_pairs(line)
if pairs:
- return pairs[0][0], pairs[0][1], odict.ODictCaseless(pairs[1:])
+ return pairs[0][0], pairs[0][1], CookieAttrs(tuple(x) for x in pairs[1:])
def format_set_cookie_header(name, value, attrs):
"""
Formats a Set-Cookie header value.
"""
- pairs = [[name, value]]
- pairs.extend(attrs.lst)
+ pairs = [(name, value)]
+ pairs.extend(
+ attrs.fields if hasattr(attrs, "fields") else attrs
+ )
return _format_set_cookie_pairs(pairs)
+def parse_cookie_headers(cookie_headers):
+ cookie_list = []
+ for header in cookie_headers:
+ cookie_list.extend(parse_cookie_header(header))
+ return cookie_list
+
+
def parse_cookie_header(line):
"""
Parse a Cookie header value.
- Returns a (possibly empty) ODict object.
+ Returns a list of (lhs, rhs) tuples.
"""
pairs, off_ = _read_pairs(line)
- return odict.ODict(pairs)
+ return pairs
-def format_cookie_header(od):
+def format_cookie_header(lst):
"""
Formats a Cookie header value.
"""
- return _format_pairs(od.lst)
+ return _format_pairs(lst)
def refresh_set_cookie_header(c, delta):
@@ -209,10 +243,10 @@ def refresh_set_cookie_header(c, delta):
raise ValueError("Invalid Cookie")
if "expires" in attrs:
- e = parsedate_tz(attrs["expires"][-1])
+ e = parsedate_tz(attrs["expires"])
if e:
f = mktime_tz(e) + delta
- attrs["expires"] = [formatdate(f)]
+ attrs = attrs.with_set_all("expires", [formatdate(f)])
else:
# This can happen when the expires tag is invalid.
# reddit.com sends a an expires tag like this: "Thu, 31 Dec
@@ -220,7 +254,7 @@ def refresh_set_cookie_header(c, delta):
# strictly correct according to the cookie spec. Browsers
# appear to parse this tolerantly - maybe we should too.
# For now, we just ignore this.
- del attrs["expires"]
+ attrs = attrs.with_delitem("expires")
ret = format_set_cookie_header(name, value, attrs)
if not ret:
diff --git a/netlib/http/headers.py b/netlib/http/headers.py
index 72739f90..60d3f429 100644
--- a/netlib/http/headers.py
+++ b/netlib/http/headers.py
@@ -1,9 +1,3 @@
-"""
-
-Unicode Handling
-----------------
-See also: http://lucumr.pocoo.org/2013/7/2/the-updated-guide-to-unicode/
-"""
from __future__ import absolute_import, print_function, division
import re
@@ -13,23 +7,22 @@ try:
except ImportError: # pragma: no cover
from collections import MutableMapping # Workaround for Python < 3.3
-
import six
+from ..multidict import MultiDict
+from ..utils import always_bytes
-from netlib.utils import always_byte_args, always_bytes, Serializable
+# See also: http://lucumr.pocoo.org/2013/7/2/the-updated-guide-to-unicode/
if six.PY2: # pragma: no cover
_native = lambda x: x
_always_bytes = lambda x: x
- _always_byte_args = lambda x: x
else:
# While headers _should_ be ASCII, it's not uncommon for certain headers to be utf-8 encoded.
_native = lambda x: x.decode("utf-8", "surrogateescape")
_always_bytes = lambda x: always_bytes(x, "utf-8", "surrogateescape")
- _always_byte_args = always_byte_args("utf-8", "surrogateescape")
-class Headers(MutableMapping, Serializable):
+class Headers(MultiDict):
"""
Header class which allows both convenient access to individual headers as well as
direct access to the underlying raw data. Provides a full dictionary interface.
@@ -49,11 +42,11 @@ class Headers(MutableMapping, Serializable):
>>> h["host"]
"example.com"
- # Headers can also be creatd from a list of raw (header_name, header_value) byte tuples
+ # Headers can also be created from a list of raw (header_name, header_value) byte tuples
>>> h = Headers([
- [b"Host",b"example.com"],
- [b"Accept",b"text/html"],
- [b"accept",b"application/xml"]
+ (b"Host",b"example.com"),
+ (b"Accept",b"text/html"),
+ (b"accept",b"application/xml")
])
# Multiple headers are folded into a single header as per RFC7230
@@ -77,7 +70,6 @@ class Headers(MutableMapping, Serializable):
For use with the "Set-Cookie" header, see :py:meth:`get_all`.
"""
- @_always_byte_args
def __init__(self, fields=None, **headers):
"""
Args:
@@ -89,19 +81,29 @@ class Headers(MutableMapping, Serializable):
If ``**headers`` contains multiple keys that have equal ``.lower()`` s,
the behavior is undefined.
"""
- self.fields = fields or []
+ super(Headers, self).__init__(fields)
- for name, value in self.fields:
- if not isinstance(name, bytes) or not isinstance(value, bytes):
- raise ValueError("Headers passed as fields must be bytes.")
+ for key, value in self.fields:
+ if not isinstance(key, bytes) or not isinstance(value, bytes):
+ raise TypeError("Header fields must be bytes.")
# content_type -> content-type
headers = {
- _always_bytes(name).replace(b"_", b"-"): value
+ _always_bytes(name).replace(b"_", b"-"): _always_bytes(value)
for name, value in six.iteritems(headers)
}
self.update(headers)
+ @staticmethod
+ def _reduce_values(values):
+ # Headers can be folded
+ return ", ".join(values)
+
+ @staticmethod
+ def _kconv(key):
+ # Headers are case-insensitive
+ return key.lower()
+
def __bytes__(self):
if self.fields:
return b"\r\n".join(b": ".join(field) for field in self.fields) + b"\r\n"
@@ -111,98 +113,40 @@ class Headers(MutableMapping, Serializable):
if six.PY2: # pragma: no cover
__str__ = __bytes__
- @_always_byte_args
- def __getitem__(self, name):
- values = self.get_all(name)
- if not values:
- raise KeyError(name)
- return ", ".join(values)
-
- @_always_byte_args
- def __setitem__(self, name, value):
- idx = self._index(name)
-
- # To please the human eye, we insert at the same position the first existing header occured.
- if idx is not None:
- del self[name]
- self.fields.insert(idx, [name, value])
- else:
- self.fields.append([name, value])
-
- @_always_byte_args
- def __delitem__(self, name):
- if name not in self:
- raise KeyError(name)
- name = name.lower()
- self.fields = [
- field for field in self.fields
- if name != field[0].lower()
- ]
+ def __delitem__(self, key):
+ key = _always_bytes(key)
+ super(Headers, self).__delitem__(key)
def __iter__(self):
- seen = set()
- for name, _ in self.fields:
- name_lower = name.lower()
- if name_lower not in seen:
- seen.add(name_lower)
- yield _native(name)
-
- def __len__(self):
- return len(set(name.lower() for name, _ in self.fields))
-
- # __hash__ = object.__hash__
-
- def _index(self, name):
- name = name.lower()
- for i, field in enumerate(self.fields):
- if field[0].lower() == name:
- return i
- return None
-
- def __eq__(self, other):
- if isinstance(other, Headers):
- return self.fields == other.fields
- return False
-
- def __ne__(self, other):
- return not self.__eq__(other)
-
- @_always_byte_args
+ for x in super(Headers, self).__iter__():
+ yield _native(x)
+
def get_all(self, name):
"""
Like :py:meth:`get`, but does not fold multiple headers into a single one.
This is useful for Set-Cookie headers, which do not support folding.
-
See also: https://tools.ietf.org/html/rfc7230#section-3.2.2
"""
- name_lower = name.lower()
- values = [_native(value) for n, value in self.fields if n.lower() == name_lower]
- return values
+ name = _always_bytes(name)
+ return [
+ _native(x) for x in
+ super(Headers, self).get_all(name)
+ ]
- @_always_byte_args
def set_all(self, name, values):
"""
Explicitly set multiple headers for the given key.
See: :py:meth:`get_all`
"""
- values = map(_always_bytes, values) # _always_byte_args does not fix lists
- if name in self:
- del self[name]
- self.fields.extend(
- [name, value] for value in values
- )
-
- def get_state(self):
- return tuple(tuple(field) for field in self.fields)
-
- def set_state(self, state):
- self.fields = [list(field) for field in state]
+ name = _always_bytes(name)
+ values = [_always_bytes(x) for x in values]
+ return super(Headers, self).set_all(name, values)
- @classmethod
- def from_state(cls, state):
- return cls([list(field) for field in state])
+ def insert(self, index, key, value):
+ key = _always_bytes(key)
+ value = _always_bytes(value)
+ super(Headers, self).insert(index, key, value)
- @_always_byte_args
def replace(self, pattern, repl, flags=0):
"""
Replaces a regular expression pattern with repl in each "name: value"
@@ -211,6 +155,8 @@ class Headers(MutableMapping, Serializable):
Returns:
The number of replacements made.
"""
+ pattern = _always_bytes(pattern)
+ repl = _always_bytes(repl)
pattern = re.compile(pattern, flags)
replacements = 0
diff --git a/netlib/http/http1/read.py b/netlib/http/http1/read.py
index 6e3a1b93..d30976bd 100644
--- a/netlib/http/http1/read.py
+++ b/netlib/http/http1/read.py
@@ -316,14 +316,14 @@ def _read_headers(rfile):
if not ret:
raise HttpSyntaxException("Invalid headers")
# continued header
- ret[-1][1] = ret[-1][1] + b'\r\n ' + line.strip()
+ ret[-1] = (ret[-1][0], ret[-1][1] + b'\r\n ' + line.strip())
else:
try:
name, value = line.split(b":", 1)
value = value.strip()
if not name:
raise ValueError()
- ret.append([name, value])
+ ret.append((name, value))
except ValueError:
raise HttpSyntaxException("Invalid headers")
return Headers(ret)
diff --git a/netlib/http/http2/connections.py b/netlib/http/http2/connections.py
index f900b67c..6643b6b9 100644
--- a/netlib/http/http2/connections.py
+++ b/netlib/http/http2/connections.py
@@ -201,13 +201,13 @@ class HTTP2Protocol(object):
headers = request.headers.copy()
if ':authority' not in headers:
- headers.fields.insert(0, (b':authority', authority.encode('ascii')))
+ headers.insert(0, b':authority', authority.encode('ascii'))
if ':scheme' not in headers:
- headers.fields.insert(0, (b':scheme', request.scheme.encode('ascii')))
+ headers.insert(0, b':scheme', request.scheme.encode('ascii'))
if ':path' not in headers:
- headers.fields.insert(0, (b':path', request.path.encode('ascii')))
+ headers.insert(0, b':path', request.path.encode('ascii'))
if ':method' not in headers:
- headers.fields.insert(0, (b':method', request.method.encode('ascii')))
+ headers.insert(0, b':method', request.method.encode('ascii'))
if hasattr(request, 'stream_id'):
stream_id = request.stream_id
@@ -224,7 +224,7 @@ class HTTP2Protocol(object):
headers = response.headers.copy()
if ':status' not in headers:
- headers.fields.insert(0, (b':status', str(response.status_code).encode('ascii')))
+ headers.insert(0, b':status', str(response.status_code).encode('ascii'))
if hasattr(response, 'stream_id'):
stream_id = response.stream_id
@@ -420,7 +420,7 @@ class HTTP2Protocol(object):
self._handle_unexpected_frame(frm)
headers = Headers(
- [[k.encode('ascii'), v.encode('ascii')] for k, v in self.decoder.decode(header_blocks)]
+ (k.encode('ascii'), v.encode('ascii')) for k, v in self.decoder.decode(header_blocks)
)
return stream_id, headers, body
diff --git a/netlib/http/message.py b/netlib/http/message.py
index da9681a0..028f43a1 100644
--- a/netlib/http/message.py
+++ b/netlib/http/message.py
@@ -4,6 +4,7 @@ import warnings
import six
+from ..multidict import MultiDict
from .headers import Headers
from .. import encoding, utils
@@ -25,6 +26,9 @@ class MessageData(utils.Serializable):
def __ne__(self, other):
return not self.__eq__(other)
+ def __hash__(self):
+ return hash(frozenset(self.__dict__.items()))
+
def set_state(self, state):
for k, v in state.items():
if k == "headers":
@@ -51,6 +55,9 @@ class Message(utils.Serializable):
def __ne__(self, other):
return not self.__eq__(other)
+ def __hash__(self):
+ return hash(self.data) ^ 1
+
def get_state(self):
return self.data.get_state()
diff --git a/netlib/http/request.py b/netlib/http/request.py
index a42150ff..056a2d93 100644
--- a/netlib/http/request.py
+++ b/netlib/http/request.py
@@ -10,6 +10,7 @@ from netlib import utils
from netlib.http import cookies
from netlib.odict import ODict
from .. import encoding
+from ..multidict import MultiDictView
from .headers import Headers
from .message import Message, _native, _always_bytes, MessageData
@@ -224,45 +225,64 @@ class Request(Message):
@property
def query(self):
+ # type: () -> MultiDictView
"""
- The request query string as an :py:class:`ODict` object.
- None, if there is no query.
+ The request query string as an :py:class:`MultiDictView` object.
"""
+ return MultiDictView(
+ self._get_query,
+ self._set_query
+ )
+
+ def _get_query(self):
_, _, _, _, query, _ = urllib.parse.urlparse(self.url)
- if query:
- return ODict(utils.urldecode(query))
- return None
+ return tuple(utils.urldecode(query))
- @query.setter
- def query(self, odict):
- query = utils.urlencode(odict.lst)
+ def _set_query(self, value):
+ query = utils.urlencode(value)
scheme, netloc, path, params, _, fragment = urllib.parse.urlparse(self.url)
_, _, _, self.path = utils.parse_url(
urllib.parse.urlunparse([scheme, netloc, path, params, query, fragment]))
+ @query.setter
+ def query(self, value):
+ self._set_query(value)
+
@property
def cookies(self):
+ # type: () -> MultiDictView
"""
The request cookies.
- An empty :py:class:`ODict` object if the cookie monster ate them all.
+
+ An empty :py:class:`MultiDictView` object if the cookie monster ate them all.
"""
- ret = ODict()
- for i in self.headers.get_all("Cookie"):
- ret.extend(cookies.parse_cookie_header(i))
- return ret
+ return MultiDictView(
+ self._get_cookies,
+ self._set_cookies
+ )
+
+ def _get_cookies(self):
+ h = self.headers.get_all("Cookie")
+ return tuple(cookies.parse_cookie_headers(h))
+
+ def _set_cookies(self, value):
+ self.headers["cookie"] = cookies.format_cookie_header(value)
@cookies.setter
- def cookies(self, odict):
- self.headers["cookie"] = cookies.format_cookie_header(odict)
+ def cookies(self, value):
+ self._set_cookies(value)
@property
def path_components(self):
"""
- The URL's path components as a list of strings.
+ The URL's path components as a tuple of strings.
Components are unquoted.
"""
_, _, path, _, _, _ = urllib.parse.urlparse(self.url)
- return [urllib.parse.unquote(i) for i in path.split("/") if i]
+ # This needs to be a tuple so that it's immutable.
+ # Otherwise, this would fail silently:
+ # request.path_components.append("foo")
+ return tuple(urllib.parse.unquote(i) for i in path.split("/") if i)
@path_components.setter
def path_components(self, components):
@@ -309,64 +329,53 @@ class Request(Message):
@property
def urlencoded_form(self):
"""
- The URL-encoded form data as an :py:class:`ODict` object.
- None if there is no data or the content-type indicates non-form data.
+ The URL-encoded form data as an :py:class:`MultiDictView` object.
+ An empty MultiDictView if the content-type indicates non-form data
+ or the content could not be parsed.
"""
+ return MultiDictView(
+ self._get_urlencoded_form,
+ self._set_urlencoded_form
+ )
+
+ def _get_urlencoded_form(self):
is_valid_content_type = "application/x-www-form-urlencoded" in self.headers.get("content-type", "").lower()
- if self.content and is_valid_content_type:
- return ODict(utils.urldecode(self.content))
- return None
+ if is_valid_content_type:
+ return tuple(utils.urldecode(self.content))
+ return ()
- @urlencoded_form.setter
- def urlencoded_form(self, odict):
+ def _set_urlencoded_form(self, value):
"""
Sets the body to the URL-encoded form data, and adds the appropriate content-type header.
This will overwrite the existing content if there is one.
"""
self.headers["content-type"] = "application/x-www-form-urlencoded"
- self.content = utils.urlencode(odict.lst)
+ self.content = utils.urlencode(value)
+
+ @urlencoded_form.setter
+ def urlencoded_form(self, value):
+ self._set_urlencoded_form(value)
@property
def multipart_form(self):
"""
- The multipart form data as an :py:class:`ODict` object.
- None if there is no data or the content-type indicates non-form data.
+ The multipart form data as an :py:class:`MultipartFormDict` object.
+ None if the content-type indicates non-form data.
"""
+ return MultiDictView(
+ self._get_multipart_form,
+ self._set_multipart_form
+ )
+
+ def _get_multipart_form(self):
is_valid_content_type = "multipart/form-data" in self.headers.get("content-type", "").lower()
- if self.content and is_valid_content_type:
- return ODict(utils.multipartdecode(self.headers,self.content))
- return None
+ if is_valid_content_type:
+ return utils.multipartdecode(self.headers, self.content)
+ return ()
- @multipart_form.setter
- def multipart_form(self, value):
+ def _set_multipart_form(self, value):
raise NotImplementedError()
- # Legacy
-
- def get_query(self): # pragma: no cover
- warnings.warn(".get_query is deprecated, use .query instead.", DeprecationWarning)
- return self.query or ODict([])
-
- def set_query(self, odict): # pragma: no cover
- warnings.warn(".set_query is deprecated, use .query instead.", DeprecationWarning)
- self.query = odict
-
- def get_path_components(self): # pragma: no cover
- warnings.warn(".get_path_components is deprecated, use .path_components instead.", DeprecationWarning)
- return self.path_components
-
- def set_path_components(self, lst): # pragma: no cover
- warnings.warn(".set_path_components is deprecated, use .path_components instead.", DeprecationWarning)
- self.path_components = lst
-
- def get_form_urlencoded(self): # pragma: no cover
- warnings.warn(".get_form_urlencoded is deprecated, use .urlencoded_form instead.", DeprecationWarning)
- return self.urlencoded_form or ODict([])
-
- def set_form_urlencoded(self, odict): # pragma: no cover
- warnings.warn(".set_form_urlencoded is deprecated, use .urlencoded_form instead.", DeprecationWarning)
- self.urlencoded_form = odict
-
- def get_form_multipart(self): # pragma: no cover
- warnings.warn(".get_form_multipart is deprecated, use .multipart_form instead.", DeprecationWarning)
- return self.multipart_form or ODict([])
+ @multipart_form.setter
+ def multipart_form(self, value):
+ self._set_multipart_form(value)
diff --git a/netlib/http/response.py b/netlib/http/response.py
index 2f06149e..7d272e10 100644
--- a/netlib/http/response.py
+++ b/netlib/http/response.py
@@ -1,14 +1,13 @@
from __future__ import absolute_import, print_function, division
-import warnings
from email.utils import parsedate_tz, formatdate, mktime_tz
import time
from . import cookies
from .headers import Headers
from .message import Message, _native, _always_bytes, MessageData
+from ..multidict import MultiDictView
from .. import utils
-from ..odict import ODict
class ResponseData(MessageData):
@@ -72,29 +71,35 @@ class Response(Message):
@property
def cookies(self):
+ # type: () -> MultiDictView
"""
- Get the contents of all Set-Cookie headers.
+ The response cookies. A possibly empty :py:class:`MultiDictView`, where the keys are
+ cookie name strings, and values are (value, attr) tuples. Value is a string, and attr is
+ an ODictCaseless containing cookie attributes. Within attrs, unary attributes (e.g. HTTPOnly)
+ are indicated by a Null value.
- A possibly empty :py:class:`ODict`, where keys are cookie name strings,
- and values are [value, attr] lists. Value is a string, and attr is
- an ODictCaseless containing cookie attributes. Within attrs, unary
- attributes (e.g. HTTPOnly) are indicated by a Null value.
+ Caveats:
+ Updating the attr
"""
- ret = []
- for header in self.headers.get_all("set-cookie"):
- v = cookies.parse_set_cookie_header(header)
- if v:
- name, value, attrs = v
- ret.append([name, [value, attrs]])
- return ODict(ret)
+ return MultiDictView(
+ self._get_cookies,
+ self._set_cookies
+ )
+
+ def _get_cookies(self):
+ h = self.headers.get_all("set-cookie")
+ return tuple(cookies.parse_set_cookie_headers(h))
+
+ def _set_cookies(self, value):
+ cookie_headers = []
+ for k, v in value:
+ header = cookies.format_set_cookie_header(k, v[0], v[1])
+ cookie_headers.append(header)
+ self.headers.set_all("set-cookie", cookie_headers)
@cookies.setter
- def cookies(self, odict):
- values = []
- for i in odict.lst:
- header = cookies.format_set_cookie_header(i[0], i[1][0], i[1][1])
- values.append(header)
- self.headers.set_all("set-cookie", values)
+ def cookies(self, value):
+ self._set_cookies(value)
def refresh(self, now=None):
"""
diff --git a/netlib/multidict.py b/netlib/multidict.py
new file mode 100644
index 00000000..248acdec
--- /dev/null
+++ b/netlib/multidict.py
@@ -0,0 +1,282 @@
+from __future__ import absolute_import, print_function, division
+
+from abc import ABCMeta, abstractmethod
+
+from typing import Tuple, TypeVar
+
+try:
+ from collections.abc import MutableMapping
+except ImportError: # pragma: no cover
+ from collections import MutableMapping # Workaround for Python < 3.3
+
+import six
+
+from .utils import Serializable
+
+
+@six.add_metaclass(ABCMeta)
+class _MultiDict(MutableMapping, Serializable):
+ def __repr__(self):
+ fields = tuple(
+ repr(field)
+ for field in self.fields
+ )
+ return "{cls}[{fields}]".format(
+ cls=type(self).__name__,
+ fields=", ".join(fields)
+ )
+
+ @staticmethod
+ @abstractmethod
+ def _reduce_values(values):
+ """
+ If a user accesses multidict["foo"], this method
+ reduces all values for "foo" to a single value that is returned.
+ For example, HTTP headers are folded, whereas we will just take
+ the first cookie we found with that name.
+ """
+
+ @staticmethod
+ @abstractmethod
+ def _kconv(key):
+ """
+ This method converts a key to its canonical representation.
+ For example, HTTP headers are case-insensitive, so this method returns key.lower().
+ """
+
+ def __getitem__(self, key):
+ values = self.get_all(key)
+ if not values:
+ raise KeyError(key)
+ return self._reduce_values(values)
+
+ def __setitem__(self, key, value):
+ self.set_all(key, [value])
+
+ def __delitem__(self, key):
+ if key not in self:
+ raise KeyError(key)
+ key = self._kconv(key)
+ self.fields = tuple(
+ field for field in self.fields
+ if key != self._kconv(field[0])
+ )
+
+ def __iter__(self):
+ seen = set()
+ for key, _ in self.fields:
+ key_kconv = self._kconv(key)
+ if key_kconv not in seen:
+ seen.add(key_kconv)
+ yield key
+
+ def __len__(self):
+ return len(set(self._kconv(key) for key, _ in self.fields))
+
+ def __eq__(self, other):
+ if isinstance(other, MultiDict):
+ return self.fields == other.fields
+ return False
+
+ def __ne__(self, other):
+ return not self.__eq__(other)
+
+ def __hash__(self):
+ return hash(self.fields)
+
+ def get_all(self, key):
+ """
+ Return the list of all values for a given key.
+ If that key is not in the MultiDict, the return value will be an empty list.
+ """
+ key = self._kconv(key)
+ return [
+ value
+ for k, value in self.fields
+ if self._kconv(k) == key
+ ]
+
+ def set_all(self, key, values):
+ """
+ Remove the old values for a key and add new ones.
+ """
+ key_kconv = self._kconv(key)
+
+ new_fields = []
+ for field in self.fields:
+ if self._kconv(field[0]) == key_kconv:
+ if values:
+ new_fields.append(
+ (key, values.pop(0))
+ )
+ else:
+ new_fields.append(field)
+ while values:
+ new_fields.append(
+ (key, values.pop(0))
+ )
+ self.fields = tuple(new_fields)
+
+ def add(self, key, value):
+ """
+ Add an additional value for the given key at the bottom.
+ """
+ self.insert(len(self.fields), key, value)
+
+ def insert(self, index, key, value):
+ """
+ Insert an additional value for the given key at the specified position.
+ """
+ item = (key, value)
+ self.fields = self.fields[:index] + (item,) + self.fields[index:]
+
+ def keys(self, multi=False):
+ """
+ Get all keys.
+
+ Args:
+ multi(bool):
+ If True, one key per value will be returned.
+ If False, duplicate keys will only be returned once.
+ """
+ return (
+ k
+ for k, _ in self.items(multi)
+ )
+
+ def values(self, multi=False):
+ """
+ Get all values.
+
+ Args:
+ multi(bool):
+ If True, all values will be returned.
+ If False, only the first value per key will be returned.
+ """
+ return (
+ v
+ for _, v in self.items(multi)
+ )
+
+ def items(self, multi=False):
+ """
+ Get all (key, value) tuples.
+
+ Args:
+ multi(bool):
+ If True, all (key, value) pairs will be returned
+ If False, only the first (key, value) pair per unique key will be returned.
+ """
+ if multi:
+ return self.fields
+ else:
+ return super(_MultiDict, self).items()
+
+ def to_dict(self):
+ """
+ Get the MultiDict as a plain Python dict.
+ Keys with multiple values are returned as lists.
+
+ Example:
+
+ .. code-block:: python
+
+ # Simple dict with duplicate values.
+ >>> d
+ MultiDictView[("name", "value"), ("a", "false"), ("a", "42")]
+ >>> d.to_dict()
+ {
+ "name": "value",
+ "a": ["false", "42"]
+ }
+ """
+ d = {}
+ for key in self:
+ values = self.get_all(key)
+ if len(values) == 1:
+ d[key] = values[0]
+ else:
+ d[key] = values
+ return d
+
+ def get_state(self):
+ return self.fields
+
+ def set_state(self, state):
+ self.fields = tuple(tuple(x) for x in state)
+
+ @classmethod
+ def from_state(cls, state):
+ return cls(tuple(x) for x in state)
+
+
+class MultiDict(_MultiDict):
+ def __init__(self, fields=None):
+ super(MultiDict, self).__init__()
+ self.fields = tuple(fields) if fields else tuple() # type: Tuple[Tuple[bytes, bytes], ...]
+
+
+@six.add_metaclass(ABCMeta)
+class ImmutableMultiDict(MultiDict):
+ def _immutable(self, *_):
+ raise TypeError('{} objects are immutable'.format(self.__class__.__name__))
+
+ __delitem__ = set_all = insert = _immutable
+
+ def with_delitem(self, key):
+ """
+ Returns:
+ An updated ImmutableMultiDict. The original object will not be modified.
+ """
+ ret = self.copy()
+ super(ImmutableMultiDict, ret).__delitem__(key)
+ return ret
+
+ def with_set_all(self, key, values):
+ """
+ Returns:
+ An updated ImmutableMultiDict. The original object will not be modified.
+ """
+ ret = self.copy()
+ super(ImmutableMultiDict, ret).set_all(key, values)
+ return ret
+
+ def with_insert(self, index, key, value):
+ """
+ Returns:
+ An updated ImmutableMultiDict. The original object will not be modified.
+ """
+ ret = self.copy()
+ super(ImmutableMultiDict, ret).insert(index, key, value)
+ return ret
+
+
+class MultiDictView(_MultiDict):
+ """
+ The MultiDictView provides the MultiDict interface over calculated data.
+ The view itself contains no state - data is retrieved from the parent on
+ request, and stored back to the parent on change.
+ """
+ def __init__(self, getter, setter):
+ self._getter = getter
+ self._setter = setter
+ super(MultiDictView, self).__init__()
+
+ @staticmethod
+ def _kconv(key):
+ # All request-attributes are case-sensitive.
+ return key
+
+ @staticmethod
+ def _reduce_values(values):
+ # We just return the first element if
+ # multiple elements exist with the same key.
+ return values[0]
+
+ @property
+ def fields(self):
+ return self._getter()
+
+ @fields.setter
+ def fields(self, value):
+ return self._setter(value)
diff --git a/netlib/utils.py b/netlib/utils.py
index be2701a0..7499f71f 100644
--- a/netlib/utils.py
+++ b/netlib/utils.py
@@ -51,17 +51,6 @@ def always_bytes(unicode_or_bytes, *encode_args):
return unicode_or_bytes
-def always_byte_args(*encode_args):
- """Decorator that transparently encodes all arguments passed as unicode"""
- def decorator(fun):
- def _fun(*args, **kwargs):
- args = [always_bytes(arg, *encode_args) for arg in args]
- kwargs = {k: always_bytes(v, *encode_args) for k, v in six.iteritems(kwargs)}
- return fun(*args, **kwargs)
- return _fun
- return decorator
-
-
def native(s, *encoding_opts):
"""
Convert :py:class:`bytes` or :py:class:`unicode` to the native
diff --git a/pathod/utils.py b/pathod/utils.py
index 1e5bd9a4..d1e2dd00 100644
--- a/pathod/utils.py
+++ b/pathod/utils.py
@@ -1,5 +1,6 @@
import os
import sys
+import netlib.utils
SIZE_UNITS = dict(
@@ -75,27 +76,7 @@ def escape_unprintables(s):
return s
-class Data(object):
-
- def __init__(self, name):
- m = __import__(name)
- dirname, _ = os.path.split(m.__file__)
- self.dirname = os.path.abspath(dirname)
-
- def path(self, path):
- """
- Returns a path to the package data housed at 'path' under this
- module.Path can be a path to a file, or to a directory.
-
- This function will raise ValueError if the path does not exist.
- """
- fullpath = os.path.join(self.dirname, path)
- if not os.path.exists(fullpath):
- raise ValueError("dataPath: %s does not exist." % fullpath)
- return fullpath
-
-
-data = Data(__name__)
+data = netlib.utils.Data(__name__)
def daemonize(stdin='/dev/null', stdout='/dev/null', stderr='/dev/null'): # pragma: no cover
diff --git a/test/mitmproxy/test_examples.py b/test/mitmproxy/test_examples.py
index c401a6b9..c4b06f4b 100644
--- a/test/mitmproxy/test_examples.py
+++ b/test/mitmproxy/test_examples.py
@@ -5,11 +5,12 @@ from contextlib import contextmanager
from mitmproxy import utils, script
from mitmproxy.proxy import config
+import netlib.utils
from netlib import tutils as netutils
from netlib.http import Headers
from . import tservers, tutils
-example_dir = utils.Data(__name__).path("../../examples")
+example_dir = netlib.utils.Data(__name__).path("../../examples")
class DummyContext(object):
@@ -94,14 +95,22 @@ def test_modify_form():
flow = tutils.tflow(req=netutils.treq(headers=form_header))
with example("modify_form.py") as ex:
ex.run("request", flow)
- assert flow.request.urlencoded_form["mitmproxy"] == ["rocks"]
+ assert flow.request.urlencoded_form["mitmproxy"] == "rocks"
+
+ flow.request.headers["content-type"] = ""
+ ex.run("request", flow)
+ assert list(flow.request.urlencoded_form.items()) == [("foo", "bar")]
def test_modify_querystring():
flow = tutils.tflow(req=netutils.treq(path="/search?q=term"))
with example("modify_querystring.py") as ex:
ex.run("request", flow)
- assert flow.request.query["mitmproxy"] == ["rocks"]
+ assert flow.request.query["mitmproxy"] == "rocks"
+
+ flow.request.path = "/"
+ ex.run("request", flow)
+ assert flow.request.query["mitmproxy"] == "rocks"
def test_modify_response_body():
diff --git a/test/mitmproxy/test_flow.py b/test/mitmproxy/test_flow.py
index b9c6a2f6..bf417423 100644
--- a/test/mitmproxy/test_flow.py
+++ b/test/mitmproxy/test_flow.py
@@ -1067,60 +1067,6 @@ class TestRequest:
assert r.url == "https://address:22/path"
assert r.pretty_url == "https://foo.com:22/path"
- def test_path_components(self):
- r = HTTPRequest.wrap(netlib.tutils.treq())
- r.path = "/"
- assert r.get_path_components() == []
- r.path = "/foo/bar"
- assert r.get_path_components() == ["foo", "bar"]
- q = odict.ODict()
- q["test"] = ["123"]
- r.set_query(q)
- assert r.get_path_components() == ["foo", "bar"]
-
- r.set_path_components([])
- assert r.get_path_components() == []
- r.set_path_components(["foo"])
- assert r.get_path_components() == ["foo"]
- r.set_path_components(["/oo"])
- assert r.get_path_components() == ["/oo"]
- assert "%2F" in r.path
-
- def test_getset_form_urlencoded(self):
- d = odict.ODict([("one", "two"), ("three", "four")])
- r = HTTPRequest.wrap(netlib.tutils.treq(content=netlib.utils.urlencode(d.lst)))
- r.headers["content-type"] = "application/x-www-form-urlencoded"
- assert r.get_form_urlencoded() == d
-
- d = odict.ODict([("x", "y")])
- r.set_form_urlencoded(d)
- assert r.get_form_urlencoded() == d
-
- r.headers["content-type"] = "foo"
- assert not r.get_form_urlencoded()
-
- def test_getset_query(self):
- r = HTTPRequest.wrap(netlib.tutils.treq())
- r.path = "/foo?x=y&a=b"
- q = r.get_query()
- assert q.lst == [("x", "y"), ("a", "b")]
-
- r.path = "/"
- q = r.get_query()
- assert not q
-
- r.path = "/?adsfa"
- q = r.get_query()
- assert q.lst == [("adsfa", "")]
-
- r.path = "/foo?x=y&a=b"
- assert r.get_query()
- r.set_query(odict.ODict([]))
- assert not r.get_query()
- qv = odict.ODict([("a", "b"), ("c", "d")])
- r.set_query(qv)
- assert r.get_query() == qv
-
def test_anticache(self):
r = HTTPRequest.wrap(netlib.tutils.treq())
r.headers = Headers()
diff --git a/test/mitmproxy/test_flow_export.py b/test/mitmproxy/test_flow_export.py
index 035f07b7..c252c5bd 100644
--- a/test/mitmproxy/test_flow_export.py
+++ b/test/mitmproxy/test_flow_export.py
@@ -21,7 +21,7 @@ def python_equals(testdata, text):
assert clean_blanks(text).rstrip() == clean_blanks(d).rstrip()
-req_get = lambda: netlib.tutils.treq(method='GET', content='')
+req_get = lambda: netlib.tutils.treq(method='GET', content='', path=b"/path?a=foo&a=bar&b=baz")
req_post = lambda: netlib.tutils.treq(method='POST', headers=None)
@@ -31,7 +31,7 @@ req_patch = lambda: netlib.tutils.treq(method='PATCH', path=b"/path?query=param"
class TestExportCurlCommand():
def test_get(self):
flow = tutils.tflow(req=req_get())
- result = """curl -H 'header:qvalue' -H 'content-length:7' 'http://address/path'"""
+ result = """curl -H 'header:qvalue' -H 'content-length:7' 'http://address/path?a=foo&a=bar&b=baz'"""
assert flow_export.curl_command(flow) == result
def test_post(self):
@@ -70,7 +70,7 @@ class TestRawRequest():
def test_get(self):
flow = tutils.tflow(req=req_get())
result = dedent("""
- GET /path HTTP/1.1\r
+ GET /path?a=foo&a=bar&b=baz HTTP/1.1\r
header: qvalue\r
content-length: 7\r
host: address:22\r
diff --git a/test/mitmproxy/test_flow_export/locust_get.py b/test/mitmproxy/test_flow_export/locust_get.py
index 72d5932a..632d5d53 100644
--- a/test/mitmproxy/test_flow_export/locust_get.py
+++ b/test/mitmproxy/test_flow_export/locust_get.py
@@ -14,10 +14,16 @@ class UserBehavior(TaskSet):
'content-length': '7',
}
+ params = {
+ 'a': ['foo', 'bar'],
+ 'b': 'baz',
+ }
+
self.response = self.client.request(
method='GET',
url=url,
headers=headers,
+ params=params,
)
### Additional tasks can go here ###
diff --git a/test/mitmproxy/test_flow_export/locust_task_get.py b/test/mitmproxy/test_flow_export/locust_task_get.py
index 76f144fa..03821cd8 100644
--- a/test/mitmproxy/test_flow_export/locust_task_get.py
+++ b/test/mitmproxy/test_flow_export/locust_task_get.py
@@ -7,8 +7,14 @@
'content-length': '7',
}
+ params = {
+ 'a': ['foo', 'bar'],
+ 'b': 'baz',
+ }
+
self.response = self.client.request(
method='GET',
url=url,
headers=headers,
+ params=params,
)
diff --git a/test/mitmproxy/test_flow_export/python_get.py b/test/mitmproxy/test_flow_export/python_get.py
index ee3f48eb..af8f7c81 100644
--- a/test/mitmproxy/test_flow_export/python_get.py
+++ b/test/mitmproxy/test_flow_export/python_get.py
@@ -7,10 +7,16 @@ headers = {
'content-length': '7',
}
+params = {
+ 'a': ['foo', 'bar'],
+ 'b': 'baz',
+}
+
response = requests.request(
method='GET',
url=url,
headers=headers,
+ params=params,
)
print(response.text)
diff --git a/test/mitmproxy/tutils.py b/test/mitmproxy/tutils.py
index d51ac185..2dfd710e 100644
--- a/test/mitmproxy/tutils.py
+++ b/test/mitmproxy/tutils.py
@@ -8,6 +8,7 @@ from contextlib import contextmanager
from unittest.case import SkipTest
+import netlib.utils
import netlib.tutils
from mitmproxy import utils, controller
from mitmproxy.models import (
@@ -163,4 +164,4 @@ def capture_stderr(command, *args, **kwargs):
sys.stderr = out
-test_data = utils.Data(__name__)
+test_data = netlib.utils.Data(__name__)
diff --git a/test/netlib/http/http1/test_read.py b/test/netlib/http/http1/test_read.py
index 90234070..d8106904 100644
--- a/test/netlib/http/http1/test_read.py
+++ b/test/netlib/http/http1/test_read.py
@@ -261,7 +261,7 @@ class TestReadHeaders(object):
b"\r\n"
)
headers = self._read(data)
- assert headers.fields == [[b"Header", b"one"], [b"Header2", b"two"]]
+ assert headers.fields == ((b"Header", b"one"), (b"Header2", b"two"))
def test_read_multi(self):
data = (
@@ -270,7 +270,7 @@ class TestReadHeaders(object):
b"\r\n"
)
headers = self._read(data)
- assert headers.fields == [[b"Header", b"one"], [b"Header", b"two"]]
+ assert headers.fields == ((b"Header", b"one"), (b"Header", b"two"))
def test_read_continued(self):
data = (
@@ -280,7 +280,7 @@ class TestReadHeaders(object):
b"\r\n"
)
headers = self._read(data)
- assert headers.fields == [[b"Header", b"one\r\n two"], [b"Header2", b"three"]]
+ assert headers.fields == ((b"Header", b"one\r\n two"), (b"Header2", b"three"))
def test_read_continued_err(self):
data = b"\tfoo: bar\r\n"
@@ -300,7 +300,7 @@ class TestReadHeaders(object):
def test_read_empty_value(self):
data = b"bar:"
headers = self._read(data)
- assert headers.fields == [[b"bar", b""]]
+ assert headers.fields == ((b"bar", b""),)
def test_read_chunked():
req = treq(content=None)
diff --git a/test/netlib/http/http2/test_connections.py b/test/netlib/http/http2/test_connections.py
index 7b003067..7d240c0e 100644
--- a/test/netlib/http/http2/test_connections.py
+++ b/test/netlib/http/http2/test_connections.py
@@ -312,7 +312,7 @@ class TestReadRequest(tservers.ServerTestBase):
req = protocol.read_request(NotImplemented)
assert req.stream_id
- assert req.headers.fields == [[b':method', b'GET'], [b':path', b'/'], [b':scheme', b'https']]
+ assert req.headers.fields == ((b':method', b'GET'), (b':path', b'/'), (b':scheme', b'https'))
assert req.content == b'foobar'
@@ -418,7 +418,7 @@ class TestReadResponse(tservers.ServerTestBase):
assert resp.http_version == "HTTP/2.0"
assert resp.status_code == 200
assert resp.reason == ''
- assert resp.headers.fields == [[b':status', b'200'], [b'etag', b'foobar']]
+ assert resp.headers.fields == ((b':status', b'200'), (b'etag', b'foobar'))
assert resp.content == b'foobar'
assert resp.timestamp_end
@@ -445,7 +445,7 @@ class TestReadEmptyResponse(tservers.ServerTestBase):
assert resp.http_version == "HTTP/2.0"
assert resp.status_code == 200
assert resp.reason == ''
- assert resp.headers.fields == [[b':status', b'200'], [b'etag', b'foobar']]
+ assert resp.headers.fields == ((b':status', b'200'), (b'etag', b'foobar'))
assert resp.content == b''
diff --git a/test/netlib/http/test_cookies.py b/test/netlib/http/test_cookies.py
index da28850f..6f84c4ce 100644
--- a/test/netlib/http/test_cookies.py
+++ b/test/netlib/http/test_cookies.py
@@ -128,10 +128,10 @@ def test_cookie_roundtrips():
]
for s, lst in pairs:
ret = cookies.parse_cookie_header(s)
- assert ret.lst == lst
+ assert ret == lst
s2 = cookies.format_cookie_header(ret)
ret = cookies.parse_cookie_header(s2)
- assert ret.lst == lst
+ assert ret == lst
def test_parse_set_cookie_pairs():
@@ -197,24 +197,28 @@ def test_parse_set_cookie_header():
],
[
"one=uno",
- ("one", "uno", [])
+ ("one", "uno", ())
],
[
"one=uno; foo=bar",
- ("one", "uno", [["foo", "bar"]])
- ]
+ ("one", "uno", (("foo", "bar"),))
+ ],
+ [
+ "one=uno; foo=bar; foo=baz",
+ ("one", "uno", (("foo", "bar"), ("foo", "baz")))
+ ],
]
for s, expected in vals:
ret = cookies.parse_set_cookie_header(s)
if expected:
assert ret[0] == expected[0]
assert ret[1] == expected[1]
- assert ret[2].lst == expected[2]
+ assert ret[2].items(multi=True) == expected[2]
s2 = cookies.format_set_cookie_header(*ret)
ret2 = cookies.parse_set_cookie_header(s2)
assert ret2[0] == expected[0]
assert ret2[1] == expected[1]
- assert ret2[2].lst == expected[2]
+ assert ret2[2].items(multi=True) == expected[2]
else:
assert ret is None
diff --git a/test/netlib/http/test_headers.py b/test/netlib/http/test_headers.py
index 8c1db9dc..cd2ca9d1 100644
--- a/test/netlib/http/test_headers.py
+++ b/test/netlib/http/test_headers.py
@@ -5,10 +5,10 @@ from netlib.tutils import raises
class TestHeaders(object):
def _2host(self):
return Headers(
- [
- [b"Host", b"example.com"],
- [b"host", b"example.org"]
- ]
+ (
+ (b"Host", b"example.com"),
+ (b"host", b"example.org")
+ )
)
def test_init(self):
@@ -38,20 +38,10 @@ class TestHeaders(object):
assert headers["Host"] == "example.com"
assert headers["Accept"] == "text/plain"
- with raises(ValueError):
+ with raises(TypeError):
Headers([[b"Host", u"not-bytes"]])
- def test_getitem(self):
- headers = Headers(Host="example.com")
- assert headers["Host"] == "example.com"
- assert headers["host"] == "example.com"
- with raises(KeyError):
- _ = headers["Accept"]
-
- headers = self._2host()
- assert headers["Host"] == "example.com, example.org"
-
- def test_str(self):
+ def test_bytes(self):
headers = Headers(Host="example.com")
assert bytes(headers) == b"Host: example.com\r\n"
@@ -64,93 +54,6 @@ class TestHeaders(object):
headers = Headers()
assert bytes(headers) == b""
- def test_setitem(self):
- headers = Headers()
- headers["Host"] = "example.com"
- assert "Host" in headers
- assert "host" in headers
- assert headers["Host"] == "example.com"
-
- headers["host"] = "example.org"
- assert "Host" in headers
- assert "host" in headers
- assert headers["Host"] == "example.org"
-
- headers["accept"] = "text/plain"
- assert len(headers) == 2
- assert "Accept" in headers
- assert "Host" in headers
-
- headers = self._2host()
- assert len(headers.fields) == 2
- headers["Host"] = "example.com"
- assert len(headers.fields) == 1
- assert "Host" in headers
-
- def test_delitem(self):
- headers = Headers(Host="example.com")
- assert len(headers) == 1
- del headers["host"]
- assert len(headers) == 0
- try:
- del headers["host"]
- except KeyError:
- assert True
- else:
- assert False
-
- headers = self._2host()
- del headers["Host"]
- assert len(headers) == 0
-
- def test_keys(self):
- headers = Headers(Host="example.com")
- assert list(headers.keys()) == ["Host"]
-
- headers = self._2host()
- assert list(headers.keys()) == ["Host"]
-
- def test_eq_ne(self):
- headers1 = Headers(Host="example.com")
- headers2 = Headers(host="example.com")
- assert not (headers1 == headers2)
- assert headers1 != headers2
-
- headers1 = Headers(Host="example.com")
- headers2 = Headers(Host="example.com")
- assert headers1 == headers2
- assert not (headers1 != headers2)
-
- assert headers1 != 42
-
- def test_get_all(self):
- headers = self._2host()
- assert headers.get_all("host") == ["example.com", "example.org"]
- assert headers.get_all("accept") == []
-
- def test_set_all(self):
- headers = Headers(Host="example.com")
- headers.set_all("Accept", ["text/plain"])
- assert len(headers) == 2
- assert "accept" in headers
-
- headers = self._2host()
- headers.set_all("Host", ["example.org"])
- assert headers["host"] == "example.org"
-
- headers.set_all("Host", ["example.org", "example.net"])
- assert headers["host"] == "example.org, example.net"
-
- def test_state(self):
- headers = self._2host()
- assert len(headers.get_state()) == 2
- assert headers == Headers.from_state(headers.get_state())
-
- headers2 = Headers()
- assert headers != headers2
- headers2.set_state(headers.get_state())
- assert headers == headers2
-
def test_replace_simple(self):
headers = Headers(Host="example.com", Accept="text/plain")
replacements = headers.replace("Host: ", "X-Host: ")
diff --git a/test/netlib/http/test_request.py b/test/netlib/http/test_request.py
index 7ed6bd0f..fae7aefe 100644
--- a/test/netlib/http/test_request.py
+++ b/test/netlib/http/test_request.py
@@ -3,16 +3,14 @@ from __future__ import absolute_import, print_function, division
import six
-from netlib import utils
from netlib.http import Headers
-from netlib.odict import ODict
from netlib.tutils import treq, raises
from .test_message import _test_decoded_attr, _test_passthrough_attr
class TestRequestData(object):
def test_init(self):
- with raises(ValueError if six.PY2 else TypeError):
+ with raises(ValueError):
treq(headers="foobar")
assert isinstance(treq(headers=None).headers, Headers)
@@ -158,16 +156,17 @@ class TestRequestUtils(object):
def test_get_query(self):
request = treq()
- assert request.query is None
+ assert not request.query
request.url = "http://localhost:80/foo?bar=42"
- assert request.query.lst == [("bar", "42")]
+ assert dict(request.query) == {"bar": "42"}
def test_set_query(self):
- request = treq(host=b"foo", headers = Headers(host=b"bar"))
- request.query = ODict([])
- assert request.host == "foo"
- assert request.headers["host"] == "bar"
+ request = treq()
+ assert not request.query
+ request.query["foo"] = "bar"
+ assert request.query["foo"] == "bar"
+ assert request.path == "/path?foo=bar"
def test_get_cookies_none(self):
request = treq()
@@ -177,47 +176,50 @@ class TestRequestUtils(object):
def test_get_cookies_single(self):
request = treq()
request.headers = Headers(cookie="cookiename=cookievalue")
- result = request.cookies
- assert len(result) == 1
- assert result['cookiename'] == ['cookievalue']
+ assert len(request.cookies) == 1
+ assert request.cookies['cookiename'] == 'cookievalue'
def test_get_cookies_double(self):
request = treq()
request.headers = Headers(cookie="cookiename=cookievalue;othercookiename=othercookievalue")
result = request.cookies
assert len(result) == 2
- assert result['cookiename'] == ['cookievalue']
- assert result['othercookiename'] == ['othercookievalue']
+ assert result['cookiename'] == 'cookievalue'
+ assert result['othercookiename'] == 'othercookievalue'
def test_get_cookies_withequalsign(self):
request = treq()
request.headers = Headers(cookie="cookiename=coo=kievalue;othercookiename=othercookievalue")
result = request.cookies
assert len(result) == 2
- assert result['cookiename'] == ['coo=kievalue']
- assert result['othercookiename'] == ['othercookievalue']
+ assert result['cookiename'] == 'coo=kievalue'
+ assert result['othercookiename'] == 'othercookievalue'
def test_set_cookies(self):
request = treq()
request.headers = Headers(cookie="cookiename=cookievalue")
result = request.cookies
- result["cookiename"] = ["foo"]
- request.cookies = result
- assert request.cookies["cookiename"] == ["foo"]
+ result["cookiename"] = "foo"
+ assert request.cookies["cookiename"] == "foo"
def test_get_path_components(self):
request = treq(path=b"/foo/bar")
- assert request.path_components == ["foo", "bar"]
+ assert request.path_components == ("foo", "bar")
def test_set_path_components(self):
- request = treq(host=b"foo", headers = Headers(host=b"bar"))
+ request = treq()
request.path_components = ["foo", "baz"]
assert request.path == "/foo/baz"
+
request.path_components = []
assert request.path == "/"
- request.query = ODict([])
- assert request.host == "foo"
- assert request.headers["host"] == "bar"
+
+ request.path_components = ["foo", "baz"]
+ request.query["hello"] = "hello"
+ assert request.path_components == ("foo", "baz")
+
+ request.path_components = ["abc"]
+ assert request.path == "/abc?hello=hello"
def test_anticache(self):
request = treq()
@@ -246,26 +248,21 @@ class TestRequestUtils(object):
assert "gzip" in request.headers["Accept-Encoding"]
def test_get_urlencoded_form(self):
- request = treq(content="foobar")
- assert request.urlencoded_form is None
+ request = treq(content="foobar=baz")
+ assert not request.urlencoded_form
request.headers["Content-Type"] = "application/x-www-form-urlencoded"
- assert request.urlencoded_form == ODict(utils.urldecode(request.content))
+ assert list(request.urlencoded_form.items()) == [("foobar", "baz")]
def test_set_urlencoded_form(self):
request = treq()
- request.urlencoded_form = ODict([('foo', 'bar'), ('rab', 'oof')])
+ request.urlencoded_form = [('foo', 'bar'), ('rab', 'oof')]
assert request.headers["Content-Type"] == "application/x-www-form-urlencoded"
assert request.content
def test_get_multipart_form(self):
request = treq(content="foobar")
- assert request.multipart_form is None
+ assert not request.multipart_form
request.headers["Content-Type"] = "multipart/form-data"
- assert request.multipart_form == ODict(
- utils.multipartdecode(
- request.headers,
- request.content
- )
- )
+ assert list(request.multipart_form.items()) == []
diff --git a/test/netlib/http/test_response.py b/test/netlib/http/test_response.py
index 5440176c..cfd093d4 100644
--- a/test/netlib/http/test_response.py
+++ b/test/netlib/http/test_response.py
@@ -6,6 +6,7 @@ import six
import time
from netlib.http import Headers
+from netlib.http.cookies import CookieAttrs
from netlib.odict import ODict, ODictCaseless
from netlib.tutils import raises, tresp
from .test_message import _test_passthrough_attr, _test_decoded_attr
@@ -13,7 +14,7 @@ from .test_message import _test_passthrough_attr, _test_decoded_attr
class TestResponseData(object):
def test_init(self):
- with raises(ValueError if six.PY2 else TypeError):
+ with raises(ValueError):
tresp(headers="foobar")
assert isinstance(tresp(headers=None).headers, Headers)
@@ -56,7 +57,7 @@ class TestResponseUtils(object):
result = resp.cookies
assert len(result) == 1
assert "cookiename" in result
- assert result["cookiename"][0] == ["cookievalue", ODict()]
+ assert result["cookiename"] == ("cookievalue", CookieAttrs())
def test_get_cookies_with_parameters(self):
resp = tresp()
@@ -64,13 +65,13 @@ class TestResponseUtils(object):
result = resp.cookies
assert len(result) == 1
assert "cookiename" in result
- assert result["cookiename"][0][0] == "cookievalue"
- attrs = result["cookiename"][0][1]
+ assert result["cookiename"][0] == "cookievalue"
+ attrs = result["cookiename"][1]
assert len(attrs) == 4
- assert attrs["domain"] == ["example.com"]
- assert attrs["expires"] == ["Wed Oct 21 16:29:41 2015"]
- assert attrs["path"] == ["/"]
- assert attrs["httponly"] == [None]
+ assert attrs["domain"] == "example.com"
+ assert attrs["expires"] == "Wed Oct 21 16:29:41 2015"
+ assert attrs["path"] == "/"
+ assert attrs["httponly"] is None
def test_get_cookies_no_value(self):
resp = tresp()
@@ -78,8 +79,8 @@ class TestResponseUtils(object):
result = resp.cookies
assert len(result) == 1
assert "cookiename" in result
- assert result["cookiename"][0][0] == ""
- assert len(result["cookiename"][0][1]) == 2
+ assert result["cookiename"][0] == ""
+ assert len(result["cookiename"][1]) == 2
def test_get_cookies_twocookies(self):
resp = tresp()
@@ -90,19 +91,16 @@ class TestResponseUtils(object):
result = resp.cookies
assert len(result) == 2
assert "cookiename" in result
- assert result["cookiename"][0] == ["cookievalue", ODict()]
+ assert result["cookiename"] == ("cookievalue", CookieAttrs())
assert "othercookie" in result
- assert result["othercookie"][0] == ["othervalue", ODict()]
+ assert result["othercookie"] == ("othervalue", CookieAttrs())
def test_set_cookies(self):
resp = tresp()
- v = resp.cookies
- v.add("foo", ["bar", ODictCaseless()])
- resp.cookies = v
+ resp.cookies["foo"] = ("bar", {})
- v = resp.cookies
- assert len(v) == 1
- assert v["foo"] == [["bar", ODictCaseless()]]
+ assert len(resp.cookies) == 1
+ assert resp.cookies["foo"] == ("bar", CookieAttrs())
def test_refresh(self):
r = tresp()
diff --git a/test/netlib/test_multidict.py b/test/netlib/test_multidict.py
new file mode 100644
index 00000000..5bb65e3f
--- /dev/null
+++ b/test/netlib/test_multidict.py
@@ -0,0 +1,239 @@
+from netlib import tutils
+from netlib.multidict import MultiDict, ImmutableMultiDict, MultiDictView
+
+
+class _TMulti(object):
+ @staticmethod
+ def _reduce_values(values):
+ return values[0]
+
+ @staticmethod
+ def _kconv(key):
+ return key.lower()
+
+
+class TMultiDict(_TMulti, MultiDict):
+ pass
+
+
+class TImmutableMultiDict(_TMulti, ImmutableMultiDict):
+ pass
+
+
+class TestMultiDict(object):
+ @staticmethod
+ def _multi():
+ return TMultiDict((
+ ("foo", "bar"),
+ ("bar", "baz"),
+ ("Bar", "bam")
+ ))
+
+ def test_init(self):
+ md = TMultiDict()
+ assert len(md) == 0
+
+ md = TMultiDict([("foo", "bar")])
+ assert len(md) == 1
+ assert md.fields == (("foo", "bar"),)
+
+ def test_repr(self):
+ assert repr(self._multi()) == (
+ "TMultiDict[('foo', 'bar'), ('bar', 'baz'), ('Bar', 'bam')]"
+ )
+
+ def test_getitem(self):
+ md = TMultiDict([("foo", "bar")])
+ assert "foo" in md
+ assert "Foo" in md
+ assert md["foo"] == "bar"
+
+ with tutils.raises(KeyError):
+ _ = md["bar"]
+
+ md_multi = TMultiDict(
+ [("foo", "a"), ("foo", "b")]
+ )
+ assert md_multi["foo"] == "a"
+
+ def test_setitem(self):
+ md = TMultiDict()
+ md["foo"] = "bar"
+ assert md.fields == (("foo", "bar"),)
+
+ md["foo"] = "baz"
+ assert md.fields == (("foo", "baz"),)
+
+ md["bar"] = "bam"
+ assert md.fields == (("foo", "baz"), ("bar", "bam"))
+
+ def test_delitem(self):
+ md = self._multi()
+ del md["foo"]
+ assert "foo" not in md
+ assert "bar" in md
+
+ with tutils.raises(KeyError):
+ del md["foo"]
+
+ del md["bar"]
+ assert md.fields == ()
+
+ def test_iter(self):
+ md = self._multi()
+ assert list(md.__iter__()) == ["foo", "bar"]
+
+ def test_len(self):
+ md = TMultiDict()
+ assert len(md) == 0
+
+ md = self._multi()
+ assert len(md) == 2
+
+ def test_eq(self):
+ assert TMultiDict() == TMultiDict()
+ assert not (TMultiDict() == 42)
+
+ md1 = self._multi()
+ md2 = self._multi()
+ assert md1 == md2
+ md1.fields = md1.fields[1:] + md1.fields[:1]
+ assert not (md1 == md2)
+
+ def test_ne(self):
+ assert not TMultiDict() != TMultiDict()
+ assert TMultiDict() != self._multi()
+ assert TMultiDict() != 42
+
+ def test_get_all(self):
+ md = self._multi()
+ assert md.get_all("foo") == ["bar"]
+ assert md.get_all("bar") == ["baz", "bam"]
+ assert md.get_all("baz") == []
+
+ def test_set_all(self):
+ md = TMultiDict()
+ md.set_all("foo", ["bar", "baz"])
+ assert md.fields == (("foo", "bar"), ("foo", "baz"))
+
+ md = TMultiDict((
+ ("a", "b"),
+ ("x", "x"),
+ ("c", "d"),
+ ("X", "x"),
+ ("e", "f"),
+ ))
+ md.set_all("x", ["1", "2", "3"])
+ assert md.fields == (
+ ("a", "b"),
+ ("x", "1"),
+ ("c", "d"),
+ ("x", "2"),
+ ("e", "f"),
+ ("x", "3"),
+ )
+ md.set_all("x", ["4"])
+ assert md.fields == (
+ ("a", "b"),
+ ("x", "4"),
+ ("c", "d"),
+ ("e", "f"),
+ )
+
+ def test_add(self):
+ md = self._multi()
+ md.add("foo", "foo")
+ assert md.fields == (
+ ("foo", "bar"),
+ ("bar", "baz"),
+ ("Bar", "bam"),
+ ("foo", "foo")
+ )
+
+ def test_insert(self):
+ md = TMultiDict([("b", "b")])
+ md.insert(0, "a", "a")
+ md.insert(2, "c", "c")
+ assert md.fields == (("a", "a"), ("b", "b"), ("c", "c"))
+
+ def test_keys(self):
+ md = self._multi()
+ assert list(md.keys()) == ["foo", "bar"]
+ assert list(md.keys(multi=True)) == ["foo", "bar", "Bar"]
+
+ def test_values(self):
+ md = self._multi()
+ assert list(md.values()) == ["bar", "baz"]
+ assert list(md.values(multi=True)) == ["bar", "baz", "bam"]
+
+ def test_items(self):
+ md = self._multi()
+ assert list(md.items()) == [("foo", "bar"), ("bar", "baz")]
+ assert list(md.items(multi=True)) == [("foo", "bar"), ("bar", "baz"), ("Bar", "bam")]
+
+ def test_to_dict(self):
+ md = self._multi()
+ assert md.to_dict() == {
+ "foo": "bar",
+ "bar": ["baz", "bam"]
+ }
+
+ def test_state(self):
+ md = self._multi()
+ assert len(md.get_state()) == 3
+ assert md == TMultiDict.from_state(md.get_state())
+
+ md2 = TMultiDict()
+ assert md != md2
+ md2.set_state(md.get_state())
+ assert md == md2
+
+
+class TestImmutableMultiDict(object):
+ def test_modify(self):
+ md = TImmutableMultiDict()
+ with tutils.raises(TypeError):
+ md["foo"] = "bar"
+
+ with tutils.raises(TypeError):
+ del md["foo"]
+
+ with tutils.raises(TypeError):
+ md.add("foo", "bar")
+
+ def test_with_delitem(self):
+ md = TImmutableMultiDict([("foo", "bar")])
+ assert md.with_delitem("foo").fields == ()
+ assert md.fields == (("foo", "bar"),)
+
+ def test_with_set_all(self):
+ md = TImmutableMultiDict()
+ assert md.with_set_all("foo", ["bar"]).fields == (("foo", "bar"),)
+ assert md.fields == ()
+
+ def test_with_insert(self):
+ md = TImmutableMultiDict()
+ assert md.with_insert(0, "foo", "bar").fields == (("foo", "bar"),)
+
+
+class TParent(object):
+ def __init__(self):
+ self.vals = tuple()
+
+ def setter(self, vals):
+ self.vals = vals
+
+ def getter(self):
+ return self.vals
+
+
+class TestMultiDictView(object):
+ def test_modify(self):
+ p = TParent()
+ tv = MultiDictView(p.getter, p.setter)
+ assert len(tv) == 0
+ tv["a"] = "b"
+ assert p.vals == (("a", "b"),)
+ tv["c"] = "b"
+ assert p.vals == (("a", "b"), ("c", "b"))
+ assert tv["a"] == "b"
diff --git a/test/pathod/test_pathod.py b/test/pathod/test_pathod.py
index 10f3b5a3..05a3962e 100644
--- a/test/pathod/test_pathod.py
+++ b/test/pathod/test_pathod.py
@@ -233,6 +233,7 @@ class CommonTests(tutils.DaemonTests):
# FIXME: Race Condition?
assert "Parse error" in self.d.text_log()
+ @pytest.mark.skip(reason="race condition")
def test_websocket_frame_disconnect_error(self):
self.pathoc(["ws:/p/", "wf:b@10:d3"], ws_read_limit=0)
assert self.d.last_log()
diff --git a/test/pathod/tutils.py b/test/pathod/tutils.py
index 9739afde..f6ed3efb 100644
--- a/test/pathod/tutils.py
+++ b/test/pathod/tutils.py
@@ -116,7 +116,7 @@ tmpdir = netlib.tutils.tmpdir
raises = netlib.tutils.raises
-test_data = utils.Data(__name__)
+test_data = netlib.utils.Data(__name__)
def render(r, settings=language.Settings()):