diff options
author | Aldo Cortesi <aldo@nullcube.com> | 2016-10-20 11:56:38 +1300 |
---|---|---|
committer | Aldo Cortesi <aldo@nullcube.com> | 2016-10-20 11:56:38 +1300 |
commit | 8430f857b504a3e7406dc36e54dc32783569d0dd (patch) | |
tree | d3116cd540faf01f272a0892fc6a9b83b4f6de8a /netlib | |
parent | 853e03a5e753354fad3a3fa5384ef3a09384ef43 (diff) | |
download | mitmproxy-8430f857b504a3e7406dc36e54dc32783569d0dd.tar.gz mitmproxy-8430f857b504a3e7406dc36e54dc32783569d0dd.tar.bz2 mitmproxy-8430f857b504a3e7406dc36e54dc32783569d0dd.zip |
The final piece: netlib -> mitproxy.net
Diffstat (limited to 'netlib')
-rw-r--r-- | netlib/__init__.py | 0 | ||||
-rw-r--r-- | netlib/check.py | 22 | ||||
-rw-r--r-- | netlib/http/__init__.py | 15 | ||||
-rw-r--r-- | netlib/http/authentication.py | 176 | ||||
-rw-r--r-- | netlib/http/cookies.py | 384 | ||||
-rw-r--r-- | netlib/http/encoding.py | 175 | ||||
-rw-r--r-- | netlib/http/headers.py | 221 | ||||
-rw-r--r-- | netlib/http/http1/__init__.py | 24 | ||||
-rw-r--r-- | netlib/http/http1/assemble.py | 100 | ||||
-rw-r--r-- | netlib/http/http1/read.py | 377 | ||||
-rw-r--r-- | netlib/http/http2/__init__.py | 8 | ||||
-rw-r--r-- | netlib/http/http2/framereader.py | 25 | ||||
-rw-r--r-- | netlib/http/http2/utils.py | 37 | ||||
-rw-r--r-- | netlib/http/message.py | 300 | ||||
-rw-r--r-- | netlib/http/multipart.py | 32 | ||||
-rw-r--r-- | netlib/http/request.py | 405 | ||||
-rw-r--r-- | netlib/http/response.py | 192 | ||||
-rw-r--r-- | netlib/http/status_codes.py | 104 | ||||
-rw-r--r-- | netlib/http/url.py | 127 | ||||
-rw-r--r-- | netlib/http/user_agents.py | 50 | ||||
-rw-r--r-- | netlib/socks.py | 234 | ||||
-rw-r--r-- | netlib/tcp.py | 989 | ||||
-rw-r--r-- | netlib/websockets/__init__.py | 35 | ||||
-rw-r--r-- | netlib/websockets/frame.py | 274 | ||||
-rw-r--r-- | netlib/websockets/masker.py | 25 | ||||
-rw-r--r-- | netlib/websockets/utils.py | 90 | ||||
-rw-r--r-- | netlib/wsgi.py | 166 |
27 files changed, 0 insertions, 4587 deletions
diff --git a/netlib/__init__.py b/netlib/__init__.py deleted file mode 100644 index e69de29b..00000000 --- a/netlib/__init__.py +++ /dev/null diff --git a/netlib/check.py b/netlib/check.py deleted file mode 100644 index 7b007cb5..00000000 --- a/netlib/check.py +++ /dev/null @@ -1,22 +0,0 @@ -import re - -_label_valid = re.compile(b"(?!-)[A-Z\d-]{1,63}(?<!-)$", re.IGNORECASE) - - -def is_valid_host(host: bytes) -> bool: - """ - Checks if a hostname is valid. - """ - try: - host.decode("idna") - except ValueError: - return False - if len(host) > 255: - return False - if host and host[-1:] == b".": - host = host[:-1] - return all(_label_valid.match(x) for x in host.split(b".")) - - -def is_valid_port(port): - return 0 <= port <= 65535 diff --git a/netlib/http/__init__.py b/netlib/http/__init__.py deleted file mode 100644 index 315f61ac..00000000 --- a/netlib/http/__init__.py +++ /dev/null @@ -1,15 +0,0 @@ -from netlib.http.request import Request -from netlib.http.response import Response -from netlib.http.message import Message -from netlib.http.headers import Headers, parse_content_type -from netlib.http.message import decoded -from netlib.http import http1, http2, status_codes, multipart - -__all__ = [ - "Request", - "Response", - "Message", - "Headers", "parse_content_type", - "decoded", - "http1", "http2", "status_codes", "multipart", -] diff --git a/netlib/http/authentication.py b/netlib/http/authentication.py deleted file mode 100644 index a65279e4..00000000 --- a/netlib/http/authentication.py +++ /dev/null @@ -1,176 +0,0 @@ -import argparse -import binascii - - -def parse_http_basic_auth(s): - words = s.split() - if len(words) != 2: - return None - scheme = words[0] - try: - user = binascii.a2b_base64(words[1]).decode("utf8", "replace") - except binascii.Error: - return None - parts = user.split(':') - if len(parts) != 2: - return None - return scheme, parts[0], parts[1] - - -def assemble_http_basic_auth(scheme, username, password): - v = binascii.b2a_base64((username + ":" + password).encode("utf8")).decode("ascii") - return scheme + " " + v - - -class NullProxyAuth: - - """ - No proxy auth at all (returns empty challange headers) - """ - - def __init__(self, password_manager): - self.password_manager = password_manager - - def clean(self, headers_): - """ - Clean up authentication headers, so they're not passed upstream. - """ - - def authenticate(self, headers_): - """ - Tests that the user is allowed to use the proxy - """ - return True - - def auth_challenge_headers(self): - """ - Returns a dictionary containing the headers require to challenge the user - """ - return {} - - -class BasicAuth(NullProxyAuth): - CHALLENGE_HEADER = None - AUTH_HEADER = None - - def __init__(self, password_manager, realm): - NullProxyAuth.__init__(self, password_manager) - self.realm = realm - - def clean(self, headers): - del headers[self.AUTH_HEADER] - - def authenticate(self, headers): - auth_value = headers.get(self.AUTH_HEADER) - if not auth_value: - return False - parts = parse_http_basic_auth(auth_value) - if not parts: - return False - scheme, username, password = parts - if scheme.lower() != 'basic': - return False - if not self.password_manager.test(username, password): - return False - self.username = username - return True - - def auth_challenge_headers(self): - return {self.CHALLENGE_HEADER: 'Basic realm="%s"' % self.realm} - - -class BasicWebsiteAuth(BasicAuth): - CHALLENGE_HEADER = 'WWW-Authenticate' - AUTH_HEADER = 'Authorization' - - -class BasicProxyAuth(BasicAuth): - CHALLENGE_HEADER = 'Proxy-Authenticate' - AUTH_HEADER = 'Proxy-Authorization' - - -class PassMan: - - def test(self, username_, password_token_): - return False - - -class PassManNonAnon(PassMan): - - """ - Ensure the user specifies a username, accept any password. - """ - - def test(self, username, password_token_): - if username: - return True - return False - - -class PassManHtpasswd(PassMan): - - """ - Read usernames and passwords from an htpasswd file - """ - - def __init__(self, path): - """ - Raises ValueError if htpasswd file is invalid. - """ - import passlib.apache - self.htpasswd = passlib.apache.HtpasswdFile(path) - - def test(self, username, password_token): - return bool(self.htpasswd.check_password(username, password_token)) - - -class PassManSingleUser(PassMan): - - def __init__(self, username, password): - self.username, self.password = username, password - - def test(self, username, password_token): - return self.username == username and self.password == password_token - - -class AuthAction(argparse.Action): - - """ - Helper class to allow seamless integration int argparse. Example usage: - parser.add_argument( - "--nonanonymous", - action=NonanonymousAuthAction, nargs=0, - help="Allow access to any user long as a credentials are specified." - ) - """ - - def __call__(self, parser, namespace, values, option_string=None): - passman = self.getPasswordManager(values) - authenticator = BasicProxyAuth(passman, "mitmproxy") - setattr(namespace, self.dest, authenticator) - - def getPasswordManager(self, s): # pragma: no cover - raise NotImplementedError() - - -class SingleuserAuthAction(AuthAction): - - def getPasswordManager(self, s): - if len(s.split(':')) != 2: - raise argparse.ArgumentTypeError( - "Invalid single-user specification. Please use the format username:password" - ) - username, password = s.split(':') - return PassManSingleUser(username, password) - - -class NonanonymousAuthAction(AuthAction): - - def getPasswordManager(self, s): - return PassManNonAnon() - - -class HtpasswdAuthAction(AuthAction): - - def getPasswordManager(self, s): - return PassManHtpasswd(s) diff --git a/netlib/http/cookies.py b/netlib/http/cookies.py deleted file mode 100644 index 9f32fa5e..00000000 --- a/netlib/http/cookies.py +++ /dev/null @@ -1,384 +0,0 @@ -import collections -import email.utils -import re -import time - -from mitmproxy.types import multidict - -""" -A flexible module for cookie parsing and manipulation. - -This module differs from usual standards-compliant cookie modules in a number -of ways. We try to be as permissive as possible, and to retain even mal-formed -information. Duplicate cookies are preserved in parsing, and can be set in -formatting. We do attempt to escape and quote values where needed, but will not -reject data that violate the specs. - -Parsing accepts the formats in RFC6265 and partially RFC2109 and RFC2965. We -also parse the comma-separated variant of Set-Cookie that allows multiple -cookies to be set in a single header. Serialization follows RFC6265. - - http://tools.ietf.org/html/rfc6265 - http://tools.ietf.org/html/rfc2109 - http://tools.ietf.org/html/rfc2965 -""" - -_cookie_params = set(( - 'expires', 'path', 'comment', 'max-age', - 'secure', 'httponly', 'version', -)) - -ESCAPE = re.compile(r"([\"\\])") - - -class CookieAttrs(multidict.ImmutableMultiDict): - @staticmethod - def _kconv(key): - return key.lower() - - @staticmethod - def _reduce_values(values): - # See the StickyCookieTest for a weird cookie that only makes sense - # if we take the last part. - return values[-1] - -SetCookie = collections.namedtuple("SetCookie", ["value", "attrs"]) - - -def _read_until(s, start, term): - """ - Read until one of the characters in term is reached. - """ - if start == len(s): - return "", start + 1 - for i in range(start, len(s)): - if s[i] in term: - return s[start:i], i - return s[start:i + 1], i + 1 - - -def _read_quoted_string(s, start): - """ - start: offset to the first quote of the string to be read - - A sort of loose super-set of the various quoted string specifications. - - RFC6265 disallows backslashes or double quotes within quoted strings. - Prior RFCs use backslashes to escape. This leaves us free to apply - backslash escaping by default and be compatible with everything. - """ - escaping = False - ret = [] - # Skip the first quote - i = start # initialize in case the loop doesn't run. - for i in range(start + 1, len(s)): - if escaping: - ret.append(s[i]) - escaping = False - elif s[i] == '"': - break - elif s[i] == "\\": - escaping = True - else: - ret.append(s[i]) - return "".join(ret), i + 1 - - -def _read_key(s, start, delims=";="): - """ - Read a key - the LHS of a token/value pair in a cookie. - """ - return _read_until(s, start, delims) - - -def _read_value(s, start, delims): - """ - Reads a value - the RHS of a token/value pair in a cookie. - """ - if start >= len(s): - return "", start - elif s[start] == '"': - return _read_quoted_string(s, start) - else: - return _read_until(s, start, delims) - - -def _read_cookie_pairs(s, off=0): - """ - Read pairs of lhs=rhs values from Cookie headers. - - off: start offset - """ - pairs = [] - - while True: - lhs, off = _read_key(s, off) - lhs = lhs.lstrip() - - if lhs: - rhs = None - if off < len(s) and s[off] == "=": - rhs, off = _read_value(s, off + 1, ";") - - pairs.append([lhs, rhs]) - - off += 1 - - if not off < len(s): - break - - return pairs, off - - -def _read_set_cookie_pairs(s, off=0): - """ - Read pairs of lhs=rhs values from SetCookie headers while handling multiple cookies. - - off: start offset - specials: attributes that are treated specially - """ - cookies = [] - pairs = [] - - while True: - lhs, off = _read_key(s, off, ";=,") - lhs = lhs.lstrip() - - if lhs: - rhs = None - if off < len(s) and s[off] == "=": - rhs, off = _read_value(s, off + 1, ";,") - - # Special handliing of attributes - if lhs.lower() == "expires": - # 'expires' values can contain commas in them so they need to - # be handled separately. - - # We actually bank on the fact that the expires value WILL - # contain a comma. Things will fail, if they don't. - - # '3' is just a heuristic we use to determine whether we've - # only read a part of the expires value and we should read more. - if len(rhs) <= 3: - trail, off = _read_value(s, off + 1, ";,") - rhs = rhs + "," + trail - - pairs.append([lhs, rhs]) - - # comma marks the beginning of a new cookie - if off < len(s) and s[off] == ",": - cookies.append(pairs) - pairs = [] - - off += 1 - - if not off < len(s): - break - - if pairs or not cookies: - cookies.append(pairs) - - return cookies, off - - -def _has_special(s): - for i in s: - if i in '",;\\': - return True - o = ord(i) - if o < 0x21 or o > 0x7e: - return True - return False - - -def _format_pairs(pairs, specials=(), sep="; "): - """ - specials: A lower-cased list of keys that will not be quoted. - """ - vals = [] - for k, v in pairs: - if v is None: - vals.append(k) - else: - if k.lower() not in specials and _has_special(v): - v = ESCAPE.sub(r"\\\1", v) - v = '"%s"' % v - vals.append("%s=%s" % (k, v)) - return sep.join(vals) - - -def _format_set_cookie_pairs(lst): - return _format_pairs( - lst, - specials=("expires", "path") - ) - - -def parse_cookie_header(line): - """ - Parse a Cookie header value. - Returns a list of (lhs, rhs) tuples. - """ - pairs, off_ = _read_cookie_pairs(line) - return pairs - - -def parse_cookie_headers(cookie_headers): - cookie_list = [] - for header in cookie_headers: - cookie_list.extend(parse_cookie_header(header)) - return cookie_list - - -def format_cookie_header(lst): - """ - Formats a Cookie header value. - """ - return _format_pairs(lst) - - -def parse_set_cookie_header(line): - """ - Parse a Set-Cookie header value - - Returns a list of (name, value, attrs) tuples, where attrs is a - CookieAttrs dict of attributes. No attempt is made to parse attribute - values - they are treated purely as strings. - """ - cookie_pairs, off = _read_set_cookie_pairs(line) - cookies = [ - (pairs[0][0], pairs[0][1], CookieAttrs(tuple(x) for x in pairs[1:])) - for pairs in cookie_pairs if pairs - ] - return cookies - - -def parse_set_cookie_headers(headers): - rv = [] - for header in headers: - cookies = parse_set_cookie_header(header) - if cookies: - for name, value, attrs in cookies: - rv.append((name, SetCookie(value, attrs))) - return rv - - -def format_set_cookie_header(set_cookies): - """ - Formats a Set-Cookie header value. - """ - - rv = [] - - for set_cookie in set_cookies: - name, value, attrs = set_cookie - - pairs = [(name, value)] - pairs.extend( - attrs.fields if hasattr(attrs, "fields") else attrs - ) - - rv.append(_format_set_cookie_pairs(pairs)) - - return ", ".join(rv) - - -def refresh_set_cookie_header(c, delta): - """ - Args: - c: A Set-Cookie string - delta: Time delta in seconds - Returns: - A refreshed Set-Cookie string - """ - - name, value, attrs = parse_set_cookie_header(c)[0] - if not name or not value: - raise ValueError("Invalid Cookie") - - if "expires" in attrs: - e = email.utils.parsedate_tz(attrs["expires"]) - if e: - f = email.utils.mktime_tz(e) + delta - attrs = attrs.with_set_all("expires", [email.utils.formatdate(f)]) - else: - # This can happen when the expires tag is invalid. - # reddit.com sends a an expires tag like this: "Thu, 31 Dec - # 2037 23:59:59 GMT", which is valid RFC 1123, but not - # strictly correct according to the cookie spec. Browsers - # appear to parse this tolerantly - maybe we should too. - # For now, we just ignore this. - attrs = attrs.with_delitem("expires") - - rv = format_set_cookie_header([(name, value, attrs)]) - if not rv: - raise ValueError("Invalid Cookie") - return rv - - -def get_expiration_ts(cookie_attrs): - """ - Determines the time when the cookie will be expired. - - Considering both 'expires' and 'max-age' parameters. - - Returns: timestamp of when the cookie will expire. - None, if no expiration time is set. - """ - if 'expires' in cookie_attrs: - e = email.utils.parsedate_tz(cookie_attrs["expires"]) - if e: - return email.utils.mktime_tz(e) - - elif 'max-age' in cookie_attrs: - try: - max_age = int(cookie_attrs['Max-Age']) - except ValueError: - pass - else: - now_ts = time.time() - return now_ts + max_age - - return None - - -def is_expired(cookie_attrs): - """ - Determines whether a cookie has expired. - - Returns: boolean - """ - - exp_ts = get_expiration_ts(cookie_attrs) - now_ts = time.time() - - # If no expiration information was provided with the cookie - if exp_ts is None: - return False - else: - return exp_ts <= now_ts - - -def group_cookies(pairs): - """ - Converts a list of pairs to a (name, value, attrs) for each cookie. - """ - - if not pairs: - return [] - - cookie_list = [] - - # First pair is always a new cookie - name, value = pairs[0] - attrs = [] - - for k, v in pairs[1:]: - if k.lower() in _cookie_params: - attrs.append((k, v)) - else: - cookie_list.append((name, value, CookieAttrs(attrs))) - name, value, attrs = k, v, [] - - cookie_list.append((name, value, CookieAttrs(attrs))) - return cookie_list diff --git a/netlib/http/encoding.py b/netlib/http/encoding.py deleted file mode 100644 index e123a033..00000000 --- a/netlib/http/encoding.py +++ /dev/null @@ -1,175 +0,0 @@ -""" -Utility functions for decoding response bodies. -""" - -import codecs -import collections -from io import BytesIO - -import gzip -import zlib -import brotli - -from typing import Union - - -# We have a shared single-element cache for encoding and decoding. -# This is quite useful in practice, e.g. -# flow.request.content = flow.request.content.replace(b"foo", b"bar") -# does not require an .encode() call if content does not contain b"foo" -CachedDecode = collections.namedtuple("CachedDecode", "encoded encoding errors decoded") -_cache = CachedDecode(None, None, None, None) - - -def decode(encoded: Union[str, bytes], encoding: str, errors: str='strict') -> Union[str, bytes]: - """ - Decode the given input object - - Returns: - The decoded value - - Raises: - ValueError, if decoding fails. - """ - if len(encoded) == 0: - return encoded - - global _cache - cached = ( - isinstance(encoded, bytes) and - _cache.encoded == encoded and - _cache.encoding == encoding and - _cache.errors == errors - ) - if cached: - return _cache.decoded - try: - try: - decoded = custom_decode[encoding](encoded) - except KeyError: - decoded = codecs.decode(encoded, encoding, errors) - if encoding in ("gzip", "deflate", "br"): - _cache = CachedDecode(encoded, encoding, errors, decoded) - return decoded - except TypeError: - raise - except Exception as e: - raise ValueError("{} when decoding {} with {}: {}".format( - type(e).__name__, - repr(encoded)[:10], - repr(encoding), - repr(e), - )) - - -def encode(decoded: Union[str, bytes], encoding: str, errors: str='strict') -> Union[str, bytes]: - """ - Encode the given input object - - Returns: - The encoded value - - Raises: - ValueError, if encoding fails. - """ - if len(decoded) == 0: - return decoded - - global _cache - cached = ( - isinstance(decoded, bytes) and - _cache.decoded == decoded and - _cache.encoding == encoding and - _cache.errors == errors - ) - if cached: - return _cache.encoded - try: - try: - value = decoded - if isinstance(value, str): - value = decoded.encode() - encoded = custom_encode[encoding](value) - except KeyError: - encoded = codecs.encode(decoded, encoding, errors) - if encoding in ("gzip", "deflate", "br"): - _cache = CachedDecode(encoded, encoding, errors, decoded) - return encoded - except TypeError: - raise - except Exception as e: - raise ValueError("{} when encoding {} with {}: {}".format( - type(e).__name__, - repr(decoded)[:10], - repr(encoding), - repr(e), - )) - - -def identity(content): - """ - Returns content unchanged. Identity is the default value of - Accept-Encoding headers. - """ - return content - - -def decode_gzip(content): - gfile = gzip.GzipFile(fileobj=BytesIO(content)) - return gfile.read() - - -def encode_gzip(content): - s = BytesIO() - gf = gzip.GzipFile(fileobj=s, mode='wb') - gf.write(content) - gf.close() - return s.getvalue() - - -def decode_brotli(content): - return brotli.decompress(content) - - -def encode_brotli(content): - return brotli.compress(content) - - -def decode_deflate(content): - """ - Returns decompressed data for DEFLATE. Some servers may respond with - compressed data without a zlib header or checksum. An undocumented - feature of zlib permits the lenient decompression of data missing both - values. - - http://bugs.python.org/issue5784 - """ - try: - return zlib.decompress(content) - except zlib.error: - return zlib.decompress(content, -15) - - -def encode_deflate(content): - """ - Returns compressed content, always including zlib header and checksum. - """ - return zlib.compress(content) - - -custom_decode = { - "none": identity, - "identity": identity, - "gzip": decode_gzip, - "deflate": decode_deflate, - "br": decode_brotli, -} -custom_encode = { - "none": identity, - "identity": identity, - "gzip": encode_gzip, - "deflate": encode_deflate, - "br": encode_brotli, -} - -__all__ = ["encode", "decode"] diff --git a/netlib/http/headers.py b/netlib/http/headers.py deleted file mode 100644 index 8fc0cd43..00000000 --- a/netlib/http/headers.py +++ /dev/null @@ -1,221 +0,0 @@ -import re - -import collections -from mitmproxy.types import multidict -from mitmproxy.utils import strutils - -# See also: http://lucumr.pocoo.org/2013/7/2/the-updated-guide-to-unicode/ - - -# While headers _should_ be ASCII, it's not uncommon for certain headers to be utf-8 encoded. -def _native(x): - return x.decode("utf-8", "surrogateescape") - - -def _always_bytes(x): - return strutils.always_bytes(x, "utf-8", "surrogateescape") - - -class Headers(multidict.MultiDict): - """ - Header class which allows both convenient access to individual headers as well as - direct access to the underlying raw data. Provides a full dictionary interface. - - Example: - - .. code-block:: python - - # Create headers with keyword arguments - >>> h = Headers(host="example.com", content_type="application/xml") - - # Headers mostly behave like a normal dict. - >>> h["Host"] - "example.com" - - # HTTP Headers are case insensitive - >>> h["host"] - "example.com" - - # Headers can also be created from a list of raw (header_name, header_value) byte tuples - >>> h = Headers([ - (b"Host",b"example.com"), - (b"Accept",b"text/html"), - (b"accept",b"application/xml") - ]) - - # Multiple headers are folded into a single header as per RFC7230 - >>> h["Accept"] - "text/html, application/xml" - - # Setting a header removes all existing headers with the same name. - >>> h["Accept"] = "application/text" - >>> h["Accept"] - "application/text" - - # bytes(h) returns a HTTP1 header block. - >>> print(bytes(h)) - Host: example.com - Accept: application/text - - # For full control, the raw header fields can be accessed - >>> h.fields - - Caveats: - For use with the "Set-Cookie" header, see :py:meth:`get_all`. - """ - - def __init__(self, fields=(), **headers): - """ - Args: - fields: (optional) list of ``(name, value)`` header byte tuples, - e.g. ``[(b"Host", b"example.com")]``. All names and values must be bytes. - **headers: Additional headers to set. Will overwrite existing values from `fields`. - For convenience, underscores in header names will be transformed to dashes - - this behaviour does not extend to other methods. - If ``**headers`` contains multiple keys that have equal ``.lower()`` s, - the behavior is undefined. - """ - super().__init__(fields) - - for key, value in self.fields: - if not isinstance(key, bytes) or not isinstance(value, bytes): - raise TypeError("Header fields must be bytes.") - - # content_type -> content-type - headers = { - _always_bytes(name).replace(b"_", b"-"): _always_bytes(value) - for name, value in headers.items() - } - self.update(headers) - - @staticmethod - def _reduce_values(values): - # Headers can be folded - return ", ".join(values) - - @staticmethod - def _kconv(key): - # Headers are case-insensitive - return key.lower() - - def __bytes__(self): - if self.fields: - return b"\r\n".join(b": ".join(field) for field in self.fields) + b"\r\n" - else: - return b"" - - def __delitem__(self, key): - key = _always_bytes(key) - super().__delitem__(key) - - def __iter__(self): - for x in super().__iter__(): - yield _native(x) - - def get_all(self, name): - """ - Like :py:meth:`get`, but does not fold multiple headers into a single one. - This is useful for Set-Cookie headers, which do not support folding. - See also: https://tools.ietf.org/html/rfc7230#section-3.2.2 - """ - name = _always_bytes(name) - return [ - _native(x) for x in - super().get_all(name) - ] - - def set_all(self, name, values): - """ - Explicitly set multiple headers for the given key. - See: :py:meth:`get_all` - """ - name = _always_bytes(name) - values = [_always_bytes(x) for x in values] - return super().set_all(name, values) - - def insert(self, index, key, value): - key = _always_bytes(key) - value = _always_bytes(value) - super().insert(index, key, value) - - def items(self, multi=False): - if multi: - return ( - (_native(k), _native(v)) - for k, v in self.fields - ) - else: - return super().items() - - def replace(self, pattern, repl, flags=0, count=0): - """ - Replaces a regular expression pattern with repl in each "name: value" - header line. - - Returns: - The number of replacements made. - """ - if isinstance(pattern, str): - pattern = strutils.escaped_str_to_bytes(pattern) - if isinstance(repl, str): - repl = strutils.escaped_str_to_bytes(repl) - pattern = re.compile(pattern, flags) - replacements = 0 - flag_count = count > 0 - fields = [] - for name, value in self.fields: - line, n = pattern.subn(repl, name + b": " + value, count=count) - try: - name, value = line.split(b": ", 1) - except ValueError: - # We get a ValueError if the replacement removed the ": " - # There's not much we can do about this, so we just keep the header as-is. - pass - else: - replacements += n - if flag_count: - count -= n - if count == 0: - break - fields.append((name, value)) - self.fields = tuple(fields) - return replacements - - -def parse_content_type(c): - """ - A simple parser for content-type values. Returns a (type, subtype, - parameters) tuple, where type and subtype are strings, and parameters - is a dict. If the string could not be parsed, return None. - - E.g. the following string: - - text/html; charset=UTF-8 - - Returns: - - ("text", "html", {"charset": "UTF-8"}) - """ - parts = c.split(";", 1) - ts = parts[0].split("/", 1) - if len(ts) != 2: - return None - d = collections.OrderedDict() - if len(parts) == 2: - for i in parts[1].split(";"): - clause = i.split("=", 1) - if len(clause) == 2: - d[clause[0].strip()] = clause[1].strip() - return ts[0].lower(), ts[1].lower(), d - - -def assemble_content_type(type, subtype, parameters): - if not parameters: - return "{}/{}".format(type, subtype) - params = "; ".join( - "{}={}".format(k, v) - for k, v in parameters.items() - ) - return "{}/{}; {}".format( - type, subtype, params - ) diff --git a/netlib/http/http1/__init__.py b/netlib/http/http1/__init__.py deleted file mode 100644 index e4bf01c5..00000000 --- a/netlib/http/http1/__init__.py +++ /dev/null @@ -1,24 +0,0 @@ -from .read import ( - read_request, read_request_head, - read_response, read_response_head, - read_body, - connection_close, - expected_http_body_size, -) -from .assemble import ( - assemble_request, assemble_request_head, - assemble_response, assemble_response_head, - assemble_body, -) - - -__all__ = [ - "read_request", "read_request_head", - "read_response", "read_response_head", - "read_body", - "connection_close", - "expected_http_body_size", - "assemble_request", "assemble_request_head", - "assemble_response", "assemble_response_head", - "assemble_body", -] diff --git a/netlib/http/http1/assemble.py b/netlib/http/http1/assemble.py deleted file mode 100644 index e0a91ad8..00000000 --- a/netlib/http/http1/assemble.py +++ /dev/null @@ -1,100 +0,0 @@ -import netlib.http.url -from mitmproxy import exceptions - - -def assemble_request(request): - if request.data.content is None: - raise exceptions.HttpException("Cannot assemble flow with missing content") - head = assemble_request_head(request) - body = b"".join(assemble_body(request.data.headers, [request.data.content])) - return head + body - - -def assemble_request_head(request): - first_line = _assemble_request_line(request.data) - headers = _assemble_request_headers(request.data) - return b"%s\r\n%s\r\n" % (first_line, headers) - - -def assemble_response(response): - if response.data.content is None: - raise exceptions.HttpException("Cannot assemble flow with missing content") - head = assemble_response_head(response) - body = b"".join(assemble_body(response.data.headers, [response.data.content])) - return head + body - - -def assemble_response_head(response): - first_line = _assemble_response_line(response.data) - headers = _assemble_response_headers(response.data) - return b"%s\r\n%s\r\n" % (first_line, headers) - - -def assemble_body(headers, body_chunks): - if "chunked" in headers.get("transfer-encoding", "").lower(): - for chunk in body_chunks: - if chunk: - yield b"%x\r\n%s\r\n" % (len(chunk), chunk) - yield b"0\r\n\r\n" - else: - for chunk in body_chunks: - yield chunk - - -def _assemble_request_line(request_data): - """ - Args: - request_data (netlib.http.request.RequestData) - """ - form = request_data.first_line_format - if form == "relative": - return b"%s %s %s" % ( - request_data.method, - request_data.path, - request_data.http_version - ) - elif form == "authority": - return b"%s %s:%d %s" % ( - request_data.method, - request_data.host, - request_data.port, - request_data.http_version - ) - elif form == "absolute": - return b"%s %s://%s:%d%s %s" % ( - request_data.method, - request_data.scheme, - request_data.host, - request_data.port, - request_data.path, - request_data.http_version - ) - else: - raise RuntimeError("Invalid request form") - - -def _assemble_request_headers(request_data): - """ - Args: - request_data (netlib.http.request.RequestData) - """ - headers = request_data.headers.copy() - if "host" not in headers and request_data.scheme and request_data.host and request_data.port: - headers["host"] = netlib.http.url.hostport( - request_data.scheme, - request_data.host, - request_data.port - ) - return bytes(headers) - - -def _assemble_response_line(response_data): - return b"%s %d %s" % ( - response_data.http_version, - response_data.status_code, - response_data.reason, - ) - - -def _assemble_response_headers(response): - return bytes(response.headers) diff --git a/netlib/http/http1/read.py b/netlib/http/http1/read.py deleted file mode 100644 index e6b22863..00000000 --- a/netlib/http/http1/read.py +++ /dev/null @@ -1,377 +0,0 @@ -import time -import sys -import re - -from netlib.http import request -from netlib.http import response -from netlib.http import headers -from netlib.http import url -from netlib import check -from mitmproxy import exceptions - - -def get_header_tokens(headers, key): - """ - Retrieve all tokens for a header key. A number of different headers - follow a pattern where each header line can containe comma-separated - tokens, and headers can be set multiple times. - """ - if key not in headers: - return [] - tokens = headers[key].split(",") - return [token.strip() for token in tokens] - - -def read_request(rfile, body_size_limit=None): - request = read_request_head(rfile) - expected_body_size = expected_http_body_size(request) - request.data.content = b"".join(read_body(rfile, expected_body_size, limit=body_size_limit)) - request.timestamp_end = time.time() - return request - - -def read_request_head(rfile): - """ - Parse an HTTP request head (request line + headers) from an input stream - - Args: - rfile: The input stream - - Returns: - The HTTP request object (without body) - - Raises: - exceptions.HttpReadDisconnect: No bytes can be read from rfile. - exceptions.HttpSyntaxException: The input is malformed HTTP. - exceptions.HttpException: Any other error occured. - """ - timestamp_start = time.time() - if hasattr(rfile, "reset_timestamps"): - rfile.reset_timestamps() - - form, method, scheme, host, port, path, http_version = _read_request_line(rfile) - headers = _read_headers(rfile) - - if hasattr(rfile, "first_byte_timestamp"): - # more accurate timestamp_start - timestamp_start = rfile.first_byte_timestamp - - return request.Request( - form, method, scheme, host, port, path, http_version, headers, None, timestamp_start - ) - - -def read_response(rfile, request, body_size_limit=None): - response = read_response_head(rfile) - expected_body_size = expected_http_body_size(request, response) - response.data.content = b"".join(read_body(rfile, expected_body_size, body_size_limit)) - response.timestamp_end = time.time() - return response - - -def read_response_head(rfile): - """ - Parse an HTTP response head (response line + headers) from an input stream - - Args: - rfile: The input stream - - Returns: - The HTTP request object (without body) - - Raises: - exceptions.HttpReadDisconnect: No bytes can be read from rfile. - exceptions.HttpSyntaxException: The input is malformed HTTP. - exceptions.HttpException: Any other error occured. - """ - - timestamp_start = time.time() - if hasattr(rfile, "reset_timestamps"): - rfile.reset_timestamps() - - http_version, status_code, message = _read_response_line(rfile) - headers = _read_headers(rfile) - - if hasattr(rfile, "first_byte_timestamp"): - # more accurate timestamp_start - timestamp_start = rfile.first_byte_timestamp - - return response.Response(http_version, status_code, message, headers, None, timestamp_start) - - -def read_body(rfile, expected_size, limit=None, max_chunk_size=4096): - """ - Read an HTTP message body - - Args: - rfile: The input stream - expected_size: The expected body size (see :py:meth:`expected_body_size`) - limit: Maximum body size - max_chunk_size: Maximium chunk size that gets yielded - - Returns: - A generator that yields byte chunks of the content. - - Raises: - exceptions.HttpException, if an error occurs - - Caveats: - max_chunk_size is not considered if the transfer encoding is chunked. - """ - if not limit or limit < 0: - limit = sys.maxsize - if not max_chunk_size: - max_chunk_size = limit - - if expected_size is None: - for x in _read_chunked(rfile, limit): - yield x - elif expected_size >= 0: - if limit is not None and expected_size > limit: - raise exceptions.HttpException( - "HTTP Body too large. " - "Limit is {}, content length was advertised as {}".format(limit, expected_size) - ) - bytes_left = expected_size - while bytes_left: - chunk_size = min(bytes_left, max_chunk_size) - content = rfile.read(chunk_size) - if len(content) < chunk_size: - raise exceptions.HttpException("Unexpected EOF") - yield content - bytes_left -= chunk_size - else: - bytes_left = limit - while bytes_left: - chunk_size = min(bytes_left, max_chunk_size) - content = rfile.read(chunk_size) - if not content: - return - yield content - bytes_left -= chunk_size - not_done = rfile.read(1) - if not_done: - raise exceptions.HttpException("HTTP body too large. Limit is {}.".format(limit)) - - -def connection_close(http_version, headers): - """ - Checks the message to see if the client connection should be closed - according to RFC 2616 Section 8.1. - """ - # At first, check if we have an explicit Connection header. - if "connection" in headers: - tokens = get_header_tokens(headers, "connection") - if "close" in tokens: - return True - elif "keep-alive" in tokens: - return False - - # If we don't have a Connection header, HTTP 1.1 connections are assumed to - # be persistent - return http_version != "HTTP/1.1" and http_version != b"HTTP/1.1" # FIXME: Remove one case. - - -def expected_http_body_size(request, response=None): - """ - Returns: - The expected body length: - - a positive integer, if the size is known in advance - - None, if the size in unknown in advance (chunked encoding) - - -1, if all data should be read until end of stream. - - Raises: - exceptions.HttpSyntaxException, if the content length header is invalid - """ - # Determine response size according to - # http://tools.ietf.org/html/rfc7230#section-3.3 - if not response: - headers = request.headers - response_code = None - is_request = True - else: - headers = response.headers - response_code = response.status_code - is_request = False - - if is_request: - if headers.get("expect", "").lower() == "100-continue": - return 0 - else: - if request.method.upper() == "HEAD": - return 0 - if 100 <= response_code <= 199: - return 0 - if response_code == 200 and request.method.upper() == "CONNECT": - return 0 - if response_code in (204, 304): - return 0 - - if "chunked" in headers.get("transfer-encoding", "").lower(): - return None - if "content-length" in headers: - try: - size = int(headers["content-length"]) - if size < 0: - raise ValueError() - return size - except ValueError: - raise exceptions.HttpSyntaxException("Unparseable Content Length") - if is_request: - return 0 - return -1 - - -def _get_first_line(rfile): - try: - line = rfile.readline() - if line == b"\r\n" or line == b"\n": - # Possible leftover from previous message - line = rfile.readline() - except exceptions.TcpDisconnect: - raise exceptions.HttpReadDisconnect("Remote disconnected") - if not line: - raise exceptions.HttpReadDisconnect("Remote disconnected") - return line.strip() - - -def _read_request_line(rfile): - try: - line = _get_first_line(rfile) - except exceptions.HttpReadDisconnect: - # We want to provide a better error message. - raise exceptions.HttpReadDisconnect("Client disconnected") - - try: - method, path, http_version = line.split() - - if path == b"*" or path.startswith(b"/"): - form = "relative" - scheme, host, port = None, None, None - elif method == b"CONNECT": - form = "authority" - host, port = _parse_authority_form(path) - scheme, path = None, None - else: - form = "absolute" - scheme, host, port, path = url.parse(path) - - _check_http_version(http_version) - except ValueError: - raise exceptions.HttpSyntaxException("Bad HTTP request line: {}".format(line)) - - return form, method, scheme, host, port, path, http_version - - -def _parse_authority_form(hostport): - """ - Returns (host, port) if hostport is a valid authority-form host specification. - http://tools.ietf.org/html/draft-luotonen-web-proxy-tunneling-01 section 3.1 - - Raises: - ValueError, if the input is malformed - """ - try: - host, port = hostport.split(b":") - port = int(port) - if not check.is_valid_host(host) or not check.is_valid_port(port): - raise ValueError() - except ValueError: - raise exceptions.HttpSyntaxException("Invalid host specification: {}".format(hostport)) - - return host, port - - -def _read_response_line(rfile): - try: - line = _get_first_line(rfile) - except exceptions.HttpReadDisconnect: - # We want to provide a better error message. - raise exceptions.HttpReadDisconnect("Server disconnected") - - try: - parts = line.split(None, 2) - if len(parts) == 2: # handle missing message gracefully - parts.append(b"") - - http_version, status_code, message = parts - status_code = int(status_code) - _check_http_version(http_version) - - except ValueError: - raise exceptions.HttpSyntaxException("Bad HTTP response line: {}".format(line)) - - return http_version, status_code, message - - -def _check_http_version(http_version): - if not re.match(br"^HTTP/\d\.\d$", http_version): - raise exceptions.HttpSyntaxException("Unknown HTTP version: {}".format(http_version)) - - -def _read_headers(rfile): - """ - Read a set of headers. - Stop once a blank line is reached. - - Returns: - A headers object - - Raises: - exceptions.HttpSyntaxException - """ - ret = [] - while True: - line = rfile.readline() - if not line or line == b"\r\n" or line == b"\n": - break - if line[0] in b" \t": - if not ret: - raise exceptions.HttpSyntaxException("Invalid headers") - # continued header - ret[-1] = (ret[-1][0], ret[-1][1] + b'\r\n ' + line.strip()) - else: - try: - name, value = line.split(b":", 1) - value = value.strip() - if not name: - raise ValueError() - ret.append((name, value)) - except ValueError: - raise exceptions.HttpSyntaxException( - "Invalid header line: %s" % repr(line) - ) - return headers.Headers(ret) - - -def _read_chunked(rfile, limit=sys.maxsize): - """ - Read a HTTP body with chunked transfer encoding. - - Args: - rfile: the input file - limit: A positive integer - """ - total = 0 - while True: - line = rfile.readline(128) - if line == b"": - raise exceptions.HttpException("Connection closed prematurely") - if line != b"\r\n" and line != b"\n": - try: - length = int(line, 16) - except ValueError: - raise exceptions.HttpSyntaxException("Invalid chunked encoding length: {}".format(line)) - total += length - if total > limit: - raise exceptions.HttpException( - "HTTP Body too large. Limit is {}, " - "chunked content longer than {}".format(limit, total) - ) - chunk = rfile.read(length) - suffix = rfile.readline(5) - if suffix != b"\r\n": - raise exceptions.HttpSyntaxException("Malformed chunked body") - if length == 0: - return - yield chunk diff --git a/netlib/http/http2/__init__.py b/netlib/http/http2/__init__.py deleted file mode 100644 index 20cc63a0..00000000 --- a/netlib/http/http2/__init__.py +++ /dev/null @@ -1,8 +0,0 @@ -from netlib.http.http2.framereader import read_raw_frame, parse_frame -from netlib.http.http2.utils import parse_headers - -__all__ = [ - "read_raw_frame", - "parse_frame", - "parse_headers", -] diff --git a/netlib/http/http2/framereader.py b/netlib/http/http2/framereader.py deleted file mode 100644 index 6a164919..00000000 --- a/netlib/http/http2/framereader.py +++ /dev/null @@ -1,25 +0,0 @@ -import codecs - -import hyperframe -from mitmproxy import exceptions - - -def read_raw_frame(rfile): - header = rfile.safe_read(9) - length = int(codecs.encode(header[:3], 'hex_codec'), 16) - - if length == 4740180: - raise exceptions.HttpException("Length field looks more like HTTP/1.1:\n{}".format(rfile.read(-1))) - - body = rfile.safe_read(length) - return [header, body] - - -def parse_frame(header, body=None): - if body is None: - body = header[9:] - header = header[:9] - - frame, length = hyperframe.frame.Frame.parse_frame_header(header) - frame.parse_body(memoryview(body)) - return frame diff --git a/netlib/http/http2/utils.py b/netlib/http/http2/utils.py deleted file mode 100644 index 164bacc8..00000000 --- a/netlib/http/http2/utils.py +++ /dev/null @@ -1,37 +0,0 @@ -from netlib.http import url - - -def parse_headers(headers): - authority = headers.get(':authority', '').encode() - method = headers.get(':method', 'GET').encode() - scheme = headers.get(':scheme', 'https').encode() - path = headers.get(':path', '/').encode() - - headers.pop(":method", None) - headers.pop(":scheme", None) - headers.pop(":path", None) - - host = None - port = None - - if path == b'*' or path.startswith(b"/"): - first_line_format = "relative" - elif method == b'CONNECT': # pragma: no cover - raise NotImplementedError("CONNECT over HTTP/2 is not implemented.") - else: # pragma: no cover - first_line_format = "absolute" - # FIXME: verify if path or :host contains what we need - scheme, host, port, _ = url.parse(path) - - if authority: - host, _, port = authority.partition(b':') - - if not host: - host = b'localhost' - - if not port: - port = 443 if scheme == b'https' else 80 - - port = int(port) - - return first_line_format, method, scheme, host, port, path diff --git a/netlib/http/message.py b/netlib/http/message.py deleted file mode 100644 index 772a124e..00000000 --- a/netlib/http/message.py +++ /dev/null @@ -1,300 +0,0 @@ -import re -import warnings -from typing import Optional - -from mitmproxy.utils import strutils -from netlib.http import encoding -from mitmproxy.types import serializable -from netlib.http import headers - - -# While headers _should_ be ASCII, it's not uncommon for certain headers to be utf-8 encoded. -def _native(x): - return x.decode("utf-8", "surrogateescape") - - -def _always_bytes(x): - return strutils.always_bytes(x, "utf-8", "surrogateescape") - - -class MessageData(serializable.Serializable): - def __eq__(self, other): - if isinstance(other, MessageData): - return self.__dict__ == other.__dict__ - return False - - def __ne__(self, other): - return not self.__eq__(other) - - def set_state(self, state): - for k, v in state.items(): - if k == "headers": - v = headers.Headers.from_state(v) - setattr(self, k, v) - - def get_state(self): - state = vars(self).copy() - state["headers"] = state["headers"].get_state() - return state - - @classmethod - def from_state(cls, state): - state["headers"] = headers.Headers.from_state(state["headers"]) - return cls(**state) - - -class Message(serializable.Serializable): - def __eq__(self, other): - if isinstance(other, Message): - return self.data == other.data - return False - - def __ne__(self, other): - return not self.__eq__(other) - - def get_state(self): - return self.data.get_state() - - def set_state(self, state): - self.data.set_state(state) - - @classmethod - def from_state(cls, state): - state["headers"] = headers.Headers.from_state(state["headers"]) - return cls(**state) - - @property - def headers(self): - """ - Message headers object - - Returns: - netlib.http.Headers - """ - return self.data.headers - - @headers.setter - def headers(self, h): - self.data.headers = h - - @property - def raw_content(self) -> bytes: - """ - The raw (encoded) HTTP message body - - See also: :py:attr:`content`, :py:class:`text` - """ - return self.data.content - - @raw_content.setter - def raw_content(self, content): - self.data.content = content - - def get_content(self, strict: bool=True) -> bytes: - """ - The HTTP message body decoded with the content-encoding header (e.g. gzip) - - Raises: - ValueError, when the content-encoding is invalid and strict is True. - - See also: :py:class:`raw_content`, :py:attr:`text` - """ - if self.raw_content is None: - return None - ce = self.headers.get("content-encoding") - if ce: - try: - return encoding.decode(self.raw_content, ce) - except ValueError: - if strict: - raise - return self.raw_content - else: - return self.raw_content - - def set_content(self, value): - if value is None: - self.raw_content = None - return - if not isinstance(value, bytes): - raise TypeError( - "Message content must be bytes, not {}. " - "Please use .text if you want to assign a str." - .format(type(value).__name__) - ) - ce = self.headers.get("content-encoding") - try: - self.raw_content = encoding.encode(value, ce or "identity") - except ValueError: - # So we have an invalid content-encoding? - # Let's remove it! - del self.headers["content-encoding"] - self.raw_content = value - self.headers["content-length"] = str(len(self.raw_content)) - - content = property(get_content, set_content) - - @property - def http_version(self): - """ - Version string, e.g. "HTTP/1.1" - """ - return _native(self.data.http_version) - - @http_version.setter - def http_version(self, http_version): - self.data.http_version = _always_bytes(http_version) - - @property - def timestamp_start(self): - """ - First byte timestamp - """ - return self.data.timestamp_start - - @timestamp_start.setter - def timestamp_start(self, timestamp_start): - self.data.timestamp_start = timestamp_start - - @property - def timestamp_end(self): - """ - Last byte timestamp - """ - return self.data.timestamp_end - - @timestamp_end.setter - def timestamp_end(self, timestamp_end): - self.data.timestamp_end = timestamp_end - - def _get_content_type_charset(self) -> Optional[str]: - ct = headers.parse_content_type(self.headers.get("content-type", "")) - if ct: - return ct[2].get("charset") - - def _guess_encoding(self) -> str: - enc = self._get_content_type_charset() - if enc: - return enc - - if "json" in self.headers.get("content-type", ""): - return "utf8" - else: - # We may also want to check for HTML meta tags here at some point. - return "latin-1" - - def get_text(self, strict: bool=True) -> str: - """ - The HTTP message body decoded with both content-encoding header (e.g. gzip) - and content-type header charset. - - Raises: - ValueError, when either content-encoding or charset is invalid and strict is True. - - See also: :py:attr:`content`, :py:class:`raw_content` - """ - if self.raw_content is None: - return None - enc = self._guess_encoding() - - content = self.get_content(strict) - try: - return encoding.decode(content, enc) - except ValueError: - if strict: - raise - return content.decode("utf8", "surrogateescape") - - def set_text(self, text): - if text is None: - self.content = None - return - enc = self._guess_encoding() - - try: - self.content = encoding.encode(text, enc) - except ValueError: - # Fall back to UTF-8 and update the content-type header. - ct = headers.parse_content_type(self.headers.get("content-type", "")) or ("text", "plain", {}) - ct[2]["charset"] = "utf-8" - self.headers["content-type"] = headers.assemble_content_type(*ct) - enc = "utf8" - self.content = text.encode(enc, "surrogateescape") - - text = property(get_text, set_text) - - def decode(self, strict=True): - """ - Decodes body based on the current Content-Encoding header, then - removes the header. If there is no Content-Encoding header, no - action is taken. - - Raises: - ValueError, when the content-encoding is invalid and strict is True. - """ - self.raw_content = self.get_content(strict) - self.headers.pop("content-encoding", None) - - def encode(self, e): - """ - Encodes body with the encoding e, where e is "gzip", "deflate", "identity", or "br". - Any existing content-encodings are overwritten, - the content is not decoded beforehand. - - Raises: - ValueError, when the specified content-encoding is invalid. - """ - self.headers["content-encoding"] = e - self.content = self.raw_content - if "content-encoding" not in self.headers: - raise ValueError("Invalid content encoding {}".format(repr(e))) - - def replace(self, pattern, repl, flags=0, count=0): - """ - Replaces a regular expression pattern with repl in both the headers - and the body of the message. Encoded body will be decoded - before replacement, and re-encoded afterwards. - - Returns: - The number of replacements made. - """ - if isinstance(pattern, str): - pattern = strutils.escaped_str_to_bytes(pattern) - if isinstance(repl, str): - repl = strutils.escaped_str_to_bytes(repl) - replacements = 0 - if self.content: - self.content, replacements = re.subn( - pattern, repl, self.content, flags=flags, count=count - ) - replacements += self.headers.replace(pattern, repl, flags=flags, count=count) - return replacements - - # Legacy - - @property - def body(self): # pragma: no cover - warnings.warn(".body is deprecated, use .content instead.", DeprecationWarning) - return self.content - - @body.setter - def body(self, body): # pragma: no cover - warnings.warn(".body is deprecated, use .content instead.", DeprecationWarning) - self.content = body - - -class decoded: - """ - Deprecated: You can now directly use :py:attr:`content`. - :py:attr:`raw_content` has the encoded content. - """ - - def __init__(self, message): # pragma no cover - warnings.warn("decoded() is deprecated, you can now directly use .content instead. " - ".raw_content has the encoded content.", DeprecationWarning) - - def __enter__(self): # pragma no cover - pass - - def __exit__(self, type, value, tb): # pragma no cover - pass diff --git a/netlib/http/multipart.py b/netlib/http/multipart.py deleted file mode 100644 index 536b2809..00000000 --- a/netlib/http/multipart.py +++ /dev/null @@ -1,32 +0,0 @@ -import re - -from netlib.http import headers - - -def decode(hdrs, content): - """ - Takes a multipart boundary encoded string and returns list of (key, value) tuples. - """ - v = hdrs.get("content-type") - if v: - v = headers.parse_content_type(v) - if not v: - return [] - try: - boundary = v[2]["boundary"].encode("ascii") - except (KeyError, UnicodeError): - return [] - - rx = re.compile(br'\bname="([^"]+)"') - r = [] - - for i in content.split(b"--" + boundary): - parts = i.splitlines() - if len(parts) > 1 and parts[0][0:2] != b"--": - match = rx.search(parts[1]) - if match: - key = match.group(1) - value = b"".join(parts[3 + parts[2:].index(b""):]) - r.append((key, value)) - return r - return [] diff --git a/netlib/http/request.py b/netlib/http/request.py deleted file mode 100644 index 16b0c986..00000000 --- a/netlib/http/request.py +++ /dev/null @@ -1,405 +0,0 @@ -import re -import urllib - -from mitmproxy.types import multidict -from mitmproxy.utils import strutils -from netlib.http import multipart -from netlib.http import cookies -from netlib.http import headers as nheaders -from netlib.http import message -import netlib.http.url - -# This regex extracts & splits the host header into host and port. -# Handles the edge case of IPv6 addresses containing colons. -# https://bugzilla.mozilla.org/show_bug.cgi?id=45891 -host_header_re = re.compile(r"^(?P<host>[^:]+|\[.+\])(?::(?P<port>\d+))?$") - - -class RequestData(message.MessageData): - def __init__( - self, - first_line_format, - method, - scheme, - host, - port, - path, - http_version, - headers=(), - content=None, - timestamp_start=None, - timestamp_end=None - ): - if isinstance(method, str): - method = method.encode("ascii", "strict") - if isinstance(scheme, str): - scheme = scheme.encode("ascii", "strict") - if isinstance(host, str): - host = host.encode("idna", "strict") - if isinstance(path, str): - path = path.encode("ascii", "strict") - if isinstance(http_version, str): - http_version = http_version.encode("ascii", "strict") - if not isinstance(headers, nheaders.Headers): - headers = nheaders.Headers(headers) - if isinstance(content, str): - raise ValueError("Content must be bytes, not {}".format(type(content).__name__)) - - self.first_line_format = first_line_format - self.method = method - self.scheme = scheme - self.host = host - self.port = port - self.path = path - self.http_version = http_version - self.headers = headers - self.content = content - self.timestamp_start = timestamp_start - self.timestamp_end = timestamp_end - - -class Request(message.Message): - """ - An HTTP request. - """ - def __init__(self, *args, **kwargs): - super().__init__() - self.data = RequestData(*args, **kwargs) - - def __repr__(self): - if self.host and self.port: - hostport = "{}:{}".format(self.host, self.port) - else: - hostport = "" - path = self.path or "" - return "Request({} {}{})".format( - self.method, hostport, path - ) - - def replace(self, pattern, repl, flags=0, count=0): - """ - Replaces a regular expression pattern with repl in the headers, the - request path and the body of the request. Encoded content will be - decoded before replacement, and re-encoded afterwards. - - Returns: - The number of replacements made. - """ - if isinstance(pattern, str): - pattern = strutils.escaped_str_to_bytes(pattern) - if isinstance(repl, str): - repl = strutils.escaped_str_to_bytes(repl) - - c = super().replace(pattern, repl, flags, count) - self.path, pc = re.subn( - pattern, repl, self.data.path, flags=flags, count=count - ) - c += pc - return c - - @property - def first_line_format(self): - """ - HTTP request form as defined in `RFC7230 <https://tools.ietf.org/html/rfc7230#section-5.3>`_. - - origin-form and asterisk-form are subsumed as "relative". - """ - return self.data.first_line_format - - @first_line_format.setter - def first_line_format(self, first_line_format): - self.data.first_line_format = first_line_format - - @property - def method(self): - """ - HTTP request method, e.g. "GET". - """ - return message._native(self.data.method).upper() - - @method.setter - def method(self, method): - self.data.method = message._always_bytes(method) - - @property - def scheme(self): - """ - HTTP request scheme, which should be "http" or "https". - """ - if not self.data.scheme: - return self.data.scheme - return message._native(self.data.scheme) - - @scheme.setter - def scheme(self, scheme): - self.data.scheme = message._always_bytes(scheme) - - @property - def host(self): - """ - Target host. This may be parsed from the raw request - (e.g. from a ``GET http://example.com/ HTTP/1.1`` request line) - or inferred from the proxy mode (e.g. an IP in transparent mode). - - Setting the host attribute also updates the host header, if present. - """ - if not self.data.host: - return self.data.host - try: - return self.data.host.decode("idna") - except UnicodeError: - return self.data.host.decode("utf8", "surrogateescape") - - @host.setter - def host(self, host): - if isinstance(host, str): - try: - # There's no non-strict mode for IDNA encoding. - # We don't want this operation to fail though, so we try - # utf8 as a last resort. - host = host.encode("idna", "strict") - except UnicodeError: - host = host.encode("utf8", "surrogateescape") - - self.data.host = host - - # Update host header - if "host" in self.headers: - if host: - self.headers["host"] = host - else: - self.headers.pop("host") - - @property - def port(self): - """ - Target port - """ - return self.data.port - - @port.setter - def port(self, port): - self.data.port = port - - @property - def path(self): - """ - HTTP request path, e.g. "/index.html". - Guaranteed to start with a slash, except for OPTIONS requests, which may just be "*". - """ - if self.data.path is None: - return None - else: - return message._native(self.data.path) - - @path.setter - def path(self, path): - self.data.path = message._always_bytes(path) - - @property - def url(self): - """ - The URL string, constructed from the request's URL components - """ - if self.first_line_format == "authority": - return "%s:%d" % (self.host, self.port) - return netlib.http.url.unparse(self.scheme, self.host, self.port, self.path) - - @url.setter - def url(self, url): - self.scheme, self.host, self.port, self.path = netlib.http.url.parse(url) - - def _parse_host_header(self): - """Extract the host and port from Host header""" - if "host" not in self.headers: - return None, None - host, port = self.headers["host"], None - m = host_header_re.match(host) - if m: - host = m.group("host").strip("[]") - if m.group("port"): - port = int(m.group("port")) - return host, port - - @property - def pretty_host(self): - """ - Similar to :py:attr:`host`, but using the Host headers as an additional preferred data source. - This is useful in transparent mode where :py:attr:`host` is only an IP address, - but may not reflect the actual destination as the Host header could be spoofed. - """ - host, port = self._parse_host_header() - if not host: - return self.host - if not port: - port = 443 if self.scheme == 'https' else 80 - # Prefer the original address if host header has an unexpected form - return host if port == self.port else self.host - - @property - def pretty_url(self): - """ - Like :py:attr:`url`, but using :py:attr:`pretty_host` instead of :py:attr:`host`. - """ - if self.first_line_format == "authority": - return "%s:%d" % (self.pretty_host, self.port) - return netlib.http.url.unparse(self.scheme, self.pretty_host, self.port, self.path) - - @property - def query(self) -> multidict.MultiDictView: - """ - The request query string as an :py:class:`~netlib.multidict.MultiDictView` object. - """ - return multidict.MultiDictView( - self._get_query, - self._set_query - ) - - def _get_query(self): - query = urllib.parse.urlparse(self.url).query - return tuple(netlib.http.url.decode(query)) - - def _set_query(self, query_data): - query = netlib.http.url.encode(query_data) - _, _, path, params, _, fragment = urllib.parse.urlparse(self.url) - self.path = urllib.parse.urlunparse(["", "", path, params, query, fragment]) - - @query.setter - def query(self, value): - self._set_query(value) - - @property - def cookies(self) -> multidict.MultiDictView: - """ - The request cookies. - - An empty :py:class:`~netlib.multidict.MultiDictView` object if the cookie monster ate them all. - """ - return multidict.MultiDictView( - self._get_cookies, - self._set_cookies - ) - - def _get_cookies(self): - h = self.headers.get_all("Cookie") - return tuple(cookies.parse_cookie_headers(h)) - - def _set_cookies(self, value): - self.headers["cookie"] = cookies.format_cookie_header(value) - - @cookies.setter - def cookies(self, value): - self._set_cookies(value) - - @property - def path_components(self): - """ - The URL's path components as a tuple of strings. - Components are unquoted. - """ - path = urllib.parse.urlparse(self.url).path - # This needs to be a tuple so that it's immutable. - # Otherwise, this would fail silently: - # request.path_components.append("foo") - return tuple(netlib.http.url.unquote(i) for i in path.split("/") if i) - - @path_components.setter - def path_components(self, components): - components = map(lambda x: netlib.http.url.quote(x, safe=""), components) - path = "/" + "/".join(components) - _, _, _, params, query, fragment = urllib.parse.urlparse(self.url) - self.path = urllib.parse.urlunparse(["", "", path, params, query, fragment]) - - def anticache(self): - """ - Modifies this request to remove headers that might produce a cached - response. That is, we remove ETags and If-Modified-Since headers. - """ - delheaders = [ - "if-modified-since", - "if-none-match", - ] - for i in delheaders: - self.headers.pop(i, None) - - def anticomp(self): - """ - Modifies this request to remove headers that will compress the - resource's data. - """ - self.headers["accept-encoding"] = "identity" - - def constrain_encoding(self): - """ - Limits the permissible Accept-Encoding values, based on what we can - decode appropriately. - """ - accept_encoding = self.headers.get("accept-encoding") - if accept_encoding: - self.headers["accept-encoding"] = ( - ', '.join( - e - for e in {"gzip", "identity", "deflate", "br"} - if e in accept_encoding - ) - ) - - @property - def urlencoded_form(self): - """ - The URL-encoded form data as an :py:class:`~netlib.multidict.MultiDictView` object. - An empty multidict.MultiDictView if the content-type indicates non-form data - or the content could not be parsed. - """ - return multidict.MultiDictView( - self._get_urlencoded_form, - self._set_urlencoded_form - ) - - def _get_urlencoded_form(self): - is_valid_content_type = "application/x-www-form-urlencoded" in self.headers.get("content-type", "").lower() - if is_valid_content_type: - try: - return tuple(netlib.http.url.decode(self.content)) - except ValueError: - pass - return () - - def _set_urlencoded_form(self, form_data): - """ - Sets the body to the URL-encoded form data, and adds the appropriate content-type header. - This will overwrite the existing content if there is one. - """ - self.headers["content-type"] = "application/x-www-form-urlencoded" - self.content = netlib.http.url.encode(form_data).encode() - - @urlencoded_form.setter - def urlencoded_form(self, value): - self._set_urlencoded_form(value) - - @property - def multipart_form(self): - """ - The multipart form data as an :py:class:`~netlib.multidict.MultiDictView` object. - None if the content-type indicates non-form data. - """ - return multidict.MultiDictView( - self._get_multipart_form, - self._set_multipart_form - ) - - def _get_multipart_form(self): - is_valid_content_type = "multipart/form-data" in self.headers.get("content-type", "").lower() - if is_valid_content_type: - try: - return multipart.decode(self.headers, self.content) - except ValueError: - pass - return () - - def _set_multipart_form(self, value): - raise NotImplementedError() - - @multipart_form.setter - def multipart_form(self, value): - self._set_multipart_form(value) diff --git a/netlib/http/response.py b/netlib/http/response.py deleted file mode 100644 index 4d1d5d24..00000000 --- a/netlib/http/response.py +++ /dev/null @@ -1,192 +0,0 @@ -import time -from email.utils import parsedate_tz, formatdate, mktime_tz -from mitmproxy.utils import human -from mitmproxy.types import multidict -from netlib.http import cookies -from netlib.http import headers as nheaders -from netlib.http import message -from netlib.http import status_codes -from typing import AnyStr -from typing import Dict -from typing import Iterable -from typing import Tuple -from typing import Union - - -class ResponseData(message.MessageData): - def __init__( - self, - http_version, - status_code, - reason=None, - headers=(), - content=None, - timestamp_start=None, - timestamp_end=None - ): - if isinstance(http_version, str): - http_version = http_version.encode("ascii", "strict") - if isinstance(reason, str): - reason = reason.encode("ascii", "strict") - if not isinstance(headers, nheaders.Headers): - headers = nheaders.Headers(headers) - if isinstance(content, str): - raise ValueError("Content must be bytes, not {}".format(type(content).__name__)) - - self.http_version = http_version - self.status_code = status_code - self.reason = reason - self.headers = headers - self.content = content - self.timestamp_start = timestamp_start - self.timestamp_end = timestamp_end - - -class Response(message.Message): - """ - An HTTP response. - """ - def __init__(self, *args, **kwargs): - super().__init__() - self.data = ResponseData(*args, **kwargs) - - def __repr__(self): - if self.raw_content: - details = "{}, {}".format( - self.headers.get("content-type", "unknown content type"), - human.pretty_size(len(self.raw_content)) - ) - else: - details = "no content" - return "Response({status_code} {reason}, {details})".format( - status_code=self.status_code, - reason=self.reason, - details=details - ) - - @classmethod - def make( - cls, - status_code: int=200, - content: AnyStr=b"", - headers: Union[Dict[AnyStr, AnyStr], Iterable[Tuple[bytes, bytes]]]=() - ): - """ - Simplified API for creating response objects. - """ - resp = cls( - b"HTTP/1.1", - status_code, - status_codes.RESPONSES.get(status_code, "").encode(), - (), - None - ) - - # Headers can be list or dict, we differentiate here. - if isinstance(headers, dict): - resp.headers = nheaders.Headers(**headers) - elif isinstance(headers, Iterable): - resp.headers = nheaders.Headers(headers) - else: - raise TypeError("Expected headers to be an iterable or dict, but is {}.".format( - type(headers).__name__ - )) - - # Assign this manually to update the content-length header. - if isinstance(content, bytes): - resp.content = content - elif isinstance(content, str): - resp.text = content - else: - raise TypeError("Expected content to be str or bytes, but is {}.".format( - type(content).__name__ - )) - - return resp - - @property - def status_code(self): - """ - HTTP Status Code, e.g. ``200``. - """ - return self.data.status_code - - @status_code.setter - def status_code(self, status_code): - self.data.status_code = status_code - - @property - def reason(self): - """ - HTTP Reason Phrase, e.g. "Not Found". - This is always :py:obj:`None` for HTTP2 requests, because HTTP2 responses do not contain a reason phrase. - """ - return message._native(self.data.reason) - - @reason.setter - def reason(self, reason): - self.data.reason = message._always_bytes(reason) - - @property - def cookies(self) -> multidict.MultiDictView: - """ - The response cookies. A possibly empty - :py:class:`~netlib.multidict.MultiDictView`, where the keys are cookie - name strings, and values are (value, attr) tuples. Value is a string, - and attr is an MultiDictView containing cookie attributes. Within - attrs, unary attributes (e.g. HTTPOnly) are indicated by a Null value. - - Caveats: - Updating the attr - """ - return multidict.MultiDictView( - self._get_cookies, - self._set_cookies - ) - - def _get_cookies(self): - h = self.headers.get_all("set-cookie") - return tuple(cookies.parse_set_cookie_headers(h)) - - def _set_cookies(self, value): - cookie_headers = [] - for k, v in value: - header = cookies.format_set_cookie_header([(k, v[0], v[1])]) - cookie_headers.append(header) - self.headers.set_all("set-cookie", cookie_headers) - - @cookies.setter - def cookies(self, value): - self._set_cookies(value) - - def refresh(self, now=None): - """ - This fairly complex and heuristic function refreshes a server - response for replay. - - - It adjusts date, expires and last-modified headers. - - It adjusts cookie expiration. - """ - if not now: - now = time.time() - delta = now - self.timestamp_start - refresh_headers = [ - "date", - "expires", - "last-modified", - ] - for i in refresh_headers: - if i in self.headers: - d = parsedate_tz(self.headers[i]) - if d: - new = mktime_tz(d) + delta - self.headers[i] = formatdate(new) - c = [] - for set_cookie_header in self.headers.get_all("set-cookie"): - try: - refreshed = cookies.refresh_set_cookie_header(set_cookie_header, delta) - except ValueError: - refreshed = set_cookie_header - c.append(refreshed) - if c: - self.headers.set_all("set-cookie", c) diff --git a/netlib/http/status_codes.py b/netlib/http/status_codes.py deleted file mode 100644 index 5a83cd73..00000000 --- a/netlib/http/status_codes.py +++ /dev/null @@ -1,104 +0,0 @@ -CONTINUE = 100 -SWITCHING = 101 -OK = 200 -CREATED = 201 -ACCEPTED = 202 -NON_AUTHORITATIVE_INFORMATION = 203 -NO_CONTENT = 204 -RESET_CONTENT = 205 -PARTIAL_CONTENT = 206 -MULTI_STATUS = 207 - -MULTIPLE_CHOICE = 300 -MOVED_PERMANENTLY = 301 -FOUND = 302 -SEE_OTHER = 303 -NOT_MODIFIED = 304 -USE_PROXY = 305 -TEMPORARY_REDIRECT = 307 - -BAD_REQUEST = 400 -UNAUTHORIZED = 401 -PAYMENT_REQUIRED = 402 -FORBIDDEN = 403 -NOT_FOUND = 404 -NOT_ALLOWED = 405 -NOT_ACCEPTABLE = 406 -PROXY_AUTH_REQUIRED = 407 -REQUEST_TIMEOUT = 408 -CONFLICT = 409 -GONE = 410 -LENGTH_REQUIRED = 411 -PRECONDITION_FAILED = 412 -REQUEST_ENTITY_TOO_LARGE = 413 -REQUEST_URI_TOO_LONG = 414 -UNSUPPORTED_MEDIA_TYPE = 415 -REQUESTED_RANGE_NOT_SATISFIABLE = 416 -EXPECTATION_FAILED = 417 -IM_A_TEAPOT = 418 - -INTERNAL_SERVER_ERROR = 500 -NOT_IMPLEMENTED = 501 -BAD_GATEWAY = 502 -SERVICE_UNAVAILABLE = 503 -GATEWAY_TIMEOUT = 504 -HTTP_VERSION_NOT_SUPPORTED = 505 -INSUFFICIENT_STORAGE_SPACE = 507 -NOT_EXTENDED = 510 - -RESPONSES = { - # 100 - CONTINUE: "Continue", - SWITCHING: "Switching Protocols", - - # 200 - OK: "OK", - CREATED: "Created", - ACCEPTED: "Accepted", - NON_AUTHORITATIVE_INFORMATION: "Non-Authoritative Information", - NO_CONTENT: "No Content", - RESET_CONTENT: "Reset Content.", - PARTIAL_CONTENT: "Partial Content", - MULTI_STATUS: "Multi-Status", - - # 300 - MULTIPLE_CHOICE: "Multiple Choices", - MOVED_PERMANENTLY: "Moved Permanently", - FOUND: "Found", - SEE_OTHER: "See Other", - NOT_MODIFIED: "Not Modified", - USE_PROXY: "Use Proxy", - # 306 not defined?? - TEMPORARY_REDIRECT: "Temporary Redirect", - - # 400 - BAD_REQUEST: "Bad Request", - UNAUTHORIZED: "Unauthorized", - PAYMENT_REQUIRED: "Payment Required", - FORBIDDEN: "Forbidden", - NOT_FOUND: "Not Found", - NOT_ALLOWED: "Method Not Allowed", - NOT_ACCEPTABLE: "Not Acceptable", - PROXY_AUTH_REQUIRED: "Proxy Authentication Required", - REQUEST_TIMEOUT: "Request Time-out", - CONFLICT: "Conflict", - GONE: "Gone", - LENGTH_REQUIRED: "Length Required", - PRECONDITION_FAILED: "Precondition Failed", - REQUEST_ENTITY_TOO_LARGE: "Request Entity Too Large", - REQUEST_URI_TOO_LONG: "Request-URI Too Long", - UNSUPPORTED_MEDIA_TYPE: "Unsupported Media Type", - REQUESTED_RANGE_NOT_SATISFIABLE: "Requested Range not satisfiable", - EXPECTATION_FAILED: "Expectation Failed", - IM_A_TEAPOT: "I'm a teapot", - - # 500 - INTERNAL_SERVER_ERROR: "Internal Server Error", - NOT_IMPLEMENTED: "Not Implemented", - BAD_GATEWAY: "Bad Gateway", - SERVICE_UNAVAILABLE: "Service Unavailable", - GATEWAY_TIMEOUT: "Gateway Time-out", - HTTP_VERSION_NOT_SUPPORTED: "HTTP Version not supported", - INSUFFICIENT_STORAGE_SPACE: "Insufficient Storage Space", - NOT_EXTENDED: "Not Extended" -} diff --git a/netlib/http/url.py b/netlib/http/url.py deleted file mode 100644 index 3ca58120..00000000 --- a/netlib/http/url.py +++ /dev/null @@ -1,127 +0,0 @@ -import urllib -from typing import Sequence -from typing import Tuple - -from netlib import check - - -# PY2 workaround -def decode_parse_result(result, enc): - if hasattr(result, "decode"): - return result.decode(enc) - else: - return urllib.parse.ParseResult(*[x.decode(enc) for x in result]) - - -# PY2 workaround -def encode_parse_result(result, enc): - if hasattr(result, "encode"): - return result.encode(enc) - else: - return urllib.parse.ParseResult(*[x.encode(enc) for x in result]) - - -def parse(url): - """ - URL-parsing function that checks that - - port is an integer 0-65535 - - host is a valid IDNA-encoded hostname with no null-bytes - - path is valid ASCII - - Args: - A URL (as bytes or as unicode) - - Returns: - A (scheme, host, port, path) tuple - - Raises: - ValueError, if the URL is not properly formatted. - """ - parsed = urllib.parse.urlparse(url) - - if not parsed.hostname: - raise ValueError("No hostname given") - - if isinstance(url, bytes): - host = parsed.hostname - - # this should not raise a ValueError, - # but we try to be very forgiving here and accept just everything. - # decode_parse_result(parsed, "ascii") - else: - host = parsed.hostname.encode("idna") - parsed = encode_parse_result(parsed, "ascii") - - port = parsed.port - if not port: - port = 443 if parsed.scheme == b"https" else 80 - - full_path = urllib.parse.urlunparse( - (b"", b"", parsed.path, parsed.params, parsed.query, parsed.fragment) - ) - if not full_path.startswith(b"/"): - full_path = b"/" + full_path - - if not check.is_valid_host(host): - raise ValueError("Invalid Host") - if not check.is_valid_port(port): - raise ValueError("Invalid Port") - - return parsed.scheme, host, port, full_path - - -def unparse(scheme, host, port, path=""): - """ - Returns a URL string, constructed from the specified components. - - Args: - All args must be str. - """ - if path == "*": - path = "" - return "%s://%s%s" % (scheme, hostport(scheme, host, port), path) - - -def encode(s: Sequence[Tuple[str, str]]) -> str: - """ - Takes a list of (key, value) tuples and returns a urlencoded string. - """ - return urllib.parse.urlencode(s, False, errors="surrogateescape") - - -def decode(s): - """ - Takes a urlencoded string and returns a list of surrogate-escaped (key, value) tuples. - """ - return urllib.parse.parse_qsl(s, keep_blank_values=True, errors='surrogateescape') - - -def quote(b: str, safe: str="/") -> str: - """ - Returns: - An ascii-encodable str. - """ - return urllib.parse.quote(b, safe=safe, errors="surrogateescape") - - -def unquote(s: str) -> str: - """ - Args: - s: A surrogate-escaped str - Returns: - A surrogate-escaped str - """ - return urllib.parse.unquote(s, errors="surrogateescape") - - -def hostport(scheme, host, port): - """ - Returns the host component, with a port specifcation if needed. - """ - if (port, scheme) in [(80, "http"), (443, "https"), (80, b"http"), (443, b"https")]: - return host - else: - if isinstance(host, bytes): - return b"%s:%d" % (host, port) - else: - return "%s:%d" % (host, port) diff --git a/netlib/http/user_agents.py b/netlib/http/user_agents.py deleted file mode 100644 index d0ca2f21..00000000 --- a/netlib/http/user_agents.py +++ /dev/null @@ -1,50 +0,0 @@ -""" - A small collection of useful user-agent header strings. These should be - kept reasonably current to reflect common usage. -""" - -# pylint: line-too-long - -# A collection of (name, shortcut, string) tuples. - -UASTRINGS = [ - ("android", - "a", - "Mozilla/5.0 (Linux; U; Android 4.1.1; en-gb; Nexus 7 Build/JRO03D) AFL/01.04.02"), # noqa - ("blackberry", - "l", - "Mozilla/5.0 (BlackBerry; U; BlackBerry 9900; en) AppleWebKit/534.11+ (KHTML, like Gecko) Version/7.1.0.346 Mobile Safari/534.11+"), # noqa - ("bingbot", - "b", - "Mozilla/5.0 (compatible; bingbot/2.0; +http://www.bing.com/bingbot.htm)"), # noqa - ("chrome", - "c", - "Mozilla/5.0 (Windows NT 6.1; WOW64) AppleWebKit/537.1 (KHTML, like Gecko) Chrome/22.0.1207.1 Safari/537.1"), # noqa - ("firefox", - "f", - "Mozilla/5.0 (Windows NT 6.1; Win64; x64; rv:14.0) Gecko/20120405 Firefox/14.0a1"), # noqa - ("googlebot", - "g", - "Googlebot/2.1 (+http://www.googlebot.com/bot.html)"), # noqa - ("ie9", - "i", - "Mozilla/5.0 (Windows; U; MSIE 9.0; WIndows NT 9.0; en-US)"), # noqa - ("ipad", - "p", - "Mozilla/5.0 (iPad; CPU OS 5_1 like Mac OS X) AppleWebKit/534.46 (KHTML, like Gecko) Version/5.1 Mobile/9B176 Safari/7534.48.3"), # noqa - ("iphone", - "h", - "Mozilla/5.0 (iPhone; CPU iPhone OS 4_2_1 like Mac OS X) AppleWebKit/533.17.9 (KHTML, like Gecko) Version/5.0.2 Mobile/8C148a Safari/6533.18.5"), # noqa - ("safari", - "s", - "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_7_3) AppleWebKit/534.55.3 (KHTML, like Gecko) Version/5.1.3 Safari/534.53.10"), # noqa -] - - -def get_by_shortcut(s): - """ - Retrieve a user agent entry by shortcut. - """ - for i in UASTRINGS: - if s == i[1]: - return i diff --git a/netlib/socks.py b/netlib/socks.py deleted file mode 100644 index 377308a8..00000000 --- a/netlib/socks.py +++ /dev/null @@ -1,234 +0,0 @@ -import struct -import array -import ipaddress - -from netlib import tcp -from netlib import check -from mitmproxy.types import bidi - - -class SocksError(Exception): - def __init__(self, code, message): - super().__init__(message) - self.code = code - -VERSION = bidi.BiDi( - SOCKS4=0x04, - SOCKS5=0x05 -) - -CMD = bidi.BiDi( - CONNECT=0x01, - BIND=0x02, - UDP_ASSOCIATE=0x03 -) - -ATYP = bidi.BiDi( - IPV4_ADDRESS=0x01, - DOMAINNAME=0x03, - IPV6_ADDRESS=0x04 -) - -REP = bidi.BiDi( - SUCCEEDED=0x00, - GENERAL_SOCKS_SERVER_FAILURE=0x01, - CONNECTION_NOT_ALLOWED_BY_RULESET=0x02, - NETWORK_UNREACHABLE=0x03, - HOST_UNREACHABLE=0x04, - CONNECTION_REFUSED=0x05, - TTL_EXPIRED=0x06, - COMMAND_NOT_SUPPORTED=0x07, - ADDRESS_TYPE_NOT_SUPPORTED=0x08, -) - -METHOD = bidi.BiDi( - NO_AUTHENTICATION_REQUIRED=0x00, - GSSAPI=0x01, - USERNAME_PASSWORD=0x02, - NO_ACCEPTABLE_METHODS=0xFF -) - -USERNAME_PASSWORD_VERSION = bidi.BiDi( - DEFAULT=0x01 -) - - -class ClientGreeting: - __slots__ = ("ver", "methods") - - def __init__(self, ver, methods): - self.ver = ver - self.methods = array.array("B") - self.methods.extend(methods) - - def assert_socks5(self): - if self.ver != VERSION.SOCKS5: - if self.ver == ord("G") and len(self.methods) == ord("E"): - guess = "Probably not a SOCKS request but a regular HTTP request. " - else: - guess = "" - - raise SocksError( - REP.GENERAL_SOCKS_SERVER_FAILURE, - guess + "Invalid SOCKS version. Expected 0x05, got 0x%x" % self.ver - ) - - @classmethod - def from_file(cls, f, fail_early=False): - """ - :param fail_early: If true, a SocksError will be raised if the first byte does not indicate socks5. - """ - ver, nmethods = struct.unpack("!BB", f.safe_read(2)) - client_greeting = cls(ver, []) - if fail_early: - client_greeting.assert_socks5() - client_greeting.methods.fromstring(f.safe_read(nmethods)) - return client_greeting - - def to_file(self, f): - f.write(struct.pack("!BB", self.ver, len(self.methods))) - f.write(self.methods.tostring()) - - -class ServerGreeting: - __slots__ = ("ver", "method") - - def __init__(self, ver, method): - self.ver = ver - self.method = method - - def assert_socks5(self): - if self.ver != VERSION.SOCKS5: - if self.ver == ord("H") and self.method == ord("T"): - guess = "Probably not a SOCKS request but a regular HTTP response. " - else: - guess = "" - - raise SocksError( - REP.GENERAL_SOCKS_SERVER_FAILURE, - guess + "Invalid SOCKS version. Expected 0x05, got 0x%x" % self.ver - ) - - @classmethod - def from_file(cls, f): - ver, method = struct.unpack("!BB", f.safe_read(2)) - return cls(ver, method) - - def to_file(self, f): - f.write(struct.pack("!BB", self.ver, self.method)) - - -class UsernamePasswordAuth: - __slots__ = ("ver", "username", "password") - - def __init__(self, ver, username, password): - self.ver = ver - self.username = username - self.password = password - - def assert_authver1(self): - if self.ver != USERNAME_PASSWORD_VERSION.DEFAULT: - raise SocksError( - 0, - "Invalid auth version. Expected 0x01, got 0x%x" % self.ver - ) - - @classmethod - def from_file(cls, f): - ver, ulen = struct.unpack("!BB", f.safe_read(2)) - username = f.safe_read(ulen) - plen, = struct.unpack("!B", f.safe_read(1)) - password = f.safe_read(plen) - return cls(ver, username.decode(), password.decode()) - - def to_file(self, f): - f.write(struct.pack("!BB", self.ver, len(self.username))) - f.write(self.username.encode()) - f.write(struct.pack("!B", len(self.password))) - f.write(self.password.encode()) - - -class UsernamePasswordAuthResponse: - __slots__ = ("ver", "status") - - def __init__(self, ver, status): - self.ver = ver - self.status = status - - def assert_authver1(self): - if self.ver != USERNAME_PASSWORD_VERSION.DEFAULT: - raise SocksError( - 0, - "Invalid auth version. Expected 0x01, got 0x%x" % self.ver - ) - - @classmethod - def from_file(cls, f): - ver, status = struct.unpack("!BB", f.safe_read(2)) - return cls(ver, status) - - def to_file(self, f): - f.write(struct.pack("!BB", self.ver, self.status)) - - -class Message: - __slots__ = ("ver", "msg", "atyp", "addr") - - def __init__(self, ver, msg, atyp, addr): - self.ver = ver - self.msg = msg - self.atyp = atyp - self.addr = tcp.Address.wrap(addr) - - def assert_socks5(self): - if self.ver != VERSION.SOCKS5: - raise SocksError( - REP.GENERAL_SOCKS_SERVER_FAILURE, - "Invalid SOCKS version. Expected 0x05, got 0x%x" % self.ver - ) - - @classmethod - def from_file(cls, f): - ver, msg, rsv, atyp = struct.unpack("!BBBB", f.safe_read(4)) - if rsv != 0x00: - raise SocksError( - REP.GENERAL_SOCKS_SERVER_FAILURE, - "Socks Request: Invalid reserved byte: %s" % rsv - ) - if atyp == ATYP.IPV4_ADDRESS: - # We use tnoa here as ntop is not commonly available on Windows. - host = ipaddress.IPv4Address(f.safe_read(4)).compressed - use_ipv6 = False - elif atyp == ATYP.IPV6_ADDRESS: - host = ipaddress.IPv6Address(f.safe_read(16)).compressed - use_ipv6 = True - elif atyp == ATYP.DOMAINNAME: - length, = struct.unpack("!B", f.safe_read(1)) - host = f.safe_read(length) - if not check.is_valid_host(host): - raise SocksError(REP.GENERAL_SOCKS_SERVER_FAILURE, "Invalid hostname: %s" % host) - host = host.decode("idna") - use_ipv6 = False - else: - raise SocksError(REP.ADDRESS_TYPE_NOT_SUPPORTED, - "Socks Request: Unknown ATYP: %s" % atyp) - - port, = struct.unpack("!H", f.safe_read(2)) - addr = tcp.Address((host, port), use_ipv6=use_ipv6) - return cls(ver, msg, atyp, addr) - - def to_file(self, f): - f.write(struct.pack("!BBBB", self.ver, self.msg, 0x00, self.atyp)) - if self.atyp == ATYP.IPV4_ADDRESS: - f.write(ipaddress.IPv4Address(self.addr.host).packed) - elif self.atyp == ATYP.IPV6_ADDRESS: - f.write(ipaddress.IPv6Address(self.addr.host).packed) - elif self.atyp == ATYP.DOMAINNAME: - f.write(struct.pack("!B", len(self.addr.host))) - f.write(self.addr.host.encode("idna")) - else: - raise SocksError( - REP.ADDRESS_TYPE_NOT_SUPPORTED, - "Unknown ATYP: %s" % self.atyp - ) - f.write(struct.pack("!H", self.addr.port)) diff --git a/netlib/tcp.py b/netlib/tcp.py deleted file mode 100644 index ac368a9c..00000000 --- a/netlib/tcp.py +++ /dev/null @@ -1,989 +0,0 @@ -import os -import select -import socket -import sys -import threading -import time -import traceback - -import binascii - -from typing import Optional # noqa - -from mitmproxy.utils import strutils - -import certifi -from backports import ssl_match_hostname -import OpenSSL -from OpenSSL import SSL - -from mitmproxy import certs -from mitmproxy.utils import version_check -from mitmproxy.types import serializable -from mitmproxy import exceptions -from mitmproxy.types import basethread - -# This is a rather hackish way to make sure that -# the latest version of pyOpenSSL is actually installed. -version_check.check_pyopenssl_version() - -socket_fileobject = socket.SocketIO - -EINTR = 4 -if os.environ.get("NO_ALPN"): - HAS_ALPN = False -else: - HAS_ALPN = SSL._lib.Cryptography_HAS_ALPN - -# To enable all SSL methods use: SSLv23 -# then add options to disable certain methods -# https://bugs.launchpad.net/pyopenssl/+bug/1020632/comments/3 -SSL_BASIC_OPTIONS = ( - SSL.OP_CIPHER_SERVER_PREFERENCE -) -if hasattr(SSL, "OP_NO_COMPRESSION"): - SSL_BASIC_OPTIONS |= SSL.OP_NO_COMPRESSION - -SSL_DEFAULT_METHOD = SSL.SSLv23_METHOD -SSL_DEFAULT_OPTIONS = ( - SSL.OP_NO_SSLv2 | - SSL.OP_NO_SSLv3 | - SSL_BASIC_OPTIONS -) -if hasattr(SSL, "OP_NO_COMPRESSION"): - SSL_DEFAULT_OPTIONS |= SSL.OP_NO_COMPRESSION - -""" -Map a reasonable SSL version specification into the format OpenSSL expects. -Don't ask... -https://bugs.launchpad.net/pyopenssl/+bug/1020632/comments/3 -""" -sslversion_choices = { - "all": (SSL.SSLv23_METHOD, SSL_BASIC_OPTIONS), - # SSLv23_METHOD + NO_SSLv2 + NO_SSLv3 == TLS 1.0+ - # TLSv1_METHOD would be TLS 1.0 only - "secure": (SSL.SSLv23_METHOD, (SSL.OP_NO_SSLv2 | SSL.OP_NO_SSLv3 | SSL_BASIC_OPTIONS)), - "SSLv2": (SSL.SSLv2_METHOD, SSL_BASIC_OPTIONS), - "SSLv3": (SSL.SSLv3_METHOD, SSL_BASIC_OPTIONS), - "TLSv1": (SSL.TLSv1_METHOD, SSL_BASIC_OPTIONS), - "TLSv1_1": (SSL.TLSv1_1_METHOD, SSL_BASIC_OPTIONS), - "TLSv1_2": (SSL.TLSv1_2_METHOD, SSL_BASIC_OPTIONS), -} - - -class SSLKeyLogger: - - def __init__(self, filename): - self.filename = filename - self.f = None - self.lock = threading.Lock() - - # required for functools.wraps, which pyOpenSSL uses. - __name__ = "SSLKeyLogger" - - def __call__(self, connection, where, ret): - if where == SSL.SSL_CB_HANDSHAKE_DONE and ret == 1: - with self.lock: - if not self.f: - d = os.path.dirname(self.filename) - if not os.path.isdir(d): - os.makedirs(d) - self.f = open(self.filename, "ab") - self.f.write(b"\r\n") - client_random = binascii.hexlify(connection.client_random()) - masterkey = binascii.hexlify(connection.master_key()) - self.f.write(b"CLIENT_RANDOM %s %s\r\n" % (client_random, masterkey)) - self.f.flush() - - def close(self): - with self.lock: - if self.f: - self.f.close() - - @staticmethod - def create_logfun(filename): - if filename: - return SSLKeyLogger(filename) - return False - -log_ssl_key = SSLKeyLogger.create_logfun( - os.getenv("MITMPROXY_SSLKEYLOGFILE") or os.getenv("SSLKEYLOGFILE")) - - -class _FileLike: - BLOCKSIZE = 1024 * 32 - - def __init__(self, o): - self.o = o - self._log = None - self.first_byte_timestamp = None - - def set_descriptor(self, o): - self.o = o - - def __getattr__(self, attr): - return getattr(self.o, attr) - - def start_log(self): - """ - Starts or resets the log. - - This will store all bytes read or written. - """ - self._log = [] - - def stop_log(self): - """ - Stops the log. - """ - self._log = None - - def is_logging(self): - return self._log is not None - - def get_log(self): - """ - Returns the log as a string. - """ - if not self.is_logging(): - raise ValueError("Not logging!") - return b"".join(self._log) - - def add_log(self, v): - if self.is_logging(): - self._log.append(v) - - def reset_timestamps(self): - self.first_byte_timestamp = None - - -class Writer(_FileLike): - - def flush(self): - """ - May raise exceptions.TcpDisconnect - """ - if hasattr(self.o, "flush"): - try: - self.o.flush() - except (socket.error, IOError) as v: - raise exceptions.TcpDisconnect(str(v)) - - def write(self, v): - """ - May raise exceptions.TcpDisconnect - """ - if v: - self.first_byte_timestamp = self.first_byte_timestamp or time.time() - try: - if hasattr(self.o, "sendall"): - self.add_log(v) - return self.o.sendall(v) - else: - r = self.o.write(v) - self.add_log(v[:r]) - return r - except (SSL.Error, socket.error) as e: - raise exceptions.TcpDisconnect(str(e)) - - -class Reader(_FileLike): - - def read(self, length): - """ - If length is -1, we read until connection closes. - """ - result = b'' - start = time.time() - while length == -1 or length > 0: - if length == -1 or length > self.BLOCKSIZE: - rlen = self.BLOCKSIZE - else: - rlen = length - try: - data = self.o.read(rlen) - except SSL.ZeroReturnError: - # TLS connection was shut down cleanly - break - except (SSL.WantWriteError, SSL.WantReadError): - # From the OpenSSL docs: - # If the underlying BIO is non-blocking, SSL_read() will also return when the - # underlying BIO could not satisfy the needs of SSL_read() to continue the - # operation. In this case a call to SSL_get_error with the return value of - # SSL_read() will yield SSL_ERROR_WANT_READ or SSL_ERROR_WANT_WRITE. - if (time.time() - start) < self.o.gettimeout(): - time.sleep(0.1) - continue - else: - raise exceptions.TcpTimeout() - except socket.timeout: - raise exceptions.TcpTimeout() - except socket.error as e: - raise exceptions.TcpDisconnect(str(e)) - except SSL.SysCallError as e: - if e.args == (-1, 'Unexpected EOF'): - break - raise exceptions.TlsException(str(e)) - except SSL.Error as e: - raise exceptions.TlsException(str(e)) - self.first_byte_timestamp = self.first_byte_timestamp or time.time() - if not data: - break - result += data - if length != -1: - length -= len(data) - self.add_log(result) - return result - - def readline(self, size=None): - result = b'' - bytes_read = 0 - while True: - if size is not None and bytes_read >= size: - break - ch = self.read(1) - bytes_read += 1 - if not ch: - break - else: - result += ch - if ch == b'\n': - break - return result - - def safe_read(self, length): - """ - Like .read, but is guaranteed to either return length bytes, or - raise an exception. - """ - result = self.read(length) - if length != -1 and len(result) != length: - if not result: - raise exceptions.TcpDisconnect() - else: - raise exceptions.TcpReadIncomplete( - "Expected %s bytes, got %s" % (length, len(result)) - ) - return result - - def peek(self, length): - """ - Tries to peek into the underlying file object. - - Returns: - Up to the next N bytes if peeking is successful. - - Raises: - exceptions.TcpException if there was an error with the socket - TlsException if there was an error with pyOpenSSL. - NotImplementedError if the underlying file object is not a [pyOpenSSL] socket - """ - if isinstance(self.o, socket_fileobject): - try: - return self.o._sock.recv(length, socket.MSG_PEEK) - except socket.error as e: - raise exceptions.TcpException(repr(e)) - elif isinstance(self.o, SSL.Connection): - try: - return self.o.recv(length, socket.MSG_PEEK) - except SSL.Error as e: - raise exceptions.TlsException(str(e)) - else: - raise NotImplementedError("Can only peek into (pyOpenSSL) sockets") - - -class Address(serializable.Serializable): - - """ - This class wraps an IPv4/IPv6 tuple to provide named attributes and - ipv6 information. - """ - - def __init__(self, address, use_ipv6=False): - self.address = tuple(address) - self.use_ipv6 = use_ipv6 - - def get_state(self): - return { - "address": self.address, - "use_ipv6": self.use_ipv6 - } - - def set_state(self, state): - self.address = state["address"] - self.use_ipv6 = state["use_ipv6"] - - @classmethod - def from_state(cls, state): - return Address(**state) - - @classmethod - def wrap(cls, t): - if isinstance(t, cls): - return t - else: - return cls(t) - - def __call__(self): - return self.address - - @property - def host(self): - return self.address[0] - - @property - def port(self): - return self.address[1] - - @property - def use_ipv6(self): - return self.family == socket.AF_INET6 - - @use_ipv6.setter - def use_ipv6(self, b): - self.family = socket.AF_INET6 if b else socket.AF_INET - - def __repr__(self): - return "{}:{}".format(self.host, self.port) - - def __eq__(self, other): - if not other: - return False - other = Address.wrap(other) - return (self.address, self.family) == (other.address, other.family) - - def __ne__(self, other): - return not self.__eq__(other) - - def __hash__(self): - return hash(self.address) ^ 42 # different hash than the tuple alone. - - -def ssl_read_select(rlist, timeout): - """ - This is a wrapper around select.select() which also works for SSL.Connections - by taking ssl_connection.pending() into account. - - Caveats: - If .pending() > 0 for any of the connections in rlist, we avoid the select syscall - and **will not include any other connections which may or may not be ready**. - - Args: - rlist: wait until ready for reading - - Returns: - subset of rlist which is ready for reading. - """ - return [ - conn for conn in rlist - if isinstance(conn, SSL.Connection) and conn.pending() > 0 - ] or select.select(rlist, (), (), timeout)[0] - - -def close_socket(sock): - """ - Does a hard close of a socket, without emitting a RST. - """ - try: - # We already indicate that we close our end. - # may raise "Transport endpoint is not connected" on Linux - sock.shutdown(socket.SHUT_WR) - - # Section 4.2.2.13 of RFC 1122 tells us that a close() with any pending - # readable data could lead to an immediate RST being sent (which is the - # case on Windows). - # http://ia600609.us.archive.org/22/items/TheUltimateSo_lingerPageOrWhyIsMyTcpNotReliable/the-ultimate-so_linger-page-or-why-is-my-tcp-not-reliable.html - # - # This in turn results in the following issue: If we send an error page - # to the client and then close the socket, the RST may be received by - # the client before the error page and the users sees a connection - # error rather than the error page. Thus, we try to empty the read - # buffer on Windows first. (see - # https://github.com/mitmproxy/mitmproxy/issues/527#issuecomment-93782988) - # - - if os.name == "nt": # pragma: no cover - # We cannot rely on the shutdown()-followed-by-read()-eof technique - # proposed by the page above: Some remote machines just don't send - # a TCP FIN, which would leave us in the unfortunate situation that - # recv() would block infinitely. As a workaround, we set a timeout - # here even if we are in blocking mode. - sock.settimeout(sock.gettimeout() or 20) - - # limit at a megabyte so that we don't read infinitely - for _ in range(1024 ** 3 // 4096): - # may raise a timeout/disconnect exception. - if not sock.recv(4096): - break - - # Now we can close the other half as well. - sock.shutdown(socket.SHUT_RD) - - except socket.error: - pass - - sock.close() - - -class _Connection: - - rbufsize = -1 - wbufsize = -1 - - def _makefile(self): - """ - Set up .rfile and .wfile attributes from .connection - """ - # Ideally, we would use the Buffered IO in Python 3 by default. - # Unfortunately, the implementation of .peek() is broken for n>1 bytes, - # as it may just return what's left in the buffer and not all the bytes we want. - # As a workaround, we just use unbuffered sockets directly. - # https://mail.python.org/pipermail/python-dev/2009-June/089986.html - self.rfile = Reader(socket.SocketIO(self.connection, "rb")) - self.wfile = Writer(socket.SocketIO(self.connection, "wb")) - - def __init__(self, connection): - if connection: - self.connection = connection - self.ip_address = Address(connection.getpeername()) - self._makefile() - else: - self.connection = None - self.ip_address = None - self.rfile = None - self.wfile = None - - self.ssl_established = False - self.finished = False - - def get_current_cipher(self): - if not self.ssl_established: - return None - - name = self.connection.get_cipher_name() - bits = self.connection.get_cipher_bits() - version = self.connection.get_cipher_version() - return name, bits, version - - def finish(self): - self.finished = True - # If we have an SSL connection, wfile.close == connection.close - # (We call _FileLike.set_descriptor(conn)) - # Closing the socket is not our task, therefore we don't call close - # then. - if not isinstance(self.connection, SSL.Connection): - if not getattr(self.wfile, "closed", False): - try: - self.wfile.flush() - self.wfile.close() - except exceptions.TcpDisconnect: - pass - - self.rfile.close() - else: - try: - self.connection.shutdown() - except SSL.Error: - pass - - def _create_ssl_context(self, - method=SSL_DEFAULT_METHOD, - options=SSL_DEFAULT_OPTIONS, - verify_options=SSL.VERIFY_NONE, - ca_path=None, - ca_pemfile=None, - cipher_list=None, - alpn_protos=None, - alpn_select=None, - alpn_select_callback=None, - sni=None, - ): - """ - Creates an SSL Context. - - :param method: One of SSLv2_METHOD, SSLv3_METHOD, SSLv23_METHOD, TLSv1_METHOD, TLSv1_1_METHOD, or TLSv1_2_METHOD - :param options: A bit field consisting of OpenSSL.SSL.OP_* values - :param verify_options: A bit field consisting of OpenSSL.SSL.VERIFY_* values - :param ca_path: Path to a directory of trusted CA certificates prepared using the c_rehash tool - :param ca_pemfile: Path to a PEM formatted trusted CA certificate - :param cipher_list: A textual OpenSSL cipher list, see https://www.openssl.org/docs/apps/ciphers.html - :rtype : SSL.Context - """ - context = SSL.Context(method) - # Options (NO_SSLv2/3) - if options is not None: - context.set_options(options) - - # Verify Options (NONE/PEER and trusted CAs) - if verify_options is not None: - def verify_cert(conn, x509, errno, err_depth, is_cert_verified): - if not is_cert_verified: - self.ssl_verification_error = exceptions.InvalidCertificateException( - "Certificate Verification Error for {}: {} (errno: {}, depth: {})".format( - sni, - strutils.native(SSL._ffi.string(SSL._lib.X509_verify_cert_error_string(errno)), "utf8"), - errno, - err_depth - ) - ) - return is_cert_verified - - context.set_verify(verify_options, verify_cert) - if ca_path is None and ca_pemfile is None: - ca_pemfile = certifi.where() - context.load_verify_locations(ca_pemfile, ca_path) - - # Workaround for - # https://github.com/pyca/pyopenssl/issues/190 - # https://github.com/mitmproxy/mitmproxy/issues/472 - # Options already set before are not cleared. - context.set_mode(SSL._lib.SSL_MODE_AUTO_RETRY) - - # Cipher List - if cipher_list: - try: - context.set_cipher_list(cipher_list) - - # TODO: maybe change this to with newer pyOpenSSL APIs - context.set_tmp_ecdh(OpenSSL.crypto.get_elliptic_curve('prime256v1')) - except SSL.Error as v: - raise exceptions.TlsException("SSL cipher specification error: %s" % str(v)) - - # SSLKEYLOGFILE - if log_ssl_key: - context.set_info_callback(log_ssl_key) - - if HAS_ALPN: - if alpn_protos is not None: - # advertise application layer protocols - context.set_alpn_protos(alpn_protos) - elif alpn_select is not None and alpn_select_callback is None: - # select application layer protocol - def alpn_select_callback(conn_, options): - if alpn_select in options: - return bytes(alpn_select) - else: # pragma no cover - return options[0] - context.set_alpn_select_callback(alpn_select_callback) - elif alpn_select_callback is not None and alpn_select is None: - context.set_alpn_select_callback(alpn_select_callback) - elif alpn_select_callback is not None and alpn_select is not None: - raise exceptions.TlsException("ALPN error: only define alpn_select (string) OR alpn_select_callback (method).") - - return context - - -class ConnectionCloser: - def __init__(self, conn): - self.conn = conn - self._canceled = False - - def pop(self): - """ - Cancel the current closer, and return a fresh one. - """ - self._canceled = True - return ConnectionCloser(self.conn) - - def __enter__(self): - return self - - def __exit__(self, *args): - if not self._canceled: - self.conn.close() - - -class TCPClient(_Connection): - - def __init__(self, address, source_address=None, spoof_source_address=None): - super().__init__(None) - self.address = address - self.source_address = source_address - self.cert = None - self.server_certs = [] - self.ssl_verification_error = None # type: Optional[exceptions.InvalidCertificateException] - self.sni = None - self.spoof_source_address = spoof_source_address - - @property - def address(self): - return self.__address - - @address.setter - def address(self, address): - if address: - self.__address = Address.wrap(address) - else: - self.__address = None - - @property - def source_address(self): - return self.__source_address - - @source_address.setter - def source_address(self, source_address): - if source_address: - self.__source_address = Address.wrap(source_address) - else: - self.__source_address = None - - def close(self): - # Make sure to close the real socket, not the SSL proxy. - # OpenSSL is really good at screwing up, i.e. when trying to recv from a failed connection, - # it tries to renegotiate... - if isinstance(self.connection, SSL.Connection): - close_socket(self.connection._socket) - else: - close_socket(self.connection) - - def create_ssl_context(self, cert=None, alpn_protos=None, **sslctx_kwargs): - context = self._create_ssl_context( - alpn_protos=alpn_protos, - **sslctx_kwargs) - # Client Certs - if cert: - try: - context.use_privatekey_file(cert) - context.use_certificate_file(cert) - except SSL.Error as v: - raise exceptions.TlsException("SSL client certificate error: %s" % str(v)) - return context - - def convert_to_ssl(self, sni=None, alpn_protos=None, **sslctx_kwargs): - """ - cert: Path to a file containing both client cert and private key. - - options: A bit field consisting of OpenSSL.SSL.OP_* values - verify_options: A bit field consisting of OpenSSL.SSL.VERIFY_* values - ca_path: Path to a directory of trusted CA certificates prepared using the c_rehash tool - ca_pemfile: Path to a PEM formatted trusted CA certificate - """ - verification_mode = sslctx_kwargs.get('verify_options', None) - if verification_mode == SSL.VERIFY_PEER and not sni: - raise exceptions.TlsException("Cannot validate certificate hostname without SNI") - - context = self.create_ssl_context( - alpn_protos=alpn_protos, - sni=sni, - **sslctx_kwargs - ) - self.connection = SSL.Connection(context, self.connection) - if sni: - self.sni = sni - self.connection.set_tlsext_host_name(sni.encode("idna")) - self.connection.set_connect_state() - try: - self.connection.do_handshake() - except SSL.Error as v: - if self.ssl_verification_error: - raise self.ssl_verification_error - else: - raise exceptions.TlsException("SSL handshake error: %s" % repr(v)) - else: - # Fix for pre v1.0 OpenSSL, which doesn't throw an exception on - # certificate validation failure - if verification_mode == SSL.VERIFY_PEER and self.ssl_verification_error: - raise self.ssl_verification_error - - self.cert = certs.SSLCert(self.connection.get_peer_certificate()) - - # Keep all server certificates in a list - for i in self.connection.get_peer_cert_chain(): - self.server_certs.append(certs.SSLCert(i)) - - # Validate TLS Hostname - try: - crt = dict( - subjectAltName=[("DNS", x.decode("ascii", "strict")) for x in self.cert.altnames] - ) - if self.cert.cn: - crt["subject"] = [[["commonName", self.cert.cn.decode("ascii", "strict")]]] - if sni: - hostname = sni - else: - hostname = "no-hostname" - ssl_match_hostname.match_hostname(crt, hostname) - except (ValueError, ssl_match_hostname.CertificateError) as e: - self.ssl_verification_error = exceptions.InvalidCertificateException( - "Certificate Verification Error for {}: {}".format( - sni or repr(self.address), - str(e) - ) - ) - if verification_mode == SSL.VERIFY_PEER: - raise self.ssl_verification_error - - self.ssl_established = True - self.rfile.set_descriptor(self.connection) - self.wfile.set_descriptor(self.connection) - - def makesocket(self): - # some parties (cuckoo sandbox) need to hook this - return socket.socket(self.address.family, socket.SOCK_STREAM) - - def connect(self): - try: - connection = self.makesocket() - - if self.spoof_source_address: - try: - # 19 is `IP_TRANSPARENT`, which is only available on Python 3.3+ on some OSes - if not connection.getsockopt(socket.SOL_IP, 19): - connection.setsockopt(socket.SOL_IP, 19, 1) - except socket.error as e: - raise exceptions.TcpException( - "Failed to spoof the source address: " + e.strerror - ) - if self.source_address: - connection.bind(self.source_address()) - connection.connect(self.address()) - self.source_address = Address(connection.getsockname()) - except (socket.error, IOError) as err: - raise exceptions.TcpException( - 'Error connecting to "%s": %s' % - (self.address.host, err) - ) - self.connection = connection - self.ip_address = Address(connection.getpeername()) - self._makefile() - return ConnectionCloser(self) - - def settimeout(self, n): - self.connection.settimeout(n) - - def gettimeout(self): - return self.connection.gettimeout() - - def get_alpn_proto_negotiated(self): - if HAS_ALPN and self.ssl_established: - return self.connection.get_alpn_proto_negotiated() - else: - return b"" - - -class BaseHandler(_Connection): - - """ - The instantiator is expected to call the handle() and finish() methods. - """ - - def __init__(self, connection, address, server): - super().__init__(connection) - self.address = Address.wrap(address) - self.server = server - self.clientcert = None - - def create_ssl_context(self, - cert, key, - handle_sni=None, - request_client_cert=None, - chain_file=None, - dhparams=None, - extra_chain_certs=None, - **sslctx_kwargs): - """ - cert: A certs.SSLCert object or the path to a certificate - chain file. - - handle_sni: SNI handler, should take a connection object. Server - name can be retrieved like this: - - connection.get_servername() - - And you can specify the connection keys as follows: - - new_context = Context(TLSv1_METHOD) - new_context.use_privatekey(key) - new_context.use_certificate(cert) - connection.set_context(new_context) - - The request_client_cert argument requires some explanation. We're - supposed to be able to do this with no negative effects - if the - client has no cert to present, we're notified and proceed as usual. - Unfortunately, Android seems to have a bug (tested on 4.2.2) - when - an Android client is asked to present a certificate it does not - have, it hangs up, which is frankly bogus. Some time down the track - we may be able to make the proper behaviour the default again, but - until then we're conservative. - """ - - context = self._create_ssl_context(ca_pemfile=chain_file, **sslctx_kwargs) - - context.use_privatekey(key) - if isinstance(cert, certs.SSLCert): - context.use_certificate(cert.x509) - else: - context.use_certificate_chain_file(cert) - - if extra_chain_certs: - for i in extra_chain_certs: - context.add_extra_chain_cert(i.x509) - - if handle_sni: - # SNI callback happens during do_handshake() - context.set_tlsext_servername_callback(handle_sni) - - if request_client_cert: - def save_cert(conn_, cert, errno_, depth_, preverify_ok_): - self.clientcert = certs.SSLCert(cert) - # Return true to prevent cert verification error - return True - context.set_verify(SSL.VERIFY_PEER, save_cert) - - if dhparams: - SSL._lib.SSL_CTX_set_tmp_dh(context._context, dhparams) - - return context - - def convert_to_ssl(self, cert, key, **sslctx_kwargs): - """ - Convert connection to SSL. - For a list of parameters, see BaseHandler._create_ssl_context(...) - """ - - context = self.create_ssl_context( - cert, - key, - **sslctx_kwargs) - self.connection = SSL.Connection(context, self.connection) - self.connection.set_accept_state() - try: - self.connection.do_handshake() - except SSL.Error as v: - raise exceptions.TlsException("SSL handshake error: %s" % repr(v)) - self.ssl_established = True - self.rfile.set_descriptor(self.connection) - self.wfile.set_descriptor(self.connection) - - def handle(self): # pragma: no cover - raise NotImplementedError - - def settimeout(self, n): - self.connection.settimeout(n) - - def get_alpn_proto_negotiated(self): - if HAS_ALPN and self.ssl_established: - return self.connection.get_alpn_proto_negotiated() - else: - return b"" - - -class Counter: - def __init__(self): - self._count = 0 - self._lock = threading.Lock() - - @property - def count(self): - with self._lock: - return self._count - - def __enter__(self): - with self._lock: - self._count += 1 - - def __exit__(self, *args): - with self._lock: - self._count -= 1 - - -class TCPServer: - request_queue_size = 20 - - def __init__(self, address): - self.address = Address.wrap(address) - self.__is_shut_down = threading.Event() - self.__shutdown_request = False - self.socket = socket.socket(self.address.family, socket.SOCK_STREAM) - self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - self.socket.bind(self.address()) - self.address = Address.wrap(self.socket.getsockname()) - self.socket.listen(self.request_queue_size) - self.handler_counter = Counter() - - def connection_thread(self, connection, client_address): - with self.handler_counter: - client_address = Address(client_address) - try: - self.handle_client_connection(connection, client_address) - except: - self.handle_error(connection, client_address) - finally: - close_socket(connection) - - def serve_forever(self, poll_interval=0.1): - self.__is_shut_down.clear() - try: - while not self.__shutdown_request: - try: - r, w_, e_ = select.select( - [self.socket], [], [], poll_interval) - except select.error as ex: # pragma: no cover - if ex[0] == EINTR: - continue - else: - raise - if self.socket in r: - connection, client_address = self.socket.accept() - t = basethread.BaseThread( - "TCPConnectionHandler (%s: %s:%s -> %s:%s)" % ( - self.__class__.__name__, - client_address[0], - client_address[1], - self.address.host, - self.address.port - ), - target=self.connection_thread, - args=(connection, client_address), - ) - t.setDaemon(1) - try: - t.start() - except threading.ThreadError: - self.handle_error(connection, Address(client_address)) - connection.close() - finally: - self.__shutdown_request = False - self.__is_shut_down.set() - - def shutdown(self): - self.__shutdown_request = True - self.__is_shut_down.wait() - self.socket.close() - self.handle_shutdown() - - def handle_error(self, connection_, client_address, fp=sys.stderr): - """ - Called when handle_client_connection raises an exception. - """ - # If a thread has persisted after interpreter exit, the module might be - # none. - if traceback: - exc = str(traceback.format_exc()) - print(u'-' * 40, file=fp) - print( - u"Error in processing of request from %s" % repr(client_address), file=fp) - print(exc, file=fp) - print(u'-' * 40, file=fp) - - def handle_client_connection(self, conn, client_address): # pragma: no cover - """ - Called after client connection. - """ - raise NotImplementedError - - def handle_shutdown(self): - """ - Called after server shutdown. - """ - - def wait_for_silence(self, timeout=5): - start = time.time() - while 1: - if time.time() - start >= timeout: - raise exceptions.Timeout( - "%s service threads still alive" % - self.handler_counter.count - ) - if self.handler_counter.count == 0: - return diff --git a/netlib/websockets/__init__.py b/netlib/websockets/__init__.py deleted file mode 100644 index 2d6f0a0c..00000000 --- a/netlib/websockets/__init__.py +++ /dev/null @@ -1,35 +0,0 @@ -from .frame import FrameHeader -from .frame import Frame -from .frame import OPCODE -from .frame import CLOSE_REASON -from .masker import Masker -from .utils import MAGIC -from .utils import VERSION -from .utils import client_handshake_headers -from .utils import server_handshake_headers -from .utils import check_handshake -from .utils import check_client_version -from .utils import create_server_nonce -from .utils import get_extensions -from .utils import get_protocol -from .utils import get_client_key -from .utils import get_server_accept - -__all__ = [ - "FrameHeader", - "Frame", - "OPCODE", - "CLOSE_REASON", - "Masker", - "MAGIC", - "VERSION", - "client_handshake_headers", - "server_handshake_headers", - "check_handshake", - "check_client_version", - "create_server_nonce", - "get_extensions", - "get_protocol", - "get_client_key", - "get_server_accept", -] diff --git a/netlib/websockets/frame.py b/netlib/websockets/frame.py deleted file mode 100644 index bc4ae43a..00000000 --- a/netlib/websockets/frame.py +++ /dev/null @@ -1,274 +0,0 @@ -import os -import struct -import io - -from netlib import tcp -from mitmproxy.utils import strutils -from mitmproxy.utils import bits -from mitmproxy.utils import human -from mitmproxy.types import bidi -from .masker import Masker - - -MAX_16_BIT_INT = (1 << 16) -MAX_64_BIT_INT = (1 << 64) - -DEFAULT = object() - -# RFC 6455, Section 5.2 - Base Framing Protocol -OPCODE = bidi.BiDi( - CONTINUE=0x00, - TEXT=0x01, - BINARY=0x02, - CLOSE=0x08, - PING=0x09, - PONG=0x0a -) - -# RFC 6455, Section 7.4.1 - Defined Status Codes -CLOSE_REASON = bidi.BiDi( - NORMAL_CLOSURE=1000, - GOING_AWAY=1001, - PROTOCOL_ERROR=1002, - UNSUPPORTED_DATA=1003, - RESERVED=1004, - RESERVED_NO_STATUS=1005, - RESERVED_ABNORMAL_CLOSURE=1006, - INVALID_PAYLOAD_DATA=1007, - POLICY_VIOLATION=1008, - MESSAGE_TOO_BIG=1009, - MANDATORY_EXTENSION=1010, - INTERNAL_ERROR=1011, - RESERVED_TLS_HANDHSAKE_FAILED=1015, -) - - -class FrameHeader: - - def __init__( - self, - opcode=OPCODE.TEXT, - payload_length=0, - fin=False, - rsv1=False, - rsv2=False, - rsv3=False, - masking_key=DEFAULT, - mask=DEFAULT, - length_code=DEFAULT - ): - if not 0 <= opcode < 2 ** 4: - raise ValueError("opcode must be 0-16") - self.opcode = opcode - self.payload_length = payload_length - self.fin = fin - self.rsv1 = rsv1 - self.rsv2 = rsv2 - self.rsv3 = rsv3 - - if length_code is DEFAULT: - self.length_code = self._make_length_code(self.payload_length) - else: - self.length_code = length_code - - if mask is DEFAULT and masking_key is DEFAULT: - self.mask = False - self.masking_key = b"" - elif mask is DEFAULT: - self.mask = 1 - self.masking_key = masking_key - elif masking_key is DEFAULT: - self.mask = mask - self.masking_key = os.urandom(4) - else: - self.mask = mask - self.masking_key = masking_key - - if self.masking_key and len(self.masking_key) != 4: - raise ValueError("Masking key must be 4 bytes.") - - @classmethod - def _make_length_code(self, length): - """ - A websockets frame contains an initial length_code, and an optional - extended length code to represent the actual length if length code is - larger than 125 - """ - if length <= 125: - return length - elif length >= 126 and length <= 65535: - return 126 - else: - return 127 - - def __repr__(self): - vals = [ - "ws frame:", - OPCODE.get_name(self.opcode, hex(self.opcode)).lower() - ] - flags = [] - for i in ["fin", "rsv1", "rsv2", "rsv3", "mask"]: - if getattr(self, i): - flags.append(i) - if flags: - vals.extend([":", "|".join(flags)]) - if self.masking_key: - vals.append(":key=%s" % repr(self.masking_key)) - if self.payload_length: - vals.append(" %s" % human.pretty_size(self.payload_length)) - return "".join(vals) - - def __bytes__(self): - first_byte = bits.setbit(0, 7, self.fin) - first_byte = bits.setbit(first_byte, 6, self.rsv1) - first_byte = bits.setbit(first_byte, 5, self.rsv2) - first_byte = bits.setbit(first_byte, 4, self.rsv3) - first_byte = first_byte | self.opcode - - second_byte = bits.setbit(self.length_code, 7, self.mask) - - b = bytes([first_byte, second_byte]) - - if self.payload_length < 126: - pass - elif self.payload_length < MAX_16_BIT_INT: - # '!H' pack as 16 bit unsigned short - # add 2 byte extended payload length - b += struct.pack('!H', self.payload_length) - elif self.payload_length < MAX_64_BIT_INT: - # '!Q' = pack as 64 bit unsigned long long - # add 8 bytes extended payload length - b += struct.pack('!Q', self.payload_length) - else: - raise ValueError("Payload length exceeds 64bit integer") - - if self.masking_key: - b += self.masking_key - return b - - @classmethod - def from_file(cls, fp): - """ - read a websockets frame header - """ - first_byte, second_byte = fp.safe_read(2) - fin = bits.getbit(first_byte, 7) - rsv1 = bits.getbit(first_byte, 6) - rsv2 = bits.getbit(first_byte, 5) - rsv3 = bits.getbit(first_byte, 4) - opcode = first_byte & 0xF - mask_bit = bits.getbit(second_byte, 7) - length_code = second_byte & 0x7F - - # payload_length > 125 indicates you need to read more bytes - # to get the actual payload length - if length_code <= 125: - payload_length = length_code - elif length_code == 126: - payload_length, = struct.unpack("!H", fp.safe_read(2)) - else: # length_code == 127: - payload_length, = struct.unpack("!Q", fp.safe_read(8)) - - # masking key only present if mask bit set - if mask_bit == 1: - masking_key = fp.safe_read(4) - else: - masking_key = None - - return cls( - fin=fin, - rsv1=rsv1, - rsv2=rsv2, - rsv3=rsv3, - opcode=opcode, - mask=mask_bit, - length_code=length_code, - payload_length=payload_length, - masking_key=masking_key, - ) - - def __eq__(self, other): - if isinstance(other, FrameHeader): - return bytes(self) == bytes(other) - return False - - -class Frame: - """ - Represents a single WebSockets frame. - Constructor takes human readable forms of the frame components. - from_bytes() reads from a file-like object to create a new Frame. - - WebSockets Frame as defined in RFC6455 - - 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 - +-+-+-+-+-------+-+-------------+-------------------------------+ - |F|R|R|R| opcode|M| Payload len | Extended payload length | - |I|S|S|S| (4) |A| (7) | (16/64) | - |N|V|V|V| |S| | (if payload len==126/127) | - | |1|2|3| |K| | | - +-+-+-+-+-------+-+-------------+ - - - - - - - - - - - - - - - + - | Extended payload length continued, if payload len == 127 | - + - - - - - - - - - - - - - - - +-------------------------------+ - | |Masking-key, if MASK set to 1 | - +-------------------------------+-------------------------------+ - | Masking-key (continued) | Payload Data | - +-------------------------------- - - - - - - - - - - - - - - - + - : Payload Data continued ... : - + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + - | Payload Data continued ... | - +---------------------------------------------------------------+ - """ - - def __init__(self, payload=b"", **kwargs): - self.payload = payload - kwargs["payload_length"] = kwargs.get("payload_length", len(payload)) - self.header = FrameHeader(**kwargs) - - @classmethod - def from_bytes(cls, bytestring): - """ - Construct a websocket frame from an in-memory bytestring - to construct a frame from a stream of bytes, use from_file() directly - """ - return cls.from_file(tcp.Reader(io.BytesIO(bytestring))) - - def __repr__(self): - ret = repr(self.header) - if self.payload: - ret = ret + "\nPayload:\n" + strutils.bytes_to_escaped_str(self.payload) - return ret - - def __bytes__(self): - """ - Serialize the frame to wire format. Returns a string. - """ - b = bytes(self.header) - if self.header.masking_key: - b += Masker(self.header.masking_key)(self.payload) - else: - b += self.payload - return b - - @classmethod - def from_file(cls, fp): - """ - read a websockets frame sent by a server or client - - fp is a "file like" object that could be backed by a network - stream or a disk or an in memory stream reader - """ - header = FrameHeader.from_file(fp) - payload = fp.safe_read(header.payload_length) - - if header.mask == 1 and header.masking_key: - payload = Masker(header.masking_key)(payload) - - frame = cls(payload) - frame.header = header - return frame - - def __eq__(self, other): - if isinstance(other, Frame): - return bytes(self) == bytes(other) - return False diff --git a/netlib/websockets/masker.py b/netlib/websockets/masker.py deleted file mode 100644 index 47b1a688..00000000 --- a/netlib/websockets/masker.py +++ /dev/null @@ -1,25 +0,0 @@ -class Masker: - """ - Data sent from the server must be masked to prevent malicious clients - from sending data over the wire in predictable patterns. - - Servers do not have to mask data they send to the client. - https://tools.ietf.org/html/rfc6455#section-5.3 - """ - - def __init__(self, key): - self.key = key - self.offset = 0 - - def mask(self, offset, data): - result = bytearray(data) - for i in range(len(data)): - result[i] ^= self.key[offset % 4] - offset += 1 - result = bytes(result) - return result - - def __call__(self, data): - ret = self.mask(self.offset, data) - self.offset += len(ret) - return ret diff --git a/netlib/websockets/utils.py b/netlib/websockets/utils.py deleted file mode 100644 index 98043662..00000000 --- a/netlib/websockets/utils.py +++ /dev/null @@ -1,90 +0,0 @@ -""" -Collection of WebSockets Protocol utility functions (RFC6455) -Spec: https://tools.ietf.org/html/rfc6455 -""" - - -import base64 -import hashlib -import os - -from netlib import http -from mitmproxy.utils import strutils - -MAGIC = b'258EAFA5-E914-47DA-95CA-C5AB0DC85B11' -VERSION = "13" - - -def client_handshake_headers(version=None, key=None, protocol=None, extensions=None): - """ - Create the headers for a valid HTTP upgrade request. If Key is not - specified, it is generated, and can be found in sec-websocket-key in - the returned header set. - - Returns an instance of http.Headers - """ - if version is None: - version = VERSION - if key is None: - key = base64.b64encode(os.urandom(16)).decode('ascii') - h = http.Headers( - connection="upgrade", - upgrade="websocket", - sec_websocket_version=version, - sec_websocket_key=key, - ) - if protocol is not None: - h['sec-websocket-protocol'] = protocol - if extensions is not None: - h['sec-websocket-extensions'] = extensions - return h - - -def server_handshake_headers(client_key, protocol=None, extensions=None): - """ - The server response is a valid HTTP 101 response. - - Returns an instance of http.Headers - """ - h = http.Headers( - connection="upgrade", - upgrade="websocket", - sec_websocket_accept=create_server_nonce(client_key), - ) - if protocol is not None: - h['sec-websocket-protocol'] = protocol - if extensions is not None: - h['sec-websocket-extensions'] = extensions - return h - - -def check_handshake(headers): - return ( - "upgrade" in headers.get("connection", "").lower() and - headers.get("upgrade", "").lower() == "websocket" and - (headers.get("sec-websocket-key") is not None or headers.get("sec-websocket-accept") is not None) - ) - - -def create_server_nonce(client_nonce): - return base64.b64encode(hashlib.sha1(strutils.always_bytes(client_nonce) + MAGIC).digest()) - - -def check_client_version(headers): - return headers.get("sec-websocket-version", "") == VERSION - - -def get_extensions(headers): - return headers.get("sec-websocket-extensions", None) - - -def get_protocol(headers): - return headers.get("sec-websocket-protocol", None) - - -def get_client_key(headers): - return headers.get("sec-websocket-key", None) - - -def get_server_accept(headers): - return headers.get("sec-websocket-accept", None) diff --git a/netlib/wsgi.py b/netlib/wsgi.py deleted file mode 100644 index 5a54cd70..00000000 --- a/netlib/wsgi.py +++ /dev/null @@ -1,166 +0,0 @@ -import time -import traceback -import urllib -import io - -from netlib import http -from netlib import tcp -from mitmproxy.utils import strutils - - -class ClientConn: - - def __init__(self, address): - self.address = tcp.Address.wrap(address) - - -class Flow: - - def __init__(self, address, request): - self.client_conn = ClientConn(address) - self.request = request - - -class Request: - - def __init__(self, scheme, method, path, http_version, headers, content): - self.scheme, self.method, self.path = scheme, method, path - self.headers, self.content = headers, content - self.http_version = http_version - - -def date_time_string(): - """Return the current date and time formatted for a message header.""" - WEEKS = ['Mon', 'Tue', 'Wed', 'Thu', 'Fri', 'Sat', 'Sun'] - MONTHS = [ - None, - 'Jan', 'Feb', 'Mar', 'Apr', 'May', 'Jun', - 'Jul', 'Aug', 'Sep', 'Oct', 'Nov', 'Dec' - ] - now = time.time() - year, month, day, hh, mm, ss, wd, y_, z_ = time.gmtime(now) - s = "%s, %02d %3s %4d %02d:%02d:%02d GMT" % ( - WEEKS[wd], - day, MONTHS[month], year, - hh, mm, ss - ) - return s - - -class WSGIAdaptor: - - def __init__(self, app, domain, port, sversion): - self.app, self.domain, self.port, self.sversion = app, domain, port, sversion - - def make_environ(self, flow, errsoc, **extra): - """ - Raises: - ValueError, if the content-encoding is invalid. - """ - path = strutils.native(flow.request.path, "latin-1") - if '?' in path: - path_info, query = strutils.native(path, "latin-1").split('?', 1) - else: - path_info = path - query = '' - environ = { - 'wsgi.version': (1, 0), - 'wsgi.url_scheme': strutils.native(flow.request.scheme, "latin-1"), - 'wsgi.input': io.BytesIO(flow.request.content or b""), - 'wsgi.errors': errsoc, - 'wsgi.multithread': True, - 'wsgi.multiprocess': False, - 'wsgi.run_once': False, - 'SERVER_SOFTWARE': self.sversion, - 'REQUEST_METHOD': strutils.native(flow.request.method, "latin-1"), - 'SCRIPT_NAME': '', - 'PATH_INFO': urllib.parse.unquote(path_info), - 'QUERY_STRING': query, - 'CONTENT_TYPE': strutils.native(flow.request.headers.get('Content-Type', ''), "latin-1"), - 'CONTENT_LENGTH': strutils.native(flow.request.headers.get('Content-Length', ''), "latin-1"), - 'SERVER_NAME': self.domain, - 'SERVER_PORT': str(self.port), - 'SERVER_PROTOCOL': strutils.native(flow.request.http_version, "latin-1"), - } - environ.update(extra) - if flow.client_conn.address: - environ["REMOTE_ADDR"] = strutils.native(flow.client_conn.address.host, "latin-1") - environ["REMOTE_PORT"] = flow.client_conn.address.port - - for key, value in flow.request.headers.items(): - key = 'HTTP_' + strutils.native(key, "latin-1").upper().replace('-', '_') - if key not in ('HTTP_CONTENT_TYPE', 'HTTP_CONTENT_LENGTH'): - environ[key] = value - return environ - - def error_page(self, soc, headers_sent, s): - """ - Make a best-effort attempt to write an error page. If headers are - already sent, we just bung the error into the page. - """ - c = """ - <html> - <h1>Internal Server Error</h1> - <pre>{err}"</pre> - </html> - """.format(err=s).strip().encode() - - if not headers_sent: - soc.write(b"HTTP/1.1 500 Internal Server Error\r\n") - soc.write(b"Content-Type: text/html\r\n") - soc.write("Content-Length: {length}\r\n".format(length=len(c)).encode()) - soc.write(b"\r\n") - soc.write(c) - - def serve(self, request, soc, **env): - state = dict( - response_started=False, - headers_sent=False, - status=None, - headers=None - ) - - def write(data): - if not state["headers_sent"]: - soc.write("HTTP/1.1 {status}\r\n".format(status=state["status"]).encode()) - headers = state["headers"] - if 'server' not in headers: - headers["Server"] = self.sversion - if 'date' not in headers: - headers["Date"] = date_time_string() - soc.write(bytes(headers)) - soc.write(b"\r\n") - state["headers_sent"] = True - if data: - soc.write(data) - soc.flush() - - def start_response(status, headers, exc_info=None): - if exc_info: - if state["headers_sent"]: - raise exc_info[1] - elif state["status"]: - raise AssertionError('Response already started') - state["status"] = status - state["headers"] = http.Headers([[strutils.always_bytes(k), strutils.always_bytes(v)] for k, v in headers]) - if exc_info: - self.error_page(soc, state["headers_sent"], traceback.format_tb(exc_info[2])) - state["headers_sent"] = True - - errs = io.BytesIO() - try: - dataiter = self.app( - self.make_environ(request, errs, **env), start_response - ) - for i in dataiter: - write(i) - if not state["headers_sent"]: - write(b"") - except Exception: - try: - s = traceback.format_exc() - errs.write(s.encode("utf-8", "replace")) - self.error_page(soc, state["headers_sent"], s) - except Exception: # pragma: no cover - pass - return errs.getvalue() |