diff options
Diffstat (limited to 'netlib')
-rw-r--r-- | netlib/__init__.py | 1 | ||||
-rw-r--r-- | netlib/certffi.py | 9 | ||||
-rw-r--r-- | netlib/certutils.py | 247 | ||||
-rw-r--r-- | netlib/contrib/__init__.py | 0 | ||||
-rw-r--r-- | netlib/contrib/md5crypt.py | 94 | ||||
-rw-r--r-- | netlib/http.py | 191 | ||||
-rw-r--r-- | netlib/http_auth.py | 32 | ||||
-rw-r--r-- | netlib/http_status.py | 1 | ||||
-rw-r--r-- | netlib/http_uastrings.py | 2 | ||||
-rw-r--r-- | netlib/odict.py | 26 | ||||
-rw-r--r-- | netlib/socks.py | 149 | ||||
-rw-r--r-- | netlib/tcp.py | 175 | ||||
-rw-r--r-- | netlib/test.py | 9 | ||||
-rw-r--r-- | netlib/utils.py | 10 | ||||
-rw-r--r-- | netlib/version.py | 3 | ||||
-rw-r--r-- | netlib/wsgi.py | 35 |
16 files changed, 593 insertions, 391 deletions
diff --git a/netlib/__init__.py b/netlib/__init__.py index e69de29b..9b4faa33 100644 --- a/netlib/__init__.py +++ b/netlib/__init__.py @@ -0,0 +1 @@ +from __future__ import (absolute_import, print_function, division) diff --git a/netlib/certffi.py b/netlib/certffi.py index c5d7c95e..81dc72e8 100644 --- a/netlib/certffi.py +++ b/netlib/certffi.py @@ -1,7 +1,9 @@ +from __future__ import (absolute_import, print_function, division) import cffi import OpenSSL + xffi = cffi.FFI() -xffi.cdef (""" +xffi.cdef(""" struct rsa_meth_st { int flags; ...; @@ -18,6 +20,7 @@ xffi.verify( extra_compile_args=['-w'] ) + def handle(privkey): new = xffi.new("struct rsa_st*") newbuf = xffi.buffer(new) @@ -26,11 +29,13 @@ def handle(privkey): newbuf[:] = oldbuf[:] return new + def set_flags(privkey, val): hdl = handle(privkey) - hdl.meth.flags = val + hdl.meth.flags = val return privkey + def get_flags(privkey): hdl = handle(privkey) return hdl.meth.flags diff --git a/netlib/certutils.py b/netlib/certutils.py index 187abfae..af6177d8 100644 --- a/netlib/certutils.py +++ b/netlib/certutils.py @@ -1,10 +1,10 @@ +from __future__ import (absolute_import, print_function, division) import os, ssl, time, datetime +import itertools from pyasn1.type import univ, constraint, char, namedtype, tag from pyasn1.codec.der.decoder import decode from pyasn1.error import PyAsn1Error import OpenSSL -import tcp -import UserDict DEFAULT_EXP = 62208000 # =24 * 60 * 60 * 720 # Generated with "openssl dhparam". It's too slow to generate this on startup. @@ -29,12 +29,12 @@ def create_ca(o, cn, exp): cert.add_extensions([ OpenSSL.crypto.X509Extension("basicConstraints", True, "CA:TRUE"), - OpenSSL.crypto.X509Extension("nsCertType", True, + OpenSSL.crypto.X509Extension("nsCertType", False, "sslCA"), - OpenSSL.crypto.X509Extension("extendedKeyUsage", True, + OpenSSL.crypto.X509Extension("extendedKeyUsage", False, "serverAuth,clientAuth,emailProtection,timeStamping,msCodeInd,msCodeCom,msCTLSign,msSGC,msEFS,nsSGC" ), - OpenSSL.crypto.X509Extension("keyUsage", False, + OpenSSL.crypto.X509Extension("keyUsage", True, "keyCertSign, cRLSign"), OpenSSL.crypto.X509Extension("subjectKeyIdentifier", False, "hash", subject=cert), @@ -67,62 +67,72 @@ def dummy_cert(privkey, cacert, commonname, sans): cert.set_serial_number(int(time.time()*10000)) if ss: cert.set_version(2) - cert.add_extensions([OpenSSL.crypto.X509Extension("subjectAltName", True, ss)]) + cert.add_extensions([OpenSSL.crypto.X509Extension("subjectAltName", False, ss)]) cert.set_pubkey(cacert.get_pubkey()) cert.sign(privkey, "sha1") return SSLCert(cert) -class _Node(UserDict.UserDict): - def __init__(self): - UserDict.UserDict.__init__(self) - self.value = None - - -class DNTree: - """ - Domain store that knows about wildcards. DNS wildcards are very - restricted - the only valid variety is an asterisk on the left-most - domain component, i.e.: - - *.foo.com - """ - def __init__(self): - self.d = _Node() - - def add(self, dn, cert): - parts = dn.split(".") - parts.reverse() - current = self.d - for i in parts: - current = current.setdefault(i, _Node()) - current.value = cert - - def get(self, dn): - parts = dn.split(".") - current = self.d - for i in reversed(parts): - if i in current: - current = current[i] - elif "*" in current: - return current["*"].value - else: - return None - return current.value - +# DNTree did not pass TestCertStore.test_sans_change and is temporarily replaced by a simple dict. +# +# class _Node(UserDict.UserDict): +# def __init__(self): +# UserDict.UserDict.__init__(self) +# self.value = None +# +# +# class DNTree: +# """ +# Domain store that knows about wildcards. DNS wildcards are very +# restricted - the only valid variety is an asterisk on the left-most +# domain component, i.e.: +# +# *.foo.com +# """ +# def __init__(self): +# self.d = _Node() +# +# def add(self, dn, cert): +# parts = dn.split(".") +# parts.reverse() +# current = self.d +# for i in parts: +# current = current.setdefault(i, _Node()) +# current.value = cert +# +# def get(self, dn): +# parts = dn.split(".") +# current = self.d +# for i in reversed(parts): +# if i in current: +# current = current[i] +# elif "*" in current: +# return current["*"].value +# else: +# return None +# return current.value + + +class CertStoreEntry(object): + def __init__(self, cert, privatekey, chain_file): + self.cert = cert + self.privatekey = privatekey + self.chain_file = chain_file class CertStore: """ Implements an in-memory certificate store. """ - def __init__(self, privkey, cacert, dhparams=None): - self.privkey, self.cacert = privkey, cacert + def __init__(self, default_privatekey, default_ca, default_chain_file, dhparams=None): + self.default_privatekey = default_privatekey + self.default_ca = default_ca + self.default_chain_file = default_chain_file self.dhparams = dhparams - self.certs = DNTree() + self.certs = dict() - @classmethod - def load_dhparam(klass, path): + @staticmethod + def load_dhparam(path): # netlib<=0.10 doesn't generate a dhparam file. # Create it now if neccessary. @@ -140,21 +150,21 @@ class CertStore: return dh @classmethod - def from_store(klass, path, basename): - p = os.path.join(path, basename + "-ca.pem") - if not os.path.exists(p): - key, ca = klass.create_store(path, basename) + def from_store(cls, path, basename): + ca_path = os.path.join(path, basename + "-ca.pem") + if not os.path.exists(ca_path): + key, ca = cls.create_store(path, basename) else: - p = os.path.join(path, basename + "-ca.pem") - raw = file(p, "rb").read() + with open(ca_path, "rb") as f: + raw = f.read() ca = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, raw) key = OpenSSL.crypto.load_privatekey(OpenSSL.crypto.FILETYPE_PEM, raw) - dhp = os.path.join(path, basename + "-dhparam.pem") - dh = klass.load_dhparam(dhp) - return klass(key, ca, dh) + dh_path = os.path.join(path, basename + "-dhparam.pem") + dh = cls.load_dhparam(dh_path) + return cls(key, ca, ca_path, dh) - @classmethod - def create_store(klass, path, basename, o=None, cn=None, expiry=DEFAULT_EXP): + @staticmethod + def create_store(path, basename, o=None, cn=None, expiry=DEFAULT_EXP): if not os.path.exists(path): os.makedirs(path) @@ -163,58 +173,71 @@ class CertStore: key, ca = create_ca(o=o, cn=cn, exp=expiry) # Dump the CA plus private key - f = open(os.path.join(path, basename + "-ca.pem"), "wb") - f.write(OpenSSL.crypto.dump_privatekey(OpenSSL.crypto.FILETYPE_PEM, key)) - f.write(OpenSSL.crypto.dump_certificate(OpenSSL.crypto.FILETYPE_PEM, ca)) - f.close() + with open(os.path.join(path, basename + "-ca.pem"), "wb") as f: + f.write(OpenSSL.crypto.dump_privatekey(OpenSSL.crypto.FILETYPE_PEM, key)) + f.write(OpenSSL.crypto.dump_certificate(OpenSSL.crypto.FILETYPE_PEM, ca)) # Dump the certificate in PEM format - f = open(os.path.join(path, basename + "-ca-cert.pem"), "wb") - f.write(OpenSSL.crypto.dump_certificate(OpenSSL.crypto.FILETYPE_PEM, ca)) - f.close() + with open(os.path.join(path, basename + "-ca-cert.pem"), "wb") as f: + f.write(OpenSSL.crypto.dump_certificate(OpenSSL.crypto.FILETYPE_PEM, ca)) # Create a .cer file with the same contents for Android - f = open(os.path.join(path, basename + "-ca-cert.cer"), "wb") - f.write(OpenSSL.crypto.dump_certificate(OpenSSL.crypto.FILETYPE_PEM, ca)) - f.close() + with open(os.path.join(path, basename + "-ca-cert.cer"), "wb") as f: + f.write(OpenSSL.crypto.dump_certificate(OpenSSL.crypto.FILETYPE_PEM, ca)) # Dump the certificate in PKCS12 format for Windows devices - f = open(os.path.join(path, basename + "-ca-cert.p12"), "wb") - p12 = OpenSSL.crypto.PKCS12() - p12.set_certificate(ca) - p12.set_privatekey(key) - f.write(p12.export()) - f.close() - - f = open(os.path.join(path, basename + "-dhparam.pem"), "wb") - f.write(DEFAULT_DHPARAM) - f.close() + with open(os.path.join(path, basename + "-ca-cert.p12"), "wb") as f: + p12 = OpenSSL.crypto.PKCS12() + p12.set_certificate(ca) + p12.set_privatekey(key) + f.write(p12.export()) + + with open(os.path.join(path, basename + "-dhparam.pem"), "wb") as f: + f.write(DEFAULT_DHPARAM) + return key, ca def add_cert_file(self, spec, path): - raw = file(path, "rb").read() - cert = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, raw) + with open(path, "rb") as f: + raw = f.read() + cert = SSLCert(OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, raw)) try: - privkey = OpenSSL.crypto.load_privatekey(OpenSSL.crypto.FILETYPE_PEM, raw) + privatekey = OpenSSL.crypto.load_privatekey(OpenSSL.crypto.FILETYPE_PEM, raw) except Exception: - privkey = None - self.add_cert(SSLCert(cert), privkey, spec) + privatekey = self.default_privatekey + self.add_cert( + CertStoreEntry(cert, privatekey, path), + spec + ) - def add_cert(self, cert, privkey, *names): + def add_cert(self, entry, *names): """ Adds a cert to the certstore. We register the CN in the cert plus any SANs, and also the list of names provided as an argument. """ - if cert.cn: - self.certs.add(cert.cn, (cert, privkey)) - for i in cert.altnames: - self.certs.add(i, (cert, privkey)) + if entry.cert.cn: + self.certs[entry.cert.cn] = entry + for i in entry.cert.altnames: + self.certs[i] = entry for i in names: - self.certs.add(i, (cert, privkey)) + self.certs[i] = entry + + @staticmethod + def asterisk_forms(dn): + parts = dn.split(".") + parts.reverse() + curr_dn = "" + dn_forms = ["*"] + for part in parts[:-1]: + curr_dn = "." + part + curr_dn # .example.com + dn_forms.append("*" + curr_dn) # *.example.com + if parts[-1] != "*": + dn_forms.append(parts[-1] + curr_dn) + return dn_forms def get_cert(self, commonname, sans): """ - Returns an (cert, privkey) tuple. + Returns an (cert, privkey, cert_chain) tuple. commonname: Common name for the generated certificate. Must be a valid, plain-ASCII, IDNA-encoded domain name. @@ -223,17 +246,30 @@ class CertStore: Return None if the certificate could not be found or generated. """ - c = self.certs.get(commonname) - if not c: - c = dummy_cert(self.privkey, self.cacert, commonname, sans) - self.add_cert(c, None) - c = (c, None) - return (c[0], c[1] or self.privkey) + + potential_keys = self.asterisk_forms(commonname) + for s in sans: + potential_keys.extend(self.asterisk_forms(s)) + potential_keys.append((commonname, tuple(sans))) + + name = next(itertools.ifilter(lambda key: key in self.certs, potential_keys), None) + if name: + entry = self.certs[name] + else: + entry = CertStoreEntry( + cert=dummy_cert(self.default_privatekey, self.default_ca, commonname, sans), + privatekey=self.default_privatekey, + chain_file=self.default_chain_file + ) + self.certs[(commonname, tuple(sans))] = entry + + return entry.cert, entry.privatekey, entry.chain_file def gen_pkey(self, cert): - import certffi - certffi.set_flags(self.privkey, 1) - return self.privkey + # FIXME: We should do something with cert here? + from . import certffi + certffi.set_flags(self.default_privatekey, 1) + return self.default_privatekey class _GeneralName(univ.Choice): @@ -262,6 +298,9 @@ class SSLCert: def __eq__(self, other): return self.digest("sha1") == other.digest("sha1") + def __ne__(self, other): + return not self.__eq__(other) + @classmethod def from_pem(klass, txt): x509 = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, txt) @@ -337,11 +376,3 @@ class SSLCert: for i in dec[0]: altnames.append(i[0].asOctets()) return altnames - - - -def get_remote_cert(host, port, sni): - c = tcp.TCPClient((host, port)) - c.connect() - c.convert_to_ssl(sni=sni) - return c.cert diff --git a/netlib/contrib/__init__.py b/netlib/contrib/__init__.py deleted file mode 100644 index e69de29b..00000000 --- a/netlib/contrib/__init__.py +++ /dev/null diff --git a/netlib/contrib/md5crypt.py b/netlib/contrib/md5crypt.py deleted file mode 100644 index d64ea8ac..00000000 --- a/netlib/contrib/md5crypt.py +++ /dev/null @@ -1,94 +0,0 @@ -# Based on FreeBSD src/lib/libcrypt/crypt.c 1.2 -# http://www.freebsd.org/cgi/cvsweb.cgi/~checkout~/src/lib/libcrypt/crypt.c?rev=1.2&content-type=text/plain - -# Original license: -# * "THE BEER-WARE LICENSE" (Revision 42): -# * <phk@login.dknet.dk> wrote this file. As long as you retain this notice you -# * can do whatever you want with this stuff. If we meet some day, and you think -# * this stuff is worth it, you can buy me a beer in return. Poul-Henning Kamp - -# This port adds no further stipulations. I forfeit any copyright interest. - -import md5 - -def md5crypt(password, salt, magic='$1$'): - # /* The password first, since that is what is most unknown */ /* Then our magic string */ /* Then the raw salt */ - m = md5.new() - m.update(password + magic + salt) - - # /* Then just as many characters of the MD5(pw,salt,pw) */ - mixin = md5.md5(password + salt + password).digest() - for i in range(0, len(password)): - m.update(mixin[i % 16]) - - # /* Then something really weird... */ - # Also really broken, as far as I can tell. -m - i = len(password) - while i: - if i & 1: - m.update('\x00') - else: - m.update(password[0]) - i >>= 1 - - final = m.digest() - - # /* and now, just to make sure things don't run too fast */ - for i in range(1000): - m2 = md5.md5() - if i & 1: - m2.update(password) - else: - m2.update(final) - - if i % 3: - m2.update(salt) - - if i % 7: - m2.update(password) - - if i & 1: - m2.update(final) - else: - m2.update(password) - - final = m2.digest() - - # This is the bit that uses to64() in the original code. - - itoa64 = './0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz' - - rearranged = '' - for a, b, c in ((0, 6, 12), (1, 7, 13), (2, 8, 14), (3, 9, 15), (4, 10, 5)): - v = ord(final[a]) << 16 | ord(final[b]) << 8 | ord(final[c]) - for i in range(4): - rearranged += itoa64[v & 0x3f]; v >>= 6 - - v = ord(final[11]) - for i in range(2): - rearranged += itoa64[v & 0x3f]; v >>= 6 - - return magic + salt + '$' + rearranged - -if __name__ == '__main__': - - def test(clear_password, the_hash): - magic, salt = the_hash[1:].split('$')[:2] - magic = '$' + magic + '$' - return md5crypt(clear_password, salt, magic) == the_hash - - test_cases = ( - (' ', '$1$yiiZbNIH$YiCsHZjcTkYd31wkgW8JF.'), - ('pass', '$1$YeNsbWdH$wvOF8JdqsoiLix754LTW90'), - ('____fifteen____', '$1$s9lUWACI$Kk1jtIVVdmT01p0z3b/hw1'), - ('____sixteen_____', '$1$dL3xbVZI$kkgqhCanLdxODGq14g/tW1'), - ('____seventeen____', '$1$NaH5na7J$j7y8Iss0hcRbu3kzoJs5V.'), - ('__________thirty-three___________', '$1$HO7Q6vzJ$yGwp2wbL5D7eOVzOmxpsy.'), - ('apache', '$apr1$J.w5a/..$IW9y6DR0oO/ADuhlMF5/X1') - ) - - for clearpw, hashpw in test_cases: - if test(clearpw, hashpw): - print '%s: pass' % clearpw - else: - print '%s: FAIL' % clearpw diff --git a/netlib/http.py b/netlib/http.py index f5b8118a..d2fc6343 100644 --- a/netlib/http.py +++ b/netlib/http.py @@ -1,5 +1,7 @@ +from __future__ import (absolute_import, print_function, division) import string, urlparse, binascii -import odict, utils +import sys +from . import odict, utils class HttpError(Exception): @@ -43,6 +45,11 @@ def parse_url(url): return None if not scheme: return None + if '@' in netloc: + # FIXME: Consider what to do with the discarded credentials here Most + # probably we should extend the signature to return these as a separate + # value. + _, netloc = string.rsplit(netloc, '@', maxsplit=1) if ':' in netloc: host, port = string.rsplit(netloc, ':', maxsplit=1) try: @@ -88,14 +95,14 @@ def read_headers(fp): # We're being liberal in what we accept, here. if i > 0: name = line[:i] - value = line[i+1:].strip() + value = line[i + 1:].strip() ret.append([name, value]) else: return None return odict.ODictCaseless(ret) -def read_chunked(fp, headers, limit, is_request): +def read_chunked(fp, limit, is_request): """ Read a chunked HTTP body. @@ -103,10 +110,9 @@ def read_chunked(fp, headers, limit, is_request): """ # FIXME: Should check if chunked is the final encoding in the headers # http://tools.ietf.org/html/draft-ietf-httpbis-p1-messaging-16#section-3.3 3.3 2. - content = "" total = 0 code = 400 if is_request else 502 - while 1: + while True: line = fp.readline(128) if line == "": raise HttpErrorConnClosed(code, "Connection closed prematurely") @@ -114,27 +120,22 @@ def read_chunked(fp, headers, limit, is_request): try: length = int(line, 16) except ValueError: - # FIXME: Not strictly correct - this could be from the server, in which - # case we should send a 502. - raise HttpError(code, "Invalid chunked encoding length: %s"%line) - if not length: - break + raise HttpError( + code, + "Invalid chunked encoding length: %s" % line + ) total += length if limit is not None and total > limit: - msg = "HTTP Body too large."\ - " Limit is %s, chunked content length was at least %s"%(limit, total) + msg = "HTTP Body too large. Limit is %s," \ + " chunked content longer than %s" % (limit, total) raise HttpError(code, msg) - content += fp.read(length) - line = fp.readline(5) - if line != '\r\n': + chunk = fp.read(length) + suffix = fp.readline(5) + if suffix != '\r\n': raise HttpError(code, "Malformed chunked body") - while 1: - line = fp.readline() - if line == "": - raise HttpErrorConnClosed(code, "Connection closed prematurely") - if line == '\r\n' or line == '\n': - break - return content + yield line, chunk, '\r\n' + if length == 0: + return def get_header_tokens(headers, key): @@ -151,7 +152,9 @@ def get_header_tokens(headers, key): def has_chunked_encoding(headers): - return "chunked" in [i.lower() for i in get_header_tokens(headers, "transfer-encoding")] + return "chunked" in [ + i.lower() for i in get_header_tokens(headers, "transfer-encoding") + ] def parse_http_protocol(s): @@ -263,7 +266,9 @@ def parse_init_http(line): def connection_close(httpversion, headers): """ - Checks the message to see if the client connection should be closed according to RFC 2616 Section 8.1 + Checks the message to see if the client connection should be closed + according to RFC 2616 Section 8.1 Note that a connection should be + closed as well if the response has been read until end of the stream. """ # At first, check if we have an explicit Connection header. if "connection" in headers: @@ -272,7 +277,8 @@ def connection_close(httpversion, headers): return True elif "keep-alive" in toks: return False - # If we don't have a Connection header, HTTP 1.1 connections are assumed to be persistent + # If we don't have a Connection header, HTTP 1.1 connections are assumed to + # be persistent if httpversion == (1, 1): return False return True @@ -280,7 +286,7 @@ def connection_close(httpversion, headers): def parse_response_line(line): parts = line.strip().split(" ", 2) - if len(parts) == 2: # handle missing message gracefully + if len(parts) == 2: # handle missing message gracefully parts.append("") if len(parts) != 3: return None @@ -292,60 +298,141 @@ def parse_response_line(line): return (proto, code, msg) -def read_response(rfile, method, body_size_limit): +def read_response(rfile, request_method, body_size_limit, include_body=True): """ Return an (httpversion, code, msg, headers, content) tuple. + + By default, both response header and body are read. + If include_body=False is specified, content may be one of the following: + - None, if the response is technically allowed to have a response body + - "", if the response must not have a response body (e.g. it's a response to a HEAD request) """ line = rfile.readline() - if line == "\r\n" or line == "\n": # Possible leftover from previous message + if line == "\r\n" or line == "\n": # Possible leftover from previous message line = rfile.readline() if not line: raise HttpErrorConnClosed(502, "Server disconnect.") parts = parse_response_line(line) if not parts: - raise HttpError(502, "Invalid server response: %s"%repr(line)) + raise HttpError(502, "Invalid server response: %s" % repr(line)) proto, code, msg = parts httpversion = parse_http_protocol(proto) if httpversion is None: - raise HttpError(502, "Invalid HTTP version in line: %s"%repr(proto)) + raise HttpError(502, "Invalid HTTP version in line: %s" % repr(proto)) headers = read_headers(rfile) if headers is None: raise HttpError(502, "Invalid headers.") - # Parse response body according to http://tools.ietf.org/html/draft-ietf-httpbis-p1-messaging-16#section-3.3 - if method in ["HEAD", "CONNECT"] or (code in [204, 304]) or 100 <= code <= 199: - content = "" + if include_body: + content = read_http_body( + rfile, + headers, + body_size_limit, + request_method, + code, + False + ) else: - content = read_http_body(rfile, headers, body_size_limit, False) + # if include_body==False then a None content means the body should be + # read separately + content = None return httpversion, code, msg, headers, content -def read_http_body(rfile, headers, limit, is_request): +def read_http_body(*args, **kwargs): + return "".join( + content for _, content, _ in read_http_body_chunked(*args, **kwargs) + ) + + +def read_http_body_chunked(rfile, headers, limit, request_method, response_code, is_request, max_chunk_size=None): """ Read an HTTP message body: rfile: A file descriptor to read from headers: An ODictCaseless object limit: Size limit. - is_request: True if the body to read belongs to a request, False otherwise + is_request: True if the body to read belongs to a request, False + otherwise """ + if max_chunk_size is None: + max_chunk_size = limit or sys.maxint + + expected_size = expected_http_body_size( + headers, is_request, request_method, response_code + ) + + if expected_size is None: + if has_chunked_encoding(headers): + # Python 3: yield from + for x in read_chunked(rfile, limit, is_request): + yield x + else: # pragma: nocover + raise HttpError( + 400 if is_request else 502, + "Content-Length unknown but no chunked encoding" + ) + elif expected_size >= 0: + if limit is not None and expected_size > limit: + raise HttpError( + 400 if is_request else 509, + "HTTP Body too large. Limit is %s, content-length was %s" % ( + limit, expected_size + ) + ) + bytes_left = expected_size + while bytes_left: + chunk_size = min(bytes_left, max_chunk_size) + yield "", rfile.read(chunk_size), "" + bytes_left -= chunk_size + else: + bytes_left = limit or -1 + while bytes_left: + chunk_size = min(bytes_left, max_chunk_size) + content = rfile.read(chunk_size) + if not content: + return + yield "", content, "" + bytes_left -= chunk_size + not_done = rfile.read(1) + if not_done: + raise HttpError( + 400 if is_request else 509, + "HTTP Body too large. Limit is %s," % limit + ) + + +def expected_http_body_size(headers, is_request, request_method, response_code): + """ + Returns the expected body length: + - a positive integer, if the size is known in advance + - None, if the size in unknown in advance (chunked encoding or invalid + data) + - -1, if all data should be read until end of stream. + + May raise HttpError. + """ + # Determine response size according to + # http://tools.ietf.org/html/rfc7230#section-3.3 + if request_method: + request_method = request_method.upper() + + if (not is_request and ( + request_method == "HEAD" or + (request_method == "CONNECT" and response_code == 200) or + response_code in [204, 304] or + 100 <= response_code <= 199)): + return 0 if has_chunked_encoding(headers): - content = read_chunked(rfile, headers, limit, is_request) - elif "content-length" in headers: + return None + if "content-length" in headers: try: - l = int(headers["content-length"][0]) - if l < 0: + size = int(headers["content-length"][0]) + if size < 0: raise ValueError() + return size except ValueError: - raise HttpError(400 if is_request else 502, "Invalid content-length header: %s"%headers["content-length"]) - if limit is not None and l > limit: - raise HttpError(400 if is_request else 509, "HTTP Body too large. Limit is %s, content-length was %s"%(limit, l)) - content = rfile.read(l) - elif is_request: - content = "" - else: - content = rfile.read(limit if limit else -1) - not_done = rfile.read(1) - if not_done: - raise HttpError(400 if is_request else 509, "HTTP Body too large. Limit is %s," % limit) - return content
\ No newline at end of file + return None + if is_request: + return 0 + return -1 diff --git a/netlib/http_auth.py b/netlib/http_auth.py index b0451e3b..dca6e2f3 100644 --- a/netlib/http_auth.py +++ b/netlib/http_auth.py @@ -1,6 +1,6 @@ -from .contrib import md5crypt -import http +from __future__ import (absolute_import, print_function, division) from argparse import Action, ArgumentTypeError +from . import http class NullProxyAuth(): @@ -78,32 +78,15 @@ class PassManHtpasswd: """ Read usernames and passwords from an htpasswd file """ - def __init__(self, fp): + def __init__(self, path): """ Raises ValueError if htpasswd file is invalid. """ - self.usernames = {} - for l in fp: - l = l.strip().split(':') - if len(l) != 2: - raise ValueError("Invalid htpasswd file.") - parts = l[1].split('$') - if len(parts) != 4: - raise ValueError("Invalid htpasswd file.") - self.usernames[l[0]] = dict( - token = l[1], - dummy = parts[0], - magic = parts[1], - salt = parts[2], - hashed_password = parts[3] - ) + import passlib.apache + self.htpasswd = passlib.apache.HtpasswdFile(path) def test(self, username, password_token): - ui = self.usernames.get(username) - if not ui: - return False - expected = md5crypt.md5crypt(password_token, ui["salt"], '$'+ui["magic"]+'$') - return expected==ui["token"] + return bool(self.htpasswd.check_password(username, password_token)) class PassManSingleUser: @@ -149,6 +132,5 @@ class NonanonymousAuthAction(AuthAction): class HtpasswdAuthAction(AuthAction): def getPasswordManager(self, s): - with open(s, "r") as f: - return PassManHtpasswd(f) + return PassManHtpasswd(s) diff --git a/netlib/http_status.py b/netlib/http_status.py index 9f3f7e15..7dba2d56 100644 --- a/netlib/http_status.py +++ b/netlib/http_status.py @@ -1,3 +1,4 @@ +from __future__ import (absolute_import, print_function, division) CONTINUE = 100 SWITCHING = 101 diff --git a/netlib/http_uastrings.py b/netlib/http_uastrings.py index 826c31a5..d0d145da 100644 --- a/netlib/http_uastrings.py +++ b/netlib/http_uastrings.py @@ -1,3 +1,5 @@ +from __future__ import (absolute_import, print_function, division) + """ A small collection of useful user-agent header strings. These should be kept reasonably current to reflect common usage. diff --git a/netlib/odict.py b/netlib/odict.py index ea95a586..61448e6d 100644 --- a/netlib/odict.py +++ b/netlib/odict.py @@ -1,3 +1,4 @@ +from __future__ import (absolute_import, print_function, division) import re, copy @@ -23,6 +24,9 @@ class ODict: def __eq__(self, other): return self.lst == other.lst + def __ne__(self, other): + return not self.__eq__(other) + def __iter__(self): return self.lst.__iter__() @@ -97,16 +101,6 @@ class ODict: def items(self): return self.lst[:] - def _get_state(self): - return [tuple(i) for i in self.lst] - - def _load_state(self, state): - self.list = [list(i) for i in state] - - @classmethod - def _from_state(klass, state): - return klass([list(i) for i in state]) - def copy(self): """ Returns a copy of this object. @@ -167,6 +161,18 @@ class ODict: self.lst = nlst return count + # Implement the StateObject protocol from mitmproxy + def get_state(self, short=False): + return [tuple(i) for i in self.lst] + + def load_state(self, state): + self.list = [list(i) for i in state] + + @classmethod + def from_state(klass, state): + return klass([list(i) for i in state]) + + class ODictCaseless(ODict): """ diff --git a/netlib/socks.py b/netlib/socks.py new file mode 100644 index 00000000..a3c4e9a2 --- /dev/null +++ b/netlib/socks.py @@ -0,0 +1,149 @@ +from __future__ import (absolute_import, print_function, division) +import socket +import struct +import array +from . import tcp + + +class SocksError(Exception): + def __init__(self, code, message): + super(SocksError, self).__init__(message) + self.code = code + + +class VERSION(object): + SOCKS4 = 0x04 + SOCKS5 = 0x05 + + +class CMD(object): + CONNECT = 0x01 + BIND = 0x02 + UDP_ASSOCIATE = 0x03 + + +class ATYP(object): + IPV4_ADDRESS = 0x01 + DOMAINNAME = 0x03 + IPV6_ADDRESS = 0x04 + + +class REP(object): + SUCCEEDED = 0x00 + GENERAL_SOCKS_SERVER_FAILURE = 0x01 + CONNECTION_NOT_ALLOWED_BY_RULESET = 0x02 + NETWORK_UNREACHABLE = 0x03 + HOST_UNREACHABLE = 0x04 + CONNECTION_REFUSED = 0x05 + TTL_EXPIRED = 0x06 + COMMAND_NOT_SUPPORTED = 0x07 + ADDRESS_TYPE_NOT_SUPPORTED = 0x08 + + +class METHOD(object): + NO_AUTHENTICATION_REQUIRED = 0x00 + GSSAPI = 0x01 + USERNAME_PASSWORD = 0x02 + NO_ACCEPTABLE_METHODS = 0xFF + + +def _read(f, n): + try: + d = f.read(n) + if len(d) == n: + return d + else: + raise SocksError( + REP.GENERAL_SOCKS_SERVER_FAILURE, + "Incomplete Read" + ) + except socket.error as e: + raise SocksError(REP.GENERAL_SOCKS_SERVER_FAILURE, str(e)) + + +class ClientGreeting(object): + __slots__ = ("ver", "methods") + + def __init__(self, ver, methods): + self.ver = ver + self.methods = methods + + @classmethod + def from_file(cls, f): + ver, nmethods = struct.unpack("!BB", _read(f, 2)) + methods = array.array("B") + methods.fromstring(_read(f, nmethods)) + return cls(ver, methods) + + def to_file(self, f): + f.write(struct.pack("!BB", self.ver, len(self.methods))) + f.write(self.methods.tostring()) + + +class ServerGreeting(object): + __slots__ = ("ver", "method") + + def __init__(self, ver, method): + self.ver = ver + self.method = method + + @classmethod + def from_file(cls, f): + ver, method = struct.unpack("!BB", _read(f, 2)) + return cls(ver, method) + + def to_file(self, f): + f.write(struct.pack("!BB", self.ver, self.method)) + + +class Message(object): + __slots__ = ("ver", "msg", "atyp", "addr") + + def __init__(self, ver, msg, atyp, addr): + self.ver = ver + self.msg = msg + self.atyp = atyp + self.addr = addr + + @classmethod + def from_file(cls, f): + ver, msg, rsv, atyp = struct.unpack("!BBBB", _read(f, 4)) + if rsv != 0x00: + raise SocksError(REP.GENERAL_SOCKS_SERVER_FAILURE, + "Socks Request: Invalid reserved byte: %s" % rsv) + + if atyp == ATYP.IPV4_ADDRESS: + # We use tnoa here as ntop is not commonly available on Windows. + host = socket.inet_ntoa(_read(f, 4)) + use_ipv6 = False + elif atyp == ATYP.IPV6_ADDRESS: + host = socket.inet_ntop(socket.AF_INET6, _read(f, 16)) + use_ipv6 = True + elif atyp == ATYP.DOMAINNAME: + length, = struct.unpack("!B", _read(f, 1)) + host = _read(f, length) + use_ipv6 = False + else: + raise SocksError(REP.ADDRESS_TYPE_NOT_SUPPORTED, + "Socks Request: Unknown ATYP: %s" % atyp) + + port, = struct.unpack("!H", _read(f, 2)) + addr = tcp.Address((host, port), use_ipv6=use_ipv6) + return cls(ver, msg, atyp, addr) + + def to_file(self, f): + f.write(struct.pack("!BBBB", self.ver, self.msg, 0x00, self.atyp)) + if self.atyp == ATYP.IPV4_ADDRESS: + f.write(socket.inet_aton(self.addr.host)) + elif self.atyp == ATYP.IPV6_ADDRESS: + f.write(socket.inet_pton(socket.AF_INET6, self.addr.host)) + elif self.atyp == ATYP.DOMAINNAME: + f.write(struct.pack("!B", len(self.addr.host))) + f.write(self.addr.host) + else: + raise SocksError( + REP.ADDRESS_TYPE_NOT_SUPPORTED, + "Unknown ATYP: %s" % self.atyp + ) + f.write(struct.pack("!H", self.addr.port)) + diff --git a/netlib/tcp.py b/netlib/tcp.py index e72d5e48..6b7540aa 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -1,6 +1,13 @@ -import select, socket, threading, sys, time, traceback +from __future__ import (absolute_import, print_function, division) +import select +import socket +import sys +import threading +import time +import traceback from OpenSSL import SSL -import certutils + +from . import certutils EINTR = 4 @@ -10,32 +17,6 @@ SSLv3_METHOD = SSL.SSLv3_METHOD SSLv23_METHOD = SSL.SSLv23_METHOD TLSv1_METHOD = SSL.TLSv1_METHOD -OP_ALL = SSL.OP_ALL -OP_CIPHER_SERVER_PREFERENCE = SSL.OP_CIPHER_SERVER_PREFERENCE -OP_COOKIE_EXCHANGE = SSL.OP_COOKIE_EXCHANGE -OP_DONT_INSERT_EMPTY_FRAGMENTS = SSL.OP_DONT_INSERT_EMPTY_FRAGMENTS -OP_EPHEMERAL_RSA = SSL.OP_EPHEMERAL_RSA -OP_MICROSOFT_BIG_SSLV3_BUFFER = SSL.OP_MICROSOFT_BIG_SSLV3_BUFFER -OP_MICROSOFT_SESS_ID_BUG = SSL.OP_MICROSOFT_SESS_ID_BUG -OP_MSIE_SSLV2_RSA_PADDING = SSL.OP_MSIE_SSLV2_RSA_PADDING -OP_NETSCAPE_CA_DN_BUG = SSL.OP_NETSCAPE_CA_DN_BUG -OP_NETSCAPE_CHALLENGE_BUG = SSL.OP_NETSCAPE_CHALLENGE_BUG -OP_NETSCAPE_DEMO_CIPHER_CHANGE_BUG = SSL.OP_NETSCAPE_DEMO_CIPHER_CHANGE_BUG -OP_NETSCAPE_REUSE_CIPHER_CHANGE_BUG = SSL.OP_NETSCAPE_REUSE_CIPHER_CHANGE_BUG -OP_NO_QUERY_MTU = SSL.OP_NO_QUERY_MTU -OP_NO_SSLv2 = SSL.OP_NO_SSLv2 -OP_NO_SSLv3 = SSL.OP_NO_SSLv3 -OP_NO_TICKET = SSL.OP_NO_TICKET -OP_NO_TLSv1 = SSL.OP_NO_TLSv1 -OP_PKCS1_CHECK_1 = SSL.OP_PKCS1_CHECK_1 -OP_PKCS1_CHECK_2 = SSL.OP_PKCS1_CHECK_2 -OP_SINGLE_DH_USE = SSL.OP_SINGLE_DH_USE -OP_SSLEAY_080_CLIENT_DH_BUG = SSL.OP_SSLEAY_080_CLIENT_DH_BUG -OP_SSLREF2_REUSE_CERT_TYPE_BUG = SSL.OP_SSLREF2_REUSE_CERT_TYPE_BUG -OP_TLS_BLOCK_PADDING_BUG = SSL.OP_TLS_BLOCK_PADDING_BUG -OP_TLS_D5_BUG = SSL.OP_TLS_D5_BUG -OP_TLS_ROLLBACK_BUG = SSL.OP_TLS_ROLLBACK_BUG - class NetLibError(Exception): pass class NetLibDisconnect(NetLibError): pass @@ -212,10 +193,47 @@ class Address(object): def use_ipv6(self, b): self.family = socket.AF_INET6 if b else socket.AF_INET + def __repr__(self): + return repr(self.address) + def __eq__(self, other): other = Address.wrap(other) return (self.address, self.family) == (other.address, other.family) + def __ne__(self, other): + return not self.__eq__(other) + + +def close_socket(sock): + """ + Does a hard close of a socket, without emitting a RST. + """ + try: + # We already indicate that we close our end. + # If we close RD, any further received bytes would result in a RST being set, which we want to avoid + # for our purposes + sock.shutdown(socket.SHUT_WR) # may raise "Transport endpoint is not connected" on Linux + + # Section 4.2.2.13 of RFC 1122 tells us that a close() with any + # pending readable data could lead to an immediate RST being sent (which is the case on Windows). + # http://ia600609.us.archive.org/22/items/TheUltimateSo_lingerPageOrWhyIsMyTcpNotReliable/the-ultimate-so_linger-page-or-why-is-my-tcp-not-reliable.html + # + # However, we cannot rely on the shutdown()-followed-by-read()-eof technique proposed by the page above: + # Some remote machines just don't send a TCP FIN, which would leave us in the unfortunate situation that + # recv() would block infinitely. + # As a workaround, we set a timeout here even if we are in blocking mode. + # Please let us know if you have a better solution to this problem. + + sock.settimeout(sock.gettimeout() or 20) + # may raise a timeout/disconnect exception. + while sock.recv(4096): # pragma: no cover + pass + + except socket.error: + pass + + sock.close() + class _Connection(object): def get_current_cipher(self): @@ -229,40 +247,39 @@ class _Connection(object): def finish(self): self.finished = True - try: + + # If we have an SSL connection, wfile.close == connection.close + # (We call _FileLike.set_descriptor(conn)) + # Closing the socket is not our task, therefore we don't call close then. + if type(self.connection) != SSL.Connection: if not getattr(self.wfile, "closed", False): - self.wfile.flush() - self.close() + try: + self.wfile.flush() + except NetLibDisconnect: + pass + self.wfile.close() self.rfile.close() - except (socket.error, NetLibDisconnect): - # Remote has disconnected - pass - - def close(self): - """ - Does a hard close of the socket, i.e. a shutdown, followed by a close. - """ - try: - if self.ssl_established: + else: + try: self.connection.shutdown() - self.connection.sock_shutdown(socket.SHUT_WR) - else: - self.connection.shutdown(socket.SHUT_WR) - #Section 4.2.2.13 of RFC 1122 tells us that a close() with any - # pending readable data could lead to an immediate RST being sent. - #http://ia600609.us.archive.org/22/items/TheUltimateSo_lingerPageOrWhyIsMyTcpNotReliable/the-ultimate-so_linger-page-or-why-is-my-tcp-not-reliable.html - while self.connection.recv(4096): # pragma: no cover + except SSL.Error: pass - self.connection.close() - except (socket.error, SSL.Error, IOError): - # Socket probably already closed - pass class TCPClient(_Connection): rbufsize = -1 wbufsize = -1 + + def close(self): + # Make sure to close the real socket, not the SSL proxy. + # OpenSSL is really good at screwing up, i.e. when trying to recv from a failed connection, + # it tries to renegotiate... + if type(self.connection) == SSL.Connection: + close_socket(self.connection._socket) + else: + close_socket(self.connection) + def __init__(self, address, source_address=None): self.address = Address.wrap(address) self.source_address = Address.wrap(source_address) if source_address else None @@ -274,6 +291,8 @@ class TCPClient(_Connection): def convert_to_ssl(self, cert=None, sni=None, method=TLSv1_METHOD, options=None, cipher_list=None): """ cert: Path to a file containing both client cert and private key. + + options: A bit field consisting of OpenSSL.SSL.OP_* values """ context = SSL.Context(method) if cipher_list: @@ -290,7 +309,6 @@ class TCPClient(_Connection): except SSL.Error, v: raise NetLibError("SSL client certificate error: %s"%str(v)) self.connection = SSL.Connection(context, self.connection) - self.ssl_established = True if sni: self.sni = sni self.connection.set_tlsext_host_name(sni) @@ -298,7 +316,8 @@ class TCPClient(_Connection): try: self.connection.do_handshake() except SSL.Error, v: - raise NetLibError("SSL handshake error: %s"%str(v)) + raise NetLibError("SSL handshake error: %s"%repr(v)) + self.ssl_established = True self.cert = certutils.SSLCert(self.connection.get_peer_certificate()) self.rfile.set_descriptor(self.connection) self.wfile.set_descriptor(self.connection) @@ -309,6 +328,8 @@ class TCPClient(_Connection): if self.source_address: connection.bind(self.source_address()) connection.connect(self.address()) + if not self.source_address: + self.source_address = Address(connection.getsockname()) self.rfile = Reader(connection.makefile('rb', self.rbufsize)) self.wfile = Writer(connection.makefile('wb', self.wbufsize)) except (socket.error, IOError), err: @@ -343,21 +364,25 @@ class BaseHandler(_Connection): def _create_ssl_context(self, cert, key, method=SSLv23_METHOD, options=None, handle_sni=None, request_client_cert=None, cipher_list=None, - dhparams=None ): + dhparams=None, chain_file=None): """ cert: A certutils.SSLCert object. + method: One of SSLv2_METHOD, SSLv3_METHOD, SSLv23_METHOD, or TLSv1_METHOD + handle_sni: SNI handler, should take a connection object. Server name can be retrieved like this: - connection.get_servername() + connection.get_servername() + + options: A bit field consisting of OpenSSL.SSL.OP_* values - And you can specify the connection keys as follows: + And you can specify the connection keys as follows: - new_context = Context(TLSv1_METHOD) - new_context.use_privatekey(key) - new_context.use_certificate(cert) - connection.set_context(new_context) + new_context = Context(TLSv1_METHOD) + new_context.use_privatekey(key) + new_context.use_certificate(cert) + connection.set_context(new_context) The request_client_cert argument requires some explanation. We're supposed to be able to do this with no negative effects - if the @@ -371,6 +396,8 @@ class BaseHandler(_Connection): ctx = SSL.Context(method) if not options is None: ctx.set_options(options) + if chain_file: + ctx.load_verify_locations(chain_file) if cipher_list: try: ctx.set_cipher_list(cipher_list) @@ -398,12 +425,12 @@ class BaseHandler(_Connection): """ ctx = self._create_ssl_context(cert, key, **sslctx_kwargs) self.connection = SSL.Connection(ctx, self.connection) - self.ssl_established = True self.connection.set_accept_state() try: self.connection.do_handshake() except SSL.Error, v: - raise NetLibError("SSL handshake error: %s"%str(v)) + raise NetLibError("SSL handshake error: %s"%repr(v)) + self.ssl_established = True self.rfile.set_descriptor(self.connection) self.wfile.set_descriptor(self.connection) @@ -414,7 +441,6 @@ class BaseHandler(_Connection): self.connection.settimeout(n) - class TCPServer(object): request_queue_size = 20 def __init__(self, address): @@ -434,11 +460,7 @@ class TCPServer(object): except: self.handle_error(connection, client_address) finally: - try: - connection.shutdown(socket.SHUT_RDWR) - except: - pass - connection.close() + close_socket(connection) def serve_forever(self, poll_interval=0.1): self.__is_shut_down.clear() @@ -450,7 +472,7 @@ class TCPServer(object): if ex[0] == EINTR: continue else: - raise + raise if self.socket in r: connection, client_address = self.socket.accept() t = threading.Thread( @@ -472,7 +494,7 @@ class TCPServer(object): self.socket.close() self.handle_shutdown() - def handle_error(self, request, client_address, fp=sys.stderr): + def handle_error(self, connection, client_address, fp=sys.stderr): """ Called when handle_client_connection raises an exception. """ @@ -480,10 +502,13 @@ class TCPServer(object): # none. if traceback: exc = traceback.format_exc() - print >> fp, '-'*40 - print >> fp, "Error in processing of request from %s:%s" % (client_address.host, client_address.port) - print >> fp, exc - print >> fp, '-'*40 + print('-' * 40, file=fp) + print( + "Error in processing of request from %s:%s" % ( + client_address.host, client_address.port + ), file=fp) + print(exc, file=fp) + print('-' * 40, file=fp) def handle_client_connection(self, conn, client_address): # pragma: no cover """ diff --git a/netlib/test.py b/netlib/test.py index bb0012ad..fb468907 100644 --- a/netlib/test.py +++ b/netlib/test.py @@ -1,6 +1,7 @@ +from __future__ import (absolute_import, print_function, division) import threading, Queue, cStringIO -import tcp, certutils import OpenSSL +from . import tcp, certutils class ServerThread(threading.Thread): def __init__(self, server): @@ -63,7 +64,7 @@ class TServer(tcp.TCPServer): key = OpenSSL.crypto.load_privatekey(OpenSSL.crypto.FILETYPE_PEM, raw) if self.ssl["v3_only"]: method = tcp.SSLv3_METHOD - options = tcp.OP_NO_SSLv2|tcp.OP_NO_TLSv1 + options = OpenSSL.SSL.OP_NO_SSLv2 | OpenSSL.SSL.OP_NO_TLSv1 else: method = tcp.SSLv23_METHOD options = None @@ -79,7 +80,7 @@ class TServer(tcp.TCPServer): h.handle() h.finish() - def handle_error(self, request, client_address): + def handle_error(self, connection, client_address, fp=None): s = cStringIO.StringIO() - tcp.TCPServer.handle_error(self, request, client_address, s) + tcp.TCPServer.handle_error(self, connection, client_address, s) self.q.put(s.getvalue()) diff --git a/netlib/utils.py b/netlib/utils.py index 61fd54ae..79077ac6 100644 --- a/netlib/utils.py +++ b/netlib/utils.py @@ -1,3 +1,5 @@ +from __future__ import (absolute_import, print_function, division) + def isascii(s): try: @@ -32,13 +34,13 @@ def hexdump(s): """ parts = [] for i in range(0, len(s), 16): - o = "%.10x"%i - part = s[i:i+16] - x = " ".join("%.2x"%ord(i) for i in part) + o = "%.10x" % i + part = s[i:i + 16] + x = " ".join("%.2x" % ord(i) for i in part) if len(part) < 16: x += " " x += " ".join(" " for i in range(16 - len(part))) parts.append( (o, x, cleanBin(part, True)) ) - return parts + return parts
\ No newline at end of file diff --git a/netlib/version.py b/netlib/version.py index 1d3250e1..913f753a 100644 --- a/netlib/version.py +++ b/netlib/version.py @@ -1,4 +1,7 @@ +from __future__ import (absolute_import, print_function, division) + IVERSION = (0, 11) VERSION = ".".join(str(i) for i in IVERSION) +MINORVERSION = ".".join(str(i) for i in IVERSION[:2]) NAME = "netlib" NAMEVERSION = NAME + " " + VERSION diff --git a/netlib/wsgi.py b/netlib/wsgi.py index b576bdff..568b1f9c 100644 --- a/netlib/wsgi.py +++ b/netlib/wsgi.py @@ -1,5 +1,6 @@ +from __future__ import (absolute_import, print_function, division) import cStringIO, urllib, time, traceback -import odict, tcp +from . import odict, tcp class ClientConn: @@ -8,15 +9,15 @@ class ClientConn: class Flow: - def __init__(self, client_conn): - self.client_conn = client_conn + def __init__(self, address, request): + self.client_conn = ClientConn(address) + self.request = request class Request: - def __init__(self, client_conn, scheme, method, path, headers, content): + def __init__(self, scheme, method, path, headers, content): self.scheme, self.method, self.path = scheme, method, path self.headers, self.content = headers, content - self.flow = Flow(client_conn) def date_time_string(): @@ -38,37 +39,37 @@ class WSGIAdaptor: def __init__(self, app, domain, port, sversion): self.app, self.domain, self.port, self.sversion = app, domain, port, sversion - def make_environ(self, request, errsoc, **extra): - if '?' in request.path: - path_info, query = request.path.split('?', 1) + def make_environ(self, flow, errsoc, **extra): + if '?' in flow.request.path: + path_info, query = flow.request.path.split('?', 1) else: - path_info = request.path + path_info = flow.request.path query = '' environ = { 'wsgi.version': (1, 0), - 'wsgi.url_scheme': request.scheme, - 'wsgi.input': cStringIO.StringIO(request.content), + 'wsgi.url_scheme': flow.request.scheme, + 'wsgi.input': cStringIO.StringIO(flow.request.content), 'wsgi.errors': errsoc, 'wsgi.multithread': True, 'wsgi.multiprocess': False, 'wsgi.run_once': False, 'SERVER_SOFTWARE': self.sversion, - 'REQUEST_METHOD': request.method, + 'REQUEST_METHOD': flow.request.method, 'SCRIPT_NAME': '', 'PATH_INFO': urllib.unquote(path_info), 'QUERY_STRING': query, - 'CONTENT_TYPE': request.headers.get('Content-Type', [''])[0], - 'CONTENT_LENGTH': request.headers.get('Content-Length', [''])[0], + 'CONTENT_TYPE': flow.request.headers.get('Content-Type', [''])[0], + 'CONTENT_LENGTH': flow.request.headers.get('Content-Length', [''])[0], 'SERVER_NAME': self.domain, 'SERVER_PORT': str(self.port), # FIXME: We need to pick up the protocol read from the request. 'SERVER_PROTOCOL': "HTTP/1.1", } environ.update(extra) - if request.flow.client_conn.address: - environ["REMOTE_ADDR"], environ["REMOTE_PORT"] = request.flow.client_conn.address() + if flow.client_conn.address: + environ["REMOTE_ADDR"], environ["REMOTE_PORT"] = flow.client_conn.address() - for key, value in request.headers.items(): + for key, value in flow.request.headers.items(): key = 'HTTP_' + key.upper().replace('-', '_') if key not in ('HTTP_CONTENT_TYPE', 'HTTP_CONTENT_LENGTH'): environ[key] = value |