diff options
Diffstat (limited to 'netlib')
-rw-r--r-- | netlib/tcp.py | 140 |
1 files changed, 86 insertions, 54 deletions
diff --git a/netlib/tcp.py b/netlib/tcp.py index 7f98b4f9..ba4f008c 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -302,6 +302,43 @@ class _Connection(object): except SSL.Error: pass + """ + Creates an SSL Context. + """ + def _create_ssl_context(self, + method=SSLv23_METHOD, + options=(OP_NO_SSLv2 | OP_NO_SSLv3), + cipher_list=None + ): + """ + :param method: One of SSLv2_METHOD, SSLv3_METHOD, SSLv23_METHOD, TLSv1_METHOD or TLSv1_1_METHOD + :param options: A bit field consisting of OpenSSL.SSL.OP_* values + :param cipher_list: A textual OpenSSL cipher list, see https://www.openssl.org/docs/apps/ciphers.html + :rtype : SSL.Context + """ + context = SSL.Context(method) + # Options (NO_SSLv2/3) + if options is not None: + context.set_options(options) + + # Workaround for + # https://github.com/pyca/pyopenssl/issues/190 + # https://github.com/mitmproxy/mitmproxy/issues/472 + context.set_mode(SSL._lib.SSL_MODE_AUTO_RETRY) # Options already set before are not cleared. + + # Cipher List + if cipher_list: + try: + context.set_cipher_list(cipher_list) + except SSL.Error, v: + raise NetLibError("SSL cipher specification error: %s"%str(v)) + + # SSLKEYLOGFILE + if log_ssl_key: + context.set_info_callback(log_ssl_key) + + return context + class TCPClient(_Connection): rbufsize = -1 @@ -324,32 +361,28 @@ class TCPClient(_Connection): self.ssl_established = False self.sni = None - def convert_to_ssl(self, cert=None, sni=None, method=SSLv23_METHOD, options=(OP_NO_SSLv2 | OP_NO_SSLv3), cipher_list=None): - """ - cert: Path to a file containing both client cert and private key. - - options: A bit field consisting of OpenSSL.SSL.OP_* values - """ - context = SSL.Context(method) - if cipher_list: - try: - context.set_cipher_list(cipher_list) - except SSL.Error, v: - raise NetLibError("SSL cipher specification error: %s"%str(v)) - if options is not None: - context.set_options(options) + def create_ssl_context(self, cert=None, **sslctx_kwargs): + context = self._create_ssl_context(**sslctx_kwargs) + # Client Certs if cert: try: context.use_privatekey_file(cert) context.use_certificate_file(cert) except SSL.Error, v: raise NetLibError("SSL client certificate error: %s"%str(v)) + return context + + def convert_to_ssl(self, sni=None, **sslctx_kwargs): + """ + cert: Path to a file containing both client cert and private key. + + options: A bit field consisting of OpenSSL.SSL.OP_* values + """ + context = self.create_ssl_context(**sslctx_kwargs) self.connection = SSL.Connection(context, self.connection) if sni: self.sni = sni self.connection.set_tlsext_host_name(sni) - if log_ssl_key: - context.set_info_callback(log_ssl_key) self.connection.set_connect_state() try: self.connection.do_handshake() @@ -400,21 +433,21 @@ class BaseHandler(_Connection): self.ssl_established = False self.clientcert = None - def _create_ssl_context(self, cert, key, method=SSLv23_METHOD, options=OP_NO_SSLv2, - handle_sni=None, request_client_cert=None, cipher_list=None, - dhparams=None, chain_file=None): + def create_ssl_context(self, + cert, key, + handle_sni=None, + request_client_cert=None, + chain_file=None, + dhparams=None, + **sslctx_kwargs): """ cert: A certutils.SSLCert object. - method: One of SSLv2_METHOD, SSLv3_METHOD, SSLv23_METHOD, or TLSv1_METHOD - handle_sni: SNI handler, should take a connection object. Server name can be retrieved like this: connection.get_servername() - options: A bit field consisting of OpenSSL.SSL.OP_* values - And you can specify the connection keys as follows: new_context = Context(TLSv1_METHOD) @@ -431,40 +464,38 @@ class BaseHandler(_Connection): we may be able to make the proper behaviour the default again, but until then we're conservative. """ - ctx = SSL.Context(method) - if not options is None: - ctx.set_options(options) - if chain_file: - ctx.load_verify_locations(chain_file) - if cipher_list: - try: - ctx.set_cipher_list(cipher_list) - except SSL.Error, v: - raise NetLibError("SSL cipher specification error: %s"%str(v)) + context = self._create_ssl_context(**sslctx_kwargs) + + context.use_privatekey(key) + context.use_certificate(cert.x509) + if handle_sni: # SNI callback happens during do_handshake() - ctx.set_tlsext_servername_callback(handle_sni) - ctx.use_privatekey(key) - ctx.use_certificate(cert.x509) - if dhparams: - SSL._lib.SSL_CTX_set_tmp_dh(ctx._context, dhparams) + context.set_tlsext_servername_callback(handle_sni) + if request_client_cert: - def ver(*args): - self.clientcert = certutils.SSLCert(args[1]) + def save_cert(conn, cert, errno, depth, preverify_ok): + self.clientcert = certutils.SSLCert(cert) # Return true to prevent cert verification error return True - ctx.set_verify(SSL.VERIFY_PEER, ver) - if log_ssl_key: - ctx.set_info_callback(log_ssl_key) - return ctx + context.set_verify(SSL.VERIFY_PEER, save_cert) + + # Cert Verify + if chain_file: + context.load_verify_locations(chain_file) + + if dhparams: + SSL._lib.SSL_CTX_set_tmp_dh(context._context, dhparams) + + return context def convert_to_ssl(self, cert, key, **sslctx_kwargs): """ Convert connection to SSL. For a list of parameters, see BaseHandler._create_ssl_context(...) """ - ctx = self._create_ssl_context(cert, key, **sslctx_kwargs) - self.connection = SSL.Connection(ctx, self.connection) + context = self.create_ssl_context(cert, key, **sslctx_kwargs) + self.connection = SSL.Connection(context, self.connection) self.connection.set_accept_state() try: self.connection.do_handshake() @@ -474,7 +505,7 @@ class BaseHandler(_Connection): self.rfile.set_descriptor(self.connection) self.wfile.set_descriptor(self.connection) - def handle(self): # pragma: no cover + def handle(self): # pragma: no cover raise NotImplementedError def settimeout(self, n): @@ -483,6 +514,7 @@ class BaseHandler(_Connection): class TCPServer(object): request_queue_size = 20 + def __init__(self, address): self.address = Address.wrap(address) self.__is_shut_down = threading.Event() @@ -508,7 +540,7 @@ class TCPServer(object): while not self.__shutdown_request: try: r, w, e = select.select([self.socket], [], [], poll_interval) - except select.error, ex: # pragma: no cover + except select.error as ex: # pragma: no cover if ex[0] == EINTR: continue else: @@ -516,12 +548,12 @@ class TCPServer(object): if self.socket in r: connection, client_address = self.socket.accept() t = threading.Thread( - target = self.connection_thread, - args = (connection, client_address), - name = "ConnectionThread (%s:%s -> %s:%s)" % - (client_address[0], client_address[1], - self.address.host, self.address.port) - ) + target=self.connection_thread, + args=(connection, client_address), + name="ConnectionThread (%s:%s -> %s:%s)" % + (client_address[0], client_address[1], + self.address.host, self.address.port) + ) t.setDaemon(1) t.start() finally: |