aboutsummaryrefslogtreecommitdiffstats
path: root/netlib
diff options
context:
space:
mode:
Diffstat (limited to 'netlib')
-rw-r--r--netlib/tcp.py11
-rw-r--r--netlib/test.py67
2 files changed, 73 insertions, 5 deletions
diff --git a/netlib/tcp.py b/netlib/tcp.py
index d0ca09f3..56cc0dea 100644
--- a/netlib/tcp.py
+++ b/netlib/tcp.py
@@ -1,4 +1,4 @@
-import select, socket, threading, traceback, sys, time
+import select, socket, threading, sys, time, traceback
from OpenSSL import SSL
import certutils
@@ -84,13 +84,14 @@ class _FileLike:
def reset_timestamps(self):
self.first_byte_timestamp = None
+
class Writer(_FileLike):
def flush(self):
- try:
- if hasattr(self.o, "flush"):
+ if hasattr(self.o, "flush"):
+ try:
self.o.flush()
- except socket.error, v:
- raise NetLibDisconnect(str(v))
+ except socket.error, v:
+ raise NetLibDisconnect(str(v))
def write(self, v):
if v:
diff --git a/netlib/test.py b/netlib/test.py
new file mode 100644
index 00000000..2f72f979
--- /dev/null
+++ b/netlib/test.py
@@ -0,0 +1,67 @@
+import threading, Queue, cStringIO
+import tcp
+
+class ServerThread(threading.Thread):
+ def __init__(self, server):
+ self.server = server
+ threading.Thread.__init__(self)
+
+ def run(self):
+ self.server.serve_forever()
+
+ def shutdown(self):
+ self.server.shutdown()
+
+
+class ServerTestBase:
+ @classmethod
+ def setupAll(cls):
+ cls.q = Queue.Queue()
+ s = cls.makeserver()
+ cls.port = s.port
+ cls.server = ServerThread(s)
+ cls.server.start()
+
+ @classmethod
+ def teardownAll(cls):
+ cls.server.shutdown()
+
+
+ @property
+ def last_handler(self):
+ return self.server.server.last_handler
+
+
+class TServer(tcp.TCPServer):
+ def __init__(self, ssl, q, handler_klass, addr=("127.0.0.1", 0)):
+ """
+ ssl: A {cert, key, v3_only} dict.
+ """
+ tcp.TCPServer.__init__(self, addr)
+ self.ssl, self.q = ssl, q
+ self.handler_klass = handler_klass
+ self.last_handler = None
+
+ def handle_connection(self, request, client_address):
+ h = self.handler_klass(request, client_address, self)
+ self.last_handler = h
+ if self.ssl:
+ if self.ssl["v3_only"]:
+ method = tcp.SSLv3_METHOD
+ options = tcp.OP_NO_SSLv2|tcp.OP_NO_TLSv1
+ else:
+ method = tcp.SSLv23_METHOD
+ options = None
+ h.convert_to_ssl(
+ self.ssl["cert"],
+ self.ssl["key"],
+ method = method,
+ options = options,
+ )
+ h.handle()
+ h.finish()
+
+ def handle_error(self, request, client_address):
+ s = cStringIO.StringIO()
+ tcp.TCPServer.handle_error(self, request, client_address, s)
+ self.q.put(s.getvalue())