aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--netlib/tcp.py11
-rw-r--r--netlib/test.py67
-rw-r--r--test/test_certutils.py1
-rw-r--r--test/test_tcp.py159
4 files changed, 146 insertions, 92 deletions
diff --git a/netlib/tcp.py b/netlib/tcp.py
index d0ca09f3..56cc0dea 100644
--- a/netlib/tcp.py
+++ b/netlib/tcp.py
@@ -1,4 +1,4 @@
-import select, socket, threading, traceback, sys, time
+import select, socket, threading, sys, time, traceback
from OpenSSL import SSL
import certutils
@@ -84,13 +84,14 @@ class _FileLike:
def reset_timestamps(self):
self.first_byte_timestamp = None
+
class Writer(_FileLike):
def flush(self):
- try:
- if hasattr(self.o, "flush"):
+ if hasattr(self.o, "flush"):
+ try:
self.o.flush()
- except socket.error, v:
- raise NetLibDisconnect(str(v))
+ except socket.error, v:
+ raise NetLibDisconnect(str(v))
def write(self, v):
if v:
diff --git a/netlib/test.py b/netlib/test.py
new file mode 100644
index 00000000..2f72f979
--- /dev/null
+++ b/netlib/test.py
@@ -0,0 +1,67 @@
+import threading, Queue, cStringIO
+import tcp
+
+class ServerThread(threading.Thread):
+ def __init__(self, server):
+ self.server = server
+ threading.Thread.__init__(self)
+
+ def run(self):
+ self.server.serve_forever()
+
+ def shutdown(self):
+ self.server.shutdown()
+
+
+class ServerTestBase:
+ @classmethod
+ def setupAll(cls):
+ cls.q = Queue.Queue()
+ s = cls.makeserver()
+ cls.port = s.port
+ cls.server = ServerThread(s)
+ cls.server.start()
+
+ @classmethod
+ def teardownAll(cls):
+ cls.server.shutdown()
+
+
+ @property
+ def last_handler(self):
+ return self.server.server.last_handler
+
+
+class TServer(tcp.TCPServer):
+ def __init__(self, ssl, q, handler_klass, addr=("127.0.0.1", 0)):
+ """
+ ssl: A {cert, key, v3_only} dict.
+ """
+ tcp.TCPServer.__init__(self, addr)
+ self.ssl, self.q = ssl, q
+ self.handler_klass = handler_klass
+ self.last_handler = None
+
+ def handle_connection(self, request, client_address):
+ h = self.handler_klass(request, client_address, self)
+ self.last_handler = h
+ if self.ssl:
+ if self.ssl["v3_only"]:
+ method = tcp.SSLv3_METHOD
+ options = tcp.OP_NO_SSLv2|tcp.OP_NO_TLSv1
+ else:
+ method = tcp.SSLv23_METHOD
+ options = None
+ h.convert_to_ssl(
+ self.ssl["cert"],
+ self.ssl["key"],
+ method = method,
+ options = options,
+ )
+ h.handle()
+ h.finish()
+
+ def handle_error(self, request, client_address):
+ s = cStringIO.StringIO()
+ tcp.TCPServer.handle_error(self, request, client_address, s)
+ self.q.put(s.getvalue())
diff --git a/test/test_certutils.py b/test/test_certutils.py
index 582fb9c4..334a6be4 100644
--- a/test/test_certutils.py
+++ b/test/test_certutils.py
@@ -30,6 +30,7 @@ class TestCertStore:
ca = os.path.join(d, "ca")
assert certutils.dummy_ca(ca)
c = certutils.CertStore()
+ assert not c.get_cert("../foo.com", [])
assert not c.get_cert("foo.com", [])
assert c.get_cert("foo.com", [], ca)
assert c.get_cert("foo.com", [], ca)
diff --git a/test/test_tcp.py b/test/test_tcp.py
index 0417aa21..ce06ad66 100644
--- a/test/test_tcp.py
+++ b/test/test_tcp.py
@@ -1,38 +1,7 @@
import cStringIO, threading, Queue, time
-from netlib import tcp, certutils
+from netlib import tcp, certutils, test
import tutils
-class ServerThread(threading.Thread):
- def __init__(self, server):
- self.server = server
- threading.Thread.__init__(self)
-
- def run(self):
- self.server.serve_forever()
-
- def shutdown(self):
- self.server.shutdown()
-
-
-class ServerTestBase:
- @classmethod
- def setupAll(cls):
- cls.q = Queue.Queue()
- s = cls.makeserver()
- cls.port = s.port
- cls.server = ServerThread(s)
- cls.server.start()
-
- @classmethod
- def teardownAll(cls):
- cls.server.shutdown()
-
-
- @property
- def last_handler(self):
- return self.server.server.last_handler
-
-
class SNIHandler(tcp.BaseHandler):
sni = None
def handle_sni(self, connection):
@@ -88,43 +57,10 @@ class TimeoutHandler(tcp.BaseHandler):
self.timeout = True
-class TServer(tcp.TCPServer):
- def __init__(self, addr, ssl, q, handler_klass, v3_only=False):
- tcp.TCPServer.__init__(self, addr)
- self.ssl, self.q = ssl, q
- self.v3_only = v3_only
- self.handler_klass = handler_klass
- self.last_handler = None
-
- def handle_connection(self, request, client_address):
- h = self.handler_klass(request, client_address, self)
- self.last_handler = h
- if self.ssl:
- if self.v3_only:
- method = tcp.SSLv3_METHOD
- options = tcp.OP_NO_SSLv2|tcp.OP_NO_TLSv1
- else:
- method = tcp.SSLv23_METHOD
- options = None
- h.convert_to_ssl(
- tutils.test_data.path("data/server.crt"),
- tutils.test_data.path("data/server.key"),
- method = method,
- options = options,
- )
- h.handle()
- h.finish()
-
- def handle_error(self, request, client_address):
- s = cStringIO.StringIO()
- tcp.TCPServer.handle_error(self, request, client_address, s)
- self.q.put(s.getvalue())
-
-
-class TestServer(ServerTestBase):
+class TestServer(test.ServerTestBase):
@classmethod
def makeserver(cls):
- return TServer(("127.0.0.1", 0), False, cls.q, EchoHandler)
+ return test.TServer(False, cls.q, EchoHandler)
def test_echo(self):
testval = "echo!\n"
@@ -135,10 +71,10 @@ class TestServer(ServerTestBase):
assert c.rfile.readline() == testval
-class TestDisconnect(ServerTestBase):
+class TestDisconnect(test.ServerTestBase):
@classmethod
def makeserver(cls):
- return TServer(("127.0.0.1", 0), False, cls.q, EchoHandler)
+ return test.TServer(False, cls.q, EchoHandler)
def test_echo(self):
testval = "echo!\n"
@@ -149,10 +85,18 @@ class TestDisconnect(ServerTestBase):
assert c.rfile.readline() == testval
-class TestServerSSL(ServerTestBase):
+class TestServerSSL(test.ServerTestBase):
@classmethod
def makeserver(cls):
- return TServer(("127.0.0.1", 0), True, cls.q, EchoHandler)
+ return test.TServer(
+ dict(
+ cert = tutils.test_data.path("data/server.crt"),
+ key = tutils.test_data.path("data/server.key"),
+ v3_only = False
+ ),
+ cls.q,
+ EchoHandler
+ )
def test_echo(self):
c = tcp.TCPClient("127.0.0.1", self.port)
@@ -167,10 +111,19 @@ class TestServerSSL(ServerTestBase):
assert certutils.get_remote_cert("127.0.0.1", self.port, None).digest("sha1")
-class TestSSLv3Only(ServerTestBase):
+class TestSSLv3Only(test.ServerTestBase):
+ v3_only = True
@classmethod
def makeserver(cls):
- return TServer(("127.0.0.1", 0), True, cls.q, EchoHandler, True)
+ return test.TServer(
+ dict(
+ cert = tutils.test_data.path("data/server.crt"),
+ key = tutils.test_data.path("data/server.key"),
+ v3_only = True
+ ),
+ cls.q,
+ EchoHandler,
+ )
def test_failure(self):
c = tcp.TCPClient("127.0.0.1", self.port)
@@ -178,10 +131,18 @@ class TestSSLv3Only(ServerTestBase):
tutils.raises(tcp.NetLibError, c.convert_to_ssl, sni="foo.com", method=tcp.TLSv1_METHOD)
-class TestSSLClientCert(ServerTestBase):
+class TestSSLClientCert(test.ServerTestBase):
@classmethod
def makeserver(cls):
- return TServer(("127.0.0.1", 0), True, cls.q, CertHandler)
+ return test.TServer(
+ dict(
+ cert = tutils.test_data.path("data/server.crt"),
+ key = tutils.test_data.path("data/server.key"),
+ v3_only = False
+ ),
+ cls.q,
+ CertHandler
+ )
def test_clientcert(self):
c = tcp.TCPClient("127.0.0.1", self.port)
@@ -199,10 +160,18 @@ class TestSSLClientCert(ServerTestBase):
)
-class TestSNI(ServerTestBase):
+class TestSNI(test.ServerTestBase):
@classmethod
def makeserver(cls):
- return TServer(("127.0.0.1", 0), True, cls.q, SNIHandler)
+ return test.TServer(
+ dict(
+ cert = tutils.test_data.path("data/server.crt"),
+ key = tutils.test_data.path("data/server.key"),
+ v3_only = False
+ ),
+ cls.q,
+ SNIHandler
+ )
def test_echo(self):
c = tcp.TCPClient("127.0.0.1", self.port)
@@ -211,10 +180,18 @@ class TestSNI(ServerTestBase):
assert c.rfile.readline() == "foo.com"
-class TestSSLDisconnect(ServerTestBase):
+class TestSSLDisconnect(test.ServerTestBase):
@classmethod
def makeserver(cls):
- return TServer(("127.0.0.1", 0), True, cls.q, DisconnectHandler)
+ return test.TServer(
+ dict(
+ cert = tutils.test_data.path("data/server.crt"),
+ key = tutils.test_data.path("data/server.key"),
+ v3_only = False
+ ),
+ cls.q,
+ DisconnectHandler
+ )
def test_echo(self):
c = tcp.TCPClient("127.0.0.1", self.port)
@@ -227,10 +204,10 @@ class TestSSLDisconnect(ServerTestBase):
tutils.raises(Queue.Empty, self.q.get_nowait)
-class TestDisconnect(ServerTestBase):
+class TestSSLDisconnect(test.ServerTestBase):
@classmethod
def makeserver(cls):
- return TServer(("127.0.0.1", 0), False, cls.q, DisconnectHandler)
+ return test.TServer(False, cls.q, DisconnectHandler)
def test_echo(self):
c = tcp.TCPClient("127.0.0.1", self.port)
@@ -242,10 +219,10 @@ class TestDisconnect(ServerTestBase):
c.close()
-class TestServerTimeOut(ServerTestBase):
+class TestServerTimeOut(test.ServerTestBase):
@classmethod
def makeserver(cls):
- return TServer(("127.0.0.1", 0), False, cls.q, TimeoutHandler)
+ return test.TServer(False, cls.q, TimeoutHandler)
def test_timeout(self):
c = tcp.TCPClient("127.0.0.1", self.port)
@@ -254,10 +231,10 @@ class TestServerTimeOut(ServerTestBase):
assert self.last_handler.timeout
-class TestTimeOut(ServerTestBase):
+class TestTimeOut(test.ServerTestBase):
@classmethod
def makeserver(cls):
- return TServer(("127.0.0.1", 0), False, cls.q, HangHandler)
+ return test.TServer(False, cls.q, HangHandler)
def test_timeout(self):
c = tcp.TCPClient("127.0.0.1", self.port)
@@ -266,10 +243,18 @@ class TestTimeOut(ServerTestBase):
tutils.raises(tcp.NetLibTimeout, c.rfile.read, 10)
-class TestSSLTimeOut(ServerTestBase):
+class TestSSLTimeOut(test.ServerTestBase):
@classmethod
def makeserver(cls):
- return TServer(("127.0.0.1", 0), True, cls.q, HangHandler)
+ return test.TServer(
+ dict(
+ cert = tutils.test_data.path("data/server.crt"),
+ key = tutils.test_data.path("data/server.key"),
+ v3_only = False
+ ),
+ cls.q,
+ HangHandler
+ )
def test_timeout_client(self):
c = tcp.TCPClient("127.0.0.1", self.port)