diff options
Diffstat (limited to 'netlib')
-rw-r--r-- | netlib/__init__.py | 1 | ||||
-rw-r--r-- | netlib/certffi.py | 9 | ||||
-rw-r--r-- | netlib/certutils.py | 134 | ||||
-rw-r--r-- | netlib/contrib/__init__.py | 0 | ||||
-rw-r--r-- | netlib/contrib/md5crypt.py | 94 | ||||
-rw-r--r-- | netlib/http.py | 161 | ||||
-rw-r--r-- | netlib/http_auth.py | 32 | ||||
-rw-r--r-- | netlib/http_status.py | 1 | ||||
-rw-r--r-- | netlib/http_uastrings.py | 2 | ||||
-rw-r--r-- | netlib/odict.py | 7 | ||||
-rw-r--r-- | netlib/socks.py | 128 | ||||
-rw-r--r-- | netlib/tcp.py | 46 | ||||
-rw-r--r-- | netlib/test.py | 3 | ||||
-rw-r--r-- | netlib/utils.py | 10 | ||||
-rw-r--r-- | netlib/version.py | 3 | ||||
-rw-r--r-- | netlib/wsgi.py | 35 |
16 files changed, 399 insertions, 267 deletions
diff --git a/netlib/__init__.py b/netlib/__init__.py index e69de29b..9b4faa33 100644 --- a/netlib/__init__.py +++ b/netlib/__init__.py @@ -0,0 +1 @@ +from __future__ import (absolute_import, print_function, division) diff --git a/netlib/certffi.py b/netlib/certffi.py index c5d7c95e..81dc72e8 100644 --- a/netlib/certffi.py +++ b/netlib/certffi.py @@ -1,7 +1,9 @@ +from __future__ import (absolute_import, print_function, division) import cffi import OpenSSL + xffi = cffi.FFI() -xffi.cdef (""" +xffi.cdef(""" struct rsa_meth_st { int flags; ...; @@ -18,6 +20,7 @@ xffi.verify( extra_compile_args=['-w'] ) + def handle(privkey): new = xffi.new("struct rsa_st*") newbuf = xffi.buffer(new) @@ -26,11 +29,13 @@ def handle(privkey): newbuf[:] = oldbuf[:] return new + def set_flags(privkey, val): hdl = handle(privkey) - hdl.meth.flags = val + hdl.meth.flags = val return privkey + def get_flags(privkey): hdl = handle(privkey) return hdl.meth.flags diff --git a/netlib/certutils.py b/netlib/certutils.py index 4c50b984..fe067ca1 100644 --- a/netlib/certutils.py +++ b/netlib/certutils.py @@ -1,10 +1,10 @@ +from __future__ import (absolute_import, print_function, division) import os, ssl, time, datetime +import itertools from pyasn1.type import univ, constraint, char, namedtype, tag from pyasn1.codec.der.decoder import decode from pyasn1.error import PyAsn1Error import OpenSSL -import tcp -import UserDict DEFAULT_EXP = 62208000 # =24 * 60 * 60 * 720 # Generated with "openssl dhparam". It's too slow to generate this on startup. @@ -22,7 +22,7 @@ def create_ca(o, cn, exp): cert.set_version(2) cert.get_subject().CN = cn cert.get_subject().O = o - cert.gmtime_adj_notBefore(0) + cert.gmtime_adj_notBefore(-3600*48) cert.gmtime_adj_notAfter(exp) cert.set_issuer(cert.get_subject()) cert.set_pubkey(key) @@ -73,42 +73,44 @@ def dummy_cert(privkey, cacert, commonname, sans): return SSLCert(cert) -class _Node(UserDict.UserDict): - def __init__(self): - UserDict.UserDict.__init__(self) - self.value = None - - -class DNTree: - """ - Domain store that knows about wildcards. DNS wildcards are very - restricted - the only valid variety is an asterisk on the left-most - domain component, i.e.: - - *.foo.com - """ - def __init__(self): - self.d = _Node() - - def add(self, dn, cert): - parts = dn.split(".") - parts.reverse() - current = self.d - for i in parts: - current = current.setdefault(i, _Node()) - current.value = cert - - def get(self, dn): - parts = dn.split(".") - current = self.d - for i in reversed(parts): - if i in current: - current = current[i] - elif "*" in current: - return current["*"].value - else: - return None - return current.value +# DNTree did not pass TestCertStore.test_sans_change and is temporarily replaced by a simple dict. +# +# class _Node(UserDict.UserDict): +# def __init__(self): +# UserDict.UserDict.__init__(self) +# self.value = None +# +# +# class DNTree: +# """ +# Domain store that knows about wildcards. DNS wildcards are very +# restricted - the only valid variety is an asterisk on the left-most +# domain component, i.e.: +# +# *.foo.com +# """ +# def __init__(self): +# self.d = _Node() +# +# def add(self, dn, cert): +# parts = dn.split(".") +# parts.reverse() +# current = self.d +# for i in parts: +# current = current.setdefault(i, _Node()) +# current.value = cert +# +# def get(self, dn): +# parts = dn.split(".") +# current = self.d +# for i in reversed(parts): +# if i in current: +# current = current[i] +# elif "*" in current: +# return current["*"].value +# else: +# return None +# return current.value @@ -119,7 +121,7 @@ class CertStore: def __init__(self, privkey, cacert, dhparams=None): self.privkey, self.cacert = privkey, cacert self.dhparams = dhparams - self.certs = DNTree() + self.certs = dict() @classmethod def load_dhparam(klass, path): @@ -206,11 +208,24 @@ class CertStore: any SANs, and also the list of names provided as an argument. """ if cert.cn: - self.certs.add(cert.cn, (cert, privkey)) + self.certs[cert.cn] = (cert, privkey) for i in cert.altnames: - self.certs.add(i, (cert, privkey)) + self.certs[i] = (cert, privkey) for i in names: - self.certs.add(i, (cert, privkey)) + self.certs[i] = (cert, privkey) + + @staticmethod + def asterisk_forms(dn): + parts = dn.split(".") + parts.reverse() + curr_dn = "" + dn_forms = ["*"] + for part in parts[:-1]: + curr_dn = "." + part + curr_dn # .example.com + dn_forms.append("*" + curr_dn) # *.example.com + if parts[-1] != "*": + dn_forms.append(parts[-1] + curr_dn) + return dn_forms def get_cert(self, commonname, sans): """ @@ -223,15 +238,23 @@ class CertStore: Return None if the certificate could not be found or generated. """ - c = self.certs.get(commonname) - if not c: - c = dummy_cert(self.privkey, self.cacert, commonname, sans) - self.add_cert(c, None) - c = (c, None) - return (c[0], c[1] or self.privkey) + + potential_keys = self.asterisk_forms(commonname) + for s in sans: + potential_keys.extend(self.asterisk_forms(s)) + potential_keys.append((commonname, tuple(sans))) + + name = next(itertools.ifilter(lambda key: key in self.certs, potential_keys), None) + if name: + c = self.certs[name] + else: + c = dummy_cert(self.privkey, self.cacert, commonname, sans), None + self.certs[(commonname, tuple(sans))] = c + + return c[0], (c[1] or self.privkey) def gen_pkey(self, cert): - import certffi + from . import certffi certffi.set_flags(self.privkey, 1) return self.privkey @@ -262,6 +285,9 @@ class SSLCert: def __eq__(self, other): return self.digest("sha1") == other.digest("sha1") + def __ne__(self, other): + return not self.__eq__(other) + @classmethod def from_pem(klass, txt): x509 = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, txt) @@ -337,11 +363,3 @@ class SSLCert: for i in dec[0]: altnames.append(i[0].asOctets()) return altnames - - - -def get_remote_cert(host, port, sni): - c = tcp.TCPClient((host, port)) - c.connect() - c.convert_to_ssl(sni=sni) - return c.cert diff --git a/netlib/contrib/__init__.py b/netlib/contrib/__init__.py deleted file mode 100644 index e69de29b..00000000 --- a/netlib/contrib/__init__.py +++ /dev/null diff --git a/netlib/contrib/md5crypt.py b/netlib/contrib/md5crypt.py deleted file mode 100644 index d64ea8ac..00000000 --- a/netlib/contrib/md5crypt.py +++ /dev/null @@ -1,94 +0,0 @@ -# Based on FreeBSD src/lib/libcrypt/crypt.c 1.2 -# http://www.freebsd.org/cgi/cvsweb.cgi/~checkout~/src/lib/libcrypt/crypt.c?rev=1.2&content-type=text/plain - -# Original license: -# * "THE BEER-WARE LICENSE" (Revision 42): -# * <phk@login.dknet.dk> wrote this file. As long as you retain this notice you -# * can do whatever you want with this stuff. If we meet some day, and you think -# * this stuff is worth it, you can buy me a beer in return. Poul-Henning Kamp - -# This port adds no further stipulations. I forfeit any copyright interest. - -import md5 - -def md5crypt(password, salt, magic='$1$'): - # /* The password first, since that is what is most unknown */ /* Then our magic string */ /* Then the raw salt */ - m = md5.new() - m.update(password + magic + salt) - - # /* Then just as many characters of the MD5(pw,salt,pw) */ - mixin = md5.md5(password + salt + password).digest() - for i in range(0, len(password)): - m.update(mixin[i % 16]) - - # /* Then something really weird... */ - # Also really broken, as far as I can tell. -m - i = len(password) - while i: - if i & 1: - m.update('\x00') - else: - m.update(password[0]) - i >>= 1 - - final = m.digest() - - # /* and now, just to make sure things don't run too fast */ - for i in range(1000): - m2 = md5.md5() - if i & 1: - m2.update(password) - else: - m2.update(final) - - if i % 3: - m2.update(salt) - - if i % 7: - m2.update(password) - - if i & 1: - m2.update(final) - else: - m2.update(password) - - final = m2.digest() - - # This is the bit that uses to64() in the original code. - - itoa64 = './0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz' - - rearranged = '' - for a, b, c in ((0, 6, 12), (1, 7, 13), (2, 8, 14), (3, 9, 15), (4, 10, 5)): - v = ord(final[a]) << 16 | ord(final[b]) << 8 | ord(final[c]) - for i in range(4): - rearranged += itoa64[v & 0x3f]; v >>= 6 - - v = ord(final[11]) - for i in range(2): - rearranged += itoa64[v & 0x3f]; v >>= 6 - - return magic + salt + '$' + rearranged - -if __name__ == '__main__': - - def test(clear_password, the_hash): - magic, salt = the_hash[1:].split('$')[:2] - magic = '$' + magic + '$' - return md5crypt(clear_password, salt, magic) == the_hash - - test_cases = ( - (' ', '$1$yiiZbNIH$YiCsHZjcTkYd31wkgW8JF.'), - ('pass', '$1$YeNsbWdH$wvOF8JdqsoiLix754LTW90'), - ('____fifteen____', '$1$s9lUWACI$Kk1jtIVVdmT01p0z3b/hw1'), - ('____sixteen_____', '$1$dL3xbVZI$kkgqhCanLdxODGq14g/tW1'), - ('____seventeen____', '$1$NaH5na7J$j7y8Iss0hcRbu3kzoJs5V.'), - ('__________thirty-three___________', '$1$HO7Q6vzJ$yGwp2wbL5D7eOVzOmxpsy.'), - ('apache', '$apr1$J.w5a/..$IW9y6DR0oO/ADuhlMF5/X1') - ) - - for clearpw, hashpw in test_cases: - if test(clearpw, hashpw): - print '%s: pass' % clearpw - else: - print '%s: FAIL' % clearpw diff --git a/netlib/http.py b/netlib/http.py index 51f85627..35e959cd 100644 --- a/netlib/http.py +++ b/netlib/http.py @@ -1,15 +1,17 @@ +from __future__ import (absolute_import, print_function, division) import string, urlparse, binascii -import odict, utils +import sys +from . import odict, utils -class HttpError(Exception): - def __init__(self, code, msg): - self.code, self.msg = code, msg - def __str__(self): - return "HttpError(%s, %s)"%(self.code, self.msg) +class HttpError(Exception): + def __init__(self, code, message): + super(HttpError, self).__init__(message) + self.code = code -class HttpErrorConnClosed(HttpError): pass +class HttpErrorConnClosed(HttpError): + pass def _is_valid_port(port): @@ -43,6 +45,11 @@ def parse_url(url): return None if not scheme: return None + if '@' in netloc: + # FIXME: Consider what to do with the discarded credentials here Most + # probably we should extend the signature to return these as a separate + # value. + _, netloc = string.rsplit(netloc, '@', maxsplit=1) if ':' in netloc: host, port = string.rsplit(netloc, ':', maxsplit=1) try: @@ -88,14 +95,14 @@ def read_headers(fp): # We're being liberal in what we accept, here. if i > 0: name = line[:i] - value = line[i+1:].strip() + value = line[i + 1:].strip() ret.append([name, value]) else: return None return odict.ODictCaseless(ret) -def read_chunked(fp, headers, limit, is_request): +def read_chunked(fp, limit, is_request): """ Read a chunked HTTP body. @@ -103,10 +110,9 @@ def read_chunked(fp, headers, limit, is_request): """ # FIXME: Should check if chunked is the final encoding in the headers # http://tools.ietf.org/html/draft-ietf-httpbis-p1-messaging-16#section-3.3 3.3 2. - content = "" total = 0 code = 400 if is_request else 502 - while 1: + while True: line = fp.readline(128) if line == "": raise HttpErrorConnClosed(code, "Connection closed prematurely") @@ -114,27 +120,19 @@ def read_chunked(fp, headers, limit, is_request): try: length = int(line, 16) except ValueError: - # FIXME: Not strictly correct - this could be from the server, in which - # case we should send a 502. - raise HttpError(code, "Invalid chunked encoding length: %s"%line) - if not length: - break + raise HttpError(code, "Invalid chunked encoding length: %s" % line) total += length if limit is not None and total > limit: - msg = "HTTP Body too large."\ - " Limit is %s, chunked content length was at least %s"%(limit, total) + msg = "HTTP Body too large." \ + " Limit is %s, chunked content length was at least %s" % (limit, total) raise HttpError(code, msg) - content += fp.read(length) - line = fp.readline(5) - if line != '\r\n': + chunk = fp.read(length) + suffix = fp.readline(5) + if suffix != '\r\n': raise HttpError(code, "Malformed chunked body") - while 1: - line = fp.readline() - if line == "": - raise HttpErrorConnClosed(code, "Connection closed prematurely") - if line == '\r\n' or line == '\n': - break - return content + yield line, chunk, '\r\n' + if length == 0: + return def get_header_tokens(headers, key): @@ -264,6 +262,7 @@ def parse_init_http(line): def connection_close(httpversion, headers): """ Checks the message to see if the client connection should be closed according to RFC 2616 Section 8.1 + Note that a connection should be closed as well if the response has been read until end of the stream. """ # At first, check if we have an explicit Connection header. if "connection" in headers: @@ -280,7 +279,7 @@ def connection_close(httpversion, headers): def parse_response_line(line): parts = line.strip().split(" ", 2) - if len(parts) == 2: # handle missing message gracefully + if len(parts) == 2: # handle missing message gracefully parts.append("") if len(parts) != 3: return None @@ -292,35 +291,43 @@ def parse_response_line(line): return (proto, code, msg) -def read_response(rfile, method, body_size_limit): +def read_response(rfile, request_method, body_size_limit, include_body=True): """ Return an (httpversion, code, msg, headers, content) tuple. + + By default, both response header and body are read. + If include_body=False is specified, content may be one of the following: + - None, if the response is technically allowed to have a response body + - "", if the response must not have a response body (e.g. it's a response to a HEAD request) """ line = rfile.readline() - if line == "\r\n" or line == "\n": # Possible leftover from previous message + if line == "\r\n" or line == "\n": # Possible leftover from previous message line = rfile.readline() if not line: raise HttpErrorConnClosed(502, "Server disconnect.") parts = parse_response_line(line) if not parts: - raise HttpError(502, "Invalid server response: %s"%repr(line)) + raise HttpError(502, "Invalid server response: %s" % repr(line)) proto, code, msg = parts httpversion = parse_http_protocol(proto) if httpversion is None: - raise HttpError(502, "Invalid HTTP version in line: %s"%repr(proto)) + raise HttpError(502, "Invalid HTTP version in line: %s" % repr(proto)) headers = read_headers(rfile) if headers is None: raise HttpError(502, "Invalid headers.") - # Parse response body according to http://tools.ietf.org/html/draft-ietf-httpbis-p1-messaging-16#section-3.3 - if method in ["HEAD", "CONNECT"] or (code in [204, 304]) or 100 <= code <= 199: - content = "" + if include_body: + content = read_http_body(rfile, headers, body_size_limit, request_method, code, False) else: - content = read_http_body(rfile, headers, body_size_limit, False) + content = None # if include_body==False then a None content means the body should be read separately return httpversion, code, msg, headers, content -def read_http_body(rfile, headers, limit, is_request): +def read_http_body(*args, **kwargs): + return "".join(content for _, content, _ in read_http_body_chunked(*args, **kwargs)) + + +def read_http_body_chunked(rfile, headers, limit, request_method, response_code, is_request, max_chunk_size=None): """ Read an HTTP message body: @@ -329,23 +336,69 @@ def read_http_body(rfile, headers, limit, is_request): limit: Size limit. is_request: True if the body to read belongs to a request, False otherwise """ - if has_chunked_encoding(headers): - content = read_chunked(rfile, headers, limit, is_request) - elif "content-length" in headers: - try: - l = int(headers["content-length"][0]) - if l < 0: - raise ValueError() - except ValueError: - raise HttpError(400 if is_request else 502, "Invalid content-length header: %s"%headers["content-length"]) - if limit is not None and l > limit: - raise HttpError(400 if is_request else 509, "HTTP Body too large. Limit is %s, content-length was %s"%(limit, l)) - content = rfile.read(l) - elif is_request: - content = "" + if max_chunk_size is None: + max_chunk_size = limit or sys.maxint + + expected_size = expected_http_body_size(headers, is_request, request_method, response_code) + + if expected_size is None: + if has_chunked_encoding(headers): + # Python 3: yield from + for x in read_chunked(rfile, limit, is_request): + yield x + else: # pragma: nocover + raise HttpError(400 if is_request else 502, "Content-Length unknown but no chunked encoding") + elif expected_size >= 0: + if limit is not None and expected_size > limit: + raise HttpError(400 if is_request else 509, + "HTTP Body too large. Limit is %s, content-length was %s" % (limit, expected_size)) + bytes_left = expected_size + while bytes_left: + chunk_size = min(bytes_left, max_chunk_size) + yield "", rfile.read(chunk_size), "" + bytes_left -= chunk_size else: - content = rfile.read(limit if limit else -1) + bytes_left = limit or -1 + while bytes_left: + chunk_size = min(bytes_left, max_chunk_size) + content = rfile.read(chunk_size) + if not content: + return + yield "", content, "" + bytes_left -= chunk_size not_done = rfile.read(1) if not_done: raise HttpError(400 if is_request else 509, "HTTP Body too large. Limit is %s," % limit) - return content
\ No newline at end of file + + +def expected_http_body_size(headers, is_request, request_method, response_code): + """ + Returns the expected body length: + - a positive integer, if the size is known in advance + - None, if the size in unknown in advance (chunked encoding) + - -1, if all data should be read until end of stream. + """ + + # Determine response size according to http://tools.ietf.org/html/rfc7230#section-3.3 + if request_method: + request_method = request_method.upper() + + if (not is_request and ( + request_method == "HEAD" or + (request_method == "CONNECT" and response_code == 200) or + response_code in [204, 304] or + 100 <= response_code <= 199)): + return 0 + if has_chunked_encoding(headers): + return None + if "content-length" in headers: + try: + size = int(headers["content-length"][0]) + if size < 0: + raise ValueError() + return size + except ValueError: + raise HttpError(400 if is_request else 502, "Invalid content-length header: %s" % headers["content-length"]) + if is_request: + return 0 + return -1 diff --git a/netlib/http_auth.py b/netlib/http_auth.py index b0451e3b..49f5925f 100644 --- a/netlib/http_auth.py +++ b/netlib/http_auth.py @@ -1,6 +1,7 @@ -from .contrib import md5crypt -import http +from __future__ import (absolute_import, print_function, division) +from passlib.apache import HtpasswdFile from argparse import Action, ArgumentTypeError +from . import http class NullProxyAuth(): @@ -78,32 +79,14 @@ class PassManHtpasswd: """ Read usernames and passwords from an htpasswd file """ - def __init__(self, fp): + def __init__(self, path): """ Raises ValueError if htpasswd file is invalid. """ - self.usernames = {} - for l in fp: - l = l.strip().split(':') - if len(l) != 2: - raise ValueError("Invalid htpasswd file.") - parts = l[1].split('$') - if len(parts) != 4: - raise ValueError("Invalid htpasswd file.") - self.usernames[l[0]] = dict( - token = l[1], - dummy = parts[0], - magic = parts[1], - salt = parts[2], - hashed_password = parts[3] - ) + self.htpasswd = HtpasswdFile(path) def test(self, username, password_token): - ui = self.usernames.get(username) - if not ui: - return False - expected = md5crypt.md5crypt(password_token, ui["salt"], '$'+ui["magic"]+'$') - return expected==ui["token"] + return bool(self.htpasswd.check_password(username, password_token)) class PassManSingleUser: @@ -149,6 +132,5 @@ class NonanonymousAuthAction(AuthAction): class HtpasswdAuthAction(AuthAction): def getPasswordManager(self, s): - with open(s, "r") as f: - return PassManHtpasswd(f) + return PassManHtpasswd(s) diff --git a/netlib/http_status.py b/netlib/http_status.py index 9f3f7e15..7dba2d56 100644 --- a/netlib/http_status.py +++ b/netlib/http_status.py @@ -1,3 +1,4 @@ +from __future__ import (absolute_import, print_function, division) CONTINUE = 100 SWITCHING = 101 diff --git a/netlib/http_uastrings.py b/netlib/http_uastrings.py index 826c31a5..d0d145da 100644 --- a/netlib/http_uastrings.py +++ b/netlib/http_uastrings.py @@ -1,3 +1,5 @@ +from __future__ import (absolute_import, print_function, division) + """ A small collection of useful user-agent header strings. These should be kept reasonably current to reflect common usage. diff --git a/netlib/odict.py b/netlib/odict.py index 46b74e8e..1e51bb3f 100644 --- a/netlib/odict.py +++ b/netlib/odict.py @@ -1,3 +1,4 @@ +from __future__ import (absolute_import, print_function, division) import re, copy @@ -23,6 +24,9 @@ class ODict: def __eq__(self, other): return self.lst == other.lst + def __ne__(self, other): + return not self.__eq__(other) + def __iter__(self): return self.lst.__iter__() @@ -60,7 +64,8 @@ class ODict: key, they are cleared. """ if isinstance(valuelist, basestring): - raise ValueError("ODict valuelist should be lists.") + raise ValueError("Expected list of values instead of string. Example: odict['Host'] = ['www.example.com']") + new = self._filter_lst(k, self.lst) for i in valuelist: new.append([k, i]) diff --git a/netlib/socks.py b/netlib/socks.py new file mode 100644 index 00000000..1da5b6cc --- /dev/null +++ b/netlib/socks.py @@ -0,0 +1,128 @@ +from __future__ import (absolute_import, print_function, division) +import socket +import struct +import array +from . import tcp + + +class SocksError(Exception): + def __init__(self, code, message): + super(SocksError, self).__init__(message) + self.code = code + + +class VERSION(object): + SOCKS4 = 0x04 + SOCKS5 = 0x05 + + +class CMD(object): + CONNECT = 0x01 + BIND = 0x02 + UDP_ASSOCIATE = 0x03 + + +class ATYP(object): + IPV4_ADDRESS = 0x01 + DOMAINNAME = 0x03 + IPV6_ADDRESS = 0x04 + + +class REP(object): + SUCCEEDED = 0x00 + GENERAL_SOCKS_SERVER_FAILURE = 0x01 + CONNECTION_NOT_ALLOWED_BY_RULESET = 0x02 + NETWORK_UNREACHABLE = 0x03 + HOST_UNREACHABLE = 0x04 + CONNECTION_REFUSED = 0x05 + TTL_EXPIRED = 0x06 + COMMAND_NOT_SUPPORTED = 0x07 + ADDRESS_TYPE_NOT_SUPPORTED = 0x08 + + +class METHOD(object): + NO_AUTHENTICATION_REQUIRED = 0x00 + GSSAPI = 0x01 + USERNAME_PASSWORD = 0x02 + NO_ACCEPTABLE_METHODS = 0xFF + + +class ClientGreeting(object): + __slots__ = ("ver", "methods") + + def __init__(self, ver, methods): + self.ver = ver + self.methods = methods + + @classmethod + def from_file(cls, f): + ver, nmethods = struct.unpack("!BB", f.read(2)) + methods = array.array("B") + methods.fromstring(f.read(nmethods)) + return cls(ver, methods) + + def to_file(self, f): + f.write(struct.pack("!BB", self.ver, len(self.methods))) + f.write(self.methods.tostring()) + +class ServerGreeting(object): + __slots__ = ("ver", "method") + + def __init__(self, ver, method): + self.ver = ver + self.method = method + + @classmethod + def from_file(cls, f): + ver, method = struct.unpack("!BB", f.read(2)) + return cls(ver, method) + + def to_file(self, f): + f.write(struct.pack("!BB", self.ver, self.method)) + +class Message(object): + __slots__ = ("ver", "msg", "atyp", "addr") + + def __init__(self, ver, msg, atyp, addr): + self.ver = ver + self.msg = msg + self.atyp = atyp + self.addr = addr + + @classmethod + def from_file(cls, f): + ver, msg, rsv, atyp = struct.unpack("!BBBB", f.read(4)) + if rsv != 0x00: + raise SocksError(REP.GENERAL_SOCKS_SERVER_FAILURE, + "Socks Request: Invalid reserved byte: %s" % rsv) + + if atyp == ATYP.IPV4_ADDRESS: + host = socket.inet_ntoa(f.read(4)) # We use tnoa here as ntop is not commonly available on Windows. + use_ipv6 = False + elif atyp == ATYP.IPV6_ADDRESS: + host = socket.inet_ntop(socket.AF_INET6, f.read(16)) + use_ipv6 = True + elif atyp == ATYP.DOMAINNAME: + length, = struct.unpack("!B", f.read(1)) + host = f.read(length) + use_ipv6 = False + else: + raise SocksError(REP.ADDRESS_TYPE_NOT_SUPPORTED, + "Socks Request: Unknown ATYP: %s" % atyp) + + port, = struct.unpack("!H", f.read(2)) + addr = tcp.Address((host, port), use_ipv6=use_ipv6) + return cls(ver, msg, atyp, addr) + + def to_file(self, f): + f.write(struct.pack("!BBBB", self.ver, self.msg, 0x00, self.atyp)) + if self.atyp == ATYP.IPV4_ADDRESS: + f.write(socket.inet_aton(self.addr.host)) + elif self.atyp == ATYP.IPV6_ADDRESS: + f.write(socket.inet_pton(socket.AF_INET6, self.addr.host)) + elif self.atyp == ATYP.DOMAINNAME: + f.write(struct.pack("!B", len(self.addr.host))) + f.write(self.addr.host) + else: + raise SocksError(REP.ADDRESS_TYPE_NOT_SUPPORTED, "Unknown ATYP: %s" % self.atyp) + f.write(struct.pack("!H", self.addr.port))
\ No newline at end of file diff --git a/netlib/tcp.py b/netlib/tcp.py index c5f97f94..2704eeae 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -1,6 +1,7 @@ +from __future__ import (absolute_import, print_function, division) import select, socket, threading, sys, time, traceback from OpenSSL import SSL -import certutils +from . import certutils EINTR = 4 @@ -17,7 +18,10 @@ OP_DONT_INSERT_EMPTY_FRAGMENTS = SSL.OP_DONT_INSERT_EMPTY_FRAGMENTS OP_EPHEMERAL_RSA = SSL.OP_EPHEMERAL_RSA OP_MICROSOFT_BIG_SSLV3_BUFFER = SSL.OP_MICROSOFT_BIG_SSLV3_BUFFER OP_MICROSOFT_SESS_ID_BUG = SSL.OP_MICROSOFT_SESS_ID_BUG -OP_MSIE_SSLV2_RSA_PADDING = SSL.OP_MSIE_SSLV2_RSA_PADDING +try: + OP_MSIE_SSLV2_RSA_PADDING = SSL.OP_MSIE_SSLV2_RSA_PADDING +except AttributeError: + pass OP_NETSCAPE_CA_DN_BUG = SSL.OP_NETSCAPE_CA_DN_BUG OP_NETSCAPE_CHALLENGE_BUG = SSL.OP_NETSCAPE_CHALLENGE_BUG OP_NETSCAPE_DEMO_CIPHER_CHANGE_BUG = SSL.OP_NETSCAPE_DEMO_CIPHER_CHANGE_BUG @@ -212,10 +216,16 @@ class Address(object): def use_ipv6(self, b): self.family = socket.AF_INET6 if b else socket.AF_INET + def __repr__(self): + return repr(self.address) + def __eq__(self, other): other = Address.wrap(other) return (self.address, self.family) == (other.address, other.family) + def __ne__(self, other): + return not self.__eq__(other) + class _Connection(object): def get_current_cipher(self): @@ -309,6 +319,8 @@ class TCPClient(_Connection): if self.source_address: connection.bind(self.source_address()) connection.connect(self.address()) + if not self.source_address: + self.source_address = Address(connection.getsockname()) self.rfile = Reader(connection.makefile('rb', self.rbufsize)) self.wfile = Writer(connection.makefile('wb', self.wbufsize)) except (socket.error, IOError), err: @@ -341,10 +353,9 @@ class BaseHandler(_Connection): self.ssl_established = False self.clientcert = None - def convert_to_ssl(self, cert, key, - method=SSLv23_METHOD, options=None, handle_sni=None, - request_client_cert=False, cipher_list=None, dhparams=None - ): + def _create_ssl_context(self, cert, key, method=SSLv23_METHOD, options=None, + handle_sni=None, request_client_cert=None, cipher_list=None, + dhparams=None, ca_file=None): """ cert: A certutils.SSLCert object. method: One of SSLv2_METHOD, SSLv3_METHOD, SSLv23_METHOD, or TLSv1_METHOD @@ -372,6 +383,8 @@ class BaseHandler(_Connection): ctx = SSL.Context(method) if not options is None: ctx.set_options(options) + if ca_file: + ctx.load_verify_locations(ca_file) if cipher_list: try: ctx.set_cipher_list(cipher_list) @@ -390,6 +403,14 @@ class BaseHandler(_Connection): # Return true to prevent cert verification error return True ctx.set_verify(SSL.VERIFY_PEER, ver) + return ctx + + def convert_to_ssl(self, cert, key, **sslctx_kwargs): + """ + Convert connection to SSL. + For a list of parameters, see BaseHandler._create_ssl_context(...) + """ + ctx = self._create_ssl_context(cert, key, **sslctx_kwargs) self.connection = SSL.Connection(ctx, self.connection) self.ssl_established = True self.connection.set_accept_state() @@ -443,7 +464,7 @@ class TCPServer(object): if ex[0] == EINTR: continue else: - raise + raise if self.socket in r: connection, client_address = self.socket.accept() t = threading.Thread( @@ -473,10 +494,13 @@ class TCPServer(object): # none. if traceback: exc = traceback.format_exc() - print >> fp, '-'*40 - print >> fp, "Error in processing of request from %s:%s" % (client_address.host, client_address.port) - print >> fp, exc - print >> fp, '-'*40 + print('-'*40, file=fp) + print( + "Error in processing of request from %s:%s" % ( + client_address.host, client_address.port + ), file=fp) + print(exc, file=fp) + print('-'*40, file=fp) def handle_client_connection(self, conn, client_address): # pragma: no cover """ diff --git a/netlib/test.py b/netlib/test.py index bb0012ad..31a848a6 100644 --- a/netlib/test.py +++ b/netlib/test.py @@ -1,6 +1,7 @@ +from __future__ import (absolute_import, print_function, division) import threading, Queue, cStringIO -import tcp, certutils import OpenSSL +from . import tcp, certutils class ServerThread(threading.Thread): def __init__(self, server): diff --git a/netlib/utils.py b/netlib/utils.py index 61fd54ae..79077ac6 100644 --- a/netlib/utils.py +++ b/netlib/utils.py @@ -1,3 +1,5 @@ +from __future__ import (absolute_import, print_function, division) + def isascii(s): try: @@ -32,13 +34,13 @@ def hexdump(s): """ parts = [] for i in range(0, len(s), 16): - o = "%.10x"%i - part = s[i:i+16] - x = " ".join("%.2x"%ord(i) for i in part) + o = "%.10x" % i + part = s[i:i + 16] + x = " ".join("%.2x" % ord(i) for i in part) if len(part) < 16: x += " " x += " ".join(" " for i in range(16 - len(part))) parts.append( (o, x, cleanBin(part, True)) ) - return parts + return parts
\ No newline at end of file diff --git a/netlib/version.py b/netlib/version.py index 1d3250e1..913f753a 100644 --- a/netlib/version.py +++ b/netlib/version.py @@ -1,4 +1,7 @@ +from __future__ import (absolute_import, print_function, division) + IVERSION = (0, 11) VERSION = ".".join(str(i) for i in IVERSION) +MINORVERSION = ".".join(str(i) for i in IVERSION[:2]) NAME = "netlib" NAMEVERSION = NAME + " " + VERSION diff --git a/netlib/wsgi.py b/netlib/wsgi.py index b576bdff..568b1f9c 100644 --- a/netlib/wsgi.py +++ b/netlib/wsgi.py @@ -1,5 +1,6 @@ +from __future__ import (absolute_import, print_function, division) import cStringIO, urllib, time, traceback -import odict, tcp +from . import odict, tcp class ClientConn: @@ -8,15 +9,15 @@ class ClientConn: class Flow: - def __init__(self, client_conn): - self.client_conn = client_conn + def __init__(self, address, request): + self.client_conn = ClientConn(address) + self.request = request class Request: - def __init__(self, client_conn, scheme, method, path, headers, content): + def __init__(self, scheme, method, path, headers, content): self.scheme, self.method, self.path = scheme, method, path self.headers, self.content = headers, content - self.flow = Flow(client_conn) def date_time_string(): @@ -38,37 +39,37 @@ class WSGIAdaptor: def __init__(self, app, domain, port, sversion): self.app, self.domain, self.port, self.sversion = app, domain, port, sversion - def make_environ(self, request, errsoc, **extra): - if '?' in request.path: - path_info, query = request.path.split('?', 1) + def make_environ(self, flow, errsoc, **extra): + if '?' in flow.request.path: + path_info, query = flow.request.path.split('?', 1) else: - path_info = request.path + path_info = flow.request.path query = '' environ = { 'wsgi.version': (1, 0), - 'wsgi.url_scheme': request.scheme, - 'wsgi.input': cStringIO.StringIO(request.content), + 'wsgi.url_scheme': flow.request.scheme, + 'wsgi.input': cStringIO.StringIO(flow.request.content), 'wsgi.errors': errsoc, 'wsgi.multithread': True, 'wsgi.multiprocess': False, 'wsgi.run_once': False, 'SERVER_SOFTWARE': self.sversion, - 'REQUEST_METHOD': request.method, + 'REQUEST_METHOD': flow.request.method, 'SCRIPT_NAME': '', 'PATH_INFO': urllib.unquote(path_info), 'QUERY_STRING': query, - 'CONTENT_TYPE': request.headers.get('Content-Type', [''])[0], - 'CONTENT_LENGTH': request.headers.get('Content-Length', [''])[0], + 'CONTENT_TYPE': flow.request.headers.get('Content-Type', [''])[0], + 'CONTENT_LENGTH': flow.request.headers.get('Content-Length', [''])[0], 'SERVER_NAME': self.domain, 'SERVER_PORT': str(self.port), # FIXME: We need to pick up the protocol read from the request. 'SERVER_PROTOCOL': "HTTP/1.1", } environ.update(extra) - if request.flow.client_conn.address: - environ["REMOTE_ADDR"], environ["REMOTE_PORT"] = request.flow.client_conn.address() + if flow.client_conn.address: + environ["REMOTE_ADDR"], environ["REMOTE_PORT"] = flow.client_conn.address() - for key, value in request.headers.items(): + for key, value in flow.request.headers.items(): key = 'HTTP_' + key.upper().replace('-', '_') if key not in ('HTTP_CONTENT_TYPE', 'HTTP_CONTENT_LENGTH'): environ[key] = value |