diff options
40 files changed, 1015 insertions, 642 deletions
diff --git a/docs/dev/models.rst b/docs/dev/models.rst index 8c4e6825..f2ddf242 100644 --- a/docs/dev/models.rst +++ b/docs/dev/models.rst @@ -56,6 +56,17 @@ Datastructures :special-members: :no-undoc-members: + .. autoclass:: MultiDictView + + .. automethod:: get_all + .. automethod:: set_all + .. automethod:: add + .. automethod:: insert + .. automethod:: keys + .. automethod:: values + .. automethod:: items + .. automethod:: to_dict + .. autoclass:: decoded .. automodule:: mitmproxy.models diff --git a/examples/modify_form.py b/examples/modify_form.py index 86188781..3fe0cf96 100644 --- a/examples/modify_form.py +++ b/examples/modify_form.py @@ -1,5 +1,8 @@ def request(context, flow): - form = flow.request.urlencoded_form - if form is not None: - form["mitmproxy"] = ["rocks"] - flow.request.urlencoded_form = form + if flow.request.urlencoded_form: + flow.request.urlencoded_form["mitmproxy"] = "rocks" + else: + # This sets the proper content type and overrides the body. + flow.request.urlencoded_form = [ + ("foo", "bar") + ] diff --git a/examples/modify_querystring.py b/examples/modify_querystring.py index d682df69..b89e5c8d 100644 --- a/examples/modify_querystring.py +++ b/examples/modify_querystring.py @@ -1,5 +1,2 @@ def request(context, flow): - q = flow.request.query - if q: - q["mitmproxy"] = ["rocks"] - flow.request.query = q + flow.request.query["mitmproxy"] = "rocks" diff --git a/mitmproxy/console/flowview.py b/mitmproxy/console/flowview.py index b2ebe49e..2010cecd 100644 --- a/mitmproxy/console/flowview.py +++ b/mitmproxy/console/flowview.py @@ -6,8 +6,7 @@ import sys import math import urwid -from netlib import odict -from netlib.http import Headers +from netlib.http import Headers, status_codes from . import common, grideditor, signals, searchable, tabs from . import flowdetailview from .. import utils, controller, contentviews @@ -187,7 +186,7 @@ class FlowView(tabs.Tabs): viewmode, message, limit, - (bytes(message.headers), message.content) # Cache invalidation + message # Cache invalidation ) def _get_content_view(self, viewmode, message, max_lines, _): @@ -316,21 +315,18 @@ class FlowView(tabs.Tabs): return "Invalid URL." signals.flow_change.send(self, flow = self.flow) - def set_resp_code(self, code): - response = self.flow.response + def set_resp_status_code(self, status_code): try: - response.status_code = int(code) + status_code = int(status_code) except ValueError: return None - import BaseHTTPServer - if int(code) in BaseHTTPServer.BaseHTTPRequestHandler.responses: - response.msg = BaseHTTPServer.BaseHTTPRequestHandler.responses[ - int(code)][0] + self.flow.response.status_code = status_code + if status_code in status_codes.RESPONSES: + self.flow.response.reason = status_codes.RESPONSES[status_code] signals.flow_change.send(self, flow = self.flow) - def set_resp_msg(self, msg): - response = self.flow.response - response.msg = msg + def set_resp_reason(self, reason): + self.flow.response.reason = reason signals.flow_change.send(self, flow = self.flow) def set_headers(self, fields, conn): @@ -338,22 +334,22 @@ class FlowView(tabs.Tabs): signals.flow_change.send(self, flow = self.flow) def set_query(self, lst, conn): - conn.set_query(odict.ODict(lst)) + conn.query = lst signals.flow_change.send(self, flow = self.flow) def set_path_components(self, lst, conn): - conn.set_path_components(lst) + conn.path_components = lst signals.flow_change.send(self, flow = self.flow) def set_form(self, lst, conn): - conn.set_form_urlencoded(odict.ODict(lst)) + conn.urlencoded_form = lst signals.flow_change.send(self, flow = self.flow) def edit_form(self, conn): self.master.view_grideditor( grideditor.URLEncodedFormEditor( self.master, - conn.get_form_urlencoded().lst, + conn.urlencoded_form.items(multi=True), self.set_form, conn ) @@ -364,7 +360,7 @@ class FlowView(tabs.Tabs): self.edit_form(conn) def set_cookies(self, lst, conn): - conn.cookies = odict.ODict(lst) + conn.cookies = lst signals.flow_change.send(self, flow = self.flow) def set_setcookies(self, data, conn): @@ -388,7 +384,7 @@ class FlowView(tabs.Tabs): self.master.view_grideditor( grideditor.CookieEditor( self.master, - message.cookies.lst, + message.cookies.items(multi=True), self.set_cookies, message ) @@ -397,7 +393,7 @@ class FlowView(tabs.Tabs): self.master.view_grideditor( grideditor.SetCookieEditor( self.master, - message.cookies, + message.cookies.items(multi=True), self.set_setcookies, message ) @@ -413,7 +409,7 @@ class FlowView(tabs.Tabs): c = self.master.spawn_editor(message.content or "") message.content = c.rstrip("\n") elif part == "f": - if not message.get_form_urlencoded() and message.content: + if not message.urlencoded_form and message.content: signals.status_prompt_onekey.send( prompt = "Existing body is not a URL-encoded form. Clear and edit?", keys = [ @@ -435,7 +431,7 @@ class FlowView(tabs.Tabs): ) ) elif part == "p": - p = message.get_path_components() + p = message.path_components self.master.view_grideditor( grideditor.PathEditor( self.master, @@ -448,7 +444,7 @@ class FlowView(tabs.Tabs): self.master.view_grideditor( grideditor.QueryEditor( self.master, - message.get_query().lst, + message.query.items(multi=True), self.set_query, message ) ) @@ -458,7 +454,7 @@ class FlowView(tabs.Tabs): text = message.url, callback = self.set_url ) - elif part == "m": + elif part == "m" and message == self.flow.request: signals.status_prompt_onekey.send( prompt = "Method", keys = common.METHOD_OPTIONS, @@ -468,13 +464,13 @@ class FlowView(tabs.Tabs): signals.status_prompt.send( prompt = "Code", text = str(message.status_code), - callback = self.set_resp_code + callback = self.set_resp_status_code ) - elif part == "m": + elif part == "m" and message == self.flow.response: signals.status_prompt.send( prompt = "Message", - text = message.msg, - callback = self.set_resp_msg + text = message.reason, + callback = self.set_resp_reason ) signals.flow_change.send(self, flow = self.flow) diff --git a/mitmproxy/console/grideditor.py b/mitmproxy/console/grideditor.py index 46ff348e..11ce7d02 100644 --- a/mitmproxy/console/grideditor.py +++ b/mitmproxy/console/grideditor.py @@ -700,17 +700,17 @@ class SetCookieEditor(GridEditor): def data_in(self, data): flattened = [] - for k, v in data.items(): - flattened.append([k, v[0], v[1].lst]) + for key, (value, attrs) in data: + flattened.append([key, value, attrs.items(multi=True)]) return flattened def data_out(self, data): vals = [] - for i in data: + for key, value, attrs in data: vals.append( [ - i[0], - [i[1], odict.ODictCaseless(i[2])] + key, + (value, attrs) ] ) - return odict.ODict(vals) + return vals diff --git a/mitmproxy/flow.py b/mitmproxy/flow.py index ccedd1d4..a9018e16 100644 --- a/mitmproxy/flow.py +++ b/mitmproxy/flow.py @@ -158,9 +158,9 @@ class SetHeaders: for _, header, value, cpatt in self.lst: if cpatt(f): if f.response: - f.response.headers.fields.append((header, value)) + f.response.headers.add(header, value) else: - f.request.headers.fields.append((header, value)) + f.request.headers.add(header, value) class StreamLargeBodies(object): @@ -265,7 +265,7 @@ class ServerPlaybackState: form_contents = r.urlencoded_form or r.multipart_form if self.ignore_payload_params and form_contents: key.extend( - p for p in form_contents + p for p in form_contents.items(multi=True) if p[0] not in self.ignore_payload_params ) else: @@ -321,10 +321,10 @@ class StickyCookieState: """ domain = f.request.host path = "/" - if attrs["domain"]: - domain = attrs["domain"][-1] - if attrs["path"]: - path = attrs["path"][-1] + if "domain" in attrs: + domain = attrs["domain"] + if "path" in attrs: + path = attrs["path"] return (domain, f.request.port, path) def domain_match(self, a, b): @@ -335,28 +335,26 @@ class StickyCookieState: return False def handle_response(self, f): - for i in f.response.headers.get_all("set-cookie"): + for name, (value, attrs) in f.response.cookies.items(multi=True): # FIXME: We now know that Cookie.py screws up some cookies with # valid RFC 822/1123 datetime specifications for expiry. Sigh. - name, value, attrs = cookies.parse_set_cookie_header(str(i)) a = self.ckey(attrs, f) if self.domain_match(f.request.host, a[0]): - b = attrs.lst - b.insert(0, [name, value]) - self.jar[a][name] = odict.ODictCaseless(b) + b = attrs.with_insert(0, name, value) + self.jar[a][name] = b def handle_request(self, f): l = [] if f.match(self.flt): - for i in self.jar.keys(): + for domain, port, path in self.jar.keys(): match = [ - self.domain_match(f.request.host, i[0]), - f.request.port == i[1], - f.request.path.startswith(i[2]) + self.domain_match(f.request.host, domain), + f.request.port == port, + f.request.path.startswith(path) ] if all(match): - c = self.jar[i] - l.extend([cookies.format_cookie_header(c[name]) for name in c.keys()]) + c = self.jar[(domain, port, path)] + l.extend([cookies.format_cookie_header(c[name].items(multi=True)) for name in c.keys()]) if l: f.request.stickycookie = True f.request.headers["cookie"] = "; ".join(l) diff --git a/mitmproxy/flow_export.py b/mitmproxy/flow_export.py index d8e65704..ae282fce 100644 --- a/mitmproxy/flow_export.py +++ b/mitmproxy/flow_export.py @@ -51,7 +51,7 @@ def python_code(flow): params = "" if flow.request.query: - lines = [" '%s': '%s',\n" % (k, v) for k, v in flow.request.query] + lines = [" %s: %s,\n" % (repr(k), repr(v)) for k, v in flow.request.query.to_dict().items()] params = "\nparams = {\n%s}\n" % "".join(lines) args += "\n params=params," @@ -140,7 +140,7 @@ def locust_code(flow): params = "" if flow.request.query: - lines = [" '%s': '%s',\n" % (k, v) for k, v in flow.request.query] + lines = [" %s: %s,\n" % (repr(k), repr(v)) for k, v in flow.request.query.to_dict().items()] params = "\n params = {\n%s }\n" % "".join(lines) args += "\n params=params," diff --git a/mitmproxy/protocol/base.py b/mitmproxy/protocol/base.py index 536f2753..c8e58d1b 100644 --- a/mitmproxy/protocol/base.py +++ b/mitmproxy/protocol/base.py @@ -133,24 +133,15 @@ class ServerConnectionMixin(object): "The proxy shall not connect to itself.".format(repr(address)) ) - def set_server(self, address, server_tls=None, sni=None): + def set_server(self, address): """ Sets a new server address. If there is an existing connection, it will be closed. - - Raises: - ~mitmproxy.exceptions.ProtocolException: - if ``server_tls`` is ``True``, but there was no TLS layer on the - protocol stack which could have processed this. """ if self.server_conn: self.disconnect() self.log("Set new server address: " + repr(address), "debug") self.server_conn.address = address self.__check_self_connect() - if server_tls: - raise ProtocolException( - "Cannot upgrade to TLS, no TLS layer on the protocol stack." - ) def disconnect(self): """ diff --git a/mitmproxy/protocol/http.py b/mitmproxy/protocol/http.py index 9cb35176..d9111303 100644 --- a/mitmproxy/protocol/http.py +++ b/mitmproxy/protocol/http.py @@ -120,7 +120,7 @@ class UpstreamConnectLayer(Layer): if address != self.server_conn.via.address: self.ctx.set_server(address) - def set_server(self, address, server_tls=None, sni=None): + def set_server(self, address): if self.ctx.server_conn: self.ctx.disconnect() address = tcp.Address.wrap(address) @@ -128,11 +128,6 @@ class UpstreamConnectLayer(Layer): self.connect_request.port = address.port self.server_conn.address = address - if server_tls: - raise ProtocolException( - "Cannot upgrade to TLS, no TLS layer on the protocol stack." - ) - class HttpLayer(Layer): @@ -149,7 +144,7 @@ class HttpLayer(Layer): def __call__(self): if self.mode == "transparent": - self.__initial_server_tls = self._server_tls + self.__initial_server_tls = self.server_tls self.__initial_server_conn = self.server_conn while True: try: @@ -360,8 +355,9 @@ class HttpLayer(Layer): if self.mode == "regular" or self.mode == "transparent": # If there's an existing connection that doesn't match our expectations, kill it. - if address != self.server_conn.address or tls != self.server_conn.tls_established: - self.set_server(address, tls, address.host) + if address != self.server_conn.address or tls != self.server_tls: + self.set_server(address) + self.set_server_tls(tls, address.host) # Establish connection is neccessary. if not self.server_conn: self.connect() diff --git a/mitmproxy/protocol/tls.py b/mitmproxy/protocol/tls.py index 26c3f9d2..74c55ab4 100644 --- a/mitmproxy/protocol/tls.py +++ b/mitmproxy/protocol/tls.py @@ -266,18 +266,22 @@ class TlsClientHello(object): return self._client_hello @property - def client_cipher_suites(self): + def cipher_suites(self): return self._client_hello.cipher_suites.cipher_suites @property - def client_sni(self): + def sni(self): for extension in self._client_hello.extensions: - if (extension.type == 0x00 and len(extension.server_names) == 1 - and extension.server_names[0].type == 0): + is_valid_sni_extension = ( + extension.type == 0x00 + and len(extension.server_names) == 1 + and extension.server_names[0].type == 0 + ) + if is_valid_sni_extension: return extension.server_names[0].name @property - def client_alpn_protocols(self): + def alpn_protocols(self): for extension in self._client_hello.extensions: if extension.type == 0x10: return list(extension.alpn_protocols) @@ -304,55 +308,78 @@ class TlsClientHello(object): def __repr__(self): return "TlsClientHello( sni: %s alpn_protocols: %s, cipher_suites: %s)" % \ - (self.client_sni, self.client_alpn_protocols, self.client_cipher_suites) + (self.sni, self.alpn_protocols, self.cipher_suites) class TlsLayer(Layer): + """ + The TLS layer implements transparent TLS connections. - def __init__(self, ctx, client_tls, server_tls): - self.client_sni = None - self.client_alpn_protocols = None - self.client_ciphers = [] + It exposes the following API to child layers: + + - :py:meth:`set_server_tls` to modify TLS settings for the server connection. + - :py:attr:`server_tls`, :py:attr:`server_sni` as read-only attributes describing the current TLS settings for + the server connection. + """ + def __init__(self, ctx, client_tls, server_tls): super(TlsLayer, self).__init__(ctx) self._client_tls = client_tls self._server_tls = server_tls - self._sni_from_server_change = None + self._custom_server_sni = None + self._client_hello = None # type: TlsClientHello def __call__(self): """ - The strategy for establishing SSL is as follows: + The strategy for establishing TLS is as follows: First, we determine whether we need the server cert to establish ssl with the client. If so, we first connect to the server and then to the client. - If not, we only connect to the client and do the server_ssl lazily on a Connect message. - - An additional complexity is that establish ssl with the server may require a SNI value from - the client. In an ideal world, we'd do the following: - 1. Start the SSL handshake with the client - 2. Check if the client sends a SNI. - 3. Pause the client handshake, establish SSL with the server. - 4. Finish the client handshake with the certificate from the server. - There's just one issue: We cannot get a callback from OpenSSL if the client doesn't send a SNI. :( - Thus, we manually peek into the connection and parse the ClientHello message to obtain both SNI and ALPN values. - - Further notes: - - OpenSSL 1.0.2 introduces a callback that would help here: - https://www.openssl.org/docs/ssl/SSL_CTX_set_cert_cb.html - - The original mitmproxy issue is https://github.com/mitmproxy/mitmproxy/issues/427 - """ - - client_tls_requires_server_cert = ( - self._client_tls and self._server_tls and not self.config.no_upstream_cert - ) + If not, we only connect to the client and do the server handshake lazily. + An additional complexity is that we need to mirror SNI and ALPN from the client when connecting to the server. + We manually peek into the connection and parse the ClientHello message to obtain these values. + """ if self._client_tls: - self._parse_client_hello() + # Peek into the connection, read the initial client hello and parse it to obtain SNI and ALPN values. + try: + self._client_hello = TlsClientHello.from_client_conn(self.client_conn) + except TlsProtocolException as e: + self.log("Cannot parse Client Hello: %s" % repr(e), "error") + + # Do we need to do a server handshake now? + # There are two reasons why we would want to establish TLS with the server now: + # 1. If we already have an existing server connection and server_tls is True, + # we need to establish TLS now because .connect() will not be called anymore. + # 2. We may need information from the server connection for the client handshake. + # + # A couple of factors influence (2): + # 2.1 There actually is (or will be) a TLS-enabled upstream connection + # 2.2 An upstream connection is not wanted by the user if --no-upstream-cert is passed. + # 2.3 An upstream connection is implied by add_upstream_certs_to_client_chain + # 2.4 The client wants to negotiate an alternative protocol in its handshake, we need to find out + # what is supported by the server + # 2.5 The client did not sent a SNI value, we don't know the certificate subject. + client_tls_requires_server_connection = ( + self._server_tls + and not self.config.no_upstream_cert + and ( + self.config.add_upstream_certs_to_client_chain + or self._client_hello.alpn_protocols + or not self._client_hello.sni + ) + ) + establish_server_tls_now = ( + (self.server_conn and self._server_tls) + or client_tls_requires_server_connection + ) - if client_tls_requires_server_cert: + if self._client_tls and establish_server_tls_now: self._establish_tls_with_client_and_server() elif self._client_tls: self._establish_tls_with_client() + elif establish_server_tls_now: + self._establish_tls_with_server() layer = self.ctx.next_layer(self) layer() @@ -367,47 +394,48 @@ class TlsLayer(Layer): else: return "TlsLayer(inactive)" - def _parse_client_hello(self): - """ - Peek into the connection, read the initial client hello and parse it to obtain ALPN values. - """ - try: - parsed = TlsClientHello.from_client_conn(self.client_conn) - self.client_sni = parsed.client_sni - self.client_alpn_protocols = parsed.client_alpn_protocols - self.client_ciphers = parsed.client_cipher_suites - except TlsProtocolException as e: - self.log("Cannot parse Client Hello: %s" % repr(e), "error") - def connect(self): if not self.server_conn: self.ctx.connect() if self._server_tls and not self.server_conn.tls_established: self._establish_tls_with_server() - def set_server(self, address, server_tls=None, sni=None): - if server_tls is not None: - self._sni_from_server_change = sni - self._server_tls = server_tls - self.ctx.set_server(address, None, None) + def set_server_tls(self, server_tls, sni=None): + """ + Set the TLS settings for the next server connection that will be established. + This function will not alter an existing connection. + + Args: + server_tls: Shall we establish TLS with the server? + sni: ``bytes`` for a custom SNI value, + ``None`` for the client SNI value, + ``False`` if no SNI value should be sent. + """ + self._server_tls = server_tls + self._custom_server_sni = sni + + @property + def server_tls(self): + """ + ``True``, if the next server connection that will be established should be upgraded to TLS. + """ + return self._server_tls @property - def sni_for_server_connection(self): - if self._sni_from_server_change is False: + def server_sni(self): + """ + The Server Name Indication we want to send with the next server TLS handshake. + """ + if self._custom_server_sni is False: return None else: - return self._sni_from_server_change or self.client_sni + return self._custom_server_sni or self._client_hello.sni @property def alpn_for_client_connection(self): return self.server_conn.get_alpn_proto_negotiated() def __alpn_select_callback(self, conn_, options): - """ - Once the client signals the alternate protocols it supports, - we reconnect upstream with the same list and pass the server's choice down to the client. - """ - # This gets triggered if we haven't established an upstream connection yet. default_alpn = b'http/1.1' # alpn_preference = b'h2' @@ -422,12 +450,12 @@ class TlsLayer(Layer): return choice def _establish_tls_with_client_and_server(self): - # If establishing TLS with the server fails, we try to establish TLS with the client nonetheless - # to send an error message over TLS. try: self.ctx.connect() self._establish_tls_with_server() except Exception: + # If establishing TLS with the server fails, we try to establish TLS with the client nonetheless + # to send an error message over TLS. try: self._establish_tls_with_client() except: @@ -466,9 +494,9 @@ class TlsLayer(Layer): ClientHandshakeException, ClientHandshakeException( "Cannot establish TLS with client (sni: {sni}): {e}".format( - sni=self.client_sni, e=repr(e) + sni=self._client_hello.sni, e=repr(e) ), - self.client_sni or repr(self.server_conn.address) + self._client_hello.sni or repr(self.server_conn.address) ), sys.exc_info()[2] ) @@ -480,8 +508,8 @@ class TlsLayer(Layer): # If the server only supports spdy (next to http/1.1), it may select that # and mitmproxy would enter TCP passthrough mode, which we want to avoid. deprecated_http2_variant = lambda x: x.startswith(b"h2-") or x.startswith(b"spdy") - if self.client_alpn_protocols: - alpn = [x for x in self.client_alpn_protocols if not deprecated_http2_variant(x)] + if self._client_hello.alpn_protocols: + alpn = [x for x in self._client_hello.alpn_protocols if not deprecated_http2_variant(x)] else: alpn = None if alpn and b"h2" in alpn and not self.config.http2: @@ -490,14 +518,14 @@ class TlsLayer(Layer): ciphers_server = self.config.ciphers_server if not ciphers_server: ciphers_server = [] - for id in self.client_ciphers: + for id in self._client_hello.cipher_suites: if id in CIPHER_ID_NAME_MAP.keys(): ciphers_server.append(CIPHER_ID_NAME_MAP[id]) ciphers_server = ':'.join(ciphers_server) self.server_conn.establish_ssl( self.config.clientcerts, - self.sni_for_server_connection, + self.server_sni, method=self.config.openssl_method_server, options=self.config.openssl_options_server, verify_options=self.config.openssl_verification_mode_server, @@ -524,7 +552,7 @@ class TlsLayer(Layer): TlsProtocolException, TlsProtocolException("Cannot establish TLS with {address} (sni: {sni}): {e}".format( address=repr(self.server_conn.address), - sni=self.sni_for_server_connection, + sni=self.server_sni, e=repr(e), )), sys.exc_info()[2] @@ -534,7 +562,7 @@ class TlsLayer(Layer): TlsProtocolException, TlsProtocolException("Cannot establish TLS with {address} (sni: {sni}): {e}".format( address=repr(self.server_conn.address), - sni=self.sni_for_server_connection, + sni=self.server_sni, e=repr(e), )), sys.exc_info()[2] @@ -569,13 +597,13 @@ class TlsLayer(Layer): sans.add(host) host = upstream_cert.cn.decode("utf8").encode("idna") # Also add SNI values. - if self.client_sni: - sans.add(self.client_sni) - if self._sni_from_server_change: - sans.add(self._sni_from_server_change) + if self._client_hello.sni: + sans.add(self._client_hello.sni) + if self._custom_server_sni: + sans.add(self._custom_server_sni) - # Some applications don't consider the CN and expect the hostname to be in the SANs. - # For example, Thunderbird 38 will display a warning if the remote host is only the CN. + # RFC 2818: If a subjectAltName extension of type dNSName is present, that MUST be used as the identity. + # In other words, the Common Name is irrelevant then. if host: sans.add(host) return self.config.certstore.get_cert(host, list(sans)) diff --git a/mitmproxy/proxy/root_context.py b/mitmproxy/proxy/root_context.py index 9caae02a..c55105ec 100644 --- a/mitmproxy/proxy/root_context.py +++ b/mitmproxy/proxy/root_context.py @@ -63,7 +63,7 @@ class RootContext(object): except TlsProtocolException as e: self.log("Cannot parse Client Hello: %s" % repr(e), "error") else: - ignore = self.config.check_ignore((client_hello.client_sni, 443)) + ignore = self.config.check_ignore((client_hello.sni, 443)) if ignore: return RawTCPLayer(top_layer, logging=False) diff --git a/mitmproxy/utils.py b/mitmproxy/utils.py index 5fd062ea..cda5bba6 100644 --- a/mitmproxy/utils.py +++ b/mitmproxy/utils.py @@ -7,6 +7,9 @@ import json import importlib import inspect +import netlib.utils + + def timestamp(): """ Returns a serializable UTC timestamp. @@ -73,25 +76,7 @@ def pretty_duration(secs): return "{:.0f}ms".format(secs * 1000) -class Data: - - def __init__(self, name): - m = importlib.import_module(name) - dirname = os.path.dirname(inspect.getsourcefile(m)) - self.dirname = os.path.abspath(dirname) - - def path(self, path): - """ - Returns a path to the package data housed at 'path' under this - module.Path can be a path to a file, or to a directory. - - This function will raise ValueError if the path does not exist. - """ - fullpath = os.path.join(self.dirname, path) - if not os.path.exists(fullpath): - raise ValueError("dataPath: %s does not exist." % fullpath) - return fullpath -pkg_data = Data(__name__) +pkg_data = netlib.utils.Data(__name__) class LRUCache: diff --git a/netlib/encoding.py b/netlib/encoding.py index 14479e00..98502451 100644 --- a/netlib/encoding.py +++ b/netlib/encoding.py @@ -5,7 +5,6 @@ from __future__ import absolute_import from io import BytesIO import gzip import zlib -from .utils import always_byte_args ENCODINGS = {"identity", "gzip", "deflate"} diff --git a/netlib/http/__init__.py b/netlib/http/__init__.py index 917080f7..c4eb1d58 100644 --- a/netlib/http/__init__.py +++ b/netlib/http/__init__.py @@ -3,12 +3,12 @@ from .request import Request from .response import Response from .headers import Headers from .message import decoded -from . import http1, http2 +from . import http1, http2, status_codes __all__ = [ "Request", "Response", "Headers", "decoded", - "http1", "http2", + "http1", "http2", "status_codes", ] diff --git a/netlib/http/cookies.py b/netlib/http/cookies.py index 4451f1da..88c76870 100644 --- a/netlib/http/cookies.py +++ b/netlib/http/cookies.py @@ -1,8 +1,8 @@ -from six.moves import http_cookies as Cookie +import collections import re -import string from email.utils import parsedate_tz, formatdate, mktime_tz +from netlib.multidict import ImmutableMultiDict from .. import odict """ @@ -157,42 +157,76 @@ def _parse_set_cookie_pairs(s): return pairs +def parse_set_cookie_headers(headers): + ret = [] + for header in headers: + v = parse_set_cookie_header(header) + if v: + name, value, attrs = v + ret.append((name, SetCookie(value, attrs))) + return ret + + +class CookieAttrs(ImmutableMultiDict): + @staticmethod + def _kconv(key): + return key.lower() + + @staticmethod + def _reduce_values(values): + # See the StickyCookieTest for a weird cookie that only makes sense + # if we take the last part. + return values[-1] + + +SetCookie = collections.namedtuple("SetCookie", ["value", "attrs"]) + + def parse_set_cookie_header(line): """ Parse a Set-Cookie header value Returns a (name, value, attrs) tuple, or None, where attrs is an - ODictCaseless set of attributes. No attempt is made to parse attribute + CookieAttrs dict of attributes. No attempt is made to parse attribute values - they are treated purely as strings. """ pairs = _parse_set_cookie_pairs(line) if pairs: - return pairs[0][0], pairs[0][1], odict.ODictCaseless(pairs[1:]) + return pairs[0][0], pairs[0][1], CookieAttrs(tuple(x) for x in pairs[1:]) def format_set_cookie_header(name, value, attrs): """ Formats a Set-Cookie header value. """ - pairs = [[name, value]] - pairs.extend(attrs.lst) + pairs = [(name, value)] + pairs.extend( + attrs.fields if hasattr(attrs, "fields") else attrs + ) return _format_set_cookie_pairs(pairs) +def parse_cookie_headers(cookie_headers): + cookie_list = [] + for header in cookie_headers: + cookie_list.extend(parse_cookie_header(header)) + return cookie_list + + def parse_cookie_header(line): """ Parse a Cookie header value. - Returns a (possibly empty) ODict object. + Returns a list of (lhs, rhs) tuples. """ pairs, off_ = _read_pairs(line) - return odict.ODict(pairs) + return pairs -def format_cookie_header(od): +def format_cookie_header(lst): """ Formats a Cookie header value. """ - return _format_pairs(od.lst) + return _format_pairs(lst) def refresh_set_cookie_header(c, delta): @@ -209,10 +243,10 @@ def refresh_set_cookie_header(c, delta): raise ValueError("Invalid Cookie") if "expires" in attrs: - e = parsedate_tz(attrs["expires"][-1]) + e = parsedate_tz(attrs["expires"]) if e: f = mktime_tz(e) + delta - attrs["expires"] = [formatdate(f)] + attrs = attrs.with_set_all("expires", [formatdate(f)]) else: # This can happen when the expires tag is invalid. # reddit.com sends a an expires tag like this: "Thu, 31 Dec @@ -220,7 +254,7 @@ def refresh_set_cookie_header(c, delta): # strictly correct according to the cookie spec. Browsers # appear to parse this tolerantly - maybe we should too. # For now, we just ignore this. - del attrs["expires"] + attrs = attrs.with_delitem("expires") ret = format_set_cookie_header(name, value, attrs) if not ret: diff --git a/netlib/http/headers.py b/netlib/http/headers.py index 72739f90..60d3f429 100644 --- a/netlib/http/headers.py +++ b/netlib/http/headers.py @@ -1,9 +1,3 @@ -""" - -Unicode Handling ----------------- -See also: http://lucumr.pocoo.org/2013/7/2/the-updated-guide-to-unicode/ -""" from __future__ import absolute_import, print_function, division import re @@ -13,23 +7,22 @@ try: except ImportError: # pragma: no cover from collections import MutableMapping # Workaround for Python < 3.3 - import six +from ..multidict import MultiDict +from ..utils import always_bytes -from netlib.utils import always_byte_args, always_bytes, Serializable +# See also: http://lucumr.pocoo.org/2013/7/2/the-updated-guide-to-unicode/ if six.PY2: # pragma: no cover _native = lambda x: x _always_bytes = lambda x: x - _always_byte_args = lambda x: x else: # While headers _should_ be ASCII, it's not uncommon for certain headers to be utf-8 encoded. _native = lambda x: x.decode("utf-8", "surrogateescape") _always_bytes = lambda x: always_bytes(x, "utf-8", "surrogateescape") - _always_byte_args = always_byte_args("utf-8", "surrogateescape") -class Headers(MutableMapping, Serializable): +class Headers(MultiDict): """ Header class which allows both convenient access to individual headers as well as direct access to the underlying raw data. Provides a full dictionary interface. @@ -49,11 +42,11 @@ class Headers(MutableMapping, Serializable): >>> h["host"] "example.com" - # Headers can also be creatd from a list of raw (header_name, header_value) byte tuples + # Headers can also be created from a list of raw (header_name, header_value) byte tuples >>> h = Headers([ - [b"Host",b"example.com"], - [b"Accept",b"text/html"], - [b"accept",b"application/xml"] + (b"Host",b"example.com"), + (b"Accept",b"text/html"), + (b"accept",b"application/xml") ]) # Multiple headers are folded into a single header as per RFC7230 @@ -77,7 +70,6 @@ class Headers(MutableMapping, Serializable): For use with the "Set-Cookie" header, see :py:meth:`get_all`. """ - @_always_byte_args def __init__(self, fields=None, **headers): """ Args: @@ -89,19 +81,29 @@ class Headers(MutableMapping, Serializable): If ``**headers`` contains multiple keys that have equal ``.lower()`` s, the behavior is undefined. """ - self.fields = fields or [] + super(Headers, self).__init__(fields) - for name, value in self.fields: - if not isinstance(name, bytes) or not isinstance(value, bytes): - raise ValueError("Headers passed as fields must be bytes.") + for key, value in self.fields: + if not isinstance(key, bytes) or not isinstance(value, bytes): + raise TypeError("Header fields must be bytes.") # content_type -> content-type headers = { - _always_bytes(name).replace(b"_", b"-"): value + _always_bytes(name).replace(b"_", b"-"): _always_bytes(value) for name, value in six.iteritems(headers) } self.update(headers) + @staticmethod + def _reduce_values(values): + # Headers can be folded + return ", ".join(values) + + @staticmethod + def _kconv(key): + # Headers are case-insensitive + return key.lower() + def __bytes__(self): if self.fields: return b"\r\n".join(b": ".join(field) for field in self.fields) + b"\r\n" @@ -111,98 +113,40 @@ class Headers(MutableMapping, Serializable): if six.PY2: # pragma: no cover __str__ = __bytes__ - @_always_byte_args - def __getitem__(self, name): - values = self.get_all(name) - if not values: - raise KeyError(name) - return ", ".join(values) - - @_always_byte_args - def __setitem__(self, name, value): - idx = self._index(name) - - # To please the human eye, we insert at the same position the first existing header occured. - if idx is not None: - del self[name] - self.fields.insert(idx, [name, value]) - else: - self.fields.append([name, value]) - - @_always_byte_args - def __delitem__(self, name): - if name not in self: - raise KeyError(name) - name = name.lower() - self.fields = [ - field for field in self.fields - if name != field[0].lower() - ] + def __delitem__(self, key): + key = _always_bytes(key) + super(Headers, self).__delitem__(key) def __iter__(self): - seen = set() - for name, _ in self.fields: - name_lower = name.lower() - if name_lower not in seen: - seen.add(name_lower) - yield _native(name) - - def __len__(self): - return len(set(name.lower() for name, _ in self.fields)) - - # __hash__ = object.__hash__ - - def _index(self, name): - name = name.lower() - for i, field in enumerate(self.fields): - if field[0].lower() == name: - return i - return None - - def __eq__(self, other): - if isinstance(other, Headers): - return self.fields == other.fields - return False - - def __ne__(self, other): - return not self.__eq__(other) - - @_always_byte_args + for x in super(Headers, self).__iter__(): + yield _native(x) + def get_all(self, name): """ Like :py:meth:`get`, but does not fold multiple headers into a single one. This is useful for Set-Cookie headers, which do not support folding. - See also: https://tools.ietf.org/html/rfc7230#section-3.2.2 """ - name_lower = name.lower() - values = [_native(value) for n, value in self.fields if n.lower() == name_lower] - return values + name = _always_bytes(name) + return [ + _native(x) for x in + super(Headers, self).get_all(name) + ] - @_always_byte_args def set_all(self, name, values): """ Explicitly set multiple headers for the given key. See: :py:meth:`get_all` """ - values = map(_always_bytes, values) # _always_byte_args does not fix lists - if name in self: - del self[name] - self.fields.extend( - [name, value] for value in values - ) - - def get_state(self): - return tuple(tuple(field) for field in self.fields) - - def set_state(self, state): - self.fields = [list(field) for field in state] + name = _always_bytes(name) + values = [_always_bytes(x) for x in values] + return super(Headers, self).set_all(name, values) - @classmethod - def from_state(cls, state): - return cls([list(field) for field in state]) + def insert(self, index, key, value): + key = _always_bytes(key) + value = _always_bytes(value) + super(Headers, self).insert(index, key, value) - @_always_byte_args def replace(self, pattern, repl, flags=0): """ Replaces a regular expression pattern with repl in each "name: value" @@ -211,6 +155,8 @@ class Headers(MutableMapping, Serializable): Returns: The number of replacements made. """ + pattern = _always_bytes(pattern) + repl = _always_bytes(repl) pattern = re.compile(pattern, flags) replacements = 0 diff --git a/netlib/http/http1/read.py b/netlib/http/http1/read.py index 6e3a1b93..d30976bd 100644 --- a/netlib/http/http1/read.py +++ b/netlib/http/http1/read.py @@ -316,14 +316,14 @@ def _read_headers(rfile): if not ret: raise HttpSyntaxException("Invalid headers") # continued header - ret[-1][1] = ret[-1][1] + b'\r\n ' + line.strip() + ret[-1] = (ret[-1][0], ret[-1][1] + b'\r\n ' + line.strip()) else: try: name, value = line.split(b":", 1) value = value.strip() if not name: raise ValueError() - ret.append([name, value]) + ret.append((name, value)) except ValueError: raise HttpSyntaxException("Invalid headers") return Headers(ret) diff --git a/netlib/http/http2/connections.py b/netlib/http/http2/connections.py index f900b67c..6643b6b9 100644 --- a/netlib/http/http2/connections.py +++ b/netlib/http/http2/connections.py @@ -201,13 +201,13 @@ class HTTP2Protocol(object): headers = request.headers.copy() if ':authority' not in headers: - headers.fields.insert(0, (b':authority', authority.encode('ascii'))) + headers.insert(0, b':authority', authority.encode('ascii')) if ':scheme' not in headers: - headers.fields.insert(0, (b':scheme', request.scheme.encode('ascii'))) + headers.insert(0, b':scheme', request.scheme.encode('ascii')) if ':path' not in headers: - headers.fields.insert(0, (b':path', request.path.encode('ascii'))) + headers.insert(0, b':path', request.path.encode('ascii')) if ':method' not in headers: - headers.fields.insert(0, (b':method', request.method.encode('ascii'))) + headers.insert(0, b':method', request.method.encode('ascii')) if hasattr(request, 'stream_id'): stream_id = request.stream_id @@ -224,7 +224,7 @@ class HTTP2Protocol(object): headers = response.headers.copy() if ':status' not in headers: - headers.fields.insert(0, (b':status', str(response.status_code).encode('ascii'))) + headers.insert(0, b':status', str(response.status_code).encode('ascii')) if hasattr(response, 'stream_id'): stream_id = response.stream_id @@ -420,7 +420,7 @@ class HTTP2Protocol(object): self._handle_unexpected_frame(frm) headers = Headers( - [[k.encode('ascii'), v.encode('ascii')] for k, v in self.decoder.decode(header_blocks)] + (k.encode('ascii'), v.encode('ascii')) for k, v in self.decoder.decode(header_blocks) ) return stream_id, headers, body diff --git a/netlib/http/message.py b/netlib/http/message.py index da9681a0..028f43a1 100644 --- a/netlib/http/message.py +++ b/netlib/http/message.py @@ -4,6 +4,7 @@ import warnings import six +from ..multidict import MultiDict from .headers import Headers from .. import encoding, utils @@ -25,6 +26,9 @@ class MessageData(utils.Serializable): def __ne__(self, other): return not self.__eq__(other) + def __hash__(self): + return hash(frozenset(self.__dict__.items())) + def set_state(self, state): for k, v in state.items(): if k == "headers": @@ -51,6 +55,9 @@ class Message(utils.Serializable): def __ne__(self, other): return not self.__eq__(other) + def __hash__(self): + return hash(self.data) ^ 1 + def get_state(self): return self.data.get_state() diff --git a/netlib/http/request.py b/netlib/http/request.py index a42150ff..056a2d93 100644 --- a/netlib/http/request.py +++ b/netlib/http/request.py @@ -10,6 +10,7 @@ from netlib import utils from netlib.http import cookies from netlib.odict import ODict from .. import encoding +from ..multidict import MultiDictView from .headers import Headers from .message import Message, _native, _always_bytes, MessageData @@ -224,45 +225,64 @@ class Request(Message): @property def query(self): + # type: () -> MultiDictView """ - The request query string as an :py:class:`ODict` object. - None, if there is no query. + The request query string as an :py:class:`MultiDictView` object. """ + return MultiDictView( + self._get_query, + self._set_query + ) + + def _get_query(self): _, _, _, _, query, _ = urllib.parse.urlparse(self.url) - if query: - return ODict(utils.urldecode(query)) - return None + return tuple(utils.urldecode(query)) - @query.setter - def query(self, odict): - query = utils.urlencode(odict.lst) + def _set_query(self, value): + query = utils.urlencode(value) scheme, netloc, path, params, _, fragment = urllib.parse.urlparse(self.url) _, _, _, self.path = utils.parse_url( urllib.parse.urlunparse([scheme, netloc, path, params, query, fragment])) + @query.setter + def query(self, value): + self._set_query(value) + @property def cookies(self): + # type: () -> MultiDictView """ The request cookies. - An empty :py:class:`ODict` object if the cookie monster ate them all. + + An empty :py:class:`MultiDictView` object if the cookie monster ate them all. """ - ret = ODict() - for i in self.headers.get_all("Cookie"): - ret.extend(cookies.parse_cookie_header(i)) - return ret + return MultiDictView( + self._get_cookies, + self._set_cookies + ) + + def _get_cookies(self): + h = self.headers.get_all("Cookie") + return tuple(cookies.parse_cookie_headers(h)) + + def _set_cookies(self, value): + self.headers["cookie"] = cookies.format_cookie_header(value) @cookies.setter - def cookies(self, odict): - self.headers["cookie"] = cookies.format_cookie_header(odict) + def cookies(self, value): + self._set_cookies(value) @property def path_components(self): """ - The URL's path components as a list of strings. + The URL's path components as a tuple of strings. Components are unquoted. """ _, _, path, _, _, _ = urllib.parse.urlparse(self.url) - return [urllib.parse.unquote(i) for i in path.split("/") if i] + # This needs to be a tuple so that it's immutable. + # Otherwise, this would fail silently: + # request.path_components.append("foo") + return tuple(urllib.parse.unquote(i) for i in path.split("/") if i) @path_components.setter def path_components(self, components): @@ -309,64 +329,53 @@ class Request(Message): @property def urlencoded_form(self): """ - The URL-encoded form data as an :py:class:`ODict` object. - None if there is no data or the content-type indicates non-form data. + The URL-encoded form data as an :py:class:`MultiDictView` object. + An empty MultiDictView if the content-type indicates non-form data + or the content could not be parsed. """ + return MultiDictView( + self._get_urlencoded_form, + self._set_urlencoded_form + ) + + def _get_urlencoded_form(self): is_valid_content_type = "application/x-www-form-urlencoded" in self.headers.get("content-type", "").lower() - if self.content and is_valid_content_type: - return ODict(utils.urldecode(self.content)) - return None + if is_valid_content_type: + return tuple(utils.urldecode(self.content)) + return () - @urlencoded_form.setter - def urlencoded_form(self, odict): + def _set_urlencoded_form(self, value): """ Sets the body to the URL-encoded form data, and adds the appropriate content-type header. This will overwrite the existing content if there is one. """ self.headers["content-type"] = "application/x-www-form-urlencoded" - self.content = utils.urlencode(odict.lst) + self.content = utils.urlencode(value) + + @urlencoded_form.setter + def urlencoded_form(self, value): + self._set_urlencoded_form(value) @property def multipart_form(self): """ - The multipart form data as an :py:class:`ODict` object. - None if there is no data or the content-type indicates non-form data. + The multipart form data as an :py:class:`MultipartFormDict` object. + None if the content-type indicates non-form data. """ + return MultiDictView( + self._get_multipart_form, + self._set_multipart_form + ) + + def _get_multipart_form(self): is_valid_content_type = "multipart/form-data" in self.headers.get("content-type", "").lower() - if self.content and is_valid_content_type: - return ODict(utils.multipartdecode(self.headers,self.content)) - return None + if is_valid_content_type: + return utils.multipartdecode(self.headers, self.content) + return () - @multipart_form.setter - def multipart_form(self, value): + def _set_multipart_form(self, value): raise NotImplementedError() - # Legacy - - def get_query(self): # pragma: no cover - warnings.warn(".get_query is deprecated, use .query instead.", DeprecationWarning) - return self.query or ODict([]) - - def set_query(self, odict): # pragma: no cover - warnings.warn(".set_query is deprecated, use .query instead.", DeprecationWarning) - self.query = odict - - def get_path_components(self): # pragma: no cover - warnings.warn(".get_path_components is deprecated, use .path_components instead.", DeprecationWarning) - return self.path_components - - def set_path_components(self, lst): # pragma: no cover - warnings.warn(".set_path_components is deprecated, use .path_components instead.", DeprecationWarning) - self.path_components = lst - - def get_form_urlencoded(self): # pragma: no cover - warnings.warn(".get_form_urlencoded is deprecated, use .urlencoded_form instead.", DeprecationWarning) - return self.urlencoded_form or ODict([]) - - def set_form_urlencoded(self, odict): # pragma: no cover - warnings.warn(".set_form_urlencoded is deprecated, use .urlencoded_form instead.", DeprecationWarning) - self.urlencoded_form = odict - - def get_form_multipart(self): # pragma: no cover - warnings.warn(".get_form_multipart is deprecated, use .multipart_form instead.", DeprecationWarning) - return self.multipart_form or ODict([]) + @multipart_form.setter + def multipart_form(self, value): + self._set_multipart_form(value) diff --git a/netlib/http/response.py b/netlib/http/response.py index 2f06149e..7d272e10 100644 --- a/netlib/http/response.py +++ b/netlib/http/response.py @@ -1,14 +1,13 @@ from __future__ import absolute_import, print_function, division -import warnings from email.utils import parsedate_tz, formatdate, mktime_tz import time from . import cookies from .headers import Headers from .message import Message, _native, _always_bytes, MessageData +from ..multidict import MultiDictView from .. import utils -from ..odict import ODict class ResponseData(MessageData): @@ -72,29 +71,35 @@ class Response(Message): @property def cookies(self): + # type: () -> MultiDictView """ - Get the contents of all Set-Cookie headers. + The response cookies. A possibly empty :py:class:`MultiDictView`, where the keys are + cookie name strings, and values are (value, attr) tuples. Value is a string, and attr is + an ODictCaseless containing cookie attributes. Within attrs, unary attributes (e.g. HTTPOnly) + are indicated by a Null value. - A possibly empty :py:class:`ODict`, where keys are cookie name strings, - and values are [value, attr] lists. Value is a string, and attr is - an ODictCaseless containing cookie attributes. Within attrs, unary - attributes (e.g. HTTPOnly) are indicated by a Null value. + Caveats: + Updating the attr """ - ret = [] - for header in self.headers.get_all("set-cookie"): - v = cookies.parse_set_cookie_header(header) - if v: - name, value, attrs = v - ret.append([name, [value, attrs]]) - return ODict(ret) + return MultiDictView( + self._get_cookies, + self._set_cookies + ) + + def _get_cookies(self): + h = self.headers.get_all("set-cookie") + return tuple(cookies.parse_set_cookie_headers(h)) + + def _set_cookies(self, value): + cookie_headers = [] + for k, v in value: + header = cookies.format_set_cookie_header(k, v[0], v[1]) + cookie_headers.append(header) + self.headers.set_all("set-cookie", cookie_headers) @cookies.setter - def cookies(self, odict): - values = [] - for i in odict.lst: - header = cookies.format_set_cookie_header(i[0], i[1][0], i[1][1]) - values.append(header) - self.headers.set_all("set-cookie", values) + def cookies(self, value): + self._set_cookies(value) def refresh(self, now=None): """ diff --git a/netlib/multidict.py b/netlib/multidict.py new file mode 100644 index 00000000..248acdec --- /dev/null +++ b/netlib/multidict.py @@ -0,0 +1,282 @@ +from __future__ import absolute_import, print_function, division + +from abc import ABCMeta, abstractmethod + +from typing import Tuple, TypeVar + +try: + from collections.abc import MutableMapping +except ImportError: # pragma: no cover + from collections import MutableMapping # Workaround for Python < 3.3 + +import six + +from .utils import Serializable + + +@six.add_metaclass(ABCMeta) +class _MultiDict(MutableMapping, Serializable): + def __repr__(self): + fields = tuple( + repr(field) + for field in self.fields + ) + return "{cls}[{fields}]".format( + cls=type(self).__name__, + fields=", ".join(fields) + ) + + @staticmethod + @abstractmethod + def _reduce_values(values): + """ + If a user accesses multidict["foo"], this method + reduces all values for "foo" to a single value that is returned. + For example, HTTP headers are folded, whereas we will just take + the first cookie we found with that name. + """ + + @staticmethod + @abstractmethod + def _kconv(key): + """ + This method converts a key to its canonical representation. + For example, HTTP headers are case-insensitive, so this method returns key.lower(). + """ + + def __getitem__(self, key): + values = self.get_all(key) + if not values: + raise KeyError(key) + return self._reduce_values(values) + + def __setitem__(self, key, value): + self.set_all(key, [value]) + + def __delitem__(self, key): + if key not in self: + raise KeyError(key) + key = self._kconv(key) + self.fields = tuple( + field for field in self.fields + if key != self._kconv(field[0]) + ) + + def __iter__(self): + seen = set() + for key, _ in self.fields: + key_kconv = self._kconv(key) + if key_kconv not in seen: + seen.add(key_kconv) + yield key + + def __len__(self): + return len(set(self._kconv(key) for key, _ in self.fields)) + + def __eq__(self, other): + if isinstance(other, MultiDict): + return self.fields == other.fields + return False + + def __ne__(self, other): + return not self.__eq__(other) + + def __hash__(self): + return hash(self.fields) + + def get_all(self, key): + """ + Return the list of all values for a given key. + If that key is not in the MultiDict, the return value will be an empty list. + """ + key = self._kconv(key) + return [ + value + for k, value in self.fields + if self._kconv(k) == key + ] + + def set_all(self, key, values): + """ + Remove the old values for a key and add new ones. + """ + key_kconv = self._kconv(key) + + new_fields = [] + for field in self.fields: + if self._kconv(field[0]) == key_kconv: + if values: + new_fields.append( + (key, values.pop(0)) + ) + else: + new_fields.append(field) + while values: + new_fields.append( + (key, values.pop(0)) + ) + self.fields = tuple(new_fields) + + def add(self, key, value): + """ + Add an additional value for the given key at the bottom. + """ + self.insert(len(self.fields), key, value) + + def insert(self, index, key, value): + """ + Insert an additional value for the given key at the specified position. + """ + item = (key, value) + self.fields = self.fields[:index] + (item,) + self.fields[index:] + + def keys(self, multi=False): + """ + Get all keys. + + Args: + multi(bool): + If True, one key per value will be returned. + If False, duplicate keys will only be returned once. + """ + return ( + k + for k, _ in self.items(multi) + ) + + def values(self, multi=False): + """ + Get all values. + + Args: + multi(bool): + If True, all values will be returned. + If False, only the first value per key will be returned. + """ + return ( + v + for _, v in self.items(multi) + ) + + def items(self, multi=False): + """ + Get all (key, value) tuples. + + Args: + multi(bool): + If True, all (key, value) pairs will be returned + If False, only the first (key, value) pair per unique key will be returned. + """ + if multi: + return self.fields + else: + return super(_MultiDict, self).items() + + def to_dict(self): + """ + Get the MultiDict as a plain Python dict. + Keys with multiple values are returned as lists. + + Example: + + .. code-block:: python + + # Simple dict with duplicate values. + >>> d + MultiDictView[("name", "value"), ("a", "false"), ("a", "42")] + >>> d.to_dict() + { + "name": "value", + "a": ["false", "42"] + } + """ + d = {} + for key in self: + values = self.get_all(key) + if len(values) == 1: + d[key] = values[0] + else: + d[key] = values + return d + + def get_state(self): + return self.fields + + def set_state(self, state): + self.fields = tuple(tuple(x) for x in state) + + @classmethod + def from_state(cls, state): + return cls(tuple(x) for x in state) + + +class MultiDict(_MultiDict): + def __init__(self, fields=None): + super(MultiDict, self).__init__() + self.fields = tuple(fields) if fields else tuple() # type: Tuple[Tuple[bytes, bytes], ...] + + +@six.add_metaclass(ABCMeta) +class ImmutableMultiDict(MultiDict): + def _immutable(self, *_): + raise TypeError('{} objects are immutable'.format(self.__class__.__name__)) + + __delitem__ = set_all = insert = _immutable + + def with_delitem(self, key): + """ + Returns: + An updated ImmutableMultiDict. The original object will not be modified. + """ + ret = self.copy() + super(ImmutableMultiDict, ret).__delitem__(key) + return ret + + def with_set_all(self, key, values): + """ + Returns: + An updated ImmutableMultiDict. The original object will not be modified. + """ + ret = self.copy() + super(ImmutableMultiDict, ret).set_all(key, values) + return ret + + def with_insert(self, index, key, value): + """ + Returns: + An updated ImmutableMultiDict. The original object will not be modified. + """ + ret = self.copy() + super(ImmutableMultiDict, ret).insert(index, key, value) + return ret + + +class MultiDictView(_MultiDict): + """ + The MultiDictView provides the MultiDict interface over calculated data. + The view itself contains no state - data is retrieved from the parent on + request, and stored back to the parent on change. + """ + def __init__(self, getter, setter): + self._getter = getter + self._setter = setter + super(MultiDictView, self).__init__() + + @staticmethod + def _kconv(key): + # All request-attributes are case-sensitive. + return key + + @staticmethod + def _reduce_values(values): + # We just return the first element if + # multiple elements exist with the same key. + return values[0] + + @property + def fields(self): + return self._getter() + + @fields.setter + def fields(self, value): + return self._setter(value) diff --git a/netlib/utils.py b/netlib/utils.py index be2701a0..7499f71f 100644 --- a/netlib/utils.py +++ b/netlib/utils.py @@ -51,17 +51,6 @@ def always_bytes(unicode_or_bytes, *encode_args): return unicode_or_bytes -def always_byte_args(*encode_args): - """Decorator that transparently encodes all arguments passed as unicode""" - def decorator(fun): - def _fun(*args, **kwargs): - args = [always_bytes(arg, *encode_args) for arg in args] - kwargs = {k: always_bytes(v, *encode_args) for k, v in six.iteritems(kwargs)} - return fun(*args, **kwargs) - return _fun - return decorator - - def native(s, *encoding_opts): """ Convert :py:class:`bytes` or :py:class:`unicode` to the native diff --git a/pathod/utils.py b/pathod/utils.py index 1e5bd9a4..d1e2dd00 100644 --- a/pathod/utils.py +++ b/pathod/utils.py @@ -1,5 +1,6 @@ import os import sys +import netlib.utils SIZE_UNITS = dict( @@ -75,27 +76,7 @@ def escape_unprintables(s): return s -class Data(object): - - def __init__(self, name): - m = __import__(name) - dirname, _ = os.path.split(m.__file__) - self.dirname = os.path.abspath(dirname) - - def path(self, path): - """ - Returns a path to the package data housed at 'path' under this - module.Path can be a path to a file, or to a directory. - - This function will raise ValueError if the path does not exist. - """ - fullpath = os.path.join(self.dirname, path) - if not os.path.exists(fullpath): - raise ValueError("dataPath: %s does not exist." % fullpath) - return fullpath - - -data = Data(__name__) +data = netlib.utils.Data(__name__) def daemonize(stdin='/dev/null', stdout='/dev/null', stderr='/dev/null'): # pragma: no cover diff --git a/test/mitmproxy/test_examples.py b/test/mitmproxy/test_examples.py index c401a6b9..c4b06f4b 100644 --- a/test/mitmproxy/test_examples.py +++ b/test/mitmproxy/test_examples.py @@ -5,11 +5,12 @@ from contextlib import contextmanager from mitmproxy import utils, script from mitmproxy.proxy import config +import netlib.utils from netlib import tutils as netutils from netlib.http import Headers from . import tservers, tutils -example_dir = utils.Data(__name__).path("../../examples") +example_dir = netlib.utils.Data(__name__).path("../../examples") class DummyContext(object): @@ -94,14 +95,22 @@ def test_modify_form(): flow = tutils.tflow(req=netutils.treq(headers=form_header)) with example("modify_form.py") as ex: ex.run("request", flow) - assert flow.request.urlencoded_form["mitmproxy"] == ["rocks"] + assert flow.request.urlencoded_form["mitmproxy"] == "rocks" + + flow.request.headers["content-type"] = "" + ex.run("request", flow) + assert list(flow.request.urlencoded_form.items()) == [("foo", "bar")] def test_modify_querystring(): flow = tutils.tflow(req=netutils.treq(path="/search?q=term")) with example("modify_querystring.py") as ex: ex.run("request", flow) - assert flow.request.query["mitmproxy"] == ["rocks"] + assert flow.request.query["mitmproxy"] == "rocks" + + flow.request.path = "/" + ex.run("request", flow) + assert flow.request.query["mitmproxy"] == "rocks" def test_modify_response_body(): diff --git a/test/mitmproxy/test_flow.py b/test/mitmproxy/test_flow.py index b9c6a2f6..bf417423 100644 --- a/test/mitmproxy/test_flow.py +++ b/test/mitmproxy/test_flow.py @@ -1067,60 +1067,6 @@ class TestRequest: assert r.url == "https://address:22/path" assert r.pretty_url == "https://foo.com:22/path" - def test_path_components(self): - r = HTTPRequest.wrap(netlib.tutils.treq()) - r.path = "/" - assert r.get_path_components() == [] - r.path = "/foo/bar" - assert r.get_path_components() == ["foo", "bar"] - q = odict.ODict() - q["test"] = ["123"] - r.set_query(q) - assert r.get_path_components() == ["foo", "bar"] - - r.set_path_components([]) - assert r.get_path_components() == [] - r.set_path_components(["foo"]) - assert r.get_path_components() == ["foo"] - r.set_path_components(["/oo"]) - assert r.get_path_components() == ["/oo"] - assert "%2F" in r.path - - def test_getset_form_urlencoded(self): - d = odict.ODict([("one", "two"), ("three", "four")]) - r = HTTPRequest.wrap(netlib.tutils.treq(content=netlib.utils.urlencode(d.lst))) - r.headers["content-type"] = "application/x-www-form-urlencoded" - assert r.get_form_urlencoded() == d - - d = odict.ODict([("x", "y")]) - r.set_form_urlencoded(d) - assert r.get_form_urlencoded() == d - - r.headers["content-type"] = "foo" - assert not r.get_form_urlencoded() - - def test_getset_query(self): - r = HTTPRequest.wrap(netlib.tutils.treq()) - r.path = "/foo?x=y&a=b" - q = r.get_query() - assert q.lst == [("x", "y"), ("a", "b")] - - r.path = "/" - q = r.get_query() - assert not q - - r.path = "/?adsfa" - q = r.get_query() - assert q.lst == [("adsfa", "")] - - r.path = "/foo?x=y&a=b" - assert r.get_query() - r.set_query(odict.ODict([])) - assert not r.get_query() - qv = odict.ODict([("a", "b"), ("c", "d")]) - r.set_query(qv) - assert r.get_query() == qv - def test_anticache(self): r = HTTPRequest.wrap(netlib.tutils.treq()) r.headers = Headers() diff --git a/test/mitmproxy/test_flow_export.py b/test/mitmproxy/test_flow_export.py index 035f07b7..c252c5bd 100644 --- a/test/mitmproxy/test_flow_export.py +++ b/test/mitmproxy/test_flow_export.py @@ -21,7 +21,7 @@ def python_equals(testdata, text): assert clean_blanks(text).rstrip() == clean_blanks(d).rstrip() -req_get = lambda: netlib.tutils.treq(method='GET', content='') +req_get = lambda: netlib.tutils.treq(method='GET', content='', path=b"/path?a=foo&a=bar&b=baz") req_post = lambda: netlib.tutils.treq(method='POST', headers=None) @@ -31,7 +31,7 @@ req_patch = lambda: netlib.tutils.treq(method='PATCH', path=b"/path?query=param" class TestExportCurlCommand(): def test_get(self): flow = tutils.tflow(req=req_get()) - result = """curl -H 'header:qvalue' -H 'content-length:7' 'http://address/path'""" + result = """curl -H 'header:qvalue' -H 'content-length:7' 'http://address/path?a=foo&a=bar&b=baz'""" assert flow_export.curl_command(flow) == result def test_post(self): @@ -70,7 +70,7 @@ class TestRawRequest(): def test_get(self): flow = tutils.tflow(req=req_get()) result = dedent(""" - GET /path HTTP/1.1\r + GET /path?a=foo&a=bar&b=baz HTTP/1.1\r header: qvalue\r content-length: 7\r host: address:22\r diff --git a/test/mitmproxy/test_flow_export/locust_get.py b/test/mitmproxy/test_flow_export/locust_get.py index 72d5932a..632d5d53 100644 --- a/test/mitmproxy/test_flow_export/locust_get.py +++ b/test/mitmproxy/test_flow_export/locust_get.py @@ -14,10 +14,16 @@ class UserBehavior(TaskSet): 'content-length': '7', } + params = { + 'a': ['foo', 'bar'], + 'b': 'baz', + } + self.response = self.client.request( method='GET', url=url, headers=headers, + params=params, ) ### Additional tasks can go here ### diff --git a/test/mitmproxy/test_flow_export/locust_task_get.py b/test/mitmproxy/test_flow_export/locust_task_get.py index 76f144fa..03821cd8 100644 --- a/test/mitmproxy/test_flow_export/locust_task_get.py +++ b/test/mitmproxy/test_flow_export/locust_task_get.py @@ -7,8 +7,14 @@ 'content-length': '7', } + params = { + 'a': ['foo', 'bar'], + 'b': 'baz', + } + self.response = self.client.request( method='GET', url=url, headers=headers, + params=params, ) diff --git a/test/mitmproxy/test_flow_export/python_get.py b/test/mitmproxy/test_flow_export/python_get.py index ee3f48eb..af8f7c81 100644 --- a/test/mitmproxy/test_flow_export/python_get.py +++ b/test/mitmproxy/test_flow_export/python_get.py @@ -7,10 +7,16 @@ headers = { 'content-length': '7', } +params = { + 'a': ['foo', 'bar'], + 'b': 'baz', +} + response = requests.request( method='GET', url=url, headers=headers, + params=params, ) print(response.text) diff --git a/test/mitmproxy/tutils.py b/test/mitmproxy/tutils.py index d51ac185..2dfd710e 100644 --- a/test/mitmproxy/tutils.py +++ b/test/mitmproxy/tutils.py @@ -8,6 +8,7 @@ from contextlib import contextmanager from unittest.case import SkipTest +import netlib.utils import netlib.tutils from mitmproxy import utils, controller from mitmproxy.models import ( @@ -163,4 +164,4 @@ def capture_stderr(command, *args, **kwargs): sys.stderr = out -test_data = utils.Data(__name__) +test_data = netlib.utils.Data(__name__) diff --git a/test/netlib/http/http1/test_read.py b/test/netlib/http/http1/test_read.py index 90234070..d8106904 100644 --- a/test/netlib/http/http1/test_read.py +++ b/test/netlib/http/http1/test_read.py @@ -261,7 +261,7 @@ class TestReadHeaders(object): b"\r\n" ) headers = self._read(data) - assert headers.fields == [[b"Header", b"one"], [b"Header2", b"two"]] + assert headers.fields == ((b"Header", b"one"), (b"Header2", b"two")) def test_read_multi(self): data = ( @@ -270,7 +270,7 @@ class TestReadHeaders(object): b"\r\n" ) headers = self._read(data) - assert headers.fields == [[b"Header", b"one"], [b"Header", b"two"]] + assert headers.fields == ((b"Header", b"one"), (b"Header", b"two")) def test_read_continued(self): data = ( @@ -280,7 +280,7 @@ class TestReadHeaders(object): b"\r\n" ) headers = self._read(data) - assert headers.fields == [[b"Header", b"one\r\n two"], [b"Header2", b"three"]] + assert headers.fields == ((b"Header", b"one\r\n two"), (b"Header2", b"three")) def test_read_continued_err(self): data = b"\tfoo: bar\r\n" @@ -300,7 +300,7 @@ class TestReadHeaders(object): def test_read_empty_value(self): data = b"bar:" headers = self._read(data) - assert headers.fields == [[b"bar", b""]] + assert headers.fields == ((b"bar", b""),) def test_read_chunked(): req = treq(content=None) diff --git a/test/netlib/http/http2/test_connections.py b/test/netlib/http/http2/test_connections.py index 7b003067..7d240c0e 100644 --- a/test/netlib/http/http2/test_connections.py +++ b/test/netlib/http/http2/test_connections.py @@ -312,7 +312,7 @@ class TestReadRequest(tservers.ServerTestBase): req = protocol.read_request(NotImplemented) assert req.stream_id - assert req.headers.fields == [[b':method', b'GET'], [b':path', b'/'], [b':scheme', b'https']] + assert req.headers.fields == ((b':method', b'GET'), (b':path', b'/'), (b':scheme', b'https')) assert req.content == b'foobar' @@ -418,7 +418,7 @@ class TestReadResponse(tservers.ServerTestBase): assert resp.http_version == "HTTP/2.0" assert resp.status_code == 200 assert resp.reason == '' - assert resp.headers.fields == [[b':status', b'200'], [b'etag', b'foobar']] + assert resp.headers.fields == ((b':status', b'200'), (b'etag', b'foobar')) assert resp.content == b'foobar' assert resp.timestamp_end @@ -445,7 +445,7 @@ class TestReadEmptyResponse(tservers.ServerTestBase): assert resp.http_version == "HTTP/2.0" assert resp.status_code == 200 assert resp.reason == '' - assert resp.headers.fields == [[b':status', b'200'], [b'etag', b'foobar']] + assert resp.headers.fields == ((b':status', b'200'), (b'etag', b'foobar')) assert resp.content == b'' diff --git a/test/netlib/http/test_cookies.py b/test/netlib/http/test_cookies.py index da28850f..6f84c4ce 100644 --- a/test/netlib/http/test_cookies.py +++ b/test/netlib/http/test_cookies.py @@ -128,10 +128,10 @@ def test_cookie_roundtrips(): ] for s, lst in pairs: ret = cookies.parse_cookie_header(s) - assert ret.lst == lst + assert ret == lst s2 = cookies.format_cookie_header(ret) ret = cookies.parse_cookie_header(s2) - assert ret.lst == lst + assert ret == lst def test_parse_set_cookie_pairs(): @@ -197,24 +197,28 @@ def test_parse_set_cookie_header(): ], [ "one=uno", - ("one", "uno", []) + ("one", "uno", ()) ], [ "one=uno; foo=bar", - ("one", "uno", [["foo", "bar"]]) - ] + ("one", "uno", (("foo", "bar"),)) + ], + [ + "one=uno; foo=bar; foo=baz", + ("one", "uno", (("foo", "bar"), ("foo", "baz"))) + ], ] for s, expected in vals: ret = cookies.parse_set_cookie_header(s) if expected: assert ret[0] == expected[0] assert ret[1] == expected[1] - assert ret[2].lst == expected[2] + assert ret[2].items(multi=True) == expected[2] s2 = cookies.format_set_cookie_header(*ret) ret2 = cookies.parse_set_cookie_header(s2) assert ret2[0] == expected[0] assert ret2[1] == expected[1] - assert ret2[2].lst == expected[2] + assert ret2[2].items(multi=True) == expected[2] else: assert ret is None diff --git a/test/netlib/http/test_headers.py b/test/netlib/http/test_headers.py index 8c1db9dc..cd2ca9d1 100644 --- a/test/netlib/http/test_headers.py +++ b/test/netlib/http/test_headers.py @@ -5,10 +5,10 @@ from netlib.tutils import raises class TestHeaders(object): def _2host(self): return Headers( - [ - [b"Host", b"example.com"], - [b"host", b"example.org"] - ] + ( + (b"Host", b"example.com"), + (b"host", b"example.org") + ) ) def test_init(self): @@ -38,20 +38,10 @@ class TestHeaders(object): assert headers["Host"] == "example.com" assert headers["Accept"] == "text/plain" - with raises(ValueError): + with raises(TypeError): Headers([[b"Host", u"not-bytes"]]) - def test_getitem(self): - headers = Headers(Host="example.com") - assert headers["Host"] == "example.com" - assert headers["host"] == "example.com" - with raises(KeyError): - _ = headers["Accept"] - - headers = self._2host() - assert headers["Host"] == "example.com, example.org" - - def test_str(self): + def test_bytes(self): headers = Headers(Host="example.com") assert bytes(headers) == b"Host: example.com\r\n" @@ -64,93 +54,6 @@ class TestHeaders(object): headers = Headers() assert bytes(headers) == b"" - def test_setitem(self): - headers = Headers() - headers["Host"] = "example.com" - assert "Host" in headers - assert "host" in headers - assert headers["Host"] == "example.com" - - headers["host"] = "example.org" - assert "Host" in headers - assert "host" in headers - assert headers["Host"] == "example.org" - - headers["accept"] = "text/plain" - assert len(headers) == 2 - assert "Accept" in headers - assert "Host" in headers - - headers = self._2host() - assert len(headers.fields) == 2 - headers["Host"] = "example.com" - assert len(headers.fields) == 1 - assert "Host" in headers - - def test_delitem(self): - headers = Headers(Host="example.com") - assert len(headers) == 1 - del headers["host"] - assert len(headers) == 0 - try: - del headers["host"] - except KeyError: - assert True - else: - assert False - - headers = self._2host() - del headers["Host"] - assert len(headers) == 0 - - def test_keys(self): - headers = Headers(Host="example.com") - assert list(headers.keys()) == ["Host"] - - headers = self._2host() - assert list(headers.keys()) == ["Host"] - - def test_eq_ne(self): - headers1 = Headers(Host="example.com") - headers2 = Headers(host="example.com") - assert not (headers1 == headers2) - assert headers1 != headers2 - - headers1 = Headers(Host="example.com") - headers2 = Headers(Host="example.com") - assert headers1 == headers2 - assert not (headers1 != headers2) - - assert headers1 != 42 - - def test_get_all(self): - headers = self._2host() - assert headers.get_all("host") == ["example.com", "example.org"] - assert headers.get_all("accept") == [] - - def test_set_all(self): - headers = Headers(Host="example.com") - headers.set_all("Accept", ["text/plain"]) - assert len(headers) == 2 - assert "accept" in headers - - headers = self._2host() - headers.set_all("Host", ["example.org"]) - assert headers["host"] == "example.org" - - headers.set_all("Host", ["example.org", "example.net"]) - assert headers["host"] == "example.org, example.net" - - def test_state(self): - headers = self._2host() - assert len(headers.get_state()) == 2 - assert headers == Headers.from_state(headers.get_state()) - - headers2 = Headers() - assert headers != headers2 - headers2.set_state(headers.get_state()) - assert headers == headers2 - def test_replace_simple(self): headers = Headers(Host="example.com", Accept="text/plain") replacements = headers.replace("Host: ", "X-Host: ") diff --git a/test/netlib/http/test_request.py b/test/netlib/http/test_request.py index 7ed6bd0f..fae7aefe 100644 --- a/test/netlib/http/test_request.py +++ b/test/netlib/http/test_request.py @@ -3,16 +3,14 @@ from __future__ import absolute_import, print_function, division import six -from netlib import utils from netlib.http import Headers -from netlib.odict import ODict from netlib.tutils import treq, raises from .test_message import _test_decoded_attr, _test_passthrough_attr class TestRequestData(object): def test_init(self): - with raises(ValueError if six.PY2 else TypeError): + with raises(ValueError): treq(headers="foobar") assert isinstance(treq(headers=None).headers, Headers) @@ -158,16 +156,17 @@ class TestRequestUtils(object): def test_get_query(self): request = treq() - assert request.query is None + assert not request.query request.url = "http://localhost:80/foo?bar=42" - assert request.query.lst == [("bar", "42")] + assert dict(request.query) == {"bar": "42"} def test_set_query(self): - request = treq(host=b"foo", headers = Headers(host=b"bar")) - request.query = ODict([]) - assert request.host == "foo" - assert request.headers["host"] == "bar" + request = treq() + assert not request.query + request.query["foo"] = "bar" + assert request.query["foo"] == "bar" + assert request.path == "/path?foo=bar" def test_get_cookies_none(self): request = treq() @@ -177,47 +176,50 @@ class TestRequestUtils(object): def test_get_cookies_single(self): request = treq() request.headers = Headers(cookie="cookiename=cookievalue") - result = request.cookies - assert len(result) == 1 - assert result['cookiename'] == ['cookievalue'] + assert len(request.cookies) == 1 + assert request.cookies['cookiename'] == 'cookievalue' def test_get_cookies_double(self): request = treq() request.headers = Headers(cookie="cookiename=cookievalue;othercookiename=othercookievalue") result = request.cookies assert len(result) == 2 - assert result['cookiename'] == ['cookievalue'] - assert result['othercookiename'] == ['othercookievalue'] + assert result['cookiename'] == 'cookievalue' + assert result['othercookiename'] == 'othercookievalue' def test_get_cookies_withequalsign(self): request = treq() request.headers = Headers(cookie="cookiename=coo=kievalue;othercookiename=othercookievalue") result = request.cookies assert len(result) == 2 - assert result['cookiename'] == ['coo=kievalue'] - assert result['othercookiename'] == ['othercookievalue'] + assert result['cookiename'] == 'coo=kievalue' + assert result['othercookiename'] == 'othercookievalue' def test_set_cookies(self): request = treq() request.headers = Headers(cookie="cookiename=cookievalue") result = request.cookies - result["cookiename"] = ["foo"] - request.cookies = result - assert request.cookies["cookiename"] == ["foo"] + result["cookiename"] = "foo" + assert request.cookies["cookiename"] == "foo" def test_get_path_components(self): request = treq(path=b"/foo/bar") - assert request.path_components == ["foo", "bar"] + assert request.path_components == ("foo", "bar") def test_set_path_components(self): - request = treq(host=b"foo", headers = Headers(host=b"bar")) + request = treq() request.path_components = ["foo", "baz"] assert request.path == "/foo/baz" + request.path_components = [] assert request.path == "/" - request.query = ODict([]) - assert request.host == "foo" - assert request.headers["host"] == "bar" + + request.path_components = ["foo", "baz"] + request.query["hello"] = "hello" + assert request.path_components == ("foo", "baz") + + request.path_components = ["abc"] + assert request.path == "/abc?hello=hello" def test_anticache(self): request = treq() @@ -246,26 +248,21 @@ class TestRequestUtils(object): assert "gzip" in request.headers["Accept-Encoding"] def test_get_urlencoded_form(self): - request = treq(content="foobar") - assert request.urlencoded_form is None + request = treq(content="foobar=baz") + assert not request.urlencoded_form request.headers["Content-Type"] = "application/x-www-form-urlencoded" - assert request.urlencoded_form == ODict(utils.urldecode(request.content)) + assert list(request.urlencoded_form.items()) == [("foobar", "baz")] def test_set_urlencoded_form(self): request = treq() - request.urlencoded_form = ODict([('foo', 'bar'), ('rab', 'oof')]) + request.urlencoded_form = [('foo', 'bar'), ('rab', 'oof')] assert request.headers["Content-Type"] == "application/x-www-form-urlencoded" assert request.content def test_get_multipart_form(self): request = treq(content="foobar") - assert request.multipart_form is None + assert not request.multipart_form request.headers["Content-Type"] = "multipart/form-data" - assert request.multipart_form == ODict( - utils.multipartdecode( - request.headers, - request.content - ) - ) + assert list(request.multipart_form.items()) == [] diff --git a/test/netlib/http/test_response.py b/test/netlib/http/test_response.py index 5440176c..cfd093d4 100644 --- a/test/netlib/http/test_response.py +++ b/test/netlib/http/test_response.py @@ -6,6 +6,7 @@ import six import time from netlib.http import Headers +from netlib.http.cookies import CookieAttrs from netlib.odict import ODict, ODictCaseless from netlib.tutils import raises, tresp from .test_message import _test_passthrough_attr, _test_decoded_attr @@ -13,7 +14,7 @@ from .test_message import _test_passthrough_attr, _test_decoded_attr class TestResponseData(object): def test_init(self): - with raises(ValueError if six.PY2 else TypeError): + with raises(ValueError): tresp(headers="foobar") assert isinstance(tresp(headers=None).headers, Headers) @@ -56,7 +57,7 @@ class TestResponseUtils(object): result = resp.cookies assert len(result) == 1 assert "cookiename" in result - assert result["cookiename"][0] == ["cookievalue", ODict()] + assert result["cookiename"] == ("cookievalue", CookieAttrs()) def test_get_cookies_with_parameters(self): resp = tresp() @@ -64,13 +65,13 @@ class TestResponseUtils(object): result = resp.cookies assert len(result) == 1 assert "cookiename" in result - assert result["cookiename"][0][0] == "cookievalue" - attrs = result["cookiename"][0][1] + assert result["cookiename"][0] == "cookievalue" + attrs = result["cookiename"][1] assert len(attrs) == 4 - assert attrs["domain"] == ["example.com"] - assert attrs["expires"] == ["Wed Oct 21 16:29:41 2015"] - assert attrs["path"] == ["/"] - assert attrs["httponly"] == [None] + assert attrs["domain"] == "example.com" + assert attrs["expires"] == "Wed Oct 21 16:29:41 2015" + assert attrs["path"] == "/" + assert attrs["httponly"] is None def test_get_cookies_no_value(self): resp = tresp() @@ -78,8 +79,8 @@ class TestResponseUtils(object): result = resp.cookies assert len(result) == 1 assert "cookiename" in result - assert result["cookiename"][0][0] == "" - assert len(result["cookiename"][0][1]) == 2 + assert result["cookiename"][0] == "" + assert len(result["cookiename"][1]) == 2 def test_get_cookies_twocookies(self): resp = tresp() @@ -90,19 +91,16 @@ class TestResponseUtils(object): result = resp.cookies assert len(result) == 2 assert "cookiename" in result - assert result["cookiename"][0] == ["cookievalue", ODict()] + assert result["cookiename"] == ("cookievalue", CookieAttrs()) assert "othercookie" in result - assert result["othercookie"][0] == ["othervalue", ODict()] + assert result["othercookie"] == ("othervalue", CookieAttrs()) def test_set_cookies(self): resp = tresp() - v = resp.cookies - v.add("foo", ["bar", ODictCaseless()]) - resp.cookies = v + resp.cookies["foo"] = ("bar", {}) - v = resp.cookies - assert len(v) == 1 - assert v["foo"] == [["bar", ODictCaseless()]] + assert len(resp.cookies) == 1 + assert resp.cookies["foo"] == ("bar", CookieAttrs()) def test_refresh(self): r = tresp() diff --git a/test/netlib/test_multidict.py b/test/netlib/test_multidict.py new file mode 100644 index 00000000..5bb65e3f --- /dev/null +++ b/test/netlib/test_multidict.py @@ -0,0 +1,239 @@ +from netlib import tutils +from netlib.multidict import MultiDict, ImmutableMultiDict, MultiDictView + + +class _TMulti(object): + @staticmethod + def _reduce_values(values): + return values[0] + + @staticmethod + def _kconv(key): + return key.lower() + + +class TMultiDict(_TMulti, MultiDict): + pass + + +class TImmutableMultiDict(_TMulti, ImmutableMultiDict): + pass + + +class TestMultiDict(object): + @staticmethod + def _multi(): + return TMultiDict(( + ("foo", "bar"), + ("bar", "baz"), + ("Bar", "bam") + )) + + def test_init(self): + md = TMultiDict() + assert len(md) == 0 + + md = TMultiDict([("foo", "bar")]) + assert len(md) == 1 + assert md.fields == (("foo", "bar"),) + + def test_repr(self): + assert repr(self._multi()) == ( + "TMultiDict[('foo', 'bar'), ('bar', 'baz'), ('Bar', 'bam')]" + ) + + def test_getitem(self): + md = TMultiDict([("foo", "bar")]) + assert "foo" in md + assert "Foo" in md + assert md["foo"] == "bar" + + with tutils.raises(KeyError): + _ = md["bar"] + + md_multi = TMultiDict( + [("foo", "a"), ("foo", "b")] + ) + assert md_multi["foo"] == "a" + + def test_setitem(self): + md = TMultiDict() + md["foo"] = "bar" + assert md.fields == (("foo", "bar"),) + + md["foo"] = "baz" + assert md.fields == (("foo", "baz"),) + + md["bar"] = "bam" + assert md.fields == (("foo", "baz"), ("bar", "bam")) + + def test_delitem(self): + md = self._multi() + del md["foo"] + assert "foo" not in md + assert "bar" in md + + with tutils.raises(KeyError): + del md["foo"] + + del md["bar"] + assert md.fields == () + + def test_iter(self): + md = self._multi() + assert list(md.__iter__()) == ["foo", "bar"] + + def test_len(self): + md = TMultiDict() + assert len(md) == 0 + + md = self._multi() + assert len(md) == 2 + + def test_eq(self): + assert TMultiDict() == TMultiDict() + assert not (TMultiDict() == 42) + + md1 = self._multi() + md2 = self._multi() + assert md1 == md2 + md1.fields = md1.fields[1:] + md1.fields[:1] + assert not (md1 == md2) + + def test_ne(self): + assert not TMultiDict() != TMultiDict() + assert TMultiDict() != self._multi() + assert TMultiDict() != 42 + + def test_get_all(self): + md = self._multi() + assert md.get_all("foo") == ["bar"] + assert md.get_all("bar") == ["baz", "bam"] + assert md.get_all("baz") == [] + + def test_set_all(self): + md = TMultiDict() + md.set_all("foo", ["bar", "baz"]) + assert md.fields == (("foo", "bar"), ("foo", "baz")) + + md = TMultiDict(( + ("a", "b"), + ("x", "x"), + ("c", "d"), + ("X", "x"), + ("e", "f"), + )) + md.set_all("x", ["1", "2", "3"]) + assert md.fields == ( + ("a", "b"), + ("x", "1"), + ("c", "d"), + ("x", "2"), + ("e", "f"), + ("x", "3"), + ) + md.set_all("x", ["4"]) + assert md.fields == ( + ("a", "b"), + ("x", "4"), + ("c", "d"), + ("e", "f"), + ) + + def test_add(self): + md = self._multi() + md.add("foo", "foo") + assert md.fields == ( + ("foo", "bar"), + ("bar", "baz"), + ("Bar", "bam"), + ("foo", "foo") + ) + + def test_insert(self): + md = TMultiDict([("b", "b")]) + md.insert(0, "a", "a") + md.insert(2, "c", "c") + assert md.fields == (("a", "a"), ("b", "b"), ("c", "c")) + + def test_keys(self): + md = self._multi() + assert list(md.keys()) == ["foo", "bar"] + assert list(md.keys(multi=True)) == ["foo", "bar", "Bar"] + + def test_values(self): + md = self._multi() + assert list(md.values()) == ["bar", "baz"] + assert list(md.values(multi=True)) == ["bar", "baz", "bam"] + + def test_items(self): + md = self._multi() + assert list(md.items()) == [("foo", "bar"), ("bar", "baz")] + assert list(md.items(multi=True)) == [("foo", "bar"), ("bar", "baz"), ("Bar", "bam")] + + def test_to_dict(self): + md = self._multi() + assert md.to_dict() == { + "foo": "bar", + "bar": ["baz", "bam"] + } + + def test_state(self): + md = self._multi() + assert len(md.get_state()) == 3 + assert md == TMultiDict.from_state(md.get_state()) + + md2 = TMultiDict() + assert md != md2 + md2.set_state(md.get_state()) + assert md == md2 + + +class TestImmutableMultiDict(object): + def test_modify(self): + md = TImmutableMultiDict() + with tutils.raises(TypeError): + md["foo"] = "bar" + + with tutils.raises(TypeError): + del md["foo"] + + with tutils.raises(TypeError): + md.add("foo", "bar") + + def test_with_delitem(self): + md = TImmutableMultiDict([("foo", "bar")]) + assert md.with_delitem("foo").fields == () + assert md.fields == (("foo", "bar"),) + + def test_with_set_all(self): + md = TImmutableMultiDict() + assert md.with_set_all("foo", ["bar"]).fields == (("foo", "bar"),) + assert md.fields == () + + def test_with_insert(self): + md = TImmutableMultiDict() + assert md.with_insert(0, "foo", "bar").fields == (("foo", "bar"),) + + +class TParent(object): + def __init__(self): + self.vals = tuple() + + def setter(self, vals): + self.vals = vals + + def getter(self): + return self.vals + + +class TestMultiDictView(object): + def test_modify(self): + p = TParent() + tv = MultiDictView(p.getter, p.setter) + assert len(tv) == 0 + tv["a"] = "b" + assert p.vals == (("a", "b"),) + tv["c"] = "b" + assert p.vals == (("a", "b"), ("c", "b")) + assert tv["a"] == "b" diff --git a/test/pathod/test_pathod.py b/test/pathod/test_pathod.py index 10f3b5a3..05a3962e 100644 --- a/test/pathod/test_pathod.py +++ b/test/pathod/test_pathod.py @@ -233,6 +233,7 @@ class CommonTests(tutils.DaemonTests): # FIXME: Race Condition? assert "Parse error" in self.d.text_log() + @pytest.mark.skip(reason="race condition") def test_websocket_frame_disconnect_error(self): self.pathoc(["ws:/p/", "wf:b@10:d3"], ws_read_limit=0) assert self.d.last_log() diff --git a/test/pathod/tutils.py b/test/pathod/tutils.py index 9739afde..f6ed3efb 100644 --- a/test/pathod/tutils.py +++ b/test/pathod/tutils.py @@ -116,7 +116,7 @@ tmpdir = netlib.tutils.tmpdir raises = netlib.tutils.raises -test_data = utils.Data(__name__) +test_data = netlib.utils.Data(__name__) def render(r, settings=language.Settings()): |