From d33d3663ecb166461d9cb5a78a29b44ee7a8fbb7 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Thu, 18 Feb 2016 13:03:40 +0100 Subject: combine projects --- netlib/README.rst | 35 -- netlib/__init__.py | 1 + netlib/certutils.py | 472 +++++++++++++++++ netlib/encoding.py | 88 +++ netlib/exceptions.py | 56 ++ netlib/http/__init__.py | 14 + netlib/http/authentication.py | 167 ++++++ netlib/http/cookies.py | 193 +++++++ netlib/http/headers.py | 204 +++++++ netlib/http/http1/__init__.py | 25 + netlib/http/http1/assemble.py | 104 ++++ netlib/http/http1/read.py | 362 +++++++++++++ netlib/http/http2/__init__.py | 6 + netlib/http/http2/connections.py | 426 +++++++++++++++ netlib/http/message.py | 222 ++++++++ netlib/http/request.py | 356 +++++++++++++ netlib/http/response.py | 116 ++++ netlib/http/status_codes.py | 106 ++++ netlib/http/user_agents.py | 52 ++ netlib/netlib/__init__.py | 1 - netlib/netlib/certutils.py | 472 ----------------- netlib/netlib/encoding.py | 88 --- netlib/netlib/exceptions.py | 56 -- netlib/netlib/http/__init__.py | 14 - netlib/netlib/http/authentication.py | 167 ------ netlib/netlib/http/cookies.py | 193 ------- netlib/netlib/http/headers.py | 204 ------- netlib/netlib/http/http1/__init__.py | 25 - netlib/netlib/http/http1/assemble.py | 104 ---- netlib/netlib/http/http1/read.py | 362 ------------- netlib/netlib/http/http2/__init__.py | 6 - netlib/netlib/http/http2/connections.py | 426 --------------- netlib/netlib/http/message.py | 222 -------- netlib/netlib/http/request.py | 356 ------------- netlib/netlib/http/response.py | 116 ---- netlib/netlib/http/status_codes.py | 106 ---- netlib/netlib/http/user_agents.py | 52 -- netlib/netlib/odict.py | 193 ------- netlib/netlib/socks.py | 176 ------ netlib/netlib/tcp.py | 911 -------------------------------- netlib/netlib/tutils.py | 133 ----- netlib/netlib/utils.py | 418 --------------- netlib/netlib/version.py | 6 - netlib/netlib/version_check.py | 60 --- netlib/netlib/websockets/__init__.py | 2 - netlib/netlib/websockets/frame.py | 316 ----------- netlib/netlib/websockets/protocol.py | 115 ---- netlib/netlib/wsgi.py | 164 ------ netlib/odict.py | 193 +++++++ netlib/setup.cfg | 2 - netlib/setup.py | 70 --- netlib/socks.py | 176 ++++++ netlib/tcp.py | 911 ++++++++++++++++++++++++++++++++ netlib/tutils.py | 133 +++++ netlib/utils.py | 418 +++++++++++++++ netlib/version.py | 6 + netlib/version_check.py | 60 +++ netlib/websockets/__init__.py | 2 + netlib/websockets/frame.py | 316 +++++++++++ netlib/websockets/protocol.py | 115 ++++ netlib/wsgi.py | 164 ++++++ 61 files changed, 5464 insertions(+), 5571 deletions(-) delete mode 100644 netlib/README.rst create mode 100644 netlib/__init__.py create mode 100644 netlib/certutils.py create mode 100644 netlib/encoding.py create mode 100644 netlib/exceptions.py create mode 100644 netlib/http/__init__.py create mode 100644 netlib/http/authentication.py create mode 100644 netlib/http/cookies.py create mode 100644 netlib/http/headers.py create mode 100644 netlib/http/http1/__init__.py create mode 100644 netlib/http/http1/assemble.py create mode 100644 netlib/http/http1/read.py create mode 100644 netlib/http/http2/__init__.py create mode 100644 netlib/http/http2/connections.py create mode 100644 netlib/http/message.py create mode 100644 netlib/http/request.py create mode 100644 netlib/http/response.py create mode 100644 netlib/http/status_codes.py create mode 100644 netlib/http/user_agents.py delete mode 100644 netlib/netlib/__init__.py delete mode 100644 netlib/netlib/certutils.py delete mode 100644 netlib/netlib/encoding.py delete mode 100644 netlib/netlib/exceptions.py delete mode 100644 netlib/netlib/http/__init__.py delete mode 100644 netlib/netlib/http/authentication.py delete mode 100644 netlib/netlib/http/cookies.py delete mode 100644 netlib/netlib/http/headers.py delete mode 100644 netlib/netlib/http/http1/__init__.py delete mode 100644 netlib/netlib/http/http1/assemble.py delete mode 100644 netlib/netlib/http/http1/read.py delete mode 100644 netlib/netlib/http/http2/__init__.py delete mode 100644 netlib/netlib/http/http2/connections.py delete mode 100644 netlib/netlib/http/message.py delete mode 100644 netlib/netlib/http/request.py delete mode 100644 netlib/netlib/http/response.py delete mode 100644 netlib/netlib/http/status_codes.py delete mode 100644 netlib/netlib/http/user_agents.py delete mode 100644 netlib/netlib/odict.py delete mode 100644 netlib/netlib/socks.py delete mode 100644 netlib/netlib/tcp.py delete mode 100644 netlib/netlib/tutils.py delete mode 100644 netlib/netlib/utils.py delete mode 100644 netlib/netlib/version.py delete mode 100644 netlib/netlib/version_check.py delete mode 100644 netlib/netlib/websockets/__init__.py delete mode 100644 netlib/netlib/websockets/frame.py delete mode 100644 netlib/netlib/websockets/protocol.py delete mode 100644 netlib/netlib/wsgi.py create mode 100644 netlib/odict.py delete mode 100644 netlib/setup.cfg delete mode 100644 netlib/setup.py create mode 100644 netlib/socks.py create mode 100644 netlib/tcp.py create mode 100644 netlib/tutils.py create mode 100644 netlib/utils.py create mode 100644 netlib/version.py create mode 100644 netlib/version_check.py create mode 100644 netlib/websockets/__init__.py create mode 100644 netlib/websockets/frame.py create mode 100644 netlib/websockets/protocol.py create mode 100644 netlib/wsgi.py (limited to 'netlib') diff --git a/netlib/README.rst b/netlib/README.rst deleted file mode 100644 index 16bd65a7..00000000 --- a/netlib/README.rst +++ /dev/null @@ -1,35 +0,0 @@ -|travis| |coveralls| |downloads| |latest_release| |python_versions| - -Netlib is a collection of network utility classes, used by the pathod and -mitmproxy projects. It differs from other projects in some fundamental -respects, because both pathod and mitmproxy often need to violate standards. -This means that protocols are implemented as small, well-contained and flexible -functions, and are designed to allow misbehaviour when needed. - - -Development ------------ - -If you'd like to work on netlib, check out the instructions in mitmproxy's README_. - -.. |travis| image:: https://shields.mitmproxy.org/travis/mitmproxy/netlib/master.svg - :target: https://travis-ci.org/mitmproxy/netlib - :alt: Build Status - -.. |coveralls| image:: https://shields.mitmproxy.org/coveralls/mitmproxy/netlib/master.svg - :target: https://coveralls.io/r/mitmproxy/netlib - :alt: Coverage Status - -.. |downloads| image:: https://shields.mitmproxy.org/pypi/dm/netlib.svg?color=orange - :target: https://pypi.python.org/pypi/netlib - :alt: Downloads - -.. |latest_release| image:: https://shields.mitmproxy.org/pypi/v/netlib.svg - :target: https://pypi.python.org/pypi/netlib - :alt: Latest Version - -.. |python_versions| image:: https://shields.mitmproxy.org/pypi/pyversions/netlib.svg - :target: https://pypi.python.org/pypi/netlib - :alt: Supported Python versions - -.. _README: https://github.com/mitmproxy/mitmproxy#hacking \ No newline at end of file diff --git a/netlib/__init__.py b/netlib/__init__.py new file mode 100644 index 00000000..9b4faa33 --- /dev/null +++ b/netlib/__init__.py @@ -0,0 +1 @@ +from __future__ import (absolute_import, print_function, division) diff --git a/netlib/certutils.py b/netlib/certutils.py new file mode 100644 index 00000000..616a778e --- /dev/null +++ b/netlib/certutils.py @@ -0,0 +1,472 @@ +from __future__ import (absolute_import, print_function, division) +import os +import ssl +import time +import datetime +from six.moves import filter +import ipaddress + +import sys +from pyasn1.type import univ, constraint, char, namedtype, tag +from pyasn1.codec.der.decoder import decode +from pyasn1.error import PyAsn1Error +import OpenSSL + +from .utils import Serializable + +# Default expiry must not be too long: https://github.com/mitmproxy/mitmproxy/issues/815 + +DEFAULT_EXP = 94608000 # = 24 * 60 * 60 * 365 * 3 +# Generated with "openssl dhparam". It's too slow to generate this on startup. +DEFAULT_DHPARAM = b""" +-----BEGIN DH PARAMETERS----- +MIICCAKCAgEAyT6LzpwVFS3gryIo29J5icvgxCnCebcdSe/NHMkD8dKJf8suFCg3 +O2+dguLakSVif/t6dhImxInJk230HmfC8q93hdcg/j8rLGJYDKu3ik6H//BAHKIv +j5O9yjU3rXCfmVJQic2Nne39sg3CreAepEts2TvYHhVv3TEAzEqCtOuTjgDv0ntJ +Gwpj+BJBRQGG9NvprX1YGJ7WOFBP/hWU7d6tgvE6Xa7T/u9QIKpYHMIkcN/l3ZFB +chZEqVlyrcngtSXCROTPcDOQ6Q8QzhaBJS+Z6rcsd7X+haiQqvoFcmaJ08Ks6LQC +ZIL2EtYJw8V8z7C0igVEBIADZBI6OTbuuhDwRw//zU1uq52Oc48CIZlGxTYG/Evq +o9EWAXUYVzWkDSTeBH1r4z/qLPE2cnhtMxbFxuvK53jGB0emy2y1Ei6IhKshJ5qX +IB/aE7SSHyQ3MDHHkCmQJCsOd4Mo26YX61NZ+n501XjqpCBQ2+DfZCBh8Va2wDyv +A2Ryg9SUz8j0AXViRNMJgJrr446yro/FuJZwnQcO3WQnXeqSBnURqKjmqkeFP+d8 +6mk2tqJaY507lRNqtGlLnj7f5RNoBFJDCLBNurVgfvq9TCVWKDIFD4vZRjCrnl6I +rD693XKIHUCWOjMh1if6omGXKHH40QuME2gNa50+YPn1iYDl88uDbbMCAQI= +-----END DH PARAMETERS----- +""" + + +def create_ca(o, cn, exp): + key = OpenSSL.crypto.PKey() + key.generate_key(OpenSSL.crypto.TYPE_RSA, 2048) + cert = OpenSSL.crypto.X509() + cert.set_serial_number(int(time.time() * 10000)) + cert.set_version(2) + cert.get_subject().CN = cn + cert.get_subject().O = o + cert.gmtime_adj_notBefore(-3600 * 48) + cert.gmtime_adj_notAfter(exp) + cert.set_issuer(cert.get_subject()) + cert.set_pubkey(key) + cert.add_extensions([ + OpenSSL.crypto.X509Extension( + b"basicConstraints", + True, + b"CA:TRUE" + ), + OpenSSL.crypto.X509Extension( + b"nsCertType", + False, + b"sslCA" + ), + OpenSSL.crypto.X509Extension( + b"extendedKeyUsage", + False, + b"serverAuth,clientAuth,emailProtection,timeStamping,msCodeInd,msCodeCom,msCTLSign,msSGC,msEFS,nsSGC" + ), + OpenSSL.crypto.X509Extension( + b"keyUsage", + True, + b"keyCertSign, cRLSign" + ), + OpenSSL.crypto.X509Extension( + b"subjectKeyIdentifier", + False, + b"hash", + subject=cert + ), + ]) + cert.sign(key, "sha256") + return key, cert + + +def dummy_cert(privkey, cacert, commonname, sans): + """ + Generates a dummy certificate. + + privkey: CA private key + cacert: CA certificate + commonname: Common name for the generated certificate. + sans: A list of Subject Alternate Names. + + Returns cert if operation succeeded, None if not. + """ + ss = [] + for i in sans: + try: + ipaddress.ip_address(i.decode("ascii")) + except ValueError: + ss.append(b"DNS: %s" % i) + else: + ss.append(b"IP: %s" % i) + ss = b", ".join(ss) + + cert = OpenSSL.crypto.X509() + cert.gmtime_adj_notBefore(-3600 * 48) + cert.gmtime_adj_notAfter(DEFAULT_EXP) + cert.set_issuer(cacert.get_subject()) + if commonname is not None: + cert.get_subject().CN = commonname + cert.set_serial_number(int(time.time() * 10000)) + if ss: + cert.set_version(2) + cert.add_extensions( + [OpenSSL.crypto.X509Extension(b"subjectAltName", False, ss)]) + cert.set_pubkey(cacert.get_pubkey()) + cert.sign(privkey, "sha256") + return SSLCert(cert) + + +# DNTree did not pass TestCertStore.test_sans_change and is temporarily replaced by a simple dict. +# +# class _Node(UserDict.UserDict): +# def __init__(self): +# UserDict.UserDict.__init__(self) +# self.value = None +# +# +# class DNTree: +# """ +# Domain store that knows about wildcards. DNS wildcards are very +# restricted - the only valid variety is an asterisk on the left-most +# domain component, i.e.: +# +# *.foo.com +# """ +# def __init__(self): +# self.d = _Node() +# +# def add(self, dn, cert): +# parts = dn.split(".") +# parts.reverse() +# current = self.d +# for i in parts: +# current = current.setdefault(i, _Node()) +# current.value = cert +# +# def get(self, dn): +# parts = dn.split(".") +# current = self.d +# for i in reversed(parts): +# if i in current: +# current = current[i] +# elif "*" in current: +# return current["*"].value +# else: +# return None +# return current.value + + +class CertStoreEntry(object): + + def __init__(self, cert, privatekey, chain_file): + self.cert = cert + self.privatekey = privatekey + self.chain_file = chain_file + + +class CertStore(object): + + """ + Implements an in-memory certificate store. + """ + + def __init__( + self, + default_privatekey, + default_ca, + default_chain_file, + dhparams): + self.default_privatekey = default_privatekey + self.default_ca = default_ca + self.default_chain_file = default_chain_file + self.dhparams = dhparams + self.certs = dict() + + @staticmethod + def load_dhparam(path): + + # netlib<=0.10 doesn't generate a dhparam file. + # Create it now if neccessary. + if not os.path.exists(path): + with open(path, "wb") as f: + f.write(DEFAULT_DHPARAM) + + bio = OpenSSL.SSL._lib.BIO_new_file(path.encode(sys.getfilesystemencoding()), b"r") + if bio != OpenSSL.SSL._ffi.NULL: + bio = OpenSSL.SSL._ffi.gc(bio, OpenSSL.SSL._lib.BIO_free) + dh = OpenSSL.SSL._lib.PEM_read_bio_DHparams( + bio, + OpenSSL.SSL._ffi.NULL, + OpenSSL.SSL._ffi.NULL, + OpenSSL.SSL._ffi.NULL) + dh = OpenSSL.SSL._ffi.gc(dh, OpenSSL.SSL._lib.DH_free) + return dh + + @classmethod + def from_store(cls, path, basename): + ca_path = os.path.join(path, basename + "-ca.pem") + if not os.path.exists(ca_path): + key, ca = cls.create_store(path, basename) + else: + with open(ca_path, "rb") as f: + raw = f.read() + ca = OpenSSL.crypto.load_certificate( + OpenSSL.crypto.FILETYPE_PEM, + raw) + key = OpenSSL.crypto.load_privatekey( + OpenSSL.crypto.FILETYPE_PEM, + raw) + dh_path = os.path.join(path, basename + "-dhparam.pem") + dh = cls.load_dhparam(dh_path) + return cls(key, ca, ca_path, dh) + + @staticmethod + def create_store(path, basename, o=None, cn=None, expiry=DEFAULT_EXP): + if not os.path.exists(path): + os.makedirs(path) + + o = o or basename + cn = cn or basename + + key, ca = create_ca(o=o, cn=cn, exp=expiry) + # Dump the CA plus private key + with open(os.path.join(path, basename + "-ca.pem"), "wb") as f: + f.write( + OpenSSL.crypto.dump_privatekey( + OpenSSL.crypto.FILETYPE_PEM, + key)) + f.write( + OpenSSL.crypto.dump_certificate( + OpenSSL.crypto.FILETYPE_PEM, + ca)) + + # Dump the certificate in PEM format + with open(os.path.join(path, basename + "-ca-cert.pem"), "wb") as f: + f.write( + OpenSSL.crypto.dump_certificate( + OpenSSL.crypto.FILETYPE_PEM, + ca)) + + # Create a .cer file with the same contents for Android + with open(os.path.join(path, basename + "-ca-cert.cer"), "wb") as f: + f.write( + OpenSSL.crypto.dump_certificate( + OpenSSL.crypto.FILETYPE_PEM, + ca)) + + # Dump the certificate in PKCS12 format for Windows devices + with open(os.path.join(path, basename + "-ca-cert.p12"), "wb") as f: + p12 = OpenSSL.crypto.PKCS12() + p12.set_certificate(ca) + p12.set_privatekey(key) + f.write(p12.export()) + + with open(os.path.join(path, basename + "-dhparam.pem"), "wb") as f: + f.write(DEFAULT_DHPARAM) + + return key, ca + + def add_cert_file(self, spec, path): + with open(path, "rb") as f: + raw = f.read() + cert = SSLCert( + OpenSSL.crypto.load_certificate( + OpenSSL.crypto.FILETYPE_PEM, + raw)) + try: + privatekey = OpenSSL.crypto.load_privatekey( + OpenSSL.crypto.FILETYPE_PEM, + raw) + except Exception: + privatekey = self.default_privatekey + self.add_cert( + CertStoreEntry(cert, privatekey, path), + spec + ) + + def add_cert(self, entry, *names): + """ + Adds a cert to the certstore. We register the CN in the cert plus + any SANs, and also the list of names provided as an argument. + """ + if entry.cert.cn: + self.certs[entry.cert.cn] = entry + for i in entry.cert.altnames: + self.certs[i] = entry + for i in names: + self.certs[i] = entry + + @staticmethod + def asterisk_forms(dn): + if dn is None: + return [] + parts = dn.split(b".") + parts.reverse() + curr_dn = b"" + dn_forms = [b"*"] + for part in parts[:-1]: + curr_dn = b"." + part + curr_dn # .example.com + dn_forms.append(b"*" + curr_dn) # *.example.com + if parts[-1] != b"*": + dn_forms.append(parts[-1] + curr_dn) + return dn_forms + + def get_cert(self, commonname, sans): + """ + Returns an (cert, privkey, cert_chain) tuple. + + commonname: Common name for the generated certificate. Must be a + valid, plain-ASCII, IDNA-encoded domain name. + + sans: A list of Subject Alternate Names. + """ + + potential_keys = self.asterisk_forms(commonname) + for s in sans: + potential_keys.extend(self.asterisk_forms(s)) + potential_keys.append((commonname, tuple(sans))) + + name = next( + filter(lambda key: key in self.certs, potential_keys), + None + ) + if name: + entry = self.certs[name] + else: + entry = CertStoreEntry( + cert=dummy_cert( + self.default_privatekey, + self.default_ca, + commonname, + sans), + privatekey=self.default_privatekey, + chain_file=self.default_chain_file) + self.certs[(commonname, tuple(sans))] = entry + + return entry.cert, entry.privatekey, entry.chain_file + + +class _GeneralName(univ.Choice): + # We are only interested in dNSNames. We use a default handler to ignore + # other types. + # TODO: We should also handle iPAddresses. + componentType = namedtype.NamedTypes( + namedtype.NamedType('dNSName', char.IA5String().subtype( + implicitTag=tag.Tag(tag.tagClassContext, tag.tagFormatSimple, 2) + ) + ), + ) + + +class _GeneralNames(univ.SequenceOf): + componentType = _GeneralName() + sizeSpec = univ.SequenceOf.sizeSpec + \ + constraint.ValueSizeConstraint(1, 1024) + + +class SSLCert(Serializable): + + def __init__(self, cert): + """ + Returns a (common name, [subject alternative names]) tuple. + """ + self.x509 = cert + + def __eq__(self, other): + return self.digest("sha256") == other.digest("sha256") + + def __ne__(self, other): + return not self.__eq__(other) + + def get_state(self): + return self.to_pem() + + def set_state(self, state): + self.x509 = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, state) + + @classmethod + def from_state(cls, state): + cls.from_pem(state) + + @classmethod + def from_pem(cls, txt): + x509 = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, txt) + return cls(x509) + + @classmethod + def from_der(cls, der): + pem = ssl.DER_cert_to_PEM_cert(der) + return cls.from_pem(pem) + + def to_pem(self): + return OpenSSL.crypto.dump_certificate( + OpenSSL.crypto.FILETYPE_PEM, + self.x509) + + def digest(self, name): + return self.x509.digest(name) + + @property + def issuer(self): + return self.x509.get_issuer().get_components() + + @property + def notbefore(self): + t = self.x509.get_notBefore() + return datetime.datetime.strptime(t.decode("ascii"), "%Y%m%d%H%M%SZ") + + @property + def notafter(self): + t = self.x509.get_notAfter() + return datetime.datetime.strptime(t.decode("ascii"), "%Y%m%d%H%M%SZ") + + @property + def has_expired(self): + return self.x509.has_expired() + + @property + def subject(self): + return self.x509.get_subject().get_components() + + @property + def serial(self): + return self.x509.get_serial_number() + + @property + def keyinfo(self): + pk = self.x509.get_pubkey() + types = { + OpenSSL.crypto.TYPE_RSA: "RSA", + OpenSSL.crypto.TYPE_DSA: "DSA", + } + return ( + types.get(pk.type(), "UNKNOWN"), + pk.bits() + ) + + @property + def cn(self): + c = None + for i in self.subject: + if i[0] == b"CN": + c = i[1] + return c + + @property + def altnames(self): + """ + Returns: + All DNS altnames. + """ + # tcp.TCPClient.convert_to_ssl assumes that this property only contains DNS altnames for hostname verification. + altnames = [] + for i in range(self.x509.get_extension_count()): + ext = self.x509.get_extension(i) + if ext.get_short_name() == b"subjectAltName": + try: + dec = decode(ext.get_data(), asn1Spec=_GeneralNames()) + except PyAsn1Error: + continue + for i in dec[0]: + altnames.append(i[0].asOctets()) + return altnames diff --git a/netlib/encoding.py b/netlib/encoding.py new file mode 100644 index 00000000..14479e00 --- /dev/null +++ b/netlib/encoding.py @@ -0,0 +1,88 @@ +""" + Utility functions for decoding response bodies. +""" +from __future__ import absolute_import +from io import BytesIO +import gzip +import zlib +from .utils import always_byte_args + + +ENCODINGS = {"identity", "gzip", "deflate"} + + +def decode(e, content): + if not isinstance(content, bytes): + return None + encoding_map = { + "identity": identity, + "gzip": decode_gzip, + "deflate": decode_deflate, + } + if e not in encoding_map: + return None + return encoding_map[e](content) + + +def encode(e, content): + if not isinstance(content, bytes): + return None + encoding_map = { + "identity": identity, + "gzip": encode_gzip, + "deflate": encode_deflate, + } + if e not in encoding_map: + return None + return encoding_map[e](content) + + +def identity(content): + """ + Returns content unchanged. Identity is the default value of + Accept-Encoding headers. + """ + return content + + +def decode_gzip(content): + gfile = gzip.GzipFile(fileobj=BytesIO(content)) + try: + return gfile.read() + except (IOError, EOFError): + return None + + +def encode_gzip(content): + s = BytesIO() + gf = gzip.GzipFile(fileobj=s, mode='wb') + gf.write(content) + gf.close() + return s.getvalue() + + +def decode_deflate(content): + """ + Returns decompressed data for DEFLATE. Some servers may respond with + compressed data without a zlib header or checksum. An undocumented + feature of zlib permits the lenient decompression of data missing both + values. + + http://bugs.python.org/issue5784 + """ + try: + try: + return zlib.decompress(content) + except zlib.error: + return zlib.decompress(content, -15) + except zlib.error: + return None + + +def encode_deflate(content): + """ + Returns compressed content, always including zlib header and checksum. + """ + return zlib.compress(content) + +__all__ = ["ENCODINGS", "encode", "decode"] diff --git a/netlib/exceptions.py b/netlib/exceptions.py new file mode 100644 index 00000000..05f1054b --- /dev/null +++ b/netlib/exceptions.py @@ -0,0 +1,56 @@ +""" +We try to be very hygienic regarding the exceptions we throw: +Every Exception netlib raises shall be a subclass of NetlibException. + + +See also: http://lucumr.pocoo.org/2014/10/16/on-error-handling/ +""" +from __future__ import absolute_import, print_function, division + + +class NetlibException(Exception): + """ + Base class for all exceptions thrown by netlib. + """ + def __init__(self, message=None): + super(NetlibException, self).__init__(message) + + +class Disconnect(object): + """Immediate EOF""" + + +class HttpException(NetlibException): + pass + + +class HttpReadDisconnect(HttpException, Disconnect): + pass + + +class HttpSyntaxException(HttpException): + pass + + +class TcpException(NetlibException): + pass + + +class TcpDisconnect(TcpException, Disconnect): + pass + + +class TcpReadIncomplete(TcpException): + pass + + +class TcpTimeout(TcpException): + pass + + +class TlsException(NetlibException): + pass + + +class InvalidCertificateException(TlsException): + pass diff --git a/netlib/http/__init__.py b/netlib/http/__init__.py new file mode 100644 index 00000000..fd632cd5 --- /dev/null +++ b/netlib/http/__init__.py @@ -0,0 +1,14 @@ +from __future__ import absolute_import, print_function, division +from .request import Request +from .response import Response +from .headers import Headers +from .message import decoded, CONTENT_MISSING +from . import http1, http2 + +__all__ = [ + "Request", + "Response", + "Headers", + "decoded", "CONTENT_MISSING", + "http1", "http2", +] diff --git a/netlib/http/authentication.py b/netlib/http/authentication.py new file mode 100644 index 00000000..d769abe5 --- /dev/null +++ b/netlib/http/authentication.py @@ -0,0 +1,167 @@ +from __future__ import (absolute_import, print_function, division) +from argparse import Action, ArgumentTypeError +import binascii + + +def parse_http_basic_auth(s): + words = s.split() + if len(words) != 2: + return None + scheme = words[0] + try: + user = binascii.a2b_base64(words[1]).decode("utf8", "replace") + except binascii.Error: + return None + parts = user.split(':') + if len(parts) != 2: + return None + return scheme, parts[0], parts[1] + + +def assemble_http_basic_auth(scheme, username, password): + v = binascii.b2a_base64((username + ":" + password).encode("utf8")).decode("ascii") + return scheme + " " + v + + +class NullProxyAuth(object): + + """ + No proxy auth at all (returns empty challange headers) + """ + + def __init__(self, password_manager): + self.password_manager = password_manager + + def clean(self, headers_): + """ + Clean up authentication headers, so they're not passed upstream. + """ + + def authenticate(self, headers_): + """ + Tests that the user is allowed to use the proxy + """ + return True + + def auth_challenge_headers(self): + """ + Returns a dictionary containing the headers require to challenge the user + """ + return {} + + +class BasicProxyAuth(NullProxyAuth): + CHALLENGE_HEADER = 'Proxy-Authenticate' + AUTH_HEADER = 'Proxy-Authorization' + + def __init__(self, password_manager, realm): + NullProxyAuth.__init__(self, password_manager) + self.realm = realm + + def clean(self, headers): + del headers[self.AUTH_HEADER] + + def authenticate(self, headers): + auth_value = headers.get(self.AUTH_HEADER) + if not auth_value: + return False + parts = parse_http_basic_auth(auth_value) + if not parts: + return False + scheme, username, password = parts + if scheme.lower() != 'basic': + return False + if not self.password_manager.test(username, password): + return False + self.username = username + return True + + def auth_challenge_headers(self): + return {self.CHALLENGE_HEADER: 'Basic realm="%s"' % self.realm} + + +class PassMan(object): + + def test(self, username_, password_token_): + return False + + +class PassManNonAnon(PassMan): + + """ + Ensure the user specifies a username, accept any password. + """ + + def test(self, username, password_token_): + if username: + return True + return False + + +class PassManHtpasswd(PassMan): + + """ + Read usernames and passwords from an htpasswd file + """ + + def __init__(self, path): + """ + Raises ValueError if htpasswd file is invalid. + """ + import passlib.apache + self.htpasswd = passlib.apache.HtpasswdFile(path) + + def test(self, username, password_token): + return bool(self.htpasswd.check_password(username, password_token)) + + +class PassManSingleUser(PassMan): + + def __init__(self, username, password): + self.username, self.password = username, password + + def test(self, username, password_token): + return self.username == username and self.password == password_token + + +class AuthAction(Action): + + """ + Helper class to allow seamless integration int argparse. Example usage: + parser.add_argument( + "--nonanonymous", + action=NonanonymousAuthAction, nargs=0, + help="Allow access to any user long as a credentials are specified." + ) + """ + + def __call__(self, parser, namespace, values, option_string=None): + passman = self.getPasswordManager(values) + authenticator = BasicProxyAuth(passman, "mitmproxy") + setattr(namespace, self.dest, authenticator) + + def getPasswordManager(self, s): # pragma: nocover + raise NotImplementedError() + + +class SingleuserAuthAction(AuthAction): + + def getPasswordManager(self, s): + if len(s.split(':')) != 2: + raise ArgumentTypeError( + "Invalid single-user specification. Please use the format username:password" + ) + username, password = s.split(':') + return PassManSingleUser(username, password) + + +class NonanonymousAuthAction(AuthAction): + + def getPasswordManager(self, s): + return PassManNonAnon() + + +class HtpasswdAuthAction(AuthAction): + + def getPasswordManager(self, s): + return PassManHtpasswd(s) diff --git a/netlib/http/cookies.py b/netlib/http/cookies.py new file mode 100644 index 00000000..18544b5e --- /dev/null +++ b/netlib/http/cookies.py @@ -0,0 +1,193 @@ +import re + +from .. import odict + +""" +A flexible module for cookie parsing and manipulation. + +This module differs from usual standards-compliant cookie modules in a number +of ways. We try to be as permissive as possible, and to retain even mal-formed +information. Duplicate cookies are preserved in parsing, and can be set in +formatting. We do attempt to escape and quote values where needed, but will not +reject data that violate the specs. + +Parsing accepts the formats in RFC6265 and partially RFC2109 and RFC2965. We do +not parse the comma-separated variant of Set-Cookie that allows multiple +cookies to be set in a single header. Technically this should be feasible, but +it turns out that violations of RFC6265 that makes the parsing problem +indeterminate are much more common than genuine occurences of the multi-cookie +variants. Serialization follows RFC6265. + + http://tools.ietf.org/html/rfc6265 + http://tools.ietf.org/html/rfc2109 + http://tools.ietf.org/html/rfc2965 +""" + +# TODO: Disallow LHS-only Cookie values + + +def _read_until(s, start, term): + """ + Read until one of the characters in term is reached. + """ + if start == len(s): + return "", start + 1 + for i in range(start, len(s)): + if s[i] in term: + return s[start:i], i + return s[start:i + 1], i + 1 + + +def _read_token(s, start): + """ + Read a token - the LHS of a token/value pair in a cookie. + """ + return _read_until(s, start, ";=") + + +def _read_quoted_string(s, start): + """ + start: offset to the first quote of the string to be read + + A sort of loose super-set of the various quoted string specifications. + + RFC6265 disallows backslashes or double quotes within quoted strings. + Prior RFCs use backslashes to escape. This leaves us free to apply + backslash escaping by default and be compatible with everything. + """ + escaping = False + ret = [] + # Skip the first quote + i = start # initialize in case the loop doesn't run. + for i in range(start + 1, len(s)): + if escaping: + ret.append(s[i]) + escaping = False + elif s[i] == '"': + break + elif s[i] == "\\": + escaping = True + else: + ret.append(s[i]) + return "".join(ret), i + 1 + + +def _read_value(s, start, delims): + """ + Reads a value - the RHS of a token/value pair in a cookie. + + special: If the value is special, commas are premitted. Else comma + terminates. This helps us support old and new style values. + """ + if start >= len(s): + return "", start + elif s[start] == '"': + return _read_quoted_string(s, start) + else: + return _read_until(s, start, delims) + + +def _read_pairs(s, off=0): + """ + Read pairs of lhs=rhs values. + + off: start offset + specials: a lower-cased list of keys that may contain commas + """ + vals = [] + while True: + lhs, off = _read_token(s, off) + lhs = lhs.lstrip() + if lhs: + rhs = None + if off < len(s): + if s[off] == "=": + rhs, off = _read_value(s, off + 1, ";") + vals.append([lhs, rhs]) + off += 1 + if not off < len(s): + break + return vals, off + + +def _has_special(s): + for i in s: + if i in '",;\\': + return True + o = ord(i) + if o < 0x21 or o > 0x7e: + return True + return False + + +ESCAPE = re.compile(r"([\"\\])") + + +def _format_pairs(lst, specials=(), sep="; "): + """ + specials: A lower-cased list of keys that will not be quoted. + """ + vals = [] + for k, v in lst: + if v is None: + vals.append(k) + else: + if k.lower() not in specials and _has_special(v): + v = ESCAPE.sub(r"\\\1", v) + v = '"%s"' % v + vals.append("%s=%s" % (k, v)) + return sep.join(vals) + + +def _format_set_cookie_pairs(lst): + return _format_pairs( + lst, + specials=("expires", "path") + ) + + +def _parse_set_cookie_pairs(s): + """ + For Set-Cookie, we support multiple cookies as described in RFC2109. + This function therefore returns a list of lists. + """ + pairs, off_ = _read_pairs(s) + return pairs + + +def parse_set_cookie_header(line): + """ + Parse a Set-Cookie header value + + Returns a (name, value, attrs) tuple, or None, where attrs is an + ODictCaseless set of attributes. No attempt is made to parse attribute + values - they are treated purely as strings. + """ + pairs = _parse_set_cookie_pairs(line) + if pairs: + return pairs[0][0], pairs[0][1], odict.ODictCaseless(pairs[1:]) + + +def format_set_cookie_header(name, value, attrs): + """ + Formats a Set-Cookie header value. + """ + pairs = [[name, value]] + pairs.extend(attrs.lst) + return _format_set_cookie_pairs(pairs) + + +def parse_cookie_header(line): + """ + Parse a Cookie header value. + Returns a (possibly empty) ODict object. + """ + pairs, off_ = _read_pairs(line) + return odict.ODict(pairs) + + +def format_cookie_header(od): + """ + Formats a Cookie header value. + """ + return _format_pairs(od.lst) diff --git a/netlib/http/headers.py b/netlib/http/headers.py new file mode 100644 index 00000000..78404796 --- /dev/null +++ b/netlib/http/headers.py @@ -0,0 +1,204 @@ +""" + +Unicode Handling +---------------- +See also: http://lucumr.pocoo.org/2013/7/2/the-updated-guide-to-unicode/ +""" +from __future__ import absolute_import, print_function, division +import copy +try: + from collections.abc import MutableMapping +except ImportError: # pragma: nocover + from collections import MutableMapping # Workaround for Python < 3.3 + + +import six + +from netlib.utils import always_byte_args, always_bytes, Serializable + +if six.PY2: # pragma: nocover + _native = lambda x: x + _always_bytes = lambda x: x + _always_byte_args = lambda x: x +else: + # While headers _should_ be ASCII, it's not uncommon for certain headers to be utf-8 encoded. + _native = lambda x: x.decode("utf-8", "surrogateescape") + _always_bytes = lambda x: always_bytes(x, "utf-8", "surrogateescape") + _always_byte_args = always_byte_args("utf-8", "surrogateescape") + + +class Headers(MutableMapping, Serializable): + """ + Header class which allows both convenient access to individual headers as well as + direct access to the underlying raw data. Provides a full dictionary interface. + + Example: + + .. code-block:: python + + # Create headers with keyword arguments + >>> h = Headers(host="example.com", content_type="application/xml") + + # Headers mostly behave like a normal dict. + >>> h["Host"] + "example.com" + + # HTTP Headers are case insensitive + >>> h["host"] + "example.com" + + # Headers can also be creatd from a list of raw (header_name, header_value) byte tuples + >>> h = Headers([ + [b"Host",b"example.com"], + [b"Accept",b"text/html"], + [b"accept",b"application/xml"] + ]) + + # Multiple headers are folded into a single header as per RFC7230 + >>> h["Accept"] + "text/html, application/xml" + + # Setting a header removes all existing headers with the same name. + >>> h["Accept"] = "application/text" + >>> h["Accept"] + "application/text" + + # bytes(h) returns a HTTP1 header block. + >>> print(bytes(h)) + Host: example.com + Accept: application/text + + # For full control, the raw header fields can be accessed + >>> h.fields + + Caveats: + For use with the "Set-Cookie" header, see :py:meth:`get_all`. + """ + + @_always_byte_args + def __init__(self, fields=None, **headers): + """ + Args: + fields: (optional) list of ``(name, value)`` header byte tuples, + e.g. ``[(b"Host", b"example.com")]``. All names and values must be bytes. + **headers: Additional headers to set. Will overwrite existing values from `fields`. + For convenience, underscores in header names will be transformed to dashes - + this behaviour does not extend to other methods. + If ``**headers`` contains multiple keys that have equal ``.lower()`` s, + the behavior is undefined. + """ + self.fields = fields or [] + + for name, value in self.fields: + if not isinstance(name, bytes) or not isinstance(value, bytes): + raise ValueError("Headers passed as fields must be bytes.") + + # content_type -> content-type + headers = { + _always_bytes(name).replace(b"_", b"-"): value + for name, value in six.iteritems(headers) + } + self.update(headers) + + def __bytes__(self): + if self.fields: + return b"\r\n".join(b": ".join(field) for field in self.fields) + b"\r\n" + else: + return b"" + + if six.PY2: # pragma: nocover + __str__ = __bytes__ + + @_always_byte_args + def __getitem__(self, name): + values = self.get_all(name) + if not values: + raise KeyError(name) + return ", ".join(values) + + @_always_byte_args + def __setitem__(self, name, value): + idx = self._index(name) + + # To please the human eye, we insert at the same position the first existing header occured. + if idx is not None: + del self[name] + self.fields.insert(idx, [name, value]) + else: + self.fields.append([name, value]) + + @_always_byte_args + def __delitem__(self, name): + if name not in self: + raise KeyError(name) + name = name.lower() + self.fields = [ + field for field in self.fields + if name != field[0].lower() + ] + + def __iter__(self): + seen = set() + for name, _ in self.fields: + name_lower = name.lower() + if name_lower not in seen: + seen.add(name_lower) + yield _native(name) + + def __len__(self): + return len(set(name.lower() for name, _ in self.fields)) + + # __hash__ = object.__hash__ + + def _index(self, name): + name = name.lower() + for i, field in enumerate(self.fields): + if field[0].lower() == name: + return i + return None + + def __eq__(self, other): + if isinstance(other, Headers): + return self.fields == other.fields + return False + + def __ne__(self, other): + return not self.__eq__(other) + + @_always_byte_args + def get_all(self, name): + """ + Like :py:meth:`get`, but does not fold multiple headers into a single one. + This is useful for Set-Cookie headers, which do not support folding. + + See also: https://tools.ietf.org/html/rfc7230#section-3.2.2 + """ + name_lower = name.lower() + values = [_native(value) for n, value in self.fields if n.lower() == name_lower] + return values + + @_always_byte_args + def set_all(self, name, values): + """ + Explicitly set multiple headers for the given key. + See: :py:meth:`get_all` + """ + values = map(_always_bytes, values) # _always_byte_args does not fix lists + if name in self: + del self[name] + self.fields.extend( + [name, value] for value in values + ) + + def copy(self): + return Headers(copy.copy(self.fields)) + + def get_state(self): + return tuple(tuple(field) for field in self.fields) + + def set_state(self, state): + self.fields = [list(field) for field in state] + + @classmethod + def from_state(cls, state): + return cls([list(field) for field in state]) \ No newline at end of file diff --git a/netlib/http/http1/__init__.py b/netlib/http/http1/__init__.py new file mode 100644 index 00000000..2aa7e26a --- /dev/null +++ b/netlib/http/http1/__init__.py @@ -0,0 +1,25 @@ +from __future__ import absolute_import, print_function, division +from .read import ( + read_request, read_request_head, + read_response, read_response_head, + read_body, + connection_close, + expected_http_body_size, +) +from .assemble import ( + assemble_request, assemble_request_head, + assemble_response, assemble_response_head, + assemble_body, +) + + +__all__ = [ + "read_request", "read_request_head", + "read_response", "read_response_head", + "read_body", + "connection_close", + "expected_http_body_size", + "assemble_request", "assemble_request_head", + "assemble_response", "assemble_response_head", + "assemble_body", +] diff --git a/netlib/http/http1/assemble.py b/netlib/http/http1/assemble.py new file mode 100644 index 00000000..785ee8d3 --- /dev/null +++ b/netlib/http/http1/assemble.py @@ -0,0 +1,104 @@ +from __future__ import absolute_import, print_function, division + +from ... import utils +import itertools +from ...exceptions import HttpException +from .. import CONTENT_MISSING + + +def assemble_request(request): + if request.content == CONTENT_MISSING: + raise HttpException("Cannot assemble flow with CONTENT_MISSING") + head = assemble_request_head(request) + body = b"".join(assemble_body(request.data.headers, [request.data.content])) + return head + body + + +def assemble_request_head(request): + first_line = _assemble_request_line(request.data) + headers = _assemble_request_headers(request.data) + return b"%s\r\n%s\r\n" % (first_line, headers) + + +def assemble_response(response): + if response.content == CONTENT_MISSING: + raise HttpException("Cannot assemble flow with CONTENT_MISSING") + head = assemble_response_head(response) + body = b"".join(assemble_body(response.data.headers, [response.data.content])) + return head + body + + +def assemble_response_head(response): + first_line = _assemble_response_line(response.data) + headers = _assemble_response_headers(response.data) + return b"%s\r\n%s\r\n" % (first_line, headers) + + +def assemble_body(headers, body_chunks): + if "chunked" in headers.get("transfer-encoding", "").lower(): + for chunk in body_chunks: + if chunk: + yield b"%x\r\n%s\r\n" % (len(chunk), chunk) + yield b"0\r\n\r\n" + else: + for chunk in body_chunks: + yield chunk + + +def _assemble_request_line(request_data): + """ + Args: + request_data (netlib.http.request.RequestData) + """ + form = request_data.first_line_format + if form == "relative": + return b"%s %s %s" % ( + request_data.method, + request_data.path, + request_data.http_version + ) + elif form == "authority": + return b"%s %s:%d %s" % ( + request_data.method, + request_data.host, + request_data.port, + request_data.http_version + ) + elif form == "absolute": + return b"%s %s://%s:%d%s %s" % ( + request_data.method, + request_data.scheme, + request_data.host, + request_data.port, + request_data.path, + request_data.http_version + ) + else: + raise RuntimeError("Invalid request form") + + +def _assemble_request_headers(request_data): + """ + Args: + request_data (netlib.http.request.RequestData) + """ + headers = request_data.headers.copy() + if "host" not in headers and request_data.scheme and request_data.host and request_data.port: + headers["host"] = utils.hostport( + request_data.scheme, + request_data.host, + request_data.port + ) + return bytes(headers) + + +def _assemble_response_line(response_data): + return b"%s %d %s" % ( + response_data.http_version, + response_data.status_code, + response_data.reason, + ) + + +def _assemble_response_headers(response): + return bytes(response.headers) diff --git a/netlib/http/http1/read.py b/netlib/http/http1/read.py new file mode 100644 index 00000000..6e3a1b93 --- /dev/null +++ b/netlib/http/http1/read.py @@ -0,0 +1,362 @@ +from __future__ import absolute_import, print_function, division +import time +import sys +import re + +from ... import utils +from ...exceptions import HttpReadDisconnect, HttpSyntaxException, HttpException, TcpDisconnect +from .. import Request, Response, Headers + + +def read_request(rfile, body_size_limit=None): + request = read_request_head(rfile) + expected_body_size = expected_http_body_size(request) + request.data.content = b"".join(read_body(rfile, expected_body_size, limit=body_size_limit)) + request.timestamp_end = time.time() + return request + + +def read_request_head(rfile): + """ + Parse an HTTP request head (request line + headers) from an input stream + + Args: + rfile: The input stream + + Returns: + The HTTP request object (without body) + + Raises: + HttpReadDisconnect: No bytes can be read from rfile. + HttpSyntaxException: The input is malformed HTTP. + HttpException: Any other error occured. + """ + timestamp_start = time.time() + if hasattr(rfile, "reset_timestamps"): + rfile.reset_timestamps() + + form, method, scheme, host, port, path, http_version = _read_request_line(rfile) + headers = _read_headers(rfile) + + if hasattr(rfile, "first_byte_timestamp"): + # more accurate timestamp_start + timestamp_start = rfile.first_byte_timestamp + + return Request( + form, method, scheme, host, port, path, http_version, headers, None, timestamp_start + ) + + +def read_response(rfile, request, body_size_limit=None): + response = read_response_head(rfile) + expected_body_size = expected_http_body_size(request, response) + response.data.content = b"".join(read_body(rfile, expected_body_size, body_size_limit)) + response.timestamp_end = time.time() + return response + + +def read_response_head(rfile): + """ + Parse an HTTP response head (response line + headers) from an input stream + + Args: + rfile: The input stream + + Returns: + The HTTP request object (without body) + + Raises: + HttpReadDisconnect: No bytes can be read from rfile. + HttpSyntaxException: The input is malformed HTTP. + HttpException: Any other error occured. + """ + + timestamp_start = time.time() + if hasattr(rfile, "reset_timestamps"): + rfile.reset_timestamps() + + http_version, status_code, message = _read_response_line(rfile) + headers = _read_headers(rfile) + + if hasattr(rfile, "first_byte_timestamp"): + # more accurate timestamp_start + timestamp_start = rfile.first_byte_timestamp + + return Response(http_version, status_code, message, headers, None, timestamp_start) + + +def read_body(rfile, expected_size, limit=None, max_chunk_size=4096): + """ + Read an HTTP message body + + Args: + rfile: The input stream + expected_size: The expected body size (see :py:meth:`expected_body_size`) + limit: Maximum body size + max_chunk_size: Maximium chunk size that gets yielded + + Returns: + A generator that yields byte chunks of the content. + + Raises: + HttpException, if an error occurs + + Caveats: + max_chunk_size is not considered if the transfer encoding is chunked. + """ + if not limit or limit < 0: + limit = sys.maxsize + if not max_chunk_size: + max_chunk_size = limit + + if expected_size is None: + for x in _read_chunked(rfile, limit): + yield x + elif expected_size >= 0: + if limit is not None and expected_size > limit: + raise HttpException( + "HTTP Body too large. " + "Limit is {}, content length was advertised as {}".format(limit, expected_size) + ) + bytes_left = expected_size + while bytes_left: + chunk_size = min(bytes_left, max_chunk_size) + content = rfile.read(chunk_size) + if len(content) < chunk_size: + raise HttpException("Unexpected EOF") + yield content + bytes_left -= chunk_size + else: + bytes_left = limit + while bytes_left: + chunk_size = min(bytes_left, max_chunk_size) + content = rfile.read(chunk_size) + if not content: + return + yield content + bytes_left -= chunk_size + not_done = rfile.read(1) + if not_done: + raise HttpException("HTTP body too large. Limit is {}.".format(limit)) + + +def connection_close(http_version, headers): + """ + Checks the message to see if the client connection should be closed + according to RFC 2616 Section 8.1. + """ + # At first, check if we have an explicit Connection header. + if "connection" in headers: + tokens = utils.get_header_tokens(headers, "connection") + if "close" in tokens: + return True + elif "keep-alive" in tokens: + return False + + # If we don't have a Connection header, HTTP 1.1 connections are assumed to + # be persistent + return http_version != "HTTP/1.1" and http_version != b"HTTP/1.1" # FIXME: Remove one case. + + +def expected_http_body_size(request, response=None): + """ + Returns: + The expected body length: + - a positive integer, if the size is known in advance + - None, if the size in unknown in advance (chunked encoding) + - -1, if all data should be read until end of stream. + + Raises: + HttpSyntaxException, if the content length header is invalid + """ + # Determine response size according to + # http://tools.ietf.org/html/rfc7230#section-3.3 + if not response: + headers = request.headers + response_code = None + is_request = True + else: + headers = response.headers + response_code = response.status_code + is_request = False + + if is_request: + if headers.get("expect", "").lower() == "100-continue": + return 0 + else: + if request.method.upper() == "HEAD": + return 0 + if 100 <= response_code <= 199: + return 0 + if response_code == 200 and request.method.upper() == "CONNECT": + return 0 + if response_code in (204, 304): + return 0 + + if "chunked" in headers.get("transfer-encoding", "").lower(): + return None + if "content-length" in headers: + try: + size = int(headers["content-length"]) + if size < 0: + raise ValueError() + return size + except ValueError: + raise HttpSyntaxException("Unparseable Content Length") + if is_request: + return 0 + return -1 + + +def _get_first_line(rfile): + try: + line = rfile.readline() + if line == b"\r\n" or line == b"\n": + # Possible leftover from previous message + line = rfile.readline() + except TcpDisconnect: + raise HttpReadDisconnect("Remote disconnected") + if not line: + raise HttpReadDisconnect("Remote disconnected") + return line.strip() + + +def _read_request_line(rfile): + try: + line = _get_first_line(rfile) + except HttpReadDisconnect: + # We want to provide a better error message. + raise HttpReadDisconnect("Client disconnected") + + try: + method, path, http_version = line.split(b" ") + + if path == b"*" or path.startswith(b"/"): + form = "relative" + scheme, host, port = None, None, None + elif method == b"CONNECT": + form = "authority" + host, port = _parse_authority_form(path) + scheme, path = None, None + else: + form = "absolute" + scheme, host, port, path = utils.parse_url(path) + + _check_http_version(http_version) + except ValueError: + raise HttpSyntaxException("Bad HTTP request line: {}".format(line)) + + return form, method, scheme, host, port, path, http_version + + +def _parse_authority_form(hostport): + """ + Returns (host, port) if hostport is a valid authority-form host specification. + http://tools.ietf.org/html/draft-luotonen-web-proxy-tunneling-01 section 3.1 + + Raises: + ValueError, if the input is malformed + """ + try: + host, port = hostport.split(b":") + port = int(port) + if not utils.is_valid_host(host) or not utils.is_valid_port(port): + raise ValueError() + except ValueError: + raise HttpSyntaxException("Invalid host specification: {}".format(hostport)) + + return host, port + + +def _read_response_line(rfile): + try: + line = _get_first_line(rfile) + except HttpReadDisconnect: + # We want to provide a better error message. + raise HttpReadDisconnect("Server disconnected") + + try: + + parts = line.split(b" ", 2) + if len(parts) == 2: # handle missing message gracefully + parts.append(b"") + + http_version, status_code, message = parts + status_code = int(status_code) + _check_http_version(http_version) + + except ValueError: + raise HttpSyntaxException("Bad HTTP response line: {}".format(line)) + + return http_version, status_code, message + + +def _check_http_version(http_version): + if not re.match(br"^HTTP/\d\.\d$", http_version): + raise HttpSyntaxException("Unknown HTTP version: {}".format(http_version)) + + +def _read_headers(rfile): + """ + Read a set of headers. + Stop once a blank line is reached. + + Returns: + A headers object + + Raises: + HttpSyntaxException + """ + ret = [] + while True: + line = rfile.readline() + if not line or line == b"\r\n" or line == b"\n": + break + if line[0] in b" \t": + if not ret: + raise HttpSyntaxException("Invalid headers") + # continued header + ret[-1][1] = ret[-1][1] + b'\r\n ' + line.strip() + else: + try: + name, value = line.split(b":", 1) + value = value.strip() + if not name: + raise ValueError() + ret.append([name, value]) + except ValueError: + raise HttpSyntaxException("Invalid headers") + return Headers(ret) + + +def _read_chunked(rfile, limit=sys.maxsize): + """ + Read a HTTP body with chunked transfer encoding. + + Args: + rfile: the input file + limit: A positive integer + """ + total = 0 + while True: + line = rfile.readline(128) + if line == b"": + raise HttpException("Connection closed prematurely") + if line != b"\r\n" and line != b"\n": + try: + length = int(line, 16) + except ValueError: + raise HttpSyntaxException("Invalid chunked encoding length: {}".format(line)) + total += length + if total > limit: + raise HttpException( + "HTTP Body too large. Limit is {}, " + "chunked content longer than {}".format(limit, total) + ) + chunk = rfile.read(length) + suffix = rfile.readline(5) + if suffix != b"\r\n": + raise HttpSyntaxException("Malformed chunked body") + if length == 0: + return + yield chunk diff --git a/netlib/http/http2/__init__.py b/netlib/http/http2/__init__.py new file mode 100644 index 00000000..7043d36f --- /dev/null +++ b/netlib/http/http2/__init__.py @@ -0,0 +1,6 @@ +from __future__ import absolute_import, print_function, division +from .connections import HTTP2Protocol + +__all__ = [ + "HTTP2Protocol" +] diff --git a/netlib/http/http2/connections.py b/netlib/http/http2/connections.py new file mode 100644 index 00000000..52fa7193 --- /dev/null +++ b/netlib/http/http2/connections.py @@ -0,0 +1,426 @@ +from __future__ import (absolute_import, print_function, division) +import itertools +import time + +from hpack.hpack import Encoder, Decoder +from ... import utils +from .. import Headers, Response, Request + +from hyperframe import frame + + +class TCPHandler(object): + + def __init__(self, rfile, wfile=None): + self.rfile = rfile + self.wfile = wfile + + +class HTTP2Protocol(object): + + ERROR_CODES = utils.BiDi( + NO_ERROR=0x0, + PROTOCOL_ERROR=0x1, + INTERNAL_ERROR=0x2, + FLOW_CONTROL_ERROR=0x3, + SETTINGS_TIMEOUT=0x4, + STREAM_CLOSED=0x5, + FRAME_SIZE_ERROR=0x6, + REFUSED_STREAM=0x7, + CANCEL=0x8, + COMPRESSION_ERROR=0x9, + CONNECT_ERROR=0xa, + ENHANCE_YOUR_CALM=0xb, + INADEQUATE_SECURITY=0xc, + HTTP_1_1_REQUIRED=0xd + ) + + CLIENT_CONNECTION_PREFACE = b'PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n' + + HTTP2_DEFAULT_SETTINGS = { + frame.SettingsFrame.HEADER_TABLE_SIZE: 4096, + frame.SettingsFrame.ENABLE_PUSH: 1, + frame.SettingsFrame.MAX_CONCURRENT_STREAMS: None, + frame.SettingsFrame.INITIAL_WINDOW_SIZE: 2 ** 16 - 1, + frame.SettingsFrame.MAX_FRAME_SIZE: 2 ** 14, + frame.SettingsFrame.MAX_HEADER_LIST_SIZE: None, + } + + def __init__( + self, + tcp_handler=None, + rfile=None, + wfile=None, + is_server=False, + dump_frames=False, + encoder=None, + decoder=None, + unhandled_frame_cb=None, + ): + self.tcp_handler = tcp_handler or TCPHandler(rfile, wfile) + self.is_server = is_server + self.dump_frames = dump_frames + self.encoder = encoder or Encoder() + self.decoder = decoder or Decoder() + self.unhandled_frame_cb = unhandled_frame_cb + + self.http2_settings = self.HTTP2_DEFAULT_SETTINGS.copy() + self.current_stream_id = None + self.connection_preface_performed = False + + def read_request( + self, + __rfile, + include_body=True, + body_size_limit=None, + allow_empty=False, + ): + if body_size_limit is not None: + raise NotImplementedError() + + self.perform_connection_preface() + + timestamp_start = time.time() + if hasattr(self.tcp_handler.rfile, "reset_timestamps"): + self.tcp_handler.rfile.reset_timestamps() + + stream_id, headers, body = self._receive_transmission( + include_body=include_body, + ) + + if hasattr(self.tcp_handler.rfile, "first_byte_timestamp"): + # more accurate timestamp_start + timestamp_start = self.tcp_handler.rfile.first_byte_timestamp + + timestamp_end = time.time() + + authority = headers.get(':authority', b'') + method = headers.get(':method', 'GET') + scheme = headers.get(':scheme', 'https') + path = headers.get(':path', '/') + host = None + port = None + + if path == '*' or path.startswith("/"): + form_in = "relative" + elif method == 'CONNECT': + form_in = "authority" + if ":" in authority: + host, port = authority.split(":", 1) + else: + host = authority + else: + form_in = "absolute" + # FIXME: verify if path or :host contains what we need + scheme, host, port, _ = utils.parse_url(path) + scheme = scheme.decode('ascii') + host = host.decode('ascii') + + if host is None: + host = 'localhost' + if port is None: + port = 80 if scheme == 'http' else 443 + port = int(port) + + request = Request( + form_in, + method.encode('ascii'), + scheme.encode('ascii'), + host.encode('ascii'), + port, + path.encode('ascii'), + b"HTTP/2.0", + headers, + body, + timestamp_start, + timestamp_end, + ) + request.stream_id = stream_id + + return request + + def read_response( + self, + __rfile, + request_method=b'', + body_size_limit=None, + include_body=True, + stream_id=None, + ): + if body_size_limit is not None: + raise NotImplementedError() + + self.perform_connection_preface() + + timestamp_start = time.time() + if hasattr(self.tcp_handler.rfile, "reset_timestamps"): + self.tcp_handler.rfile.reset_timestamps() + + stream_id, headers, body = self._receive_transmission( + stream_id=stream_id, + include_body=include_body, + ) + + if hasattr(self.tcp_handler.rfile, "first_byte_timestamp"): + # more accurate timestamp_start + timestamp_start = self.tcp_handler.rfile.first_byte_timestamp + + if include_body: + timestamp_end = time.time() + else: + timestamp_end = None + + response = Response( + b"HTTP/2.0", + int(headers.get(':status', 502)), + b'', + headers, + body, + timestamp_start=timestamp_start, + timestamp_end=timestamp_end, + ) + response.stream_id = stream_id + + return response + + def assemble(self, message): + if isinstance(message, Request): + return self.assemble_request(message) + elif isinstance(message, Response): + return self.assemble_response(message) + else: + raise ValueError("HTTP message not supported.") + + def assemble_request(self, request): + assert isinstance(request, Request) + + authority = self.tcp_handler.sni if self.tcp_handler.sni else self.tcp_handler.address.host + if self.tcp_handler.address.port != 443: + authority += ":%d" % self.tcp_handler.address.port + + headers = request.headers.copy() + + if ':authority' not in headers: + headers.fields.insert(0, (b':authority', authority.encode('ascii'))) + if ':scheme' not in headers: + headers.fields.insert(0, (b':scheme', request.scheme.encode('ascii'))) + if ':path' not in headers: + headers.fields.insert(0, (b':path', request.path.encode('ascii'))) + if ':method' not in headers: + headers.fields.insert(0, (b':method', request.method.encode('ascii'))) + + if hasattr(request, 'stream_id'): + stream_id = request.stream_id + else: + stream_id = self._next_stream_id() + + return list(itertools.chain( + self._create_headers(headers, stream_id, end_stream=(request.body is None or len(request.body) == 0)), + self._create_body(request.body, stream_id))) + + def assemble_response(self, response): + assert isinstance(response, Response) + + headers = response.headers.copy() + + if ':status' not in headers: + headers.fields.insert(0, (b':status', str(response.status_code).encode('ascii'))) + + if hasattr(response, 'stream_id'): + stream_id = response.stream_id + else: + stream_id = self._next_stream_id() + + return list(itertools.chain( + self._create_headers(headers, stream_id, end_stream=(response.body is None or len(response.body) == 0)), + self._create_body(response.body, stream_id), + )) + + def perform_connection_preface(self, force=False): + if force or not self.connection_preface_performed: + if self.is_server: + self.perform_server_connection_preface(force) + else: + self.perform_client_connection_preface(force) + + def perform_server_connection_preface(self, force=False): + if force or not self.connection_preface_performed: + self.connection_preface_performed = True + + magic_length = len(self.CLIENT_CONNECTION_PREFACE) + magic = self.tcp_handler.rfile.safe_read(magic_length) + assert magic == self.CLIENT_CONNECTION_PREFACE + + frm = frame.SettingsFrame(settings={ + frame.SettingsFrame.ENABLE_PUSH: 0, + frame.SettingsFrame.MAX_CONCURRENT_STREAMS: 1, + }) + self.send_frame(frm, hide=True) + self._receive_settings(hide=True) + + def perform_client_connection_preface(self, force=False): + if force or not self.connection_preface_performed: + self.connection_preface_performed = True + + self.tcp_handler.wfile.write(self.CLIENT_CONNECTION_PREFACE) + + self.send_frame(frame.SettingsFrame(), hide=True) + self._receive_settings(hide=True) # server announces own settings + self._receive_settings(hide=True) # server acks my settings + + def send_frame(self, frm, hide=False): + raw_bytes = frm.serialize() + self.tcp_handler.wfile.write(raw_bytes) + self.tcp_handler.wfile.flush() + if not hide and self.dump_frames: # pragma no cover + print(frm.human_readable(">>")) + + def read_frame(self, hide=False): + while True: + frm = utils.http2_read_frame(self.tcp_handler.rfile) + if not hide and self.dump_frames: # pragma no cover + print(frm.human_readable("<<")) + + if isinstance(frm, frame.PingFrame): + raw_bytes = frame.PingFrame(flags=['ACK'], payload=frm.payload).serialize() + self.tcp_handler.wfile.write(raw_bytes) + self.tcp_handler.wfile.flush() + continue + if isinstance(frm, frame.SettingsFrame) and 'ACK' not in frm.flags: + self._apply_settings(frm.settings, hide) + if isinstance(frm, frame.DataFrame) and frm.flow_controlled_length > 0: + self._update_flow_control_window(frm.stream_id, frm.flow_controlled_length) + return frm + + def check_alpn(self): + alp = self.tcp_handler.get_alpn_proto_negotiated() + if alp != b'h2': + raise NotImplementedError( + "HTTP2Protocol can not handle unknown ALP: %s" % alp) + return True + + def _handle_unexpected_frame(self, frm): + if isinstance(frm, frame.SettingsFrame): + return + if self.unhandled_frame_cb: + self.unhandled_frame_cb(frm) + + def _receive_settings(self, hide=False): + while True: + frm = self.read_frame(hide) + if isinstance(frm, frame.SettingsFrame): + break + else: + self._handle_unexpected_frame(frm) + + def _next_stream_id(self): + if self.current_stream_id is None: + if self.is_server: + # servers must use even stream ids + self.current_stream_id = 2 + else: + # clients must use odd stream ids + self.current_stream_id = 1 + else: + self.current_stream_id += 2 + return self.current_stream_id + + def _apply_settings(self, settings, hide=False): + for setting, value in settings.items(): + old_value = self.http2_settings[setting] + if not old_value: + old_value = '-' + self.http2_settings[setting] = value + + frm = frame.SettingsFrame(flags=['ACK']) + self.send_frame(frm, hide) + + def _update_flow_control_window(self, stream_id, increment): + frm = frame.WindowUpdateFrame(stream_id=0, window_increment=increment) + self.send_frame(frm) + frm = frame.WindowUpdateFrame(stream_id=stream_id, window_increment=increment) + self.send_frame(frm) + + def _create_headers(self, headers, stream_id, end_stream=True): + def frame_cls(chunks): + for i in chunks: + if i == 0: + yield frame.HeadersFrame, i + else: + yield frame.ContinuationFrame, i + + header_block_fragment = self.encoder.encode(headers.fields) + + chunk_size = self.http2_settings[frame.SettingsFrame.MAX_FRAME_SIZE] + chunks = range(0, len(header_block_fragment), chunk_size) + frms = [frm_cls( + flags=[], + stream_id=stream_id, + data=header_block_fragment[i:i+chunk_size]) for frm_cls, i in frame_cls(chunks)] + + frms[-1].flags.add('END_HEADERS') + if end_stream: + frms[0].flags.add('END_STREAM') + + if self.dump_frames: # pragma no cover + for frm in frms: + print(frm.human_readable(">>")) + + return [frm.serialize() for frm in frms] + + def _create_body(self, body, stream_id): + if body is None or len(body) == 0: + return b'' + + chunk_size = self.http2_settings[frame.SettingsFrame.MAX_FRAME_SIZE] + chunks = range(0, len(body), chunk_size) + frms = [frame.DataFrame( + flags=[], + stream_id=stream_id, + data=body[i:i+chunk_size]) for i in chunks] + frms[-1].flags.add('END_STREAM') + + if self.dump_frames: # pragma no cover + for frm in frms: + print(frm.human_readable(">>")) + + return [frm.serialize() for frm in frms] + + def _receive_transmission(self, stream_id=None, include_body=True): + if not include_body: + raise NotImplementedError() + + body_expected = True + + header_blocks = b'' + body = b'' + + while True: + frm = self.read_frame() + if ( + (isinstance(frm, frame.HeadersFrame) or isinstance(frm, frame.ContinuationFrame)) and + (stream_id is None or frm.stream_id == stream_id) + ): + stream_id = frm.stream_id + header_blocks += frm.data + if 'END_STREAM' in frm.flags: + body_expected = False + if 'END_HEADERS' in frm.flags: + break + else: + self._handle_unexpected_frame(frm) + + while body_expected: + frm = self.read_frame() + if isinstance(frm, frame.DataFrame) and frm.stream_id == stream_id: + body += frm.data + if 'END_STREAM' in frm.flags: + break + else: + self._handle_unexpected_frame(frm) + + headers = Headers( + [[k.encode('ascii'), v.encode('ascii')] for k, v in self.decoder.decode(header_blocks)] + ) + + return stream_id, headers, body diff --git a/netlib/http/message.py b/netlib/http/message.py new file mode 100644 index 00000000..e3d8ce37 --- /dev/null +++ b/netlib/http/message.py @@ -0,0 +1,222 @@ +from __future__ import absolute_import, print_function, division + +import warnings + +import six + +from .headers import Headers +from .. import encoding, utils + +CONTENT_MISSING = 0 + +if six.PY2: # pragma: nocover + _native = lambda x: x + _always_bytes = lambda x: x +else: + # While the HTTP head _should_ be ASCII, it's not uncommon for certain headers to be utf-8 encoded. + _native = lambda x: x.decode("utf-8", "surrogateescape") + _always_bytes = lambda x: utils.always_bytes(x, "utf-8", "surrogateescape") + + +class MessageData(utils.Serializable): + def __eq__(self, other): + if isinstance(other, MessageData): + return self.__dict__ == other.__dict__ + return False + + def __ne__(self, other): + return not self.__eq__(other) + + def set_state(self, state): + for k, v in state.items(): + if k == "headers": + v = Headers.from_state(v) + setattr(self, k, v) + + def get_state(self): + state = vars(self).copy() + state["headers"] = state["headers"].get_state() + return state + + @classmethod + def from_state(cls, state): + state["headers"] = Headers.from_state(state["headers"]) + return cls(**state) + + +class Message(utils.Serializable): + def __init__(self, data): + self.data = data + + def __eq__(self, other): + if isinstance(other, Message): + return self.data == other.data + return False + + def __ne__(self, other): + return not self.__eq__(other) + + def get_state(self): + return self.data.get_state() + + def set_state(self, state): + self.data.set_state(state) + + @classmethod + def from_state(cls, state): + return cls(**state) + + @property + def headers(self): + """ + Message headers object + + Returns: + netlib.http.Headers + """ + return self.data.headers + + @headers.setter + def headers(self, h): + self.data.headers = h + + @property + def content(self): + """ + The raw (encoded) HTTP message body + + See also: :py:attr:`text` + """ + return self.data.content + + @content.setter + def content(self, content): + self.data.content = content + if isinstance(content, bytes): + self.headers["content-length"] = str(len(content)) + + @property + def http_version(self): + """ + Version string, e.g. "HTTP/1.1" + """ + return _native(self.data.http_version) + + @http_version.setter + def http_version(self, http_version): + self.data.http_version = _always_bytes(http_version) + + @property + def timestamp_start(self): + """ + First byte timestamp + """ + return self.data.timestamp_start + + @timestamp_start.setter + def timestamp_start(self, timestamp_start): + self.data.timestamp_start = timestamp_start + + @property + def timestamp_end(self): + """ + Last byte timestamp + """ + return self.data.timestamp_end + + @timestamp_end.setter + def timestamp_end(self, timestamp_end): + self.data.timestamp_end = timestamp_end + + @property + def text(self): + """ + The decoded HTTP message body. + Decoded contents are not cached, so accessing this attribute repeatedly is relatively expensive. + + .. note:: + This is not implemented yet. + + See also: :py:attr:`content`, :py:class:`decoded` + """ + # This attribute should be called text, because that's what requests does. + raise NotImplementedError() + + @text.setter + def text(self, text): + raise NotImplementedError() + + def decode(self): + """ + Decodes body based on the current Content-Encoding header, then + removes the header. If there is no Content-Encoding header, no + action is taken. + + Returns: + True, if decoding succeeded. + False, otherwise. + """ + ce = self.headers.get("content-encoding") + data = encoding.decode(ce, self.content) + if data is None: + return False + self.content = data + self.headers.pop("content-encoding", None) + return True + + def encode(self, e): + """ + Encodes body with the encoding e, where e is "gzip", "deflate" or "identity". + + Returns: + True, if decoding succeeded. + False, otherwise. + """ + data = encoding.encode(e, self.content) + if data is None: + return False + self.content = data + self.headers["content-encoding"] = e + return True + + # Legacy + + @property + def body(self): # pragma: nocover + warnings.warn(".body is deprecated, use .content instead.", DeprecationWarning) + return self.content + + @body.setter + def body(self, body): # pragma: nocover + warnings.warn(".body is deprecated, use .content instead.", DeprecationWarning) + self.content = body + + +class decoded(object): + """ + A context manager that decodes a request or response, and then + re-encodes it with the same encoding after execution of the block. + + Example: + + .. code-block:: python + + with decoded(request): + request.content = request.content.replace("foo", "bar") + """ + + def __init__(self, message): + self.message = message + ce = message.headers.get("content-encoding") + if ce in encoding.ENCODINGS: + self.ce = ce + else: + self.ce = None + + def __enter__(self): + if self.ce: + self.message.decode() + + def __exit__(self, type, value, tb): + if self.ce: + self.message.encode(self.ce) diff --git a/netlib/http/request.py b/netlib/http/request.py new file mode 100644 index 00000000..b9076c0f --- /dev/null +++ b/netlib/http/request.py @@ -0,0 +1,356 @@ +from __future__ import absolute_import, print_function, division + +import warnings + +import six +from six.moves import urllib + +from netlib import utils +from netlib.http import cookies +from netlib.odict import ODict +from .. import encoding +from .headers import Headers +from .message import Message, _native, _always_bytes, MessageData + + +class RequestData(MessageData): + def __init__(self, first_line_format, method, scheme, host, port, path, http_version, headers=None, content=None, + timestamp_start=None, timestamp_end=None): + if not isinstance(headers, Headers): + headers = Headers(headers) + + self.first_line_format = first_line_format + self.method = method + self.scheme = scheme + self.host = host + self.port = port + self.path = path + self.http_version = http_version + self.headers = headers + self.content = content + self.timestamp_start = timestamp_start + self.timestamp_end = timestamp_end + + +class Request(Message): + """ + An HTTP request. + """ + def __init__(self, *args, **kwargs): + data = RequestData(*args, **kwargs) + super(Request, self).__init__(data) + + def __repr__(self): + if self.host and self.port: + hostport = "{}:{}".format(self.host, self.port) + else: + hostport = "" + path = self.path or "" + return "Request({} {}{})".format( + self.method, hostport, path + ) + + @property + def first_line_format(self): + """ + HTTP request form as defined in `RFC7230 `_. + + origin-form and asterisk-form are subsumed as "relative". + """ + return self.data.first_line_format + + @first_line_format.setter + def first_line_format(self, first_line_format): + self.data.first_line_format = first_line_format + + @property + def method(self): + """ + HTTP request method, e.g. "GET". + """ + return _native(self.data.method).upper() + + @method.setter + def method(self, method): + self.data.method = _always_bytes(method) + + @property + def scheme(self): + """ + HTTP request scheme, which should be "http" or "https". + """ + return _native(self.data.scheme) + + @scheme.setter + def scheme(self, scheme): + self.data.scheme = _always_bytes(scheme) + + @property + def host(self): + """ + Target host. This may be parsed from the raw request + (e.g. from a ``GET http://example.com/ HTTP/1.1`` request line) + or inferred from the proxy mode (e.g. an IP in transparent mode). + + Setting the host attribute also updates the host header, if present. + """ + + if six.PY2: # pragma: nocover + return self.data.host + + if not self.data.host: + return self.data.host + try: + return self.data.host.decode("idna") + except UnicodeError: + return self.data.host.decode("utf8", "surrogateescape") + + @host.setter + def host(self, host): + if isinstance(host, six.text_type): + try: + # There's no non-strict mode for IDNA encoding. + # We don't want this operation to fail though, so we try + # utf8 as a last resort. + host = host.encode("idna", "strict") + except UnicodeError: + host = host.encode("utf8", "surrogateescape") + + self.data.host = host + + # Update host header + if "host" in self.headers: + if host: + self.headers["host"] = host + else: + self.headers.pop("host") + + @property + def port(self): + """ + Target port + """ + return self.data.port + + @port.setter + def port(self, port): + self.data.port = port + + @property + def path(self): + """ + HTTP request path, e.g. "/index.html". + Guaranteed to start with a slash. + """ + return _native(self.data.path) + + @path.setter + def path(self, path): + self.data.path = _always_bytes(path) + + @property + def url(self): + """ + The URL string, constructed from the request's URL components + """ + return utils.unparse_url(self.scheme, self.host, self.port, self.path) + + @url.setter + def url(self, url): + self.scheme, self.host, self.port, self.path = utils.parse_url(url) + + @property + def pretty_host(self): + """ + Similar to :py:attr:`host`, but using the Host headers as an additional preferred data source. + This is useful in transparent mode where :py:attr:`host` is only an IP address, + but may not reflect the actual destination as the Host header could be spoofed. + """ + return self.headers.get("host", self.host) + + @property + def pretty_url(self): + """ + Like :py:attr:`url`, but using :py:attr:`pretty_host` instead of :py:attr:`host`. + """ + if self.first_line_format == "authority": + return "%s:%d" % (self.pretty_host, self.port) + return utils.unparse_url(self.scheme, self.pretty_host, self.port, self.path) + + @property + def query(self): + """ + The request query string as an :py:class:`ODict` object. + None, if there is no query. + """ + _, _, _, _, query, _ = urllib.parse.urlparse(self.url) + if query: + return ODict(utils.urldecode(query)) + return None + + @query.setter + def query(self, odict): + query = utils.urlencode(odict.lst) + scheme, netloc, path, params, _, fragment = urllib.parse.urlparse(self.url) + _, _, _, self.path = utils.parse_url( + urllib.parse.urlunparse([scheme, netloc, path, params, query, fragment])) + + @property + def cookies(self): + """ + The request cookies. + An empty :py:class:`ODict` object if the cookie monster ate them all. + """ + ret = ODict() + for i in self.headers.get_all("Cookie"): + ret.extend(cookies.parse_cookie_header(i)) + return ret + + @cookies.setter + def cookies(self, odict): + self.headers["cookie"] = cookies.format_cookie_header(odict) + + @property + def path_components(self): + """ + The URL's path components as a list of strings. + Components are unquoted. + """ + _, _, path, _, _, _ = urllib.parse.urlparse(self.url) + return [urllib.parse.unquote(i) for i in path.split("/") if i] + + @path_components.setter + def path_components(self, components): + components = map(lambda x: urllib.parse.quote(x, safe=""), components) + path = "/" + "/".join(components) + scheme, netloc, _, params, query, fragment = urllib.parse.urlparse(self.url) + _, _, _, self.path = utils.parse_url( + urllib.parse.urlunparse([scheme, netloc, path, params, query, fragment])) + + def anticache(self): + """ + Modifies this request to remove headers that might produce a cached + response. That is, we remove ETags and If-Modified-Since headers. + """ + delheaders = [ + "if-modified-since", + "if-none-match", + ] + for i in delheaders: + self.headers.pop(i, None) + + def anticomp(self): + """ + Modifies this request to remove headers that will compress the + resource's data. + """ + self.headers["accept-encoding"] = "identity" + + def constrain_encoding(self): + """ + Limits the permissible Accept-Encoding values, based on what we can + decode appropriately. + """ + accept_encoding = self.headers.get("accept-encoding") + if accept_encoding: + self.headers["accept-encoding"] = ( + ', '.join( + e + for e in encoding.ENCODINGS + if e in accept_encoding + ) + ) + + @property + def urlencoded_form(self): + """ + The URL-encoded form data as an :py:class:`ODict` object. + None if there is no data or the content-type indicates non-form data. + """ + is_valid_content_type = "application/x-www-form-urlencoded" in self.headers.get("content-type", "").lower() + if self.content and is_valid_content_type: + return ODict(utils.urldecode(self.content)) + return None + + @urlencoded_form.setter + def urlencoded_form(self, odict): + """ + Sets the body to the URL-encoded form data, and adds the appropriate content-type header. + This will overwrite the existing content if there is one. + """ + self.headers["content-type"] = "application/x-www-form-urlencoded" + self.content = utils.urlencode(odict.lst) + + @property + def multipart_form(self): + """ + The multipart form data as an :py:class:`ODict` object. + None if there is no data or the content-type indicates non-form data. + """ + is_valid_content_type = "multipart/form-data" in self.headers.get("content-type", "").lower() + if self.content and is_valid_content_type: + return ODict(utils.multipartdecode(self.headers,self.content)) + return None + + @multipart_form.setter + def multipart_form(self, value): + raise NotImplementedError() + + # Legacy + + def get_cookies(self): # pragma: nocover + warnings.warn(".get_cookies is deprecated, use .cookies instead.", DeprecationWarning) + return self.cookies + + def set_cookies(self, odict): # pragma: nocover + warnings.warn(".set_cookies is deprecated, use .cookies instead.", DeprecationWarning) + self.cookies = odict + + def get_query(self): # pragma: nocover + warnings.warn(".get_query is deprecated, use .query instead.", DeprecationWarning) + return self.query or ODict([]) + + def set_query(self, odict): # pragma: nocover + warnings.warn(".set_query is deprecated, use .query instead.", DeprecationWarning) + self.query = odict + + def get_path_components(self): # pragma: nocover + warnings.warn(".get_path_components is deprecated, use .path_components instead.", DeprecationWarning) + return self.path_components + + def set_path_components(self, lst): # pragma: nocover + warnings.warn(".set_path_components is deprecated, use .path_components instead.", DeprecationWarning) + self.path_components = lst + + def get_form_urlencoded(self): # pragma: nocover + warnings.warn(".get_form_urlencoded is deprecated, use .urlencoded_form instead.", DeprecationWarning) + return self.urlencoded_form or ODict([]) + + def set_form_urlencoded(self, odict): # pragma: nocover + warnings.warn(".set_form_urlencoded is deprecated, use .urlencoded_form instead.", DeprecationWarning) + self.urlencoded_form = odict + + def get_form_multipart(self): # pragma: nocover + warnings.warn(".get_form_multipart is deprecated, use .multipart_form instead.", DeprecationWarning) + return self.multipart_form or ODict([]) + + @property + def form_in(self): # pragma: nocover + warnings.warn(".form_in is deprecated, use .first_line_format instead.", DeprecationWarning) + return self.first_line_format + + @form_in.setter + def form_in(self, form_in): # pragma: nocover + warnings.warn(".form_in is deprecated, use .first_line_format instead.", DeprecationWarning) + self.first_line_format = form_in + + @property + def form_out(self): # pragma: nocover + warnings.warn(".form_out is deprecated, use .first_line_format instead.", DeprecationWarning) + return self.first_line_format + + @form_out.setter + def form_out(self, form_out): # pragma: nocover + warnings.warn(".form_out is deprecated, use .first_line_format instead.", DeprecationWarning) + self.first_line_format = form_out + diff --git a/netlib/http/response.py b/netlib/http/response.py new file mode 100644 index 00000000..8f4d6215 --- /dev/null +++ b/netlib/http/response.py @@ -0,0 +1,116 @@ +from __future__ import absolute_import, print_function, division + +import warnings + +from . import cookies +from .headers import Headers +from .message import Message, _native, _always_bytes, MessageData +from .. import utils +from ..odict import ODict + + +class ResponseData(MessageData): + def __init__(self, http_version, status_code, reason=None, headers=None, content=None, + timestamp_start=None, timestamp_end=None): + if not isinstance(headers, Headers): + headers = Headers(headers) + + self.http_version = http_version + self.status_code = status_code + self.reason = reason + self.headers = headers + self.content = content + self.timestamp_start = timestamp_start + self.timestamp_end = timestamp_end + + +class Response(Message): + """ + An HTTP response. + """ + def __init__(self, *args, **kwargs): + data = ResponseData(*args, **kwargs) + super(Response, self).__init__(data) + + def __repr__(self): + if self.content: + details = "{}, {}".format( + self.headers.get("content-type", "unknown content type"), + utils.pretty_size(len(self.content)) + ) + else: + details = "no content" + return "Response({status_code} {reason}, {details})".format( + status_code=self.status_code, + reason=self.reason, + details=details + ) + + @property + def status_code(self): + """ + HTTP Status Code, e.g. ``200``. + """ + return self.data.status_code + + @status_code.setter + def status_code(self, status_code): + self.data.status_code = status_code + + @property + def reason(self): + """ + HTTP Reason Phrase, e.g. "Not Found". + This is always :py:obj:`None` for HTTP2 requests, because HTTP2 responses do not contain a reason phrase. + """ + return _native(self.data.reason) + + @reason.setter + def reason(self, reason): + self.data.reason = _always_bytes(reason) + + @property + def cookies(self): + """ + Get the contents of all Set-Cookie headers. + + A possibly empty :py:class:`ODict`, where keys are cookie name strings, + and values are [value, attr] lists. Value is a string, and attr is + an ODictCaseless containing cookie attributes. Within attrs, unary + attributes (e.g. HTTPOnly) are indicated by a Null value. + """ + ret = [] + for header in self.headers.get_all("set-cookie"): + v = cookies.parse_set_cookie_header(header) + if v: + name, value, attrs = v + ret.append([name, [value, attrs]]) + return ODict(ret) + + @cookies.setter + def cookies(self, odict): + values = [] + for i in odict.lst: + header = cookies.format_set_cookie_header(i[0], i[1][0], i[1][1]) + values.append(header) + self.headers.set_all("set-cookie", values) + + # Legacy + + def get_cookies(self): # pragma: nocover + warnings.warn(".get_cookies is deprecated, use .cookies instead.", DeprecationWarning) + return self.cookies + + def set_cookies(self, odict): # pragma: nocover + warnings.warn(".set_cookies is deprecated, use .cookies instead.", DeprecationWarning) + self.cookies = odict + + @property + def msg(self): # pragma: nocover + warnings.warn(".msg is deprecated, use .reason instead.", DeprecationWarning) + return self.reason + + @msg.setter + def msg(self, reason): # pragma: nocover + warnings.warn(".msg is deprecated, use .reason instead.", DeprecationWarning) + self.reason = reason diff --git a/netlib/http/status_codes.py b/netlib/http/status_codes.py new file mode 100644 index 00000000..8a4dc1f5 --- /dev/null +++ b/netlib/http/status_codes.py @@ -0,0 +1,106 @@ +from __future__ import absolute_import, print_function, division + +CONTINUE = 100 +SWITCHING = 101 +OK = 200 +CREATED = 201 +ACCEPTED = 202 +NON_AUTHORITATIVE_INFORMATION = 203 +NO_CONTENT = 204 +RESET_CONTENT = 205 +PARTIAL_CONTENT = 206 +MULTI_STATUS = 207 + +MULTIPLE_CHOICE = 300 +MOVED_PERMANENTLY = 301 +FOUND = 302 +SEE_OTHER = 303 +NOT_MODIFIED = 304 +USE_PROXY = 305 +TEMPORARY_REDIRECT = 307 + +BAD_REQUEST = 400 +UNAUTHORIZED = 401 +PAYMENT_REQUIRED = 402 +FORBIDDEN = 403 +NOT_FOUND = 404 +NOT_ALLOWED = 405 +NOT_ACCEPTABLE = 406 +PROXY_AUTH_REQUIRED = 407 +REQUEST_TIMEOUT = 408 +CONFLICT = 409 +GONE = 410 +LENGTH_REQUIRED = 411 +PRECONDITION_FAILED = 412 +REQUEST_ENTITY_TOO_LARGE = 413 +REQUEST_URI_TOO_LONG = 414 +UNSUPPORTED_MEDIA_TYPE = 415 +REQUESTED_RANGE_NOT_SATISFIABLE = 416 +EXPECTATION_FAILED = 417 +IM_A_TEAPOT = 418 + +INTERNAL_SERVER_ERROR = 500 +NOT_IMPLEMENTED = 501 +BAD_GATEWAY = 502 +SERVICE_UNAVAILABLE = 503 +GATEWAY_TIMEOUT = 504 +HTTP_VERSION_NOT_SUPPORTED = 505 +INSUFFICIENT_STORAGE_SPACE = 507 +NOT_EXTENDED = 510 + +RESPONSES = { + # 100 + CONTINUE: "Continue", + SWITCHING: "Switching Protocols", + + # 200 + OK: "OK", + CREATED: "Created", + ACCEPTED: "Accepted", + NON_AUTHORITATIVE_INFORMATION: "Non-Authoritative Information", + NO_CONTENT: "No Content", + RESET_CONTENT: "Reset Content.", + PARTIAL_CONTENT: "Partial Content", + MULTI_STATUS: "Multi-Status", + + # 300 + MULTIPLE_CHOICE: "Multiple Choices", + MOVED_PERMANENTLY: "Moved Permanently", + FOUND: "Found", + SEE_OTHER: "See Other", + NOT_MODIFIED: "Not Modified", + USE_PROXY: "Use Proxy", + # 306 not defined?? + TEMPORARY_REDIRECT: "Temporary Redirect", + + # 400 + BAD_REQUEST: "Bad Request", + UNAUTHORIZED: "Unauthorized", + PAYMENT_REQUIRED: "Payment Required", + FORBIDDEN: "Forbidden", + NOT_FOUND: "Not Found", + NOT_ALLOWED: "Method Not Allowed", + NOT_ACCEPTABLE: "Not Acceptable", + PROXY_AUTH_REQUIRED: "Proxy Authentication Required", + REQUEST_TIMEOUT: "Request Time-out", + CONFLICT: "Conflict", + GONE: "Gone", + LENGTH_REQUIRED: "Length Required", + PRECONDITION_FAILED: "Precondition Failed", + REQUEST_ENTITY_TOO_LARGE: "Request Entity Too Large", + REQUEST_URI_TOO_LONG: "Request-URI Too Long", + UNSUPPORTED_MEDIA_TYPE: "Unsupported Media Type", + REQUESTED_RANGE_NOT_SATISFIABLE: "Requested Range not satisfiable", + EXPECTATION_FAILED: "Expectation Failed", + IM_A_TEAPOT: "I'm a teapot", + + # 500 + INTERNAL_SERVER_ERROR: "Internal Server Error", + NOT_IMPLEMENTED: "Not Implemented", + BAD_GATEWAY: "Bad Gateway", + SERVICE_UNAVAILABLE: "Service Unavailable", + GATEWAY_TIMEOUT: "Gateway Time-out", + HTTP_VERSION_NOT_SUPPORTED: "HTTP Version not supported", + INSUFFICIENT_STORAGE_SPACE: "Insufficient Storage Space", + NOT_EXTENDED: "Not Extended" +} diff --git a/netlib/http/user_agents.py b/netlib/http/user_agents.py new file mode 100644 index 00000000..e8681908 --- /dev/null +++ b/netlib/http/user_agents.py @@ -0,0 +1,52 @@ +from __future__ import (absolute_import, print_function, division) + +""" + A small collection of useful user-agent header strings. These should be + kept reasonably current to reflect common usage. +""" + +# pylint: line-too-long + +# A collection of (name, shortcut, string) tuples. + +UASTRINGS = [ + ("android", + "a", + "Mozilla/5.0 (Linux; U; Android 4.1.1; en-gb; Nexus 7 Build/JRO03D) AFL/01.04.02"), # noqa + ("blackberry", + "l", + "Mozilla/5.0 (BlackBerry; U; BlackBerry 9900; en) AppleWebKit/534.11+ (KHTML, like Gecko) Version/7.1.0.346 Mobile Safari/534.11+"), # noqa + ("bingbot", + "b", + "Mozilla/5.0 (compatible; bingbot/2.0; +http://www.bing.com/bingbot.htm)"), # noqa + ("chrome", + "c", + "Mozilla/5.0 (Windows NT 6.1; WOW64) AppleWebKit/537.1 (KHTML, like Gecko) Chrome/22.0.1207.1 Safari/537.1"), # noqa + ("firefox", + "f", + "Mozilla/5.0 (Windows NT 6.1; Win64; x64; rv:14.0) Gecko/20120405 Firefox/14.0a1"), # noqa + ("googlebot", + "g", + "Googlebot/2.1 (+http://www.googlebot.com/bot.html)"), # noqa + ("ie9", + "i", + "Mozilla/5.0 (Windows; U; MSIE 9.0; WIndows NT 9.0; en-US)"), # noqa + ("ipad", + "p", + "Mozilla/5.0 (iPad; CPU OS 5_1 like Mac OS X) AppleWebKit/534.46 (KHTML, like Gecko) Version/5.1 Mobile/9B176 Safari/7534.48.3"), # noqa + ("iphone", + "h", + "Mozilla/5.0 (iPhone; CPU iPhone OS 4_2_1 like Mac OS X) AppleWebKit/533.17.9 (KHTML, like Gecko) Version/5.0.2 Mobile/8C148a Safari/6533.18.5"), # noqa + ("safari", + "s", + "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_7_3) AppleWebKit/534.55.3 (KHTML, like Gecko) Version/5.1.3 Safari/534.53.10"), # noqa +] + + +def get_by_shortcut(s): + """ + Retrieve a user agent entry by shortcut. + """ + for i in UASTRINGS: + if s == i[1]: + return i diff --git a/netlib/netlib/__init__.py b/netlib/netlib/__init__.py deleted file mode 100644 index 9b4faa33..00000000 --- a/netlib/netlib/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from __future__ import (absolute_import, print_function, division) diff --git a/netlib/netlib/certutils.py b/netlib/netlib/certutils.py deleted file mode 100644 index 616a778e..00000000 --- a/netlib/netlib/certutils.py +++ /dev/null @@ -1,472 +0,0 @@ -from __future__ import (absolute_import, print_function, division) -import os -import ssl -import time -import datetime -from six.moves import filter -import ipaddress - -import sys -from pyasn1.type import univ, constraint, char, namedtype, tag -from pyasn1.codec.der.decoder import decode -from pyasn1.error import PyAsn1Error -import OpenSSL - -from .utils import Serializable - -# Default expiry must not be too long: https://github.com/mitmproxy/mitmproxy/issues/815 - -DEFAULT_EXP = 94608000 # = 24 * 60 * 60 * 365 * 3 -# Generated with "openssl dhparam". It's too slow to generate this on startup. -DEFAULT_DHPARAM = b""" ------BEGIN DH PARAMETERS----- -MIICCAKCAgEAyT6LzpwVFS3gryIo29J5icvgxCnCebcdSe/NHMkD8dKJf8suFCg3 -O2+dguLakSVif/t6dhImxInJk230HmfC8q93hdcg/j8rLGJYDKu3ik6H//BAHKIv -j5O9yjU3rXCfmVJQic2Nne39sg3CreAepEts2TvYHhVv3TEAzEqCtOuTjgDv0ntJ -Gwpj+BJBRQGG9NvprX1YGJ7WOFBP/hWU7d6tgvE6Xa7T/u9QIKpYHMIkcN/l3ZFB -chZEqVlyrcngtSXCROTPcDOQ6Q8QzhaBJS+Z6rcsd7X+haiQqvoFcmaJ08Ks6LQC -ZIL2EtYJw8V8z7C0igVEBIADZBI6OTbuuhDwRw//zU1uq52Oc48CIZlGxTYG/Evq -o9EWAXUYVzWkDSTeBH1r4z/qLPE2cnhtMxbFxuvK53jGB0emy2y1Ei6IhKshJ5qX -IB/aE7SSHyQ3MDHHkCmQJCsOd4Mo26YX61NZ+n501XjqpCBQ2+DfZCBh8Va2wDyv -A2Ryg9SUz8j0AXViRNMJgJrr446yro/FuJZwnQcO3WQnXeqSBnURqKjmqkeFP+d8 -6mk2tqJaY507lRNqtGlLnj7f5RNoBFJDCLBNurVgfvq9TCVWKDIFD4vZRjCrnl6I -rD693XKIHUCWOjMh1if6omGXKHH40QuME2gNa50+YPn1iYDl88uDbbMCAQI= ------END DH PARAMETERS----- -""" - - -def create_ca(o, cn, exp): - key = OpenSSL.crypto.PKey() - key.generate_key(OpenSSL.crypto.TYPE_RSA, 2048) - cert = OpenSSL.crypto.X509() - cert.set_serial_number(int(time.time() * 10000)) - cert.set_version(2) - cert.get_subject().CN = cn - cert.get_subject().O = o - cert.gmtime_adj_notBefore(-3600 * 48) - cert.gmtime_adj_notAfter(exp) - cert.set_issuer(cert.get_subject()) - cert.set_pubkey(key) - cert.add_extensions([ - OpenSSL.crypto.X509Extension( - b"basicConstraints", - True, - b"CA:TRUE" - ), - OpenSSL.crypto.X509Extension( - b"nsCertType", - False, - b"sslCA" - ), - OpenSSL.crypto.X509Extension( - b"extendedKeyUsage", - False, - b"serverAuth,clientAuth,emailProtection,timeStamping,msCodeInd,msCodeCom,msCTLSign,msSGC,msEFS,nsSGC" - ), - OpenSSL.crypto.X509Extension( - b"keyUsage", - True, - b"keyCertSign, cRLSign" - ), - OpenSSL.crypto.X509Extension( - b"subjectKeyIdentifier", - False, - b"hash", - subject=cert - ), - ]) - cert.sign(key, "sha256") - return key, cert - - -def dummy_cert(privkey, cacert, commonname, sans): - """ - Generates a dummy certificate. - - privkey: CA private key - cacert: CA certificate - commonname: Common name for the generated certificate. - sans: A list of Subject Alternate Names. - - Returns cert if operation succeeded, None if not. - """ - ss = [] - for i in sans: - try: - ipaddress.ip_address(i.decode("ascii")) - except ValueError: - ss.append(b"DNS: %s" % i) - else: - ss.append(b"IP: %s" % i) - ss = b", ".join(ss) - - cert = OpenSSL.crypto.X509() - cert.gmtime_adj_notBefore(-3600 * 48) - cert.gmtime_adj_notAfter(DEFAULT_EXP) - cert.set_issuer(cacert.get_subject()) - if commonname is not None: - cert.get_subject().CN = commonname - cert.set_serial_number(int(time.time() * 10000)) - if ss: - cert.set_version(2) - cert.add_extensions( - [OpenSSL.crypto.X509Extension(b"subjectAltName", False, ss)]) - cert.set_pubkey(cacert.get_pubkey()) - cert.sign(privkey, "sha256") - return SSLCert(cert) - - -# DNTree did not pass TestCertStore.test_sans_change and is temporarily replaced by a simple dict. -# -# class _Node(UserDict.UserDict): -# def __init__(self): -# UserDict.UserDict.__init__(self) -# self.value = None -# -# -# class DNTree: -# """ -# Domain store that knows about wildcards. DNS wildcards are very -# restricted - the only valid variety is an asterisk on the left-most -# domain component, i.e.: -# -# *.foo.com -# """ -# def __init__(self): -# self.d = _Node() -# -# def add(self, dn, cert): -# parts = dn.split(".") -# parts.reverse() -# current = self.d -# for i in parts: -# current = current.setdefault(i, _Node()) -# current.value = cert -# -# def get(self, dn): -# parts = dn.split(".") -# current = self.d -# for i in reversed(parts): -# if i in current: -# current = current[i] -# elif "*" in current: -# return current["*"].value -# else: -# return None -# return current.value - - -class CertStoreEntry(object): - - def __init__(self, cert, privatekey, chain_file): - self.cert = cert - self.privatekey = privatekey - self.chain_file = chain_file - - -class CertStore(object): - - """ - Implements an in-memory certificate store. - """ - - def __init__( - self, - default_privatekey, - default_ca, - default_chain_file, - dhparams): - self.default_privatekey = default_privatekey - self.default_ca = default_ca - self.default_chain_file = default_chain_file - self.dhparams = dhparams - self.certs = dict() - - @staticmethod - def load_dhparam(path): - - # netlib<=0.10 doesn't generate a dhparam file. - # Create it now if neccessary. - if not os.path.exists(path): - with open(path, "wb") as f: - f.write(DEFAULT_DHPARAM) - - bio = OpenSSL.SSL._lib.BIO_new_file(path.encode(sys.getfilesystemencoding()), b"r") - if bio != OpenSSL.SSL._ffi.NULL: - bio = OpenSSL.SSL._ffi.gc(bio, OpenSSL.SSL._lib.BIO_free) - dh = OpenSSL.SSL._lib.PEM_read_bio_DHparams( - bio, - OpenSSL.SSL._ffi.NULL, - OpenSSL.SSL._ffi.NULL, - OpenSSL.SSL._ffi.NULL) - dh = OpenSSL.SSL._ffi.gc(dh, OpenSSL.SSL._lib.DH_free) - return dh - - @classmethod - def from_store(cls, path, basename): - ca_path = os.path.join(path, basename + "-ca.pem") - if not os.path.exists(ca_path): - key, ca = cls.create_store(path, basename) - else: - with open(ca_path, "rb") as f: - raw = f.read() - ca = OpenSSL.crypto.load_certificate( - OpenSSL.crypto.FILETYPE_PEM, - raw) - key = OpenSSL.crypto.load_privatekey( - OpenSSL.crypto.FILETYPE_PEM, - raw) - dh_path = os.path.join(path, basename + "-dhparam.pem") - dh = cls.load_dhparam(dh_path) - return cls(key, ca, ca_path, dh) - - @staticmethod - def create_store(path, basename, o=None, cn=None, expiry=DEFAULT_EXP): - if not os.path.exists(path): - os.makedirs(path) - - o = o or basename - cn = cn or basename - - key, ca = create_ca(o=o, cn=cn, exp=expiry) - # Dump the CA plus private key - with open(os.path.join(path, basename + "-ca.pem"), "wb") as f: - f.write( - OpenSSL.crypto.dump_privatekey( - OpenSSL.crypto.FILETYPE_PEM, - key)) - f.write( - OpenSSL.crypto.dump_certificate( - OpenSSL.crypto.FILETYPE_PEM, - ca)) - - # Dump the certificate in PEM format - with open(os.path.join(path, basename + "-ca-cert.pem"), "wb") as f: - f.write( - OpenSSL.crypto.dump_certificate( - OpenSSL.crypto.FILETYPE_PEM, - ca)) - - # Create a .cer file with the same contents for Android - with open(os.path.join(path, basename + "-ca-cert.cer"), "wb") as f: - f.write( - OpenSSL.crypto.dump_certificate( - OpenSSL.crypto.FILETYPE_PEM, - ca)) - - # Dump the certificate in PKCS12 format for Windows devices - with open(os.path.join(path, basename + "-ca-cert.p12"), "wb") as f: - p12 = OpenSSL.crypto.PKCS12() - p12.set_certificate(ca) - p12.set_privatekey(key) - f.write(p12.export()) - - with open(os.path.join(path, basename + "-dhparam.pem"), "wb") as f: - f.write(DEFAULT_DHPARAM) - - return key, ca - - def add_cert_file(self, spec, path): - with open(path, "rb") as f: - raw = f.read() - cert = SSLCert( - OpenSSL.crypto.load_certificate( - OpenSSL.crypto.FILETYPE_PEM, - raw)) - try: - privatekey = OpenSSL.crypto.load_privatekey( - OpenSSL.crypto.FILETYPE_PEM, - raw) - except Exception: - privatekey = self.default_privatekey - self.add_cert( - CertStoreEntry(cert, privatekey, path), - spec - ) - - def add_cert(self, entry, *names): - """ - Adds a cert to the certstore. We register the CN in the cert plus - any SANs, and also the list of names provided as an argument. - """ - if entry.cert.cn: - self.certs[entry.cert.cn] = entry - for i in entry.cert.altnames: - self.certs[i] = entry - for i in names: - self.certs[i] = entry - - @staticmethod - def asterisk_forms(dn): - if dn is None: - return [] - parts = dn.split(b".") - parts.reverse() - curr_dn = b"" - dn_forms = [b"*"] - for part in parts[:-1]: - curr_dn = b"." + part + curr_dn # .example.com - dn_forms.append(b"*" + curr_dn) # *.example.com - if parts[-1] != b"*": - dn_forms.append(parts[-1] + curr_dn) - return dn_forms - - def get_cert(self, commonname, sans): - """ - Returns an (cert, privkey, cert_chain) tuple. - - commonname: Common name for the generated certificate. Must be a - valid, plain-ASCII, IDNA-encoded domain name. - - sans: A list of Subject Alternate Names. - """ - - potential_keys = self.asterisk_forms(commonname) - for s in sans: - potential_keys.extend(self.asterisk_forms(s)) - potential_keys.append((commonname, tuple(sans))) - - name = next( - filter(lambda key: key in self.certs, potential_keys), - None - ) - if name: - entry = self.certs[name] - else: - entry = CertStoreEntry( - cert=dummy_cert( - self.default_privatekey, - self.default_ca, - commonname, - sans), - privatekey=self.default_privatekey, - chain_file=self.default_chain_file) - self.certs[(commonname, tuple(sans))] = entry - - return entry.cert, entry.privatekey, entry.chain_file - - -class _GeneralName(univ.Choice): - # We are only interested in dNSNames. We use a default handler to ignore - # other types. - # TODO: We should also handle iPAddresses. - componentType = namedtype.NamedTypes( - namedtype.NamedType('dNSName', char.IA5String().subtype( - implicitTag=tag.Tag(tag.tagClassContext, tag.tagFormatSimple, 2) - ) - ), - ) - - -class _GeneralNames(univ.SequenceOf): - componentType = _GeneralName() - sizeSpec = univ.SequenceOf.sizeSpec + \ - constraint.ValueSizeConstraint(1, 1024) - - -class SSLCert(Serializable): - - def __init__(self, cert): - """ - Returns a (common name, [subject alternative names]) tuple. - """ - self.x509 = cert - - def __eq__(self, other): - return self.digest("sha256") == other.digest("sha256") - - def __ne__(self, other): - return not self.__eq__(other) - - def get_state(self): - return self.to_pem() - - def set_state(self, state): - self.x509 = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, state) - - @classmethod - def from_state(cls, state): - cls.from_pem(state) - - @classmethod - def from_pem(cls, txt): - x509 = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, txt) - return cls(x509) - - @classmethod - def from_der(cls, der): - pem = ssl.DER_cert_to_PEM_cert(der) - return cls.from_pem(pem) - - def to_pem(self): - return OpenSSL.crypto.dump_certificate( - OpenSSL.crypto.FILETYPE_PEM, - self.x509) - - def digest(self, name): - return self.x509.digest(name) - - @property - def issuer(self): - return self.x509.get_issuer().get_components() - - @property - def notbefore(self): - t = self.x509.get_notBefore() - return datetime.datetime.strptime(t.decode("ascii"), "%Y%m%d%H%M%SZ") - - @property - def notafter(self): - t = self.x509.get_notAfter() - return datetime.datetime.strptime(t.decode("ascii"), "%Y%m%d%H%M%SZ") - - @property - def has_expired(self): - return self.x509.has_expired() - - @property - def subject(self): - return self.x509.get_subject().get_components() - - @property - def serial(self): - return self.x509.get_serial_number() - - @property - def keyinfo(self): - pk = self.x509.get_pubkey() - types = { - OpenSSL.crypto.TYPE_RSA: "RSA", - OpenSSL.crypto.TYPE_DSA: "DSA", - } - return ( - types.get(pk.type(), "UNKNOWN"), - pk.bits() - ) - - @property - def cn(self): - c = None - for i in self.subject: - if i[0] == b"CN": - c = i[1] - return c - - @property - def altnames(self): - """ - Returns: - All DNS altnames. - """ - # tcp.TCPClient.convert_to_ssl assumes that this property only contains DNS altnames for hostname verification. - altnames = [] - for i in range(self.x509.get_extension_count()): - ext = self.x509.get_extension(i) - if ext.get_short_name() == b"subjectAltName": - try: - dec = decode(ext.get_data(), asn1Spec=_GeneralNames()) - except PyAsn1Error: - continue - for i in dec[0]: - altnames.append(i[0].asOctets()) - return altnames diff --git a/netlib/netlib/encoding.py b/netlib/netlib/encoding.py deleted file mode 100644 index 14479e00..00000000 --- a/netlib/netlib/encoding.py +++ /dev/null @@ -1,88 +0,0 @@ -""" - Utility functions for decoding response bodies. -""" -from __future__ import absolute_import -from io import BytesIO -import gzip -import zlib -from .utils import always_byte_args - - -ENCODINGS = {"identity", "gzip", "deflate"} - - -def decode(e, content): - if not isinstance(content, bytes): - return None - encoding_map = { - "identity": identity, - "gzip": decode_gzip, - "deflate": decode_deflate, - } - if e not in encoding_map: - return None - return encoding_map[e](content) - - -def encode(e, content): - if not isinstance(content, bytes): - return None - encoding_map = { - "identity": identity, - "gzip": encode_gzip, - "deflate": encode_deflate, - } - if e not in encoding_map: - return None - return encoding_map[e](content) - - -def identity(content): - """ - Returns content unchanged. Identity is the default value of - Accept-Encoding headers. - """ - return content - - -def decode_gzip(content): - gfile = gzip.GzipFile(fileobj=BytesIO(content)) - try: - return gfile.read() - except (IOError, EOFError): - return None - - -def encode_gzip(content): - s = BytesIO() - gf = gzip.GzipFile(fileobj=s, mode='wb') - gf.write(content) - gf.close() - return s.getvalue() - - -def decode_deflate(content): - """ - Returns decompressed data for DEFLATE. Some servers may respond with - compressed data without a zlib header or checksum. An undocumented - feature of zlib permits the lenient decompression of data missing both - values. - - http://bugs.python.org/issue5784 - """ - try: - try: - return zlib.decompress(content) - except zlib.error: - return zlib.decompress(content, -15) - except zlib.error: - return None - - -def encode_deflate(content): - """ - Returns compressed content, always including zlib header and checksum. - """ - return zlib.compress(content) - -__all__ = ["ENCODINGS", "encode", "decode"] diff --git a/netlib/netlib/exceptions.py b/netlib/netlib/exceptions.py deleted file mode 100644 index 05f1054b..00000000 --- a/netlib/netlib/exceptions.py +++ /dev/null @@ -1,56 +0,0 @@ -""" -We try to be very hygienic regarding the exceptions we throw: -Every Exception netlib raises shall be a subclass of NetlibException. - - -See also: http://lucumr.pocoo.org/2014/10/16/on-error-handling/ -""" -from __future__ import absolute_import, print_function, division - - -class NetlibException(Exception): - """ - Base class for all exceptions thrown by netlib. - """ - def __init__(self, message=None): - super(NetlibException, self).__init__(message) - - -class Disconnect(object): - """Immediate EOF""" - - -class HttpException(NetlibException): - pass - - -class HttpReadDisconnect(HttpException, Disconnect): - pass - - -class HttpSyntaxException(HttpException): - pass - - -class TcpException(NetlibException): - pass - - -class TcpDisconnect(TcpException, Disconnect): - pass - - -class TcpReadIncomplete(TcpException): - pass - - -class TcpTimeout(TcpException): - pass - - -class TlsException(NetlibException): - pass - - -class InvalidCertificateException(TlsException): - pass diff --git a/netlib/netlib/http/__init__.py b/netlib/netlib/http/__init__.py deleted file mode 100644 index fd632cd5..00000000 --- a/netlib/netlib/http/__init__.py +++ /dev/null @@ -1,14 +0,0 @@ -from __future__ import absolute_import, print_function, division -from .request import Request -from .response import Response -from .headers import Headers -from .message import decoded, CONTENT_MISSING -from . import http1, http2 - -__all__ = [ - "Request", - "Response", - "Headers", - "decoded", "CONTENT_MISSING", - "http1", "http2", -] diff --git a/netlib/netlib/http/authentication.py b/netlib/netlib/http/authentication.py deleted file mode 100644 index d769abe5..00000000 --- a/netlib/netlib/http/authentication.py +++ /dev/null @@ -1,167 +0,0 @@ -from __future__ import (absolute_import, print_function, division) -from argparse import Action, ArgumentTypeError -import binascii - - -def parse_http_basic_auth(s): - words = s.split() - if len(words) != 2: - return None - scheme = words[0] - try: - user = binascii.a2b_base64(words[1]).decode("utf8", "replace") - except binascii.Error: - return None - parts = user.split(':') - if len(parts) != 2: - return None - return scheme, parts[0], parts[1] - - -def assemble_http_basic_auth(scheme, username, password): - v = binascii.b2a_base64((username + ":" + password).encode("utf8")).decode("ascii") - return scheme + " " + v - - -class NullProxyAuth(object): - - """ - No proxy auth at all (returns empty challange headers) - """ - - def __init__(self, password_manager): - self.password_manager = password_manager - - def clean(self, headers_): - """ - Clean up authentication headers, so they're not passed upstream. - """ - - def authenticate(self, headers_): - """ - Tests that the user is allowed to use the proxy - """ - return True - - def auth_challenge_headers(self): - """ - Returns a dictionary containing the headers require to challenge the user - """ - return {} - - -class BasicProxyAuth(NullProxyAuth): - CHALLENGE_HEADER = 'Proxy-Authenticate' - AUTH_HEADER = 'Proxy-Authorization' - - def __init__(self, password_manager, realm): - NullProxyAuth.__init__(self, password_manager) - self.realm = realm - - def clean(self, headers): - del headers[self.AUTH_HEADER] - - def authenticate(self, headers): - auth_value = headers.get(self.AUTH_HEADER) - if not auth_value: - return False - parts = parse_http_basic_auth(auth_value) - if not parts: - return False - scheme, username, password = parts - if scheme.lower() != 'basic': - return False - if not self.password_manager.test(username, password): - return False - self.username = username - return True - - def auth_challenge_headers(self): - return {self.CHALLENGE_HEADER: 'Basic realm="%s"' % self.realm} - - -class PassMan(object): - - def test(self, username_, password_token_): - return False - - -class PassManNonAnon(PassMan): - - """ - Ensure the user specifies a username, accept any password. - """ - - def test(self, username, password_token_): - if username: - return True - return False - - -class PassManHtpasswd(PassMan): - - """ - Read usernames and passwords from an htpasswd file - """ - - def __init__(self, path): - """ - Raises ValueError if htpasswd file is invalid. - """ - import passlib.apache - self.htpasswd = passlib.apache.HtpasswdFile(path) - - def test(self, username, password_token): - return bool(self.htpasswd.check_password(username, password_token)) - - -class PassManSingleUser(PassMan): - - def __init__(self, username, password): - self.username, self.password = username, password - - def test(self, username, password_token): - return self.username == username and self.password == password_token - - -class AuthAction(Action): - - """ - Helper class to allow seamless integration int argparse. Example usage: - parser.add_argument( - "--nonanonymous", - action=NonanonymousAuthAction, nargs=0, - help="Allow access to any user long as a credentials are specified." - ) - """ - - def __call__(self, parser, namespace, values, option_string=None): - passman = self.getPasswordManager(values) - authenticator = BasicProxyAuth(passman, "mitmproxy") - setattr(namespace, self.dest, authenticator) - - def getPasswordManager(self, s): # pragma: nocover - raise NotImplementedError() - - -class SingleuserAuthAction(AuthAction): - - def getPasswordManager(self, s): - if len(s.split(':')) != 2: - raise ArgumentTypeError( - "Invalid single-user specification. Please use the format username:password" - ) - username, password = s.split(':') - return PassManSingleUser(username, password) - - -class NonanonymousAuthAction(AuthAction): - - def getPasswordManager(self, s): - return PassManNonAnon() - - -class HtpasswdAuthAction(AuthAction): - - def getPasswordManager(self, s): - return PassManHtpasswd(s) diff --git a/netlib/netlib/http/cookies.py b/netlib/netlib/http/cookies.py deleted file mode 100644 index 18544b5e..00000000 --- a/netlib/netlib/http/cookies.py +++ /dev/null @@ -1,193 +0,0 @@ -import re - -from .. import odict - -""" -A flexible module for cookie parsing and manipulation. - -This module differs from usual standards-compliant cookie modules in a number -of ways. We try to be as permissive as possible, and to retain even mal-formed -information. Duplicate cookies are preserved in parsing, and can be set in -formatting. We do attempt to escape and quote values where needed, but will not -reject data that violate the specs. - -Parsing accepts the formats in RFC6265 and partially RFC2109 and RFC2965. We do -not parse the comma-separated variant of Set-Cookie that allows multiple -cookies to be set in a single header. Technically this should be feasible, but -it turns out that violations of RFC6265 that makes the parsing problem -indeterminate are much more common than genuine occurences of the multi-cookie -variants. Serialization follows RFC6265. - - http://tools.ietf.org/html/rfc6265 - http://tools.ietf.org/html/rfc2109 - http://tools.ietf.org/html/rfc2965 -""" - -# TODO: Disallow LHS-only Cookie values - - -def _read_until(s, start, term): - """ - Read until one of the characters in term is reached. - """ - if start == len(s): - return "", start + 1 - for i in range(start, len(s)): - if s[i] in term: - return s[start:i], i - return s[start:i + 1], i + 1 - - -def _read_token(s, start): - """ - Read a token - the LHS of a token/value pair in a cookie. - """ - return _read_until(s, start, ";=") - - -def _read_quoted_string(s, start): - """ - start: offset to the first quote of the string to be read - - A sort of loose super-set of the various quoted string specifications. - - RFC6265 disallows backslashes or double quotes within quoted strings. - Prior RFCs use backslashes to escape. This leaves us free to apply - backslash escaping by default and be compatible with everything. - """ - escaping = False - ret = [] - # Skip the first quote - i = start # initialize in case the loop doesn't run. - for i in range(start + 1, len(s)): - if escaping: - ret.append(s[i]) - escaping = False - elif s[i] == '"': - break - elif s[i] == "\\": - escaping = True - else: - ret.append(s[i]) - return "".join(ret), i + 1 - - -def _read_value(s, start, delims): - """ - Reads a value - the RHS of a token/value pair in a cookie. - - special: If the value is special, commas are premitted. Else comma - terminates. This helps us support old and new style values. - """ - if start >= len(s): - return "", start - elif s[start] == '"': - return _read_quoted_string(s, start) - else: - return _read_until(s, start, delims) - - -def _read_pairs(s, off=0): - """ - Read pairs of lhs=rhs values. - - off: start offset - specials: a lower-cased list of keys that may contain commas - """ - vals = [] - while True: - lhs, off = _read_token(s, off) - lhs = lhs.lstrip() - if lhs: - rhs = None - if off < len(s): - if s[off] == "=": - rhs, off = _read_value(s, off + 1, ";") - vals.append([lhs, rhs]) - off += 1 - if not off < len(s): - break - return vals, off - - -def _has_special(s): - for i in s: - if i in '",;\\': - return True - o = ord(i) - if o < 0x21 or o > 0x7e: - return True - return False - - -ESCAPE = re.compile(r"([\"\\])") - - -def _format_pairs(lst, specials=(), sep="; "): - """ - specials: A lower-cased list of keys that will not be quoted. - """ - vals = [] - for k, v in lst: - if v is None: - vals.append(k) - else: - if k.lower() not in specials and _has_special(v): - v = ESCAPE.sub(r"\\\1", v) - v = '"%s"' % v - vals.append("%s=%s" % (k, v)) - return sep.join(vals) - - -def _format_set_cookie_pairs(lst): - return _format_pairs( - lst, - specials=("expires", "path") - ) - - -def _parse_set_cookie_pairs(s): - """ - For Set-Cookie, we support multiple cookies as described in RFC2109. - This function therefore returns a list of lists. - """ - pairs, off_ = _read_pairs(s) - return pairs - - -def parse_set_cookie_header(line): - """ - Parse a Set-Cookie header value - - Returns a (name, value, attrs) tuple, or None, where attrs is an - ODictCaseless set of attributes. No attempt is made to parse attribute - values - they are treated purely as strings. - """ - pairs = _parse_set_cookie_pairs(line) - if pairs: - return pairs[0][0], pairs[0][1], odict.ODictCaseless(pairs[1:]) - - -def format_set_cookie_header(name, value, attrs): - """ - Formats a Set-Cookie header value. - """ - pairs = [[name, value]] - pairs.extend(attrs.lst) - return _format_set_cookie_pairs(pairs) - - -def parse_cookie_header(line): - """ - Parse a Cookie header value. - Returns a (possibly empty) ODict object. - """ - pairs, off_ = _read_pairs(line) - return odict.ODict(pairs) - - -def format_cookie_header(od): - """ - Formats a Cookie header value. - """ - return _format_pairs(od.lst) diff --git a/netlib/netlib/http/headers.py b/netlib/netlib/http/headers.py deleted file mode 100644 index 78404796..00000000 --- a/netlib/netlib/http/headers.py +++ /dev/null @@ -1,204 +0,0 @@ -""" - -Unicode Handling ----------------- -See also: http://lucumr.pocoo.org/2013/7/2/the-updated-guide-to-unicode/ -""" -from __future__ import absolute_import, print_function, division -import copy -try: - from collections.abc import MutableMapping -except ImportError: # pragma: nocover - from collections import MutableMapping # Workaround for Python < 3.3 - - -import six - -from netlib.utils import always_byte_args, always_bytes, Serializable - -if six.PY2: # pragma: nocover - _native = lambda x: x - _always_bytes = lambda x: x - _always_byte_args = lambda x: x -else: - # While headers _should_ be ASCII, it's not uncommon for certain headers to be utf-8 encoded. - _native = lambda x: x.decode("utf-8", "surrogateescape") - _always_bytes = lambda x: always_bytes(x, "utf-8", "surrogateescape") - _always_byte_args = always_byte_args("utf-8", "surrogateescape") - - -class Headers(MutableMapping, Serializable): - """ - Header class which allows both convenient access to individual headers as well as - direct access to the underlying raw data. Provides a full dictionary interface. - - Example: - - .. code-block:: python - - # Create headers with keyword arguments - >>> h = Headers(host="example.com", content_type="application/xml") - - # Headers mostly behave like a normal dict. - >>> h["Host"] - "example.com" - - # HTTP Headers are case insensitive - >>> h["host"] - "example.com" - - # Headers can also be creatd from a list of raw (header_name, header_value) byte tuples - >>> h = Headers([ - [b"Host",b"example.com"], - [b"Accept",b"text/html"], - [b"accept",b"application/xml"] - ]) - - # Multiple headers are folded into a single header as per RFC7230 - >>> h["Accept"] - "text/html, application/xml" - - # Setting a header removes all existing headers with the same name. - >>> h["Accept"] = "application/text" - >>> h["Accept"] - "application/text" - - # bytes(h) returns a HTTP1 header block. - >>> print(bytes(h)) - Host: example.com - Accept: application/text - - # For full control, the raw header fields can be accessed - >>> h.fields - - Caveats: - For use with the "Set-Cookie" header, see :py:meth:`get_all`. - """ - - @_always_byte_args - def __init__(self, fields=None, **headers): - """ - Args: - fields: (optional) list of ``(name, value)`` header byte tuples, - e.g. ``[(b"Host", b"example.com")]``. All names and values must be bytes. - **headers: Additional headers to set. Will overwrite existing values from `fields`. - For convenience, underscores in header names will be transformed to dashes - - this behaviour does not extend to other methods. - If ``**headers`` contains multiple keys that have equal ``.lower()`` s, - the behavior is undefined. - """ - self.fields = fields or [] - - for name, value in self.fields: - if not isinstance(name, bytes) or not isinstance(value, bytes): - raise ValueError("Headers passed as fields must be bytes.") - - # content_type -> content-type - headers = { - _always_bytes(name).replace(b"_", b"-"): value - for name, value in six.iteritems(headers) - } - self.update(headers) - - def __bytes__(self): - if self.fields: - return b"\r\n".join(b": ".join(field) for field in self.fields) + b"\r\n" - else: - return b"" - - if six.PY2: # pragma: nocover - __str__ = __bytes__ - - @_always_byte_args - def __getitem__(self, name): - values = self.get_all(name) - if not values: - raise KeyError(name) - return ", ".join(values) - - @_always_byte_args - def __setitem__(self, name, value): - idx = self._index(name) - - # To please the human eye, we insert at the same position the first existing header occured. - if idx is not None: - del self[name] - self.fields.insert(idx, [name, value]) - else: - self.fields.append([name, value]) - - @_always_byte_args - def __delitem__(self, name): - if name not in self: - raise KeyError(name) - name = name.lower() - self.fields = [ - field for field in self.fields - if name != field[0].lower() - ] - - def __iter__(self): - seen = set() - for name, _ in self.fields: - name_lower = name.lower() - if name_lower not in seen: - seen.add(name_lower) - yield _native(name) - - def __len__(self): - return len(set(name.lower() for name, _ in self.fields)) - - # __hash__ = object.__hash__ - - def _index(self, name): - name = name.lower() - for i, field in enumerate(self.fields): - if field[0].lower() == name: - return i - return None - - def __eq__(self, other): - if isinstance(other, Headers): - return self.fields == other.fields - return False - - def __ne__(self, other): - return not self.__eq__(other) - - @_always_byte_args - def get_all(self, name): - """ - Like :py:meth:`get`, but does not fold multiple headers into a single one. - This is useful for Set-Cookie headers, which do not support folding. - - See also: https://tools.ietf.org/html/rfc7230#section-3.2.2 - """ - name_lower = name.lower() - values = [_native(value) for n, value in self.fields if n.lower() == name_lower] - return values - - @_always_byte_args - def set_all(self, name, values): - """ - Explicitly set multiple headers for the given key. - See: :py:meth:`get_all` - """ - values = map(_always_bytes, values) # _always_byte_args does not fix lists - if name in self: - del self[name] - self.fields.extend( - [name, value] for value in values - ) - - def copy(self): - return Headers(copy.copy(self.fields)) - - def get_state(self): - return tuple(tuple(field) for field in self.fields) - - def set_state(self, state): - self.fields = [list(field) for field in state] - - @classmethod - def from_state(cls, state): - return cls([list(field) for field in state]) \ No newline at end of file diff --git a/netlib/netlib/http/http1/__init__.py b/netlib/netlib/http/http1/__init__.py deleted file mode 100644 index 2aa7e26a..00000000 --- a/netlib/netlib/http/http1/__init__.py +++ /dev/null @@ -1,25 +0,0 @@ -from __future__ import absolute_import, print_function, division -from .read import ( - read_request, read_request_head, - read_response, read_response_head, - read_body, - connection_close, - expected_http_body_size, -) -from .assemble import ( - assemble_request, assemble_request_head, - assemble_response, assemble_response_head, - assemble_body, -) - - -__all__ = [ - "read_request", "read_request_head", - "read_response", "read_response_head", - "read_body", - "connection_close", - "expected_http_body_size", - "assemble_request", "assemble_request_head", - "assemble_response", "assemble_response_head", - "assemble_body", -] diff --git a/netlib/netlib/http/http1/assemble.py b/netlib/netlib/http/http1/assemble.py deleted file mode 100644 index 785ee8d3..00000000 --- a/netlib/netlib/http/http1/assemble.py +++ /dev/null @@ -1,104 +0,0 @@ -from __future__ import absolute_import, print_function, division - -from ... import utils -import itertools -from ...exceptions import HttpException -from .. import CONTENT_MISSING - - -def assemble_request(request): - if request.content == CONTENT_MISSING: - raise HttpException("Cannot assemble flow with CONTENT_MISSING") - head = assemble_request_head(request) - body = b"".join(assemble_body(request.data.headers, [request.data.content])) - return head + body - - -def assemble_request_head(request): - first_line = _assemble_request_line(request.data) - headers = _assemble_request_headers(request.data) - return b"%s\r\n%s\r\n" % (first_line, headers) - - -def assemble_response(response): - if response.content == CONTENT_MISSING: - raise HttpException("Cannot assemble flow with CONTENT_MISSING") - head = assemble_response_head(response) - body = b"".join(assemble_body(response.data.headers, [response.data.content])) - return head + body - - -def assemble_response_head(response): - first_line = _assemble_response_line(response.data) - headers = _assemble_response_headers(response.data) - return b"%s\r\n%s\r\n" % (first_line, headers) - - -def assemble_body(headers, body_chunks): - if "chunked" in headers.get("transfer-encoding", "").lower(): - for chunk in body_chunks: - if chunk: - yield b"%x\r\n%s\r\n" % (len(chunk), chunk) - yield b"0\r\n\r\n" - else: - for chunk in body_chunks: - yield chunk - - -def _assemble_request_line(request_data): - """ - Args: - request_data (netlib.http.request.RequestData) - """ - form = request_data.first_line_format - if form == "relative": - return b"%s %s %s" % ( - request_data.method, - request_data.path, - request_data.http_version - ) - elif form == "authority": - return b"%s %s:%d %s" % ( - request_data.method, - request_data.host, - request_data.port, - request_data.http_version - ) - elif form == "absolute": - return b"%s %s://%s:%d%s %s" % ( - request_data.method, - request_data.scheme, - request_data.host, - request_data.port, - request_data.path, - request_data.http_version - ) - else: - raise RuntimeError("Invalid request form") - - -def _assemble_request_headers(request_data): - """ - Args: - request_data (netlib.http.request.RequestData) - """ - headers = request_data.headers.copy() - if "host" not in headers and request_data.scheme and request_data.host and request_data.port: - headers["host"] = utils.hostport( - request_data.scheme, - request_data.host, - request_data.port - ) - return bytes(headers) - - -def _assemble_response_line(response_data): - return b"%s %d %s" % ( - response_data.http_version, - response_data.status_code, - response_data.reason, - ) - - -def _assemble_response_headers(response): - return bytes(response.headers) diff --git a/netlib/netlib/http/http1/read.py b/netlib/netlib/http/http1/read.py deleted file mode 100644 index 6e3a1b93..00000000 --- a/netlib/netlib/http/http1/read.py +++ /dev/null @@ -1,362 +0,0 @@ -from __future__ import absolute_import, print_function, division -import time -import sys -import re - -from ... import utils -from ...exceptions import HttpReadDisconnect, HttpSyntaxException, HttpException, TcpDisconnect -from .. import Request, Response, Headers - - -def read_request(rfile, body_size_limit=None): - request = read_request_head(rfile) - expected_body_size = expected_http_body_size(request) - request.data.content = b"".join(read_body(rfile, expected_body_size, limit=body_size_limit)) - request.timestamp_end = time.time() - return request - - -def read_request_head(rfile): - """ - Parse an HTTP request head (request line + headers) from an input stream - - Args: - rfile: The input stream - - Returns: - The HTTP request object (without body) - - Raises: - HttpReadDisconnect: No bytes can be read from rfile. - HttpSyntaxException: The input is malformed HTTP. - HttpException: Any other error occured. - """ - timestamp_start = time.time() - if hasattr(rfile, "reset_timestamps"): - rfile.reset_timestamps() - - form, method, scheme, host, port, path, http_version = _read_request_line(rfile) - headers = _read_headers(rfile) - - if hasattr(rfile, "first_byte_timestamp"): - # more accurate timestamp_start - timestamp_start = rfile.first_byte_timestamp - - return Request( - form, method, scheme, host, port, path, http_version, headers, None, timestamp_start - ) - - -def read_response(rfile, request, body_size_limit=None): - response = read_response_head(rfile) - expected_body_size = expected_http_body_size(request, response) - response.data.content = b"".join(read_body(rfile, expected_body_size, body_size_limit)) - response.timestamp_end = time.time() - return response - - -def read_response_head(rfile): - """ - Parse an HTTP response head (response line + headers) from an input stream - - Args: - rfile: The input stream - - Returns: - The HTTP request object (without body) - - Raises: - HttpReadDisconnect: No bytes can be read from rfile. - HttpSyntaxException: The input is malformed HTTP. - HttpException: Any other error occured. - """ - - timestamp_start = time.time() - if hasattr(rfile, "reset_timestamps"): - rfile.reset_timestamps() - - http_version, status_code, message = _read_response_line(rfile) - headers = _read_headers(rfile) - - if hasattr(rfile, "first_byte_timestamp"): - # more accurate timestamp_start - timestamp_start = rfile.first_byte_timestamp - - return Response(http_version, status_code, message, headers, None, timestamp_start) - - -def read_body(rfile, expected_size, limit=None, max_chunk_size=4096): - """ - Read an HTTP message body - - Args: - rfile: The input stream - expected_size: The expected body size (see :py:meth:`expected_body_size`) - limit: Maximum body size - max_chunk_size: Maximium chunk size that gets yielded - - Returns: - A generator that yields byte chunks of the content. - - Raises: - HttpException, if an error occurs - - Caveats: - max_chunk_size is not considered if the transfer encoding is chunked. - """ - if not limit or limit < 0: - limit = sys.maxsize - if not max_chunk_size: - max_chunk_size = limit - - if expected_size is None: - for x in _read_chunked(rfile, limit): - yield x - elif expected_size >= 0: - if limit is not None and expected_size > limit: - raise HttpException( - "HTTP Body too large. " - "Limit is {}, content length was advertised as {}".format(limit, expected_size) - ) - bytes_left = expected_size - while bytes_left: - chunk_size = min(bytes_left, max_chunk_size) - content = rfile.read(chunk_size) - if len(content) < chunk_size: - raise HttpException("Unexpected EOF") - yield content - bytes_left -= chunk_size - else: - bytes_left = limit - while bytes_left: - chunk_size = min(bytes_left, max_chunk_size) - content = rfile.read(chunk_size) - if not content: - return - yield content - bytes_left -= chunk_size - not_done = rfile.read(1) - if not_done: - raise HttpException("HTTP body too large. Limit is {}.".format(limit)) - - -def connection_close(http_version, headers): - """ - Checks the message to see if the client connection should be closed - according to RFC 2616 Section 8.1. - """ - # At first, check if we have an explicit Connection header. - if "connection" in headers: - tokens = utils.get_header_tokens(headers, "connection") - if "close" in tokens: - return True - elif "keep-alive" in tokens: - return False - - # If we don't have a Connection header, HTTP 1.1 connections are assumed to - # be persistent - return http_version != "HTTP/1.1" and http_version != b"HTTP/1.1" # FIXME: Remove one case. - - -def expected_http_body_size(request, response=None): - """ - Returns: - The expected body length: - - a positive integer, if the size is known in advance - - None, if the size in unknown in advance (chunked encoding) - - -1, if all data should be read until end of stream. - - Raises: - HttpSyntaxException, if the content length header is invalid - """ - # Determine response size according to - # http://tools.ietf.org/html/rfc7230#section-3.3 - if not response: - headers = request.headers - response_code = None - is_request = True - else: - headers = response.headers - response_code = response.status_code - is_request = False - - if is_request: - if headers.get("expect", "").lower() == "100-continue": - return 0 - else: - if request.method.upper() == "HEAD": - return 0 - if 100 <= response_code <= 199: - return 0 - if response_code == 200 and request.method.upper() == "CONNECT": - return 0 - if response_code in (204, 304): - return 0 - - if "chunked" in headers.get("transfer-encoding", "").lower(): - return None - if "content-length" in headers: - try: - size = int(headers["content-length"]) - if size < 0: - raise ValueError() - return size - except ValueError: - raise HttpSyntaxException("Unparseable Content Length") - if is_request: - return 0 - return -1 - - -def _get_first_line(rfile): - try: - line = rfile.readline() - if line == b"\r\n" or line == b"\n": - # Possible leftover from previous message - line = rfile.readline() - except TcpDisconnect: - raise HttpReadDisconnect("Remote disconnected") - if not line: - raise HttpReadDisconnect("Remote disconnected") - return line.strip() - - -def _read_request_line(rfile): - try: - line = _get_first_line(rfile) - except HttpReadDisconnect: - # We want to provide a better error message. - raise HttpReadDisconnect("Client disconnected") - - try: - method, path, http_version = line.split(b" ") - - if path == b"*" or path.startswith(b"/"): - form = "relative" - scheme, host, port = None, None, None - elif method == b"CONNECT": - form = "authority" - host, port = _parse_authority_form(path) - scheme, path = None, None - else: - form = "absolute" - scheme, host, port, path = utils.parse_url(path) - - _check_http_version(http_version) - except ValueError: - raise HttpSyntaxException("Bad HTTP request line: {}".format(line)) - - return form, method, scheme, host, port, path, http_version - - -def _parse_authority_form(hostport): - """ - Returns (host, port) if hostport is a valid authority-form host specification. - http://tools.ietf.org/html/draft-luotonen-web-proxy-tunneling-01 section 3.1 - - Raises: - ValueError, if the input is malformed - """ - try: - host, port = hostport.split(b":") - port = int(port) - if not utils.is_valid_host(host) or not utils.is_valid_port(port): - raise ValueError() - except ValueError: - raise HttpSyntaxException("Invalid host specification: {}".format(hostport)) - - return host, port - - -def _read_response_line(rfile): - try: - line = _get_first_line(rfile) - except HttpReadDisconnect: - # We want to provide a better error message. - raise HttpReadDisconnect("Server disconnected") - - try: - - parts = line.split(b" ", 2) - if len(parts) == 2: # handle missing message gracefully - parts.append(b"") - - http_version, status_code, message = parts - status_code = int(status_code) - _check_http_version(http_version) - - except ValueError: - raise HttpSyntaxException("Bad HTTP response line: {}".format(line)) - - return http_version, status_code, message - - -def _check_http_version(http_version): - if not re.match(br"^HTTP/\d\.\d$", http_version): - raise HttpSyntaxException("Unknown HTTP version: {}".format(http_version)) - - -def _read_headers(rfile): - """ - Read a set of headers. - Stop once a blank line is reached. - - Returns: - A headers object - - Raises: - HttpSyntaxException - """ - ret = [] - while True: - line = rfile.readline() - if not line or line == b"\r\n" or line == b"\n": - break - if line[0] in b" \t": - if not ret: - raise HttpSyntaxException("Invalid headers") - # continued header - ret[-1][1] = ret[-1][1] + b'\r\n ' + line.strip() - else: - try: - name, value = line.split(b":", 1) - value = value.strip() - if not name: - raise ValueError() - ret.append([name, value]) - except ValueError: - raise HttpSyntaxException("Invalid headers") - return Headers(ret) - - -def _read_chunked(rfile, limit=sys.maxsize): - """ - Read a HTTP body with chunked transfer encoding. - - Args: - rfile: the input file - limit: A positive integer - """ - total = 0 - while True: - line = rfile.readline(128) - if line == b"": - raise HttpException("Connection closed prematurely") - if line != b"\r\n" and line != b"\n": - try: - length = int(line, 16) - except ValueError: - raise HttpSyntaxException("Invalid chunked encoding length: {}".format(line)) - total += length - if total > limit: - raise HttpException( - "HTTP Body too large. Limit is {}, " - "chunked content longer than {}".format(limit, total) - ) - chunk = rfile.read(length) - suffix = rfile.readline(5) - if suffix != b"\r\n": - raise HttpSyntaxException("Malformed chunked body") - if length == 0: - return - yield chunk diff --git a/netlib/netlib/http/http2/__init__.py b/netlib/netlib/http/http2/__init__.py deleted file mode 100644 index 7043d36f..00000000 --- a/netlib/netlib/http/http2/__init__.py +++ /dev/null @@ -1,6 +0,0 @@ -from __future__ import absolute_import, print_function, division -from .connections import HTTP2Protocol - -__all__ = [ - "HTTP2Protocol" -] diff --git a/netlib/netlib/http/http2/connections.py b/netlib/netlib/http/http2/connections.py deleted file mode 100644 index 52fa7193..00000000 --- a/netlib/netlib/http/http2/connections.py +++ /dev/null @@ -1,426 +0,0 @@ -from __future__ import (absolute_import, print_function, division) -import itertools -import time - -from hpack.hpack import Encoder, Decoder -from ... import utils -from .. import Headers, Response, Request - -from hyperframe import frame - - -class TCPHandler(object): - - def __init__(self, rfile, wfile=None): - self.rfile = rfile - self.wfile = wfile - - -class HTTP2Protocol(object): - - ERROR_CODES = utils.BiDi( - NO_ERROR=0x0, - PROTOCOL_ERROR=0x1, - INTERNAL_ERROR=0x2, - FLOW_CONTROL_ERROR=0x3, - SETTINGS_TIMEOUT=0x4, - STREAM_CLOSED=0x5, - FRAME_SIZE_ERROR=0x6, - REFUSED_STREAM=0x7, - CANCEL=0x8, - COMPRESSION_ERROR=0x9, - CONNECT_ERROR=0xa, - ENHANCE_YOUR_CALM=0xb, - INADEQUATE_SECURITY=0xc, - HTTP_1_1_REQUIRED=0xd - ) - - CLIENT_CONNECTION_PREFACE = b'PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n' - - HTTP2_DEFAULT_SETTINGS = { - frame.SettingsFrame.HEADER_TABLE_SIZE: 4096, - frame.SettingsFrame.ENABLE_PUSH: 1, - frame.SettingsFrame.MAX_CONCURRENT_STREAMS: None, - frame.SettingsFrame.INITIAL_WINDOW_SIZE: 2 ** 16 - 1, - frame.SettingsFrame.MAX_FRAME_SIZE: 2 ** 14, - frame.SettingsFrame.MAX_HEADER_LIST_SIZE: None, - } - - def __init__( - self, - tcp_handler=None, - rfile=None, - wfile=None, - is_server=False, - dump_frames=False, - encoder=None, - decoder=None, - unhandled_frame_cb=None, - ): - self.tcp_handler = tcp_handler or TCPHandler(rfile, wfile) - self.is_server = is_server - self.dump_frames = dump_frames - self.encoder = encoder or Encoder() - self.decoder = decoder or Decoder() - self.unhandled_frame_cb = unhandled_frame_cb - - self.http2_settings = self.HTTP2_DEFAULT_SETTINGS.copy() - self.current_stream_id = None - self.connection_preface_performed = False - - def read_request( - self, - __rfile, - include_body=True, - body_size_limit=None, - allow_empty=False, - ): - if body_size_limit is not None: - raise NotImplementedError() - - self.perform_connection_preface() - - timestamp_start = time.time() - if hasattr(self.tcp_handler.rfile, "reset_timestamps"): - self.tcp_handler.rfile.reset_timestamps() - - stream_id, headers, body = self._receive_transmission( - include_body=include_body, - ) - - if hasattr(self.tcp_handler.rfile, "first_byte_timestamp"): - # more accurate timestamp_start - timestamp_start = self.tcp_handler.rfile.first_byte_timestamp - - timestamp_end = time.time() - - authority = headers.get(':authority', b'') - method = headers.get(':method', 'GET') - scheme = headers.get(':scheme', 'https') - path = headers.get(':path', '/') - host = None - port = None - - if path == '*' or path.startswith("/"): - form_in = "relative" - elif method == 'CONNECT': - form_in = "authority" - if ":" in authority: - host, port = authority.split(":", 1) - else: - host = authority - else: - form_in = "absolute" - # FIXME: verify if path or :host contains what we need - scheme, host, port, _ = utils.parse_url(path) - scheme = scheme.decode('ascii') - host = host.decode('ascii') - - if host is None: - host = 'localhost' - if port is None: - port = 80 if scheme == 'http' else 443 - port = int(port) - - request = Request( - form_in, - method.encode('ascii'), - scheme.encode('ascii'), - host.encode('ascii'), - port, - path.encode('ascii'), - b"HTTP/2.0", - headers, - body, - timestamp_start, - timestamp_end, - ) - request.stream_id = stream_id - - return request - - def read_response( - self, - __rfile, - request_method=b'', - body_size_limit=None, - include_body=True, - stream_id=None, - ): - if body_size_limit is not None: - raise NotImplementedError() - - self.perform_connection_preface() - - timestamp_start = time.time() - if hasattr(self.tcp_handler.rfile, "reset_timestamps"): - self.tcp_handler.rfile.reset_timestamps() - - stream_id, headers, body = self._receive_transmission( - stream_id=stream_id, - include_body=include_body, - ) - - if hasattr(self.tcp_handler.rfile, "first_byte_timestamp"): - # more accurate timestamp_start - timestamp_start = self.tcp_handler.rfile.first_byte_timestamp - - if include_body: - timestamp_end = time.time() - else: - timestamp_end = None - - response = Response( - b"HTTP/2.0", - int(headers.get(':status', 502)), - b'', - headers, - body, - timestamp_start=timestamp_start, - timestamp_end=timestamp_end, - ) - response.stream_id = stream_id - - return response - - def assemble(self, message): - if isinstance(message, Request): - return self.assemble_request(message) - elif isinstance(message, Response): - return self.assemble_response(message) - else: - raise ValueError("HTTP message not supported.") - - def assemble_request(self, request): - assert isinstance(request, Request) - - authority = self.tcp_handler.sni if self.tcp_handler.sni else self.tcp_handler.address.host - if self.tcp_handler.address.port != 443: - authority += ":%d" % self.tcp_handler.address.port - - headers = request.headers.copy() - - if ':authority' not in headers: - headers.fields.insert(0, (b':authority', authority.encode('ascii'))) - if ':scheme' not in headers: - headers.fields.insert(0, (b':scheme', request.scheme.encode('ascii'))) - if ':path' not in headers: - headers.fields.insert(0, (b':path', request.path.encode('ascii'))) - if ':method' not in headers: - headers.fields.insert(0, (b':method', request.method.encode('ascii'))) - - if hasattr(request, 'stream_id'): - stream_id = request.stream_id - else: - stream_id = self._next_stream_id() - - return list(itertools.chain( - self._create_headers(headers, stream_id, end_stream=(request.body is None or len(request.body) == 0)), - self._create_body(request.body, stream_id))) - - def assemble_response(self, response): - assert isinstance(response, Response) - - headers = response.headers.copy() - - if ':status' not in headers: - headers.fields.insert(0, (b':status', str(response.status_code).encode('ascii'))) - - if hasattr(response, 'stream_id'): - stream_id = response.stream_id - else: - stream_id = self._next_stream_id() - - return list(itertools.chain( - self._create_headers(headers, stream_id, end_stream=(response.body is None or len(response.body) == 0)), - self._create_body(response.body, stream_id), - )) - - def perform_connection_preface(self, force=False): - if force or not self.connection_preface_performed: - if self.is_server: - self.perform_server_connection_preface(force) - else: - self.perform_client_connection_preface(force) - - def perform_server_connection_preface(self, force=False): - if force or not self.connection_preface_performed: - self.connection_preface_performed = True - - magic_length = len(self.CLIENT_CONNECTION_PREFACE) - magic = self.tcp_handler.rfile.safe_read(magic_length) - assert magic == self.CLIENT_CONNECTION_PREFACE - - frm = frame.SettingsFrame(settings={ - frame.SettingsFrame.ENABLE_PUSH: 0, - frame.SettingsFrame.MAX_CONCURRENT_STREAMS: 1, - }) - self.send_frame(frm, hide=True) - self._receive_settings(hide=True) - - def perform_client_connection_preface(self, force=False): - if force or not self.connection_preface_performed: - self.connection_preface_performed = True - - self.tcp_handler.wfile.write(self.CLIENT_CONNECTION_PREFACE) - - self.send_frame(frame.SettingsFrame(), hide=True) - self._receive_settings(hide=True) # server announces own settings - self._receive_settings(hide=True) # server acks my settings - - def send_frame(self, frm, hide=False): - raw_bytes = frm.serialize() - self.tcp_handler.wfile.write(raw_bytes) - self.tcp_handler.wfile.flush() - if not hide and self.dump_frames: # pragma no cover - print(frm.human_readable(">>")) - - def read_frame(self, hide=False): - while True: - frm = utils.http2_read_frame(self.tcp_handler.rfile) - if not hide and self.dump_frames: # pragma no cover - print(frm.human_readable("<<")) - - if isinstance(frm, frame.PingFrame): - raw_bytes = frame.PingFrame(flags=['ACK'], payload=frm.payload).serialize() - self.tcp_handler.wfile.write(raw_bytes) - self.tcp_handler.wfile.flush() - continue - if isinstance(frm, frame.SettingsFrame) and 'ACK' not in frm.flags: - self._apply_settings(frm.settings, hide) - if isinstance(frm, frame.DataFrame) and frm.flow_controlled_length > 0: - self._update_flow_control_window(frm.stream_id, frm.flow_controlled_length) - return frm - - def check_alpn(self): - alp = self.tcp_handler.get_alpn_proto_negotiated() - if alp != b'h2': - raise NotImplementedError( - "HTTP2Protocol can not handle unknown ALP: %s" % alp) - return True - - def _handle_unexpected_frame(self, frm): - if isinstance(frm, frame.SettingsFrame): - return - if self.unhandled_frame_cb: - self.unhandled_frame_cb(frm) - - def _receive_settings(self, hide=False): - while True: - frm = self.read_frame(hide) - if isinstance(frm, frame.SettingsFrame): - break - else: - self._handle_unexpected_frame(frm) - - def _next_stream_id(self): - if self.current_stream_id is None: - if self.is_server: - # servers must use even stream ids - self.current_stream_id = 2 - else: - # clients must use odd stream ids - self.current_stream_id = 1 - else: - self.current_stream_id += 2 - return self.current_stream_id - - def _apply_settings(self, settings, hide=False): - for setting, value in settings.items(): - old_value = self.http2_settings[setting] - if not old_value: - old_value = '-' - self.http2_settings[setting] = value - - frm = frame.SettingsFrame(flags=['ACK']) - self.send_frame(frm, hide) - - def _update_flow_control_window(self, stream_id, increment): - frm = frame.WindowUpdateFrame(stream_id=0, window_increment=increment) - self.send_frame(frm) - frm = frame.WindowUpdateFrame(stream_id=stream_id, window_increment=increment) - self.send_frame(frm) - - def _create_headers(self, headers, stream_id, end_stream=True): - def frame_cls(chunks): - for i in chunks: - if i == 0: - yield frame.HeadersFrame, i - else: - yield frame.ContinuationFrame, i - - header_block_fragment = self.encoder.encode(headers.fields) - - chunk_size = self.http2_settings[frame.SettingsFrame.MAX_FRAME_SIZE] - chunks = range(0, len(header_block_fragment), chunk_size) - frms = [frm_cls( - flags=[], - stream_id=stream_id, - data=header_block_fragment[i:i+chunk_size]) for frm_cls, i in frame_cls(chunks)] - - frms[-1].flags.add('END_HEADERS') - if end_stream: - frms[0].flags.add('END_STREAM') - - if self.dump_frames: # pragma no cover - for frm in frms: - print(frm.human_readable(">>")) - - return [frm.serialize() for frm in frms] - - def _create_body(self, body, stream_id): - if body is None or len(body) == 0: - return b'' - - chunk_size = self.http2_settings[frame.SettingsFrame.MAX_FRAME_SIZE] - chunks = range(0, len(body), chunk_size) - frms = [frame.DataFrame( - flags=[], - stream_id=stream_id, - data=body[i:i+chunk_size]) for i in chunks] - frms[-1].flags.add('END_STREAM') - - if self.dump_frames: # pragma no cover - for frm in frms: - print(frm.human_readable(">>")) - - return [frm.serialize() for frm in frms] - - def _receive_transmission(self, stream_id=None, include_body=True): - if not include_body: - raise NotImplementedError() - - body_expected = True - - header_blocks = b'' - body = b'' - - while True: - frm = self.read_frame() - if ( - (isinstance(frm, frame.HeadersFrame) or isinstance(frm, frame.ContinuationFrame)) and - (stream_id is None or frm.stream_id == stream_id) - ): - stream_id = frm.stream_id - header_blocks += frm.data - if 'END_STREAM' in frm.flags: - body_expected = False - if 'END_HEADERS' in frm.flags: - break - else: - self._handle_unexpected_frame(frm) - - while body_expected: - frm = self.read_frame() - if isinstance(frm, frame.DataFrame) and frm.stream_id == stream_id: - body += frm.data - if 'END_STREAM' in frm.flags: - break - else: - self._handle_unexpected_frame(frm) - - headers = Headers( - [[k.encode('ascii'), v.encode('ascii')] for k, v in self.decoder.decode(header_blocks)] - ) - - return stream_id, headers, body diff --git a/netlib/netlib/http/message.py b/netlib/netlib/http/message.py deleted file mode 100644 index e3d8ce37..00000000 --- a/netlib/netlib/http/message.py +++ /dev/null @@ -1,222 +0,0 @@ -from __future__ import absolute_import, print_function, division - -import warnings - -import six - -from .headers import Headers -from .. import encoding, utils - -CONTENT_MISSING = 0 - -if six.PY2: # pragma: nocover - _native = lambda x: x - _always_bytes = lambda x: x -else: - # While the HTTP head _should_ be ASCII, it's not uncommon for certain headers to be utf-8 encoded. - _native = lambda x: x.decode("utf-8", "surrogateescape") - _always_bytes = lambda x: utils.always_bytes(x, "utf-8", "surrogateescape") - - -class MessageData(utils.Serializable): - def __eq__(self, other): - if isinstance(other, MessageData): - return self.__dict__ == other.__dict__ - return False - - def __ne__(self, other): - return not self.__eq__(other) - - def set_state(self, state): - for k, v in state.items(): - if k == "headers": - v = Headers.from_state(v) - setattr(self, k, v) - - def get_state(self): - state = vars(self).copy() - state["headers"] = state["headers"].get_state() - return state - - @classmethod - def from_state(cls, state): - state["headers"] = Headers.from_state(state["headers"]) - return cls(**state) - - -class Message(utils.Serializable): - def __init__(self, data): - self.data = data - - def __eq__(self, other): - if isinstance(other, Message): - return self.data == other.data - return False - - def __ne__(self, other): - return not self.__eq__(other) - - def get_state(self): - return self.data.get_state() - - def set_state(self, state): - self.data.set_state(state) - - @classmethod - def from_state(cls, state): - return cls(**state) - - @property - def headers(self): - """ - Message headers object - - Returns: - netlib.http.Headers - """ - return self.data.headers - - @headers.setter - def headers(self, h): - self.data.headers = h - - @property - def content(self): - """ - The raw (encoded) HTTP message body - - See also: :py:attr:`text` - """ - return self.data.content - - @content.setter - def content(self, content): - self.data.content = content - if isinstance(content, bytes): - self.headers["content-length"] = str(len(content)) - - @property - def http_version(self): - """ - Version string, e.g. "HTTP/1.1" - """ - return _native(self.data.http_version) - - @http_version.setter - def http_version(self, http_version): - self.data.http_version = _always_bytes(http_version) - - @property - def timestamp_start(self): - """ - First byte timestamp - """ - return self.data.timestamp_start - - @timestamp_start.setter - def timestamp_start(self, timestamp_start): - self.data.timestamp_start = timestamp_start - - @property - def timestamp_end(self): - """ - Last byte timestamp - """ - return self.data.timestamp_end - - @timestamp_end.setter - def timestamp_end(self, timestamp_end): - self.data.timestamp_end = timestamp_end - - @property - def text(self): - """ - The decoded HTTP message body. - Decoded contents are not cached, so accessing this attribute repeatedly is relatively expensive. - - .. note:: - This is not implemented yet. - - See also: :py:attr:`content`, :py:class:`decoded` - """ - # This attribute should be called text, because that's what requests does. - raise NotImplementedError() - - @text.setter - def text(self, text): - raise NotImplementedError() - - def decode(self): - """ - Decodes body based on the current Content-Encoding header, then - removes the header. If there is no Content-Encoding header, no - action is taken. - - Returns: - True, if decoding succeeded. - False, otherwise. - """ - ce = self.headers.get("content-encoding") - data = encoding.decode(ce, self.content) - if data is None: - return False - self.content = data - self.headers.pop("content-encoding", None) - return True - - def encode(self, e): - """ - Encodes body with the encoding e, where e is "gzip", "deflate" or "identity". - - Returns: - True, if decoding succeeded. - False, otherwise. - """ - data = encoding.encode(e, self.content) - if data is None: - return False - self.content = data - self.headers["content-encoding"] = e - return True - - # Legacy - - @property - def body(self): # pragma: nocover - warnings.warn(".body is deprecated, use .content instead.", DeprecationWarning) - return self.content - - @body.setter - def body(self, body): # pragma: nocover - warnings.warn(".body is deprecated, use .content instead.", DeprecationWarning) - self.content = body - - -class decoded(object): - """ - A context manager that decodes a request or response, and then - re-encodes it with the same encoding after execution of the block. - - Example: - - .. code-block:: python - - with decoded(request): - request.content = request.content.replace("foo", "bar") - """ - - def __init__(self, message): - self.message = message - ce = message.headers.get("content-encoding") - if ce in encoding.ENCODINGS: - self.ce = ce - else: - self.ce = None - - def __enter__(self): - if self.ce: - self.message.decode() - - def __exit__(self, type, value, tb): - if self.ce: - self.message.encode(self.ce) diff --git a/netlib/netlib/http/request.py b/netlib/netlib/http/request.py deleted file mode 100644 index b9076c0f..00000000 --- a/netlib/netlib/http/request.py +++ /dev/null @@ -1,356 +0,0 @@ -from __future__ import absolute_import, print_function, division - -import warnings - -import six -from six.moves import urllib - -from netlib import utils -from netlib.http import cookies -from netlib.odict import ODict -from .. import encoding -from .headers import Headers -from .message import Message, _native, _always_bytes, MessageData - - -class RequestData(MessageData): - def __init__(self, first_line_format, method, scheme, host, port, path, http_version, headers=None, content=None, - timestamp_start=None, timestamp_end=None): - if not isinstance(headers, Headers): - headers = Headers(headers) - - self.first_line_format = first_line_format - self.method = method - self.scheme = scheme - self.host = host - self.port = port - self.path = path - self.http_version = http_version - self.headers = headers - self.content = content - self.timestamp_start = timestamp_start - self.timestamp_end = timestamp_end - - -class Request(Message): - """ - An HTTP request. - """ - def __init__(self, *args, **kwargs): - data = RequestData(*args, **kwargs) - super(Request, self).__init__(data) - - def __repr__(self): - if self.host and self.port: - hostport = "{}:{}".format(self.host, self.port) - else: - hostport = "" - path = self.path or "" - return "Request({} {}{})".format( - self.method, hostport, path - ) - - @property - def first_line_format(self): - """ - HTTP request form as defined in `RFC7230 `_. - - origin-form and asterisk-form are subsumed as "relative". - """ - return self.data.first_line_format - - @first_line_format.setter - def first_line_format(self, first_line_format): - self.data.first_line_format = first_line_format - - @property - def method(self): - """ - HTTP request method, e.g. "GET". - """ - return _native(self.data.method).upper() - - @method.setter - def method(self, method): - self.data.method = _always_bytes(method) - - @property - def scheme(self): - """ - HTTP request scheme, which should be "http" or "https". - """ - return _native(self.data.scheme) - - @scheme.setter - def scheme(self, scheme): - self.data.scheme = _always_bytes(scheme) - - @property - def host(self): - """ - Target host. This may be parsed from the raw request - (e.g. from a ``GET http://example.com/ HTTP/1.1`` request line) - or inferred from the proxy mode (e.g. an IP in transparent mode). - - Setting the host attribute also updates the host header, if present. - """ - - if six.PY2: # pragma: nocover - return self.data.host - - if not self.data.host: - return self.data.host - try: - return self.data.host.decode("idna") - except UnicodeError: - return self.data.host.decode("utf8", "surrogateescape") - - @host.setter - def host(self, host): - if isinstance(host, six.text_type): - try: - # There's no non-strict mode for IDNA encoding. - # We don't want this operation to fail though, so we try - # utf8 as a last resort. - host = host.encode("idna", "strict") - except UnicodeError: - host = host.encode("utf8", "surrogateescape") - - self.data.host = host - - # Update host header - if "host" in self.headers: - if host: - self.headers["host"] = host - else: - self.headers.pop("host") - - @property - def port(self): - """ - Target port - """ - return self.data.port - - @port.setter - def port(self, port): - self.data.port = port - - @property - def path(self): - """ - HTTP request path, e.g. "/index.html". - Guaranteed to start with a slash. - """ - return _native(self.data.path) - - @path.setter - def path(self, path): - self.data.path = _always_bytes(path) - - @property - def url(self): - """ - The URL string, constructed from the request's URL components - """ - return utils.unparse_url(self.scheme, self.host, self.port, self.path) - - @url.setter - def url(self, url): - self.scheme, self.host, self.port, self.path = utils.parse_url(url) - - @property - def pretty_host(self): - """ - Similar to :py:attr:`host`, but using the Host headers as an additional preferred data source. - This is useful in transparent mode where :py:attr:`host` is only an IP address, - but may not reflect the actual destination as the Host header could be spoofed. - """ - return self.headers.get("host", self.host) - - @property - def pretty_url(self): - """ - Like :py:attr:`url`, but using :py:attr:`pretty_host` instead of :py:attr:`host`. - """ - if self.first_line_format == "authority": - return "%s:%d" % (self.pretty_host, self.port) - return utils.unparse_url(self.scheme, self.pretty_host, self.port, self.path) - - @property - def query(self): - """ - The request query string as an :py:class:`ODict` object. - None, if there is no query. - """ - _, _, _, _, query, _ = urllib.parse.urlparse(self.url) - if query: - return ODict(utils.urldecode(query)) - return None - - @query.setter - def query(self, odict): - query = utils.urlencode(odict.lst) - scheme, netloc, path, params, _, fragment = urllib.parse.urlparse(self.url) - _, _, _, self.path = utils.parse_url( - urllib.parse.urlunparse([scheme, netloc, path, params, query, fragment])) - - @property - def cookies(self): - """ - The request cookies. - An empty :py:class:`ODict` object if the cookie monster ate them all. - """ - ret = ODict() - for i in self.headers.get_all("Cookie"): - ret.extend(cookies.parse_cookie_header(i)) - return ret - - @cookies.setter - def cookies(self, odict): - self.headers["cookie"] = cookies.format_cookie_header(odict) - - @property - def path_components(self): - """ - The URL's path components as a list of strings. - Components are unquoted. - """ - _, _, path, _, _, _ = urllib.parse.urlparse(self.url) - return [urllib.parse.unquote(i) for i in path.split("/") if i] - - @path_components.setter - def path_components(self, components): - components = map(lambda x: urllib.parse.quote(x, safe=""), components) - path = "/" + "/".join(components) - scheme, netloc, _, params, query, fragment = urllib.parse.urlparse(self.url) - _, _, _, self.path = utils.parse_url( - urllib.parse.urlunparse([scheme, netloc, path, params, query, fragment])) - - def anticache(self): - """ - Modifies this request to remove headers that might produce a cached - response. That is, we remove ETags and If-Modified-Since headers. - """ - delheaders = [ - "if-modified-since", - "if-none-match", - ] - for i in delheaders: - self.headers.pop(i, None) - - def anticomp(self): - """ - Modifies this request to remove headers that will compress the - resource's data. - """ - self.headers["accept-encoding"] = "identity" - - def constrain_encoding(self): - """ - Limits the permissible Accept-Encoding values, based on what we can - decode appropriately. - """ - accept_encoding = self.headers.get("accept-encoding") - if accept_encoding: - self.headers["accept-encoding"] = ( - ', '.join( - e - for e in encoding.ENCODINGS - if e in accept_encoding - ) - ) - - @property - def urlencoded_form(self): - """ - The URL-encoded form data as an :py:class:`ODict` object. - None if there is no data or the content-type indicates non-form data. - """ - is_valid_content_type = "application/x-www-form-urlencoded" in self.headers.get("content-type", "").lower() - if self.content and is_valid_content_type: - return ODict(utils.urldecode(self.content)) - return None - - @urlencoded_form.setter - def urlencoded_form(self, odict): - """ - Sets the body to the URL-encoded form data, and adds the appropriate content-type header. - This will overwrite the existing content if there is one. - """ - self.headers["content-type"] = "application/x-www-form-urlencoded" - self.content = utils.urlencode(odict.lst) - - @property - def multipart_form(self): - """ - The multipart form data as an :py:class:`ODict` object. - None if there is no data or the content-type indicates non-form data. - """ - is_valid_content_type = "multipart/form-data" in self.headers.get("content-type", "").lower() - if self.content and is_valid_content_type: - return ODict(utils.multipartdecode(self.headers,self.content)) - return None - - @multipart_form.setter - def multipart_form(self, value): - raise NotImplementedError() - - # Legacy - - def get_cookies(self): # pragma: nocover - warnings.warn(".get_cookies is deprecated, use .cookies instead.", DeprecationWarning) - return self.cookies - - def set_cookies(self, odict): # pragma: nocover - warnings.warn(".set_cookies is deprecated, use .cookies instead.", DeprecationWarning) - self.cookies = odict - - def get_query(self): # pragma: nocover - warnings.warn(".get_query is deprecated, use .query instead.", DeprecationWarning) - return self.query or ODict([]) - - def set_query(self, odict): # pragma: nocover - warnings.warn(".set_query is deprecated, use .query instead.", DeprecationWarning) - self.query = odict - - def get_path_components(self): # pragma: nocover - warnings.warn(".get_path_components is deprecated, use .path_components instead.", DeprecationWarning) - return self.path_components - - def set_path_components(self, lst): # pragma: nocover - warnings.warn(".set_path_components is deprecated, use .path_components instead.", DeprecationWarning) - self.path_components = lst - - def get_form_urlencoded(self): # pragma: nocover - warnings.warn(".get_form_urlencoded is deprecated, use .urlencoded_form instead.", DeprecationWarning) - return self.urlencoded_form or ODict([]) - - def set_form_urlencoded(self, odict): # pragma: nocover - warnings.warn(".set_form_urlencoded is deprecated, use .urlencoded_form instead.", DeprecationWarning) - self.urlencoded_form = odict - - def get_form_multipart(self): # pragma: nocover - warnings.warn(".get_form_multipart is deprecated, use .multipart_form instead.", DeprecationWarning) - return self.multipart_form or ODict([]) - - @property - def form_in(self): # pragma: nocover - warnings.warn(".form_in is deprecated, use .first_line_format instead.", DeprecationWarning) - return self.first_line_format - - @form_in.setter - def form_in(self, form_in): # pragma: nocover - warnings.warn(".form_in is deprecated, use .first_line_format instead.", DeprecationWarning) - self.first_line_format = form_in - - @property - def form_out(self): # pragma: nocover - warnings.warn(".form_out is deprecated, use .first_line_format instead.", DeprecationWarning) - return self.first_line_format - - @form_out.setter - def form_out(self, form_out): # pragma: nocover - warnings.warn(".form_out is deprecated, use .first_line_format instead.", DeprecationWarning) - self.first_line_format = form_out - diff --git a/netlib/netlib/http/response.py b/netlib/netlib/http/response.py deleted file mode 100644 index 8f4d6215..00000000 --- a/netlib/netlib/http/response.py +++ /dev/null @@ -1,116 +0,0 @@ -from __future__ import absolute_import, print_function, division - -import warnings - -from . import cookies -from .headers import Headers -from .message import Message, _native, _always_bytes, MessageData -from .. import utils -from ..odict import ODict - - -class ResponseData(MessageData): - def __init__(self, http_version, status_code, reason=None, headers=None, content=None, - timestamp_start=None, timestamp_end=None): - if not isinstance(headers, Headers): - headers = Headers(headers) - - self.http_version = http_version - self.status_code = status_code - self.reason = reason - self.headers = headers - self.content = content - self.timestamp_start = timestamp_start - self.timestamp_end = timestamp_end - - -class Response(Message): - """ - An HTTP response. - """ - def __init__(self, *args, **kwargs): - data = ResponseData(*args, **kwargs) - super(Response, self).__init__(data) - - def __repr__(self): - if self.content: - details = "{}, {}".format( - self.headers.get("content-type", "unknown content type"), - utils.pretty_size(len(self.content)) - ) - else: - details = "no content" - return "Response({status_code} {reason}, {details})".format( - status_code=self.status_code, - reason=self.reason, - details=details - ) - - @property - def status_code(self): - """ - HTTP Status Code, e.g. ``200``. - """ - return self.data.status_code - - @status_code.setter - def status_code(self, status_code): - self.data.status_code = status_code - - @property - def reason(self): - """ - HTTP Reason Phrase, e.g. "Not Found". - This is always :py:obj:`None` for HTTP2 requests, because HTTP2 responses do not contain a reason phrase. - """ - return _native(self.data.reason) - - @reason.setter - def reason(self, reason): - self.data.reason = _always_bytes(reason) - - @property - def cookies(self): - """ - Get the contents of all Set-Cookie headers. - - A possibly empty :py:class:`ODict`, where keys are cookie name strings, - and values are [value, attr] lists. Value is a string, and attr is - an ODictCaseless containing cookie attributes. Within attrs, unary - attributes (e.g. HTTPOnly) are indicated by a Null value. - """ - ret = [] - for header in self.headers.get_all("set-cookie"): - v = cookies.parse_set_cookie_header(header) - if v: - name, value, attrs = v - ret.append([name, [value, attrs]]) - return ODict(ret) - - @cookies.setter - def cookies(self, odict): - values = [] - for i in odict.lst: - header = cookies.format_set_cookie_header(i[0], i[1][0], i[1][1]) - values.append(header) - self.headers.set_all("set-cookie", values) - - # Legacy - - def get_cookies(self): # pragma: nocover - warnings.warn(".get_cookies is deprecated, use .cookies instead.", DeprecationWarning) - return self.cookies - - def set_cookies(self, odict): # pragma: nocover - warnings.warn(".set_cookies is deprecated, use .cookies instead.", DeprecationWarning) - self.cookies = odict - - @property - def msg(self): # pragma: nocover - warnings.warn(".msg is deprecated, use .reason instead.", DeprecationWarning) - return self.reason - - @msg.setter - def msg(self, reason): # pragma: nocover - warnings.warn(".msg is deprecated, use .reason instead.", DeprecationWarning) - self.reason = reason diff --git a/netlib/netlib/http/status_codes.py b/netlib/netlib/http/status_codes.py deleted file mode 100644 index 8a4dc1f5..00000000 --- a/netlib/netlib/http/status_codes.py +++ /dev/null @@ -1,106 +0,0 @@ -from __future__ import absolute_import, print_function, division - -CONTINUE = 100 -SWITCHING = 101 -OK = 200 -CREATED = 201 -ACCEPTED = 202 -NON_AUTHORITATIVE_INFORMATION = 203 -NO_CONTENT = 204 -RESET_CONTENT = 205 -PARTIAL_CONTENT = 206 -MULTI_STATUS = 207 - -MULTIPLE_CHOICE = 300 -MOVED_PERMANENTLY = 301 -FOUND = 302 -SEE_OTHER = 303 -NOT_MODIFIED = 304 -USE_PROXY = 305 -TEMPORARY_REDIRECT = 307 - -BAD_REQUEST = 400 -UNAUTHORIZED = 401 -PAYMENT_REQUIRED = 402 -FORBIDDEN = 403 -NOT_FOUND = 404 -NOT_ALLOWED = 405 -NOT_ACCEPTABLE = 406 -PROXY_AUTH_REQUIRED = 407 -REQUEST_TIMEOUT = 408 -CONFLICT = 409 -GONE = 410 -LENGTH_REQUIRED = 411 -PRECONDITION_FAILED = 412 -REQUEST_ENTITY_TOO_LARGE = 413 -REQUEST_URI_TOO_LONG = 414 -UNSUPPORTED_MEDIA_TYPE = 415 -REQUESTED_RANGE_NOT_SATISFIABLE = 416 -EXPECTATION_FAILED = 417 -IM_A_TEAPOT = 418 - -INTERNAL_SERVER_ERROR = 500 -NOT_IMPLEMENTED = 501 -BAD_GATEWAY = 502 -SERVICE_UNAVAILABLE = 503 -GATEWAY_TIMEOUT = 504 -HTTP_VERSION_NOT_SUPPORTED = 505 -INSUFFICIENT_STORAGE_SPACE = 507 -NOT_EXTENDED = 510 - -RESPONSES = { - # 100 - CONTINUE: "Continue", - SWITCHING: "Switching Protocols", - - # 200 - OK: "OK", - CREATED: "Created", - ACCEPTED: "Accepted", - NON_AUTHORITATIVE_INFORMATION: "Non-Authoritative Information", - NO_CONTENT: "No Content", - RESET_CONTENT: "Reset Content.", - PARTIAL_CONTENT: "Partial Content", - MULTI_STATUS: "Multi-Status", - - # 300 - MULTIPLE_CHOICE: "Multiple Choices", - MOVED_PERMANENTLY: "Moved Permanently", - FOUND: "Found", - SEE_OTHER: "See Other", - NOT_MODIFIED: "Not Modified", - USE_PROXY: "Use Proxy", - # 306 not defined?? - TEMPORARY_REDIRECT: "Temporary Redirect", - - # 400 - BAD_REQUEST: "Bad Request", - UNAUTHORIZED: "Unauthorized", - PAYMENT_REQUIRED: "Payment Required", - FORBIDDEN: "Forbidden", - NOT_FOUND: "Not Found", - NOT_ALLOWED: "Method Not Allowed", - NOT_ACCEPTABLE: "Not Acceptable", - PROXY_AUTH_REQUIRED: "Proxy Authentication Required", - REQUEST_TIMEOUT: "Request Time-out", - CONFLICT: "Conflict", - GONE: "Gone", - LENGTH_REQUIRED: "Length Required", - PRECONDITION_FAILED: "Precondition Failed", - REQUEST_ENTITY_TOO_LARGE: "Request Entity Too Large", - REQUEST_URI_TOO_LONG: "Request-URI Too Long", - UNSUPPORTED_MEDIA_TYPE: "Unsupported Media Type", - REQUESTED_RANGE_NOT_SATISFIABLE: "Requested Range not satisfiable", - EXPECTATION_FAILED: "Expectation Failed", - IM_A_TEAPOT: "I'm a teapot", - - # 500 - INTERNAL_SERVER_ERROR: "Internal Server Error", - NOT_IMPLEMENTED: "Not Implemented", - BAD_GATEWAY: "Bad Gateway", - SERVICE_UNAVAILABLE: "Service Unavailable", - GATEWAY_TIMEOUT: "Gateway Time-out", - HTTP_VERSION_NOT_SUPPORTED: "HTTP Version not supported", - INSUFFICIENT_STORAGE_SPACE: "Insufficient Storage Space", - NOT_EXTENDED: "Not Extended" -} diff --git a/netlib/netlib/http/user_agents.py b/netlib/netlib/http/user_agents.py deleted file mode 100644 index e8681908..00000000 --- a/netlib/netlib/http/user_agents.py +++ /dev/null @@ -1,52 +0,0 @@ -from __future__ import (absolute_import, print_function, division) - -""" - A small collection of useful user-agent header strings. These should be - kept reasonably current to reflect common usage. -""" - -# pylint: line-too-long - -# A collection of (name, shortcut, string) tuples. - -UASTRINGS = [ - ("android", - "a", - "Mozilla/5.0 (Linux; U; Android 4.1.1; en-gb; Nexus 7 Build/JRO03D) AFL/01.04.02"), # noqa - ("blackberry", - "l", - "Mozilla/5.0 (BlackBerry; U; BlackBerry 9900; en) AppleWebKit/534.11+ (KHTML, like Gecko) Version/7.1.0.346 Mobile Safari/534.11+"), # noqa - ("bingbot", - "b", - "Mozilla/5.0 (compatible; bingbot/2.0; +http://www.bing.com/bingbot.htm)"), # noqa - ("chrome", - "c", - "Mozilla/5.0 (Windows NT 6.1; WOW64) AppleWebKit/537.1 (KHTML, like Gecko) Chrome/22.0.1207.1 Safari/537.1"), # noqa - ("firefox", - "f", - "Mozilla/5.0 (Windows NT 6.1; Win64; x64; rv:14.0) Gecko/20120405 Firefox/14.0a1"), # noqa - ("googlebot", - "g", - "Googlebot/2.1 (+http://www.googlebot.com/bot.html)"), # noqa - ("ie9", - "i", - "Mozilla/5.0 (Windows; U; MSIE 9.0; WIndows NT 9.0; en-US)"), # noqa - ("ipad", - "p", - "Mozilla/5.0 (iPad; CPU OS 5_1 like Mac OS X) AppleWebKit/534.46 (KHTML, like Gecko) Version/5.1 Mobile/9B176 Safari/7534.48.3"), # noqa - ("iphone", - "h", - "Mozilla/5.0 (iPhone; CPU iPhone OS 4_2_1 like Mac OS X) AppleWebKit/533.17.9 (KHTML, like Gecko) Version/5.0.2 Mobile/8C148a Safari/6533.18.5"), # noqa - ("safari", - "s", - "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_7_3) AppleWebKit/534.55.3 (KHTML, like Gecko) Version/5.1.3 Safari/534.53.10"), # noqa -] - - -def get_by_shortcut(s): - """ - Retrieve a user agent entry by shortcut. - """ - for i in UASTRINGS: - if s == i[1]: - return i diff --git a/netlib/netlib/odict.py b/netlib/netlib/odict.py deleted file mode 100644 index 1e6e381a..00000000 --- a/netlib/netlib/odict.py +++ /dev/null @@ -1,193 +0,0 @@ -from __future__ import (absolute_import, print_function, division) -import re -import copy -import six - -from .utils import Serializable - - -def safe_subn(pattern, repl, target, *args, **kwargs): - """ - There are Unicode conversion problems with re.subn. We try to smooth - that over by casting the pattern and replacement to strings. We really - need a better solution that is aware of the actual content ecoding. - """ - return re.subn(str(pattern), str(repl), target, *args, **kwargs) - - -class ODict(Serializable): - - """ - A dictionary-like object for managing ordered (key, value) data. Think - about it as a convenient interface to a list of (key, value) tuples. - """ - - def __init__(self, lst=None): - self.lst = lst or [] - - def _kconv(self, s): - return s - - def __eq__(self, other): - return self.lst == other.lst - - def __ne__(self, other): - return not self.__eq__(other) - - def __iter__(self): - return self.lst.__iter__() - - def __getitem__(self, k): - """ - Returns a list of values matching key. - """ - ret = [] - k = self._kconv(k) - for i in self.lst: - if self._kconv(i[0]) == k: - ret.append(i[1]) - return ret - - def keys(self): - return list(set([self._kconv(i[0]) for i in self.lst])) - - def _filter_lst(self, k, lst): - k = self._kconv(k) - new = [] - for i in lst: - if self._kconv(i[0]) != k: - new.append(i) - return new - - def __len__(self): - """ - Total number of (key, value) pairs. - """ - return len(self.lst) - - def __setitem__(self, k, valuelist): - """ - Sets the values for key k. If there are existing values for this - key, they are cleared. - """ - if isinstance(valuelist, six.text_type) or isinstance(valuelist, six.binary_type): - raise ValueError( - "Expected list of values instead of string. " - "Example: odict[b'Host'] = [b'www.example.com']" - ) - kc = self._kconv(k) - new = [] - for i in self.lst: - if self._kconv(i[0]) == kc: - if valuelist: - new.append([k, valuelist.pop(0)]) - else: - new.append(i) - while valuelist: - new.append([k, valuelist.pop(0)]) - self.lst = new - - def __delitem__(self, k): - """ - Delete all items matching k. - """ - self.lst = self._filter_lst(k, self.lst) - - def __contains__(self, k): - k = self._kconv(k) - for i in self.lst: - if self._kconv(i[0]) == k: - return True - return False - - def add(self, key, value, prepend=False): - if prepend: - self.lst.insert(0, [key, value]) - else: - self.lst.append([key, value]) - - def get(self, k, d=None): - if k in self: - return self[k] - else: - return d - - def get_first(self, k, d=None): - if k in self: - return self[k][0] - else: - return d - - def items(self): - return self.lst[:] - - def copy(self): - """ - Returns a copy of this object. - """ - lst = copy.deepcopy(self.lst) - return self.__class__(lst) - - def extend(self, other): - """ - Add the contents of other, preserving any duplicates. - """ - self.lst.extend(other.lst) - - def __repr__(self): - return repr(self.lst) - - def in_any(self, key, value, caseless=False): - """ - Do any of the values matching key contain value? - - If caseless is true, value comparison is case-insensitive. - """ - if caseless: - value = value.lower() - for i in self[key]: - if caseless: - i = i.lower() - if value in i: - return True - return False - - def replace(self, pattern, repl, *args, **kwargs): - """ - Replaces a regular expression pattern with repl in both keys and - values. Encoded content will be decoded before replacement, and - re-encoded afterwards. - - Returns the number of replacements made. - """ - nlst, count = [], 0 - for i in self.lst: - k, c = safe_subn(pattern, repl, i[0], *args, **kwargs) - count += c - v, c = safe_subn(pattern, repl, i[1], *args, **kwargs) - count += c - nlst.append([k, v]) - self.lst = nlst - return count - - # Implement the StateObject protocol from mitmproxy - def get_state(self): - return [tuple(i) for i in self.lst] - - def set_state(self, state): - self.lst = [list(i) for i in state] - - @classmethod - def from_state(cls, state): - return cls([list(i) for i in state]) - - -class ODictCaseless(ODict): - - """ - A variant of ODict with "caseless" keys. This version _preserves_ key - case, but does not consider case when setting or getting items. - """ - - def _kconv(self, s): - return s.lower() diff --git a/netlib/netlib/socks.py b/netlib/netlib/socks.py deleted file mode 100644 index 51ad1c63..00000000 --- a/netlib/netlib/socks.py +++ /dev/null @@ -1,176 +0,0 @@ -from __future__ import (absolute_import, print_function, division) -import struct -import array -import ipaddress -from . import tcp, utils - - -class SocksError(Exception): - def __init__(self, code, message): - super(SocksError, self).__init__(message) - self.code = code - - -VERSION = utils.BiDi( - SOCKS4=0x04, - SOCKS5=0x05 -) - -CMD = utils.BiDi( - CONNECT=0x01, - BIND=0x02, - UDP_ASSOCIATE=0x03 -) - -ATYP = utils.BiDi( - IPV4_ADDRESS=0x01, - DOMAINNAME=0x03, - IPV6_ADDRESS=0x04 -) - -REP = utils.BiDi( - SUCCEEDED=0x00, - GENERAL_SOCKS_SERVER_FAILURE=0x01, - CONNECTION_NOT_ALLOWED_BY_RULESET=0x02, - NETWORK_UNREACHABLE=0x03, - HOST_UNREACHABLE=0x04, - CONNECTION_REFUSED=0x05, - TTL_EXPIRED=0x06, - COMMAND_NOT_SUPPORTED=0x07, - ADDRESS_TYPE_NOT_SUPPORTED=0x08, -) - -METHOD = utils.BiDi( - NO_AUTHENTICATION_REQUIRED=0x00, - GSSAPI=0x01, - USERNAME_PASSWORD=0x02, - NO_ACCEPTABLE_METHODS=0xFF -) - - -class ClientGreeting(object): - __slots__ = ("ver", "methods") - - def __init__(self, ver, methods): - self.ver = ver - self.methods = array.array("B") - self.methods.extend(methods) - - def assert_socks5(self): - if self.ver != VERSION.SOCKS5: - if self.ver == ord("G") and len(self.methods) == ord("E"): - guess = "Probably not a SOCKS request but a regular HTTP request. " - else: - guess = "" - - raise SocksError( - REP.GENERAL_SOCKS_SERVER_FAILURE, - guess + "Invalid SOCKS version. Expected 0x05, got 0x%x" % self.ver - ) - - @classmethod - def from_file(cls, f, fail_early=False): - """ - :param fail_early: If true, a SocksError will be raised if the first byte does not indicate socks5. - """ - ver, nmethods = struct.unpack("!BB", f.safe_read(2)) - client_greeting = cls(ver, []) - if fail_early: - client_greeting.assert_socks5() - client_greeting.methods.fromstring(f.safe_read(nmethods)) - return client_greeting - - def to_file(self, f): - f.write(struct.pack("!BB", self.ver, len(self.methods))) - f.write(self.methods.tostring()) - - -class ServerGreeting(object): - __slots__ = ("ver", "method") - - def __init__(self, ver, method): - self.ver = ver - self.method = method - - def assert_socks5(self): - if self.ver != VERSION.SOCKS5: - if self.ver == ord("H") and self.method == ord("T"): - guess = "Probably not a SOCKS request but a regular HTTP response. " - else: - guess = "" - - raise SocksError( - REP.GENERAL_SOCKS_SERVER_FAILURE, - guess + "Invalid SOCKS version. Expected 0x05, got 0x%x" % self.ver - ) - - @classmethod - def from_file(cls, f): - ver, method = struct.unpack("!BB", f.safe_read(2)) - return cls(ver, method) - - def to_file(self, f): - f.write(struct.pack("!BB", self.ver, self.method)) - - -class Message(object): - __slots__ = ("ver", "msg", "atyp", "addr") - - def __init__(self, ver, msg, atyp, addr): - self.ver = ver - self.msg = msg - self.atyp = atyp - self.addr = tcp.Address.wrap(addr) - - def assert_socks5(self): - if self.ver != VERSION.SOCKS5: - raise SocksError( - REP.GENERAL_SOCKS_SERVER_FAILURE, - "Invalid SOCKS version. Expected 0x05, got 0x%x" % self.ver - ) - - @classmethod - def from_file(cls, f): - ver, msg, rsv, atyp = struct.unpack("!BBBB", f.safe_read(4)) - if rsv != 0x00: - raise SocksError( - REP.GENERAL_SOCKS_SERVER_FAILURE, - "Socks Request: Invalid reserved byte: %s" % rsv - ) - if atyp == ATYP.IPV4_ADDRESS: - # We use tnoa here as ntop is not commonly available on Windows. - host = ipaddress.IPv4Address(f.safe_read(4)).compressed - use_ipv6 = False - elif atyp == ATYP.IPV6_ADDRESS: - host = ipaddress.IPv6Address(f.safe_read(16)).compressed - use_ipv6 = True - elif atyp == ATYP.DOMAINNAME: - length, = struct.unpack("!B", f.safe_read(1)) - host = f.safe_read(length) - if not utils.is_valid_host(host): - raise SocksError(REP.GENERAL_SOCKS_SERVER_FAILURE, "Invalid hostname: %s" % host) - host = host.decode("idna") - use_ipv6 = False - else: - raise SocksError(REP.ADDRESS_TYPE_NOT_SUPPORTED, - "Socks Request: Unknown ATYP: %s" % atyp) - - port, = struct.unpack("!H", f.safe_read(2)) - addr = tcp.Address((host, port), use_ipv6=use_ipv6) - return cls(ver, msg, atyp, addr) - - def to_file(self, f): - f.write(struct.pack("!BBBB", self.ver, self.msg, 0x00, self.atyp)) - if self.atyp == ATYP.IPV4_ADDRESS: - f.write(ipaddress.IPv4Address(self.addr.host).packed) - elif self.atyp == ATYP.IPV6_ADDRESS: - f.write(ipaddress.IPv6Address(self.addr.host).packed) - elif self.atyp == ATYP.DOMAINNAME: - f.write(struct.pack("!B", len(self.addr.host))) - f.write(self.addr.host.encode("idna")) - else: - raise SocksError( - REP.ADDRESS_TYPE_NOT_SUPPORTED, - "Unknown ATYP: %s" % self.atyp - ) - f.write(struct.pack("!H", self.addr.port)) diff --git a/netlib/netlib/tcp.py b/netlib/netlib/tcp.py deleted file mode 100644 index 61b41cdc..00000000 --- a/netlib/netlib/tcp.py +++ /dev/null @@ -1,911 +0,0 @@ -from __future__ import (absolute_import, print_function, division) -import os -import select -import socket -import sys -import threading -import time -import traceback - -import binascii -from six.moves import range - -import certifi -from backports import ssl_match_hostname -import six -import OpenSSL -from OpenSSL import SSL - -from . import certutils, version_check, utils - -# This is a rather hackish way to make sure that -# the latest version of pyOpenSSL is actually installed. -from netlib.exceptions import InvalidCertificateException, TcpReadIncomplete, TlsException, \ - TcpTimeout, TcpDisconnect, TcpException - -version_check.check_pyopenssl_version() - -if six.PY2: - socket_fileobject = socket._fileobject -else: - socket_fileobject = socket.SocketIO - -EINTR = 4 -if os.environ.get("NO_ALPN"): - HAS_ALPN = False -else: - HAS_ALPN = OpenSSL._util.lib.Cryptography_HAS_ALPN - -# To enable all SSL methods use: SSLv23 -# then add options to disable certain methods -# https://bugs.launchpad.net/pyopenssl/+bug/1020632/comments/3 -SSL_BASIC_OPTIONS = ( - SSL.OP_CIPHER_SERVER_PREFERENCE -) -if hasattr(SSL, "OP_NO_COMPRESSION"): - SSL_BASIC_OPTIONS |= SSL.OP_NO_COMPRESSION - -SSL_DEFAULT_METHOD = SSL.SSLv23_METHOD -SSL_DEFAULT_OPTIONS = ( - SSL.OP_NO_SSLv2 | - SSL.OP_NO_SSLv3 | - SSL_BASIC_OPTIONS -) -if hasattr(SSL, "OP_NO_COMPRESSION"): - SSL_DEFAULT_OPTIONS |= SSL.OP_NO_COMPRESSION - -""" -Map a reasonable SSL version specification into the format OpenSSL expects. -Don't ask... -https://bugs.launchpad.net/pyopenssl/+bug/1020632/comments/3 -""" -sslversion_choices = { - "all": (SSL.SSLv23_METHOD, SSL_BASIC_OPTIONS), - # SSLv23_METHOD + NO_SSLv2 + NO_SSLv3 == TLS 1.0+ - # TLSv1_METHOD would be TLS 1.0 only - "secure": (SSL.SSLv23_METHOD, (SSL.OP_NO_SSLv2 | SSL.OP_NO_SSLv3 | SSL_BASIC_OPTIONS)), - "SSLv2": (SSL.SSLv2_METHOD, SSL_BASIC_OPTIONS), - "SSLv3": (SSL.SSLv3_METHOD, SSL_BASIC_OPTIONS), - "TLSv1": (SSL.TLSv1_METHOD, SSL_BASIC_OPTIONS), - "TLSv1_1": (SSL.TLSv1_1_METHOD, SSL_BASIC_OPTIONS), - "TLSv1_2": (SSL.TLSv1_2_METHOD, SSL_BASIC_OPTIONS), -} - -class SSLKeyLogger(object): - - def __init__(self, filename): - self.filename = filename - self.f = None - self.lock = threading.Lock() - - # required for functools.wraps, which pyOpenSSL uses. - __name__ = "SSLKeyLogger" - - def __call__(self, connection, where, ret): - if where == SSL.SSL_CB_HANDSHAKE_DONE and ret == 1: - with self.lock: - if not self.f: - d = os.path.dirname(self.filename) - if not os.path.isdir(d): - os.makedirs(d) - self.f = open(self.filename, "ab") - self.f.write(b"\r\n") - client_random = binascii.hexlify(connection.client_random()) - masterkey = binascii.hexlify(connection.master_key()) - self.f.write(b"CLIENT_RANDOM %s %s\r\n" % (client_random, masterkey)) - self.f.flush() - - def close(self): - with self.lock: - if self.f: - self.f.close() - - @staticmethod - def create_logfun(filename): - if filename: - return SSLKeyLogger(filename) - return False - -log_ssl_key = SSLKeyLogger.create_logfun( - os.getenv("MITMPROXY_SSLKEYLOGFILE") or os.getenv("SSLKEYLOGFILE")) - - -class _FileLike(object): - BLOCKSIZE = 1024 * 32 - - def __init__(self, o): - self.o = o - self._log = None - self.first_byte_timestamp = None - - def set_descriptor(self, o): - self.o = o - - def __getattr__(self, attr): - return getattr(self.o, attr) - - def start_log(self): - """ - Starts or resets the log. - - This will store all bytes read or written. - """ - self._log = [] - - def stop_log(self): - """ - Stops the log. - """ - self._log = None - - def is_logging(self): - return self._log is not None - - def get_log(self): - """ - Returns the log as a string. - """ - if not self.is_logging(): - raise ValueError("Not logging!") - return b"".join(self._log) - - def add_log(self, v): - if self.is_logging(): - self._log.append(v) - - def reset_timestamps(self): - self.first_byte_timestamp = None - - -class Writer(_FileLike): - - def flush(self): - """ - May raise TcpDisconnect - """ - if hasattr(self.o, "flush"): - try: - self.o.flush() - except (socket.error, IOError) as v: - raise TcpDisconnect(str(v)) - - def write(self, v): - """ - May raise TcpDisconnect - """ - if v: - self.first_byte_timestamp = self.first_byte_timestamp or time.time() - try: - if hasattr(self.o, "sendall"): - self.add_log(v) - return self.o.sendall(v) - else: - r = self.o.write(v) - self.add_log(v[:r]) - return r - except (SSL.Error, socket.error) as e: - raise TcpDisconnect(str(e)) - - -class Reader(_FileLike): - - def read(self, length): - """ - If length is -1, we read until connection closes. - """ - result = b'' - start = time.time() - while length == -1 or length > 0: - if length == -1 or length > self.BLOCKSIZE: - rlen = self.BLOCKSIZE - else: - rlen = length - try: - data = self.o.read(rlen) - except SSL.ZeroReturnError: - # TLS connection was shut down cleanly - break - except (SSL.WantWriteError, SSL.WantReadError): - # From the OpenSSL docs: - # If the underlying BIO is non-blocking, SSL_read() will also return when the - # underlying BIO could not satisfy the needs of SSL_read() to continue the - # operation. In this case a call to SSL_get_error with the return value of - # SSL_read() will yield SSL_ERROR_WANT_READ or SSL_ERROR_WANT_WRITE. - if (time.time() - start) < self.o.gettimeout(): - time.sleep(0.1) - continue - else: - raise TcpTimeout() - except socket.timeout: - raise TcpTimeout() - except socket.error as e: - raise TcpDisconnect(str(e)) - except SSL.SysCallError as e: - if e.args == (-1, 'Unexpected EOF'): - break - raise TlsException(str(e)) - except SSL.Error as e: - raise TlsException(str(e)) - self.first_byte_timestamp = self.first_byte_timestamp or time.time() - if not data: - break - result += data - if length != -1: - length -= len(data) - self.add_log(result) - return result - - def readline(self, size=None): - result = b'' - bytes_read = 0 - while True: - if size is not None and bytes_read >= size: - break - ch = self.read(1) - bytes_read += 1 - if not ch: - break - else: - result += ch - if ch == b'\n': - break - return result - - def safe_read(self, length): - """ - Like .read, but is guaranteed to either return length bytes, or - raise an exception. - """ - result = self.read(length) - if length != -1 and len(result) != length: - if not result: - raise TcpDisconnect() - else: - raise TcpReadIncomplete( - "Expected %s bytes, got %s" % (length, len(result)) - ) - return result - - def peek(self, length): - """ - Tries to peek into the underlying file object. - - Returns: - Up to the next N bytes if peeking is successful. - - Raises: - TcpException if there was an error with the socket - TlsException if there was an error with pyOpenSSL. - NotImplementedError if the underlying file object is not a [pyOpenSSL] socket - """ - if isinstance(self.o, socket_fileobject): - try: - return self.o._sock.recv(length, socket.MSG_PEEK) - except socket.error as e: - raise TcpException(repr(e)) - elif isinstance(self.o, SSL.Connection): - try: - if tuple(int(x) for x in OpenSSL.__version__.split(".")[:2]) > (0, 15): - return self.o.recv(length, socket.MSG_PEEK) - else: - # TODO: remove once a new version is released - # Polyfill for pyOpenSSL <= 0.15.1 - # Taken from https://github.com/pyca/pyopenssl/commit/1d95dea7fea03c7c0df345a5ea30c12d8a0378d2 - buf = SSL._ffi.new("char[]", length) - result = SSL._lib.SSL_peek(self.o._ssl, buf, length) - self.o._raise_ssl_error(self.o._ssl, result) - return SSL._ffi.buffer(buf, result)[:] - except SSL.Error as e: - six.reraise(TlsException, TlsException(str(e)), sys.exc_info()[2]) - else: - raise NotImplementedError("Can only peek into (pyOpenSSL) sockets") - - -class Address(utils.Serializable): - - """ - This class wraps an IPv4/IPv6 tuple to provide named attributes and - ipv6 information. - """ - - def __init__(self, address, use_ipv6=False): - self.address = tuple(address) - self.use_ipv6 = use_ipv6 - - def get_state(self): - return { - "address": self.address, - "use_ipv6": self.use_ipv6 - } - - def set_state(self, state): - self.address = state["address"] - self.use_ipv6 = state["use_ipv6"] - - @classmethod - def from_state(cls, state): - return Address(**state) - - @classmethod - def wrap(cls, t): - if isinstance(t, cls): - return t - else: - return cls(t) - - def __call__(self): - return self.address - - @property - def host(self): - return self.address[0] - - @property - def port(self): - return self.address[1] - - @property - def use_ipv6(self): - return self.family == socket.AF_INET6 - - @use_ipv6.setter - def use_ipv6(self, b): - self.family = socket.AF_INET6 if b else socket.AF_INET - - def __repr__(self): - return "{}:{}".format(self.host, self.port) - - def __str__(self): - return str(self.address) - - def __eq__(self, other): - if not other: - return False - other = Address.wrap(other) - return (self.address, self.family) == (other.address, other.family) - - def __ne__(self, other): - return not self.__eq__(other) - - def __hash__(self): - return hash(self.address) ^ 42 # different hash than the tuple alone. - - -def ssl_read_select(rlist, timeout): - """ - This is a wrapper around select.select() which also works for SSL.Connections - by taking ssl_connection.pending() into account. - - Caveats: - If .pending() > 0 for any of the connections in rlist, we avoid the select syscall - and **will not include any other connections which may or may not be ready**. - - Args: - rlist: wait until ready for reading - - Returns: - subset of rlist which is ready for reading. - """ - return [ - conn for conn in rlist - if isinstance(conn, SSL.Connection) and conn.pending() > 0 - ] or select.select(rlist, (), (), timeout)[0] - - -def close_socket(sock): - """ - Does a hard close of a socket, without emitting a RST. - """ - try: - # We already indicate that we close our end. - # may raise "Transport endpoint is not connected" on Linux - sock.shutdown(socket.SHUT_WR) - - # Section 4.2.2.13 of RFC 1122 tells us that a close() with any pending - # readable data could lead to an immediate RST being sent (which is the - # case on Windows). - # http://ia600609.us.archive.org/22/items/TheUltimateSo_lingerPageOrWhyIsMyTcpNotReliable/the-ultimate-so_linger-page-or-why-is-my-tcp-not-reliable.html - # - # This in turn results in the following issue: If we send an error page - # to the client and then close the socket, the RST may be received by - # the client before the error page and the users sees a connection - # error rather than the error page. Thus, we try to empty the read - # buffer on Windows first. (see - # https://github.com/mitmproxy/mitmproxy/issues/527#issuecomment-93782988) - # - - if os.name == "nt": # pragma: no cover - # We cannot rely on the shutdown()-followed-by-read()-eof technique - # proposed by the page above: Some remote machines just don't send - # a TCP FIN, which would leave us in the unfortunate situation that - # recv() would block infinitely. As a workaround, we set a timeout - # here even if we are in blocking mode. - sock.settimeout(sock.gettimeout() or 20) - - # limit at a megabyte so that we don't read infinitely - for _ in range(1024 ** 3 // 4096): - # may raise a timeout/disconnect exception. - if not sock.recv(4096): - break - - # Now we can close the other half as well. - sock.shutdown(socket.SHUT_RD) - - except socket.error: - pass - - sock.close() - - -class _Connection(object): - - rbufsize = -1 - wbufsize = -1 - - def _makefile(self): - """ - Set up .rfile and .wfile attributes from .connection - """ - # Ideally, we would use the Buffered IO in Python 3 by default. - # Unfortunately, the implementation of .peek() is broken for n>1 bytes, - # as it may just return what's left in the buffer and not all the bytes we want. - # As a workaround, we just use unbuffered sockets directly. - # https://mail.python.org/pipermail/python-dev/2009-June/089986.html - if six.PY2: - self.rfile = Reader(self.connection.makefile('rb', self.rbufsize)) - self.wfile = Writer(self.connection.makefile('wb', self.wbufsize)) - else: - self.rfile = Reader(socket.SocketIO(self.connection, "rb")) - self.wfile = Writer(socket.SocketIO(self.connection, "wb")) - - def __init__(self, connection): - if connection: - self.connection = connection - self._makefile() - else: - self.connection = None - self.rfile = None - self.wfile = None - - self.ssl_established = False - self.finished = False - - def get_current_cipher(self): - if not self.ssl_established: - return None - - name = self.connection.get_cipher_name() - bits = self.connection.get_cipher_bits() - version = self.connection.get_cipher_version() - return name, bits, version - - def finish(self): - self.finished = True - # If we have an SSL connection, wfile.close == connection.close - # (We call _FileLike.set_descriptor(conn)) - # Closing the socket is not our task, therefore we don't call close - # then. - if not isinstance(self.connection, SSL.Connection): - if not getattr(self.wfile, "closed", False): - try: - self.wfile.flush() - self.wfile.close() - except TcpDisconnect: - pass - - self.rfile.close() - else: - try: - self.connection.shutdown() - except SSL.Error: - pass - - def _create_ssl_context(self, - method=SSL_DEFAULT_METHOD, - options=SSL_DEFAULT_OPTIONS, - verify_options=SSL.VERIFY_NONE, - ca_path=None, - ca_pemfile=None, - cipher_list=None, - alpn_protos=None, - alpn_select=None, - alpn_select_callback=None, - ): - """ - Creates an SSL Context. - - :param method: One of SSLv2_METHOD, SSLv3_METHOD, SSLv23_METHOD, TLSv1_METHOD, TLSv1_1_METHOD, or TLSv1_2_METHOD - :param options: A bit field consisting of OpenSSL.SSL.OP_* values - :param verify_options: A bit field consisting of OpenSSL.SSL.VERIFY_* values - :param ca_path: Path to a directory of trusted CA certificates prepared using the c_rehash tool - :param ca_pemfile: Path to a PEM formatted trusted CA certificate - :param cipher_list: A textual OpenSSL cipher list, see https://www.openssl.org/docs/apps/ciphers.html - :rtype : SSL.Context - """ - context = SSL.Context(method) - # Options (NO_SSLv2/3) - if options is not None: - context.set_options(options) - - # Verify Options (NONE/PEER and trusted CAs) - if verify_options is not None: - def verify_cert(conn, x509, errno, err_depth, is_cert_verified): - if not is_cert_verified: - self.ssl_verification_error = dict(errno=errno, - depth=err_depth) - return is_cert_verified - - context.set_verify(verify_options, verify_cert) - if ca_path is None and ca_pemfile is None: - ca_pemfile = certifi.where() - context.load_verify_locations(ca_pemfile, ca_path) - - # Workaround for - # https://github.com/pyca/pyopenssl/issues/190 - # https://github.com/mitmproxy/mitmproxy/issues/472 - # Options already set before are not cleared. - context.set_mode(SSL._lib.SSL_MODE_AUTO_RETRY) - - # Cipher List - if cipher_list: - try: - context.set_cipher_list(cipher_list) - - # TODO: maybe change this to with newer pyOpenSSL APIs - context.set_tmp_ecdh(OpenSSL.crypto.get_elliptic_curve('prime256v1')) - except SSL.Error as v: - raise TlsException("SSL cipher specification error: %s" % str(v)) - - # SSLKEYLOGFILE - if log_ssl_key: - context.set_info_callback(log_ssl_key) - - if HAS_ALPN: - if alpn_protos is not None: - # advertise application layer protocols - context.set_alpn_protos(alpn_protos) - elif alpn_select is not None and alpn_select_callback is None: - # select application layer protocol - def alpn_select_callback(conn_, options): - if alpn_select in options: - return bytes(alpn_select) - else: # pragma no cover - return options[0] - context.set_alpn_select_callback(alpn_select_callback) - elif alpn_select_callback is not None and alpn_select is None: - context.set_alpn_select_callback(alpn_select_callback) - elif alpn_select_callback is not None and alpn_select is not None: - raise TlsException("ALPN error: only define alpn_select (string) OR alpn_select_callback (method).") - - return context - - -class TCPClient(_Connection): - - def __init__(self, address, source_address=None): - super(TCPClient, self).__init__(None) - self.address = address - self.source_address = source_address - self.cert = None - self.ssl_verification_error = None - self.sni = None - - @property - def address(self): - return self.__address - - @address.setter - def address(self, address): - if address: - self.__address = Address.wrap(address) - else: - self.__address = None - - @property - def source_address(self): - return self.__source_address - - @source_address.setter - def source_address(self, source_address): - if source_address: - self.__source_address = Address.wrap(source_address) - else: - self.__source_address = None - - def close(self): - # Make sure to close the real socket, not the SSL proxy. - # OpenSSL is really good at screwing up, i.e. when trying to recv from a failed connection, - # it tries to renegotiate... - if isinstance(self.connection, SSL.Connection): - close_socket(self.connection._socket) - else: - close_socket(self.connection) - - def create_ssl_context(self, cert=None, alpn_protos=None, **sslctx_kwargs): - context = self._create_ssl_context( - alpn_protos=alpn_protos, - **sslctx_kwargs) - # Client Certs - if cert: - try: - context.use_privatekey_file(cert) - context.use_certificate_file(cert) - except SSL.Error as v: - raise TlsException("SSL client certificate error: %s" % str(v)) - return context - - def convert_to_ssl(self, sni=None, alpn_protos=None, **sslctx_kwargs): - """ - cert: Path to a file containing both client cert and private key. - - options: A bit field consisting of OpenSSL.SSL.OP_* values - verify_options: A bit field consisting of OpenSSL.SSL.VERIFY_* values - ca_path: Path to a directory of trusted CA certificates prepared using the c_rehash tool - ca_pemfile: Path to a PEM formatted trusted CA certificate - """ - verification_mode = sslctx_kwargs.get('verify_options', None) - if verification_mode == SSL.VERIFY_PEER and not sni: - raise TlsException("Cannot validate certificate hostname without SNI") - - context = self.create_ssl_context( - alpn_protos=alpn_protos, - **sslctx_kwargs - ) - self.connection = SSL.Connection(context, self.connection) - if sni: - self.sni = sni - self.connection.set_tlsext_host_name(sni) - self.connection.set_connect_state() - try: - self.connection.do_handshake() - except SSL.Error as v: - if self.ssl_verification_error: - raise InvalidCertificateException("SSL handshake error: %s" % repr(v)) - else: - raise TlsException("SSL handshake error: %s" % repr(v)) - else: - # Fix for pre v1.0 OpenSSL, which doesn't throw an exception on - # certificate validation failure - if verification_mode == SSL.VERIFY_PEER and self.ssl_verification_error is not None: - raise InvalidCertificateException("SSL handshake error: certificate verify failed") - - self.cert = certutils.SSLCert(self.connection.get_peer_certificate()) - - # Validate TLS Hostname - try: - crt = dict( - subjectAltName=[("DNS", x.decode("ascii", "strict")) for x in self.cert.altnames] - ) - if self.cert.cn: - crt["subject"] = [[["commonName", self.cert.cn.decode("ascii", "strict")]]] - if sni: - hostname = sni.decode("ascii", "strict") - else: - hostname = "no-hostname" - ssl_match_hostname.match_hostname(crt, hostname) - except (ValueError, ssl_match_hostname.CertificateError) as e: - self.ssl_verification_error = dict(depth=0, errno="Invalid Hostname") - if verification_mode == SSL.VERIFY_PEER: - raise InvalidCertificateException("Presented certificate for {} is not valid: {}".format(sni, str(e))) - - self.ssl_established = True - self.rfile.set_descriptor(self.connection) - self.wfile.set_descriptor(self.connection) - - def connect(self): - try: - connection = socket.socket(self.address.family, socket.SOCK_STREAM) - if self.source_address: - connection.bind(self.source_address()) - connection.connect(self.address()) - if not self.source_address: - self.source_address = Address(connection.getsockname()) - except (socket.error, IOError) as err: - raise TcpException( - 'Error connecting to "%s": %s' % - (self.address.host, err)) - self.connection = connection - self._makefile() - - def settimeout(self, n): - self.connection.settimeout(n) - - def gettimeout(self): - return self.connection.gettimeout() - - def get_alpn_proto_negotiated(self): - if HAS_ALPN and self.ssl_established: - return self.connection.get_alpn_proto_negotiated() - else: - return b"" - - -class BaseHandler(_Connection): - - """ - The instantiator is expected to call the handle() and finish() methods. - """ - - def __init__(self, connection, address, server): - super(BaseHandler, self).__init__(connection) - self.address = Address.wrap(address) - self.server = server - self.clientcert = None - - def create_ssl_context(self, - cert, key, - handle_sni=None, - request_client_cert=None, - chain_file=None, - dhparams=None, - **sslctx_kwargs): - """ - cert: A certutils.SSLCert object or the path to a certificate - chain file. - - handle_sni: SNI handler, should take a connection object. Server - name can be retrieved like this: - - connection.get_servername() - - And you can specify the connection keys as follows: - - new_context = Context(TLSv1_METHOD) - new_context.use_privatekey(key) - new_context.use_certificate(cert) - connection.set_context(new_context) - - The request_client_cert argument requires some explanation. We're - supposed to be able to do this with no negative effects - if the - client has no cert to present, we're notified and proceed as usual. - Unfortunately, Android seems to have a bug (tested on 4.2.2) - when - an Android client is asked to present a certificate it does not - have, it hangs up, which is frankly bogus. Some time down the track - we may be able to make the proper behaviour the default again, but - until then we're conservative. - """ - - context = self._create_ssl_context(**sslctx_kwargs) - - context.use_privatekey(key) - if isinstance(cert, certutils.SSLCert): - context.use_certificate(cert.x509) - else: - context.use_certificate_chain_file(cert) - - if handle_sni: - # SNI callback happens during do_handshake() - context.set_tlsext_servername_callback(handle_sni) - - if request_client_cert: - def save_cert(conn_, cert, errno_, depth_, preverify_ok_): - self.clientcert = certutils.SSLCert(cert) - # Return true to prevent cert verification error - return True - context.set_verify(SSL.VERIFY_PEER, save_cert) - - # Cert Verify - if chain_file: - context.load_verify_locations(chain_file) - - if dhparams: - SSL._lib.SSL_CTX_set_tmp_dh(context._context, dhparams) - - return context - - def convert_to_ssl(self, cert, key, **sslctx_kwargs): - """ - Convert connection to SSL. - For a list of parameters, see BaseHandler._create_ssl_context(...) - """ - - context = self.create_ssl_context( - cert, - key, - **sslctx_kwargs) - self.connection = SSL.Connection(context, self.connection) - self.connection.set_accept_state() - try: - self.connection.do_handshake() - except SSL.Error as v: - raise TlsException("SSL handshake error: %s" % repr(v)) - self.ssl_established = True - self.rfile.set_descriptor(self.connection) - self.wfile.set_descriptor(self.connection) - - def handle(self): # pragma: no cover - raise NotImplementedError - - def settimeout(self, n): - self.connection.settimeout(n) - - def get_alpn_proto_negotiated(self): - if HAS_ALPN and self.ssl_established: - return self.connection.get_alpn_proto_negotiated() - else: - return b"" - - -class TCPServer(object): - request_queue_size = 20 - - def __init__(self, address): - self.address = Address.wrap(address) - self.__is_shut_down = threading.Event() - self.__shutdown_request = False - self.socket = socket.socket(self.address.family, socket.SOCK_STREAM) - self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - self.socket.bind(self.address()) - self.address = Address.wrap(self.socket.getsockname()) - self.socket.listen(self.request_queue_size) - - def connection_thread(self, connection, client_address): - client_address = Address(client_address) - try: - self.handle_client_connection(connection, client_address) - except: - self.handle_error(connection, client_address) - finally: - close_socket(connection) - - def serve_forever(self, poll_interval=0.1): - self.__is_shut_down.clear() - try: - while not self.__shutdown_request: - try: - r, w_, e_ = select.select( - [self.socket], [], [], poll_interval) - except select.error as ex: # pragma: no cover - if ex[0] == EINTR: - continue - else: - raise - if self.socket in r: - connection, client_address = self.socket.accept() - t = threading.Thread( - target=self.connection_thread, - args=(connection, client_address), - name="ConnectionThread (%s:%s -> %s:%s)" % - (client_address[0], client_address[1], - self.address.host, self.address.port) - ) - t.setDaemon(1) - try: - t.start() - except threading.ThreadError: - self.handle_error(connection, Address(client_address)) - connection.close() - finally: - self.__shutdown_request = False - self.__is_shut_down.set() - - def shutdown(self): - self.__shutdown_request = True - self.__is_shut_down.wait() - self.socket.close() - self.handle_shutdown() - - def handle_error(self, connection_, client_address, fp=sys.stderr): - """ - Called when handle_client_connection raises an exception. - """ - # If a thread has persisted after interpreter exit, the module might be - # none. - if traceback: - exc = six.text_type(traceback.format_exc()) - print(u'-' * 40, file=fp) - print( - u"Error in processing of request from %s" % repr(client_address), file=fp) - print(exc, file=fp) - print(u'-' * 40, file=fp) - - def handle_client_connection(self, conn, client_address): # pragma: no cover - """ - Called after client connection. - """ - raise NotImplementedError - - def handle_shutdown(self): - """ - Called after server shutdown. - """ diff --git a/netlib/netlib/tutils.py b/netlib/netlib/tutils.py deleted file mode 100644 index f6ce8e0a..00000000 --- a/netlib/netlib/tutils.py +++ /dev/null @@ -1,133 +0,0 @@ -from io import BytesIO -import tempfile -import os -import time -import shutil -from contextlib import contextmanager -import six -import sys - -from . import utils, tcp -from .http import Request, Response, Headers - - -def treader(bytes): - """ - Construct a tcp.Read object from bytes. - """ - fp = BytesIO(bytes) - return tcp.Reader(fp) - - -@contextmanager -def tmpdir(*args, **kwargs): - orig_workdir = os.getcwd() - temp_workdir = tempfile.mkdtemp(*args, **kwargs) - os.chdir(temp_workdir) - - yield temp_workdir - - os.chdir(orig_workdir) - shutil.rmtree(temp_workdir) - - -def _check_exception(expected, actual, exc_tb): - if isinstance(expected, six.string_types): - if expected.lower() not in str(actual).lower(): - six.reraise(AssertionError, AssertionError( - "Expected %s, but caught %s" % ( - repr(expected), repr(actual) - ) - ), exc_tb) - else: - if not isinstance(actual, expected): - six.reraise(AssertionError, AssertionError( - "Expected %s, but caught %s %s" % ( - expected.__name__, actual.__class__.__name__, repr(actual) - ) - ), exc_tb) - - -def raises(expected_exception, obj=None, *args, **kwargs): - """ - Assert that a callable raises a specified exception. - - :exc An exception class or a string. If a class, assert that an - exception of this type is raised. If a string, assert that the string - occurs in the string representation of the exception, based on a - case-insenstivie match. - - :obj A callable object. - - :args Arguments to be passsed to the callable. - - :kwargs Arguments to be passed to the callable. - """ - if obj is None: - return RaisesContext(expected_exception) - else: - try: - ret = obj(*args, **kwargs) - except Exception as actual: - _check_exception(expected_exception, actual, sys.exc_info()[2]) - else: - raise AssertionError("No exception raised. Return value: {}".format(ret)) - - -class RaisesContext(object): - def __init__(self, expected_exception): - self.expected_exception = expected_exception - - def __enter__(self): - return - - def __exit__(self, exc_type, exc_val, exc_tb): - if not exc_type: - raise AssertionError("No exception raised.") - else: - _check_exception(self.expected_exception, exc_val, exc_tb) - return True - - -test_data = utils.Data(__name__) -# FIXME: Temporary workaround during repo merge. -import os -test_data.dirname = os.path.join(test_data.dirname,"..","..","test","netlib") - - -def treq(**kwargs): - """ - Returns: - netlib.http.Request - """ - default = dict( - first_line_format="relative", - method=b"GET", - scheme=b"http", - host=b"address", - port=22, - path=b"/path", - http_version=b"HTTP/1.1", - headers=Headers(header="qvalue", content_length="7"), - content=b"content" - ) - default.update(kwargs) - return Request(**default) - - -def tresp(**kwargs): - """ - Returns: - netlib.http.Response - """ - default = dict( - http_version=b"HTTP/1.1", - status_code=200, - reason=b"OK", - headers=Headers(header_response="svalue", content_length="7"), - content=b"message", - timestamp_start=time.time(), - timestamp_end=time.time(), - ) - default.update(kwargs) - return Response(**default) diff --git a/netlib/netlib/utils.py b/netlib/netlib/utils.py deleted file mode 100644 index f7bb5c4b..00000000 --- a/netlib/netlib/utils.py +++ /dev/null @@ -1,418 +0,0 @@ -from __future__ import absolute_import, print_function, division -import os.path -import re -import codecs -import unicodedata -from abc import ABCMeta, abstractmethod -import importlib -import inspect - -import six - -from six.moves import urllib -import hyperframe - - -@six.add_metaclass(ABCMeta) -class Serializable(object): - """ - Abstract Base Class that defines an API to save an object's state and restore it later on. - """ - - @classmethod - @abstractmethod - def from_state(cls, state): - """ - Create a new object from the given state. - """ - raise NotImplementedError() - - @abstractmethod - def get_state(self): - """ - Retrieve object state. - """ - raise NotImplementedError() - - @abstractmethod - def set_state(self, state): - """ - Set object state to the given state. - """ - raise NotImplementedError() - - -def always_bytes(unicode_or_bytes, *encode_args): - if isinstance(unicode_or_bytes, six.text_type): - return unicode_or_bytes.encode(*encode_args) - return unicode_or_bytes - - -def always_byte_args(*encode_args): - """Decorator that transparently encodes all arguments passed as unicode""" - def decorator(fun): - def _fun(*args, **kwargs): - args = [always_bytes(arg, *encode_args) for arg in args] - kwargs = {k: always_bytes(v, *encode_args) for k, v in six.iteritems(kwargs)} - return fun(*args, **kwargs) - return _fun - return decorator - - -def native(s, *encoding_opts): - """ - Convert :py:class:`bytes` or :py:class:`unicode` to the native - :py:class:`str` type, using latin1 encoding if conversion is necessary. - - https://www.python.org/dev/peps/pep-3333/#a-note-on-string-types - """ - if not isinstance(s, (six.binary_type, six.text_type)): - raise TypeError("%r is neither bytes nor unicode" % s) - if six.PY3: - if isinstance(s, six.binary_type): - return s.decode(*encoding_opts) - else: - if isinstance(s, six.text_type): - return s.encode(*encoding_opts) - return s - - -def isascii(bytes): - try: - bytes.decode("ascii") - except ValueError: - return False - return True - - -def clean_bin(s, keep_spacing=True): - """ - Cleans binary data to make it safe to display. - - Args: - keep_spacing: If False, tabs and newlines will also be replaced. - """ - if isinstance(s, six.text_type): - if keep_spacing: - keep = u" \n\r\t" - else: - keep = u" " - return u"".join( - ch if (unicodedata.category(ch)[0] not in "CZ" or ch in keep) else u"." - for ch in s - ) - else: - if keep_spacing: - keep = (9, 10, 13) # \t, \n, \r, - else: - keep = () - return b"".join( - six.int2byte(ch) if (31 < ch < 127 or ch in keep) else b"." - for ch in six.iterbytes(s) - ) - - -def hexdump(s): - """ - Returns: - A generator of (offset, hex, str) tuples - """ - for i in range(0, len(s), 16): - offset = "{:0=10x}".format(i).encode() - part = s[i:i + 16] - x = b" ".join("{:0=2x}".format(i).encode() for i in six.iterbytes(part)) - x = x.ljust(47) # 16*2 + 15 - yield (offset, x, clean_bin(part, False)) - - -def setbit(byte, offset, value): - """ - Set a bit in a byte to 1 if value is truthy, 0 if not. - """ - if value: - return byte | (1 << offset) - else: - return byte & ~(1 << offset) - - -def getbit(byte, offset): - mask = 1 << offset - return bool(byte & mask) - - -class BiDi(object): - - """ - A wee utility class for keeping bi-directional mappings, like field - constants in protocols. Names are attributes on the object, dict-like - access maps values to names: - - CONST = BiDi(a=1, b=2) - assert CONST.a == 1 - assert CONST.get_name(1) == "a" - """ - - def __init__(self, **kwargs): - self.names = kwargs - self.values = {} - for k, v in kwargs.items(): - self.values[v] = k - if len(self.names) != len(self.values): - raise ValueError("Duplicate values not allowed.") - - def __getattr__(self, k): - if k in self.names: - return self.names[k] - raise AttributeError("No such attribute: %s", k) - - def get_name(self, n, default=None): - return self.values.get(n, default) - - -def pretty_size(size): - suffixes = [ - ("B", 2 ** 10), - ("kB", 2 ** 20), - ("MB", 2 ** 30), - ] - for suf, lim in suffixes: - if size >= lim: - continue - else: - x = round(size / float(lim / 2 ** 10), 2) - if x == int(x): - x = int(x) - return str(x) + suf - - -class Data(object): - - def __init__(self, name): - m = importlib.import_module(name) - dirname = os.path.dirname(inspect.getsourcefile(m)) - self.dirname = os.path.abspath(dirname) - - def path(self, path): - """ - Returns a path to the package data housed at 'path' under this - module.Path can be a path to a file, or to a directory. - - This function will raise ValueError if the path does not exist. - """ - fullpath = os.path.join(self.dirname, path) - if not os.path.exists(fullpath): - raise ValueError("dataPath: %s does not exist." % fullpath) - return fullpath - - -_label_valid = re.compile(b"(?!-)[A-Z\d-]{1,63}(? 255: - return False - if host[-1] == b".": - host = host[:-1] - return all(_label_valid.match(x) for x in host.split(b".")) - - -def is_valid_port(port): - return 0 <= port <= 65535 - - -# PY2 workaround -def decode_parse_result(result, enc): - if hasattr(result, "decode"): - return result.decode(enc) - else: - return urllib.parse.ParseResult(*[x.decode(enc) for x in result]) - - -# PY2 workaround -def encode_parse_result(result, enc): - if hasattr(result, "encode"): - return result.encode(enc) - else: - return urllib.parse.ParseResult(*[x.encode(enc) for x in result]) - - -def parse_url(url): - """ - URL-parsing function that checks that - - port is an integer 0-65535 - - host is a valid IDNA-encoded hostname with no null-bytes - - path is valid ASCII - - Args: - A URL (as bytes or as unicode) - - Returns: - A (scheme, host, port, path) tuple - - Raises: - ValueError, if the URL is not properly formatted. - """ - parsed = urllib.parse.urlparse(url) - - if not parsed.hostname: - raise ValueError("No hostname given") - - if isinstance(url, six.binary_type): - host = parsed.hostname - - # this should not raise a ValueError, - # but we try to be very forgiving here and accept just everything. - # decode_parse_result(parsed, "ascii") - else: - host = parsed.hostname.encode("idna") - parsed = encode_parse_result(parsed, "ascii") - - port = parsed.port - if not port: - port = 443 if parsed.scheme == b"https" else 80 - - full_path = urllib.parse.urlunparse( - (b"", b"", parsed.path, parsed.params, parsed.query, parsed.fragment) - ) - if not full_path.startswith(b"/"): - full_path = b"/" + full_path - - if not is_valid_host(host): - raise ValueError("Invalid Host") - if not is_valid_port(port): - raise ValueError("Invalid Port") - - return parsed.scheme, host, port, full_path - - -def get_header_tokens(headers, key): - """ - Retrieve all tokens for a header key. A number of different headers - follow a pattern where each header line can containe comma-separated - tokens, and headers can be set multiple times. - """ - if key not in headers: - return [] - tokens = headers[key].split(",") - return [token.strip() for token in tokens] - - -def hostport(scheme, host, port): - """ - Returns the host component, with a port specifcation if needed. - """ - if (port, scheme) in [(80, "http"), (443, "https"), (80, b"http"), (443, b"https")]: - return host - else: - if isinstance(host, six.binary_type): - return b"%s:%d" % (host, port) - else: - return "%s:%d" % (host, port) - - -def unparse_url(scheme, host, port, path=""): - """ - Returns a URL string, constructed from the specified components. - - Args: - All args must be str. - """ - return "%s://%s%s" % (scheme, hostport(scheme, host, port), path) - - -def urlencode(s): - """ - Takes a list of (key, value) tuples and returns a urlencoded string. - """ - s = [tuple(i) for i in s] - return urllib.parse.urlencode(s, False) - - -def urldecode(s): - """ - Takes a urlencoded string and returns a list of (key, value) tuples. - """ - return urllib.parse.parse_qsl(s, keep_blank_values=True) - - -def parse_content_type(c): - """ - A simple parser for content-type values. Returns a (type, subtype, - parameters) tuple, where type and subtype are strings, and parameters - is a dict. If the string could not be parsed, return None. - - E.g. the following string: - - text/html; charset=UTF-8 - - Returns: - - ("text", "html", {"charset": "UTF-8"}) - """ - parts = c.split(";", 1) - ts = parts[0].split("/", 1) - if len(ts) != 2: - return None - d = {} - if len(parts) == 2: - for i in parts[1].split(";"): - clause = i.split("=", 1) - if len(clause) == 2: - d[clause[0].strip()] = clause[1].strip() - return ts[0].lower(), ts[1].lower(), d - - -def multipartdecode(headers, content): - """ - Takes a multipart boundary encoded string and returns list of (key, value) tuples. - """ - v = headers.get("content-type") - if v: - v = parse_content_type(v) - if not v: - return [] - try: - boundary = v[2]["boundary"].encode("ascii") - except (KeyError, UnicodeError): - return [] - - rx = re.compile(br'\bname="([^"]+)"') - r = [] - - for i in content.split(b"--" + boundary): - parts = i.splitlines() - if len(parts) > 1 and parts[0][0:2] != b"--": - match = rx.search(parts[1]) - if match: - key = match.group(1) - value = b"".join(parts[3 + parts[2:].index(b""):]) - r.append((key, value)) - return r - return [] - - -def http2_read_raw_frame(rfile): - header = rfile.safe_read(9) - length = int(codecs.encode(header[:3], 'hex_codec'), 16) - - if length == 4740180: - raise ValueError("Length field looks more like HTTP/1.1: %s" % rfile.peek(20)) - - body = rfile.safe_read(length) - return [header, body] - -def http2_read_frame(rfile): - header, body = http2_read_raw_frame(rfile) - frame, length = hyperframe.frame.Frame.parse_frame_header(header) - frame.parse_body(memoryview(body)) - return frame diff --git a/netlib/netlib/version.py b/netlib/netlib/version.py deleted file mode 100644 index 379fee0f..00000000 --- a/netlib/netlib/version.py +++ /dev/null @@ -1,6 +0,0 @@ -from __future__ import (absolute_import, print_function, division) - -IVERSION = (0, 17) -VERSION = ".".join(str(i) for i in IVERSION) -NAME = "netlib" -NAMEVERSION = NAME + " " + VERSION diff --git a/netlib/netlib/version_check.py b/netlib/netlib/version_check.py deleted file mode 100644 index 9cf27eea..00000000 --- a/netlib/netlib/version_check.py +++ /dev/null @@ -1,60 +0,0 @@ -""" -Having installed a wrong version of pyOpenSSL or netlib is unfortunately a -very common source of error. Check before every start that both versions -are somewhat okay. -""" -from __future__ import division, absolute_import, print_function -import sys -import inspect -import os.path -import six - -import OpenSSL -from . import version - -PYOPENSSL_MIN_VERSION = (0, 15) - - -def check_mitmproxy_version(mitmproxy_version, fp=sys.stderr): - # We don't introduce backward-incompatible changes in patch versions. Only - # consider major and minor version. - if version.IVERSION[:2] != mitmproxy_version[:2]: - print( - u"You are using mitmproxy %s with netlib %s. " - u"Most likely, that won't work - please upgrade!" % ( - mitmproxy_version, version.VERSION - ), - file=fp - ) - sys.exit(1) - - -def check_pyopenssl_version(min_version=PYOPENSSL_MIN_VERSION, fp=sys.stderr): - min_version_str = u".".join(six.text_type(x) for x in min_version) - try: - v = tuple(int(x) for x in OpenSSL.__version__.split(".")[:2]) - except ValueError: - print( - u"Cannot parse pyOpenSSL version: {}" - u"mitmproxy requires pyOpenSSL {} or greater.".format( - OpenSSL.__version__, min_version_str - ), - file=fp - ) - return - if v < min_version: - print( - u"You are using an outdated version of pyOpenSSL: " - u"mitmproxy requires pyOpenSSL {} or greater.".format(min_version_str), - file=fp - ) - # Some users apparently have multiple versions of pyOpenSSL installed. - # Report which one we got. - pyopenssl_path = os.path.dirname(inspect.getfile(OpenSSL)) - print( - u"Your pyOpenSSL {} installation is located at {}".format( - OpenSSL.__version__, pyopenssl_path - ), - file=fp - ) - sys.exit(1) diff --git a/netlib/netlib/websockets/__init__.py b/netlib/netlib/websockets/__init__.py deleted file mode 100644 index 1c143919..00000000 --- a/netlib/netlib/websockets/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from .frame import * -from .protocol import * diff --git a/netlib/netlib/websockets/frame.py b/netlib/netlib/websockets/frame.py deleted file mode 100644 index fce2c9d3..00000000 --- a/netlib/netlib/websockets/frame.py +++ /dev/null @@ -1,316 +0,0 @@ -from __future__ import absolute_import -import os -import struct -import io -import warnings - -import six - -from .protocol import Masker -from netlib import tcp -from netlib import utils - - -MAX_16_BIT_INT = (1 << 16) -MAX_64_BIT_INT = (1 << 64) - -DEFAULT=object() - -OPCODE = utils.BiDi( - CONTINUE=0x00, - TEXT=0x01, - BINARY=0x02, - CLOSE=0x08, - PING=0x09, - PONG=0x0a -) - - -class FrameHeader(object): - - def __init__( - self, - opcode=OPCODE.TEXT, - payload_length=0, - fin=False, - rsv1=False, - rsv2=False, - rsv3=False, - masking_key=DEFAULT, - mask=DEFAULT, - length_code=DEFAULT - ): - if not 0 <= opcode < 2 ** 4: - raise ValueError("opcode must be 0-16") - self.opcode = opcode - self.payload_length = payload_length - self.fin = fin - self.rsv1 = rsv1 - self.rsv2 = rsv2 - self.rsv3 = rsv3 - - if length_code is DEFAULT: - self.length_code = self._make_length_code(self.payload_length) - else: - self.length_code = length_code - - if mask is DEFAULT and masking_key is DEFAULT: - self.mask = False - self.masking_key = b"" - elif mask is DEFAULT: - self.mask = 1 - self.masking_key = masking_key - elif masking_key is DEFAULT: - self.mask = mask - self.masking_key = os.urandom(4) - else: - self.mask = mask - self.masking_key = masking_key - - if self.masking_key and len(self.masking_key) != 4: - raise ValueError("Masking key must be 4 bytes.") - - @classmethod - def _make_length_code(self, length): - """ - A websockets frame contains an initial length_code, and an optional - extended length code to represent the actual length if length code is - larger than 125 - """ - if length <= 125: - return length - elif length >= 126 and length <= 65535: - return 126 - else: - return 127 - - def __repr__(self): - vals = [ - "ws frame:", - OPCODE.get_name(self.opcode, hex(self.opcode)).lower() - ] - flags = [] - for i in ["fin", "rsv1", "rsv2", "rsv3", "mask"]: - if getattr(self, i): - flags.append(i) - if flags: - vals.extend([":", "|".join(flags)]) - if self.masking_key: - vals.append(":key=%s" % repr(self.masking_key)) - if self.payload_length: - vals.append(" %s" % utils.pretty_size(self.payload_length)) - return "".join(vals) - - def human_readable(self): - warnings.warn("FrameHeader.to_bytes is deprecated, use bytes(frame_header) instead.", DeprecationWarning) - return repr(self) - - def __bytes__(self): - first_byte = utils.setbit(0, 7, self.fin) - first_byte = utils.setbit(first_byte, 6, self.rsv1) - first_byte = utils.setbit(first_byte, 5, self.rsv2) - first_byte = utils.setbit(first_byte, 4, self.rsv3) - first_byte = first_byte | self.opcode - - second_byte = utils.setbit(self.length_code, 7, self.mask) - - b = six.int2byte(first_byte) + six.int2byte(second_byte) - - if self.payload_length < 126: - pass - elif self.payload_length < MAX_16_BIT_INT: - # '!H' pack as 16 bit unsigned short - # add 2 byte extended payload length - b += struct.pack('!H', self.payload_length) - elif self.payload_length < MAX_64_BIT_INT: - # '!Q' = pack as 64 bit unsigned long long - # add 8 bytes extended payload length - b += struct.pack('!Q', self.payload_length) - if self.masking_key: - b += self.masking_key - return b - - if six.PY2: - __str__ = __bytes__ - - def to_bytes(self): - warnings.warn("FrameHeader.to_bytes is deprecated, use bytes(frame_header) instead.", DeprecationWarning) - return bytes(self) - - @classmethod - def from_file(cls, fp): - """ - read a websockets frame header - """ - first_byte = six.byte2int(fp.safe_read(1)) - second_byte = six.byte2int(fp.safe_read(1)) - - fin = utils.getbit(first_byte, 7) - rsv1 = utils.getbit(first_byte, 6) - rsv2 = utils.getbit(first_byte, 5) - rsv3 = utils.getbit(first_byte, 4) - # grab right-most 4 bits - opcode = first_byte & 15 - mask_bit = utils.getbit(second_byte, 7) - # grab the next 7 bits - length_code = second_byte & 127 - - # payload_lengthy > 125 indicates you need to read more bytes - # to get the actual payload length - if length_code <= 125: - payload_length = length_code - elif length_code == 126: - payload_length, = struct.unpack("!H", fp.safe_read(2)) - elif length_code == 127: - payload_length, = struct.unpack("!Q", fp.safe_read(8)) - - # masking key only present if mask bit set - if mask_bit == 1: - masking_key = fp.safe_read(4) - else: - masking_key = None - - return cls( - fin=fin, - rsv1=rsv1, - rsv2=rsv2, - rsv3=rsv3, - opcode=opcode, - mask=mask_bit, - length_code=length_code, - payload_length=payload_length, - masking_key=masking_key, - ) - - def __eq__(self, other): - if isinstance(other, FrameHeader): - return bytes(self) == bytes(other) - return False - - -class Frame(object): - - """ - Represents one websockets frame. - Constructor takes human readable forms of the frame components - from_bytes() is also avaliable. - - WebSockets Frame as defined in RFC6455 - - 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 - +-+-+-+-+-------+-+-------------+-------------------------------+ - |F|R|R|R| opcode|M| Payload len | Extended payload length | - |I|S|S|S| (4) |A| (7) | (16/64) | - |N|V|V|V| |S| | (if payload len==126/127) | - | |1|2|3| |K| | | - +-+-+-+-+-------+-+-------------+ - - - - - - - - - - - - - - - + - | Extended payload length continued, if payload len == 127 | - + - - - - - - - - - - - - - - - +-------------------------------+ - | |Masking-key, if MASK set to 1 | - +-------------------------------+-------------------------------+ - | Masking-key (continued) | Payload Data | - +-------------------------------- - - - - - - - - - - - - - - - + - : Payload Data continued ... : - + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + - | Payload Data continued ... | - +---------------------------------------------------------------+ - """ - - def __init__(self, payload=b"", **kwargs): - self.payload = payload - kwargs["payload_length"] = kwargs.get("payload_length", len(payload)) - self.header = FrameHeader(**kwargs) - - @classmethod - def default(cls, message, from_client=False): - """ - Construct a basic websocket frame from some default values. - Creates a non-fragmented text frame. - """ - if from_client: - mask_bit = 1 - masking_key = os.urandom(4) - else: - mask_bit = 0 - masking_key = None - - return cls( - message, - fin=1, # final frame - opcode=OPCODE.TEXT, # text - mask=mask_bit, - masking_key=masking_key, - ) - - @classmethod - def from_bytes(cls, bytestring): - """ - Construct a websocket frame from an in-memory bytestring - to construct a frame from a stream of bytes, use from_file() directly - """ - return cls.from_file(tcp.Reader(io.BytesIO(bytestring))) - - def __repr__(self): - ret = repr(self.header) - if self.payload: - ret = ret + "\nPayload:\n" + utils.clean_bin(self.payload).decode("ascii") - return ret - - def human_readable(self): - warnings.warn("Frame.to_bytes is deprecated, use bytes(frame) instead.", DeprecationWarning) - return repr(self) - - def __bytes__(self): - """ - Serialize the frame to wire format. Returns a string. - """ - b = bytes(self.header) - if self.header.masking_key: - b += Masker(self.header.masking_key)(self.payload) - else: - b += self.payload - return b - - if six.PY2: - __str__ = __bytes__ - - def to_bytes(self): - warnings.warn("FrameHeader.to_bytes is deprecated, use bytes(frame_header) instead.", DeprecationWarning) - return bytes(self) - - def to_file(self, writer): - warnings.warn("Frame.to_file is deprecated, use wfile.write(bytes(frame)) instead.", DeprecationWarning) - writer.write(bytes(self)) - writer.flush() - - @classmethod - def from_file(cls, fp): - """ - read a websockets frame sent by a server or client - - fp is a "file like" object that could be backed by a network - stream or a disk or an in memory stream reader - """ - header = FrameHeader.from_file(fp) - payload = fp.safe_read(header.payload_length) - - if header.mask == 1 and header.masking_key: - payload = Masker(header.masking_key)(payload) - - return cls( - payload, - fin=header.fin, - opcode=header.opcode, - mask=header.mask, - payload_length=header.payload_length, - masking_key=header.masking_key, - rsv1=header.rsv1, - rsv2=header.rsv2, - rsv3=header.rsv3, - length_code=header.length_code - ) - - def __eq__(self, other): - if isinstance(other, Frame): - return bytes(self) == bytes(other) - return False diff --git a/netlib/netlib/websockets/protocol.py b/netlib/netlib/websockets/protocol.py deleted file mode 100644 index 1e95fa1c..00000000 --- a/netlib/netlib/websockets/protocol.py +++ /dev/null @@ -1,115 +0,0 @@ - - - -# Colleciton of utility functions that implement small portions of the RFC6455 -# WebSockets Protocol Useful for building WebSocket clients and servers. -# -# Emphassis is on readabilty, simplicity and modularity, not performance or -# completeness -# -# This is a work in progress and does not yet contain all the utilites need to -# create fully complient client/servers # -# Spec: https://tools.ietf.org/html/rfc6455 - -# The magic sha that websocket servers must know to prove they understand -# RFC6455 -from __future__ import absolute_import -import base64 -import hashlib -import os - -import binascii -import six -from ..http import Headers - -websockets_magic = b'258EAFA5-E914-47DA-95CA-C5AB0DC85B11' -VERSION = "13" - - -class Masker(object): - - """ - Data sent from the server must be masked to prevent malicious clients - from sending data over the wire in predictable patterns - - Servers do not have to mask data they send to the client. - https://tools.ietf.org/html/rfc6455#section-5.3 - """ - - def __init__(self, key): - self.key = key - self.offset = 0 - - def mask(self, offset, data): - result = bytearray(data) - if six.PY2: - for i in range(len(data)): - result[i] ^= ord(self.key[offset % 4]) - offset += 1 - result = str(result) - else: - - for i in range(len(data)): - result[i] ^= self.key[offset % 4] - offset += 1 - result = bytes(result) - return result - - def __call__(self, data): - ret = self.mask(self.offset, data) - self.offset += len(ret) - return ret - - -class WebsocketsProtocol(object): - - def __init__(self): - pass - - @classmethod - def client_handshake_headers(self, key=None, version=VERSION): - """ - Create the headers for a valid HTTP upgrade request. If Key is not - specified, it is generated, and can be found in sec-websocket-key in - the returned header set. - - Returns an instance of Headers - """ - if not key: - key = base64.b64encode(os.urandom(16)).decode('ascii') - return Headers( - sec_websocket_key=key, - sec_websocket_version=version, - connection="Upgrade", - upgrade="websocket", - ) - - @classmethod - def server_handshake_headers(self, key): - """ - The server response is a valid HTTP 101 response. - """ - return Headers( - sec_websocket_accept=self.create_server_nonce(key), - connection="Upgrade", - upgrade="websocket" - ) - - - @classmethod - def check_client_handshake(self, headers): - if headers.get("upgrade") != "websocket": - return - return headers.get("sec-websocket-key") - - - @classmethod - def check_server_handshake(self, headers): - if headers.get("upgrade") != "websocket": - return - return headers.get("sec-websocket-accept") - - - @classmethod - def create_server_nonce(self, client_nonce): - return base64.b64encode(hashlib.sha1(client_nonce + websockets_magic).digest()) diff --git a/netlib/netlib/wsgi.py b/netlib/netlib/wsgi.py deleted file mode 100644 index d6dfae5d..00000000 --- a/netlib/netlib/wsgi.py +++ /dev/null @@ -1,164 +0,0 @@ -from __future__ import (absolute_import, print_function, division) -from io import BytesIO, StringIO -import urllib -import time -import traceback - -import six -from six.moves import urllib - -from netlib.utils import always_bytes, native -from . import http, tcp - -class ClientConn(object): - - def __init__(self, address): - self.address = tcp.Address.wrap(address) - - -class Flow(object): - - def __init__(self, address, request): - self.client_conn = ClientConn(address) - self.request = request - - -class Request(object): - - def __init__(self, scheme, method, path, http_version, headers, content): - self.scheme, self.method, self.path = scheme, method, path - self.headers, self.content = headers, content - self.http_version = http_version - - -def date_time_string(): - """Return the current date and time formatted for a message header.""" - WEEKS = ['Mon', 'Tue', 'Wed', 'Thu', 'Fri', 'Sat', 'Sun'] - MONTHS = [ - None, - 'Jan', 'Feb', 'Mar', 'Apr', 'May', 'Jun', - 'Jul', 'Aug', 'Sep', 'Oct', 'Nov', 'Dec' - ] - now = time.time() - year, month, day, hh, mm, ss, wd, y_, z_ = time.gmtime(now) - s = "%s, %02d %3s %4d %02d:%02d:%02d GMT" % ( - WEEKS[wd], - day, MONTHS[month], year, - hh, mm, ss - ) - return s - - -class WSGIAdaptor(object): - - def __init__(self, app, domain, port, sversion): - self.app, self.domain, self.port, self.sversion = app, domain, port, sversion - - def make_environ(self, flow, errsoc, **extra): - path = native(flow.request.path, "latin-1") - if '?' in path: - path_info, query = native(path, "latin-1").split('?', 1) - else: - path_info = path - query = '' - environ = { - 'wsgi.version': (1, 0), - 'wsgi.url_scheme': native(flow.request.scheme, "latin-1"), - 'wsgi.input': BytesIO(flow.request.content or b""), - 'wsgi.errors': errsoc, - 'wsgi.multithread': True, - 'wsgi.multiprocess': False, - 'wsgi.run_once': False, - 'SERVER_SOFTWARE': self.sversion, - 'REQUEST_METHOD': native(flow.request.method, "latin-1"), - 'SCRIPT_NAME': '', - 'PATH_INFO': urllib.parse.unquote(path_info), - 'QUERY_STRING': query, - 'CONTENT_TYPE': native(flow.request.headers.get('Content-Type', ''), "latin-1"), - 'CONTENT_LENGTH': native(flow.request.headers.get('Content-Length', ''), "latin-1"), - 'SERVER_NAME': self.domain, - 'SERVER_PORT': str(self.port), - 'SERVER_PROTOCOL': native(flow.request.http_version, "latin-1"), - } - environ.update(extra) - if flow.client_conn.address: - environ["REMOTE_ADDR"] = native(flow.client_conn.address.host, "latin-1") - environ["REMOTE_PORT"] = flow.client_conn.address.port - - for key, value in flow.request.headers.items(): - key = 'HTTP_' + native(key, "latin-1").upper().replace('-', '_') - if key not in ('HTTP_CONTENT_TYPE', 'HTTP_CONTENT_LENGTH'): - environ[key] = value - return environ - - def error_page(self, soc, headers_sent, s): - """ - Make a best-effort attempt to write an error page. If headers are - already sent, we just bung the error into the page. - """ - c = """ - -

