diff options
-rw-r--r-- | netlib/tcp.py | 31 | ||||
-rw-r--r-- | test/test_tcp.py | 7 |
2 files changed, 20 insertions, 18 deletions
diff --git a/netlib/tcp.py b/netlib/tcp.py index 91b0c742..3c5c89b7 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -48,29 +48,30 @@ class FileLike: class TCPClient: - def __init__(self, ssl, host, port, clientcert): - self.ssl, self.host, self.port, self.clientcert = ssl, host, port, clientcert + def __init__(self, host, port): + self.host, self.port = host, port self.connection, self.rfile, self.wfile = None, None, None self.cert = None + def convert_to_ssl(self, clientcert=None): + context = SSL.Context(SSL.SSLv23_METHOD) + if clientcert: + context.use_certificate_file(self.clientcert) + self.connection = SSL.Connection(context, self.connection) + self.connection.set_connect_state() + self.cert = self.connection.get_peer_certificate() + self.rfile = FileLike(self.connection) + self.wfile = FileLike(self.connection) + def connect(self): try: addr = socket.gethostbyname(self.host) - server = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - if self.ssl: - context = SSL.Context(SSL.SSLv23_METHOD) - if self.clientcert: - context.use_certificate_file(self.clientcert) - server = SSL.Connection(context, server) - server.connect((addr, self.port)) - if self.ssl: - self.cert = server.get_peer_certificate() - self.rfile, self.wfile = FileLike(server), FileLike(server) - else: - self.rfile, self.wfile = server.makefile('rb'), server.makefile('wb') + connection = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + connection.connect((addr, self.port)) + self.rfile, self.wfile = connection.makefile('rb'), connection.makefile('wb') except socket.error, err: raise NetLibError('Error connecting to "%s": %s' % (self.host, err)) - self.connection = server + self.connection = connection class BaseHandler: diff --git a/test/test_tcp.py b/test/test_tcp.py index 1bad9a04..26286bc4 100644 --- a/test/test_tcp.py +++ b/test/test_tcp.py @@ -66,7 +66,7 @@ class TestServer(ServerTestBase): def test_echo(self): testval = "echo!\n" - c = tcp.TCPClient(False, "127.0.0.1", self.port, None) + c = tcp.TCPClient("127.0.0.1", self.port) c.connect() c.wfile.write(testval) c.wfile.flush() @@ -82,8 +82,9 @@ class TestServerSSL(ServerTestBase): return s def test_echo(self): - c = tcp.TCPClient(True, "127.0.0.1", self.port, None) + c = tcp.TCPClient("127.0.0.1", self.port) c.connect() + c.convert_to_ssl() testval = "echo!\n" c.wfile.write(testval) c.wfile.flush() @@ -92,7 +93,7 @@ class TestServerSSL(ServerTestBase): class TestTCPClient: def test_conerr(self): - c = tcp.TCPClient(True, "127.0.0.1", 0, None) + c = tcp.TCPClient("127.0.0.1", 0) tutils.raises(tcp.NetLibError, c.connect) |