diff options
Diffstat (limited to 'test')
-rw-r--r-- | test/mitmproxy/net/test_tcp.py | 6 | ||||
-rw-r--r-- | test/mitmproxy/test_certs.py | 27 | ||||
-rw-r--r-- | test/mitmproxy/test_connections.py | 199 |
3 files changed, 197 insertions, 35 deletions
diff --git a/test/mitmproxy/net/test_tcp.py b/test/mitmproxy/net/test_tcp.py index 252d896c..cf010f6e 100644 --- a/test/mitmproxy/net/test_tcp.py +++ b/test/mitmproxy/net/test_tcp.py @@ -602,12 +602,6 @@ class TestDHParams(tservers.ServerTestBase): ret = c.get_current_cipher() assert ret[0] == "DHE-RSA-AES256-SHA" - def test_create_dhparams(self): - with tutils.tmpdir() as d: - filename = os.path.join(d, "dhparam.pem") - certs.CertStore.load_dhparam(filename) - assert os.path.exists(filename) - class TestTCPClient: diff --git a/test/mitmproxy/test_certs.py b/test/mitmproxy/test_certs.py index f1eff9ba..9bd3ad25 100644 --- a/test/mitmproxy/test_certs.py +++ b/test/mitmproxy/test_certs.py @@ -117,6 +117,12 @@ class TestCertStore: ret = ca1.get_cert(b"foo.com", []) assert ret[0].serial == dc[0].serial + def test_create_dhparams(self): + with tutils.tmpdir() as d: + filename = os.path.join(d, "dhparam.pem") + certs.CertStore.load_dhparam(filename) + assert os.path.exists(filename) + class TestDummyCert: @@ -127,9 +133,10 @@ class TestDummyCert: ca.default_privatekey, ca.default_ca, b"foo.com", - [b"one.com", b"two.com", b"*.three.com"] + [b"one.com", b"two.com", b"*.three.com", b"127.0.0.1"] ) assert r.cn == b"foo.com" + assert r.altnames == [b'one.com', b'two.com', b'*.three.com'] r = certs.dummy_cert( ca.default_privatekey, @@ -138,6 +145,7 @@ class TestDummyCert: [] ) assert r.cn is None + assert r.altnames == [] class TestSSLCert: @@ -179,3 +187,20 @@ class TestSSLCert: d = f.read() s = certs.SSLCert.from_der(d) assert s.cn + + def test_state(self): + with open(tutils.test_data.path("mitmproxy/net/data/text_cert"), "rb") as f: + d = f.read() + c = certs.SSLCert.from_pem(d) + + c.get_state() + c2 = c.copy() + a = c.get_state() + b = c2.get_state() + assert a == b + assert c == c2 + assert c is not c2 + + x = certs.SSLCert('') + x.set_state(a) + assert x == c diff --git a/test/mitmproxy/test_connections.py b/test/mitmproxy/test_connections.py index fa23a53c..0083f57c 100644 --- a/test/mitmproxy/test_connections.py +++ b/test/mitmproxy/test_connections.py @@ -1,13 +1,57 @@ +import socket +import os +import threading +import ssl +import OpenSSL +import pytest from unittest import mock from mitmproxy import connections from mitmproxy import exceptions +from mitmproxy.net import tcp from mitmproxy.net.http import http1 from mitmproxy.test import tflow +from mitmproxy.test import tutils +from .net import tservers from pathod import test class TestClientConnection: + + def test_send(self): + c = tflow.tclient_conn() + c.send(b'foobar') + c.send([b'foo', b'bar']) + with pytest.raises(TypeError): + c.send('string') + with pytest.raises(TypeError): + c.send(['string', 'not']) + assert c.wfile.getvalue() == b'foobarfoobar' + + def test_repr(self): + c = tflow.tclient_conn() + assert 'address:22' in repr(c) + assert 'ALPN' in repr(c) + assert 'TLS' not in repr(c) + + c.alpn_proto_negotiated = None + c.tls_established = True + assert 'ALPN' not in repr(c) + assert 'TLS' in repr(c) + + def test_tls_established_property(self): + c = tflow.tclient_conn() + c.tls_established = True + assert c.ssl_established + assert c.tls_established + c.tls_established = False + assert not c.ssl_established + assert not c.tls_established + + def test_make_dummy(self): + c = connections.ClientConnection.make_dummy(('foobar', 1234)) + assert c.address == ('foobar', 1234) + def test_state(self): c = tflow.tclient_conn() assert connections.ClientConnection.from_state(c.get_state()).get_state() == \ @@ -24,44 +68,143 @@ class TestClientConnection: c3 = c.copy() assert c3.get_state() == c.get_state() - assert str(c) - class TestServerConnection: + def test_send(self): + c = tflow.tserver_conn() + c.send(b'foobar') + c.send([b'foo', b'bar']) + with pytest.raises(TypeError): + c.send('string') + with pytest.raises(TypeError): + c.send(['string', 'not']) + assert c.wfile.getvalue() == b'foobarfoobar' + + def test_repr(self): + c = tflow.tserver_conn() + + c.sni = 'foobar' + c.tls_established = True + c.alpn_proto_negotiated = b'h2' + assert 'address:22' in repr(c) + assert 'ALPN' in repr(c) + assert 'TLS: foobar' in repr(c) + + c.sni = None + c.tls_established = True + c.alpn_proto_negotiated = None + assert 'ALPN' not in repr(c) + assert 'TLS' in repr(c) + + c.sni = None + c.tls_established = False + assert 'TLS' not in repr(c) + + def test_tls_established_property(self): + c = tflow.tserver_conn() + c.tls_established = True + assert c.ssl_established + assert c.tls_established + c.tls_established = False + assert not c.ssl_established + assert not c.tls_established + + def test_make_dummy(self): + c = connections.ServerConnection.make_dummy(('foobar', 1234)) + assert c.address == ('foobar', 1234) + def test_simple(self): - self.d = test.Daemon() - sc = connections.ServerConnection((self.d.IFACE, self.d.port)) - sc.connect() + d = test.Daemon() + c = connections.ServerConnection((d.IFACE, d.port)) + c.connect() f = tflow.tflow() - f.server_conn = sc + f.server_conn = c f.request.path = "/p/200:da" # use this protocol just to assemble - not for actual sending - sc.wfile.write(http1.assemble_request(f.request)) - sc.wfile.flush() + c.wfile.write(http1.assemble_request(f.request)) + c.wfile.flush() - assert http1.read_response(sc.rfile, f.request, 1000) - assert self.d.last_log() + assert http1.read_response(c.rfile, f.request, 1000) + assert d.last_log() - sc.finish() - self.d.shutdown() + c.finish() + d.shutdown() def test_terminate_error(self): - self.d = test.Daemon() - sc = connections.ServerConnection((self.d.IFACE, self.d.port)) - sc.connect() - sc.connection = mock.Mock() - sc.connection.recv = mock.Mock(return_value=False) - sc.connection.flush = mock.Mock(side_effect=exceptions.TcpDisconnect) - sc.finish() - self.d.shutdown() + d = test.Daemon() + c = connections.ServerConnection((d.IFACE, d.port)) + c.connect() + c.connection = mock.Mock() + c.connection.recv = mock.Mock(return_value=False) + c.connection.flush = mock.Mock(side_effect=exceptions.TcpDisconnect) + c.finish() + d.shutdown() - def test_repr(self): - sc = tflow.tserver_conn() - assert "address:22" in repr(sc) - assert "ssl" not in repr(sc) - sc.ssl_established = True - assert "ssl" in repr(sc) - sc.sni = "foo" - assert "foo" in repr(sc) + def test_sni(self): + c = connections.ServerConnection(('', 1234)) + with pytest.raises(ValueError, matches='sni must be str, not '): + c.establish_ssl(None, b'foobar') + + +class TestClientConnectionTLS: + + @pytest.mark.parametrize("sni", [ + None, + "example.com" + ]) + def test_tls_with_sni(self, sni): + address = ('127.0.0.1', 0) + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + sock.bind(address) + sock.listen() + address = sock.getsockname() + + def client_run(): + ctx = ssl.create_default_context() + ctx.check_hostname = False + ctx.verify_mode = ssl.CERT_NONE + s = socket.create_connection(address) + s = ctx.wrap_socket(s, server_hostname=sni) + s.send(b'foobar') + s.shutdown(socket.SHUT_RDWR) + threading.Thread(target=client_run).start() + + connection, client_address = sock.accept() + c = connections.ClientConnection(connection, client_address, None) + + cert = tutils.test_data.path("mitmproxy/net/data/server.crt") + key = OpenSSL.crypto.load_privatekey( + OpenSSL.crypto.FILETYPE_PEM, + open(tutils.test_data.path("mitmproxy/net/data/server.key"), "rb").read()) + c.convert_to_ssl(cert, key) + assert c.connected() + assert c.sni == sni + assert c.tls_established + assert c.rfile.read(6) == b'foobar' + c.finish() + + +class TestServerConnectionTLS(tservers.ServerTestBase): + ssl = True + + class handler(tcp.BaseHandler): + def handle(self): + self.finish() + + @pytest.mark.parametrize("clientcert", [ + None, + tutils.test_data.path("mitmproxy/data/clientcert"), + os.path.join(tutils.test_data.path("mitmproxy/data/clientcert"), "client.pem"), + ]) + def test_tls(self, clientcert): + c = connections.ServerConnection(("127.0.0.1", self.port)) + c.connect() + c.establish_ssl(clientcert, "foo.com") + assert c.connected() + assert c.sni == "foo.com" + assert c.tls_established + c.close() + c.finish() |