diff options
Diffstat (limited to 'test/test_tcp.py')
-rw-r--r-- | test/test_tcp.py | 159 |
1 files changed, 72 insertions, 87 deletions
diff --git a/test/test_tcp.py b/test/test_tcp.py index 0417aa21..ce06ad66 100644 --- a/test/test_tcp.py +++ b/test/test_tcp.py @@ -1,38 +1,7 @@ import cStringIO, threading, Queue, time -from netlib import tcp, certutils +from netlib import tcp, certutils, test import tutils -class ServerThread(threading.Thread): - def __init__(self, server): - self.server = server - threading.Thread.__init__(self) - - def run(self): - self.server.serve_forever() - - def shutdown(self): - self.server.shutdown() - - -class ServerTestBase: - @classmethod - def setupAll(cls): - cls.q = Queue.Queue() - s = cls.makeserver() - cls.port = s.port - cls.server = ServerThread(s) - cls.server.start() - - @classmethod - def teardownAll(cls): - cls.server.shutdown() - - - @property - def last_handler(self): - return self.server.server.last_handler - - class SNIHandler(tcp.BaseHandler): sni = None def handle_sni(self, connection): @@ -88,43 +57,10 @@ class TimeoutHandler(tcp.BaseHandler): self.timeout = True -class TServer(tcp.TCPServer): - def __init__(self, addr, ssl, q, handler_klass, v3_only=False): - tcp.TCPServer.__init__(self, addr) - self.ssl, self.q = ssl, q - self.v3_only = v3_only - self.handler_klass = handler_klass - self.last_handler = None - - def handle_connection(self, request, client_address): - h = self.handler_klass(request, client_address, self) - self.last_handler = h - if self.ssl: - if self.v3_only: - method = tcp.SSLv3_METHOD - options = tcp.OP_NO_SSLv2|tcp.OP_NO_TLSv1 - else: - method = tcp.SSLv23_METHOD - options = None - h.convert_to_ssl( - tutils.test_data.path("data/server.crt"), - tutils.test_data.path("data/server.key"), - method = method, - options = options, - ) - h.handle() - h.finish() - - def handle_error(self, request, client_address): - s = cStringIO.StringIO() - tcp.TCPServer.handle_error(self, request, client_address, s) - self.q.put(s.getvalue()) - - -class TestServer(ServerTestBase): +class TestServer(test.ServerTestBase): @classmethod def makeserver(cls): - return TServer(("127.0.0.1", 0), False, cls.q, EchoHandler) + return test.TServer(False, cls.q, EchoHandler) def test_echo(self): testval = "echo!\n" @@ -135,10 +71,10 @@ class TestServer(ServerTestBase): assert c.rfile.readline() == testval -class TestDisconnect(ServerTestBase): +class TestDisconnect(test.ServerTestBase): @classmethod def makeserver(cls): - return TServer(("127.0.0.1", 0), False, cls.q, EchoHandler) + return test.TServer(False, cls.q, EchoHandler) def test_echo(self): testval = "echo!\n" @@ -149,10 +85,18 @@ class TestDisconnect(ServerTestBase): assert c.rfile.readline() == testval -class TestServerSSL(ServerTestBase): +class TestServerSSL(test.ServerTestBase): @classmethod def makeserver(cls): - return TServer(("127.0.0.1", 0), True, cls.q, EchoHandler) + return test.TServer( + dict( + cert = tutils.test_data.path("data/server.crt"), + key = tutils.test_data.path("data/server.key"), + v3_only = False + ), + cls.q, + EchoHandler + ) def test_echo(self): c = tcp.TCPClient("127.0.0.1", self.port) @@ -167,10 +111,19 @@ class TestServerSSL(ServerTestBase): assert certutils.get_remote_cert("127.0.0.1", self.port, None).digest("sha1") -class TestSSLv3Only(ServerTestBase): +class TestSSLv3Only(test.ServerTestBase): + v3_only = True @classmethod def makeserver(cls): - return TServer(("127.0.0.1", 0), True, cls.q, EchoHandler, True) + return test.TServer( + dict( + cert = tutils.test_data.path("data/server.crt"), + key = tutils.test_data.path("data/server.key"), + v3_only = True + ), + cls.q, + EchoHandler, + ) def test_failure(self): c = tcp.TCPClient("127.0.0.1", self.port) @@ -178,10 +131,18 @@ class TestSSLv3Only(ServerTestBase): tutils.raises(tcp.NetLibError, c.convert_to_ssl, sni="foo.com", method=tcp.TLSv1_METHOD) -class TestSSLClientCert(ServerTestBase): +class TestSSLClientCert(test.ServerTestBase): @classmethod def makeserver(cls): - return TServer(("127.0.0.1", 0), True, cls.q, CertHandler) + return test.TServer( + dict( + cert = tutils.test_data.path("data/server.crt"), + key = tutils.test_data.path("data/server.key"), + v3_only = False + ), + cls.q, + CertHandler + ) def test_clientcert(self): c = tcp.TCPClient("127.0.0.1", self.port) @@ -199,10 +160,18 @@ class TestSSLClientCert(ServerTestBase): ) -class TestSNI(ServerTestBase): +class TestSNI(test.ServerTestBase): @classmethod def makeserver(cls): - return TServer(("127.0.0.1", 0), True, cls.q, SNIHandler) + return test.TServer( + dict( + cert = tutils.test_data.path("data/server.crt"), + key = tutils.test_data.path("data/server.key"), + v3_only = False + ), + cls.q, + SNIHandler + ) def test_echo(self): c = tcp.TCPClient("127.0.0.1", self.port) @@ -211,10 +180,18 @@ class TestSNI(ServerTestBase): assert c.rfile.readline() == "foo.com" -class TestSSLDisconnect(ServerTestBase): +class TestSSLDisconnect(test.ServerTestBase): @classmethod def makeserver(cls): - return TServer(("127.0.0.1", 0), True, cls.q, DisconnectHandler) + return test.TServer( + dict( + cert = tutils.test_data.path("data/server.crt"), + key = tutils.test_data.path("data/server.key"), + v3_only = False + ), + cls.q, + DisconnectHandler + ) def test_echo(self): c = tcp.TCPClient("127.0.0.1", self.port) @@ -227,10 +204,10 @@ class TestSSLDisconnect(ServerTestBase): tutils.raises(Queue.Empty, self.q.get_nowait) -class TestDisconnect(ServerTestBase): +class TestSSLDisconnect(test.ServerTestBase): @classmethod def makeserver(cls): - return TServer(("127.0.0.1", 0), False, cls.q, DisconnectHandler) + return test.TServer(False, cls.q, DisconnectHandler) def test_echo(self): c = tcp.TCPClient("127.0.0.1", self.port) @@ -242,10 +219,10 @@ class TestDisconnect(ServerTestBase): c.close() -class TestServerTimeOut(ServerTestBase): +class TestServerTimeOut(test.ServerTestBase): @classmethod def makeserver(cls): - return TServer(("127.0.0.1", 0), False, cls.q, TimeoutHandler) + return test.TServer(False, cls.q, TimeoutHandler) def test_timeout(self): c = tcp.TCPClient("127.0.0.1", self.port) @@ -254,10 +231,10 @@ class TestServerTimeOut(ServerTestBase): assert self.last_handler.timeout -class TestTimeOut(ServerTestBase): +class TestTimeOut(test.ServerTestBase): @classmethod def makeserver(cls): - return TServer(("127.0.0.1", 0), False, cls.q, HangHandler) + return test.TServer(False, cls.q, HangHandler) def test_timeout(self): c = tcp.TCPClient("127.0.0.1", self.port) @@ -266,10 +243,18 @@ class TestTimeOut(ServerTestBase): tutils.raises(tcp.NetLibTimeout, c.rfile.read, 10) -class TestSSLTimeOut(ServerTestBase): +class TestSSLTimeOut(test.ServerTestBase): @classmethod def makeserver(cls): - return TServer(("127.0.0.1", 0), True, cls.q, HangHandler) + return test.TServer( + dict( + cert = tutils.test_data.path("data/server.crt"), + key = tutils.test_data.path("data/server.key"), + v3_only = False + ), + cls.q, + HangHandler + ) def test_timeout_client(self): c = tcp.TCPClient("127.0.0.1", self.port) |