diff options
author | Maximilian Hils <git@maximilianhils.com> | 2016-07-25 15:16:16 -0700 |
---|---|---|
committer | Maximilian Hils <git@maximilianhils.com> | 2016-07-25 15:16:16 -0700 |
commit | 79ebcb046e8669f80357a6c3046ec76c6adf49be (patch) | |
tree | 441981a16f1be1e620584e4a47f41767ce5585b2 | |
parent | 3254595584e1d711e7ae292ad34753a52f7a0fc1 (diff) | |
parent | 56796aeda25dda66621ce78af227ff46049ef811 (diff) | |
download | mitmproxy-79ebcb046e8669f80357a6c3046ec76c6adf49be.tar.gz mitmproxy-79ebcb046e8669f80357a6c3046ec76c6adf49be.tar.bz2 mitmproxy-79ebcb046e8669f80357a6c3046ec76c6adf49be.zip |
Merge remote-tracking branch 'origin/master' into flow_editing_v2
73 files changed, 933 insertions, 741 deletions
diff --git a/.travis.yml b/.travis.yml index e832d058..e9566ebe 100644 --- a/.travis.yml +++ b/.travis.yml @@ -20,10 +20,10 @@ matrix: include: - python: 3.5 env: TOXENV=lint -# - os: osx -# osx_image: xcode7.3 -# language: generic -# env: TOXENV=py35 + - os: osx + osx_image: xcode7.3 + language: generic + env: TOXENV=py35 - python: 3.5 env: TOXENV=py35 - python: 3.5 diff --git a/examples/filt.py b/examples/filt.py index 21744edd..9ccf9fa1 100644 --- a/examples/filt.py +++ b/examples/filt.py @@ -1,18 +1,20 @@ -# This scripts demonstrates how to use mitmproxy's filter pattern in inline scripts. +# This scripts demonstrates how to use mitmproxy's filter pattern in scripts. # Usage: mitmdump -s "filt.py FILTER" import sys from mitmproxy import filt -state = {} + +class Filter: + def __init__(self, spec): + self.filter = filt.parse(spec) + + def response(self, flow): + if flow.match(self.filter): + print("Flow matches filter:") + print(flow) def start(): if len(sys.argv) != 2: raise ValueError("Usage: -s 'filt.py FILTER'") - state["filter"] = filt.parse(sys.argv[1]) - - -def response(flow): - if flow.match(state["filter"]): - print("Flow matches filter:") - print(flow) + return Filter(sys.argv[1]) diff --git a/examples/flowwriter.py b/examples/flowwriter.py index 07c7ca20..df2e5a40 100644 --- a/examples/flowwriter.py +++ b/examples/flowwriter.py @@ -3,20 +3,21 @@ import sys from mitmproxy.flow import FlowWriter -state = {} + +class Writer: + def __init__(self, path): + if path == "-": + f = sys.stdout + else: + f = open(path, "wb") + self.w = FlowWriter(f) + + def response(self, flow): + if random.choice([True, False]): + self.w.add(flow) def start(): if len(sys.argv) != 2: raise ValueError('Usage: -s "flowriter.py filename"') - - if sys.argv[1] == "-": - f = sys.stdout - else: - f = open(sys.argv[1], "wb") - state["flow_writer"] = FlowWriter(f) - - -def response(flow): - if random.choice([True, False]): - state["flow_writer"].add(flow) + return Writer(sys.argv[1]) diff --git a/examples/iframe_injector.py b/examples/iframe_injector.py index 352c3c24..33d18bbd 100644 --- a/examples/iframe_injector.py +++ b/examples/iframe_injector.py @@ -3,26 +3,27 @@ import sys from bs4 import BeautifulSoup -iframe_url = None + +class Injector: + def __init__(self, iframe_url): + self.iframe_url = iframe_url + + def response(self, flow): + if flow.request.host in self.iframe_url: + return + html = BeautifulSoup(flow.response.content, "lxml") + if html.body: + iframe = html.new_tag( + "iframe", + src=self.iframe_url, + frameborder=0, + height=0, + width=0) + html.body.insert(0, iframe) + flow.response.content = str(html).encode("utf8") def start(): if len(sys.argv) != 2: raise ValueError('Usage: -s "iframe_injector.py url"') - global iframe_url - iframe_url = sys.argv[1] - - -def response(flow): - if flow.request.host in iframe_url: - return - html = BeautifulSoup(flow.response.content, "lxml") - if html.body: - iframe = html.new_tag( - "iframe", - src=iframe_url, - frameborder=0, - height=0, - width=0) - html.body.insert(0, iframe) - flow.response.content = str(html).encode("utf8") + return Injector(sys.argv[1]) diff --git a/examples/remote_debug.py b/examples/remote_debug.py new file mode 100644 index 00000000..fb864f78 --- /dev/null +++ b/examples/remote_debug.py @@ -0,0 +1,19 @@ +""" +This script enables remote debugging of the mitmproxy *UI* with PyCharm. +For general debugging purposes, it is easier to just debug mitmdump within PyCharm. + +Usage: + - pip install pydevd on the mitmproxy machine + - Open the Run/Debug Configuration dialog box in PyCharm, and select the Python Remote Debug configuration type. + - Debugging works in the way that mitmproxy connects to the debug server on startup. + Specify host and port that mitmproxy can use to reach your PyCharm instance on startup. + - Adjust this inline script accordingly. + - Start debug server in PyCharm + - Set breakpoints + - Start mitmproxy -s remote_debug.py +""" + + +def start(): + import pydevd + pydevd.settrace("localhost", port=5678, stdoutToServer=True, stderrToServer=True) diff --git a/examples/stub.py b/examples/stub.py index e5b4a39a..4f5061e2 100644 --- a/examples/stub.py +++ b/examples/stub.py @@ -11,7 +11,7 @@ def start(): mitmproxy.ctx.log("start") -def configure(options): +def configure(options, updated): """ Called once on script startup before any other events, and whenever options changes. """ diff --git a/mitmproxy/addons.py b/mitmproxy/addons.py index c779aaf8..a4bea9fa 100644 --- a/mitmproxy/addons.py +++ b/mitmproxy/addons.py @@ -13,16 +13,23 @@ class Addons(object): self.master = master master.options.changed.connect(self.options_update) - def options_update(self, options): + def options_update(self, options, updated): for i in self.chain: with self.master.handlecontext(): - i.configure(options) + i.configure(options, updated) - def add(self, *addons): + def add(self, options, *addons): + if not addons: + raise ValueError("No adons specified.") self.chain.extend(addons) for i in addons: self.invoke_with_context(i, "start") - self.invoke_with_context(i, "configure", self.master.options) + self.invoke_with_context( + i, + "configure", + self.master.options, + self.master.options.keys() + ) def remove(self, addon): self.chain = [i for i in self.chain if i is not addon] diff --git a/mitmproxy/builtins/anticache.py b/mitmproxy/builtins/anticache.py index f208e2fb..41a5ed95 100644 --- a/mitmproxy/builtins/anticache.py +++ b/mitmproxy/builtins/anticache.py @@ -5,7 +5,7 @@ class AntiCache: def __init__(self): self.enabled = False - def configure(self, options): + def configure(self, options, updated): self.enabled = options.anticache def request(self, flow): diff --git a/mitmproxy/builtins/anticomp.py b/mitmproxy/builtins/anticomp.py index 50bd1b73..823e960c 100644 --- a/mitmproxy/builtins/anticomp.py +++ b/mitmproxy/builtins/anticomp.py @@ -5,7 +5,7 @@ class AntiComp: def __init__(self): self.enabled = False - def configure(self, options): + def configure(self, options, updated): self.enabled = options.anticomp def request(self, flow): diff --git a/mitmproxy/builtins/dumper.py b/mitmproxy/builtins/dumper.py index 239630fb..74c2e6b2 100644 --- a/mitmproxy/builtins/dumper.py +++ b/mitmproxy/builtins/dumper.py @@ -5,6 +5,8 @@ import traceback import click +import typing # noqa + from mitmproxy import contentviews from mitmproxy import ctx from mitmproxy import exceptions @@ -19,12 +21,25 @@ def indent(n, text): return "\n".join(pad + i for i in l) -class Dumper(): +class Dumper(object): def __init__(self): - self.filter = None - self.flow_detail = None - self.outfp = None - self.showhost = None + self.filter = None # type: filt.TFilter + self.flow_detail = None # type: int + self.outfp = None # type: typing.io.TextIO + self.showhost = None # type: bool + + def configure(self, options, updated): + if options.filtstr: + self.filter = filt.parse(options.filtstr) + if not self.filter: + raise exceptions.OptionsError( + "Invalid filter expression: %s" % options.filtstr + ) + else: + self.filter = None + self.flow_detail = options.flow_detail + self.outfp = options.tfile + self.showhost = options.showhost def echo(self, text, ident=None, **style): if ident: @@ -59,7 +74,7 @@ class Dumper(): self.echo("") try: - type, lines = contentviews.get_content_view( + _, lines = contentviews.get_content_view( contentviews.get("Auto"), content, headers=getattr(message, "headers", None) @@ -67,7 +82,7 @@ class Dumper(): except exceptions.ContentViewException: s = "Content viewer failed: \n" + traceback.format_exc() ctx.log.debug(s) - type, lines = contentviews.get_content_view( + _, lines = contentviews.get_content_view( contentviews.get("Raw"), content, headers=getattr(message, "headers", None) @@ -114,9 +129,8 @@ class Dumper(): if flow.client_conn: client = click.style( strutils.escape_control_characters( - flow.client_conn.address.host - ), - bold=True + repr(flow.client_conn.address) + ) ) elif flow.request.is_replay: client = click.style("[replay]", fg="yellow", bold=True) @@ -139,17 +153,23 @@ class Dumper(): url = flow.request.url url = click.style(strutils.escape_control_characters(url), bold=True) - httpversion = "" + http_version = "" if flow.request.http_version not in ("HTTP/1.1", "HTTP/1.0"): # We hide "normal" HTTP 1. - httpversion = " " + flow.request.http_version + http_version = " " + flow.request.http_version - line = "{stickycookie}{client} {method} {url}{httpversion}".format( - stickycookie=stickycookie, + if self.flow_detail >= 2: + linebreak = "\n " + else: + linebreak = "" + + line = "{client}: {linebreak}{stickycookie}{method} {url}{http_version}".format( client=client, + stickycookie=stickycookie, + linebreak=linebreak, method=method, url=url, - httpversion=httpversion + http_version=http_version ) self.echo(line) @@ -185,9 +205,14 @@ class Dumper(): size = human.pretty_size(len(flow.response.raw_content)) size = click.style(size, bold=True) - arrows = click.style(" <<", bold=True) + arrows = click.style(" <<", bold=True) + if self.flow_detail == 1: + # This aligns the HTTP response code with the HTTP request method: + # 127.0.0.1:59519: GET http://example.com/ + # << 304 Not Modified 0b + arrows = " " * (len(repr(flow.client_conn.address)) - 2) + arrows - line = "{replay} {arrows} {code} {reason} {size}".format( + line = "{replay}{arrows} {code} {reason} {size}".format( replay=replay, arrows=arrows, code=code, @@ -211,25 +236,12 @@ class Dumper(): def match(self, f): if self.flow_detail == 0: return False - if not self.filt: + if not self.filter: return True - elif f.match(self.filt): + elif f.match(self.filter): return True return False - def configure(self, options): - if options.filtstr: - self.filt = filt.parse(options.filtstr) - if not self.filt: - raise exceptions.OptionsError( - "Invalid filter expression: %s" % options.filtstr - ) - else: - self.filt = None - self.flow_detail = options.flow_detail - self.outfp = options.tfile - self.showhost = options.showhost - def response(self, f): if self.match(f): self.echo_flow(f) @@ -239,8 +251,7 @@ class Dumper(): self.echo_flow(f) def tcp_message(self, f): - # FIXME: Filter should be applied here - if self.options.flow_detail == 0: + if not self.match(f): return message = f.messages[-1] direction = "->" if message.from_client else "<-" diff --git a/mitmproxy/builtins/filestreamer.py b/mitmproxy/builtins/filestreamer.py index 97ddc7c4..ffa565ac 100644 --- a/mitmproxy/builtins/filestreamer.py +++ b/mitmproxy/builtins/filestreamer.py @@ -19,7 +19,7 @@ class FileStreamer: self.stream = io.FilteredFlowWriter(f, filt) self.active_flows = set() - def configure(self, options): + def configure(self, options, updated): # We're already streaming - stop the previous stream and restart if self.stream: self.done() diff --git a/mitmproxy/builtins/replace.py b/mitmproxy/builtins/replace.py index 83b96cee..74d30c05 100644 --- a/mitmproxy/builtins/replace.py +++ b/mitmproxy/builtins/replace.py @@ -8,7 +8,7 @@ class Replace: def __init__(self): self.lst = [] - def configure(self, options): + def configure(self, options, updated): """ .replacements is a list of tuples (fpat, rex, s): diff --git a/mitmproxy/builtins/script.py b/mitmproxy/builtins/script.py index ab068e47..c960dd1c 100644 --- a/mitmproxy/builtins/script.py +++ b/mitmproxy/builtins/script.py @@ -16,6 +16,19 @@ import watchdog.events from watchdog.observers import polling +class NS: + def __init__(self, ns): + self.__dict__["ns"] = ns + + def __getattr__(self, key): + if key not in self.ns: + raise AttributeError("No such element: %s", key) + return self.ns[key] + + def __setattr__(self, key, value): + self.__dict__["ns"][key] = value + + def parse_command(command): """ Returns a (path, args) tuple. @@ -74,18 +87,27 @@ def load_script(path, args): ns = {'__file__': os.path.abspath(path)} with scriptenv(path, args): exec(code, ns, ns) - return ns + return NS(ns) class ReloadHandler(watchdog.events.FileSystemEventHandler): def __init__(self, callback): self.callback = callback + def filter(self, event): + if event.is_directory: + return False + if os.path.basename(event.src_path).startswith("."): + return False + return True + def on_modified(self, event): - self.callback() + if self.filter(event): + self.callback() def on_created(self, event): - self.callback() + if self.filter(event): + self.callback() class Script: @@ -118,29 +140,35 @@ class Script: # It's possible for ns to be un-initialised if we failed during # configure if self.ns is not None and not self.dead: - func = self.ns.get(name) + func = getattr(self.ns, name, None) if func: with scriptenv(self.path, self.args): - func(*args, **kwargs) + return func(*args, **kwargs) def reload(self): self.should_reload.set() + def load_script(self): + self.ns = load_script(self.path, self.args) + ret = self.run("start") + if ret: + self.ns = ret + self.run("start") + def tick(self): if self.should_reload.is_set(): self.should_reload.clear() ctx.log.info("Reloading script: %s" % self.name) self.ns = load_script(self.path, self.args) self.start() - self.configure(self.last_options) + self.configure(self.last_options, self.last_options.keys()) else: self.run("tick") def start(self): - self.ns = load_script(self.path, self.args) - self.run("start") + self.load_script() - def configure(self, options): + def configure(self, options, updated): self.last_options = options if not self.observer: self.observer = polling.PollingObserver() @@ -150,7 +178,7 @@ class Script: os.path.dirname(self.path) or "." ) self.observer.start() - self.run("configure", options) + self.run("configure", options, updated) def done(self): self.run("done") @@ -161,26 +189,27 @@ class ScriptLoader(): """ An addon that manages loading scripts from options. """ - def configure(self, options): - for s in options.scripts: - if options.scripts.count(s) > 1: - raise exceptions.OptionsError("Duplicate script: %s" % s) - - for a in ctx.master.addons.chain[:]: - if isinstance(a, Script) and a.name not in options.scripts: - ctx.log.info("Un-loading script: %s" % a.name) - ctx.master.addons.remove(a) - - current = {} - for a in ctx.master.addons.chain[:]: - if isinstance(a, Script): - current[a.name] = a - ctx.master.addons.chain.remove(a) - - for s in options.scripts: - if s in current: - ctx.master.addons.chain.append(current[s]) - else: - ctx.log.info("Loading script: %s" % s) - sc = Script(s) - ctx.master.addons.add(sc) + def configure(self, options, updated): + if "scripts" in updated: + for s in options.scripts: + if options.scripts.count(s) > 1: + raise exceptions.OptionsError("Duplicate script: %s" % s) + + for a in ctx.master.addons.chain[:]: + if isinstance(a, Script) and a.name not in options.scripts: + ctx.log.info("Un-loading script: %s" % a.name) + ctx.master.addons.remove(a) + + current = {} + for a in ctx.master.addons.chain[:]: + if isinstance(a, Script): + current[a.name] = a + ctx.master.addons.chain.remove(a) + + for s in options.scripts: + if s in current: + ctx.master.addons.chain.append(current[s]) + else: + ctx.log.info("Loading script: %s" % s) + sc = Script(s) + ctx.master.addons.add(options, sc) diff --git a/mitmproxy/builtins/setheaders.py b/mitmproxy/builtins/setheaders.py index 6bda3f55..4a784a1d 100644 --- a/mitmproxy/builtins/setheaders.py +++ b/mitmproxy/builtins/setheaders.py @@ -6,7 +6,7 @@ class SetHeaders: def __init__(self): self.lst = [] - def configure(self, options): + def configure(self, options, updated): """ options.setheaders is a tuple of (fpatt, header, value) diff --git a/mitmproxy/builtins/stickyauth.py b/mitmproxy/builtins/stickyauth.py index 1309911c..98fb65ed 100644 --- a/mitmproxy/builtins/stickyauth.py +++ b/mitmproxy/builtins/stickyauth.py @@ -10,7 +10,7 @@ class StickyAuth: self.flt = None self.hosts = {} - def configure(self, options): + def configure(self, options, updated): if options.stickyauth: flt = filt.parse(options.stickyauth) if not flt: diff --git a/mitmproxy/builtins/stickycookie.py b/mitmproxy/builtins/stickycookie.py index dc699bb4..88333d5c 100644 --- a/mitmproxy/builtins/stickycookie.py +++ b/mitmproxy/builtins/stickycookie.py @@ -32,7 +32,7 @@ class StickyCookie: self.jar = collections.defaultdict(dict) self.flt = None - def configure(self, options): + def configure(self, options, updated): if options.stickycookie: flt = filt.parse(options.stickycookie) if not flt: diff --git a/mitmproxy/console/common.py b/mitmproxy/console/common.py index 281fd658..9fb8b5c9 100644 --- a/mitmproxy/console/common.py +++ b/mitmproxy/console/common.py @@ -134,7 +134,11 @@ def save_data(path, data): if not path: return try: - with open(path, "wb") as f: + if isinstance(data, bytes): + mode = "wb" + else: + mode = "w" + with open(path, mode) as f: f.write(data) except IOError as v: signals.status_message.send(message=v.strerror) @@ -193,10 +197,9 @@ def ask_scope_and_callback(flow, cb, *args): def copy_to_clipboard_or_prompt(data): # pyperclip calls encode('utf-8') on data to be copied without checking. # if data are already encoded that way UnicodeDecodeError is thrown. - toclip = "" - try: - toclip = data.decode('utf-8') - except (UnicodeDecodeError): + if isinstance(data, bytes): + toclip = data.decode("utf8", "replace") + else: toclip = data try: @@ -216,7 +219,7 @@ def copy_to_clipboard_or_prompt(data): def format_flow_data(key, scope, flow): - data = "" + data = b"" if scope in ("q", "b"): request = flow.request.copy() request.decode(strict=False) @@ -230,7 +233,7 @@ def format_flow_data(key, scope, flow): raise ValueError("Unknown key: {}".format(key)) if scope == "b" and flow.request.raw_content and flow.response: # Add padding between request and response - data += "\r\n" * 2 + data += b"\r\n" * 2 if scope in ("s", "b") and flow.response: response = flow.response.copy() response.decode(strict=False) @@ -293,7 +296,7 @@ def ask_save_body(scope, flow): ) elif scope == "b" and request_has_content and response_has_content: ask_save_path( - (flow.request.get_content(strict=False) + "\n" + + (flow.request.get_content(strict=False) + b"\n" + flow.response.get_content(strict=False)), "Save request & response content to" ) @@ -407,7 +410,7 @@ def raw_format_flow(f, focus, extended): return urwid.Pile(pile) -def format_flow(f, focus, extended=False, hostheader=False, marked=False): +def format_flow(f, focus, extended=False, hostheader=False): d = dict( intercepted = f.intercepted, acked = f.reply.acked, @@ -420,7 +423,7 @@ def format_flow(f, focus, extended=False, hostheader=False, marked=False): err_msg = f.error.msg if f.error else None, - marked = marked, + marked = f.marked, ) if f.response: if f.response.raw_content: diff --git a/mitmproxy/console/flowlist.py b/mitmproxy/console/flowlist.py index 53e934f1..43742083 100644 --- a/mitmproxy/console/flowlist.py +++ b/mitmproxy/console/flowlist.py @@ -120,23 +120,17 @@ class ConnectionItem(urwid.WidgetWrap): self.flow, self.f, hostheader = self.master.options.showhost, - marked=self.state.flow_marked(self.flow) ) def selectable(self): return True def save_flows_prompt(self, k): - if k == "a": + if k == "l": signals.status_prompt_path.send( - prompt = "Save all flows to", + prompt = "Save listed flows to", callback = self.master.save_flows ) - elif k == "m": - signals.status_prompt_path.send( - prompt = "Save marked flows to", - callback = self.master.save_marked_flows - ) else: signals.status_prompt_path.send( prompt = "Save this flow to", @@ -188,17 +182,16 @@ class ConnectionItem(urwid.WidgetWrap): self.flow.accept_intercept(self.master) signals.flowlist_change.send(self) elif key == "d": - self.flow.kill(self.master) + if not self.flow.reply.acked: + self.flow.kill(self.master) self.state.delete_flow(self.flow) signals.flowlist_change.send(self) elif key == "D": f = self.master.duplicate_flow(self.flow) - self.master.view_flow(f) + self.master.state.set_focus_flow(f) + signals.flowlist_change.send(self) elif key == "m": - if self.state.flow_marked(self.flow): - self.state.set_flow_marked(self.flow, False) - else: - self.state.set_flow_marked(self.flow, True) + self.flow.marked = not self.flow.marked signals.flowlist_change.send(self) elif key == "M": if self.state.mark_filter: @@ -233,7 +226,7 @@ class ConnectionItem(urwid.WidgetWrap): ) elif key == "U": for f in self.state.flows: - self.state.set_flow_marked(f, False) + f.marked = False signals.flowlist_change.send(self) elif key == "V": if not self.flow.modified(): @@ -247,14 +240,14 @@ class ConnectionItem(urwid.WidgetWrap): self, prompt = "Save", keys = ( - ("all flows", "a"), + ("listed flows", "l"), ("this flow", "t"), - ("marked flows", "m"), ), callback = self.save_flows_prompt, ) elif key == "X": - self.flow.kill(self.master) + if not self.flow.reply.acked: + self.flow.kill(self.master) elif key == "enter": if self.flow.request: self.master.view_flow(self.flow) @@ -356,7 +349,8 @@ class FlowListBox(urwid.ListBox): return scheme, host, port, path = parts f = self.master.create_request(method, scheme, host, port, path) - self.master.view_flow(f) + self.master.state.set_focus_flow(f) + signals.flowlist_change.send(self) def keypress(self, size, key): key = common.shortcuts(key) diff --git a/mitmproxy/console/flowview.py b/mitmproxy/console/flowview.py index 938c8e86..c354563f 100644 --- a/mitmproxy/console/flowview.py +++ b/mitmproxy/console/flowview.py @@ -6,6 +6,7 @@ import sys import traceback import urwid +from typing import Optional, Union # noqa from mitmproxy import contentviews from mitmproxy import controller @@ -38,7 +39,7 @@ def _mkhelp(): ("d", "delete flow"), ("e", "edit request/response"), ("f", "load full body data"), - ("m", "change body display mode for this entity"), + ("m", "change body display mode for this entity\n(default mode can be changed in the options)"), (None, common.highlight_key("automatic", "a") + [("text", ": automatic detection")] @@ -75,7 +76,6 @@ def _mkhelp(): common.highlight_key("xml", "x") + [("text", ": XML")] ), - ("M", "change default body display mode"), ("E", "export flow to file"), ("r", "replay request"), ("V", "revert changes to request"), @@ -105,7 +105,8 @@ footer = [ class FlowViewHeader(urwid.WidgetWrap): def __init__(self, master, f): - self.master, self.flow = master, f + self.master = master # type: "mitmproxy.console.master.ConsoleMaster" + self.flow = f # type: models.HTTPFlow self._w = common.format_flow( f, False, @@ -135,14 +136,15 @@ class FlowView(tabs.Tabs): def __init__(self, master, state, flow, tab_offset): self.master, self.state, self.flow = master, state, flow - tabs.Tabs.__init__(self, - [ - (self.tab_request, self.view_request), - (self.tab_response, self.view_response), - (self.tab_details, self.view_details), - ], - tab_offset - ) + super(FlowView, self).__init__( + [ + (self.tab_request, self.view_request), + (self.tab_response, self.view_response), + (self.tab_details, self.view_details), + ], + tab_offset + ) + self.show() self.last_displayed_body = None signals.flow_change.connect(self.sig_flow_change) @@ -189,15 +191,21 @@ class FlowView(tabs.Tabs): limit = sys.maxsize else: limit = contentviews.VIEW_CUTOFF + + flow_modify_cache_invalidation = hash(( + message.raw_content, + message.headers.fields, + getattr(message, "path", None), + )) return cache.get( - self._get_content_view, + # We move message into this partial function as it is not hashable. + lambda *args: self._get_content_view(message, *args), viewmode, - message, limit, - message # Cache invalidation + flow_modify_cache_invalidation ) - def _get_content_view(self, viewmode, message, max_lines, _): + def _get_content_view(self, message, viewmode, max_lines, _): try: content = message.content @@ -396,7 +404,7 @@ class FlowView(tabs.Tabs): if not self.flow.response: self.flow.response = models.HTTPResponse( self.flow.request.http_version, - 200, "OK", Headers(), "" + 200, b"OK", Headers(), b"" ) self.flow.response.reply = controller.DummyReply() message = self.flow.response @@ -524,30 +532,24 @@ class FlowView(tabs.Tabs): ) signals.flow_change.send(self, flow = self.flow) - def delete_body(self, t): + def keypress(self, size, key): + conn = None # type: Optional[Union[models.HTTPRequest, models.HTTPResponse]] if self.tab_offset == TAB_REQ: - self.flow.request.content = None - else: - self.flow.response.content = None - signals.flow_change.send(self, flow = self.flow) + conn = self.flow.request + elif self.tab_offset == TAB_RESP: + conn = self.flow.response - def keypress(self, size, key): key = super(self.__class__, self).keypress(size, key) + # Special case: Space moves over to the next flow. + # We need to catch that before applying common.shortcuts() if key == " ": self.view_next_flow(self.flow) return key = common.shortcuts(key) - if self.tab_offset == TAB_REQ: - conn = self.flow.request - elif self.tab_offset == TAB_RESP: - conn = self.flow.response - else: - conn = None - if key in ("up", "down", "page up", "page down"): - # Why doesn't this just work?? + # Pass scroll events to the wrapped widget self._w.keypress(size, key) elif key == "a": self.flow.accept_intercept(self.master) @@ -563,10 +565,12 @@ class FlowView(tabs.Tabs): else: self.view_next_flow(self.flow) f = self.flow - f.kill(self.master) + if not f.reply.acked: + f.kill(self.master) self.state.delete_flow(f) elif key == "D": f = self.master.duplicate_flow(self.flow) + signals.pop_view_state.send(self) self.master.view_flow(f) signals.status_message.send(message="Duplicated.") elif key == "p": @@ -577,12 +581,12 @@ class FlowView(tabs.Tabs): signals.status_message.send(message=r) signals.flow_change.send(self, flow = self.flow) elif key == "V": - if not self.flow.modified(): + if self.flow.modified(): + self.state.revert(self.flow) + signals.flow_change.send(self, flow = self.flow) + signals.status_message.send(message="Reverted.") + else: signals.status_message.send(message="Flow not modified.") - return - self.state.revert(self.flow) - signals.flow_change.send(self, flow = self.flow) - signals.status_message.send(message="Reverted.") elif key == "W": signals.status_prompt_path.send( prompt = "Save this flow", @@ -595,133 +599,128 @@ class FlowView(tabs.Tabs): callback = self.master.run_script_once, args = (self.flow,) ) - - if not conn and key in set(list("befgmxvzEC")): + elif key == "e": + if self.tab_offset == TAB_REQ: + signals.status_prompt_onekey.send( + prompt="Edit request", + keys=( + ("cookies", "c"), + ("query", "q"), + ("path", "p"), + ("url", "u"), + ("header", "h"), + ("form", "f"), + ("raw body", "r"), + ("method", "m"), + ), + callback=self.edit + ) + elif self.tab_offset == TAB_RESP: + signals.status_prompt_onekey.send( + prompt="Edit response", + keys=( + ("cookies", "c"), + ("code", "o"), + ("message", "m"), + ("header", "h"), + ("raw body", "r"), + ), + callback=self.edit + ) + else: + signals.status_message.send( + message="Tab to the request or response", + expire=1 + ) + elif key in set("bfgmxvzEC") and not conn: signals.status_message.send( message = "Tab to the request or response", expire = 1 ) - elif conn: - if key == "b": - if self.tab_offset == TAB_REQ: - common.ask_save_body( - "q", self.master, self.state, self.flow - ) + return + elif key == "b": + if self.tab_offset == TAB_REQ: + common.ask_save_body("q", self.flow) + else: + common.ask_save_body("s", self.flow) + elif key == "f": + signals.status_message.send(message="Loading all body data...") + self.state.add_flow_setting( + self.flow, + (self.tab_offset, "fullcontents"), + True + ) + signals.flow_change.send(self, flow = self.flow) + signals.status_message.send(message="") + elif key == "m": + p = list(contentviews.view_prompts) + p.insert(0, ("Clear", "C")) + signals.status_prompt_onekey.send( + self, + prompt = "Display mode", + keys = p, + callback = self.change_this_display_mode + ) + elif key == "E": + if self.tab_offset == TAB_REQ: + scope = "q" + else: + scope = "s" + signals.status_prompt_onekey.send( + self, + prompt = "Export to file", + keys = [(e[0], e[1]) for e in export.EXPORTERS], + callback = common.export_to_clip_or_file, + args = (scope, self.flow, common.ask_save_path) + ) + elif key == "C": + if self.tab_offset == TAB_REQ: + scope = "q" + else: + scope = "s" + signals.status_prompt_onekey.send( + self, + prompt = "Export to clipboard", + keys = [(e[0], e[1]) for e in export.EXPORTERS], + callback = common.export_to_clip_or_file, + args = (scope, self.flow, common.copy_to_clipboard_or_prompt) + ) + elif key == "x": + conn.content = None + signals.flow_change.send(self, flow=self.flow) + elif key == "v": + if conn.raw_content: + t = conn.headers.get("content-type") + if "EDITOR" in os.environ or "PAGER" in os.environ: + self.master.spawn_external_viewer(conn.get_content(strict=False), t) else: - common.ask_save_body( - "s", self.master, self.state, self.flow - ) - elif key == "e": - if self.tab_offset == TAB_REQ: - signals.status_prompt_onekey.send( - prompt = "Edit request", - keys = ( - ("cookies", "c"), - ("query", "q"), - ("path", "p"), - ("url", "u"), - ("header", "h"), - ("form", "f"), - ("raw body", "r"), - ("method", "m"), - ), - callback = self.edit + signals.status_message.send( + message = "Error! Set $EDITOR or $PAGER." ) - else: - signals.status_prompt_onekey.send( - prompt = "Edit response", - keys = ( - ("cookies", "c"), - ("code", "o"), - ("message", "m"), - ("header", "h"), - ("raw body", "r"), - ), - callback = self.edit + elif key == "z": + self.flow.backup() + e = conn.headers.get("content-encoding", "identity") + if e != "identity": + try: + conn.decode() + except ValueError: + signals.status_message.send( + message = "Could not decode - invalid data?" ) - key = None - elif key == "f": - signals.status_message.send(message="Loading all body data...") - self.state.add_flow_setting( - self.flow, - (self.tab_offset, "fullcontents"), - True - ) - signals.flow_change.send(self, flow = self.flow) - signals.status_message.send(message="") - elif key == "m": - p = list(contentviews.view_prompts) - p.insert(0, ("Clear", "C")) - signals.status_prompt_onekey.send( - self, - prompt = "Display mode", - keys = p, - callback = self.change_this_display_mode - ) - key = None - elif key == "E": - if self.tab_offset == TAB_REQ: - scope = "q" - else: - scope = "s" - signals.status_prompt_onekey.send( - self, - prompt = "Export to file", - keys = [(e[0], e[1]) for e in export.EXPORTERS], - callback = common.export_to_clip_or_file, - args = (scope, self.flow, common.ask_save_path) - ) - elif key == "C": - if self.tab_offset == TAB_REQ: - scope = "q" - else: - scope = "s" - signals.status_prompt_onekey.send( - self, - prompt = "Export to clipboard", - keys = [(e[0], e[1]) for e in export.EXPORTERS], - callback = common.export_to_clip_or_file, - args = (scope, self.flow, common.copy_to_clipboard_or_prompt) - ) - elif key == "x": + else: signals.status_prompt_onekey.send( - prompt = "Delete body", + prompt = "Select encoding: ", keys = ( - ("completely", "c"), - ("mark as missing", "m"), + ("gzip", "z"), + ("deflate", "d"), ), - callback = self.delete_body + callback = self.encode_callback, + args = (conn,) ) - key = None - elif key == "v": - if conn.raw_content: - t = conn.headers.get("content-type") - if "EDITOR" in os.environ or "PAGER" in os.environ: - self.master.spawn_external_viewer(conn.get_content(strict=False), t) - else: - signals.status_message.send( - message = "Error! Set $EDITOR or $PAGER." - ) - elif key == "z": - self.flow.backup() - e = conn.headers.get("content-encoding", "identity") - if e != "identity": - if not conn.decode(): - signals.status_message.send( - message = "Could not decode - invalid data?" - ) - else: - signals.status_prompt_onekey.send( - prompt = "Select encoding: ", - keys = ( - ("gzip", "z"), - ("deflate", "d"), - ), - callback = self.encode_callback, - args = (conn,) - ) - signals.flow_change.send(self, flow = self.flow) - return key + signals.flow_change.send(self, flow = self.flow) + else: + # Key is not handled here. + return key def encode_callback(self, key, conn): encoding_map = { diff --git a/mitmproxy/console/help.py b/mitmproxy/console/help.py index 064d3cb5..ff4a072f 100644 --- a/mitmproxy/console/help.py +++ b/mitmproxy/console/help.py @@ -1,5 +1,7 @@ from __future__ import absolute_import, print_function, division +import platform + import urwid from mitmproxy import filt @@ -9,7 +11,7 @@ from mitmproxy.console import signals from netlib import version footer = [ - ("heading", 'mitmproxy v%s ' % version.VERSION), + ("heading", 'mitmproxy {} (Python {}) '.format(version.VERSION, platform.python_version())), ('heading_key', "q"), ":back ", ] diff --git a/mitmproxy/console/master.py b/mitmproxy/console/master.py index 4fd6cb78..db414147 100644 --- a/mitmproxy/console/master.py +++ b/mitmproxy/console/master.py @@ -34,6 +34,7 @@ from mitmproxy.console import palettes from mitmproxy.console import signals from mitmproxy.console import statusbar from mitmproxy.console import window +from mitmproxy.filt import FMarked from netlib import tcp, strutils EVENTLOG_SIZE = 500 @@ -48,7 +49,7 @@ class ConsoleState(flow.State): self.default_body_view = contentviews.get("Auto") self.flowsettings = weakref.WeakKeyDictionary() self.last_search = None - self.last_filter = None + self.last_filter = "" self.mark_filter = False def __setattr__(self, name, value): @@ -66,7 +67,6 @@ class ConsoleState(flow.State): def add_flow(self, f): super(ConsoleState, self).add_flow(f) self.update_focus() - self.set_flow_marked(f, False) return f def update_flow(self, f): @@ -86,10 +86,10 @@ class ConsoleState(flow.State): def set_focus(self, idx): if self.view: - if idx >= len(self.view): - idx = len(self.view) - 1 - elif idx < 0: + if idx is None or idx < 0: idx = 0 + elif idx >= len(self.view): + idx = len(self.view) - 1 self.focus = idx else: self.focus = None @@ -123,48 +123,71 @@ class ConsoleState(flow.State): self.set_focus(self.focus) return ret - def filter_marked(self, m): - def actual_func(x): - if x.id in m: - return True - return False - return actual_func + def get_nearest_matching_flow(self, flow, filt): + fidx = self.view.index(flow) + dist = 1 + + fprev = fnext = True + while fprev or fnext: + fprev, _ = self.get_from_pos(fidx - dist) + fnext, _ = self.get_from_pos(fidx + dist) + + if fprev and fprev.match(filt): + return fprev + elif fnext and fnext.match(filt): + return fnext + + dist += 1 + + return None def enable_marked_filter(self): + marked_flows = [f for f in self.flows if f.marked] + if not marked_flows: + return + + marked_filter = "~%s" % FMarked.code + + # Save Focus + last_focus, _ = self.get_focus() + nearest_marked = self.get_nearest_matching_flow(last_focus, marked_filter) + self.last_filter = self.limit_txt - marked_flows = [] - for f in self.flows: - if self.flow_marked(f): - marked_flows.append(f.id) - if len(marked_flows) > 0: - f = self.filter_marked(marked_flows) - self.view._close() - self.view = flow.FlowView(self.flows, f) - self.focus = 0 - self.set_focus(self.focus) - self.mark_filter = True + self.set_limit(marked_filter) + + # Restore Focus + if last_focus.marked: + self.set_focus_flow(last_focus) + else: + self.set_focus_flow(nearest_marked) + + self.mark_filter = True def disable_marked_filter(self): - if self.last_filter is None: - self.view = flow.FlowView(self.flows, None) + marked_filter = "~%s" % FMarked.code + + # Save Focus + last_focus, _ = self.get_focus() + nearest_marked = self.get_nearest_matching_flow(last_focus, marked_filter) + + self.set_limit(self.last_filter) + self.last_filter = "" + + # Restore Focus + if last_focus.marked: + self.set_focus_flow(last_focus) else: - self.set_limit(self.last_filter) - self.focus = 0 - self.set_focus(self.focus) - self.last_filter = None + self.set_focus_flow(nearest_marked) + self.mark_filter = False def clear(self): - marked_flows = [] - for f in self.flows: - if self.flow_marked(f): - marked_flows.append(f) - + marked_flows = [f for f in self.view if f.marked] super(ConsoleState, self).clear() for f in marked_flows: self.add_flow(f) - self.set_flow_marked(f, True) + f.marked = True if len(self.flows.views) == 0: self.focus = None @@ -172,12 +195,6 @@ class ConsoleState(flow.State): self.focus = 0 self.set_focus(self.focus) - def flow_marked(self, flow): - return self.get_flow_setting(flow, "marked", False) - - def set_flow_marked(self, flow, marked): - self.add_flow_setting(flow, "marked", marked) - class Options(mitmproxy.options.Options): def __init__( @@ -242,7 +259,7 @@ class ConsoleMaster(flow.FlowMaster): signals.pop_view_state.connect(self.sig_pop_view_state) signals.push_view_state.connect(self.sig_push_view_state) signals.sig_add_log.connect(self.sig_add_log) - self.addons.add(*builtins.default_addons()) + self.addons.add(options, *builtins.default_addons()) def __setattr__(self, name, value): self.__dict__[name] = value @@ -254,10 +271,6 @@ class ConsoleMaster(flow.FlowMaster): expire=1 ) - def load_script(self, command, use_reloader=True): - # We default to using the reloader in the console ui. - return super(ConsoleMaster, self).load_script(command, use_reloader) - def sig_add_log(self, sender, e, level): if self.options.verbosity < utils.log_tier(level): return @@ -352,7 +365,7 @@ class ConsoleMaster(flow.FlowMaster): try: return flow.read_flows_from_paths(path) except exceptions.FlowReadException as e: - signals.status_message.send(message=e.strerror) + signals.status_message.send(message=str(e)) def client_playback_path(self, path): if not isinstance(path, list): @@ -619,13 +632,6 @@ class ConsoleMaster(flow.FlowMaster): def save_flows(self, path): return self._write_flows(path, self.state.view) - def save_marked_flows(self, path): - marked_flows = [] - for f in self.state.view: - if self.state.flow_marked(f): - marked_flows.append(f) - return self._write_flows(path, marked_flows) - def load_flows_callback(self, path): if not path: return @@ -748,10 +754,3 @@ class ConsoleMaster(flow.FlowMaster): direction=direction, ), "info") self.add_log(strutils.bytes_to_escaped_str(message.content), "debug") - - @controller.handler - def script_change(self, script): - if super(ConsoleMaster, self).script_change(script): - signals.status_message.send(message='"{}" reloaded.'.format(script.path)) - else: - signals.status_message.send(message='Error reloading "{}".'.format(script.path)) diff --git a/mitmproxy/console/options.py b/mitmproxy/console/options.py index 62564a60..f9fc3764 100644 --- a/mitmproxy/console/options.py +++ b/mitmproxy/console/options.py @@ -140,7 +140,7 @@ class Options(urwid.WidgetWrap): ) self.master.loop.widget.footer.update("") signals.update_settings.connect(self.sig_update_settings) - master.options.changed.connect(self.sig_update_settings) + master.options.changed.connect(lambda sender, updated: self.sig_update_settings(sender)) def sig_update_settings(self, sender): self.lb.walker._modified() diff --git a/mitmproxy/console/searchable.py b/mitmproxy/console/searchable.py index c60d1cd9..d58d3d13 100644 --- a/mitmproxy/console/searchable.py +++ b/mitmproxy/console/searchable.py @@ -78,9 +78,9 @@ class Searchable(urwid.ListBox): return # Start search at focus + 1 if backwards: - rng = xrange(len(self.body) - 1, -1, -1) + rng = range(len(self.body) - 1, -1, -1) else: - rng = xrange(1, len(self.body) + 1) + rng = range(1, len(self.body) + 1) for i in rng: off = (self.focus_position + i) % len(self.body) w = self.body[off] diff --git a/mitmproxy/console/statusbar.py b/mitmproxy/console/statusbar.py index 3120fa71..156d1176 100644 --- a/mitmproxy/console/statusbar.py +++ b/mitmproxy/console/statusbar.py @@ -124,7 +124,7 @@ class StatusBar(urwid.WidgetWrap): super(StatusBar, self).__init__(urwid.Pile([self.ib, self.master.ab])) signals.update_settings.connect(self.sig_update_settings) signals.flowlist_change.connect(self.sig_update_settings) - master.options.changed.connect(self.sig_update_settings) + master.options.changed.connect(lambda sender, updated: self.sig_update_settings(sender)) self.redraw() def sig_update_settings(self, sender): @@ -171,10 +171,6 @@ class StatusBar(urwid.WidgetWrap): r.append("[") r.append(("heading_key", "l")) r.append(":%s]" % self.master.state.limit_txt) - if self.master.state.mark_filter: - r.append("[") - r.append(("heading_key", "Marked Flows")) - r.append("]") if self.master.options.stickycookie: r.append("[") r.append(("heading_key", "t")) diff --git a/mitmproxy/console/tabs.py b/mitmproxy/console/tabs.py index bfcdeba3..a5e9c510 100644 --- a/mitmproxy/console/tabs.py +++ b/mitmproxy/console/tabs.py @@ -25,7 +25,7 @@ class Tab(urwid.WidgetWrap): class Tabs(urwid.WidgetWrap): def __init__(self, tabs, tab_offset=0): - urwid.WidgetWrap.__init__(self, "") + super(Tabs, self).__init__("") self.tab_offset = tab_offset self.tabs = tabs self.show() diff --git a/mitmproxy/contentviews.py b/mitmproxy/contentviews.py index afdaad7f..e155bc01 100644 --- a/mitmproxy/contentviews.py +++ b/mitmproxy/contentviews.py @@ -20,6 +20,8 @@ import logging import subprocess import sys +from typing import Mapping # noqa + import html2text import lxml.etree import lxml.html @@ -76,6 +78,7 @@ def pretty_json(s): def format_dict(d): + # type: (Mapping[Union[str,bytes], Union[str,bytes]]) -> Generator[Tuple[Union[str,bytes], Union[str,bytes]]] """ Helper function that transforms the given dictionary into a list of ("key", key ) @@ -85,7 +88,7 @@ def format_dict(d): max_key_len = max(len(k) for k in d.keys()) max_key_len = min(max_key_len, KEY_MAX) for key, value in d.items(): - key += ":" + key += b":" if isinstance(key, bytes) else u":" key = key.ljust(max_key_len + 2) yield [ ("header", key), @@ -106,12 +109,16 @@ class View(object): prompt = () content_types = [] - def __call__(self, data, **metadata): + def __call__( + self, + data, # type: bytes + **metadata + ): """ Transform raw data into human-readable output. Args: - data: the data to decode/format as bytes. + data: the data to decode/format. metadata: optional keyword-only arguments for metadata. Implementations must not rely on a given argument being present. @@ -278,6 +285,10 @@ class ViewURLEncoded(View): content_types = ["application/x-www-form-urlencoded"] def __call__(self, data, **metadata): + try: + data = data.decode("ascii", "strict") + except ValueError: + return None d = url.decode(data) return "URLEncoded form", format_dict(multidict.MultiDict(d)) diff --git a/mitmproxy/controller.py b/mitmproxy/controller.py index 070ec862..35817a85 100644 --- a/mitmproxy/controller.py +++ b/mitmproxy/controller.py @@ -37,8 +37,6 @@ Events = frozenset([ "configure", "done", "tick", - - "script_change", ]) diff --git a/mitmproxy/ctx.py b/mitmproxy/ctx.py index fcfdfd0b..5d2905fa 100644 --- a/mitmproxy/ctx.py +++ b/mitmproxy/ctx.py @@ -1,4 +1,4 @@ from typing import Callable # noqa master = None # type: "mitmproxy.flow.FlowMaster" -log = None # type: Callable[[str], None] +log = None # type: "mitmproxy.controller.Log" diff --git a/mitmproxy/dump.py b/mitmproxy/dump.py index 4f34ab95..83f44d87 100644 --- a/mitmproxy/dump.py +++ b/mitmproxy/dump.py @@ -42,8 +42,8 @@ class DumpMaster(flow.FlowMaster): def __init__(self, server, options): flow.FlowMaster.__init__(self, options, server, flow.State()) self.has_errored = False - self.addons.add(*builtins.default_addons()) - self.addons.add(dumper.Dumper()) + self.addons.add(options, *builtins.default_addons()) + self.addons.add(options, dumper.Dumper()) # This line is just for type hinting self.options = self.options # type: Options self.replay_ignore_params = options.replay_ignore_params diff --git a/mitmproxy/filt.py b/mitmproxy/filt.py index 8b647b22..67915e5b 100644 --- a/mitmproxy/filt.py +++ b/mitmproxy/filt.py @@ -39,9 +39,12 @@ import functools from mitmproxy.models.http import HTTPFlow from mitmproxy.models.tcp import TCPFlow +from mitmproxy.models.flow import Flow + from netlib import strutils import pyparsing as pp +from typing import Callable def only(*types): @@ -80,6 +83,14 @@ class FErr(_Action): return True if f.error else False +class FMarked(_Action): + code = "marked" + help = "Match marked flows" + + def __call__(self, f): + return f.marked + + class FHTTP(_Action): code = "http" help = "Match HTTP flows" @@ -398,6 +409,7 @@ filt_unary = [ FAsset, FErr, FHTTP, + FMarked, FReq, FResp, FTCP, @@ -471,7 +483,11 @@ def _make(): bnf = _make() +TFilter = Callable[[Flow], bool] + + def parse(s): + # type: (str) -> TFilter try: filt = bnf.parseString(s, parseAll=True)[0] filt.pattern = s diff --git a/mitmproxy/flow/io_compat.py b/mitmproxy/flow/io_compat.py index 8cd883c3..061bf16d 100644 --- a/mitmproxy/flow/io_compat.py +++ b/mitmproxy/flow/io_compat.py @@ -60,6 +60,7 @@ def convert_017_018(data): data = convert_unicode(data) data["server_conn"]["ip_address"] = data["server_conn"].pop("peer_address") + data["marked"] = False data["version"] = (0, 18) return data diff --git a/mitmproxy/models/flow.py b/mitmproxy/models/flow.py index f4993b7a..f4a2b54b 100644 --- a/mitmproxy/models/flow.py +++ b/mitmproxy/models/flow.py @@ -8,6 +8,8 @@ from mitmproxy import stateobject from mitmproxy.models.connections import ClientConnection from mitmproxy.models.connections import ServerConnection +import six + from netlib import version from typing import Optional # noqa @@ -79,6 +81,7 @@ class Flow(stateobject.StateObject): self.intercepted = False # type: bool self._backup = None # type: Optional[Flow] self.reply = None + self.marked = False # type: bool _stateobject_attributes = dict( id=str, @@ -86,7 +89,8 @@ class Flow(stateobject.StateObject): client_conn=ClientConnection, server_conn=ServerConnection, type=str, - intercepted=bool + intercepted=bool, + marked=bool, ) def get_state(self): @@ -173,3 +177,21 @@ class Flow(stateobject.StateObject): self.intercepted = False self.reply.ack() master.handle_accept_intercept(self) + + def match(self, f): + """ + Match this flow against a compiled filter expression. Returns True + if matched, False if not. + + If f is a string, it will be compiled as a filter expression. If + the expression is invalid, ValueError is raised. + """ + if isinstance(f, six.string_types): + from .. import filt + + f = filt.parse(f) + if not f: + raise ValueError("Invalid filter expression.") + if f: + return f(self) + return True diff --git a/mitmproxy/models/http.py b/mitmproxy/models/http.py index 1fd28f00..7781e61f 100644 --- a/mitmproxy/models/http.py +++ b/mitmproxy/models/http.py @@ -2,7 +2,6 @@ from __future__ import absolute_import, print_function, division import cgi import warnings -import six from mitmproxy.models.flow import Flow from netlib import version @@ -211,24 +210,6 @@ class HTTPFlow(Flow): f.response = self.response.copy() return f - def match(self, f): - """ - Match this flow against a compiled filter expression. Returns True - if matched, False if not. - - If f is a string, it will be compiled as a filter expression. If - the expression is invalid, ValueError is raised. - """ - if isinstance(f, six.string_types): - from .. import filt - - f = filt.parse(f) - if not f: - raise ValueError("Invalid filter expression.") - if f: - return f(self) - return True - def replace(self, pattern, repl, *args, **kwargs): """ Replaces a regular expression pattern with repl in both request and diff --git a/mitmproxy/models/tcp.py b/mitmproxy/models/tcp.py index 6650141d..e33475c2 100644 --- a/mitmproxy/models/tcp.py +++ b/mitmproxy/models/tcp.py @@ -7,8 +7,6 @@ from typing import List import netlib.basetypes from mitmproxy.models.flow import Flow -import six - class TCPMessage(netlib.basetypes.Serializable): @@ -55,22 +53,3 @@ class TCPFlow(Flow): def __repr__(self): return "<TCPFlow ({} messages)>".format(len(self.messages)) - - def match(self, f): - """ - Match this flow against a compiled filter expression. Returns True - if matched, False if not. - - If f is a string, it will be compiled as a filter expression. If - the expression is invalid, ValueError is raised. - """ - if isinstance(f, six.string_types): - from .. import filt - - f = filt.parse(f) - if not f: - raise ValueError("Invalid filter expression.") - if f: - return f(self) - - return True diff --git a/mitmproxy/optmanager.py b/mitmproxy/optmanager.py index e94ef51d..140c7ca8 100644 --- a/mitmproxy/optmanager.py +++ b/mitmproxy/optmanager.py @@ -35,7 +35,7 @@ class OptManager(object): self.__dict__["_initialized"] = True @contextlib.contextmanager - def rollback(self): + def rollback(self, updated): old = self._opts.copy() try: yield @@ -44,7 +44,7 @@ class OptManager(object): self.errored.send(self, exc=e) # Rollback self.__dict__["_opts"] = old - self.changed.send(self) + self.changed.send(self, updated=updated) def __eq__(self, other): return self._opts == other._opts @@ -62,22 +62,22 @@ class OptManager(object): if not self._initialized: self._opts[attr] = value return - if attr not in self._opts: - raise KeyError("No such option: %s" % attr) - with self.rollback(): - self._opts[attr] = value - self.changed.send(self) + self.update(**{attr: value}) + + def keys(self): + return set(self._opts.keys()) def get(self, k, d=None): return self._opts.get(k, d) def update(self, **kwargs): + updated = set(kwargs.keys()) for k in kwargs: if k not in self._opts: raise KeyError("No such option: %s" % k) - with self.rollback(): + with self.rollback(updated): self._opts.update(kwargs) - self.changed.send(self) + self.changed.send(self, updated=updated) def setter(self, attr): """ diff --git a/mitmproxy/protocol/http2.py b/mitmproxy/protocol/http2.py index 1285e10e..8308f44d 100644 --- a/mitmproxy/protocol/http2.py +++ b/mitmproxy/protocol/http2.py @@ -584,6 +584,8 @@ class Http2SingleStreamLayer(http._HttpTransmissionLayer, basethread.BaseThread) except exceptions.ProtocolException as e: # pragma: no cover self.log(repr(e), "info") self.log(traceback.format_exc(), "debug") + except exceptions.Kill: + self.log("Connection killed", "info") if not self.zombie: self.zombie = time.time() diff --git a/mitmproxy/proxy/config.py b/mitmproxy/proxy/config.py index 7aa4c736..a74ba7e2 100644 --- a/mitmproxy/proxy/config.py +++ b/mitmproxy/proxy/config.py @@ -79,10 +79,10 @@ class ProxyConfig: self.certstore = None self.clientcerts = None self.openssl_verification_mode_server = None - self.configure(options) + self.configure(options, set(options.keys())) options.changed.connect(self.configure) - def configure(self, options): + def configure(self, options, updated): conflict = all( [ options.add_upstream_certs_to_client_chain, diff --git a/mitmproxy/web/app.py b/mitmproxy/web/app.py index 8ccc21c5..f8f85f3d 100644 --- a/mitmproxy/web/app.py +++ b/mitmproxy/web/app.py @@ -234,7 +234,8 @@ class AcceptFlow(RequestHandler): class FlowHandler(RequestHandler): def delete(self, flow_id): - self.flow.kill(self.master) + if not self.flow.reply.acked: + self.flow.kill(self.master) self.state.delete_flow(self.flow) def put(self, flow_id): diff --git a/mitmproxy/web/master.py b/mitmproxy/web/master.py index 3d384612..9ddb61d4 100644 --- a/mitmproxy/web/master.py +++ b/mitmproxy/web/master.py @@ -136,7 +136,7 @@ class WebMaster(flow.FlowMaster): def __init__(self, server, options): super(WebMaster, self).__init__(options, server, WebState()) - self.addons.add(*builtins.default_addons()) + self.addons.add(options, *builtins.default_addons()) self.app = app.Application( self, self.options.wdebug, self.options.wauthenticator ) diff --git a/netlib/encoding.py b/netlib/encoding.py index e3cf5f30..da282194 100644 --- a/netlib/encoding.py +++ b/netlib/encoding.py @@ -4,6 +4,7 @@ Utility functions for decoding response bodies. from __future__ import absolute_import import codecs +import collections from io import BytesIO import gzip import zlib @@ -11,7 +12,15 @@ import zlib from typing import Union # noqa -def decode(obj, encoding, errors='strict'): +# We have a shared single-element cache for encoding and decoding. +# This is quite useful in practice, e.g. +# flow.request.content = flow.request.content.replace(b"foo", b"bar") +# does not require an .encode() call if content does not contain b"foo" +CachedDecode = collections.namedtuple("CachedDecode", "encoded encoding errors decoded") +_cache = CachedDecode(None, None, None, None) + + +def decode(encoded, encoding, errors='strict'): # type: (Union[str, bytes], str, str) -> Union[str, bytes] """ Decode the given input object @@ -22,20 +31,32 @@ def decode(obj, encoding, errors='strict'): Raises: ValueError, if decoding fails. """ + global _cache + cached = ( + isinstance(encoded, bytes) and + _cache.encoded == encoded and + _cache.encoding == encoding and + _cache.errors == errors + ) + if cached: + return _cache.decoded try: try: - return custom_decode[encoding](obj) + decoded = custom_decode[encoding](encoded) except KeyError: - return codecs.decode(obj, encoding, errors) + decoded = codecs.decode(encoded, encoding, errors) + if encoding in ("gzip", "deflate"): + _cache = CachedDecode(encoded, encoding, errors, decoded) + return decoded except Exception as e: raise ValueError("{} when decoding {} with {}".format( type(e).__name__, - repr(obj)[:10], + repr(encoded)[:10], repr(encoding), )) -def encode(obj, encoding, errors='strict'): +def encode(decoded, encoding, errors='strict'): # type: (Union[str, bytes], str, str) -> Union[str, bytes] """ Encode the given input object @@ -46,15 +67,27 @@ def encode(obj, encoding, errors='strict'): Raises: ValueError, if encoding fails. """ + global _cache + cached = ( + isinstance(decoded, bytes) and + _cache.decoded == decoded and + _cache.encoding == encoding and + _cache.errors == errors + ) + if cached: + return _cache.encoded try: try: - return custom_encode[encoding](obj) + encoded = custom_encode[encoding](decoded) except KeyError: - return codecs.encode(obj, encoding, errors) + encoded = codecs.encode(decoded, encoding, errors) + if encoding in ("gzip", "deflate"): + _cache = CachedDecode(encoded, encoding, errors, decoded) + return encoded except Exception as e: raise ValueError("{} when encoding {} with {}".format( type(e).__name__, - repr(obj)[:10], + repr(decoded)[:10], repr(encoding), )) diff --git a/netlib/http/message.py b/netlib/http/message.py index 34709f0a..be35b8d1 100644 --- a/netlib/http/message.py +++ b/netlib/http/message.py @@ -32,9 +32,6 @@ class MessageData(basetypes.Serializable): def __ne__(self, other): return not self.__eq__(other) - def __hash__(self): - return hash(frozenset(self.__dict__.items())) - def set_state(self, state): for k, v in state.items(): if k == "headers": @@ -52,23 +49,7 @@ class MessageData(basetypes.Serializable): return cls(**state) -class CachedDecode(object): - __slots__ = ["encoded", "encoding", "strict", "decoded"] - - def __init__(self, object, encoding, strict, decoded): - self.encoded = object - self.encoding = encoding - self.strict = strict - self.decoded = decoded - -no_cached_decode = CachedDecode(None, None, None, None) - - class Message(basetypes.Serializable): - def __init__(self): - self._content_cache = no_cached_decode # type: CachedDecode - self._text_cache = no_cached_decode # type: CachedDecode - def __eq__(self, other): if isinstance(other, Message): return self.data == other.data @@ -77,9 +58,6 @@ class Message(basetypes.Serializable): def __ne__(self, other): return not self.__eq__(other) - def __hash__(self): - return hash(self.data) ^ 1 - def get_state(self): return self.data.get_state() @@ -132,25 +110,15 @@ class Message(basetypes.Serializable): if self.raw_content is None: return None ce = self.headers.get("content-encoding") - cached = ( - self._content_cache.encoded == self.raw_content and - (self._content_cache.strict or not strict) and - self._content_cache.encoding == ce - ) - if not cached: - is_strict = True - if ce: - try: - decoded = encoding.decode(self.raw_content, ce) - except ValueError: - if strict: - raise - is_strict = False - decoded = self.raw_content - else: - decoded = self.raw_content - self._content_cache = CachedDecode(self.raw_content, ce, is_strict, decoded) - return self._content_cache.decoded + if ce: + try: + return encoding.decode(self.raw_content, ce) + except ValueError: + if strict: + raise + return self.raw_content + else: + return self.raw_content def set_content(self, value): if value is None: @@ -163,22 +131,13 @@ class Message(basetypes.Serializable): .format(type(value).__name__) ) ce = self.headers.get("content-encoding") - cached = ( - self._content_cache.decoded == value and - self._content_cache.encoding == ce and - self._content_cache.strict - ) - if not cached: - try: - encoded = encoding.encode(value, ce or "identity") - except ValueError: - # So we have an invalid content-encoding? - # Let's remove it! - del self.headers["content-encoding"] - ce = None - encoded = value - self._content_cache = CachedDecode(encoded, ce, True, value) - self.raw_content = self._content_cache.encoded + try: + self.raw_content = encoding.encode(value, ce or "identity") + except ValueError: + # So we have an invalid content-encoding? + # Let's remove it! + del self.headers["content-encoding"] + self.raw_content = value self.headers["content-length"] = str(len(self.raw_content)) content = property(get_content, set_content) @@ -250,22 +209,12 @@ class Message(basetypes.Serializable): enc = self._guess_encoding() content = self.get_content(strict) - cached = ( - self._text_cache.encoded == content and - (self._text_cache.strict or not strict) and - self._text_cache.encoding == enc - ) - if not cached: - is_strict = self._content_cache.strict - try: - decoded = encoding.decode(content, enc) - except ValueError: - if strict: - raise - is_strict = False - decoded = self.content.decode("utf8", "replace" if six.PY2 else "surrogateescape") - self._text_cache = CachedDecode(content, enc, is_strict, decoded) - return self._text_cache.decoded + try: + return encoding.decode(content, enc) + except ValueError: + if strict: + raise + return content.decode("utf8", "replace" if six.PY2 else "surrogateescape") def set_text(self, text): if text is None: @@ -273,23 +222,15 @@ class Message(basetypes.Serializable): return enc = self._guess_encoding() - cached = ( - self._text_cache.decoded == text and - self._text_cache.encoding == enc and - self._text_cache.strict - ) - if not cached: - try: - encoded = encoding.encode(text, enc) - except ValueError: - # Fall back to UTF-8 and update the content-type header. - ct = headers.parse_content_type(self.headers.get("content-type", "")) or ("text", "plain", {}) - ct[2]["charset"] = "utf-8" - self.headers["content-type"] = headers.assemble_content_type(*ct) - enc = "utf8" - encoded = text.encode(enc, "replace" if six.PY2 else "surrogateescape") - self._text_cache = CachedDecode(encoded, enc, True, text) - self.content = self._text_cache.encoded + try: + self.content = encoding.encode(text, enc) + except ValueError: + # Fall back to UTF-8 and update the content-type header. + ct = headers.parse_content_type(self.headers.get("content-type", "")) or ("text", "plain", {}) + ct[2]["charset"] = "utf-8" + self.headers["content-type"] = headers.assemble_content_type(*ct) + enc = "utf8" + self.content = text.encode(enc, "replace" if six.PY2 else "surrogateescape") text = property(get_text, set_text) diff --git a/netlib/http/request.py b/netlib/http/request.py index ecaa9b79..061217a3 100644 --- a/netlib/http/request.py +++ b/netlib/http/request.py @@ -253,14 +253,13 @@ class Request(message.Message): ) def _get_query(self): - _, _, _, _, query, _ = urllib.parse.urlparse(self.url) + query = urllib.parse.urlparse(self.url).query return tuple(netlib.http.url.decode(query)) - def _set_query(self, value): - query = netlib.http.url.encode(value) - scheme, netloc, path, params, _, fragment = urllib.parse.urlparse(self.url) - _, _, _, self.path = netlib.http.url.parse( - urllib.parse.urlunparse([scheme, netloc, path, params, query, fragment])) + def _set_query(self, query_data): + query = netlib.http.url.encode(query_data) + _, _, path, params, _, fragment = urllib.parse.urlparse(self.url) + self.path = urllib.parse.urlunparse(["", "", path, params, query, fragment]) @query.setter def query(self, value): @@ -296,19 +295,18 @@ class Request(message.Message): The URL's path components as a tuple of strings. Components are unquoted. """ - _, _, path, _, _, _ = urllib.parse.urlparse(self.url) + path = urllib.parse.urlparse(self.url).path # This needs to be a tuple so that it's immutable. # Otherwise, this would fail silently: # request.path_components.append("foo") - return tuple(urllib.parse.unquote(i) for i in path.split("/") if i) + return tuple(netlib.http.url.unquote(i) for i in path.split("/") if i) @path_components.setter def path_components(self, components): - components = map(lambda x: urllib.parse.quote(x, safe=""), components) + components = map(lambda x: netlib.http.url.quote(x, safe=""), components) path = "/" + "/".join(components) - scheme, netloc, _, params, query, fragment = urllib.parse.urlparse(self.url) - _, _, _, self.path = netlib.http.url.parse( - urllib.parse.urlunparse([scheme, netloc, path, params, query, fragment])) + _, _, _, params, query, fragment = urllib.parse.urlparse(self.url) + self.path = urllib.parse.urlunparse(["", "", path, params, query, fragment]) def anticache(self): """ @@ -365,13 +363,13 @@ class Request(message.Message): pass return () - def _set_urlencoded_form(self, value): + def _set_urlencoded_form(self, form_data): """ Sets the body to the URL-encoded form data, and adds the appropriate content-type header. This will overwrite the existing content if there is one. """ self.headers["content-type"] = "application/x-www-form-urlencoded" - self.content = netlib.http.url.encode(value).encode() + self.content = netlib.http.url.encode(form_data).encode() @urlencoded_form.setter def urlencoded_form(self, value): diff --git a/netlib/http/url.py b/netlib/http/url.py index 2fc6e7ee..076854b9 100644 --- a/netlib/http/url.py +++ b/netlib/http/url.py @@ -82,18 +82,51 @@ def unparse(scheme, host, port, path=""): def encode(s): + # type: Sequence[Tuple[str,str]] -> str """ Takes a list of (key, value) tuples and returns a urlencoded string. """ - s = [tuple(i) for i in s] - return urllib.parse.urlencode(s, False) + if six.PY2: + return urllib.parse.urlencode(s, False) + else: + return urllib.parse.urlencode(s, False, errors="surrogateescape") def decode(s): """ - Takes a urlencoded string and returns a list of (key, value) tuples. + Takes a urlencoded string and returns a list of surrogate-escaped (key, value) tuples. + """ + if six.PY2: + return urllib.parse.parse_qsl(s, keep_blank_values=True) + else: + return urllib.parse.parse_qsl(s, keep_blank_values=True, errors='surrogateescape') + + +def quote(b, safe="/"): + """ + Returns: + An ascii-encodable str. + """ + # type: (str) -> str + if six.PY2: + return urllib.parse.quote(b, safe=safe) + else: + return urllib.parse.quote(b, safe=safe, errors="surrogateescape") + + +def unquote(s): """ - return urllib.parse.parse_qsl(s, keep_blank_values=True) + Args: + s: A surrogate-escaped str + Returns: + A surrogate-escaped str + """ + # type: (str) -> str + + if six.PY2: + return urllib.parse.unquote(s) + else: + return urllib.parse.unquote(s, errors="surrogateescape") def hostport(scheme, host, port): diff --git a/netlib/multidict.py b/netlib/multidict.py index 51053ff6..e9fec155 100644 --- a/netlib/multidict.py +++ b/netlib/multidict.py @@ -79,9 +79,6 @@ class _MultiDict(MutableMapping, basetypes.Serializable): def __ne__(self, other): return not self.__eq__(other) - def __hash__(self): - return hash(self.fields) - def get_all(self, key): """ Return the list of all values for a given key. @@ -241,6 +238,9 @@ class ImmutableMultiDict(MultiDict): __delitem__ = set_all = insert = _immutable + def __hash__(self): + return hash(self.fields) + def with_delitem(self, key): """ Returns: diff --git a/netlib/strutils.py b/netlib/strutils.py index 32e77927..8f27ebb7 100644 --- a/netlib/strutils.py +++ b/netlib/strutils.py @@ -51,8 +51,7 @@ else: def escape_control_characters(text, keep_spacing=True): """ - Replace all unicode C1 control characters from the given text with their respective control pictures. - For example, a null byte is replaced with the unicode character "\u2400". + Replace all unicode C1 control characters from the given text with a single "." Args: keep_spacing: If True, tabs and newlines will not be replaced. @@ -99,6 +98,9 @@ def bytes_to_escaped_str(data, keep_spacing=False): def escaped_str_to_bytes(data): """ Take an escaped string and return the unescaped bytes equivalent. + + Raises: + ValueError, if the escape sequence is invalid. """ if not isinstance(data, six.string_types): if six.PY2: diff --git a/test/mitmproxy/builtins/test_anticache.py b/test/mitmproxy/builtins/test_anticache.py index 5a00af03..ac321e26 100644 --- a/test/mitmproxy/builtins/test_anticache.py +++ b/test/mitmproxy/builtins/test_anticache.py @@ -8,9 +8,10 @@ from mitmproxy import options class TestAntiCache(mastertest.MasterTest): def test_simple(self): s = state.State() - m = master.FlowMaster(options.Options(anticache = True), None, s) + o = options.Options(anticache = True) + m = master.FlowMaster(o, None, s) sa = anticache.AntiCache() - m.addons.add(sa) + m.addons.add(o, sa) f = tutils.tflow(resp=True) self.invoke(m, "request", f) diff --git a/test/mitmproxy/builtins/test_anticomp.py b/test/mitmproxy/builtins/test_anticomp.py index 6bfd54bb..a5f5a270 100644 --- a/test/mitmproxy/builtins/test_anticomp.py +++ b/test/mitmproxy/builtins/test_anticomp.py @@ -8,9 +8,10 @@ from mitmproxy import options class TestAntiComp(mastertest.MasterTest): def test_simple(self): s = state.State() - m = master.FlowMaster(options.Options(anticomp = True), None, s) + o = options.Options(anticomp = True) + m = master.FlowMaster(o, None, s) sa = anticomp.AntiComp() - m.addons.add(sa) + m.addons.add(o, sa) f = tutils.tflow(resp=True) self.invoke(m, "request", f) diff --git a/test/mitmproxy/builtins/test_dumper.py b/test/mitmproxy/builtins/test_dumper.py index 57e3d036..6287fe86 100644 --- a/test/mitmproxy/builtins/test_dumper.py +++ b/test/mitmproxy/builtins/test_dumper.py @@ -15,26 +15,27 @@ class TestDumper(mastertest.MasterTest): d = dumper.Dumper() sio = StringIO() - d.configure(dump.Options(tfile = sio, flow_detail = 0)) + updated = set(["tfile", "flow_detail"]) + d.configure(dump.Options(tfile = sio, flow_detail = 0), updated) d.response(tutils.tflow()) assert not sio.getvalue() - d.configure(dump.Options(tfile = sio, flow_detail = 4)) + d.configure(dump.Options(tfile = sio, flow_detail = 4), updated) d.response(tutils.tflow()) assert sio.getvalue() sio = StringIO() - d.configure(dump.Options(tfile = sio, flow_detail = 4)) + d.configure(dump.Options(tfile = sio, flow_detail = 4), updated) d.response(tutils.tflow(resp=True)) assert "<<" in sio.getvalue() sio = StringIO() - d.configure(dump.Options(tfile = sio, flow_detail = 4)) + d.configure(dump.Options(tfile = sio, flow_detail = 4), updated) d.response(tutils.tflow(err=True)) assert "<<" in sio.getvalue() sio = StringIO() - d.configure(dump.Options(tfile = sio, flow_detail = 4)) + d.configure(dump.Options(tfile = sio, flow_detail = 4), updated) flow = tutils.tflow() flow.request = netlib.tutils.treq() flow.request.stickycookie = True @@ -47,7 +48,7 @@ class TestDumper(mastertest.MasterTest): assert sio.getvalue() sio = StringIO() - d.configure(dump.Options(tfile = sio, flow_detail = 4)) + d.configure(dump.Options(tfile = sio, flow_detail = 4), updated) flow = tutils.tflow(resp=netlib.tutils.tresp(content=b"{")) flow.response.headers["content-type"] = "application/json" flow.response.status_code = 400 @@ -55,7 +56,7 @@ class TestDumper(mastertest.MasterTest): assert sio.getvalue() sio = StringIO() - d.configure(dump.Options(tfile = sio)) + d.configure(dump.Options(tfile = sio), updated) flow = tutils.tflow() flow.request.content = None flow.response = models.HTTPResponse.wrap(netlib.tutils.tresp()) @@ -72,15 +73,13 @@ class TestContentView(mastertest.MasterTest): s = state.State() sio = StringIO() - m = mastertest.RecordingMaster( - dump.Options( - flow_detail=4, - verbosity=3, - tfile=sio, - ), - None, s + o = dump.Options( + flow_detail=4, + verbosity=3, + tfile=sio, ) + m = mastertest.RecordingMaster(o, None, s) d = dumper.Dumper() - m.addons.add(d) + m.addons.add(o, d) self.invoke(m, "response", tutils.tflow()) assert "Content viewer failed" in m.event_log[0][1] diff --git a/test/mitmproxy/builtins/test_filestreamer.py b/test/mitmproxy/builtins/test_filestreamer.py index c1d5947f..0e69b340 100644 --- a/test/mitmproxy/builtins/test_filestreamer.py +++ b/test/mitmproxy/builtins/test_filestreamer.py @@ -20,16 +20,13 @@ class TestStream(mastertest.MasterTest): return list(r.stream()) s = state.State() - m = master.FlowMaster( - options.Options( - outfile = (p, "wb") - ), - None, - s + o = options.Options( + outfile = (p, "wb") ) + m = master.FlowMaster(o, None, s) sa = filestreamer.FileStreamer() - m.addons.add(sa) + m.addons.add(o, sa) f = tutils.tflow(resp=True) self.invoke(m, "request", f) self.invoke(m, "response", f) @@ -39,7 +36,7 @@ class TestStream(mastertest.MasterTest): m.options.outfile = (p, "ab") - m.addons.add(sa) + m.addons.add(o, sa) f = tutils.tflow() self.invoke(m, "request", f) m.addons.remove(sa) diff --git a/test/mitmproxy/builtins/test_replace.py b/test/mitmproxy/builtins/test_replace.py index a0b4b722..5e70ce56 100644 --- a/test/mitmproxy/builtins/test_replace.py +++ b/test/mitmproxy/builtins/test_replace.py @@ -8,38 +8,38 @@ from mitmproxy import options class TestReplace(mastertest.MasterTest): def test_configure(self): r = replace.Replace() + updated = set(["replacements"]) r.configure(options.Options( replacements=[("one", "two", "three")] - )) + ), updated) tutils.raises( "invalid filter pattern", r.configure, options.Options( replacements=[("~b", "two", "three")] - ) + ), + updated ) tutils.raises( "invalid regular expression", r.configure, options.Options( replacements=[("foo", "+", "three")] - ) + ), + updated ) def test_simple(self): s = state.State() - m = master.FlowMaster( - options.Options( - replacements = [ - ("~q", "foo", "bar"), - ("~s", "foo", "bar"), - ] - ), - None, - s + o = options.Options( + replacements = [ + ("~q", "foo", "bar"), + ("~s", "foo", "bar"), + ] ) + m = master.FlowMaster(o, None, s) sa = replace.Replace() - m.addons.add(sa) + m.addons.add(o, sa) f = tutils.tflow() f.request.content = b"foo" diff --git a/test/mitmproxy/builtins/test_script.py b/test/mitmproxy/builtins/test_script.py index f37c7f94..2870fd17 100644 --- a/test/mitmproxy/builtins/test_script.py +++ b/test/mitmproxy/builtins/test_script.py @@ -48,39 +48,41 @@ def test_load_script(): "data/addonscripts/recorder.py" ), [] ) - assert ns["configure"] + assert ns.start class TestScript(mastertest.MasterTest): def test_simple(self): s = state.State() - m = master.FlowMaster(options.Options(), None, s) + o = options.Options() + m = master.FlowMaster(o, None, s) sc = script.Script( tutils.test_data.path( "data/addonscripts/recorder.py" ) ) - m.addons.add(sc) - assert sc.ns["call_log"] == [ + m.addons.add(o, sc) + assert sc.ns.call_log == [ ("solo", "start", (), {}), - ("solo", "configure", (options.Options(),), {}) + ("solo", "configure", (o, o.keys()), {}) ] - sc.ns["call_log"] = [] + sc.ns.call_log = [] f = tutils.tflow(resp=True) self.invoke(m, "request", f) - recf = sc.ns["call_log"][0] + recf = sc.ns.call_log[0] assert recf[1] == "request" def test_reload(self): s = state.State() - m = mastertest.RecordingMaster(options.Options(), None, s) + o = options.Options() + m = mastertest.RecordingMaster(o, None, s) with tutils.tmpdir(): with open("foo.py", "w"): pass sc = script.Script("foo.py") - m.addons.add(sc) + m.addons.add(o, sc) for _ in range(100): with open("foo.py", "a") as f: @@ -93,19 +95,22 @@ class TestScript(mastertest.MasterTest): def test_exception(self): s = state.State() - m = mastertest.RecordingMaster(options.Options(), None, s) + o = options.Options() + m = mastertest.RecordingMaster(o, None, s) sc = script.Script( tutils.test_data.path("data/addonscripts/error.py") ) - m.addons.add(sc) + m.addons.add(o, sc) f = tutils.tflow(resp=True) self.invoke(m, "request", f) assert m.event_log[0][0] == "error" def test_duplicate_flow(self): s = state.State() - fm = master.FlowMaster(None, None, s) + o = options.Options() + fm = master.FlowMaster(o, None, s) fm.addons.add( + o, script.Script( tutils.test_data.path("data/addonscripts/duplicate_flow.py") ) @@ -116,6 +121,20 @@ class TestScript(mastertest.MasterTest): assert not fm.state.view[0].request.is_replay assert fm.state.view[1].request.is_replay + def test_addon(self): + s = state.State() + o = options.Options() + m = master.FlowMaster(o, None, s) + sc = script.Script( + tutils.test_data.path( + "data/addonscripts/addon.py" + ) + ) + m.addons.add(o, sc) + assert sc.ns.event_log == [ + 'scriptstart', 'addonstart', 'addonconfigure' + ] + class TestScriptLoader(mastertest.MasterTest): def test_simple(self): @@ -123,7 +142,7 @@ class TestScriptLoader(mastertest.MasterTest): o = options.Options(scripts=[]) m = master.FlowMaster(o, None, s) sc = script.ScriptLoader() - m.addons.add(sc) + m.addons.add(o, sc) assert len(m.addons) == 1 o.update( scripts = [ @@ -139,7 +158,7 @@ class TestScriptLoader(mastertest.MasterTest): o = options.Options(scripts=["one", "one"]) m = master.FlowMaster(o, None, s) sc = script.ScriptLoader() - tutils.raises(exceptions.OptionsError, m.addons.add, sc) + tutils.raises(exceptions.OptionsError, m.addons.add, o, sc) def test_order(self): rec = tutils.test_data.path("data/addonscripts/recorder.py") @@ -154,7 +173,7 @@ class TestScriptLoader(mastertest.MasterTest): ) m = mastertest.RecordingMaster(o, None, s) sc = script.ScriptLoader() - m.addons.add(sc) + m.addons.add(o, sc) debug = [(i[0], i[1]) for i in m.event_log if i[0] == "debug"] assert debug == [ diff --git a/test/mitmproxy/builtins/test_setheaders.py b/test/mitmproxy/builtins/test_setheaders.py index 4465719d..41c18360 100644 --- a/test/mitmproxy/builtins/test_setheaders.py +++ b/test/mitmproxy/builtins/test_setheaders.py @@ -8,19 +8,20 @@ from mitmproxy import options class TestSetHeaders(mastertest.MasterTest): def mkmaster(self, **opts): s = state.State() - m = mastertest.RecordingMaster(options.Options(**opts), None, s) + o = options.Options(**opts) + m = mastertest.RecordingMaster(o, None, s) sh = setheaders.SetHeaders() - m.addons.add(sh) + m.addons.add(o, sh) return m, sh def test_configure(self): sh = setheaders.SetHeaders() + o = options.Options( + setheaders = [("~b", "one", "two")] + ) tutils.raises( "invalid setheader filter pattern", - sh.configure, - options.Options( - setheaders = [("~b", "one", "two")] - ) + sh.configure, o, o.keys() ) def test_setheaders(self): diff --git a/test/mitmproxy/builtins/test_stickyauth.py b/test/mitmproxy/builtins/test_stickyauth.py index 9233f435..5757fb2d 100644 --- a/test/mitmproxy/builtins/test_stickyauth.py +++ b/test/mitmproxy/builtins/test_stickyauth.py @@ -8,9 +8,10 @@ from mitmproxy import options class TestStickyAuth(mastertest.MasterTest): def test_simple(self): s = state.State() - m = master.FlowMaster(options.Options(stickyauth = ".*"), None, s) + o = options.Options(stickyauth = ".*") + m = master.FlowMaster(o, None, s) sa = stickyauth.StickyAuth() - m.addons.add(sa) + m.addons.add(o, sa) f = tutils.tflow(resp=True) f.request.headers["authorization"] = "foo" diff --git a/test/mitmproxy/builtins/test_stickycookie.py b/test/mitmproxy/builtins/test_stickycookie.py index 81b540db..e9d92c83 100644 --- a/test/mitmproxy/builtins/test_stickycookie.py +++ b/test/mitmproxy/builtins/test_stickycookie.py @@ -14,22 +14,23 @@ def test_domain_match(): class TestStickyCookie(mastertest.MasterTest): def mk(self): s = state.State() - m = master.FlowMaster(options.Options(stickycookie = ".*"), None, s) + o = options.Options(stickycookie = ".*") + m = master.FlowMaster(o, None, s) sc = stickycookie.StickyCookie() - m.addons.add(sc) + m.addons.add(o, sc) return s, m, sc def test_config(self): sc = stickycookie.StickyCookie() + o = options.Options(stickycookie = "~b") tutils.raises( "invalid filter", - sc.configure, - options.Options(stickycookie = "~b") + sc.configure, o, o.keys() ) def test_simple(self): s, m, sc = self.mk() - m.addons.add(sc) + m.addons.add(m.options, sc) f = tutils.tflow(resp=True) f.response.headers["set-cookie"] = "foo=bar" diff --git a/test/mitmproxy/data/addonscripts/addon.py b/test/mitmproxy/data/addonscripts/addon.py new file mode 100644 index 00000000..84173cb6 --- /dev/null +++ b/test/mitmproxy/data/addonscripts/addon.py @@ -0,0 +1,22 @@ +event_log = [] + + +class Addon: + @property + def event_log(self): + return event_log + + def start(self): + event_log.append("addonstart") + + def configure(self, options, updated): + event_log.append("addonconfigure") + + +def configure(options, updated): + event_log.append("addonconfigure") + + +def start(): + event_log.append("scriptstart") + return Addon() diff --git a/test/mitmproxy/data/addonscripts/recorder.py b/test/mitmproxy/data/addonscripts/recorder.py index b6ac8d89..890e6f4e 100644 --- a/test/mitmproxy/data/addonscripts/recorder.py +++ b/test/mitmproxy/data/addonscripts/recorder.py @@ -2,24 +2,24 @@ from mitmproxy import controller from mitmproxy import ctx import sys -call_log = [] -if len(sys.argv) > 1: - name = sys.argv[1] -else: - name = "solo" +class CallLogger: + call_log = [] -# Keep a log of all possible event calls -evts = list(controller.Events) + ["configure"] -for i in evts: - def mkprox(): - evt = i + def __init__(self, name = "solo"): + self.name = name - def prox(*args, **kwargs): - lg = (name, evt, args, kwargs) - if evt != "log": - ctx.log.info(str(lg)) - call_log.append(lg) - ctx.log.debug("%s %s" % (name, evt)) - return prox - globals()[i] = mkprox() + def __getattr__(self, attr): + if attr in controller.Events: + def prox(*args, **kwargs): + lg = (self.name, attr, args, kwargs) + if attr != "log": + ctx.log.info(str(lg)) + self.call_log.append(lg) + ctx.log.debug("%s %s" % (self.name, attr)) + return prox + raise AttributeError + + +def start(): + return CallLogger(*sys.argv[1:]) diff --git a/test/mitmproxy/data/dumpfile-011 b/test/mitmproxy/data/dumpfile-011 Binary files differindex 2534ad89..936ac0cc 100644 --- a/test/mitmproxy/data/dumpfile-011 +++ b/test/mitmproxy/data/dumpfile-011 diff --git a/test/mitmproxy/script/test_concurrent.py b/test/mitmproxy/script/test_concurrent.py index 080746e8..a5f76994 100644 --- a/test/mitmproxy/script/test_concurrent.py +++ b/test/mitmproxy/script/test_concurrent.py @@ -23,7 +23,7 @@ class TestConcurrent(mastertest.MasterTest): "data/addonscripts/concurrent_decorator.py" ) ) - m.addons.add(sc) + m.addons.add(m.options, sc) f1, f2 = tutils.tflow(), tutils.tflow() self.invoke(m, "request", f1) self.invoke(m, "request", f2) diff --git a/test/mitmproxy/test_addons.py b/test/mitmproxy/test_addons.py index 1861d4ac..a5085ea0 100644 --- a/test/mitmproxy/test_addons.py +++ b/test/mitmproxy/test_addons.py @@ -13,8 +13,9 @@ class TAddon: def test_simple(): - m = controller.Master(options.Options()) + o = options.Options() + m = controller.Master(o) a = addons.Addons(m) - a.add(TAddon("one")) + a.add(o, TAddon("one")) assert a.has_addon("one") assert not a.has_addon("two") diff --git a/test/mitmproxy/test_contentview.py b/test/mitmproxy/test_contentview.py index 2db9ab40..aad53b37 100644 --- a/test/mitmproxy/test_contentview.py +++ b/test/mitmproxy/test_contentview.py @@ -59,10 +59,10 @@ class TestContentView: assert f[0] == "Query" def test_view_urlencoded(self): - d = url.encode([("one", "two"), ("three", "four")]) + d = url.encode([("one", "two"), ("three", "four")]).encode() v = cv.ViewURLEncoded() assert v(d) - d = url.encode([("adsfa", "")]) + d = url.encode([("adsfa", "")]).encode() v = cv.ViewURLEncoded() assert v(d) diff --git a/test/mitmproxy/test_examples.py b/test/mitmproxy/test_examples.py index 0ec85f52..34fcc261 100644 --- a/test/mitmproxy/test_examples.py +++ b/test/mitmproxy/test_examples.py @@ -27,10 +27,11 @@ class RaiseMaster(master.FlowMaster): def tscript(cmd, args=""): + o = options.Options() cmd = example_dir.path(cmd) + " " + args - m = RaiseMaster(options.Options(), None, state.State()) + m = RaiseMaster(o, None, state.State()) sc = script.Script(cmd) - m.addons.add(sc) + m.addons.add(o, sc) return m, sc diff --git a/test/mitmproxy/test_flow.py b/test/mitmproxy/test_flow.py index 36b212a7..74992130 100644 --- a/test/mitmproxy/test_flow.py +++ b/test/mitmproxy/test_flow.py @@ -615,6 +615,7 @@ class TestSerialize: def test_roundtrip(self): sio = io.BytesIO() f = tutils.tflow() + f.marked = True f.request.content = bytes(bytearray(range(256))) w = flow.FlowWriter(sio) w.add(f) @@ -627,6 +628,7 @@ class TestSerialize: f2 = l[0] assert f2.get_state() == f.get_state() assert f2.request == f.request + assert f2.marked def test_load_flows(self): r = self._treader() diff --git a/test/mitmproxy/test_optmanager.py b/test/mitmproxy/test_optmanager.py index 67f76ecd..8414e6b5 100644 --- a/test/mitmproxy/test_optmanager.py +++ b/test/mitmproxy/test_optmanager.py @@ -15,6 +15,8 @@ class TO(optmanager.OptManager): def test_options(): o = TO(two="three") + assert o.keys() == set(["one", "two"]) + assert o.one is None assert o.two == "three" o.one = "one" @@ -29,7 +31,7 @@ def test_options(): rec = [] - def sub(opts): + def sub(opts, updated): rec.append(copy.copy(opts)) o.changed.connect(sub) @@ -68,7 +70,7 @@ def test_rollback(): rec = [] - def sub(opts): + def sub(opts, updated): rec.append(copy.copy(opts)) recerr = [] @@ -76,7 +78,7 @@ def test_rollback(): def errsub(opts, **kwargs): recerr.append(kwargs) - def err(opts): + def err(opts, updated): if opts.one == "ten": raise exceptions.OptionsError() diff --git a/test/mitmproxy/test_protocol_http2.py b/test/mitmproxy/test_protocol_http2.py index afbffb67..aa096a72 100644 --- a/test/mitmproxy/test_protocol_http2.py +++ b/test/mitmproxy/test_protocol_http2.py @@ -30,7 +30,7 @@ logging.getLogger("PIL.PngImagePlugin").setLevel(logging.WARNING) requires_alpn = pytest.mark.skipif( not netlib.tcp.HAS_ALPN, - reason="requires OpenSSL with ALPN support") + reason='requires OpenSSL with ALPN support') class _Http2ServerBase(netlib_tservers.ServerTestBase): @@ -80,7 +80,7 @@ class _Http2ServerBase(netlib_tservers.ServerTestBase): print(traceback.format_exc()) break - def handle_server_event(self, h2_conn, rfile, wfile): + def handle_server_event(self, event, h2_conn, rfile, wfile): raise NotImplementedError() @@ -88,7 +88,6 @@ class _Http2TestBase(object): @classmethod def setup_class(cls): - cls.masteroptions = options.Options() opts = cls.get_options() cls.config = ProxyConfig(opts) @@ -145,12 +144,14 @@ class _Http2TestBase(object): wfile, h2_conn, stream_id=1, - headers=[], + headers=None, body=b'', end_stream=None, priority_exclusive=None, priority_depends_on=None, priority_weight=None): + if headers is None: + headers = [] if end_stream is None: end_stream = (len(body) == 0) @@ -172,12 +173,12 @@ class _Http2TestBase(object): class _Http2Test(_Http2TestBase, _Http2ServerBase): @classmethod - def setup_class(self): + def setup_class(cls): _Http2TestBase.setup_class() _Http2ServerBase.setup_class() @classmethod - def teardown_class(self): + def teardown_class(cls): _Http2TestBase.teardown_class() _Http2ServerBase.teardown_class() @@ -187,7 +188,7 @@ class TestSimple(_Http2Test): request_body_buffer = b'' @classmethod - def handle_server_event(self, event, h2_conn, rfile, wfile): + def handle_server_event(cls, event, h2_conn, rfile, wfile): if isinstance(event, h2.events.ConnectionTerminated): return False elif isinstance(event, h2.events.RequestReceived): @@ -214,7 +215,7 @@ class TestSimple(_Http2Test): wfile.write(h2_conn.data_to_send()) wfile.flush() elif isinstance(event, h2.events.DataReceived): - self.request_body_buffer += event.data + cls.request_body_buffer += event.data return True def test_simple(self): @@ -225,7 +226,7 @@ class TestSimple(_Http2Test): client.wfile, h2_conn, headers=[ - (':authority', "127.0.0.1:%s" % self.server.server.address.port), + (':authority', "127.0.0.1:{}".format(self.server.server.address.port)), (':method', 'GET'), (':scheme', 'https'), (':path', '/'), @@ -269,7 +270,7 @@ class TestSimple(_Http2Test): class TestRequestWithPriority(_Http2Test): @classmethod - def handle_server_event(self, event, h2_conn, rfile, wfile): + def handle_server_event(cls, event, h2_conn, rfile, wfile): if isinstance(event, h2.events.ConnectionTerminated): return False elif isinstance(event, h2.events.RequestReceived): @@ -301,14 +302,14 @@ class TestRequestWithPriority(_Http2Test): client.wfile, h2_conn, headers=[ - (':authority', "127.0.0.1:%s" % self.server.server.address.port), + (':authority', "127.0.0.1:{}".format(self.server.server.address.port)), (':method', 'GET'), (':scheme', 'https'), (':path', '/'), ], - priority_exclusive = True, - priority_depends_on = 42424242, - priority_weight = 42, + priority_exclusive=True, + priority_depends_on=42424242, + priority_weight=42, ) done = False @@ -343,7 +344,7 @@ class TestRequestWithPriority(_Http2Test): client.wfile, h2_conn, headers=[ - (':authority', "127.0.0.1:%s" % self.server.server.address.port), + (':authority', "127.0.0.1:{}".format(self.server.server.address.port)), (':method', 'GET'), (':scheme', 'https'), (':path', '/'), @@ -381,11 +382,11 @@ class TestPriority(_Http2Test): priority_data = None @classmethod - def handle_server_event(self, event, h2_conn, rfile, wfile): + def handle_server_event(cls, event, h2_conn, rfile, wfile): if isinstance(event, h2.events.ConnectionTerminated): return False elif isinstance(event, h2.events.PriorityUpdated): - self.priority_data = (event.exclusive, event.depends_on, event.weight) + cls.priority_data = (event.exclusive, event.depends_on, event.weight) elif isinstance(event, h2.events.RequestReceived): import warnings with warnings.catch_warnings(): @@ -415,7 +416,7 @@ class TestPriority(_Http2Test): client.wfile, h2_conn, headers=[ - (':authority', "127.0.0.1:%s" % self.server.server.address.port), + (':authority', "127.0.0.1:{}".format(self.server.server.address.port)), (':method', 'GET'), (':scheme', 'https'), (':path', '/'), @@ -451,11 +452,11 @@ class TestPriorityWithExistingStream(_Http2Test): priority_data = [] @classmethod - def handle_server_event(self, event, h2_conn, rfile, wfile): + def handle_server_event(cls, event, h2_conn, rfile, wfile): if isinstance(event, h2.events.ConnectionTerminated): return False elif isinstance(event, h2.events.PriorityUpdated): - self.priority_data.append((event.exclusive, event.depends_on, event.weight)) + cls.priority_data.append((event.exclusive, event.depends_on, event.weight)) elif isinstance(event, h2.events.RequestReceived): assert not event.priority_updated @@ -486,7 +487,7 @@ class TestPriorityWithExistingStream(_Http2Test): client.wfile, h2_conn, headers=[ - (':authority', "127.0.0.1:%s" % self.server.server.address.port), + (':authority', "127.0.0.1:{}".format(self.server.server.address.port)), (':method', 'GET'), (':scheme', 'https'), (':path', '/'), @@ -527,7 +528,7 @@ class TestPriorityWithExistingStream(_Http2Test): class TestStreamResetFromServer(_Http2Test): @classmethod - def handle_server_event(self, event, h2_conn, rfile, wfile): + def handle_server_event(cls, event, h2_conn, rfile, wfile): if isinstance(event, h2.events.ConnectionTerminated): return False elif isinstance(event, h2.events.RequestReceived): @@ -543,7 +544,7 @@ class TestStreamResetFromServer(_Http2Test): client.wfile, h2_conn, headers=[ - (':authority', "127.0.0.1:%s" % self.server.server.address.port), + (':authority', "127.0.0.1:{}".format(self.server.server.address.port)), (':method', 'GET'), (':scheme', 'https'), (':path', '/'), @@ -578,7 +579,7 @@ class TestStreamResetFromServer(_Http2Test): class TestBodySizeLimit(_Http2Test): @classmethod - def handle_server_event(self, event, h2_conn, rfile, wfile): + def handle_server_event(cls, event, h2_conn, rfile, wfile): if isinstance(event, h2.events.ConnectionTerminated): return False return True @@ -592,7 +593,7 @@ class TestBodySizeLimit(_Http2Test): client.wfile, h2_conn, headers=[ - (':authority', "127.0.0.1:%s" % self.server.server.address.port), + (':authority', "127.0.0.1:{}".format(self.server.server.address.port)), (':method', 'GET'), (':scheme', 'https'), (':path', '/'), @@ -627,7 +628,7 @@ class TestBodySizeLimit(_Http2Test): class TestPushPromise(_Http2Test): @classmethod - def handle_server_event(self, event, h2_conn, rfile, wfile): + def handle_server_event(cls, event, h2_conn, rfile, wfile): if isinstance(event, h2.events.ConnectionTerminated): return False elif isinstance(event, h2.events.RequestReceived): @@ -637,14 +638,14 @@ class TestPushPromise(_Http2Test): h2_conn.send_headers(1, [(':status', '200')]) h2_conn.push_stream(1, 2, [ - (':authority', "127.0.0.1:%s" % self.port), + (':authority', "127.0.0.1:{}".format(cls.port)), (':method', 'GET'), (':scheme', 'https'), (':path', '/pushed_stream_foo'), ('foo', 'bar') ]) h2_conn.push_stream(1, 4, [ - (':authority', "127.0.0.1:%s" % self.port), + (':authority', "127.0.0.1:{}".format(cls.port)), (':method', 'GET'), (':scheme', 'https'), (':path', '/pushed_stream_bar'), @@ -675,7 +676,7 @@ class TestPushPromise(_Http2Test): client, h2_conn = self._setup_connection() self._send_request(client.wfile, h2_conn, stream_id=1, headers=[ - (':authority', "127.0.0.1:%s" % self.server.server.address.port), + (':authority', "127.0.0.1:{}".format(self.server.server.address.port)), (':method', 'GET'), (':scheme', 'https'), (':path', '/'), @@ -728,7 +729,7 @@ class TestPushPromise(_Http2Test): client, h2_conn = self._setup_connection() self._send_request(client.wfile, h2_conn, stream_id=1, headers=[ - (':authority', "127.0.0.1:%s" % self.server.server.address.port), + (':authority', "127.0.0.1:{}".format(self.server.server.address.port)), (':method', 'GET'), (':scheme', 'https'), (':path', '/'), @@ -780,7 +781,7 @@ class TestPushPromise(_Http2Test): class TestConnectionLost(_Http2Test): @classmethod - def handle_server_event(self, event, h2_conn, rfile, wfile): + def handle_server_event(cls, event, h2_conn, rfile, wfile): if isinstance(event, h2.events.RequestReceived): h2_conn.send_headers(1, [(':status', '200')]) wfile.write(h2_conn.data_to_send()) @@ -791,7 +792,7 @@ class TestConnectionLost(_Http2Test): client, h2_conn = self._setup_connection() self._send_request(client.wfile, h2_conn, stream_id=1, headers=[ - (':authority', "127.0.0.1:%s" % self.server.server.address.port), + (':authority', "127.0.0.1:{}".format(self.server.server.address.port)), (':method', 'GET'), (':scheme', 'https'), (':path', '/'), @@ -822,12 +823,12 @@ class TestConnectionLost(_Http2Test): class TestMaxConcurrentStreams(_Http2Test): @classmethod - def setup_class(self): + def setup_class(cls): _Http2TestBase.setup_class() _Http2ServerBase.setup_class(h2_server_settings={h2.settings.MAX_CONCURRENT_STREAMS: 2}) @classmethod - def handle_server_event(self, event, h2_conn, rfile, wfile): + def handle_server_event(cls, event, h2_conn, rfile, wfile): if isinstance(event, h2.events.ConnectionTerminated): return False elif isinstance(event, h2.events.RequestReceived): @@ -848,7 +849,7 @@ class TestMaxConcurrentStreams(_Http2Test): # this will exceed MAX_CONCURRENT_STREAMS on the server connection # and cause mitmproxy to throttle stream creation to the server self._send_request(client.wfile, h2_conn, stream_id=id, headers=[ - (':authority', "127.0.0.1:%s" % self.server.server.address.port), + (':authority', "127.0.0.1:{}".format(self.server.server.address.port)), (':method', 'GET'), (':scheme', 'https'), (':path', '/'), @@ -883,7 +884,7 @@ class TestMaxConcurrentStreams(_Http2Test): class TestConnectionTerminated(_Http2Test): @classmethod - def handle_server_event(self, event, h2_conn, rfile, wfile): + def handle_server_event(cls, event, h2_conn, rfile, wfile): if isinstance(event, h2.events.RequestReceived): h2_conn.close_connection(error_code=5, last_stream_id=42, additional_data=b'foobar') wfile.write(h2_conn.data_to_send()) @@ -894,7 +895,7 @@ class TestConnectionTerminated(_Http2Test): client, h2_conn = self._setup_connection() self._send_request(client.wfile, h2_conn, headers=[ - (':authority', "127.0.0.1:%s" % self.server.server.address.port), + (':authority', "127.0.0.1:{}".format(self.server.server.address.port)), (':method', 'GET'), (':scheme', 'https'), (':path', '/'), diff --git a/test/mitmproxy/test_server.py b/test/mitmproxy/test_server.py index 233af597..6230fc1f 100644 --- a/test/mitmproxy/test_server.py +++ b/test/mitmproxy/test_server.py @@ -291,7 +291,7 @@ class TestHTTP(tservers.HTTPProxyTest, CommonMixin, AppMixin): s = script.Script( tutils.test_data.path("data/addonscripts/stream_modify.py") ) - self.master.addons.add(s) + self.master.addons.add(self.master.options, s) d = self.pathod('200:b"foo"') assert d.content == b"bar" self.master.addons.remove(s) @@ -523,7 +523,7 @@ class TestTransparent(tservers.TransparentProxyTest, CommonMixin, TcpMixin): s = script.Script( tutils.test_data.path("data/addonscripts/tcp_stream_modify.py") ) - self.master.addons.add(s) + self.master.addons.add(self.master.options, s) self._tcpproxy_on() d = self.pathod('200:b"foo"') self._tcpproxy_off() diff --git a/test/mitmproxy/tservers.py b/test/mitmproxy/tservers.py index f5119166..d364162c 100644 --- a/test/mitmproxy/tservers.py +++ b/test/mitmproxy/tservers.py @@ -34,7 +34,7 @@ class TestMaster(flow.FlowMaster): s = ProxyServer(config) state = flow.State() flow.FlowMaster.__init__(self, opts, s, state) - self.addons.add(*builtins.default_addons()) + self.addons.add(opts, *builtins.default_addons()) self.apps.add(testapp, "testapp", 80) self.apps.add(errapp, "errapp", 80) self.clear_log() diff --git a/test/netlib/http/test_message.py b/test/netlib/http/test_message.py index deebd6f2..12e4706c 100644 --- a/test/netlib/http/test_message.py +++ b/test/netlib/http/test_message.py @@ -1,7 +1,6 @@ # -*- coding: utf-8 -*- from __future__ import absolute_import, print_function, division -import mock import six from netlib.tutils import tresp @@ -71,10 +70,6 @@ class TestMessage(object): assert resp != 0 - def test_hash(self): - resp = tresp() - assert hash(resp) - def test_serializable(self): resp = tresp() resp2 = http.Response.from_state(resp.get_state()) @@ -117,14 +112,6 @@ class TestMessageContentEncoding(object): assert r.content == b"message" assert r.raw_content != b"message" - r.raw_content = b"foo" - with mock.patch("netlib.encoding.decode") as e: - assert r.content - assert e.call_count == 1 - e.reset_mock() - assert r.content - assert e.call_count == 0 - def test_modify(self): r = tresp() assert "content-encoding" not in r.headers @@ -135,13 +122,6 @@ class TestMessageContentEncoding(object): r.decode() assert r.raw_content == b"foo" - r.encode("identity") - with mock.patch("netlib.encoding.encode") as e: - r.content = b"foo" - assert e.call_count == 0 - r.content = b"bar" - assert e.call_count == 1 - with tutils.raises(TypeError): r.content = u"foo" @@ -216,15 +196,6 @@ class TestMessageText(object): r.headers["content-type"] = "text/html; charset=utf8" assert r.text == u"ü" - r.encode("identity") - r.raw_content = b"foo" - with mock.patch("netlib.encoding.decode") as e: - assert r.text - assert e.call_count == 2 - e.reset_mock() - assert r.text - assert e.call_count == 0 - def test_guess_json(self): r = tresp(content=b'"\xc3\xbc"') r.headers["content-type"] = "application/json" @@ -249,14 +220,6 @@ class TestMessageText(object): assert r.raw_content == b"\xc3\xbc" assert r.headers["content-length"] == "2" - r.encode("identity") - with mock.patch("netlib.encoding.encode") as e: - e.return_value = b"" - r.text = u"ü" - assert e.call_count == 0 - r.text = u"ä" - assert e.call_count == 2 - def test_unknown_ce(self): r = tresp() r.headers["content-type"] = "text/html; charset=wtf" diff --git a/test/netlib/http/test_url.py b/test/netlib/http/test_url.py index 26b37230..768e5130 100644 --- a/test/netlib/http/test_url.py +++ b/test/netlib/http/test_url.py @@ -1,3 +1,4 @@ +import six from netlib import tutils from netlib.http import url @@ -57,10 +58,49 @@ def test_unparse(): assert url.unparse("https", "foo.com", 443, "") == "https://foo.com" -def test_urlencode(): +if six.PY2: + surrogates = bytes(bytearray(range(256))) +else: + surrogates = bytes(range(256)).decode("utf8", "surrogateescape") + +surrogates_quoted = ( + '%00%01%02%03%04%05%06%07%08%09%0A%0B%0C%0D%0E%0F' + '%10%11%12%13%14%15%16%17%18%19%1A%1B%1C%1D%1E%1F' + '%20%21%22%23%24%25%26%27%28%29%2A%2B%2C-./' + '0123456789%3A%3B%3C%3D%3E%3F' + '%40ABCDEFGHIJKLMNO' + 'PQRSTUVWXYZ%5B%5C%5D%5E_' + '%60abcdefghijklmno' + 'pqrstuvwxyz%7B%7C%7D%7E%7F' + '%80%81%82%83%84%85%86%87%88%89%8A%8B%8C%8D%8E%8F' + '%90%91%92%93%94%95%96%97%98%99%9A%9B%9C%9D%9E%9F' + '%A0%A1%A2%A3%A4%A5%A6%A7%A8%A9%AA%AB%AC%AD%AE%AF' + '%B0%B1%B2%B3%B4%B5%B6%B7%B8%B9%BA%BB%BC%BD%BE%BF' + '%C0%C1%C2%C3%C4%C5%C6%C7%C8%C9%CA%CB%CC%CD%CE%CF' + '%D0%D1%D2%D3%D4%D5%D6%D7%D8%D9%DA%DB%DC%DD%DE%DF' + '%E0%E1%E2%E3%E4%E5%E6%E7%E8%E9%EA%EB%EC%ED%EE%EF' + '%F0%F1%F2%F3%F4%F5%F6%F7%F8%F9%FA%FB%FC%FD%FE%FF' +) + + +def test_encode(): assert url.encode([('foo', 'bar')]) + assert url.encode([('foo', surrogates)]) -def test_urldecode(): +def test_decode(): s = "one=two&three=four" assert len(url.decode(s)) == 2 + assert url.decode(surrogates) + + +def test_quote(): + assert url.quote("foo") == "foo" + assert url.quote("foo bar") == "foo%20bar" + assert url.quote(surrogates) == surrogates_quoted + + +def test_unquote(): + assert url.unquote("foo") == "foo" + assert url.unquote("foo%20bar") == "foo bar" + assert url.unquote(surrogates_quoted) == surrogates diff --git a/test/netlib/test_encoding.py b/test/netlib/test_encoding.py index de10fc48..a5e81379 100644 --- a/test/netlib/test_encoding.py +++ b/test/netlib/test_encoding.py @@ -1,3 +1,4 @@ +import mock from netlib import encoding, tutils @@ -37,3 +38,32 @@ def test_deflate(): ) with tutils.raises(ValueError): encoding.decode(b"bogus", "deflate") + + +def test_cache(): + decode_gzip = mock.MagicMock() + decode_gzip.return_value = b"decoded" + encode_gzip = mock.MagicMock() + encode_gzip.return_value = b"encoded" + + with mock.patch.dict(encoding.custom_decode, gzip=decode_gzip): + with mock.patch.dict(encoding.custom_encode, gzip=encode_gzip): + assert encoding.decode(b"encoded", "gzip") == b"decoded" + assert decode_gzip.call_count == 1 + + # should be cached + assert encoding.decode(b"encoded", "gzip") == b"decoded" + assert decode_gzip.call_count == 1 + + # the other way around as well + assert encoding.encode(b"decoded", "gzip") == b"encoded" + assert encode_gzip.call_count == 0 + + # different encoding + decode_gzip.return_value = b"bar" + assert encoding.encode(b"decoded", "deflate") != b"decoded" + assert encode_gzip.call_count == 0 + + # This is not in the cache anymore + assert encoding.encode(b"decoded", "gzip") == b"encoded" + assert encode_gzip.call_count == 1 diff --git a/test/netlib/test_multidict.py b/test/netlib/test_multidict.py index 038441e7..58ae0f98 100644 --- a/test/netlib/test_multidict.py +++ b/test/netlib/test_multidict.py @@ -45,7 +45,7 @@ class TestMultiDict(object): assert md["foo"] == "bar" with tutils.raises(KeyError): - md["bar"] + assert md["bar"] md_multi = TMultiDict( [("foo", "a"), ("foo", "b")] @@ -101,6 +101,15 @@ class TestMultiDict(object): assert TMultiDict() != self._multi() assert TMultiDict() != 42 + def test_hash(self): + """ + If a class defines mutable objects and implements an __eq__() method, + it should not implement __hash__(), since the implementation of hashable + collections requires that a key's hash value is immutable. + """ + with tutils.raises(TypeError): + assert hash(TMultiDict()) + def test_get_all(self): md = self._multi() assert md.get_all("foo") == ["bar"] @@ -197,6 +206,9 @@ class TestImmutableMultiDict(object): with tutils.raises(TypeError): md.add("foo", "bar") + def test_hash(self): + assert hash(TImmutableMultiDict()) + def test_with_delitem(self): md = TImmutableMultiDict([("foo", "bar")]) assert md.with_delitem("foo").fields == () diff --git a/web/src/js/components/FlowTable/FlowRow.jsx b/web/src/js/components/FlowTable/FlowRow.jsx index 749bc0ce..7961d502 100644 --- a/web/src/js/components/FlowTable/FlowRow.jsx +++ b/web/src/js/components/FlowTable/FlowRow.jsx @@ -1,6 +1,7 @@ import React, { PropTypes } from 'react' import classnames from 'classnames' import columns from './FlowColumns' +import { pure } from '../../utils' FlowRow.propTypes = { onSelect: PropTypes.func.isRequired, @@ -9,7 +10,7 @@ FlowRow.propTypes = { selected: PropTypes.bool, } -export default function FlowRow({ flow, selected, highlighted, onSelect }) { +function FlowRow({ flow, selected, highlighted, onSelect }) { const className = classnames({ 'selected': selected, 'highlighted': highlighted, @@ -19,10 +20,12 @@ export default function FlowRow({ flow, selected, highlighted, onSelect }) { }) return ( - <tr className={className} onClick={() => onSelect(flow)}> + <tr className={className} onClick={() => onSelect(flow.id)}> {columns.map(Column => ( <Column key={Column.name} flow={flow}/> ))} </tr> ) } + +export default pure(FlowRow) diff --git a/web/src/js/components/MainView.jsx b/web/src/js/components/MainView.jsx index d7d1ebeb..f45f9eef 100644 --- a/web/src/js/components/MainView.jsx +++ b/web/src/js/components/MainView.jsx @@ -22,7 +22,7 @@ class MainView extends Component { flows={flows} selected={selectedFlow} highlight={highlight} - onSelect={flow => this.props.selectFlow(flow.id)} + onSelect={this.props.selectFlow} /> {selectedFlow && [ <Splitter key="splitter"/>, diff --git a/web/src/js/utils.js b/web/src/js/utils.js index d3b99bd0..cc17c565 100644 --- a/web/src/js/utils.js +++ b/web/src/js/utils.js @@ -1,7 +1,9 @@ -import _ from "lodash"; +import _ from 'lodash' +import React from 'react' +import shallowEqual from 'shallowequal' window._ = _; -window.React = require("react"); +window.React = React; export var Key = { UP: 38, @@ -106,15 +108,27 @@ fetchApi.put = (url, json, options) => fetchApi( } ) - export function getDiff(obj1, obj2) { let result = {...obj2}; for(let key in obj1) { if(_.isEqual(obj2[key], obj1[key])) - result[key] = undefined; + result[key] = undefined else if(!(Array.isArray(obj2[key]) && Array.isArray(obj1[key])) && typeof obj2[key] == 'object' && typeof obj1[key] == 'object') - result[key] = getDiff(obj1[key], obj2[key]); + result[key] = getDiff(obj1[key], obj2[key]) + } + return result +} + +export const pure = renderFn => class extends React.Component { + static displayName = renderFn.name + + shouldComponentUpdate(nextProps) { + console.log(!shallowEqual(this.props, nextProps)) + return !shallowEqual(this.props, nextProps) + } + + render() { + return renderFn(this.props) } - return result; } |