aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--netlib/tcp.py34
-rw-r--r--test/test_tcp.py34
2 files changed, 64 insertions, 4 deletions
diff --git a/netlib/tcp.py b/netlib/tcp.py
index 6ba58d86..b7f2b3bc 100644
--- a/netlib/tcp.py
+++ b/netlib/tcp.py
@@ -66,7 +66,7 @@ class FileLike:
if v:
try:
return self.o.sendall(v)
- except SSL.SysCallError:
+ except SSL.Error:
raise NetLibDisconnect()
def readline(self, size = None):
@@ -125,6 +125,20 @@ class TCPClient:
raise NetLibError('Error connecting to "%s": %s' % (self.host, err))
self.connection = connection
+ def close(self):
+ """
+ Does a hard close of the socket, i.e. a shutdown, followed by a close.
+ """
+ try:
+ if self.ssl_established:
+ self.connection.shutdown()
+ else:
+ self.connection.shutdown(socket.SHUT_RDWR)
+ self.connection.close()
+ except (socket.error, SSL.Error):
+ # Socket probably already closed
+ pass
+
class BaseHandler:
"""
@@ -170,7 +184,7 @@ class BaseHandler:
self.wfile.flush()
self.wfile.close()
self.rfile.close()
- self.connection.close()
+ self.close()
except socket.error:
# Remote has disconnected
pass
@@ -195,6 +209,20 @@ class BaseHandler:
def handle(self): # pragma: no cover
raise NotImplementedError
+ def close(self):
+ """
+ Does a hard close of the socket, i.e. a shutdown, followed by a close.
+ """
+ try:
+ if self.ssl_established:
+ self.connection.shutdown()
+ else:
+ self.connection.shutdown(socket.SHUT_RDWR)
+ self.connection.close()
+ except (socket.error, SSL.Error):
+ # Socket probably already closed
+ pass
+
class TCPServer:
request_queue_size = 20
@@ -252,7 +280,7 @@ class TCPServer:
Called when handle_connection raises an exception.
"""
# If a thread has persisted after interpreter exit, the module might be
- # none.
+ # none.
if traceback:
exc = traceback.format_exc()
print >> fp, '-'*40
diff --git a/test/test_tcp.py b/test/test_tcp.py
index 359890d5..cb27c63b 100644
--- a/test/test_tcp.py
+++ b/test/test_tcp.py
@@ -54,7 +54,7 @@ class EchoHandler(tcp.BaseHandler):
class DisconnectHandler(tcp.BaseHandler):
def handle(self):
- self.finish()
+ self.close()
class TServer(tcp.TCPServer):
@@ -102,6 +102,20 @@ class TestServer(ServerTestBase):
assert c.rfile.readline() == testval
+class TestDisconnect(ServerTestBase):
+ @classmethod
+ def makeserver(cls):
+ return TServer(("127.0.0.1", 0), False, cls.q, EchoHandler)
+
+ def test_echo(self):
+ testval = "echo!\n"
+ c = tcp.TCPClient("127.0.0.1", self.port)
+ c.connect()
+ c.wfile.write(testval)
+ c.wfile.flush()
+ assert c.rfile.readline() == testval
+
+
class TestServerSSL(ServerTestBase):
@classmethod
def makeserver(cls):
@@ -154,6 +168,24 @@ class TestSSLDisconnect(ServerTestBase):
c.convert_to_ssl()
# Excercise SSL.ZeroReturnError
c.rfile.read(10)
+ c.close()
+ tutils.raises(tcp.NetLibDisconnect, c.wfile.write, "foo")
+ tutils.raises(Queue.Empty, self.q.get_nowait)
+
+
+class TestDisconnect(ServerTestBase):
+ @classmethod
+ def makeserver(cls):
+ return TServer(("127.0.0.1", 0), False, cls.q, DisconnectHandler)
+
+ def test_echo(self):
+ c = tcp.TCPClient("127.0.0.1", self.port)
+ c.connect()
+ # Excercise SSL.ZeroReturnError
+ c.rfile.read(10)
+ c.wfile.write("foo")
+ c.close()
+ c.close()
class TestTCPClient: