diff options
57 files changed, 1178 insertions, 1191 deletions
diff --git a/.appveyor.yml b/.appveyor.yml index 13782ee8..7fa65e1b 100644 --- a/.appveyor.yml +++ b/.appveyor.yml @@ -25,7 +25,7 @@ install: - "pip install -U tox" test_script: - - ps: "tox -- --cov netlib --cov mitmproxy --cov pathod | Select-String -NotMatch Cryptography_locking_cb" + - ps: "tox -- --cov netlib --cov mitmproxy --cov pathod -v" deploy_script: ps: | diff --git a/.travis.yml b/.travis.yml index a114dafb..f7c5e839 100644 --- a/.travis.yml +++ b/.travis.yml @@ -53,7 +53,7 @@ install: fi - pip install tox -script: tox -- --cov netlib --cov mitmproxy --cov pathod +script: tox -- --cov netlib --cov mitmproxy --cov pathod -v after_success: - | diff --git a/mitmproxy/addons.py b/mitmproxy/addons.py index 329d1215..2658c0af 100644 --- a/mitmproxy/addons.py +++ b/mitmproxy/addons.py @@ -4,7 +4,7 @@ import pprint def _get_name(itm): - return getattr(itm, "name", itm.__class__.__name__) + return getattr(itm, "name", itm.__class__.__name__.lower()) class Addons(object): @@ -13,6 +13,16 @@ class Addons(object): self.master = master master.options.changed.connect(self.options_update) + def get(self, name): + """ + Retrieve an addon by name. Addon names are equal to the .name + attribute on the instance, or the lower case class name if that + does not exist. + """ + for i in self.chain: + if name == _get_name(i): + return i + def options_update(self, options, updated): for i in self.chain: with self.master.handlecontext(): @@ -39,14 +49,6 @@ class Addons(object): for i in self.chain: self.invoke_with_context(i, "done") - def has_addon(self, name): - """ - Is an addon with this name registered? - """ - for i in self.chain: - if _get_name(i) == name: - return True - def __len__(self): return len(self.chain) diff --git a/mitmproxy/builtins/__init__.py b/mitmproxy/builtins/__init__.py index 3974d736..26e9dfbd 100644 --- a/mitmproxy/builtins/__init__.py +++ b/mitmproxy/builtins/__init__.py @@ -8,6 +8,8 @@ from mitmproxy.builtins import stickycookie from mitmproxy.builtins import script from mitmproxy.builtins import replace from mitmproxy.builtins import setheaders +from mitmproxy.builtins import serverplayback +from mitmproxy.builtins import clientplayback def default_addons(): @@ -20,4 +22,6 @@ def default_addons(): filestreamer.FileStreamer(), replace.Replace(), setheaders.SetHeaders(), + serverplayback.ServerPlayback(), + clientplayback.ClientPlayback(), ] diff --git a/mitmproxy/builtins/clientplayback.py b/mitmproxy/builtins/clientplayback.py new file mode 100644 index 00000000..c40d1904 --- /dev/null +++ b/mitmproxy/builtins/clientplayback.py @@ -0,0 +1,39 @@ +from mitmproxy import exceptions, flow, ctx + + +class ClientPlayback: + def __init__(self): + self.flows = None + self.current = None + self.keepserving = None + self.has_replayed = False + + def count(self): + if self.flows: + return len(self.flows) + return 0 + + def load(self, flows): + self.flows = flows + + def configure(self, options, updated): + if "client_replay" in updated: + if options.client_replay: + try: + flows = flow.read_flows_from_paths(options.client_replay) + except exceptions.FlowReadException as e: + raise exceptions.OptionsError(str(e)) + self.load(flows) + else: + self.flows = None + self.keepserving = options.keepserving + + def tick(self): + if self.current and not self.current.is_alive(): + self.current = None + if self.flows and not self.current: + self.current = ctx.master.replay_request(self.flows.pop(0)) + self.has_replayed = True + if self.has_replayed: + if not self.flows and not self.current and not self.keepserving: + ctx.master.shutdown() diff --git a/mitmproxy/builtins/dumper.py b/mitmproxy/builtins/dumper.py index 743ca72e..60d00518 100644 --- a/mitmproxy/builtins/dumper.py +++ b/mitmproxy/builtins/dumper.py @@ -232,6 +232,14 @@ class Dumper(object): if self.match(f): self.echo_flow(f) + def tcp_error(self, f): + self.echo( + "Error in TCP connection to {}: {}".format( + repr(f.server_conn.address), f.error + ), + fg="red" + ) + def tcp_message(self, f): if not self.match(f): return diff --git a/mitmproxy/builtins/replace.py b/mitmproxy/builtins/replace.py index c938d683..df3cab04 100644 --- a/mitmproxy/builtins/replace.py +++ b/mitmproxy/builtins/replace.py @@ -36,9 +36,9 @@ class Replace: for rex, s, cpatt in self.lst: if cpatt(f): if f.response: - f.response.replace(rex, s) + f.response.replace(rex, s, flags=re.DOTALL) else: - f.request.replace(rex, s) + f.request.replace(rex, s, flags=re.DOTALL) def request(self, flow): if not flow.reply.has_message: diff --git a/mitmproxy/builtins/script.py b/mitmproxy/builtins/script.py index ae1d1b91..1ebec873 100644 --- a/mitmproxy/builtins/script.py +++ b/mitmproxy/builtins/script.py @@ -10,6 +10,7 @@ import traceback from mitmproxy import exceptions from mitmproxy import controller from mitmproxy import ctx +from mitmproxy.flow import master as flowmaster import watchdog.events @@ -67,7 +68,11 @@ def scriptenv(path, args): tb = tb.tb_next if not os.path.abspath(s[0]).startswith(scriptdir): break - ctx.log.error("Script error: %s" % "".join(traceback.format_exception(etype, value, tb))) + ctx.log.error( + "Script error: %s" % "".join( + traceback.format_exception(etype, value, tb) + ) + ) finally: sys.argv = oldargs sys.path.pop() @@ -189,6 +194,15 @@ class ScriptLoader(): """ An addon that manages loading scripts from options. """ + def run_once(self, command, flows): + sc = Script(command) + sc.load_script() + for f in flows: + for evt, o in flowmaster.event_sequence(f): + sc.run(evt, o) + sc.done() + return sc + def configure(self, options, updated): if "scripts" in updated: for s in options.scripts: diff --git a/mitmproxy/builtins/serverplayback.py b/mitmproxy/builtins/serverplayback.py new file mode 100644 index 00000000..29fc95ef --- /dev/null +++ b/mitmproxy/builtins/serverplayback.py @@ -0,0 +1,122 @@ +from __future__ import absolute_import, print_function, division +from six.moves import urllib +import hashlib + +from netlib import strutils +from mitmproxy import exceptions, flow, ctx + + +class ServerPlayback(object): + def __init__(self): + self.options = None + + self.flowmap = {} + self.stop = False + self.final_flow = None + + def load(self, flows): + for i in flows: + if i.response: + l = self.flowmap.setdefault(self._hash(i), []) + l.append(i) + + def clear(self): + self.flowmap = {} + + def count(self): + return sum([len(i) for i in self.flowmap.values()]) + + def _hash(self, flow): + """ + Calculates a loose hash of the flow request. + """ + r = flow.request + + _, _, path, _, query, _ = urllib.parse.urlparse(r.url) + queriesArray = urllib.parse.parse_qsl(query, keep_blank_values=True) + + key = [str(r.port), str(r.scheme), str(r.method), str(path)] + if not self.options.server_replay_ignore_content: + form_contents = r.urlencoded_form or r.multipart_form + if self.options.server_replay_ignore_payload_params and form_contents: + params = [ + strutils.always_bytes(i) + for i in self.options.server_replay_ignore_payload_params + ] + for p in form_contents.items(multi=True): + if p[0] not in params: + key.append(p) + else: + key.append(str(r.raw_content)) + + if not self.options.server_replay_ignore_host: + key.append(r.host) + + filtered = [] + ignore_params = self.options.server_replay_ignore_params or [] + for p in queriesArray: + if p[0] not in ignore_params: + filtered.append(p) + for p in filtered: + key.append(p[0]) + key.append(p[1]) + + if self.options.server_replay_use_headers: + headers = [] + for i in self.options.server_replay_use_headers: + v = r.headers.get(i) + headers.append((i, v)) + key.append(headers) + return hashlib.sha256( + repr(key).encode("utf8", "surrogateescape") + ).digest() + + def next_flow(self, request): + """ + Returns the next flow object, or None if no matching flow was + found. + """ + hsh = self._hash(request) + if hsh in self.flowmap: + if self.options.server_replay_nopop: + return self.flowmap[hsh][0] + else: + ret = self.flowmap[hsh].pop(0) + if not self.flowmap[hsh]: + del self.flowmap[hsh] + return ret + + def configure(self, options, updated): + self.options = options + if "server_replay" in updated: + self.clear() + if options.server_replay: + try: + flows = flow.read_flows_from_paths(options.server_replay) + except exceptions.FlowReadException as e: + raise exceptions.OptionsError(str(e)) + self.load(flows) + + def tick(self): + if self.stop and not self.final_flow.live: + ctx.master.shutdown() + + def request(self, f): + if self.flowmap: + rflow = self.next_flow(f) + if rflow: + response = rflow.response.copy() + response.is_replay = True + if self.options.refresh_server_playback: + response.refresh() + f.response = response + if not self.flowmap and not self.options.keepserving: + self.final_flow = f + self.stop = True + elif self.options.replay_kill_extra: + ctx.log.warn( + "server_playback: killed non-replay request {}".format( + f.request.url + ) + ) + f.reply.kill() diff --git a/mitmproxy/cmdline.py b/mitmproxy/cmdline.py index b5b939b8..9fb4a561 100644 --- a/mitmproxy/cmdline.py +++ b/mitmproxy/cmdline.py @@ -218,10 +218,10 @@ def get_common_options(args): anticache=args.anticache, anticomp=args.anticomp, client_replay=args.client_replay, - kill=args.kill, + replay_kill_extra=args.replay_kill_extra, no_server=args.no_server, refresh_server_playback=not args.norefresh, - rheaders=args.rheaders, + server_replay_use_headers=args.server_replay_use_headers, rfile=args.rfile, replacements=reps, setheaders=setheaders, @@ -233,11 +233,11 @@ def get_common_options(args): showhost=args.showhost, outfile=args.outfile, verbosity=args.verbose, - nopop=args.nopop, - replay_ignore_content=args.replay_ignore_content, - replay_ignore_params=args.replay_ignore_params, - replay_ignore_payload_params=args.replay_ignore_payload_params, - replay_ignore_host=args.replay_ignore_host, + server_replay_nopop=args.server_replay_nopop, + server_replay_ignore_content=args.server_replay_ignore_content, + server_replay_ignore_params=args.server_replay_ignore_params, + server_replay_ignore_payload_params=args.server_replay_ignore_payload_params, + server_replay_ignore_host=args.server_replay_ignore_host, auth_nonanonymous = args.auth_nonanonymous, auth_singleuser = args.auth_singleuser, @@ -600,13 +600,13 @@ def server_replay(parser): help="Replay server responses from a saved file." ) group.add_argument( - "-k", "--kill", - action="store_true", dest="kill", default=False, + "-k", "--replay-kill-extra", + action="store_true", dest="replay_kill_extra", default=False, help="Kill extra requests during replay." ) group.add_argument( - "--rheader", - action="append", dest="rheaders", type=str, + "--server-replay-use-header", + action="append", dest="server_replay_use_headers", type=str, help="Request headers to be considered during replay. " "Can be passed multiple times." ) @@ -620,21 +620,21 @@ def server_replay(parser): ) group.add_argument( "--no-pop", - action="store_true", dest="nopop", default=False, + action="store_true", dest="server_replay_nopop", default=False, help="Disable response pop from response flow. " "This makes it possible to replay same response multiple times." ) payload = group.add_mutually_exclusive_group() payload.add_argument( "--replay-ignore-content", - action="store_true", dest="replay_ignore_content", default=False, + action="store_true", dest="server_replay_ignore_content", default=False, help=""" Ignore request's content while searching for a saved flow to replay """ ) payload.add_argument( "--replay-ignore-payload-param", - action="append", dest="replay_ignore_payload_params", type=str, + action="append", dest="server_replay_ignore_payload_params", type=str, help=""" Request's payload parameters (application/x-www-form-urlencoded or multipart/form-data) to be ignored while searching for a saved flow to replay. @@ -644,7 +644,7 @@ def server_replay(parser): group.add_argument( "--replay-ignore-param", - action="append", dest="replay_ignore_params", type=str, + action="append", dest="server_replay_ignore_params", type=str, help=""" Request's parameters to be ignored while searching for a saved flow to replay. Can be passed multiple times. @@ -653,7 +653,7 @@ def server_replay(parser): group.add_argument( "--replay-ignore-host", action="store_true", - dest="replay_ignore_host", + dest="server_replay_ignore_host", default=False, help="Ignore request's destination host while searching for a saved flow to replay") diff --git a/mitmproxy/console/flowlist.py b/mitmproxy/console/flowlist.py index 11e8fc99..c052de7b 100644 --- a/mitmproxy/console/flowlist.py +++ b/mitmproxy/console/flowlist.py @@ -18,14 +18,15 @@ def _mkhelp(): ("d", "delete flow"), ("D", "duplicate flow"), ("e", "toggle eventlog"), + ("E", "export flow to file"), ("f", "filter view"), ("F", "toggle follow flow list"), ("L", "load saved flows"), ("m", "toggle flow mark"), ("M", "toggle marked flow view"), ("n", "create a new request"), - ("E", "export flow to file"), ("r", "replay request"), + ("S", "server replay request/s"), ("U", "unmark all marked flows"), ("V", "revert changes to request"), ("w", "save flows "), @@ -140,36 +141,13 @@ class ConnectionItem(urwid.WidgetWrap): args = (self.flow,) ) - def stop_server_playback_prompt(self, a): - if a != "n": - self.master.stop_server_playback() - def server_replay_prompt(self, k): + a = self.master.addons.get("serverplayback") if k == "a": - self.master.start_server_playback( - [i.copy() for i in self.master.state.view], - self.master.options.kill, self.master.options.rheaders, - False, self.master.options.nopop, - self.master.options.replay_ignore_params, - self.master.options.replay_ignore_content, - self.master.options.replay_ignore_payload_params, - self.master.options.replay_ignore_host - ) + a.load([i.copy() for i in self.master.state.view]) elif k == "t": - self.master.start_server_playback( - [self.flow.copy()], - self.master.options.kill, self.master.options.rheaders, - False, self.master.options.nopop, - self.master.options.replay_ignore_params, - self.master.options.replay_ignore_content, - self.master.options.replay_ignore_payload_params, - self.master.options.replay_ignore_host - ) - else: - signals.status_prompt_path.send( - prompt = "Server replay path", - callback = self.master.server_playback_path - ) + a.load([self.flow.copy()]) + signals.update_settings.send(self) def mouse_event(self, size, event, button, col, row, focus): if event == "mouse press" and button == 1: @@ -202,29 +180,30 @@ class ConnectionItem(urwid.WidgetWrap): self.state.enable_marked_filter() signals.flowlist_change.send(self) elif key == "r": - r = self.master.replay_request(self.flow) - if r: - signals.status_message.send(message=r) + self.master.replay_request(self.flow) signals.flowlist_change.send(self) elif key == "S": - if not self.master.server_playback: + def stop_server_playback(response): + if response == "y": + self.master.options.server_replay = [] + a = self.master.addons.get("serverplayback") + if a.count(): signals.status_prompt_onekey.send( - prompt = "Server Replay", + prompt = "Stop current server replay?", keys = ( - ("all flows", "a"), - ("this flow", "t"), - ("file", "f"), + ("yes", "y"), + ("no", "n"), ), - callback = self.server_replay_prompt, + callback = stop_server_playback, ) else: signals.status_prompt_onekey.send( - prompt = "Stop current server replay?", + prompt = "Server Replay", keys = ( - ("yes", "y"), - ("no", "n"), + ("all flows", "a"), + ("this flow", "t"), ), - callback = self.stop_server_playback_prompt, + callback = self.server_replay_prompt, ) elif key == "U": for f in self.state.flows: diff --git a/mitmproxy/console/flowview.py b/mitmproxy/console/flowview.py index 15be379b..add10527 100644 --- a/mitmproxy/console/flowview.py +++ b/mitmproxy/console/flowview.py @@ -544,9 +544,7 @@ class FlowView(tabs.Tabs): elif key == "p": self.view_prev_flow(self.flow) elif key == "r": - r = self.master.replay_request(self.flow) - if r: - signals.status_message.send(message=r) + self.master.replay_request(self.flow) signals.flow_change.send(self, flow = self.flow) elif key == "V": if self.flow.modified(): diff --git a/mitmproxy/console/help.py b/mitmproxy/console/help.py index 8024dc31..e3e2f54c 100644 --- a/mitmproxy/console/help.py +++ b/mitmproxy/console/help.py @@ -53,7 +53,7 @@ class HelpView(urwid.ListBox): ("o", "options"), ("q", "quit / return to previous page"), ("Q", "quit without confirm prompt"), - ("R", "replay of HTTP requests/responses"), + ("R", "replay of requests/responses from file"), ] text.extend( common.format_keyvals(keys, key="key", val="text", indent=4) diff --git a/mitmproxy/console/master.py b/mitmproxy/console/master.py index a6942ca4..8ded1ea1 100644 --- a/mitmproxy/console/master.py +++ b/mitmproxy/console/master.py @@ -22,7 +22,6 @@ from mitmproxy import contentviews from mitmproxy import controller from mitmproxy import exceptions from mitmproxy import flow -from mitmproxy import script from mitmproxy import utils import mitmproxy.options from mitmproxy.console import flowlist @@ -67,11 +66,13 @@ class ConsoleState(flow.State): def add_flow(self, f): super(ConsoleState, self).add_flow(f) + signals.flowlist_change.send(self) self.update_focus() return f def update_flow(self, f): super(ConsoleState, self).update_flow(f) + signals.flowlist_change.send(self) self.update_focus() return f @@ -245,12 +246,6 @@ class ConsoleMaster(flow.FlowMaster): self.logbuffer = urwid.SimpleListWalker([]) self.follow = options.follow - if options.client_replay: - self.client_playback_path(options.client_replay) - - if options.server_replay: - self.server_playback_path(options.server_replay) - self.view_stack = [] if options.app: @@ -332,39 +327,13 @@ class ConsoleMaster(flow.FlowMaster): self.loop.widget = window self.loop.draw_screen() - def _run_script_method(self, method, s, f): - status, val = s.run(method, f) - if val: - if status: - signals.add_log("Method %s return: %s" % (method, val), "debug") - else: - signals.add_log( - "Method %s error: %s" % - (method, val[1]), "error") - def run_script_once(self, command, f): - if not command: - return - signals.add_log("Running script on flow: %s" % command, "debug") - + sc = self.addons.get("scriptloader") try: - s = script.Script(command) - s.load() - except script.ScriptException as e: - signals.status_message.send( - message='Error loading "{}".'.format(command) - ) - signals.add_log('Error loading "{}":\n{}'.format(command, e), "error") - return - - if f.request: - self._run_script_method("request", s, f) - if f.response: - self._run_script_method("response", s, f) - if f.error: - self._run_script_method("error", s, f) - s.unload() - signals.flow_change.send(self, flow = f) + with self.handlecontext(): + sc.run_once(command, [f]) + except mitmproxy.exceptions.AddonError as e: + signals.add_log("Script error: %s" % e, "warn") def toggle_eventlog(self): self.options.eventlog = not self.options.eventlog @@ -384,28 +353,6 @@ class ConsoleMaster(flow.FlowMaster): except exceptions.FlowReadException as e: signals.status_message.send(message=str(e)) - def client_playback_path(self, path): - if not isinstance(path, list): - path = [path] - flows = self._readflows(path) - if flows: - self.start_client_playback(flows, False) - - def server_playback_path(self, path): - if not isinstance(path, list): - path = [path] - flows = self._readflows(path) - if flows: - self.start_server_playback( - flows, - self.options.kill, self.options.rheaders, - False, self.options.nopop, - self.options.replay_ignore_params, - self.options.replay_ignore_content, - self.options.replay_ignore_payload_params, - self.options.replay_ignore_host - ) - def spawn_editor(self, data): text = not isinstance(data, bytes) fd, name = tempfile.mkstemp('', "mproxy", text=text) diff --git a/mitmproxy/console/options.py b/mitmproxy/console/options.py index f7fb2f90..97313bf4 100644 --- a/mitmproxy/console/options.py +++ b/mitmproxy/console/options.py @@ -114,8 +114,8 @@ class Options(urwid.WidgetWrap): select.Option( "Kill Extra", "x", - lambda: master.options.kill, - master.options.toggler("kill") + lambda: master.options.replay_kill_extra, + master.options.toggler("replay_kill_extra") ), select.Option( "No Refresh", @@ -165,7 +165,7 @@ class Options(urwid.WidgetWrap): anticomp = False, ignore_hosts = (), tcp_hosts = (), - kill = False, + replay_kill_extra = False, no_upstream_cert = False, refresh_server_playback = True, replacements = [], diff --git a/mitmproxy/console/statusbar.py b/mitmproxy/console/statusbar.py index 43d68d51..bbfb41ab 100644 --- a/mitmproxy/console/statusbar.py +++ b/mitmproxy/console/statusbar.py @@ -136,6 +136,9 @@ class StatusBar(urwid.WidgetWrap): def get_status(self): r = [] + sreplay = self.master.addons.get("serverplayback") + creplay = self.master.addons.get("clientplayback") + if len(self.master.options.setheaders): r.append("[") r.append(("heading_key", "H")) @@ -144,17 +147,14 @@ class StatusBar(urwid.WidgetWrap): r.append("[") r.append(("heading_key", "R")) r.append("eplacing]") - if self.master.client_playback: + if creplay.count(): r.append("[") r.append(("heading_key", "cplayback")) - r.append(":%s to go]" % self.master.client_playback.count()) - if self.master.server_playback: + r.append(":%s]" % creplay.count()) + if sreplay.count(): r.append("[") r.append(("heading_key", "splayback")) - if self.master.options.nopop: - r.append(":%s in file]" % self.master.server_playback.count()) - else: - r.append(":%s to go]" % self.master.server_playback.count()) + r.append(":%s]" % sreplay.count()) if self.master.options.ignore_hosts: r.append("[") r.append(("heading_key", "I")) @@ -193,7 +193,7 @@ class StatusBar(urwid.WidgetWrap): opts.append("showhost") if not self.master.options.refresh_server_playback: opts.append("norefresh") - if self.master.options.kill: + if self.master.options.replay_kill_extra: opts.append("killextra") if self.master.options.no_upstream_cert: opts.append("no-upstream-cert") diff --git a/mitmproxy/console/window.py b/mitmproxy/console/window.py index 35593643..159f68ed 100644 --- a/mitmproxy/console/window.py +++ b/mitmproxy/console/window.py @@ -57,13 +57,11 @@ class Window(urwid.Frame): callback = self.master.stop_client_playback_prompt, ) elif k == "s": - if not self.master.server_playback: - signals.status_prompt_path.send( - self, - prompt = "Server replay path", - callback = self.master.server_playback_path - ) - else: + a = self.master.addons.get("serverplayback") + if a.count(): + def stop_server_playback(response): + if response == "y": + self.master.options.server_replay = [] signals.status_prompt_onekey.send( self, prompt = "Stop current server replay?", @@ -71,7 +69,13 @@ class Window(urwid.Frame): ("yes", "y"), ("no", "n"), ), - callback = self.master.stop_server_playback_prompt, + callback = stop_server_playback + ) + else: + signals.status_prompt_path.send( + self, + prompt = "Server playback path", + callback = lambda x: self.master.options.setter("server_replay")([x]) ) def keypress(self, size, k): diff --git a/mitmproxy/dump.py b/mitmproxy/dump.py index 51124224..778968b9 100644 --- a/mitmproxy/dump.py +++ b/mitmproxy/dump.py @@ -46,37 +46,17 @@ class DumpMaster(flow.FlowMaster): self.addons.add(options, dumper.Dumper()) # This line is just for type hinting self.options = self.options # type: Options - self.replay_ignore_params = options.replay_ignore_params - self.replay_ignore_content = options.replay_ignore_content - self.replay_ignore_host = options.replay_ignore_host - self.refresh_server_playback = options.refresh_server_playback - self.replay_ignore_payload_params = options.replay_ignore_payload_params - self.set_stream_large_bodies(options.stream_large_bodies) + print("Proxy server listening at http://%s:%d" % ( + (options.listen_host or "0.0.0.0"), + options.listen_port)) + if self.server and self.options.http2 and not tcp.HAS_ALPN: # pragma: no cover print("ALPN support missing (OpenSSL 1.0.2+ required)!\n" "HTTP/2 is disabled. Use --no-http2 to silence this warning.", file=sys.stderr) - if options.server_replay: - self.start_server_playback( - self._readflow(options.server_replay), - options.kill, options.rheaders, - not options.keepserving, - options.nopop, - options.replay_ignore_params, - options.replay_ignore_content, - options.replay_ignore_payload_params, - options.replay_ignore_host - ) - - if options.client_replay: - self.start_client_playback( - self._readflow(options.client_replay), - not options.keepserving - ) - if options.rfile: try: self.load_flows_file(options.rfile) diff --git a/mitmproxy/exceptions.py b/mitmproxy/exceptions.py index 94876514..6873215c 100644 --- a/mitmproxy/exceptions.py +++ b/mitmproxy/exceptions.py @@ -7,9 +7,6 @@ See also: http://lucumr.pocoo.org/2014/10/16/on-error-handling/ """ from __future__ import absolute_import, print_function, division -import sys -import traceback - class ProxyException(Exception): @@ -30,6 +27,10 @@ class Kill(ProxyException): class ProtocolException(ProxyException): + """ + ProtocolExceptions are caused by invalid user input, unavailable network resources, + or other events that are outside of our influence. + """ pass @@ -62,6 +63,10 @@ class Http2ProtocolException(ProtocolException): pass +class Http2ZombieException(ProtocolException): + pass + + class ServerException(ProxyException): pass @@ -74,27 +79,6 @@ class ReplayException(ProxyException): pass -class ScriptException(ProxyException): - - @classmethod - def from_exception_context(cls, cut_tb=1): - """ - Must be called while the current stack handles an exception. - - Args: - cut_tb: remove N frames from the stack trace to hide internal calls. - """ - exc_type, exc_value, exc_traceback = sys.exc_info() - - while cut_tb > 0: - exc_traceback = exc_traceback.tb_next - cut_tb -= 1 - - tb = "".join(traceback.format_exception(exc_type, exc_value, exc_traceback)) - - return cls(tb) - - class FlowReadException(ProxyException): pass @@ -113,3 +97,7 @@ class OptionsError(Exception): class AddonError(Exception): pass + + +class ReplayError(Exception): + pass diff --git a/mitmproxy/filt.py b/mitmproxy/filt.py index 67915e5b..eb3e392b 100644 --- a/mitmproxy/filt.py +++ b/mitmproxy/filt.py @@ -244,6 +244,7 @@ class FHeadResponse(_Rex): class FBod(_Rex): code = "b" help = "Body" + flags = re.DOTALL @only(HTTPFlow, TCPFlow) def __call__(self, f): @@ -264,6 +265,7 @@ class FBod(_Rex): class FBodRequest(_Rex): code = "bq" help = "Request body" + flags = re.DOTALL @only(HTTPFlow, TCPFlow) def __call__(self, f): @@ -280,6 +282,7 @@ class FBodRequest(_Rex): class FBodResponse(_Rex): code = "bs" help = "Response body" + flags = re.DOTALL @only(HTTPFlow, TCPFlow) def __call__(self, f): diff --git a/mitmproxy/flow/__init__.py b/mitmproxy/flow/__init__.py index 8a64180e..cb79482c 100644 --- a/mitmproxy/flow/__init__.py +++ b/mitmproxy/flow/__init__.py @@ -4,16 +4,12 @@ from mitmproxy.flow import export, modules from mitmproxy.flow.io import FlowWriter, FilteredFlowWriter, FlowReader, read_flows_from_paths from mitmproxy.flow.master import FlowMaster from mitmproxy.flow.modules import ( - AppRegistry, StreamLargeBodies, ClientPlaybackState, ServerPlaybackState + AppRegistry, StreamLargeBodies ) from mitmproxy.flow.state import State, FlowView -# TODO: We may want to remove the imports from .modules and just expose "modules" - __all__ = [ "export", "modules", "FlowWriter", "FilteredFlowWriter", "FlowReader", "read_flows_from_paths", - "FlowMaster", - "AppRegistry", "StreamLargeBodies", "ClientPlaybackState", - "ServerPlaybackState", "State", "FlowView", + "FlowMaster", "AppRegistry", "StreamLargeBodies", "State", "FlowView", ] diff --git a/mitmproxy/flow/master.py b/mitmproxy/flow/master.py index 9cdcc8dd..94b46f3f 100644 --- a/mitmproxy/flow/master.py +++ b/mitmproxy/flow/master.py @@ -15,6 +15,30 @@ from mitmproxy.onboarding import app from mitmproxy.protocol import http_replay +def event_sequence(f): + if isinstance(f, models.HTTPFlow): + if f.request: + yield "request", f + if f.response: + yield "responseheaders", f + yield "response", f + if f.error: + yield "error", f + elif isinstance(f, models.TCPFlow): + messages = f.messages + f.messages = [] + f.reply = controller.DummyReply() + yield "tcp_open", f + while messages: + f.messages.append(messages.pop(0)) + yield "tcp_message", f + if f.error: + yield "tcp_error", f + yield "tcp_close", f + else: + raise NotImplementedError + + class FlowMaster(controller.Master): @property @@ -29,23 +53,11 @@ class FlowMaster(controller.Master): if server: self.add_server(server) self.state = state - self.server_playback = None # type: Optional[modules.ServerPlaybackState] - self.client_playback = None # type: Optional[modules.ClientPlaybackState] - self.kill_nonreplay = False - self.stream_large_bodies = None # type: Optional[modules.StreamLargeBodies] - self.replay_ignore_params = False - self.replay_ignore_content = None - self.replay_ignore_host = False - self.apps = modules.AppRegistry() def start_app(self, host, port): - self.apps.add( - app.mapp, - host, - port - ) + self.apps.add(app.mapp, host, port) def set_stream_large_bodies(self, max_size): if max_size is not None: @@ -53,92 +65,6 @@ class FlowMaster(controller.Master): else: self.stream_large_bodies = False - def start_client_playback(self, flows, exit): - """ - flows: List of flows. - """ - self.client_playback = modules.ClientPlaybackState(flows, exit) - - def stop_client_playback(self): - self.client_playback = None - - def start_server_playback( - self, - flows, - kill, - headers, - exit, - nopop, - ignore_params, - ignore_content, - ignore_payload_params, - ignore_host): - """ - flows: List of flows. - kill: Boolean, should we kill requests not part of the replay? - ignore_params: list of parameters to ignore in server replay - ignore_content: true if request content should be ignored in server replay - ignore_payload_params: list of content params to ignore in server replay - ignore_host: true if request host should be ignored in server replay - """ - self.server_playback = modules.ServerPlaybackState( - headers, - flows, - exit, - nopop, - ignore_params, - ignore_content, - ignore_payload_params, - ignore_host) - self.kill_nonreplay = kill - - def stop_server_playback(self): - self.server_playback = None - - def do_server_playback(self, flow): - """ - This method should be called by child classes in the request - handler. Returns True if playback has taken place, None if not. - """ - if self.server_playback: - rflow = self.server_playback.next_flow(flow) - if not rflow: - return None - response = rflow.response.copy() - response.is_replay = True - if self.options.refresh_server_playback: - response.refresh() - flow.response = response - return True - return None - - def tick(self, timeout): - if self.client_playback: - stop = ( - self.client_playback.done() and - self.state.active_flow_count() == 0 - ) - exit = self.client_playback.exit - if stop: - self.stop_client_playback() - if exit: - self.shutdown() - else: - self.client_playback.tick(self) - - if self.server_playback: - stop = ( - self.server_playback.count() == 0 and - self.state.active_flow_count() == 0 and - not self.kill_nonreplay - ) - exit = self.server_playback.exit - if stop: - self.stop_server_playback() - if exit: - self.shutdown() - return super(FlowMaster, self).tick(timeout) - def duplicate_flow(self, f): """ Duplicate flow, and insert it into state without triggering any of @@ -182,28 +108,9 @@ class FlowMaster(controller.Master): f.request.host = self.server.config.upstream_server.address.host f.request.port = self.server.config.upstream_server.address.port f.request.scheme = self.server.config.upstream_server.scheme - - f.reply = controller.DummyReply() - if f.request: - self.request(f) - if f.response: - self.responseheaders(f) - self.response(f) - if f.error: - self.error(f) - elif isinstance(f, models.TCPFlow): - messages = f.messages - f.messages = [] - f.reply = controller.DummyReply() - self.tcp_open(f) - while messages: - f.messages.append(messages.pop(0)) - self.tcp_message(f) - if f.error: - self.tcp_error(f) - self.tcp_close(f) - else: - raise NotImplementedError() + f.reply = controller.DummyReply() + for e, o in event_sequence(f): + getattr(self, e)(o) def load_flows(self, fr): """ @@ -229,43 +136,49 @@ class FlowMaster(controller.Master): except IOError as v: raise exceptions.FlowReadException(v.strerror) - def process_new_request(self, f): - if self.server_playback: - pb = self.do_server_playback(f) - if not pb and self.kill_nonreplay: - self.add_log("Killed {}".format(f.request.url), "info") - f.reply.kill() - def replay_request(self, f, block=False): """ - Returns None if successful, or error message if not. + Returns an http_replay.RequestReplayThred object. + May raise exceptions.ReplayError. """ if f.live: - return "Can't replay live request." + raise exceptions.ReplayError( + "Can't replay live flow." + ) if f.intercepted: - return "Can't replay while intercepting..." + raise exceptions.ReplayError( + "Can't replay intercepted flow." + ) if f.request.raw_content is None: - return "Can't replay request with missing content..." - if f.request: - f.backup() - f.request.is_replay = True - - # TODO: We should be able to remove this. - if "Content-Length" in f.request.headers: - f.request.headers["Content-Length"] = str(len(f.request.raw_content)) - - f.response = None - f.error = None - self.process_new_request(f) - rt = http_replay.RequestReplayThread( - self.server.config, - f, - self.event_queue, - self.should_exit + raise exceptions.ReplayError( + "Can't replay flow with missing content." + ) + if not f.request: + raise exceptions.ReplayError( + "Can't replay flow with missing request." ) - rt.start() # pragma: no cover - if block: - rt.join() + + f.backup() + f.request.is_replay = True + + # TODO: We should be able to remove this. + if "Content-Length" in f.request.headers: + f.request.headers["Content-Length"] = str(len(f.request.raw_content)) + + f.response = None + f.error = None + # FIXME: process through all addons? + # self.process_new_request(f) + rt = http_replay.RequestReplayThread( + self.server.config, + f, + self.event_queue, + self.should_exit + ) + rt.start() # pragma: no cover + if block: + rt.join() + return rt @controller.handler def log(self, l): @@ -294,9 +207,6 @@ class FlowMaster(controller.Master): @controller.handler def error(self, f): self.state.update_flow(f) - if self.client_playback: - self.client_playback.clear(f) - return f @controller.handler def request(self, f): @@ -314,8 +224,6 @@ class FlowMaster(controller.Master): return if f not in self.state.flows: # don't add again on replay self.state.add_flow(f) - self.process_new_request(f) - return f @controller.handler def responseheaders(self, f): @@ -325,18 +233,14 @@ class FlowMaster(controller.Master): except netlib.exceptions.HttpException: f.reply.kill() return - return f @controller.handler def response(self, f): self.state.update_flow(f) - if self.client_playback: - self.client_playback.clear(f) - return f @controller.handler def websockets_handshake(self, f): - return f + pass def handle_intercept(self, f): self.state.update_flow(f) @@ -356,10 +260,7 @@ class FlowMaster(controller.Master): @controller.handler def tcp_error(self, flow): - self.add_log("Error in TCP connection to {}: {}".format( - repr(flow.server_conn.address), - flow.error - ), "info") + pass @controller.handler def tcp_close(self, flow): diff --git a/mitmproxy/flow/modules.py b/mitmproxy/flow/modules.py index fb3c52da..7d8a282e 100644 --- a/mitmproxy/flow/modules.py +++ b/mitmproxy/flow/modules.py @@ -1,13 +1,7 @@ from __future__ import absolute_import, print_function, division -import hashlib - -from six.moves import urllib - -from mitmproxy import controller from netlib import wsgi from netlib import version -from netlib import strutils from netlib.http import http1 @@ -50,129 +44,3 @@ class StreamLargeBodies(object): if not r.raw_content and not (0 <= expected_size <= self.max_size): # r.stream may already be a callable, which we want to preserve. r.stream = r.stream or True - - -class ClientPlaybackState: - def __init__(self, flows, exit): - self.flows, self.exit = flows, exit - self.current = None - self.testing = False # Disables actual replay for testing. - - def count(self): - return len(self.flows) - - def done(self): - if len(self.flows) == 0 and not self.current: - return True - return False - - def clear(self, flow): - """ - A request has returned in some way - if this is the one we're - servicing, go to the next flow. - """ - if flow is self.current: - self.current = None - - def tick(self, master): - if self.flows and not self.current: - self.current = self.flows.pop(0).copy() - if not self.testing: - master.replay_request(self.current) - else: - self.current.reply = controller.DummyReply() - master.request(self.current) - if self.current.response: - master.response(self.current) - - -class ServerPlaybackState: - def __init__( - self, - headers, - flows, - exit, - nopop, - ignore_params, - ignore_content, - ignore_payload_params, - ignore_host): - """ - headers: Case-insensitive list of request headers that should be - included in request-response matching. - """ - self.headers = headers - self.exit = exit - self.nopop = nopop - self.ignore_params = ignore_params - self.ignore_content = ignore_content - self.ignore_payload_params = [strutils.always_bytes(x) for x in (ignore_payload_params or ())] - self.ignore_host = ignore_host - self.fmap = {} - for i in flows: - if i.response: - l = self.fmap.setdefault(self._hash(i), []) - l.append(i) - - def count(self): - return sum(len(i) for i in self.fmap.values()) - - def _hash(self, flow): - """ - Calculates a loose hash of the flow request. - """ - r = flow.request - - _, _, path, _, query, _ = urllib.parse.urlparse(r.url) - queriesArray = urllib.parse.parse_qsl(query, keep_blank_values=True) - - key = [ - str(r.port), - str(r.scheme), - str(r.method), - str(path), - ] - - if not self.ignore_content: - form_contents = r.urlencoded_form or r.multipart_form - if self.ignore_payload_params and form_contents: - key.extend( - p for p in form_contents.items(multi=True) - if p[0] not in self.ignore_payload_params - ) - else: - key.append(str(r.raw_content)) - - if not self.ignore_host: - key.append(r.host) - - filtered = [] - ignore_params = self.ignore_params or [] - for p in queriesArray: - if p[0] not in ignore_params: - filtered.append(p) - for p in filtered: - key.append(p[0]) - key.append(p[1]) - - if self.headers: - headers = [] - for i in self.headers: - v = r.headers.get(i) - headers.append((i, v)) - key.append(headers) - return hashlib.sha256(repr(key).encode("utf8", "surrogateescape")).digest() - - def next_flow(self, request): - """ - Returns the next flow object, or None if no matching flow was - found. - """ - l = self.fmap.get(self._hash(request)) - if not l: - return None - - if self.nopop: - return l[0] - else: - return l.pop(0) diff --git a/mitmproxy/options.py b/mitmproxy/options.py index 0ac44cd8..ba4ed0c7 100644 --- a/mitmproxy/options.py +++ b/mitmproxy/options.py @@ -30,15 +30,16 @@ class Options(optmanager.OptManager): anticache=False, # type: bool anticomp=False, # type: bool client_replay=None, # type: Optional[str] - kill=False, # type: bool + replay_kill_extra=False, # type: bool + keepserving=True, # type: bool no_server=False, # type: bool - nopop=False, # type: bool + server_replay_nopop=False, # type: bool refresh_server_playback=False, # type: bool rfile=None, # type: Optional[str] scripts=(), # type: Sequence[str] showhost=False, # type: bool replacements=(), # type: Sequence[Tuple[str, str, str]] - rheaders=(), # type: Sequence[str] + server_replay_use_headers=(), # type: Sequence[str] setheaders=(), # type: Sequence[Tuple[str, str, str]] server_replay=None, # type: Optional[str] stickycookie=None, # type: Optional[str] @@ -46,10 +47,10 @@ class Options(optmanager.OptManager): stream_large_bodies=None, # type: Optional[str] verbosity=2, # type: int outfile=None, # type: Tuple[str, str] - replay_ignore_content=False, # type: bool - replay_ignore_params=(), # type: Sequence[str] - replay_ignore_payload_params=(), # type: Sequence[str] - replay_ignore_host=False, # type: bool + server_replay_ignore_content=False, # type: bool + server_replay_ignore_params=(), # type: Sequence[str] + server_replay_ignore_payload_params=(), # type: Sequence[str] + server_replay_ignore_host=False, # type: bool # Proxy options auth_nonanonymous=False, # type: bool @@ -88,15 +89,16 @@ class Options(optmanager.OptManager): self.anticache = anticache self.anticomp = anticomp self.client_replay = client_replay - self.kill = kill + self.keepserving = keepserving + self.replay_kill_extra = replay_kill_extra self.no_server = no_server - self.nopop = nopop + self.server_replay_nopop = server_replay_nopop self.refresh_server_playback = refresh_server_playback self.rfile = rfile self.scripts = scripts self.showhost = showhost self.replacements = replacements - self.rheaders = rheaders + self.server_replay_use_headers = server_replay_use_headers self.setheaders = setheaders self.server_replay = server_replay self.stickycookie = stickycookie @@ -104,10 +106,10 @@ class Options(optmanager.OptManager): self.stream_large_bodies = stream_large_bodies self.verbosity = verbosity self.outfile = outfile - self.replay_ignore_content = replay_ignore_content - self.replay_ignore_params = replay_ignore_params - self.replay_ignore_payload_params = replay_ignore_payload_params - self.replay_ignore_host = replay_ignore_host + self.server_replay_ignore_content = server_replay_ignore_content + self.server_replay_ignore_params = server_replay_ignore_params + self.server_replay_ignore_payload_params = server_replay_ignore_payload_params + self.server_replay_ignore_host = server_replay_ignore_host # Proxy options self.auth_nonanonymous = auth_nonanonymous diff --git a/mitmproxy/protocol/http.py b/mitmproxy/protocol/http.py index 1418d6e9..1632e66f 100644 --- a/mitmproxy/protocol/http.py +++ b/mitmproxy/protocol/http.py @@ -153,12 +153,13 @@ class HttpLayer(base.Layer): # We optimistically guess there might be an HTTP client on the # other end self.send_error_response(400, repr(e)) - self.log( - "request", - "warn", - "HTTP protocol error in client request: %s" % e + six.reraise( + exceptions.ProtocolException, + exceptions.ProtocolException( + "HTTP protocol error in client request: {}".format(e) + ), + sys.exc_info()[2] ) - return self.log("request", "debug", [repr(request)]) diff --git a/mitmproxy/protocol/http2.py b/mitmproxy/protocol/http2.py index 0e42d619..1595fb61 100644 --- a/mitmproxy/protocol/http2.py +++ b/mitmproxy/protocol/http2.py @@ -96,15 +96,17 @@ class Http2Layer(base.Layer): self.server_to_client_stream_ids = dict([(0, 0)]) self.client_conn.h2 = SafeH2Connection(self.client_conn, client_side=False, header_encoding=False) - # make sure that we only pass actual SSL.Connection objects in here, - # because otherwise ssl_read_select fails! - self.active_conns = [self.client_conn.connection] - def _initiate_server_conn(self): - self.server_conn.h2 = SafeH2Connection(self.server_conn, client_side=True, header_encoding=False) - self.server_conn.h2.initiate_connection() - self.server_conn.send(self.server_conn.h2.data_to_send()) - self.active_conns.append(self.server_conn.connection) + if self.server_conn: + self.server_conn.h2 = SafeH2Connection(self.server_conn, client_side=True, header_encoding=False) + self.server_conn.h2.initiate_connection() + self.server_conn.send(self.server_conn.h2.data_to_send()) + + def _complete_handshake(self): + preamble = self.client_conn.rfile.read(24) + self.client_conn.h2.initiate_connection() + self.client_conn.h2.receive_data(preamble) + self.client_conn.send(self.client_conn.h2.data_to_send()) def next_layer(self): # pragma: no cover # WebSockets over HTTP/2? @@ -126,7 +128,7 @@ class Http2Layer(base.Layer): eid = event.stream_id if isinstance(event, events.RequestReceived): - return self._handle_request_received(eid, event) + return self._handle_request_received(eid, event, source_conn.h2) elif isinstance(event, events.ResponseReceived): return self._handle_response_received(eid, event) elif isinstance(event, events.DataReceived): @@ -138,9 +140,9 @@ class Http2Layer(base.Layer): elif isinstance(event, events.RemoteSettingsChanged): return self._handle_remote_settings_changed(event, other_conn) elif isinstance(event, events.ConnectionTerminated): - return self._handle_connection_terminated(event) + return self._handle_connection_terminated(event, is_server) elif isinstance(event, events.PushedStreamReceived): - return self._handle_pushed_stream_received(event) + return self._handle_pushed_stream_received(event, source_conn.h2) elif isinstance(event, events.PriorityUpdated): return self._handle_priority_updated(eid, event) elif isinstance(event, events.TrailersReceived): @@ -149,9 +151,9 @@ class Http2Layer(base.Layer): # fail-safe for unhandled events return True - def _handle_request_received(self, eid, event): + def _handle_request_received(self, eid, event, h2_connection): headers = netlib.http.Headers([[k, v] for k, v in event.headers]) - self.streams[eid] = Http2SingleStreamLayer(self, eid, headers) + self.streams[eid] = Http2SingleStreamLayer(self, h2_connection, eid, headers) self.streams[eid].timestamp_start = time.time() self.streams[eid].no_body = (event.stream_ended is not None) if event.priority_updated is not None: @@ -173,7 +175,7 @@ class Http2Layer(base.Layer): def _handle_data_received(self, eid, event, source_conn): bsl = self.config.options.body_size_limit if bsl and self.streams[eid].queued_data_length > bsl: - self.streams[eid].zombie = time.time() + self.streams[eid].kill() source_conn.h2.safe_reset_stream( event.stream_id, h2.errors.REFUSED_STREAM @@ -194,7 +196,7 @@ class Http2Layer(base.Layer): return True def _handle_stream_reset(self, eid, event, is_server, other_conn): - self.streams[eid].zombie = time.time() + self.streams[eid].kill() if eid in self.streams and event.error_code == h2.errors.CANCEL: if is_server: other_stream_id = self.streams[eid].client_stream_id @@ -209,7 +211,13 @@ class Http2Layer(base.Layer): other_conn.h2.safe_update_settings(new_settings) return True - def _handle_connection_terminated(self, event): + def _handle_connection_terminated(self, event, is_server): + self.log("HTTP/2 connection terminated by {}: error code: {}, last stream id: {}, additional data: {}".format( + "server" if is_server else "client", + event.error_code, + event.last_stream_id, + event.additional_data), "info") + if event.error_code != h2.errors.NO_ERROR: # Something terrible has happened - kill everything! self.client_conn.h2.close_connection( @@ -226,7 +234,7 @@ class Http2Layer(base.Layer): """ return False - def _handle_pushed_stream_received(self, event): + def _handle_pushed_stream_received(self, event, h2_connection): # pushed stream ids should be unique and not dependent on race conditions # only the parent stream id must be looked up first parent_eid = self.server_to_client_stream_ids[event.parent_stream_id] @@ -235,7 +243,7 @@ class Http2Layer(base.Layer): self.client_conn.send(self.client_conn.h2.data_to_send()) headers = netlib.http.Headers([[k, v] for k, v in event.headers]) - self.streams[event.pushed_stream_id] = Http2SingleStreamLayer(self, event.pushed_stream_id, headers) + self.streams[event.pushed_stream_id] = Http2SingleStreamLayer(self, h2_connection, event.pushed_stream_id, headers) self.streams[event.pushed_stream_id].timestamp_start = time.time() self.streams[event.pushed_stream_id].pushed = True self.streams[event.pushed_stream_id].parent_stream_id = parent_eid @@ -253,7 +261,7 @@ class Http2Layer(base.Layer): with self.server_conn.h2.lock: mapped_stream_id = event.stream_id if mapped_stream_id in self.streams and self.streams[mapped_stream_id].server_stream_id: - # if the stream is already up and running and was sent to the server + # if the stream is already up and running and was sent to the server, # use the mapped server stream id to update priority information mapped_stream_id = self.streams[mapped_stream_id].server_stream_id @@ -294,37 +302,36 @@ class Http2Layer(base.Layer): def _kill_all_streams(self): for stream in self.streams.values(): - if not stream.zombie: - stream.zombie = time.time() - stream.request_data_finished.set() - stream.response_arrived.set() - stream.data_finished.set() + stream.kill() def __call__(self): - if self.server_conn: - self._initiate_server_conn() + self._initiate_server_conn() + self._complete_handshake() - preamble = self.client_conn.rfile.read(24) - self.client_conn.h2.initiate_connection() - self.client_conn.h2.receive_data(preamble) - self.client_conn.send(self.client_conn.h2.data_to_send()) + client = self.client_conn.connection + server = self.server_conn.connection + conns = [client, server] try: while True: - r = tcp.ssl_read_select(self.active_conns, 1) + r = tcp.ssl_read_select(conns, 1) for conn in r: - source_conn = self.client_conn if conn == self.client_conn.connection else self.server_conn - other_conn = self.server_conn if conn == self.client_conn.connection else self.client_conn + source_conn = self.client_conn if conn == client else self.server_conn + other_conn = self.server_conn if conn == client else self.client_conn is_server = (conn == self.server_conn.connection) with source_conn.h2.lock: try: - raw_frame = b''.join(http2.framereader.http2_read_raw_frame(source_conn.rfile)) + raw_frame = b''.join(http2.read_raw_frame(source_conn.rfile)) except: # read frame failed: connection closed self._kill_all_streams() return + if source_conn.h2.state_machine.state == h2.connection.ConnectionState.CLOSED: + self.log("HTTP/2 connection entered closed state already", "debug") + return + incoming_events = source_conn.h2.receive_data(raw_frame) source_conn.send(source_conn.h2.data_to_send()) @@ -354,10 +361,11 @@ def detect_zombie_stream(func): class Http2SingleStreamLayer(http._HttpTransmissionLayer, basethread.BaseThread): - def __init__(self, ctx, stream_id, request_headers): + def __init__(self, ctx, h2_connection, stream_id, request_headers): super(Http2SingleStreamLayer, self).__init__( ctx, name="Http2SingleStreamLayer-{}".format(stream_id) ) + self.h2_connection = h2_connection self.zombie = None self.client_stream_id = stream_id self.server_stream_id = None @@ -365,6 +373,9 @@ class Http2SingleStreamLayer(http._HttpTransmissionLayer, basethread.BaseThread) self.response_headers = None self.pushed = False + self.timestamp_start = None + self.timestamp_end = None + self.request_data_queue = queue.Queue() self.request_queued_data_length = 0 self.request_data_finished = threading.Event() @@ -381,6 +392,13 @@ class Http2SingleStreamLayer(http._HttpTransmissionLayer, basethread.BaseThread) self.priority_weight = None self.handled_priority_event = None + def kill(self): + if not self.zombie: + self.zombie = time.time() + self.request_data_finished.set() + self.response_arrived.set() + self.response_data_finished.set() + def connect(self): # pragma: no cover raise exceptions.Http2ProtocolException("HTTP2 layer should already have a connection.") @@ -421,10 +439,11 @@ class Http2SingleStreamLayer(http._HttpTransmissionLayer, basethread.BaseThread) self.request_queued_data_length = v def raise_zombie(self, pre_command=None): - if self.zombie is not None: + connection_closed = self.h2_connection.state_machine.state == h2.connection.ConnectionState.CLOSED + if self.zombie is not None or not hasattr(self.server_conn, 'h2') or connection_closed: if pre_command is not None: pre_command() - raise exceptions.Http2ProtocolException("Zombie Stream") + raise exceptions.Http2ZombieException("Connection already dead") @detect_zombie_stream def read_request(self): @@ -508,6 +527,7 @@ class Http2SingleStreamLayer(http._HttpTransmissionLayer, basethread.BaseThread) except Exception as e: # pragma: no cover raise e finally: + self.raise_zombie() self.server_conn.h2.lock.release() if not self.no_body: @@ -581,6 +601,8 @@ class Http2SingleStreamLayer(http._HttpTransmissionLayer, basethread.BaseThread) try: layer() + except exceptions.Http2ZombieException as e: # pragma: no cover + pass except exceptions.ProtocolException as e: # pragma: no cover self.log(repr(e), "info") self.log(traceback.format_exc(), "debug") @@ -589,5 +611,4 @@ class Http2SingleStreamLayer(http._HttpTransmissionLayer, basethread.BaseThread) except exceptions.Kill: self.log("Connection killed", "info") - if not self.zombie: - self.zombie = time.time() + self.kill() diff --git a/mitmproxy/protocol/http_replay.py b/mitmproxy/protocol/http_replay.py index bfde06c5..877eaa22 100644 --- a/mitmproxy/protocol/http_replay.py +++ b/mitmproxy/protocol/http_replay.py @@ -33,6 +33,7 @@ class RequestReplayThread(basethread.BaseThread): def run(self): r = self.flow.request first_line_format_backup = r.first_line_format + server = None try: self.flow.response = None @@ -103,3 +104,5 @@ class RequestReplayThread(basethread.BaseThread): self.channel.tell("log", Log(traceback.format_exc(), "error")) finally: r.first_line_format = first_line_format_backup + if server: + server.finish() diff --git a/mitmproxy/proxy/server.py b/mitmproxy/proxy/server.py index 4fd5755a..c5fd5f9e 100644 --- a/mitmproxy/proxy/server.py +++ b/mitmproxy/proxy/server.py @@ -132,7 +132,7 @@ class ConnectionHandler(object): self.log(str(e), "warn") self.log("Invalid certificate, closing connection. Pass --insecure to disable validation.", "warn") else: - self.log(repr(e), "warn") + self.log(str(e), "warn") self.log(traceback.format_exc(), "debug") # If an error propagates to the topmost level, diff --git a/mitmproxy/web/app.py b/mitmproxy/web/app.py index c92ba4d3..5498c2d9 100644 --- a/mitmproxy/web/app.py +++ b/mitmproxy/web/app.py @@ -116,7 +116,7 @@ class RequestHandler(BasicAuth, tornado.web.RequestHandler): def json(self): if not self.request.headers.get("Content-Type").startswith("application/json"): return None - return json.loads(self.request.body) + return json.loads(self.request.body.decode()) @property def state(self): diff --git a/netlib/debug.py b/netlib/debug.py index 29c7f655..f9c700de 100644 --- a/netlib/debug.py +++ b/netlib/debug.py @@ -37,7 +37,7 @@ def sysinfo(): return "\n".join(data) -def dump_info(sig, frm, file=sys.stdout): # pragma: no cover +def dump_info(signal=None, frame=None, file=sys.stdout): # pragma: no cover print("****************************************************", file=file) print("Summary", file=file) print("=======", file=file) @@ -81,7 +81,7 @@ def dump_info(sig, frm, file=sys.stdout): # pragma: no cover print("****************************************************", file=file) -def dump_stacks(signal, frame, file=sys.stdout): +def dump_stacks(signal=None, frame=None, file=sys.stdout): id2name = dict([(th.ident, th.name) for th in threading.enumerate()]) code = [] for threadId, stack in sys._current_frames().items(): diff --git a/netlib/encoding.py b/netlib/encoding.py index 9c8acff7..9b8b3868 100644 --- a/netlib/encoding.py +++ b/netlib/encoding.py @@ -5,7 +5,9 @@ from __future__ import absolute_import import codecs import collections +import six from io import BytesIO + import gzip import zlib import brotli @@ -32,6 +34,9 @@ def decode(encoded, encoding, errors='strict'): Raises: ValueError, if decoding fails. """ + if len(encoded) == 0: + return encoded + global _cache cached = ( isinstance(encoded, bytes) and @@ -49,11 +54,14 @@ def decode(encoded, encoding, errors='strict'): if encoding in ("gzip", "deflate", "br"): _cache = CachedDecode(encoded, encoding, errors, decoded) return decoded + except TypeError: + raise except Exception as e: - raise ValueError("{} when decoding {} with {}".format( + raise ValueError("{} when decoding {} with {}: {}".format( type(e).__name__, repr(encoded)[:10], repr(encoding), + repr(e), )) @@ -68,6 +76,9 @@ def encode(decoded, encoding, errors='strict'): Raises: ValueError, if encoding fails. """ + if len(decoded) == 0: + return decoded + global _cache cached = ( isinstance(decoded, bytes) and @@ -79,17 +90,23 @@ def encode(decoded, encoding, errors='strict'): return _cache.encoded try: try: - encoded = custom_encode[encoding](decoded) + value = decoded + if not six.PY2 and isinstance(value, six.string_types): + value = decoded.encode() + encoded = custom_encode[encoding](value) except KeyError: encoded = codecs.encode(decoded, encoding, errors) if encoding in ("gzip", "deflate", "br"): _cache = CachedDecode(encoded, encoding, errors, decoded) return encoded + except TypeError: + raise except Exception as e: - raise ValueError("{} when encoding {} with {}".format( + raise ValueError("{} when encoding {} with {}: {}".format( type(e).__name__, repr(decoded)[:10], repr(encoding), + repr(e), )) @@ -145,12 +162,14 @@ def encode_deflate(content): custom_decode = { + "none": identity, "identity": identity, "gzip": decode_gzip, "deflate": decode_deflate, "br": decode_brotli, } custom_encode = { + "none": identity, "identity": identity, "gzip": encode_gzip, "deflate": encode_deflate, diff --git a/netlib/http/headers.py b/netlib/http/headers.py index 131e8ce5..b55874ca 100644 --- a/netlib/http/headers.py +++ b/netlib/http/headers.py @@ -14,6 +14,7 @@ if six.PY2: # pragma: no cover return x def _always_bytes(x): + strutils.always_bytes(x, "utf-8", "replace") # raises a TypeError if x != str/bytes/None. return x else: # While headers _should_ be ASCII, it's not uncommon for certain headers to be utf-8 encoded. diff --git a/netlib/http/http2/__init__.py b/netlib/http/http2/__init__.py index 60064190..7f84a1ab 100644 --- a/netlib/http/http2/__init__.py +++ b/netlib/http/http2/__init__.py @@ -1,8 +1,10 @@ from __future__ import absolute_import, print_function, division -from netlib.http.http2 import framereader + +from netlib.http.http2.framereader import read_raw_frame, parse_frame from netlib.http.http2.utils import parse_headers __all__ = [ - "framereader", + "read_raw_frame", + "parse_frame", "parse_headers", ] diff --git a/netlib/http/http2/framereader.py b/netlib/http/http2/framereader.py index eb9b069a..8b7cfffb 100644 --- a/netlib/http/http2/framereader.py +++ b/netlib/http/http2/framereader.py @@ -4,7 +4,7 @@ import hyperframe from ...exceptions import HttpException -def http2_read_raw_frame(rfile): +def read_raw_frame(rfile): header = rfile.safe_read(9) length = int(codecs.encode(header[:3], 'hex_codec'), 16) @@ -15,8 +15,11 @@ def http2_read_raw_frame(rfile): return [header, body] -def http2_read_frame(rfile): - header, body = http2_read_raw_frame(rfile) +def parse_frame(header, body=None): + if body is None: + body = header[9:] + header = header[:9] + frame, length = hyperframe.frame.Frame.parse_frame_header(header) frame.parse_body(memoryview(body)) return frame diff --git a/netlib/http/response.py b/netlib/http/response.py index ae29298f..385e233a 100644 --- a/netlib/http/response.py +++ b/netlib/http/response.py @@ -84,15 +84,6 @@ class Response(message.Message): (), None ) - # Assign this manually to update the content-length header. - if isinstance(content, bytes): - resp.content = content - elif isinstance(content, str): - resp.text = content - else: - raise TypeError("Expected content to be str or bytes, but is {}.".format( - type(content).__name__ - )) # Headers can be list or dict, we differentiate here. if isinstance(headers, dict): @@ -104,6 +95,16 @@ class Response(message.Message): type(headers).__name__ )) + # Assign this manually to update the content-length header. + if isinstance(content, bytes): + resp.content = content + elif isinstance(content, str): + resp.text = content + else: + raise TypeError("Expected content to be str or bytes, but is {}.".format( + type(content).__name__ + )) + return resp @property diff --git a/netlib/strutils.py b/netlib/strutils.py index 4cb3b805..d43c2aab 100644 --- a/netlib/strutils.py +++ b/netlib/strutils.py @@ -8,7 +8,10 @@ import six def always_bytes(unicode_or_bytes, *encode_args): if isinstance(unicode_or_bytes, six.text_type): return unicode_or_bytes.encode(*encode_args) - return unicode_or_bytes + elif isinstance(unicode_or_bytes, bytes) or unicode_or_bytes is None: + return unicode_or_bytes + else: + raise TypeError("Expected str or bytes, but got {}.".format(type(unicode_or_bytes).__name__)) def native(s, *encoding_opts): diff --git a/pathod/language/http2.py b/pathod/language/http2.py index c0313baa..519ee699 100644 --- a/pathod/language/http2.py +++ b/pathod/language/http2.py @@ -189,7 +189,7 @@ class Response(_HTTP2Message): resp = http.Response( b'HTTP/2.0', - self.status_code.string(), + int(self.status_code.string()), b'', headers, body, diff --git a/pathod/protocols/http2.py b/pathod/protocols/http2.py index 5ad120de..7b162664 100644 --- a/pathod/protocols/http2.py +++ b/pathod/protocols/http2.py @@ -6,7 +6,7 @@ import time import hyperframe.frame from hpack.hpack import Encoder, Decoder -from netlib import utils, strutils +from netlib import utils from netlib.http import http2 import netlib.http.headers import netlib.http.response @@ -201,7 +201,7 @@ class HTTP2StateProtocol(object): headers = response.headers.copy() if ':status' not in headers: - headers.insert(0, b':status', strutils.always_bytes(response.status_code)) + headers.insert(0, b':status', str(response.status_code).encode()) if hasattr(response, 'stream_id'): stream_id = response.stream_id @@ -254,7 +254,7 @@ class HTTP2StateProtocol(object): def read_frame(self, hide=False): while True: - frm = http2.framereader.http2_read_frame(self.tcp_handler.rfile) + frm = http2.parse_frame(*http2.read_raw_frame(self.tcp_handler.rfile)) if not hide and self.dump_frames: # pragma no cover print(frm.human_readable("<<")) @@ -85,7 +85,7 @@ setup( "tornado>=4.3, <4.5", "urwid>=1.3.1, <1.4", "watchdog>=0.8.3, <0.9", - "brotlipy>=0.3.0, <0.5", + "brotlipy>=0.5.1, <0.7", ], extras_require={ ':sys_platform == "win32"': [ @@ -107,6 +107,7 @@ setup( "pytest-cov>=2.2.1, <3", "pytest-timeout>=1.0.0, <2", "pytest-xdist>=1.14, <2", + "pytest-faulthandler>=1.3.0, <2", "sphinx>=1.3.5, <1.5", "sphinx-autobuild>=0.5.2, <0.7", "sphinxcontrib-documentedlist>=0.4.0, <0.5", diff --git a/test/mitmproxy/builtins/test_clientplayback.py b/test/mitmproxy/builtins/test_clientplayback.py new file mode 100644 index 00000000..15702340 --- /dev/null +++ b/test/mitmproxy/builtins/test_clientplayback.py @@ -0,0 +1,37 @@ +import mock + +from mitmproxy.builtins import clientplayback +from mitmproxy import options + +from .. import tutils, mastertest + + +class TestClientPlayback: + def test_playback(self): + cp = clientplayback.ClientPlayback() + cp.configure(options.Options(), []) + assert cp.count() == 0 + f = tutils.tflow(resp=True) + cp.load([f]) + assert cp.count() == 1 + RP = "mitmproxy.protocol.http_replay.RequestReplayThread" + with mock.patch(RP) as rp: + assert not cp.current + with mastertest.mockctx(): + cp.tick() + rp.assert_called() + assert cp.current + + cp.keepserving = False + cp.flows = None + cp.current = None + with mock.patch("mitmproxy.controller.Master.shutdown") as sd: + with mastertest.mockctx(): + cp.tick() + sd.assert_called() + + def test_configure(self): + cp = clientplayback.ClientPlayback() + cp.configure( + options.Options(), [] + ) diff --git a/test/mitmproxy/builtins/test_script.py b/test/mitmproxy/builtins/test_script.py index 0bac6ca0..09e5bc92 100644 --- a/test/mitmproxy/builtins/test_script.py +++ b/test/mitmproxy/builtins/test_script.py @@ -137,6 +137,31 @@ class TestScript(mastertest.MasterTest): class TestScriptLoader(mastertest.MasterTest): + def test_run_once(self): + s = state.State() + o = options.Options(scripts=[]) + m = master.FlowMaster(o, None, s) + sl = script.ScriptLoader() + m.addons.add(o, sl) + + f = tutils.tflow(resp=True) + with m.handlecontext(): + sc = sl.run_once( + tutils.test_data.path( + "data/addonscripts/recorder.py" + ), [f] + ) + evts = [i[1] for i in sc.ns.call_log] + assert evts == ['start', 'request', 'responseheaders', 'response', 'done'] + + with m.handlecontext(): + tutils.raises( + "file not found", + sl.run_once, + "nonexistent", + [f] + ) + def test_simple(self): s = state.State() o = options.Options(scripts=[]) diff --git a/test/mitmproxy/builtins/test_serverplayback.py b/test/mitmproxy/builtins/test_serverplayback.py new file mode 100644 index 00000000..4db509da --- /dev/null +++ b/test/mitmproxy/builtins/test_serverplayback.py @@ -0,0 +1,284 @@ +from .. import tutils, mastertest + +import netlib.tutils +from mitmproxy.builtins import serverplayback +from mitmproxy import options +from mitmproxy import exceptions +from mitmproxy import flow + + +class TestServerPlayback: + def test_server_playback(self): + sp = serverplayback.ServerPlayback() + sp.configure(options.Options(), []) + f = tutils.tflow(resp=True) + + assert not sp.flowmap + + sp.load([f]) + assert sp.flowmap + assert sp.next_flow(f) + assert not sp.flowmap + + def test_ignore_host(self): + sp = serverplayback.ServerPlayback() + sp.configure(options.Options(server_replay_ignore_host=True), []) + + r = tutils.tflow(resp=True) + r2 = tutils.tflow(resp=True) + + r.request.host = "address" + r2.request.host = "address" + assert sp._hash(r) == sp._hash(r2) + r2.request.host = "wrong_address" + assert sp._hash(r) == sp._hash(r2) + + def test_ignore_content(self): + s = serverplayback.ServerPlayback() + s.configure(options.Options(server_replay_ignore_content=False), []) + + r = tutils.tflow(resp=True) + r2 = tutils.tflow(resp=True) + + r.request.content = b"foo" + r2.request.content = b"foo" + assert s._hash(r) == s._hash(r2) + r2.request.content = b"bar" + assert not s._hash(r) == s._hash(r2) + + s.configure(options.Options(server_replay_ignore_content=True), []) + r = tutils.tflow(resp=True) + r2 = tutils.tflow(resp=True) + r.request.content = b"foo" + r2.request.content = b"foo" + assert s._hash(r) == s._hash(r2) + r2.request.content = b"bar" + assert s._hash(r) == s._hash(r2) + r2.request.content = b"" + assert s._hash(r) == s._hash(r2) + r2.request.content = None + assert s._hash(r) == s._hash(r2) + + def test_ignore_content_wins_over_params(self): + s = serverplayback.ServerPlayback() + s.configure( + options.Options( + server_replay_ignore_content=True, + server_replay_ignore_payload_params=[ + "param1", "param2" + ] + ), + [] + ) + # NOTE: parameters are mutually exclusive in options + + r = tutils.tflow(resp=True) + r.request.headers["Content-Type"] = "application/x-www-form-urlencoded" + r.request.content = b"paramx=y" + + r2 = tutils.tflow(resp=True) + r2.request.headers["Content-Type"] = "application/x-www-form-urlencoded" + r2.request.content = b"paramx=x" + + # same parameters + assert s._hash(r) == s._hash(r2) + + def test_ignore_payload_params_other_content_type(self): + s = serverplayback.ServerPlayback() + s.configure( + options.Options( + server_replay_ignore_content=False, + server_replay_ignore_payload_params=[ + "param1", "param2" + ] + ), + [] + + ) + r = tutils.tflow(resp=True) + r.request.headers["Content-Type"] = "application/json" + r.request.content = b'{"param1":"1"}' + r2 = tutils.tflow(resp=True) + r2.request.headers["Content-Type"] = "application/json" + r2.request.content = b'{"param1":"1"}' + # same content + assert s._hash(r) == s._hash(r2) + # distint content (note only x-www-form-urlencoded payload is analysed) + r2.request.content = b'{"param1":"2"}' + assert not s._hash(r) == s._hash(r2) + + def test_hash(self): + s = serverplayback.ServerPlayback() + s.configure(options.Options(), []) + + r = tutils.tflow() + r2 = tutils.tflow() + + assert s._hash(r) + assert s._hash(r) == s._hash(r2) + r.request.headers["foo"] = "bar" + assert s._hash(r) == s._hash(r2) + r.request.path = "voing" + assert s._hash(r) != s._hash(r2) + + r.request.path = "path?blank_value" + r2.request.path = "path?" + assert s._hash(r) != s._hash(r2) + + def test_headers(self): + s = serverplayback.ServerPlayback() + s.configure(options.Options(server_replay_use_headers=["foo"]), []) + + r = tutils.tflow(resp=True) + r.request.headers["foo"] = "bar" + r2 = tutils.tflow(resp=True) + assert not s._hash(r) == s._hash(r2) + r2.request.headers["foo"] = "bar" + assert s._hash(r) == s._hash(r2) + r2.request.headers["oink"] = "bar" + assert s._hash(r) == s._hash(r2) + + r = tutils.tflow(resp=True) + r2 = tutils.tflow(resp=True) + assert s._hash(r) == s._hash(r2) + + def test_load(self): + s = serverplayback.ServerPlayback() + s.configure(options.Options(), []) + + r = tutils.tflow(resp=True) + r.request.headers["key"] = "one" + + r2 = tutils.tflow(resp=True) + r2.request.headers["key"] = "two" + + s.load([r, r2]) + + assert s.count() == 2 + + n = s.next_flow(r) + assert n.request.headers["key"] == "one" + assert s.count() == 1 + + n = s.next_flow(r) + assert n.request.headers["key"] == "two" + assert not s.flowmap + assert s.count() == 0 + + assert not s.next_flow(r) + + def test_load_with_server_replay_nopop(self): + s = serverplayback.ServerPlayback() + s.configure(options.Options(server_replay_nopop=True), []) + + r = tutils.tflow(resp=True) + r.request.headers["key"] = "one" + + r2 = tutils.tflow(resp=True) + r2.request.headers["key"] = "two" + + s.load([r, r2]) + + assert s.count() == 2 + s.next_flow(r) + assert s.count() == 2 + + def test_ignore_params(self): + s = serverplayback.ServerPlayback() + s.configure( + options.Options( + server_replay_ignore_params=["param1", "param2"] + ), + [] + ) + + r = tutils.tflow(resp=True) + r.request.path = "/test?param1=1" + r2 = tutils.tflow(resp=True) + r2.request.path = "/test" + assert s._hash(r) == s._hash(r2) + r2.request.path = "/test?param1=2" + assert s._hash(r) == s._hash(r2) + r2.request.path = "/test?param2=1" + assert s._hash(r) == s._hash(r2) + r2.request.path = "/test?param3=2" + assert not s._hash(r) == s._hash(r2) + + def test_ignore_payload_params(self): + s = serverplayback.ServerPlayback() + s.configure( + options.Options( + server_replay_ignore_payload_params=["param1", "param2"] + ), + [] + ) + + r = tutils.tflow(resp=True) + r.request.headers["Content-Type"] = "application/x-www-form-urlencoded" + r.request.content = b"paramx=x¶m1=1" + r2 = tutils.tflow(resp=True) + r2.request.headers["Content-Type"] = "application/x-www-form-urlencoded" + r2.request.content = b"paramx=x¶m1=1" + # same parameters + assert s._hash(r) == s._hash(r2) + # ignored parameters != + r2.request.content = b"paramx=x¶m1=2" + assert s._hash(r) == s._hash(r2) + # missing parameter + r2.request.content = b"paramx=x" + assert s._hash(r) == s._hash(r2) + # ignorable parameter added + r2.request.content = b"paramx=x¶m1=2" + assert s._hash(r) == s._hash(r2) + # not ignorable parameter changed + r2.request.content = b"paramx=y¶m1=1" + assert not s._hash(r) == s._hash(r2) + # not ignorable parameter missing + r2.request.content = b"param1=1" + assert not s._hash(r) == s._hash(r2) + + def test_server_playback_full(self): + state = flow.State() + s = serverplayback.ServerPlayback() + o = options.Options(refresh_server_playback = True, keepserving=False) + m = mastertest.RecordingMaster(o, None, state) + m.addons.add(o, s) + + f = tutils.tflow() + f.response = netlib.tutils.tresp(content=f.request.content) + s.load([f, f]) + + tf = tutils.tflow() + assert not tf.response + m.request(tf) + assert tf.response == f.response + + tf = tutils.tflow() + tf.request.content = b"gibble" + assert not tf.response + m.request(tf) + assert not tf.response + + assert not s.stop + s.tick() + assert not s.stop + + tf = tutils.tflow() + m.request(tutils.tflow()) + assert s.stop + + def test_server_playback_kill(self): + state = flow.State() + s = serverplayback.ServerPlayback() + o = options.Options(refresh_server_playback = True, replay_kill_extra=True) + m = mastertest.RecordingMaster(o, None, state) + m.addons.add(o, s) + + f = tutils.tflow() + f.response = netlib.tutils.tresp(content=f.request.content) + s.load([f]) + + f = tutils.tflow() + f.request.host = "nonexistent" + m.request(f) + assert f.reply.value == exceptions.Kill diff --git a/test/mitmproxy/console/test_master.py b/test/mitmproxy/console/test_master.py index fcb87e1b..8388a6bd 100644 --- a/test/mitmproxy/console/test_master.py +++ b/test/mitmproxy/console/test_master.py @@ -107,7 +107,7 @@ def test_format_keyvals(): def test_options(): - assert console.master.Options(kill=True) + assert console.master.Options(replay_kill_extra=True) class TestMaster(mastertest.MasterTest): diff --git a/test/mitmproxy/mastertest.py b/test/mitmproxy/mastertest.py index 08659d19..a14fe02a 100644 --- a/test/mitmproxy/mastertest.py +++ b/test/mitmproxy/mastertest.py @@ -1,8 +1,14 @@ +import contextlib + from . import tutils import netlib.tutils from mitmproxy.flow import master -from mitmproxy import flow, proxy, models, controller +from mitmproxy import flow, proxy, models, controller, options + + +class TestMaster: + pass class MasterTest: @@ -16,7 +22,9 @@ class MasterTest: master.serverconnect(f.server_conn) master.request(f) if not f.error: - f.response = models.HTTPResponse.wrap(netlib.tutils.tresp(content=content)) + f.response = models.HTTPResponse.wrap( + netlib.tutils.tresp(content=content) + ) master.response(f) master.clientdisconnect(f) return f @@ -41,3 +49,12 @@ class RecordingMaster(master.FlowMaster): def add_log(self, e, level): self.event_log.append((level, e)) + + +@contextlib.contextmanager +def mockctx(): + state = flow.State() + o = options.Options(refresh_server_playback = True, keepserving=False) + m = RecordingMaster(o, proxy.DummyServer(o), state) + with m.handlecontext(): + yield diff --git a/test/mitmproxy/protocol/test_http1.py b/test/mitmproxy/protocol/test_http1.py index 7d04c56b..2fc4ac63 100644 --- a/test/mitmproxy/protocol/test_http1.py +++ b/test/mitmproxy/protocol/test_http1.py @@ -18,14 +18,15 @@ class TestInvalidRequests(tservers.HTTPProxyTest): def test_double_connect(self): p = self.pathoc() - r = p.request("connect:'%s:%s'" % ("127.0.0.1", self.server2.port)) + with p.connect(): + r = p.request("connect:'%s:%s'" % ("127.0.0.1", self.server2.port)) assert r.status_code == 400 assert b"Invalid HTTP request form" in r.content def test_relative_request(self): p = self.pathoc_raw() - p.connect() - r = p.request("get:/p/200") + with p.connect(): + r = p.request("get:/p/200") assert r.status_code == 400 assert b"Invalid HTTP request form" in r.content @@ -61,5 +62,8 @@ class TestHeadContentLength(tservers.HTTPProxyTest): def test_head_content_length(self): p = self.pathoc() - resp = p.request("""head:'%s/p/200:h"Content-Length"="42"'""" % self.server.urlbase) + with p.connect(): + resp = p.request( + """head:'%s/p/200:h"Content-Length"="42"'""" % self.server.urlbase + ) assert resp.headers["Content-Length"] == "42" diff --git a/test/mitmproxy/protocol/test_http2.py b/test/mitmproxy/protocol/test_http2.py index 1eabebf1..c4bd2049 100644 --- a/test/mitmproxy/protocol/test_http2.py +++ b/test/mitmproxy/protocol/test_http2.py @@ -15,7 +15,7 @@ from mitmproxy.proxy.config import ProxyConfig import netlib from ...netlib import tservers as netlib_tservers from netlib.exceptions import HttpException -from netlib.http.http2 import framereader +from netlib.http import http1, http2 from .. import tservers @@ -33,6 +33,11 @@ requires_alpn = pytest.mark.skipif( reason='requires OpenSSL with ALPN support') +# inspect the log: +# for msg in self.proxy.tmaster.tlog: +# print(msg) + + class _Http2ServerBase(netlib_tservers.ServerTestBase): ssl = dict(alpn_select=b'h2') @@ -55,7 +60,7 @@ class _Http2ServerBase(netlib_tservers.ServerTestBase): done = False while not done: try: - raw = b''.join(framereader.http2_read_raw_frame(self.rfile)) + raw = b''.join(http2.read_raw_frame(self.rfile)) events = h2_conn.receive_data(raw) except HttpException: print(traceback.format_exc()) @@ -124,11 +129,17 @@ class _Http2TestBase(object): client.connect() # send CONNECT request - client.wfile.write( - b"CONNECT localhost:%d HTTP/1.1\r\n" - b"Host: localhost:%d\r\n" - b"\r\n" % (self.server.server.address.port, self.server.server.address.port) - ) + client.wfile.write(http1.assemble_request(netlib.http.Request( + 'authority', + b'CONNECT', + b'', + b'localhost', + self.server.server.address.port, + b'/', + b'HTTP/1.1', + [(b'host', b'localhost:%d' % self.server.server.address.port)], + b'', + ))) client.wfile.flush() # read CONNECT response @@ -242,7 +253,7 @@ class TestSimple(_Http2Test): done = False while not done: try: - raw = b''.join(framereader.http2_read_raw_frame(client.rfile)) + raw = b''.join(http2.read_raw_frame(client.rfile)) events = h2_conn.receive_data(raw) except HttpException: print(traceback.format_exc()) @@ -319,7 +330,7 @@ class TestRequestWithPriority(_Http2Test): done = False while not done: try: - raw = b''.join(framereader.http2_read_raw_frame(client.rfile)) + raw = b''.join(http2.read_raw_frame(client.rfile)) events = h2_conn.receive_data(raw) except HttpException: print(traceback.format_exc()) @@ -358,7 +369,7 @@ class TestRequestWithPriority(_Http2Test): done = False while not done: try: - raw = b''.join(framereader.http2_read_raw_frame(client.rfile)) + raw = b''.join(http2.read_raw_frame(client.rfile)) events = h2_conn.receive_data(raw) except HttpException: print(traceback.format_exc()) @@ -430,7 +441,7 @@ class TestPriority(_Http2Test): done = False while not done: try: - raw = b''.join(framereader.http2_read_raw_frame(client.rfile)) + raw = b''.join(http2.read_raw_frame(client.rfile)) events = h2_conn.receive_data(raw) except HttpException: print(traceback.format_exc()) @@ -507,7 +518,7 @@ class TestPriorityWithExistingStream(_Http2Test): done = False while not done: try: - raw = b''.join(framereader.http2_read_raw_frame(client.rfile)) + raw = b''.join(http2.read_raw_frame(client.rfile)) events = h2_conn.receive_data(raw) except HttpException: print(traceback.format_exc()) @@ -558,7 +569,7 @@ class TestStreamResetFromServer(_Http2Test): done = False while not done: try: - raw = b''.join(framereader.http2_read_raw_frame(client.rfile)) + raw = b''.join(http2.read_raw_frame(client.rfile)) events = h2_conn.receive_data(raw) except HttpException: print(traceback.format_exc()) @@ -608,7 +619,7 @@ class TestBodySizeLimit(_Http2Test): done = False while not done: try: - raw = b''.join(framereader.http2_read_raw_frame(client.rfile)) + raw = b''.join(http2.read_raw_frame(client.rfile)) events = h2_conn.receive_data(raw) except HttpException: print(traceback.format_exc()) @@ -693,7 +704,7 @@ class TestPushPromise(_Http2Test): responses = 0 while not done: try: - raw = b''.join(framereader.http2_read_raw_frame(client.rfile)) + raw = b''.join(http2.read_raw_frame(client.rfile)) events = h2_conn.receive_data(raw) except HttpException: print(traceback.format_exc()) @@ -746,7 +757,7 @@ class TestPushPromise(_Http2Test): responses = 0 while not done: try: - raw = b''.join(framereader.http2_read_raw_frame(client.rfile)) + raw = b''.join(http2.read_raw_frame(client.rfile)) events = h2_conn.receive_data(raw) except HttpException: print(traceback.format_exc()) @@ -806,7 +817,7 @@ class TestConnectionLost(_Http2Test): done = False while not done: try: - raw = b''.join(framereader.http2_read_raw_frame(client.rfile)) + raw = b''.join(http2.read_raw_frame(client.rfile)) h2_conn.receive_data(raw) except HttpException: print(traceback.format_exc()) @@ -863,7 +874,7 @@ class TestMaxConcurrentStreams(_Http2Test): ended_streams = 0 while ended_streams != len(new_streams): try: - header, body = framereader.http2_read_raw_frame(client.rfile) + header, body = http2.read_raw_frame(client.rfile) events = h2_conn.receive_data(b''.join([header, body])) except: break @@ -909,7 +920,7 @@ class TestConnectionTerminated(_Http2Test): connection_terminated_event = None while not done: try: - raw = b''.join(framereader.http2_read_raw_frame(client.rfile)) + raw = b''.join(http2.read_raw_frame(client.rfile)) events = h2_conn.receive_data(raw) for event in events: if isinstance(event, h2.events.ConnectionTerminated): diff --git a/test/mitmproxy/test_addons.py b/test/mitmproxy/test_addons.py index a5085ea0..52d7f07f 100644 --- a/test/mitmproxy/test_addons.py +++ b/test/mitmproxy/test_addons.py @@ -17,5 +17,5 @@ def test_simple(): m = controller.Master(o) a = addons.Addons(m) a.add(o, TAddon("one")) - assert a.has_addon("one") - assert not a.has_addon("two") + assert a.get("one") + assert not a.get("two") diff --git a/test/mitmproxy/test_dump.py b/test/mitmproxy/test_dump.py index 90f33264..06f39e9d 100644 --- a/test/mitmproxy/test_dump.py +++ b/test/mitmproxy/test_dump.py @@ -45,18 +45,17 @@ class TestDumpMaster(mastertest.MasterTest): m = dump.DumpMaster(None, o) f = tutils.tflow(err=True) m.error(f) - assert m.error(f) assert "error" in o.tfile.getvalue() def test_replay(self): - o = dump.Options(server_replay=["nonexistent"], kill=True) - tutils.raises(dump.DumpError, dump.DumpMaster, None, o) + o = dump.Options(server_replay=["nonexistent"], replay_kill_extra=True) + tutils.raises(exceptions.OptionsError, dump.DumpMaster, None, o) with tutils.tmpdir() as t: p = os.path.join(t, "rep") self.flowfile(p) - o = dump.Options(server_replay=[p], kill=True) + o = dump.Options(server_replay=[p], replay_kill_extra=True) o.verbosity = 0 o.flow_detail = 0 m = dump.DumpMaster(None, o) @@ -64,13 +63,13 @@ class TestDumpMaster(mastertest.MasterTest): self.cycle(m, b"content") self.cycle(m, b"content") - o = dump.Options(server_replay=[p], kill=False) + o = dump.Options(server_replay=[p], replay_kill_extra=False) o.verbosity = 0 o.flow_detail = 0 m = dump.DumpMaster(None, o) self.cycle(m, b"nonexistent") - o = dump.Options(client_replay=[p], kill=False) + o = dump.Options(client_replay=[p], replay_kill_extra=False) o.verbosity = 0 o.flow_detail = 0 m = dump.DumpMaster(None, o) diff --git a/test/mitmproxy/test_flow.py b/test/mitmproxy/test_flow.py index 1caeb100..0fe45afb 100644 --- a/test/mitmproxy/test_flow.py +++ b/test/mitmproxy/test_flow.py @@ -37,261 +37,6 @@ def test_app_registry(): assert ar.get(r) -class TestClientPlaybackState: - - def test_tick(self): - first = tutils.tflow() - s = flow.State() - fm = flow.FlowMaster(None, None, s) - fm.start_client_playback([first, tutils.tflow()], True) - c = fm.client_playback - c.testing = True - - assert not c.done() - assert not s.flow_count() - assert c.count() == 2 - c.tick(fm) - assert s.flow_count() - assert c.count() == 1 - - c.tick(fm) - assert c.count() == 1 - - c.clear(c.current) - c.tick(fm) - assert c.count() == 0 - c.clear(c.current) - assert c.done() - - fm.state.clear() - fm.tick(timeout=0) - - fm.stop_client_playback() - assert not fm.client_playback - - -class TestServerPlaybackState: - - def test_hash(self): - s = flow.ServerPlaybackState( - None, - [], - False, - False, - None, - False, - None, - False) - r = tutils.tflow() - r2 = tutils.tflow() - - assert s._hash(r) - assert s._hash(r) == s._hash(r2) - r.request.headers["foo"] = "bar" - assert s._hash(r) == s._hash(r2) - r.request.path = "voing" - assert s._hash(r) != s._hash(r2) - - r.request.path = "path?blank_value" - r2.request.path = "path?" - assert s._hash(r) != s._hash(r2) - - def test_headers(self): - s = flow.ServerPlaybackState( - ["foo"], - [], - False, - False, - None, - False, - None, - False) - r = tutils.tflow(resp=True) - r.request.headers["foo"] = "bar" - r2 = tutils.tflow(resp=True) - assert not s._hash(r) == s._hash(r2) - r2.request.headers["foo"] = "bar" - assert s._hash(r) == s._hash(r2) - r2.request.headers["oink"] = "bar" - assert s._hash(r) == s._hash(r2) - - r = tutils.tflow(resp=True) - r2 = tutils.tflow(resp=True) - assert s._hash(r) == s._hash(r2) - - def test_load(self): - r = tutils.tflow(resp=True) - r.request.headers["key"] = "one" - - r2 = tutils.tflow(resp=True) - r2.request.headers["key"] = "two" - - s = flow.ServerPlaybackState( - None, [ - r, r2], False, False, None, False, None, False) - assert s.count() == 2 - assert len(s.fmap.keys()) == 1 - - n = s.next_flow(r) - assert n.request.headers["key"] == "one" - assert s.count() == 1 - - n = s.next_flow(r) - assert n.request.headers["key"] == "two" - assert s.count() == 0 - - assert not s.next_flow(r) - - def test_load_with_nopop(self): - r = tutils.tflow(resp=True) - r.request.headers["key"] = "one" - - r2 = tutils.tflow(resp=True) - r2.request.headers["key"] = "two" - - s = flow.ServerPlaybackState( - None, [ - r, r2], False, True, None, False, None, False) - - assert s.count() == 2 - s.next_flow(r) - assert s.count() == 2 - - def test_ignore_params(self): - s = flow.ServerPlaybackState( - None, [], False, False, [ - "param1", "param2"], False, None, False) - r = tutils.tflow(resp=True) - r.request.path = "/test?param1=1" - r2 = tutils.tflow(resp=True) - r2.request.path = "/test" - assert s._hash(r) == s._hash(r2) - r2.request.path = "/test?param1=2" - assert s._hash(r) == s._hash(r2) - r2.request.path = "/test?param2=1" - assert s._hash(r) == s._hash(r2) - r2.request.path = "/test?param3=2" - assert not s._hash(r) == s._hash(r2) - - def test_ignore_payload_params(self): - s = flow.ServerPlaybackState( - None, [], False, False, None, False, [ - "param1", "param2"], False) - r = tutils.tflow(resp=True) - r.request.headers["Content-Type"] = "application/x-www-form-urlencoded" - r.request.content = b"paramx=x¶m1=1" - r2 = tutils.tflow(resp=True) - r2.request.headers["Content-Type"] = "application/x-www-form-urlencoded" - r2.request.content = b"paramx=x¶m1=1" - # same parameters - assert s._hash(r) == s._hash(r2) - # ignored parameters != - r2.request.content = b"paramx=x¶m1=2" - assert s._hash(r) == s._hash(r2) - # missing parameter - r2.request.content = b"paramx=x" - assert s._hash(r) == s._hash(r2) - # ignorable parameter added - r2.request.content = b"paramx=x¶m1=2" - assert s._hash(r) == s._hash(r2) - # not ignorable parameter changed - r2.request.content = b"paramx=y¶m1=1" - assert not s._hash(r) == s._hash(r2) - # not ignorable parameter missing - r2.request.content = b"param1=1" - assert not s._hash(r) == s._hash(r2) - - def test_ignore_payload_params_other_content_type(self): - s = flow.ServerPlaybackState( - None, [], False, False, None, False, [ - "param1", "param2"], False) - r = tutils.tflow(resp=True) - r.request.headers["Content-Type"] = "application/json" - r.request.content = b'{"param1":"1"}' - r2 = tutils.tflow(resp=True) - r2.request.headers["Content-Type"] = "application/json" - r2.request.content = b'{"param1":"1"}' - # same content - assert s._hash(r) == s._hash(r2) - # distint content (note only x-www-form-urlencoded payload is analysed) - r2.request.content = b'{"param1":"2"}' - assert not s._hash(r) == s._hash(r2) - - def test_ignore_payload_wins_over_params(self): - # NOTE: parameters are mutually exclusive in options - s = flow.ServerPlaybackState( - None, [], False, False, None, True, [ - "param1", "param2"], False) - r = tutils.tflow(resp=True) - r.request.headers["Content-Type"] = "application/x-www-form-urlencoded" - r.request.content = b"paramx=y" - r2 = tutils.tflow(resp=True) - r2.request.headers["Content-Type"] = "application/x-www-form-urlencoded" - r2.request.content = b"paramx=x" - # same parameters - assert s._hash(r) == s._hash(r2) - - def test_ignore_content(self): - s = flow.ServerPlaybackState( - None, - [], - False, - False, - None, - False, - None, - False) - r = tutils.tflow(resp=True) - r2 = tutils.tflow(resp=True) - - r.request.content = b"foo" - r2.request.content = b"foo" - assert s._hash(r) == s._hash(r2) - r2.request.content = b"bar" - assert not s._hash(r) == s._hash(r2) - - # now ignoring content - s = flow.ServerPlaybackState( - None, - [], - False, - False, - None, - True, - None, - False) - r = tutils.tflow(resp=True) - r2 = tutils.tflow(resp=True) - r.request.content = b"foo" - r2.request.content = b"foo" - assert s._hash(r) == s._hash(r2) - r2.request.content = b"bar" - assert s._hash(r) == s._hash(r2) - r2.request.content = b"" - assert s._hash(r) == s._hash(r2) - r2.request.content = None - assert s._hash(r) == s._hash(r2) - - def test_ignore_host(self): - s = flow.ServerPlaybackState( - None, - [], - False, - False, - None, - False, - None, - True) - r = tutils.tflow(resp=True) - r2 = tutils.tflow(resp=True) - - r.request.host = "address" - r2.request.host = "address" - assert s._hash(r) == s._hash(r2) - r2.request.host = "wrong_address" - assert s._hash(r) == s._hash(r2) - - class TestHTTPFlow(object): def test_copy(self): @@ -699,13 +444,13 @@ class TestFlowMaster: fm = flow.FlowMaster(None, None, s) f = tutils.tflow(resp=True) f.request.content = None - assert "missing" in fm.replay_request(f) + tutils.raises("missing", fm.replay_request, f) f.intercepted = True - assert "intercepting" in fm.replay_request(f) + tutils.raises("intercepted", fm.replay_request, f) f.live = True - assert "live" in fm.replay_request(f) + tutils.raises("live", fm.replay_request, f) def test_duplicate_flow(self): s = flow.State() @@ -743,103 +488,6 @@ class TestFlowMaster: fm.shutdown() - def test_client_playback(self): - s = flow.State() - - f = tutils.tflow(resp=True) - pb = [tutils.tflow(resp=True), f] - fm = flow.FlowMaster( - options.Options(), - DummyServer(ProxyConfig(options.Options())), - s - ) - assert not fm.start_server_playback( - pb, - False, - [], - False, - False, - None, - False, - None, - False) - assert not fm.start_client_playback(pb, False) - fm.client_playback.testing = True - - assert not fm.state.flow_count() - fm.tick(0) - assert fm.state.flow_count() - - f.error = Error("error") - fm.error(f) - - def test_server_playback(self): - s = flow.State() - - f = tutils.tflow() - f.response = HTTPResponse.wrap(netlib.tutils.tresp(content=f.request)) - pb = [f] - - fm = flow.FlowMaster(options.Options(), None, s) - fm.refresh_server_playback = True - assert not fm.do_server_playback(tutils.tflow()) - - fm.start_server_playback( - pb, - False, - [], - False, - False, - None, - False, - None, - False) - assert fm.do_server_playback(tutils.tflow()) - - fm.start_server_playback( - pb, - False, - [], - True, - False, - None, - False, - None, - False) - r = tutils.tflow() - r.request.content = b"gibble" - assert not fm.do_server_playback(r) - assert fm.do_server_playback(tutils.tflow()) - - fm.tick(0) - assert fm.should_exit.is_set() - - fm.stop_server_playback() - assert not fm.server_playback - - def test_server_playback_kill(self): - s = flow.State() - f = tutils.tflow() - f.response = HTTPResponse.wrap(netlib.tutils.tresp(content=f.request)) - pb = [f] - fm = flow.FlowMaster(None, None, s) - fm.refresh_server_playback = True - fm.start_server_playback( - pb, - True, - [], - False, - False, - None, - False, - None, - False) - - f = tutils.tflow() - f.request.host = "nonexistent" - fm.request(f) - assert f.reply.value == Kill - class TestRequest: diff --git a/test/mitmproxy/test_fuzzing.py b/test/mitmproxy/test_fuzzing.py index 27ea36a6..905ba1cd 100644 --- a/test/mitmproxy/test_fuzzing.py +++ b/test/mitmproxy/test_fuzzing.py @@ -11,17 +11,20 @@ class TestFuzzy(tservers.HTTPProxyTest): def test_idna_err(self): req = r'get:"http://localhost:%s":i10,"\xc6"' p = self.pathoc() - assert p.request(req % self.server.port).status_code == 400 + with p.connect(): + assert p.request(req % self.server.port).status_code == 400 def test_nullbytes(self): req = r'get:"http://localhost:%s":i19,"\x00"' p = self.pathoc() - assert p.request(req % self.server.port).status_code == 400 + with p.connect(): + assert p.request(req % self.server.port).status_code == 400 def test_invalid_ipv6_url(self): req = 'get:"http://localhost:%s":i13,"["' p = self.pathoc() - resp = p.request(req % self.server.port) + with p.connect(): + resp = p.request(req % self.server.port) assert resp.status_code == 400 # def test_invalid_upstream(self): diff --git a/test/mitmproxy/test_server.py b/test/mitmproxy/test_server.py index e0a8da47..c5a5bb71 100644 --- a/test/mitmproxy/test_server.py +++ b/test/mitmproxy/test_server.py @@ -60,7 +60,7 @@ class CommonMixin: # Disconnect error l.request.path = "/p/305:d0" rt = self.master.replay_request(l, block=True) - assert not rt + assert rt if isinstance(self, tservers.HTTPUpstreamProxyTest): assert l.response.status_code == 502 else: @@ -72,7 +72,7 @@ class CommonMixin: # In upstream mode with ssl, the replay will fail as we cannot establish # SSL with the upstream proxy. rt = self.master.replay_request(l, block=True) - assert not rt + assert rt if isinstance(self, tservers.HTTPUpstreamProxyTest): assert l.response.status_code == 502 else: @@ -91,11 +91,11 @@ class CommonMixin: def test_invalid_http(self): t = tcp.TCPClient(("127.0.0.1", self.proxy.port)) - t.connect() - t.wfile.write(b"invalid\r\n\r\n") - t.wfile.flush() - line = t.rfile.readline() - assert (b"Bad Request" in line) or (b"Bad Gateway" in line) + with t.connect(): + t.wfile.write(b"invalid\r\n\r\n") + t.wfile.flush() + line = t.rfile.readline() + assert (b"Bad Request" in line) or (b"Bad Gateway" in line) def test_sni(self): if not self.ssl: @@ -208,20 +208,22 @@ class TestHTTP(tservers.HTTPProxyTest, CommonMixin, AppMixin): def test_app_err(self): p = self.pathoc() - ret = p.request("get:'http://errapp/'") + with p.connect(): + ret = p.request("get:'http://errapp/'") assert ret.status_code == 500 assert b"ValueError" in ret.content def test_invalid_connect(self): t = tcp.TCPClient(("127.0.0.1", self.proxy.port)) - t.connect() - t.wfile.write(b"CONNECT invalid\n\n") - t.wfile.flush() - assert b"Bad Request" in t.rfile.readline() + with t.connect(): + t.wfile.write(b"CONNECT invalid\n\n") + t.wfile.flush() + assert b"Bad Request" in t.rfile.readline() def test_upstream_ssl_error(self): p = self.pathoc() - ret = p.request("get:'https://localhost:%s/'" % self.server.port) + with p.connect(): + ret = p.request("get:'https://localhost:%s/'" % self.server.port) assert ret.status_code == 400 def test_connection_close(self): @@ -232,25 +234,28 @@ class TestHTTP(tservers.HTTPProxyTest, CommonMixin, AppMixin): # Lets sanity check that the connection does indeed stay open by # issuing two requests over the same connection p = self.pathoc() - assert p.request("get:'%s'" % response) - assert p.request("get:'%s'" % response) + with p.connect(): + assert p.request("get:'%s'" % response) + assert p.request("get:'%s'" % response) # Now check that the connection is closed as the client specifies p = self.pathoc() - assert p.request("get:'%s':h'Connection'='close'" % response) - # There's a race here, which means we can get any of a number of errors. - # Rather than introduce yet another sleep into the test suite, we just - # relax the Exception specification. - with raises(Exception): - p.request("get:'%s'" % response) + with p.connect(): + assert p.request("get:'%s':h'Connection'='close'" % response) + # There's a race here, which means we can get any of a number of errors. + # Rather than introduce yet another sleep into the test suite, we just + # relax the Exception specification. + with raises(Exception): + p.request("get:'%s'" % response) def test_reconnect(self): req = "get:'%s/p/200:b@1:da'" % self.server.urlbase p = self.pathoc() - assert p.request(req) - # Server has disconnected. Mitmproxy should detect this, and reconnect. - assert p.request(req) - assert p.request(req) + with p.connect(): + assert p.request(req) + # Server has disconnected. Mitmproxy should detect this, and reconnect. + assert p.request(req) + assert p.request(req) def test_get_connection_switching(self): def switched(l): @@ -260,18 +265,21 @@ class TestHTTP(tservers.HTTPProxyTest, CommonMixin, AppMixin): req = "get:'%s/p/200:b@1'" p = self.pathoc() - assert p.request(req % self.server.urlbase) - assert p.request(req % self.server2.urlbase) + with p.connect(): + assert p.request(req % self.server.urlbase) + assert p.request(req % self.server2.urlbase) assert switched(self.proxy.tlog) def test_blank_leading_line(self): p = self.pathoc() - req = "get:'%s/p/201':i0,'\r\n'" - assert p.request(req % self.server.urlbase).status_code == 201 + with p.connect(): + req = "get:'%s/p/201':i0,'\r\n'" + assert p.request(req % self.server.urlbase).status_code == 201 def test_invalid_headers(self): p = self.pathoc() - resp = p.request("get:'http://foo':h':foo'='bar'") + with p.connect(): + resp = p.request("get:'http://foo':h':foo'='bar'") assert resp.status_code == 400 def test_stream(self): @@ -301,15 +309,16 @@ class TestHTTPAuth(tservers.HTTPProxyTest): self.master.options.auth_singleuser = "test:test" assert self.pathod("202").status_code == 407 p = self.pathoc() - ret = p.request(""" - get - 'http://localhost:%s/p/202' - h'%s'='%s' - """ % ( - self.server.port, - http.authentication.BasicProxyAuth.AUTH_HEADER, - authentication.assemble_http_basic_auth("basic", "test", "test") - )) + with p.connect(): + ret = p.request(""" + get + 'http://localhost:%s/p/202' + h'%s'='%s' + """ % ( + self.server.port, + http.authentication.BasicProxyAuth.AUTH_HEADER, + authentication.assemble_http_basic_auth("basic", "test", "test") + )) assert ret.status_code == 202 @@ -318,14 +327,15 @@ class TestHTTPReverseAuth(tservers.ReverseProxyTest): self.master.options.auth_singleuser = "test:test" assert self.pathod("202").status_code == 401 p = self.pathoc() - ret = p.request(""" - get - '/p/202' - h'%s'='%s' - """ % ( - http.authentication.BasicWebsiteAuth.AUTH_HEADER, - authentication.assemble_http_basic_auth("basic", "test", "test") - )) + with p.connect(): + ret = p.request(""" + get + '/p/202' + h'%s'='%s' + """ % ( + http.authentication.BasicWebsiteAuth.AUTH_HEADER, + authentication.assemble_http_basic_auth("basic", "test", "test") + )) assert ret.status_code == 202 @@ -354,7 +364,8 @@ class TestHTTPS(tservers.HTTPProxyTest, CommonMixin, TcpMixin): def test_error_post_connect(self): p = self.pathoc() - assert p.request("get:/:i0,'invalid\r\n\r\n'").status_code == 400 + with p.connect(): + assert p.request("get:/:i0,'invalid\r\n\r\n'").status_code == 400 class TestHTTPSCertfile(tservers.HTTPProxyTest, CommonMixin): @@ -389,7 +400,8 @@ class TestHTTPSUpstreamServerVerificationWTrustedCert(tservers.HTTPProxyTest): def _request(self): p = self.pathoc(sni="example.mitmproxy.org") - return p.request("get:/p/242") + with p.connect(): + return p.request("get:/p/242") def test_verification_w_cadir(self): self.config.options.update( @@ -426,7 +438,8 @@ class TestHTTPSUpstreamServerVerificationWBadCert(tservers.HTTPProxyTest): def _request(self): p = self.pathoc(sni="example.mitmproxy.org") - return p.request("get:/p/242") + with p.connect(): + return p.request("get:/p/242") @classmethod def get_options(cls): @@ -481,13 +494,15 @@ class TestSocks5(tservers.SocksModeTest): def test_simple(self): p = self.pathoc() - p.socks_connect(("localhost", self.server.port)) - f = p.request("get:/p/200") + with p.connect(): + p.socks_connect(("localhost", self.server.port)) + f = p.request("get:/p/200") assert f.status_code == 200 def test_with_authentication_only(self): p = self.pathoc() - f = p.request("get:/p/200") + with p.connect(): + f = p.request("get:/p/200") assert f.status_code == 502 assert b"SOCKS5 mode failure" in f.content @@ -496,21 +511,21 @@ class TestSocks5(tservers.SocksModeTest): mitmproxy doesn't support UDP or BIND SOCKS CMDs """ p = self.pathoc() - - socks.ClientGreeting( - socks.VERSION.SOCKS5, - [socks.METHOD.NO_AUTHENTICATION_REQUIRED] - ).to_file(p.wfile) - socks.Message( - socks.VERSION.SOCKS5, - socks.CMD.BIND, - socks.ATYP.DOMAINNAME, - ("example.com", 8080) - ).to_file(p.wfile) - - p.wfile.flush() - p.rfile.read(2) # read server greeting - f = p.request("get:/p/200") # the request doesn't matter, error response from handshake will be read anyway. + with p.connect(): + socks.ClientGreeting( + socks.VERSION.SOCKS5, + [socks.METHOD.NO_AUTHENTICATION_REQUIRED] + ).to_file(p.wfile) + socks.Message( + socks.VERSION.SOCKS5, + socks.CMD.BIND, + socks.ATYP.DOMAINNAME, + ("example.com", 8080) + ).to_file(p.wfile) + + p.wfile.flush() + p.rfile.read(2) # read server greeting + f = p.request("get:/p/200") # the request doesn't matter, error response from handshake will be read anyway. assert f.status_code == 502 assert b"SOCKS5 mode failure" in f.content @@ -531,21 +546,23 @@ class TestHttps2Http(tservers.ReverseProxyTest): p = pathoc.Pathoc( ("localhost", self.proxy.port), ssl=True, sni=sni, fp=None ) - p.connect() return p def test_all(self): p = self.pathoc(ssl=True) - assert p.request("get:'/p/200'").status_code == 200 + with p.connect(): + assert p.request("get:'/p/200'").status_code == 200 def test_sni(self): p = self.pathoc(ssl=True, sni="example.com") - assert p.request("get:'/p/200'").status_code == 200 - assert all("Error in handle_sni" not in msg for msg in self.proxy.tlog) + with p.connect(): + assert p.request("get:'/p/200'").status_code == 200 + assert all("Error in handle_sni" not in msg for msg in self.proxy.tlog) def test_http(self): p = self.pathoc(ssl=False) - assert p.request("get:'/p/200'").status_code == 200 + with p.connect(): + assert p.request("get:'/p/200'").status_code == 200 class TestTransparent(tservers.TransparentProxyTest, CommonMixin, TcpMixin): @@ -703,29 +720,29 @@ class TestRedirectRequest(tservers.HTTPProxyTest): self.master.redirect_port = self.server2.port p = self.pathoc() - - self.server.clear_log() - self.server2.clear_log() - r1 = p.request("get:'/p/200'") - assert r1.status_code == 200 - assert self.server.last_log() - assert not self.server2.last_log() - - self.server.clear_log() - self.server2.clear_log() - r2 = p.request("get:'/p/201'") - assert r2.status_code == 201 - assert not self.server.last_log() - assert self.server2.last_log() - - self.server.clear_log() - self.server2.clear_log() - r3 = p.request("get:'/p/202'") - assert r3.status_code == 202 - assert self.server.last_log() - assert not self.server2.last_log() - - assert r1.content == r2.content == r3.content + with p.connect(): + self.server.clear_log() + self.server2.clear_log() + r1 = p.request("get:'/p/200'") + assert r1.status_code == 200 + assert self.server.last_log() + assert not self.server2.last_log() + + self.server.clear_log() + self.server2.clear_log() + r2 = p.request("get:'/p/201'") + assert r2.status_code == 201 + assert not self.server.last_log() + assert self.server2.last_log() + + self.server.clear_log() + self.server2.clear_log() + r3 = p.request("get:'/p/202'") + assert r3.status_code == 202 + assert self.server.last_log() + assert not self.server2.last_log() + + assert r1.content == r2.content == r3.content class MasterStreamRequest(tservers.TestMaster): @@ -743,22 +760,22 @@ class TestStreamRequest(tservers.HTTPProxyTest): def test_stream_simple(self): p = self.pathoc() - - # a request with 100k of data but without content-length - r1 = p.request("get:'%s/p/200:r:b@100k:d102400'" % self.server.urlbase) - assert r1.status_code == 200 - assert len(r1.content) > 100000 + with p.connect(): + # a request with 100k of data but without content-length + r1 = p.request("get:'%s/p/200:r:b@100k:d102400'" % self.server.urlbase) + assert r1.status_code == 200 + assert len(r1.content) > 100000 def test_stream_multiple(self): p = self.pathoc() + with p.connect(): + # simple request with streaming turned on + r1 = p.request("get:'%s/p/200'" % self.server.urlbase) + assert r1.status_code == 200 - # simple request with streaming turned on - r1 = p.request("get:'%s/p/200'" % self.server.urlbase) - assert r1.status_code == 200 - - # now send back 100k of data, streamed but not chunked - r1 = p.request("get:'%s/p/201:b@100k'" % self.server.urlbase) - assert r1.status_code == 201 + # now send back 100k of data, streamed but not chunked + r1 = p.request("get:'%s/p/201:b@100k'" % self.server.urlbase) + assert r1.status_code == 201 def test_stream_chunked(self): connection = socket.socket(socket.AF_INET, socket.SOCK_STREAM) @@ -887,7 +904,8 @@ class TestUpstreamProxy(tservers.HTTPUpstreamProxyTest, CommonMixin, AppMixin): ("~s", "baz", "ORLY") ] p = self.pathoc() - req = p.request("get:'%s/p/418:b\"foo\"'" % self.server.urlbase) + with p.connect(): + req = p.request("get:'%s/p/418:b\"foo\"'" % self.server.urlbase) assert req.content == b"ORLY" assert req.status_code == 418 @@ -948,7 +966,8 @@ class TestUpstreamProxySSL( def test_simple(self): p = self.pathoc() - req = p.request("get:'/p/418:b\"content\"'") + with p.connect(): + req = p.request("get:'/p/418:b\"content\"'") assert req.content == b"content" assert req.status_code == 418 @@ -1006,48 +1025,49 @@ class TestProxyChainingSSLReconnect(tservers.HTTPUpstreamProxyTest): ]) p = self.pathoc() - req = p.request("get:'/p/418:b\"content\"'") - assert req.content == b"content" - assert req.status_code == 418 - - assert self.proxy.tmaster.state.flow_count() == 2 # CONNECT and request - # CONNECT, failing request, - assert self.chain[0].tmaster.state.flow_count() == 4 - # reCONNECT, request - # failing request, request - assert self.chain[1].tmaster.state.flow_count() == 2 - # (doesn't store (repeated) CONNECTs from chain[0] - # as it is a regular proxy) - - assert not self.chain[1].tmaster.state.flows[0].response # killed - assert self.chain[1].tmaster.state.flows[1].response - - assert self.proxy.tmaster.state.flows[0].request.first_line_format == "authority" - assert self.proxy.tmaster.state.flows[1].request.first_line_format == "relative" - - assert self.chain[0].tmaster.state.flows[ - 0].request.first_line_format == "authority" - assert self.chain[0].tmaster.state.flows[ - 1].request.first_line_format == "relative" - assert self.chain[0].tmaster.state.flows[ - 2].request.first_line_format == "authority" - assert self.chain[0].tmaster.state.flows[ - 3].request.first_line_format == "relative" - - assert self.chain[1].tmaster.state.flows[ - 0].request.first_line_format == "relative" - assert self.chain[1].tmaster.state.flows[ - 1].request.first_line_format == "relative" - - req = p.request("get:'/p/418:b\"content2\"'") - - assert req.status_code == 502 - assert self.proxy.tmaster.state.flow_count() == 3 # + new request - # + new request, repeated CONNECT from chain[1] - assert self.chain[0].tmaster.state.flow_count() == 6 - # (both terminated) - # nothing happened here - assert self.chain[1].tmaster.state.flow_count() == 2 + with p.connect(): + req = p.request("get:'/p/418:b\"content\"'") + assert req.content == b"content" + assert req.status_code == 418 + + assert self.proxy.tmaster.state.flow_count() == 2 # CONNECT and request + # CONNECT, failing request, + assert self.chain[0].tmaster.state.flow_count() == 4 + # reCONNECT, request + # failing request, request + assert self.chain[1].tmaster.state.flow_count() == 2 + # (doesn't store (repeated) CONNECTs from chain[0] + # as it is a regular proxy) + + assert not self.chain[1].tmaster.state.flows[0].response # killed + assert self.chain[1].tmaster.state.flows[1].response + + assert self.proxy.tmaster.state.flows[0].request.first_line_format == "authority" + assert self.proxy.tmaster.state.flows[1].request.first_line_format == "relative" + + assert self.chain[0].tmaster.state.flows[ + 0].request.first_line_format == "authority" + assert self.chain[0].tmaster.state.flows[ + 1].request.first_line_format == "relative" + assert self.chain[0].tmaster.state.flows[ + 2].request.first_line_format == "authority" + assert self.chain[0].tmaster.state.flows[ + 3].request.first_line_format == "relative" + + assert self.chain[1].tmaster.state.flows[ + 0].request.first_line_format == "relative" + assert self.chain[1].tmaster.state.flows[ + 1].request.first_line_format == "relative" + + req = p.request("get:'/p/418:b\"content2\"'") + + assert req.status_code == 502 + assert self.proxy.tmaster.state.flow_count() == 3 # + new request + # + new request, repeated CONNECT from chain[1] + assert self.chain[0].tmaster.state.flow_count() == 6 + # (both terminated) + # nothing happened here + assert self.chain[1].tmaster.state.flow_count() == 2 class AddUpstreamCertsToClientChainMixin: @@ -1066,12 +1086,13 @@ class AddUpstreamCertsToClientChainMixin: d = f.read() upstreamCert = SSLCert.from_pem(d) p = self.pathoc() - upstream_cert_found_in_client_chain = False - for receivedCert in p.server_certs: - if receivedCert.digest('sha256') == upstreamCert.digest('sha256'): - upstream_cert_found_in_client_chain = True - break - assert(upstream_cert_found_in_client_chain == self.master.options.add_upstream_certs_to_client_chain) + with p.connect(): + upstream_cert_found_in_client_chain = False + for receivedCert in p.server_certs: + if receivedCert.digest('sha256') == upstreamCert.digest('sha256'): + upstream_cert_found_in_client_chain = True + break + assert(upstream_cert_found_in_client_chain == self.master.options.add_upstream_certs_to_client_chain) class TestHTTPSAddUpstreamCertsToClientChainTrue( diff --git a/test/mitmproxy/tservers.py b/test/mitmproxy/tservers.py index 1597f59c..4291f743 100644 --- a/test/mitmproxy/tservers.py +++ b/test/mitmproxy/tservers.py @@ -3,6 +3,7 @@ import threading import tempfile import flask import mock +import sys from mitmproxy.proxy.config import ProxyConfig from mitmproxy.proxy.server import ProxyServer @@ -10,6 +11,7 @@ import pathod.test import pathod.pathoc from mitmproxy import flow, controller, options from mitmproxy import builtins +import netlib.exceptions testapp = flask.Flask(__name__) @@ -104,6 +106,14 @@ class ProxyTestBase(object): cls.server.shutdown() cls.server2.shutdown() + def teardown(self): + try: + self.server.wait_for_silence() + except netlib.exceptions.Timeout: + # FIXME: Track down the Windows sync issues + if sys.platform != "win32": + raise + def setup(self): self.master.clear_log() self.master.state.clear() @@ -125,6 +135,15 @@ class ProxyTestBase(object): ) +class LazyPathoc(pathod.pathoc.Pathoc): + def __init__(self, lazy_connect, *args, **kwargs): + self.lazy_connect = lazy_connect + pathod.pathoc.Pathoc.__init__(self, *args, **kwargs) + + def connect(self): + return pathod.pathoc.Pathoc.connect(self, self.lazy_connect) + + class HTTPProxyTest(ProxyTestBase): def pathoc_raw(self): @@ -134,14 +153,14 @@ class HTTPProxyTest(ProxyTestBase): """ Returns a connected Pathoc instance. """ - p = pathod.pathoc.Pathoc( - ("localhost", self.proxy.port), ssl=self.ssl, sni=sni, fp=None - ) if self.ssl: - p.connect(("127.0.0.1", self.server.port)) + conn = ("127.0.0.1", self.server.port) else: - p.connect() - return p + conn = None + return LazyPathoc( + conn, + ("localhost", self.proxy.port), ssl=self.ssl, sni=sni, fp=None + ) def pathod(self, spec, sni=None): """ @@ -152,18 +171,20 @@ class HTTPProxyTest(ProxyTestBase): q = "get:'/p/%s'" % spec else: q = "get:'%s/p/%s'" % (self.server.urlbase, spec) - return p.request(q) + with p.connect(): + return p.request(q) def app(self, page): if self.ssl: p = pathod.pathoc.Pathoc( ("127.0.0.1", self.proxy.port), True, fp=None ) - p.connect((options.APP_HOST, options.APP_PORT)) - return p.request("get:'%s'" % page) + with p.connect((options.APP_HOST, options.APP_PORT)): + return p.request("get:'%s'" % page) else: p = self.pathoc() - return p.request("get:'http://%s%s'" % (options.APP_HOST, page)) + with p.connect(): + return p.request("get:'http://%s%s'" % (options.APP_HOST, page)) class TResolver: @@ -210,7 +231,8 @@ class TransparentProxyTest(ProxyTestBase): else: p = self.pathoc() q = "get:'/p/%s'" % spec - return p.request(q) + with p.connect(): + return p.request(q) def pathoc(self, sni=None): """ @@ -219,7 +241,6 @@ class TransparentProxyTest(ProxyTestBase): p = pathod.pathoc.Pathoc( ("localhost", self.proxy.port), ssl=self.ssl, sni=sni, fp=None ) - p.connect() return p @@ -247,7 +268,6 @@ class ReverseProxyTest(ProxyTestBase): p = pathod.pathoc.Pathoc( ("localhost", self.proxy.port), ssl=self.ssl, sni=sni, fp=None ) - p.connect() return p def pathod(self, spec, sni=None): @@ -260,7 +280,8 @@ class ReverseProxyTest(ProxyTestBase): else: p = self.pathoc() q = "get:'/p/%s'" % spec - return p.request(q) + with p.connect(): + return p.request(q) class SocksModeTest(HTTPProxyTest): diff --git a/test/netlib/http/test_headers.py b/test/netlib/http/test_headers.py index ad2bc548..e8752c52 100644 --- a/test/netlib/http/test_headers.py +++ b/test/netlib/http/test_headers.py @@ -43,6 +43,15 @@ class TestHeaders(object): with raises(TypeError): Headers([[b"Host", u"not-bytes"]]) + def test_set(self): + headers = Headers() + headers[u"foo"] = u"1" + headers[b"bar"] = b"2" + headers["baz"] = b"3" + with raises(TypeError): + headers["foobar"] = 42 + assert len(headers) == 3 + def test_bytes(self): headers = Headers(Host="example.com") assert bytes(headers) == b"Host: example.com\r\n" diff --git a/test/netlib/http/test_response.py b/test/netlib/http/test_response.py index c7b1b646..e97cc419 100644 --- a/test/netlib/http/test_response.py +++ b/test/netlib/http/test_response.py @@ -34,6 +34,11 @@ class TestResponseCore(object): assert r.status_code == 200 assert r.content == b"" + r = Response.make(418, "teatime") + assert r.status_code == 418 + assert r.content == b"teatime" + assert r.headers["content-length"] == "7" + Response.make(content=b"foo") Response.make(content="foo") with raises(TypeError): diff --git a/test/netlib/test_encoding.py b/test/netlib/test_encoding.py index 08e69ec5..e1175ef0 100644 --- a/test/netlib/test_encoding.py +++ b/test/netlib/test_encoding.py @@ -1,55 +1,46 @@ import mock +import pytest + from netlib import encoding, tutils -def test_identity(): - assert b"string" == encoding.decode(b"string", "identity") - assert b"string" == encoding.encode(b"string", "identity") +@pytest.mark.parametrize("encoder", [ + 'identity', + 'none', +]) +def test_identity(encoder): + assert b"string" == encoding.decode(b"string", encoder) + assert b"string" == encoding.encode(b"string", encoder) with tutils.raises(ValueError): encoding.encode(b"string", "nonexistent encoding") -def test_gzip(): - assert b"string" == encoding.decode( - encoding.encode( - b"string", - "gzip" - ), - "gzip" - ) - with tutils.raises(ValueError): - encoding.decode(b"bogus", "gzip") +@pytest.mark.parametrize("encoder", [ + 'gzip', + 'br', + 'deflate', +]) +def test_encoders(encoder): + assert "" == encoding.decode("", encoder) + assert b"" == encoding.decode(b"", encoder) - -def test_brotli(): - assert b"string" == encoding.decode( + assert "string" == encoding.decode( encoding.encode( - b"string", - "br" + "string", + encoder ), - "br" + encoder ) - with tutils.raises(ValueError): - encoding.decode(b"bogus", "br") - - -def test_deflate(): assert b"string" == encoding.decode( encoding.encode( b"string", - "deflate" + encoder ), - "deflate" - ) - assert b"string" == encoding.decode( - encoding.encode( - b"string", - "deflate" - )[2:-4], - "deflate" + encoder ) + with tutils.raises(ValueError): - encoding.decode(b"bogus", "deflate") + encoding.decode(b"foobar", encoder) def test_cache(): diff --git a/test/netlib/test_strutils.py b/test/netlib/test_strutils.py index 5be254a3..0f58cac5 100644 --- a/test/netlib/test_strutils.py +++ b/test/netlib/test_strutils.py @@ -8,6 +8,8 @@ def test_always_bytes(): assert strutils.always_bytes("foo") == b"foo" with tutils.raises(ValueError): strutils.always_bytes(u"\u2605", "ascii") + with tutils.raises(TypeError): + strutils.always_bytes(42, "ascii") def test_native(): diff --git a/test/pathod/test_protocols_http2.py b/test/pathod/test_protocols_http2.py index 8d7efc82..7f65c0eb 100644 --- a/test/pathod/test_protocols_http2.py +++ b/test/pathod/test_protocols_http2.py @@ -5,7 +5,7 @@ import hyperframe from netlib import tcp, http from netlib.tutils import raises from netlib.exceptions import TcpDisconnect -from netlib.http.http2 import framereader +from netlib.http import http2 from ..netlib import tservers as netlib_tservers @@ -112,11 +112,11 @@ class TestPerformServerConnectionPreface(netlib_tservers.ServerTestBase): self.wfile.flush() # check empty settings frame - raw = framereader.http2_read_raw_frame(self.rfile) + raw = http2.read_raw_frame(self.rfile) assert raw == codecs.decode('00000c040000000000000200000000000300000001', 'hex_codec') # check settings acknowledgement - raw = framereader.http2_read_raw_frame(self.rfile) + raw = http2.read_raw_frame(self.rfile) assert raw == codecs.decode('000000040100000000', 'hex_codec') # send settings acknowledgement |