diff options
Diffstat (limited to 'netlib/http')
-rw-r--r-- | netlib/http/authentication.py | 4 | ||||
-rw-r--r-- | netlib/http/exceptions.py | 18 | ||||
-rw-r--r-- | netlib/http/http1/protocol.py | 41 | ||||
-rw-r--r-- | netlib/http/http2/protocol.py | 44 | ||||
-rw-r--r-- | netlib/http/semantics.py | 297 |
5 files changed, 270 insertions, 134 deletions
diff --git a/netlib/http/authentication.py b/netlib/http/authentication.py index 29b9eb3c..fe1f0d14 100644 --- a/netlib/http/authentication.py +++ b/netlib/http/authentication.py @@ -62,10 +62,10 @@ class BasicProxyAuth(NullProxyAuth): del headers[self.AUTH_HEADER] def authenticate(self, headers): - auth_value = headers.get(self.AUTH_HEADER, []) + auth_value = headers.get(self.AUTH_HEADER) if not auth_value: return False - parts = parse_http_basic_auth(auth_value[0]) + parts = parse_http_basic_auth(auth_value) if not parts: return False scheme, username, password = parts diff --git a/netlib/http/exceptions.py b/netlib/http/exceptions.py index 987a7908..8a2bbebc 100644 --- a/netlib/http/exceptions.py +++ b/netlib/http/exceptions.py @@ -1,6 +1,3 @@ -from netlib import odict - - class HttpError(Exception): def __init__(self, code, message): @@ -10,18 +7,3 @@ class HttpError(Exception): class HttpErrorConnClosed(HttpError): pass - - -class HttpAuthenticationError(Exception): - - def __init__(self, auth_headers=None): - super(HttpAuthenticationError, self).__init__( - "Proxy Authentication Required" - ) - if isinstance(auth_headers, dict): - auth_headers = odict.ODictCaseless(auth_headers.items()) - self.headers = auth_headers - self.code = 407 - - def __repr__(self): - return "Proxy Authentication Required" diff --git a/netlib/http/http1/protocol.py b/netlib/http/http1/protocol.py index 50975818..bf33a18e 100644 --- a/netlib/http/http1/protocol.py +++ b/netlib/http/http1/protocol.py @@ -3,8 +3,8 @@ import string import sys import time -from netlib import odict, utils, tcp, http -from netlib.http import semantics +from ... import utils, tcp, http +from .. import semantics, Headers from ..exceptions import * @@ -96,7 +96,7 @@ class HTTP1Protocol(semantics.ProtocolMixin): if headers is None: raise HttpError(400, "Invalid headers") - expect_header = headers.get_first("expect", "").lower() + expect_header = headers.get("expect", "").lower() if expect_header == "100-continue" and httpversion == (1, 1): self.tcp_handler.wfile.write( 'HTTP/1.1 100 Continue\r\n' @@ -232,10 +232,9 @@ class HTTP1Protocol(semantics.ProtocolMixin): Read a set of headers. Stop once a blank line is reached. - Return a ODictCaseless object, or None if headers are invalid. + Return a Header object, or None if headers are invalid. """ ret = [] - name = '' while True: line = self.tcp_handler.rfile.readline() if not line or line == '\r\n' or line == '\n': @@ -254,7 +253,7 @@ class HTTP1Protocol(semantics.ProtocolMixin): ret.append([name, value]) else: return None - return odict.ODictCaseless(ret) + return Headers(ret) def read_http_body(self, *args, **kwargs): @@ -272,7 +271,7 @@ class HTTP1Protocol(semantics.ProtocolMixin): ): """ Read an HTTP message body: - headers: An ODictCaseless object + headers: A Header object limit: Size limit. is_request: True if the body to read belongs to a request, False otherwise @@ -356,7 +355,7 @@ class HTTP1Protocol(semantics.ProtocolMixin): return None if "content-length" in headers: try: - size = int(headers["content-length"][0]) + size = int(headers["content-length"]) if size < 0: raise ValueError() return size @@ -369,9 +368,7 @@ class HTTP1Protocol(semantics.ProtocolMixin): @classmethod def has_chunked_encoding(self, headers): - return "chunked" in [ - i.lower() for i in utils.get_header_tokens(headers, "transfer-encoding") - ] + return "chunked" in headers.get("transfer-encoding", "").lower() def _get_request_line(self): @@ -547,18 +544,20 @@ class HTTP1Protocol(semantics.ProtocolMixin): def _assemble_request_headers(self, request): headers = request.headers.copy() for k in request._headers_to_strip_off: - del headers[k] + headers.pop(k, None) if 'host' not in headers and request.scheme and request.host and request.port: - headers["Host"] = [utils.hostport(request.scheme, - request.host, - request.port)] + headers["Host"] = utils.hostport( + request.scheme, + request.host, + request.port + ) # If content is defined (i.e. not None or CONTENT_MISSING), we always # add a content-length header. if request.body or request.body == "": - headers["Content-Length"] = [str(len(request.body))] + headers["Content-Length"] = str(len(request.body)) - return headers.format() + return str(headers) def _assemble_response_first_line(self, response): return 'HTTP/%s.%s %s %s' % ( @@ -575,13 +574,13 @@ class HTTP1Protocol(semantics.ProtocolMixin): ): headers = response.headers.copy() for k in response._headers_to_strip_off: - del headers[k] + headers.pop(k, None) if not preserve_transfer_encoding: - del headers['Transfer-Encoding'] + headers.pop('Transfer-Encoding', None) # If body is defined (i.e. not None or CONTENT_MISSING), we always # add a content-length header. if response.body or response.body == "": - headers["Content-Length"] = [str(len(response.body))] + headers["Content-Length"] = str(len(response.body)) - return headers.format() + return str(headers) diff --git a/netlib/http/http2/protocol.py b/netlib/http/http2/protocol.py index 4328ebdd..b6d376d3 100644 --- a/netlib/http/http2/protocol.py +++ b/netlib/http/http2/protocol.py @@ -3,7 +3,7 @@ import itertools import time from hpack.hpack import Encoder, Decoder -from netlib import http, utils, odict +from netlib import http, utils from netlib.http import semantics from . import frame @@ -85,10 +85,10 @@ class HTTP2Protocol(semantics.ProtocolMixin): timestamp_end = time.time() - authority = headers.get_first(':authority', '') - method = headers.get_first(':method', 'GET') - scheme = headers.get_first(':scheme', 'https') - path = headers.get_first(':path', '/') + authority = headers.get(':authority', '') + method = headers.get(':method', 'GET') + scheme = headers.get(':scheme', 'https') + path = headers.get(':path', '/') host = None port = None @@ -161,7 +161,7 @@ class HTTP2Protocol(semantics.ProtocolMixin): response = http.Response( (2, 0), - int(headers.get_first(':status')), + int(headers.get(':status', 502)), "", headers, body, @@ -181,16 +181,14 @@ class HTTP2Protocol(semantics.ProtocolMixin): headers = request.headers.copy() - if ':authority' not in headers.keys(): - headers.add(':authority', bytes(authority), prepend=True) - if ':scheme' not in headers.keys(): - headers.add(':scheme', bytes(request.scheme), prepend=True) - if ':path' not in headers.keys(): - headers.add(':path', bytes(request.path), prepend=True) - if ':method' not in headers.keys(): - headers.add(':method', bytes(request.method), prepend=True) - - headers = headers.items() + if ':authority' not in headers: + headers.fields.insert(0, (':authority', bytes(authority))) + if ':scheme' not in headers: + headers.fields.insert(0, (':scheme', bytes(request.scheme))) + if ':path' not in headers: + headers.fields.insert(0, (':path', bytes(request.path))) + if ':method' not in headers: + headers.fields.insert(0, (':method', bytes(request.method))) if hasattr(request, 'stream_id'): stream_id = request.stream_id @@ -206,10 +204,8 @@ class HTTP2Protocol(semantics.ProtocolMixin): headers = response.headers.copy() - if ':status' not in headers.keys(): - headers.add(':status', bytes(str(response.status_code)), prepend=True) - - headers = headers.items() + if ':status' not in headers: + headers.fields.insert(0, (':status', bytes(str(response.status_code)))) if hasattr(response, 'stream_id'): stream_id = response.stream_id @@ -336,7 +332,7 @@ class HTTP2Protocol(semantics.ProtocolMixin): else: yield frame.ContinuationFrame, i - header_block_fragment = self.encoder.encode(headers) + header_block_fragment = self.encoder.encode(headers.fields) chunk_size = self.http2_settings[frame.SettingsFrame.SETTINGS.SETTINGS_MAX_FRAME_SIZE] chunks = range(0, len(header_block_fragment), chunk_size) @@ -409,8 +405,8 @@ class HTTP2Protocol(semantics.ProtocolMixin): else: self._handle_unexpected_frame(frm) - headers = odict.ODictCaseless() - for header, value in self.decoder.decode(header_block_fragment): - headers.add(header, value) + headers = http.Headers( + [[str(k), str(v)] for k, v in self.decoder.decode(header_block_fragment)] + ) return stream_id, headers, body diff --git a/netlib/http/semantics.py b/netlib/http/semantics.py index 2b960483..edf5fc07 100644 --- a/netlib/http/semantics.py +++ b/netlib/http/semantics.py @@ -1,8 +1,10 @@ from __future__ import (absolute_import, print_function, division) +import UserDict +import copy import urllib import urlparse -from .. import utils, odict +from .. import odict from . import cookies, exceptions from netlib import utils, encoding @@ -12,8 +14,165 @@ HDR_FORM_MULTIPART = "multipart/form-data" CONTENT_MISSING = 0 -class ProtocolMixin(object): +class Headers(UserDict.DictMixin): + """ + Header class which allows both convenient access to individual headers as well as + direct access to the underlying raw data. Provides a full dictionary interface. + + Example: + + .. code-block:: python + + # Create header from a list of (header_name, header_value) tuples + >>> h = Headers([ + ["Host","example.com"], + ["Accept","text/html"], + ["accept","application/xml"] + ]) + + # Headers mostly behave like a normal dict. + >>> h["Host"] + "example.com" + + # HTTP Headers are case insensitive + >>> h["host"] + "example.com" + + # Multiple headers are folded into a single header as per RFC7230 + >>> h["Accept"] + "text/html, application/xml" + + # Setting a header removes all existing headers with the same name. + >>> h["Accept"] = "application/text" + >>> h["Accept"] + "application/text" + + # str(h) returns a HTTP1 header block. + >>> print(h) + Host: example.com + Accept: application/text + + # For full control, the raw header fields can be accessed + >>> h.fields + + # Headers can also be crated from keyword arguments + >>> h = Headers(host="example.com", content_type="application/xml") + + Caveats: + For use with the "Set-Cookie" header, see :py:meth:`get_all`. + """ + + def __init__(self, fields=None, **headers): + """ + Args: + fields: (optional) list of ``(name, value)`` header tuples, e.g. ``[("Host","example.com")]`` + **headers: Additional headers to set. Will overwrite existing values from `fields`. + For convenience, underscores in header names will be transformed to dashes - + this behaviour does not extend to other methods. + If ``**headers`` contains multiple keys that have equal ``.lower()`` s, + the behavior is undefined. + """ + self.fields = fields or [] + + # content_type -> content-type + headers = { + name.replace("_", "-"): value + for name, value in headers.iteritems() + } + self.update(headers) + + def __str__(self): + return "\r\n".join(": ".join(field) for field in self.fields) + "\r\n" + + def __getitem__(self, name): + values = self.get_all(name) + if not values: + raise KeyError(name) + else: + return ", ".join(values) + + def __setitem__(self, name, value): + idx = self._index(name) + + # To please the human eye, we insert at the same position the first existing header occured. + if idx is not None: + del self[name] + self.fields.insert(idx, [name, value]) + else: + self.fields.append([name, value]) + + def __delitem__(self, name): + if name not in self: + raise KeyError(name) + name = name.lower() + self.fields = [ + field for field in self.fields + if name != field[0].lower() + ] + + def _index(self, name): + name = name.lower() + for i, field in enumerate(self.fields): + if field[0].lower() == name: + return i + return None + + def keys(self): + seen = set() + names = [] + for name, _ in self.fields: + name_lower = name.lower() + if name_lower not in seen: + seen.add(name_lower) + names.append(name) + return names + + def __eq__(self, other): + if isinstance(other, Headers): + return self.fields == other.fields + return False + + def __ne__(self, other): + return not self.__eq__(other) + + def get_all(self, name, default=[]): + """ + Like :py:meth:`get`, but does not fold multiple headers into a single one. + This is useful for Set-Cookie headers, which do not support folding. + + See also: https://tools.ietf.org/html/rfc7230#section-3.2.2 + """ + name = name.lower() + values = [value for n, value in self.fields if n.lower() == name] + return values or default + + def set_all(self, name, values): + """ + Explicitly set multiple headers for the given key. + See: :py:meth:`get_all` + """ + if name in self: + del self[name] + self.fields.extend( + [name, value] for value in values + ) + + def copy(self): + return Headers(copy.copy(self.fields)) + + # Implement the StateObject protocol from mitmproxy + def get_state(self, short=False): + return tuple(tuple(field) for field in self.fields) + + def load_state(self, state): + self.fields = [list(field) for field in state] + + @classmethod + def from_state(cls, state): + return cls([list(field) for field in state]) + +class ProtocolMixin(object): def read_request(self, *args, **kwargs): # pragma: no cover raise NotImplementedError @@ -47,23 +206,23 @@ class Request(object): ] def __init__( - self, - form_in, - method, - scheme, - host, - port, - path, - httpversion, - headers=None, - body=None, - timestamp_start=None, - timestamp_end=None, - form_out=None + self, + form_in, + method, + scheme, + host, + port, + path, + httpversion, + headers=None, + body=None, + timestamp_start=None, + timestamp_end=None, + form_out=None ): if not headers: - headers = odict.ODictCaseless() - assert isinstance(headers, odict.ODictCaseless) + headers = Headers() + assert isinstance(headers, Headers) self.form_in = form_in self.method = method @@ -80,8 +239,10 @@ class Request(object): def __eq__(self, other): try: - self_d = [self.__dict__[k] for k in self.__dict__ if k not in ('timestamp_start', 'timestamp_end')] - other_d = [other.__dict__[k] for k in other.__dict__ if k not in ('timestamp_start', 'timestamp_end')] + self_d = [self.__dict__[k] for k in self.__dict__ if + k not in ('timestamp_start', 'timestamp_end')] + other_d = [other.__dict__[k] for k in other.__dict__ if + k not in ('timestamp_start', 'timestamp_end')] return self_d == other_d except: return False @@ -134,30 +295,35 @@ class Request(object): "if-none-match", ] for i in delheaders: - del self.headers[i] + self.headers.pop(i, None) def anticomp(self): """ Modifies this request to remove headers that will compress the resource's data. """ - self.headers["accept-encoding"] = ["identity"] + self.headers["accept-encoding"] = "identity" def constrain_encoding(self): """ Limits the permissible Accept-Encoding values, based on what we can decode appropriately. """ - if self.headers["accept-encoding"]: - self.headers["accept-encoding"] = [ + accept_encoding = self.headers.get("accept-encoding") + if accept_encoding: + self.headers["accept-encoding"] = ( ', '.join( - e for e in encoding.ENCODINGS if e in self.headers.get_first("accept-encoding"))] + e + for e in encoding.ENCODINGS + if e in accept_encoding + ) + ) def update_host_header(self): """ Update the host header to reflect the current target. """ - self.headers["Host"] = [self.host] + self.headers["Host"] = self.host def get_form(self): """ @@ -166,9 +332,9 @@ class Request(object): indicates non-form data. """ if self.body: - if self.headers.in_any("content-type", HDR_FORM_URLENCODED, True): + if HDR_FORM_URLENCODED in self.headers.get("content-type","").lower(): return self.get_form_urlencoded() - elif self.headers.in_any("content-type", HDR_FORM_MULTIPART, True): + elif HDR_FORM_MULTIPART in self.headers.get("content-type","").lower(): return self.get_form_multipart() return odict.ODict([]) @@ -178,18 +344,12 @@ class Request(object): Returns an empty ODict if there is no data or the content-type indicates non-form data. """ - if self.body and self.headers.in_any( - "content-type", - HDR_FORM_URLENCODED, - True): + if self.body and HDR_FORM_URLENCODED in self.headers.get("content-type","").lower(): return odict.ODict(utils.urldecode(self.body)) return odict.ODict([]) def get_form_multipart(self): - if self.body and self.headers.in_any( - "content-type", - HDR_FORM_MULTIPART, - True): + if self.body and HDR_FORM_MULTIPART in self.headers.get("content-type","").lower(): return odict.ODict( utils.multipartdecode( self.headers, @@ -204,7 +364,7 @@ class Request(object): """ # FIXME: If there's an existing content-type header indicating a # url-encoded form, leave it alone. - self.headers["Content-Type"] = [HDR_FORM_URLENCODED] + self.headers["Content-Type"] = HDR_FORM_URLENCODED self.body = utils.urlencode(odict.lst) def get_path_components(self): @@ -263,7 +423,7 @@ class Request(object): """ host = None if hostheader: - host = self.headers.get_first("host") + host = self.headers.get("Host") if not host: host = self.host if host: @@ -287,7 +447,7 @@ class Request(object): Returns a possibly empty netlib.odict.ODict object. """ ret = odict.ODict() - for i in self.headers["cookie"]: + for i in self.headers.get_all("cookie"): ret.extend(cookies.parse_cookie_header(i)) return ret @@ -297,7 +457,7 @@ class Request(object): headers. """ v = cookies.format_cookie_header(odict) - self.headers["Cookie"] = [v] + self.headers["Cookie"] = v @property def url(self): @@ -336,18 +496,17 @@ class Request(object): class EmptyRequest(Request): - def __init__( - self, - form_in="", - method="", - scheme="", - host="", - port="", - path="", - httpversion=(0, 0), - headers=None, - body="" + self, + form_in="", + method="", + scheme="", + host="", + port="", + path="", + httpversion=(0, 0), + headers=None, + body="" ): super(EmptyRequest, self).__init__( form_in=form_in, @@ -357,7 +516,7 @@ class EmptyRequest(Request): port=port, path=path, httpversion=httpversion, - headers=(headers or odict.ODictCaseless()), + headers=headers, body=body, ) @@ -370,19 +529,19 @@ class Response(object): ] def __init__( - self, - httpversion, - status_code, - msg=None, - headers=None, - body=None, - sslinfo=None, - timestamp_start=None, - timestamp_end=None, + self, + httpversion, + status_code, + msg=None, + headers=None, + body=None, + sslinfo=None, + timestamp_start=None, + timestamp_end=None, ): if not headers: - headers = odict.ODictCaseless() - assert isinstance(headers, odict.ODictCaseless) + headers = Headers() + assert isinstance(headers, Headers) self.httpversion = httpversion self.status_code = status_code @@ -395,8 +554,10 @@ class Response(object): def __eq__(self, other): try: - self_d = [self.__dict__[k] for k in self.__dict__ if k not in ('timestamp_start', 'timestamp_end')] - other_d = [other.__dict__[k] for k in other.__dict__ if k not in ('timestamp_start', 'timestamp_end')] + self_d = [self.__dict__[k] for k in self.__dict__ if + k not in ('timestamp_start', 'timestamp_end')] + other_d = [other.__dict__[k] for k in other.__dict__ if + k not in ('timestamp_start', 'timestamp_end')] return self_d == other_d except: return False @@ -412,9 +573,7 @@ class Response(object): return "<Response: {status_code} {msg} ({contenttype}, {size})>".format( status_code=self.status_code, msg=self.msg, - contenttype=self.headers.get_first( - "content-type", - "unknown content type"), + contenttype=self.headers.get("content-type", "unknown content type"), size=size) def get_cookies(self): @@ -427,7 +586,7 @@ class Response(object): attributes (e.g. HTTPOnly) are indicated by a Null value. """ ret = [] - for header in self.headers["set-cookie"]: + for header in self.headers.get_all("set-cookie"): v = cookies.parse_set_cookie_header(header) if v: name, value, attrs = v @@ -450,7 +609,7 @@ class Response(object): i[1][1] ) ) - self.headers["Set-Cookie"] = values + self.headers.set_all("Set-Cookie", values) @property def content(self): # pragma: no cover |