diff options
Diffstat (limited to 'netlib')
-rw-r--r-- | netlib/certutils.py | 56 | ||||
-rw-r--r-- | netlib/h2/frame.py | 34 | ||||
-rw-r--r-- | netlib/h2/h2.py | 30 | ||||
-rw-r--r-- | netlib/http.py | 13 | ||||
-rw-r--r-- | netlib/http_auth.py | 22 | ||||
-rw-r--r-- | netlib/http_cookies.py | 10 | ||||
-rw-r--r-- | netlib/http_status.py | 84 | ||||
-rw-r--r-- | netlib/odict.py | 10 | ||||
-rw-r--r-- | netlib/socks.py | 43 | ||||
-rw-r--r-- | netlib/tcp.py | 62 | ||||
-rw-r--r-- | netlib/test.py | 24 | ||||
-rw-r--r-- | netlib/utils.py | 10 | ||||
-rw-r--r-- | netlib/websockets.py | 87 | ||||
-rw-r--r-- | netlib/wsgi.py | 52 |
14 files changed, 308 insertions, 229 deletions
diff --git a/netlib/certutils.py b/netlib/certutils.py index f5375c03..da0e3355 100644 --- a/netlib/certutils.py +++ b/netlib/certutils.py @@ -1,12 +1,15 @@ from __future__ import (absolute_import, print_function, division) -import os, ssl, time, datetime +import os +import ssl +import time +import datetime import itertools from pyasn1.type import univ, constraint, char, namedtype, tag from pyasn1.codec.der.decoder import decode from pyasn1.error import PyAsn1Error import OpenSSL -DEFAULT_EXP = 157680000 # = 24 * 60 * 60 * 365 * 5 +DEFAULT_EXP = 157680000 # = 24 * 60 * 60 * 365 * 5 # Generated with "openssl dhparam". It's too slow to generate this on startup. DEFAULT_DHPARAM = """-----BEGIN DH PARAMETERS----- MIGHAoGBAOdPzMbYgoYfO3YBYauCLRlE8X1XypTiAjoeCFD0qWRx8YUsZ6Sj20W5 @@ -14,31 +17,32 @@ zsfQxlZfKovo3f2MftjkDkbI/C/tDgxoe0ZPbjy5CjdOhkzxn0oTbKTs16Rw8DyK 1LjTR65sQJkJEdgsX8TSi/cicCftJZl9CaZEaObF2bdgSgGK+PezAgEC -----END DH PARAMETERS-----""" + def create_ca(o, cn, exp): key = OpenSSL.crypto.PKey() key.generate_key(OpenSSL.crypto.TYPE_RSA, 1024) cert = OpenSSL.crypto.X509() - cert.set_serial_number(int(time.time()*10000)) + cert.set_serial_number(int(time.time() * 10000)) cert.set_version(2) cert.get_subject().CN = cn cert.get_subject().O = o - cert.gmtime_adj_notBefore(-3600*48) + cert.gmtime_adj_notBefore(-3600 * 48) cert.gmtime_adj_notAfter(exp) cert.set_issuer(cert.get_subject()) cert.set_pubkey(key) cert.add_extensions([ - OpenSSL.crypto.X509Extension("basicConstraints", True, - "CA:TRUE"), - OpenSSL.crypto.X509Extension("nsCertType", False, - "sslCA"), - OpenSSL.crypto.X509Extension("extendedKeyUsage", False, - "serverAuth,clientAuth,emailProtection,timeStamping,msCodeInd,msCodeCom,msCTLSign,msSGC,msEFS,nsSGC" - ), - OpenSSL.crypto.X509Extension("keyUsage", True, - "keyCertSign, cRLSign"), - OpenSSL.crypto.X509Extension("subjectKeyIdentifier", False, "hash", - subject=cert), - ]) + OpenSSL.crypto.X509Extension("basicConstraints", True, + "CA:TRUE"), + OpenSSL.crypto.X509Extension("nsCertType", False, + "sslCA"), + OpenSSL.crypto.X509Extension("extendedKeyUsage", False, + "serverAuth,clientAuth,emailProtection,timeStamping,msCodeInd,msCodeCom,msCTLSign,msSGC,msEFS,nsSGC" + ), + OpenSSL.crypto.X509Extension("keyUsage", True, + "keyCertSign, cRLSign"), + OpenSSL.crypto.X509Extension("subjectKeyIdentifier", False, "hash", + subject=cert), + ]) cert.sign(key, "sha1") return key, cert @@ -56,15 +60,15 @@ def dummy_cert(privkey, cacert, commonname, sans): """ ss = [] for i in sans: - ss.append("DNS: %s"%i) + ss.append("DNS: %s" % i) ss = ", ".join(ss) cert = OpenSSL.crypto.X509() - cert.gmtime_adj_notBefore(-3600*48) + cert.gmtime_adj_notBefore(-3600 * 48) cert.gmtime_adj_notAfter(DEFAULT_EXP) cert.set_issuer(cacert.get_subject()) cert.get_subject().CN = commonname - cert.set_serial_number(int(time.time()*10000)) + cert.set_serial_number(int(time.time() * 10000)) if ss: cert.set_version(2) cert.add_extensions([OpenSSL.crypto.X509Extension("subjectAltName", False, ss)]) @@ -114,6 +118,7 @@ def dummy_cert(privkey, cacert, commonname, sans): class CertStoreEntry(object): + def __init__(self, cert, privatekey, chain_file): self.cert = cert self.privatekey = privatekey @@ -121,9 +126,11 @@ class CertStoreEntry(object): class CertStore(object): + """ Implements an in-memory certificate store. """ + def __init__(self, default_privatekey, default_ca, default_chain_file, dhparams=None): self.default_privatekey = default_privatekey self.default_ca = default_ca @@ -144,11 +151,11 @@ class CertStore(object): if bio != OpenSSL.SSL._ffi.NULL: bio = OpenSSL.SSL._ffi.gc(bio, OpenSSL.SSL._lib.BIO_free) dh = OpenSSL.SSL._lib.PEM_read_bio_DHparams( - bio, OpenSSL.SSL._ffi.NULL, OpenSSL.SSL._ffi.NULL, OpenSSL.SSL._ffi.NULL - ) + bio, OpenSSL.SSL._ffi.NULL, OpenSSL.SSL._ffi.NULL, OpenSSL.SSL._ffi.NULL + ) dh = OpenSSL.SSL._ffi.gc(dh, OpenSSL.SSL._lib.DH_free) return dh - + @classmethod def from_store(cls, path, basename): ca_path = os.path.join(path, basename + "-ca.pem") @@ -277,8 +284,8 @@ class _GeneralName(univ.Choice): # other types. componentType = namedtype.NamedTypes( namedtype.NamedType('dNSName', char.IA5String().subtype( - implicitTag=tag.Tag(tag.tagClassContext, tag.tagFormatSimple, 2) - ) + implicitTag=tag.Tag(tag.tagClassContext, tag.tagFormatSimple, 2) + ) ), ) @@ -289,6 +296,7 @@ class _GeneralNames(univ.SequenceOf): class SSLCert(object): + def __init__(self, cert): """ Returns a (common name, [subject alternative names]) tuple. diff --git a/netlib/h2/frame.py b/netlib/h2/frame.py index 52cc2992..d846b3b9 100644 --- a/netlib/h2/frame.py +++ b/netlib/h2/frame.py @@ -5,8 +5,11 @@ import struct import io from .. import utils, odict, tcp +from functools import reduce + class Frame(object): + """ Baseclass Frame contains header @@ -53,6 +56,7 @@ class Frame(object): def __eq__(self, other): return self.to_bytes() == other.to_bytes() + class DataFrame(Frame): TYPE = 0x0 VALID_FLAGS = [Frame.FLAG_END_STREAM, Frame.FLAG_PADDED] @@ -89,11 +93,13 @@ class DataFrame(Frame): return b + class HeadersFrame(Frame): TYPE = 0x1 VALID_FLAGS = [Frame.FLAG_END_STREAM, Frame.FLAG_END_HEADERS, Frame.FLAG_PADDED, Frame.FLAG_PRIORITY] - def __init__(self, length=0, flags=Frame.FLAG_NO_FLAGS, stream_id=0x0, header_block_fragment=b'', pad_length=0, exclusive=False, stream_dependency=0x0, weight=0): + def __init__(self, length=0, flags=Frame.FLAG_NO_FLAGS, stream_id=0x0, header_block_fragment=b'', + pad_length=0, exclusive=False, stream_dependency=0x0, weight=0): super(HeadersFrame, self).__init__(length, flags, stream_id) self.header_block_fragment = header_block_fragment self.pad_length = pad_length @@ -137,6 +143,7 @@ class HeadersFrame(Frame): return b + class PriorityFrame(Frame): TYPE = 0x2 VALID_FLAGS = [] @@ -166,6 +173,7 @@ class PriorityFrame(Frame): return struct.pack('!LB', (int(self.exclusive) << 31) | self.stream_dependency, self.weight) + class RstStreamFrame(Frame): TYPE = 0x3 VALID_FLAGS = [] @@ -186,18 +194,19 @@ class RstStreamFrame(Frame): return struct.pack('!L', self.error_code) + class SettingsFrame(Frame): TYPE = 0x4 VALID_FLAGS = [Frame.FLAG_ACK] SETTINGS = utils.BiDi( - SETTINGS_HEADER_TABLE_SIZE = 0x1, - SETTINGS_ENABLE_PUSH = 0x2, - SETTINGS_MAX_CONCURRENT_STREAMS = 0x3, - SETTINGS_INITIAL_WINDOW_SIZE = 0x4, - SETTINGS_MAX_FRAME_SIZE = 0x5, - SETTINGS_MAX_HEADER_LIST_SIZE = 0x6, - ) + SETTINGS_HEADER_TABLE_SIZE=0x1, + SETTINGS_ENABLE_PUSH=0x2, + SETTINGS_MAX_CONCURRENT_STREAMS=0x3, + SETTINGS_INITIAL_WINDOW_SIZE=0x4, + SETTINGS_MAX_FRAME_SIZE=0x5, + SETTINGS_MAX_HEADER_LIST_SIZE=0x6, + ) def __init__(self, length=0, flags=Frame.FLAG_NO_FLAGS, stream_id=0x0, settings={}): super(SettingsFrame, self).__init__(length, flags, stream_id) @@ -208,7 +217,7 @@ class SettingsFrame(Frame): f = self(length=length, flags=flags, stream_id=stream_id) for i in xrange(0, len(payload), 6): - identifier, value = struct.unpack("!HL", payload[i:i+6]) + identifier, value = struct.unpack("!HL", payload[i:i + 6]) f.settings[identifier] = value return f @@ -223,6 +232,7 @@ class SettingsFrame(Frame): return b + class PushPromiseFrame(Frame): TYPE = 0x5 VALID_FLAGS = [Frame.FLAG_END_HEADERS, Frame.FLAG_PADDED] @@ -267,6 +277,7 @@ class PushPromiseFrame(Frame): return b + class PingFrame(Frame): TYPE = 0x6 VALID_FLAGS = [Frame.FLAG_ACK] @@ -289,6 +300,7 @@ class PingFrame(Frame): b += b'\0' * (8 - len(b)) return b + class GoAwayFrame(Frame): TYPE = 0x7 VALID_FLAGS = [] @@ -317,6 +329,7 @@ class GoAwayFrame(Frame): b += bytes(self.data) return b + class WindowUpdateFrame(Frame): TYPE = 0x8 VALID_FLAGS = [] @@ -335,11 +348,12 @@ class WindowUpdateFrame(Frame): return f def payload_bytes(self): - if self.window_size_increment <= 0 or self.window_size_increment >= 2**31: + if self.window_size_increment <= 0 or self.window_size_increment >= 2 ** 31: raise ValueError('Window Szie Increment MUST be greater than 0 and less than 2^31.') return struct.pack('!L', self.window_size_increment & 0x7FFFFFFF) + class ContinuationFrame(Frame): TYPE = 0x9 VALID_FLAGS = [Frame.FLAG_END_HEADERS] diff --git a/netlib/h2/h2.py b/netlib/h2/h2.py index 5d74c1c8..1a39a635 100644 --- a/netlib/h2/h2.py +++ b/netlib/h2/h2.py @@ -8,18 +8,18 @@ import io CLIENT_CONNECTION_PREFACE = '505249202a20485454502f322e300d0a0d0a534d0d0a0d0a' ERROR_CODES = utils.BiDi( - NO_ERROR = 0x0, - PROTOCOL_ERROR = 0x1, - INTERNAL_ERROR = 0x2, - FLOW_CONTROL_ERROR = 0x3, - SETTINGS_TIMEOUT = 0x4, - STREAM_CLOSED = 0x5, - FRAME_SIZE_ERROR = 0x6, - REFUSED_STREAM = 0x7, - CANCEL = 0x8, - COMPRESSION_ERROR = 0x9, - CONNECT_ERROR = 0xa, - ENHANCE_YOUR_CALM = 0xb, - INADEQUATE_SECURITY = 0xc, - HTTP_1_1_REQUIRED = 0xd - ) + NO_ERROR=0x0, + PROTOCOL_ERROR=0x1, + INTERNAL_ERROR=0x2, + FLOW_CONTROL_ERROR=0x3, + SETTINGS_TIMEOUT=0x4, + STREAM_CLOSED=0x5, + FRAME_SIZE_ERROR=0x6, + REFUSED_STREAM=0x7, + CANCEL=0x8, + COMPRESSION_ERROR=0x9, + CONNECT_ERROR=0xa, + ENHANCE_YOUR_CALM=0xb, + INADEQUATE_SECURITY=0xc, + HTTP_1_1_REQUIRED=0xd +) diff --git a/netlib/http.py b/netlib/http.py index 43155486..47658097 100644 --- a/netlib/http.py +++ b/netlib/http.py @@ -8,6 +8,7 @@ from . import odict, utils, tcp, http_status class HttpError(Exception): + def __init__(self, code, message): super(HttpError, self).__init__(message) self.code = code @@ -95,7 +96,7 @@ def read_headers(fp): """ ret = [] name = '' - while 1: + while True: line = fp.readline() if not line or line == '\r\n' or line == '\n': break @@ -337,7 +338,7 @@ def read_http_body_chunked( otherwise """ if max_chunk_size is None: - max_chunk_size = limit or sys.maxint + max_chunk_size = limit or sys.maxsize expected_size = expected_http_body_size( headers, is_request, request_method, response_code @@ -399,10 +400,10 @@ def expected_http_body_size(headers, is_request, request_method, response_code): request_method = request_method.upper() if (not is_request and ( - request_method == "HEAD" or - (request_method == "CONNECT" and response_code == 200) or - response_code in [204, 304] or - 100 <= response_code <= 199)): + request_method == "HEAD" or + (request_method == "CONNECT" and response_code == 200) or + response_code in [204, 304] or + 100 <= response_code <= 199)): return 0 if has_chunked_encoding(headers): return None diff --git a/netlib/http_auth.py b/netlib/http_auth.py index 296e094c..261b6654 100644 --- a/netlib/http_auth.py +++ b/netlib/http_auth.py @@ -4,9 +4,11 @@ from . import http class NullProxyAuth(object): + """ No proxy auth at all (returns empty challange headers) """ + def __init__(self, password_manager): self.password_manager = password_manager @@ -48,7 +50,7 @@ class BasicProxyAuth(NullProxyAuth): if not parts: return False scheme, username, password = parts - if scheme.lower()!='basic': + if scheme.lower() != 'basic': return False if not self.password_manager.test(username, password): return False @@ -56,18 +58,21 @@ class BasicProxyAuth(NullProxyAuth): return True def auth_challenge_headers(self): - return {self.CHALLENGE_HEADER:'Basic realm="%s"'%self.realm} + return {self.CHALLENGE_HEADER: 'Basic realm="%s"' % self.realm} class PassMan(object): + def test(self, username, password_token): return False class PassManNonAnon(PassMan): + """ Ensure the user specifies a username, accept any password. """ + def test(self, username, password_token): if username: return True @@ -75,9 +80,11 @@ class PassManNonAnon(PassMan): class PassManHtpasswd(PassMan): + """ Read usernames and passwords from an htpasswd file """ + def __init__(self, path): """ Raises ValueError if htpasswd file is invalid. @@ -90,14 +97,16 @@ class PassManHtpasswd(PassMan): class PassManSingleUser(PassMan): + def __init__(self, username, password): self.username, self.password = username, password def test(self, username, password_token): - return self.username==username and self.password==password_token + return self.username == username and self.password == password_token class AuthAction(Action): + """ Helper class to allow seamless integration int argparse. Example usage: parser.add_argument( @@ -106,16 +115,18 @@ class AuthAction(Action): help="Allow access to any user long as a credentials are specified." ) """ + def __call__(self, parser, namespace, values, option_string=None): passman = self.getPasswordManager(values) authenticator = BasicProxyAuth(passman, "mitmproxy") setattr(namespace, self.dest, authenticator) - def getPasswordManager(self, s): # pragma: nocover + def getPasswordManager(self, s): # pragma: nocover raise NotImplementedError() class SingleuserAuthAction(AuthAction): + def getPasswordManager(self, s): if len(s.split(':')) != 2: raise ArgumentTypeError( @@ -126,11 +137,12 @@ class SingleuserAuthAction(AuthAction): class NonanonymousAuthAction(AuthAction): + def getPasswordManager(self, s): return PassManNonAnon() class HtpasswdAuthAction(AuthAction): + def getPasswordManager(self, s): return PassManHtpasswd(s) - diff --git a/netlib/http_cookies.py b/netlib/http_cookies.py index 8e245891..73e3f589 100644 --- a/netlib/http_cookies.py +++ b/netlib/http_cookies.py @@ -96,7 +96,7 @@ def _read_pairs(s, off=0, specials=()): specials: a lower-cased list of keys that may contain commas """ vals = [] - while 1: + while True: lhs, off = _read_token(s, off) lhs = lhs.lstrip() if lhs: @@ -135,15 +135,15 @@ def _format_pairs(lst, specials=(), sep="; "): else: if k.lower() not in specials and _has_special(v): v = ESCAPE.sub(r"\\\1", v) - v = '"%s"'%v - vals.append("%s=%s"%(k, v)) + v = '"%s"' % v + vals.append("%s=%s" % (k, v)) return sep.join(vals) def _format_set_cookie_pairs(lst): return _format_pairs( lst, - specials = ("expires", "path") + specials=("expires", "path") ) @@ -154,7 +154,7 @@ def _parse_set_cookie_pairs(s): """ pairs, off = _read_pairs( s, - specials = ("expires", "path") + specials=("expires", "path") ) return pairs diff --git a/netlib/http_status.py b/netlib/http_status.py index 7dba2d56..dc09f465 100644 --- a/netlib/http_status.py +++ b/netlib/http_status.py @@ -1,51 +1,51 @@ from __future__ import (absolute_import, print_function, division) -CONTINUE = 100 -SWITCHING = 101 -OK = 200 -CREATED = 201 -ACCEPTED = 202 -NON_AUTHORITATIVE_INFORMATION = 203 -NO_CONTENT = 204 -RESET_CONTENT = 205 -PARTIAL_CONTENT = 206 -MULTI_STATUS = 207 +CONTINUE = 100 +SWITCHING = 101 +OK = 200 +CREATED = 201 +ACCEPTED = 202 +NON_AUTHORITATIVE_INFORMATION = 203 +NO_CONTENT = 204 +RESET_CONTENT = 205 +PARTIAL_CONTENT = 206 +MULTI_STATUS = 207 -MULTIPLE_CHOICE = 300 -MOVED_PERMANENTLY = 301 -FOUND = 302 -SEE_OTHER = 303 -NOT_MODIFIED = 304 -USE_PROXY = 305 -TEMPORARY_REDIRECT = 307 +MULTIPLE_CHOICE = 300 +MOVED_PERMANENTLY = 301 +FOUND = 302 +SEE_OTHER = 303 +NOT_MODIFIED = 304 +USE_PROXY = 305 +TEMPORARY_REDIRECT = 307 -BAD_REQUEST = 400 -UNAUTHORIZED = 401 -PAYMENT_REQUIRED = 402 -FORBIDDEN = 403 -NOT_FOUND = 404 -NOT_ALLOWED = 405 -NOT_ACCEPTABLE = 406 -PROXY_AUTH_REQUIRED = 407 -REQUEST_TIMEOUT = 408 -CONFLICT = 409 -GONE = 410 -LENGTH_REQUIRED = 411 -PRECONDITION_FAILED = 412 -REQUEST_ENTITY_TOO_LARGE = 413 -REQUEST_URI_TOO_LONG = 414 -UNSUPPORTED_MEDIA_TYPE = 415 +BAD_REQUEST = 400 +UNAUTHORIZED = 401 +PAYMENT_REQUIRED = 402 +FORBIDDEN = 403 +NOT_FOUND = 404 +NOT_ALLOWED = 405 +NOT_ACCEPTABLE = 406 +PROXY_AUTH_REQUIRED = 407 +REQUEST_TIMEOUT = 408 +CONFLICT = 409 +GONE = 410 +LENGTH_REQUIRED = 411 +PRECONDITION_FAILED = 412 +REQUEST_ENTITY_TOO_LARGE = 413 +REQUEST_URI_TOO_LONG = 414 +UNSUPPORTED_MEDIA_TYPE = 415 REQUESTED_RANGE_NOT_SATISFIABLE = 416 -EXPECTATION_FAILED = 417 +EXPECTATION_FAILED = 417 -INTERNAL_SERVER_ERROR = 500 -NOT_IMPLEMENTED = 501 -BAD_GATEWAY = 502 -SERVICE_UNAVAILABLE = 503 -GATEWAY_TIMEOUT = 504 -HTTP_VERSION_NOT_SUPPORTED = 505 -INSUFFICIENT_STORAGE_SPACE = 507 -NOT_EXTENDED = 510 +INTERNAL_SERVER_ERROR = 500 +NOT_IMPLEMENTED = 501 +BAD_GATEWAY = 502 +SERVICE_UNAVAILABLE = 503 +GATEWAY_TIMEOUT = 504 +HTTP_VERSION_NOT_SUPPORTED = 505 +INSUFFICIENT_STORAGE_SPACE = 507 +NOT_EXTENDED = 510 RESPONSES = { # 100 diff --git a/netlib/odict.py b/netlib/odict.py index dd738c55..f52acd50 100644 --- a/netlib/odict.py +++ b/netlib/odict.py @@ -1,5 +1,6 @@ from __future__ import (absolute_import, print_function, division) -import re, copy +import re +import copy def safe_subn(pattern, repl, target, *args, **kwargs): @@ -12,10 +13,12 @@ def safe_subn(pattern, repl, target, *args, **kwargs): class ODict(object): + """ A dictionary-like object for managing ordered (key, value) data. Think about it as a convenient interface to a list of (key, value) tuples. """ + def __init__(self, lst=None): self.lst = lst or [] @@ -157,7 +160,7 @@ class ODict(object): "key: value" """ for k, v in self.lst: - s = "%s: %s"%(k, v) + s = "%s: %s" % (k, v) if re.search(expr, s): return True return False @@ -192,11 +195,12 @@ class ODict(object): return klass([list(i) for i in state]) - class ODictCaseless(ODict): + """ A variant of ODict with "caseless" keys. This version _preserves_ key case, but does not consider case when setting or getting items. """ + def _kconv(self, s): return s.lower() diff --git a/netlib/socks.py b/netlib/socks.py index 6f9f57bd..5a73c61a 100644 --- a/netlib/socks.py +++ b/netlib/socks.py @@ -6,49 +6,50 @@ from . import tcp, utils class SocksError(Exception): + def __init__(self, code, message): super(SocksError, self).__init__(message) self.code = code VERSION = utils.BiDi( - SOCKS4 = 0x04, - SOCKS5 = 0x05 + SOCKS4=0x04, + SOCKS5=0x05 ) CMD = utils.BiDi( - CONNECT = 0x01, - BIND = 0x02, - UDP_ASSOCIATE = 0x03 + CONNECT=0x01, + BIND=0x02, + UDP_ASSOCIATE=0x03 ) ATYP = utils.BiDi( - IPV4_ADDRESS = 0x01, - DOMAINNAME = 0x03, - IPV6_ADDRESS = 0x04 + IPV4_ADDRESS=0x01, + DOMAINNAME=0x03, + IPV6_ADDRESS=0x04 ) REP = utils.BiDi( - SUCCEEDED = 0x00, - GENERAL_SOCKS_SERVER_FAILURE = 0x01, - CONNECTION_NOT_ALLOWED_BY_RULESET = 0x02, - NETWORK_UNREACHABLE = 0x03, - HOST_UNREACHABLE = 0x04, - CONNECTION_REFUSED = 0x05, - TTL_EXPIRED = 0x06, - COMMAND_NOT_SUPPORTED = 0x07, - ADDRESS_TYPE_NOT_SUPPORTED = 0x08, + SUCCEEDED=0x00, + GENERAL_SOCKS_SERVER_FAILURE=0x01, + CONNECTION_NOT_ALLOWED_BY_RULESET=0x02, + NETWORK_UNREACHABLE=0x03, + HOST_UNREACHABLE=0x04, + CONNECTION_REFUSED=0x05, + TTL_EXPIRED=0x06, + COMMAND_NOT_SUPPORTED=0x07, + ADDRESS_TYPE_NOT_SUPPORTED=0x08, ) METHOD = utils.BiDi( - NO_AUTHENTICATION_REQUIRED = 0x00, - GSSAPI = 0x01, - USERNAME_PASSWORD = 0x02, - NO_ACCEPTABLE_METHODS = 0xFF + NO_AUTHENTICATION_REQUIRED=0x00, + GSSAPI=0x01, + USERNAME_PASSWORD=0x02, + NO_ACCEPTABLE_METHODS=0xFF ) diff --git a/netlib/tcp.py b/netlib/tcp.py index 399203bb..7c115554 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -22,14 +22,28 @@ OP_NO_SSLv2 = SSL.OP_NO_SSLv2 OP_NO_SSLv3 = SSL.OP_NO_SSLv3 -class NetLibError(Exception): pass -class NetLibDisconnect(NetLibError): pass -class NetLibIncomplete(NetLibError): pass -class NetLibTimeout(NetLibError): pass -class NetLibSSLError(NetLibError): pass +class NetLibError(Exception): + pass + + +class NetLibDisconnect(NetLibError): + pass + + +class NetLibIncomplete(NetLibError): + pass + + +class NetLibTimeout(NetLibError): + pass + + +class NetLibSSLError(NetLibError): + pass class SSLKeyLogger(object): + def __init__(self, filename): self.filename = filename self.f = None @@ -67,6 +81,7 @@ log_ssl_key = SSLKeyLogger.create_logfun(os.getenv("MITMPROXY_SSLKEYLOGFILE") or class _FileLike(object): BLOCKSIZE = 1024 * 32 + def __init__(self, o): self.o = o self._log = None @@ -112,6 +127,7 @@ class _FileLike(object): class Writer(_FileLike): + def flush(self): """ May raise NetLibDisconnect @@ -119,7 +135,7 @@ class Writer(_FileLike): if hasattr(self.o, "flush"): try: self.o.flush() - except (socket.error, IOError), v: + except (socket.error, IOError) as v: raise NetLibDisconnect(str(v)) def write(self, v): @@ -135,11 +151,12 @@ class Writer(_FileLike): r = self.o.write(v) self.add_log(v[:r]) return r - except (SSL.Error, socket.error) as e: + except (SSL.Error, socket.error) as e: raise NetLibDisconnect(str(e)) class Reader(_FileLike): + def read(self, length): """ If length is -1, we read until connection closes. @@ -180,7 +197,7 @@ class Reader(_FileLike): self.add_log(result) return result - def readline(self, size = None): + def readline(self, size=None): result = '' bytes_read = 0 while True: @@ -204,16 +221,18 @@ class Reader(_FileLike): result = self.read(length) if length != -1 and len(result) != length: raise NetLibIncomplete( - "Expected %s bytes, got %s"%(length, len(result)) + "Expected %s bytes, got %s" % (length, len(result)) ) return result class Address(object): + """ This class wraps an IPv4/IPv6 tuple to provide named attributes and ipv6 information. """ + def __init__(self, address, use_ipv6=False): self.address = tuple(address) self.use_ipv6 = use_ipv6 @@ -304,6 +323,7 @@ def close_socket(sock): class _Connection(object): + def get_current_cipher(self): if not self.ssl_established: return None @@ -319,7 +339,7 @@ class _Connection(object): # (We call _FileLike.set_descriptor(conn)) # Closing the socket is not our task, therefore we don't call close # then. - if type(self.connection) != SSL.Connection: + if not isinstance(self.connection, SSL.Connection): if not getattr(self.wfile, "closed", False): try: self.wfile.flush() @@ -337,6 +357,7 @@ class _Connection(object): """ Creates an SSL Context. """ + def _create_ssl_context(self, method=SSLv23_METHOD, options=(OP_NO_SSLv2 | OP_NO_SSLv3), @@ -362,8 +383,8 @@ class _Connection(object): if cipher_list: try: context.set_cipher_list(cipher_list) - except SSL.Error, v: - raise NetLibError("SSL cipher specification error: %s"%str(v)) + except SSL.Error as v: + raise NetLibError("SSL cipher specification error: %s" % str(v)) # SSLKEYLOGFILE if log_ssl_key: @@ -380,7 +401,7 @@ class TCPClient(_Connection): # Make sure to close the real socket, not the SSL proxy. # OpenSSL is really good at screwing up, i.e. when trying to recv from a failed connection, # it tries to renegotiate... - if type(self.connection) == SSL.Connection: + if isinstance(self.connection, SSL.Connection): close_socket(self.connection._socket) else: close_socket(self.connection) @@ -400,8 +421,8 @@ class TCPClient(_Connection): try: context.use_privatekey_file(cert) context.use_certificate_file(cert) - except SSL.Error, v: - raise NetLibError("SSL client certificate error: %s"%str(v)) + except SSL.Error as v: + raise NetLibError("SSL client certificate error: %s" % str(v)) return context def convert_to_ssl(self, sni=None, **sslctx_kwargs): @@ -418,8 +439,8 @@ class TCPClient(_Connection): self.connection.set_connect_state() try: self.connection.do_handshake() - except SSL.Error, v: - raise NetLibError("SSL handshake error: %s"%repr(v)) + except SSL.Error as v: + raise NetLibError("SSL handshake error: %s" % repr(v)) self.ssl_established = True self.cert = certutils.SSLCert(self.connection.get_peer_certificate()) self.rfile.set_descriptor(self.connection) @@ -435,7 +456,7 @@ class TCPClient(_Connection): self.source_address = Address(connection.getsockname()) self.rfile = Reader(connection.makefile('rb', self.rbufsize)) self.wfile = Writer(connection.makefile('wb', self.wbufsize)) - except (socket.error, IOError), err: + except (socket.error, IOError) as err: raise NetLibError('Error connecting to "%s": %s' % (self.address.host, err)) self.connection = connection @@ -447,6 +468,7 @@ class TCPClient(_Connection): class BaseHandler(_Connection): + """ The instantiator is expected to call the handle() and finish() methods. @@ -531,8 +553,8 @@ class BaseHandler(_Connection): self.connection.set_accept_state() try: self.connection.do_handshake() - except SSL.Error, v: - raise NetLibError("SSL handshake error: %s"%repr(v)) + except SSL.Error as v: + raise NetLibError("SSL handshake error: %s" % repr(v)) self.ssl_established = True self.rfile.set_descriptor(self.connection) self.wfile.set_descriptor(self.connection) diff --git a/netlib/test.py b/netlib/test.py index db30c0e6..b6f94273 100644 --- a/netlib/test.py +++ b/netlib/test.py @@ -1,9 +1,13 @@ from __future__ import (absolute_import, print_function, division) -import threading, Queue, cStringIO +import threading +import Queue +import cStringIO import OpenSSL from . import tcp, certutils + class ServerThread(threading.Thread): + def __init__(self, server): self.server = server threading.Thread.__init__(self) @@ -19,6 +23,7 @@ class ServerTestBase(object): ssl = None handler = None addr = ("localhost", 0) + @classmethod def setupAll(cls): cls.q = Queue.Queue() @@ -41,10 +46,11 @@ class ServerTestBase(object): class TServer(tcp.TCPServer): + def __init__(self, ssl, q, handler_klass, addr): """ ssl: A dictionary of SSL parameters: - + cert, key, request_client_cert, cipher_list, dhparams, v3_only """ @@ -70,13 +76,13 @@ class TServer(tcp.TCPServer): options = None h.convert_to_ssl( cert, key, - method = method, - options = options, - handle_sni = getattr(h, "handle_sni", None), - request_client_cert = self.ssl["request_client_cert"], - cipher_list = self.ssl.get("cipher_list", None), - dhparams = self.ssl.get("dhparams", None), - chain_file = self.ssl.get("chain_file", None) + method=method, + options=options, + handle_sni=getattr(h, "handle_sni", None), + request_client_cert=self.ssl["request_client_cert"], + cipher_list=self.ssl.get("cipher_list", None), + dhparams=self.ssl.get("dhparams", None), + chain_file=self.ssl.get("chain_file", None) ) h.handle() h.finish() diff --git a/netlib/utils.py b/netlib/utils.py index 7e539977..9c5404e6 100644 --- a/netlib/utils.py +++ b/netlib/utils.py @@ -68,6 +68,7 @@ def getbit(byte, offset): class BiDi: + """ A wee utility class for keeping bi-directional mappings, like field constants in protocols. Names are attributes on the object, dict-like @@ -77,6 +78,7 @@ class BiDi: assert CONST.a == 1 assert CONST.get_name(1) == "a" """ + def __init__(self, **kwargs): self.names = kwargs self.values = {} @@ -96,15 +98,15 @@ class BiDi: def pretty_size(size): suffixes = [ - ("B", 2**10), - ("kB", 2**20), - ("MB", 2**30), + ("B", 2 ** 10), + ("kB", 2 ** 20), + ("MB", 2 ** 30), ] for suf, lim in suffixes: if size >= lim: continue else: - x = round(size/float(lim/2**10), 2) + x = round(size / float(lim / 2 ** 10), 2) if x == int(x): x = int(x) return str(x) + suf diff --git a/netlib/websockets.py b/netlib/websockets.py index a2d55c19..63dc03f1 100644 --- a/netlib/websockets.py +++ b/netlib/websockets.py @@ -26,16 +26,17 @@ MAX_64_BIT_INT = (1 << 64) OPCODE = utils.BiDi( - CONTINUE = 0x00, - TEXT = 0x01, - BINARY = 0x02, - CLOSE = 0x08, - PING = 0x09, - PONG = 0x0a + CONTINUE=0x00, + TEXT=0x01, + BINARY=0x02, + CLOSE=0x08, + PING=0x09, + PONG=0x0a ) class Masker: + """ Data sent from the server must be masked to prevent malicious clients from sending data over the wire in predictable patterns @@ -43,6 +44,7 @@ class Masker: Servers do not have to mask data they send to the client. https://tools.ietf.org/html/rfc6455#section-5.3 """ + def __init__(self, key): self.key = key self.masks = [utils.bytes_to_int(byte) for byte in key] @@ -128,17 +130,18 @@ DEFAULT = object() class FrameHeader: + def __init__( self, - opcode = OPCODE.TEXT, - payload_length = 0, - fin = False, - rsv1 = False, - rsv2 = False, - rsv3 = False, - masking_key = DEFAULT, - mask = DEFAULT, - length_code = DEFAULT + opcode=OPCODE.TEXT, + payload_length=0, + fin=False, + rsv1=False, + rsv2=False, + rsv3=False, + masking_key=DEFAULT, + mask=DEFAULT, + length_code=DEFAULT ): if not 0 <= opcode < 2 ** 4: raise ValueError("opcode must be 0-16") @@ -182,9 +185,9 @@ class FrameHeader: if flags: vals.extend([":", "|".join(flags)]) if self.masking_key: - vals.append(":key=%s"%repr(self.masking_key)) + vals.append(":key=%s" % repr(self.masking_key)) if self.payload_length: - vals.append(" %s"%utils.pretty_size(self.payload_length)) + vals.append(" %s" % utils.pretty_size(self.payload_length)) return "".join(vals) def to_bytes(self): @@ -246,15 +249,15 @@ class FrameHeader: masking_key = None return klass( - fin = fin, - rsv1 = rsv1, - rsv2 = rsv2, - rsv3 = rsv3, - opcode = opcode, - mask = mask_bit, - length_code = length_code, - payload_length = payload_length, - masking_key = masking_key, + fin=fin, + rsv1=rsv1, + rsv2=rsv2, + rsv3=rsv3, + opcode=opcode, + mask=mask_bit, + length_code=length_code, + payload_length=payload_length, + masking_key=masking_key, ) def __eq__(self, other): @@ -262,6 +265,7 @@ class FrameHeader: class Frame(object): + """ Represents one websockets frame. Constructor takes human readable forms of the frame components @@ -287,13 +291,14 @@ class Frame(object): | Payload Data continued ... | +---------------------------------------------------------------+ """ - def __init__(self, payload = "", **kwargs): + + def __init__(self, payload="", **kwargs): self.payload = payload kwargs["payload_length"] = kwargs.get("payload_length", len(payload)) self.header = FrameHeader(**kwargs) @classmethod - def default(cls, message, from_client = False): + def default(cls, message, from_client=False): """ Construct a basic websocket frame from some default values. Creates a non-fragmented text frame. @@ -307,10 +312,10 @@ class Frame(object): return cls( message, - fin = 1, # final frame - opcode = OPCODE.TEXT, # text - mask = mask_bit, - masking_key = masking_key, + fin=1, # final frame + opcode=OPCODE.TEXT, # text + mask=mask_bit, + masking_key=masking_key, ) @classmethod @@ -356,15 +361,15 @@ class Frame(object): return cls( payload, - fin = header.fin, - opcode = header.opcode, - mask = header.mask, - payload_length = header.payload_length, - masking_key = header.masking_key, - rsv1 = header.rsv1, - rsv2 = header.rsv2, - rsv3 = header.rsv3, - length_code = header.length_code + fin=header.fin, + opcode=header.opcode, + mask=header.mask, + payload_length=header.payload_length, + masking_key=header.masking_key, + rsv1=header.rsv1, + rsv2=header.rsv2, + rsv3=header.rsv3, + length_code=header.length_code ) def __eq__(self, other): diff --git a/netlib/wsgi.py b/netlib/wsgi.py index 1b979608..f393039a 100644 --- a/netlib/wsgi.py +++ b/netlib/wsgi.py @@ -7,17 +7,20 @@ from . import odict, tcp class ClientConn(object): + def __init__(self, address): self.address = tcp.Address.wrap(address) class Flow(object): + def __init__(self, address, request): self.client_conn = ClientConn(address) self.request = request class Request(object): + def __init__(self, scheme, method, path, headers, content): self.scheme, self.method, self.path = scheme, method, path self.headers, self.content = headers, content @@ -42,6 +45,7 @@ def date_time_string(): class WSGIAdaptor(object): + def __init__(self, app, domain, port, sversion): self.app, self.domain, self.port, self.sversion = app, domain, port, sversion @@ -52,24 +56,24 @@ class WSGIAdaptor(object): path_info = flow.request.path query = '' environ = { - 'wsgi.version': (1, 0), - 'wsgi.url_scheme': flow.request.scheme, - 'wsgi.input': cStringIO.StringIO(flow.request.content), - 'wsgi.errors': errsoc, - 'wsgi.multithread': True, - 'wsgi.multiprocess': False, - 'wsgi.run_once': False, - 'SERVER_SOFTWARE': self.sversion, - 'REQUEST_METHOD': flow.request.method, - 'SCRIPT_NAME': '', - 'PATH_INFO': urllib.unquote(path_info), - 'QUERY_STRING': query, - 'CONTENT_TYPE': flow.request.headers.get('Content-Type', [''])[0], - 'CONTENT_LENGTH': flow.request.headers.get('Content-Length', [''])[0], - 'SERVER_NAME': self.domain, - 'SERVER_PORT': str(self.port), + 'wsgi.version': (1, 0), + 'wsgi.url_scheme': flow.request.scheme, + 'wsgi.input': cStringIO.StringIO(flow.request.content), + 'wsgi.errors': errsoc, + 'wsgi.multithread': True, + 'wsgi.multiprocess': False, + 'wsgi.run_once': False, + 'SERVER_SOFTWARE': self.sversion, + 'REQUEST_METHOD': flow.request.method, + 'SCRIPT_NAME': '', + 'PATH_INFO': urllib.unquote(path_info), + 'QUERY_STRING': query, + 'CONTENT_TYPE': flow.request.headers.get('Content-Type', [''])[0], + 'CONTENT_LENGTH': flow.request.headers.get('Content-Length', [''])[0], + 'SERVER_NAME': self.domain, + 'SERVER_PORT': str(self.port), # FIXME: We need to pick up the protocol read from the request. - 'SERVER_PROTOCOL': "HTTP/1.1", + 'SERVER_PROTOCOL': "HTTP/1.1", } environ.update(extra) if flow.client_conn.address: @@ -91,25 +95,25 @@ class WSGIAdaptor(object): <h1>Internal Server Error</h1> <pre>%s"</pre> </html> - """%s + """ % s if not headers_sent: soc.write("HTTP/1.1 500 Internal Server Error\r\n") soc.write("Content-Type: text/html\r\n") - soc.write("Content-Length: %s\r\n"%len(c)) + soc.write("Content-Length: %s\r\n" % len(c)) soc.write("\r\n") soc.write(c) def serve(self, request, soc, **env): state = dict( - response_started = False, - headers_sent = False, - status = None, - headers = None + response_started=False, + headers_sent=False, + status=None, + headers=None ) def write(data): if not state["headers_sent"]: - soc.write("HTTP/1.1 %s\r\n"%state["status"]) + soc.write("HTTP/1.1 %s\r\n" % state["status"]) h = state["headers"] if 'server' not in h: h["Server"] = [self.sversion] |