aboutsummaryrefslogtreecommitdiffstats
path: root/netlib
diff options
context:
space:
mode:
Diffstat (limited to 'netlib')
-rw-r--r--netlib/certutils.py56
-rw-r--r--netlib/h2/frame.py34
-rw-r--r--netlib/h2/h2.py30
-rw-r--r--netlib/http.py13
-rw-r--r--netlib/http_auth.py22
-rw-r--r--netlib/http_cookies.py10
-rw-r--r--netlib/http_status.py84
-rw-r--r--netlib/odict.py10
-rw-r--r--netlib/socks.py43
-rw-r--r--netlib/tcp.py62
-rw-r--r--netlib/test.py24
-rw-r--r--netlib/utils.py10
-rw-r--r--netlib/websockets.py87
-rw-r--r--netlib/wsgi.py52
14 files changed, 308 insertions, 229 deletions
diff --git a/netlib/certutils.py b/netlib/certutils.py
index f5375c03..da0e3355 100644
--- a/netlib/certutils.py
+++ b/netlib/certutils.py
@@ -1,12 +1,15 @@
from __future__ import (absolute_import, print_function, division)
-import os, ssl, time, datetime
+import os
+import ssl
+import time
+import datetime
import itertools
from pyasn1.type import univ, constraint, char, namedtype, tag
from pyasn1.codec.der.decoder import decode
from pyasn1.error import PyAsn1Error
import OpenSSL
-DEFAULT_EXP = 157680000 # = 24 * 60 * 60 * 365 * 5
+DEFAULT_EXP = 157680000 # = 24 * 60 * 60 * 365 * 5
# Generated with "openssl dhparam". It's too slow to generate this on startup.
DEFAULT_DHPARAM = """-----BEGIN DH PARAMETERS-----
MIGHAoGBAOdPzMbYgoYfO3YBYauCLRlE8X1XypTiAjoeCFD0qWRx8YUsZ6Sj20W5
@@ -14,31 +17,32 @@ zsfQxlZfKovo3f2MftjkDkbI/C/tDgxoe0ZPbjy5CjdOhkzxn0oTbKTs16Rw8DyK
1LjTR65sQJkJEdgsX8TSi/cicCftJZl9CaZEaObF2bdgSgGK+PezAgEC
-----END DH PARAMETERS-----"""
+
def create_ca(o, cn, exp):
key = OpenSSL.crypto.PKey()
key.generate_key(OpenSSL.crypto.TYPE_RSA, 1024)
cert = OpenSSL.crypto.X509()
- cert.set_serial_number(int(time.time()*10000))
+ cert.set_serial_number(int(time.time() * 10000))
cert.set_version(2)
cert.get_subject().CN = cn
cert.get_subject().O = o
- cert.gmtime_adj_notBefore(-3600*48)
+ cert.gmtime_adj_notBefore(-3600 * 48)
cert.gmtime_adj_notAfter(exp)
cert.set_issuer(cert.get_subject())
cert.set_pubkey(key)
cert.add_extensions([
- OpenSSL.crypto.X509Extension("basicConstraints", True,
- "CA:TRUE"),
- OpenSSL.crypto.X509Extension("nsCertType", False,
- "sslCA"),
- OpenSSL.crypto.X509Extension("extendedKeyUsage", False,
- "serverAuth,clientAuth,emailProtection,timeStamping,msCodeInd,msCodeCom,msCTLSign,msSGC,msEFS,nsSGC"
- ),
- OpenSSL.crypto.X509Extension("keyUsage", True,
- "keyCertSign, cRLSign"),
- OpenSSL.crypto.X509Extension("subjectKeyIdentifier", False, "hash",
- subject=cert),
- ])
+ OpenSSL.crypto.X509Extension("basicConstraints", True,
+ "CA:TRUE"),
+ OpenSSL.crypto.X509Extension("nsCertType", False,
+ "sslCA"),
+ OpenSSL.crypto.X509Extension("extendedKeyUsage", False,
+ "serverAuth,clientAuth,emailProtection,timeStamping,msCodeInd,msCodeCom,msCTLSign,msSGC,msEFS,nsSGC"
+ ),
+ OpenSSL.crypto.X509Extension("keyUsage", True,
+ "keyCertSign, cRLSign"),
+ OpenSSL.crypto.X509Extension("subjectKeyIdentifier", False, "hash",
+ subject=cert),
+ ])
cert.sign(key, "sha1")
return key, cert
@@ -56,15 +60,15 @@ def dummy_cert(privkey, cacert, commonname, sans):
"""
ss = []
for i in sans:
- ss.append("DNS: %s"%i)
+ ss.append("DNS: %s" % i)
ss = ", ".join(ss)
cert = OpenSSL.crypto.X509()
- cert.gmtime_adj_notBefore(-3600*48)
+ cert.gmtime_adj_notBefore(-3600 * 48)
cert.gmtime_adj_notAfter(DEFAULT_EXP)
cert.set_issuer(cacert.get_subject())
cert.get_subject().CN = commonname
- cert.set_serial_number(int(time.time()*10000))
+ cert.set_serial_number(int(time.time() * 10000))
if ss:
cert.set_version(2)
cert.add_extensions([OpenSSL.crypto.X509Extension("subjectAltName", False, ss)])
@@ -114,6 +118,7 @@ def dummy_cert(privkey, cacert, commonname, sans):
class CertStoreEntry(object):
+
def __init__(self, cert, privatekey, chain_file):
self.cert = cert
self.privatekey = privatekey
@@ -121,9 +126,11 @@ class CertStoreEntry(object):
class CertStore(object):
+
"""
Implements an in-memory certificate store.
"""
+
def __init__(self, default_privatekey, default_ca, default_chain_file, dhparams=None):
self.default_privatekey = default_privatekey
self.default_ca = default_ca
@@ -144,11 +151,11 @@ class CertStore(object):
if bio != OpenSSL.SSL._ffi.NULL:
bio = OpenSSL.SSL._ffi.gc(bio, OpenSSL.SSL._lib.BIO_free)
dh = OpenSSL.SSL._lib.PEM_read_bio_DHparams(
- bio, OpenSSL.SSL._ffi.NULL, OpenSSL.SSL._ffi.NULL, OpenSSL.SSL._ffi.NULL
- )
+ bio, OpenSSL.SSL._ffi.NULL, OpenSSL.SSL._ffi.NULL, OpenSSL.SSL._ffi.NULL
+ )
dh = OpenSSL.SSL._ffi.gc(dh, OpenSSL.SSL._lib.DH_free)
return dh
-
+
@classmethod
def from_store(cls, path, basename):
ca_path = os.path.join(path, basename + "-ca.pem")
@@ -277,8 +284,8 @@ class _GeneralName(univ.Choice):
# other types.
componentType = namedtype.NamedTypes(
namedtype.NamedType('dNSName', char.IA5String().subtype(
- implicitTag=tag.Tag(tag.tagClassContext, tag.tagFormatSimple, 2)
- )
+ implicitTag=tag.Tag(tag.tagClassContext, tag.tagFormatSimple, 2)
+ )
),
)
@@ -289,6 +296,7 @@ class _GeneralNames(univ.SequenceOf):
class SSLCert(object):
+
def __init__(self, cert):
"""
Returns a (common name, [subject alternative names]) tuple.
diff --git a/netlib/h2/frame.py b/netlib/h2/frame.py
index 52cc2992..d846b3b9 100644
--- a/netlib/h2/frame.py
+++ b/netlib/h2/frame.py
@@ -5,8 +5,11 @@ import struct
import io
from .. import utils, odict, tcp
+from functools import reduce
+
class Frame(object):
+
"""
Baseclass Frame
contains header
@@ -53,6 +56,7 @@ class Frame(object):
def __eq__(self, other):
return self.to_bytes() == other.to_bytes()
+
class DataFrame(Frame):
TYPE = 0x0
VALID_FLAGS = [Frame.FLAG_END_STREAM, Frame.FLAG_PADDED]
@@ -89,11 +93,13 @@ class DataFrame(Frame):
return b
+
class HeadersFrame(Frame):
TYPE = 0x1
VALID_FLAGS = [Frame.FLAG_END_STREAM, Frame.FLAG_END_HEADERS, Frame.FLAG_PADDED, Frame.FLAG_PRIORITY]
- def __init__(self, length=0, flags=Frame.FLAG_NO_FLAGS, stream_id=0x0, header_block_fragment=b'', pad_length=0, exclusive=False, stream_dependency=0x0, weight=0):
+ def __init__(self, length=0, flags=Frame.FLAG_NO_FLAGS, stream_id=0x0, header_block_fragment=b'',
+ pad_length=0, exclusive=False, stream_dependency=0x0, weight=0):
super(HeadersFrame, self).__init__(length, flags, stream_id)
self.header_block_fragment = header_block_fragment
self.pad_length = pad_length
@@ -137,6 +143,7 @@ class HeadersFrame(Frame):
return b
+
class PriorityFrame(Frame):
TYPE = 0x2
VALID_FLAGS = []
@@ -166,6 +173,7 @@ class PriorityFrame(Frame):
return struct.pack('!LB', (int(self.exclusive) << 31) | self.stream_dependency, self.weight)
+
class RstStreamFrame(Frame):
TYPE = 0x3
VALID_FLAGS = []
@@ -186,18 +194,19 @@ class RstStreamFrame(Frame):
return struct.pack('!L', self.error_code)
+
class SettingsFrame(Frame):
TYPE = 0x4
VALID_FLAGS = [Frame.FLAG_ACK]
SETTINGS = utils.BiDi(
- SETTINGS_HEADER_TABLE_SIZE = 0x1,
- SETTINGS_ENABLE_PUSH = 0x2,
- SETTINGS_MAX_CONCURRENT_STREAMS = 0x3,
- SETTINGS_INITIAL_WINDOW_SIZE = 0x4,
- SETTINGS_MAX_FRAME_SIZE = 0x5,
- SETTINGS_MAX_HEADER_LIST_SIZE = 0x6,
- )
+ SETTINGS_HEADER_TABLE_SIZE=0x1,
+ SETTINGS_ENABLE_PUSH=0x2,
+ SETTINGS_MAX_CONCURRENT_STREAMS=0x3,
+ SETTINGS_INITIAL_WINDOW_SIZE=0x4,
+ SETTINGS_MAX_FRAME_SIZE=0x5,
+ SETTINGS_MAX_HEADER_LIST_SIZE=0x6,
+ )
def __init__(self, length=0, flags=Frame.FLAG_NO_FLAGS, stream_id=0x0, settings={}):
super(SettingsFrame, self).__init__(length, flags, stream_id)
@@ -208,7 +217,7 @@ class SettingsFrame(Frame):
f = self(length=length, flags=flags, stream_id=stream_id)
for i in xrange(0, len(payload), 6):
- identifier, value = struct.unpack("!HL", payload[i:i+6])
+ identifier, value = struct.unpack("!HL", payload[i:i + 6])
f.settings[identifier] = value
return f
@@ -223,6 +232,7 @@ class SettingsFrame(Frame):
return b
+
class PushPromiseFrame(Frame):
TYPE = 0x5
VALID_FLAGS = [Frame.FLAG_END_HEADERS, Frame.FLAG_PADDED]
@@ -267,6 +277,7 @@ class PushPromiseFrame(Frame):
return b
+
class PingFrame(Frame):
TYPE = 0x6
VALID_FLAGS = [Frame.FLAG_ACK]
@@ -289,6 +300,7 @@ class PingFrame(Frame):
b += b'\0' * (8 - len(b))
return b
+
class GoAwayFrame(Frame):
TYPE = 0x7
VALID_FLAGS = []
@@ -317,6 +329,7 @@ class GoAwayFrame(Frame):
b += bytes(self.data)
return b
+
class WindowUpdateFrame(Frame):
TYPE = 0x8
VALID_FLAGS = []
@@ -335,11 +348,12 @@ class WindowUpdateFrame(Frame):
return f
def payload_bytes(self):
- if self.window_size_increment <= 0 or self.window_size_increment >= 2**31:
+ if self.window_size_increment <= 0 or self.window_size_increment >= 2 ** 31:
raise ValueError('Window Szie Increment MUST be greater than 0 and less than 2^31.')
return struct.pack('!L', self.window_size_increment & 0x7FFFFFFF)
+
class ContinuationFrame(Frame):
TYPE = 0x9
VALID_FLAGS = [Frame.FLAG_END_HEADERS]
diff --git a/netlib/h2/h2.py b/netlib/h2/h2.py
index 5d74c1c8..1a39a635 100644
--- a/netlib/h2/h2.py
+++ b/netlib/h2/h2.py
@@ -8,18 +8,18 @@ import io
CLIENT_CONNECTION_PREFACE = '505249202a20485454502f322e300d0a0d0a534d0d0a0d0a'
ERROR_CODES = utils.BiDi(
- NO_ERROR = 0x0,
- PROTOCOL_ERROR = 0x1,
- INTERNAL_ERROR = 0x2,
- FLOW_CONTROL_ERROR = 0x3,
- SETTINGS_TIMEOUT = 0x4,
- STREAM_CLOSED = 0x5,
- FRAME_SIZE_ERROR = 0x6,
- REFUSED_STREAM = 0x7,
- CANCEL = 0x8,
- COMPRESSION_ERROR = 0x9,
- CONNECT_ERROR = 0xa,
- ENHANCE_YOUR_CALM = 0xb,
- INADEQUATE_SECURITY = 0xc,
- HTTP_1_1_REQUIRED = 0xd
- )
+ NO_ERROR=0x0,
+ PROTOCOL_ERROR=0x1,
+ INTERNAL_ERROR=0x2,
+ FLOW_CONTROL_ERROR=0x3,
+ SETTINGS_TIMEOUT=0x4,
+ STREAM_CLOSED=0x5,
+ FRAME_SIZE_ERROR=0x6,
+ REFUSED_STREAM=0x7,
+ CANCEL=0x8,
+ COMPRESSION_ERROR=0x9,
+ CONNECT_ERROR=0xa,
+ ENHANCE_YOUR_CALM=0xb,
+ INADEQUATE_SECURITY=0xc,
+ HTTP_1_1_REQUIRED=0xd
+)
diff --git a/netlib/http.py b/netlib/http.py
index 43155486..47658097 100644
--- a/netlib/http.py
+++ b/netlib/http.py
@@ -8,6 +8,7 @@ from . import odict, utils, tcp, http_status
class HttpError(Exception):
+
def __init__(self, code, message):
super(HttpError, self).__init__(message)
self.code = code
@@ -95,7 +96,7 @@ def read_headers(fp):
"""
ret = []
name = ''
- while 1:
+ while True:
line = fp.readline()
if not line or line == '\r\n' or line == '\n':
break
@@ -337,7 +338,7 @@ def read_http_body_chunked(
otherwise
"""
if max_chunk_size is None:
- max_chunk_size = limit or sys.maxint
+ max_chunk_size = limit or sys.maxsize
expected_size = expected_http_body_size(
headers, is_request, request_method, response_code
@@ -399,10 +400,10 @@ def expected_http_body_size(headers, is_request, request_method, response_code):
request_method = request_method.upper()
if (not is_request and (
- request_method == "HEAD" or
- (request_method == "CONNECT" and response_code == 200) or
- response_code in [204, 304] or
- 100 <= response_code <= 199)):
+ request_method == "HEAD" or
+ (request_method == "CONNECT" and response_code == 200) or
+ response_code in [204, 304] or
+ 100 <= response_code <= 199)):
return 0
if has_chunked_encoding(headers):
return None
diff --git a/netlib/http_auth.py b/netlib/http_auth.py
index 296e094c..261b6654 100644
--- a/netlib/http_auth.py
+++ b/netlib/http_auth.py
@@ -4,9 +4,11 @@ from . import http
class NullProxyAuth(object):
+
"""
No proxy auth at all (returns empty challange headers)
"""
+
def __init__(self, password_manager):
self.password_manager = password_manager
@@ -48,7 +50,7 @@ class BasicProxyAuth(NullProxyAuth):
if not parts:
return False
scheme, username, password = parts
- if scheme.lower()!='basic':
+ if scheme.lower() != 'basic':
return False
if not self.password_manager.test(username, password):
return False
@@ -56,18 +58,21 @@ class BasicProxyAuth(NullProxyAuth):
return True
def auth_challenge_headers(self):
- return {self.CHALLENGE_HEADER:'Basic realm="%s"'%self.realm}
+ return {self.CHALLENGE_HEADER: 'Basic realm="%s"' % self.realm}
class PassMan(object):
+
def test(self, username, password_token):
return False
class PassManNonAnon(PassMan):
+
"""
Ensure the user specifies a username, accept any password.
"""
+
def test(self, username, password_token):
if username:
return True
@@ -75,9 +80,11 @@ class PassManNonAnon(PassMan):
class PassManHtpasswd(PassMan):
+
"""
Read usernames and passwords from an htpasswd file
"""
+
def __init__(self, path):
"""
Raises ValueError if htpasswd file is invalid.
@@ -90,14 +97,16 @@ class PassManHtpasswd(PassMan):
class PassManSingleUser(PassMan):
+
def __init__(self, username, password):
self.username, self.password = username, password
def test(self, username, password_token):
- return self.username==username and self.password==password_token
+ return self.username == username and self.password == password_token
class AuthAction(Action):
+
"""
Helper class to allow seamless integration int argparse. Example usage:
parser.add_argument(
@@ -106,16 +115,18 @@ class AuthAction(Action):
help="Allow access to any user long as a credentials are specified."
)
"""
+
def __call__(self, parser, namespace, values, option_string=None):
passman = self.getPasswordManager(values)
authenticator = BasicProxyAuth(passman, "mitmproxy")
setattr(namespace, self.dest, authenticator)
- def getPasswordManager(self, s): # pragma: nocover
+ def getPasswordManager(self, s): # pragma: nocover
raise NotImplementedError()
class SingleuserAuthAction(AuthAction):
+
def getPasswordManager(self, s):
if len(s.split(':')) != 2:
raise ArgumentTypeError(
@@ -126,11 +137,12 @@ class SingleuserAuthAction(AuthAction):
class NonanonymousAuthAction(AuthAction):
+
def getPasswordManager(self, s):
return PassManNonAnon()
class HtpasswdAuthAction(AuthAction):
+
def getPasswordManager(self, s):
return PassManHtpasswd(s)
-
diff --git a/netlib/http_cookies.py b/netlib/http_cookies.py
index 8e245891..73e3f589 100644
--- a/netlib/http_cookies.py
+++ b/netlib/http_cookies.py
@@ -96,7 +96,7 @@ def _read_pairs(s, off=0, specials=()):
specials: a lower-cased list of keys that may contain commas
"""
vals = []
- while 1:
+ while True:
lhs, off = _read_token(s, off)
lhs = lhs.lstrip()
if lhs:
@@ -135,15 +135,15 @@ def _format_pairs(lst, specials=(), sep="; "):
else:
if k.lower() not in specials and _has_special(v):
v = ESCAPE.sub(r"\\\1", v)
- v = '"%s"'%v
- vals.append("%s=%s"%(k, v))
+ v = '"%s"' % v
+ vals.append("%s=%s" % (k, v))
return sep.join(vals)
def _format_set_cookie_pairs(lst):
return _format_pairs(
lst,
- specials = ("expires", "path")
+ specials=("expires", "path")
)
@@ -154,7 +154,7 @@ def _parse_set_cookie_pairs(s):
"""
pairs, off = _read_pairs(
s,
- specials = ("expires", "path")
+ specials=("expires", "path")
)
return pairs
diff --git a/netlib/http_status.py b/netlib/http_status.py
index 7dba2d56..dc09f465 100644
--- a/netlib/http_status.py
+++ b/netlib/http_status.py
@@ -1,51 +1,51 @@
from __future__ import (absolute_import, print_function, division)
-CONTINUE = 100
-SWITCHING = 101
-OK = 200
-CREATED = 201
-ACCEPTED = 202
-NON_AUTHORITATIVE_INFORMATION = 203
-NO_CONTENT = 204
-RESET_CONTENT = 205
-PARTIAL_CONTENT = 206
-MULTI_STATUS = 207
+CONTINUE = 100
+SWITCHING = 101
+OK = 200
+CREATED = 201
+ACCEPTED = 202
+NON_AUTHORITATIVE_INFORMATION = 203
+NO_CONTENT = 204
+RESET_CONTENT = 205
+PARTIAL_CONTENT = 206
+MULTI_STATUS = 207
-MULTIPLE_CHOICE = 300
-MOVED_PERMANENTLY = 301
-FOUND = 302
-SEE_OTHER = 303
-NOT_MODIFIED = 304
-USE_PROXY = 305
-TEMPORARY_REDIRECT = 307
+MULTIPLE_CHOICE = 300
+MOVED_PERMANENTLY = 301
+FOUND = 302
+SEE_OTHER = 303
+NOT_MODIFIED = 304
+USE_PROXY = 305
+TEMPORARY_REDIRECT = 307
-BAD_REQUEST = 400
-UNAUTHORIZED = 401
-PAYMENT_REQUIRED = 402
-FORBIDDEN = 403
-NOT_FOUND = 404
-NOT_ALLOWED = 405
-NOT_ACCEPTABLE = 406
-PROXY_AUTH_REQUIRED = 407
-REQUEST_TIMEOUT = 408
-CONFLICT = 409
-GONE = 410
-LENGTH_REQUIRED = 411
-PRECONDITION_FAILED = 412
-REQUEST_ENTITY_TOO_LARGE = 413
-REQUEST_URI_TOO_LONG = 414
-UNSUPPORTED_MEDIA_TYPE = 415
+BAD_REQUEST = 400
+UNAUTHORIZED = 401
+PAYMENT_REQUIRED = 402
+FORBIDDEN = 403
+NOT_FOUND = 404
+NOT_ALLOWED = 405
+NOT_ACCEPTABLE = 406
+PROXY_AUTH_REQUIRED = 407
+REQUEST_TIMEOUT = 408
+CONFLICT = 409
+GONE = 410
+LENGTH_REQUIRED = 411
+PRECONDITION_FAILED = 412
+REQUEST_ENTITY_TOO_LARGE = 413
+REQUEST_URI_TOO_LONG = 414
+UNSUPPORTED_MEDIA_TYPE = 415
REQUESTED_RANGE_NOT_SATISFIABLE = 416
-EXPECTATION_FAILED = 417
+EXPECTATION_FAILED = 417
-INTERNAL_SERVER_ERROR = 500
-NOT_IMPLEMENTED = 501
-BAD_GATEWAY = 502
-SERVICE_UNAVAILABLE = 503
-GATEWAY_TIMEOUT = 504
-HTTP_VERSION_NOT_SUPPORTED = 505
-INSUFFICIENT_STORAGE_SPACE = 507
-NOT_EXTENDED = 510
+INTERNAL_SERVER_ERROR = 500
+NOT_IMPLEMENTED = 501
+BAD_GATEWAY = 502
+SERVICE_UNAVAILABLE = 503
+GATEWAY_TIMEOUT = 504
+HTTP_VERSION_NOT_SUPPORTED = 505
+INSUFFICIENT_STORAGE_SPACE = 507
+NOT_EXTENDED = 510
RESPONSES = {
# 100
diff --git a/netlib/odict.py b/netlib/odict.py
index dd738c55..f52acd50 100644
--- a/netlib/odict.py
+++ b/netlib/odict.py
@@ -1,5 +1,6 @@
from __future__ import (absolute_import, print_function, division)
-import re, copy
+import re
+import copy
def safe_subn(pattern, repl, target, *args, **kwargs):
@@ -12,10 +13,12 @@ def safe_subn(pattern, repl, target, *args, **kwargs):
class ODict(object):
+
"""
A dictionary-like object for managing ordered (key, value) data. Think
about it as a convenient interface to a list of (key, value) tuples.
"""
+
def __init__(self, lst=None):
self.lst = lst or []
@@ -157,7 +160,7 @@ class ODict(object):
"key: value"
"""
for k, v in self.lst:
- s = "%s: %s"%(k, v)
+ s = "%s: %s" % (k, v)
if re.search(expr, s):
return True
return False
@@ -192,11 +195,12 @@ class ODict(object):
return klass([list(i) for i in state])
-
class ODictCaseless(ODict):
+
"""
A variant of ODict with "caseless" keys. This version _preserves_ key
case, but does not consider case when setting or getting items.
"""
+
def _kconv(self, s):
return s.lower()
diff --git a/netlib/socks.py b/netlib/socks.py
index 6f9f57bd..5a73c61a 100644
--- a/netlib/socks.py
+++ b/netlib/socks.py
@@ -6,49 +6,50 @@ from . import tcp, utils
class SocksError(Exception):
+
def __init__(self, code, message):
super(SocksError, self).__init__(message)
self.code = code
VERSION = utils.BiDi(
- SOCKS4 = 0x04,
- SOCKS5 = 0x05
+ SOCKS4=0x04,
+ SOCKS5=0x05
)
CMD = utils.BiDi(
- CONNECT = 0x01,
- BIND = 0x02,
- UDP_ASSOCIATE = 0x03
+ CONNECT=0x01,
+ BIND=0x02,
+ UDP_ASSOCIATE=0x03
)
ATYP = utils.BiDi(
- IPV4_ADDRESS = 0x01,
- DOMAINNAME = 0x03,
- IPV6_ADDRESS = 0x04
+ IPV4_ADDRESS=0x01,
+ DOMAINNAME=0x03,
+ IPV6_ADDRESS=0x04
)
REP = utils.BiDi(
- SUCCEEDED = 0x00,
- GENERAL_SOCKS_SERVER_FAILURE = 0x01,
- CONNECTION_NOT_ALLOWED_BY_RULESET = 0x02,
- NETWORK_UNREACHABLE = 0x03,
- HOST_UNREACHABLE = 0x04,
- CONNECTION_REFUSED = 0x05,
- TTL_EXPIRED = 0x06,
- COMMAND_NOT_SUPPORTED = 0x07,
- ADDRESS_TYPE_NOT_SUPPORTED = 0x08,
+ SUCCEEDED=0x00,
+ GENERAL_SOCKS_SERVER_FAILURE=0x01,
+ CONNECTION_NOT_ALLOWED_BY_RULESET=0x02,
+ NETWORK_UNREACHABLE=0x03,
+ HOST_UNREACHABLE=0x04,
+ CONNECTION_REFUSED=0x05,
+ TTL_EXPIRED=0x06,
+ COMMAND_NOT_SUPPORTED=0x07,
+ ADDRESS_TYPE_NOT_SUPPORTED=0x08,
)
METHOD = utils.BiDi(
- NO_AUTHENTICATION_REQUIRED = 0x00,
- GSSAPI = 0x01,
- USERNAME_PASSWORD = 0x02,
- NO_ACCEPTABLE_METHODS = 0xFF
+ NO_AUTHENTICATION_REQUIRED=0x00,
+ GSSAPI=0x01,
+ USERNAME_PASSWORD=0x02,
+ NO_ACCEPTABLE_METHODS=0xFF
)
diff --git a/netlib/tcp.py b/netlib/tcp.py
index 399203bb..7c115554 100644
--- a/netlib/tcp.py
+++ b/netlib/tcp.py
@@ -22,14 +22,28 @@ OP_NO_SSLv2 = SSL.OP_NO_SSLv2
OP_NO_SSLv3 = SSL.OP_NO_SSLv3
-class NetLibError(Exception): pass
-class NetLibDisconnect(NetLibError): pass
-class NetLibIncomplete(NetLibError): pass
-class NetLibTimeout(NetLibError): pass
-class NetLibSSLError(NetLibError): pass
+class NetLibError(Exception):
+ pass
+
+
+class NetLibDisconnect(NetLibError):
+ pass
+
+
+class NetLibIncomplete(NetLibError):
+ pass
+
+
+class NetLibTimeout(NetLibError):
+ pass
+
+
+class NetLibSSLError(NetLibError):
+ pass
class SSLKeyLogger(object):
+
def __init__(self, filename):
self.filename = filename
self.f = None
@@ -67,6 +81,7 @@ log_ssl_key = SSLKeyLogger.create_logfun(os.getenv("MITMPROXY_SSLKEYLOGFILE") or
class _FileLike(object):
BLOCKSIZE = 1024 * 32
+
def __init__(self, o):
self.o = o
self._log = None
@@ -112,6 +127,7 @@ class _FileLike(object):
class Writer(_FileLike):
+
def flush(self):
"""
May raise NetLibDisconnect
@@ -119,7 +135,7 @@ class Writer(_FileLike):
if hasattr(self.o, "flush"):
try:
self.o.flush()
- except (socket.error, IOError), v:
+ except (socket.error, IOError) as v:
raise NetLibDisconnect(str(v))
def write(self, v):
@@ -135,11 +151,12 @@ class Writer(_FileLike):
r = self.o.write(v)
self.add_log(v[:r])
return r
- except (SSL.Error, socket.error) as e:
+ except (SSL.Error, socket.error) as e:
raise NetLibDisconnect(str(e))
class Reader(_FileLike):
+
def read(self, length):
"""
If length is -1, we read until connection closes.
@@ -180,7 +197,7 @@ class Reader(_FileLike):
self.add_log(result)
return result
- def readline(self, size = None):
+ def readline(self, size=None):
result = ''
bytes_read = 0
while True:
@@ -204,16 +221,18 @@ class Reader(_FileLike):
result = self.read(length)
if length != -1 and len(result) != length:
raise NetLibIncomplete(
- "Expected %s bytes, got %s"%(length, len(result))
+ "Expected %s bytes, got %s" % (length, len(result))
)
return result
class Address(object):
+
"""
This class wraps an IPv4/IPv6 tuple to provide named attributes and
ipv6 information.
"""
+
def __init__(self, address, use_ipv6=False):
self.address = tuple(address)
self.use_ipv6 = use_ipv6
@@ -304,6 +323,7 @@ def close_socket(sock):
class _Connection(object):
+
def get_current_cipher(self):
if not self.ssl_established:
return None
@@ -319,7 +339,7 @@ class _Connection(object):
# (We call _FileLike.set_descriptor(conn))
# Closing the socket is not our task, therefore we don't call close
# then.
- if type(self.connection) != SSL.Connection:
+ if not isinstance(self.connection, SSL.Connection):
if not getattr(self.wfile, "closed", False):
try:
self.wfile.flush()
@@ -337,6 +357,7 @@ class _Connection(object):
"""
Creates an SSL Context.
"""
+
def _create_ssl_context(self,
method=SSLv23_METHOD,
options=(OP_NO_SSLv2 | OP_NO_SSLv3),
@@ -362,8 +383,8 @@ class _Connection(object):
if cipher_list:
try:
context.set_cipher_list(cipher_list)
- except SSL.Error, v:
- raise NetLibError("SSL cipher specification error: %s"%str(v))
+ except SSL.Error as v:
+ raise NetLibError("SSL cipher specification error: %s" % str(v))
# SSLKEYLOGFILE
if log_ssl_key:
@@ -380,7 +401,7 @@ class TCPClient(_Connection):
# Make sure to close the real socket, not the SSL proxy.
# OpenSSL is really good at screwing up, i.e. when trying to recv from a failed connection,
# it tries to renegotiate...
- if type(self.connection) == SSL.Connection:
+ if isinstance(self.connection, SSL.Connection):
close_socket(self.connection._socket)
else:
close_socket(self.connection)
@@ -400,8 +421,8 @@ class TCPClient(_Connection):
try:
context.use_privatekey_file(cert)
context.use_certificate_file(cert)
- except SSL.Error, v:
- raise NetLibError("SSL client certificate error: %s"%str(v))
+ except SSL.Error as v:
+ raise NetLibError("SSL client certificate error: %s" % str(v))
return context
def convert_to_ssl(self, sni=None, **sslctx_kwargs):
@@ -418,8 +439,8 @@ class TCPClient(_Connection):
self.connection.set_connect_state()
try:
self.connection.do_handshake()
- except SSL.Error, v:
- raise NetLibError("SSL handshake error: %s"%repr(v))
+ except SSL.Error as v:
+ raise NetLibError("SSL handshake error: %s" % repr(v))
self.ssl_established = True
self.cert = certutils.SSLCert(self.connection.get_peer_certificate())
self.rfile.set_descriptor(self.connection)
@@ -435,7 +456,7 @@ class TCPClient(_Connection):
self.source_address = Address(connection.getsockname())
self.rfile = Reader(connection.makefile('rb', self.rbufsize))
self.wfile = Writer(connection.makefile('wb', self.wbufsize))
- except (socket.error, IOError), err:
+ except (socket.error, IOError) as err:
raise NetLibError('Error connecting to "%s": %s' % (self.address.host, err))
self.connection = connection
@@ -447,6 +468,7 @@ class TCPClient(_Connection):
class BaseHandler(_Connection):
+
"""
The instantiator is expected to call the handle() and finish() methods.
@@ -531,8 +553,8 @@ class BaseHandler(_Connection):
self.connection.set_accept_state()
try:
self.connection.do_handshake()
- except SSL.Error, v:
- raise NetLibError("SSL handshake error: %s"%repr(v))
+ except SSL.Error as v:
+ raise NetLibError("SSL handshake error: %s" % repr(v))
self.ssl_established = True
self.rfile.set_descriptor(self.connection)
self.wfile.set_descriptor(self.connection)
diff --git a/netlib/test.py b/netlib/test.py
index db30c0e6..b6f94273 100644
--- a/netlib/test.py
+++ b/netlib/test.py
@@ -1,9 +1,13 @@
from __future__ import (absolute_import, print_function, division)
-import threading, Queue, cStringIO
+import threading
+import Queue
+import cStringIO
import OpenSSL
from . import tcp, certutils
+
class ServerThread(threading.Thread):
+
def __init__(self, server):
self.server = server
threading.Thread.__init__(self)
@@ -19,6 +23,7 @@ class ServerTestBase(object):
ssl = None
handler = None
addr = ("localhost", 0)
+
@classmethod
def setupAll(cls):
cls.q = Queue.Queue()
@@ -41,10 +46,11 @@ class ServerTestBase(object):
class TServer(tcp.TCPServer):
+
def __init__(self, ssl, q, handler_klass, addr):
"""
ssl: A dictionary of SSL parameters:
-
+
cert, key, request_client_cert, cipher_list,
dhparams, v3_only
"""
@@ -70,13 +76,13 @@ class TServer(tcp.TCPServer):
options = None
h.convert_to_ssl(
cert, key,
- method = method,
- options = options,
- handle_sni = getattr(h, "handle_sni", None),
- request_client_cert = self.ssl["request_client_cert"],
- cipher_list = self.ssl.get("cipher_list", None),
- dhparams = self.ssl.get("dhparams", None),
- chain_file = self.ssl.get("chain_file", None)
+ method=method,
+ options=options,
+ handle_sni=getattr(h, "handle_sni", None),
+ request_client_cert=self.ssl["request_client_cert"],
+ cipher_list=self.ssl.get("cipher_list", None),
+ dhparams=self.ssl.get("dhparams", None),
+ chain_file=self.ssl.get("chain_file", None)
)
h.handle()
h.finish()
diff --git a/netlib/utils.py b/netlib/utils.py
index 7e539977..9c5404e6 100644
--- a/netlib/utils.py
+++ b/netlib/utils.py
@@ -68,6 +68,7 @@ def getbit(byte, offset):
class BiDi:
+
"""
A wee utility class for keeping bi-directional mappings, like field
constants in protocols. Names are attributes on the object, dict-like
@@ -77,6 +78,7 @@ class BiDi:
assert CONST.a == 1
assert CONST.get_name(1) == "a"
"""
+
def __init__(self, **kwargs):
self.names = kwargs
self.values = {}
@@ -96,15 +98,15 @@ class BiDi:
def pretty_size(size):
suffixes = [
- ("B", 2**10),
- ("kB", 2**20),
- ("MB", 2**30),
+ ("B", 2 ** 10),
+ ("kB", 2 ** 20),
+ ("MB", 2 ** 30),
]
for suf, lim in suffixes:
if size >= lim:
continue
else:
- x = round(size/float(lim/2**10), 2)
+ x = round(size / float(lim / 2 ** 10), 2)
if x == int(x):
x = int(x)
return str(x) + suf
diff --git a/netlib/websockets.py b/netlib/websockets.py
index a2d55c19..63dc03f1 100644
--- a/netlib/websockets.py
+++ b/netlib/websockets.py
@@ -26,16 +26,17 @@ MAX_64_BIT_INT = (1 << 64)
OPCODE = utils.BiDi(
- CONTINUE = 0x00,
- TEXT = 0x01,
- BINARY = 0x02,
- CLOSE = 0x08,
- PING = 0x09,
- PONG = 0x0a
+ CONTINUE=0x00,
+ TEXT=0x01,
+ BINARY=0x02,
+ CLOSE=0x08,
+ PING=0x09,
+ PONG=0x0a
)
class Masker:
+
"""
Data sent from the server must be masked to prevent malicious clients
from sending data over the wire in predictable patterns
@@ -43,6 +44,7 @@ class Masker:
Servers do not have to mask data they send to the client.
https://tools.ietf.org/html/rfc6455#section-5.3
"""
+
def __init__(self, key):
self.key = key
self.masks = [utils.bytes_to_int(byte) for byte in key]
@@ -128,17 +130,18 @@ DEFAULT = object()
class FrameHeader:
+
def __init__(
self,
- opcode = OPCODE.TEXT,
- payload_length = 0,
- fin = False,
- rsv1 = False,
- rsv2 = False,
- rsv3 = False,
- masking_key = DEFAULT,
- mask = DEFAULT,
- length_code = DEFAULT
+ opcode=OPCODE.TEXT,
+ payload_length=0,
+ fin=False,
+ rsv1=False,
+ rsv2=False,
+ rsv3=False,
+ masking_key=DEFAULT,
+ mask=DEFAULT,
+ length_code=DEFAULT
):
if not 0 <= opcode < 2 ** 4:
raise ValueError("opcode must be 0-16")
@@ -182,9 +185,9 @@ class FrameHeader:
if flags:
vals.extend([":", "|".join(flags)])
if self.masking_key:
- vals.append(":key=%s"%repr(self.masking_key))
+ vals.append(":key=%s" % repr(self.masking_key))
if self.payload_length:
- vals.append(" %s"%utils.pretty_size(self.payload_length))
+ vals.append(" %s" % utils.pretty_size(self.payload_length))
return "".join(vals)
def to_bytes(self):
@@ -246,15 +249,15 @@ class FrameHeader:
masking_key = None
return klass(
- fin = fin,
- rsv1 = rsv1,
- rsv2 = rsv2,
- rsv3 = rsv3,
- opcode = opcode,
- mask = mask_bit,
- length_code = length_code,
- payload_length = payload_length,
- masking_key = masking_key,
+ fin=fin,
+ rsv1=rsv1,
+ rsv2=rsv2,
+ rsv3=rsv3,
+ opcode=opcode,
+ mask=mask_bit,
+ length_code=length_code,
+ payload_length=payload_length,
+ masking_key=masking_key,
)
def __eq__(self, other):
@@ -262,6 +265,7 @@ class FrameHeader:
class Frame(object):
+
"""
Represents one websockets frame.
Constructor takes human readable forms of the frame components
@@ -287,13 +291,14 @@ class Frame(object):
| Payload Data continued ... |
+---------------------------------------------------------------+
"""
- def __init__(self, payload = "", **kwargs):
+
+ def __init__(self, payload="", **kwargs):
self.payload = payload
kwargs["payload_length"] = kwargs.get("payload_length", len(payload))
self.header = FrameHeader(**kwargs)
@classmethod
- def default(cls, message, from_client = False):
+ def default(cls, message, from_client=False):
"""
Construct a basic websocket frame from some default values.
Creates a non-fragmented text frame.
@@ -307,10 +312,10 @@ class Frame(object):
return cls(
message,
- fin = 1, # final frame
- opcode = OPCODE.TEXT, # text
- mask = mask_bit,
- masking_key = masking_key,
+ fin=1, # final frame
+ opcode=OPCODE.TEXT, # text
+ mask=mask_bit,
+ masking_key=masking_key,
)
@classmethod
@@ -356,15 +361,15 @@ class Frame(object):
return cls(
payload,
- fin = header.fin,
- opcode = header.opcode,
- mask = header.mask,
- payload_length = header.payload_length,
- masking_key = header.masking_key,
- rsv1 = header.rsv1,
- rsv2 = header.rsv2,
- rsv3 = header.rsv3,
- length_code = header.length_code
+ fin=header.fin,
+ opcode=header.opcode,
+ mask=header.mask,
+ payload_length=header.payload_length,
+ masking_key=header.masking_key,
+ rsv1=header.rsv1,
+ rsv2=header.rsv2,
+ rsv3=header.rsv3,
+ length_code=header.length_code
)
def __eq__(self, other):
diff --git a/netlib/wsgi.py b/netlib/wsgi.py
index 1b979608..f393039a 100644
--- a/netlib/wsgi.py
+++ b/netlib/wsgi.py
@@ -7,17 +7,20 @@ from . import odict, tcp
class ClientConn(object):
+
def __init__(self, address):
self.address = tcp.Address.wrap(address)
class Flow(object):
+
def __init__(self, address, request):
self.client_conn = ClientConn(address)
self.request = request
class Request(object):
+
def __init__(self, scheme, method, path, headers, content):
self.scheme, self.method, self.path = scheme, method, path
self.headers, self.content = headers, content
@@ -42,6 +45,7 @@ def date_time_string():
class WSGIAdaptor(object):
+
def __init__(self, app, domain, port, sversion):
self.app, self.domain, self.port, self.sversion = app, domain, port, sversion
@@ -52,24 +56,24 @@ class WSGIAdaptor(object):
path_info = flow.request.path
query = ''
environ = {
- 'wsgi.version': (1, 0),
- 'wsgi.url_scheme': flow.request.scheme,
- 'wsgi.input': cStringIO.StringIO(flow.request.content),
- 'wsgi.errors': errsoc,
- 'wsgi.multithread': True,
- 'wsgi.multiprocess': False,
- 'wsgi.run_once': False,
- 'SERVER_SOFTWARE': self.sversion,
- 'REQUEST_METHOD': flow.request.method,
- 'SCRIPT_NAME': '',
- 'PATH_INFO': urllib.unquote(path_info),
- 'QUERY_STRING': query,
- 'CONTENT_TYPE': flow.request.headers.get('Content-Type', [''])[0],
- 'CONTENT_LENGTH': flow.request.headers.get('Content-Length', [''])[0],
- 'SERVER_NAME': self.domain,
- 'SERVER_PORT': str(self.port),
+ 'wsgi.version': (1, 0),
+ 'wsgi.url_scheme': flow.request.scheme,
+ 'wsgi.input': cStringIO.StringIO(flow.request.content),
+ 'wsgi.errors': errsoc,
+ 'wsgi.multithread': True,
+ 'wsgi.multiprocess': False,
+ 'wsgi.run_once': False,
+ 'SERVER_SOFTWARE': self.sversion,
+ 'REQUEST_METHOD': flow.request.method,
+ 'SCRIPT_NAME': '',
+ 'PATH_INFO': urllib.unquote(path_info),
+ 'QUERY_STRING': query,
+ 'CONTENT_TYPE': flow.request.headers.get('Content-Type', [''])[0],
+ 'CONTENT_LENGTH': flow.request.headers.get('Content-Length', [''])[0],
+ 'SERVER_NAME': self.domain,
+ 'SERVER_PORT': str(self.port),
# FIXME: We need to pick up the protocol read from the request.
- 'SERVER_PROTOCOL': "HTTP/1.1",
+ 'SERVER_PROTOCOL': "HTTP/1.1",
}
environ.update(extra)
if flow.client_conn.address:
@@ -91,25 +95,25 @@ class WSGIAdaptor(object):
<h1>Internal Server Error</h1>
<pre>%s"</pre>
</html>
- """%s
+ """ % s
if not headers_sent:
soc.write("HTTP/1.1 500 Internal Server Error\r\n")
soc.write("Content-Type: text/html\r\n")
- soc.write("Content-Length: %s\r\n"%len(c))
+ soc.write("Content-Length: %s\r\n" % len(c))
soc.write("\r\n")
soc.write(c)
def serve(self, request, soc, **env):
state = dict(
- response_started = False,
- headers_sent = False,
- status = None,
- headers = None
+ response_started=False,
+ headers_sent=False,
+ status=None,
+ headers=None
)
def write(data):
if not state["headers_sent"]:
- soc.write("HTTP/1.1 %s\r\n"%state["status"])
+ soc.write("HTTP/1.1 %s\r\n" % state["status"])
h = state["headers"]
if 'server' not in h:
h["Server"] = [self.sversion]