diff options
-rwxr-xr-x[-rw-r--r--] | examples/flowbasic | 11 | ||||
-rwxr-xr-x[-rw-r--r--] | examples/stickycookies | 8 | ||||
-rw-r--r-- | mitmproxy/console/__init__.py | 29 | ||||
-rw-r--r-- | mitmproxy/controller.py | 162 | ||||
-rw-r--r-- | mitmproxy/dump.py | 17 | ||||
-rw-r--r-- | mitmproxy/flow.py | 81 | ||||
-rw-r--r-- | mitmproxy/models/flow.py | 2 | ||||
-rw-r--r-- | mitmproxy/proxy/root_context.py | 1 | ||||
-rw-r--r-- | mitmproxy/web/__init__.py | 14 | ||||
-rw-r--r-- | netlib/tcp.py | 2 | ||||
-rw-r--r-- | test/mitmproxy/test_controller.py | 38 | ||||
-rw-r--r-- | test/mitmproxy/test_dump.py | 22 | ||||
-rw-r--r-- | test/mitmproxy/test_flow.py | 64 | ||||
-rw-r--r-- | test/mitmproxy/test_script.py | 2 | ||||
-rw-r--r-- | test/mitmproxy/test_server.py | 66 | ||||
-rw-r--r-- | test/mitmproxy/tservers.py | 16 |
16 files changed, 301 insertions, 234 deletions
diff --git a/examples/flowbasic b/examples/flowbasic index 4a87b86a..d8d55caa 100644..100755 --- a/examples/flowbasic +++ b/examples/flowbasic @@ -8,7 +8,7 @@ Note that request and response messages are not automatically replied to, so we need to implement handlers to do this. """ -from mitmproxy import flow +from mitmproxy import flow, controller from mitmproxy.proxy import ProxyServer, ProxyConfig @@ -19,18 +19,15 @@ class MyMaster(flow.FlowMaster): except KeyboardInterrupt: self.shutdown() + @controller.handler def handle_request(self, f): f = flow.FlowMaster.handle_request(self, f) - if f: - f.reply() - return f + print(f) + @controller.handler def handle_response(self, f): f = flow.FlowMaster.handle_response(self, f) - if f: - f.reply() print(f) - return f config = ProxyConfig( diff --git a/examples/stickycookies b/examples/stickycookies index 8f11de8d..43e5371d 100644..100755 --- a/examples/stickycookies +++ b/examples/stickycookies @@ -21,19 +21,19 @@ class StickyMaster(controller.Master): except KeyboardInterrupt: self.shutdown() - def handle_request(self, flow): + @controller.handler + def request(self, flow): hid = (flow.request.host, flow.request.port) if "cookie" in flow.request.headers: self.stickyhosts[hid] = flow.request.headers.get_all("cookie") elif hid in self.stickyhosts: flow.request.headers.set_all("cookie", self.stickyhosts[hid]) - flow.reply() - def handle_response(self, flow): + @controller.handler + def response(self, flow): hid = (flow.request.host, flow.request.port) if "set-cookie" in flow.response.headers: self.stickyhosts[hid] = flow.response.headers.get_all("set-cookie") - flow.reply() config = proxy.ProxyConfig(port=8080) diff --git a/mitmproxy/console/__init__.py b/mitmproxy/console/__init__.py index 1dd032be..c3157292 100644 --- a/mitmproxy/console/__init__.py +++ b/mitmproxy/console/__init__.py @@ -16,7 +16,7 @@ import weakref from netlib import tcp -from .. import flow, script, contentviews +from .. import flow, script, contentviews, controller from . import flowlist, flowview, help, window, signals, options from . import grideditor, palettes, statusbar, palettepicker from ..exceptions import FlowReadException, ScriptException @@ -713,14 +713,15 @@ class ConsoleMaster(flow.FlowMaster): ) def process_flow(self, f): - if self.state.intercept and f.match(self.state.intercept) and not f.request.is_replay: + should_intercept = any( + [ + self.state.intercept and f.match(self.state.intercept) and not f.request.is_replay, + f.intercepted, + ] + ) + if should_intercept: f.intercept(self) - else: - # check if flow was intercepted within an inline script by flow.intercept() - if f.intercepted: - f.intercept(self) - else: - f.reply() + f.reply.take() signals.flowlist_change.send(self) signals.flow_change.send(self, flow = f) @@ -728,25 +729,29 @@ class ConsoleMaster(flow.FlowMaster): self.eventlist[:] = [] # Handlers - def handle_error(self, f): + @controller.handler + def error(self, f): f = flow.FlowMaster.handle_error(self, f) if f: self.process_flow(f) return f - def handle_request(self, f): + @controller.handler + def request(self, f): f = flow.FlowMaster.handle_request(self, f) if f: self.process_flow(f) return f - def handle_response(self, f): + @controller.handler + def response(self, f): f = flow.FlowMaster.handle_response(self, f) if f: self.process_flow(f) return f - def handle_script_change(self, script): + @controller.handler + def script_change(self, script): if super(ConsoleMaster, self).handle_script_change(script): signals.status_message.send(message='"{}" reloaded.'.format(script.filename)) else: diff --git a/mitmproxy/controller.py b/mitmproxy/controller.py index af8a77bd..dcf920ef 100644 --- a/mitmproxy/controller.py +++ b/mitmproxy/controller.py @@ -1,22 +1,58 @@ from __future__ import absolute_import from six.moves import queue import threading +import functools +import sys -from .exceptions import Kill +from . import exceptions +Events = frozenset([ + "clientconnect", + "clientdisconnect", + "serverconnect", + "serverdisconnect", -class Master(object): + "tcp_open", + "tcp_message", + "tcp_error", + "tcp_close", + + "request", + "response", + "responseheaders", + + "next_layer", + + "error", + "log", +]) + +class ControlError(Exception): + pass + + +class Master(object): """ - The master handles mitmproxy's main event loop. + The master handles mitmproxy's main event loop. """ - - def __init__(self): + def __init__(self, *servers): self.event_queue = queue.Queue() self.should_exit = threading.Event() + self.servers = [] + for i in servers: + self.add_server(i) + + def add_server(self, server): + # We give a Channel to the server which can be used to communicate with the master + channel = Channel(self.event_queue, self.should_exit) + server.set_channel(channel) + self.servers.append(server) def start(self): self.should_exit.clear() + for server in self.servers: + ServerThread(server).start() def run(self): self.start() @@ -36,7 +72,17 @@ class Master(object): # exception is thrown. while True: mtype, obj = self.event_queue.get(timeout=timeout) - handle_func = getattr(self, "handle_" + mtype) + if mtype not in Events: + raise ControlError("Unknown event %s"%repr(mtype)) + handle_func = getattr(self, mtype) + if not hasattr(handle_func, "func_dict"): + raise ControlError("Handler %s not a function"%mtype) + if not handle_func.func_dict.get("__handler"): + raise ControlError( + "Handler function %s is not decorated with controller.handler"%( + handle_func + ) + ) handle_func(obj) self.event_queue.task_done() changed = True @@ -45,38 +91,12 @@ class Master(object): return changed def shutdown(self): - self.should_exit.set() - - -class ServerMaster(Master): - - """ - The ServerMaster adds server thread support to the master. - """ - - def __init__(self): - super(ServerMaster, self).__init__() - self.servers = [] - - def add_server(self, server): - # We give a Channel to the server which can be used to communicate with the master - channel = Channel(self.event_queue, self.should_exit) - server.set_channel(channel) - self.servers.append(server) - - def start(self): - super(ServerMaster, self).start() - for server in self.servers: - ServerThread(server).start() - - def shutdown(self): for server in self.servers: server.shutdown() - super(ServerMaster, self).shutdown() + self.should_exit.set() class ServerThread(threading.Thread): - def __init__(self, server): self.server = server super(ServerThread, self).__init__() @@ -88,12 +108,10 @@ class ServerThread(threading.Thread): class Channel(object): - """ - The only way for the proxy server to communicate with the master - is to use the channel it has been given. + The only way for the proxy server to communicate with the master + is to use the channel it has been given. """ - def __init__(self, q, should_exit): self.q = q self.should_exit = should_exit @@ -104,7 +122,7 @@ class Channel(object): master. Then wait for a response. Raises: - Kill: All connections should be closed immediately. + exceptions.Kill: All connections should be closed immediately. """ m.reply = Reply(m) self.q.put((mtype, m)) @@ -114,11 +132,10 @@ class Channel(object): g = m.reply.q.get(timeout=0.5) except queue.Empty: # pragma: no cover continue - if g == Kill: - raise Kill() + if g == exceptions.Kill: + raise exceptions.Kill() return g - - raise Kill() + raise exceptions.Kill() def tell(self, mtype, m): """ @@ -130,14 +147,17 @@ class Channel(object): class DummyReply(object): - """ A reply object that does nothing. Useful when we need an object to seem like it has a channel, and during testing. """ - def __init__(self): self.acked = False + self.taken = False + self.handled = False + + def take(self): + self.taken = True def __call__(self, msg=False): self.acked = True @@ -147,23 +167,63 @@ class DummyReply(object): NO_REPLY = object() -class Reply(object): +def handler(f): + @functools.wraps(f) + def wrapper(*args, **kwargs): + if len(args) == 1: + message = args[0] + elif len(args) == 2: + message = args[1] + else: + raise ControlError("Handler takes one argument: a message") + + if not hasattr(message, "reply"): + raise ControlError("Message %s has no reply attribute"%message) + + handling = False + # We're the first handler - ack responsibility is ours + if not message.reply.handled: + handling = True + message.reply.handled = True + + ret = f(*args, **kwargs) + if handling and not message.reply.acked and not message.reply.taken: + message.reply() + return ret + wrapper.func_dict["__handler"] = True + return wrapper + + +class Reply(object): """ Messages sent through a channel are decorated with a "reply" attribute. This object is used to respond to the message through the return channel. """ - def __init__(self, obj): self.obj = obj self.q = queue.Queue() + # Has this message been acked? self.acked = False + # Has the user taken responsibility for ack-ing? + self.taken = False + # Has a handler taken responsibility for ack-ing? + self.handled = False + + def take(self): + self.taken = True def __call__(self, msg=NO_REPLY): + if self.acked: + raise ControlError("Message already acked.") + self.acked = True + if msg is NO_REPLY: + self.q.put(self.obj) + else: + self.q.put(msg) + + def __del__(self): if not self.acked: - self.acked = True - if msg is NO_REPLY: - self.q.put(self.obj) - else: - self.q.put(msg) + # This will be ignored by the interpreter, but emit a warning + raise ControlError("Un-acked message") diff --git a/mitmproxy/dump.py b/mitmproxy/dump.py index 8f9488be..4443995a 100644 --- a/mitmproxy/dump.py +++ b/mitmproxy/dump.py @@ -6,7 +6,7 @@ import itertools from netlib import tcp from netlib.utils import bytes_to_escaped_str, pretty_size -from . import flow, filt, contentviews +from . import flow, filt, contentviews, controller from .exceptions import ContentViewException, FlowReadException, ScriptException @@ -325,22 +325,25 @@ class DumpMaster(flow.FlowMaster): self.echo_flow(f) - def handle_request(self, f): - flow.FlowMaster.handle_request(self, f) + @controller.handler + def request(self, f): + flow.FlowMaster.request(self, f) self.state.delete_flow(f) if f: f.reply() return f - def handle_response(self, f): - flow.FlowMaster.handle_response(self, f) + @controller.handler + def response(self, f): + flow.FlowMaster.response(self, f) if f: f.reply() self._process_flow(f) return f - def handle_error(self, f): - flow.FlowMaster.handle_error(self, f) + @controller.handler + def error(self, f): + flow.FlowMaster.error(self, f) if f: self._process_flow(f) return f diff --git a/mitmproxy/flow.py b/mitmproxy/flow.py index d70ec2d9..407f0d7b 100644 --- a/mitmproxy/flow.py +++ b/mitmproxy/flow.py @@ -208,9 +208,9 @@ class ClientPlaybackState: master.replay_request(self.current) else: self.current.reply = controller.DummyReply() - master.handle_request(self.current) + master.request(self.current) if self.current.response: - master.handle_response(self.current) + master.response(self.current) class ServerPlaybackState: @@ -546,7 +546,8 @@ class FlowStore(FlowList): def kill_all(self, master): for f in self._list: - f.kill(master) + if not f.reply.acked: + f.kill(master) class State(object): @@ -637,7 +638,7 @@ class State(object): self.flows.kill_all(master) -class FlowMaster(controller.ServerMaster): +class FlowMaster(controller.Master): @property def server(self): @@ -893,23 +894,23 @@ class FlowMaster(controller.ServerMaster): f.reply = controller.DummyReply() if f.request: - self.handle_request(f) + self.request(f) if f.response: - self.handle_responseheaders(f) - self.handle_response(f) + self.responseheaders(f) + self.response(f) if f.error: - self.handle_error(f) + self.error(f) elif isinstance(f, TCPFlow): messages = f.messages f.messages = [] f.reply = controller.DummyReply() - self.handle_tcp_open(f) + self.tcp_open(f) while messages: f.messages.append(messages.pop(0)) - self.handle_tcp_message(f) + self.tcp_message(f) if f.error: - self.handle_tcp_error(f) - self.handle_tcp_close(f) + self.tcp_error(f) + self.tcp_close(f) else: raise NotImplementedError() @@ -985,39 +986,40 @@ class FlowMaster(controller.ServerMaster): if block: rt.join() - def handle_log(self, l): + @controller.handler + def log(self, l): self.add_event(l.msg, l.level) - l.reply() - def handle_clientconnect(self, root_layer): + @controller.handler + def clientconnect(self, root_layer): self.run_script_hook("clientconnect", root_layer) - root_layer.reply() - def handle_clientdisconnect(self, root_layer): + @controller.handler + def clientdisconnect(self, root_layer): self.run_script_hook("clientdisconnect", root_layer) - root_layer.reply() - def handle_serverconnect(self, server_conn): + @controller.handler + def serverconnect(self, server_conn): self.run_script_hook("serverconnect", server_conn) - server_conn.reply() - def handle_serverdisconnect(self, server_conn): + @controller.handler + def serverdisconnect(self, server_conn): self.run_script_hook("serverdisconnect", server_conn) - server_conn.reply() - def handle_next_layer(self, top_layer): + @controller.handler + def next_layer(self, top_layer): self.run_script_hook("next_layer", top_layer) - top_layer.reply() - def handle_error(self, f): + @controller.handler + def error(self, f): self.state.update_flow(f) self.run_script_hook("error", f) if self.client_playback: self.client_playback.clear(f) - f.reply() return f - def handle_request(self, f): + @controller.handler + def request(self, f): if f.live: app = self.apps.get(f.request) if app: @@ -1039,20 +1041,19 @@ class FlowMaster(controller.ServerMaster): self.run_script_hook("request", f) return f - def handle_responseheaders(self, f): + @controller.handler + def responseheaders(self, f): try: if self.stream_large_bodies: self.stream_large_bodies.run(f, False) except HttpException: f.reply(Kill) return - self.run_script_hook("responseheaders", f) - - f.reply() return f - def handle_response(self, f): + @controller.handler + def response(self, f): self.active_flows.discard(f) self.state.update_flow(f) self.replacehooks.run(f) @@ -1099,14 +1100,15 @@ class FlowMaster(controller.ServerMaster): self.add_event('"{}" reloaded.'.format(s.filename), 'info') return ok - def handle_tcp_open(self, flow): + @controller.handler + def tcp_open(self, flow): # TODO: This would break mitmproxy currently. # self.state.add_flow(flow) self.active_flows.add(flow) self.run_script_hook("tcp_open", flow) - flow.reply() - def handle_tcp_message(self, flow): + @controller.handler + def tcp_message(self, flow): self.run_script_hook("tcp_message", flow) message = flow.messages[-1] direction = "->" if message.from_client else "<-" @@ -1116,22 +1118,21 @@ class FlowMaster(controller.ServerMaster): direction=direction, ), "info") self.add_event(clean_bin(message.content), "debug") - flow.reply() - def handle_tcp_error(self, flow): + @controller.handler + def tcp_error(self, flow): self.add_event("Error in TCP connection to {}: {}".format( repr(flow.server_conn.address), flow.error ), "info") self.run_script_hook("tcp_error", flow) - flow.reply() - def handle_tcp_close(self, flow): + @controller.handler + def tcp_close(self, flow): self.active_flows.discard(flow) if self.stream: self.stream.add(flow) self.run_script_hook("tcp_close", flow) - flow.reply() def shutdown(self): super(FlowMaster, self).shutdown() diff --git a/mitmproxy/models/flow.py b/mitmproxy/models/flow.py index 1019c9fb..8797fcd8 100644 --- a/mitmproxy/models/flow.py +++ b/mitmproxy/models/flow.py @@ -152,7 +152,7 @@ class Flow(stateobject.StateObject): self.error = Error("Connection killed") self.intercepted = False self.reply(Kill) - master.handle_error(self) + master.error(self) def intercept(self, master): """ diff --git a/mitmproxy/proxy/root_context.py b/mitmproxy/proxy/root_context.py index 96e7aab6..9b4e2963 100644 --- a/mitmproxy/proxy/root_context.py +++ b/mitmproxy/proxy/root_context.py @@ -132,7 +132,6 @@ class RootContext(object): class Log(object): - def __init__(self, msg, level="info"): self.msg = msg self.level = level diff --git a/mitmproxy/web/__init__.py b/mitmproxy/web/__init__.py index 956d221d..f8102ed8 100644 --- a/mitmproxy/web/__init__.py +++ b/mitmproxy/web/__init__.py @@ -6,7 +6,7 @@ import sys from netlib.http import authentication -from .. import flow +from .. import flow, controller from ..exceptions import FlowReadException from . import app @@ -194,18 +194,20 @@ class WebMaster(flow.FlowMaster): if self.state.intercept and self.state.intercept( f) and not f.request.is_replay: f.intercept(self) - else: - f.reply() + f.reply.take() - def handle_request(self, f): + @controller.handler + def request(self, f): super(WebMaster, self).handle_request(f) self._process_flow(f) - def handle_response(self, f): + @controller.handler + def response(self, f): super(WebMaster, self).handle_response(f) self._process_flow(f) - def handle_error(self, f): + @controller.handler + def error(self, f): super(WebMaster, self).handle_error(f) self._process_flow(f) diff --git a/netlib/tcp.py b/netlib/tcp.py index ad75cff8..c7231dbb 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -901,7 +901,7 @@ class TCPServer(object): """ # If a thread has persisted after interpreter exit, the module might be # none. - if traceback: + if traceback and six: exc = six.text_type(traceback.format_exc()) print(u'-' * 40, file=fp) print( diff --git a/test/mitmproxy/test_controller.py b/test/mitmproxy/test_controller.py index f7bf615a..83ad428e 100644 --- a/test/mitmproxy/test_controller.py +++ b/test/mitmproxy/test_controller.py @@ -2,7 +2,7 @@ from threading import Thread, Event from mock import Mock -from mitmproxy.controller import Reply, DummyReply, Channel, ServerThread, ServerMaster, Master +from mitmproxy import controller from six.moves import queue from mitmproxy.exceptions import Kill @@ -10,11 +10,15 @@ from mitmproxy.proxy import DummyServer from netlib.tutils import raises +class TMsg: + pass + + class TestMaster(object): def test_simple(self): - - class DummyMaster(Master): - def handle_panic(self, _): + class DummyMaster(controller.Master): + @controller.handler + def log(self, _): m.should_exit.set() def tick(self, timeout): @@ -23,14 +27,14 @@ class TestMaster(object): m = DummyMaster() assert not m.should_exit.is_set() - m.event_queue.put(("panic", 42)) + msg = TMsg() + msg.reply = controller.DummyReply() + m.event_queue.put(("log", msg)) m.run() assert m.should_exit.is_set() - -class TestServerMaster(object): - def test_simple(self): - m = ServerMaster() + def test_server_simple(self): + m = controller.Master() s = DummyServer(None) m.add_server(s) m.start() @@ -42,7 +46,7 @@ class TestServerMaster(object): class TestServerThread(object): def test_simple(self): m = Mock() - t = ServerThread(m) + t = controller.ServerThread(m) t.run() assert m.serve_forever.called @@ -50,7 +54,7 @@ class TestServerThread(object): class TestChannel(object): def test_tell(self): q = queue.Queue() - channel = Channel(q, Event()) + channel = controller.Channel(q, Event()) m = Mock() channel.tell("test", m) assert q.get() == ("test", m) @@ -66,21 +70,21 @@ class TestChannel(object): Thread(target=reply).start() - channel = Channel(q, Event()) + channel = controller.Channel(q, Event()) assert channel.ask("test", Mock()) == 42 def test_ask_shutdown(self): q = queue.Queue() done = Event() done.set() - channel = Channel(q, done) + channel = controller.Channel(q, done) with raises(Kill): channel.ask("test", Mock()) class TestDummyReply(object): def test_simple(self): - reply = DummyReply() + reply = controller.DummyReply() assert not reply.acked reply() assert reply.acked @@ -88,18 +92,18 @@ class TestDummyReply(object): class TestReply(object): def test_simple(self): - reply = Reply(42) + reply = controller.Reply(42) assert not reply.acked reply("foo") assert reply.acked assert reply.q.get() == "foo" def test_default(self): - reply = Reply(42) + reply = controller.Reply(42) reply() assert reply.q.get() == 42 def test_reply_none(self): - reply = Reply(42) + reply = controller.Reply(42) reply(None) assert reply.q.get() is None diff --git a/test/mitmproxy/test_dump.py b/test/mitmproxy/test_dump.py index ad4cee53..7d625c34 100644 --- a/test/mitmproxy/test_dump.py +++ b/test/mitmproxy/test_dump.py @@ -64,14 +64,14 @@ class TestDumpMaster: f = tutils.tflow(req=netlib.tutils.treq(content=content)) l = Log("connect") l.reply = mock.MagicMock() - m.handle_log(l) - m.handle_clientconnect(f.client_conn) - m.handle_serverconnect(f.server_conn) - m.handle_request(f) + m.log(l) + m.clientconnect(f.client_conn) + m.serverconnect(f.server_conn) + m.request(f) if not f.error: f.response = HTTPResponse.wrap(netlib.tutils.tresp(content=content)) - f = m.handle_response(f) - m.handle_clientdisconnect(f.client_conn) + f = m.response(f) + m.clientdisconnect(f.client_conn) return f def _dummy_cycle(self, n, filt, content, **options): @@ -95,8 +95,8 @@ class TestDumpMaster: o = dump.Options(flow_detail=1) m = dump.DumpMaster(None, o, outfile=cs) f = tutils.tflow(err=True) - m.handle_request(f) - assert m.handle_error(f) + m.request(f) + assert m.error(f) assert "error" in cs.getvalue() def test_missing_content(self): @@ -105,10 +105,10 @@ class TestDumpMaster: m = dump.DumpMaster(None, o, outfile=cs) f = tutils.tflow() f.request.content = None - m.handle_request(f) + m.request(f) f.response = HTTPResponse.wrap(netlib.tutils.tresp()) f.response.content = None - m.handle_response(f) + m.response(f) assert "content missing" in cs.getvalue() def test_replay(self): @@ -160,7 +160,7 @@ class TestDumpMaster: assert o.verbosity == 2 def test_filter(self): - assert not "GET" in self._dummy_cycle(1, "~u foo", "", verbosity=1) + assert "GET" not in self._dummy_cycle(1, "~u foo", "", verbosity=1) def test_app(self): o = dump.Options(app=True) diff --git a/test/mitmproxy/test_flow.py b/test/mitmproxy/test_flow.py index 5441ea59..c5e39966 100644 --- a/test/mitmproxy/test_flow.py +++ b/test/mitmproxy/test_flow.py @@ -53,7 +53,7 @@ class TestStickyCookieState: assert s.domain_match("www.google.com", ".google.com") assert s.domain_match("google.com", ".google.com") - def test_handle_response(self): + def test_response(self): c = "SSID=mooo; domain=.google.com, FOO=bar; Domain=.google.com; Path=/; "\ "Expires=Wed, 13-Jan-2021 22:23:01 GMT; Secure; " @@ -100,7 +100,7 @@ class TestStickyCookieState: assert len(s.jar[googlekey].keys()) == 1 assert s.jar[googlekey]["somecookie"].items()[0][1] == "newvalue" - def test_handle_request(self): + def test_request(self): s, f = self._response("SSID=mooo", "www.google.com") assert "cookie" not in f.request.headers s.handle_request(f) @@ -109,7 +109,7 @@ class TestStickyCookieState: class TestStickyAuthState: - def test_handle_response(self): + def test_response(self): s = flow.StickyAuthState(filt.parse(".*")) f = tutils.tflow(resp=True) f.request.headers["authorization"] = "foo" @@ -460,25 +460,20 @@ class TestFlow(object): fm = flow.FlowMaster(None, s) f = tutils.tflow() f.intercept(mock.Mock()) - assert not f.reply.acked f.kill(fm) - assert f.reply.acked + for i in s.view: + assert "killed" in str(i.error) def test_killall(self): s = flow.State() fm = flow.FlowMaster(None, s) f = tutils.tflow() - fm.handle_request(f) + f.intercept(fm) - f = tutils.tflow() - fm.handle_request(f) - - for i in s.view: - assert not i.reply.acked s.killall(fm) for i in s.view: - assert i.reply.acked + assert "killed" in str(i.error) def test_accept_intercept(self): f = tutils.tflow() @@ -803,8 +798,8 @@ class TestFlowMaster: fm = flow.FlowMaster(None, s) fm.load_script(tutils.test_data.path("scripts/reqerr.py")) f = tutils.tflow() - fm.handle_clientconnect(f.client_conn) - assert fm.handle_request(f) + fm.clientconnect(f.client_conn) + assert fm.request(f) def test_script(self): s = flow.State() @@ -812,18 +807,18 @@ class TestFlowMaster: fm.load_script(tutils.test_data.path("scripts/all.py")) f = tutils.tflow(resp=True) - fm.handle_clientconnect(f.client_conn) + fm.clientconnect(f.client_conn) assert fm.scripts[0].ns["log"][-1] == "clientconnect" - fm.handle_serverconnect(f.server_conn) + fm.serverconnect(f.server_conn) assert fm.scripts[0].ns["log"][-1] == "serverconnect" - fm.handle_request(f) + fm.request(f) assert fm.scripts[0].ns["log"][-1] == "request" - fm.handle_response(f) + fm.response(f) assert fm.scripts[0].ns["log"][-1] == "response" # load second script fm.load_script(tutils.test_data.path("scripts/all.py")) assert len(fm.scripts) == 2 - fm.handle_clientdisconnect(f.server_conn) + fm.clientdisconnect(f.server_conn) assert fm.scripts[0].ns["log"][-1] == "clientdisconnect" assert fm.scripts[1].ns["log"][-1] == "clientdisconnect" @@ -833,7 +828,7 @@ class TestFlowMaster: fm.load_script(tutils.test_data.path("scripts/all.py")) f.error = tutils.terr() - fm.handle_error(f) + fm.error(f) assert fm.scripts[0].ns["log"][-1] == "error" def test_duplicate_flow(self): @@ -858,21 +853,20 @@ class TestFlowMaster: fm.anticache = True fm.anticomp = True f = tutils.tflow(req=None) - fm.handle_clientconnect(f.client_conn) + fm.clientconnect(f.client_conn) f.request = HTTPRequest.wrap(netlib.tutils.treq()) - fm.handle_request(f) + fm.request(f) assert s.flow_count() == 1 f.response = HTTPResponse.wrap(netlib.tutils.tresp()) - fm.handle_response(f) - assert not fm.handle_response(None) + fm.response(f) assert s.flow_count() == 1 - fm.handle_clientdisconnect(f.client_conn) + fm.clientdisconnect(f.client_conn) f.error = Error("msg") f.error.reply = controller.DummyReply() - fm.handle_error(f) + fm.error(f) fm.load_script(tutils.test_data.path("scripts/a.py")) fm.shutdown() @@ -901,7 +895,7 @@ class TestFlowMaster: assert fm.state.flow_count() f.error = Error("error") - fm.handle_error(f) + fm.error(f) def test_server_playback(self): s = flow.State() @@ -982,12 +976,12 @@ class TestFlowMaster: fm.set_stickycookie(".*") f = tutils.tflow(resp=True) f.response.headers["set-cookie"] = "foo=bar" - fm.handle_request(f) - fm.handle_response(f) + fm.request(f) + fm.response(f) assert fm.stickycookie_state.jar assert not "cookie" in f.request.headers f = f.copy() - fm.handle_request(f) + fm.request(f) assert f.request.headers["cookie"] == "foo=bar" def test_stickyauth(self): @@ -1002,12 +996,12 @@ class TestFlowMaster: fm.set_stickyauth(".*") f = tutils.tflow(resp=True) f.request.headers["authorization"] = "foo" - fm.handle_request(f) + fm.request(f) f = tutils.tflow(resp=True) assert fm.stickyauth_state.hosts assert not "authorization" in f.request.headers - fm.handle_request(f) + fm.request(f) assert f.request.headers["authorization"] == "foo" def test_stream(self): @@ -1023,15 +1017,15 @@ class TestFlowMaster: f = tutils.tflow(resp=True) fm.start_stream(file(p, "ab"), None) - fm.handle_request(f) - fm.handle_response(f) + fm.request(f) + fm.response(f) fm.stop_stream() assert r()[0].response f = tutils.tflow() fm.start_stream(file(p, "ab"), None) - fm.handle_request(f) + fm.request(f) fm.shutdown() assert not r()[1].response diff --git a/test/mitmproxy/test_script.py b/test/mitmproxy/test_script.py index f321d15c..dd6f51ae 100644 --- a/test/mitmproxy/test_script.py +++ b/test/mitmproxy/test_script.py @@ -7,7 +7,7 @@ def test_duplicate_flow(): fm = flow.FlowMaster(None, s) fm.load_script(tutils.test_data.path("scripts/duplicate_flow.py")) f = tutils.tflow() - fm.handle_request(f) + fm.request(f) assert fm.state.flow_count() == 2 assert not fm.state.view[0].request.is_replay assert fm.state.view[1].request.is_replay diff --git a/test/mitmproxy/test_server.py b/test/mitmproxy/test_server.py index 0701d52b..bb4949e1 100644 --- a/test/mitmproxy/test_server.py +++ b/test/mitmproxy/test_server.py @@ -12,6 +12,7 @@ from netlib.http import authentication, http1 from netlib.tutils import raises from pathod import pathoc, pathod +from mitmproxy import controller from mitmproxy.proxy.config import HostMatcher from mitmproxy.exceptions import Kill from mitmproxy.models import Error, HTTPResponse, HTTPFlow @@ -190,8 +191,8 @@ class TcpMixin: assert i_cert == i2_cert == n_cert # Make sure that TCP messages are in the event log. - assert any("305" in m for m in self.master.log) - assert any("306" in m for m in self.master.log) + assert any("305" in m for m in self.master.tlog) + assert any("306" in m for m in self.master.tlog) class AppMixin: @@ -260,7 +261,7 @@ class TestHTTP(tservers.HTTPProxyTest, CommonMixin, AppMixin): p = self.pathoc() assert p.request(req % self.server.urlbase) assert p.request(req % self.server2.urlbase) - assert switched(self.proxy.log) + assert switched(self.proxy.tlog) def test_blank_leading_line(self): p = self.pathoc() @@ -499,7 +500,7 @@ class TestHttps2Http(tservers.ReverseProxyTest): def test_sni(self): p = self.pathoc(ssl=True, sni="example.com") assert p.request("get:'/p/200'").status_code == 200 - assert all("Error in handle_sni" not in msg for msg in self.proxy.log) + assert all("Error in handle_sni" not in msg for msg in self.proxy.tlog) def test_http(self): p = self.pathoc(ssl=False) @@ -623,7 +624,8 @@ class TestProxySSL(tservers.HTTPProxyTest): class MasterRedirectRequest(tservers.TestMaster): redirect_port = None # Set by TestRedirectRequest - def handle_request(self, f): + @controller.handler + def request(self, f): if f.request.path == "/p/201": # This part should have no impact, but it should also not cause any exceptions. @@ -634,12 +636,13 @@ class MasterRedirectRequest(tservers.TestMaster): # This is the actual redirection. f.request.port = self.redirect_port - super(MasterRedirectRequest, self).handle_request(f) + super(MasterRedirectRequest, self).request(f) - def handle_response(self, f): + @controller.handler + def response(self, f): f.response.content = str(f.client_conn.address.port) f.response.headers["server-conn-id"] = str(f.server_conn.source_address.port) - super(MasterRedirectRequest, self).handle_response(f) + super(MasterRedirectRequest, self).response(f) class TestRedirectRequest(tservers.HTTPProxyTest): @@ -689,10 +692,9 @@ class MasterStreamRequest(tservers.TestMaster): """ Enables the stream flag on the flow for all requests """ - - def handle_responseheaders(self, f): + @controller.handler + def responseheaders(self, f): f.response.stream = True - f.reply() class TestStreamRequest(tservers.HTTPProxyTest): @@ -739,8 +741,8 @@ class TestStreamRequest(tservers.HTTPProxyTest): class MasterFakeResponse(tservers.TestMaster): - - def handle_request(self, f): + @controller.handler + def request(self, f): resp = HTTPResponse.wrap(netlib.tutils.tresp()) f.reply(resp) @@ -761,13 +763,14 @@ class TestServerConnect(tservers.HTTPProxyTest): def test_unnecessary_serverconnect(self): """A replayed/fake response with no_upstream_cert should not connect to an upstream server""" assert self.pathod("200").status_code == 200 - for msg in self.proxy.tmaster.log: + for msg in self.proxy.tmaster.tlog: assert "serverconnect" not in msg class MasterKillRequest(tservers.TestMaster): - def handle_request(self, f): + @controller.handler + def request(self, f): f.reply(Kill) @@ -783,7 +786,8 @@ class TestKillRequest(tservers.HTTPProxyTest): class MasterKillResponse(tservers.TestMaster): - def handle_response(self, f): + @controller.handler + def response(self, f): f.reply(Kill) @@ -812,7 +816,8 @@ class TestTransparentResolveError(tservers.TransparentProxyTest): class MasterIncomplete(tservers.TestMaster): - def handle_request(self, f): + @controller.handler + def request(self, f): resp = HTTPResponse.wrap(netlib.tutils.tresp()) resp.content = None f.reply(resp) @@ -930,7 +935,9 @@ class TestProxyChainingSSLReconnect(tservers.HTTPUpstreamProxyTest): k = [0] # variable scope workaround: put into array _func = getattr(master, attr) - def handler(f): + @controller.handler + def handler(*args): + f = args[-1] k[0] += 1 if not (k[0] in exclude): f.client_conn.finish() @@ -940,13 +947,16 @@ class TestProxyChainingSSLReconnect(tservers.HTTPUpstreamProxyTest): setattr(master, attr, handler) - kill_requests(self.chain[1].tmaster, "handle_request", - exclude=[ - # fail first request - 2, # allow second request - ]) + kill_requests( + self.chain[1].tmaster, + "request", + exclude = [ + # fail first request + 2, # allow second request + ] + ) - kill_requests(self.chain[0].tmaster, "handle_request", + kill_requests(self.chain[0].tmaster, "request", exclude=[ 1, # CONNECT # fail first request @@ -1004,10 +1014,10 @@ class AddUpstreamCertsToClientChainMixin: ssl = True servercert = tutils.test_data.path("data/trusted-server.crt") ssloptions = pathod.SSLOptions( - cn="trusted-cert", - certs=[ - ("trusted-cert", servercert) - ] + cn="trusted-cert", + certs=[ + ("trusted-cert", servercert) + ] ) def test_add_upstream_certs_to_client_chain(self): diff --git a/test/mitmproxy/tservers.py b/test/mitmproxy/tservers.py index c9d68cfd..24ebb476 100644 --- a/test/mitmproxy/tservers.py +++ b/test/mitmproxy/tservers.py @@ -39,19 +39,11 @@ class TestMaster(flow.FlowMaster): self.apps.add(errapp, "errapp", 80) self.clear_log() - def handle_request(self, f): - flow.FlowMaster.handle_request(self, f) - f.reply() - - def handle_response(self, f): - flow.FlowMaster.handle_response(self, f) - f.reply() - def clear_log(self): - self.log = [] + self.tlog = [] def add_event(self, message, level=None): - self.log.append(message) + self.tlog.append(message) class ProxyThread(threading.Thread): @@ -68,8 +60,8 @@ class ProxyThread(threading.Thread): return self.tmaster.server.address.port @property - def log(self): - return self.tmaster.log + def tlog(self): + return self.tmaster.tlog def run(self): self.tmaster.run() |