aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--netlib/tcp.py2
-rw-r--r--test/test_tcp.py42
2 files changed, 34 insertions, 10 deletions
diff --git a/netlib/tcp.py b/netlib/tcp.py
index 3c5c89b7..276d3162 100644
--- a/netlib/tcp.py
+++ b/netlib/tcp.py
@@ -59,6 +59,7 @@ class TCPClient:
context.use_certificate_file(self.clientcert)
self.connection = SSL.Connection(context, self.connection)
self.connection.set_connect_state()
+ self.connection.do_handshake()
self.cert = self.connection.get_peer_certificate()
self.rfile = FileLike(self.connection)
self.wfile = FileLike(self.connection)
@@ -95,6 +96,7 @@ class BaseHandler:
ctx.use_certificate_file(cert)
self.connection = SSL.Connection(ctx, self.connection)
self.connection.set_accept_state()
+ self.connection.do_handshake()
self.rfile = FileLike(self.connection)
self.wfile = FileLike(self.connection)
diff --git a/test/test_tcp.py b/test/test_tcp.py
index 26286bc4..a81632e7 100644
--- a/test/test_tcp.py
+++ b/test/test_tcp.py
@@ -25,13 +25,8 @@ class ServerTestBase:
cls.server.shutdown()
-class THandler(tcp.BaseHandler):
+class EchoHandler(tcp.BaseHandler):
def handle(self):
- if self.server.ssl:
- self.convert_to_ssl(
- tutils.test_data.path("data/server.crt"),
- tutils.test_data.path("data/server.key"),
- )
v = self.rfile.readline()
if v.startswith("echo"):
self.wfile.write(v)
@@ -40,13 +35,24 @@ class THandler(tcp.BaseHandler):
self.wfile.flush()
+class DisconnectHandler(tcp.BaseHandler):
+ def handle(self):
+ self.finish()
+
+
class TServer(tcp.TCPServer):
- def __init__(self, addr, ssl, q):
+ def __init__(self, addr, ssl, q, handler):
tcp.TCPServer.__init__(self, addr)
self.ssl, self.q = ssl, q
+ self.handler = handler
def handle_connection(self, request, client_address):
- h = THandler(request, client_address, self)
+ h = self.handler(request, client_address, self)
+ if self.ssl:
+ h.convert_to_ssl(
+ tutils.test_data.path("data/server.crt"),
+ tutils.test_data.path("data/server.key"),
+ )
h.handle()
h.finish()
@@ -60,7 +66,7 @@ class TestServer(ServerTestBase):
@classmethod
def makeserver(cls):
cls.q = Queue.Queue()
- s = TServer(("127.0.0.1", 0), False, cls.q)
+ s = TServer(("127.0.0.1", 0), False, cls.q, EchoHandler)
cls.port = s.port
return s
@@ -77,7 +83,7 @@ class TestServerSSL(ServerTestBase):
@classmethod
def makeserver(cls):
cls.q = Queue.Queue()
- s = TServer(("127.0.0.1", 0), True, cls.q)
+ s = TServer(("127.0.0.1", 0), True, cls.q, EchoHandler)
cls.port = s.port
return s
@@ -91,6 +97,22 @@ class TestServerSSL(ServerTestBase):
assert c.rfile.readline() == testval
+class TestSSLDisconnect(ServerTestBase):
+ @classmethod
+ def makeserver(cls):
+ cls.q = Queue.Queue()
+ s = TServer(("127.0.0.1", 0), True, cls.q, DisconnectHandler)
+ cls.port = s.port
+ return s
+
+ def test_echo(self):
+ c = tcp.TCPClient("127.0.0.1", self.port)
+ c.connect()
+ c.convert_to_ssl()
+ # Excercise SSL.ZeroReturnError
+ c.rfile.read(10)
+
+
class TestTCPClient:
def test_conerr(self):
c = tcp.TCPClient("127.0.0.1", 0)