Internal Server Error

-
{err}"
- - """.format(err=s).strip().encode() - - if not headers_sent: - soc.write(b"HTTP/1.1 500 Internal Server Error\r\n") - soc.write(b"Content-Type: text/html\r\n") - soc.write("Content-Length: {length}\r\n".format(length=len(c)).encode()) - soc.write(b"\r\n") - soc.write(c) - - def serve(self, request, soc, **env): - state = dict( - response_started=False, - headers_sent=False, - status=None, - headers=None - ) - - def write(data): - if not state["headers_sent"]: - soc.write("HTTP/1.1 {status}\r\n".format(status=state["status"]).encode()) - headers = state["headers"] - if 'server' not in headers: - headers["Server"] = self.sversion - if 'date' not in headers: - headers["Date"] = date_time_string() - soc.write(bytes(headers)) - soc.write(b"\r\n") - state["headers_sent"] = True - if data: - soc.write(data) - soc.flush() - - def start_response(status, headers, exc_info=None): - if exc_info: - if state["headers_sent"]: - six.reraise(*exc_info) - elif state["status"]: - raise AssertionError('Response already started') - state["status"] = status - state["headers"] = http.Headers([[always_bytes(k), always_bytes(v)] for k,v in headers]) - if exc_info: - self.error_page(soc, state["headers_sent"], traceback.format_tb(exc_info[2])) - state["headers_sent"] = True - - errs = six.BytesIO() - try: - dataiter = self.app( - self.make_environ(request, errs, **env), start_response - ) - for i in dataiter: - write(i) - if not state["headers_sent"]: - write(b"") - except Exception as e: - try: - s = traceback.format_exc() - errs.write(s.encode("utf-8", "replace")) - self.error_page(soc, state["headers_sent"], s) - except Exception: # pragma: no cover - pass - return errs.getvalue() diff --git a/netlib/odict.py b/netlib/odict.py new file mode 100644 index 00000000..1e6e381a --- /dev/null +++ b/netlib/odict.py @@ -0,0 +1,193 @@ +from __future__ import (absolute_import, print_function, division) +import re +import copy +import six + +from .utils import Serializable + + +def safe_subn(pattern, repl, target, *args, **kwargs): + """ + There are Unicode conversion problems with re.subn. We try to smooth + that over by casting the pattern and replacement to strings. We really + need a better solution that is aware of the actual content ecoding. + """ + return re.subn(str(pattern), str(repl), target, *args, **kwargs) + + +class ODict(Serializable): + + """ + A dictionary-like object for managing ordered (key, value) data. Think + about it as a convenient interface to a list of (key, value) tuples. + """ + + def __init__(self, lst=None): + self.lst = lst or [] + + def _kconv(self, s): + return s + + def __eq__(self, other): + return self.lst == other.lst + + def __ne__(self, other): + return not self.__eq__(other) + + def __iter__(self): + return self.lst.__iter__() + + def __getitem__(self, k): + """ + Returns a list of values matching key. + """ + ret = [] + k = self._kconv(k) + for i in self.lst: + if self._kconv(i[0]) == k: + ret.append(i[1]) + return ret + + def keys(self): + return list(set([self._kconv(i[0]) for i in self.lst])) + + def _filter_lst(self, k, lst): + k = self._kconv(k) + new = [] + for i in lst: + if self._kconv(i[0]) != k: + new.append(i) + return new + + def __len__(self): + """ + Total number of (key, value) pairs. + """ + return len(self.lst) + + def __setitem__(self, k, valuelist): + """ + Sets the values for key k. If there are existing values for this + key, they are cleared. + """ + if isinstance(valuelist, six.text_type) or isinstance(valuelist, six.binary_type): + raise ValueError( + "Expected list of values instead of string. " + "Example: odict[b'Host'] = [b'www.example.com']" + ) + kc = self._kconv(k) + new = [] + for i in self.lst: + if self._kconv(i[0]) == kc: + if valuelist: + new.append([k, valuelist.pop(0)]) + else: + new.append(i) + while valuelist: + new.append([k, valuelist.pop(0)]) + self.lst = new + + def __delitem__(self, k): + """ + Delete all items matching k. + """ + self.lst = self._filter_lst(k, self.lst) + + def __contains__(self, k): + k = self._kconv(k) + for i in self.lst: + if self._kconv(i[0]) == k: + return True + return False + + def add(self, key, value, prepend=False): + if prepend: + self.lst.insert(0, [key, value]) + else: + self.lst.append([key, value]) + + def get(self, k, d=None): + if k in self: + return self[k] + else: + return d + + def get_first(self, k, d=None): + if k in self: + return self[k][0] + else: + return d + + def items(self): + return self.lst[:] + + def copy(self): + """ + Returns a copy of this object. + """ + lst = copy.deepcopy(self.lst) + return self.__class__(lst) + + def extend(self, other): + """ + Add the contents of other, preserving any duplicates. + """ + self.lst.extend(other.lst) + + def __repr__(self): + return repr(self.lst) + + def in_any(self, key, value, caseless=False): + """ + Do any of the values matching key contain value? + + If caseless is true, value comparison is case-insensitive. + """ + if caseless: + value = value.lower() + for i in self[key]: + if caseless: + i = i.lower() + if value in i: + return True + return False + + def replace(self, pattern, repl, *args, **kwargs): + """ + Replaces a regular expression pattern with repl in both keys and + values. Encoded content will be decoded before replacement, and + re-encoded afterwards. + + Returns the number of replacements made. + """ + nlst, count = [], 0 + for i in self.lst: + k, c = safe_subn(pattern, repl, i[0], *args, **kwargs) + count += c + v, c = safe_subn(pattern, repl, i[1], *args, **kwargs) + count += c + nlst.append([k, v]) + self.lst = nlst + return count + + # Implement the StateObject protocol from mitmproxy + def get_state(self): + return [tuple(i) for i in self.lst] + + def set_state(self, state): + self.lst = [list(i) for i in state] + + @classmethod + def from_state(cls, state): + return cls([list(i) for i in state]) + + +class ODictCaseless(ODict): + + """ + A variant of ODict with "caseless" keys. This version _preserves_ key + case, but does not consider case when setting or getting items. + """ + + def _kconv(self, s): + return s.lower() diff --git a/netlib/setup.cfg b/netlib/setup.cfg deleted file mode 100644 index 3480374b..00000000 --- a/netlib/setup.cfg +++ /dev/null @@ -1,2 +0,0 @@ -[bdist_wheel] -universal=1 \ No newline at end of file diff --git a/netlib/setup.py b/netlib/setup.py deleted file mode 100644 index 0c9a721d..00000000 --- a/netlib/setup.py +++ /dev/null @@ -1,70 +0,0 @@ -from setuptools import setup, find_packages -from codecs import open -import os -import sys - -from netlib import version - -# Based on https://github.com/pypa/sampleproject/blob/master/setup.py -# and https://python-packaging-user-guide.readthedocs.org/ -# and https://caremad.io/2014/11/distributing-a-cffi-project/ - -here = os.path.abspath(os.path.dirname(__file__)) - -with open(os.path.join(here, 'README.rst'), encoding='utf-8') as f: - long_description = f.read() - -setup( - name="netlib", - version=version.VERSION, - description="A collection of network utilities used by pathod and mitmproxy.", - long_description=long_description, - url="http://github.com/mitmproxy/netlib", - author="Aldo Cortesi", - author_email="aldo@corte.si", - license="MIT", - classifiers=[ - "License :: OSI Approved :: MIT License", - "Development Status :: 3 - Alpha", - "Operating System :: POSIX", - "Programming Language :: Python", - "Programming Language :: Python :: 2", - "Programming Language :: Python :: 2.7", - "Programming Language :: Python :: 3", - "Programming Language :: Python :: 3.5", - "Programming Language :: Python :: Implementation :: CPython", - "Programming Language :: Python :: Implementation :: PyPy", - "Topic :: Internet", - "Topic :: Internet :: WWW/HTTP", - "Topic :: Internet :: WWW/HTTP :: HTTP Servers", - "Topic :: Software Development :: Testing", - "Topic :: Software Development :: Testing :: Traffic Generation", - ], - packages=find_packages(), - install_requires=[ - "pyasn1>=0.1.9, <0.2", - "pyOpenSSL>=0.15.1, <0.16", - "cryptography>=1.2.2, <1.3", - "passlib>=1.6.5, <1.7", - "hpack>=2.1.0, <3.0", - "hyperframe>=3.2.0, <4.0", - "six>=1.10.0, <1.11", - "certifi>=2015.11.20.1", # no semver here - this should always be on the last release! - "backports.ssl_match_hostname>=3.5.0.1, <3.6", - ], - extras_require={ - # Do not use a range operator here: https://bitbucket.org/pypa/setuptools/issues/380 - # Ubuntu Trusty and other still ship with setuptools < 17.1 - ':python_version == "2.7"': [ - "ipaddress>=1.0.15, <1.1", - ], - 'dev': [ - "mock>=1.3.0, <1.4", - "pytest>=2.8.7, <2.9", - "pytest-xdist>=1.14, <1.15", - "pytest-cov>=2.2.1, <2.3", - "pytest-timeout>=1.0.0, <1.1", - "coveralls>=1.1, <1.2" - ] - }, -) diff --git a/netlib/socks.py b/netlib/socks.py new file mode 100644 index 00000000..51ad1c63 --- /dev/null +++ b/netlib/socks.py @@ -0,0 +1,176 @@ +from __future__ import (absolute_import, print_function, division) +import struct +import array +import ipaddress +from . import tcp, utils + + +class SocksError(Exception): + def __init__(self, code, message): + super(SocksError, self).__init__(message) + self.code = code + + +VERSION = utils.BiDi( + SOCKS4=0x04, + SOCKS5=0x05 +) + +CMD = utils.BiDi( + CONNECT=0x01, + BIND=0x02, + UDP_ASSOCIATE=0x03 +) + +ATYP = utils.BiDi( + IPV4_ADDRESS=0x01, + DOMAINNAME=0x03, + IPV6_ADDRESS=0x04 +) + +REP = utils.BiDi( + SUCCEEDED=0x00, + GENERAL_SOCKS_SERVER_FAILURE=0x01, + CONNECTION_NOT_ALLOWED_BY_RULESET=0x02, + NETWORK_UNREACHABLE=0x03, + HOST_UNREACHABLE=0x04, + CONNECTION_REFUSED=0x05, + TTL_EXPIRED=0x06, + COMMAND_NOT_SUPPORTED=0x07, + ADDRESS_TYPE_NOT_SUPPORTED=0x08, +) + +METHOD = utils.BiDi( + NO_AUTHENTICATION_REQUIRED=0x00, + GSSAPI=0x01, + USERNAME_PASSWORD=0x02, + NO_ACCEPTABLE_METHODS=0xFF +) + + +class ClientGreeting(object): + __slots__ = ("ver", "methods") + + def __init__(self, ver, methods): + self.ver = ver + self.methods = array.array("B") + self.methods.extend(methods) + + def assert_socks5(self): + if self.ver != VERSION.SOCKS5: + if self.ver == ord("G") and len(self.methods) == ord("E"): + guess = "Probably not a SOCKS request but a regular HTTP request. " + else: + guess = "" + + raise SocksError( + REP.GENERAL_SOCKS_SERVER_FAILURE, + guess + "Invalid SOCKS version. Expected 0x05, got 0x%x" % self.ver + ) + + @classmethod + def from_file(cls, f, fail_early=False): + """ + :param fail_early: If true, a SocksError will be raised if the first byte does not indicate socks5. + """ + ver, nmethods = struct.unpack("!BB", f.safe_read(2)) + client_greeting = cls(ver, []) + if fail_early: + client_greeting.assert_socks5() + client_greeting.methods.fromstring(f.safe_read(nmethods)) + return client_greeting + + def to_file(self, f): + f.write(struct.pack("!BB", self.ver, len(self.methods))) + f.write(self.methods.tostring()) + + +class ServerGreeting(object): + __slots__ = ("ver", "method") + + def __init__(self, ver, method): + self.ver = ver + self.method = method + + def assert_socks5(self): + if self.ver != VERSION.SOCKS5: + if self.ver == ord("H") and self.method == ord("T"): + guess = "Probably not a SOCKS request but a regular HTTP response. " + else: + guess = "" + + raise SocksError( + REP.GENERAL_SOCKS_SERVER_FAILURE, + guess + "Invalid SOCKS version. Expected 0x05, got 0x%x" % self.ver + ) + + @classmethod + def from_file(cls, f): + ver, method = struct.unpack("!BB", f.safe_read(2)) + return cls(ver, method) + + def to_file(self, f): + f.write(struct.pack("!BB", self.ver, self.method)) + + +class Message(object): + __slots__ = ("ver", "msg", "atyp", "addr") + + def __init__(self, ver, msg, atyp, addr): + self.ver = ver + self.msg = msg + self.atyp = atyp + self.addr = tcp.Address.wrap(addr) + + def assert_socks5(self): + if self.ver != VERSION.SOCKS5: + raise SocksError( + REP.GENERAL_SOCKS_SERVER_FAILURE, + "Invalid SOCKS version. Expected 0x05, got 0x%x" % self.ver + ) + + @classmethod + def from_file(cls, f): + ver, msg, rsv, atyp = struct.unpack("!BBBB", f.safe_read(4)) + if rsv != 0x00: + raise SocksError( + REP.GENERAL_SOCKS_SERVER_FAILURE, + "Socks Request: Invalid reserved byte: %s" % rsv + ) + if atyp == ATYP.IPV4_ADDRESS: + # We use tnoa here as ntop is not commonly available on Windows. + host = ipaddress.IPv4Address(f.safe_read(4)).compressed + use_ipv6 = False + elif atyp == ATYP.IPV6_ADDRESS: + host = ipaddress.IPv6Address(f.safe_read(16)).compressed + use_ipv6 = True + elif atyp == ATYP.DOMAINNAME: + length, = struct.unpack("!B", f.safe_read(1)) + host = f.safe_read(length) + if not utils.is_valid_host(host): + raise SocksError(REP.GENERAL_SOCKS_SERVER_FAILURE, "Invalid hostname: %s" % host) + host = host.decode("idna") + use_ipv6 = False + else: + raise SocksError(REP.ADDRESS_TYPE_NOT_SUPPORTED, + "Socks Request: Unknown ATYP: %s" % atyp) + + port, = struct.unpack("!H", f.safe_read(2)) + addr = tcp.Address((host, port), use_ipv6=use_ipv6) + return cls(ver, msg, atyp, addr) + + def to_file(self, f): + f.write(struct.pack("!BBBB", self.ver, self.msg, 0x00, self.atyp)) + if self.atyp == ATYP.IPV4_ADDRESS: + f.write(ipaddress.IPv4Address(self.addr.host).packed) + elif self.atyp == ATYP.IPV6_ADDRESS: + f.write(ipaddress.IPv6Address(self.addr.host).packed) + elif self.atyp == ATYP.DOMAINNAME: + f.write(struct.pack("!B", len(self.addr.host))) + f.write(self.addr.host.encode("idna")) + else: + raise SocksError( + REP.ADDRESS_TYPE_NOT_SUPPORTED, + "Unknown ATYP: %s" % self.atyp + ) + f.write(struct.pack("!H", self.addr.port)) diff --git a/netlib/tcp.py b/netlib/tcp.py new file mode 100644 index 00000000..61b41cdc --- /dev/null +++ b/netlib/tcp.py @@ -0,0 +1,911 @@ +from __future__ import (absolute_import, print_function, division) +import os +import select +import socket +import sys +import threading +import time +import traceback + +import binascii +from six.moves import range + +import certifi +from backports import ssl_match_hostname +import six +import OpenSSL +from OpenSSL import SSL + +from . import certutils, version_check, utils + +# This is a rather hackish way to make sure that +# the latest version of pyOpenSSL is actually installed. +from netlib.exceptions import InvalidCertificateException, TcpReadIncomplete, TlsException, \ + TcpTimeout, TcpDisconnect, TcpException + +version_check.check_pyopenssl_version() + +if six.PY2: + socket_fileobject = socket._fileobject +else: + socket_fileobject = socket.SocketIO + +EINTR = 4 +if os.environ.get("NO_ALPN"): + HAS_ALPN = False +else: + HAS_ALPN = OpenSSL._util.lib.Cryptography_HAS_ALPN + +# To enable all SSL methods use: SSLv23 +# then add options to disable certain methods +# https://bugs.launchpad.net/pyopenssl/+bug/1020632/comments/3 +SSL_BASIC_OPTIONS = ( + SSL.OP_CIPHER_SERVER_PREFERENCE +) +if hasattr(SSL, "OP_NO_COMPRESSION"): + SSL_BASIC_OPTIONS |= SSL.OP_NO_COMPRESSION + +SSL_DEFAULT_METHOD = SSL.SSLv23_METHOD +SSL_DEFAULT_OPTIONS = ( + SSL.OP_NO_SSLv2 | + SSL.OP_NO_SSLv3 | + SSL_BASIC_OPTIONS +) +if hasattr(SSL, "OP_NO_COMPRESSION"): + SSL_DEFAULT_OPTIONS |= SSL.OP_NO_COMPRESSION + +""" +Map a reasonable SSL version specification into the format OpenSSL expects. +Don't ask... +https://bugs.launchpad.net/pyopenssl/+bug/1020632/comments/3 +""" +sslversion_choices = { + "all": (SSL.SSLv23_METHOD, SSL_BASIC_OPTIONS), + # SSLv23_METHOD + NO_SSLv2 + NO_SSLv3 == TLS 1.0+ + # TLSv1_METHOD would be TLS 1.0 only + "secure": (SSL.SSLv23_METHOD, (SSL.OP_NO_SSLv2 | SSL.OP_NO_SSLv3 | SSL_BASIC_OPTIONS)), + "SSLv2": (SSL.SSLv2_METHOD, SSL_BASIC_OPTIONS), + "SSLv3": (SSL.SSLv3_METHOD, SSL_BASIC_OPTIONS), + "TLSv1": (SSL.TLSv1_METHOD, SSL_BASIC_OPTIONS), + "TLSv1_1": (SSL.TLSv1_1_METHOD, SSL_BASIC_OPTIONS), + "TLSv1_2": (SSL.TLSv1_2_METHOD, SSL_BASIC_OPTIONS), +} + +class SSLKeyLogger(object): + + def __init__(self, filename): + self.filename = filename + self.f = None + self.lock = threading.Lock() + + # required for functools.wraps, which pyOpenSSL uses. + __name__ = "SSLKeyLogger" + + def __call__(self, connection, where, ret): + if where == SSL.SSL_CB_HANDSHAKE_DONE and ret == 1: + with self.lock: + if not self.f: + d = os.path.dirname(self.filename) + if not os.path.isdir(d): + os.makedirs(d) + self.f = open(self.filename, "ab") + self.f.write(b"\r\n") + client_random = binascii.hexlify(connection.client_random()) + masterkey = binascii.hexlify(connection.master_key()) + self.f.write(b"CLIENT_RANDOM %s %s\r\n" % (client_random, masterkey)) + self.f.flush() + + def close(self): + with self.lock: + if self.f: + self.f.close() + + @staticmethod + def create_logfun(filename): + if filename: + return SSLKeyLogger(filename) + return False + +log_ssl_key = SSLKeyLogger.create_logfun( + os.getenv("MITMPROXY_SSLKEYLOGFILE") or os.getenv("SSLKEYLOGFILE")) + + +class _FileLike(object): + BLOCKSIZE = 1024 * 32 + + def __init__(self, o): + self.o = o + self._log = None + self.first_byte_timestamp = None + + def set_descriptor(self, o): + self.o = o + + def __getattr__(self, attr): + return getattr(self.o, attr) + + def start_log(self): + """ + Starts or resets the log. + + This will store all bytes read or written. + """ + self._log = [] + + def stop_log(self): + """ + Stops the log. + """ + self._log = None + + def is_logging(self): + return self._log is not None + + def get_log(self): + """ + Returns the log as a string. + """ + if not self.is_logging(): + raise ValueError("Not logging!") + return b"".join(self._log) + + def add_log(self, v): + if self.is_logging(): + self._log.append(v) + + def reset_timestamps(self): + self.first_byte_timestamp = None + + +class Writer(_FileLike): + + def flush(self): + """ + May raise TcpDisconnect + """ + if hasattr(self.o, "flush"): + try: + self.o.flush() + except (socket.error, IOError) as v: + raise TcpDisconnect(str(v)) + + def write(self, v): + """ + May raise TcpDisconnect + """ + if v: + self.first_byte_timestamp = self.first_byte_timestamp or time.time() + try: + if hasattr(self.o, "sendall"): + self.add_log(v) + return self.o.sendall(v) + else: + r = self.o.write(v) + self.add_log(v[:r]) + return r + except (SSL.Error, socket.error) as e: + raise TcpDisconnect(str(e)) + + +class Reader(_FileLike): + + def read(self, length): + """ + If length is -1, we read until connection closes. + """ + result = b'' + start = time.time() + while length == -1 or length > 0: + if length == -1 or length > self.BLOCKSIZE: + rlen = self.BLOCKSIZE + else: + rlen = length + try: + data = self.o.read(rlen) + except SSL.ZeroReturnError: + # TLS connection was shut down cleanly + break + except (SSL.WantWriteError, SSL.WantReadError): + # From the OpenSSL docs: + # If the underlying BIO is non-blocking, SSL_read() will also return when the + # underlying BIO could not satisfy the needs of SSL_read() to continue the + # operation. In this case a call to SSL_get_error with the return value of + # SSL_read() will yield SSL_ERROR_WANT_READ or SSL_ERROR_WANT_WRITE. + if (time.time() - start) < self.o.gettimeout(): + time.sleep(0.1) + continue + else: + raise TcpTimeout() + except socket.timeout: + raise TcpTimeout() + except socket.error as e: + raise TcpDisconnect(str(e)) + except SSL.SysCallError as e: + if e.args == (-1, 'Unexpected EOF'): + break + raise TlsException(str(e)) + except SSL.Error as e: + raise TlsException(str(e)) + self.first_byte_timestamp = self.first_byte_timestamp or time.time() + if not data: + break + result += data + if length != -1: + length -= len(data) + self.add_log(result) + return result + + def readline(self, size=None): + result = b'' + bytes_read = 0 + while True: + if size is not None and bytes_read >= size: + break + ch = self.read(1) + bytes_read += 1 + if not ch: + break + else: + result += ch + if ch == b'\n': + break + return result + + def safe_read(self, length): + """ + Like .read, but is guaranteed to either return length bytes, or + raise an exception. + """ + result = self.read(length) + if length != -1 and len(result) != length: + if not result: + raise TcpDisconnect() + else: + raise TcpReadIncomplete( + "Expected %s bytes, got %s" % (length, len(result)) + ) + return result + + def peek(self, length): + """ + Tries to peek into the underlying file object. + + Returns: + Up to the next N bytes if peeking is successful. + + Raises: + TcpException if there was an error with the socket + TlsException if there was an error with pyOpenSSL. + NotImplementedError if the underlying file object is not a [pyOpenSSL] socket + """ + if isinstance(self.o, socket_fileobject): + try: + return self.o._sock.recv(length, socket.MSG_PEEK) + except socket.error as e: + raise TcpException(repr(e)) + elif isinstance(self.o, SSL.Connection): + try: + if tuple(int(x) for x in OpenSSL.__version__.split(".")[:2]) > (0, 15): + return self.o.recv(length, socket.MSG_PEEK) + else: + # TODO: remove once a new version is released + # Polyfill for pyOpenSSL <= 0.15.1 + # Taken from https://github.com/pyca/pyopenssl/commit/1d95dea7fea03c7c0df345a5ea30c12d8a0378d2 + buf = SSL._ffi.new("char[]", length) + result = SSL._lib.SSL_peek(self.o._ssl, buf, length) + self.o._raise_ssl_error(self.o._ssl, result) + return SSL._ffi.buffer(buf, result)[:] + except SSL.Error as e: + six.reraise(TlsException, TlsException(str(e)), sys.exc_info()[2]) + else: + raise NotImplementedError("Can only peek into (pyOpenSSL) sockets") + + +class Address(utils.Serializable): + + """ + This class wraps an IPv4/IPv6 tuple to provide named attributes and + ipv6 information. + """ + + def __init__(self, address, use_ipv6=False): + self.address = tuple(address) + self.use_ipv6 = use_ipv6 + + def get_state(self): + return { + "address": self.address, + "use_ipv6": self.use_ipv6 + } + + def set_state(self, state): + self.address = state["address"] + self.use_ipv6 = state["use_ipv6"] + + @classmethod + def from_state(cls, state): + return Address(**state) + + @classmethod + def wrap(cls, t): + if isinstance(t, cls): + return t + else: + return cls(t) + + def __call__(self): + return self.address + + @property + def host(self): + return self.address[0] + + @property + def port(self): + return self.address[1] + + @property + def use_ipv6(self): + return self.family == socket.AF_INET6 + + @use_ipv6.setter + def use_ipv6(self, b): + self.family = socket.AF_INET6 if b else socket.AF_INET + + def __repr__(self): + return "{}:{}".format(self.host, self.port) + + def __str__(self): + return str(self.address) + + def __eq__(self, other): + if not other: + return False + other = Address.wrap(other) + return (self.address, self.family) == (other.address, other.family) + + def __ne__(self, other): + return not self.__eq__(other) + + def __hash__(self): + return hash(self.address) ^ 42 # different hash than the tuple alone. + + +def ssl_read_select(rlist, timeout): + """ + This is a wrapper around select.select() which also works for SSL.Connections + by taking ssl_connection.pending() into account. + + Caveats: + If .pending() > 0 for any of the connections in rlist, we avoid the select syscall + and **will not include any other connections which may or may not be ready**. + + Args: + rlist: wait until ready for reading + + Returns: + subset of rlist which is ready for reading. + """ + return [ + conn for conn in rlist + if isinstance(conn, SSL.Connection) and conn.pending() > 0 + ] or select.select(rlist, (), (), timeout)[0] + + +def close_socket(sock): + """ + Does a hard close of a socket, without emitting a RST. + """ + try: + # We already indicate that we close our end. + # may raise "Transport endpoint is not connected" on Linux + sock.shutdown(socket.SHUT_WR) + + # Section 4.2.2.13 of RFC 1122 tells us that a close() with any pending + # readable data could lead to an immediate RST being sent (which is the + # case on Windows). + # http://ia600609.us.archive.org/22/items/TheUltimateSo_lingerPageOrWhyIsMyTcpNotReliable/the-ultimate-so_linger-page-or-why-is-my-tcp-not-reliable.html + # + # This in turn results in the following issue: If we send an error page + # to the client and then close the socket, the RST may be received by + # the client before the error page and the users sees a connection + # error rather than the error page. Thus, we try to empty the read + # buffer on Windows first. (see + # https://github.com/mitmproxy/mitmproxy/issues/527#issuecomment-93782988) + # + + if os.name == "nt": # pragma: no cover + # We cannot rely on the shutdown()-followed-by-read()-eof technique + # proposed by the page above: Some remote machines just don't send + # a TCP FIN, which would leave us in the unfortunate situation that + # recv() would block infinitely. As a workaround, we set a timeout + # here even if we are in blocking mode. + sock.settimeout(sock.gettimeout() or 20) + + # limit at a megabyte so that we don't read infinitely + for _ in range(1024 ** 3 // 4096): + # may raise a timeout/disconnect exception. + if not sock.recv(4096): + break + + # Now we can close the other half as well. + sock.shutdown(socket.SHUT_RD) + + except socket.error: + pass + + sock.close() + + +class _Connection(object): + + rbufsize = -1 + wbufsize = -1 + + def _makefile(self): + """ + Set up .rfile and .wfile attributes from .connection + """ + # Ideally, we would use the Buffered IO in Python 3 by default. + # Unfortunately, the implementation of .peek() is broken for n>1 bytes, + # as it may just return what's left in the buffer and not all the bytes we want. + # As a workaround, we just use unbuffered sockets directly. + # https://mail.python.org/pipermail/python-dev/2009-June/089986.html + if six.PY2: + self.rfile = Reader(self.connection.makefile('rb', self.rbufsize)) + self.wfile = Writer(self.connection.makefile('wb', self.wbufsize)) + else: + self.rfile = Reader(socket.SocketIO(self.connection, "rb")) + self.wfile = Writer(socket.SocketIO(self.connection, "wb")) + + def __init__(self, connection): + if connection: + self.connection = connection + self._makefile() + else: + self.connection = None + self.rfile = None + self.wfile = None + + self.ssl_established = False + self.finished = False + + def get_current_cipher(self): + if not self.ssl_established: + return None + + name = self.connection.get_cipher_name() + bits = self.connection.get_cipher_bits() + version = self.connection.get_cipher_version() + return name, bits, version + + def finish(self): + self.finished = True + # If we have an SSL connection, wfile.close == connection.close + # (We call _FileLike.set_descriptor(conn)) + # Closing the socket is not our task, therefore we don't call close + # then. + if not isinstance(self.connection, SSL.Connection): + if not getattr(self.wfile, "closed", False): + try: + self.wfile.flush() + self.wfile.close() + except TcpDisconnect: + pass + + self.rfile.close() + else: + try: + self.connection.shutdown() + except SSL.Error: + pass + + def _create_ssl_context(self, + method=SSL_DEFAULT_METHOD, + options=SSL_DEFAULT_OPTIONS, + verify_options=SSL.VERIFY_NONE, + ca_path=None, + ca_pemfile=None, + cipher_list=None, + alpn_protos=None, + alpn_select=None, + alpn_select_callback=None, + ): + """ + Creates an SSL Context. + + :param method: One of SSLv2_METHOD, SSLv3_METHOD, SSLv23_METHOD, TLSv1_METHOD, TLSv1_1_METHOD, or TLSv1_2_METHOD + :param options: A bit field consisting of OpenSSL.SSL.OP_* values + :param verify_options: A bit field consisting of OpenSSL.SSL.VERIFY_* values + :param ca_path: Path to a directory of trusted CA certificates prepared using the c_rehash tool + :param ca_pemfile: Path to a PEM formatted trusted CA certificate + :param cipher_list: A textual OpenSSL cipher list, see https://www.openssl.org/docs/apps/ciphers.html + :rtype : SSL.Context + """ + context = SSL.Context(method) + # Options (NO_SSLv2/3) + if options is not None: + context.set_options(options) + + # Verify Options (NONE/PEER and trusted CAs) + if verify_options is not None: + def verify_cert(conn, x509, errno, err_depth, is_cert_verified): + if not is_cert_verified: + self.ssl_verification_error = dict(errno=errno, + depth=err_depth) + return is_cert_verified + + context.set_verify(verify_options, verify_cert) + if ca_path is None and ca_pemfile is None: + ca_pemfile = certifi.where() + context.load_verify_locations(ca_pemfile, ca_path) + + # Workaround for + # https://github.com/pyca/pyopenssl/issues/190 + # https://github.com/mitmproxy/mitmproxy/issues/472 + # Options already set before are not cleared. + context.set_mode(SSL._lib.SSL_MODE_AUTO_RETRY) + + # Cipher List + if cipher_list: + try: + context.set_cipher_list(cipher_list) + + # TODO: maybe change this to with newer pyOpenSSL APIs + context.set_tmp_ecdh(OpenSSL.crypto.get_elliptic_curve('prime256v1')) + except SSL.Error as v: + raise TlsException("SSL cipher specification error: %s" % str(v)) + + # SSLKEYLOGFILE + if log_ssl_key: + context.set_info_callback(log_ssl_key) + + if HAS_ALPN: + if alpn_protos is not None: + # advertise application layer protocols + context.set_alpn_protos(alpn_protos) + elif alpn_select is not None and alpn_select_callback is None: + # select application layer protocol + def alpn_select_callback(conn_, options): + if alpn_select in options: + return bytes(alpn_select) + else: # pragma no cover + return options[0] + context.set_alpn_select_callback(alpn_select_callback) + elif alpn_select_callback is not None and alpn_select is None: + context.set_alpn_select_callback(alpn_select_callback) + elif alpn_select_callback is not None and alpn_select is not None: + raise TlsException("ALPN error: only define alpn_select (string) OR alpn_select_callback (method).") + + return context + + +class TCPClient(_Connection): + + def __init__(self, address, source_address=None): + super(TCPClient, self).__init__(None) + self.address = address + self.source_address = source_address + self.cert = None + self.ssl_verification_error = None + self.sni = None + + @property + def address(self): + return self.__address + + @address.setter + def address(self, address): + if address: + self.__address = Address.wrap(address) + else: + self.__address = None + + @property + def source_address(self): + return self.__source_address + + @source_address.setter + def source_address(self, source_address): + if source_address: + self.__source_address = Address.wrap(source_address) + else: + self.__source_address = None + + def close(self): + # Make sure to close the real socket, not the SSL proxy. + # OpenSSL is really good at screwing up, i.e. when trying to recv from a failed connection, + # it tries to renegotiate... + if isinstance(self.connection, SSL.Connection): + close_socket(self.connection._socket) + else: + close_socket(self.connection) + + def create_ssl_context(self, cert=None, alpn_protos=None, **sslctx_kwargs): + context = self._create_ssl_context( + alpn_protos=alpn_protos, + **sslctx_kwargs) + # Client Certs + if cert: + try: + context.use_privatekey_file(cert) + context.use_certificate_file(cert) + except SSL.Error as v: + raise TlsException("SSL client certificate error: %s" % str(v)) + return context + + def convert_to_ssl(self, sni=None, alpn_protos=None, **sslctx_kwargs): + """ + cert: Path to a file containing both client cert and private key. + + options: A bit field consisting of OpenSSL.SSL.OP_* values + verify_options: A bit field consisting of OpenSSL.SSL.VERIFY_* values + ca_path: Path to a directory of trusted CA certificates prepared using the c_rehash tool + ca_pemfile: Path to a PEM formatted trusted CA certificate + """ + verification_mode = sslctx_kwargs.get('verify_options', None) + if verification_mode == SSL.VERIFY_PEER and not sni: + raise TlsException("Cannot validate certificate hostname without SNI") + + context = self.create_ssl_context( + alpn_protos=alpn_protos, + **sslctx_kwargs + ) + self.connection = SSL.Connection(context, self.connection) + if sni: + self.sni = sni + self.connection.set_tlsext_host_name(sni) + self.connection.set_connect_state() + try: + self.connection.do_handshake() + except SSL.Error as v: + if self.ssl_verification_error: + raise InvalidCertificateException("SSL handshake error: %s" % repr(v)) + else: + raise TlsException("SSL handshake error: %s" % repr(v)) + else: + # Fix for pre v1.0 OpenSSL, which doesn't throw an exception on + # certificate validation failure + if verification_mode == SSL.VERIFY_PEER and self.ssl_verification_error is not None: + raise InvalidCertificateException("SSL handshake error: certificate verify failed") + + self.cert = certutils.SSLCert(self.connection.get_peer_certificate()) + + # Validate TLS Hostname + try: + crt = dict( + subjectAltName=[("DNS", x.decode("ascii", "strict")) for x in self.cert.altnames] + ) + if self.cert.cn: + crt["subject"] = [[["commonName", self.cert.cn.decode("ascii", "strict")]]] + if sni: + hostname = sni.decode("ascii", "strict") + else: + hostname = "no-hostname" + ssl_match_hostname.match_hostname(crt, hostname) + except (ValueError, ssl_match_hostname.CertificateError) as e: + self.ssl_verification_error = dict(depth=0, errno="Invalid Hostname") + if verification_mode == SSL.VERIFY_PEER: + raise InvalidCertificateException("Presented certificate for {} is not valid: {}".format(sni, str(e))) + + self.ssl_established = True + self.rfile.set_descriptor(self.connection) + self.wfile.set_descriptor(self.connection) + + def connect(self): + try: + connection = socket.socket(self.address.family, socket.SOCK_STREAM) + if self.source_address: + connection.bind(self.source_address()) + connection.connect(self.address()) + if not self.source_address: + self.source_address = Address(connection.getsockname()) + except (socket.error, IOError) as err: + raise TcpException( + 'Error connecting to "%s": %s' % + (self.address.host, err)) + self.connection = connection + self._makefile() + + def settimeout(self, n): + self.connection.settimeout(n) + + def gettimeout(self): + return self.connection.gettimeout() + + def get_alpn_proto_negotiated(self): + if HAS_ALPN and self.ssl_established: + return self.connection.get_alpn_proto_negotiated() + else: + return b"" + + +class BaseHandler(_Connection): + + """ + The instantiator is expected to call the handle() and finish() methods. + """ + + def __init__(self, connection, address, server): + super(BaseHandler, self).__init__(connection) + self.address = Address.wrap(address) + self.server = server + self.clientcert = None + + def create_ssl_context(self, + cert, key, + handle_sni=None, + request_client_cert=None, + chain_file=None, + dhparams=None, + **sslctx_kwargs): + """ + cert: A certutils.SSLCert object or the path to a certificate + chain file. + + handle_sni: SNI handler, should take a connection object. Server + name can be retrieved like this: + + connection.get_servername() + + And you can specify the connection keys as follows: + + new_context = Context(TLSv1_METHOD) + new_context.use_privatekey(key) + new_context.use_certificate(cert) + connection.set_context(new_context) + + The request_client_cert argument requires some explanation. We're + supposed to be able to do this with no negative effects - if the + client has no cert to present, we're notified and proceed as usual. + Unfortunately, Android seems to have a bug (tested on 4.2.2) - when + an Android client is asked to present a certificate it does not + have, it hangs up, which is frankly bogus. Some time down the track + we may be able to make the proper behaviour the default again, but + until then we're conservative. + """ + + context = self._create_ssl_context(**sslctx_kwargs) + + context.use_privatekey(key) + if isinstance(cert, certutils.SSLCert): + context.use_certificate(cert.x509) + else: + context.use_certificate_chain_file(cert) + + if handle_sni: + # SNI callback happens during do_handshake() + context.set_tlsext_servername_callback(handle_sni) + + if request_client_cert: + def save_cert(conn_, cert, errno_, depth_, preverify_ok_): + self.clientcert = certutils.SSLCert(cert) + # Return true to prevent cert verification error + return True + context.set_verify(SSL.VERIFY_PEER, save_cert) + + # Cert Verify + if chain_file: + context.load_verify_locations(chain_file) + + if dhparams: + SSL._lib.SSL_CTX_set_tmp_dh(context._context, dhparams) + + return context + + def convert_to_ssl(self, cert, key, **sslctx_kwargs): + """ + Convert connection to SSL. + For a list of parameters, see BaseHandler._create_ssl_context(...) + """ + + context = self.create_ssl_context( + cert, + key, + **sslctx_kwargs) + self.connection = SSL.Connection(context, self.connection) + self.connection.set_accept_state() + try: + self.connection.do_handshake() + except SSL.Error as v: + raise TlsException("SSL handshake error: %s" % repr(v)) + self.ssl_established = True + self.rfile.set_descriptor(self.connection) + self.wfile.set_descriptor(self.connection) + + def handle(self): # pragma: no cover + raise NotImplementedError + + def settimeout(self, n): + self.connection.settimeout(n) + + def get_alpn_proto_negotiated(self): + if HAS_ALPN and self.ssl_established: + return self.connection.get_alpn_proto_negotiated() + else: + return b"" + + +class TCPServer(object): + request_queue_size = 20 + + def __init__(self, address): + self.address = Address.wrap(address) + self.__is_shut_down = threading.Event() + self.__shutdown_request = False + self.socket = socket.socket(self.address.family, socket.SOCK_STREAM) + self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + self.socket.bind(self.address()) + self.address = Address.wrap(self.socket.getsockname()) + self.socket.listen(self.request_queue_size) + + def connection_thread(self, connection, client_address): + client_address = Address(client_address) + try: + self.handle_client_connection(connection, client_address) + except: + self.handle_error(connection, client_address) + finally: + close_socket(connection) + + def serve_forever(self, poll_interval=0.1): + self.__is_shut_down.clear() + try: + while not self.__shutdown_request: + try: + r, w_, e_ = select.select( + [self.socket], [], [], poll_interval) + except select.error as ex: # pragma: no cover + if ex[0] == EINTR: + continue + else: + raise + if self.socket in r: + connection, client_address = self.socket.accept() + t = threading.Thread( + target=self.connection_thread, + args=(connection, client_address), + name="ConnectionThread (%s:%s -> %s:%s)" % + (client_address[0], client_address[1], + self.address.host, self.address.port) + ) + t.setDaemon(1) + try: + t.start() + except threading.ThreadError: + self.handle_error(connection, Address(client_address)) + connection.close() + finally: + self.__shutdown_request = False + self.__is_shut_down.set() + + def shutdown(self): + self.__shutdown_request = True + self.__is_shut_down.wait() + self.socket.close() + self.handle_shutdown() + + def handle_error(self, connection_, client_address, fp=sys.stderr): + """ + Called when handle_client_connection raises an exception. + """ + # If a thread has persisted after interpreter exit, the module might be + # none. + if traceback: + exc = six.text_type(traceback.format_exc()) + print(u'-' * 40, file=fp) + print( + u"Error in processing of request from %s" % repr(client_address), file=fp) + print(exc, file=fp) + print(u'-' * 40, file=fp) + + def handle_client_connection(self, conn, client_address): # pragma: no cover + """ + Called after client connection. + """ + raise NotImplementedError + + def handle_shutdown(self): + """ + Called after server shutdown. + """ diff --git a/netlib/tutils.py b/netlib/tutils.py new file mode 100644 index 00000000..f6ce8e0a --- /dev/null +++ b/netlib/tutils.py @@ -0,0 +1,133 @@ +from io import BytesIO +import tempfile +import os +import time +import shutil +from contextlib import contextmanager +import six +import sys + +from . import utils, tcp +from .http import Request, Response, Headers + + +def treader(bytes): + """ + Construct a tcp.Read object from bytes. + """ + fp = BytesIO(bytes) + return tcp.Reader(fp) + + +@contextmanager +def tmpdir(*args, **kwargs): + orig_workdir = os.getcwd() + temp_workdir = tempfile.mkdtemp(*args, **kwargs) + os.chdir(temp_workdir) + + yield temp_workdir + + os.chdir(orig_workdir) + shutil.rmtree(temp_workdir) + + +def _check_exception(expected, actual, exc_tb): + if isinstance(expected, six.string_types): + if expected.lower() not in str(actual).lower(): + six.reraise(AssertionError, AssertionError( + "Expected %s, but caught %s" % ( + repr(expected), repr(actual) + ) + ), exc_tb) + else: + if not isinstance(actual, expected): + six.reraise(AssertionError, AssertionError( + "Expected %s, but caught %s %s" % ( + expected.__name__, actual.__class__.__name__, repr(actual) + ) + ), exc_tb) + + +def raises(expected_exception, obj=None, *args, **kwargs): + """ + Assert that a callable raises a specified exception. + + :exc An exception class or a string. If a class, assert that an + exception of this type is raised. If a string, assert that the string + occurs in the string representation of the exception, based on a + case-insenstivie match. + + :obj A callable object. + + :args Arguments to be passsed to the callable. + + :kwargs Arguments to be passed to the callable. + """ + if obj is None: + return RaisesContext(expected_exception) + else: + try: + ret = obj(*args, **kwargs) + except Exception as actual: + _check_exception(expected_exception, actual, sys.exc_info()[2]) + else: + raise AssertionError("No exception raised. Return value: {}".format(ret)) + + +class RaisesContext(object): + def __init__(self, expected_exception): + self.expected_exception = expected_exception + + def __enter__(self): + return + + def __exit__(self, exc_type, exc_val, exc_tb): + if not exc_type: + raise AssertionError("No exception raised.") + else: + _check_exception(self.expected_exception, exc_val, exc_tb) + return True + + +test_data = utils.Data(__name__) +# FIXME: Temporary workaround during repo merge. +import os +test_data.dirname = os.path.join(test_data.dirname,"..","..","test","netlib") + + +def treq(**kwargs): + """ + Returns: + netlib.http.Request + """ + default = dict( + first_line_format="relative", + method=b"GET", + scheme=b"http", + host=b"address", + port=22, + path=b"/path", + http_version=b"HTTP/1.1", + headers=Headers(header="qvalue", content_length="7"), + content=b"content" + ) + default.update(kwargs) + return Request(**default) + + +def tresp(**kwargs): + """ + Returns: + netlib.http.Response + """ + default = dict( + http_version=b"HTTP/1.1", + status_code=200, + reason=b"OK", + headers=Headers(header_response="svalue", content_length="7"), + content=b"message", + timestamp_start=time.time(), + timestamp_end=time.time(), + ) + default.update(kwargs) + return Response(**default) diff --git a/netlib/utils.py b/netlib/utils.py new file mode 100644 index 00000000..f7bb5c4b --- /dev/null +++ b/netlib/utils.py @@ -0,0 +1,418 @@ +from __future__ import absolute_import, print_function, division +import os.path +import re +import codecs +import unicodedata +from abc import ABCMeta, abstractmethod +import importlib +import inspect + +import six + +from six.moves import urllib +import hyperframe + + +@six.add_metaclass(ABCMeta) +class Serializable(object): + """ + Abstract Base Class that defines an API to save an object's state and restore it later on. + """ + + @classmethod + @abstractmethod + def from_state(cls, state): + """ + Create a new object from the given state. + """ + raise NotImplementedError() + + @abstractmethod + def get_state(self): + """ + Retrieve object state. + """ + raise NotImplementedError() + + @abstractmethod + def set_state(self, state): + """ + Set object state to the given state. + """ + raise NotImplementedError() + + +def always_bytes(unicode_or_bytes, *encode_args): + if isinstance(unicode_or_bytes, six.text_type): + return unicode_or_bytes.encode(*encode_args) + return unicode_or_bytes + + +def always_byte_args(*encode_args): + """Decorator that transparently encodes all arguments passed as unicode""" + def decorator(fun): + def _fun(*args, **kwargs): + args = [always_bytes(arg, *encode_args) for arg in args] + kwargs = {k: always_bytes(v, *encode_args) for k, v in six.iteritems(kwargs)} + return fun(*args, **kwargs) + return _fun + return decorator + + +def native(s, *encoding_opts): + """ + Convert :py:class:`bytes` or :py:class:`unicode` to the native + :py:class:`str` type, using latin1 encoding if conversion is necessary. + + https://www.python.org/dev/peps/pep-3333/#a-note-on-string-types + """ + if not isinstance(s, (six.binary_type, six.text_type)): + raise TypeError("%r is neither bytes nor unicode" % s) + if six.PY3: + if isinstance(s, six.binary_type): + return s.decode(*encoding_opts) + else: + if isinstance(s, six.text_type): + return s.encode(*encoding_opts) + return s + + +def isascii(bytes): + try: + bytes.decode("ascii") + except ValueError: + return False + return True + + +def clean_bin(s, keep_spacing=True): + """ + Cleans binary data to make it safe to display. + + Args: + keep_spacing: If False, tabs and newlines will also be replaced. + """ + if isinstance(s, six.text_type): + if keep_spacing: + keep = u" \n\r\t" + else: + keep = u" " + return u"".join( + ch if (unicodedata.category(ch)[0] not in "CZ" or ch in keep) else u"." + for ch in s + ) + else: + if keep_spacing: + keep = (9, 10, 13) # \t, \n, \r, + else: + keep = () + return b"".join( + six.int2byte(ch) if (31 < ch < 127 or ch in keep) else b"." + for ch in six.iterbytes(s) + ) + + +def hexdump(s): + """ + Returns: + A generator of (offset, hex, str) tuples + """ + for i in range(0, len(s), 16): + offset = "{:0=10x}".format(i).encode() + part = s[i:i + 16] + x = b" ".join("{:0=2x}".format(i).encode() for i in six.iterbytes(part)) + x = x.ljust(47) # 16*2 + 15 + yield (offset, x, clean_bin(part, False)) + + +def setbit(byte, offset, value): + """ + Set a bit in a byte to 1 if value is truthy, 0 if not. + """ + if value: + return byte | (1 << offset) + else: + return byte & ~(1 << offset) + + +def getbit(byte, offset): + mask = 1 << offset + return bool(byte & mask) + + +class BiDi(object): + + """ + A wee utility class for keeping bi-directional mappings, like field + constants in protocols. Names are attributes on the object, dict-like + access maps values to names: + + CONST = BiDi(a=1, b=2) + assert CONST.a == 1 + assert CONST.get_name(1) == "a" + """ + + def __init__(self, **kwargs): + self.names = kwargs + self.values = {} + for k, v in kwargs.items(): + self.values[v] = k + if len(self.names) != len(self.values): + raise ValueError("Duplicate values not allowed.") + + def __getattr__(self, k): + if k in self.names: + return self.names[k] + raise AttributeError("No such attribute: %s", k) + + def get_name(self, n, default=None): + return self.values.get(n, default) + + +def pretty_size(size): + suffixes = [ + ("B", 2 ** 10), + ("kB", 2 ** 20), + ("MB", 2 ** 30), + ] + for suf, lim in suffixes: + if size >= lim: + continue + else: + x = round(size / float(lim / 2 ** 10), 2) + if x == int(x): + x = int(x) + return str(x) + suf + + +class Data(object): + + def __init__(self, name): + m = importlib.import_module(name) + dirname = os.path.dirname(inspect.getsourcefile(m)) + self.dirname = os.path.abspath(dirname) + + def path(self, path): + """ + Returns a path to the package data housed at 'path' under this + module.Path can be a path to a file, or to a directory. + + This function will raise ValueError if the path does not exist. + """ + fullpath = os.path.join(self.dirname, path) + if not os.path.exists(fullpath): + raise ValueError("dataPath: %s does not exist." % fullpath) + return fullpath + + +_label_valid = re.compile(b"(?!-)[A-Z\d-]{1,63}(? 255: + return False + if host[-1] == b".": + host = host[:-1] + return all(_label_valid.match(x) for x in host.split(b".")) + + +def is_valid_port(port): + return 0 <= port <= 65535 + + +# PY2 workaround +def decode_parse_result(result, enc): + if hasattr(result, "decode"): + return result.decode(enc) + else: + return urllib.parse.ParseResult(*[x.decode(enc) for x in result]) + + +# PY2 workaround +def encode_parse_result(result, enc): + if hasattr(result, "encode"): + return result.encode(enc) + else: + return urllib.parse.ParseResult(*[x.encode(enc) for x in result]) + + +def parse_url(url): + """ + URL-parsing function that checks that + - port is an integer 0-65535 + - host is a valid IDNA-encoded hostname with no null-bytes + - path is valid ASCII + + Args: + A URL (as bytes or as unicode) + + Returns: + A (scheme, host, port, path) tuple + + Raises: + ValueError, if the URL is not properly formatted. + """ + parsed = urllib.parse.urlparse(url) + + if not parsed.hostname: + raise ValueError("No hostname given") + + if isinstance(url, six.binary_type): + host = parsed.hostname + + # this should not raise a ValueError, + # but we try to be very forgiving here and accept just everything. + # decode_parse_result(parsed, "ascii") + else: + host = parsed.hostname.encode("idna") + parsed = encode_parse_result(parsed, "ascii") + + port = parsed.port + if not port: + port = 443 if parsed.scheme == b"https" else 80 + + full_path = urllib.parse.urlunparse( + (b"", b"", parsed.path, parsed.params, parsed.query, parsed.fragment) + ) + if not full_path.startswith(b"/"): + full_path = b"/" + full_path + + if not is_valid_host(host): + raise ValueError("Invalid Host") + if not is_valid_port(port): + raise ValueError("Invalid Port") + + return parsed.scheme, host, port, full_path + + +def get_header_tokens(headers, key): + """ + Retrieve all tokens for a header key. A number of different headers + follow a pattern where each header line can containe comma-separated + tokens, and headers can be set multiple times. + """ + if key not in headers: + return [] + tokens = headers[key].split(",") + return [token.strip() for token in tokens] + + +def hostport(scheme, host, port): + """ + Returns the host component, with a port specifcation if needed. + """ + if (port, scheme) in [(80, "http"), (443, "https"), (80, b"http"), (443, b"https")]: + return host + else: + if isinstance(host, six.binary_type): + return b"%s:%d" % (host, port) + else: + return "%s:%d" % (host, port) + + +def unparse_url(scheme, host, port, path=""): + """ + Returns a URL string, constructed from the specified components. + + Args: + All args must be str. + """ + return "%s://%s%s" % (scheme, hostport(scheme, host, port), path) + + +def urlencode(s): + """ + Takes a list of (key, value) tuples and returns a urlencoded string. + """ + s = [tuple(i) for i in s] + return urllib.parse.urlencode(s, False) + + +def urldecode(s): + """ + Takes a urlencoded string and returns a list of (key, value) tuples. + """ + return urllib.parse.parse_qsl(s, keep_blank_values=True) + + +def parse_content_type(c): + """ + A simple parser for content-type values. Returns a (type, subtype, + parameters) tuple, where type and subtype are strings, and parameters + is a dict. If the string could not be parsed, return None. + + E.g. the following string: + + text/html; charset=UTF-8 + + Returns: + + ("text", "html", {"charset": "UTF-8"}) + """ + parts = c.split(";", 1) + ts = parts[0].split("/", 1) + if len(ts) != 2: + return None + d = {} + if len(parts) == 2: + for i in parts[1].split(";"): + clause = i.split("=", 1) + if len(clause) == 2: + d[clause[0].strip()] = clause[1].strip() + return ts[0].lower(), ts[1].lower(), d + + +def multipartdecode(headers, content): + """ + Takes a multipart boundary encoded string and returns list of (key, value) tuples. + """ + v = headers.get("content-type") + if v: + v = parse_content_type(v) + if not v: + return [] + try: + boundary = v[2]["boundary"].encode("ascii") + except (KeyError, UnicodeError): + return [] + + rx = re.compile(br'\bname="([^"]+)"') + r = [] + + for i in content.split(b"--" + boundary): + parts = i.splitlines() + if len(parts) > 1 and parts[0][0:2] != b"--": + match = rx.search(parts[1]) + if match: + key = match.group(1) + value = b"".join(parts[3 + parts[2:].index(b""):]) + r.append((key, value)) + return r + return [] + + +def http2_read_raw_frame(rfile): + header = rfile.safe_read(9) + length = int(codecs.encode(header[:3], 'hex_codec'), 16) + + if length == 4740180: + raise ValueError("Length field looks more like HTTP/1.1: %s" % rfile.peek(20)) + + body = rfile.safe_read(length) + return [header, body] + +def http2_read_frame(rfile): + header, body = http2_read_raw_frame(rfile) + frame, length = hyperframe.frame.Frame.parse_frame_header(header) + frame.parse_body(memoryview(body)) + return frame diff --git a/netlib/version.py b/netlib/version.py new file mode 100644 index 00000000..379fee0f --- /dev/null +++ b/netlib/version.py @@ -0,0 +1,6 @@ +from __future__ import (absolute_import, print_function, division) + +IVERSION = (0, 17) +VERSION = ".".join(str(i) for i in IVERSION) +NAME = "netlib" +NAMEVERSION = NAME + " " + VERSION diff --git a/netlib/version_check.py b/netlib/version_check.py new file mode 100644 index 00000000..9cf27eea --- /dev/null +++ b/netlib/version_check.py @@ -0,0 +1,60 @@ +""" +Having installed a wrong version of pyOpenSSL or netlib is unfortunately a +very common source of error. Check before every start that both versions +are somewhat okay. +""" +from __future__ import division, absolute_import, print_function +import sys +import inspect +import os.path +import six + +import OpenSSL +from . import version + +PYOPENSSL_MIN_VERSION = (0, 15) + + +def check_mitmproxy_version(mitmproxy_version, fp=sys.stderr): + # We don't introduce backward-incompatible changes in patch versions. Only + # consider major and minor version. + if version.IVERSION[:2] != mitmproxy_version[:2]: + print( + u"You are using mitmproxy %s with netlib %s. " + u"Most likely, that won't work - please upgrade!" % ( + mitmproxy_version, version.VERSION + ), + file=fp + ) + sys.exit(1) + + +def check_pyopenssl_version(min_version=PYOPENSSL_MIN_VERSION, fp=sys.stderr): + min_version_str = u".".join(six.text_type(x) for x in min_version) + try: + v = tuple(int(x) for x in OpenSSL.__version__.split(".")[:2]) + except ValueError: + print( + u"Cannot parse pyOpenSSL version: {}" + u"mitmproxy requires pyOpenSSL {} or greater.".format( + OpenSSL.__version__, min_version_str + ), + file=fp + ) + return + if v < min_version: + print( + u"You are using an outdated version of pyOpenSSL: " + u"mitmproxy requires pyOpenSSL {} or greater.".format(min_version_str), + file=fp + ) + # Some users apparently have multiple versions of pyOpenSSL installed. + # Report which one we got. + pyopenssl_path = os.path.dirname(inspect.getfile(OpenSSL)) + print( + u"Your pyOpenSSL {} installation is located at {}".format( + OpenSSL.__version__, pyopenssl_path + ), + file=fp + ) + sys.exit(1) diff --git a/netlib/websockets/__init__.py b/netlib/websockets/__init__.py new file mode 100644 index 00000000..1c143919 --- /dev/null +++ b/netlib/websockets/__init__.py @@ -0,0 +1,2 @@ +from .frame import * +from .protocol import * diff --git a/netlib/websockets/frame.py b/netlib/websockets/frame.py new file mode 100644 index 00000000..fce2c9d3 --- /dev/null +++ b/netlib/websockets/frame.py @@ -0,0 +1,316 @@ +from __future__ import absolute_import +import os +import struct +import io +import warnings + +import six + +from .protocol import Masker +from netlib import tcp +from netlib import utils + + +MAX_16_BIT_INT = (1 << 16) +MAX_64_BIT_INT = (1 << 64) + +DEFAULT=object() + +OPCODE = utils.BiDi( + CONTINUE=0x00, + TEXT=0x01, + BINARY=0x02, + CLOSE=0x08, + PING=0x09, + PONG=0x0a +) + + +class FrameHeader(object): + + def __init__( + self, + opcode=OPCODE.TEXT, + payload_length=0, + fin=False, + rsv1=False, + rsv2=False, + rsv3=False, + masking_key=DEFAULT, + mask=DEFAULT, + length_code=DEFAULT + ): + if not 0 <= opcode < 2 ** 4: + raise ValueError("opcode must be 0-16") + self.opcode = opcode + self.payload_length = payload_length + self.fin = fin + self.rsv1 = rsv1 + self.rsv2 = rsv2 + self.rsv3 = rsv3 + + if length_code is DEFAULT: + self.length_code = self._make_length_code(self.payload_length) + else: + self.length_code = length_code + + if mask is DEFAULT and masking_key is DEFAULT: + self.mask = False + self.masking_key = b"" + elif mask is DEFAULT: + self.mask = 1 + self.masking_key = masking_key + elif masking_key is DEFAULT: + self.mask = mask + self.masking_key = os.urandom(4) + else: + self.mask = mask + self.masking_key = masking_key + + if self.masking_key and len(self.masking_key) != 4: + raise ValueError("Masking key must be 4 bytes.") + + @classmethod + def _make_length_code(self, length): + """ + A websockets frame contains an initial length_code, and an optional + extended length code to represent the actual length if length code is + larger than 125 + """ + if length <= 125: + return length + elif length >= 126 and length <= 65535: + return 126 + else: + return 127 + + def __repr__(self): + vals = [ + "ws frame:", + OPCODE.get_name(self.opcode, hex(self.opcode)).lower() + ] + flags = [] + for i in ["fin", "rsv1", "rsv2", "rsv3", "mask"]: + if getattr(self, i): + flags.append(i) + if flags: + vals.extend([":", "|".join(flags)]) + if self.masking_key: + vals.append(":key=%s" % repr(self.masking_key)) + if self.payload_length: + vals.append(" %s" % utils.pretty_size(self.payload_length)) + return "".join(vals) + + def human_readable(self): + warnings.warn("FrameHeader.to_bytes is deprecated, use bytes(frame_header) instead.", DeprecationWarning) + return repr(self) + + def __bytes__(self): + first_byte = utils.setbit(0, 7, self.fin) + first_byte = utils.setbit(first_byte, 6, self.rsv1) + first_byte = utils.setbit(first_byte, 5, self.rsv2) + first_byte = utils.setbit(first_byte, 4, self.rsv3) + first_byte = first_byte | self.opcode + + second_byte = utils.setbit(self.length_code, 7, self.mask) + + b = six.int2byte(first_byte) + six.int2byte(second_byte) + + if self.payload_length < 126: + pass + elif self.payload_length < MAX_16_BIT_INT: + # '!H' pack as 16 bit unsigned short + # add 2 byte extended payload length + b += struct.pack('!H', self.payload_length) + elif self.payload_length < MAX_64_BIT_INT: + # '!Q' = pack as 64 bit unsigned long long + # add 8 bytes extended payload length + b += struct.pack('!Q', self.payload_length) + if self.masking_key: + b += self.masking_key + return b + + if six.PY2: + __str__ = __bytes__ + + def to_bytes(self): + warnings.warn("FrameHeader.to_bytes is deprecated, use bytes(frame_header) instead.", DeprecationWarning) + return bytes(self) + + @classmethod + def from_file(cls, fp): + """ + read a websockets frame header + """ + first_byte = six.byte2int(fp.safe_read(1)) + second_byte = six.byte2int(fp.safe_read(1)) + + fin = utils.getbit(first_byte, 7) + rsv1 = utils.getbit(first_byte, 6) + rsv2 = utils.getbit(first_byte, 5) + rsv3 = utils.getbit(first_byte, 4) + # grab right-most 4 bits + opcode = first_byte & 15 + mask_bit = utils.getbit(second_byte, 7) + # grab the next 7 bits + length_code = second_byte & 127 + + # payload_lengthy > 125 indicates you need to read more bytes + # to get the actual payload length + if length_code <= 125: + payload_length = length_code + elif length_code == 126: + payload_length, = struct.unpack("!H", fp.safe_read(2)) + elif length_code == 127: + payload_length, = struct.unpack("!Q", fp.safe_read(8)) + + # masking key only present if mask bit set + if mask_bit == 1: + masking_key = fp.safe_read(4) + else: + masking_key = None + + return cls( + fin=fin, + rsv1=rsv1, + rsv2=rsv2, + rsv3=rsv3, + opcode=opcode, + mask=mask_bit, + length_code=length_code, + payload_length=payload_length, + masking_key=masking_key, + ) + + def __eq__(self, other): + if isinstance(other, FrameHeader): + return bytes(self) == bytes(other) + return False + + +class Frame(object): + + """ + Represents one websockets frame. + Constructor takes human readable forms of the frame components + from_bytes() is also avaliable. + + WebSockets Frame as defined in RFC6455 + + 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 + +-+-+-+-+-------+-+-------------+-------------------------------+ + |F|R|R|R| opcode|M| Payload len | Extended payload length | + |I|S|S|S| (4) |A| (7) | (16/64) | + |N|V|V|V| |S| | (if payload len==126/127) | + | |1|2|3| |K| | | + +-+-+-+-+-------+-+-------------+ - - - - - - - - - - - - - - - + + | Extended payload length continued, if payload len == 127 | + + - - - - - - - - - - - - - - - +-------------------------------+ + | |Masking-key, if MASK set to 1 | + +-------------------------------+-------------------------------+ + | Masking-key (continued) | Payload Data | + +-------------------------------- - - - - - - - - - - - - - - - + + : Payload Data continued ... : + + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + | Payload Data continued ... | + +---------------------------------------------------------------+ + """ + + def __init__(self, payload=b"", **kwargs): + self.payload = payload + kwargs["payload_length"] = kwargs.get("payload_length", len(payload)) + self.header = FrameHeader(**kwargs) + + @classmethod + def default(cls, message, from_client=False): + """ + Construct a basic websocket frame from some default values. + Creates a non-fragmented text frame. + """ + if from_client: + mask_bit = 1 + masking_key = os.urandom(4) + else: + mask_bit = 0 + masking_key = None + + return cls( + message, + fin=1, # final frame + opcode=OPCODE.TEXT, # text + mask=mask_bit, + masking_key=masking_key, + ) + + @classmethod + def from_bytes(cls, bytestring): + """ + Construct a websocket frame from an in-memory bytestring + to construct a frame from a stream of bytes, use from_file() directly + """ + return cls.from_file(tcp.Reader(io.BytesIO(bytestring))) + + def __repr__(self): + ret = repr(self.header) + if self.payload: + ret = ret + "\nPayload:\n" + utils.clean_bin(self.payload).decode("ascii") + return ret + + def human_readable(self): + warnings.warn("Frame.to_bytes is deprecated, use bytes(frame) instead.", DeprecationWarning) + return repr(self) + + def __bytes__(self): + """ + Serialize the frame to wire format. Returns a string. + """ + b = bytes(self.header) + if self.header.masking_key: + b += Masker(self.header.masking_key)(self.payload) + else: + b += self.payload + return b + + if six.PY2: + __str__ = __bytes__ + + def to_bytes(self): + warnings.warn("FrameHeader.to_bytes is deprecated, use bytes(frame_header) instead.", DeprecationWarning) + return bytes(self) + + def to_file(self, writer): + warnings.warn("Frame.to_file is deprecated, use wfile.write(bytes(frame)) instead.", DeprecationWarning) + writer.write(bytes(self)) + writer.flush() + + @classmethod + def from_file(cls, fp): + """ + read a websockets frame sent by a server or client + + fp is a "file like" object that could be backed by a network + stream or a disk or an in memory stream reader + """ + header = FrameHeader.from_file(fp) + payload = fp.safe_read(header.payload_length) + + if header.mask == 1 and header.masking_key: + payload = Masker(header.masking_key)(payload) + + return cls( + payload, + fin=header.fin, + opcode=header.opcode, + mask=header.mask, + payload_length=header.payload_length, + masking_key=header.masking_key, + rsv1=header.rsv1, + rsv2=header.rsv2, + rsv3=header.rsv3, + length_code=header.length_code + ) + + def __eq__(self, other): + if isinstance(other, Frame): + return bytes(self) == bytes(other) + return False diff --git a/netlib/websockets/protocol.py b/netlib/websockets/protocol.py new file mode 100644 index 00000000..1e95fa1c --- /dev/null +++ b/netlib/websockets/protocol.py @@ -0,0 +1,115 @@ + + + +# Colleciton of utility functions that implement small portions of the RFC6455 +# WebSockets Protocol Useful for building WebSocket clients and servers. +# +# Emphassis is on readabilty, simplicity and modularity, not performance or +# completeness +# +# This is a work in progress and does not yet contain all the utilites need to +# create fully complient client/servers # +# Spec: https://tools.ietf.org/html/rfc6455 + +# The magic sha that websocket servers must know to prove they understand +# RFC6455 +from __future__ import absolute_import +import base64 +import hashlib +import os + +import binascii +import six +from ..http import Headers + +websockets_magic = b'258EAFA5-E914-47DA-95CA-C5AB0DC85B11' +VERSION = "13" + + +class Masker(object): + + """ + Data sent from the server must be masked to prevent malicious clients + from sending data over the wire in predictable patterns + + Servers do not have to mask data they send to the client. + https://tools.ietf.org/html/rfc6455#section-5.3 + """ + + def __init__(self, key): + self.key = key + self.offset = 0 + + def mask(self, offset, data): + result = bytearray(data) + if six.PY2: + for i in range(len(data)): + result[i] ^= ord(self.key[offset % 4]) + offset += 1 + result = str(result) + else: + + for i in range(len(data)): + result[i] ^= self.key[offset % 4] + offset += 1 + result = bytes(result) + return result + + def __call__(self, data): + ret = self.mask(self.offset, data) + self.offset += len(ret) + return ret + + +class WebsocketsProtocol(object): + + def __init__(self): + pass + + @classmethod + def client_handshake_headers(self, key=None, version=VERSION): + """ + Create the headers for a valid HTTP upgrade request. If Key is not + specified, it is generated, and can be found in sec-websocket-key in + the returned header set. + + Returns an instance of Headers + """ + if not key: + key = base64.b64encode(os.urandom(16)).decode('ascii') + return Headers( + sec_websocket_key=key, + sec_websocket_version=version, + connection="Upgrade", + upgrade="websocket", + ) + + @classmethod + def server_handshake_headers(self, key): + """ + The server response is a valid HTTP 101 response. + """ + return Headers( + sec_websocket_accept=self.create_server_nonce(key), + connection="Upgrade", + upgrade="websocket" + ) + + + @classmethod + def check_client_handshake(self, headers): + if headers.get("upgrade") != "websocket": + return + return headers.get("sec-websocket-key") + + + @classmethod + def check_server_handshake(self, headers): + if headers.get("upgrade") != "websocket": + return + return headers.get("sec-websocket-accept") + + + @classmethod + def create_server_nonce(self, client_nonce): + return base64.b64encode(hashlib.sha1(client_nonce + websockets_magic).digest()) diff --git a/netlib/wsgi.py b/netlib/wsgi.py new file mode 100644 index 00000000..d6dfae5d --- /dev/null +++ b/netlib/wsgi.py @@ -0,0 +1,164 @@ +from __future__ import (absolute_import, print_function, division) +from io import BytesIO, StringIO +import urllib +import time +import traceback + +import six +from six.moves import urllib + +from netlib.utils import always_bytes, native +from . import http, tcp + +class ClientConn(object): + + def __init__(self, address): + self.address = tcp.Address.wrap(address) + + +class Flow(object): + + def __init__(self, address, request): + self.client_conn = ClientConn(address) + self.request = request + + +class Request(object): + + def __init__(self, scheme, method, path, http_version, headers, content): + self.scheme, self.method, self.path = scheme, method, path + self.headers, self.content = headers, content + self.http_version = http_version + + +def date_time_string(): + """Return the current date and time formatted for a message header.""" + WEEKS = ['Mon', 'Tue', 'Wed', 'Thu', 'Fri', 'Sat', 'Sun'] + MONTHS = [ + None, + 'Jan', 'Feb', 'Mar', 'Apr', 'May', 'Jun', + 'Jul', 'Aug', 'Sep', 'Oct', 'Nov', 'Dec' + ] + now = time.time() + year, month, day, hh, mm, ss, wd, y_, z_ = time.gmtime(now) + s = "%s, %02d %3s %4d %02d:%02d:%02d GMT" % ( + WEEKS[wd], + day, MONTHS[month], year, + hh, mm, ss + ) + return s + + +class WSGIAdaptor(object): + + def __init__(self, app, domain, port, sversion): + self.app, self.domain, self.port, self.sversion = app, domain, port, sversion + + def make_environ(self, flow, errsoc, **extra): + path = native(flow.request.path, "latin-1") + if '?' in path: + path_info, query = native(path, "latin-1").split('?', 1) + else: + path_info = path + query = '' + environ = { + 'wsgi.version': (1, 0), + 'wsgi.url_scheme': native(flow.request.scheme, "latin-1"), + 'wsgi.input': BytesIO(flow.request.content or b""), + 'wsgi.errors': errsoc, + 'wsgi.multithread': True, + 'wsgi.multiprocess': False, + 'wsgi.run_once': False, + 'SERVER_SOFTWARE': self.sversion, + 'REQUEST_METHOD': native(flow.request.method, "latin-1"), + 'SCRIPT_NAME': '', + 'PATH_INFO': urllib.parse.unquote(path_info), + 'QUERY_STRING': query, + 'CONTENT_TYPE': native(flow.request.headers.get('Content-Type', ''), "latin-1"), + 'CONTENT_LENGTH': native(flow.request.headers.get('Content-Length', ''), "latin-1"), + 'SERVER_NAME': self.domain, + 'SERVER_PORT': str(self.port), + 'SERVER_PROTOCOL': native(flow.request.http_version, "latin-1"), + } + environ.update(extra) + if flow.client_conn.address: + environ["REMOTE_ADDR"] = native(flow.client_conn.address.host, "latin-1") + environ["REMOTE_PORT"] = flow.client_conn.address.port + + for key, value in flow.request.headers.items(): + key = 'HTTP_' + native(key, "latin-1").upper().replace('-', '_') + if key not in ('HTTP_CONTENT_TYPE', 'HTTP_CONTENT_LENGTH'): + environ[key] = value + return environ + + def error_page(self, soc, headers_sent, s): + """ + Make a best-effort attempt to write an error page. If headers are + already sent, we just bung the error into the page. + """ + c = """ + +

