aboutsummaryrefslogtreecommitdiffstats
path: root/test/test_tcp.py
diff options
context:
space:
mode:
Diffstat (limited to 'test/test_tcp.py')
-rw-r--r--test/test_tcp.py159
1 files changed, 72 insertions, 87 deletions
diff --git a/test/test_tcp.py b/test/test_tcp.py
index 0417aa21..ce06ad66 100644
--- a/test/test_tcp.py
+++ b/test/test_tcp.py
@@ -1,38 +1,7 @@
import cStringIO, threading, Queue, time
-from netlib import tcp, certutils
+from netlib import tcp, certutils, test
import tutils
-class ServerThread(threading.Thread):
- def __init__(self, server):
- self.server = server
- threading.Thread.__init__(self)
-
- def run(self):
- self.server.serve_forever()
-
- def shutdown(self):
- self.server.shutdown()
-
-
-class ServerTestBase:
- @classmethod
- def setupAll(cls):
- cls.q = Queue.Queue()
- s = cls.makeserver()
- cls.port = s.port
- cls.server = ServerThread(s)
- cls.server.start()
-
- @classmethod
- def teardownAll(cls):
- cls.server.shutdown()
-
-
- @property
- def last_handler(self):
- return self.server.server.last_handler
-
-
class SNIHandler(tcp.BaseHandler):
sni = None
def handle_sni(self, connection):
@@ -88,43 +57,10 @@ class TimeoutHandler(tcp.BaseHandler):
self.timeout = True
-class TServer(tcp.TCPServer):
- def __init__(self, addr, ssl, q, handler_klass, v3_only=False):
- tcp.TCPServer.__init__(self, addr)
- self.ssl, self.q = ssl, q
- self.v3_only = v3_only
- self.handler_klass = handler_klass
- self.last_handler = None
-
- def handle_connection(self, request, client_address):
- h = self.handler_klass(request, client_address, self)
- self.last_handler = h
- if self.ssl:
- if self.v3_only:
- method = tcp.SSLv3_METHOD
- options = tcp.OP_NO_SSLv2|tcp.OP_NO_TLSv1
- else:
- method = tcp.SSLv23_METHOD
- options = None
- h.convert_to_ssl(
- tutils.test_data.path("data/server.crt"),
- tutils.test_data.path("data/server.key"),
- method = method,
- options = options,
- )
- h.handle()
- h.finish()
-
- def handle_error(self, request, client_address):
- s = cStringIO.StringIO()
- tcp.TCPServer.handle_error(self, request, client_address, s)
- self.q.put(s.getvalue())
-
-
-class TestServer(ServerTestBase):
+class TestServer(test.ServerTestBase):
@classmethod
def makeserver(cls):
- return TServer(("127.0.0.1", 0), False, cls.q, EchoHandler)
+ return test.TServer(False, cls.q, EchoHandler)
def test_echo(self):
testval = "echo!\n"
@@ -135,10 +71,10 @@ class TestServer(ServerTestBase):
assert c.rfile.readline() == testval
-class TestDisconnect(ServerTestBase):
+class TestDisconnect(test.ServerTestBase):
@classmethod
def makeserver(cls):
- return TServer(("127.0.0.1", 0), False, cls.q, EchoHandler)
+ return test.TServer(False, cls.q, EchoHandler)
def test_echo(self):
testval = "echo!\n"
@@ -149,10 +85,18 @@ class TestDisconnect(ServerTestBase):
assert c.rfile.readline() == testval
-class TestServerSSL(ServerTestBase):
+class TestServerSSL(test.ServerTestBase):
@classmethod
def makeserver(cls):
- return TServer(("127.0.0.1", 0), True, cls.q, EchoHandler)
+ return test.TServer(
+ dict(
+ cert = tutils.test_data.path("data/server.crt"),
+ key = tutils.test_data.path("data/server.key"),
+ v3_only = False
+ ),
+ cls.q,
+ EchoHandler
+ )
def test_echo(self):
c = tcp.TCPClient("127.0.0.1", self.port)
@@ -167,10 +111,19 @@ class TestServerSSL(ServerTestBase):
assert certutils.get_remote_cert("127.0.0.1", self.port, None).digest("sha1")
-class TestSSLv3Only(ServerTestBase):
+class TestSSLv3Only(test.ServerTestBase):
+ v3_only = True
@classmethod
def makeserver(cls):
- return TServer(("127.0.0.1", 0), True, cls.q, EchoHandler, True)
+ return test.TServer(
+ dict(
+ cert = tutils.test_data.path("data/server.crt"),
+ key = tutils.test_data.path("data/server.key"),
+ v3_only = True
+ ),
+ cls.q,
+ EchoHandler,
+ )
def test_failure(self):
c = tcp.TCPClient("127.0.0.1", self.port)
@@ -178,10 +131,18 @@ class TestSSLv3Only(ServerTestBase):
tutils.raises(tcp.NetLibError, c.convert_to_ssl, sni="foo.com", method=tcp.TLSv1_METHOD)
-class TestSSLClientCert(ServerTestBase):
+class TestSSLClientCert(test.ServerTestBase):
@classmethod
def makeserver(cls):
- return TServer(("127.0.0.1", 0), True, cls.q, CertHandler)
+ return test.TServer(
+ dict(
+ cert = tutils.test_data.path("data/server.crt"),
+ key = tutils.test_data.path("data/server.key"),
+ v3_only = False
+ ),
+ cls.q,
+ CertHandler
+ )
def test_clientcert(self):
c = tcp.TCPClient("127.0.0.1", self.port)
@@ -199,10 +160,18 @@ class TestSSLClientCert(ServerTestBase):
)
-class TestSNI(ServerTestBase):
+class TestSNI(test.ServerTestBase):
@classmethod
def makeserver(cls):
- return TServer(("127.0.0.1", 0), True, cls.q, SNIHandler)
+ return test.TServer(
+ dict(
+ cert = tutils.test_data.path("data/server.crt"),
+ key = tutils.test_data.path("data/server.key"),
+ v3_only = False
+ ),
+ cls.q,
+ SNIHandler
+ )
def test_echo(self):
c = tcp.TCPClient("127.0.0.1", self.port)
@@ -211,10 +180,18 @@ class TestSNI(ServerTestBase):
assert c.rfile.readline() == "foo.com"
-class TestSSLDisconnect(ServerTestBase):
+class TestSSLDisconnect(test.ServerTestBase):
@classmethod
def makeserver(cls):
- return TServer(("127.0.0.1", 0), True, cls.q, DisconnectHandler)
+ return test.TServer(
+ dict(
+ cert = tutils.test_data.path("data/server.crt"),
+ key = tutils.test_data.path("data/server.key"),
+ v3_only = False
+ ),
+ cls.q,
+ DisconnectHandler
+ )
def test_echo(self):
c = tcp.TCPClient("127.0.0.1", self.port)
@@ -227,10 +204,10 @@ class TestSSLDisconnect(ServerTestBase):
tutils.raises(Queue.Empty, self.q.get_nowait)
-class TestDisconnect(ServerTestBase):
+class TestSSLDisconnect(test.ServerTestBase):
@classmethod
def makeserver(cls):
- return TServer(("127.0.0.1", 0), False, cls.q, DisconnectHandler)
+ return test.TServer(False, cls.q, DisconnectHandler)
def test_echo(self):
c = tcp.TCPClient("127.0.0.1", self.port)
@@ -242,10 +219,10 @@ class TestDisconnect(ServerTestBase):
c.close()
-class TestServerTimeOut(ServerTestBase):
+class TestServerTimeOut(test.ServerTestBase):
@classmethod
def makeserver(cls):
- return TServer(("127.0.0.1", 0), False, cls.q, TimeoutHandler)
+ return test.TServer(False, cls.q, TimeoutHandler)
def test_timeout(self):
c = tcp.TCPClient("127.0.0.1", self.port)
@@ -254,10 +231,10 @@ class TestServerTimeOut(ServerTestBase):
assert self.last_handler.timeout
-class TestTimeOut(ServerTestBase):
+class TestTimeOut(test.ServerTestBase):
@classmethod
def makeserver(cls):
- return TServer(("127.0.0.1", 0), False, cls.q, HangHandler)
+ return test.TServer(False, cls.q, HangHandler)
def test_timeout(self):
c = tcp.TCPClient("127.0.0.1", self.port)
@@ -266,10 +243,18 @@ class TestTimeOut(ServerTestBase):
tutils.raises(tcp.NetLibTimeout, c.rfile.read, 10)
-class TestSSLTimeOut(ServerTestBase):
+class TestSSLTimeOut(test.ServerTestBase):
@classmethod
def makeserver(cls):
- return TServer(("127.0.0.1", 0), True, cls.q, HangHandler)
+ return test.TServer(
+ dict(
+ cert = tutils.test_data.path("data/server.crt"),
+ key = tutils.test_data.path("data/server.key"),
+ v3_only = False
+ ),
+ cls.q,
+ HangHandler
+ )
def test_timeout_client(self):
c = tcp.TCPClient("127.0.0.1", self.port)