diff options
Diffstat (limited to 'test')
56 files changed, 1335 insertions, 490 deletions
diff --git a/test/examples/test_xss_scanner.py b/test/examples/test_xss_scanner.py index e15d7e10..8cf06a2a 100644 --- a/test/examples/test_xss_scanner.py +++ b/test/examples/test_xss_scanner.py @@ -343,10 +343,10 @@ class TestXSSScanner(): monkeypatch.setattr("mitmproxy.ctx.log", logger) xss.log_SQLi_data(None) assert logger.args == [] - xss.log_SQLi_data(xss.SQLiData(b'https://example.com', - b'Location', - b'Oracle.*Driver', - b'Oracle')) + xss.log_SQLi_data(xss.SQLiData('https://example.com', + 'Location', + 'Oracle.*Driver', + 'Oracle')) assert logger.args[0] == '===== SQLi Found =====' assert logger.args[1] == 'SQLi URL: https://example.com' assert logger.args[2] == 'Injection Point: Location' diff --git a/test/mitmproxy/addons/test_browser.py b/test/mitmproxy/addons/test_browser.py index d1b32186..407a3fe6 100644 --- a/test/mitmproxy/addons/test_browser.py +++ b/test/mitmproxy/addons/test_browser.py @@ -5,7 +5,8 @@ from mitmproxy.test import taddons def test_browser(): - with mock.patch("subprocess.Popen") as po: + with mock.patch("subprocess.Popen") as po, mock.patch("shutil.which") as which: + which.return_value = "chrome" b = browser.Browser() with taddons.context() as tctx: b.start() @@ -18,3 +19,13 @@ def test_browser(): assert tctx.master.has_log("already running") b.done() assert not b.browser + + +def test_no_browser(): + with mock.patch("shutil.which") as which: + which.return_value = False + + b = browser.Browser() + with taddons.context() as tctx: + b.start() + assert tctx.master.has_log("platform is not supported") diff --git a/test/mitmproxy/addons/test_clientplayback.py b/test/mitmproxy/addons/test_clientplayback.py index 2dc7eb92..3f990668 100644 --- a/test/mitmproxy/addons/test_clientplayback.py +++ b/test/mitmproxy/addons/test_clientplayback.py @@ -52,6 +52,10 @@ class TestClientPlayback: cp.stop_replay() assert not cp.flows + df = tflow.DummyFlow(tflow.tclient_conn(), tflow.tserver_conn(), True) + with pytest.raises(exceptions.CommandError, match="Can't replay live flow."): + cp.start_replay([df]) + def test_load_file(self, tmpdir): cp = clientplayback.ClientPlayback() with taddons.context(): diff --git a/test/mitmproxy/addons/test_core.py b/test/mitmproxy/addons/test_core.py index c132d80a..5aa4ef37 100644 --- a/test/mitmproxy/addons/test_core.py +++ b/test/mitmproxy/addons/test_core.py @@ -69,9 +69,6 @@ def test_flow_set(): f = tflow.tflow(resp=True) assert sa.flow_set_options() - with pytest.raises(exceptions.CommandError): - sa.flow_set([f], "flibble", "post") - assert f.request.method != "post" sa.flow_set([f], "method", "post") assert f.request.method == "POST" @@ -126,9 +123,6 @@ def test_encoding(): sa.encode_toggle([f], "request") assert "content-encoding" not in f.request.headers - with pytest.raises(exceptions.CommandError): - sa.encode([f], "request", "invalid") - def test_options(tmpdir): p = str(tmpdir.join("path")) diff --git a/test/mitmproxy/addons/test_cut.py b/test/mitmproxy/addons/test_cut.py index 242c6c2f..56568f21 100644 --- a/test/mitmproxy/addons/test_cut.py +++ b/test/mitmproxy/addons/test_cut.py @@ -7,89 +7,58 @@ from mitmproxy.test import taddons from mitmproxy.test import tflow from mitmproxy.test import tutils import pytest +import pyperclip from unittest import mock def test_extract(): tf = tflow.tflow(resp=True) tests = [ - ["q.method", "GET"], - ["q.scheme", "http"], - ["q.host", "address"], - ["q.port", "22"], - ["q.path", "/path"], - ["q.url", "http://address:22/path"], - ["q.text", "content"], - ["q.content", b"content"], - ["q.raw_content", b"content"], - ["q.header[header]", "qvalue"], - - ["s.status_code", "200"], - ["s.reason", "OK"], - ["s.text", "message"], - ["s.content", b"message"], - ["s.raw_content", b"message"], - ["s.header[header-response]", "svalue"], - - ["cc.address.port", "22"], - ["cc.address.host", "127.0.0.1"], - ["cc.tls_version", "TLSv1.2"], - ["cc.sni", "address"], - ["cc.ssl_established", "false"], - - ["sc.address.port", "22"], - ["sc.address.host", "address"], - ["sc.ip_address.host", "192.168.0.1"], - ["sc.tls_version", "TLSv1.2"], - ["sc.sni", "address"], - ["sc.ssl_established", "false"], + ["request.method", "GET"], + ["request.scheme", "http"], + ["request.host", "address"], + ["request.http_version", "HTTP/1.1"], + ["request.port", "22"], + ["request.path", "/path"], + ["request.url", "http://address:22/path"], + ["request.text", "content"], + ["request.content", b"content"], + ["request.raw_content", b"content"], + ["request.timestamp_start", "946681200"], + ["request.timestamp_end", "946681201"], + ["request.header[header]", "qvalue"], + + ["response.status_code", "200"], + ["response.reason", "OK"], + ["response.text", "message"], + ["response.content", b"message"], + ["response.raw_content", b"message"], + ["response.header[header-response]", "svalue"], + ["response.timestamp_start", "946681202"], + ["response.timestamp_end", "946681203"], + + ["client_conn.address.port", "22"], + ["client_conn.address.host", "127.0.0.1"], + ["client_conn.tls_version", "TLSv1.2"], + ["client_conn.sni", "address"], + ["client_conn.tls_established", "false"], + + ["server_conn.address.port", "22"], + ["server_conn.address.host", "address"], + ["server_conn.ip_address.host", "192.168.0.1"], + ["server_conn.tls_version", "TLSv1.2"], + ["server_conn.sni", "address"], + ["server_conn.tls_established", "false"], ] - for t in tests: - ret = cut.extract(t[0], tf) - if ret != t[1]: - raise AssertionError("%s: Expected %s, got %s" % (t[0], t[1], ret)) + for spec, expected in tests: + ret = cut.extract(spec, tf) + assert spec and ret == expected with open(tutils.test_data.path("mitmproxy/net/data/text_cert"), "rb") as f: d = f.read() - c1 = certs.SSLCert.from_pem(d) + c1 = certs.Cert.from_pem(d) tf.server_conn.cert = c1 - assert "CERTIFICATE" in cut.extract("sc.cert", tf) - - -def test_parse_cutspec(): - tests = [ - ("", None, True), - ("req.method", ("@all", ["req.method"]), False), - ( - "req.method,req.host", - ("@all", ["req.method", "req.host"]), - False - ), - ( - "req.method,req.host|~b foo", - ("~b foo", ["req.method", "req.host"]), - False - ), - ( - "req.method,req.host|~b foo | ~b bar", - ("~b foo | ~b bar", ["req.method", "req.host"]), - False - ), - ( - "req.method, req.host | ~b foo | ~b bar", - ("~b foo | ~b bar", ["req.method", "req.host"]), - False - ), - ] - for cutspec, output, err in tests: - try: - assert cut.parse_cutspec(cutspec) == output - except exceptions.CommandError: - if not err: - raise - else: - if err: - raise AssertionError("Expected error.") + assert "CERTIFICATE" in cut.extract("server_conn.cert", tf) def test_headername(): @@ -110,69 +79,95 @@ def test_cut_clip(): v.add([tflow.tflow(resp=True)]) with mock.patch('pyperclip.copy') as pc: - tctx.command(c.clip, "q.method|@all") + tctx.command(c.clip, "@all", "request.method") assert pc.called with mock.patch('pyperclip.copy') as pc: - tctx.command(c.clip, "q.content|@all") + tctx.command(c.clip, "@all", "request.content") assert pc.called with mock.patch('pyperclip.copy') as pc: - tctx.command(c.clip, "q.method,q.content|@all") + tctx.command(c.clip, "@all", "request.method,request.content") assert pc.called + with mock.patch('pyperclip.copy') as pc: + log_message = "Pyperclip could not find a " \ + "copy/paste mechanism for your system." + pc.side_effect = pyperclip.PyperclipException(log_message) + tctx.command(c.clip, "@all", "request.method") + assert tctx.master.has_log(log_message, level="error") + -def test_cut_file(tmpdir): +def test_cut_save(tmpdir): f = str(tmpdir.join("path")) v = view.View() c = cut.Cut() with taddons.context() as tctx: tctx.master.addons.add(v, c) - v.add([tflow.tflow(resp=True)]) - tctx.command(c.save, "q.method|@all", f) + tctx.command(c.save, "@all", "request.method", f) assert qr(f) == b"GET" - tctx.command(c.save, "q.content|@all", f) + tctx.command(c.save, "@all", "request.content", f) assert qr(f) == b"content" - tctx.command(c.save, "q.content|@all", "+" + f) + tctx.command(c.save, "@all", "request.content", "+" + f) assert qr(f) == b"content\ncontent" v.add([tflow.tflow(resp=True)]) - tctx.command(c.save, "q.method|@all", f) + tctx.command(c.save, "@all", "request.method", f) assert qr(f).splitlines() == [b"GET", b"GET"] - tctx.command(c.save, "q.method,q.content|@all", f) + tctx.command(c.save, "@all", "request.method,request.content", f) assert qr(f).splitlines() == [b"GET,content", b"GET,content"] -def test_cut(): +@pytest.mark.parametrize("exception, log_message", [ + (PermissionError, "Permission denied"), + (IsADirectoryError, "Is a directory"), + (FileNotFoundError, "No such file or directory") +]) +def test_cut_save_open(exception, log_message, tmpdir): + f = str(tmpdir.join("path")) v = view.View() c = cut.Cut() with taddons.context() as tctx: - v.add([tflow.tflow(resp=True)]) tctx.master.addons.add(v, c) - assert c.cut("q.method|@all") == [["GET"]] - assert c.cut("q.scheme|@all") == [["http"]] - assert c.cut("q.host|@all") == [["address"]] - assert c.cut("q.port|@all") == [["22"]] - assert c.cut("q.path|@all") == [["/path"]] - assert c.cut("q.url|@all") == [["http://address:22/path"]] - assert c.cut("q.content|@all") == [[b"content"]] - assert c.cut("q.header[header]|@all") == [["qvalue"]] - assert c.cut("q.header[unknown]|@all") == [[""]] - - assert c.cut("s.status_code|@all") == [["200"]] - assert c.cut("s.reason|@all") == [["OK"]] - assert c.cut("s.content|@all") == [[b"message"]] - assert c.cut("s.header[header-response]|@all") == [["svalue"]] - assert c.cut("moo") == [[""]] + v.add([tflow.tflow(resp=True)]) + + with mock.patch("mitmproxy.addons.cut.open") as m: + m.side_effect = exception(log_message) + tctx.command(c.save, "@all", "request.method", f) + assert tctx.master.has_log(log_message, level="error") + + +def test_cut(): + c = cut.Cut() + with taddons.context(): + tflows = [tflow.tflow(resp=True)] + assert c.cut(tflows, ["request.method"]) == [["GET"]] + assert c.cut(tflows, ["request.scheme"]) == [["http"]] + assert c.cut(tflows, ["request.host"]) == [["address"]] + assert c.cut(tflows, ["request.port"]) == [["22"]] + assert c.cut(tflows, ["request.path"]) == [["/path"]] + assert c.cut(tflows, ["request.url"]) == [["http://address:22/path"]] + assert c.cut(tflows, ["request.content"]) == [[b"content"]] + assert c.cut(tflows, ["request.header[header]"]) == [["qvalue"]] + assert c.cut(tflows, ["request.header[unknown]"]) == [[""]] + + assert c.cut(tflows, ["response.status_code"]) == [["200"]] + assert c.cut(tflows, ["response.reason"]) == [["OK"]] + assert c.cut(tflows, ["response.content"]) == [[b"message"]] + assert c.cut(tflows, ["response.header[header-response]"]) == [["svalue"]] + assert c.cut(tflows, ["moo"]) == [[""]] with pytest.raises(exceptions.CommandError): - assert c.cut("__dict__") == [[""]] + assert c.cut(tflows, ["__dict__"]) == [[""]] + + with taddons.context(): + tflows = [tflow.tflow(resp=False)] + assert c.cut(tflows, ["response.reason"]) == [[""]] + assert c.cut(tflows, ["response.header[key]"]) == [[""]] - v = view.View() c = cut.Cut() - with taddons.context() as tctx: - tctx.master.addons.add(v, c) - v.add([tflow.ttcpflow()]) - assert c.cut("q.method|@all") == [[""]] - assert c.cut("s.status|@all") == [[""]] + with taddons.context(): + tflows = [tflow.ttcpflow()] + assert c.cut(tflows, ["request.method"]) == [[""]] + assert c.cut(tflows, ["response.status"]) == [[""]] diff --git a/test/mitmproxy/addons/test_eventstore.py b/test/mitmproxy/addons/test_eventstore.py index f54b9980..8ac26b05 100644 --- a/test/mitmproxy/addons/test_eventstore.py +++ b/test/mitmproxy/addons/test_eventstore.py @@ -30,3 +30,18 @@ def test_simple(): assert not sig_add.called assert sig_refresh.called + + +def test_max_size(): + store = eventstore.EventStore(3) + assert store.size == 3 + store.log(log.LogEntry("foo", "info")) + store.log(log.LogEntry("bar", "info")) + store.log(log.LogEntry("baz", "info")) + assert len(store.data) == 3 + assert ["foo", "bar", "baz"] == [x.msg for x in store.data] + + # overflow + store.log(log.LogEntry("boo", "info")) + assert len(store.data) == 3 + assert ["bar", "baz", "boo"] == [x.msg for x in store.data] diff --git a/test/mitmproxy/addons/test_export.py b/test/mitmproxy/addons/test_export.py index 233c62d5..07227a7a 100644 --- a/test/mitmproxy/addons/test_export.py +++ b/test/mitmproxy/addons/test_export.py @@ -1,6 +1,8 @@ -import pytest import os +import pytest +import pyperclip + from mitmproxy import exceptions from mitmproxy.addons import export # heh from mitmproxy.test import tflow @@ -94,9 +96,24 @@ def test_export(tmpdir): os.unlink(f) +@pytest.mark.parametrize("exception, log_message", [ + (PermissionError, "Permission denied"), + (IsADirectoryError, "Is a directory"), + (FileNotFoundError, "No such file or directory") +]) +def test_export_open(exception, log_message, tmpdir): + f = str(tmpdir.join("path")) + e = export.Export() + with taddons.context() as tctx: + with mock.patch("mitmproxy.addons.export.open") as m: + m.side_effect = exception(log_message) + e.file("raw", tflow.tflow(resp=True), f) + assert tctx.master.has_log(log_message, level="error") + + def test_clip(tmpdir): e = export.Export() - with taddons.context(): + with taddons.context() as tctx: with pytest.raises(exceptions.CommandError): e.clip("nonexistent", tflow.tflow(resp=True)) @@ -107,3 +124,10 @@ def test_clip(tmpdir): with mock.patch('pyperclip.copy') as pc: e.clip("curl", tflow.tflow(resp=True)) assert pc.called + + with mock.patch('pyperclip.copy') as pc: + log_message = "Pyperclip could not find a " \ + "copy/paste mechanism for your system." + pc.side_effect = pyperclip.PyperclipException(log_message) + e.clip("raw", tflow.tflow(resp=True)) + assert tctx.master.has_log(log_message, level="error") diff --git a/test/mitmproxy/addons/test_proxyauth.py b/test/mitmproxy/addons/test_proxyauth.py index 1d05e137..97259d1c 100644 --- a/test/mitmproxy/addons/test_proxyauth.py +++ b/test/mitmproxy/addons/test_proxyauth.py @@ -190,7 +190,7 @@ class TestProxyAuth: with pytest.raises(exceptions.OptionsError): ctx.configure(up, proxyauth="ldap:test:test:test") - with pytest.raises(IndexError): + with pytest.raises(exceptions.OptionsError): ctx.configure(up, proxyauth="ldap:fake_serveruid=?dc=example,dc=com:person") with pytest.raises(exceptions.OptionsError): diff --git a/test/mitmproxy/addons/test_save.py b/test/mitmproxy/addons/test_save.py index a4e425cd..2dee708f 100644 --- a/test/mitmproxy/addons/test_save.py +++ b/test/mitmproxy/addons/test_save.py @@ -44,6 +44,19 @@ def test_tcp(tmpdir): assert rd(p) +def test_websocket(tmpdir): + sa = save.Save() + with taddons.context() as tctx: + p = str(tmpdir.join("foo")) + tctx.configure(sa, save_stream_file=p) + + f = tflow.twebsocketflow() + sa.websocket_start(f) + sa.websocket_end(f) + tctx.configure(sa, save_stream_file=None) + assert rd(p) + + def test_save_command(tmpdir): sa = save.Save() with taddons.context() as tctx: diff --git a/test/mitmproxy/addons/test_script.py b/test/mitmproxy/addons/test_script.py index c4fe6b43..78a5be6c 100644 --- a/test/mitmproxy/addons/test_script.py +++ b/test/mitmproxy/addons/test_script.py @@ -68,6 +68,18 @@ class TestScript: with pytest.raises(exceptions.OptionsError): script.Script("nonexistent") + def test_quotes_around_filename(self): + """ + Test that a script specified as '"foo.py"' works to support the calling convention of + mitmproxy 2.0, as e.g. used by Cuckoo Sandbox. + """ + path = tutils.test_data.path("mitmproxy/data/addonscripts/recorder/recorder.py") + + s = script.Script( + '"{}"'.format(path) + ) + assert '"' not in s.fullpath + def test_simple(self): with taddons.context() as tctx: sc = script.Script( diff --git a/test/mitmproxy/addons/test_view.py b/test/mitmproxy/addons/test_view.py index 1e0c3b55..6f2a9ca5 100644 --- a/test/mitmproxy/addons/test_view.py +++ b/test/mitmproxy/addons/test_view.py @@ -30,7 +30,7 @@ def test_order_refresh(): with taddons.context() as tctx: tctx.configure(v, view_order="time") v.add([tf]) - tf.request.timestamp_start = 1 + tf.request.timestamp_start = 10 assert not sargs v.update([tf]) assert sargs @@ -41,7 +41,7 @@ def test_order_generators(): tf = tflow.tflow(resp=True) rs = view.OrderRequestStart(v) - assert rs.generate(tf) == 0 + assert rs.generate(tf) == 946681200 rm = view.OrderRequestMethod(v) assert rm.generate(tf) == tf.request.method @@ -147,6 +147,10 @@ def test_create(): assert v[0].request.url == "http://foo.com/" v.create("get", "http://foo.com") assert len(v) == 2 + with pytest.raises(exceptions.CommandError, match="Invalid URL"): + v.create("get", "http://foo.com\\") + with pytest.raises(exceptions.CommandError, match="Invalid URL"): + v.create("get", "http://") def test_orders(): @@ -175,6 +179,10 @@ def test_load(tmpdir): v.load_file("nonexistent_file_path") except IOError: assert False + with open(path, "wb") as f: + f.write(b"invalidflows") + v.load_file(path) + assert tctx.master.has_log("Invalid data format.") def test_resolve(): diff --git a/test/mitmproxy/contentviews/test_auto.py b/test/mitmproxy/contentviews/test_auto.py index 2ff43139..cd888a2d 100644 --- a/test/mitmproxy/contentviews/test_auto.py +++ b/test/mitmproxy/contentviews/test_auto.py @@ -1,6 +1,6 @@ from mitmproxy.contentviews import auto from mitmproxy.net import http -from mitmproxy.types import multidict +from mitmproxy.coretypes import multidict from . import full_eval diff --git a/test/mitmproxy/contentviews/test_base.py b/test/mitmproxy/contentviews/test_base.py index 777ab4dd..c94d8be2 100644 --- a/test/mitmproxy/contentviews/test_base.py +++ b/test/mitmproxy/contentviews/test_base.py @@ -1 +1,17 @@ -# TODO: write tests +import pytest +from mitmproxy.contentviews import base + + +def test_format_dict(): + d = {"one": "two", "three": "four"} + f_d = base.format_dict(d) + assert next(f_d) + + d = {"adsfa": ""} + f_d = base.format_dict(d) + assert next(f_d) + + d = {} + f_d = base.format_dict(d) + with pytest.raises(StopIteration): + next(f_d) diff --git a/test/mitmproxy/contentviews/test_query.py b/test/mitmproxy/contentviews/test_query.py index d2bddd05..741b23f1 100644 --- a/test/mitmproxy/contentviews/test_query.py +++ b/test/mitmproxy/contentviews/test_query.py @@ -1,5 +1,5 @@ from mitmproxy.contentviews import query -from mitmproxy.types import multidict +from mitmproxy.coretypes import multidict from . import full_eval diff --git a/test/mitmproxy/types/__init__.py b/test/mitmproxy/coretypes/__init__.py index e69de29b..e69de29b 100644 --- a/test/mitmproxy/types/__init__.py +++ b/test/mitmproxy/coretypes/__init__.py diff --git a/test/mitmproxy/types/test_basethread.py b/test/mitmproxy/coretypes/test_basethread.py index a91588eb..4a383fea 100644 --- a/test/mitmproxy/types/test_basethread.py +++ b/test/mitmproxy/coretypes/test_basethread.py @@ -1,5 +1,5 @@ import re -from mitmproxy.types import basethread +from mitmproxy.coretypes import basethread def test_basethread(): diff --git a/test/mitmproxy/types/test_bidi.py b/test/mitmproxy/coretypes/test_bidi.py index e3a259fd..3bdad3c2 100644 --- a/test/mitmproxy/types/test_bidi.py +++ b/test/mitmproxy/coretypes/test_bidi.py @@ -1,5 +1,5 @@ import pytest -from mitmproxy.types import bidi +from mitmproxy.coretypes import bidi def test_bidi(): diff --git a/test/mitmproxy/types/test_multidict.py b/test/mitmproxy/coretypes/test_multidict.py index c76cd753..273d8ca2 100644 --- a/test/mitmproxy/types/test_multidict.py +++ b/test/mitmproxy/coretypes/test_multidict.py @@ -1,6 +1,6 @@ import pytest -from mitmproxy.types import multidict +from mitmproxy.coretypes import multidict class _TMulti: diff --git a/test/mitmproxy/types/test_serializable.py b/test/mitmproxy/coretypes/test_serializable.py index 390d17e1..a316f876 100644 --- a/test/mitmproxy/types/test_serializable.py +++ b/test/mitmproxy/coretypes/test_serializable.py @@ -1,6 +1,6 @@ import copy -from mitmproxy.types import serializable +from mitmproxy.coretypes import serializable class SerializableDummy(serializable.Serializable): diff --git a/test/mitmproxy/data/addonscripts/concurrent_decorator_class.py b/test/mitmproxy/data/addonscripts/concurrent_decorator_class.py index 2a7d300c..b52f55c5 100644 --- a/test/mitmproxy/data/addonscripts/concurrent_decorator_class.py +++ b/test/mitmproxy/data/addonscripts/concurrent_decorator_class.py @@ -5,7 +5,7 @@ from mitmproxy.script import concurrent class ConcurrentClass: @concurrent - def request(flow): + def request(self, flow): time.sleep(0.1) diff --git a/test/mitmproxy/net/http/test_cookies.py b/test/mitmproxy/net/http/test_cookies.py index 77549d9e..e12b0f00 100644 --- a/test/mitmproxy/net/http/test_cookies.py +++ b/test/mitmproxy/net/http/test_cookies.py @@ -7,6 +7,10 @@ from mitmproxy.net.http import cookies cookie_pairs = [ [ + "=uno", + [["", "uno"]] + ], + [ "", [] ], @@ -16,7 +20,7 @@ cookie_pairs = [ ], [ "one", - [["one", None]] + [["one", ""]] ], [ "one=uno; two=due", @@ -36,7 +40,7 @@ cookie_pairs = [ ], [ "one=uno; two; three=tre", - [["one", "uno"], ["two", None], ["three", "tre"]] + [["one", "uno"], ["two", ""], ["three", "tre"]] ], [ "_lvs2=zHai1+Hq+Tc2vmc2r4GAbdOI5Jopg3EwsdUT9g=; " @@ -79,8 +83,12 @@ def test_read_quoted_string(): def test_read_cookie_pairs(): vals = [ [ + "=uno", + [["", "uno"]] + ], + [ "one", - [["one", None]] + [["one", ""]] ], [ "one=two", @@ -100,7 +108,7 @@ def test_read_cookie_pairs(): ], [ 'one="two"; three=four; five', - [["one", "two"], ["three", "four"], ["five", None]] + [["one", "two"], ["three", "four"], ["five", ""]] ], [ 'one="\\"two"; three=four', @@ -135,6 +143,12 @@ def test_cookie_roundtrips(): def test_parse_set_cookie_pairs(): pairs = [ [ + "=uno", + [[ + ["", "uno"] + ]] + ], + [ "one=uno", [[ ["one", "uno"] @@ -150,7 +164,7 @@ def test_parse_set_cookie_pairs(): "one=uno; foo", [[ ["one", "uno"], - ["foo", None] + ["foo", ""] ]] ], [ @@ -200,6 +214,12 @@ def test_parse_set_cookie_header(): ";", [] ], [ + "=uno", + [ + ("", "uno", ()) + ] + ], + [ "one=uno", [ ("one", "uno", ()) diff --git a/test/mitmproxy/net/http/test_response.py b/test/mitmproxy/net/http/test_response.py index fa1770fe..f3470384 100644 --- a/test/mitmproxy/net/http/test_response.py +++ b/test/mitmproxy/net/http/test_response.py @@ -113,7 +113,7 @@ class TestResponseUtils: assert attrs["domain"] == "example.com" assert attrs["expires"] == "Wed Oct 21 16:29:41 2015" assert attrs["path"] == "/" - assert attrs["httponly"] is None + assert attrs["httponly"] == "" def test_get_cookies_no_value(self): resp = tresp() @@ -150,10 +150,10 @@ class TestResponseUtils: n = time.time() r.headers["date"] = email.utils.formatdate(n) pre = r.headers["date"] - r.refresh(n) + r.refresh(946681202) assert pre == r.headers["date"] - r.refresh(n + 60) + r.refresh(946681262) d = email.utils.parsedate_tz(r.headers["date"]) d = email.utils.mktime_tz(d) # Weird that this is not exact... diff --git a/test/mitmproxy/net/http/test_url.py b/test/mitmproxy/net/http/test_url.py index 2064aab8..c9f61faf 100644 --- a/test/mitmproxy/net/http/test_url.py +++ b/test/mitmproxy/net/http/test_url.py @@ -108,6 +108,7 @@ def test_empty_key_trailing_equal_sign(): def test_encode(): assert url.encode([('foo', 'bar')]) assert url.encode([('foo', surrogates)]) + assert not url.encode([], similar_to="justatext") def test_decode(): diff --git a/test/mitmproxy/net/test_tcp.py b/test/mitmproxy/net/test_tcp.py index 3e27929d..8c012e42 100644 --- a/test/mitmproxy/net/test_tcp.py +++ b/test/mitmproxy/net/test_tcp.py @@ -1,4 +1,5 @@ from io import BytesIO +import re import queue import time import socket @@ -95,7 +96,13 @@ class TestServerBind(tservers.ServerTestBase): class handler(tcp.BaseHandler): def handle(self): - self.wfile.write(str(self.connection.getpeername()).encode()) + # We may get an ipv4-mapped ipv6 address here, e.g. ::ffff:127.0.0.1. + # Those still appear as "127.0.0.1" in the table, so we need to strip the prefix. + peername = self.connection.getpeername() + address = re.sub("^::ffff:(?=\d+.\d+.\d+.\d+$)", "", peername[0]) + port = peername[1] + + self.wfile.write(str((address, port)).encode()) self.wfile.flush() def test_bind(self): @@ -171,7 +178,7 @@ class TestServerSSL(tservers.ServerTestBase): def test_echo(self): c = tcp.TCPClient(("127.0.0.1", self.port)) with c.connect(): - c.convert_to_ssl(sni="foo.com", options=SSL.OP_ALL) + c.convert_to_tls(sni="foo.com", options=SSL.OP_ALL) testval = b"echo!\n" c.wfile.write(testval) c.wfile.flush() @@ -181,7 +188,7 @@ class TestServerSSL(tservers.ServerTestBase): c = tcp.TCPClient(("127.0.0.1", self.port)) with c.connect(): assert not c.get_current_cipher() - c.convert_to_ssl(sni="foo.com") + c.convert_to_tls(sni="foo.com") ret = c.get_current_cipher() assert ret assert "AES" in ret[0] @@ -198,7 +205,7 @@ class TestSSLv3Only(tservers.ServerTestBase): c = tcp.TCPClient(("127.0.0.1", self.port)) with c.connect(): with pytest.raises(exceptions.TlsException): - c.convert_to_ssl(sni="foo.com") + c.convert_to_tls(sni="foo.com") class TestInvalidTrustFile(tservers.ServerTestBase): @@ -206,7 +213,7 @@ class TestInvalidTrustFile(tservers.ServerTestBase): c = tcp.TCPClient(("127.0.0.1", self.port)) with c.connect(): with pytest.raises(exceptions.TlsException): - c.convert_to_ssl( + c.convert_to_tls( sni="example.mitmproxy.org", verify=SSL.VERIFY_PEER, ca_pemfile=tutils.test_data.path("mitmproxy/net/data/verificationcerts/generate.py") @@ -224,7 +231,7 @@ class TestSSLUpstreamCertVerificationWBadServerCert(tservers.ServerTestBase): def test_mode_default_should_pass(self): c = tcp.TCPClient(("127.0.0.1", self.port)) with c.connect(): - c.convert_to_ssl() + c.convert_to_tls() # Verification errors should be saved even if connection isn't aborted # aborted @@ -238,7 +245,7 @@ class TestSSLUpstreamCertVerificationWBadServerCert(tservers.ServerTestBase): def test_mode_none_should_pass(self): c = tcp.TCPClient(("127.0.0.1", self.port)) with c.connect(): - c.convert_to_ssl(verify=SSL.VERIFY_NONE) + c.convert_to_tls(verify=SSL.VERIFY_NONE) # Verification errors should be saved even if connection isn't aborted assert c.ssl_verification_error @@ -252,7 +259,7 @@ class TestSSLUpstreamCertVerificationWBadServerCert(tservers.ServerTestBase): c = tcp.TCPClient(("127.0.0.1", self.port)) with c.connect(): with pytest.raises(exceptions.InvalidCertificateException): - c.convert_to_ssl( + c.convert_to_tls( sni="example.mitmproxy.org", verify=SSL.VERIFY_PEER, ca_pemfile=tutils.test_data.path("mitmproxy/net/data/verificationcerts/trusted-root.crt") @@ -277,7 +284,7 @@ class TestSSLUpstreamCertVerificationWBadHostname(tservers.ServerTestBase): c = tcp.TCPClient(("127.0.0.1", self.port)) with c.connect(): with pytest.raises(exceptions.TlsException): - c.convert_to_ssl( + c.convert_to_tls( verify=SSL.VERIFY_PEER, ca_pemfile=tutils.test_data.path("mitmproxy/net/data/verificationcerts/trusted-root.crt") ) @@ -285,7 +292,7 @@ class TestSSLUpstreamCertVerificationWBadHostname(tservers.ServerTestBase): def test_mode_none_should_pass_without_sni(self): c = tcp.TCPClient(("127.0.0.1", self.port)) with c.connect(): - c.convert_to_ssl( + c.convert_to_tls( verify=SSL.VERIFY_NONE, ca_path=tutils.test_data.path("mitmproxy/net/data/verificationcerts/") ) @@ -296,7 +303,7 @@ class TestSSLUpstreamCertVerificationWBadHostname(tservers.ServerTestBase): c = tcp.TCPClient(("127.0.0.1", self.port)) with c.connect(): with pytest.raises(exceptions.InvalidCertificateException): - c.convert_to_ssl( + c.convert_to_tls( sni="mitmproxy.org", verify=SSL.VERIFY_PEER, ca_pemfile=tutils.test_data.path("mitmproxy/net/data/verificationcerts/trusted-root.crt") @@ -315,7 +322,7 @@ class TestSSLUpstreamCertVerificationWValidCertChain(tservers.ServerTestBase): def test_mode_strict_w_pemfile_should_pass(self): c = tcp.TCPClient(("127.0.0.1", self.port)) with c.connect(): - c.convert_to_ssl( + c.convert_to_tls( sni="example.mitmproxy.org", verify=SSL.VERIFY_PEER, ca_pemfile=tutils.test_data.path("mitmproxy/net/data/verificationcerts/trusted-root.crt") @@ -331,7 +338,7 @@ class TestSSLUpstreamCertVerificationWValidCertChain(tservers.ServerTestBase): def test_mode_strict_w_cadir_should_pass(self): c = tcp.TCPClient(("127.0.0.1", self.port)) with c.connect(): - c.convert_to_ssl( + c.convert_to_tls( sni="example.mitmproxy.org", verify=SSL.VERIFY_PEER, ca_path=tutils.test_data.path("mitmproxy/net/data/verificationcerts/") @@ -365,7 +372,7 @@ class TestSSLClientCert(tservers.ServerTestBase): def test_clientcert(self): c = tcp.TCPClient(("127.0.0.1", self.port)) with c.connect(): - c.convert_to_ssl( + c.convert_to_tls( cert=tutils.test_data.path("mitmproxy/net/data/clientcert/client.pem")) assert c.rfile.readline().strip() == b"1" @@ -373,7 +380,7 @@ class TestSSLClientCert(tservers.ServerTestBase): c = tcp.TCPClient(("127.0.0.1", self.port)) with c.connect(): with pytest.raises(exceptions.TlsException): - c.convert_to_ssl(cert=tutils.test_data.path("mitmproxy/net/data/clientcert/make")) + c.convert_to_tls(cert=tutils.test_data.path("mitmproxy/net/data/clientcert/make")) class TestSNI(tservers.ServerTestBase): @@ -393,15 +400,15 @@ class TestSNI(tservers.ServerTestBase): def test_echo(self): c = tcp.TCPClient(("127.0.0.1", self.port)) with c.connect(): - c.convert_to_ssl(sni="foo.com") + c.convert_to_tls(sni="foo.com") assert c.sni == "foo.com" assert c.rfile.readline() == b"foo.com" def test_idn(self): c = tcp.TCPClient(("127.0.0.1", self.port)) with c.connect(): - c.convert_to_ssl(sni="mitmproxyäöüß.example.com") - assert c.ssl_established + c.convert_to_tls(sni="mitmproxyäöüß.example.com") + assert c.tls_established assert "doesn't match" not in str(c.ssl_verification_error) @@ -414,7 +421,7 @@ class TestServerCipherList(tservers.ServerTestBase): def test_echo(self): c = tcp.TCPClient(("127.0.0.1", self.port)) with c.connect(): - c.convert_to_ssl(sni="foo.com") + c.convert_to_tls(sni="foo.com") expected = b"['AES256-GCM-SHA384']" assert c.rfile.read(len(expected) + 2) == expected @@ -435,7 +442,7 @@ class TestServerCurrentCipher(tservers.ServerTestBase): def test_echo(self): c = tcp.TCPClient(("127.0.0.1", self.port)) with c.connect(): - c.convert_to_ssl(sni="foo.com") + c.convert_to_tls(sni="foo.com") assert b'AES256-GCM-SHA384' in c.rfile.readline() @@ -449,7 +456,7 @@ class TestServerCipherListError(tservers.ServerTestBase): c = tcp.TCPClient(("127.0.0.1", self.port)) with c.connect(): with pytest.raises(Exception, match="handshake error"): - c.convert_to_ssl(sni="foo.com") + c.convert_to_tls(sni="foo.com") class TestClientCipherListError(tservers.ServerTestBase): @@ -462,7 +469,7 @@ class TestClientCipherListError(tservers.ServerTestBase): c = tcp.TCPClient(("127.0.0.1", self.port)) with c.connect(): with pytest.raises(Exception, match="cipher specification"): - c.convert_to_ssl(sni="foo.com", cipher_list="bogus") + c.convert_to_tls(sni="foo.com", cipher_list="bogus") class TestSSLDisconnect(tservers.ServerTestBase): @@ -477,7 +484,7 @@ class TestSSLDisconnect(tservers.ServerTestBase): def test_echo(self): c = tcp.TCPClient(("127.0.0.1", self.port)) with c.connect(): - c.convert_to_ssl() + c.convert_to_tls() # Excercise SSL.ZeroReturnError c.rfile.read(10) c.close() @@ -494,7 +501,7 @@ class TestSSLHardDisconnect(tservers.ServerTestBase): def test_echo(self): c = tcp.TCPClient(("127.0.0.1", self.port)) with c.connect(): - c.convert_to_ssl() + c.convert_to_tls() # Exercise SSL.SysCallError c.rfile.read(10) c.close() @@ -558,7 +565,7 @@ class TestALPNClient(tservers.ServerTestBase): def test_alpn(self, monkeypatch, alpn_protos, expected_negotiated, expected_response): c = tcp.TCPClient(("127.0.0.1", self.port)) with c.connect(): - c.convert_to_ssl(alpn_protos=alpn_protos) + c.convert_to_tls(alpn_protos=alpn_protos) assert c.get_alpn_proto_negotiated() == expected_negotiated assert c.rfile.readline().strip() == expected_response @@ -580,7 +587,7 @@ class TestSSLTimeOut(tservers.ServerTestBase): def test_timeout_client(self): c = tcp.TCPClient(("127.0.0.1", self.port)) with c.connect(): - c.convert_to_ssl() + c.convert_to_tls() c.settimeout(0.1) with pytest.raises(exceptions.TcpTimeout): c.rfile.read(10) @@ -598,7 +605,7 @@ class TestDHParams(tservers.ServerTestBase): def test_dhparams(self): c = tcp.TCPClient(("127.0.0.1", self.port)) with c.connect(): - c.convert_to_ssl() + c.convert_to_tls() ret = c.get_current_cipher() assert ret[0] == "DHE-RSA-AES256-SHA" @@ -794,5 +801,5 @@ class TestPeekSSL(TestPeek): def _connect(self, c): with c.connect() as conn: - c.convert_to_ssl() + c.convert_to_tls() return conn.pop() diff --git a/test/mitmproxy/net/test_tls.py b/test/mitmproxy/net/test_tls.py index d0583d34..489bf89f 100644 --- a/test/mitmproxy/net/test_tls.py +++ b/test/mitmproxy/net/test_tls.py @@ -1,3 +1,5 @@ +import io + import pytest from mitmproxy import exceptions @@ -6,6 +8,17 @@ from mitmproxy.net.tcp import TCPClient from test.mitmproxy.net.test_tcp import EchoHandler from . import tservers +CLIENT_HELLO_NO_EXTENSIONS = bytes.fromhex( + "03015658a756ab2c2bff55f636814deac086b7ca56b65058c7893ffc6074f5245f70205658a75475103a152637" + "78e1bb6d22e8bbd5b6b0a3a59760ad354e91ba20d353001a0035002f000a000500040009000300060008006000" + "61006200640100" +) +FULL_CLIENT_HELLO_NO_EXTENSIONS = ( + b"\x16\x03\x03\x00\x65" # record layer + b"\x01\x00\x00\x61" + # handshake header + CLIENT_HELLO_NO_EXTENSIONS +) + class TestMasterSecretLogger(tservers.ServerTestBase): handler = EchoHandler @@ -22,7 +35,7 @@ class TestMasterSecretLogger(tservers.ServerTestBase): c = TCPClient(("127.0.0.1", self.port)) with c.connect(): - c.convert_to_ssl() + c.convert_to_tls() c.wfile.write(testval) c.wfile.flush() assert c.rfile.readline() == testval @@ -53,3 +66,92 @@ class TestTLSInvalid: with pytest.raises(exceptions.TlsException, match="ALPN error"): tls.create_client_context(alpn_select="foo", alpn_select_callback="bar") + + +def test_is_record_magic(): + assert not tls.is_tls_record_magic(b"POST /") + assert not tls.is_tls_record_magic(b"\x16\x03") + assert not tls.is_tls_record_magic(b"\x16\x03\x04") + assert tls.is_tls_record_magic(b"\x16\x03\x00") + assert tls.is_tls_record_magic(b"\x16\x03\x01") + assert tls.is_tls_record_magic(b"\x16\x03\x02") + assert tls.is_tls_record_magic(b"\x16\x03\x03") + + +def test_get_client_hello(): + rfile = io.BufferedReader(io.BytesIO( + FULL_CLIENT_HELLO_NO_EXTENSIONS + )) + assert tls.get_client_hello(rfile) + + rfile = io.BufferedReader(io.BytesIO( + FULL_CLIENT_HELLO_NO_EXTENSIONS[:30] + )) + with pytest.raises(exceptions.TlsProtocolException, message="Unexpected EOF"): + tls.get_client_hello(rfile) + + rfile = io.BufferedReader(io.BytesIO( + b"GET /" + )) + with pytest.raises(exceptions.TlsProtocolException, message="Expected TLS record"): + tls.get_client_hello(rfile) + + +class TestClientHello: + def test_no_extensions(self): + c = tls.ClientHello(CLIENT_HELLO_NO_EXTENSIONS) + assert repr(c) + assert c.sni is None + assert c.cipher_suites == [53, 47, 10, 5, 4, 9, 3, 6, 8, 96, 97, 98, 100] + assert c.alpn_protocols == [] + assert c.extensions == [] + + def test_extensions(self): + data = bytes.fromhex( + "03033b70638d2523e1cba15f8364868295305e9c52aceabda4b5147210abc783e6e1000022c02bc02fc02cc030" + "cca9cca8cc14cc13c009c013c00ac014009c009d002f0035000a0100006cff0100010000000010000e00000b65" + "78616d706c652e636f6d0017000000230000000d00120010060106030501050304010403020102030005000501" + "00000000001200000010000e000c02683208687474702f312e3175500000000b00020100000a00080006001d00" + "170018" + ) + c = tls.ClientHello(data) + assert repr(c) + assert c.sni == 'example.com' + assert c.cipher_suites == [ + 49195, 49199, 49196, 49200, 52393, 52392, 52244, 52243, 49161, + 49171, 49162, 49172, 156, 157, 47, 53, 10 + ] + assert c.alpn_protocols == [b'h2', b'http/1.1'] + assert c.extensions == [ + (65281, b'\x00'), + (0, b'\x00\x0e\x00\x00\x0bexample.com'), + (23, b''), + (35, b''), + (13, b'\x00\x10\x06\x01\x06\x03\x05\x01\x05\x03\x04\x01\x04\x03\x02\x01\x02\x03'), + (5, b'\x01\x00\x00\x00\x00'), + (18, b''), + (16, b'\x00\x0c\x02h2\x08http/1.1'), + (30032, b''), + (11, b'\x01\x00'), + (10, b'\x00\x06\x00\x1d\x00\x17\x00\x18') + ] + + def test_from_file(self): + rfile = io.BufferedReader(io.BytesIO( + FULL_CLIENT_HELLO_NO_EXTENSIONS + )) + assert tls.ClientHello.from_file(rfile) + + rfile = io.BufferedReader(io.BytesIO( + b"" + )) + with pytest.raises(exceptions.TlsProtocolException): + tls.ClientHello.from_file(rfile) + + rfile = io.BufferedReader(io.BytesIO( + b"\x16\x03\x03\x00\x07" # record layer + b"\x01\x00\x00\x03" + # handshake header + b"foo" + )) + with pytest.raises(exceptions.TlsProtocolException, message='Cannot parse Client Hello'): + tls.ClientHello.from_file(rfile) diff --git a/test/mitmproxy/net/tools/getcertnames b/test/mitmproxy/net/tools/getcertnames index d64e5ff5..9349415f 100644 --- a/test/mitmproxy/net/tools/getcertnames +++ b/test/mitmproxy/net/tools/getcertnames @@ -7,7 +7,7 @@ from mitmproxy.net import tcp def get_remote_cert(host, port, sni): c = tcp.TCPClient((host, port)) c.connect() - c.convert_to_ssl(sni=sni) + c.convert_to_tls(sni=sni) return c.cert if len(sys.argv) > 2: diff --git a/test/mitmproxy/net/tservers.py b/test/mitmproxy/net/tservers.py index 44701aa5..22e195e3 100644 --- a/test/mitmproxy/net/tservers.py +++ b/test/mitmproxy/net/tservers.py @@ -60,7 +60,7 @@ class _TServer(tcp.TCPServer): else: method = OpenSSL.SSL.SSLv23_METHOD options = None - h.convert_to_ssl( + h.convert_to_tls( cert, key, method=method, diff --git a/test/mitmproxy/platform/test_pf.py b/test/mitmproxy/platform/test_pf.py index 3292d345..b048a697 100644 --- a/test/mitmproxy/platform/test_pf.py +++ b/test/mitmproxy/platform/test_pf.py @@ -15,6 +15,7 @@ class TestLookup: d = f.read() assert pf.lookup("192.168.1.111", 40000, d) == ("5.5.5.5", 80) + assert pf.lookup("::ffff:192.168.1.111", 40000, d) == ("5.5.5.5", 80) with pytest.raises(Exception, match="Could not resolve original destination"): pf.lookup("192.168.1.112", 40000, d) with pytest.raises(Exception, match="Could not resolve original destination"): diff --git a/test/mitmproxy/proxy/protocol/test_http2.py b/test/mitmproxy/proxy/protocol/test_http2.py index 4f161ef5..194a57c9 100644 --- a/test/mitmproxy/proxy/protocol/test_http2.py +++ b/test/mitmproxy/proxy/protocol/test_http2.py @@ -141,7 +141,7 @@ class _Http2TestBase: while self.client.rfile.readline() != b"\r\n": pass - self.client.convert_to_ssl(alpn_protos=[b'h2']) + self.client.convert_to_tls(alpn_protos=[b'h2']) config = h2.config.H2Configuration( client_side=True, diff --git a/test/mitmproxy/proxy/protocol/test_tls.py b/test/mitmproxy/proxy/protocol/test_tls.py index e17ee46f..e69de29b 100644 --- a/test/mitmproxy/proxy/protocol/test_tls.py +++ b/test/mitmproxy/proxy/protocol/test_tls.py @@ -1,26 +0,0 @@ -from mitmproxy.proxy.protocol.tls import TlsClientHello - - -class TestClientHello: - - def test_no_extensions(self): - data = bytes.fromhex( - "03015658a756ab2c2bff55f636814deac086b7ca56b65058c7893ffc6074f5245f70205658a75475103a152637" - "78e1bb6d22e8bbd5b6b0a3a59760ad354e91ba20d353001a0035002f000a000500040009000300060008006000" - "61006200640100" - ) - c = TlsClientHello(data) - assert c.sni is None - assert c.alpn_protocols == [] - - def test_extensions(self): - data = bytes.fromhex( - "03033b70638d2523e1cba15f8364868295305e9c52aceabda4b5147210abc783e6e1000022c02bc02fc02cc030" - "cca9cca8cc14cc13c009c013c00ac014009c009d002f0035000a0100006cff0100010000000010000e00000b65" - "78616d706c652e636f6d0017000000230000000d00120010060106030501050304010403020102030005000501" - "00000000001200000010000e000c02683208687474702f312e3175500000000b00020100000a00080006001d00" - "170018" - ) - c = TlsClientHello(data) - assert c.sni == 'example.com' - assert c.alpn_protocols == [b'h2', b'http/1.1'] diff --git a/test/mitmproxy/proxy/protocol/test_websocket.py b/test/mitmproxy/proxy/protocol/test_websocket.py index 460d85f8..5cd9601c 100644 --- a/test/mitmproxy/proxy/protocol/test_websocket.py +++ b/test/mitmproxy/proxy/protocol/test_websocket.py @@ -1,5 +1,6 @@ import pytest import os +import struct import tempfile import traceback @@ -33,6 +34,7 @@ class _WebSocketServerBase(net_tservers.ServerTestBase): connection='upgrade', upgrade='websocket', sec_websocket_accept=b'', + sec_websocket_extensions='permessage-deflate' if "permessage-deflate" in request.headers.values() else '' ), content=b'', ) @@ -80,7 +82,7 @@ class _WebSocketTestBase: if self.client: self.client.close() - def setup_connection(self): + def setup_connection(self, extension=False): self.client = tcp.TCPClient(("127.0.0.1", self.proxy.port)) self.client.connect() @@ -99,8 +101,8 @@ class _WebSocketTestBase: response = http.http1.read_response(self.client.rfile, request) if self.ssl: - self.client.convert_to_ssl() - assert self.client.ssl_established + self.client.convert_to_tls() + assert self.client.tls_established request = http.Request( "relative", @@ -115,6 +117,7 @@ class _WebSocketTestBase: upgrade="websocket", sec_websocket_version="13", sec_websocket_key="1234", + sec_websocket_extensions="permessage-deflate" if extension else "" ), content=b'') self.client.wfile.write(http.http1.assemble_request(request)) @@ -145,11 +148,11 @@ class TestSimple(_WebSocketTest): wfile.flush() frame = websockets.Frame.from_file(rfile) - wfile.write(bytes(frame)) + wfile.write(bytes(websockets.Frame(fin=1, opcode=frame.header.opcode, payload=frame.payload))) wfile.flush() frame = websockets.Frame.from_file(rfile) - wfile.write(bytes(frame)) + wfile.write(bytes(websockets.Frame(fin=1, opcode=frame.header.opcode, payload=frame.payload))) wfile.flush() @pytest.mark.parametrize('streaming', [True, False]) @@ -164,36 +167,78 @@ class TestSimple(_WebSocketTest): frame = websockets.Frame.from_file(self.client.rfile) assert frame.payload == b'server-foobar' - self.client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.TEXT, payload=b'self.client-foobar'))) + self.client.wfile.write(bytes(websockets.Frame(fin=1, mask=1, opcode=websockets.OPCODE.TEXT, payload=b'self.client-foobar'))) self.client.wfile.flush() frame = websockets.Frame.from_file(self.client.rfile) assert frame.payload == b'self.client-foobar' - self.client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.BINARY, payload=b'\xde\xad\xbe\xef'))) + self.client.wfile.write(bytes(websockets.Frame(fin=1, mask=1, opcode=websockets.OPCODE.BINARY, payload=b'\xde\xad\xbe\xef'))) self.client.wfile.flush() frame = websockets.Frame.from_file(self.client.rfile) assert frame.payload == b'\xde\xad\xbe\xef' - self.client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.CLOSE))) + self.client.wfile.write(bytes(websockets.Frame(fin=1, mask=1, opcode=websockets.OPCODE.CLOSE))) self.client.wfile.flush() assert len(self.master.state.flows) == 2 assert isinstance(self.master.state.flows[0], HTTPFlow) assert isinstance(self.master.state.flows[1], WebSocketFlow) assert len(self.master.state.flows[1].messages) == 5 - assert self.master.state.flows[1].messages[0].content == b'server-foobar' + assert self.master.state.flows[1].messages[0].content == 'server-foobar' assert self.master.state.flows[1].messages[0].type == websockets.OPCODE.TEXT - assert self.master.state.flows[1].messages[1].content == b'self.client-foobar' + assert self.master.state.flows[1].messages[1].content == 'self.client-foobar' assert self.master.state.flows[1].messages[1].type == websockets.OPCODE.TEXT - assert self.master.state.flows[1].messages[2].content == b'self.client-foobar' + assert self.master.state.flows[1].messages[2].content == 'self.client-foobar' assert self.master.state.flows[1].messages[2].type == websockets.OPCODE.TEXT assert self.master.state.flows[1].messages[3].content == b'\xde\xad\xbe\xef' assert self.master.state.flows[1].messages[3].type == websockets.OPCODE.BINARY assert self.master.state.flows[1].messages[4].content == b'\xde\xad\xbe\xef' assert self.master.state.flows[1].messages[4].type == websockets.OPCODE.BINARY + def test_change_payload(self): + class Addon: + def websocket_message(self, f): + f.messages[-1].content = "foo" + + self.master.addons.add(Addon()) + self.setup_connection() + + frame = websockets.Frame.from_file(self.client.rfile) + assert frame.payload == b'foo' + + self.client.wfile.write(bytes(websockets.Frame(fin=1, mask=1, opcode=websockets.OPCODE.TEXT, payload=b'self.client-foobar'))) + self.client.wfile.flush() + + frame = websockets.Frame.from_file(self.client.rfile) + assert frame.payload == b'foo' + + self.client.wfile.write(bytes(websockets.Frame(fin=1, mask=1, opcode=websockets.OPCODE.BINARY, payload=b'\xde\xad\xbe\xef'))) + self.client.wfile.flush() + + frame = websockets.Frame.from_file(self.client.rfile) + assert frame.payload == b'foo' + + +class TestKillFlow(_WebSocketTest): + + @classmethod + def handle_websockets(cls, rfile, wfile): + wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.TEXT, payload=b'server-foobar'))) + wfile.flush() + + def test_kill(self): + class KillFlow: + def websocket_message(self, f): + f.kill() + + self.master.addons.add(KillFlow()) + self.setup_connection() + + with pytest.raises(exceptions.TcpDisconnect): + websockets.Frame.from_file(self.client.rfile) + class TestSimpleTLS(_WebSocketTest): ssl = True @@ -204,7 +249,7 @@ class TestSimpleTLS(_WebSocketTest): wfile.flush() frame = websockets.Frame.from_file(rfile) - wfile.write(bytes(frame)) + wfile.write(bytes(websockets.Frame(fin=1, opcode=frame.header.opcode, payload=frame.payload))) wfile.flush() def test_simple_tls(self): @@ -213,13 +258,13 @@ class TestSimpleTLS(_WebSocketTest): frame = websockets.Frame.from_file(self.client.rfile) assert frame.payload == b'server-foobar' - self.client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.TEXT, payload=b'self.client-foobar'))) + self.client.wfile.write(bytes(websockets.Frame(fin=1, mask=1, opcode=websockets.OPCODE.TEXT, payload=b'self.client-foobar'))) self.client.wfile.flush() frame = websockets.Frame.from_file(self.client.rfile) assert frame.payload == b'self.client-foobar' - self.client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.CLOSE))) + self.client.wfile.write(bytes(websockets.Frame(fin=1, mask=1, opcode=websockets.OPCODE.CLOSE))) self.client.wfile.flush() @@ -234,22 +279,24 @@ class TestPing(_WebSocketTest): assert frame.header.opcode == websockets.OPCODE.PONG assert frame.payload == b'foobar' - wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.TEXT, payload=b'pong-received'))) + wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.PONG, payload=b'done'))) + wfile.flush() + + wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.CLOSE))) wfile.flush() + websockets.Frame.from_file(rfile) def test_ping(self): self.setup_connection() frame = websockets.Frame.from_file(self.client.rfile) + websockets.Frame.from_file(self.client.rfile) + self.client.wfile.write(bytes(websockets.Frame(fin=1, mask=1, opcode=websockets.OPCODE.CLOSE))) + self.client.wfile.flush() assert frame.header.opcode == websockets.OPCODE.PING - assert frame.payload == b'foobar' + assert frame.payload == b'' # We don't send payload to other end - self.client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.PONG, payload=frame.payload))) - self.client.wfile.flush() - - frame = websockets.Frame.from_file(self.client.rfile) - assert frame.header.opcode == websockets.OPCODE.TEXT - assert frame.payload == b'pong-received' + assert self.master.has_log("Pong Received from server", "info") class TestPong(_WebSocketTest): @@ -258,20 +305,29 @@ class TestPong(_WebSocketTest): def handle_websockets(cls, rfile, wfile): frame = websockets.Frame.from_file(rfile) assert frame.header.opcode == websockets.OPCODE.PING - assert frame.payload == b'foobar' + assert frame.payload == b'' wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.PONG, payload=frame.payload))) wfile.flush() + wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.CLOSE))) + wfile.flush() + websockets.Frame.from_file(rfile) + def test_pong(self): self.setup_connection() - self.client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.PING, payload=b'foobar'))) + self.client.wfile.write(bytes(websockets.Frame(fin=1, mask=1, opcode=websockets.OPCODE.PING, payload=b'foobar'))) self.client.wfile.flush() frame = websockets.Frame.from_file(self.client.rfile) + websockets.Frame.from_file(self.client.rfile) + self.client.wfile.write(bytes(websockets.Frame(fin=1, mask=1, opcode=websockets.OPCODE.CLOSE))) + self.client.wfile.flush() + assert frame.header.opcode == websockets.OPCODE.PONG assert frame.payload == b'foobar' + assert self.master.has_log("Pong Received from server", "info") class TestClose(_WebSocketTest): @@ -279,7 +335,7 @@ class TestClose(_WebSocketTest): @classmethod def handle_websockets(cls, rfile, wfile): frame = websockets.Frame.from_file(rfile) - wfile.write(bytes(frame)) + wfile.write(bytes(websockets.Frame(fin=1, opcode=frame.header.opcode, payload=frame.payload))) wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.CLOSE))) wfile.flush() @@ -289,7 +345,7 @@ class TestClose(_WebSocketTest): def test_close(self): self.setup_connection() - self.client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.CLOSE))) + self.client.wfile.write(bytes(websockets.Frame(fin=1, mask=1, opcode=websockets.OPCODE.CLOSE))) self.client.wfile.flush() websockets.Frame.from_file(self.client.rfile) @@ -299,7 +355,7 @@ class TestClose(_WebSocketTest): def test_close_payload_1(self): self.setup_connection() - self.client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.CLOSE, payload=b'\00\42'))) + self.client.wfile.write(bytes(websockets.Frame(fin=1, mask=1, opcode=websockets.OPCODE.CLOSE, payload=b'\00\42'))) self.client.wfile.flush() websockets.Frame.from_file(self.client.rfile) @@ -309,7 +365,7 @@ class TestClose(_WebSocketTest): def test_close_payload_2(self): self.setup_connection() - self.client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.CLOSE, payload=b'\00\42foobar'))) + self.client.wfile.write(bytes(websockets.Frame(fin=1, mask=1, opcode=websockets.OPCODE.CLOSE, payload=b'\00\42foobar'))) self.client.wfile.flush() websockets.Frame.from_file(self.client.rfile) @@ -329,8 +385,9 @@ class TestInvalidFrame(_WebSocketTest): # with pytest.raises(exceptions.TcpDisconnect): frame = websockets.Frame.from_file(self.client.rfile) - assert frame.header.opcode == 15 - assert frame.payload == b'foobar' + code, = struct.unpack('!H', frame.payload[:2]) + assert code == 1002 + assert frame.payload[2:].startswith(b'Invalid opcode') class TestStreaming(_WebSocketTest): @@ -360,3 +417,51 @@ class TestStreaming(_WebSocketTest): assert frame assert self.master.state.flows[1].messages == [] # Message not appended as the final frame isn't received + + +class TestExtension(_WebSocketTest): + + @classmethod + def handle_websockets(cls, rfile, wfile): + wfile.write(b'\xc1\x0f*N-*K-\xd2M\xcb\xcfOJ,\x02\x00') + wfile.flush() + + frame = websockets.Frame.from_file(rfile) + assert frame.header.rsv1 + wfile.write(b'\xc1\nJ\xce\xc9L\xcd+\x81r\x00\x00') + wfile.flush() + + frame = websockets.Frame.from_file(rfile) + assert frame.header.rsv1 + wfile.write(b'\xc2\x07\xba\xb7v\xdf{\x00\x00') + wfile.flush() + + def test_extension(self): + self.setup_connection(True) + + frame = websockets.Frame.from_file(self.client.rfile) + assert frame.header.rsv1 + + self.client.wfile.write(b'\xc1\x8fQ\xb7vX\x1by\xbf\x14\x9c\x9c\xa7\x15\x9ax9\x12}\xb5v') + self.client.wfile.flush() + + frame = websockets.Frame.from_file(self.client.rfile) + assert frame.header.rsv1 + + self.client.wfile.write(b'\xc2\x87\xeb\xbb\x0csQ\x0cz\xac\x90\xbb\x0c') + self.client.wfile.flush() + + frame = websockets.Frame.from_file(self.client.rfile) + assert frame.header.rsv1 + + assert len(self.master.state.flows[1].messages) == 5 + assert self.master.state.flows[1].messages[0].content == 'server-foobar' + assert self.master.state.flows[1].messages[0].type == websockets.OPCODE.TEXT + assert self.master.state.flows[1].messages[1].content == 'client-foobar' + assert self.master.state.flows[1].messages[1].type == websockets.OPCODE.TEXT + assert self.master.state.flows[1].messages[2].content == 'client-foobar' + assert self.master.state.flows[1].messages[2].type == websockets.OPCODE.TEXT + assert self.master.state.flows[1].messages[3].content == b'\xde\xad\xbe\xef' + assert self.master.state.flows[1].messages[3].type == websockets.OPCODE.BINARY + assert self.master.state.flows[1].messages[4].content == b'\xde\xad\xbe\xef' + assert self.master.state.flows[1].messages[4].type == websockets.OPCODE.BINARY diff --git a/test/mitmproxy/proxy/test_server.py b/test/mitmproxy/proxy/test_server.py index 8dce9bcd..87ec443a 100644 --- a/test/mitmproxy/proxy/test_server.py +++ b/test/mitmproxy/proxy/test_server.py @@ -143,9 +143,9 @@ class TcpMixin: # Test that we get the original SSL cert if self.ssl: - i_cert = certs.SSLCert(i.sslinfo.certchain[0]) - i2_cert = certs.SSLCert(i2.sslinfo.certchain[0]) - n_cert = certs.SSLCert(n.sslinfo.certchain[0]) + i_cert = certs.Cert(i.sslinfo.certchain[0]) + i2_cert = certs.Cert(i2.sslinfo.certchain[0]) + n_cert = certs.Cert(n.sslinfo.certchain[0]) assert i_cert == i2_cert assert i_cert != n_cert @@ -188,9 +188,9 @@ class TcpMixin: # Test that we get the original SSL cert if self.ssl: - i_cert = certs.SSLCert(i.sslinfo.certchain[0]) - i2_cert = certs.SSLCert(i2.sslinfo.certchain[0]) - n_cert = certs.SSLCert(n.sslinfo.certchain[0]) + i_cert = certs.Cert(i.sslinfo.certchain[0]) + i2_cert = certs.Cert(i2.sslinfo.certchain[0]) + n_cert = certs.Cert(n.sslinfo.certchain[0]) assert i_cert == i2_cert assert i_cert != n_cert @@ -511,6 +511,14 @@ class TestReverse(tservers.ReverseProxyTest, CommonMixin, TcpMixin): req = self.master.state.flows[0].request assert req.host_header == "127.0.0.1" + def test_selfconnection(self): + self.options.mode = "reverse:http://127.0.0.1:0" + + p = self.pathoc() + with p.connect(): + p.request("get:/") + assert self.master.has_log("The proxy shall not connect to itself.") + class TestReverseSSL(tservers.ReverseProxyTest, CommonMixin, TcpMixin): reverse = True @@ -579,7 +587,7 @@ class TestSocks5SSL(tservers.SocksModeTest): p = self.pathoc_raw() with p.connect(): p.socks_connect(("localhost", self.server.port)) - p.convert_to_ssl() + p.convert_to_tls() f = p.request("get:/p/200") assert f.status_code == 200 @@ -709,7 +717,7 @@ class TestProxy(tservers.HTTPProxyTest): first_flow = self.master.state.flows[0] second_flow = self.master.state.flows[1] assert first_flow.server_conn.timestamp_tcp_setup - assert first_flow.server_conn.timestamp_ssl_setup is None + assert first_flow.server_conn.timestamp_tls_setup is None assert second_flow.server_conn.timestamp_tcp_setup assert first_flow.server_conn.timestamp_tcp_setup == second_flow.server_conn.timestamp_tcp_setup @@ -723,12 +731,13 @@ class TestProxy(tservers.HTTPProxyTest): class TestProxySSL(tservers.HTTPProxyTest): ssl = True - def test_request_ssl_setup_timestamp_presence(self): + def test_request_tls_attribute_presence(self): # tests that the ssl timestamp is present when ssl is used f = self.pathod("304:b@10k") assert f.status_code == 304 first_flow = self.master.state.flows[0] - assert first_flow.server_conn.timestamp_ssl_setup + assert first_flow.server_conn.timestamp_tls_setup + assert first_flow.client_conn.tls_extensions def test_via(self): # tests that the ssl timestamp is present when ssl is used @@ -1149,7 +1158,7 @@ class AddUpstreamCertsToClientChainMixin: def test_add_upstream_certs_to_client_chain(self): with open(self.servercert, "rb") as f: d = f.read() - upstreamCert = certs.SSLCert.from_pem(d) + upstreamCert = certs.Cert.from_pem(d) p = self.pathoc() with p.connect(): upstream_cert_found_in_client_chain = False diff --git a/test/mitmproxy/test_certs.py b/test/mitmproxy/test_certs.py index 693bebc6..dcc185c0 100644 --- a/test/mitmproxy/test_certs.py +++ b/test/mitmproxy/test_certs.py @@ -136,18 +136,18 @@ class TestDummyCert: assert r.altnames == [] -class TestSSLCert: +class TestCert: def test_simple(self): with open(tutils.test_data.path("mitmproxy/net/data/text_cert"), "rb") as f: d = f.read() - c1 = certs.SSLCert.from_pem(d) + c1 = certs.Cert.from_pem(d) assert c1.cn == b"google.com" assert len(c1.altnames) == 436 with open(tutils.test_data.path("mitmproxy/net/data/text_cert_2"), "rb") as f: d = f.read() - c2 = certs.SSLCert.from_pem(d) + c2 = certs.Cert.from_pem(d) assert c2.cn == b"www.inode.co.nz" assert len(c2.altnames) == 2 assert c2.digest("sha1") @@ -165,20 +165,20 @@ class TestSSLCert: def test_err_broken_sans(self): with open(tutils.test_data.path("mitmproxy/net/data/text_cert_weird1"), "rb") as f: d = f.read() - c = certs.SSLCert.from_pem(d) + c = certs.Cert.from_pem(d) # This breaks unless we ignore a decoding error. assert c.altnames is not None def test_der(self): with open(tutils.test_data.path("mitmproxy/net/data/dercert"), "rb") as f: d = f.read() - s = certs.SSLCert.from_der(d) + s = certs.Cert.from_der(d) assert s.cn def test_state(self): with open(tutils.test_data.path("mitmproxy/net/data/text_cert"), "rb") as f: d = f.read() - c = certs.SSLCert.from_pem(d) + c = certs.Cert.from_pem(d) c.get_state() c2 = c.copy() @@ -188,6 +188,6 @@ class TestSSLCert: assert c == c2 assert c is not c2 - x = certs.SSLCert('') + x = certs.Cert('') x.set_state(a) assert x == c diff --git a/test/mitmproxy/test_command.py b/test/mitmproxy/test_command.py index 43b97742..c777192d 100644 --- a/test/mitmproxy/test_command.py +++ b/test/mitmproxy/test_command.py @@ -4,27 +4,56 @@ from mitmproxy import flow from mitmproxy import exceptions from mitmproxy.test import tflow from mitmproxy.test import taddons +import mitmproxy.types import io import pytest class TAddon: + @command.command("cmd1") def cmd1(self, foo: str) -> str: """cmd1 help""" return "ret " + foo + @command.command("cmd2") def cmd2(self, foo: str) -> str: return 99 + @command.command("cmd3") def cmd3(self, foo: int) -> int: return foo + @command.command("cmd4") + def cmd4(self, a: int, b: str, c: mitmproxy.types.Path) -> str: + return "ok" + + @command.command("subcommand") + def subcommand(self, cmd: mitmproxy.types.Cmd, *args: mitmproxy.types.Arg) -> str: + return "ok" + + @command.command("empty") def empty(self) -> None: pass + @command.command("varargs") def varargs(self, one: str, *var: str) -> typing.Sequence[str]: return list(var) + def choices(self) -> typing.Sequence[str]: + return ["one", "two", "three"] + + @command.argument("arg", type=mitmproxy.types.Choice("choices")) + def choose(self, arg: str) -> typing.Sequence[str]: + return ["one", "two", "three"] + + @command.command("path") + def path(self, arg: mitmproxy.types.Path) -> None: + pass + + @command.command("flow") + def flow(self, f: flow.Flow, s: str) -> None: + pass + class TestCommand: def test_varargs(self): @@ -52,6 +81,144 @@ class TestCommand: c = command.Command(cm, "cmd.three", a.cmd3) assert c.call(["1"]) == 1 + def test_parse_partial(self): + tests = [ + [ + "foo bar", + [ + command.ParseResult( + value = "foo", type = mitmproxy.types.Cmd, valid = False + ), + command.ParseResult( + value = "bar", type = mitmproxy.types.Unknown, valid = False + ) + ], + [], + ], + [ + "cmd1 'bar", + [ + command.ParseResult(value = "cmd1", type = mitmproxy.types.Cmd, valid = True), + command.ParseResult(value = "'bar", type = str, valid = True) + ], + [], + ], + [ + "a", + [command.ParseResult(value = "a", type = mitmproxy.types.Cmd, valid = False)], + [], + ], + [ + "", + [command.ParseResult(value = "", type = mitmproxy.types.Cmd, valid = False)], + [] + ], + [ + "cmd3 1", + [ + command.ParseResult(value = "cmd3", type = mitmproxy.types.Cmd, valid = True), + command.ParseResult(value = "1", type = int, valid = True), + ], + [] + ], + [ + "cmd3 ", + [ + command.ParseResult(value = "cmd3", type = mitmproxy.types.Cmd, valid = True), + command.ParseResult(value = "", type = int, valid = False), + ], + [] + ], + [ + "subcommand ", + [ + command.ParseResult( + value = "subcommand", type = mitmproxy.types.Cmd, valid = True, + ), + command.ParseResult(value = "", type = mitmproxy.types.Cmd, valid = False), + ], + ["arg"], + ], + [ + "subcommand cmd3 ", + [ + command.ParseResult(value = "subcommand", type = mitmproxy.types.Cmd, valid = True), + command.ParseResult(value = "cmd3", type = mitmproxy.types.Cmd, valid = True), + command.ParseResult(value = "", type = int, valid = False), + ], + [] + ], + [ + "cmd4", + [ + command.ParseResult(value = "cmd4", type = mitmproxy.types.Cmd, valid = True), + ], + ["int", "str", "path"] + ], + [ + "cmd4 ", + [ + command.ParseResult(value = "cmd4", type = mitmproxy.types.Cmd, valid = True), + command.ParseResult(value = "", type = int, valid = False), + ], + ["str", "path"] + ], + [ + "cmd4 1", + [ + command.ParseResult(value = "cmd4", type = mitmproxy.types.Cmd, valid = True), + command.ParseResult(value = "1", type = int, valid = True), + ], + ["str", "path"] + ], + [ + "cmd4 1", + [ + command.ParseResult(value = "cmd4", type = mitmproxy.types.Cmd, valid = True), + command.ParseResult(value = "1", type = int, valid = True), + ], + ["str", "path"] + ], + [ + "flow", + [ + command.ParseResult(value = "flow", type = mitmproxy.types.Cmd, valid = True), + ], + ["flow", "str"] + ], + [ + "flow ", + [ + command.ParseResult(value = "flow", type = mitmproxy.types.Cmd, valid = True), + command.ParseResult(value = "", type = flow.Flow, valid = False), + ], + ["str"] + ], + [ + "flow x", + [ + command.ParseResult(value = "flow", type = mitmproxy.types.Cmd, valid = True), + command.ParseResult(value = "x", type = flow.Flow, valid = False), + ], + ["str"] + ], + [ + "flow x ", + [ + command.ParseResult(value = "flow", type = mitmproxy.types.Cmd, valid = True), + command.ParseResult(value = "x", type = flow.Flow, valid = False), + command.ParseResult(value = "", type = str, valid = True), + ], + [] + ], + ] + with taddons.context() as tctx: + tctx.master.addons.add(TAddon()) + for s, expected, expectedremain in tests: + current, remain = tctx.master.commands.parse_partial(s) + assert current == expected + assert expectedremain == remain + def test_simple(): with taddons.context() as tctx: @@ -64,7 +231,7 @@ def test_simple(): c.call("nonexistent") with pytest.raises(exceptions.CommandError, match="Invalid"): c.call("") - with pytest.raises(exceptions.CommandError, match="Usage"): + with pytest.raises(exceptions.CommandError, match="argument mismatch"): c.call("one.two too many args") c.add("empty", a.empty) @@ -76,15 +243,18 @@ def test_simple(): def test_typename(): - assert command.typename(str, True) == "str" - assert command.typename(typing.Sequence[flow.Flow], True) == "[flow]" - assert command.typename(typing.Sequence[flow.Flow], False) == "flowspec" + assert command.typename(str) == "str" + assert command.typename(typing.Sequence[flow.Flow]) == "[flow]" + + assert command.typename(mitmproxy.types.Data) == "[data]" + assert command.typename(mitmproxy.types.CutSpec) == "[cut]" - assert command.typename(command.Cuts, False) == "cutspec" - assert command.typename(command.Cuts, True) == "[cuts]" + assert command.typename(flow.Flow) == "flow" + assert command.typename(typing.Sequence[str]) == "[str]" - assert command.typename(flow.Flow, False) == "flow" - assert command.typename(typing.Sequence[str], False) == "[str]" + assert command.typename(mitmproxy.types.Choice("foo")) == "choice" + assert command.typename(mitmproxy.types.Path) == "path" + assert command.typename(mitmproxy.types.Cmd) == "cmd" class DummyConsole: @@ -94,7 +264,7 @@ class DummyConsole: return [tflow.tflow(resp=True)] * n @command.command("cut") - def cut(self, spec: str) -> command.Cuts: + def cut(self, spec: str) -> mitmproxy.types.Data: return [["test"]] @@ -102,38 +272,11 @@ def test_parsearg(): with taddons.context() as tctx: tctx.master.addons.add(DummyConsole()) assert command.parsearg(tctx.master.commands, "foo", str) == "foo" - - assert command.parsearg(tctx.master.commands, "1", int) == 1 + with pytest.raises(exceptions.CommandError, match="Unsupported"): + command.parsearg(tctx.master.commands, "foo", type) with pytest.raises(exceptions.CommandError): command.parsearg(tctx.master.commands, "foo", int) - assert command.parsearg(tctx.master.commands, "true", bool) is True - assert command.parsearg(tctx.master.commands, "false", bool) is False - with pytest.raises(exceptions.CommandError): - command.parsearg(tctx.master.commands, "flobble", bool) - - assert len(command.parsearg( - tctx.master.commands, "2", typing.Sequence[flow.Flow] - )) == 2 - assert command.parsearg(tctx.master.commands, "1", flow.Flow) - with pytest.raises(exceptions.CommandError): - command.parsearg(tctx.master.commands, "2", flow.Flow) - with pytest.raises(exceptions.CommandError): - command.parsearg(tctx.master.commands, "0", flow.Flow) - with pytest.raises(exceptions.CommandError): - command.parsearg(tctx.master.commands, "foo", Exception) - - assert command.parsearg( - tctx.master.commands, "foo", command.Cuts - ) == [["test"]] - - assert command.parsearg( - tctx.master.commands, "foo", typing.Sequence[str] - ) == ["foo"] - assert command.parsearg( - tctx.master.commands, "foo, bar", typing.Sequence[str] - ) == ["foo", "bar"] - class TDec: @command.command("cmd1") @@ -169,4 +312,4 @@ def test_verify_arg_signature(): with pytest.raises(exceptions.CommandError): command.verify_arg_signature(lambda: None, [1, 2], {}) print('hello there') - command.verify_arg_signature(lambda a, b: None, [1, 2], {})
\ No newline at end of file + command.verify_arg_signature(lambda a, b: None, [1, 2], {}) diff --git a/test/mitmproxy/test_connections.py b/test/mitmproxy/test_connections.py index 83f0bd34..00cdbc87 100644 --- a/test/mitmproxy/test_connections.py +++ b/test/mitmproxy/test_connections.py @@ -41,10 +41,10 @@ class TestClientConnection: def test_tls_established_property(self): c = tflow.tclient_conn() c.tls_established = True - assert c.ssl_established + assert c.tls_established assert c.tls_established c.tls_established = False - assert not c.ssl_established + assert not c.tls_established assert not c.tls_established def test_make_dummy(self): @@ -113,10 +113,10 @@ class TestServerConnection: def test_tls_established_property(self): c = tflow.tserver_conn() c.tls_established = True - assert c.ssl_established + assert c.tls_established assert c.tls_established c.tls_established = False - assert not c.ssl_established + assert not c.tls_established assert not c.tls_established def test_make_dummy(self): @@ -155,7 +155,7 @@ class TestServerConnection: def test_sni(self): c = connections.ServerConnection(('', 1234)) with pytest.raises(ValueError, matches='sni must be str, not '): - c.establish_ssl(None, b'foobar') + c.establish_tls(sni=b'foobar') def test_state(self): c = tflow.tserver_conn() @@ -206,7 +206,7 @@ class TestClientConnectionTLS: key = OpenSSL.crypto.load_privatekey( OpenSSL.crypto.FILETYPE_PEM, raw_key) - c.convert_to_ssl(cert, key) + c.convert_to_tls(cert, key) assert c.connected() assert c.sni == sni assert c.tls_established @@ -222,17 +222,16 @@ class TestServerConnectionTLS(tservers.ServerTestBase): def handle(self): self.finish() - @pytest.mark.parametrize("clientcert", [ + @pytest.mark.parametrize("client_certs", [ None, tutils.test_data.path("mitmproxy/data/clientcert"), tutils.test_data.path("mitmproxy/data/clientcert/client.pem"), ]) - def test_tls(self, clientcert): + def test_tls(self, client_certs): c = connections.ServerConnection(("127.0.0.1", self.port)) c.connect() - c.establish_ssl(clientcert, "foo.com") + c.establish_tls(client_certs=client_certs) assert c.connected() - assert c.sni == "foo.com" assert c.tls_established c.close() c.finish() diff --git a/test/mitmproxy/test_flow.py b/test/mitmproxy/test_flow.py index fcc766b5..8cc11a16 100644 --- a/test/mitmproxy/test_flow.py +++ b/test/mitmproxy/test_flow.py @@ -97,7 +97,7 @@ class TestSerialize: class TestFlowMaster: - def test_load_flow_reverse(self): + def test_load_http_flow_reverse(self): s = tservers.TestState() opts = options.Options( mode="reverse:https://use-this-domain" @@ -108,6 +108,20 @@ class TestFlowMaster: fm.load_flow(f) assert s.flows[0].request.host == "use-this-domain" + def test_load_websocket_flow(self): + s = tservers.TestState() + opts = options.Options( + mode="reverse:https://use-this-domain" + ) + fm = master.Master(opts) + fm.addons.add(s) + f = tflow.twebsocketflow() + fm.load_flow(f.handshake_flow) + fm.load_flow(f) + assert s.flows[0].request.host == "use-this-domain" + assert s.flows[1].handshake_flow == f.handshake_flow + assert len(s.flows[1].messages) == len(f.messages) + def test_replay(self): opts = options.Options() fm = master.Master(opts) diff --git a/test/mitmproxy/test_flowfilter.py b/test/mitmproxy/test_flowfilter.py index c411258a..4eb37d81 100644 --- a/test/mitmproxy/test_flowfilter.py +++ b/test/mitmproxy/test_flowfilter.py @@ -420,6 +420,20 @@ class TestMatchingWebSocketFlow: e = self.err() assert self.q("~e", e) + def test_domain(self): + q = self.flow() + assert self.q("~d example.com", q) + assert not self.q("~d none", q) + + def test_url(self): + q = self.flow() + assert self.q("~u example.com", q) + assert self.q("~u example.com/ws", q) + assert not self.q("~u moo/path", q) + + q.handshake_flow = None + assert not self.q("~u example.com", q) + def test_body(self): f = self.flow() diff --git a/test/mitmproxy/test_http.py b/test/mitmproxy/test_http.py index 4463961a..49e61e25 100644 --- a/test/mitmproxy/test_http.py +++ b/test/mitmproxy/test_http.py @@ -203,6 +203,15 @@ class TestHTTPFlow: f.resume() assert f.reply.state == "committed" + def test_resume_duplicated(self): + f = tflow.tflow() + f.intercept() + f2 = f.copy() + assert f.intercepted is f2.intercepted is True + f.resume() + f2.resume() + assert f.intercepted is f2.intercepted is False + def test_replace_unicode(self): f = tflow.tflow(resp=True) f.response.content = b"\xc2foo" diff --git a/test/mitmproxy/test_stateobject.py b/test/mitmproxy/test_stateobject.py index d8c7a8e9..bd5d1792 100644 --- a/test/mitmproxy/test_stateobject.py +++ b/test/mitmproxy/test_stateobject.py @@ -1,101 +1,146 @@ -from typing import List +import typing + import pytest from mitmproxy.stateobject import StateObject -class Child(StateObject): +class TObject(StateObject): def __init__(self, x): self.x = x - _stateobject_attributes = dict( - x=int - ) - @classmethod def from_state(cls, state): obj = cls(None) obj.set_state(state) return obj + +class Child(TObject): + _stateobject_attributes = dict( + x=int + ) + def __eq__(self, other): return isinstance(other, Child) and self.x == other.x -class Container(StateObject): - def __init__(self): - self.child = None - self.children = None - self.dictionary = None +class TTuple(TObject): + _stateobject_attributes = dict( + x=typing.Tuple[int, Child] + ) + + +class TList(TObject): + _stateobject_attributes = dict( + x=typing.List[Child] + ) + +class TDict(TObject): _stateobject_attributes = dict( - child=Child, - children=List[Child], - dictionary=dict, + x=typing.Dict[str, Child] ) - @classmethod - def from_state(cls, state): - obj = cls() - obj.set_state(state) - return obj + +class TAny(TObject): + _stateobject_attributes = dict( + x=typing.Any + ) + + +class TSerializableChild(TObject): + _stateobject_attributes = dict( + x=Child + ) def test_simple(): a = Child(42) + assert a.get_state() == {"x": 42} b = a.copy() - assert b.get_state() == {"x": 42} a.set_state({"x": 44}) assert a.x == 44 assert b.x == 42 -def test_container(): - a = Container() - a.child = Child(42) +def test_serializable_child(): + child = Child(42) + a = TSerializableChild(child) + assert a.get_state() == { + "x": {"x": 42} + } + a.set_state({ + "x": {"x": 43} + }) + assert a.x.x == 43 + assert a.x is child b = a.copy() - assert a.child.x == b.child.x - b.child.x = 44 - assert a.child.x != b.child.x + assert a.x == b.x + assert a.x is not b.x -def test_container_list(): - a = Container() - a.children = [Child(42), Child(44)] +def test_tuple(): + a = TTuple((42, Child(43))) assert a.get_state() == { - "child": None, - "children": [{"x": 42}, {"x": 44}], - "dictionary": None, + "x": (42, {"x": 43}) } - copy = a.copy() - assert len(copy.children) == 2 - assert copy.children is not a.children - assert copy.children[0] is not a.children[0] - assert Container.from_state(a.get_state()) + b = a.copy() + a.set_state({"x": (44, {"x": 45})}) + assert a.x == (44, Child(45)) + assert b.x == (42, Child(43)) + +def test_tuple_err(): + a = TTuple(None) + with pytest.raises(ValueError, msg="Invalid data"): + a.set_state({"x": (42,)}) -def test_container_dict(): - a = Container() - a.dictionary = dict() - a.dictionary['foo'] = 'bar' - a.dictionary['bar'] = Child(44) + +def test_list(): + a = TList([Child(1), Child(2)]) assert a.get_state() == { - "child": None, - "children": None, - "dictionary": {'bar': {'x': 44}, 'foo': 'bar'}, + "x": [{"x": 1}, {"x": 2}], } copy = a.copy() - assert len(copy.dictionary) == 2 - assert copy.dictionary is not a.dictionary - assert copy.dictionary['bar'] is not a.dictionary['bar'] + assert len(copy.x) == 2 + assert copy.x is not a.x + assert copy.x[0] is not a.x[0] + + +def test_dict(): + a = TDict({"foo": Child(42)}) + assert a.get_state() == { + "x": {"foo": {"x": 42}} + } + b = a.copy() + assert list(a.x.items()) == list(b.x.items()) + assert a.x is not b.x + assert a.x["foo"] is not b.x["foo"] + + +def test_any(): + a = TAny(42) + b = a.copy() + assert a.x == b.x + + a = TAny(object()) + with pytest.raises(AssertionError): + a.get_state() def test_too_much_state(): - a = Container() - a.child = Child(42) + a = Child(42) s = a.get_state() s['foo'] = 'bar' - b = Container() with pytest.raises(RuntimeWarning): - b.set_state(s) + a.set_state(s) + + +def test_none(): + a = Child(None) + assert a.get_state() == {"x": None} + a = Child(42) + a.set_state({"x": None}) + assert a.x is None diff --git a/test/mitmproxy/test_typemanager.py b/test/mitmproxy/test_typemanager.py new file mode 100644 index 00000000..e69de29b --- /dev/null +++ b/test/mitmproxy/test_typemanager.py diff --git a/test/mitmproxy/test_types.py b/test/mitmproxy/test_types.py new file mode 100644 index 00000000..72492fa9 --- /dev/null +++ b/test/mitmproxy/test_types.py @@ -0,0 +1,237 @@ +import pytest +import os +import typing +import contextlib + +from mitmproxy.test import tutils +import mitmproxy.exceptions +import mitmproxy.types +from mitmproxy.test import taddons +from mitmproxy.test import tflow +from mitmproxy import command +from mitmproxy import flow + +from . import test_command + + +@contextlib.contextmanager +def chdir(path: str): + old_dir = os.getcwd() + os.chdir(path) + yield + os.chdir(old_dir) + + +def test_bool(): + with taddons.context() as tctx: + b = mitmproxy.types._BoolType() + assert b.completion(tctx.master.commands, bool, "b") == ["false", "true"] + assert b.parse(tctx.master.commands, bool, "true") is True + assert b.parse(tctx.master.commands, bool, "false") is False + assert b.is_valid(tctx.master.commands, bool, True) is True + assert b.is_valid(tctx.master.commands, bool, "foo") is False + with pytest.raises(mitmproxy.exceptions.TypeError): + b.parse(tctx.master.commands, bool, "foo") + + +def test_str(): + with taddons.context() as tctx: + b = mitmproxy.types._StrType() + assert b.is_valid(tctx.master.commands, str, "foo") is True + assert b.is_valid(tctx.master.commands, str, 1) is False + assert b.completion(tctx.master.commands, str, "") == [] + assert b.parse(tctx.master.commands, str, "foo") == "foo" + + +def test_unknown(): + with taddons.context() as tctx: + b = mitmproxy.types._UnknownType() + assert b.is_valid(tctx.master.commands, mitmproxy.types.Unknown, "foo") is False + assert b.is_valid(tctx.master.commands, mitmproxy.types.Unknown, 1) is False + assert b.completion(tctx.master.commands, mitmproxy.types.Unknown, "") == [] + assert b.parse(tctx.master.commands, mitmproxy.types.Unknown, "foo") == "foo" + + +def test_int(): + with taddons.context() as tctx: + b = mitmproxy.types._IntType() + assert b.is_valid(tctx.master.commands, int, "foo") is False + assert b.is_valid(tctx.master.commands, int, 1) is True + assert b.completion(tctx.master.commands, int, "b") == [] + assert b.parse(tctx.master.commands, int, "1") == 1 + assert b.parse(tctx.master.commands, int, "999") == 999 + with pytest.raises(mitmproxy.exceptions.TypeError): + b.parse(tctx.master.commands, int, "foo") + + +def test_path(): + with taddons.context() as tctx: + b = mitmproxy.types._PathType() + assert b.parse(tctx.master.commands, mitmproxy.types.Path, "/foo") == "/foo" + assert b.parse(tctx.master.commands, mitmproxy.types.Path, "/bar") == "/bar" + assert b.is_valid(tctx.master.commands, mitmproxy.types.Path, "foo") is True + assert b.is_valid(tctx.master.commands, mitmproxy.types.Path, 3) is False + + def normPathOpts(prefix, match): + ret = [] + for s in b.completion(tctx.master.commands, mitmproxy.types.Path, match): + s = s[len(prefix):] + s = s.replace(os.sep, "/") + ret.append(s) + return ret + + cd = os.path.normpath(tutils.test_data.path("mitmproxy/completion")) + assert normPathOpts(cd, cd) == ['/aaa', '/aab', '/aac', '/bbb/'] + assert normPathOpts(cd, os.path.join(cd, "a")) == ['/aaa', '/aab', '/aac'] + with chdir(cd): + assert normPathOpts("", "./") == ['./aaa', './aab', './aac', './bbb/'] + assert normPathOpts("", "") == ['./aaa', './aab', './aac', './bbb/'] + assert b.completion( + tctx.master.commands, mitmproxy.types.Path, "nonexistent" + ) == ["nonexistent"] + + +def test_cmd(): + with taddons.context() as tctx: + tctx.master.addons.add(test_command.TAddon()) + b = mitmproxy.types._CmdType() + assert b.is_valid(tctx.master.commands, mitmproxy.types.Cmd, "foo") is False + assert b.is_valid(tctx.master.commands, mitmproxy.types.Cmd, "cmd1") is True + assert b.parse(tctx.master.commands, mitmproxy.types.Cmd, "cmd1") == "cmd1" + with pytest.raises(mitmproxy.exceptions.TypeError): + assert b.parse(tctx.master.commands, mitmproxy.types.Cmd, "foo") + assert len( + b.completion(tctx.master.commands, mitmproxy.types.Cmd, "") + ) == len(tctx.master.commands.commands.keys()) + + +def test_cutspec(): + with taddons.context() as tctx: + b = mitmproxy.types._CutSpecType() + b.parse(tctx.master.commands, mitmproxy.types.CutSpec, "foo,bar") == ["foo", "bar"] + assert b.is_valid(tctx.master.commands, mitmproxy.types.CutSpec, 1) is False + assert b.is_valid(tctx.master.commands, mitmproxy.types.CutSpec, "foo") is False + assert b.is_valid(tctx.master.commands, mitmproxy.types.CutSpec, "request.path") is True + + assert b.completion( + tctx.master.commands, mitmproxy.types.CutSpec, "request.p" + ) == b.valid_prefixes + ret = b.completion(tctx.master.commands, mitmproxy.types.CutSpec, "request.port,f") + assert ret[0].startswith("request.port,") + assert len(ret) == len(b.valid_prefixes) + + +def test_arg(): + with taddons.context() as tctx: + b = mitmproxy.types._ArgType() + assert b.completion(tctx.master.commands, mitmproxy.types.Arg, "") == [] + assert b.parse(tctx.master.commands, mitmproxy.types.Arg, "foo") == "foo" + assert b.is_valid(tctx.master.commands, mitmproxy.types.Arg, "foo") is True + assert b.is_valid(tctx.master.commands, mitmproxy.types.Arg, 1) is False + + +def test_strseq(): + with taddons.context() as tctx: + b = mitmproxy.types._StrSeqType() + assert b.completion(tctx.master.commands, typing.Sequence[str], "") == [] + assert b.parse(tctx.master.commands, typing.Sequence[str], "foo") == ["foo"] + assert b.parse(tctx.master.commands, typing.Sequence[str], "foo,bar") == ["foo", "bar"] + assert b.is_valid(tctx.master.commands, typing.Sequence[str], ["foo"]) is True + assert b.is_valid(tctx.master.commands, typing.Sequence[str], ["a", "b", 3]) is False + assert b.is_valid(tctx.master.commands, typing.Sequence[str], 1) is False + assert b.is_valid(tctx.master.commands, typing.Sequence[str], "foo") is False + + +class DummyConsole: + @command.command("view.resolve") + def resolve(self, spec: str) -> typing.Sequence[flow.Flow]: + if spec == "err": + raise mitmproxy.exceptions.CommandError() + n = int(spec) + return [tflow.tflow(resp=True)] * n + + @command.command("cut") + def cut(self, spec: str) -> mitmproxy.types.Data: + return [["test"]] + + @command.command("options") + def options(self) -> typing.Sequence[str]: + return ["one", "two", "three"] + + +def test_flow(): + with taddons.context() as tctx: + tctx.master.addons.add(DummyConsole()) + b = mitmproxy.types._FlowType() + assert len(b.completion(tctx.master.commands, flow.Flow, "")) == len(b.valid_prefixes) + assert b.parse(tctx.master.commands, flow.Flow, "1") + assert b.is_valid(tctx.master.commands, flow.Flow, tflow.tflow()) is True + assert b.is_valid(tctx.master.commands, flow.Flow, "xx") is False + with pytest.raises(mitmproxy.exceptions.TypeError): + b.parse(tctx.master.commands, flow.Flow, "0") + with pytest.raises(mitmproxy.exceptions.TypeError): + b.parse(tctx.master.commands, flow.Flow, "2") + with pytest.raises(mitmproxy.exceptions.TypeError): + b.parse(tctx.master.commands, flow.Flow, "err") + + +def test_flows(): + with taddons.context() as tctx: + tctx.master.addons.add(DummyConsole()) + b = mitmproxy.types._FlowsType() + assert len( + b.completion(tctx.master.commands, typing.Sequence[flow.Flow], "") + ) == len(b.valid_prefixes) + assert b.is_valid(tctx.master.commands, typing.Sequence[flow.Flow], [tflow.tflow()]) is True + assert b.is_valid(tctx.master.commands, typing.Sequence[flow.Flow], "xx") is False + assert b.is_valid(tctx.master.commands, typing.Sequence[flow.Flow], 0) is False + assert len(b.parse(tctx.master.commands, typing.Sequence[flow.Flow], "0")) == 0 + assert len(b.parse(tctx.master.commands, typing.Sequence[flow.Flow], "1")) == 1 + assert len(b.parse(tctx.master.commands, typing.Sequence[flow.Flow], "2")) == 2 + with pytest.raises(mitmproxy.exceptions.TypeError): + b.parse(tctx.master.commands, typing.Sequence[flow.Flow], "err") + + +def test_data(): + with taddons.context() as tctx: + b = mitmproxy.types._DataType() + assert b.is_valid(tctx.master.commands, mitmproxy.types.Data, 0) is False + assert b.is_valid(tctx.master.commands, mitmproxy.types.Data, []) is True + assert b.is_valid(tctx.master.commands, mitmproxy.types.Data, [["x"]]) is True + assert b.is_valid(tctx.master.commands, mitmproxy.types.Data, [[b"x"]]) is True + assert b.is_valid(tctx.master.commands, mitmproxy.types.Data, [[1]]) is False + with pytest.raises(mitmproxy.exceptions.TypeError): + b.parse(tctx.master.commands, mitmproxy.types.Data, "foo") + with pytest.raises(mitmproxy.exceptions.TypeError): + b.parse(tctx.master.commands, mitmproxy.types.Data, "foo") + + +def test_choice(): + with taddons.context() as tctx: + tctx.master.addons.add(DummyConsole()) + b = mitmproxy.types._ChoiceType() + assert b.is_valid( + tctx.master.commands, + mitmproxy.types.Choice("options"), + "one", + ) is True + assert b.is_valid( + tctx.master.commands, + mitmproxy.types.Choice("options"), + "invalid", + ) is False + assert b.is_valid( + tctx.master.commands, + mitmproxy.types.Choice("nonexistent"), + "invalid", + ) is False + comp = b.completion(tctx.master.commands, mitmproxy.types.Choice("options"), "") + assert comp == ["one", "two", "three"] + assert b.parse(tctx.master.commands, mitmproxy.types.Choice("options"), "one") == "one" + with pytest.raises(mitmproxy.exceptions.TypeError): + b.parse(tctx.master.commands, mitmproxy.types.Choice("options"), "invalid") + + +def test_typemanager(): + assert mitmproxy.types.CommandTypes.get(bool, None) + assert mitmproxy.types.CommandTypes.get(mitmproxy.types.Choice("choide"), None) diff --git a/test/mitmproxy/test_version.py b/test/mitmproxy/test_version.py index f87b0851..8c176542 100644 --- a/test/mitmproxy/test_version.py +++ b/test/mitmproxy/test_version.py @@ -1,10 +1,36 @@ +import pathlib import runpy +import subprocess +from unittest import mock from mitmproxy import version def test_version(capsys): - runpy.run_module('mitmproxy.version', run_name='__main__') + here = pathlib.Path(__file__).absolute().parent + version_file = here / ".." / ".." / "mitmproxy" / "version.py" + runpy.run_path(str(version_file), run_name='__main__') stdout, stderr = capsys.readouterr() assert len(stdout) > 0 assert stdout.strip() == version.VERSION + + +def test_get_version_hardcoded(): + version.VERSION = "3.0.0.dev123-0xcafebabe" + assert version.get_version() == "3.0.0" + assert version.get_version(True) == "3.0.0.dev123" + assert version.get_version(True, True) == "3.0.0.dev123-0xcafebabe" + + +def test_get_version(): + version.VERSION = "3.0.0" + + with mock.patch('subprocess.check_output') as m: + m.return_value = b"tag-0-cafecafe" + assert version.get_version(True, True) == "3.0.0" + + m.return_value = b"tag-2-cafecafe" + assert version.get_version(True, True) == "3.0.0.dev2-0xcafecaf" + + m.side_effect = subprocess.CalledProcessError(-1, 'git describe --long') + assert version.get_version(True, True) == "3.0.0" diff --git a/test/mitmproxy/test_websocket.py b/test/mitmproxy/test_websocket.py index 7c53a4b0..fcacec36 100644 --- a/test/mitmproxy/test_websocket.py +++ b/test/mitmproxy/test_websocket.py @@ -3,6 +3,7 @@ import pytest from mitmproxy.io import tnetstring from mitmproxy import flowfilter +from mitmproxy.exceptions import Kill, ControlException from mitmproxy.test import tflow @@ -42,6 +43,20 @@ class TestWebSocketFlow: assert f.error.get_state() == f2.error.get_state() assert f.error is not f2.error + def test_kill(self): + f = tflow.twebsocketflow() + with pytest.raises(ControlException): + f.intercept() + f.resume() + f.kill() + + f = tflow.twebsocketflow() + f.intercept() + assert f.killable + f.kill() + assert not f.killable + assert f.reply.value == Kill + def test_match(self): f = tflow.twebsocketflow() assert not flowfilter.match("~b nonexistent", f) @@ -71,3 +86,9 @@ class TestWebSocketFlow: d = tflow.twebsocketflow().handshake_flow.get_state() tnetstring.dump(d, b) assert b.getvalue() + + def test_message_kill(self): + f = tflow.twebsocketflow() + assert not f.messages[-1].killed + f.messages[-1].kill() + assert f.messages[-1].killed diff --git a/test/mitmproxy/tools/console/test_commander.py b/test/mitmproxy/tools/console/test_commander.py new file mode 100644 index 00000000..2a96995d --- /dev/null +++ b/test/mitmproxy/tools/console/test_commander.py @@ -0,0 +1,98 @@ + +from mitmproxy.tools.console.commander import commander +from mitmproxy.test import taddons + + +class TestListCompleter: + def test_cycle(self): + tests = [ + [ + "", + ["a", "b", "c"], + ["a", "b", "c", "a"] + ], + [ + "xxx", + ["a", "b", "c"], + ["xxx", "xxx", "xxx"] + ], + [ + "b", + ["a", "b", "ba", "bb", "c"], + ["b", "ba", "bb", "b"] + ], + ] + for start, options, cycle in tests: + c = commander.ListCompleter(start, options) + for expected in cycle: + assert c.cycle() == expected + + +class TestCommandBuffer: + + def test_backspace(self): + tests = [ + [("", 0), ("", 0)], + [("1", 0), ("1", 0)], + [("1", 1), ("", 0)], + [("123", 3), ("12", 2)], + [("123", 2), ("13", 1)], + [("123", 0), ("123", 0)], + ] + with taddons.context() as tctx: + for start, output in tests: + cb = commander.CommandBuffer(tctx.master) + cb.text, cb.cursor = start[0], start[1] + cb.backspace() + assert cb.text == output[0] + assert cb.cursor == output[1] + + def test_left(self): + cursors = [3, 2, 1, 0, 0] + with taddons.context() as tctx: + cb = commander.CommandBuffer(tctx.master) + cb.text, cb.cursor = "abcd", 4 + for c in cursors: + cb.left() + assert cb.cursor == c + + def test_right(self): + cursors = [1, 2, 3, 4, 4] + with taddons.context() as tctx: + cb = commander.CommandBuffer(tctx.master) + cb.text, cb.cursor = "abcd", 0 + for c in cursors: + cb.right() + assert cb.cursor == c + + def test_insert(self): + tests = [ + [("", 0), ("x", 1)], + [("a", 0), ("xa", 1)], + [("xa", 2), ("xax", 3)], + ] + with taddons.context() as tctx: + for start, output in tests: + cb = commander.CommandBuffer(tctx.master) + cb.text, cb.cursor = start[0], start[1] + cb.insert("x") + assert cb.text == output[0] + assert cb.cursor == output[1] + + def test_cycle_completion(self): + with taddons.context() as tctx: + cb = commander.CommandBuffer(tctx.master) + cb.text = "foo bar" + cb.cursor = len(cb.text) + cb.cycle_completion() + + def test_render(self): + with taddons.context() as tctx: + cb = commander.CommandBuffer(tctx.master) + cb.text = "foo" + assert cb.render() + + def test_flatten(self): + with taddons.context() as tctx: + cb = commander.CommandBuffer(tctx.master) + assert cb.flatten("foo bar") == "foo bar" diff --git a/test/mitmproxy/tools/console/test_common.py b/test/mitmproxy/tools/console/test_common.py index 3ab4fd67..72438c49 100644 --- a/test/mitmproxy/tools/console/test_common.py +++ b/test/mitmproxy/tools/console/test_common.py @@ -1,12 +1,34 @@ +import urwid + from mitmproxy.test import tflow from mitmproxy.tools.console import common -from ....conftest import skip_appveyor - -@skip_appveyor def test_format_flow(): f = tflow.tflow(resp=True) assert common.format_flow(f, True) assert common.format_flow(f, True, hostheader=True) assert common.format_flow(f, True, extended=True) + + +def test_format_keyvals(): + assert common.format_keyvals( + [ + ("aa", "bb"), + ("cc", "dd"), + ("ee", None), + ] + ) + wrapped = urwid.BoxAdapter( + urwid.ListBox( + urwid.SimpleFocusListWalker( + common.format_keyvals([("foo", "bar")]) + ) + ), 1 + ) + assert wrapped.render((30, )) + assert common.format_keyvals( + [ + ("aa", wrapped) + ] + ) diff --git a/test/mitmproxy/tools/console/test_defaultkeys.py b/test/mitmproxy/tools/console/test_defaultkeys.py new file mode 100644 index 00000000..1f17c888 --- /dev/null +++ b/test/mitmproxy/tools/console/test_defaultkeys.py @@ -0,0 +1,23 @@ +from mitmproxy.test.tflow import tflow +from mitmproxy.tools.console import defaultkeys +from mitmproxy.tools.console import keymap +from mitmproxy.tools.console import master +from mitmproxy import command + + +def test_commands_exist(): + km = keymap.Keymap(None) + defaultkeys.map(km) + assert km.bindings + m = master.ConsoleMaster(None) + m.load_flow(tflow()) + + for binding in km.bindings: + cmd, *args = command.lexer(binding.command) + assert cmd in m.commands.commands + + cmd_obj = m.commands.commands[cmd] + try: + cmd_obj.prepare_args(args) + except Exception as e: + raise ValueError("Invalid command: {}".format(binding.command)) from e diff --git a/test/mitmproxy/tools/console/test_keymap.py b/test/mitmproxy/tools/console/test_keymap.py index 00e64991..7b475ff8 100644 --- a/test/mitmproxy/tools/console/test_keymap.py +++ b/test/mitmproxy/tools/console/test_keymap.py @@ -42,7 +42,7 @@ def test_join(): km = keymap.Keymap(tctx.master) km.add("key", "str", ["options"], "help1") km.add("key", "str", ["commands"]) - return + assert len(km.bindings) == 1 assert len(km.bindings[0].contexts) == 2 assert km.bindings[0].help == "help1" diff --git a/test/mitmproxy/tools/console/test_master.py b/test/mitmproxy/tools/console/test_master.py index fd9b301e..6f46ce9e 100644 --- a/test/mitmproxy/tools/console/test_master.py +++ b/test/mitmproxy/tools/console/test_master.py @@ -4,22 +4,9 @@ from mitmproxy import options from mitmproxy.test import tflow from mitmproxy.test import tutils from mitmproxy.tools import console -from mitmproxy.tools.console import common from ... import tservers -def test_format_keyvals(): - assert common.format_keyvals( - [ - ("aa", "bb"), - None, - ("cc", "dd"), - (None, "dd"), - (None, "dd"), - ] - ) - - def test_options(): assert options.Options(server_replay_kill_extra=True) diff --git a/test/mitmproxy/tools/console/test_pathedit.py b/test/mitmproxy/tools/console/test_pathedit.py deleted file mode 100644 index b9f51f5a..00000000 --- a/test/mitmproxy/tools/console/test_pathedit.py +++ /dev/null @@ -1,72 +0,0 @@ -import os -from os.path import normpath -from unittest import mock - -from mitmproxy.tools.console import pathedit -from mitmproxy.test import tutils - - -class TestPathCompleter: - - def test_lookup_construction(self): - c = pathedit._PathCompleter() - - cd = os.path.normpath(tutils.test_data.path("mitmproxy/completion")) - ca = os.path.join(cd, "a") - assert c.complete(ca).endswith(normpath("/completion/aaa")) - assert c.complete(ca).endswith(normpath("/completion/aab")) - c.reset() - ca = os.path.join(cd, "aaa") - assert c.complete(ca).endswith(normpath("/completion/aaa")) - assert c.complete(ca).endswith(normpath("/completion/aaa")) - c.reset() - assert c.complete(cd).endswith(normpath("/completion/aaa")) - - def test_completion(self): - c = pathedit._PathCompleter(True) - c.reset() - c.lookup = [ - ("a", "x/a"), - ("aa", "x/aa"), - ] - assert c.complete("a") == "a" - assert c.final == "x/a" - assert c.complete("a") == "aa" - assert c.complete("a") == "a" - - c = pathedit._PathCompleter(True) - r = c.complete("l") - assert c.final.endswith(r) - - c.reset() - assert c.complete("/nonexistent") == "/nonexistent" - assert c.final == "/nonexistent" - c.reset() - assert c.complete("~") != "~" - - c.reset() - s = "thisisatotallynonexistantpathforsure" - assert c.complete(s) == s - assert c.final == s - - -class TestPathEdit: - - def test_keypress(self): - - pe = pathedit.PathEdit("", "") - - with mock.patch('urwid.widget.Edit.get_edit_text') as get_text, \ - mock.patch('urwid.widget.Edit.set_edit_text') as set_text: - - cd = os.path.normpath(tutils.test_data.path("mitmproxy/completion")) - get_text.return_value = os.path.join(cd, "a") - - # Pressing tab should set completed path - pe.keypress((1,), "tab") - set_text_called_with = set_text.call_args[0][0] - assert set_text_called_with.endswith(normpath("/completion/aaa")) - - # Pressing any other key should reset - pe.keypress((1,), "a") - assert pe.lookup is None diff --git a/test/mitmproxy/tools/web/test_app.py b/test/mitmproxy/tools/web/test_app.py index 248581b9..5afc0bca 100644 --- a/test/mitmproxy/tools/web/test_app.py +++ b/test/mitmproxy/tools/web/test_app.py @@ -322,7 +322,7 @@ class TestApp(tornado.testing.AsyncHTTPTestCase): ws_client2 = yield websocket.websocket_connect(ws_url) ws_client2.close() - def test_generate_tflow_js(self): + def _test_generate_tflow_js(self): _tflow = app.flow_to_json(tflow.tflow(resp=True, err=True)) # Set some value as constant, so that _tflow.js would not change every time. _tflow['client_conn']['id'] = "4a18d1a0-50a1-48dd-9aa6-d45d74282939" diff --git a/test/mitmproxy/utils/test_debug.py b/test/mitmproxy/utils/test_debug.py index a8e1054d..0ca6ead0 100644 --- a/test/mitmproxy/utils/test_debug.py +++ b/test/mitmproxy/utils/test_debug.py @@ -1,5 +1,4 @@ import io -import subprocess import sys from unittest import mock import pytest @@ -14,18 +13,6 @@ def test_dump_system_info_precompiled(precompiled): assert ("binary" in debug.dump_system_info()) == precompiled -def test_dump_system_info_version(): - with mock.patch('subprocess.check_output') as m: - m.return_value = b"v2.0.0-0-cafecafe" - x = debug.dump_system_info() - assert 'dev' not in x - assert 'cafecafe' in x - - with mock.patch('subprocess.check_output') as m: - m.side_effect = subprocess.CalledProcessError(-1, 'git describe --tags --long') - assert 'dev' not in debug.dump_system_info() - - def test_dump_info(): cs = io.StringIO() debug.dump_info(None, None, file=cs, testing=True) diff --git a/test/mitmproxy/utils/test_human.py b/test/mitmproxy/utils/test_human.py index e8ffaad4..947cfa4a 100644 --- a/test/mitmproxy/utils/test_human.py +++ b/test/mitmproxy/utils/test_human.py @@ -54,3 +54,5 @@ def test_format_address(): assert human.format_address(("::ffff:127.0.0.1", "54010", "0", "0")) == "127.0.0.1:54010" assert human.format_address(("127.0.0.1", "54010")) == "127.0.0.1:54010" assert human.format_address(("example.com", "54010")) == "example.com:54010" + assert human.format_address(("::", "8080")) == "*:8080" + assert human.format_address(("0.0.0.0", "8080")) == "*:8080" diff --git a/test/mitmproxy/utils/test_typecheck.py b/test/mitmproxy/utils/test_typecheck.py index 66b1884e..9cb4334e 100644 --- a/test/mitmproxy/utils/test_typecheck.py +++ b/test/mitmproxy/utils/test_typecheck.py @@ -4,7 +4,6 @@ from unittest import mock import pytest from mitmproxy.utils import typecheck -from mitmproxy import command class TBase: @@ -88,34 +87,14 @@ def test_check_any(): typecheck.check_option_type("foo", None, typing.Any) -def test_check_command_type(): - assert(typecheck.check_command_type("foo", str)) - assert(typecheck.check_command_type(["foo"], typing.Sequence[str])) - assert(not typecheck.check_command_type(["foo", 1], typing.Sequence[str])) - assert(typecheck.check_command_type(None, None)) - assert(not typecheck.check_command_type(["foo"], typing.Sequence[int])) - assert(not typecheck.check_command_type("foo", typing.Sequence[int])) - assert(typecheck.check_command_type([["foo", b"bar"]], command.Cuts)) - assert(not typecheck.check_command_type(["foo", b"bar"], command.Cuts)) - assert(not typecheck.check_command_type([["foo", 22]], command.Cuts)) - - # Python 3.5 only defines __parameters__ - m = mock.Mock() - m.__str__ = lambda self: "typing.Sequence" - m.__parameters__ = (int,) - - typecheck.check_command_type([10], m) - - # Python 3.5 only defines __union_params__ - m = mock.Mock() - m.__str__ = lambda self: "typing.Union" - m.__union_params__ = (int,) - assert not typecheck.check_command_type([22], m) - - def test_typesec_to_str(): assert(typecheck.typespec_to_str(str)) == "str" assert(typecheck.typespec_to_str(typing.Sequence[str])) == "sequence of str" assert(typecheck.typespec_to_str(typing.Optional[str])) == "optional str" with pytest.raises(NotImplementedError): typecheck.typespec_to_str(dict) + + +def test_mapping_types(): + # this is not covered by check_option_type, but still belongs in this module + assert (str, int) == typecheck.mapping_types(typing.Mapping[str, int]) diff --git a/test/pathod/protocols/test_http2.py b/test/pathod/protocols/test_http2.py index b1eebc73..95965cee 100644 --- a/test/pathod/protocols/test_http2.py +++ b/test/pathod/protocols/test_http2.py @@ -75,7 +75,7 @@ class TestCheckALPNMatch(net_tservers.ServerTestBase): def test_check_alpn(self): c = tcp.TCPClient(("127.0.0.1", self.port)) with c.connect(): - c.convert_to_ssl(alpn_protos=[b'h2']) + c.convert_to_tls(alpn_protos=[b'h2']) protocol = HTTP2StateProtocol(c) assert protocol.check_alpn() @@ -89,7 +89,7 @@ class TestCheckALPNMismatch(net_tservers.ServerTestBase): def test_check_alpn(self): c = tcp.TCPClient(("127.0.0.1", self.port)) with c.connect(): - c.convert_to_ssl(alpn_protos=[b'h2']) + c.convert_to_tls(alpn_protos=[b'h2']) protocol = HTTP2StateProtocol(c) with pytest.raises(NotImplementedError): protocol.check_alpn() @@ -207,7 +207,7 @@ class TestApplySettings(net_tservers.ServerTestBase): def test_apply_settings(self): c = tcp.TCPClient(("127.0.0.1", self.port)) with c.connect(): - c.convert_to_ssl() + c.convert_to_tls() protocol = HTTP2StateProtocol(c) protocol._apply_settings({ @@ -302,7 +302,7 @@ class TestReadRequest(net_tservers.ServerTestBase): def test_read_request(self): c = tcp.TCPClient(("127.0.0.1", self.port)) with c.connect(): - c.convert_to_ssl() + c.convert_to_tls() protocol = HTTP2StateProtocol(c, is_server=True) protocol.connection_preface_performed = True @@ -328,7 +328,7 @@ class TestReadRequestRelative(net_tservers.ServerTestBase): def test_asterisk_form(self): c = tcp.TCPClient(("127.0.0.1", self.port)) with c.connect(): - c.convert_to_ssl() + c.convert_to_tls() protocol = HTTP2StateProtocol(c, is_server=True) protocol.connection_preface_performed = True @@ -351,7 +351,7 @@ class TestReadRequestAbsolute(net_tservers.ServerTestBase): def test_absolute_form(self): c = tcp.TCPClient(("127.0.0.1", self.port)) with c.connect(): - c.convert_to_ssl() + c.convert_to_tls() protocol = HTTP2StateProtocol(c, is_server=True) protocol.connection_preface_performed = True @@ -378,7 +378,7 @@ class TestReadResponse(net_tservers.ServerTestBase): def test_read_response(self): c = tcp.TCPClient(("127.0.0.1", self.port)) with c.connect(): - c.convert_to_ssl() + c.convert_to_tls() protocol = HTTP2StateProtocol(c) protocol.connection_preface_performed = True @@ -404,7 +404,7 @@ class TestReadEmptyResponse(net_tservers.ServerTestBase): def test_read_empty_response(self): c = tcp.TCPClient(("127.0.0.1", self.port)) with c.connect(): - c.convert_to_ssl() + c.convert_to_tls() protocol = HTTP2StateProtocol(c) protocol.connection_preface_performed = True diff --git a/test/pathod/test_pathoc.py b/test/pathod/test_pathoc.py index 4b50e2a7..297b54d4 100644 --- a/test/pathod/test_pathoc.py +++ b/test/pathod/test_pathoc.py @@ -238,11 +238,11 @@ class TestDaemonHTTP2(PathocTestDaemon): http2_skip_connection_preface=True, ) - tmp_convert_to_ssl = c.convert_to_ssl - c.convert_to_ssl = Mock() - c.convert_to_ssl.side_effect = tmp_convert_to_ssl + tmp_convert_to_tls = c.convert_to_tls + c.convert_to_tls = Mock() + c.convert_to_tls.side_effect = tmp_convert_to_tls with c.connect(): - _, kwargs = c.convert_to_ssl.call_args + _, kwargs = c.convert_to_tls.call_args assert set(kwargs['alpn_protos']) == set([b'http/1.1', b'h2']) def test_request(self): diff --git a/test/pathod/test_pathod.py b/test/pathod/test_pathod.py index c0011952..d6522cb6 100644 --- a/test/pathod/test_pathod.py +++ b/test/pathod/test_pathod.py @@ -153,7 +153,7 @@ class CommonTests(tservers.DaemonTests): c = tcp.TCPClient(("localhost", self.d.port)) with c.connect(): if self.ssl: - c.convert_to_ssl() + c.convert_to_tls() c.wfile.write(b"foo\n\n\n") c.wfile.flush() l = self.d.last_log() @@ -241,7 +241,7 @@ class TestDaemonSSL(CommonTests): with c.connect(): c.wfile.write(b"\0\0\0\0") with pytest.raises(exceptions.TlsException): - c.convert_to_ssl() + c.convert_to_tls() l = self.d.last_log() assert l["type"] == "error" assert "SSL" in l["msg"] |