Internal Server Error

+
{err}"
+ + """.format(err=s).strip().encode() + + if not headers_sent: + soc.write(b"HTTP/1.1 500 Internal Server Error\r\n") + soc.write(b"Content-Type: text/html\r\n") + soc.write("Content-Length: {length}\r\n".format(length=len(c)).encode()) + soc.write(b"\r\n") + soc.write(c) + + def serve(self, request, soc, **env): + state = dict( + response_started=False, + headers_sent=False, + status=None, + headers=None + ) + + def write(data): + if not state["headers_sent"]: + soc.write("HTTP/1.1 {status}\r\n".format(status=state["status"]).encode()) + headers = state["headers"] + if 'server' not in headers: + headers["Server"] = self.sversion + if 'date' not in headers: + headers["Date"] = date_time_string() + soc.write(bytes(headers)) + soc.write(b"\r\n") + state["headers_sent"] = True + if data: + soc.write(data) + soc.flush() + + def start_response(status, headers, exc_info=None): + if exc_info: + if state["headers_sent"]: + six.reraise(*exc_info) + elif state["status"]: + raise AssertionError('Response already started') + state["status"] = status + state["headers"] = http.Headers([[always_bytes(k), always_bytes(v)] for k,v in headers]) + if exc_info: + self.error_page(soc, state["headers_sent"], traceback.format_tb(exc_info[2])) + state["headers_sent"] = True + + errs = six.BytesIO() + try: + dataiter = self.app( + self.make_environ(request, errs, **env), start_response + ) + for i in dataiter: + write(i) + if not state["headers_sent"]: + write(b"") + except Exception as e: + try: + s = traceback.format_exc() + errs.write(s.encode("utf-8", "replace")) + self.error_page(soc, state["headers_sent"], s) + except Exception: # pragma: no cover + pass + return errs.getvalue() -- cgit v1.2.3