diff options
Diffstat (limited to 'netlib')
-rw-r--r-- | netlib/debug.py | 45 | ||||
-rw-r--r-- | netlib/encoding.py | 97 | ||||
-rw-r--r-- | netlib/http/cookies.py | 29 | ||||
-rw-r--r-- | netlib/http/headers.py | 34 | ||||
-rw-r--r-- | netlib/http/http1/assemble.py | 4 | ||||
-rw-r--r-- | netlib/http/http1/read.py | 5 | ||||
-rw-r--r-- | netlib/http/http2/__init__.py | 2 | ||||
-rw-r--r-- | netlib/http/http2/utils.py | 37 | ||||
-rw-r--r-- | netlib/http/message.py | 257 | ||||
-rw-r--r-- | netlib/http/request.py | 40 | ||||
-rw-r--r-- | netlib/http/response.py | 12 | ||||
-rw-r--r-- | netlib/multidict.py | 24 | ||||
-rw-r--r-- | netlib/strutils.py | 125 | ||||
-rw-r--r-- | netlib/tcp.py | 4 | ||||
-rw-r--r-- | netlib/utils.py | 11 | ||||
-rw-r--r-- | netlib/websockets/frame.py | 2 | ||||
-rw-r--r-- | netlib/wsgi.py | 4 |
17 files changed, 507 insertions, 225 deletions
diff --git a/netlib/debug.py b/netlib/debug.py index a395afcb..29c7f655 100644 --- a/netlib/debug.py +++ b/netlib/debug.py @@ -7,8 +7,6 @@ import signal import platform import traceback -import psutil - from netlib import version from OpenSSL import SSL @@ -19,7 +17,7 @@ def sysinfo(): "Mitmproxy version: %s" % version.VERSION, "Python version: %s" % platform.python_version(), "Platform: %s" % platform.platform(), - "SSL version: %s" % SSL.SSLeay_version(SSL.SSLEAY_VERSION), + "SSL version: %s" % SSL.SSLeay_version(SSL.SSLEAY_VERSION).decode(), ] d = platform.linux_distribution() t = "Linux distro: %s %s %s" % d @@ -40,15 +38,32 @@ def sysinfo(): def dump_info(sig, frm, file=sys.stdout): # pragma: no cover - p = psutil.Process() - print("****************************************************", file=file) print("Summary", file=file) print("=======", file=file) - print("num threads: ", p.num_threads(), file=file) - if hasattr(p, "num_fds"): - print("num fds: ", p.num_fds(), file=file) - print("memory: ", p.memory_info(), file=file) + + try: + import psutil + except: + print("(psutil not installed, skipping some debug info)", file=file) + else: + p = psutil.Process() + print("num threads: ", p.num_threads(), file=file) + if hasattr(p, "num_fds"): + print("num fds: ", p.num_fds(), file=file) + print("memory: ", p.memory_info(), file=file) + + print(file=file) + print("Files", file=file) + print("=====", file=file) + for i in p.open_files(): + print(i, file=file) + + print(file=file) + print("Connections", file=file) + print("===========", file=file) + for i in p.connections(): + print(i, file=file) print(file=file) print("Threads", file=file) @@ -63,18 +78,6 @@ def dump_info(sig, frm, file=sys.stdout): # pragma: no cover for i in bthreads: print(i._threadinfo(), file=file) - print(file=file) - print("Files", file=file) - print("=====", file=file) - for i in p.open_files(): - print(i, file=file) - - print(file=file) - print("Connections", file=file) - print("===========", file=file) - for i in p.connections(): - print(i, file=file) - print("****************************************************", file=file) diff --git a/netlib/encoding.py b/netlib/encoding.py index 98502451..8b67b543 100644 --- a/netlib/encoding.py +++ b/netlib/encoding.py @@ -1,39 +1,62 @@ """ - Utility functions for decoding response bodies. +Utility functions for decoding response bodies. """ from __future__ import absolute_import + +import codecs from io import BytesIO import gzip import zlib +from typing import Union # noqa + -ENCODINGS = {"identity", "gzip", "deflate"} +def decode(obj, encoding, errors='strict'): + # type: (Union[str, bytes], str) -> Union[str, bytes] + """ + Decode the given input object + Returns: + The decoded value -def decode(e, content): - if not isinstance(content, bytes): - return None - encoding_map = { - "identity": identity, - "gzip": decode_gzip, - "deflate": decode_deflate, - } - if e not in encoding_map: - return None - return encoding_map[e](content) + Raises: + ValueError, if decoding fails. + """ + try: + try: + return custom_decode[encoding](obj) + except KeyError: + return codecs.decode(obj, encoding, errors) + except Exception as e: + raise ValueError("{} when decoding {} with {}".format( + type(e).__name__, + repr(obj)[:10], + repr(encoding), + )) + + +def encode(obj, encoding, errors='strict'): + # type: (Union[str, bytes], str) -> Union[str, bytes] + """ + Encode the given input object + Returns: + The encoded value -def encode(e, content): - if not isinstance(content, bytes): - return None - encoding_map = { - "identity": identity, - "gzip": encode_gzip, - "deflate": encode_deflate, - } - if e not in encoding_map: - return None - return encoding_map[e](content) + Raises: + ValueError, if encoding fails. + """ + try: + try: + return custom_encode[encoding](obj) + except KeyError: + return codecs.encode(obj, encoding, errors) + except Exception as e: + raise ValueError("{} when encoding {} with {}".format( + type(e).__name__, + repr(obj)[:10], + repr(encoding), + )) def identity(content): @@ -46,10 +69,7 @@ def identity(content): def decode_gzip(content): gfile = gzip.GzipFile(fileobj=BytesIO(content)) - try: - return gfile.read() - except (IOError, EOFError): - return None + return gfile.read() def encode_gzip(content): @@ -70,12 +90,9 @@ def decode_deflate(content): http://bugs.python.org/issue5784 """ try: - try: - return zlib.decompress(content) - except zlib.error: - return zlib.decompress(content, -15) + return zlib.decompress(content) except zlib.error: - return None + return zlib.decompress(content, -15) def encode_deflate(content): @@ -84,4 +101,16 @@ def encode_deflate(content): """ return zlib.compress(content) -__all__ = ["ENCODINGS", "encode", "decode"] + +custom_decode = { + "identity": identity, + "gzip": decode_gzip, + "deflate": decode_deflate, +} +custom_encode = { + "identity": identity, + "gzip": encode_gzip, + "deflate": encode_deflate, +} + +__all__ = ["encode", "decode"] diff --git a/netlib/http/cookies.py b/netlib/http/cookies.py index 768a85df..dd0af99c 100644 --- a/netlib/http/cookies.py +++ b/netlib/http/cookies.py @@ -1,7 +1,8 @@ import collections +import email.utils import re +import time -import email.utils from netlib import multidict """ @@ -260,3 +261,29 @@ def refresh_set_cookie_header(c, delta): if not ret: raise ValueError("Invalid Cookie") return ret + + +def is_expired(cookie_attrs): + """ + Determines whether a cookie has expired. + + Returns: boolean + """ + + # See if 'expires' time is in the past + expires = False + if 'expires' in cookie_attrs: + e = email.utils.parsedate_tz(cookie_attrs["expires"]) + if e: + exp_ts = email.utils.mktime_tz(e) + now_ts = time.time() + expires = exp_ts < now_ts + + # or if Max-Age is 0 + max_age = False + try: + max_age = int(cookie_attrs.get('Max-Age', 1)) == 0 + except ValueError: + pass + + return expires or max_age diff --git a/netlib/http/headers.py b/netlib/http/headers.py index 14888ea9..36e5060c 100644 --- a/netlib/http/headers.py +++ b/netlib/http/headers.py @@ -2,6 +2,7 @@ from __future__ import absolute_import, print_function, division import re +import collections import six from netlib import multidict from netlib import strutils @@ -148,6 +149,15 @@ class Headers(multidict.MultiDict): value = _always_bytes(value) super(Headers, self).insert(index, key, value) + def items(self, multi=False): + if multi: + return ( + (_native(k), _native(v)) + for k, v in self.fields + ) + else: + return super(Headers, self).items() + def replace(self, pattern, repl, flags=0): """ Replaces a regular expression pattern with repl in each "name: value" @@ -156,8 +166,10 @@ class Headers(multidict.MultiDict): Returns: The number of replacements made. """ - pattern = _always_bytes(pattern) - repl = _always_bytes(repl) + if isinstance(pattern, six.text_type): + pattern = strutils.escaped_str_to_bytes(pattern) + if isinstance(repl, six.text_type): + repl = strutils.escaped_str_to_bytes(repl) pattern = re.compile(pattern, flags) replacements = 0 @@ -172,8 +184,8 @@ class Headers(multidict.MultiDict): pass else: replacements += n - fields.append([name, value]) - self.fields = fields + fields.append((name, value)) + self.fields = tuple(fields) return replacements @@ -195,10 +207,22 @@ def parse_content_type(c): ts = parts[0].split("/", 1) if len(ts) != 2: return None - d = {} + d = collections.OrderedDict() if len(parts) == 2: for i in parts[1].split(";"): clause = i.split("=", 1) if len(clause) == 2: d[clause[0].strip()] = clause[1].strip() return ts[0].lower(), ts[1].lower(), d + + +def assemble_content_type(type, subtype, parameters): + if not parameters: + return "{}/{}".format(type, subtype) + params = "; ".join( + "{}={}".format(k, v) + for k, v in parameters.items() + ) + return "{}/{}; {}".format( + type, subtype, params + ) diff --git a/netlib/http/http1/assemble.py b/netlib/http/http1/assemble.py index 511328f1..e74732d2 100644 --- a/netlib/http/http1/assemble.py +++ b/netlib/http/http1/assemble.py @@ -5,7 +5,7 @@ from netlib import exceptions def assemble_request(request): - if request.content is None: + if request.data.content is None: raise exceptions.HttpException("Cannot assemble flow with missing content") head = assemble_request_head(request) body = b"".join(assemble_body(request.data.headers, [request.data.content])) @@ -19,7 +19,7 @@ def assemble_request_head(request): def assemble_response(response): - if response.content is None: + if response.data.content is None: raise exceptions.HttpException("Cannot assemble flow with missing content") head = assemble_response_head(response) body = b"".join(assemble_body(response.data.headers, [response.data.content])) diff --git a/netlib/http/http1/read.py b/netlib/http/http1/read.py index a4c341fd..70fffbd4 100644 --- a/netlib/http/http1/read.py +++ b/netlib/http/http1/read.py @@ -244,7 +244,7 @@ def _read_request_line(rfile): raise exceptions.HttpReadDisconnect("Client disconnected") try: - method, path, http_version = line.split(b" ") + method, path, http_version = line.split() if path == b"*" or path.startswith(b"/"): form = "relative" @@ -291,8 +291,7 @@ def _read_response_line(rfile): raise exceptions.HttpReadDisconnect("Server disconnected") try: - - parts = line.split(b" ", 2) + parts = line.split(None, 2) if len(parts) == 2: # handle missing message gracefully parts.append(b"") diff --git a/netlib/http/http2/__init__.py b/netlib/http/http2/__init__.py index 6a979a0d..60064190 100644 --- a/netlib/http/http2/__init__.py +++ b/netlib/http/http2/__init__.py @@ -1,6 +1,8 @@ from __future__ import absolute_import, print_function, division from netlib.http.http2 import framereader +from netlib.http.http2.utils import parse_headers __all__ = [ "framereader", + "parse_headers", ] diff --git a/netlib/http/http2/utils.py b/netlib/http/http2/utils.py new file mode 100644 index 00000000..164bacc8 --- /dev/null +++ b/netlib/http/http2/utils.py @@ -0,0 +1,37 @@ +from netlib.http import url + + +def parse_headers(headers): + authority = headers.get(':authority', '').encode() + method = headers.get(':method', 'GET').encode() + scheme = headers.get(':scheme', 'https').encode() + path = headers.get(':path', '/').encode() + + headers.pop(":method", None) + headers.pop(":scheme", None) + headers.pop(":path", None) + + host = None + port = None + + if path == b'*' or path.startswith(b"/"): + first_line_format = "relative" + elif method == b'CONNECT': # pragma: no cover + raise NotImplementedError("CONNECT over HTTP/2 is not implemented.") + else: # pragma: no cover + first_line_format = "absolute" + # FIXME: verify if path or :host contains what we need + scheme, host, port, _ = url.parse(path) + + if authority: + host, _, port = authority.partition(b':') + + if not host: + host = b'localhost' + + if not port: + port = 443 if scheme == b'https' else 80 + + port = int(port) + + return first_line_format, method, scheme, host, port, path diff --git a/netlib/http/message.py b/netlib/http/message.py index b633b671..34709f0a 100644 --- a/netlib/http/message.py +++ b/netlib/http/message.py @@ -1,5 +1,6 @@ from __future__ import absolute_import, print_function, division +import re import warnings import six @@ -51,7 +52,23 @@ class MessageData(basetypes.Serializable): return cls(**state) +class CachedDecode(object): + __slots__ = ["encoded", "encoding", "strict", "decoded"] + + def __init__(self, object, encoding, strict, decoded): + self.encoded = object + self.encoding = encoding + self.strict = strict + self.decoded = decoded + +no_cached_decode = CachedDecode(None, None, None, None) + + class Message(basetypes.Serializable): + def __init__(self): + self._content_cache = no_cached_decode # type: CachedDecode + self._text_cache = no_cached_decode # type: CachedDecode + def __eq__(self, other): if isinstance(other, Message): return self.data == other.data @@ -89,19 +106,82 @@ class Message(basetypes.Serializable): self.data.headers = h @property - def content(self): + def raw_content(self): + # type: () -> bytes """ The raw (encoded) HTTP message body - See also: :py:attr:`text` + See also: :py:attr:`content`, :py:class:`text` """ return self.data.content - @content.setter - def content(self, content): + @raw_content.setter + def raw_content(self, content): self.data.content = content - if isinstance(content, bytes): - self.headers["content-length"] = str(len(content)) + + def get_content(self, strict=True): + # type: (bool) -> bytes + """ + The HTTP message body decoded with the content-encoding header (e.g. gzip) + + Raises: + ValueError, when the content-encoding is invalid and strict is True. + + See also: :py:class:`raw_content`, :py:attr:`text` + """ + if self.raw_content is None: + return None + ce = self.headers.get("content-encoding") + cached = ( + self._content_cache.encoded == self.raw_content and + (self._content_cache.strict or not strict) and + self._content_cache.encoding == ce + ) + if not cached: + is_strict = True + if ce: + try: + decoded = encoding.decode(self.raw_content, ce) + except ValueError: + if strict: + raise + is_strict = False + decoded = self.raw_content + else: + decoded = self.raw_content + self._content_cache = CachedDecode(self.raw_content, ce, is_strict, decoded) + return self._content_cache.decoded + + def set_content(self, value): + if value is None: + self.raw_content = None + return + if not isinstance(value, bytes): + raise TypeError( + "Message content must be bytes, not {}. " + "Please use .text if you want to assign a str." + .format(type(value).__name__) + ) + ce = self.headers.get("content-encoding") + cached = ( + self._content_cache.decoded == value and + self._content_cache.encoding == ce and + self._content_cache.strict + ) + if not cached: + try: + encoded = encoding.encode(value, ce or "identity") + except ValueError: + # So we have an invalid content-encoding? + # Let's remove it! + del self.headers["content-encoding"] + ce = None + encoded = value + self._content_cache = CachedDecode(encoded, ce, True, value) + self.raw_content = self._content_cache.encoded + self.headers["content-length"] = str(len(self.raw_content)) + + content = property(get_content, set_content) @property def http_version(self): @@ -136,56 +216,108 @@ class Message(basetypes.Serializable): def timestamp_end(self, timestamp_end): self.data.timestamp_end = timestamp_end - @property - def text(self): - """ - The decoded HTTP message body. - Decoded contents are not cached, so accessing this attribute repeatedly is relatively expensive. + def _get_content_type_charset(self): + # type: () -> Optional[str] + ct = headers.parse_content_type(self.headers.get("content-type", "")) + if ct: + return ct[2].get("charset") - .. note:: - This is not implemented yet. + def _guess_encoding(self): + # type: () -> str + enc = self._get_content_type_charset() + if enc: + return enc - See also: :py:attr:`content`, :py:class:`decoded` + if "json" in self.headers.get("content-type", ""): + return "utf8" + else: + # We may also want to check for HTML meta tags here at some point. + return "latin-1" + + def get_text(self, strict=True): + # type: (bool) -> six.text_type """ - # This attribute should be called text, because that's what requests does. - raise NotImplementedError() + The HTTP message body decoded with both content-encoding header (e.g. gzip) + and content-type header charset. - @text.setter - def text(self, text): - raise NotImplementedError() + Raises: + ValueError, when either content-encoding or charset is invalid and strict is True. - def decode(self): + See also: :py:attr:`content`, :py:class:`raw_content` + """ + if self.raw_content is None: + return None + enc = self._guess_encoding() + + content = self.get_content(strict) + cached = ( + self._text_cache.encoded == content and + (self._text_cache.strict or not strict) and + self._text_cache.encoding == enc + ) + if not cached: + is_strict = self._content_cache.strict + try: + decoded = encoding.decode(content, enc) + except ValueError: + if strict: + raise + is_strict = False + decoded = self.content.decode("utf8", "replace" if six.PY2 else "surrogateescape") + self._text_cache = CachedDecode(content, enc, is_strict, decoded) + return self._text_cache.decoded + + def set_text(self, text): + if text is None: + self.content = None + return + enc = self._guess_encoding() + + cached = ( + self._text_cache.decoded == text and + self._text_cache.encoding == enc and + self._text_cache.strict + ) + if not cached: + try: + encoded = encoding.encode(text, enc) + except ValueError: + # Fall back to UTF-8 and update the content-type header. + ct = headers.parse_content_type(self.headers.get("content-type", "")) or ("text", "plain", {}) + ct[2]["charset"] = "utf-8" + self.headers["content-type"] = headers.assemble_content_type(*ct) + enc = "utf8" + encoded = text.encode(enc, "replace" if six.PY2 else "surrogateescape") + self._text_cache = CachedDecode(encoded, enc, True, text) + self.content = self._text_cache.encoded + + text = property(get_text, set_text) + + def decode(self, strict=True): """ - Decodes body based on the current Content-Encoding header, then - removes the header. If there is no Content-Encoding header, no - action is taken. + Decodes body based on the current Content-Encoding header, then + removes the header. If there is no Content-Encoding header, no + action is taken. - Returns: - True, if decoding succeeded. - False, otherwise. + Raises: + ValueError, when the content-encoding is invalid and strict is True. """ - ce = self.headers.get("content-encoding") - data = encoding.decode(ce, self.content) - if data is None: - return False - self.content = data + self.raw_content = self.get_content(strict) self.headers.pop("content-encoding", None) - return True def encode(self, e): """ - Encodes body with the encoding e, where e is "gzip", "deflate" or "identity". + Encodes body with the encoding e, where e is "gzip", "deflate" or "identity". + Any existing content-encodings are overwritten, + the content is not decoded beforehand. - Returns: - True, if decoding succeeded. - False, otherwise. + Raises: + ValueError, when the specified content-encoding is invalid. """ - data = encoding.encode(e, self.content) - if data is None: - return False - self.content = data self.headers["content-encoding"] = e - return True + self.content = self.raw_content + if "content-encoding" not in self.headers: + raise ValueError("Invalid content encoding {}".format(repr(e))) def replace(self, pattern, repl, flags=0): """ @@ -196,13 +328,15 @@ class Message(basetypes.Serializable): Returns: The number of replacements made. """ - # TODO: Proper distinction between text and bytes. + if isinstance(pattern, six.text_type): + pattern = strutils.escaped_str_to_bytes(pattern) + if isinstance(repl, six.text_type): + repl = strutils.escaped_str_to_bytes(repl) replacements = 0 if self.content: - with decoded(self): - self.content, replacements = strutils.safe_subn( - pattern, repl, self.content, flags=flags - ) + self.content, replacements = re.subn( + pattern, repl, self.content, flags=flags + ) replacements += self.headers.replace(pattern, repl, flags) return replacements @@ -221,29 +355,16 @@ class Message(basetypes.Serializable): class decoded(object): """ - A context manager that decodes a request or response, and then - re-encodes it with the same encoding after execution of the block. - - Example: - - .. code-block:: python - - with decoded(request): - request.content = request.content.replace("foo", "bar") + Deprecated: You can now directly use :py:attr:`content`. + :py:attr:`raw_content` has the encoded content. """ - def __init__(self, message): - self.message = message - ce = message.headers.get("content-encoding") - if ce in encoding.ENCODINGS: - self.ce = ce - else: - self.ce = None + def __init__(self, message): # pragma no cover + warnings.warn("decoded() is deprecated, you can now directly use .content instead. " + ".raw_content has the encoded content.", DeprecationWarning) - def __enter__(self): - if self.ce: - self.message.decode() + def __enter__(self): # pragma no cover + pass - def __exit__(self, type, value, tb): - if self.ce: - self.message.encode(self.ce) + def __exit__(self, type, value, tb): # pragma no cover + pass diff --git a/netlib/http/request.py b/netlib/http/request.py index 01801d42..ecaa9b79 100644 --- a/netlib/http/request.py +++ b/netlib/http/request.py @@ -5,7 +5,6 @@ import re import six from six.moves import urllib -from netlib import encoding from netlib import multidict from netlib import strutils from netlib.http import multipart @@ -23,8 +22,20 @@ host_header_re = re.compile(r"^(?P<host>[^:]+|\[.+\])(?::(?P<port>\d+))?$") class RequestData(message.MessageData): def __init__(self, first_line_format, method, scheme, host, port, path, http_version, headers=(), content=None, timestamp_start=None, timestamp_end=None): + if isinstance(method, six.text_type): + method = method.encode("ascii", "strict") + if isinstance(scheme, six.text_type): + scheme = scheme.encode("ascii", "strict") + if isinstance(host, six.text_type): + host = host.encode("idna", "strict") + if isinstance(path, six.text_type): + path = path.encode("ascii", "strict") + if isinstance(http_version, six.text_type): + http_version = http_version.encode("ascii", "strict") if not isinstance(headers, nheaders.Headers): headers = nheaders.Headers(headers) + if isinstance(content, six.text_type): + raise ValueError("Content must be bytes, not {}".format(type(content).__name__)) self.first_line_format = first_line_format self.method = method @@ -44,6 +55,7 @@ class Request(message.Message): An HTTP request. """ def __init__(self, *args, **kwargs): + super(Request, self).__init__() self.data = RequestData(*args, **kwargs) def __repr__(self): @@ -65,10 +77,14 @@ class Request(message.Message): Returns: The number of replacements made. """ - # TODO: Proper distinction between text and bytes. + if isinstance(pattern, six.text_type): + pattern = strutils.escaped_str_to_bytes(pattern) + if isinstance(repl, six.text_type): + repl = strutils.escaped_str_to_bytes(repl) + c = super(Request, self).replace(pattern, repl, flags) - self.path, pc = strutils.safe_subn( - pattern, repl, self.path, flags=flags + self.path, pc = re.subn( + pattern, repl, self.data.path, flags=flags ) c += pc return c @@ -102,6 +118,8 @@ class Request(message.Message): """ HTTP request scheme, which should be "http" or "https". """ + if not self.data.scheme: + return self.data.scheme return message._native(self.data.scheme) @scheme.setter @@ -321,7 +339,7 @@ class Request(message.Message): self.headers["accept-encoding"] = ( ', '.join( e - for e in encoding.ENCODINGS + for e in {"gzip", "identity", "deflate"} if e in accept_encoding ) ) @@ -341,7 +359,10 @@ class Request(message.Message): def _get_urlencoded_form(self): is_valid_content_type = "application/x-www-form-urlencoded" in self.headers.get("content-type", "").lower() if is_valid_content_type: - return tuple(netlib.http.url.decode(self.content)) + try: + return tuple(netlib.http.url.decode(self.content)) + except ValueError: + pass return () def _set_urlencoded_form(self, value): @@ -350,7 +371,7 @@ class Request(message.Message): This will overwrite the existing content if there is one. """ self.headers["content-type"] = "application/x-www-form-urlencoded" - self.content = netlib.http.url.encode(value) + self.content = netlib.http.url.encode(value).encode() @urlencoded_form.setter def urlencoded_form(self, value): @@ -370,7 +391,10 @@ class Request(message.Message): def _get_multipart_form(self): is_valid_content_type = "multipart/form-data" in self.headers.get("content-type", "").lower() if is_valid_content_type: - return multipart.decode(self.headers, self.content) + try: + return multipart.decode(self.headers, self.content) + except ValueError: + pass return () def _set_multipart_form(self, value): diff --git a/netlib/http/response.py b/netlib/http/response.py index 17d69418..85f54940 100644 --- a/netlib/http/response.py +++ b/netlib/http/response.py @@ -2,6 +2,7 @@ from __future__ import absolute_import, print_function, division from email.utils import parsedate_tz, formatdate, mktime_tz import time +import six from netlib.http import cookies from netlib.http import headers as nheaders @@ -13,8 +14,14 @@ from netlib import human class ResponseData(message.MessageData): def __init__(self, http_version, status_code, reason=None, headers=(), content=None, timestamp_start=None, timestamp_end=None): + if isinstance(http_version, six.text_type): + http_version = http_version.encode("ascii", "strict") + if isinstance(reason, six.text_type): + reason = reason.encode("ascii", "strict") if not isinstance(headers, nheaders.Headers): headers = nheaders.Headers(headers) + if isinstance(content, six.text_type): + raise ValueError("Content must be bytes, not {}".format(type(content).__name__)) self.http_version = http_version self.status_code = status_code @@ -30,13 +37,14 @@ class Response(message.Message): An HTTP response. """ def __init__(self, *args, **kwargs): + super(Response, self).__init__() self.data = ResponseData(*args, **kwargs) def __repr__(self): - if self.content: + if self.raw_content: details = "{}, {}".format( self.headers.get("content-type", "unknown content type"), - human.pretty_size(len(self.content)) + human.pretty_size(len(self.raw_content)) ) else: details = "no content" diff --git a/netlib/multidict.py b/netlib/multidict.py index 50c879d9..51053ff6 100644 --- a/netlib/multidict.py +++ b/netlib/multidict.py @@ -170,18 +170,10 @@ class _MultiDict(MutableMapping, basetypes.Serializable): else: return super(_MultiDict, self).items() - def clear(self, key): - """ - Removes all items with the specified key, and does not raise an - exception if the key does not exist. - """ - if key in self: - del self[key] - def collect(self): """ Returns a list of (key, value) tuples, where values are either - singular if threre is only one matching item for a key, or a list + singular if there is only one matching item for a key, or a list if there are more than one. The order of the keys matches the order in the underlying fields list. """ @@ -204,18 +196,16 @@ class _MultiDict(MutableMapping, basetypes.Serializable): .. code-block:: python # Simple dict with duplicate values. - >>> d - MultiDictView[("name", "value"), ("a", "false"), ("a", "42")] + >>> d = MultiDict([("name", "value"), ("a", False), ("a", 42)]) >>> d.to_dict() { "name": "value", - "a": ["false", "42"] + "a": [False, 42] } """ - d = {} - for k, v in self.collect(): - d[k] = v - return d + return { + k: v for k, v in self.collect() + } def get_state(self): return self.fields @@ -307,4 +297,4 @@ class MultiDictView(_MultiDict): @fields.setter def fields(self, value): - return self._setter(value) + self._setter(value) diff --git a/netlib/strutils.py b/netlib/strutils.py index 5ad41c7e..32e77927 100644 --- a/netlib/strutils.py +++ b/netlib/strutils.py @@ -1,5 +1,5 @@ +from __future__ import absolute_import, print_function, division import re -import unicodedata import codecs import six @@ -20,68 +20,80 @@ def native(s, *encoding_opts): """ if not isinstance(s, (six.binary_type, six.text_type)): raise TypeError("%r is neither bytes nor unicode" % s) - if six.PY3: - if isinstance(s, six.binary_type): - return s.decode(*encoding_opts) - else: + if six.PY2: if isinstance(s, six.text_type): return s.encode(*encoding_opts) + else: + if isinstance(s, six.binary_type): + return s.decode(*encoding_opts) return s -def clean_bin(s, keep_spacing=True): - """ - Cleans binary data to make it safe to display. +# Translate control characters to "safe" characters. This implementation initially +# replaced them with the matching control pictures (http://unicode.org/charts/PDF/U2400.pdf), +# but that turned out to render badly with monospace fonts. We are back to "." therefore. +_control_char_trans = { + x: ord(".") # x + 0x2400 for unicode control group pictures + for x in range(32) +} +_control_char_trans[127] = ord(".") # 0x2421 +_control_char_trans_newline = _control_char_trans.copy() +for x in ("\r", "\n", "\t"): + del _control_char_trans_newline[ord(x)] - Args: - keep_spacing: If False, tabs and newlines will also be replaced. - """ - if isinstance(s, six.text_type): - if keep_spacing: - keep = u" \n\r\t" - else: - keep = u" " - return u"".join( - ch if (unicodedata.category(ch)[0] not in "CZ" or ch in keep) else u"." - for ch in s - ) - else: - if keep_spacing: - keep = (9, 10, 13) # \t, \n, \r, - else: - keep = () - return b"".join( - six.int2byte(ch) if (31 < ch < 127 or ch in keep) else b"." - for ch in six.iterbytes(s) - ) +if six.PY2: + pass +else: + _control_char_trans = str.maketrans(_control_char_trans) + _control_char_trans_newline = str.maketrans(_control_char_trans_newline) -def safe_subn(pattern, repl, target, *args, **kwargs): + +def escape_control_characters(text, keep_spacing=True): """ - There are Unicode conversion problems with re.subn. We try to smooth - that over by casting the pattern and replacement to strings. We really - need a better solution that is aware of the actual content ecoding. + Replace all unicode C1 control characters from the given text with their respective control pictures. + For example, a null byte is replaced with the unicode character "\u2400". + + Args: + keep_spacing: If True, tabs and newlines will not be replaced. """ - return re.subn(str(pattern), str(repl), target, *args, **kwargs) + # type: (six.string_types) -> six.text_type + if not isinstance(text, six.string_types): + raise ValueError("text type must be unicode but is {}".format(type(text).__name__)) + + trans = _control_char_trans_newline if keep_spacing else _control_char_trans + if six.PY2: + return u"".join( + six.unichr(trans.get(ord(ch), ord(ch))) + for ch in text + ) + return text.translate(trans) -def bytes_to_escaped_str(data): +def bytes_to_escaped_str(data, keep_spacing=False): """ Take bytes and return a safe string that can be displayed to the user. Single quotes are always escaped, double quotes are never escaped: "'" + bytes_to_escaped_str(...) + "'" gives a valid Python string. + + Args: + keep_spacing: If True, tabs and newlines will not be escaped. """ - # TODO: We may want to support multi-byte characters without escaping them. - # One way to do would be calling .decode("utf8", "backslashreplace") first - # and then escaping UTF8 control chars (see clean_bin). if not isinstance(data, bytes): raise ValueError("data must be bytes, but is {}".format(data.__class__.__name__)) # We always insert a double-quote here so that we get a single-quoted string back # https://stackoverflow.com/questions/29019340/why-does-python-use-different-quotes-for-representing-strings-depending-on-their - return repr(b'"' + data).lstrip("b")[2:-1] + ret = repr(b'"' + data).lstrip("b")[2:-1] + if keep_spacing: + ret = re.sub( + r"(?<!\\)(\\\\)*\\([nrt])", + lambda m: (m.group(1) or "") + dict(n="\n", r="\r", t="\t")[m.group(2)], + ret + ) + return ret def escaped_str_to_bytes(data): @@ -103,24 +115,17 @@ def escaped_str_to_bytes(data): return codecs.escape_decode(data)[0] -def isBin(s): - """ - Does this string have any non-ASCII characters? - """ - for i in s: - i = ord(i) - if i < 9 or 13 < i < 32 or 126 < i: - return True - return False - - -def isMostlyBin(s): - s = s[:100] - return sum(isBin(ch) for ch in s) / len(s) > 0.3 +def is_mostly_bin(s): + # type: (bytes) -> bool + return sum( + i < 9 or 13 < i < 32 or 126 < i + for i in six.iterbytes(s[:100]) + ) / len(s[:100]) > 0.3 -def isXML(s): - return s.strip().startswith("<") +def is_xml(s): + # type: (bytes) -> bool + return s.strip().startswith(b"<") def clean_hanging_newline(t): @@ -141,8 +146,12 @@ def hexdump(s): A generator of (offset, hex, str) tuples """ for i in range(0, len(s), 16): - offset = "{:0=10x}".format(i).encode() + offset = "{:0=10x}".format(i) part = s[i:i + 16] - x = b" ".join("{:0=2x}".format(i).encode() for i in six.iterbytes(part)) + x = " ".join("{:0=2x}".format(i) for i in six.iterbytes(part)) x = x.ljust(47) # 16*2 + 15 - yield (offset, x, clean_bin(part, False)) + part_repr = native(escape_control_characters( + part.decode("ascii", "replace").replace(u"\ufffd", u"."), + False + )) + yield (offset, x, part_repr) diff --git a/netlib/tcp.py b/netlib/tcp.py index 69dafc1f..cf099edd 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -676,7 +676,7 @@ class TCPClient(_Connection): self.connection = SSL.Connection(context, self.connection) if sni: self.sni = sni - self.connection.set_tlsext_host_name(sni) + self.connection.set_tlsext_host_name(sni.encode("idna")) self.connection.set_connect_state() try: self.connection.do_handshake() @@ -705,7 +705,7 @@ class TCPClient(_Connection): if self.cert.cn: crt["subject"] = [[["commonName", self.cert.cn.decode("ascii", "strict")]]] if sni: - hostname = sni.decode("ascii", "strict") + hostname = sni else: hostname = "no-hostname" ssl_match_hostname.match_hostname(crt, hostname) diff --git a/netlib/utils.py b/netlib/utils.py index 79340cbd..9eebf22c 100644 --- a/netlib/utils.py +++ b/netlib/utils.py @@ -56,6 +56,13 @@ class Data(object): dirname = os.path.dirname(inspect.getsourcefile(m)) self.dirname = os.path.abspath(dirname) + def push(self, subpath): + """ + Change the data object to a path relative to the module. + """ + self.dirname = os.path.join(self.dirname, subpath) + return self + def path(self, path): """ Returns a path to the package data housed at 'path' under this @@ -73,11 +80,9 @@ _label_valid = re.compile(b"(?!-)[A-Z\d-]{1,63}(?<!-)$", re.IGNORECASE) def is_valid_host(host): + # type: (bytes) -> bool """ Checks if a hostname is valid. - - Args: - host (bytes): The hostname """ try: host.decode("idna") diff --git a/netlib/websockets/frame.py b/netlib/websockets/frame.py index 42196ffb..7d355699 100644 --- a/netlib/websockets/frame.py +++ b/netlib/websockets/frame.py @@ -255,7 +255,7 @@ class Frame(object): def __repr__(self): ret = repr(self.header) if self.payload: - ret = ret + "\nPayload:\n" + strutils.clean_bin(self.payload).decode("ascii") + ret = ret + "\nPayload:\n" + strutils.bytes_to_escaped_str(self.payload) return ret def human_readable(self): diff --git a/netlib/wsgi.py b/netlib/wsgi.py index c66fddc2..0def75b5 100644 --- a/netlib/wsgi.py +++ b/netlib/wsgi.py @@ -54,6 +54,10 @@ class WSGIAdaptor(object): self.app, self.domain, self.port, self.sversion = app, domain, port, sversion def make_environ(self, flow, errsoc, **extra): + """ + Raises: + ValueError, if the content-encoding is invalid. + """ path = strutils.native(flow.request.path, "latin-1") if '?' in path: path_info, query = strutils.native(path, "latin-1").split('?', 1) |