diff options
author | Aldo Cortesi <aldo@corte.si> | 2014-09-07 13:04:18 +1200 |
---|---|---|
committer | Aldo Cortesi <aldo@corte.si> | 2014-09-07 13:04:18 +1200 |
commit | bf5fef1e0b52854683984abb9023a395521d003a (patch) | |
tree | a437207b26620616d0905d106e3c4972d1f9ef20 /test | |
parent | c1438050ed7263872fb64b19fbb06428bd4605ac (diff) | |
parent | 3d62e90dbf7ea05283e16752531a261e53a4bb47 (diff) | |
download | mitmproxy-bf5fef1e0b52854683984abb9023a395521d003a.tar.gz mitmproxy-bf5fef1e0b52854683984abb9023a395521d003a.tar.bz2 mitmproxy-bf5fef1e0b52854683984abb9023a395521d003a.zip |
Merge pull request #347 from mitmproxy/issue_341
Remove BackReferenceMixin
Diffstat (limited to 'test')
-rw-r--r-- | test/test_console.py | 12 | ||||
-rw-r--r-- | test/test_console_common.py | 2 | ||||
-rw-r--r-- | test/test_dump.py | 41 | ||||
-rw-r--r-- | test/test_examples.py | 2 | ||||
-rw-r--r-- | test/test_flow.py | 323 | ||||
-rw-r--r-- | test/test_protocol_http.py | 170 | ||||
-rw-r--r-- | test/test_proxy.py | 27 | ||||
-rw-r--r-- | test/test_script.py | 16 | ||||
-rw-r--r-- | test/test_server.py | 282 | ||||
-rw-r--r-- | test/tservers.py | 119 | ||||
-rw-r--r-- | test/tutils.py | 120 |
11 files changed, 589 insertions, 525 deletions
diff --git a/test/test_console.py b/test/test_console.py index 0c5b4591..3b6c941d 100644 --- a/test/test_console.py +++ b/test/test_console.py @@ -51,20 +51,20 @@ class TestConsoleState: assert c.get_focus() == (None, None) def _add_request(self, state): - r = tutils.treq() - return state.add_request(r) + f = tutils.tflow() + return state.add_request(f) def _add_response(self, state): f = self._add_request(state) - r = tutils.tresp(f.request) - state.add_response(r) + f.response = tutils.tresp() + state.add_response(f) def test_add_response(self): c = console.ConsoleState() f = self._add_request(c) - r = tutils.tresp(f.request) + f.response = tutils.tresp() c.focus = None - c.add_response(r) + c.add_response(f) def test_focus_view(self): c = console.ConsoleState() diff --git a/test/test_console_common.py b/test/test_console_common.py index d798e4dc..1949dad5 100644 --- a/test/test_console_common.py +++ b/test/test_console_common.py @@ -9,7 +9,7 @@ import tutils def test_format_flow(): - f = tutils.tflow_full() + f = tutils.tflow(resp=True) assert common.format_flow(f, True) assert common.format_flow(f, True, hostheader=True) assert common.format_flow(f, True, extended=True) diff --git a/test/test_dump.py b/test/test_dump.py index 6f70450f..fd93cc03 100644 --- a/test/test_dump.py +++ b/test/test_dump.py @@ -10,31 +10,27 @@ def test_strfuncs(): t.is_replay = True dump.str_response(t) - t = tutils.treq() - t.flow.client_conn = None - t.stickycookie = True - assert "stickycookie" in dump.str_request(t, False) - assert "stickycookie" in dump.str_request(t, True) - assert "replay" in dump.str_request(t, False) - assert "replay" in dump.str_request(t, True) + f = tutils.tflow() + f.client_conn = None + f.request.stickycookie = True + assert "stickycookie" in dump.str_request(f, False) + assert "stickycookie" in dump.str_request(f, True) + assert "replay" in dump.str_request(f, False) + assert "replay" in dump.str_request(f, True) class TestDumpMaster: def _cycle(self, m, content): - req = tutils.treq(content=content) + f = tutils.tflow(req=tutils.treq(content)) l = Log("connect") l.reply = mock.MagicMock() m.handle_log(l) - cc = req.flow.client_conn - cc.reply = mock.MagicMock() - m.handle_clientconnect(cc) - sc = proxy.connection.ServerConnection((req.get_host(), req.get_port()), None) - sc.reply = mock.MagicMock() - m.handle_serverconnect(sc) - m.handle_request(req) - resp = tutils.tresp(req, content=content) - f = m.handle_response(resp) - m.handle_clientdisconnect(cc) + m.handle_clientconnect(f.client_conn) + m.handle_serverconnect(f.server_conn) + m.handle_request(f) + f.response = tutils.tresp(content) + f = m.handle_response(f) + m.handle_clientdisconnect(f.client_conn) return f def _dummy_cycle(self, n, filt, content, **options): @@ -49,8 +45,7 @@ class TestDumpMaster: def _flowfile(self, path): f = open(path, "wb") fw = flow.FlowWriter(f) - t = tutils.tflow_full() - t.response = tutils.tresp(t.request) + t = tutils.tflow(resp=True) fw.add(t) f.close() @@ -58,9 +53,9 @@ class TestDumpMaster: cs = StringIO() o = dump.Options(flow_detail=1) m = dump.DumpMaster(None, o, None, outfile=cs) - f = tutils.tflow_err() - m.handle_request(f.request) - assert m.handle_error(f.error) + f = tutils.tflow(err=True) + m.handle_request(f) + assert m.handle_error(f) assert "error" in cs.getvalue() def test_replay(self): diff --git a/test/test_examples.py b/test/test_examples.py index d18b5862..d557080e 100644 --- a/test/test_examples.py +++ b/test/test_examples.py @@ -12,6 +12,8 @@ def test_load_scripts(): tmaster = tservers.TestMaster(config.ProxyConfig()) for f in scripts: + if "iframe_injector" in f: + f += " foo" # one argument required if "modify_response_body" in f: f += " foo bar" # two arguments required script.Script(f, tmaster) # Loads the script file.
\ No newline at end of file diff --git a/test/test_flow.py b/test/test_flow.py index 88e7b9d7..914138c9 100644 --- a/test/test_flow.py +++ b/test/test_flow.py @@ -14,7 +14,8 @@ def test_app_registry(): ar.add("foo", "domain", 80) r = tutils.treq() - r.set_url("http://domain:80/") + r.host = "domain" + r.port = 80 assert ar.get(r) r.port = 81 @@ -32,8 +33,7 @@ def test_app_registry(): class TestStickyCookieState: def _response(self, cookie, host): s = flow.StickyCookieState(filt.parse(".*")) - f = tutils.tflow_full() - f.server_conn.address = tcp.Address((host, 80)) + f = tutils.tflow(req=tutils.treq(host=host, port=80), resp=True) f.response.headers["Set-Cookie"] = [cookie] s.handle_response(f) return s, f @@ -66,12 +66,12 @@ class TestStickyCookieState: class TestStickyAuthState: def test_handle_response(self): s = flow.StickyAuthState(filt.parse(".*")) - f = tutils.tflow_full() + f = tutils.tflow(resp=True) f.request.headers["authorization"] = ["foo"] s.handle_request(f) assert "address" in s.hosts - f = tutils.tflow_full() + f = tutils.tflow(resp=True) s.handle_request(f) assert f.request.headers["authorization"] == ["foo"] @@ -123,24 +123,24 @@ class TestServerPlaybackState: def test_headers(self): s = flow.ServerPlaybackState(["foo"], [], False, False) - r = tutils.tflow_full() + r = tutils.tflow(resp=True) r.request.headers["foo"] = ["bar"] - r2 = tutils.tflow_full() + r2 = tutils.tflow(resp=True) assert not s._hash(r) == s._hash(r2) r2.request.headers["foo"] = ["bar"] assert s._hash(r) == s._hash(r2) r2.request.headers["oink"] = ["bar"] assert s._hash(r) == s._hash(r2) - r = tutils.tflow_full() - r2 = tutils.tflow_full() + r = tutils.tflow(resp=True) + r2 = tutils.tflow(resp=True) assert s._hash(r) == s._hash(r2) def test_load(self): - r = tutils.tflow_full() + r = tutils.tflow(resp=True) r.request.headers["key"] = ["one"] - r2 = tutils.tflow_full() + r2 = tutils.tflow(resp=True) r2.request.headers["key"] = ["two"] s = flow.ServerPlaybackState(None, [r, r2], False, False) @@ -158,10 +158,10 @@ class TestServerPlaybackState: assert not s.next_flow(r) def test_load_with_nopop(self): - r = tutils.tflow_full() + r = tutils.tflow(resp=True) r.request.headers["key"] = ["one"] - r2 = tutils.tflow_full() + r2 = tutils.tflow(resp=True) r2.request.headers["key"] = ["two"] s = flow.ServerPlaybackState(None, [r, r2], False, True) @@ -173,7 +173,7 @@ class TestServerPlaybackState: class TestFlow: def test_copy(self): - f = tutils.tflow_full() + f = tutils.tflow(resp=True) a0 = f._get_state() f2 = f.copy() a = f._get_state() @@ -188,7 +188,7 @@ class TestFlow: assert f.response == f2.response assert not f.response is f2.response - f = tutils.tflow_err() + f = tutils.tflow(err=True) f2 = f.copy() assert not f is f2 assert not f.request is f2.request @@ -198,12 +198,12 @@ class TestFlow: assert not f.error is f2.error def test_match(self): - f = tutils.tflow_full() + f = tutils.tflow(resp=True) assert not f.match("~b test") assert f.match(None) assert not f.match("~b test") - f = tutils.tflow_err() + f = tutils.tflow(err=True) assert f.match("~e") tutils.raises(ValueError, f.match, "~") @@ -220,14 +220,14 @@ class TestFlow: assert f.request.content == "foo" def test_backup_idempotence(self): - f = tutils.tflow_full() + f = tutils.tflow(resp=True) f.backup() f.revert() f.backup() f.revert() def test_getset_state(self): - f = tutils.tflow_full() + f = tutils.tflow(resp=True) state = f._get_state() assert f._get_state() == protocol.http.HTTPFlow._from_state(state)._get_state() @@ -248,55 +248,42 @@ class TestFlow: s = flow.State() fm = flow.FlowMaster(None, s) f = tutils.tflow() - f.request = tutils.treq() - f.intercept() - assert not f.request.reply.acked - f.kill(fm) - assert f.request.reply.acked f.intercept() - f.response = tutils.tresp() - f.request.reply() - assert not f.response.reply.acked + assert not f.reply.acked f.kill(fm) - assert f.response.reply.acked + assert f.reply.acked def test_killall(self): s = flow.State() fm = flow.FlowMaster(None, s) - r = tutils.treq() - fm.handle_request(r) + f = tutils.tflow() + fm.handle_request(f) - r = tutils.treq() - fm.handle_request(r) + f = tutils.tflow() + fm.handle_request(f) for i in s.view: - assert not i.request.reply.acked + assert not i.reply.acked s.killall(fm) for i in s.view: - assert i.request.reply.acked + assert i.reply.acked def test_accept_intercept(self): f = tutils.tflow() - f.request = tutils.treq() - f.intercept() - assert not f.request.reply.acked - f.accept_intercept() - assert f.request.reply.acked - f.response = tutils.tresp() + f.intercept() - f.request.reply() - assert not f.response.reply.acked + assert not f.reply.acked f.accept_intercept() - assert f.response.reply.acked + assert f.reply.acked def test_replace_unicode(self): - f = tutils.tflow_full() + f = tutils.tflow(resp=True) f.response.content = "\xc2foo" f.replace("foo", u"bar") def test_replace(self): - f = tutils.tflow_full() + f = tutils.tflow(resp=True) f.request.headers["foo"] = ["foo"] f.request.content = "afoob" @@ -311,7 +298,7 @@ class TestFlow: assert f.response.content == "abarb" def test_replace_encoded(self): - f = tutils.tflow_full() + f = tutils.tflow(resp=True) f.request.content = "afoob" f.request.encode("gzip") f.response.content = "afoob" @@ -332,9 +319,8 @@ class TestFlow: class TestState: def test_backup(self): c = flow.State() - req = tutils.treq() - f = c.add_request(req) - + f = tutils.tflow() + c.add_request(f) f.backup() c.revert(f) @@ -344,72 +330,66 @@ class TestState: connect -> request -> response """ - bc = tutils.tclient_conn() c = flow.State() - - req = tutils.treq(bc) - f = c.add_request(req) + f = tutils.tflow() + c.add_request(f) assert f assert c.flow_count() == 1 assert c.active_flow_count() == 1 - newreq = tutils.treq() - assert c.add_request(newreq) + newf = tutils.tflow() + assert c.add_request(newf) assert c.active_flow_count() == 2 - resp = tutils.tresp(req) - assert c.add_response(resp) + f.response = tutils.tresp() + assert c.add_response(f) assert c.flow_count() == 2 assert c.active_flow_count() == 1 - unseen_resp = tutils.tresp() - unseen_resp.flow = None - assert not c.add_response(unseen_resp) + _ = tutils.tresp() + assert not c.add_response(None) assert c.active_flow_count() == 1 - resp = tutils.tresp(newreq) - assert c.add_response(resp) + newf.response = tutils.tresp() + assert c.add_response(newf) assert c.active_flow_count() == 0 def test_err(self): c = flow.State() - req = tutils.treq() - f = c.add_request(req) + f = tutils.tflow() + c.add_request(f) f.error = Error("message") - assert c.add_error(f.error) - - e = Error("message") - assert not c.add_error(e) + assert c.add_error(f) c = flow.State() - req = tutils.treq() - f = c.add_request(req) - e = tutils.terr() + f = tutils.tflow() + c.add_request(f) c.set_limit("~e") assert not c.view - assert c.add_error(e) + f.error = tutils.terr() + assert c.add_error(f) assert c.view def test_set_limit(self): c = flow.State() - req = tutils.treq() + f = tutils.tflow() assert len(c.view) == 0 - c.add_request(req) + c.add_request(f) assert len(c.view) == 1 c.set_limit("~s") assert c.limit_txt == "~s" assert len(c.view) == 0 - resp = tutils.tresp(req) - c.add_response(resp) + f.response = tutils.tresp() + c.add_response(f) assert len(c.view) == 1 c.set_limit(None) assert len(c.view) == 1 - req = tutils.treq() - c.add_request(req) + f = tutils.tflow() + c.add_request(f) assert len(c.view) == 2 c.set_limit("~q") assert len(c.view) == 1 @@ -427,20 +407,19 @@ class TestState: assert c.intercept_txt == None def _add_request(self, state): - req = tutils.treq() - f = state.add_request(req) + f = tutils.tflow() + state.add_request(f) return f def _add_response(self, state): - req = tutils.treq() - state.add_request(req) - resp = tutils.tresp(req) - state.add_response(resp) + f = tutils.tflow() + state.add_request(f) + f.response = tutils.tresp() + state.add_response(f) def _add_error(self, state): - req = tutils.treq() - f = state.add_request(req) - f.error = Error("msg") + f = tutils.tflow(err=True) + state.add_request(f) def test_clear(self): c = flow.State() @@ -479,10 +458,10 @@ class TestSerialize: sio = StringIO() w = flow.FlowWriter(sio) for i in range(3): - f = tutils.tflow_full() + f = tutils.tflow(resp=True) w.add(f) for i in range(3): - f = tutils.tflow_err() + f = tutils.tflow(err=True) w.add(f) sio.seek(0) @@ -502,7 +481,7 @@ class TestSerialize: f2 = l[0] assert f2._get_state() == f._get_state() - assert f2.request._assemble() == f.request._assemble() + assert f2.request.assemble() == f.request.assemble() def test_load_flows(self): r = self._treader() @@ -516,11 +495,11 @@ class TestSerialize: fl = filt.parse("~c 200") w = flow.FilteredFlowWriter(sio, fl) - f = tutils.tflow_full() + f = tutils.tflow(resp=True) f.response.code = 200 w.add(f) - f = tutils.tflow_full() + f = tutils.tflow(resp=True) f.response.code = 201 w.add(f) @@ -565,7 +544,7 @@ class TestFlowMaster: def test_replay(self): s = flow.State() fm = flow.FlowMaster(None, s) - f = tutils.tflow_full() + f = tutils.tflow(resp=True) f.request.content = CONTENT_MISSING assert "missing" in fm.replay_request(f) @@ -576,48 +555,44 @@ class TestFlowMaster: s = flow.State() fm = flow.FlowMaster(None, s) assert not fm.load_script(tutils.test_data.path("scripts/reqerr.py")) - req = tutils.treq() - fm.handle_clientconnect(req.flow.client_conn) - assert fm.handle_request(req) + f = tutils.tflow() + fm.handle_clientconnect(f.client_conn) + assert fm.handle_request(f) def test_script(self): s = flow.State() fm = flow.FlowMaster(None, s) assert not fm.load_script(tutils.test_data.path("scripts/all.py")) - req = tutils.treq() - fm.handle_clientconnect(req.flow.client_conn) + f = tutils.tflow(resp=True) + + fm.handle_clientconnect(f.client_conn) assert fm.scripts[0].ns["log"][-1] == "clientconnect" - sc = ServerConnection((req.get_host(), req.get_port()), None) - sc.reply = controller.DummyReply() - fm.handle_serverconnect(sc) + fm.handle_serverconnect(f.server_conn) assert fm.scripts[0].ns["log"][-1] == "serverconnect" - f = fm.handle_request(req) + fm.handle_request(f) assert fm.scripts[0].ns["log"][-1] == "request" - resp = tutils.tresp(req) - fm.handle_response(resp) + fm.handle_response(f) assert fm.scripts[0].ns["log"][-1] == "response" #load second script assert not fm.load_script(tutils.test_data.path("scripts/all.py")) assert len(fm.scripts) == 2 - fm.handle_clientdisconnect(sc) + fm.handle_clientdisconnect(f.server_conn) assert fm.scripts[0].ns["log"][-1] == "clientdisconnect" assert fm.scripts[1].ns["log"][-1] == "clientdisconnect" - #unload first script fm.unload_scripts() assert len(fm.scripts) == 0 - assert not fm.load_script(tutils.test_data.path("scripts/all.py")) - err = tutils.terr() - err.reply = controller.DummyReply() - fm.handle_error(err) + + f.error = tutils.terr() + fm.handle_error(f) assert fm.scripts[0].ns["log"][-1] == "error" def test_duplicate_flow(self): s = flow.State() fm = flow.FlowMaster(None, s) - f = tutils.tflow_full() + f = tutils.tflow(resp=True) f = fm.load_flow(f) assert s.flow_count() == 1 f2 = fm.duplicate_flow(f) @@ -630,25 +605,22 @@ class TestFlowMaster: fm = flow.FlowMaster(None, s) fm.anticache = True fm.anticomp = True - req = tutils.treq() - fm.handle_clientconnect(req.flow.client_conn) - - f = fm.handle_request(req) + f = tutils.tflow(req=None) + fm.handle_clientconnect(f.client_conn) + f.request = tutils.treq() + fm.handle_request(f) assert s.flow_count() == 1 - resp = tutils.tresp(req) - fm.handle_response(resp) + f.response = tutils.tresp() + fm.handle_response(f) + assert not fm.handle_response(None) assert s.flow_count() == 1 - rx = tutils.tresp() - rx.flow = None - assert not fm.handle_response(rx) - - fm.handle_clientdisconnect(req.flow.client_conn) + fm.handle_clientdisconnect(f.client_conn) f.error = Error("msg") f.error.reply = controller.DummyReply() - fm.handle_error(f.error) + fm.handle_error(f) fm.load_script(tutils.test_data.path("scripts/a.py")) fm.shutdown() @@ -656,8 +628,8 @@ class TestFlowMaster: def test_client_playback(self): s = flow.State() - f = tutils.tflow_full() - pb = [tutils.tflow_full(), f] + f = tutils.tflow(resp=True) + pb = [tutils.tflow(resp=True), f] fm = flow.FlowMaster(None, s) assert not fm.start_server_playback(pb, False, [], False, False) assert not fm.start_client_playback(pb, False) @@ -668,8 +640,7 @@ class TestFlowMaster: assert fm.state.flow_count() f.error = Error("error") - f.error.reply = controller.DummyReply() - fm.handle_error(f.error) + fm.handle_error(f) def test_server_playback(self): s = flow.State() @@ -723,15 +694,15 @@ class TestFlowMaster: assert not fm.stickycookie_state fm.set_stickycookie(".*") - tf = tutils.tflow_full() - tf.response.headers["set-cookie"] = ["foo=bar"] - fm.handle_request(tf.request) - fm.handle_response(tf.response) + f = tutils.tflow(resp=True) + f.response.headers["set-cookie"] = ["foo=bar"] + fm.handle_request(f) + fm.handle_response(f) assert fm.stickycookie_state.jar - assert not "cookie" in tf.request.headers - tf = tf.copy() - fm.handle_request(tf.request) - assert tf.request.headers["cookie"] == ["foo=bar"] + assert not "cookie" in f.request.headers + f = f.copy() + fm.handle_request(f) + assert f.request.headers["cookie"] == ["foo=bar"] def test_stickyauth(self): s = flow.State() @@ -743,14 +714,14 @@ class TestFlowMaster: assert not fm.stickyauth_state fm.set_stickyauth(".*") - tf = tutils.tflow_full() - tf.request.headers["authorization"] = ["foo"] - fm.handle_request(tf.request) + f = tutils.tflow(resp=True) + f.request.headers["authorization"] = ["foo"] + fm.handle_request(f) - f = tutils.tflow_full() + f = tutils.tflow(resp=True) assert fm.stickyauth_state.hosts assert not "authorization" in f.request.headers - fm.handle_request(f.request) + fm.handle_request(f) assert f.request.headers["authorization"] == ["foo"] def test_stream(self): @@ -762,61 +733,63 @@ class TestFlowMaster: s = flow.State() fm = flow.FlowMaster(None, s) - tf = tutils.tflow_full() + f = tutils.tflow(resp=True) fm.start_stream(file(p, "ab"), None) - fm.handle_request(tf.request) - fm.handle_response(tf.response) + fm.handle_request(f) + fm.handle_response(f) fm.stop_stream() assert r()[0].response - tf = tutils.tflow() + f = tutils.tflow() fm.start_stream(file(p, "ab"), None) - fm.handle_request(tf.request) + fm.handle_request(f) fm.shutdown() assert not r()[1].response class TestRequest: def test_simple(self): - r = tutils.treq() - u = r.get_url() - assert r.set_url(u) - assert not r.set_url("") - assert r.get_url() == u - assert r._assemble() - assert r.size() == len(r._assemble()) + f = tutils.tflow() + r = f.request + u = r.url + r.url = u + tutils.raises(ValueError, setattr, r, "url", "") + assert r.url == u + assert r.assemble() + assert r.size() == len(r.assemble()) r2 = r.copy() assert r == r2 r.content = None - assert r._assemble() - assert r.size() == len(r._assemble()) + assert r.assemble() + assert r.size() == len(r.assemble()) r.content = CONTENT_MISSING - tutils.raises("Cannot assemble flow with CONTENT_MISSING", r._assemble) + tutils.raises("Cannot assemble flow with CONTENT_MISSING", r.assemble) def test_get_url(self): - r = tutils.tflow().request + r = tutils.treq() - assert r.get_url() == "http://address:22/path" + assert r.url == "http://address:22/path" - r.flow.server_conn.ssl_established = True - assert r.get_url() == "https://address:22/path" + r.scheme = "https" + assert r.url == "https://address:22/path" - r.flow.server_conn.address = tcp.Address(("host", 42)) - assert r.get_url() == "https://host:42/path" + r.host = "host" + r.port = 42 + assert r.url == "https://host:42/path" r.host = "address" r.port = 22 - assert r.get_url() == "https://address:22/path" + assert r.url== "https://address:22/path" - assert r.get_url(hostheader=True) == "https://address:22/path" + assert r.pretty_url(True) == "https://address:22/path" r.headers["Host"] = ["foo.com"] - assert r.get_url() == "https://address:22/path" - assert r.get_url(hostheader=True) == "https://foo.com:22/path" + assert r.pretty_url(False) == "https://address:22/path" + assert r.pretty_url(True) == "https://foo.com:22/path" def test_path_components(self): r = tutils.treq() @@ -979,8 +952,8 @@ class TestRequest: h["headername"] = ["headervalue"] r = tutils.treq() r.headers = h - result = len(r._assemble_headers()) - assert result == 62 + raw = r._assemble_headers() + assert len(raw) == 62 def test_get_content_type(self): h = flow.ODictCaseless() @@ -991,20 +964,20 @@ class TestRequest: class TestResponse: def test_simple(self): - f = tutils.tflow_full() + f = tutils.tflow(resp=True) resp = f.response - assert resp._assemble() - assert resp.size() == len(resp._assemble()) + assert resp.assemble() + assert resp.size() == len(resp.assemble()) resp2 = resp.copy() assert resp2 == resp resp.content = None - assert resp._assemble() - assert resp.size() == len(resp._assemble()) + assert resp.assemble() + assert resp.size() == len(resp.assemble()) resp.content = CONTENT_MISSING - tutils.raises("Cannot assemble flow with CONTENT_MISSING", resp._assemble) + tutils.raises("Cannot assemble flow with CONTENT_MISSING", resp.assemble) def test_refresh(self): r = tutils.tresp() @@ -1227,7 +1200,7 @@ def test_replacehooks(): h.run(f) assert f.request.content == "foo" - f = tutils.tflow_full() + f = tutils.tflow(resp=True) f.request.content = "foo" f.response.content = "foo" h.run(f) @@ -1280,7 +1253,7 @@ def test_setheaders(): h.clear() h.add("~s", "one", "two") h.add("~s", "one", "three") - f = tutils.tflow_full() + f = tutils.tflow(resp=True) f.request.headers["one"] = ["xxx"] f.response.headers["one"] = ["xxx"] h.run(f) diff --git a/test/test_protocol_http.py b/test/test_protocol_http.py index 3b922c06..ea6cf3fd 100644 --- a/test/test_protocol_http.py +++ b/test/test_protocol_http.py @@ -1,5 +1,4 @@ from libmproxy.protocol.http import * -from libmproxy.protocol import KILL from cStringIO import StringIO import tutils, tservers @@ -26,42 +25,66 @@ def test_stripped_chunked_encoding_no_content(): class TestHTTPRequest: def test_asterisk_form(self): s = StringIO("OPTIONS * HTTP/1.1") - f = tutils.tflow_noreq() + f = tutils.tflow(req=None) f.request = HTTPRequest.from_stream(s) assert f.request.form_in == "relative" - x = f.request._assemble() - assert f.request._assemble() == "OPTIONS * HTTP/1.1\r\nHost: address:22\r\n\r\n" + f.request.host = f.server_conn.address.host + f.request.port = f.server_conn.address.port + f.request.scheme = "http" + assert f.request.assemble() == "OPTIONS * HTTP/1.1\r\nHost: address:22\r\n\r\n" def test_origin_form(self): s = StringIO("GET /foo\xff HTTP/1.1") tutils.raises("Bad HTTP request line", HTTPRequest.from_stream, s) + s = StringIO("GET /foo HTTP/1.1\r\nConnection: Upgrade\r\nUpgrade: h2c") + r = HTTPRequest.from_stream(s) + assert r.headers["Upgrade"] == ["h2c"] + + raw = r._assemble_headers() + assert "Upgrade" not in raw + assert "Host" not in raw + + r.url = "http://example.com/foo" + + raw = r._assemble_headers() + assert "Host" in raw + assert not "Host" in r.headers + r.update_host_header() + assert "Host" in r.headers + def test_authority_form(self): s = StringIO("CONNECT oops-no-port.com HTTP/1.1") tutils.raises("Bad HTTP request line", HTTPRequest.from_stream, s) s = StringIO("CONNECT address:22 HTTP/1.1") r = HTTPRequest.from_stream(s) - assert r._assemble() == "CONNECT address:22 HTTP/1.1\r\nHost: address:22\r\n\r\n" + r.scheme, r.host, r.port = "http", "address", 22 + assert r.assemble() == "CONNECT address:22 HTTP/1.1\r\nHost: address:22\r\n\r\n" + assert r.pretty_url(False) == "address:22" def test_absolute_form(self): s = StringIO("GET oops-no-protocol.com HTTP/1.1") tutils.raises("Bad HTTP request line", HTTPRequest.from_stream, s) s = StringIO("GET http://address:22/ HTTP/1.1") r = HTTPRequest.from_stream(s) - assert r._assemble() == "GET http://address:22/ HTTP/1.1\r\nHost: address:22\r\n\r\n" + assert r.assemble() == "GET http://address:22/ HTTP/1.1\r\nHost: address:22\r\n\r\n" def test_assemble_unknown_form(self): r = tutils.treq() - tutils.raises("Invalid request form", r._assemble, "antiauthority") + tutils.raises("Invalid request form", r.assemble, "antiauthority") def test_set_url(self): r = tutils.treq_absolute() - r.set_url("https://otheraddress:42/ORLY") + r.url = "https://otheraddress:42/ORLY" assert r.scheme == "https" assert r.host == "otheraddress" assert r.port == 42 assert r.path == "/ORLY" + def test_repr(self): + r = tutils.treq() + assert repr(r) + class TestHTTPResponse: def test_read_from_stringio(self): @@ -83,6 +106,19 @@ class TestHTTPResponse: assert r.content == "" tutils.raises("Invalid server response: 'content", HTTPResponse.from_stream, s, "GET") + def test_repr(self): + r = tutils.tresp() + assert "unknown content type" in repr(r) + r.headers["content-type"] = ["foo"] + assert "foo" in repr(r) + assert repr(tutils.tresp(content=CONTENT_MISSING)) + + +class TestHTTPFlow(object): + def test_repr(self): + f = tutils.tflow(resp=True, err=True) + assert repr(f) + class TestInvalidRequests(tservers.HTTPProxTest): ssl = True @@ -97,120 +133,4 @@ class TestInvalidRequests(tservers.HTTPProxTest): p.connect() r = p.request("get:/p/200") assert r.status_code == 400 - assert "Invalid HTTP request form" in r.content - - -class TestProxyChaining(tservers.HTTPChainProxyTest): - def test_all(self): - self.chain[1].tmaster.replacehooks.add("~q", "foo", "bar") # replace in request - self.chain[0].tmaster.replacehooks.add("~q", "foo", "oh noes!") - self.proxy.tmaster.replacehooks.add("~q", "bar", "baz") - self.chain[0].tmaster.replacehooks.add("~s", "baz", "ORLY") # replace in response - - p = self.pathoc() - req = p.request("get:'%s/p/418:b\"foo\"'" % self.server.urlbase) - assert req.content == "ORLY" - assert req.status_code == 418 - -class TestProxyChainingSSL(tservers.HTTPChainProxyTest): - ssl = True - def test_simple(self): - p = self.pathoc() - req = p.request("get:'/p/418:b\"content\"'") - assert req.content == "content" - assert req.status_code == 418 - - assert self.chain[1].tmaster.state.flow_count() == 2 # CONNECT from pathoc to chain[0], - # request from pathoc to chain[0] - assert self.chain[0].tmaster.state.flow_count() == 2 # CONNECT from chain[1] to proxy, - # request from chain[1] to proxy - assert self.proxy.tmaster.state.flow_count() == 1 # request from chain[0] (regular proxy doesn't store CONNECTs) - - def test_closing_connect_response(self): - """ - https://github.com/mitmproxy/mitmproxy/issues/313 - """ - def handle_request(r): - r.httpversion = (1,0) - del r.headers["Content-Length"] - r.reply() - _handle_request = self.chain[0].tmaster.handle_request - self.chain[0].tmaster.handle_request = handle_request - try: - assert self.pathoc().request("get:/p/418").status_code == 418 - finally: - self.chain[0].tmaster.handle_request = _handle_request - - def test_sni(self): - p = self.pathoc(sni="foo.com") - req = p.request("get:'/p/418:b\"content\"'") - assert req.content == "content" - assert req.status_code == 418 - -class TestProxyChainingSSLReconnect(tservers.HTTPChainProxyTest): - ssl = True - - def test_reconnect(self): - """ - Tests proper functionality of ConnectionHandler.server_reconnect mock. - If we have a disconnect on a secure connection that's transparently proxified to - an upstream http proxy, we need to send the CONNECT request again. - """ - def kill_requests(master, attr, exclude): - k = [0] # variable scope workaround: put into array - _func = getattr(master, attr) - def handler(r): - k[0] += 1 - if not (k[0] in exclude): - r.flow.client_conn.finish() - r.flow.error = Error("terminated") - r.reply(KILL) - return _func(r) - setattr(master, attr, handler) - - kill_requests(self.proxy.tmaster, "handle_request", - exclude=[ - # fail first request - 2, # allow second request - ]) - - kill_requests(self.chain[0].tmaster, "handle_request", - exclude=[ - 1, # CONNECT - # fail first request - 3, # reCONNECT - 4, # request - ]) - - p = self.pathoc() - req = p.request("get:'/p/418:b\"content\"'") - assert self.chain[1].tmaster.state.flow_count() == 2 # CONNECT and request - assert self.chain[0].tmaster.state.flow_count() == 4 # CONNECT, failing request, - # reCONNECT, request - assert self.proxy.tmaster.state.flow_count() == 2 # failing request, request - # (doesn't store (repeated) CONNECTs from chain[0] - # as it is a regular proxy) - assert req.content == "content" - assert req.status_code == 418 - - assert not self.proxy.tmaster.state._flow_list[0].response # killed - assert self.proxy.tmaster.state._flow_list[1].response - - assert self.chain[1].tmaster.state._flow_list[0].request.form_in == "authority" - assert self.chain[1].tmaster.state._flow_list[1].request.form_in == "relative" - - assert self.chain[0].tmaster.state._flow_list[0].request.form_in == "authority" - assert self.chain[0].tmaster.state._flow_list[1].request.form_in == "relative" - assert self.chain[0].tmaster.state._flow_list[2].request.form_in == "authority" - assert self.chain[0].tmaster.state._flow_list[3].request.form_in == "relative" - - assert self.proxy.tmaster.state._flow_list[0].request.form_in == "relative" - assert self.proxy.tmaster.state._flow_list[1].request.form_in == "relative" - - req = p.request("get:'/p/418:b\"content2\"'") - - assert req.status_code == 502 - assert self.chain[1].tmaster.state.flow_count() == 3 # + new request - assert self.chain[0].tmaster.state.flow_count() == 6 # + new request, repeated CONNECT from chain[1] - # (both terminated) - assert self.proxy.tmaster.state.flow_count() == 2 # nothing happened here + assert "Invalid HTTP request form" in r.content
\ No newline at end of file diff --git a/test/test_proxy.py b/test/test_proxy.py index b33cdcfd..d13c7ba9 100644 --- a/test/test_proxy.py +++ b/test/test_proxy.py @@ -23,25 +23,34 @@ class TestServerConnection: self.d.shutdown() def test_simple(self): - sc = ServerConnection((self.d.IFACE, self.d.port), None) + sc = ServerConnection((self.d.IFACE, self.d.port)) sc.connect() - r = tutils.treq() - r.flow.server_conn = sc - r.path = "/p/200:da" - sc.send(r._assemble()) - assert http.read_response(sc.rfile, r.method, 1000) + f = tutils.tflow() + f.server_conn = sc + f.request.path = "/p/200:da" + sc.send(f.request.assemble()) + assert http.read_response(sc.rfile, f.request.method, 1000) assert self.d.last_log() sc.finish() def test_terminate_error(self): - sc = ServerConnection((self.d.IFACE, self.d.port), None) + sc = ServerConnection((self.d.IFACE, self.d.port)) sc.connect() sc.connection = mock.Mock() sc.connection.recv = mock.Mock(return_value=False) sc.connection.flush = mock.Mock(side_effect=tcp.NetLibDisconnect) sc.finish() + def test_repr(self): + sc = tutils.tserver_conn() + assert "address:22" in repr(sc) + assert "ssl" not in repr(sc) + sc.ssl_established = True + assert "ssl" in repr(sc) + sc.sni = "foo" + assert "foo" in repr(sc) + class TestProcessProxyOptions: def p(self, *args): @@ -112,7 +121,7 @@ class TestProcessProxyOptions: class TestProxyServer: - @tutils.SkipWindows # binding to 0.0.0.0:1 works without special permissions on Windows + @tutils.SkipWindows # binding to 0.0.0.0:1 works without special permissions on Windows def test_err(self): parser = argparse.ArgumentParser() cmdline.common_options(parser) @@ -138,4 +147,4 @@ class TestConnectionHandler: config = dict(get_upstream_server=mock.Mock(side_effect=RuntimeError)) c = ConnectionHandler(config, mock.MagicMock(), ("127.0.0.1", 8080), None, mock.MagicMock(), None) with tutils.capture_stderr(c.handle) as output: - assert "mitmproxy has crashed" in output
\ No newline at end of file + assert "mitmproxy has crashed" in output diff --git a/test/test_script.py b/test/test_script.py index 587c52d6..aed7def1 100644 --- a/test/test_script.py +++ b/test/test_script.py @@ -29,8 +29,8 @@ class TestScript: s = flow.State() fm = flow.FlowMaster(None, s) fm.load_script(tutils.test_data.path("scripts/duplicate_flow.py")) - r = tutils.treq() - fm.handle_request(r) + f = tutils.tflow() + fm.handle_request(f) assert fm.state.flow_count() == 2 assert not fm.state.view[0].request.is_replay assert fm.state.view[1].request.is_replay @@ -65,12 +65,12 @@ class TestScript: fm.load_script(tutils.test_data.path("scripts/concurrent_decorator.py")) with mock.patch("libmproxy.controller.DummyReply.__call__") as m: - r1, r2 = tutils.treq(), tutils.treq() + f1, f2 = tutils.tflow(), tutils.tflow() t_start = time.time() - fm.handle_request(r1) - r1.reply() - fm.handle_request(r2) - r2.reply() + fm.handle_request(f1) + f1.reply() + fm.handle_request(f2) + f2.reply() # Two instantiations assert m.call_count == 0 # No calls yet. @@ -99,7 +99,7 @@ class TestScript: d = Dummy() assert s.run(hook, d)[0] d.reply() - while (time.time() - t_start) < 5 and m.call_count <= 5: + while (time.time() - t_start) < 20 and m.call_count <= 5: if m.call_count == 5: return time.sleep(0.001) diff --git a/test/test_server.py b/test/test_server.py index 3906cb8e..d33bcc89 100644 --- a/test/test_server.py +++ b/test/test_server.py @@ -1,10 +1,10 @@ import socket, time -import mock +from libmproxy.proxy.config import ProxyConfig from netlib import tcp, http_auth, http from libpathod import pathoc, pathod +from netlib.certutils import SSLCert import tutils, tservers -from libmproxy import flow -from libmproxy.protocol import KILL +from libmproxy.protocol import KILL, Error from libmproxy.protocol.http import CONTENT_MISSING """ @@ -21,8 +21,11 @@ class CommonMixin: def test_replay(self): assert self.pathod("304").status_code == 304 - assert len(self.master.state.view) == 1 - l = self.master.state.view[0] + if isinstance(self, tservers.HTTPUpstreamProxTest) and self.ssl: + assert len(self.master.state.view) == 2 + else: + assert len(self.master.state.view) == 1 + l = self.master.state.view[-1] assert l.response.code == 304 l.request.path = "/p/305" rt = self.master.replay_request(l, block=True) @@ -31,18 +34,28 @@ class CommonMixin: # Disconnect error l.request.path = "/p/305:d0" rt = self.master.replay_request(l, block=True) - assert l.error + assert not rt + if isinstance(self, tservers.HTTPUpstreamProxTest): + assert l.response.code == 502 + else: + assert l.error # Port error l.request.port = 1 - self.master.replay_request(l, block=True) - assert l.error + # In upstream mode, we get a 502 response from the upstream proxy server. + # In upstream mode with ssl, the replay will fail as we cannot establish SSL with the upstream proxy. + rt = self.master.replay_request(l, block=True) + assert not rt + if isinstance(self, tservers.HTTPUpstreamProxTest) and not self.ssl: + assert l.response.code == 502 + else: + assert l.error def test_http(self): f = self.pathod("304") assert f.status_code == 304 - l = self.master.state.view[0] + l = self.master.state.view[-1] # In Upstream mode with SSL, we may already have a previous CONNECT request. assert l.client_conn.address assert "host" in l.request.headers assert l.response.code == 304 @@ -55,6 +68,51 @@ class CommonMixin: line = t.rfile.readline() assert ("Bad Request" in line) or ("Bad Gateway" in line) + def test_sni(self): + if not self.ssl: + return + + f = self.pathod("304", sni="testserver.com") + assert f.status_code == 304 + log = self.server.last_log() + assert log["request"]["sni"] == "testserver.com" + +class TcpMixin: + def _ignore_on(self): + conf = ProxyConfig(ignore=[".+:%s" % self.server.port]) + self.config.ignore.append(conf.ignore[0]) + + def _ignore_off(self): + self.config.ignore.pop() + + def test_ignore(self): + spec = '304:h"Alternate-Protocol"="mitmproxy-will-remove-this"' + n = self.pathod(spec) + self._ignore_on() + i = self.pathod(spec) + i2 = self.pathod(spec) + self._ignore_off() + + assert i.status_code == i2.status_code == n.status_code == 304 + assert "Alternate-Protocol" in i.headers + assert "Alternate-Protocol" in i2.headers + assert "Alternate-Protocol" not in n.headers + + # Test that we get the original SSL cert + if self.ssl: + i_cert = SSLCert(i.sslinfo.certchain[0]) + i2_cert = SSLCert(i2.sslinfo.certchain[0]) + n_cert = SSLCert(n.sslinfo.certchain[0]) + + assert i_cert == i2_cert + assert i_cert != n_cert + + # Test Non-HTTP traffic + spec = "200:i0,@100:d0" # this results in just 100 random bytes + assert self.pathod(spec).status_code == 502 # mitmproxy responds with bad gateway + self._ignore_on() + tutils.raises("invalid server response", self.pathod, spec) # pathoc tries to parse answer as HTTP + self._ignore_off() class AppMixin: @@ -64,7 +122,6 @@ class AppMixin: assert "mitmproxy" in ret.content - class TestHTTP(tservers.HTTPProxTest, CommonMixin, AppMixin): def test_app_err(self): p = self.pathoc() @@ -186,7 +243,7 @@ class TestHTTPConnectSSLError(tservers.HTTPProxTest): tutils.raises("502 - Bad Gateway", p.http_connect, dst) -class TestHTTPS(tservers.HTTPProxTest, CommonMixin): +class TestHTTPS(tservers.HTTPProxTest, CommonMixin, TcpMixin): ssl = True ssloptions = pathod.SSLOptions(request_client_cert=True) clientcerts = True @@ -195,12 +252,6 @@ class TestHTTPS(tservers.HTTPProxTest, CommonMixin): assert f.status_code == 304 assert self.server.last_log()["request"]["clientcert"]["keyinfo"] - def test_sni(self): - f = self.pathod("304", sni="testserver.com") - assert f.status_code == 304 - l = self.server.last_log() - assert self.server.last_log()["request"]["sni"] == "testserver.com" - def test_error_post_connect(self): p = self.pathoc() assert p.request("get:/:i0,'invalid\r\n\r\n'").status_code == 400 @@ -228,21 +279,16 @@ class TestHTTPSNoCommonName(tservers.HTTPProxTest): assert f.sslinfo.certchain[0].get_subject().CN == "127.0.0.1" -class TestReverse(tservers.ReverseProxTest, CommonMixin): +class TestReverse(tservers.ReverseProxTest, CommonMixin, TcpMixin): reverse = True -class TestTransparent(tservers.TransparentProxTest, CommonMixin): +class TestTransparent(tservers.TransparentProxTest, CommonMixin, TcpMixin): ssl = False -class TestTransparentSSL(tservers.TransparentProxTest, CommonMixin): +class TestTransparentSSL(tservers.TransparentProxTest, CommonMixin, TcpMixin): ssl = True - def test_sni(self): - f = self.pathod("304", sni="testserver.com") - assert f.status_code == 304 - l = self.server.last_log() - assert l["request"]["sni"] == "testserver.com" def test_sslerr(self): p = pathoc.Pathoc(("localhost", self.proxy.port)) @@ -323,7 +369,7 @@ class TestProxy(tservers.HTTPProxTest): f = self.pathod("200:b@100") assert f.status_code == 200 f = self.master.state.view[0] - assert f.server_conn.peername == ("127.0.0.1", self.server.port) + assert f.server_conn.address == ("127.0.0.1", self.server.port) class TestProxySSL(tservers.HTTPProxTest): ssl=True @@ -331,29 +377,30 @@ class TestProxySSL(tservers.HTTPProxTest): # tests that the ssl timestamp is present when ssl is used f = self.pathod("304:b@10k") assert f.status_code == 304 - first_request = self.master.state.view[0].request - assert first_request.flow.server_conn.timestamp_ssl_setup + first_flow = self.master.state.view[0] + assert first_flow.server_conn.timestamp_ssl_setup class MasterRedirectRequest(tservers.TestMaster): - def handle_request(self, request): + redirect_port = None # Set by TestRedirectRequest + + def handle_request(self, f): + request = f.request if request.path == "/p/201": - url = request.get_url() + url = request.url new = "http://127.0.0.1:%s/p/201" % self.redirect_port - request.set_url(new) - request.set_url(new) - request.flow.live.change_server(("127.0.0.1", self.redirect_port), False) - request.set_url(url) - tutils.raises("SSL handshake error", request.flow.live.change_server, ("127.0.0.1", self.redirect_port), True) - request.set_url(new) - request.set_url(url) - request.set_url(new) - tservers.TestMaster.handle_request(self, request) + request.url = new + f.live.change_server(("127.0.0.1", self.redirect_port), False) + request.url = url + tutils.raises("SSL handshake error", f.live.change_server, ("127.0.0.1", self.redirect_port), True) + request.url = new + tservers.TestMaster.handle_request(self, f) - def handle_response(self, response): - response.content = str(response.flow.client_conn.address.port) - tservers.TestMaster.handle_response(self, response) + def handle_response(self, f): + f.response.content = str(f.client_conn.address.port) + f.response.headers["server-conn-id"] = [str(f.server_conn.source_address.port)] + tservers.TestMaster.handle_response(self, f) class TestRedirectRequest(tservers.HTTPProxTest): @@ -385,16 +432,17 @@ class TestRedirectRequest(tservers.HTTPProxTest): assert self.server.last_log() assert not self.server2.last_log() - assert r3.content == r2.content == r1.content + assert r1.content == r2.content == r3.content + assert r1.headers.get_first("server-conn-id") == r3.headers.get_first("server-conn-id") # Make sure that we actually use the same connection in this test case class MasterStreamRequest(tservers.TestMaster): """ Enables the stream flag on the flow for all requests """ - def handle_responseheaders(self, r): - r.stream = True - r.reply() + def handle_responseheaders(self, f): + f.response.stream = True + f.reply() class TestStreamRequest(tservers.HTTPProxTest): masterclass = MasterStreamRequest @@ -445,9 +493,9 @@ class TestStreamRequest(tservers.HTTPProxTest): class MasterFakeResponse(tservers.TestMaster): - def handle_request(self, m): + def handle_request(self, f): resp = tutils.tresp() - m.reply(resp) + f.reply(resp) class TestFakeResponse(tservers.HTTPProxTest): @@ -458,8 +506,8 @@ class TestFakeResponse(tservers.HTTPProxTest): class MasterKillRequest(tservers.TestMaster): - def handle_request(self, m): - m.reply(KILL) + def handle_request(self, f): + f.reply(KILL) class TestKillRequest(tservers.HTTPProxTest): @@ -471,8 +519,8 @@ class TestKillRequest(tservers.HTTPProxTest): class MasterKillResponse(tservers.TestMaster): - def handle_response(self, m): - m.reply(KILL) + def handle_response(self, f): + f.reply(KILL) class TestKillResponse(tservers.HTTPProxTest): @@ -495,10 +543,10 @@ class TestTransparentResolveError(tservers.TransparentProxTest): class MasterIncomplete(tservers.TestMaster): - def handle_request(self, m): + def handle_request(self, f): resp = tutils.tresp() resp.content = CONTENT_MISSING - m.reply(resp) + f.reply(resp) class TestIncompleteResponse(tservers.HTTPProxTest): @@ -510,6 +558,132 @@ class TestIncompleteResponse(tservers.HTTPProxTest): class TestCertForward(tservers.HTTPProxTest): certforward = True ssl = True + def test_app_err(self): tutils.raises("handshake error", self.pathod, "200:b@100") + +class TestUpstreamProxy(tservers.HTTPUpstreamProxTest, CommonMixin, AppMixin): + ssl = False + + def test_order(self): + self.proxy.tmaster.replacehooks.add("~q", "foo", "bar") # replace in request + self.chain[0].tmaster.replacehooks.add("~q", "bar", "baz") + self.chain[1].tmaster.replacehooks.add("~q", "foo", "oh noes!") + self.chain[0].tmaster.replacehooks.add("~s", "baz", "ORLY") # replace in response + + p = self.pathoc() + req = p.request("get:'%s/p/418:b\"foo\"'" % self.server.urlbase) + assert req.content == "ORLY" + assert req.status_code == 418 + + +class TestUpstreamProxySSL(tservers.HTTPUpstreamProxTest, CommonMixin, TcpMixin): + ssl = True + + def _ignore_on(self): + super(TestUpstreamProxySSL, self)._ignore_on() + conf = ProxyConfig(ignore=[".+:%s" % self.server.port]) + for proxy in self.chain: + proxy.tmaster.server.config.ignore.append(conf.ignore[0]) + + def _ignore_off(self): + super(TestUpstreamProxySSL, self)._ignore_off() + for proxy in self.chain: + proxy.tmaster.server.config.ignore.pop() + + def test_simple(self): + p = self.pathoc() + req = p.request("get:'/p/418:b\"content\"'") + assert req.content == "content" + assert req.status_code == 418 + + assert self.proxy.tmaster.state.flow_count() == 2 # CONNECT from pathoc to chain[0], + # request from pathoc to chain[0] + assert self.chain[0].tmaster.state.flow_count() == 2 # CONNECT from proxy to chain[1], + # request from proxy to chain[1] + assert self.chain[1].tmaster.state.flow_count() == 1 # request from chain[0] (regular proxy doesn't store CONNECTs) + + def test_closing_connect_response(self): + """ + https://github.com/mitmproxy/mitmproxy/issues/313 + """ + def handle_request(f): + f.request.httpversion = (1, 0) + del f.request.headers["Content-Length"] + f.reply() + _handle_request = self.chain[0].tmaster.handle_request + self.chain[0].tmaster.handle_request = handle_request + try: + assert self.pathoc().request("get:/p/418").status_code == 418 + finally: + self.chain[0].tmaster.handle_request = _handle_request + + +class TestProxyChainingSSLReconnect(tservers.HTTPUpstreamProxTest): + ssl = True + + def test_reconnect(self): + """ + Tests proper functionality of ConnectionHandler.server_reconnect mock. + If we have a disconnect on a secure connection that's transparently proxified to + an upstream http proxy, we need to send the CONNECT request again. + """ + def kill_requests(master, attr, exclude): + k = [0] # variable scope workaround: put into array + _func = getattr(master, attr) + def handler(f): + k[0] += 1 + if not (k[0] in exclude): + f.client_conn.finish() + f.error = Error("terminated") + f.reply(KILL) + return _func(f) + setattr(master, attr, handler) + + kill_requests(self.chain[1].tmaster, "handle_request", + exclude=[ + # fail first request + 2, # allow second request + ]) + + kill_requests(self.chain[0].tmaster, "handle_request", + exclude=[ + 1, # CONNECT + # fail first request + 3, # reCONNECT + 4, # request + ]) + + p = self.pathoc() + req = p.request("get:'/p/418:b\"content\"'") + assert self.proxy.tmaster.state.flow_count() == 2 # CONNECT and request + assert self.chain[0].tmaster.state.flow_count() == 4 # CONNECT, failing request, + # reCONNECT, request + assert self.chain[1].tmaster.state.flow_count() == 2 # failing request, request + # (doesn't store (repeated) CONNECTs from chain[0] + # as it is a regular proxy) + assert req.content == "content" + assert req.status_code == 418 + + assert not self.chain[1].tmaster.state._flow_list[0].response # killed + assert self.chain[1].tmaster.state._flow_list[1].response + + assert self.proxy.tmaster.state._flow_list[0].request.form_in == "authority" + assert self.proxy.tmaster.state._flow_list[1].request.form_in == "relative" + + assert self.chain[0].tmaster.state._flow_list[0].request.form_in == "authority" + assert self.chain[0].tmaster.state._flow_list[1].request.form_in == "relative" + assert self.chain[0].tmaster.state._flow_list[2].request.form_in == "authority" + assert self.chain[0].tmaster.state._flow_list[3].request.form_in == "relative" + + assert self.chain[1].tmaster.state._flow_list[0].request.form_in == "relative" + assert self.chain[1].tmaster.state._flow_list[1].request.form_in == "relative" + + req = p.request("get:'/p/418:b\"content2\"'") + + assert req.status_code == 502 + assert self.proxy.tmaster.state.flow_count() == 3 # + new request + assert self.chain[0].tmaster.state.flow_count() == 6 # + new request, repeated CONNECT from chain[1] + # (both terminated) + assert self.chain[1].tmaster.state.flow_count() == 2 # nothing happened here diff --git a/test/tservers.py b/test/tservers.py index a12a440e..8a2e72a4 100644 --- a/test/tservers.py +++ b/test/tservers.py @@ -36,13 +36,13 @@ class TestMaster(flow.FlowMaster): self.apps.add(errapp, "errapp", 80) self.clear_log() - def handle_request(self, m): - flow.FlowMaster.handle_request(self, m) - m.reply() + def handle_request(self, f): + flow.FlowMaster.handle_request(self, f) + f.reply() - def handle_response(self, m): - flow.FlowMaster.handle_response(self, m) - m.reply() + def handle_response(self, f): + flow.FlowMaster.handle_response(self, f) + f.reply() def clear_log(self): self.log = [] @@ -84,29 +84,19 @@ class ProxTestBase(object): masterclass = TestMaster externalapp = False certforward = False + @classmethod def setupAll(cls): cls.server = libpathod.test.Daemon(ssl=cls.ssl, ssloptions=cls.ssloptions) cls.server2 = libpathod.test.Daemon(ssl=cls.ssl, ssloptions=cls.ssloptions) - pconf = cls.get_proxy_config() - cls.confdir = os.path.join(tempfile.gettempdir(), "mitmproxy") - cls.config = ProxyConfig( - no_upstream_cert = cls.no_upstream_cert, - confdir = cls.confdir, - authenticator = cls.authenticator, - certforward = cls.certforward, - ssl_ports=([cls.server.port, cls.server2.port] if cls.ssl else []), - **pconf - ) + + cls.config = ProxyConfig(**cls.get_proxy_config()) + tmaster = cls.masterclass(cls.config) tmaster.start_app(APP_HOST, APP_PORT, cls.externalapp) cls.proxy = ProxyThread(tmaster) cls.proxy.start() - @property - def master(cls): - return cls.proxy.tmaster - @classmethod def teardownAll(cls): shutil.rmtree(cls.confdir) @@ -121,24 +111,20 @@ class ProxTestBase(object): self.server2.clear_log() @property - def scheme(self): - return "https" if self.ssl else "http" - - @property - def proxies(self): - """ - The URL base for the server instance. - """ - return ( - (self.scheme, ("127.0.0.1", self.proxy.port)) - ) + def master(self): + return self.proxy.tmaster @classmethod def get_proxy_config(cls): - d = dict() - if cls.clientcerts: - d["clientcerts"] = tutils.test_data.path("data/clientcert") - return d + cls.confdir = os.path.join(tempfile.gettempdir(), "mitmproxy") + return dict( + no_upstream_cert = cls.no_upstream_cert, + confdir = cls.confdir, + authenticator = cls.authenticator, + certforward = cls.certforward, + ssl_ports=([cls.server.port, cls.server2.port] if cls.ssl else []), + clientcerts = tutils.test_data.path("data/clientcert") if cls.clientcerts else None + ) class HTTPProxTest(ProxTestBase): @@ -265,49 +251,50 @@ class ReverseProxTest(ProxTestBase): class ChainProxTest(ProxTestBase): """ - Chain n instances of mitmproxy in a row - because we can. + Chain three instances of mitmproxy in a row to test upstream mode. + Proxy order is cls.proxy -> cls.chain[0] -> cls.chain[1] + cls.proxy and cls.chain[0] are in upstream mode, + cls.chain[1] is in regular mode. """ + chain = None n = 2 - chain_config = [lambda port, sslports: ProxyConfig( - upstream_server= (False, False, "127.0.0.1", port), - http_form_in = "absolute", - http_form_out = "absolute", - ssl_ports=sslports - )] * n + @classmethod def setupAll(cls): - super(ChainProxTest, cls).setupAll() cls.chain = [] - for i in range(cls.n): - sslports = [cls.server.port, cls.server2.port] - config = cls.chain_config[i](cls.proxy.port if i == 0 else cls.chain[-1].port, - sslports) + super(ChainProxTest, cls).setupAll() + for _ in range(cls.n): + config = ProxyConfig(**cls.get_proxy_config()) tmaster = cls.masterclass(config) - tmaster.start_app(APP_HOST, APP_PORT, cls.externalapp) - cls.chain.append(ProxyThread(tmaster)) - cls.chain[-1].start() + proxy = ProxyThread(tmaster) + proxy.start() + cls.chain.insert(0, proxy) + + # Patch the orginal proxy to upstream mode + cls.config = cls.proxy.tmaster.config = cls.proxy.tmaster.server.config = ProxyConfig(**cls.get_proxy_config()) + @classmethod def teardownAll(cls): super(ChainProxTest, cls).teardownAll() - for p in cls.chain: - p.tmaster.shutdown() + for proxy in cls.chain: + proxy.shutdown() def setUp(self): super(ChainProxTest, self).setUp() - for p in self.chain: - p.tmaster.clear_log() - p.tmaster.state.clear() + for proxy in self.chain: + proxy.tmaster.clear_log() + proxy.tmaster.state.clear() + @classmethod + def get_proxy_config(cls): + d = super(ChainProxTest, cls).get_proxy_config() + if cls.chain: # First proxy is in normal mode. + d.update( + mode="upstream", + upstream_server=(False, False, "127.0.0.1", cls.chain[0].port) + ) + return d -class HTTPChainProxyTest(ChainProxTest): - def pathoc(self, sni=None): - """ - Returns a connected Pathoc instance. - """ - p = libpathod.pathoc.Pathoc(("localhost", self.chain[-1].port), ssl=self.ssl, sni=sni) - if self.ssl: - p.connect(("127.0.0.1", self.server.port)) - else: - p.connect() - return p +class HTTPUpstreamProxTest(ChainProxTest, HTTPProxTest): + pass
\ No newline at end of file diff --git a/test/tutils.py b/test/tutils.py index 05f65a21..69f79a91 100644 --- a/test/tutils.py +++ b/test/tutils.py @@ -23,7 +23,38 @@ def SkipWindows(fn): return fn +def tflow(client_conn=True, server_conn=True, req=True, resp=None, err=None): + """ + @type client_conn: bool | None | libmproxy.proxy.connection.ClientConnection + @type server_conn: bool | None | libmproxy.proxy.connection.ServerConnection + @type req: bool | None | libmproxy.protocol.http.HTTPRequest + @type resp: bool | None | libmproxy.protocol.http.HTTPResponse + @type err: bool | None | libmproxy.protocol.primitives.Error + @return: bool | None | libmproxy.protocol.http.HTTPFlow + """ + if client_conn is True: + client_conn = tclient_conn() + if server_conn is True: + server_conn = tserver_conn() + if req is True: + req = treq() + if resp is True: + resp = tresp() + if err is True: + err = terr() + + f = http.HTTPFlow(client_conn, server_conn) + f.request = req + f.response = resp + f.error = err + f.reply = controller.DummyReply() + return f + + def tclient_conn(): + """ + @return: libmproxy.proxy.connection.ClientConnection + """ c = ClientConnection._from_state(dict( address=dict(address=("address", 22), use_ipv6=True), clientcert=None @@ -33,6 +64,9 @@ def tclient_conn(): def tserver_conn(): + """ + @return: libmproxy.proxy.connection.ServerConnection + """ c = ServerConnection._from_state(dict( address=dict(address=("address", 22), use_ipv6=True), state=[], @@ -43,75 +77,46 @@ def tserver_conn(): return c -def treq_absolute(conn=None, content="content"): - r = treq(conn, content) +def treq(content="content", scheme="http", host="address", port=22): + """ + @return: libmproxy.protocol.http.HTTPRequest + """ + headers = flow.ODictCaseless() + headers["header"] = ["qvalue"] + req = http.HTTPRequest("relative", "GET", scheme, host, port, "/path", (1, 1), headers, content, + None, None, None) + return req + +def treq_absolute(content="content"): + """ + @return: libmproxy.protocol.http.HTTPRequest + """ + r = treq(content) r.form_in = r.form_out = "absolute" r.host = "address" r.port = 22 r.scheme = "http" return r -def treq(conn=None, content="content"): - if not conn: - conn = tclient_conn() - server_conn = tserver_conn() - headers = flow.ODictCaseless() - headers["header"] = ["qvalue"] - f = http.HTTPFlow(conn, server_conn) - f.request = http.HTTPRequest("relative", "GET", None, None, None, "/path", (1, 1), headers, content, - None, None, None) - f.request.reply = controller.DummyReply() - return f.request - - -def tresp(req=None, content="message"): - if not req: - req = treq() - f = req.flow +def tresp(content="message"): + """ + @return: libmproxy.protocol.http.HTTPResponse + """ headers = flow.ODictCaseless() headers["header_response"] = ["svalue"] - cert = certutils.SSLCert.from_der(file(test_data.path("data/dercert"), "rb").read()) - f.server_conn = ServerConnection._from_state(dict( - address=dict(address=("address", 22), use_ipv6=True), - state=[], - source_address=None, - cert=cert.to_pem())) - f.response = http.HTTPResponse((1, 1), 200, "OK", headers, content, time(), time()) - f.response.reply = controller.DummyReply() - return f.response + resp = http.HTTPResponse((1, 1), 200, "OK", headers, content, time(), time()) + return resp -def terr(req=None): - if not req: - req = treq() - f = req.flow - f.error = Error("error") - f.error.reply = controller.DummyReply() - return f.error - -def tflow_noreq(): - f = tflow() - f.request = None - return f -def tflow(req=None): - if not req: - req = treq() - return req.flow - - -def tflow_full(): - f = tflow() - f.response = tresp(f.request) - return f - - -def tflow_err(): - f = tflow() - f.error = terr(f.request) - return f +def terr(content="error"): + """ + @return: libmproxy.protocol.primitives.Error + """ + err = Error(content) + return err def tflowview(request_contents=None): m = Mock() @@ -119,8 +124,7 @@ def tflowview(request_contents=None): if request_contents == None: flow = tflow() else: - req = treq(None, request_contents) - flow = tflow(req) + flow = tflow(req=treq(request_contents)) fv = FlowView(m, cs, flow) return fv |