aboutsummaryrefslogtreecommitdiffstats
path: root/netlib
diff options
context:
space:
mode:
Diffstat (limited to 'netlib')
-rw-r--r--netlib/certutils.py2
-rw-r--r--netlib/tcp.py56
-rw-r--r--netlib/test.py11
3 files changed, 48 insertions, 21 deletions
diff --git a/netlib/certutils.py b/netlib/certutils.py
index 0349bec7..94294f6e 100644
--- a/netlib/certutils.py
+++ b/netlib/certutils.py
@@ -237,7 +237,7 @@ class SSLCert:
def get_remote_cert(host, port, sni):
- c = tcp.TCPClient(host, port)
+ c = tcp.TCPClient((host, port))
c.connect()
c.convert_to_ssl(sni=sni)
return c.cert
diff --git a/netlib/tcp.py b/netlib/tcp.py
index e48f4f6b..bad166d0 100644
--- a/netlib/tcp.py
+++ b/netlib/tcp.py
@@ -173,6 +173,35 @@ class Reader(_FileLike):
return result
+class Address(tuple):
+ """
+ This class wraps an IPv4/IPv6 tuple to provide named attributes and ipv6 information.
+ """
+ def __new__(cls, address, use_ipv6=False):
+ a = super(Address, cls).__new__(cls, tuple(address))
+ a.family = socket.AF_INET6 if use_ipv6 else socket.AF_INET
+ return a
+
+ @classmethod
+ def wrap(cls, t):
+ if isinstance(t, cls):
+ return t
+ else:
+ return cls(t)
+
+ @property
+ def host(self):
+ return self[0]
+
+ @property
+ def port(self):
+ return self[1]
+
+ @property
+ def is_ipv6(self):
+ return self.family == socket.AF_INET6
+
+
class SocketCloseMixin:
def finish(self):
self.finished = True
@@ -209,10 +238,9 @@ class SocketCloseMixin:
class TCPClient(SocketCloseMixin):
rbufsize = -1
wbufsize = -1
- def __init__(self, host, port, source_address=None, use_ipv6=False):
- self.host, self.port = host, port
+ def __init__(self, address, source_address=None):
+ self.address = Address.wrap(address)
self.source_address = source_address
- self.use_ipv6 = use_ipv6
self.connection, self.rfile, self.wfile = None, None, None
self.cert = None
self.ssl_established = False
@@ -245,14 +273,14 @@ class TCPClient(SocketCloseMixin):
def connect(self):
try:
- connection = socket.socket(socket.AF_INET6 if self.use_ipv6 else socket.AF_INET, socket.SOCK_STREAM)
+ connection = socket.socket(self.address.family, socket.SOCK_STREAM)
if self.source_address:
connection.bind(self.source_address)
- connection.connect((self.host, self.port))
+ connection.connect(self.address)
self.rfile = Reader(connection.makefile('rb', self.rbufsize))
self.wfile = Writer(connection.makefile('wb', self.wbufsize))
except (socket.error, IOError), err:
- raise NetLibError('Error connecting to "%s": %s' % (self.host, err))
+ raise NetLibError('Error connecting to "%s": %s' % (self.address[0], err))
self.connection = connection
def settimeout(self, n):
@@ -269,8 +297,9 @@ class BaseHandler(SocketCloseMixin):
"""
rbufsize = -1
wbufsize = -1
- def __init__(self, connection):
+ def __init__(self, connection, address):
self.connection = connection
+ self.address = Address.wrap(address)
self.rfile = Reader(self.connection.makefile('rb', self.rbufsize))
self.wfile = Writer(self.connection.makefile('wb', self.wbufsize))
@@ -339,19 +368,18 @@ class BaseHandler(SocketCloseMixin):
class TCPServer:
request_queue_size = 20
- def __init__(self, server_address, use_ipv6=False):
- self.server_address = server_address
- self.use_ipv6 = use_ipv6
+ def __init__(self, address):
+ self.address = Address.wrap(address)
self.__is_shut_down = threading.Event()
self.__shutdown_request = False
- self.socket = socket.socket(socket.AF_INET6 if self.use_ipv6 else socket.AF_INET, socket.SOCK_STREAM)
+ self.socket = socket.socket(self.address.family, socket.SOCK_STREAM)
self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
- self.socket.bind(self.server_address)
- self.server_address = self.socket.getsockname()
- self.port = self.server_address[1]
+ self.socket.bind(self.address)
+ self.address = Address.wrap(self.socket.getsockname())
self.socket.listen(self.request_queue_size)
def connection_thread(self, connection, client_address):
+ client_address = Address(client_address)
try:
self.handle_client_connection(connection, client_address)
except:
diff --git a/netlib/test.py b/netlib/test.py
index f5599082..565b97cd 100644
--- a/netlib/test.py
+++ b/netlib/test.py
@@ -17,19 +17,18 @@ class ServerTestBase:
ssl = None
handler = None
addr = ("localhost", 0)
- use_ipv6 = False
@classmethod
def setupAll(cls):
cls.q = Queue.Queue()
s = cls.makeserver()
- cls.port = s.port
+ cls.port = s.address.port
cls.server = ServerThread(s)
cls.server.start()
@classmethod
def makeserver(cls):
- return TServer(cls.ssl, cls.q, cls.handler, cls.addr, cls.use_ipv6)
+ return TServer(cls.ssl, cls.q, cls.handler, cls.addr)
@classmethod
def teardownAll(cls):
@@ -41,17 +40,17 @@ class ServerTestBase:
class TServer(tcp.TCPServer):
- def __init__(self, ssl, q, handler_klass, addr, use_ipv6):
+ def __init__(self, ssl, q, handler_klass, addr):
"""
ssl: A {cert, key, v3_only} dict.
"""
- tcp.TCPServer.__init__(self, addr, use_ipv6=use_ipv6)
+ tcp.TCPServer.__init__(self, addr)
self.ssl, self.q = ssl, q
self.handler_klass = handler_klass
self.last_handler = None
def handle_client_connection(self, request, client_address):
- h = self.handler_klass(request)
+ h = self.handler_klass(request, client_address)
self.last_handler = h
if self.ssl:
cert = certutils.SSLCert.from_pem(