diff options
-rw-r--r-- | mitmproxy/net/tls.py | 20 | ||||
-rw-r--r-- | mitmproxy/proxy/protocol/tls.py | 2 | ||||
-rw-r--r-- | mitmproxy/proxy/root_context.py | 2 | ||||
-rw-r--r-- | test/mitmproxy/net/test_tls.py | 90 |
4 files changed, 95 insertions, 19 deletions
diff --git a/mitmproxy/net/tls.py b/mitmproxy/net/tls.py index 3d824114..0e43a2ac 100644 --- a/mitmproxy/net/tls.py +++ b/mitmproxy/net/tls.py @@ -363,7 +363,7 @@ def is_tls_record_magic(d): ) -def get_client_hello(client_conn): +def get_client_hello(rfile): """ Peek into the socket and read all records that contain the initial client hello message. @@ -377,12 +377,12 @@ def get_client_hello(client_conn): client_hello_size = 1 offset = 0 while len(client_hello) < client_hello_size: - record_header = client_conn.rfile.peek(offset + 5)[offset:] - if not is_tls_record_magic(record_header) or len(record_header) != 5: + record_header = rfile.peek(offset + 5)[offset:] + if not is_tls_record_magic(record_header) or len(record_header) < 5: raise exceptions.TlsProtocolException( 'Expected TLS record, got "%s" instead.' % record_header) - record_size = struct.unpack("!H", record_header[3:])[0] + 5 - record_body = client_conn.rfile.peek(offset + record_size)[offset + 5:] + record_size = struct.unpack_from("!H", record_header, 3)[0] + 5 + record_body = rfile.peek(offset + record_size)[offset + 5:] if len(record_body) != record_size - 5: raise exceptions.TlsProtocolException( "Unexpected EOF in TLS handshake: %s" % record_body) @@ -396,10 +396,8 @@ class ClientHello: def __init__(self, raw_client_hello): self._client_hello = tls_client_hello.TlsClientHello( - KaitaiStream(io.BytesIO(raw_client_hello))) - - def raw(self): - return self._client_hello + KaitaiStream(io.BytesIO(raw_client_hello)) + ) @property def cipher_suites(self): @@ -437,7 +435,7 @@ class ClientHello: return ret @classmethod - def from_client_conn(cls, client_conn) -> "ClientHello": + def from_file(cls, client_conn) -> "ClientHello": """ Peek into the connection, read the initial client hello and parse it to obtain ALPN values. client_conn: @@ -455,7 +453,7 @@ class ClientHello: except EOFError as e: raise exceptions.TlsProtocolException( 'Cannot parse Client Hello: %s, Raw Client Hello: %s' % - (repr(e), raw_client_hello.encode("hex")) + (repr(e), binascii.hexlify(raw_client_hello)) ) def __repr__(self): diff --git a/mitmproxy/proxy/protocol/tls.py b/mitmproxy/proxy/protocol/tls.py index 63023871..d04c9801 100644 --- a/mitmproxy/proxy/protocol/tls.py +++ b/mitmproxy/proxy/protocol/tls.py @@ -242,7 +242,7 @@ class TlsLayer(base.Layer): if self._client_tls: # Peek into the connection, read the initial client hello and parse it to obtain SNI and ALPN values. try: - self._client_hello = net_tls.ClientHello.from_client_conn(self.client_conn) + self._client_hello = net_tls.ClientHello.from_file(self.client_conn.rfile) except exceptions.TlsProtocolException as e: self.log("Cannot parse Client Hello: %s" % repr(e), "error") diff --git a/mitmproxy/proxy/root_context.py b/mitmproxy/proxy/root_context.py index 0af8b364..eb0008cf 100644 --- a/mitmproxy/proxy/root_context.py +++ b/mitmproxy/proxy/root_context.py @@ -53,7 +53,7 @@ class RootContext: ignore = self.config.check_ignore(top_layer.server_conn.address) if not ignore and client_tls: try: - client_hello = tls.ClientHello.from_client_conn(self.client_conn) + client_hello = tls.ClientHello.from_file(self.client_conn.rfile) except exceptions.TlsProtocolException as e: self.log("Cannot parse Client Hello: %s" % repr(e), "error") else: diff --git a/test/mitmproxy/net/test_tls.py b/test/mitmproxy/net/test_tls.py index f551b904..c67b59cd 100644 --- a/test/mitmproxy/net/test_tls.py +++ b/test/mitmproxy/net/test_tls.py @@ -1,3 +1,5 @@ +import io + import pytest from mitmproxy import exceptions @@ -6,6 +8,17 @@ from mitmproxy.net.tcp import TCPClient from test.mitmproxy.net.test_tcp import EchoHandler from . import tservers +CLIENT_HELLO_NO_EXTENSIONS = bytes.fromhex( + "03015658a756ab2c2bff55f636814deac086b7ca56b65058c7893ffc6074f5245f70205658a75475103a152637" + "78e1bb6d22e8bbd5b6b0a3a59760ad354e91ba20d353001a0035002f000a000500040009000300060008006000" + "61006200640100" +) +FULL_CLIENT_HELLO_NO_EXTENSIONS = ( + b"\x16\x03\x03\x00\x65" # record layer + b"\x01\x00\x00\x61" + # handshake header + CLIENT_HELLO_NO_EXTENSIONS +) + class TestMasterSecretLogger(tservers.ServerTestBase): handler = EchoHandler @@ -55,16 +68,43 @@ class TestTLSInvalid: tls.create_client_context(alpn_select="foo", alpn_select_callback="bar") +def test_is_record_magic(): + assert not tls.is_tls_record_magic(b"POST /") + assert not tls.is_tls_record_magic(b"\x16\x03") + assert not tls.is_tls_record_magic(b"\x16\x03\x04") + assert tls.is_tls_record_magic(b"\x16\x03\x00") + assert tls.is_tls_record_magic(b"\x16\x03\x01") + assert tls.is_tls_record_magic(b"\x16\x03\x02") + assert tls.is_tls_record_magic(b"\x16\x03\x03") + + +def test_get_client_hello(): + rfile = io.BufferedReader(io.BytesIO( + FULL_CLIENT_HELLO_NO_EXTENSIONS + )) + assert tls.get_client_hello(rfile) + + rfile = io.BufferedReader(io.BytesIO( + FULL_CLIENT_HELLO_NO_EXTENSIONS[:30] + )) + with pytest.raises(exceptions.TlsProtocolException, message="Unexpected EOF"): + tls.get_client_hello(rfile) + + rfile = io.BufferedReader(io.BytesIO( + b"GET /" + )) + with pytest.raises(exceptions.TlsProtocolException, message="Expected TLS record"): + tls.get_client_hello(rfile) + + class TestClientHello: def test_no_extensions(self): - data = bytes.fromhex( - "03015658a756ab2c2bff55f636814deac086b7ca56b65058c7893ffc6074f5245f70205658a75475103a152637" - "78e1bb6d22e8bbd5b6b0a3a59760ad354e91ba20d353001a0035002f000a000500040009000300060008006000" - "61006200640100" - ) - c = tls.ClientHello(data) + c = tls.ClientHello(CLIENT_HELLO_NO_EXTENSIONS) + assert repr(c) assert c.sni is None + assert c.cipher_suites == [53, 47, 10, 5, 4, 9, 3, 6, 8, 96, 97, 98, 100] assert c.alpn_protocols == [] + assert c.extensions == [] def test_extensions(self): data = bytes.fromhex( @@ -75,5 +115,43 @@ class TestClientHello: "170018" ) c = tls.ClientHello(data) + assert repr(c) assert c.sni == 'example.com' + assert c.cipher_suites == [ + 49195, 49199, 49196, 49200, 52393, 52392, 52244, 52243, 49161, + 49171, 49162, 49172, 156, 157, 47, 53, 10 + ] assert c.alpn_protocols == [b'h2', b'http/1.1'] + assert c.extensions == [ + (65281, b'\x00'), + (0, b'\x00\x0e\x00\x00\x0bexample.com'), + (23, b''), + (35, b''), + (13, b'\x00\x10\x06\x01\x06\x03\x05\x01\x05\x03\x04\x01\x04\x03\x02\x01\x02\x03'), + (5, b'\x01\x00\x00\x00\x00'), + (18, b''), + (16, b'\x00\x0c\x02h2\x08http/1.1'), + (30032, b''), + (11, b'\x01\x00'), + (10, b'\x00\x06\x00\x1d\x00\x17\x00\x18') + ] + + def test_from_conn(self): + rfile = io.BufferedReader(io.BytesIO( + FULL_CLIENT_HELLO_NO_EXTENSIONS + )) + assert tls.ClientHello.from_file(rfile) + + rfile = io.BufferedReader(io.BytesIO( + b"" + )) + with pytest.raises(exceptions.TlsProtocolException): + tls.ClientHello.from_file(rfile) + + rfile = io.BufferedReader(io.BytesIO( + b"\x16\x03\x03\x00\x07" # record layer + b"\x01\x00\x00\x03" + # handshake header + b"foo" + )) + with pytest.raises(exceptions.TlsProtocolException, message='Cannot parse Client Hello'): + tls.ClientHello.from_file(rfile) |