diff options
author | Maximilian Hils <git@maximilianhils.com> | 2015-09-22 01:48:35 +0200 |
---|---|---|
committer | Maximilian Hils <git@maximilianhils.com> | 2015-09-22 01:48:35 +0200 |
commit | f93752277395d201fabefed8fae6d412f13da699 (patch) | |
tree | 7f3b217b89b7d6b78725ea1a6d0185b13ab2876a | |
parent | 9fbeac50ce3f6ae49b0f0270c508b6e81a1eaf17 (diff) | |
download | mitmproxy-f93752277395d201fabefed8fae6d412f13da699.tar.gz mitmproxy-f93752277395d201fabefed8fae6d412f13da699.tar.bz2 mitmproxy-f93752277395d201fabefed8fae6d412f13da699.zip |
Headers: return str on all Python versions
-rw-r--r-- | netlib/http/__init__.py | 6 | ||||
-rw-r--r-- | netlib/http/authentication.py | 10 | ||||
-rw-r--r-- | netlib/http/headers.py | 205 | ||||
-rw-r--r-- | netlib/http/http1/assemble.py | 6 | ||||
-rw-r--r-- | netlib/http/http1/read.py | 14 | ||||
-rw-r--r-- | netlib/http/models.py | 215 | ||||
-rw-r--r-- | netlib/utils.py | 17 | ||||
-rw-r--r-- | netlib/websockets/protocol.py | 14 | ||||
-rw-r--r-- | test/http/http1/test_assemble.py | 6 | ||||
-rw-r--r-- | test/http/http1/test_read.py | 22 | ||||
-rw-r--r-- | test/http/test_authentication.py | 12 | ||||
-rw-r--r-- | test/http/test_headers.py | 149 | ||||
-rw-r--r-- | test/http/test_models.py | 152 | ||||
-rw-r--r-- | test/test_utils.py | 20 | ||||
-rw-r--r-- | test/websockets/test_websockets.py | 13 |
15 files changed, 443 insertions, 418 deletions
diff --git a/netlib/http/__init__.py b/netlib/http/__init__.py index d72884b3..0ccf6b32 100644 --- a/netlib/http/__init__.py +++ b/netlib/http/__init__.py @@ -1,11 +1,13 @@ from __future__ import absolute_import, print_function, division -from .models import Request, Response, Headers +from .headers import Headers +from .models import Request, Response from .models import ALPN_PROTO_HTTP1, ALPN_PROTO_H2 from .models import HDR_FORM_MULTIPART, HDR_FORM_URLENCODED, CONTENT_MISSING from . import http1, http2 __all__ = [ - "Request", "Response", "Headers", + "Headers", + "Request", "Response", "ALPN_PROTO_HTTP1", "ALPN_PROTO_H2", "HDR_FORM_MULTIPART", "HDR_FORM_URLENCODED", "CONTENT_MISSING", "http1", "http2", diff --git a/netlib/http/authentication.py b/netlib/http/authentication.py index 5831660b..d769abe5 100644 --- a/netlib/http/authentication.py +++ b/netlib/http/authentication.py @@ -9,18 +9,18 @@ def parse_http_basic_auth(s): return None scheme = words[0] try: - user = binascii.a2b_base64(words[1]) + user = binascii.a2b_base64(words[1]).decode("utf8", "replace") except binascii.Error: return None - parts = user.split(b':') + parts = user.split(':') if len(parts) != 2: return None return scheme, parts[0], parts[1] def assemble_http_basic_auth(scheme, username, password): - v = binascii.b2a_base64(username + b":" + password) - return scheme + b" " + v + v = binascii.b2a_base64((username + ":" + password).encode("utf8")).decode("ascii") + return scheme + " " + v class NullProxyAuth(object): @@ -69,7 +69,7 @@ class BasicProxyAuth(NullProxyAuth): if not parts: return False scheme, username, password = parts - if scheme.lower() != b'basic': + if scheme.lower() != 'basic': return False if not self.password_manager.test(username, password): return False diff --git a/netlib/http/headers.py b/netlib/http/headers.py new file mode 100644 index 00000000..1511ea2d --- /dev/null +++ b/netlib/http/headers.py @@ -0,0 +1,205 @@ +""" + +Unicode Handling +---------------- +See also: http://lucumr.pocoo.org/2013/7/2/the-updated-guide-to-unicode/ +""" +from __future__ import absolute_import, print_function, division +import copy +try: + from collections.abc import MutableMapping +except ImportError: # Workaround for Python < 3.3 + from collections import MutableMapping + + +import six + +from netlib.utils import always_byte_args + +if six.PY2: + _native = lambda x: x + _asbytes = lambda x: x + _always_byte_args = lambda x: x +else: + # While headers _should_ be ASCII, it's not uncommon for certain headers to be utf-8 encoded. + _native = lambda x: x.decode("utf-8", "surrogateescape") + _asbytes = lambda x: x.encode("utf-8", "surrogateescape") + _always_byte_args = always_byte_args("utf-8", "surrogateescape") + + +class Headers(MutableMapping, object): + """ + Header class which allows both convenient access to individual headers as well as + direct access to the underlying raw data. Provides a full dictionary interface. + + Example: + + .. code-block:: python + + # Create header from a list of (header_name, header_value) tuples + >>> h = Headers([ + ["Host","example.com"], + ["Accept","text/html"], + ["accept","application/xml"] + ]) + + # Headers mostly behave like a normal dict. + >>> h["Host"] + "example.com" + + # HTTP Headers are case insensitive + >>> h["host"] + "example.com" + + # Multiple headers are folded into a single header as per RFC7230 + >>> h["Accept"] + "text/html, application/xml" + + # Setting a header removes all existing headers with the same name. + >>> h["Accept"] = "application/text" + >>> h["Accept"] + "application/text" + + # str(h) returns a HTTP1 header block. + >>> print(h) + Host: example.com + Accept: application/text + + # For full control, the raw header fields can be accessed + >>> h.fields + + # Headers can also be crated from keyword arguments + >>> h = Headers(host="example.com", content_type="application/xml") + + Caveats: + For use with the "Set-Cookie" header, see :py:meth:`get_all`. + """ + + @_always_byte_args + def __init__(self, fields=None, **headers): + """ + Args: + fields: (optional) list of ``(name, value)`` header tuples, + e.g. ``[("Host","example.com")]``. All names and values must be bytes. + **headers: Additional headers to set. Will overwrite existing values from `fields`. + For convenience, underscores in header names will be transformed to dashes - + this behaviour does not extend to other methods. + If ``**headers`` contains multiple keys that have equal ``.lower()`` s, + the behavior is undefined. + """ + self.fields = fields or [] + + for name, value in self.fields: + if not isinstance(name, bytes) or not isinstance(value, bytes): + raise ValueError("Headers passed as fields must be bytes.") + + # content_type -> content-type + headers = { + _asbytes(name).replace(b"_", b"-"): value + for name, value in six.iteritems(headers) + } + self.update(headers) + + def __bytes__(self): + if self.fields: + return b"\r\n".join(b": ".join(field) for field in self.fields) + b"\r\n" + else: + return b"" + + if six.PY2: + __str__ = __bytes__ + + @_always_byte_args + def __getitem__(self, name): + values = self.get_all(name) + if not values: + raise KeyError(name) + return ", ".join(values) + + @_always_byte_args + def __setitem__(self, name, value): + idx = self._index(name) + + # To please the human eye, we insert at the same position the first existing header occured. + if idx is not None: + del self[name] + self.fields.insert(idx, [name, value]) + else: + self.fields.append([name, value]) + + @_always_byte_args + def __delitem__(self, name): + if name not in self: + raise KeyError(name) + name = name.lower() + self.fields = [ + field for field in self.fields + if name != field[0].lower() + ] + + def __iter__(self): + seen = set() + for name, _ in self.fields: + name_lower = name.lower() + if name_lower not in seen: + seen.add(name_lower) + yield _native(name) + + def __len__(self): + return len(set(name.lower() for name, _ in self.fields)) + + # __hash__ = object.__hash__ + + def _index(self, name): + name = name.lower() + for i, field in enumerate(self.fields): + if field[0].lower() == name: + return i + return None + + def __eq__(self, other): + if isinstance(other, Headers): + return self.fields == other.fields + return False + + def __ne__(self, other): + return not self.__eq__(other) + + @_always_byte_args + def get_all(self, name): + """ + Like :py:meth:`get`, but does not fold multiple headers into a single one. + This is useful for Set-Cookie headers, which do not support folding. + + See also: https://tools.ietf.org/html/rfc7230#section-3.2.2 + """ + name_lower = name.lower() + values = [_native(value) for n, value in self.fields if n.lower() == name_lower] + return values + + @_always_byte_args + def set_all(self, name, values): + """ + Explicitly set multiple headers for the given key. + See: :py:meth:`get_all` + """ + values = map(_asbytes, values) # _always_byte_args does not fix lists + if name in self: + del self[name] + self.fields.extend( + [name, value] for value in values + ) + + def copy(self): + return Headers(copy.copy(self.fields)) + + # Implement the StateObject protocol from mitmproxy + def get_state(self, short=False): + return tuple(tuple(field) for field in self.fields) + + def load_state(self, state): + self.fields = [list(field) for field in state] + + @classmethod + def from_state(cls, state): + return cls([list(field) for field in state])
\ No newline at end of file diff --git a/netlib/http/http1/assemble.py b/netlib/http/http1/assemble.py index c2b60a0f..88aeac05 100644 --- a/netlib/http/http1/assemble.py +++ b/netlib/http/http1/assemble.py @@ -35,7 +35,7 @@ def assemble_response_head(response): def assemble_body(headers, body_chunks): - if b"chunked" in headers.get(b"transfer-encoding", b"").lower(): + if "chunked" in headers.get("transfer-encoding", "").lower(): for chunk in body_chunks: if chunk: yield b"%x\r\n%s\r\n" % (len(chunk), chunk) @@ -76,8 +76,8 @@ def _assemble_request_line(request, form=None): def _assemble_request_headers(request): headers = request.headers.copy() - if b"host" not in headers and request.scheme and request.host and request.port: - headers[b"Host"] = utils.hostport( + if "host" not in headers and request.scheme and request.host and request.port: + headers["host"] = utils.hostport( request.scheme, request.host, request.port diff --git a/netlib/http/http1/read.py b/netlib/http/http1/read.py index c6760ff3..4c898348 100644 --- a/netlib/http/http1/read.py +++ b/netlib/http/http1/read.py @@ -146,11 +146,11 @@ def connection_close(http_version, headers): according to RFC 2616 Section 8.1. """ # At first, check if we have an explicit Connection header. - if b"connection" in headers: + if "connection" in headers: tokens = utils.get_header_tokens(headers, "connection") - if b"close" in tokens: + if "close" in tokens: return True - elif b"keep-alive" in tokens: + elif "keep-alive" in tokens: return False # If we don't have a Connection header, HTTP 1.1 connections are assumed to @@ -181,7 +181,7 @@ def expected_http_body_size(request, response=None): is_request = False if is_request: - if headers.get(b"expect", b"").lower() == b"100-continue": + if headers.get("expect", "").lower() == "100-continue": return 0 else: if request.method.upper() == b"HEAD": @@ -193,11 +193,11 @@ def expected_http_body_size(request, response=None): if response_code in (204, 304): return 0 - if b"chunked" in headers.get(b"transfer-encoding", b"").lower(): + if "chunked" in headers.get("transfer-encoding", "").lower(): return None - if b"content-length" in headers: + if "content-length" in headers: try: - size = int(headers[b"content-length"]) + size = int(headers["content-length"]) if size < 0: raise ValueError() return size diff --git a/netlib/http/models.py b/netlib/http/models.py index 512a764d..55664533 100644 --- a/netlib/http/models.py +++ b/netlib/http/models.py @@ -1,201 +1,22 @@ -from __future__ import absolute_import, print_function, division -import copy + from ..odict import ODict from .. import utils, encoding -from ..utils import always_bytes, always_byte_args, native +from ..utils import always_bytes, native from . import cookies +from .headers import Headers -import six from six.moves import urllib -try: - from collections import MutableMapping -except ImportError: - from collections.abc import MutableMapping # TODO: Move somewhere else? ALPN_PROTO_HTTP1 = b'http/1.1' ALPN_PROTO_H2 = b'h2' -HDR_FORM_URLENCODED = b"application/x-www-form-urlencoded" -HDR_FORM_MULTIPART = b"multipart/form-data" +HDR_FORM_URLENCODED = "application/x-www-form-urlencoded" +HDR_FORM_MULTIPART = "multipart/form-data" CONTENT_MISSING = 0 -class Headers(MutableMapping, object): - """ - Header class which allows both convenient access to individual headers as well as - direct access to the underlying raw data. Provides a full dictionary interface. - - Example: - - .. code-block:: python - - # Create header from a list of (header_name, header_value) tuples - >>> h = Headers([ - ["Host","example.com"], - ["Accept","text/html"], - ["accept","application/xml"] - ]) - - # Headers mostly behave like a normal dict. - >>> h["Host"] - "example.com" - - # HTTP Headers are case insensitive - >>> h["host"] - "example.com" - - # Multiple headers are folded into a single header as per RFC7230 - >>> h["Accept"] - "text/html, application/xml" - - # Setting a header removes all existing headers with the same name. - >>> h["Accept"] = "application/text" - >>> h["Accept"] - "application/text" - - # str(h) returns a HTTP1 header block. - >>> print(h) - Host: example.com - Accept: application/text - - # For full control, the raw header fields can be accessed - >>> h.fields - - # Headers can also be crated from keyword arguments - >>> h = Headers(host="example.com", content_type="application/xml") - - Caveats: - For use with the "Set-Cookie" header, see :py:meth:`get_all`. - """ - - @always_byte_args("ascii") - def __init__(self, fields=None, **headers): - """ - Args: - fields: (optional) list of ``(name, value)`` header tuples, - e.g. ``[("Host","example.com")]``. All names and values must be bytes. - **headers: Additional headers to set. Will overwrite existing values from `fields`. - For convenience, underscores in header names will be transformed to dashes - - this behaviour does not extend to other methods. - If ``**headers`` contains multiple keys that have equal ``.lower()`` s, - the behavior is undefined. - """ - self.fields = fields or [] - - # content_type -> content-type - headers = { - name.encode("ascii").replace(b"_", b"-"): value - for name, value in six.iteritems(headers) - } - self.update(headers) - - def __bytes__(self): - if self.fields: - return b"\r\n".join(b": ".join(field) for field in self.fields) + b"\r\n" - else: - return b"" - - if six.PY2: - __str__ = __bytes__ - - @always_byte_args("ascii") - def __getitem__(self, name): - values = self.get_all(name) - if not values: - raise KeyError(name) - return b", ".join(values) - - @always_byte_args("ascii") - def __setitem__(self, name, value): - idx = self._index(name) - - # To please the human eye, we insert at the same position the first existing header occured. - if idx is not None: - del self[name] - self.fields.insert(idx, [name, value]) - else: - self.fields.append([name, value]) - - @always_byte_args("ascii") - def __delitem__(self, name): - if name not in self: - raise KeyError(name) - name = name.lower() - self.fields = [ - field for field in self.fields - if name != field[0].lower() - ] - - def __iter__(self): - seen = set() - for name, _ in self.fields: - name_lower = name.lower() - if name_lower not in seen: - seen.add(name_lower) - yield name - - def __len__(self): - return len(set(name.lower() for name, _ in self.fields)) - - # __hash__ = object.__hash__ - - def _index(self, name): - name = name.lower() - for i, field in enumerate(self.fields): - if field[0].lower() == name: - return i - return None - - def __eq__(self, other): - if isinstance(other, Headers): - return self.fields == other.fields - return False - - def __ne__(self, other): - return not self.__eq__(other) - - @always_byte_args("ascii") - def get_all(self, name): - """ - Like :py:meth:`get`, but does not fold multiple headers into a single one. - This is useful for Set-Cookie headers, which do not support folding. - - See also: https://tools.ietf.org/html/rfc7230#section-3.2.2 - """ - name_lower = name.lower() - values = [value for n, value in self.fields if n.lower() == name_lower] - return values - - def set_all(self, name, values): - """ - Explicitly set multiple headers for the given key. - See: :py:meth:`get_all` - """ - name = always_bytes(name, "ascii") - values = (always_bytes(value, "ascii") for value in values) - if name in self: - del self[name] - self.fields.extend( - [name, value] for value in values - ) - - def copy(self): - return Headers(copy.copy(self.fields)) - - # Implement the StateObject protocol from mitmproxy - def get_state(self, short=False): - return tuple(tuple(field) for field in self.fields) - - def load_state(self, state): - self.fields = [list(field) for field in state] - - @classmethod - def from_state(cls, state): - return cls([list(field) for field in state]) - - class Message(object): def __init__(self, http_version, headers, body, timestamp_start, timestamp_end): self.http_version = http_version @@ -216,7 +37,7 @@ class Message(object): def body(self, body): self._body = body if isinstance(body, bytes): - self.headers[b"content-length"] = str(len(body)).encode() + self.headers["content-length"] = str(len(body)).encode() content = body @@ -268,8 +89,8 @@ class Request(Message): response. That is, we remove ETags and If-Modified-Since headers. """ delheaders = [ - b"if-modified-since", - b"if-none-match", + "if-modified-since", + "if-none-match", ] for i in delheaders: self.headers.pop(i, None) @@ -279,14 +100,14 @@ class Request(Message): Modifies this request to remove headers that will compress the resource's data. """ - self.headers["accept-encoding"] = b"identity" + self.headers["accept-encoding"] = "identity" def constrain_encoding(self): """ Limits the permissible Accept-Encoding values, based on what we can decode appropriately. """ - accept_encoding = native(self.headers.get("accept-encoding"), "ascii") + accept_encoding = self.headers.get("accept-encoding") if accept_encoding: self.headers["accept-encoding"] = ( ', '.join( @@ -309,9 +130,9 @@ class Request(Message): indicates non-form data. """ if self.body: - if HDR_FORM_URLENCODED in self.headers.get("content-type", b"").lower(): + if HDR_FORM_URLENCODED in self.headers.get("content-type", "").lower(): return self.get_form_urlencoded() - elif HDR_FORM_MULTIPART in self.headers.get("content-type", b"").lower(): + elif HDR_FORM_MULTIPART in self.headers.get("content-type", "").lower(): return self.get_form_multipart() return ODict([]) @@ -321,12 +142,12 @@ class Request(Message): Returns an empty ODict if there is no data or the content-type indicates non-form data. """ - if self.body and HDR_FORM_URLENCODED in self.headers.get("content-type", b"").lower(): + if self.body and HDR_FORM_URLENCODED in self.headers.get("content-type", "").lower(): return ODict(utils.urldecode(self.body)) return ODict([]) def get_form_multipart(self): - if self.body and HDR_FORM_MULTIPART in self.headers.get("content-type", b"").lower(): + if self.body and HDR_FORM_MULTIPART in self.headers.get("content-type", "").lower(): return ODict( utils.multipartdecode( self.headers, @@ -341,7 +162,7 @@ class Request(Message): """ # FIXME: If there's an existing content-type header indicating a # url-encoded form, leave it alone. - self.headers[b"content-type"] = HDR_FORM_URLENCODED + self.headers["content-type"] = HDR_FORM_URLENCODED self.body = utils.urlencode(odict.lst) def get_path_components(self): @@ -400,7 +221,7 @@ class Request(Message): """ if hostheader and "host" in self.headers: try: - return self.headers["host"].decode("idna") + return self.headers["host"] except ValueError: pass if self.host: @@ -420,7 +241,7 @@ class Request(Message): """ ret = ODict() for i in self.headers.get_all("Cookie"): - ret.extend(cookies.parse_cookie_header(native(i,"ascii"))) + ret.extend(cookies.parse_cookie_header(i)) return ret def set_cookies(self, odict): @@ -499,7 +320,7 @@ class Response(Message): """ ret = [] for header in self.headers.get_all("set-cookie"): - v = cookies.parse_set_cookie_header(native(header, "ascii")) + v = cookies.parse_set_cookie_header(header) if v: name, value, attrs = v ret.append([name, [value, attrs]]) diff --git a/netlib/utils.py b/netlib/utils.py index b9848038..d5b30128 100644 --- a/netlib/utils.py +++ b/netlib/utils.py @@ -269,7 +269,7 @@ def get_header_tokens(headers, key): """ if key not in headers: return [] - tokens = headers[key].split(b",") + tokens = headers[key].split(",") return [token.strip() for token in tokens] @@ -320,14 +320,14 @@ def parse_content_type(c): ("text", "html", {"charset": "UTF-8"}) """ - parts = c.split(b";", 1) - ts = parts[0].split(b"/", 1) + parts = c.split(";", 1) + ts = parts[0].split("/", 1) if len(ts) != 2: return None d = {} if len(parts) == 2: - for i in parts[1].split(b";"): - clause = i.split(b"=", 1) + for i in parts[1].split(";"): + clause = i.split("=", 1) if len(clause) == 2: d[clause[0].strip()] = clause[1].strip() return ts[0].lower(), ts[1].lower(), d @@ -337,13 +337,14 @@ def multipartdecode(headers, content): """ Takes a multipart boundary encoded string and returns list of (key, value) tuples. """ - v = headers.get(b"Content-Type") + v = headers.get("Content-Type") if v: v = parse_content_type(v) if not v: return [] - boundary = v[2].get(b"boundary") - if not boundary: + try: + boundary = v[2]["boundary"].encode("ascii") + except (KeyError, UnicodeError): return [] rx = re.compile(br'\bname="([^"]+)"') diff --git a/netlib/websockets/protocol.py b/netlib/websockets/protocol.py index 778fe7e7..e62f8df6 100644 --- a/netlib/websockets/protocol.py +++ b/netlib/websockets/protocol.py @@ -80,7 +80,7 @@ class WebsocketsProtocol(object): Returns an instance of Headers """ if not key: - key = base64.b64encode(os.urandom(16)).decode('utf-8') + key = base64.b64encode(os.urandom(16)).decode('ascii') return Headers(**{ HEADER_WEBSOCKET_KEY: key, HEADER_WEBSOCKET_VERSION: version, @@ -95,27 +95,25 @@ class WebsocketsProtocol(object): """ return Headers(**{ HEADER_WEBSOCKET_ACCEPT: self.create_server_nonce(key), - "Connection": "Upgrade", - "Upgrade": "websocket", + "connection": "Upgrade", + "upgrade": "websocket", }) @classmethod def check_client_handshake(self, headers): - if headers.get("upgrade") != b"websocket": + if headers.get("upgrade") != "websocket": return return headers.get(HEADER_WEBSOCKET_KEY) @classmethod def check_server_handshake(self, headers): - if headers.get("upgrade") != b"websocket": + if headers.get("upgrade") != "websocket": return return headers.get(HEADER_WEBSOCKET_ACCEPT) @classmethod def create_server_nonce(self, client_nonce): - return base64.b64encode( - binascii.unhexlify(hashlib.sha1(client_nonce + websockets_magic).hexdigest()) - ) + return base64.b64encode(hashlib.sha1(client_nonce + websockets_magic).digest()) diff --git a/test/http/http1/test_assemble.py b/test/http/http1/test_assemble.py index 2d250909..963e7549 100644 --- a/test/http/http1/test_assemble.py +++ b/test/http/http1/test_assemble.py @@ -77,16 +77,16 @@ def test_assemble_request_line(): def test_assemble_request_headers(): # https://github.com/mitmproxy/mitmproxy/issues/186 r = treq(body=b"") - r.headers[b"Transfer-Encoding"] = b"chunked" + r.headers["Transfer-Encoding"] = "chunked" c = _assemble_request_headers(r) assert b"Transfer-Encoding" in c - assert b"Host" in _assemble_request_headers(treq(headers=Headers())) + assert b"host" in _assemble_request_headers(treq(headers=Headers())) def test_assemble_response_headers(): # https://github.com/mitmproxy/mitmproxy/issues/186 r = tresp(body=b"") - r.headers["Transfer-Encoding"] = b"chunked" + r.headers["Transfer-Encoding"] = "chunked" c = _assemble_response_headers(r) assert b"Transfer-Encoding" in c diff --git a/test/http/http1/test_read.py b/test/http/http1/test_read.py index 55def2a5..9eb02a24 100644 --- a/test/http/http1/test_read.py +++ b/test/http/http1/test_read.py @@ -1,9 +1,7 @@ from __future__ import absolute_import, print_function, division from io import BytesIO import textwrap - from mock import Mock - from netlib.exceptions import HttpException, HttpSyntaxException, HttpReadDisconnect from netlib.http import Headers from netlib.http.http1.read import ( @@ -35,7 +33,7 @@ def test_read_request_head(): rfile.first_byte_timestamp = 42 r = read_request_head(rfile) assert r.method == b"GET" - assert r.headers["Content-Length"] == b"4" + assert r.headers["Content-Length"] == "4" assert r.body is None assert rfile.reset_timestamps.called assert r.timestamp_start == 42 @@ -62,7 +60,7 @@ def test_read_response_head(): rfile.first_byte_timestamp = 42 r = read_response_head(rfile) assert r.status_code == 418 - assert r.headers["Content-Length"] == b"4" + assert r.headers["Content-Length"] == "4" assert r.body is None assert rfile.reset_timestamps.called assert r.timestamp_start == 42 @@ -76,14 +74,12 @@ class TestReadBody(object): assert body == b"foo" assert rfile.read() == b"bar" - def test_known_size(self): rfile = BytesIO(b"foobar") body = b"".join(read_body(rfile, 3)) assert body == b"foo" assert rfile.read() == b"bar" - def test_known_size_limit(self): rfile = BytesIO(b"foobar") with raises(HttpException): @@ -99,7 +95,6 @@ class TestReadBody(object): body = b"".join(read_body(rfile, -1)) assert body == b"foobar" - def test_unknown_size_limit(self): rfile = BytesIO(b"foobar") with raises(HttpException): @@ -121,13 +116,13 @@ def test_connection_close(): def test_expected_http_body_size(): # Expect: 100-continue assert expected_http_body_size( - treq(headers=Headers(expect=b"100-continue", content_length=b"42")) + treq(headers=Headers(expect="100-continue", content_length="42")) ) == 0 # http://tools.ietf.org/html/rfc7230#section-3.3 assert expected_http_body_size( treq(method=b"HEAD"), - tresp(headers=Headers(content_length=b"42")) + tresp(headers=Headers(content_length="42")) ) == 0 assert expected_http_body_size( treq(method=b"CONNECT"), @@ -141,17 +136,17 @@ def test_expected_http_body_size(): # chunked assert expected_http_body_size( - treq(headers=Headers(transfer_encoding=b"chunked")), + treq(headers=Headers(transfer_encoding="chunked")), ) is None # explicit length - for l in (b"foo", b"-7"): + for val in (b"foo", b"-7"): with raises(HttpSyntaxException): expected_http_body_size( - treq(headers=Headers(content_length=l)) + treq(headers=Headers(content_length=val)) ) assert expected_http_body_size( - treq(headers=Headers(content_length=b"42")) + treq(headers=Headers(content_length="42")) ) == 42 # no length @@ -286,6 +281,7 @@ class TestReadHeaders(object): with raises(HttpSyntaxException): self._read(data) + def test_read_chunked(): req = treq(body=None) req.headers["Transfer-Encoding"] = "chunked" diff --git a/test/http/test_authentication.py b/test/http/test_authentication.py index a2aa774a..1df7cd9c 100644 --- a/test/http/test_authentication.py +++ b/test/http/test_authentication.py @@ -5,13 +5,13 @@ from netlib.http import authentication, Headers def test_parse_http_basic_auth(): - vals = (b"basic", b"foo", b"bar") + vals = ("basic", "foo", "bar") assert authentication.parse_http_basic_auth( authentication.assemble_http_basic_auth(*vals) ) == vals assert not authentication.parse_http_basic_auth("") assert not authentication.parse_http_basic_auth("foo bar") - v = b"basic " + binascii.b2a_base64(b"foo") + v = "basic " + binascii.b2a_base64(b"foo").decode("ascii") assert not authentication.parse_http_basic_auth(v) @@ -34,7 +34,7 @@ class TestPassManHtpasswd: def test_simple(self): pm = authentication.PassManHtpasswd(tutils.test_data.path("data/htpasswd")) - vals = (b"basic", b"test", b"test") + vals = ("basic", "test", "test") authentication.assemble_http_basic_auth(*vals) assert pm.test("test", "test") assert not pm.test("test", "foo") @@ -73,7 +73,7 @@ class TestBasicProxyAuth: ba = authentication.BasicProxyAuth(authentication.PassManNonAnon(), "test") headers = Headers() - vals = (b"basic", b"foo", b"bar") + vals = ("basic", "foo", "bar") headers[ba.AUTH_HEADER] = authentication.assemble_http_basic_auth(*vals) assert ba.authenticate(headers) @@ -86,12 +86,12 @@ class TestBasicProxyAuth: headers[ba.AUTH_HEADER] = "foo" assert not ba.authenticate(headers) - vals = (b"foo", b"foo", b"bar") + vals = ("foo", "foo", "bar") headers[ba.AUTH_HEADER] = authentication.assemble_http_basic_auth(*vals) assert not ba.authenticate(headers) ba = authentication.BasicProxyAuth(authentication.PassMan(), "test") - vals = (b"basic", b"foo", b"bar") + vals = ("basic", "foo", "bar") headers[ba.AUTH_HEADER] = authentication.assemble_http_basic_auth(*vals) assert not ba.authenticate(headers) diff --git a/test/http/test_headers.py b/test/http/test_headers.py new file mode 100644 index 00000000..f1af1feb --- /dev/null +++ b/test/http/test_headers.py @@ -0,0 +1,149 @@ +from netlib.http import Headers +from netlib.tutils import raises + + +class TestHeaders(object): + def _2host(self): + return Headers( + [ + [b"Host", b"example.com"], + [b"host", b"example.org"] + ] + ) + + def test_init(self): + headers = Headers() + assert len(headers) == 0 + + headers = Headers([[b"Host", b"example.com"]]) + assert len(headers) == 1 + assert headers["Host"] == "example.com" + + headers = Headers(Host="example.com") + assert len(headers) == 1 + assert headers["Host"] == "example.com" + + headers = Headers( + [[b"Host", b"invalid"]], + Host="example.com" + ) + assert len(headers) == 1 + assert headers["Host"] == "example.com" + + headers = Headers( + [[b"Host", b"invalid"], [b"Accept", b"text/plain"]], + Host="example.com" + ) + assert len(headers) == 2 + assert headers["Host"] == "example.com" + assert headers["Accept"] == "text/plain" + + def test_getitem(self): + headers = Headers(Host="example.com") + assert headers["Host"] == "example.com" + assert headers["host"] == "example.com" + with raises(KeyError): + _ = headers["Accept"] + + headers = self._2host() + assert headers["Host"] == "example.com, example.org" + + def test_str(self): + headers = Headers(Host="example.com") + assert bytes(headers) == b"Host: example.com\r\n" + + headers = Headers([ + [b"Host", b"example.com"], + [b"Accept", b"text/plain"] + ]) + assert bytes(headers) == b"Host: example.com\r\nAccept: text/plain\r\n" + + headers = Headers() + assert bytes(headers) == b"" + + def test_setitem(self): + headers = Headers() + headers["Host"] = "example.com" + assert "Host" in headers + assert "host" in headers + assert headers["Host"] == "example.com" + + headers["host"] = "example.org" + assert "Host" in headers + assert "host" in headers + assert headers["Host"] == "example.org" + + headers["accept"] = "text/plain" + assert len(headers) == 2 + assert "Accept" in headers + assert "Host" in headers + + headers = self._2host() + assert len(headers.fields) == 2 + headers["Host"] = "example.com" + assert len(headers.fields) == 1 + assert "Host" in headers + + def test_delitem(self): + headers = Headers(Host="example.com") + assert len(headers) == 1 + del headers["host"] + assert len(headers) == 0 + try: + del headers["host"] + except KeyError: + assert True + else: + assert False + + headers = self._2host() + del headers["Host"] + assert len(headers) == 0 + + def test_keys(self): + headers = Headers(Host="example.com") + assert list(headers.keys()) == ["Host"] + + headers = self._2host() + assert list(headers.keys()) == ["Host"] + + def test_eq_ne(self): + headers1 = Headers(Host="example.com") + headers2 = Headers(host="example.com") + assert not (headers1 == headers2) + assert headers1 != headers2 + + headers1 = Headers(Host="example.com") + headers2 = Headers(Host="example.com") + assert headers1 == headers2 + assert not (headers1 != headers2) + + assert headers1 != 42 + + def test_get_all(self): + headers = self._2host() + assert headers.get_all("host") == ["example.com", "example.org"] + assert headers.get_all("accept") == [] + + def test_set_all(self): + headers = Headers(Host="example.com") + headers.set_all("Accept", ["text/plain"]) + assert len(headers) == 2 + assert "accept" in headers + + headers = self._2host() + headers.set_all("Host", ["example.org"]) + assert headers["host"] == "example.org" + + headers.set_all("Host", ["example.org", "example.net"]) + assert headers["host"] == "example.org, example.net" + + def test_state(self): + headers = self._2host() + assert len(headers.get_state()) == 2 + assert headers == Headers.from_state(headers.get_state()) + + headers2 = Headers() + assert headers != headers2 + headers2.load_state(headers.get_state()) + assert headers == headers2 diff --git a/test/http/test_models.py b/test/http/test_models.py index d420b22b..10e0795a 100644 --- a/test/http/test_models.py +++ b/test/http/test_models.py @@ -58,20 +58,20 @@ class TestRequest(object): req = tutils.treq() req.headers["Accept-Encoding"] = "foobar" req.anticomp() - assert req.headers["Accept-Encoding"] == b"identity" + assert req.headers["Accept-Encoding"] == "identity" def test_constrain_encoding(self): req = tutils.treq() req.headers["Accept-Encoding"] = "identity, gzip, foo" req.constrain_encoding() - assert b"foo" not in req.headers["Accept-Encoding"] + assert "foo" not in req.headers["Accept-Encoding"] def test_update_host(self): req = tutils.treq() req.headers["Host"] = "" req.host = "foobar" req.update_host_header() - assert req.headers["Host"] == b"foobar" + assert req.headers["Host"] == "foobar" def test_get_form(self): req = tutils.treq() @@ -393,149 +393,3 @@ class TestResponse(object): v = resp.get_cookies() assert len(v) == 1 assert v["foo"] == [["bar", ODictCaseless()]] - - -class TestHeaders(object): - def _2host(self): - return Headers( - [ - [b"Host", b"example.com"], - [b"host", b"example.org"] - ] - ) - - def test_init(self): - headers = Headers() - assert len(headers) == 0 - - headers = Headers([[b"Host", b"example.com"]]) - assert len(headers) == 1 - assert headers["Host"] == b"example.com" - - headers = Headers(Host="example.com") - assert len(headers) == 1 - assert headers["Host"] == b"example.com" - - headers = Headers( - [[b"Host", b"invalid"]], - Host="example.com" - ) - assert len(headers) == 1 - assert headers["Host"] == b"example.com" - - headers = Headers( - [[b"Host", b"invalid"], [b"Accept", b"text/plain"]], - Host="example.com" - ) - assert len(headers) == 2 - assert headers["Host"] == b"example.com" - assert headers["Accept"] == b"text/plain" - - def test_getitem(self): - headers = Headers(Host="example.com") - assert headers["Host"] == b"example.com" - assert headers["host"] == b"example.com" - tutils.raises(KeyError, headers.__getitem__, "Accept") - - headers = self._2host() - assert headers["Host"] == b"example.com, example.org" - - def test_str(self): - headers = Headers(Host="example.com") - assert bytes(headers) == b"Host: example.com\r\n" - - headers = Headers([ - [b"Host", b"example.com"], - [b"Accept", b"text/plain"] - ]) - assert bytes(headers) == b"Host: example.com\r\nAccept: text/plain\r\n" - - headers = Headers() - assert bytes(headers) == b"" - - def test_setitem(self): - headers = Headers() - headers["Host"] = "example.com" - assert "Host" in headers - assert "host" in headers - assert headers["Host"] == b"example.com" - - headers["host"] = "example.org" - assert "Host" in headers - assert "host" in headers - assert headers["Host"] == b"example.org" - - headers["accept"] = "text/plain" - assert len(headers) == 2 - assert "Accept" in headers - assert "Host" in headers - - headers = self._2host() - assert len(headers.fields) == 2 - headers["Host"] = "example.com" - assert len(headers.fields) == 1 - assert "Host" in headers - - def test_delitem(self): - headers = Headers(Host="example.com") - assert len(headers) == 1 - del headers["host"] - assert len(headers) == 0 - try: - del headers["host"] - except KeyError: - assert True - else: - assert False - - headers = self._2host() - del headers["Host"] - assert len(headers) == 0 - - def test_keys(self): - headers = Headers(Host="example.com") - assert list(headers.keys()) == [b"Host"] - - headers = self._2host() - assert list(headers.keys()) == [b"Host"] - - def test_eq_ne(self): - headers1 = Headers(Host="example.com") - headers2 = Headers(host="example.com") - assert not (headers1 == headers2) - assert headers1 != headers2 - - headers1 = Headers(Host="example.com") - headers2 = Headers(Host="example.com") - assert headers1 == headers2 - assert not (headers1 != headers2) - - assert headers1 != 42 - - def test_get_all(self): - headers = self._2host() - assert headers.get_all("host") == [b"example.com", b"example.org"] - assert headers.get_all("accept") == [] - - def test_set_all(self): - headers = Headers(Host="example.com") - headers.set_all("Accept", ["text/plain"]) - assert len(headers) == 2 - assert "accept" in headers - - headers = self._2host() - headers.set_all("Host", ["example.org"]) - assert headers["host"] == b"example.org" - - headers.set_all("Host", ["example.org", "example.net"]) - assert headers["host"] == b"example.org, example.net" - - def test_state(self): - headers = self._2host() - assert len(headers.get_state()) == 2 - assert headers == Headers.from_state(headers.get_state()) - - headers2 = Headers() - assert headers != headers2 - headers2.load_state(headers.get_state()) - assert headers == headers2 diff --git a/test/test_utils.py b/test/test_utils.py index 8f4b4059..17636cc4 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -103,17 +103,17 @@ def test_get_header_tokens(): headers = Headers() assert utils.get_header_tokens(headers, "foo") == [] headers["foo"] = "bar" - assert utils.get_header_tokens(headers, "foo") == [b"bar"] + assert utils.get_header_tokens(headers, "foo") == ["bar"] headers["foo"] = "bar, voing" - assert utils.get_header_tokens(headers, "foo") == [b"bar", b"voing"] + assert utils.get_header_tokens(headers, "foo") == ["bar", "voing"] headers.set_all("foo", ["bar, voing", "oink"]) - assert utils.get_header_tokens(headers, "foo") == [b"bar", b"voing", b"oink"] + assert utils.get_header_tokens(headers, "foo") == ["bar", "voing", "oink"] def test_multipartdecode(): - boundary = b'somefancyboundary' + boundary = 'somefancyboundary' headers = Headers( - content_type=b'multipart/form-data; boundary=' + boundary + content_type='multipart/form-data; boundary=' + boundary ) content = ( "--{0}\n" @@ -122,7 +122,7 @@ def test_multipartdecode(): "--{0}\n" "Content-Disposition: form-data; name=\"field2\"\n\n" "value2\n" - "--{0}--".format(boundary.decode()).encode() + "--{0}--".format(boundary).encode() ) form = utils.multipartdecode(headers, content) @@ -134,8 +134,8 @@ def test_multipartdecode(): def test_parse_content_type(): p = utils.parse_content_type - assert p(b"text/html") == (b"text", b"html", {}) - assert p(b"text") is None + assert p("text/html") == ("text", "html", {}) + assert p("text") is None - v = p(b"text/html; charset=UTF-8") - assert v == (b'text', b'html', {b'charset': b'UTF-8'}) + v = p("text/html; charset=UTF-8") + assert v == ('text', 'html', {'charset': 'UTF-8'}) diff --git a/test/websockets/test_websockets.py b/test/websockets/test_websockets.py index 48acc2d6..4ae4cf45 100644 --- a/test/websockets/test_websockets.py +++ b/test/websockets/test_websockets.py @@ -64,15 +64,14 @@ class WebSocketsClient(tcp.TCPClient): preamble = b'GET / HTTP/1.1' self.wfile.write(preamble + b"\r\n") headers = self.protocol.client_handshake_headers() - self.client_nonce = headers["sec-websocket-key"] + self.client_nonce = headers["sec-websocket-key"].encode("ascii") self.wfile.write(bytes(headers) + b"\r\n") self.wfile.flush() resp = read_response(self.rfile, treq(method="GET")) server_nonce = self.protocol.check_server_handshake(resp.headers) - if not server_nonce == self.protocol.create_server_nonce( - self.client_nonce): + if not server_nonce == self.protocol.create_server_nonce(self.client_nonce): self.close() def read_next_message(self): @@ -207,14 +206,14 @@ class TestFrameHeader: fin=True, payload_length=10 ) - assert f.human_readable() + assert repr(f) f = websockets.FrameHeader() - assert f.human_readable() + assert repr(f) def test_funky(self): f = websockets.FrameHeader(masking_key=b"test", mask=False) - bytes = f.to_bytes() - f2 = websockets.FrameHeader.from_file(tutils.treader(bytes)) + raw = bytes(f) + f2 = websockets.FrameHeader.from_file(tutils.treader(raw)) assert not f2.mask def test_violations(self): |