aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--netlib/tcp.py31
-rw-r--r--test/test_tcp.py7
2 files changed, 20 insertions, 18 deletions
diff --git a/netlib/tcp.py b/netlib/tcp.py
index 91b0c742..3c5c89b7 100644
--- a/netlib/tcp.py
+++ b/netlib/tcp.py
@@ -48,29 +48,30 @@ class FileLike:
class TCPClient:
- def __init__(self, ssl, host, port, clientcert):
- self.ssl, self.host, self.port, self.clientcert = ssl, host, port, clientcert
+ def __init__(self, host, port):
+ self.host, self.port = host, port
self.connection, self.rfile, self.wfile = None, None, None
self.cert = None
+ def convert_to_ssl(self, clientcert=None):
+ context = SSL.Context(SSL.SSLv23_METHOD)
+ if clientcert:
+ context.use_certificate_file(self.clientcert)
+ self.connection = SSL.Connection(context, self.connection)
+ self.connection.set_connect_state()
+ self.cert = self.connection.get_peer_certificate()
+ self.rfile = FileLike(self.connection)
+ self.wfile = FileLike(self.connection)
+
def connect(self):
try:
addr = socket.gethostbyname(self.host)
- server = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
- if self.ssl:
- context = SSL.Context(SSL.SSLv23_METHOD)
- if self.clientcert:
- context.use_certificate_file(self.clientcert)
- server = SSL.Connection(context, server)
- server.connect((addr, self.port))
- if self.ssl:
- self.cert = server.get_peer_certificate()
- self.rfile, self.wfile = FileLike(server), FileLike(server)
- else:
- self.rfile, self.wfile = server.makefile('rb'), server.makefile('wb')
+ connection = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+ connection.connect((addr, self.port))
+ self.rfile, self.wfile = connection.makefile('rb'), connection.makefile('wb')
except socket.error, err:
raise NetLibError('Error connecting to "%s": %s' % (self.host, err))
- self.connection = server
+ self.connection = connection
class BaseHandler:
diff --git a/test/test_tcp.py b/test/test_tcp.py
index 1bad9a04..26286bc4 100644
--- a/test/test_tcp.py
+++ b/test/test_tcp.py
@@ -66,7 +66,7 @@ class TestServer(ServerTestBase):
def test_echo(self):
testval = "echo!\n"
- c = tcp.TCPClient(False, "127.0.0.1", self.port, None)
+ c = tcp.TCPClient("127.0.0.1", self.port)
c.connect()
c.wfile.write(testval)
c.wfile.flush()
@@ -82,8 +82,9 @@ class TestServerSSL(ServerTestBase):
return s
def test_echo(self):
- c = tcp.TCPClient(True, "127.0.0.1", self.port, None)
+ c = tcp.TCPClient("127.0.0.1", self.port)
c.connect()
+ c.convert_to_ssl()
testval = "echo!\n"
c.wfile.write(testval)
c.wfile.flush()
@@ -92,7 +93,7 @@ class TestServerSSL(ServerTestBase):
class TestTCPClient:
def test_conerr(self):
- c = tcp.TCPClient(True, "127.0.0.1", 0, None)
+ c = tcp.TCPClient("127.0.0.1", 0)
tutils.raises(tcp.NetLibError, c.connect)