diff options
Diffstat (limited to 'netlib')
-rw-r--r-- | netlib/certutils.py | 2 | ||||
-rw-r--r-- | netlib/odict.py | 4 | ||||
-rw-r--r-- | netlib/tcp.py | 186 | ||||
-rw-r--r-- | netlib/test.py | 11 | ||||
-rw-r--r-- | netlib/wsgi.py | 15 |
5 files changed, 124 insertions, 94 deletions
diff --git a/netlib/certutils.py b/netlib/certutils.py index 0349bec7..94294f6e 100644 --- a/netlib/certutils.py +++ b/netlib/certutils.py @@ -237,7 +237,7 @@ class SSLCert: def get_remote_cert(host, port, sni): - c = tcp.TCPClient(host, port) + c = tcp.TCPClient((host, port)) c.connect() c.convert_to_ssl(sni=sni) return c.cert diff --git a/netlib/odict.py b/netlib/odict.py index 0759a5bf..46b74e8e 100644 --- a/netlib/odict.py +++ b/netlib/odict.py @@ -1,5 +1,6 @@ import re, copy + def safe_subn(pattern, repl, target, *args, **kwargs): """ There are Unicode conversion problems with re.subn. We try to smooth @@ -98,6 +99,9 @@ class ODict: def _get_state(self): return [tuple(i) for i in self.lst] + def _load_state(self, state): + self.list = [list(i) for i in state] + @classmethod def _from_state(klass, state): return klass([list(i) for i in state]) diff --git a/netlib/tcp.py b/netlib/tcp.py index 33f7ef3a..34e47999 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -173,16 +173,88 @@ class Reader(_FileLike): return result -class TCPClient: +class Address(object): + """ + This class wraps an IPv4/IPv6 tuple to provide named attributes and ipv6 information. + """ + def __init__(self, address, use_ipv6=False): + self.address = tuple(address) + self.use_ipv6 = use_ipv6 + + @classmethod + def wrap(cls, t): + if isinstance(t, cls): + return t + else: + return cls(t) + + def __call__(self): + return self.address + + @property + def host(self): + return self.address[0] + + @property + def port(self): + return self.address[1] + + @property + def use_ipv6(self): + return self.family == socket.AF_INET6 + + @use_ipv6.setter + def use_ipv6(self, b): + self.family = socket.AF_INET6 if b else socket.AF_INET + + def __eq__(self, other): + other = Address.wrap(other) + return (self.address, self.family) == (other.address, other.family) + + +class SocketCloseMixin(object): + def finish(self): + self.finished = True + try: + if not getattr(self.wfile, "closed", False): + self.wfile.flush() + self.close() + self.wfile.close() + self.rfile.close() + except (socket.error, NetLibDisconnect): + # Remote has disconnected + pass + + def close(self): + """ + Does a hard close of the socket, i.e. a shutdown, followed by a close. + """ + try: + if self.ssl_established: + self.connection.shutdown() + self.connection.sock_shutdown(socket.SHUT_WR) + else: + self.connection.shutdown(socket.SHUT_WR) + #Section 4.2.2.13 of RFC 1122 tells us that a close() with any pending readable data could lead to an immediate RST being sent. + #http://ia600609.us.archive.org/22/items/TheUltimateSo_lingerPageOrWhyIsMyTcpNotReliable/the-ultimate-so_linger-page-or-why-is-my-tcp-not-reliable.html + while self.connection.recv(4096): + pass + self.connection.close() + except (socket.error, SSL.Error, IOError): + # Socket probably already closed + pass + + +class TCPClient(SocketCloseMixin): rbufsize = -1 wbufsize = -1 - def __init__(self, host, port, source_address=None, use_ipv6=False): - self.host, self.port = host, port - self.source_address = source_address - self.use_ipv6 = use_ipv6 + def __init__(self, address, source_address=None): + self.address = Address.wrap(address) + self.source_address = Address.wrap(source_address) if source_address else None self.connection, self.rfile, self.wfile = None, None, None self.cert = None self.ssl_established = False + self.sni = None def convert_to_ssl(self, cert=None, sni=None, method=TLSv1_METHOD, options=None): """ @@ -200,6 +272,7 @@ class TCPClient: self.connection = SSL.Connection(context, self.connection) self.ssl_established = True if sni: + self.sni = sni self.connection.set_tlsext_host_name(sni) self.connection.set_connect_state() try: @@ -212,14 +285,14 @@ class TCPClient: def connect(self): try: - connection = socket.socket(socket.AF_INET6 if self.use_ipv6 else socket.AF_INET, socket.SOCK_STREAM) + connection = socket.socket(self.address.family, socket.SOCK_STREAM) if self.source_address: - connection.bind(self.source_address) - connection.connect((self.host, self.port)) + connection.bind(self.source_address()) + connection.connect(self.address()) self.rfile = Reader(connection.makefile('rb', self.rbufsize)) self.wfile = Writer(connection.makefile('wb', self.wbufsize)) except (socket.error, IOError), err: - raise NetLibError('Error connecting to "%s": %s' % (self.host, err)) + raise NetLibError('Error connecting to "%s": %s' % (self.address.host, err)) self.connection = connection def settimeout(self, n): @@ -228,43 +301,24 @@ class TCPClient: def gettimeout(self): return self.connection.gettimeout() - def close(self): - """ - Does a hard close of the socket, i.e. a shutdown, followed by a close. - """ - try: - if self.ssl_established: - self.connection.shutdown() - self.connection.sock_shutdown(socket.SHUT_WR) - else: - self.connection.shutdown(socket.SHUT_WR) - #Section 4.2.2.13 of RFC 1122 tells us that a close() with any pending readable data could lead to an immediate RST being sent. - #http://ia600609.us.archive.org/22/items/TheUltimateSo_lingerPageOrWhyIsMyTcpNotReliable/the-ultimate-so_linger-page-or-why-is-my-tcp-not-reliable.html - while self.connection.recv(4096): - pass - self.connection.close() - except (socket.error, SSL.Error, IOError): - # Socket probably already closed - pass - -class BaseHandler: +class BaseHandler(SocketCloseMixin): """ The instantiator is expected to call the handle() and finish() methods. """ rbufsize = -1 wbufsize = -1 - def __init__(self, connection, client_address, server): + + def __init__(self, connection, address, server): self.connection = connection + self.address = Address.wrap(address) + self.server = server self.rfile = Reader(self.connection.makefile('rb', self.rbufsize)) self.wfile = Writer(self.connection.makefile('wb', self.wbufsize)) - self.client_address = client_address - self.server = server self.finished = False self.ssl_established = False - self.clientcert = None def convert_to_ssl(self, cert, key, method=SSLv23_METHOD, options=None, handle_sni=None, request_client_cert=False, cipher_list=None): @@ -318,66 +372,34 @@ class BaseHandler: self.rfile.set_descriptor(self.connection) self.wfile.set_descriptor(self.connection) - def finish(self): - self.finished = True - try: - if not getattr(self.wfile, "closed", False): - self.wfile.flush() - self.close() - self.wfile.close() - self.rfile.close() - except (socket.error, NetLibDisconnect): - # Remote has disconnected - pass - def handle(self): # pragma: no cover raise NotImplementedError def settimeout(self, n): self.connection.settimeout(n) - def close(self): - """ - Does a hard close of the socket, i.e. a shutdown, followed by a close. - """ - try: - if self.ssl_established: - self.connection.shutdown() - self.connection.sock_shutdown(socket.SHUT_WR) - else: - self.connection.shutdown(socket.SHUT_WR) - # Section 4.2.2.13 of RFC 1122 tells us that a close() with any - # pending readable data could lead to an immediate RST being sent. - # http://ia600609.us.archive.org/22/items/TheUltimateSo_lingerPageOrWhyIsMyTcpNotReliable/the-ultimate-so_linger-page-or-why-is-my-tcp-not-reliable.html - while self.connection.recv(4096): - pass - except (socket.error, SSL.Error): - # Socket probably already closed - pass - self.connection.close() class TCPServer: request_queue_size = 20 - def __init__(self, server_address, use_ipv6=False): - self.server_address = server_address - self.use_ipv6 = use_ipv6 + def __init__(self, address): + self.address = Address.wrap(address) self.__is_shut_down = threading.Event() self.__shutdown_request = False - self.socket = socket.socket(socket.AF_INET6 if self.use_ipv6 else socket.AF_INET, socket.SOCK_STREAM) + self.socket = socket.socket(self.address.family, socket.SOCK_STREAM) self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - self.socket.bind(self.server_address) - self.server_address = self.socket.getsockname() - self.port = self.server_address[1] + self.socket.bind(self.address()) + self.address = Address.wrap(self.socket.getsockname()) self.socket.listen(self.request_queue_size) - def request_thread(self, request, client_address): + def connection_thread(self, connection, client_address): + client_address = Address(client_address) try: - self.handle_connection(request, client_address) - request.close() + self.handle_client_connection(connection, client_address) except: - self.handle_error(request, client_address) - request.close() + self.handle_error(connection, client_address) + finally: + connection.close() def serve_forever(self, poll_interval=0.1): self.__is_shut_down.clear() @@ -391,10 +413,10 @@ class TCPServer: else: raise if self.socket in r: - request, client_address = self.socket.accept() + connection, client_address = self.socket.accept() t = threading.Thread( - target = self.request_thread, - args = (request, client_address) + target = self.connection_thread, + args = (connection, client_address) ) t.setDaemon(1) t.start() @@ -410,18 +432,18 @@ class TCPServer: def handle_error(self, request, client_address, fp=sys.stderr): """ - Called when handle_connection raises an exception. + Called when handle_client_connection raises an exception. """ # If a thread has persisted after interpreter exit, the module might be # none. if traceback: exc = traceback.format_exc() print >> fp, '-'*40 - print >> fp, "Error in processing of request from %s:%s"%client_address + print >> fp, "Error in processing of request from %s:%s" % (client_address.host, client_address.port) print >> fp, exc print >> fp, '-'*40 - def handle_connection(self, request, client_address): # pragma: no cover + def handle_client_connection(self, conn, client_address): # pragma: no cover """ Called after client connection. """ diff --git a/netlib/test.py b/netlib/test.py index 85a56739..2f6a7107 100644 --- a/netlib/test.py +++ b/netlib/test.py @@ -17,19 +17,18 @@ class ServerTestBase: ssl = None handler = None addr = ("localhost", 0) - use_ipv6 = False @classmethod def setupAll(cls): cls.q = Queue.Queue() s = cls.makeserver() - cls.port = s.port + cls.port = s.address.port cls.server = ServerThread(s) cls.server.start() @classmethod def makeserver(cls): - return TServer(cls.ssl, cls.q, cls.handler, cls.addr, cls.use_ipv6) + return TServer(cls.ssl, cls.q, cls.handler, cls.addr) @classmethod def teardownAll(cls): @@ -41,16 +40,16 @@ class ServerTestBase: class TServer(tcp.TCPServer): - def __init__(self, ssl, q, handler_klass, addr, use_ipv6): + def __init__(self, ssl, q, handler_klass, addr): """ ssl: A {cert, key, v3_only} dict. """ - tcp.TCPServer.__init__(self, addr, use_ipv6=use_ipv6) + tcp.TCPServer.__init__(self, addr) self.ssl, self.q = ssl, q self.handler_klass = handler_klass self.last_handler = None - def handle_connection(self, request, client_address): + def handle_client_connection(self, request, client_address): h = self.handler_klass(request, client_address, self) self.last_handler = h if self.ssl: diff --git a/netlib/wsgi.py b/netlib/wsgi.py index 647cb899..b576bdff 100644 --- a/netlib/wsgi.py +++ b/netlib/wsgi.py @@ -1,17 +1,22 @@ import cStringIO, urllib, time, traceback -import odict +import odict, tcp class ClientConn: def __init__(self, address): - self.address = address + self.address = tcp.Address.wrap(address) + + +class Flow: + def __init__(self, client_conn): + self.client_conn = client_conn class Request: def __init__(self, client_conn, scheme, method, path, headers, content): self.scheme, self.method, self.path = scheme, method, path self.headers, self.content = headers, content - self.client_conn = client_conn + self.flow = Flow(client_conn) def date_time_string(): @@ -60,8 +65,8 @@ class WSGIAdaptor: 'SERVER_PROTOCOL': "HTTP/1.1", } environ.update(extra) - if request.client_conn.address: - environ["REMOTE_ADDR"], environ["REMOTE_PORT"] = request.client_conn.address + if request.flow.client_conn.address: + environ["REMOTE_ADDR"], environ["REMOTE_PORT"] = request.flow.client_conn.address() for key, value in request.headers.items(): key = 'HTTP_' + key.upper().replace('-', '_') |