diff options
-rw-r--r-- | netlib/certutils.py | 29 | ||||
-rw-r--r-- | netlib/http.py | 17 | ||||
-rw-r--r-- | netlib/http_auth.py | 12 | ||||
-rw-r--r-- | netlib/tcp.py | 21 | ||||
-rw-r--r-- | netlib/test.py | 5 | ||||
-rw-r--r-- | netlib/version.py | 2 | ||||
-rw-r--r-- | setup.py | 2 | ||||
-rw-r--r-- | test/test_certutils.py | 9 | ||||
-rw-r--r-- | test/test_http.py | 22 | ||||
-rw-r--r-- | test/test_http_auth.py | 25 | ||||
-rw-r--r-- | test/test_tcp.py | 24 |
11 files changed, 94 insertions, 74 deletions
diff --git a/netlib/certutils.py b/netlib/certutils.py index dab7e318..d9b8ce57 100644 --- a/netlib/certutils.py +++ b/netlib/certutils.py @@ -5,17 +5,20 @@ from pyasn1.error import PyAsn1Error import OpenSSL import tcp +default_exp = 62208000 # =24 * 60 * 60 * 720 +default_o = "mitmproxy" +default_cn = "mitmproxy" -def create_ca(): +def create_ca(o=default_o, cn=default_cn, exp=default_exp): key = OpenSSL.crypto.PKey() key.generate_key(OpenSSL.crypto.TYPE_RSA, 1024) ca = OpenSSL.crypto.X509() ca.set_serial_number(int(time.time()*10000)) ca.set_version(2) - ca.get_subject().CN = "mitmproxy" - ca.get_subject().O = "mitmproxy" + ca.get_subject().CN = cn + ca.get_subject().O = o ca.gmtime_adj_notBefore(0) - ca.gmtime_adj_notAfter(24 * 60 * 60 * 720) + ca.gmtime_adj_notAfter(exp) ca.set_issuer(ca.get_subject()) ca.set_pubkey(key) ca.add_extensions([ @@ -35,7 +38,7 @@ def create_ca(): return key, ca -def dummy_ca(path): +def dummy_ca(path, o=default_o, cn=default_cn, exp=default_exp): dirname = os.path.dirname(path) if not os.path.exists(dirname): os.makedirs(dirname) @@ -45,7 +48,7 @@ def dummy_ca(path): else: basename = os.path.basename(path) - key, ca = create_ca() + key, ca = create_ca(o=o, cn=cn, exp=exp) # Dump the CA plus private key f = open(path, "wb") @@ -113,18 +116,6 @@ class CertStore: def __init__(self): self.certs = {} - def check_domain(self, commonname): - try: - commonname.decode("idna") - commonname.decode("ascii") - except: - return False - if ".." in commonname: - return False - if "/" in commonname: - return False - return True - def get_cert(self, commonname, sans, cacert): """ Returns an SSLCert object. @@ -138,8 +129,6 @@ class CertStore: Return None if the certificate could not be found or generated. """ - if not self.check_domain(commonname): - return None if commonname in self.certs: return self.certs[commonname] c = dummy_cert(cacert, commonname, sans) diff --git a/netlib/http.py b/netlib/http.py index f1a2bfb5..7060b688 100644 --- a/netlib/http.py +++ b/netlib/http.py @@ -283,32 +283,23 @@ def parse_init_http(line): return method, url, httpversion -def request_connection_close(httpversion, headers): +def connection_close(httpversion, headers): """ - Checks the request to see if the client connection should be closed. + Checks the message to see if the client connection should be closed according to RFC 2616 Section 8.1 """ + # At first, check if we have an explicit Connection header. if "connection" in headers: toks = get_header_tokens(headers, "connection") if "close" in toks: return True elif "keep-alive" in toks: return False - # HTTP 1.1 connections are assumed to be persistent + # If we don't have a Connection header, HTTP 1.1 connections are assumed to be persistent if httpversion == (1, 1): return False return True -def response_connection_close(httpversion, headers): - """ - Checks the response to see if the client connection should be closed. - """ - if request_connection_close(httpversion, headers): - return True - elif (not has_chunked_encoding(headers)) and "content-length" in headers: - return False - return True - def read_http_body_request(rfile, wfile, headers, httpversion, limit): """ diff --git a/netlib/http_auth.py b/netlib/http_auth.py index 948d503a..69bee5c1 100644 --- a/netlib/http_auth.py +++ b/netlib/http_auth.py @@ -131,17 +131,16 @@ class AuthAction(Action): authenticator = BasicProxyAuth(passman, "mitmproxy") setattr(namespace, self.dest, authenticator) - def getPasswordManager(self, s): - """ - returns the password manager - """ + def getPasswordManager(self, s): # pragma: nocover raise NotImplementedError() class SingleuserAuthAction(AuthAction): def getPasswordManager(self, s): if len(s.split(':')) != 2: - raise ArgumentTypeError("Invalid single-user specification. Please use the format username:password") + raise ArgumentTypeError( + "Invalid single-user specification. Please use the format username:password" + ) username, password = s.split(':') return PassManSingleUser(username, password) @@ -154,4 +153,5 @@ class NonanonymousAuthAction(AuthAction): class HtpasswdAuthAction(AuthAction): def getPasswordManager(self, s): with open(s, "r") as f: - return PassManHtpasswd(f)
\ No newline at end of file + return PassManHtpasswd(f) + diff --git a/netlib/tcp.py b/netlib/tcp.py index 31e9a398..b3be43d6 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -235,6 +235,7 @@ class TCPClient: try: if self.ssl_established: self.connection.shutdown() + self.connection.sock_shutdown(socket.SHUT_WR) else: self.connection.shutdown(socket.SHUT_WR) #Section 4.2.2.13 of RFC 1122 tells us that a close() with any pending readable data could lead to an immediate RST being sent. @@ -266,7 +267,7 @@ class BaseHandler: self.clientcert = None - def convert_to_ssl(self, cert, key, method=SSLv23_METHOD, options=None, handle_sni=None, request_client_cert=False): + def convert_to_ssl(self, cert, key, method=SSLv23_METHOD, options=None, handle_sni=None, request_client_cert=False, cipher_list=None): """ cert: A certutils.SSLCert object. method: One of SSLv2_METHOD, SSLv3_METHOD, SSLv23_METHOD, or TLSv1_METHOD @@ -294,6 +295,8 @@ class BaseHandler: ctx = SSL.Context(method) if not options is None: ctx.set_options(options) + if cipher_list: + ctx.set_cipher_list(cipher_list) if handle_sni: # SNI callback happens during do_handshake() ctx.set_tlsext_servername_callback(handle_sni) @@ -302,6 +305,8 @@ class BaseHandler: if request_client_cert: def ver(*args): self.clientcert = certutils.SSLCert(args[1]) + # Return true to prevent cert verification error + return True ctx.set_verify(SSL.VERIFY_PEER, ver) self.connection = SSL.Connection(ctx, self.connection) self.ssl_established = True @@ -338,10 +343,12 @@ class BaseHandler: try: if self.ssl_established: self.connection.shutdown() + self.connection.sock_shutdown(socket.SHUT_WR) else: self.connection.shutdown(socket.SHUT_WR) - #Section 4.2.2.13 of RFC 1122 tells us that a close() with any pending readable data could lead to an immediate RST being sent. - #http://ia600609.us.archive.org/22/items/TheUltimateSo_lingerPageOrWhyIsMyTcpNotReliable/the-ultimate-so_linger-page-or-why-is-my-tcp-not-reliable.html + # Section 4.2.2.13 of RFC 1122 tells us that a close() with any + # pending readable data could lead to an immediate RST being sent. + # http://ia600609.us.archive.org/22/items/TheUltimateSo_lingerPageOrWhyIsMyTcpNotReliable/the-ultimate-so_linger-page-or-why-is-my-tcp-not-reliable.html while self.connection.recv(4096): pass except (socket.error, SSL.Error): @@ -376,7 +383,13 @@ class TCPServer: self.__is_shut_down.clear() try: while not self.__shutdown_request: - r, w, e = select.select([self.socket], [], [], poll_interval) + try: + r, w, e = select.select([self.socket], [], [], poll_interval) + except select.error, ex: + if ex[0] == 4: + continue + else: + raise if self.socket in r: request, client_address = self.socket.accept() t = threading.Thread( diff --git a/netlib/test.py b/netlib/test.py index 661395c5..e7d4c233 100644 --- a/netlib/test.py +++ b/netlib/test.py @@ -52,7 +52,7 @@ class TServer(tcp.TCPServer): self.last_handler = h if self.ssl: cert = certutils.SSLCert.from_pem( - file(self.ssl["cert"], "r").read() + file(self.ssl["cert"], "rb").read() ) if self.ssl["v3_only"]: method = tcp.SSLv3_METHOD @@ -66,7 +66,8 @@ class TServer(tcp.TCPServer): method = method, options = options, handle_sni = getattr(h, "handle_sni", None), - request_client_cert = self.ssl["request_client_cert"] + request_client_cert = self.ssl["request_client_cert"], + cipher_list = self.ssl.get("cipher_list", None) ) h.handle() h.finish() diff --git a/netlib/version.py b/netlib/version.py index 63a9d862..32013c35 100644 --- a/netlib/version.py +++ b/netlib/version.py @@ -1,4 +1,4 @@ -IVERSION = (0, 9, 1) +IVERSION = (0, 9, 2) VERSION = ".".join(str(i) for i in IVERSION) NAME = "netlib" NAMEVERSION = NAME + " " + VERSION @@ -65,7 +65,7 @@ def findPackages(path, dataExclude=[]): return packages, package_data -long_description = file("README").read() +long_description = file("README","rb").read() packages, package_data = findPackages("netlib") setup( name = "netlib", diff --git a/test/test_certutils.py b/test/test_certutils.py index 0b4baf75..7a00caca 100644 --- a/test/test_certutils.py +++ b/test/test_certutils.py @@ -32,15 +32,6 @@ class TestCertStore: assert c.get_cert("foo.com", [], ca) assert c.get_cert("*.foo.com", [], ca) - def test_check_domain(self): - c = certutils.CertStore() - assert c.check_domain("foo") - assert c.check_domain("\x01foo") - assert not c.check_domain("\xfefoo") - assert not c.check_domain("xn--\0") - assert not c.check_domain("foo..foo") - assert not c.check_domain("foo/foo") - class TestDummyCert: def test_with_ca(self): diff --git a/test/test_http.py b/test/test_http.py index 62d0c3dc..4d89bf24 100644 --- a/test/test_http.py +++ b/test/test_http.py @@ -38,28 +38,16 @@ def test_read_chunked(): tutils.raises("too large", http.read_chunked, 500, s, 2) -def test_request_connection_close(): +def test_connection_close(): h = odict.ODictCaseless() - assert http.request_connection_close((1, 0), h) - assert not http.request_connection_close((1, 1), h) + assert http.connection_close((1, 0), h) + assert not http.connection_close((1, 1), h) h["connection"] = ["keep-alive"] - assert not http.request_connection_close((1, 1), h) + assert not http.connection_close((1, 1), h) h["connection"] = ["close"] - assert http.request_connection_close((1, 1), h) - - -def test_response_connection_close(): - h = odict.ODictCaseless() - assert http.response_connection_close((1, 1), h) - - h["content-length"] = [10] - assert not http.response_connection_close((1, 1), h) - - h["connection"] = ["close"] - assert http.response_connection_close((1, 1), h) - + assert http.connection_close((1, 1), h) def test_read_http_body_response(): h = odict.ODictCaseless() diff --git a/test/test_http_auth.py b/test/test_http_auth.py index cae69f5e..8238d4ca 100644 --- a/test/test_http_auth.py +++ b/test/test_http_auth.py @@ -1,5 +1,6 @@ import binascii, cStringIO from netlib import odict, http_auth, http +import mock import tutils class TestPassManNonAnon: @@ -17,7 +18,7 @@ class TestPassManHtpasswd: tutils.raises("invalid htpasswd", http_auth.PassManHtpasswd, s) def test_simple(self): - f = open(tutils.test_data.path("data/htpasswd")) + f = open(tutils.test_data.path("data/htpasswd"),"rb") pm = http_auth.PassManHtpasswd(f) vals = ("basic", "test", "test") @@ -79,3 +80,25 @@ class TestBasicProxyAuth: hdrs[ba.AUTH_HEADER] = [http.assemble_http_basic_auth(*vals)] assert not ba.authenticate(hdrs) + +class Bunch: pass + +class TestAuthAction: + def test_nonanonymous(self): + m = Bunch() + aa = http_auth.NonanonymousAuthAction(None, None) + aa(None, m, None, None) + assert m.authenticator + + def test_singleuser(self): + m = Bunch() + aa = http_auth.SingleuserAuthAction(None, None) + aa(None, m, "foo:bar", None) + assert m.authenticator + tutils.raises("invalid", aa, None, m, "foo", None) + + def test_httppasswd(self): + m = Bunch() + aa = http_auth.HtpasswdAuthAction(None, None) + aa(None, m, tutils.test_data.path("data/htpasswd"), None) + assert m.authenticator diff --git a/test/test_tcp.py b/test/test_tcp.py index 318d2abc..f45acb00 100644 --- a/test/test_tcp.py +++ b/test/test_tcp.py @@ -34,6 +34,14 @@ class CertHandler(tcp.BaseHandler): self.wfile.flush() +class ClientCipherListHandler(tcp.BaseHandler): + sni = None + + def handle(self): + self.wfile.write("%s"%self.connection.get_cipher_list()) + self.wfile.flush() + + class DisconnectHandler(tcp.BaseHandler): def handle(self): self.close() @@ -180,6 +188,22 @@ class TestSNI(test.ServerTestBase): assert c.rfile.readline() == "foo.com" +class TestClientCipherList(test.ServerTestBase): + handler = ClientCipherListHandler + ssl = dict( + cert = tutils.test_data.path("data/server.crt"), + key = tutils.test_data.path("data/server.key"), + request_client_cert = False, + v3_only = False, + cipher_list = 'RC4-SHA' + ) + def test_echo(self): + c = tcp.TCPClient("127.0.0.1", self.port) + c.connect() + c.convert_to_ssl(sni="foo.com") + assert c.rfile.readline() == "['RC4-SHA']" + + class TestSSLDisconnect(test.ServerTestBase): handler = DisconnectHandler ssl = dict( |