aboutsummaryrefslogtreecommitdiffstats
path: root/netlib/tcp.py
diff options
context:
space:
mode:
Diffstat (limited to 'netlib/tcp.py')
-rw-r--r--netlib/tcp.py103
1 files changed, 65 insertions, 38 deletions
diff --git a/netlib/tcp.py b/netlib/tcp.py
index d26bb5f7..de12102e 100644
--- a/netlib/tcp.py
+++ b/netlib/tcp.py
@@ -6,6 +6,7 @@ import sys
import threading
import time
import traceback
+import contextlib
import binascii
from six.moves import range
@@ -16,13 +17,10 @@ import six
import OpenSSL
from OpenSSL import SSL
-from . import certutils, version_check, utils
+from netlib import certutils, version_check, basetypes, exceptions
# This is a rather hackish way to make sure that
# the latest version of pyOpenSSL is actually installed.
-from netlib.exceptions import InvalidCertificateException, TcpReadIncomplete, TlsException, \
- TcpTimeout, TcpDisconnect, TcpException
-
version_check.check_pyopenssl_version()
if six.PY2:
@@ -71,6 +69,7 @@ sslversion_choices = {
"TLSv1_2": (SSL.TLSv1_2_METHOD, SSL_BASIC_OPTIONS),
}
+
class SSLKeyLogger(object):
def __init__(self, filename):
@@ -161,17 +160,17 @@ class Writer(_FileLike):
def flush(self):
"""
- May raise TcpDisconnect
+ May raise exceptions.TcpDisconnect
"""
if hasattr(self.o, "flush"):
try:
self.o.flush()
except (socket.error, IOError) as v:
- raise TcpDisconnect(str(v))
+ raise exceptions.TcpDisconnect(str(v))
def write(self, v):
"""
- May raise TcpDisconnect
+ May raise exceptions.TcpDisconnect
"""
if v:
self.first_byte_timestamp = self.first_byte_timestamp or time.time()
@@ -184,7 +183,7 @@ class Writer(_FileLike):
self.add_log(v[:r])
return r
except (SSL.Error, socket.error) as e:
- raise TcpDisconnect(str(e))
+ raise exceptions.TcpDisconnect(str(e))
class Reader(_FileLike):
@@ -215,17 +214,17 @@ class Reader(_FileLike):
time.sleep(0.1)
continue
else:
- raise TcpTimeout()
+ raise exceptions.TcpTimeout()
except socket.timeout:
- raise TcpTimeout()
+ raise exceptions.TcpTimeout()
except socket.error as e:
- raise TcpDisconnect(str(e))
+ raise exceptions.TcpDisconnect(str(e))
except SSL.SysCallError as e:
if e.args == (-1, 'Unexpected EOF'):
break
- raise TlsException(str(e))
+ raise exceptions.TlsException(str(e))
except SSL.Error as e:
- raise TlsException(str(e))
+ raise exceptions.TlsException(str(e))
self.first_byte_timestamp = self.first_byte_timestamp or time.time()
if not data:
break
@@ -259,9 +258,9 @@ class Reader(_FileLike):
result = self.read(length)
if length != -1 and len(result) != length:
if not result:
- raise TcpDisconnect()
+ raise exceptions.TcpDisconnect()
else:
- raise TcpReadIncomplete(
+ raise exceptions.TcpReadIncomplete(
"Expected %s bytes, got %s" % (length, len(result))
)
return result
@@ -274,7 +273,7 @@ class Reader(_FileLike):
Up to the next N bytes if peeking is successful.
Raises:
- TcpException if there was an error with the socket
+ exceptions.TcpException if there was an error with the socket
TlsException if there was an error with pyOpenSSL.
NotImplementedError if the underlying file object is not a [pyOpenSSL] socket
"""
@@ -282,7 +281,7 @@ class Reader(_FileLike):
try:
return self.o._sock.recv(length, socket.MSG_PEEK)
except socket.error as e:
- raise TcpException(repr(e))
+ raise exceptions.TcpException(repr(e))
elif isinstance(self.o, SSL.Connection):
try:
if tuple(int(x) for x in OpenSSL.__version__.split(".")[:2]) > (0, 15):
@@ -296,12 +295,12 @@ class Reader(_FileLike):
self.o._raise_ssl_error(self.o._ssl, result)
return SSL._ffi.buffer(buf, result)[:]
except SSL.Error as e:
- six.reraise(TlsException, TlsException(str(e)), sys.exc_info()[2])
+ six.reraise(exceptions.TlsException, exceptions.TlsException(str(e)), sys.exc_info()[2])
else:
raise NotImplementedError("Can only peek into (pyOpenSSL) sockets")
-class Address(utils.Serializable):
+class Address(basetypes.Serializable):
"""
This class wraps an IPv4/IPv6 tuple to provide named attributes and
@@ -489,7 +488,7 @@ class _Connection(object):
try:
self.wfile.flush()
self.wfile.close()
- except TcpDisconnect:
+ except exceptions.TcpDisconnect:
pass
self.rfile.close()
@@ -553,7 +552,7 @@ class _Connection(object):
# TODO: maybe change this to with newer pyOpenSSL APIs
context.set_tmp_ecdh(OpenSSL.crypto.get_elliptic_curve('prime256v1'))
except SSL.Error as v:
- raise TlsException("SSL cipher specification error: %s" % str(v))
+ raise exceptions.TlsException("SSL cipher specification error: %s" % str(v))
# SSLKEYLOGFILE
if log_ssl_key:
@@ -574,11 +573,17 @@ class _Connection(object):
elif alpn_select_callback is not None and alpn_select is None:
context.set_alpn_select_callback(alpn_select_callback)
elif alpn_select_callback is not None and alpn_select is not None:
- raise TlsException("ALPN error: only define alpn_select (string) OR alpn_select_callback (method).")
+ raise exceptions.TlsException("ALPN error: only define alpn_select (string) OR alpn_select_callback (method).")
return context
+@contextlib.contextmanager
+def _closer(client):
+ yield
+ client.close()
+
+
class TCPClient(_Connection):
def __init__(self, address, source_address=None):
@@ -631,7 +636,7 @@ class TCPClient(_Connection):
context.use_privatekey_file(cert)
context.use_certificate_file(cert)
except SSL.Error as v:
- raise TlsException("SSL client certificate error: %s" % str(v))
+ raise exceptions.TlsException("SSL client certificate error: %s" % str(v))
return context
def convert_to_ssl(self, sni=None, alpn_protos=None, **sslctx_kwargs):
@@ -645,7 +650,7 @@ class TCPClient(_Connection):
"""
verification_mode = sslctx_kwargs.get('verify_options', None)
if verification_mode == SSL.VERIFY_PEER and not sni:
- raise TlsException("Cannot validate certificate hostname without SNI")
+ raise exceptions.TlsException("Cannot validate certificate hostname without SNI")
context = self.create_ssl_context(
alpn_protos=alpn_protos,
@@ -660,14 +665,14 @@ class TCPClient(_Connection):
self.connection.do_handshake()
except SSL.Error as v:
if self.ssl_verification_error:
- raise InvalidCertificateException("SSL handshake error: %s" % repr(v))
+ raise exceptions.InvalidCertificateException("SSL handshake error: %s" % repr(v))
else:
- raise TlsException("SSL handshake error: %s" % repr(v))
+ raise exceptions.TlsException("SSL handshake error: %s" % repr(v))
else:
# Fix for pre v1.0 OpenSSL, which doesn't throw an exception on
# certificate validation failure
if verification_mode == SSL.VERIFY_PEER and self.ssl_verification_error is not None:
- raise InvalidCertificateException("SSL handshake error: certificate verify failed")
+ raise exceptions.InvalidCertificateException("SSL handshake error: certificate verify failed")
self.cert = certutils.SSLCert(self.connection.get_peer_certificate())
@@ -690,7 +695,7 @@ class TCPClient(_Connection):
except (ValueError, ssl_match_hostname.CertificateError) as e:
self.ssl_verification_error = dict(depth=0, errno="Invalid Hostname")
if verification_mode == SSL.VERIFY_PEER:
- raise InvalidCertificateException("Presented certificate for {} is not valid: {}".format(sni, str(e)))
+ raise exceptions.InvalidCertificateException("Presented certificate for {} is not valid: {}".format(sni, str(e)))
self.ssl_established = True
self.rfile.set_descriptor(self.connection)
@@ -704,12 +709,13 @@ class TCPClient(_Connection):
connection.connect(self.address())
self.source_address = Address(connection.getsockname())
except (socket.error, IOError) as err:
- raise TcpException(
+ raise exceptions.TcpException(
'Error connecting to "%s": %s' %
(self.address.host, err))
self.connection = connection
self.ip_address = Address(connection.getpeername())
self._makefile()
+ return _closer(self)
def settimeout(self, n):
self.connection.settimeout(n)
@@ -817,7 +823,7 @@ class BaseHandler(_Connection):
try:
self.connection.do_handshake()
except SSL.Error as v:
- raise TlsException("SSL handshake error: %s" % repr(v))
+ raise exceptions.TlsException("SSL handshake error: %s" % repr(v))
self.ssl_established = True
self.rfile.set_descriptor(self.connection)
self.wfile.set_descriptor(self.connection)
@@ -835,6 +841,25 @@ class BaseHandler(_Connection):
return b""
+class Counter:
+ def __init__(self):
+ self._count = 0
+ self._lock = threading.Lock()
+
+ @property
+ def count(self):
+ with self._lock:
+ return self._count
+
+ def __enter__(self):
+ with self._lock:
+ self._count += 1
+
+ def __exit__(self, *args):
+ with self._lock:
+ self._count -= 1
+
+
class TCPServer(object):
request_queue_size = 20
@@ -847,15 +872,17 @@ class TCPServer(object):
self.socket.bind(self.address())
self.address = Address.wrap(self.socket.getsockname())
self.socket.listen(self.request_queue_size)
+ self.handler_counter = Counter()
def connection_thread(self, connection, client_address):
- client_address = Address(client_address)
- try:
- self.handle_client_connection(connection, client_address)
- except:
- self.handle_error(connection, client_address)
- finally:
- close_socket(connection)
+ with self.handler_counter:
+ client_address = Address(client_address)
+ try:
+ self.handle_client_connection(connection, client_address)
+ except:
+ self.handle_error(connection, client_address)
+ finally:
+ close_socket(connection)
def serve_forever(self, poll_interval=0.1):
self.__is_shut_down.clear()
@@ -900,7 +927,7 @@ class TCPServer(object):
"""
# If a thread has persisted after interpreter exit, the module might be
# none.
- if traceback:
+ if traceback and six:
exc = six.text_type(traceback.format_exc())
print(u'-' * 40, file=fp)
print(