aboutsummaryrefslogtreecommitdiffstats
path: root/netlib
diff options
context:
space:
mode:
Diffstat (limited to 'netlib')
-rw-r--r--netlib/tcp.py140
1 files changed, 86 insertions, 54 deletions
diff --git a/netlib/tcp.py b/netlib/tcp.py
index 7f98b4f9..ba4f008c 100644
--- a/netlib/tcp.py
+++ b/netlib/tcp.py
@@ -302,6 +302,43 @@ class _Connection(object):
except SSL.Error:
pass
+ """
+ Creates an SSL Context.
+ """
+ def _create_ssl_context(self,
+ method=SSLv23_METHOD,
+ options=(OP_NO_SSLv2 | OP_NO_SSLv3),
+ cipher_list=None
+ ):
+ """
+ :param method: One of SSLv2_METHOD, SSLv3_METHOD, SSLv23_METHOD, TLSv1_METHOD or TLSv1_1_METHOD
+ :param options: A bit field consisting of OpenSSL.SSL.OP_* values
+ :param cipher_list: A textual OpenSSL cipher list, see https://www.openssl.org/docs/apps/ciphers.html
+ :rtype : SSL.Context
+ """
+ context = SSL.Context(method)
+ # Options (NO_SSLv2/3)
+ if options is not None:
+ context.set_options(options)
+
+ # Workaround for
+ # https://github.com/pyca/pyopenssl/issues/190
+ # https://github.com/mitmproxy/mitmproxy/issues/472
+ context.set_mode(SSL._lib.SSL_MODE_AUTO_RETRY) # Options already set before are not cleared.
+
+ # Cipher List
+ if cipher_list:
+ try:
+ context.set_cipher_list(cipher_list)
+ except SSL.Error, v:
+ raise NetLibError("SSL cipher specification error: %s"%str(v))
+
+ # SSLKEYLOGFILE
+ if log_ssl_key:
+ context.set_info_callback(log_ssl_key)
+
+ return context
+
class TCPClient(_Connection):
rbufsize = -1
@@ -324,32 +361,28 @@ class TCPClient(_Connection):
self.ssl_established = False
self.sni = None
- def convert_to_ssl(self, cert=None, sni=None, method=SSLv23_METHOD, options=(OP_NO_SSLv2 | OP_NO_SSLv3), cipher_list=None):
- """
- cert: Path to a file containing both client cert and private key.
-
- options: A bit field consisting of OpenSSL.SSL.OP_* values
- """
- context = SSL.Context(method)
- if cipher_list:
- try:
- context.set_cipher_list(cipher_list)
- except SSL.Error, v:
- raise NetLibError("SSL cipher specification error: %s"%str(v))
- if options is not None:
- context.set_options(options)
+ def create_ssl_context(self, cert=None, **sslctx_kwargs):
+ context = self._create_ssl_context(**sslctx_kwargs)
+ # Client Certs
if cert:
try:
context.use_privatekey_file(cert)
context.use_certificate_file(cert)
except SSL.Error, v:
raise NetLibError("SSL client certificate error: %s"%str(v))
+ return context
+
+ def convert_to_ssl(self, sni=None, **sslctx_kwargs):
+ """
+ cert: Path to a file containing both client cert and private key.
+
+ options: A bit field consisting of OpenSSL.SSL.OP_* values
+ """
+ context = self.create_ssl_context(**sslctx_kwargs)
self.connection = SSL.Connection(context, self.connection)
if sni:
self.sni = sni
self.connection.set_tlsext_host_name(sni)
- if log_ssl_key:
- context.set_info_callback(log_ssl_key)
self.connection.set_connect_state()
try:
self.connection.do_handshake()
@@ -400,21 +433,21 @@ class BaseHandler(_Connection):
self.ssl_established = False
self.clientcert = None
- def _create_ssl_context(self, cert, key, method=SSLv23_METHOD, options=OP_NO_SSLv2,
- handle_sni=None, request_client_cert=None, cipher_list=None,
- dhparams=None, chain_file=None):
+ def create_ssl_context(self,
+ cert, key,
+ handle_sni=None,
+ request_client_cert=None,
+ chain_file=None,
+ dhparams=None,
+ **sslctx_kwargs):
"""
cert: A certutils.SSLCert object.
- method: One of SSLv2_METHOD, SSLv3_METHOD, SSLv23_METHOD, or TLSv1_METHOD
-
handle_sni: SNI handler, should take a connection object. Server
name can be retrieved like this:
connection.get_servername()
- options: A bit field consisting of OpenSSL.SSL.OP_* values
-
And you can specify the connection keys as follows:
new_context = Context(TLSv1_METHOD)
@@ -431,40 +464,38 @@ class BaseHandler(_Connection):
we may be able to make the proper behaviour the default again, but
until then we're conservative.
"""
- ctx = SSL.Context(method)
- if not options is None:
- ctx.set_options(options)
- if chain_file:
- ctx.load_verify_locations(chain_file)
- if cipher_list:
- try:
- ctx.set_cipher_list(cipher_list)
- except SSL.Error, v:
- raise NetLibError("SSL cipher specification error: %s"%str(v))
+ context = self._create_ssl_context(**sslctx_kwargs)
+
+ context.use_privatekey(key)
+ context.use_certificate(cert.x509)
+
if handle_sni:
# SNI callback happens during do_handshake()
- ctx.set_tlsext_servername_callback(handle_sni)
- ctx.use_privatekey(key)
- ctx.use_certificate(cert.x509)
- if dhparams:
- SSL._lib.SSL_CTX_set_tmp_dh(ctx._context, dhparams)
+ context.set_tlsext_servername_callback(handle_sni)
+
if request_client_cert:
- def ver(*args):
- self.clientcert = certutils.SSLCert(args[1])
+ def save_cert(conn, cert, errno, depth, preverify_ok):
+ self.clientcert = certutils.SSLCert(cert)
# Return true to prevent cert verification error
return True
- ctx.set_verify(SSL.VERIFY_PEER, ver)
- if log_ssl_key:
- ctx.set_info_callback(log_ssl_key)
- return ctx
+ context.set_verify(SSL.VERIFY_PEER, save_cert)
+
+ # Cert Verify
+ if chain_file:
+ context.load_verify_locations(chain_file)
+
+ if dhparams:
+ SSL._lib.SSL_CTX_set_tmp_dh(context._context, dhparams)
+
+ return context
def convert_to_ssl(self, cert, key, **sslctx_kwargs):
"""
Convert connection to SSL.
For a list of parameters, see BaseHandler._create_ssl_context(...)
"""
- ctx = self._create_ssl_context(cert, key, **sslctx_kwargs)
- self.connection = SSL.Connection(ctx, self.connection)
+ context = self.create_ssl_context(cert, key, **sslctx_kwargs)
+ self.connection = SSL.Connection(context, self.connection)
self.connection.set_accept_state()
try:
self.connection.do_handshake()
@@ -474,7 +505,7 @@ class BaseHandler(_Connection):
self.rfile.set_descriptor(self.connection)
self.wfile.set_descriptor(self.connection)
- def handle(self): # pragma: no cover
+ def handle(self): # pragma: no cover
raise NotImplementedError
def settimeout(self, n):
@@ -483,6 +514,7 @@ class BaseHandler(_Connection):
class TCPServer(object):
request_queue_size = 20
+
def __init__(self, address):
self.address = Address.wrap(address)
self.__is_shut_down = threading.Event()
@@ -508,7 +540,7 @@ class TCPServer(object):
while not self.__shutdown_request:
try:
r, w, e = select.select([self.socket], [], [], poll_interval)
- except select.error, ex: # pragma: no cover
+ except select.error as ex: # pragma: no cover
if ex[0] == EINTR:
continue
else:
@@ -516,12 +548,12 @@ class TCPServer(object):
if self.socket in r:
connection, client_address = self.socket.accept()
t = threading.Thread(
- target = self.connection_thread,
- args = (connection, client_address),
- name = "ConnectionThread (%s:%s -> %s:%s)" %
- (client_address[0], client_address[1],
- self.address.host, self.address.port)
- )
+ target=self.connection_thread,
+ args=(connection, client_address),
+ name="ConnectionThread (%s:%s -> %s:%s)" %
+ (client_address[0], client_address[1],
+ self.address.host, self.address.port)
+ )
t.setDaemon(1)
t.start()
finally: