aboutsummaryrefslogtreecommitdiffstats
path: root/netlib/tcp.py
diff options
context:
space:
mode:
Diffstat (limited to 'netlib/tcp.py')
-rw-r--r--netlib/tcp.py989
1 files changed, 0 insertions, 989 deletions
diff --git a/netlib/tcp.py b/netlib/tcp.py
deleted file mode 100644
index ac368a9c..00000000
--- a/netlib/tcp.py
+++ /dev/null
@@ -1,989 +0,0 @@
-import os
-import select
-import socket
-import sys
-import threading
-import time
-import traceback
-
-import binascii
-
-from typing import Optional # noqa
-
-from mitmproxy.utils import strutils
-
-import certifi
-from backports import ssl_match_hostname
-import OpenSSL
-from OpenSSL import SSL
-
-from mitmproxy import certs
-from mitmproxy.utils import version_check
-from mitmproxy.types import serializable
-from mitmproxy import exceptions
-from mitmproxy.types import basethread
-
-# This is a rather hackish way to make sure that
-# the latest version of pyOpenSSL is actually installed.
-version_check.check_pyopenssl_version()
-
-socket_fileobject = socket.SocketIO
-
-EINTR = 4
-if os.environ.get("NO_ALPN"):
- HAS_ALPN = False
-else:
- HAS_ALPN = SSL._lib.Cryptography_HAS_ALPN
-
-# To enable all SSL methods use: SSLv23
-# then add options to disable certain methods
-# https://bugs.launchpad.net/pyopenssl/+bug/1020632/comments/3
-SSL_BASIC_OPTIONS = (
- SSL.OP_CIPHER_SERVER_PREFERENCE
-)
-if hasattr(SSL, "OP_NO_COMPRESSION"):
- SSL_BASIC_OPTIONS |= SSL.OP_NO_COMPRESSION
-
-SSL_DEFAULT_METHOD = SSL.SSLv23_METHOD
-SSL_DEFAULT_OPTIONS = (
- SSL.OP_NO_SSLv2 |
- SSL.OP_NO_SSLv3 |
- SSL_BASIC_OPTIONS
-)
-if hasattr(SSL, "OP_NO_COMPRESSION"):
- SSL_DEFAULT_OPTIONS |= SSL.OP_NO_COMPRESSION
-
-"""
-Map a reasonable SSL version specification into the format OpenSSL expects.
-Don't ask...
-https://bugs.launchpad.net/pyopenssl/+bug/1020632/comments/3
-"""
-sslversion_choices = {
- "all": (SSL.SSLv23_METHOD, SSL_BASIC_OPTIONS),
- # SSLv23_METHOD + NO_SSLv2 + NO_SSLv3 == TLS 1.0+
- # TLSv1_METHOD would be TLS 1.0 only
- "secure": (SSL.SSLv23_METHOD, (SSL.OP_NO_SSLv2 | SSL.OP_NO_SSLv3 | SSL_BASIC_OPTIONS)),
- "SSLv2": (SSL.SSLv2_METHOD, SSL_BASIC_OPTIONS),
- "SSLv3": (SSL.SSLv3_METHOD, SSL_BASIC_OPTIONS),
- "TLSv1": (SSL.TLSv1_METHOD, SSL_BASIC_OPTIONS),
- "TLSv1_1": (SSL.TLSv1_1_METHOD, SSL_BASIC_OPTIONS),
- "TLSv1_2": (SSL.TLSv1_2_METHOD, SSL_BASIC_OPTIONS),
-}
-
-
-class SSLKeyLogger:
-
- def __init__(self, filename):
- self.filename = filename
- self.f = None
- self.lock = threading.Lock()
-
- # required for functools.wraps, which pyOpenSSL uses.
- __name__ = "SSLKeyLogger"
-
- def __call__(self, connection, where, ret):
- if where == SSL.SSL_CB_HANDSHAKE_DONE and ret == 1:
- with self.lock:
- if not self.f:
- d = os.path.dirname(self.filename)
- if not os.path.isdir(d):
- os.makedirs(d)
- self.f = open(self.filename, "ab")
- self.f.write(b"\r\n")
- client_random = binascii.hexlify(connection.client_random())
- masterkey = binascii.hexlify(connection.master_key())
- self.f.write(b"CLIENT_RANDOM %s %s\r\n" % (client_random, masterkey))
- self.f.flush()
-
- def close(self):
- with self.lock:
- if self.f:
- self.f.close()
-
- @staticmethod
- def create_logfun(filename):
- if filename:
- return SSLKeyLogger(filename)
- return False
-
-log_ssl_key = SSLKeyLogger.create_logfun(
- os.getenv("MITMPROXY_SSLKEYLOGFILE") or os.getenv("SSLKEYLOGFILE"))
-
-
-class _FileLike:
- BLOCKSIZE = 1024 * 32
-
- def __init__(self, o):
- self.o = o
- self._log = None
- self.first_byte_timestamp = None
-
- def set_descriptor(self, o):
- self.o = o
-
- def __getattr__(self, attr):
- return getattr(self.o, attr)
-
- def start_log(self):
- """
- Starts or resets the log.
-
- This will store all bytes read or written.
- """
- self._log = []
-
- def stop_log(self):
- """
- Stops the log.
- """
- self._log = None
-
- def is_logging(self):
- return self._log is not None
-
- def get_log(self):
- """
- Returns the log as a string.
- """
- if not self.is_logging():
- raise ValueError("Not logging!")
- return b"".join(self._log)
-
- def add_log(self, v):
- if self.is_logging():
- self._log.append(v)
-
- def reset_timestamps(self):
- self.first_byte_timestamp = None
-
-
-class Writer(_FileLike):
-
- def flush(self):
- """
- May raise exceptions.TcpDisconnect
- """
- if hasattr(self.o, "flush"):
- try:
- self.o.flush()
- except (socket.error, IOError) as v:
- raise exceptions.TcpDisconnect(str(v))
-
- def write(self, v):
- """
- May raise exceptions.TcpDisconnect
- """
- if v:
- self.first_byte_timestamp = self.first_byte_timestamp or time.time()
- try:
- if hasattr(self.o, "sendall"):
- self.add_log(v)
- return self.o.sendall(v)
- else:
- r = self.o.write(v)
- self.add_log(v[:r])
- return r
- except (SSL.Error, socket.error) as e:
- raise exceptions.TcpDisconnect(str(e))
-
-
-class Reader(_FileLike):
-
- def read(self, length):
- """
- If length is -1, we read until connection closes.
- """
- result = b''
- start = time.time()
- while length == -1 or length > 0:
- if length == -1 or length > self.BLOCKSIZE:
- rlen = self.BLOCKSIZE
- else:
- rlen = length
- try:
- data = self.o.read(rlen)
- except SSL.ZeroReturnError:
- # TLS connection was shut down cleanly
- break
- except (SSL.WantWriteError, SSL.WantReadError):
- # From the OpenSSL docs:
- # If the underlying BIO is non-blocking, SSL_read() will also return when the
- # underlying BIO could not satisfy the needs of SSL_read() to continue the
- # operation. In this case a call to SSL_get_error with the return value of
- # SSL_read() will yield SSL_ERROR_WANT_READ or SSL_ERROR_WANT_WRITE.
- if (time.time() - start) < self.o.gettimeout():
- time.sleep(0.1)
- continue
- else:
- raise exceptions.TcpTimeout()
- except socket.timeout:
- raise exceptions.TcpTimeout()
- except socket.error as e:
- raise exceptions.TcpDisconnect(str(e))
- except SSL.SysCallError as e:
- if e.args == (-1, 'Unexpected EOF'):
- break
- raise exceptions.TlsException(str(e))
- except SSL.Error as e:
- raise exceptions.TlsException(str(e))
- self.first_byte_timestamp = self.first_byte_timestamp or time.time()
- if not data:
- break
- result += data
- if length != -1:
- length -= len(data)
- self.add_log(result)
- return result
-
- def readline(self, size=None):
- result = b''
- bytes_read = 0
- while True:
- if size is not None and bytes_read >= size:
- break
- ch = self.read(1)
- bytes_read += 1
- if not ch:
- break
- else:
- result += ch
- if ch == b'\n':
- break
- return result
-
- def safe_read(self, length):
- """
- Like .read, but is guaranteed to either return length bytes, or
- raise an exception.
- """
- result = self.read(length)
- if length != -1 and len(result) != length:
- if not result:
- raise exceptions.TcpDisconnect()
- else:
- raise exceptions.TcpReadIncomplete(
- "Expected %s bytes, got %s" % (length, len(result))
- )
- return result
-
- def peek(self, length):
- """
- Tries to peek into the underlying file object.
-
- Returns:
- Up to the next N bytes if peeking is successful.
-
- Raises:
- exceptions.TcpException if there was an error with the socket
- TlsException if there was an error with pyOpenSSL.
- NotImplementedError if the underlying file object is not a [pyOpenSSL] socket
- """
- if isinstance(self.o, socket_fileobject):
- try:
- return self.o._sock.recv(length, socket.MSG_PEEK)
- except socket.error as e:
- raise exceptions.TcpException(repr(e))
- elif isinstance(self.o, SSL.Connection):
- try:
- return self.o.recv(length, socket.MSG_PEEK)
- except SSL.Error as e:
- raise exceptions.TlsException(str(e))
- else:
- raise NotImplementedError("Can only peek into (pyOpenSSL) sockets")
-
-
-class Address(serializable.Serializable):
-
- """
- This class wraps an IPv4/IPv6 tuple to provide named attributes and
- ipv6 information.
- """
-
- def __init__(self, address, use_ipv6=False):
- self.address = tuple(address)
- self.use_ipv6 = use_ipv6
-
- def get_state(self):
- return {
- "address": self.address,
- "use_ipv6": self.use_ipv6
- }
-
- def set_state(self, state):
- self.address = state["address"]
- self.use_ipv6 = state["use_ipv6"]
-
- @classmethod
- def from_state(cls, state):
- return Address(**state)
-
- @classmethod
- def wrap(cls, t):
- if isinstance(t, cls):
- return t
- else:
- return cls(t)
-
- def __call__(self):
- return self.address
-
- @property
- def host(self):
- return self.address[0]
-
- @property
- def port(self):
- return self.address[1]
-
- @property
- def use_ipv6(self):
- return self.family == socket.AF_INET6
-
- @use_ipv6.setter
- def use_ipv6(self, b):
- self.family = socket.AF_INET6 if b else socket.AF_INET
-
- def __repr__(self):
- return "{}:{}".format(self.host, self.port)
-
- def __eq__(self, other):
- if not other:
- return False
- other = Address.wrap(other)
- return (self.address, self.family) == (other.address, other.family)
-
- def __ne__(self, other):
- return not self.__eq__(other)
-
- def __hash__(self):
- return hash(self.address) ^ 42 # different hash than the tuple alone.
-
-
-def ssl_read_select(rlist, timeout):
- """
- This is a wrapper around select.select() which also works for SSL.Connections
- by taking ssl_connection.pending() into account.
-
- Caveats:
- If .pending() > 0 for any of the connections in rlist, we avoid the select syscall
- and **will not include any other connections which may or may not be ready**.
-
- Args:
- rlist: wait until ready for reading
-
- Returns:
- subset of rlist which is ready for reading.
- """
- return [
- conn for conn in rlist
- if isinstance(conn, SSL.Connection) and conn.pending() > 0
- ] or select.select(rlist, (), (), timeout)[0]
-
-
-def close_socket(sock):
- """
- Does a hard close of a socket, without emitting a RST.
- """
- try:
- # We already indicate that we close our end.
- # may raise "Transport endpoint is not connected" on Linux
- sock.shutdown(socket.SHUT_WR)
-
- # Section 4.2.2.13 of RFC 1122 tells us that a close() with any pending
- # readable data could lead to an immediate RST being sent (which is the
- # case on Windows).
- # http://ia600609.us.archive.org/22/items/TheUltimateSo_lingerPageOrWhyIsMyTcpNotReliable/the-ultimate-so_linger-page-or-why-is-my-tcp-not-reliable.html
- #
- # This in turn results in the following issue: If we send an error page
- # to the client and then close the socket, the RST may be received by
- # the client before the error page and the users sees a connection
- # error rather than the error page. Thus, we try to empty the read
- # buffer on Windows first. (see
- # https://github.com/mitmproxy/mitmproxy/issues/527#issuecomment-93782988)
- #
-
- if os.name == "nt": # pragma: no cover
- # We cannot rely on the shutdown()-followed-by-read()-eof technique
- # proposed by the page above: Some remote machines just don't send
- # a TCP FIN, which would leave us in the unfortunate situation that
- # recv() would block infinitely. As a workaround, we set a timeout
- # here even if we are in blocking mode.
- sock.settimeout(sock.gettimeout() or 20)
-
- # limit at a megabyte so that we don't read infinitely
- for _ in range(1024 ** 3 // 4096):
- # may raise a timeout/disconnect exception.
- if not sock.recv(4096):
- break
-
- # Now we can close the other half as well.
- sock.shutdown(socket.SHUT_RD)
-
- except socket.error:
- pass
-
- sock.close()
-
-
-class _Connection:
-
- rbufsize = -1
- wbufsize = -1
-
- def _makefile(self):
- """
- Set up .rfile and .wfile attributes from .connection
- """
- # Ideally, we would use the Buffered IO in Python 3 by default.
- # Unfortunately, the implementation of .peek() is broken for n>1 bytes,
- # as it may just return what's left in the buffer and not all the bytes we want.
- # As a workaround, we just use unbuffered sockets directly.
- # https://mail.python.org/pipermail/python-dev/2009-June/089986.html
- self.rfile = Reader(socket.SocketIO(self.connection, "rb"))
- self.wfile = Writer(socket.SocketIO(self.connection, "wb"))
-
- def __init__(self, connection):
- if connection:
- self.connection = connection
- self.ip_address = Address(connection.getpeername())
- self._makefile()
- else:
- self.connection = None
- self.ip_address = None
- self.rfile = None
- self.wfile = None
-
- self.ssl_established = False
- self.finished = False
-
- def get_current_cipher(self):
- if not self.ssl_established:
- return None
-
- name = self.connection.get_cipher_name()
- bits = self.connection.get_cipher_bits()
- version = self.connection.get_cipher_version()
- return name, bits, version
-
- def finish(self):
- self.finished = True
- # If we have an SSL connection, wfile.close == connection.close
- # (We call _FileLike.set_descriptor(conn))
- # Closing the socket is not our task, therefore we don't call close
- # then.
- if not isinstance(self.connection, SSL.Connection):
- if not getattr(self.wfile, "closed", False):
- try:
- self.wfile.flush()
- self.wfile.close()
- except exceptions.TcpDisconnect:
- pass
-
- self.rfile.close()
- else:
- try:
- self.connection.shutdown()
- except SSL.Error:
- pass
-
- def _create_ssl_context(self,
- method=SSL_DEFAULT_METHOD,
- options=SSL_DEFAULT_OPTIONS,
- verify_options=SSL.VERIFY_NONE,
- ca_path=None,
- ca_pemfile=None,
- cipher_list=None,
- alpn_protos=None,
- alpn_select=None,
- alpn_select_callback=None,
- sni=None,
- ):
- """
- Creates an SSL Context.
-
- :param method: One of SSLv2_METHOD, SSLv3_METHOD, SSLv23_METHOD, TLSv1_METHOD, TLSv1_1_METHOD, or TLSv1_2_METHOD
- :param options: A bit field consisting of OpenSSL.SSL.OP_* values
- :param verify_options: A bit field consisting of OpenSSL.SSL.VERIFY_* values
- :param ca_path: Path to a directory of trusted CA certificates prepared using the c_rehash tool
- :param ca_pemfile: Path to a PEM formatted trusted CA certificate
- :param cipher_list: A textual OpenSSL cipher list, see https://www.openssl.org/docs/apps/ciphers.html
- :rtype : SSL.Context
- """
- context = SSL.Context(method)
- # Options (NO_SSLv2/3)
- if options is not None:
- context.set_options(options)
-
- # Verify Options (NONE/PEER and trusted CAs)
- if verify_options is not None:
- def verify_cert(conn, x509, errno, err_depth, is_cert_verified):
- if not is_cert_verified:
- self.ssl_verification_error = exceptions.InvalidCertificateException(
- "Certificate Verification Error for {}: {} (errno: {}, depth: {})".format(
- sni,
- strutils.native(SSL._ffi.string(SSL._lib.X509_verify_cert_error_string(errno)), "utf8"),
- errno,
- err_depth
- )
- )
- return is_cert_verified
-
- context.set_verify(verify_options, verify_cert)
- if ca_path is None and ca_pemfile is None:
- ca_pemfile = certifi.where()
- context.load_verify_locations(ca_pemfile, ca_path)
-
- # Workaround for
- # https://github.com/pyca/pyopenssl/issues/190
- # https://github.com/mitmproxy/mitmproxy/issues/472
- # Options already set before are not cleared.
- context.set_mode(SSL._lib.SSL_MODE_AUTO_RETRY)
-
- # Cipher List
- if cipher_list:
- try:
- context.set_cipher_list(cipher_list)
-
- # TODO: maybe change this to with newer pyOpenSSL APIs
- context.set_tmp_ecdh(OpenSSL.crypto.get_elliptic_curve('prime256v1'))
- except SSL.Error as v:
- raise exceptions.TlsException("SSL cipher specification error: %s" % str(v))
-
- # SSLKEYLOGFILE
- if log_ssl_key:
- context.set_info_callback(log_ssl_key)
-
- if HAS_ALPN:
- if alpn_protos is not None:
- # advertise application layer protocols
- context.set_alpn_protos(alpn_protos)
- elif alpn_select is not None and alpn_select_callback is None:
- # select application layer protocol
- def alpn_select_callback(conn_, options):
- if alpn_select in options:
- return bytes(alpn_select)
- else: # pragma no cover
- return options[0]
- context.set_alpn_select_callback(alpn_select_callback)
- elif alpn_select_callback is not None and alpn_select is None:
- context.set_alpn_select_callback(alpn_select_callback)
- elif alpn_select_callback is not None and alpn_select is not None:
- raise exceptions.TlsException("ALPN error: only define alpn_select (string) OR alpn_select_callback (method).")
-
- return context
-
-
-class ConnectionCloser:
- def __init__(self, conn):
- self.conn = conn
- self._canceled = False
-
- def pop(self):
- """
- Cancel the current closer, and return a fresh one.
- """
- self._canceled = True
- return ConnectionCloser(self.conn)
-
- def __enter__(self):
- return self
-
- def __exit__(self, *args):
- if not self._canceled:
- self.conn.close()
-
-
-class TCPClient(_Connection):
-
- def __init__(self, address, source_address=None, spoof_source_address=None):
- super().__init__(None)
- self.address = address
- self.source_address = source_address
- self.cert = None
- self.server_certs = []
- self.ssl_verification_error = None # type: Optional[exceptions.InvalidCertificateException]
- self.sni = None
- self.spoof_source_address = spoof_source_address
-
- @property
- def address(self):
- return self.__address
-
- @address.setter
- def address(self, address):
- if address:
- self.__address = Address.wrap(address)
- else:
- self.__address = None
-
- @property
- def source_address(self):
- return self.__source_address
-
- @source_address.setter
- def source_address(self, source_address):
- if source_address:
- self.__source_address = Address.wrap(source_address)
- else:
- self.__source_address = None
-
- def close(self):
- # Make sure to close the real socket, not the SSL proxy.
- # OpenSSL is really good at screwing up, i.e. when trying to recv from a failed connection,
- # it tries to renegotiate...
- if isinstance(self.connection, SSL.Connection):
- close_socket(self.connection._socket)
- else:
- close_socket(self.connection)
-
- def create_ssl_context(self, cert=None, alpn_protos=None, **sslctx_kwargs):
- context = self._create_ssl_context(
- alpn_protos=alpn_protos,
- **sslctx_kwargs)
- # Client Certs
- if cert:
- try:
- context.use_privatekey_file(cert)
- context.use_certificate_file(cert)
- except SSL.Error as v:
- raise exceptions.TlsException("SSL client certificate error: %s" % str(v))
- return context
-
- def convert_to_ssl(self, sni=None, alpn_protos=None, **sslctx_kwargs):
- """
- cert: Path to a file containing both client cert and private key.
-
- options: A bit field consisting of OpenSSL.SSL.OP_* values
- verify_options: A bit field consisting of OpenSSL.SSL.VERIFY_* values
- ca_path: Path to a directory of trusted CA certificates prepared using the c_rehash tool
- ca_pemfile: Path to a PEM formatted trusted CA certificate
- """
- verification_mode = sslctx_kwargs.get('verify_options', None)
- if verification_mode == SSL.VERIFY_PEER and not sni:
- raise exceptions.TlsException("Cannot validate certificate hostname without SNI")
-
- context = self.create_ssl_context(
- alpn_protos=alpn_protos,
- sni=sni,
- **sslctx_kwargs
- )
- self.connection = SSL.Connection(context, self.connection)
- if sni:
- self.sni = sni
- self.connection.set_tlsext_host_name(sni.encode("idna"))
- self.connection.set_connect_state()
- try:
- self.connection.do_handshake()
- except SSL.Error as v:
- if self.ssl_verification_error:
- raise self.ssl_verification_error
- else:
- raise exceptions.TlsException("SSL handshake error: %s" % repr(v))
- else:
- # Fix for pre v1.0 OpenSSL, which doesn't throw an exception on
- # certificate validation failure
- if verification_mode == SSL.VERIFY_PEER and self.ssl_verification_error:
- raise self.ssl_verification_error
-
- self.cert = certs.SSLCert(self.connection.get_peer_certificate())
-
- # Keep all server certificates in a list
- for i in self.connection.get_peer_cert_chain():
- self.server_certs.append(certs.SSLCert(i))
-
- # Validate TLS Hostname
- try:
- crt = dict(
- subjectAltName=[("DNS", x.decode("ascii", "strict")) for x in self.cert.altnames]
- )
- if self.cert.cn:
- crt["subject"] = [[["commonName", self.cert.cn.decode("ascii", "strict")]]]
- if sni:
- hostname = sni
- else:
- hostname = "no-hostname"
- ssl_match_hostname.match_hostname(crt, hostname)
- except (ValueError, ssl_match_hostname.CertificateError) as e:
- self.ssl_verification_error = exceptions.InvalidCertificateException(
- "Certificate Verification Error for {}: {}".format(
- sni or repr(self.address),
- str(e)
- )
- )
- if verification_mode == SSL.VERIFY_PEER:
- raise self.ssl_verification_error
-
- self.ssl_established = True
- self.rfile.set_descriptor(self.connection)
- self.wfile.set_descriptor(self.connection)
-
- def makesocket(self):
- # some parties (cuckoo sandbox) need to hook this
- return socket.socket(self.address.family, socket.SOCK_STREAM)
-
- def connect(self):
- try:
- connection = self.makesocket()
-
- if self.spoof_source_address:
- try:
- # 19 is `IP_TRANSPARENT`, which is only available on Python 3.3+ on some OSes
- if not connection.getsockopt(socket.SOL_IP, 19):
- connection.setsockopt(socket.SOL_IP, 19, 1)
- except socket.error as e:
- raise exceptions.TcpException(
- "Failed to spoof the source address: " + e.strerror
- )
- if self.source_address:
- connection.bind(self.source_address())
- connection.connect(self.address())
- self.source_address = Address(connection.getsockname())
- except (socket.error, IOError) as err:
- raise exceptions.TcpException(
- 'Error connecting to "%s": %s' %
- (self.address.host, err)
- )
- self.connection = connection
- self.ip_address = Address(connection.getpeername())
- self._makefile()
- return ConnectionCloser(self)
-
- def settimeout(self, n):
- self.connection.settimeout(n)
-
- def gettimeout(self):
- return self.connection.gettimeout()
-
- def get_alpn_proto_negotiated(self):
- if HAS_ALPN and self.ssl_established:
- return self.connection.get_alpn_proto_negotiated()
- else:
- return b""
-
-
-class BaseHandler(_Connection):
-
- """
- The instantiator is expected to call the handle() and finish() methods.
- """
-
- def __init__(self, connection, address, server):
- super().__init__(connection)
- self.address = Address.wrap(address)
- self.server = server
- self.clientcert = None
-
- def create_ssl_context(self,
- cert, key,
- handle_sni=None,
- request_client_cert=None,
- chain_file=None,
- dhparams=None,
- extra_chain_certs=None,
- **sslctx_kwargs):
- """
- cert: A certs.SSLCert object or the path to a certificate
- chain file.
-
- handle_sni: SNI handler, should take a connection object. Server
- name can be retrieved like this:
-
- connection.get_servername()
-
- And you can specify the connection keys as follows:
-
- new_context = Context(TLSv1_METHOD)
- new_context.use_privatekey(key)
- new_context.use_certificate(cert)
- connection.set_context(new_context)
-
- The request_client_cert argument requires some explanation. We're
- supposed to be able to do this with no negative effects - if the
- client has no cert to present, we're notified and proceed as usual.
- Unfortunately, Android seems to have a bug (tested on 4.2.2) - when
- an Android client is asked to present a certificate it does not
- have, it hangs up, which is frankly bogus. Some time down the track
- we may be able to make the proper behaviour the default again, but
- until then we're conservative.
- """
-
- context = self._create_ssl_context(ca_pemfile=chain_file, **sslctx_kwargs)
-
- context.use_privatekey(key)
- if isinstance(cert, certs.SSLCert):
- context.use_certificate(cert.x509)
- else:
- context.use_certificate_chain_file(cert)
-
- if extra_chain_certs:
- for i in extra_chain_certs:
- context.add_extra_chain_cert(i.x509)
-
- if handle_sni:
- # SNI callback happens during do_handshake()
- context.set_tlsext_servername_callback(handle_sni)
-
- if request_client_cert:
- def save_cert(conn_, cert, errno_, depth_, preverify_ok_):
- self.clientcert = certs.SSLCert(cert)
- # Return true to prevent cert verification error
- return True
- context.set_verify(SSL.VERIFY_PEER, save_cert)
-
- if dhparams:
- SSL._lib.SSL_CTX_set_tmp_dh(context._context, dhparams)
-
- return context
-
- def convert_to_ssl(self, cert, key, **sslctx_kwargs):
- """
- Convert connection to SSL.
- For a list of parameters, see BaseHandler._create_ssl_context(...)
- """
-
- context = self.create_ssl_context(
- cert,
- key,
- **sslctx_kwargs)
- self.connection = SSL.Connection(context, self.connection)
- self.connection.set_accept_state()
- try:
- self.connection.do_handshake()
- except SSL.Error as v:
- raise exceptions.TlsException("SSL handshake error: %s" % repr(v))
- self.ssl_established = True
- self.rfile.set_descriptor(self.connection)
- self.wfile.set_descriptor(self.connection)
-
- def handle(self): # pragma: no cover
- raise NotImplementedError
-
- def settimeout(self, n):
- self.connection.settimeout(n)
-
- def get_alpn_proto_negotiated(self):
- if HAS_ALPN and self.ssl_established:
- return self.connection.get_alpn_proto_negotiated()
- else:
- return b""
-
-
-class Counter:
- def __init__(self):
- self._count = 0
- self._lock = threading.Lock()
-
- @property
- def count(self):
- with self._lock:
- return self._count
-
- def __enter__(self):
- with self._lock:
- self._count += 1
-
- def __exit__(self, *args):
- with self._lock:
- self._count -= 1
-
-
-class TCPServer:
- request_queue_size = 20
-
- def __init__(self, address):
- self.address = Address.wrap(address)
- self.__is_shut_down = threading.Event()
- self.__shutdown_request = False
- self.socket = socket.socket(self.address.family, socket.SOCK_STREAM)
- self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
- self.socket.bind(self.address())
- self.address = Address.wrap(self.socket.getsockname())
- self.socket.listen(self.request_queue_size)
- self.handler_counter = Counter()
-
- def connection_thread(self, connection, client_address):
- with self.handler_counter:
- client_address = Address(client_address)
- try:
- self.handle_client_connection(connection, client_address)
- except:
- self.handle_error(connection, client_address)
- finally:
- close_socket(connection)
-
- def serve_forever(self, poll_interval=0.1):
- self.__is_shut_down.clear()
- try:
- while not self.__shutdown_request:
- try:
- r, w_, e_ = select.select(
- [self.socket], [], [], poll_interval)
- except select.error as ex: # pragma: no cover
- if ex[0] == EINTR:
- continue
- else:
- raise
- if self.socket in r:
- connection, client_address = self.socket.accept()
- t = basethread.BaseThread(
- "TCPConnectionHandler (%s: %s:%s -> %s:%s)" % (
- self.__class__.__name__,
- client_address[0],
- client_address[1],
- self.address.host,
- self.address.port
- ),
- target=self.connection_thread,
- args=(connection, client_address),
- )
- t.setDaemon(1)
- try:
- t.start()
- except threading.ThreadError:
- self.handle_error(connection, Address(client_address))
- connection.close()
- finally:
- self.__shutdown_request = False
- self.__is_shut_down.set()
-
- def shutdown(self):
- self.__shutdown_request = True
- self.__is_shut_down.wait()
- self.socket.close()
- self.handle_shutdown()
-
- def handle_error(self, connection_, client_address, fp=sys.stderr):
- """
- Called when handle_client_connection raises an exception.
- """
- # If a thread has persisted after interpreter exit, the module might be
- # none.
- if traceback:
- exc = str(traceback.format_exc())
- print(u'-' * 40, file=fp)
- print(
- u"Error in processing of request from %s" % repr(client_address), file=fp)
- print(exc, file=fp)
- print(u'-' * 40, file=fp)
-
- def handle_client_connection(self, conn, client_address): # pragma: no cover
- """
- Called after client connection.
- """
- raise NotImplementedError
-
- def handle_shutdown(self):
- """
- Called after server shutdown.
- """
-
- def wait_for_silence(self, timeout=5):
- start = time.time()
- while 1:
- if time.time() - start >= timeout:
- raise exceptions.Timeout(
- "%s service threads still alive" %
- self.handler_counter.count
- )
- if self.handler_counter.count == 0:
- return