diff options
Diffstat (limited to 'netlib/tcp.py')
-rw-r--r-- | netlib/tcp.py | 46 |
1 files changed, 35 insertions, 11 deletions
diff --git a/netlib/tcp.py b/netlib/tcp.py index c5f97f94..2704eeae 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -1,6 +1,7 @@ +from __future__ import (absolute_import, print_function, division) import select, socket, threading, sys, time, traceback from OpenSSL import SSL -import certutils +from . import certutils EINTR = 4 @@ -17,7 +18,10 @@ OP_DONT_INSERT_EMPTY_FRAGMENTS = SSL.OP_DONT_INSERT_EMPTY_FRAGMENTS OP_EPHEMERAL_RSA = SSL.OP_EPHEMERAL_RSA OP_MICROSOFT_BIG_SSLV3_BUFFER = SSL.OP_MICROSOFT_BIG_SSLV3_BUFFER OP_MICROSOFT_SESS_ID_BUG = SSL.OP_MICROSOFT_SESS_ID_BUG -OP_MSIE_SSLV2_RSA_PADDING = SSL.OP_MSIE_SSLV2_RSA_PADDING +try: + OP_MSIE_SSLV2_RSA_PADDING = SSL.OP_MSIE_SSLV2_RSA_PADDING +except AttributeError: + pass OP_NETSCAPE_CA_DN_BUG = SSL.OP_NETSCAPE_CA_DN_BUG OP_NETSCAPE_CHALLENGE_BUG = SSL.OP_NETSCAPE_CHALLENGE_BUG OP_NETSCAPE_DEMO_CIPHER_CHANGE_BUG = SSL.OP_NETSCAPE_DEMO_CIPHER_CHANGE_BUG @@ -212,10 +216,16 @@ class Address(object): def use_ipv6(self, b): self.family = socket.AF_INET6 if b else socket.AF_INET + def __repr__(self): + return repr(self.address) + def __eq__(self, other): other = Address.wrap(other) return (self.address, self.family) == (other.address, other.family) + def __ne__(self, other): + return not self.__eq__(other) + class _Connection(object): def get_current_cipher(self): @@ -309,6 +319,8 @@ class TCPClient(_Connection): if self.source_address: connection.bind(self.source_address()) connection.connect(self.address()) + if not self.source_address: + self.source_address = Address(connection.getsockname()) self.rfile = Reader(connection.makefile('rb', self.rbufsize)) self.wfile = Writer(connection.makefile('wb', self.wbufsize)) except (socket.error, IOError), err: @@ -341,10 +353,9 @@ class BaseHandler(_Connection): self.ssl_established = False self.clientcert = None - def convert_to_ssl(self, cert, key, - method=SSLv23_METHOD, options=None, handle_sni=None, - request_client_cert=False, cipher_list=None, dhparams=None - ): + def _create_ssl_context(self, cert, key, method=SSLv23_METHOD, options=None, + handle_sni=None, request_client_cert=None, cipher_list=None, + dhparams=None, ca_file=None): """ cert: A certutils.SSLCert object. method: One of SSLv2_METHOD, SSLv3_METHOD, SSLv23_METHOD, or TLSv1_METHOD @@ -372,6 +383,8 @@ class BaseHandler(_Connection): ctx = SSL.Context(method) if not options is None: ctx.set_options(options) + if ca_file: + ctx.load_verify_locations(ca_file) if cipher_list: try: ctx.set_cipher_list(cipher_list) @@ -390,6 +403,14 @@ class BaseHandler(_Connection): # Return true to prevent cert verification error return True ctx.set_verify(SSL.VERIFY_PEER, ver) + return ctx + + def convert_to_ssl(self, cert, key, **sslctx_kwargs): + """ + Convert connection to SSL. + For a list of parameters, see BaseHandler._create_ssl_context(...) + """ + ctx = self._create_ssl_context(cert, key, **sslctx_kwargs) self.connection = SSL.Connection(ctx, self.connection) self.ssl_established = True self.connection.set_accept_state() @@ -443,7 +464,7 @@ class TCPServer(object): if ex[0] == EINTR: continue else: - raise + raise if self.socket in r: connection, client_address = self.socket.accept() t = threading.Thread( @@ -473,10 +494,13 @@ class TCPServer(object): # none. if traceback: exc = traceback.format_exc() - print >> fp, '-'*40 - print >> fp, "Error in processing of request from %s:%s" % (client_address.host, client_address.port) - print >> fp, exc - print >> fp, '-'*40 + print('-'*40, file=fp) + print( + "Error in processing of request from %s:%s" % ( + client_address.host, client_address.port + ), file=fp) + print(exc, file=fp) + print('-'*40, file=fp) def handle_client_connection(self, conn, client_address): # pragma: no cover """ |