aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--.appveyor.yml2
-rw-r--r--.travis.yml2
-rw-r--r--mitmproxy/addons.py20
-rw-r--r--mitmproxy/builtins/__init__.py4
-rw-r--r--mitmproxy/builtins/clientplayback.py39
-rw-r--r--mitmproxy/builtins/dumper.py8
-rw-r--r--mitmproxy/builtins/replace.py4
-rw-r--r--mitmproxy/builtins/script.py16
-rw-r--r--mitmproxy/builtins/serverplayback.py122
-rw-r--r--mitmproxy/cmdline.py32
-rw-r--r--mitmproxy/console/flowlist.py61
-rw-r--r--mitmproxy/console/flowview.py4
-rw-r--r--mitmproxy/console/help.py2
-rw-r--r--mitmproxy/console/master.py67
-rw-r--r--mitmproxy/console/options.py6
-rw-r--r--mitmproxy/console/statusbar.py16
-rw-r--r--mitmproxy/console/window.py20
-rw-r--r--mitmproxy/dump.py28
-rw-r--r--mitmproxy/exceptions.py36
-rw-r--r--mitmproxy/filt.py3
-rw-r--r--mitmproxy/flow/__init__.py8
-rw-r--r--mitmproxy/flow/master.py231
-rw-r--r--mitmproxy/flow/modules.py132
-rw-r--r--mitmproxy/options.py30
-rw-r--r--mitmproxy/protocol/http.py11
-rw-r--r--mitmproxy/protocol/http2.py99
-rw-r--r--mitmproxy/protocol/http_replay.py3
-rw-r--r--mitmproxy/proxy/server.py2
-rw-r--r--mitmproxy/web/app.py2
-rw-r--r--netlib/debug.py4
-rw-r--r--netlib/encoding.py25
-rw-r--r--netlib/http/headers.py1
-rw-r--r--netlib/http/http2/__init__.py6
-rw-r--r--netlib/http/http2/framereader.py9
-rw-r--r--netlib/http/response.py19
-rw-r--r--netlib/strutils.py5
-rw-r--r--pathod/language/http2.py2
-rw-r--r--pathod/protocols/http2.py6
-rw-r--r--setup.py3
-rw-r--r--test/mitmproxy/builtins/test_clientplayback.py37
-rw-r--r--test/mitmproxy/builtins/test_script.py25
-rw-r--r--test/mitmproxy/builtins/test_serverplayback.py284
-rw-r--r--test/mitmproxy/console/test_master.py2
-rw-r--r--test/mitmproxy/mastertest.py21
-rw-r--r--test/mitmproxy/protocol/test_http1.py12
-rw-r--r--test/mitmproxy/protocol/test_http2.py49
-rw-r--r--test/mitmproxy/test_addons.py4
-rw-r--r--test/mitmproxy/test_dump.py11
-rw-r--r--test/mitmproxy/test_flow.py358
-rw-r--r--test/mitmproxy/test_fuzzing.py9
-rw-r--r--test/mitmproxy/test_server.py337
-rw-r--r--test/mitmproxy/tservers.py49
-rw-r--r--test/netlib/http/test_headers.py9
-rw-r--r--test/netlib/http/test_response.py5
-rw-r--r--test/netlib/test_encoding.py59
-rw-r--r--test/netlib/test_strutils.py2
-rw-r--r--test/pathod/test_protocols_http2.py6
57 files changed, 1178 insertions, 1191 deletions
diff --git a/.appveyor.yml b/.appveyor.yml
index 13782ee8..7fa65e1b 100644
--- a/.appveyor.yml
+++ b/.appveyor.yml
@@ -25,7 +25,7 @@ install:
- "pip install -U tox"
test_script:
- - ps: "tox -- --cov netlib --cov mitmproxy --cov pathod | Select-String -NotMatch Cryptography_locking_cb"
+ - ps: "tox -- --cov netlib --cov mitmproxy --cov pathod -v"
deploy_script:
ps: |
diff --git a/.travis.yml b/.travis.yml
index a114dafb..f7c5e839 100644
--- a/.travis.yml
+++ b/.travis.yml
@@ -53,7 +53,7 @@ install:
fi
- pip install tox
-script: tox -- --cov netlib --cov mitmproxy --cov pathod
+script: tox -- --cov netlib --cov mitmproxy --cov pathod -v
after_success:
- |
diff --git a/mitmproxy/addons.py b/mitmproxy/addons.py
index 329d1215..2658c0af 100644
--- a/mitmproxy/addons.py
+++ b/mitmproxy/addons.py
@@ -4,7 +4,7 @@ import pprint
def _get_name(itm):
- return getattr(itm, "name", itm.__class__.__name__)
+ return getattr(itm, "name", itm.__class__.__name__.lower())
class Addons(object):
@@ -13,6 +13,16 @@ class Addons(object):
self.master = master
master.options.changed.connect(self.options_update)
+ def get(self, name):
+ """
+ Retrieve an addon by name. Addon names are equal to the .name
+ attribute on the instance, or the lower case class name if that
+ does not exist.
+ """
+ for i in self.chain:
+ if name == _get_name(i):
+ return i
+
def options_update(self, options, updated):
for i in self.chain:
with self.master.handlecontext():
@@ -39,14 +49,6 @@ class Addons(object):
for i in self.chain:
self.invoke_with_context(i, "done")
- def has_addon(self, name):
- """
- Is an addon with this name registered?
- """
- for i in self.chain:
- if _get_name(i) == name:
- return True
-
def __len__(self):
return len(self.chain)
diff --git a/mitmproxy/builtins/__init__.py b/mitmproxy/builtins/__init__.py
index 3974d736..26e9dfbd 100644
--- a/mitmproxy/builtins/__init__.py
+++ b/mitmproxy/builtins/__init__.py
@@ -8,6 +8,8 @@ from mitmproxy.builtins import stickycookie
from mitmproxy.builtins import script
from mitmproxy.builtins import replace
from mitmproxy.builtins import setheaders
+from mitmproxy.builtins import serverplayback
+from mitmproxy.builtins import clientplayback
def default_addons():
@@ -20,4 +22,6 @@ def default_addons():
filestreamer.FileStreamer(),
replace.Replace(),
setheaders.SetHeaders(),
+ serverplayback.ServerPlayback(),
+ clientplayback.ClientPlayback(),
]
diff --git a/mitmproxy/builtins/clientplayback.py b/mitmproxy/builtins/clientplayback.py
new file mode 100644
index 00000000..c40d1904
--- /dev/null
+++ b/mitmproxy/builtins/clientplayback.py
@@ -0,0 +1,39 @@
+from mitmproxy import exceptions, flow, ctx
+
+
+class ClientPlayback:
+ def __init__(self):
+ self.flows = None
+ self.current = None
+ self.keepserving = None
+ self.has_replayed = False
+
+ def count(self):
+ if self.flows:
+ return len(self.flows)
+ return 0
+
+ def load(self, flows):
+ self.flows = flows
+
+ def configure(self, options, updated):
+ if "client_replay" in updated:
+ if options.client_replay:
+ try:
+ flows = flow.read_flows_from_paths(options.client_replay)
+ except exceptions.FlowReadException as e:
+ raise exceptions.OptionsError(str(e))
+ self.load(flows)
+ else:
+ self.flows = None
+ self.keepserving = options.keepserving
+
+ def tick(self):
+ if self.current and not self.current.is_alive():
+ self.current = None
+ if self.flows and not self.current:
+ self.current = ctx.master.replay_request(self.flows.pop(0))
+ self.has_replayed = True
+ if self.has_replayed:
+ if not self.flows and not self.current and not self.keepserving:
+ ctx.master.shutdown()
diff --git a/mitmproxy/builtins/dumper.py b/mitmproxy/builtins/dumper.py
index 743ca72e..60d00518 100644
--- a/mitmproxy/builtins/dumper.py
+++ b/mitmproxy/builtins/dumper.py
@@ -232,6 +232,14 @@ class Dumper(object):
if self.match(f):
self.echo_flow(f)
+ def tcp_error(self, f):
+ self.echo(
+ "Error in TCP connection to {}: {}".format(
+ repr(f.server_conn.address), f.error
+ ),
+ fg="red"
+ )
+
def tcp_message(self, f):
if not self.match(f):
return
diff --git a/mitmproxy/builtins/replace.py b/mitmproxy/builtins/replace.py
index c938d683..df3cab04 100644
--- a/mitmproxy/builtins/replace.py
+++ b/mitmproxy/builtins/replace.py
@@ -36,9 +36,9 @@ class Replace:
for rex, s, cpatt in self.lst:
if cpatt(f):
if f.response:
- f.response.replace(rex, s)
+ f.response.replace(rex, s, flags=re.DOTALL)
else:
- f.request.replace(rex, s)
+ f.request.replace(rex, s, flags=re.DOTALL)
def request(self, flow):
if not flow.reply.has_message:
diff --git a/mitmproxy/builtins/script.py b/mitmproxy/builtins/script.py
index ae1d1b91..1ebec873 100644
--- a/mitmproxy/builtins/script.py
+++ b/mitmproxy/builtins/script.py
@@ -10,6 +10,7 @@ import traceback
from mitmproxy import exceptions
from mitmproxy import controller
from mitmproxy import ctx
+from mitmproxy.flow import master as flowmaster
import watchdog.events
@@ -67,7 +68,11 @@ def scriptenv(path, args):
tb = tb.tb_next
if not os.path.abspath(s[0]).startswith(scriptdir):
break
- ctx.log.error("Script error: %s" % "".join(traceback.format_exception(etype, value, tb)))
+ ctx.log.error(
+ "Script error: %s" % "".join(
+ traceback.format_exception(etype, value, tb)
+ )
+ )
finally:
sys.argv = oldargs
sys.path.pop()
@@ -189,6 +194,15 @@ class ScriptLoader():
"""
An addon that manages loading scripts from options.
"""
+ def run_once(self, command, flows):
+ sc = Script(command)
+ sc.load_script()
+ for f in flows:
+ for evt, o in flowmaster.event_sequence(f):
+ sc.run(evt, o)
+ sc.done()
+ return sc
+
def configure(self, options, updated):
if "scripts" in updated:
for s in options.scripts:
diff --git a/mitmproxy/builtins/serverplayback.py b/mitmproxy/builtins/serverplayback.py
new file mode 100644
index 00000000..29fc95ef
--- /dev/null
+++ b/mitmproxy/builtins/serverplayback.py
@@ -0,0 +1,122 @@
+from __future__ import absolute_import, print_function, division
+from six.moves import urllib
+import hashlib
+
+from netlib import strutils
+from mitmproxy import exceptions, flow, ctx
+
+
+class ServerPlayback(object):
+ def __init__(self):
+ self.options = None
+
+ self.flowmap = {}
+ self.stop = False
+ self.final_flow = None
+
+ def load(self, flows):
+ for i in flows:
+ if i.response:
+ l = self.flowmap.setdefault(self._hash(i), [])
+ l.append(i)
+
+ def clear(self):
+ self.flowmap = {}
+
+ def count(self):
+ return sum([len(i) for i in self.flowmap.values()])
+
+ def _hash(self, flow):
+ """
+ Calculates a loose hash of the flow request.
+ """
+ r = flow.request
+
+ _, _, path, _, query, _ = urllib.parse.urlparse(r.url)
+ queriesArray = urllib.parse.parse_qsl(query, keep_blank_values=True)
+
+ key = [str(r.port), str(r.scheme), str(r.method), str(path)]
+ if not self.options.server_replay_ignore_content:
+ form_contents = r.urlencoded_form or r.multipart_form
+ if self.options.server_replay_ignore_payload_params and form_contents:
+ params = [
+ strutils.always_bytes(i)
+ for i in self.options.server_replay_ignore_payload_params
+ ]
+ for p in form_contents.items(multi=True):
+ if p[0] not in params:
+ key.append(p)
+ else:
+ key.append(str(r.raw_content))
+
+ if not self.options.server_replay_ignore_host:
+ key.append(r.host)
+
+ filtered = []
+ ignore_params = self.options.server_replay_ignore_params or []
+ for p in queriesArray:
+ if p[0] not in ignore_params:
+ filtered.append(p)
+ for p in filtered:
+ key.append(p[0])
+ key.append(p[1])
+
+ if self.options.server_replay_use_headers:
+ headers = []
+ for i in self.options.server_replay_use_headers:
+ v = r.headers.get(i)
+ headers.append((i, v))
+ key.append(headers)
+ return hashlib.sha256(
+ repr(key).encode("utf8", "surrogateescape")
+ ).digest()
+
+ def next_flow(self, request):
+ """
+ Returns the next flow object, or None if no matching flow was
+ found.
+ """
+ hsh = self._hash(request)
+ if hsh in self.flowmap:
+ if self.options.server_replay_nopop:
+ return self.flowmap[hsh][0]
+ else:
+ ret = self.flowmap[hsh].pop(0)
+ if not self.flowmap[hsh]:
+ del self.flowmap[hsh]
+ return ret
+
+ def configure(self, options, updated):
+ self.options = options
+ if "server_replay" in updated:
+ self.clear()
+ if options.server_replay:
+ try:
+ flows = flow.read_flows_from_paths(options.server_replay)
+ except exceptions.FlowReadException as e:
+ raise exceptions.OptionsError(str(e))
+ self.load(flows)
+
+ def tick(self):
+ if self.stop and not self.final_flow.live:
+ ctx.master.shutdown()
+
+ def request(self, f):
+ if self.flowmap:
+ rflow = self.next_flow(f)
+ if rflow:
+ response = rflow.response.copy()
+ response.is_replay = True
+ if self.options.refresh_server_playback:
+ response.refresh()
+ f.response = response
+ if not self.flowmap and not self.options.keepserving:
+ self.final_flow = f
+ self.stop = True
+ elif self.options.replay_kill_extra:
+ ctx.log.warn(
+ "server_playback: killed non-replay request {}".format(
+ f.request.url
+ )
+ )
+ f.reply.kill()
diff --git a/mitmproxy/cmdline.py b/mitmproxy/cmdline.py
index b5b939b8..9fb4a561 100644
--- a/mitmproxy/cmdline.py
+++ b/mitmproxy/cmdline.py
@@ -218,10 +218,10 @@ def get_common_options(args):
anticache=args.anticache,
anticomp=args.anticomp,
client_replay=args.client_replay,
- kill=args.kill,
+ replay_kill_extra=args.replay_kill_extra,
no_server=args.no_server,
refresh_server_playback=not args.norefresh,
- rheaders=args.rheaders,
+ server_replay_use_headers=args.server_replay_use_headers,
rfile=args.rfile,
replacements=reps,
setheaders=setheaders,
@@ -233,11 +233,11 @@ def get_common_options(args):
showhost=args.showhost,
outfile=args.outfile,
verbosity=args.verbose,
- nopop=args.nopop,
- replay_ignore_content=args.replay_ignore_content,
- replay_ignore_params=args.replay_ignore_params,
- replay_ignore_payload_params=args.replay_ignore_payload_params,
- replay_ignore_host=args.replay_ignore_host,
+ server_replay_nopop=args.server_replay_nopop,
+ server_replay_ignore_content=args.server_replay_ignore_content,
+ server_replay_ignore_params=args.server_replay_ignore_params,
+ server_replay_ignore_payload_params=args.server_replay_ignore_payload_params,
+ server_replay_ignore_host=args.server_replay_ignore_host,
auth_nonanonymous = args.auth_nonanonymous,
auth_singleuser = args.auth_singleuser,
@@ -600,13 +600,13 @@ def server_replay(parser):
help="Replay server responses from a saved file."
)
group.add_argument(
- "-k", "--kill",
- action="store_true", dest="kill", default=False,
+ "-k", "--replay-kill-extra",
+ action="store_true", dest="replay_kill_extra", default=False,
help="Kill extra requests during replay."
)
group.add_argument(
- "--rheader",
- action="append", dest="rheaders", type=str,
+ "--server-replay-use-header",
+ action="append", dest="server_replay_use_headers", type=str,
help="Request headers to be considered during replay. "
"Can be passed multiple times."
)
@@ -620,21 +620,21 @@ def server_replay(parser):
)
group.add_argument(
"--no-pop",
- action="store_true", dest="nopop", default=False,
+ action="store_true", dest="server_replay_nopop", default=False,
help="Disable response pop from response flow. "
"This makes it possible to replay same response multiple times."
)
payload = group.add_mutually_exclusive_group()
payload.add_argument(
"--replay-ignore-content",
- action="store_true", dest="replay_ignore_content", default=False,
+ action="store_true", dest="server_replay_ignore_content", default=False,
help="""
Ignore request's content while searching for a saved flow to replay
"""
)
payload.add_argument(
"--replay-ignore-payload-param",
- action="append", dest="replay_ignore_payload_params", type=str,
+ action="append", dest="server_replay_ignore_payload_params", type=str,
help="""
Request's payload parameters (application/x-www-form-urlencoded or multipart/form-data) to
be ignored while searching for a saved flow to replay.
@@ -644,7 +644,7 @@ def server_replay(parser):
group.add_argument(
"--replay-ignore-param",
- action="append", dest="replay_ignore_params", type=str,
+ action="append", dest="server_replay_ignore_params", type=str,
help="""
Request's parameters to be ignored while searching for a saved flow
to replay. Can be passed multiple times.
@@ -653,7 +653,7 @@ def server_replay(parser):
group.add_argument(
"--replay-ignore-host",
action="store_true",
- dest="replay_ignore_host",
+ dest="server_replay_ignore_host",
default=False,
help="Ignore request's destination host while searching for a saved flow to replay")
diff --git a/mitmproxy/console/flowlist.py b/mitmproxy/console/flowlist.py
index 11e8fc99..c052de7b 100644
--- a/mitmproxy/console/flowlist.py
+++ b/mitmproxy/console/flowlist.py
@@ -18,14 +18,15 @@ def _mkhelp():
("d", "delete flow"),
("D", "duplicate flow"),
("e", "toggle eventlog"),
+ ("E", "export flow to file"),
("f", "filter view"),
("F", "toggle follow flow list"),
("L", "load saved flows"),
("m", "toggle flow mark"),
("M", "toggle marked flow view"),
("n", "create a new request"),
- ("E", "export flow to file"),
("r", "replay request"),
+ ("S", "server replay request/s"),
("U", "unmark all marked flows"),
("V", "revert changes to request"),
("w", "save flows "),
@@ -140,36 +141,13 @@ class ConnectionItem(urwid.WidgetWrap):
args = (self.flow,)
)
- def stop_server_playback_prompt(self, a):
- if a != "n":
- self.master.stop_server_playback()
-
def server_replay_prompt(self, k):
+ a = self.master.addons.get("serverplayback")
if k == "a":
- self.master.start_server_playback(
- [i.copy() for i in self.master.state.view],
- self.master.options.kill, self.master.options.rheaders,
- False, self.master.options.nopop,
- self.master.options.replay_ignore_params,
- self.master.options.replay_ignore_content,
- self.master.options.replay_ignore_payload_params,
- self.master.options.replay_ignore_host
- )
+ a.load([i.copy() for i in self.master.state.view])
elif k == "t":
- self.master.start_server_playback(
- [self.flow.copy()],
- self.master.options.kill, self.master.options.rheaders,
- False, self.master.options.nopop,
- self.master.options.replay_ignore_params,
- self.master.options.replay_ignore_content,
- self.master.options.replay_ignore_payload_params,
- self.master.options.replay_ignore_host
- )
- else:
- signals.status_prompt_path.send(
- prompt = "Server replay path",
- callback = self.master.server_playback_path
- )
+ a.load([self.flow.copy()])
+ signals.update_settings.send(self)
def mouse_event(self, size, event, button, col, row, focus):
if event == "mouse press" and button == 1:
@@ -202,29 +180,30 @@ class ConnectionItem(urwid.WidgetWrap):
self.state.enable_marked_filter()
signals.flowlist_change.send(self)
elif key == "r":
- r = self.master.replay_request(self.flow)
- if r:
- signals.status_message.send(message=r)
+ self.master.replay_request(self.flow)
signals.flowlist_change.send(self)
elif key == "S":
- if not self.master.server_playback:
+ def stop_server_playback(response):
+ if response == "y":
+ self.master.options.server_replay = []
+ a = self.master.addons.get("serverplayback")
+ if a.count():
signals.status_prompt_onekey.send(
- prompt = "Server Replay",
+ prompt = "Stop current server replay?",
keys = (
- ("all flows", "a"),
- ("this flow", "t"),
- ("file", "f"),
+ ("yes", "y"),
+ ("no", "n"),
),
- callback = self.server_replay_prompt,
+ callback = stop_server_playback,
)
else:
signals.status_prompt_onekey.send(
- prompt = "Stop current server replay?",
+ prompt = "Server Replay",
keys = (
- ("yes", "y"),
- ("no", "n"),
+ ("all flows", "a"),
+ ("this flow", "t"),
),
- callback = self.stop_server_playback_prompt,
+ callback = self.server_replay_prompt,
)
elif key == "U":
for f in self.state.flows:
diff --git a/mitmproxy/console/flowview.py b/mitmproxy/console/flowview.py
index 15be379b..add10527 100644
--- a/mitmproxy/console/flowview.py
+++ b/mitmproxy/console/flowview.py
@@ -544,9 +544,7 @@ class FlowView(tabs.Tabs):
elif key == "p":
self.view_prev_flow(self.flow)
elif key == "r":
- r = self.master.replay_request(self.flow)
- if r:
- signals.status_message.send(message=r)
+ self.master.replay_request(self.flow)
signals.flow_change.send(self, flow = self.flow)
elif key == "V":
if self.flow.modified():
diff --git a/mitmproxy/console/help.py b/mitmproxy/console/help.py
index 8024dc31..e3e2f54c 100644
--- a/mitmproxy/console/help.py
+++ b/mitmproxy/console/help.py
@@ -53,7 +53,7 @@ class HelpView(urwid.ListBox):
("o", "options"),
("q", "quit / return to previous page"),
("Q", "quit without confirm prompt"),
- ("R", "replay of HTTP requests/responses"),
+ ("R", "replay of requests/responses from file"),
]
text.extend(
common.format_keyvals(keys, key="key", val="text", indent=4)
diff --git a/mitmproxy/console/master.py b/mitmproxy/console/master.py
index a6942ca4..8ded1ea1 100644
--- a/mitmproxy/console/master.py
+++ b/mitmproxy/console/master.py
@@ -22,7 +22,6 @@ from mitmproxy import contentviews
from mitmproxy import controller
from mitmproxy import exceptions
from mitmproxy import flow
-from mitmproxy import script
from mitmproxy import utils
import mitmproxy.options
from mitmproxy.console import flowlist
@@ -67,11 +66,13 @@ class ConsoleState(flow.State):
def add_flow(self, f):
super(ConsoleState, self).add_flow(f)
+ signals.flowlist_change.send(self)
self.update_focus()
return f
def update_flow(self, f):
super(ConsoleState, self).update_flow(f)
+ signals.flowlist_change.send(self)
self.update_focus()
return f
@@ -245,12 +246,6 @@ class ConsoleMaster(flow.FlowMaster):
self.logbuffer = urwid.SimpleListWalker([])
self.follow = options.follow
- if options.client_replay:
- self.client_playback_path(options.client_replay)
-
- if options.server_replay:
- self.server_playback_path(options.server_replay)
-
self.view_stack = []
if options.app:
@@ -332,39 +327,13 @@ class ConsoleMaster(flow.FlowMaster):
self.loop.widget = window
self.loop.draw_screen()
- def _run_script_method(self, method, s, f):
- status, val = s.run(method, f)
- if val:
- if status:
- signals.add_log("Method %s return: %s" % (method, val), "debug")
- else:
- signals.add_log(
- "Method %s error: %s" %
- (method, val[1]), "error")
-
def run_script_once(self, command, f):
- if not command:
- return
- signals.add_log("Running script on flow: %s" % command, "debug")
-
+ sc = self.addons.get("scriptloader")
try:
- s = script.Script(command)
- s.load()
- except script.ScriptException as e:
- signals.status_message.send(
- message='Error loading "{}".'.format(command)
- )
- signals.add_log('Error loading "{}":\n{}'.format(command, e), "error")
- return
-
- if f.request:
- self._run_script_method("request", s, f)
- if f.response:
- self._run_script_method("response", s, f)
- if f.error:
- self._run_script_method("error", s, f)
- s.unload()
- signals.flow_change.send(self, flow = f)
+ with self.handlecontext():
+ sc.run_once(command, [f])
+ except mitmproxy.exceptions.AddonError as e:
+ signals.add_log("Script error: %s" % e, "warn")
def toggle_eventlog(self):
self.options.eventlog = not self.options.eventlog
@@ -384,28 +353,6 @@ class ConsoleMaster(flow.FlowMaster):
except exceptions.FlowReadException as e:
signals.status_message.send(message=str(e))
- def client_playback_path(self, path):
- if not isinstance(path, list):
- path = [path]
- flows = self._readflows(path)
- if flows:
- self.start_client_playback(flows, False)
-
- def server_playback_path(self, path):
- if not isinstance(path, list):
- path = [path]
- flows = self._readflows(path)
- if flows:
- self.start_server_playback(
- flows,
- self.options.kill, self.options.rheaders,
- False, self.options.nopop,
- self.options.replay_ignore_params,
- self.options.replay_ignore_content,
- self.options.replay_ignore_payload_params,
- self.options.replay_ignore_host
- )
-
def spawn_editor(self, data):
text = not isinstance(data, bytes)
fd, name = tempfile.mkstemp('', "mproxy", text=text)
diff --git a/mitmproxy/console/options.py b/mitmproxy/console/options.py
index f7fb2f90..97313bf4 100644
--- a/mitmproxy/console/options.py
+++ b/mitmproxy/console/options.py
@@ -114,8 +114,8 @@ class Options(urwid.WidgetWrap):
select.Option(
"Kill Extra",
"x",
- lambda: master.options.kill,
- master.options.toggler("kill")
+ lambda: master.options.replay_kill_extra,
+ master.options.toggler("replay_kill_extra")
),
select.Option(
"No Refresh",
@@ -165,7 +165,7 @@ class Options(urwid.WidgetWrap):
anticomp = False,
ignore_hosts = (),
tcp_hosts = (),
- kill = False,
+ replay_kill_extra = False,
no_upstream_cert = False,
refresh_server_playback = True,
replacements = [],
diff --git a/mitmproxy/console/statusbar.py b/mitmproxy/console/statusbar.py
index 43d68d51..bbfb41ab 100644
--- a/mitmproxy/console/statusbar.py
+++ b/mitmproxy/console/statusbar.py
@@ -136,6 +136,9 @@ class StatusBar(urwid.WidgetWrap):
def get_status(self):
r = []
+ sreplay = self.master.addons.get("serverplayback")
+ creplay = self.master.addons.get("clientplayback")
+
if len(self.master.options.setheaders):
r.append("[")
r.append(("heading_key", "H"))
@@ -144,17 +147,14 @@ class StatusBar(urwid.WidgetWrap):
r.append("[")
r.append(("heading_key", "R"))
r.append("eplacing]")
- if self.master.client_playback:
+ if creplay.count():
r.append("[")
r.append(("heading_key", "cplayback"))
- r.append(":%s to go]" % self.master.client_playback.count())
- if self.master.server_playback:
+ r.append(":%s]" % creplay.count())
+ if sreplay.count():
r.append("[")
r.append(("heading_key", "splayback"))
- if self.master.options.nopop:
- r.append(":%s in file]" % self.master.server_playback.count())
- else:
- r.append(":%s to go]" % self.master.server_playback.count())
+ r.append(":%s]" % sreplay.count())
if self.master.options.ignore_hosts:
r.append("[")
r.append(("heading_key", "I"))
@@ -193,7 +193,7 @@ class StatusBar(urwid.WidgetWrap):
opts.append("showhost")
if not self.master.options.refresh_server_playback:
opts.append("norefresh")
- if self.master.options.kill:
+ if self.master.options.replay_kill_extra:
opts.append("killextra")
if self.master.options.no_upstream_cert:
opts.append("no-upstream-cert")
diff --git a/mitmproxy/console/window.py b/mitmproxy/console/window.py
index 35593643..159f68ed 100644
--- a/mitmproxy/console/window.py
+++ b/mitmproxy/console/window.py
@@ -57,13 +57,11 @@ class Window(urwid.Frame):
callback = self.master.stop_client_playback_prompt,
)
elif k == "s":
- if not self.master.server_playback:
- signals.status_prompt_path.send(
- self,
- prompt = "Server replay path",
- callback = self.master.server_playback_path
- )
- else:
+ a = self.master.addons.get("serverplayback")
+ if a.count():
+ def stop_server_playback(response):
+ if response == "y":
+ self.master.options.server_replay = []
signals.status_prompt_onekey.send(
self,
prompt = "Stop current server replay?",
@@ -71,7 +69,13 @@ class Window(urwid.Frame):
("yes", "y"),
("no", "n"),
),
- callback = self.master.stop_server_playback_prompt,
+ callback = stop_server_playback
+ )
+ else:
+ signals.status_prompt_path.send(
+ self,
+ prompt = "Server playback path",
+ callback = lambda x: self.master.options.setter("server_replay")([x])
)
def keypress(self, size, k):
diff --git a/mitmproxy/dump.py b/mitmproxy/dump.py
index 51124224..778968b9 100644
--- a/mitmproxy/dump.py
+++ b/mitmproxy/dump.py
@@ -46,37 +46,17 @@ class DumpMaster(flow.FlowMaster):
self.addons.add(options, dumper.Dumper())
# This line is just for type hinting
self.options = self.options # type: Options
- self.replay_ignore_params = options.replay_ignore_params
- self.replay_ignore_content = options.replay_ignore_content
- self.replay_ignore_host = options.replay_ignore_host
- self.refresh_server_playback = options.refresh_server_playback
- self.replay_ignore_payload_params = options.replay_ignore_payload_params
-
self.set_stream_large_bodies(options.stream_large_bodies)
+ print("Proxy server listening at http://%s:%d" % (
+ (options.listen_host or "0.0.0.0"),
+ options.listen_port))
+
if self.server and self.options.http2 and not tcp.HAS_ALPN: # pragma: no cover
print("ALPN support missing (OpenSSL 1.0.2+ required)!\n"
"HTTP/2 is disabled. Use --no-http2 to silence this warning.",
file=sys.stderr)
- if options.server_replay:
- self.start_server_playback(
- self._readflow(options.server_replay),
- options.kill, options.rheaders,
- not options.keepserving,
- options.nopop,
- options.replay_ignore_params,
- options.replay_ignore_content,
- options.replay_ignore_payload_params,
- options.replay_ignore_host
- )
-
- if options.client_replay:
- self.start_client_playback(
- self._readflow(options.client_replay),
- not options.keepserving
- )
-
if options.rfile:
try:
self.load_flows_file(options.rfile)
diff --git a/mitmproxy/exceptions.py b/mitmproxy/exceptions.py
index 94876514..6873215c 100644
--- a/mitmproxy/exceptions.py
+++ b/mitmproxy/exceptions.py
@@ -7,9 +7,6 @@ See also: http://lucumr.pocoo.org/2014/10/16/on-error-handling/
"""
from __future__ import absolute_import, print_function, division
-import sys
-import traceback
-
class ProxyException(Exception):
@@ -30,6 +27,10 @@ class Kill(ProxyException):
class ProtocolException(ProxyException):
+ """
+ ProtocolExceptions are caused by invalid user input, unavailable network resources,
+ or other events that are outside of our influence.
+ """
pass
@@ -62,6 +63,10 @@ class Http2ProtocolException(ProtocolException):
pass
+class Http2ZombieException(ProtocolException):
+ pass
+
+
class ServerException(ProxyException):
pass
@@ -74,27 +79,6 @@ class ReplayException(ProxyException):
pass
-class ScriptException(ProxyException):
-
- @classmethod
- def from_exception_context(cls, cut_tb=1):
- """
- Must be called while the current stack handles an exception.
-
- Args:
- cut_tb: remove N frames from the stack trace to hide internal calls.
- """
- exc_type, exc_value, exc_traceback = sys.exc_info()
-
- while cut_tb > 0:
- exc_traceback = exc_traceback.tb_next
- cut_tb -= 1
-
- tb = "".join(traceback.format_exception(exc_type, exc_value, exc_traceback))
-
- return cls(tb)
-
-
class FlowReadException(ProxyException):
pass
@@ -113,3 +97,7 @@ class OptionsError(Exception):
class AddonError(Exception):
pass
+
+
+class ReplayError(Exception):
+ pass
diff --git a/mitmproxy/filt.py b/mitmproxy/filt.py
index 67915e5b..eb3e392b 100644
--- a/mitmproxy/filt.py
+++ b/mitmproxy/filt.py
@@ -244,6 +244,7 @@ class FHeadResponse(_Rex):
class FBod(_Rex):
code = "b"
help = "Body"
+ flags = re.DOTALL
@only(HTTPFlow, TCPFlow)
def __call__(self, f):
@@ -264,6 +265,7 @@ class FBod(_Rex):
class FBodRequest(_Rex):
code = "bq"
help = "Request body"
+ flags = re.DOTALL
@only(HTTPFlow, TCPFlow)
def __call__(self, f):
@@ -280,6 +282,7 @@ class FBodRequest(_Rex):
class FBodResponse(_Rex):
code = "bs"
help = "Response body"
+ flags = re.DOTALL
@only(HTTPFlow, TCPFlow)
def __call__(self, f):
diff --git a/mitmproxy/flow/__init__.py b/mitmproxy/flow/__init__.py
index 8a64180e..cb79482c 100644
--- a/mitmproxy/flow/__init__.py
+++ b/mitmproxy/flow/__init__.py
@@ -4,16 +4,12 @@ from mitmproxy.flow import export, modules
from mitmproxy.flow.io import FlowWriter, FilteredFlowWriter, FlowReader, read_flows_from_paths
from mitmproxy.flow.master import FlowMaster
from mitmproxy.flow.modules import (
- AppRegistry, StreamLargeBodies, ClientPlaybackState, ServerPlaybackState
+ AppRegistry, StreamLargeBodies
)
from mitmproxy.flow.state import State, FlowView
-# TODO: We may want to remove the imports from .modules and just expose "modules"
-
__all__ = [
"export", "modules",
"FlowWriter", "FilteredFlowWriter", "FlowReader", "read_flows_from_paths",
- "FlowMaster",
- "AppRegistry", "StreamLargeBodies", "ClientPlaybackState",
- "ServerPlaybackState", "State", "FlowView",
+ "FlowMaster", "AppRegistry", "StreamLargeBodies", "State", "FlowView",
]
diff --git a/mitmproxy/flow/master.py b/mitmproxy/flow/master.py
index 9cdcc8dd..94b46f3f 100644
--- a/mitmproxy/flow/master.py
+++ b/mitmproxy/flow/master.py
@@ -15,6 +15,30 @@ from mitmproxy.onboarding import app
from mitmproxy.protocol import http_replay
+def event_sequence(f):
+ if isinstance(f, models.HTTPFlow):
+ if f.request:
+ yield "request", f
+ if f.response:
+ yield "responseheaders", f
+ yield "response", f
+ if f.error:
+ yield "error", f
+ elif isinstance(f, models.TCPFlow):
+ messages = f.messages
+ f.messages = []
+ f.reply = controller.DummyReply()
+ yield "tcp_open", f
+ while messages:
+ f.messages.append(messages.pop(0))
+ yield "tcp_message", f
+ if f.error:
+ yield "tcp_error", f
+ yield "tcp_close", f
+ else:
+ raise NotImplementedError
+
+
class FlowMaster(controller.Master):
@property
@@ -29,23 +53,11 @@ class FlowMaster(controller.Master):
if server:
self.add_server(server)
self.state = state
- self.server_playback = None # type: Optional[modules.ServerPlaybackState]
- self.client_playback = None # type: Optional[modules.ClientPlaybackState]
- self.kill_nonreplay = False
-
self.stream_large_bodies = None # type: Optional[modules.StreamLargeBodies]
- self.replay_ignore_params = False
- self.replay_ignore_content = None
- self.replay_ignore_host = False
-
self.apps = modules.AppRegistry()
def start_app(self, host, port):
- self.apps.add(
- app.mapp,
- host,
- port
- )
+ self.apps.add(app.mapp, host, port)
def set_stream_large_bodies(self, max_size):
if max_size is not None:
@@ -53,92 +65,6 @@ class FlowMaster(controller.Master):
else:
self.stream_large_bodies = False
- def start_client_playback(self, flows, exit):
- """
- flows: List of flows.
- """
- self.client_playback = modules.ClientPlaybackState(flows, exit)
-
- def stop_client_playback(self):
- self.client_playback = None
-
- def start_server_playback(
- self,
- flows,
- kill,
- headers,
- exit,
- nopop,
- ignore_params,
- ignore_content,
- ignore_payload_params,
- ignore_host):
- """
- flows: List of flows.
- kill: Boolean, should we kill requests not part of the replay?
- ignore_params: list of parameters to ignore in server replay
- ignore_content: true if request content should be ignored in server replay
- ignore_payload_params: list of content params to ignore in server replay
- ignore_host: true if request host should be ignored in server replay
- """
- self.server_playback = modules.ServerPlaybackState(
- headers,
- flows,
- exit,
- nopop,
- ignore_params,
- ignore_content,
- ignore_payload_params,
- ignore_host)
- self.kill_nonreplay = kill
-
- def stop_server_playback(self):
- self.server_playback = None
-
- def do_server_playback(self, flow):
- """
- This method should be called by child classes in the request
- handler. Returns True if playback has taken place, None if not.
- """
- if self.server_playback:
- rflow = self.server_playback.next_flow(flow)
- if not rflow:
- return None
- response = rflow.response.copy()
- response.is_replay = True
- if self.options.refresh_server_playback:
- response.refresh()
- flow.response = response
- return True
- return None
-
- def tick(self, timeout):
- if self.client_playback:
- stop = (
- self.client_playback.done() and
- self.state.active_flow_count() == 0
- )
- exit = self.client_playback.exit
- if stop:
- self.stop_client_playback()
- if exit:
- self.shutdown()
- else:
- self.client_playback.tick(self)
-
- if self.server_playback:
- stop = (
- self.server_playback.count() == 0 and
- self.state.active_flow_count() == 0 and
- not self.kill_nonreplay
- )
- exit = self.server_playback.exit
- if stop:
- self.stop_server_playback()
- if exit:
- self.shutdown()
- return super(FlowMaster, self).tick(timeout)
-
def duplicate_flow(self, f):
"""
Duplicate flow, and insert it into state without triggering any of
@@ -182,28 +108,9 @@ class FlowMaster(controller.Master):
f.request.host = self.server.config.upstream_server.address.host
f.request.port = self.server.config.upstream_server.address.port
f.request.scheme = self.server.config.upstream_server.scheme
-
- f.reply = controller.DummyReply()
- if f.request:
- self.request(f)
- if f.response:
- self.responseheaders(f)
- self.response(f)
- if f.error:
- self.error(f)
- elif isinstance(f, models.TCPFlow):
- messages = f.messages
- f.messages = []
- f.reply = controller.DummyReply()
- self.tcp_open(f)
- while messages:
- f.messages.append(messages.pop(0))
- self.tcp_message(f)
- if f.error:
- self.tcp_error(f)
- self.tcp_close(f)
- else:
- raise NotImplementedError()
+ f.reply = controller.DummyReply()
+ for e, o in event_sequence(f):
+ getattr(self, e)(o)
def load_flows(self, fr):
"""
@@ -229,43 +136,49 @@ class FlowMaster(controller.Master):
except IOError as v:
raise exceptions.FlowReadException(v.strerror)
- def process_new_request(self, f):
- if self.server_playback:
- pb = self.do_server_playback(f)
- if not pb and self.kill_nonreplay:
- self.add_log("Killed {}".format(f.request.url), "info")
- f.reply.kill()
-
def replay_request(self, f, block=False):
"""
- Returns None if successful, or error message if not.
+ Returns an http_replay.RequestReplayThred object.
+ May raise exceptions.ReplayError.
"""
if f.live:
- return "Can't replay live request."
+ raise exceptions.ReplayError(
+ "Can't replay live flow."
+ )
if f.intercepted:
- return "Can't replay while intercepting..."
+ raise exceptions.ReplayError(
+ "Can't replay intercepted flow."
+ )
if f.request.raw_content is None:
- return "Can't replay request with missing content..."
- if f.request:
- f.backup()
- f.request.is_replay = True
-
- # TODO: We should be able to remove this.
- if "Content-Length" in f.request.headers:
- f.request.headers["Content-Length"] = str(len(f.request.raw_content))
-
- f.response = None
- f.error = None
- self.process_new_request(f)
- rt = http_replay.RequestReplayThread(
- self.server.config,
- f,
- self.event_queue,
- self.should_exit
+ raise exceptions.ReplayError(
+ "Can't replay flow with missing content."
+ )
+ if not f.request:
+ raise exceptions.ReplayError(
+ "Can't replay flow with missing request."
)
- rt.start() # pragma: no cover
- if block:
- rt.join()
+
+ f.backup()
+ f.request.is_replay = True
+
+ # TODO: We should be able to remove this.
+ if "Content-Length" in f.request.headers:
+ f.request.headers["Content-Length"] = str(len(f.request.raw_content))
+
+ f.response = None
+ f.error = None
+ # FIXME: process through all addons?
+ # self.process_new_request(f)
+ rt = http_replay.RequestReplayThread(
+ self.server.config,
+ f,
+ self.event_queue,
+ self.should_exit
+ )
+ rt.start() # pragma: no cover
+ if block:
+ rt.join()
+ return rt
@controller.handler
def log(self, l):
@@ -294,9 +207,6 @@ class FlowMaster(controller.Master):
@controller.handler
def error(self, f):
self.state.update_flow(f)
- if self.client_playback:
- self.client_playback.clear(f)
- return f
@controller.handler
def request(self, f):
@@ -314,8 +224,6 @@ class FlowMaster(controller.Master):
return
if f not in self.state.flows: # don't add again on replay
self.state.add_flow(f)
- self.process_new_request(f)
- return f
@controller.handler
def responseheaders(self, f):
@@ -325,18 +233,14 @@ class FlowMaster(controller.Master):
except netlib.exceptions.HttpException:
f.reply.kill()
return
- return f
@controller.handler
def response(self, f):
self.state.update_flow(f)
- if self.client_playback:
- self.client_playback.clear(f)
- return f
@controller.handler
def websockets_handshake(self, f):
- return f
+ pass
def handle_intercept(self, f):
self.state.update_flow(f)
@@ -356,10 +260,7 @@ class FlowMaster(controller.Master):
@controller.handler
def tcp_error(self, flow):
- self.add_log("Error in TCP connection to {}: {}".format(
- repr(flow.server_conn.address),
- flow.error
- ), "info")
+ pass
@controller.handler
def tcp_close(self, flow):
diff --git a/mitmproxy/flow/modules.py b/mitmproxy/flow/modules.py
index fb3c52da..7d8a282e 100644
--- a/mitmproxy/flow/modules.py
+++ b/mitmproxy/flow/modules.py
@@ -1,13 +1,7 @@
from __future__ import absolute_import, print_function, division
-import hashlib
-
-from six.moves import urllib
-
-from mitmproxy import controller
from netlib import wsgi
from netlib import version
-from netlib import strutils
from netlib.http import http1
@@ -50,129 +44,3 @@ class StreamLargeBodies(object):
if not r.raw_content and not (0 <= expected_size <= self.max_size):
# r.stream may already be a callable, which we want to preserve.
r.stream = r.stream or True
-
-
-class ClientPlaybackState:
- def __init__(self, flows, exit):
- self.flows, self.exit = flows, exit
- self.current = None
- self.testing = False # Disables actual replay for testing.
-
- def count(self):
- return len(self.flows)
-
- def done(self):
- if len(self.flows) == 0 and not self.current:
- return True
- return False
-
- def clear(self, flow):
- """
- A request has returned in some way - if this is the one we're
- servicing, go to the next flow.
- """
- if flow is self.current:
- self.current = None
-
- def tick(self, master):
- if self.flows and not self.current:
- self.current = self.flows.pop(0).copy()
- if not self.testing:
- master.replay_request(self.current)
- else:
- self.current.reply = controller.DummyReply()
- master.request(self.current)
- if self.current.response:
- master.response(self.current)
-
-
-class ServerPlaybackState:
- def __init__(
- self,
- headers,
- flows,
- exit,
- nopop,
- ignore_params,
- ignore_content,
- ignore_payload_params,
- ignore_host):
- """
- headers: Case-insensitive list of request headers that should be
- included in request-response matching.
- """
- self.headers = headers
- self.exit = exit
- self.nopop = nopop
- self.ignore_params = ignore_params
- self.ignore_content = ignore_content
- self.ignore_payload_params = [strutils.always_bytes(x) for x in (ignore_payload_params or ())]
- self.ignore_host = ignore_host
- self.fmap = {}
- for i in flows:
- if i.response:
- l = self.fmap.setdefault(self._hash(i), [])
- l.append(i)
-
- def count(self):
- return sum(len(i) for i in self.fmap.values())
-
- def _hash(self, flow):
- """
- Calculates a loose hash of the flow request.
- """
- r = flow.request
-
- _, _, path, _, query, _ = urllib.parse.urlparse(r.url)
- queriesArray = urllib.parse.parse_qsl(query, keep_blank_values=True)
-
- key = [
- str(r.port),
- str(r.scheme),
- str(r.method),
- str(path),
- ]
-
- if not self.ignore_content:
- form_contents = r.urlencoded_form or r.multipart_form
- if self.ignore_payload_params and form_contents:
- key.extend(
- p for p in form_contents.items(multi=True)
- if p[0] not in self.ignore_payload_params
- )
- else:
- key.append(str(r.raw_content))
-
- if not self.ignore_host:
- key.append(r.host)
-
- filtered = []
- ignore_params = self.ignore_params or []
- for p in queriesArray:
- if p[0] not in ignore_params:
- filtered.append(p)
- for p in filtered:
- key.append(p[0])
- key.append(p[1])
-
- if self.headers:
- headers = []
- for i in self.headers:
- v = r.headers.get(i)
- headers.append((i, v))
- key.append(headers)
- return hashlib.sha256(repr(key).encode("utf8", "surrogateescape")).digest()
-
- def next_flow(self, request):
- """
- Returns the next flow object, or None if no matching flow was
- found.
- """
- l = self.fmap.get(self._hash(request))
- if not l:
- return None
-
- if self.nopop:
- return l[0]
- else:
- return l.pop(0)
diff --git a/mitmproxy/options.py b/mitmproxy/options.py
index 0ac44cd8..ba4ed0c7 100644
--- a/mitmproxy/options.py
+++ b/mitmproxy/options.py
@@ -30,15 +30,16 @@ class Options(optmanager.OptManager):
anticache=False, # type: bool
anticomp=False, # type: bool
client_replay=None, # type: Optional[str]
- kill=False, # type: bool
+ replay_kill_extra=False, # type: bool
+ keepserving=True, # type: bool
no_server=False, # type: bool
- nopop=False, # type: bool
+ server_replay_nopop=False, # type: bool
refresh_server_playback=False, # type: bool
rfile=None, # type: Optional[str]
scripts=(), # type: Sequence[str]
showhost=False, # type: bool
replacements=(), # type: Sequence[Tuple[str, str, str]]
- rheaders=(), # type: Sequence[str]
+ server_replay_use_headers=(), # type: Sequence[str]
setheaders=(), # type: Sequence[Tuple[str, str, str]]
server_replay=None, # type: Optional[str]
stickycookie=None, # type: Optional[str]
@@ -46,10 +47,10 @@ class Options(optmanager.OptManager):
stream_large_bodies=None, # type: Optional[str]
verbosity=2, # type: int
outfile=None, # type: Tuple[str, str]
- replay_ignore_content=False, # type: bool
- replay_ignore_params=(), # type: Sequence[str]
- replay_ignore_payload_params=(), # type: Sequence[str]
- replay_ignore_host=False, # type: bool
+ server_replay_ignore_content=False, # type: bool
+ server_replay_ignore_params=(), # type: Sequence[str]
+ server_replay_ignore_payload_params=(), # type: Sequence[str]
+ server_replay_ignore_host=False, # type: bool
# Proxy options
auth_nonanonymous=False, # type: bool
@@ -88,15 +89,16 @@ class Options(optmanager.OptManager):
self.anticache = anticache
self.anticomp = anticomp
self.client_replay = client_replay
- self.kill = kill
+ self.keepserving = keepserving
+ self.replay_kill_extra = replay_kill_extra
self.no_server = no_server
- self.nopop = nopop
+ self.server_replay_nopop = server_replay_nopop
self.refresh_server_playback = refresh_server_playback
self.rfile = rfile
self.scripts = scripts
self.showhost = showhost
self.replacements = replacements
- self.rheaders = rheaders
+ self.server_replay_use_headers = server_replay_use_headers
self.setheaders = setheaders
self.server_replay = server_replay
self.stickycookie = stickycookie
@@ -104,10 +106,10 @@ class Options(optmanager.OptManager):
self.stream_large_bodies = stream_large_bodies
self.verbosity = verbosity
self.outfile = outfile
- self.replay_ignore_content = replay_ignore_content
- self.replay_ignore_params = replay_ignore_params
- self.replay_ignore_payload_params = replay_ignore_payload_params
- self.replay_ignore_host = replay_ignore_host
+ self.server_replay_ignore_content = server_replay_ignore_content
+ self.server_replay_ignore_params = server_replay_ignore_params
+ self.server_replay_ignore_payload_params = server_replay_ignore_payload_params
+ self.server_replay_ignore_host = server_replay_ignore_host
# Proxy options
self.auth_nonanonymous = auth_nonanonymous
diff --git a/mitmproxy/protocol/http.py b/mitmproxy/protocol/http.py
index 1418d6e9..1632e66f 100644
--- a/mitmproxy/protocol/http.py
+++ b/mitmproxy/protocol/http.py
@@ -153,12 +153,13 @@ class HttpLayer(base.Layer):
# We optimistically guess there might be an HTTP client on the
# other end
self.send_error_response(400, repr(e))
- self.log(
- "request",
- "warn",
- "HTTP protocol error in client request: %s" % e
+ six.reraise(
+ exceptions.ProtocolException,
+ exceptions.ProtocolException(
+ "HTTP protocol error in client request: {}".format(e)
+ ),
+ sys.exc_info()[2]
)
- return
self.log("request", "debug", [repr(request)])
diff --git a/mitmproxy/protocol/http2.py b/mitmproxy/protocol/http2.py
index 0e42d619..1595fb61 100644
--- a/mitmproxy/protocol/http2.py
+++ b/mitmproxy/protocol/http2.py
@@ -96,15 +96,17 @@ class Http2Layer(base.Layer):
self.server_to_client_stream_ids = dict([(0, 0)])
self.client_conn.h2 = SafeH2Connection(self.client_conn, client_side=False, header_encoding=False)
- # make sure that we only pass actual SSL.Connection objects in here,
- # because otherwise ssl_read_select fails!
- self.active_conns = [self.client_conn.connection]
-
def _initiate_server_conn(self):
- self.server_conn.h2 = SafeH2Connection(self.server_conn, client_side=True, header_encoding=False)
- self.server_conn.h2.initiate_connection()
- self.server_conn.send(self.server_conn.h2.data_to_send())
- self.active_conns.append(self.server_conn.connection)
+ if self.server_conn:
+ self.server_conn.h2 = SafeH2Connection(self.server_conn, client_side=True, header_encoding=False)
+ self.server_conn.h2.initiate_connection()
+ self.server_conn.send(self.server_conn.h2.data_to_send())
+
+ def _complete_handshake(self):
+ preamble = self.client_conn.rfile.read(24)
+ self.client_conn.h2.initiate_connection()
+ self.client_conn.h2.receive_data(preamble)
+ self.client_conn.send(self.client_conn.h2.data_to_send())
def next_layer(self): # pragma: no cover
# WebSockets over HTTP/2?
@@ -126,7 +128,7 @@ class Http2Layer(base.Layer):
eid = event.stream_id
if isinstance(event, events.RequestReceived):
- return self._handle_request_received(eid, event)
+ return self._handle_request_received(eid, event, source_conn.h2)
elif isinstance(event, events.ResponseReceived):
return self._handle_response_received(eid, event)
elif isinstance(event, events.DataReceived):
@@ -138,9 +140,9 @@ class Http2Layer(base.Layer):
elif isinstance(event, events.RemoteSettingsChanged):
return self._handle_remote_settings_changed(event, other_conn)
elif isinstance(event, events.ConnectionTerminated):
- return self._handle_connection_terminated(event)
+ return self._handle_connection_terminated(event, is_server)
elif isinstance(event, events.PushedStreamReceived):
- return self._handle_pushed_stream_received(event)
+ return self._handle_pushed_stream_received(event, source_conn.h2)
elif isinstance(event, events.PriorityUpdated):
return self._handle_priority_updated(eid, event)
elif isinstance(event, events.TrailersReceived):
@@ -149,9 +151,9 @@ class Http2Layer(base.Layer):
# fail-safe for unhandled events
return True
- def _handle_request_received(self, eid, event):
+ def _handle_request_received(self, eid, event, h2_connection):
headers = netlib.http.Headers([[k, v] for k, v in event.headers])
- self.streams[eid] = Http2SingleStreamLayer(self, eid, headers)
+ self.streams[eid] = Http2SingleStreamLayer(self, h2_connection, eid, headers)
self.streams[eid].timestamp_start = time.time()
self.streams[eid].no_body = (event.stream_ended is not None)
if event.priority_updated is not None:
@@ -173,7 +175,7 @@ class Http2Layer(base.Layer):
def _handle_data_received(self, eid, event, source_conn):
bsl = self.config.options.body_size_limit
if bsl and self.streams[eid].queued_data_length > bsl:
- self.streams[eid].zombie = time.time()
+ self.streams[eid].kill()
source_conn.h2.safe_reset_stream(
event.stream_id,
h2.errors.REFUSED_STREAM
@@ -194,7 +196,7 @@ class Http2Layer(base.Layer):
return True
def _handle_stream_reset(self, eid, event, is_server, other_conn):
- self.streams[eid].zombie = time.time()
+ self.streams[eid].kill()
if eid in self.streams and event.error_code == h2.errors.CANCEL:
if is_server:
other_stream_id = self.streams[eid].client_stream_id
@@ -209,7 +211,13 @@ class Http2Layer(base.Layer):
other_conn.h2.safe_update_settings(new_settings)
return True
- def _handle_connection_terminated(self, event):
+ def _handle_connection_terminated(self, event, is_server):
+ self.log("HTTP/2 connection terminated by {}: error code: {}, last stream id: {}, additional data: {}".format(
+ "server" if is_server else "client",
+ event.error_code,
+ event.last_stream_id,
+ event.additional_data), "info")
+
if event.error_code != h2.errors.NO_ERROR:
# Something terrible has happened - kill everything!
self.client_conn.h2.close_connection(
@@ -226,7 +234,7 @@ class Http2Layer(base.Layer):
"""
return False
- def _handle_pushed_stream_received(self, event):
+ def _handle_pushed_stream_received(self, event, h2_connection):
# pushed stream ids should be unique and not dependent on race conditions
# only the parent stream id must be looked up first
parent_eid = self.server_to_client_stream_ids[event.parent_stream_id]
@@ -235,7 +243,7 @@ class Http2Layer(base.Layer):
self.client_conn.send(self.client_conn.h2.data_to_send())
headers = netlib.http.Headers([[k, v] for k, v in event.headers])
- self.streams[event.pushed_stream_id] = Http2SingleStreamLayer(self, event.pushed_stream_id, headers)
+ self.streams[event.pushed_stream_id] = Http2SingleStreamLayer(self, h2_connection, event.pushed_stream_id, headers)
self.streams[event.pushed_stream_id].timestamp_start = time.time()
self.streams[event.pushed_stream_id].pushed = True
self.streams[event.pushed_stream_id].parent_stream_id = parent_eid
@@ -253,7 +261,7 @@ class Http2Layer(base.Layer):
with self.server_conn.h2.lock:
mapped_stream_id = event.stream_id
if mapped_stream_id in self.streams and self.streams[mapped_stream_id].server_stream_id:
- # if the stream is already up and running and was sent to the server
+ # if the stream is already up and running and was sent to the server,
# use the mapped server stream id to update priority information
mapped_stream_id = self.streams[mapped_stream_id].server_stream_id
@@ -294,37 +302,36 @@ class Http2Layer(base.Layer):
def _kill_all_streams(self):
for stream in self.streams.values():
- if not stream.zombie:
- stream.zombie = time.time()
- stream.request_data_finished.set()
- stream.response_arrived.set()
- stream.data_finished.set()
+ stream.kill()
def __call__(self):
- if self.server_conn:
- self._initiate_server_conn()
+ self._initiate_server_conn()
+ self._complete_handshake()
- preamble = self.client_conn.rfile.read(24)
- self.client_conn.h2.initiate_connection()
- self.client_conn.h2.receive_data(preamble)
- self.client_conn.send(self.client_conn.h2.data_to_send())
+ client = self.client_conn.connection
+ server = self.server_conn.connection
+ conns = [client, server]
try:
while True:
- r = tcp.ssl_read_select(self.active_conns, 1)
+ r = tcp.ssl_read_select(conns, 1)
for conn in r:
- source_conn = self.client_conn if conn == self.client_conn.connection else self.server_conn
- other_conn = self.server_conn if conn == self.client_conn.connection else self.client_conn
+ source_conn = self.client_conn if conn == client else self.server_conn
+ other_conn = self.server_conn if conn == client else self.client_conn
is_server = (conn == self.server_conn.connection)
with source_conn.h2.lock:
try:
- raw_frame = b''.join(http2.framereader.http2_read_raw_frame(source_conn.rfile))
+ raw_frame = b''.join(http2.read_raw_frame(source_conn.rfile))
except:
# read frame failed: connection closed
self._kill_all_streams()
return
+ if source_conn.h2.state_machine.state == h2.connection.ConnectionState.CLOSED:
+ self.log("HTTP/2 connection entered closed state already", "debug")
+ return
+
incoming_events = source_conn.h2.receive_data(raw_frame)
source_conn.send(source_conn.h2.data_to_send())
@@ -354,10 +361,11 @@ def detect_zombie_stream(func):
class Http2SingleStreamLayer(http._HttpTransmissionLayer, basethread.BaseThread):
- def __init__(self, ctx, stream_id, request_headers):
+ def __init__(self, ctx, h2_connection, stream_id, request_headers):
super(Http2SingleStreamLayer, self).__init__(
ctx, name="Http2SingleStreamLayer-{}".format(stream_id)
)
+ self.h2_connection = h2_connection
self.zombie = None
self.client_stream_id = stream_id
self.server_stream_id = None
@@ -365,6 +373,9 @@ class Http2SingleStreamLayer(http._HttpTransmissionLayer, basethread.BaseThread)
self.response_headers = None
self.pushed = False
+ self.timestamp_start = None
+ self.timestamp_end = None
+
self.request_data_queue = queue.Queue()
self.request_queued_data_length = 0
self.request_data_finished = threading.Event()
@@ -381,6 +392,13 @@ class Http2SingleStreamLayer(http._HttpTransmissionLayer, basethread.BaseThread)
self.priority_weight = None
self.handled_priority_event = None
+ def kill(self):
+ if not self.zombie:
+ self.zombie = time.time()
+ self.request_data_finished.set()
+ self.response_arrived.set()
+ self.response_data_finished.set()
+
def connect(self): # pragma: no cover
raise exceptions.Http2ProtocolException("HTTP2 layer should already have a connection.")
@@ -421,10 +439,11 @@ class Http2SingleStreamLayer(http._HttpTransmissionLayer, basethread.BaseThread)
self.request_queued_data_length = v
def raise_zombie(self, pre_command=None):
- if self.zombie is not None:
+ connection_closed = self.h2_connection.state_machine.state == h2.connection.ConnectionState.CLOSED
+ if self.zombie is not None or not hasattr(self.server_conn, 'h2') or connection_closed:
if pre_command is not None:
pre_command()
- raise exceptions.Http2ProtocolException("Zombie Stream")
+ raise exceptions.Http2ZombieException("Connection already dead")
@detect_zombie_stream
def read_request(self):
@@ -508,6 +527,7 @@ class Http2SingleStreamLayer(http._HttpTransmissionLayer, basethread.BaseThread)
except Exception as e: # pragma: no cover
raise e
finally:
+ self.raise_zombie()
self.server_conn.h2.lock.release()
if not self.no_body:
@@ -581,6 +601,8 @@ class Http2SingleStreamLayer(http._HttpTransmissionLayer, basethread.BaseThread)
try:
layer()
+ except exceptions.Http2ZombieException as e: # pragma: no cover
+ pass
except exceptions.ProtocolException as e: # pragma: no cover
self.log(repr(e), "info")
self.log(traceback.format_exc(), "debug")
@@ -589,5 +611,4 @@ class Http2SingleStreamLayer(http._HttpTransmissionLayer, basethread.BaseThread)
except exceptions.Kill:
self.log("Connection killed", "info")
- if not self.zombie:
- self.zombie = time.time()
+ self.kill()
diff --git a/mitmproxy/protocol/http_replay.py b/mitmproxy/protocol/http_replay.py
index bfde06c5..877eaa22 100644
--- a/mitmproxy/protocol/http_replay.py
+++ b/mitmproxy/protocol/http_replay.py
@@ -33,6 +33,7 @@ class RequestReplayThread(basethread.BaseThread):
def run(self):
r = self.flow.request
first_line_format_backup = r.first_line_format
+ server = None
try:
self.flow.response = None
@@ -103,3 +104,5 @@ class RequestReplayThread(basethread.BaseThread):
self.channel.tell("log", Log(traceback.format_exc(), "error"))
finally:
r.first_line_format = first_line_format_backup
+ if server:
+ server.finish()
diff --git a/mitmproxy/proxy/server.py b/mitmproxy/proxy/server.py
index 4fd5755a..c5fd5f9e 100644
--- a/mitmproxy/proxy/server.py
+++ b/mitmproxy/proxy/server.py
@@ -132,7 +132,7 @@ class ConnectionHandler(object):
self.log(str(e), "warn")
self.log("Invalid certificate, closing connection. Pass --insecure to disable validation.", "warn")
else:
- self.log(repr(e), "warn")
+ self.log(str(e), "warn")
self.log(traceback.format_exc(), "debug")
# If an error propagates to the topmost level,
diff --git a/mitmproxy/web/app.py b/mitmproxy/web/app.py
index c92ba4d3..5498c2d9 100644
--- a/mitmproxy/web/app.py
+++ b/mitmproxy/web/app.py
@@ -116,7 +116,7 @@ class RequestHandler(BasicAuth, tornado.web.RequestHandler):
def json(self):
if not self.request.headers.get("Content-Type").startswith("application/json"):
return None
- return json.loads(self.request.body)
+ return json.loads(self.request.body.decode())
@property
def state(self):
diff --git a/netlib/debug.py b/netlib/debug.py
index 29c7f655..f9c700de 100644
--- a/netlib/debug.py
+++ b/netlib/debug.py
@@ -37,7 +37,7 @@ def sysinfo():
return "\n".join(data)
-def dump_info(sig, frm, file=sys.stdout): # pragma: no cover
+def dump_info(signal=None, frame=None, file=sys.stdout): # pragma: no cover
print("****************************************************", file=file)
print("Summary", file=file)
print("=======", file=file)
@@ -81,7 +81,7 @@ def dump_info(sig, frm, file=sys.stdout): # pragma: no cover
print("****************************************************", file=file)
-def dump_stacks(signal, frame, file=sys.stdout):
+def dump_stacks(signal=None, frame=None, file=sys.stdout):
id2name = dict([(th.ident, th.name) for th in threading.enumerate()])
code = []
for threadId, stack in sys._current_frames().items():
diff --git a/netlib/encoding.py b/netlib/encoding.py
index 9c8acff7..9b8b3868 100644
--- a/netlib/encoding.py
+++ b/netlib/encoding.py
@@ -5,7 +5,9 @@ from __future__ import absolute_import
import codecs
import collections
+import six
from io import BytesIO
+
import gzip
import zlib
import brotli
@@ -32,6 +34,9 @@ def decode(encoded, encoding, errors='strict'):
Raises:
ValueError, if decoding fails.
"""
+ if len(encoded) == 0:
+ return encoded
+
global _cache
cached = (
isinstance(encoded, bytes) and
@@ -49,11 +54,14 @@ def decode(encoded, encoding, errors='strict'):
if encoding in ("gzip", "deflate", "br"):
_cache = CachedDecode(encoded, encoding, errors, decoded)
return decoded
+ except TypeError:
+ raise
except Exception as e:
- raise ValueError("{} when decoding {} with {}".format(
+ raise ValueError("{} when decoding {} with {}: {}".format(
type(e).__name__,
repr(encoded)[:10],
repr(encoding),
+ repr(e),
))
@@ -68,6 +76,9 @@ def encode(decoded, encoding, errors='strict'):
Raises:
ValueError, if encoding fails.
"""
+ if len(decoded) == 0:
+ return decoded
+
global _cache
cached = (
isinstance(decoded, bytes) and
@@ -79,17 +90,23 @@ def encode(decoded, encoding, errors='strict'):
return _cache.encoded
try:
try:
- encoded = custom_encode[encoding](decoded)
+ value = decoded
+ if not six.PY2 and isinstance(value, six.string_types):
+ value = decoded.encode()
+ encoded = custom_encode[encoding](value)
except KeyError:
encoded = codecs.encode(decoded, encoding, errors)
if encoding in ("gzip", "deflate", "br"):
_cache = CachedDecode(encoded, encoding, errors, decoded)
return encoded
+ except TypeError:
+ raise
except Exception as e:
- raise ValueError("{} when encoding {} with {}".format(
+ raise ValueError("{} when encoding {} with {}: {}".format(
type(e).__name__,
repr(decoded)[:10],
repr(encoding),
+ repr(e),
))
@@ -145,12 +162,14 @@ def encode_deflate(content):
custom_decode = {
+ "none": identity,
"identity": identity,
"gzip": decode_gzip,
"deflate": decode_deflate,
"br": decode_brotli,
}
custom_encode = {
+ "none": identity,
"identity": identity,
"gzip": encode_gzip,
"deflate": encode_deflate,
diff --git a/netlib/http/headers.py b/netlib/http/headers.py
index 131e8ce5..b55874ca 100644
--- a/netlib/http/headers.py
+++ b/netlib/http/headers.py
@@ -14,6 +14,7 @@ if six.PY2: # pragma: no cover
return x
def _always_bytes(x):
+ strutils.always_bytes(x, "utf-8", "replace") # raises a TypeError if x != str/bytes/None.
return x
else:
# While headers _should_ be ASCII, it's not uncommon for certain headers to be utf-8 encoded.
diff --git a/netlib/http/http2/__init__.py b/netlib/http/http2/__init__.py
index 60064190..7f84a1ab 100644
--- a/netlib/http/http2/__init__.py
+++ b/netlib/http/http2/__init__.py
@@ -1,8 +1,10 @@
from __future__ import absolute_import, print_function, division
-from netlib.http.http2 import framereader
+
+from netlib.http.http2.framereader import read_raw_frame, parse_frame
from netlib.http.http2.utils import parse_headers
__all__ = [
- "framereader",
+ "read_raw_frame",
+ "parse_frame",
"parse_headers",
]
diff --git a/netlib/http/http2/framereader.py b/netlib/http/http2/framereader.py
index eb9b069a..8b7cfffb 100644
--- a/netlib/http/http2/framereader.py
+++ b/netlib/http/http2/framereader.py
@@ -4,7 +4,7 @@ import hyperframe
from ...exceptions import HttpException
-def http2_read_raw_frame(rfile):
+def read_raw_frame(rfile):
header = rfile.safe_read(9)
length = int(codecs.encode(header[:3], 'hex_codec'), 16)
@@ -15,8 +15,11 @@ def http2_read_raw_frame(rfile):
return [header, body]
-def http2_read_frame(rfile):
- header, body = http2_read_raw_frame(rfile)
+def parse_frame(header, body=None):
+ if body is None:
+ body = header[9:]
+ header = header[:9]
+
frame, length = hyperframe.frame.Frame.parse_frame_header(header)
frame.parse_body(memoryview(body))
return frame
diff --git a/netlib/http/response.py b/netlib/http/response.py
index ae29298f..385e233a 100644
--- a/netlib/http/response.py
+++ b/netlib/http/response.py
@@ -84,15 +84,6 @@ class Response(message.Message):
(),
None
)
- # Assign this manually to update the content-length header.
- if isinstance(content, bytes):
- resp.content = content
- elif isinstance(content, str):
- resp.text = content
- else:
- raise TypeError("Expected content to be str or bytes, but is {}.".format(
- type(content).__name__
- ))
# Headers can be list or dict, we differentiate here.
if isinstance(headers, dict):
@@ -104,6 +95,16 @@ class Response(message.Message):
type(headers).__name__
))
+ # Assign this manually to update the content-length header.
+ if isinstance(content, bytes):
+ resp.content = content
+ elif isinstance(content, str):
+ resp.text = content
+ else:
+ raise TypeError("Expected content to be str or bytes, but is {}.".format(
+ type(content).__name__
+ ))
+
return resp
@property
diff --git a/netlib/strutils.py b/netlib/strutils.py
index 4cb3b805..d43c2aab 100644
--- a/netlib/strutils.py
+++ b/netlib/strutils.py
@@ -8,7 +8,10 @@ import six
def always_bytes(unicode_or_bytes, *encode_args):
if isinstance(unicode_or_bytes, six.text_type):
return unicode_or_bytes.encode(*encode_args)
- return unicode_or_bytes
+ elif isinstance(unicode_or_bytes, bytes) or unicode_or_bytes is None:
+ return unicode_or_bytes
+ else:
+ raise TypeError("Expected str or bytes, but got {}.".format(type(unicode_or_bytes).__name__))
def native(s, *encoding_opts):
diff --git a/pathod/language/http2.py b/pathod/language/http2.py
index c0313baa..519ee699 100644
--- a/pathod/language/http2.py
+++ b/pathod/language/http2.py
@@ -189,7 +189,7 @@ class Response(_HTTP2Message):
resp = http.Response(
b'HTTP/2.0',
- self.status_code.string(),
+ int(self.status_code.string()),
b'',
headers,
body,
diff --git a/pathod/protocols/http2.py b/pathod/protocols/http2.py
index 5ad120de..7b162664 100644
--- a/pathod/protocols/http2.py
+++ b/pathod/protocols/http2.py
@@ -6,7 +6,7 @@ import time
import hyperframe.frame
from hpack.hpack import Encoder, Decoder
-from netlib import utils, strutils
+from netlib import utils
from netlib.http import http2
import netlib.http.headers
import netlib.http.response
@@ -201,7 +201,7 @@ class HTTP2StateProtocol(object):
headers = response.headers.copy()
if ':status' not in headers:
- headers.insert(0, b':status', strutils.always_bytes(response.status_code))
+ headers.insert(0, b':status', str(response.status_code).encode())
if hasattr(response, 'stream_id'):
stream_id = response.stream_id
@@ -254,7 +254,7 @@ class HTTP2StateProtocol(object):
def read_frame(self, hide=False):
while True:
- frm = http2.framereader.http2_read_frame(self.tcp_handler.rfile)
+ frm = http2.parse_frame(*http2.read_raw_frame(self.tcp_handler.rfile))
if not hide and self.dump_frames: # pragma no cover
print(frm.human_readable("<<"))
diff --git a/setup.py b/setup.py
index b86a48c0..9f75ca15 100644
--- a/setup.py
+++ b/setup.py
@@ -85,7 +85,7 @@ setup(
"tornado>=4.3, <4.5",
"urwid>=1.3.1, <1.4",
"watchdog>=0.8.3, <0.9",
- "brotlipy>=0.3.0, <0.5",
+ "brotlipy>=0.5.1, <0.7",
],
extras_require={
':sys_platform == "win32"': [
@@ -107,6 +107,7 @@ setup(
"pytest-cov>=2.2.1, <3",
"pytest-timeout>=1.0.0, <2",
"pytest-xdist>=1.14, <2",
+ "pytest-faulthandler>=1.3.0, <2",
"sphinx>=1.3.5, <1.5",
"sphinx-autobuild>=0.5.2, <0.7",
"sphinxcontrib-documentedlist>=0.4.0, <0.5",
diff --git a/test/mitmproxy/builtins/test_clientplayback.py b/test/mitmproxy/builtins/test_clientplayback.py
new file mode 100644
index 00000000..15702340
--- /dev/null
+++ b/test/mitmproxy/builtins/test_clientplayback.py
@@ -0,0 +1,37 @@
+import mock
+
+from mitmproxy.builtins import clientplayback
+from mitmproxy import options
+
+from .. import tutils, mastertest
+
+
+class TestClientPlayback:
+ def test_playback(self):
+ cp = clientplayback.ClientPlayback()
+ cp.configure(options.Options(), [])
+ assert cp.count() == 0
+ f = tutils.tflow(resp=True)
+ cp.load([f])
+ assert cp.count() == 1
+ RP = "mitmproxy.protocol.http_replay.RequestReplayThread"
+ with mock.patch(RP) as rp:
+ assert not cp.current
+ with mastertest.mockctx():
+ cp.tick()
+ rp.assert_called()
+ assert cp.current
+
+ cp.keepserving = False
+ cp.flows = None
+ cp.current = None
+ with mock.patch("mitmproxy.controller.Master.shutdown") as sd:
+ with mastertest.mockctx():
+ cp.tick()
+ sd.assert_called()
+
+ def test_configure(self):
+ cp = clientplayback.ClientPlayback()
+ cp.configure(
+ options.Options(), []
+ )
diff --git a/test/mitmproxy/builtins/test_script.py b/test/mitmproxy/builtins/test_script.py
index 0bac6ca0..09e5bc92 100644
--- a/test/mitmproxy/builtins/test_script.py
+++ b/test/mitmproxy/builtins/test_script.py
@@ -137,6 +137,31 @@ class TestScript(mastertest.MasterTest):
class TestScriptLoader(mastertest.MasterTest):
+ def test_run_once(self):
+ s = state.State()
+ o = options.Options(scripts=[])
+ m = master.FlowMaster(o, None, s)
+ sl = script.ScriptLoader()
+ m.addons.add(o, sl)
+
+ f = tutils.tflow(resp=True)
+ with m.handlecontext():
+ sc = sl.run_once(
+ tutils.test_data.path(
+ "data/addonscripts/recorder.py"
+ ), [f]
+ )
+ evts = [i[1] for i in sc.ns.call_log]
+ assert evts == ['start', 'request', 'responseheaders', 'response', 'done']
+
+ with m.handlecontext():
+ tutils.raises(
+ "file not found",
+ sl.run_once,
+ "nonexistent",
+ [f]
+ )
+
def test_simple(self):
s = state.State()
o = options.Options(scripts=[])
diff --git a/test/mitmproxy/builtins/test_serverplayback.py b/test/mitmproxy/builtins/test_serverplayback.py
new file mode 100644
index 00000000..4db509da
--- /dev/null
+++ b/test/mitmproxy/builtins/test_serverplayback.py
@@ -0,0 +1,284 @@
+from .. import tutils, mastertest
+
+import netlib.tutils
+from mitmproxy.builtins import serverplayback
+from mitmproxy import options
+from mitmproxy import exceptions
+from mitmproxy import flow
+
+
+class TestServerPlayback:
+ def test_server_playback(self):
+ sp = serverplayback.ServerPlayback()
+ sp.configure(options.Options(), [])
+ f = tutils.tflow(resp=True)
+
+ assert not sp.flowmap
+
+ sp.load([f])
+ assert sp.flowmap
+ assert sp.next_flow(f)
+ assert not sp.flowmap
+
+ def test_ignore_host(self):
+ sp = serverplayback.ServerPlayback()
+ sp.configure(options.Options(server_replay_ignore_host=True), [])
+
+ r = tutils.tflow(resp=True)
+ r2 = tutils.tflow(resp=True)
+
+ r.request.host = "address"
+ r2.request.host = "address"
+ assert sp._hash(r) == sp._hash(r2)
+ r2.request.host = "wrong_address"
+ assert sp._hash(r) == sp._hash(r2)
+
+ def test_ignore_content(self):
+ s = serverplayback.ServerPlayback()
+ s.configure(options.Options(server_replay_ignore_content=False), [])
+
+ r = tutils.tflow(resp=True)
+ r2 = tutils.tflow(resp=True)
+
+ r.request.content = b"foo"
+ r2.request.content = b"foo"
+ assert s._hash(r) == s._hash(r2)
+ r2.request.content = b"bar"
+ assert not s._hash(r) == s._hash(r2)
+
+ s.configure(options.Options(server_replay_ignore_content=True), [])
+ r = tutils.tflow(resp=True)
+ r2 = tutils.tflow(resp=True)
+ r.request.content = b"foo"
+ r2.request.content = b"foo"
+ assert s._hash(r) == s._hash(r2)
+ r2.request.content = b"bar"
+ assert s._hash(r) == s._hash(r2)
+ r2.request.content = b""
+ assert s._hash(r) == s._hash(r2)
+ r2.request.content = None
+ assert s._hash(r) == s._hash(r2)
+
+ def test_ignore_content_wins_over_params(self):
+ s = serverplayback.ServerPlayback()
+ s.configure(
+ options.Options(
+ server_replay_ignore_content=True,
+ server_replay_ignore_payload_params=[
+ "param1", "param2"
+ ]
+ ),
+ []
+ )
+ # NOTE: parameters are mutually exclusive in options
+
+ r = tutils.tflow(resp=True)
+ r.request.headers["Content-Type"] = "application/x-www-form-urlencoded"
+ r.request.content = b"paramx=y"
+
+ r2 = tutils.tflow(resp=True)
+ r2.request.headers["Content-Type"] = "application/x-www-form-urlencoded"
+ r2.request.content = b"paramx=x"
+
+ # same parameters
+ assert s._hash(r) == s._hash(r2)
+
+ def test_ignore_payload_params_other_content_type(self):
+ s = serverplayback.ServerPlayback()
+ s.configure(
+ options.Options(
+ server_replay_ignore_content=False,
+ server_replay_ignore_payload_params=[
+ "param1", "param2"
+ ]
+ ),
+ []
+
+ )
+ r = tutils.tflow(resp=True)
+ r.request.headers["Content-Type"] = "application/json"
+ r.request.content = b'{"param1":"1"}'
+ r2 = tutils.tflow(resp=True)
+ r2.request.headers["Content-Type"] = "application/json"
+ r2.request.content = b'{"param1":"1"}'
+ # same content
+ assert s._hash(r) == s._hash(r2)
+ # distint content (note only x-www-form-urlencoded payload is analysed)
+ r2.request.content = b'{"param1":"2"}'
+ assert not s._hash(r) == s._hash(r2)
+
+ def test_hash(self):
+ s = serverplayback.ServerPlayback()
+ s.configure(options.Options(), [])
+
+ r = tutils.tflow()
+ r2 = tutils.tflow()
+
+ assert s._hash(r)
+ assert s._hash(r) == s._hash(r2)
+ r.request.headers["foo"] = "bar"
+ assert s._hash(r) == s._hash(r2)
+ r.request.path = "voing"
+ assert s._hash(r) != s._hash(r2)
+
+ r.request.path = "path?blank_value"
+ r2.request.path = "path?"
+ assert s._hash(r) != s._hash(r2)
+
+ def test_headers(self):
+ s = serverplayback.ServerPlayback()
+ s.configure(options.Options(server_replay_use_headers=["foo"]), [])
+
+ r = tutils.tflow(resp=True)
+ r.request.headers["foo"] = "bar"
+ r2 = tutils.tflow(resp=True)
+ assert not s._hash(r) == s._hash(r2)
+ r2.request.headers["foo"] = "bar"
+ assert s._hash(r) == s._hash(r2)
+ r2.request.headers["oink"] = "bar"
+ assert s._hash(r) == s._hash(r2)
+
+ r = tutils.tflow(resp=True)
+ r2 = tutils.tflow(resp=True)
+ assert s._hash(r) == s._hash(r2)
+
+ def test_load(self):
+ s = serverplayback.ServerPlayback()
+ s.configure(options.Options(), [])
+
+ r = tutils.tflow(resp=True)
+ r.request.headers["key"] = "one"
+
+ r2 = tutils.tflow(resp=True)
+ r2.request.headers["key"] = "two"
+
+ s.load([r, r2])
+
+ assert s.count() == 2
+
+ n = s.next_flow(r)
+ assert n.request.headers["key"] == "one"
+ assert s.count() == 1
+
+ n = s.next_flow(r)
+ assert n.request.headers["key"] == "two"
+ assert not s.flowmap
+ assert s.count() == 0
+
+ assert not s.next_flow(r)
+
+ def test_load_with_server_replay_nopop(self):
+ s = serverplayback.ServerPlayback()
+ s.configure(options.Options(server_replay_nopop=True), [])
+
+ r = tutils.tflow(resp=True)
+ r.request.headers["key"] = "one"
+
+ r2 = tutils.tflow(resp=True)
+ r2.request.headers["key"] = "two"
+
+ s.load([r, r2])
+
+ assert s.count() == 2
+ s.next_flow(r)
+ assert s.count() == 2
+
+ def test_ignore_params(self):
+ s = serverplayback.ServerPlayback()
+ s.configure(
+ options.Options(
+ server_replay_ignore_params=["param1", "param2"]
+ ),
+ []
+ )
+
+ r = tutils.tflow(resp=True)
+ r.request.path = "/test?param1=1"
+ r2 = tutils.tflow(resp=True)
+ r2.request.path = "/test"
+ assert s._hash(r) == s._hash(r2)
+ r2.request.path = "/test?param1=2"
+ assert s._hash(r) == s._hash(r2)
+ r2.request.path = "/test?param2=1"
+ assert s._hash(r) == s._hash(r2)
+ r2.request.path = "/test?param3=2"
+ assert not s._hash(r) == s._hash(r2)
+
+ def test_ignore_payload_params(self):
+ s = serverplayback.ServerPlayback()
+ s.configure(
+ options.Options(
+ server_replay_ignore_payload_params=["param1", "param2"]
+ ),
+ []
+ )
+
+ r = tutils.tflow(resp=True)
+ r.request.headers["Content-Type"] = "application/x-www-form-urlencoded"
+ r.request.content = b"paramx=x&param1=1"
+ r2 = tutils.tflow(resp=True)
+ r2.request.headers["Content-Type"] = "application/x-www-form-urlencoded"
+ r2.request.content = b"paramx=x&param1=1"
+ # same parameters
+ assert s._hash(r) == s._hash(r2)
+ # ignored parameters !=
+ r2.request.content = b"paramx=x&param1=2"
+ assert s._hash(r) == s._hash(r2)
+ # missing parameter
+ r2.request.content = b"paramx=x"
+ assert s._hash(r) == s._hash(r2)
+ # ignorable parameter added
+ r2.request.content = b"paramx=x&param1=2"
+ assert s._hash(r) == s._hash(r2)
+ # not ignorable parameter changed
+ r2.request.content = b"paramx=y&param1=1"
+ assert not s._hash(r) == s._hash(r2)
+ # not ignorable parameter missing
+ r2.request.content = b"param1=1"
+ assert not s._hash(r) == s._hash(r2)
+
+ def test_server_playback_full(self):
+ state = flow.State()
+ s = serverplayback.ServerPlayback()
+ o = options.Options(refresh_server_playback = True, keepserving=False)
+ m = mastertest.RecordingMaster(o, None, state)
+ m.addons.add(o, s)
+
+ f = tutils.tflow()
+ f.response = netlib.tutils.tresp(content=f.request.content)
+ s.load([f, f])
+
+ tf = tutils.tflow()
+ assert not tf.response
+ m.request(tf)
+ assert tf.response == f.response
+
+ tf = tutils.tflow()
+ tf.request.content = b"gibble"
+ assert not tf.response
+ m.request(tf)
+ assert not tf.response
+
+ assert not s.stop
+ s.tick()
+ assert not s.stop
+
+ tf = tutils.tflow()
+ m.request(tutils.tflow())
+ assert s.stop
+
+ def test_server_playback_kill(self):
+ state = flow.State()
+ s = serverplayback.ServerPlayback()
+ o = options.Options(refresh_server_playback = True, replay_kill_extra=True)
+ m = mastertest.RecordingMaster(o, None, state)
+ m.addons.add(o, s)
+
+ f = tutils.tflow()
+ f.response = netlib.tutils.tresp(content=f.request.content)
+ s.load([f])
+
+ f = tutils.tflow()
+ f.request.host = "nonexistent"
+ m.request(f)
+ assert f.reply.value == exceptions.Kill
diff --git a/test/mitmproxy/console/test_master.py b/test/mitmproxy/console/test_master.py
index fcb87e1b..8388a6bd 100644
--- a/test/mitmproxy/console/test_master.py
+++ b/test/mitmproxy/console/test_master.py
@@ -107,7 +107,7 @@ def test_format_keyvals():
def test_options():
- assert console.master.Options(kill=True)
+ assert console.master.Options(replay_kill_extra=True)
class TestMaster(mastertest.MasterTest):
diff --git a/test/mitmproxy/mastertest.py b/test/mitmproxy/mastertest.py
index 08659d19..a14fe02a 100644
--- a/test/mitmproxy/mastertest.py
+++ b/test/mitmproxy/mastertest.py
@@ -1,8 +1,14 @@
+import contextlib
+
from . import tutils
import netlib.tutils
from mitmproxy.flow import master
-from mitmproxy import flow, proxy, models, controller
+from mitmproxy import flow, proxy, models, controller, options
+
+
+class TestMaster:
+ pass
class MasterTest:
@@ -16,7 +22,9 @@ class MasterTest:
master.serverconnect(f.server_conn)
master.request(f)
if not f.error:
- f.response = models.HTTPResponse.wrap(netlib.tutils.tresp(content=content))
+ f.response = models.HTTPResponse.wrap(
+ netlib.tutils.tresp(content=content)
+ )
master.response(f)
master.clientdisconnect(f)
return f
@@ -41,3 +49,12 @@ class RecordingMaster(master.FlowMaster):
def add_log(self, e, level):
self.event_log.append((level, e))
+
+
+@contextlib.contextmanager
+def mockctx():
+ state = flow.State()
+ o = options.Options(refresh_server_playback = True, keepserving=False)
+ m = RecordingMaster(o, proxy.DummyServer(o), state)
+ with m.handlecontext():
+ yield
diff --git a/test/mitmproxy/protocol/test_http1.py b/test/mitmproxy/protocol/test_http1.py
index 7d04c56b..2fc4ac63 100644
--- a/test/mitmproxy/protocol/test_http1.py
+++ b/test/mitmproxy/protocol/test_http1.py
@@ -18,14 +18,15 @@ class TestInvalidRequests(tservers.HTTPProxyTest):
def test_double_connect(self):
p = self.pathoc()
- r = p.request("connect:'%s:%s'" % ("127.0.0.1", self.server2.port))
+ with p.connect():
+ r = p.request("connect:'%s:%s'" % ("127.0.0.1", self.server2.port))
assert r.status_code == 400
assert b"Invalid HTTP request form" in r.content
def test_relative_request(self):
p = self.pathoc_raw()
- p.connect()
- r = p.request("get:/p/200")
+ with p.connect():
+ r = p.request("get:/p/200")
assert r.status_code == 400
assert b"Invalid HTTP request form" in r.content
@@ -61,5 +62,8 @@ class TestHeadContentLength(tservers.HTTPProxyTest):
def test_head_content_length(self):
p = self.pathoc()
- resp = p.request("""head:'%s/p/200:h"Content-Length"="42"'""" % self.server.urlbase)
+ with p.connect():
+ resp = p.request(
+ """head:'%s/p/200:h"Content-Length"="42"'""" % self.server.urlbase
+ )
assert resp.headers["Content-Length"] == "42"
diff --git a/test/mitmproxy/protocol/test_http2.py b/test/mitmproxy/protocol/test_http2.py
index 1eabebf1..c4bd2049 100644
--- a/test/mitmproxy/protocol/test_http2.py
+++ b/test/mitmproxy/protocol/test_http2.py
@@ -15,7 +15,7 @@ from mitmproxy.proxy.config import ProxyConfig
import netlib
from ...netlib import tservers as netlib_tservers
from netlib.exceptions import HttpException
-from netlib.http.http2 import framereader
+from netlib.http import http1, http2
from .. import tservers
@@ -33,6 +33,11 @@ requires_alpn = pytest.mark.skipif(
reason='requires OpenSSL with ALPN support')
+# inspect the log:
+# for msg in self.proxy.tmaster.tlog:
+# print(msg)
+
+
class _Http2ServerBase(netlib_tservers.ServerTestBase):
ssl = dict(alpn_select=b'h2')
@@ -55,7 +60,7 @@ class _Http2ServerBase(netlib_tservers.ServerTestBase):
done = False
while not done:
try:
- raw = b''.join(framereader.http2_read_raw_frame(self.rfile))
+ raw = b''.join(http2.read_raw_frame(self.rfile))
events = h2_conn.receive_data(raw)
except HttpException:
print(traceback.format_exc())
@@ -124,11 +129,17 @@ class _Http2TestBase(object):
client.connect()
# send CONNECT request
- client.wfile.write(
- b"CONNECT localhost:%d HTTP/1.1\r\n"
- b"Host: localhost:%d\r\n"
- b"\r\n" % (self.server.server.address.port, self.server.server.address.port)
- )
+ client.wfile.write(http1.assemble_request(netlib.http.Request(
+ 'authority',
+ b'CONNECT',
+ b'',
+ b'localhost',
+ self.server.server.address.port,
+ b'/',
+ b'HTTP/1.1',
+ [(b'host', b'localhost:%d' % self.server.server.address.port)],
+ b'',
+ )))
client.wfile.flush()
# read CONNECT response
@@ -242,7 +253,7 @@ class TestSimple(_Http2Test):
done = False
while not done:
try:
- raw = b''.join(framereader.http2_read_raw_frame(client.rfile))
+ raw = b''.join(http2.read_raw_frame(client.rfile))
events = h2_conn.receive_data(raw)
except HttpException:
print(traceback.format_exc())
@@ -319,7 +330,7 @@ class TestRequestWithPriority(_Http2Test):
done = False
while not done:
try:
- raw = b''.join(framereader.http2_read_raw_frame(client.rfile))
+ raw = b''.join(http2.read_raw_frame(client.rfile))
events = h2_conn.receive_data(raw)
except HttpException:
print(traceback.format_exc())
@@ -358,7 +369,7 @@ class TestRequestWithPriority(_Http2Test):
done = False
while not done:
try:
- raw = b''.join(framereader.http2_read_raw_frame(client.rfile))
+ raw = b''.join(http2.read_raw_frame(client.rfile))
events = h2_conn.receive_data(raw)
except HttpException:
print(traceback.format_exc())
@@ -430,7 +441,7 @@ class TestPriority(_Http2Test):
done = False
while not done:
try:
- raw = b''.join(framereader.http2_read_raw_frame(client.rfile))
+ raw = b''.join(http2.read_raw_frame(client.rfile))
events = h2_conn.receive_data(raw)
except HttpException:
print(traceback.format_exc())
@@ -507,7 +518,7 @@ class TestPriorityWithExistingStream(_Http2Test):
done = False
while not done:
try:
- raw = b''.join(framereader.http2_read_raw_frame(client.rfile))
+ raw = b''.join(http2.read_raw_frame(client.rfile))
events = h2_conn.receive_data(raw)
except HttpException:
print(traceback.format_exc())
@@ -558,7 +569,7 @@ class TestStreamResetFromServer(_Http2Test):
done = False
while not done:
try:
- raw = b''.join(framereader.http2_read_raw_frame(client.rfile))
+ raw = b''.join(http2.read_raw_frame(client.rfile))
events = h2_conn.receive_data(raw)
except HttpException:
print(traceback.format_exc())
@@ -608,7 +619,7 @@ class TestBodySizeLimit(_Http2Test):
done = False
while not done:
try:
- raw = b''.join(framereader.http2_read_raw_frame(client.rfile))
+ raw = b''.join(http2.read_raw_frame(client.rfile))
events = h2_conn.receive_data(raw)
except HttpException:
print(traceback.format_exc())
@@ -693,7 +704,7 @@ class TestPushPromise(_Http2Test):
responses = 0
while not done:
try:
- raw = b''.join(framereader.http2_read_raw_frame(client.rfile))
+ raw = b''.join(http2.read_raw_frame(client.rfile))
events = h2_conn.receive_data(raw)
except HttpException:
print(traceback.format_exc())
@@ -746,7 +757,7 @@ class TestPushPromise(_Http2Test):
responses = 0
while not done:
try:
- raw = b''.join(framereader.http2_read_raw_frame(client.rfile))
+ raw = b''.join(http2.read_raw_frame(client.rfile))
events = h2_conn.receive_data(raw)
except HttpException:
print(traceback.format_exc())
@@ -806,7 +817,7 @@ class TestConnectionLost(_Http2Test):
done = False
while not done:
try:
- raw = b''.join(framereader.http2_read_raw_frame(client.rfile))
+ raw = b''.join(http2.read_raw_frame(client.rfile))
h2_conn.receive_data(raw)
except HttpException:
print(traceback.format_exc())
@@ -863,7 +874,7 @@ class TestMaxConcurrentStreams(_Http2Test):
ended_streams = 0
while ended_streams != len(new_streams):
try:
- header, body = framereader.http2_read_raw_frame(client.rfile)
+ header, body = http2.read_raw_frame(client.rfile)
events = h2_conn.receive_data(b''.join([header, body]))
except:
break
@@ -909,7 +920,7 @@ class TestConnectionTerminated(_Http2Test):
connection_terminated_event = None
while not done:
try:
- raw = b''.join(framereader.http2_read_raw_frame(client.rfile))
+ raw = b''.join(http2.read_raw_frame(client.rfile))
events = h2_conn.receive_data(raw)
for event in events:
if isinstance(event, h2.events.ConnectionTerminated):
diff --git a/test/mitmproxy/test_addons.py b/test/mitmproxy/test_addons.py
index a5085ea0..52d7f07f 100644
--- a/test/mitmproxy/test_addons.py
+++ b/test/mitmproxy/test_addons.py
@@ -17,5 +17,5 @@ def test_simple():
m = controller.Master(o)
a = addons.Addons(m)
a.add(o, TAddon("one"))
- assert a.has_addon("one")
- assert not a.has_addon("two")
+ assert a.get("one")
+ assert not a.get("two")
diff --git a/test/mitmproxy/test_dump.py b/test/mitmproxy/test_dump.py
index 90f33264..06f39e9d 100644
--- a/test/mitmproxy/test_dump.py
+++ b/test/mitmproxy/test_dump.py
@@ -45,18 +45,17 @@ class TestDumpMaster(mastertest.MasterTest):
m = dump.DumpMaster(None, o)
f = tutils.tflow(err=True)
m.error(f)
- assert m.error(f)
assert "error" in o.tfile.getvalue()
def test_replay(self):
- o = dump.Options(server_replay=["nonexistent"], kill=True)
- tutils.raises(dump.DumpError, dump.DumpMaster, None, o)
+ o = dump.Options(server_replay=["nonexistent"], replay_kill_extra=True)
+ tutils.raises(exceptions.OptionsError, dump.DumpMaster, None, o)
with tutils.tmpdir() as t:
p = os.path.join(t, "rep")
self.flowfile(p)
- o = dump.Options(server_replay=[p], kill=True)
+ o = dump.Options(server_replay=[p], replay_kill_extra=True)
o.verbosity = 0
o.flow_detail = 0
m = dump.DumpMaster(None, o)
@@ -64,13 +63,13 @@ class TestDumpMaster(mastertest.MasterTest):
self.cycle(m, b"content")
self.cycle(m, b"content")
- o = dump.Options(server_replay=[p], kill=False)
+ o = dump.Options(server_replay=[p], replay_kill_extra=False)
o.verbosity = 0
o.flow_detail = 0
m = dump.DumpMaster(None, o)
self.cycle(m, b"nonexistent")
- o = dump.Options(client_replay=[p], kill=False)
+ o = dump.Options(client_replay=[p], replay_kill_extra=False)
o.verbosity = 0
o.flow_detail = 0
m = dump.DumpMaster(None, o)
diff --git a/test/mitmproxy/test_flow.py b/test/mitmproxy/test_flow.py
index 1caeb100..0fe45afb 100644
--- a/test/mitmproxy/test_flow.py
+++ b/test/mitmproxy/test_flow.py
@@ -37,261 +37,6 @@ def test_app_registry():
assert ar.get(r)
-class TestClientPlaybackState:
-
- def test_tick(self):
- first = tutils.tflow()
- s = flow.State()
- fm = flow.FlowMaster(None, None, s)
- fm.start_client_playback([first, tutils.tflow()], True)
- c = fm.client_playback
- c.testing = True
-
- assert not c.done()
- assert not s.flow_count()
- assert c.count() == 2
- c.tick(fm)
- assert s.flow_count()
- assert c.count() == 1
-
- c.tick(fm)
- assert c.count() == 1
-
- c.clear(c.current)
- c.tick(fm)
- assert c.count() == 0
- c.clear(c.current)
- assert c.done()
-
- fm.state.clear()
- fm.tick(timeout=0)
-
- fm.stop_client_playback()
- assert not fm.client_playback
-
-
-class TestServerPlaybackState:
-
- def test_hash(self):
- s = flow.ServerPlaybackState(
- None,
- [],
- False,
- False,
- None,
- False,
- None,
- False)
- r = tutils.tflow()
- r2 = tutils.tflow()
-
- assert s._hash(r)
- assert s._hash(r) == s._hash(r2)
- r.request.headers["foo"] = "bar"
- assert s._hash(r) == s._hash(r2)
- r.request.path = "voing"
- assert s._hash(r) != s._hash(r2)
-
- r.request.path = "path?blank_value"
- r2.request.path = "path?"
- assert s._hash(r) != s._hash(r2)
-
- def test_headers(self):
- s = flow.ServerPlaybackState(
- ["foo"],
- [],
- False,
- False,
- None,
- False,
- None,
- False)
- r = tutils.tflow(resp=True)
- r.request.headers["foo"] = "bar"
- r2 = tutils.tflow(resp=True)
- assert not s._hash(r) == s._hash(r2)
- r2.request.headers["foo"] = "bar"
- assert s._hash(r) == s._hash(r2)
- r2.request.headers["oink"] = "bar"
- assert s._hash(r) == s._hash(r2)
-
- r = tutils.tflow(resp=True)
- r2 = tutils.tflow(resp=True)
- assert s._hash(r) == s._hash(r2)
-
- def test_load(self):
- r = tutils.tflow(resp=True)
- r.request.headers["key"] = "one"
-
- r2 = tutils.tflow(resp=True)
- r2.request.headers["key"] = "two"
-
- s = flow.ServerPlaybackState(
- None, [
- r, r2], False, False, None, False, None, False)
- assert s.count() == 2
- assert len(s.fmap.keys()) == 1
-
- n = s.next_flow(r)
- assert n.request.headers["key"] == "one"
- assert s.count() == 1
-
- n = s.next_flow(r)
- assert n.request.headers["key"] == "two"
- assert s.count() == 0
-
- assert not s.next_flow(r)
-
- def test_load_with_nopop(self):
- r = tutils.tflow(resp=True)
- r.request.headers["key"] = "one"
-
- r2 = tutils.tflow(resp=True)
- r2.request.headers["key"] = "two"
-
- s = flow.ServerPlaybackState(
- None, [
- r, r2], False, True, None, False, None, False)
-
- assert s.count() == 2
- s.next_flow(r)
- assert s.count() == 2
-
- def test_ignore_params(self):
- s = flow.ServerPlaybackState(
- None, [], False, False, [
- "param1", "param2"], False, None, False)
- r = tutils.tflow(resp=True)
- r.request.path = "/test?param1=1"
- r2 = tutils.tflow(resp=True)
- r2.request.path = "/test"
- assert s._hash(r) == s._hash(r2)
- r2.request.path = "/test?param1=2"
- assert s._hash(r) == s._hash(r2)
- r2.request.path = "/test?param2=1"
- assert s._hash(r) == s._hash(r2)
- r2.request.path = "/test?param3=2"
- assert not s._hash(r) == s._hash(r2)
-
- def test_ignore_payload_params(self):
- s = flow.ServerPlaybackState(
- None, [], False, False, None, False, [
- "param1", "param2"], False)
- r = tutils.tflow(resp=True)
- r.request.headers["Content-Type"] = "application/x-www-form-urlencoded"
- r.request.content = b"paramx=x&param1=1"
- r2 = tutils.tflow(resp=True)
- r2.request.headers["Content-Type"] = "application/x-www-form-urlencoded"
- r2.request.content = b"paramx=x&param1=1"
- # same parameters
- assert s._hash(r) == s._hash(r2)
- # ignored parameters !=
- r2.request.content = b"paramx=x&param1=2"
- assert s._hash(r) == s._hash(r2)
- # missing parameter
- r2.request.content = b"paramx=x"
- assert s._hash(r) == s._hash(r2)
- # ignorable parameter added
- r2.request.content = b"paramx=x&param1=2"
- assert s._hash(r) == s._hash(r2)
- # not ignorable parameter changed
- r2.request.content = b"paramx=y&param1=1"
- assert not s._hash(r) == s._hash(r2)
- # not ignorable parameter missing
- r2.request.content = b"param1=1"
- assert not s._hash(r) == s._hash(r2)
-
- def test_ignore_payload_params_other_content_type(self):
- s = flow.ServerPlaybackState(
- None, [], False, False, None, False, [
- "param1", "param2"], False)
- r = tutils.tflow(resp=True)
- r.request.headers["Content-Type"] = "application/json"
- r.request.content = b'{"param1":"1"}'
- r2 = tutils.tflow(resp=True)
- r2.request.headers["Content-Type"] = "application/json"
- r2.request.content = b'{"param1":"1"}'
- # same content
- assert s._hash(r) == s._hash(r2)
- # distint content (note only x-www-form-urlencoded payload is analysed)
- r2.request.content = b'{"param1":"2"}'
- assert not s._hash(r) == s._hash(r2)
-
- def test_ignore_payload_wins_over_params(self):
- # NOTE: parameters are mutually exclusive in options
- s = flow.ServerPlaybackState(
- None, [], False, False, None, True, [
- "param1", "param2"], False)
- r = tutils.tflow(resp=True)
- r.request.headers["Content-Type"] = "application/x-www-form-urlencoded"
- r.request.content = b"paramx=y"
- r2 = tutils.tflow(resp=True)
- r2.request.headers["Content-Type"] = "application/x-www-form-urlencoded"
- r2.request.content = b"paramx=x"
- # same parameters
- assert s._hash(r) == s._hash(r2)
-
- def test_ignore_content(self):
- s = flow.ServerPlaybackState(
- None,
- [],
- False,
- False,
- None,
- False,
- None,
- False)
- r = tutils.tflow(resp=True)
- r2 = tutils.tflow(resp=True)
-
- r.request.content = b"foo"
- r2.request.content = b"foo"
- assert s._hash(r) == s._hash(r2)
- r2.request.content = b"bar"
- assert not s._hash(r) == s._hash(r2)
-
- # now ignoring content
- s = flow.ServerPlaybackState(
- None,
- [],
- False,
- False,
- None,
- True,
- None,
- False)
- r = tutils.tflow(resp=True)
- r2 = tutils.tflow(resp=True)
- r.request.content = b"foo"
- r2.request.content = b"foo"
- assert s._hash(r) == s._hash(r2)
- r2.request.content = b"bar"
- assert s._hash(r) == s._hash(r2)
- r2.request.content = b""
- assert s._hash(r) == s._hash(r2)
- r2.request.content = None
- assert s._hash(r) == s._hash(r2)
-
- def test_ignore_host(self):
- s = flow.ServerPlaybackState(
- None,
- [],
- False,
- False,
- None,
- False,
- None,
- True)
- r = tutils.tflow(resp=True)
- r2 = tutils.tflow(resp=True)
-
- r.request.host = "address"
- r2.request.host = "address"
- assert s._hash(r) == s._hash(r2)
- r2.request.host = "wrong_address"
- assert s._hash(r) == s._hash(r2)
-
-
class TestHTTPFlow(object):
def test_copy(self):
@@ -699,13 +444,13 @@ class TestFlowMaster:
fm = flow.FlowMaster(None, None, s)
f = tutils.tflow(resp=True)
f.request.content = None
- assert "missing" in fm.replay_request(f)
+ tutils.raises("missing", fm.replay_request, f)
f.intercepted = True
- assert "intercepting" in fm.replay_request(f)
+ tutils.raises("intercepted", fm.replay_request, f)
f.live = True
- assert "live" in fm.replay_request(f)
+ tutils.raises("live", fm.replay_request, f)
def test_duplicate_flow(self):
s = flow.State()
@@ -743,103 +488,6 @@ class TestFlowMaster:
fm.shutdown()
- def test_client_playback(self):
- s = flow.State()
-
- f = tutils.tflow(resp=True)
- pb = [tutils.tflow(resp=True), f]
- fm = flow.FlowMaster(
- options.Options(),
- DummyServer(ProxyConfig(options.Options())),
- s
- )
- assert not fm.start_server_playback(
- pb,
- False,
- [],
- False,
- False,
- None,
- False,
- None,
- False)
- assert not fm.start_client_playback(pb, False)
- fm.client_playback.testing = True
-
- assert not fm.state.flow_count()
- fm.tick(0)
- assert fm.state.flow_count()
-
- f.error = Error("error")
- fm.error(f)
-
- def test_server_playback(self):
- s = flow.State()
-
- f = tutils.tflow()
- f.response = HTTPResponse.wrap(netlib.tutils.tresp(content=f.request))
- pb = [f]
-
- fm = flow.FlowMaster(options.Options(), None, s)
- fm.refresh_server_playback = True
- assert not fm.do_server_playback(tutils.tflow())
-
- fm.start_server_playback(
- pb,
- False,
- [],
- False,
- False,
- None,
- False,
- None,
- False)
- assert fm.do_server_playback(tutils.tflow())
-
- fm.start_server_playback(
- pb,
- False,
- [],
- True,
- False,
- None,
- False,
- None,
- False)
- r = tutils.tflow()
- r.request.content = b"gibble"
- assert not fm.do_server_playback(r)
- assert fm.do_server_playback(tutils.tflow())
-
- fm.tick(0)
- assert fm.should_exit.is_set()
-
- fm.stop_server_playback()
- assert not fm.server_playback
-
- def test_server_playback_kill(self):
- s = flow.State()
- f = tutils.tflow()
- f.response = HTTPResponse.wrap(netlib.tutils.tresp(content=f.request))
- pb = [f]
- fm = flow.FlowMaster(None, None, s)
- fm.refresh_server_playback = True
- fm.start_server_playback(
- pb,
- True,
- [],
- False,
- False,
- None,
- False,
- None,
- False)
-
- f = tutils.tflow()
- f.request.host = "nonexistent"
- fm.request(f)
- assert f.reply.value == Kill
-
class TestRequest:
diff --git a/test/mitmproxy/test_fuzzing.py b/test/mitmproxy/test_fuzzing.py
index 27ea36a6..905ba1cd 100644
--- a/test/mitmproxy/test_fuzzing.py
+++ b/test/mitmproxy/test_fuzzing.py
@@ -11,17 +11,20 @@ class TestFuzzy(tservers.HTTPProxyTest):
def test_idna_err(self):
req = r'get:"http://localhost:%s":i10,"\xc6"'
p = self.pathoc()
- assert p.request(req % self.server.port).status_code == 400
+ with p.connect():
+ assert p.request(req % self.server.port).status_code == 400
def test_nullbytes(self):
req = r'get:"http://localhost:%s":i19,"\x00"'
p = self.pathoc()
- assert p.request(req % self.server.port).status_code == 400
+ with p.connect():
+ assert p.request(req % self.server.port).status_code == 400
def test_invalid_ipv6_url(self):
req = 'get:"http://localhost:%s":i13,"["'
p = self.pathoc()
- resp = p.request(req % self.server.port)
+ with p.connect():
+ resp = p.request(req % self.server.port)
assert resp.status_code == 400
# def test_invalid_upstream(self):
diff --git a/test/mitmproxy/test_server.py b/test/mitmproxy/test_server.py
index e0a8da47..c5a5bb71 100644
--- a/test/mitmproxy/test_server.py
+++ b/test/mitmproxy/test_server.py
@@ -60,7 +60,7 @@ class CommonMixin:
# Disconnect error
l.request.path = "/p/305:d0"
rt = self.master.replay_request(l, block=True)
- assert not rt
+ assert rt
if isinstance(self, tservers.HTTPUpstreamProxyTest):
assert l.response.status_code == 502
else:
@@ -72,7 +72,7 @@ class CommonMixin:
# In upstream mode with ssl, the replay will fail as we cannot establish
# SSL with the upstream proxy.
rt = self.master.replay_request(l, block=True)
- assert not rt
+ assert rt
if isinstance(self, tservers.HTTPUpstreamProxyTest):
assert l.response.status_code == 502
else:
@@ -91,11 +91,11 @@ class CommonMixin:
def test_invalid_http(self):
t = tcp.TCPClient(("127.0.0.1", self.proxy.port))
- t.connect()
- t.wfile.write(b"invalid\r\n\r\n")
- t.wfile.flush()
- line = t.rfile.readline()
- assert (b"Bad Request" in line) or (b"Bad Gateway" in line)
+ with t.connect():
+ t.wfile.write(b"invalid\r\n\r\n")
+ t.wfile.flush()
+ line = t.rfile.readline()
+ assert (b"Bad Request" in line) or (b"Bad Gateway" in line)
def test_sni(self):
if not self.ssl:
@@ -208,20 +208,22 @@ class TestHTTP(tservers.HTTPProxyTest, CommonMixin, AppMixin):
def test_app_err(self):
p = self.pathoc()
- ret = p.request("get:'http://errapp/'")
+ with p.connect():
+ ret = p.request("get:'http://errapp/'")
assert ret.status_code == 500
assert b"ValueError" in ret.content
def test_invalid_connect(self):
t = tcp.TCPClient(("127.0.0.1", self.proxy.port))
- t.connect()
- t.wfile.write(b"CONNECT invalid\n\n")
- t.wfile.flush()
- assert b"Bad Request" in t.rfile.readline()
+ with t.connect():
+ t.wfile.write(b"CONNECT invalid\n\n")
+ t.wfile.flush()
+ assert b"Bad Request" in t.rfile.readline()
def test_upstream_ssl_error(self):
p = self.pathoc()
- ret = p.request("get:'https://localhost:%s/'" % self.server.port)
+ with p.connect():
+ ret = p.request("get:'https://localhost:%s/'" % self.server.port)
assert ret.status_code == 400
def test_connection_close(self):
@@ -232,25 +234,28 @@ class TestHTTP(tservers.HTTPProxyTest, CommonMixin, AppMixin):
# Lets sanity check that the connection does indeed stay open by
# issuing two requests over the same connection
p = self.pathoc()
- assert p.request("get:'%s'" % response)
- assert p.request("get:'%s'" % response)
+ with p.connect():
+ assert p.request("get:'%s'" % response)
+ assert p.request("get:'%s'" % response)
# Now check that the connection is closed as the client specifies
p = self.pathoc()
- assert p.request("get:'%s':h'Connection'='close'" % response)
- # There's a race here, which means we can get any of a number of errors.
- # Rather than introduce yet another sleep into the test suite, we just
- # relax the Exception specification.
- with raises(Exception):
- p.request("get:'%s'" % response)
+ with p.connect():
+ assert p.request("get:'%s':h'Connection'='close'" % response)
+ # There's a race here, which means we can get any of a number of errors.
+ # Rather than introduce yet another sleep into the test suite, we just
+ # relax the Exception specification.
+ with raises(Exception):
+ p.request("get:'%s'" % response)
def test_reconnect(self):
req = "get:'%s/p/200:b@1:da'" % self.server.urlbase
p = self.pathoc()
- assert p.request(req)
- # Server has disconnected. Mitmproxy should detect this, and reconnect.
- assert p.request(req)
- assert p.request(req)
+ with p.connect():
+ assert p.request(req)
+ # Server has disconnected. Mitmproxy should detect this, and reconnect.
+ assert p.request(req)
+ assert p.request(req)
def test_get_connection_switching(self):
def switched(l):
@@ -260,18 +265,21 @@ class TestHTTP(tservers.HTTPProxyTest, CommonMixin, AppMixin):
req = "get:'%s/p/200:b@1'"
p = self.pathoc()
- assert p.request(req % self.server.urlbase)
- assert p.request(req % self.server2.urlbase)
+ with p.connect():
+ assert p.request(req % self.server.urlbase)
+ assert p.request(req % self.server2.urlbase)
assert switched(self.proxy.tlog)
def test_blank_leading_line(self):
p = self.pathoc()
- req = "get:'%s/p/201':i0,'\r\n'"
- assert p.request(req % self.server.urlbase).status_code == 201
+ with p.connect():
+ req = "get:'%s/p/201':i0,'\r\n'"
+ assert p.request(req % self.server.urlbase).status_code == 201
def test_invalid_headers(self):
p = self.pathoc()
- resp = p.request("get:'http://foo':h':foo'='bar'")
+ with p.connect():
+ resp = p.request("get:'http://foo':h':foo'='bar'")
assert resp.status_code == 400
def test_stream(self):
@@ -301,15 +309,16 @@ class TestHTTPAuth(tservers.HTTPProxyTest):
self.master.options.auth_singleuser = "test:test"
assert self.pathod("202").status_code == 407
p = self.pathoc()
- ret = p.request("""
- get
- 'http://localhost:%s/p/202'
- h'%s'='%s'
- """ % (
- self.server.port,
- http.authentication.BasicProxyAuth.AUTH_HEADER,
- authentication.assemble_http_basic_auth("basic", "test", "test")
- ))
+ with p.connect():
+ ret = p.request("""
+ get
+ 'http://localhost:%s/p/202'
+ h'%s'='%s'
+ """ % (
+ self.server.port,
+ http.authentication.BasicProxyAuth.AUTH_HEADER,
+ authentication.assemble_http_basic_auth("basic", "test", "test")
+ ))
assert ret.status_code == 202
@@ -318,14 +327,15 @@ class TestHTTPReverseAuth(tservers.ReverseProxyTest):
self.master.options.auth_singleuser = "test:test"
assert self.pathod("202").status_code == 401
p = self.pathoc()
- ret = p.request("""
- get
- '/p/202'
- h'%s'='%s'
- """ % (
- http.authentication.BasicWebsiteAuth.AUTH_HEADER,
- authentication.assemble_http_basic_auth("basic", "test", "test")
- ))
+ with p.connect():
+ ret = p.request("""
+ get
+ '/p/202'
+ h'%s'='%s'
+ """ % (
+ http.authentication.BasicWebsiteAuth.AUTH_HEADER,
+ authentication.assemble_http_basic_auth("basic", "test", "test")
+ ))
assert ret.status_code == 202
@@ -354,7 +364,8 @@ class TestHTTPS(tservers.HTTPProxyTest, CommonMixin, TcpMixin):
def test_error_post_connect(self):
p = self.pathoc()
- assert p.request("get:/:i0,'invalid\r\n\r\n'").status_code == 400
+ with p.connect():
+ assert p.request("get:/:i0,'invalid\r\n\r\n'").status_code == 400
class TestHTTPSCertfile(tservers.HTTPProxyTest, CommonMixin):
@@ -389,7 +400,8 @@ class TestHTTPSUpstreamServerVerificationWTrustedCert(tservers.HTTPProxyTest):
def _request(self):
p = self.pathoc(sni="example.mitmproxy.org")
- return p.request("get:/p/242")
+ with p.connect():
+ return p.request("get:/p/242")
def test_verification_w_cadir(self):
self.config.options.update(
@@ -426,7 +438,8 @@ class TestHTTPSUpstreamServerVerificationWBadCert(tservers.HTTPProxyTest):
def _request(self):
p = self.pathoc(sni="example.mitmproxy.org")
- return p.request("get:/p/242")
+ with p.connect():
+ return p.request("get:/p/242")
@classmethod
def get_options(cls):
@@ -481,13 +494,15 @@ class TestSocks5(tservers.SocksModeTest):
def test_simple(self):
p = self.pathoc()
- p.socks_connect(("localhost", self.server.port))
- f = p.request("get:/p/200")
+ with p.connect():
+ p.socks_connect(("localhost", self.server.port))
+ f = p.request("get:/p/200")
assert f.status_code == 200
def test_with_authentication_only(self):
p = self.pathoc()
- f = p.request("get:/p/200")
+ with p.connect():
+ f = p.request("get:/p/200")
assert f.status_code == 502
assert b"SOCKS5 mode failure" in f.content
@@ -496,21 +511,21 @@ class TestSocks5(tservers.SocksModeTest):
mitmproxy doesn't support UDP or BIND SOCKS CMDs
"""
p = self.pathoc()
-
- socks.ClientGreeting(
- socks.VERSION.SOCKS5,
- [socks.METHOD.NO_AUTHENTICATION_REQUIRED]
- ).to_file(p.wfile)
- socks.Message(
- socks.VERSION.SOCKS5,
- socks.CMD.BIND,
- socks.ATYP.DOMAINNAME,
- ("example.com", 8080)
- ).to_file(p.wfile)
-
- p.wfile.flush()
- p.rfile.read(2) # read server greeting
- f = p.request("get:/p/200") # the request doesn't matter, error response from handshake will be read anyway.
+ with p.connect():
+ socks.ClientGreeting(
+ socks.VERSION.SOCKS5,
+ [socks.METHOD.NO_AUTHENTICATION_REQUIRED]
+ ).to_file(p.wfile)
+ socks.Message(
+ socks.VERSION.SOCKS5,
+ socks.CMD.BIND,
+ socks.ATYP.DOMAINNAME,
+ ("example.com", 8080)
+ ).to_file(p.wfile)
+
+ p.wfile.flush()
+ p.rfile.read(2) # read server greeting
+ f = p.request("get:/p/200") # the request doesn't matter, error response from handshake will be read anyway.
assert f.status_code == 502
assert b"SOCKS5 mode failure" in f.content
@@ -531,21 +546,23 @@ class TestHttps2Http(tservers.ReverseProxyTest):
p = pathoc.Pathoc(
("localhost", self.proxy.port), ssl=True, sni=sni, fp=None
)
- p.connect()
return p
def test_all(self):
p = self.pathoc(ssl=True)
- assert p.request("get:'/p/200'").status_code == 200
+ with p.connect():
+ assert p.request("get:'/p/200'").status_code == 200
def test_sni(self):
p = self.pathoc(ssl=True, sni="example.com")
- assert p.request("get:'/p/200'").status_code == 200
- assert all("Error in handle_sni" not in msg for msg in self.proxy.tlog)
+ with p.connect():
+ assert p.request("get:'/p/200'").status_code == 200
+ assert all("Error in handle_sni" not in msg for msg in self.proxy.tlog)
def test_http(self):
p = self.pathoc(ssl=False)
- assert p.request("get:'/p/200'").status_code == 200
+ with p.connect():
+ assert p.request("get:'/p/200'").status_code == 200
class TestTransparent(tservers.TransparentProxyTest, CommonMixin, TcpMixin):
@@ -703,29 +720,29 @@ class TestRedirectRequest(tservers.HTTPProxyTest):
self.master.redirect_port = self.server2.port
p = self.pathoc()
-
- self.server.clear_log()
- self.server2.clear_log()
- r1 = p.request("get:'/p/200'")
- assert r1.status_code == 200
- assert self.server.last_log()
- assert not self.server2.last_log()
-
- self.server.clear_log()
- self.server2.clear_log()
- r2 = p.request("get:'/p/201'")
- assert r2.status_code == 201
- assert not self.server.last_log()
- assert self.server2.last_log()
-
- self.server.clear_log()
- self.server2.clear_log()
- r3 = p.request("get:'/p/202'")
- assert r3.status_code == 202
- assert self.server.last_log()
- assert not self.server2.last_log()
-
- assert r1.content == r2.content == r3.content
+ with p.connect():
+ self.server.clear_log()
+ self.server2.clear_log()
+ r1 = p.request("get:'/p/200'")
+ assert r1.status_code == 200
+ assert self.server.last_log()
+ assert not self.server2.last_log()
+
+ self.server.clear_log()
+ self.server2.clear_log()
+ r2 = p.request("get:'/p/201'")
+ assert r2.status_code == 201
+ assert not self.server.last_log()
+ assert self.server2.last_log()
+
+ self.server.clear_log()
+ self.server2.clear_log()
+ r3 = p.request("get:'/p/202'")
+ assert r3.status_code == 202
+ assert self.server.last_log()
+ assert not self.server2.last_log()
+
+ assert r1.content == r2.content == r3.content
class MasterStreamRequest(tservers.TestMaster):
@@ -743,22 +760,22 @@ class TestStreamRequest(tservers.HTTPProxyTest):
def test_stream_simple(self):
p = self.pathoc()
-
- # a request with 100k of data but without content-length
- r1 = p.request("get:'%s/p/200:r:b@100k:d102400'" % self.server.urlbase)
- assert r1.status_code == 200
- assert len(r1.content) > 100000
+ with p.connect():
+ # a request with 100k of data but without content-length
+ r1 = p.request("get:'%s/p/200:r:b@100k:d102400'" % self.server.urlbase)
+ assert r1.status_code == 200
+ assert len(r1.content) > 100000
def test_stream_multiple(self):
p = self.pathoc()
+ with p.connect():
+ # simple request with streaming turned on
+ r1 = p.request("get:'%s/p/200'" % self.server.urlbase)
+ assert r1.status_code == 200
- # simple request with streaming turned on
- r1 = p.request("get:'%s/p/200'" % self.server.urlbase)
- assert r1.status_code == 200
-
- # now send back 100k of data, streamed but not chunked
- r1 = p.request("get:'%s/p/201:b@100k'" % self.server.urlbase)
- assert r1.status_code == 201
+ # now send back 100k of data, streamed but not chunked
+ r1 = p.request("get:'%s/p/201:b@100k'" % self.server.urlbase)
+ assert r1.status_code == 201
def test_stream_chunked(self):
connection = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
@@ -887,7 +904,8 @@ class TestUpstreamProxy(tservers.HTTPUpstreamProxyTest, CommonMixin, AppMixin):
("~s", "baz", "ORLY")
]
p = self.pathoc()
- req = p.request("get:'%s/p/418:b\"foo\"'" % self.server.urlbase)
+ with p.connect():
+ req = p.request("get:'%s/p/418:b\"foo\"'" % self.server.urlbase)
assert req.content == b"ORLY"
assert req.status_code == 418
@@ -948,7 +966,8 @@ class TestUpstreamProxySSL(
def test_simple(self):
p = self.pathoc()
- req = p.request("get:'/p/418:b\"content\"'")
+ with p.connect():
+ req = p.request("get:'/p/418:b\"content\"'")
assert req.content == b"content"
assert req.status_code == 418
@@ -1006,48 +1025,49 @@ class TestProxyChainingSSLReconnect(tservers.HTTPUpstreamProxyTest):
])
p = self.pathoc()
- req = p.request("get:'/p/418:b\"content\"'")
- assert req.content == b"content"
- assert req.status_code == 418
-
- assert self.proxy.tmaster.state.flow_count() == 2 # CONNECT and request
- # CONNECT, failing request,
- assert self.chain[0].tmaster.state.flow_count() == 4
- # reCONNECT, request
- # failing request, request
- assert self.chain[1].tmaster.state.flow_count() == 2
- # (doesn't store (repeated) CONNECTs from chain[0]
- # as it is a regular proxy)
-
- assert not self.chain[1].tmaster.state.flows[0].response # killed
- assert self.chain[1].tmaster.state.flows[1].response
-
- assert self.proxy.tmaster.state.flows[0].request.first_line_format == "authority"
- assert self.proxy.tmaster.state.flows[1].request.first_line_format == "relative"
-
- assert self.chain[0].tmaster.state.flows[
- 0].request.first_line_format == "authority"
- assert self.chain[0].tmaster.state.flows[
- 1].request.first_line_format == "relative"
- assert self.chain[0].tmaster.state.flows[
- 2].request.first_line_format == "authority"
- assert self.chain[0].tmaster.state.flows[
- 3].request.first_line_format == "relative"
-
- assert self.chain[1].tmaster.state.flows[
- 0].request.first_line_format == "relative"
- assert self.chain[1].tmaster.state.flows[
- 1].request.first_line_format == "relative"
-
- req = p.request("get:'/p/418:b\"content2\"'")
-
- assert req.status_code == 502
- assert self.proxy.tmaster.state.flow_count() == 3 # + new request
- # + new request, repeated CONNECT from chain[1]
- assert self.chain[0].tmaster.state.flow_count() == 6
- # (both terminated)
- # nothing happened here
- assert self.chain[1].tmaster.state.flow_count() == 2
+ with p.connect():
+ req = p.request("get:'/p/418:b\"content\"'")
+ assert req.content == b"content"
+ assert req.status_code == 418
+
+ assert self.proxy.tmaster.state.flow_count() == 2 # CONNECT and request
+ # CONNECT, failing request,
+ assert self.chain[0].tmaster.state.flow_count() == 4
+ # reCONNECT, request
+ # failing request, request
+ assert self.chain[1].tmaster.state.flow_count() == 2
+ # (doesn't store (repeated) CONNECTs from chain[0]
+ # as it is a regular proxy)
+
+ assert not self.chain[1].tmaster.state.flows[0].response # killed
+ assert self.chain[1].tmaster.state.flows[1].response
+
+ assert self.proxy.tmaster.state.flows[0].request.first_line_format == "authority"
+ assert self.proxy.tmaster.state.flows[1].request.first_line_format == "relative"
+
+ assert self.chain[0].tmaster.state.flows[
+ 0].request.first_line_format == "authority"
+ assert self.chain[0].tmaster.state.flows[
+ 1].request.first_line_format == "relative"
+ assert self.chain[0].tmaster.state.flows[
+ 2].request.first_line_format == "authority"
+ assert self.chain[0].tmaster.state.flows[
+ 3].request.first_line_format == "relative"
+
+ assert self.chain[1].tmaster.state.flows[
+ 0].request.first_line_format == "relative"
+ assert self.chain[1].tmaster.state.flows[
+ 1].request.first_line_format == "relative"
+
+ req = p.request("get:'/p/418:b\"content2\"'")
+
+ assert req.status_code == 502
+ assert self.proxy.tmaster.state.flow_count() == 3 # + new request
+ # + new request, repeated CONNECT from chain[1]
+ assert self.chain[0].tmaster.state.flow_count() == 6
+ # (both terminated)
+ # nothing happened here
+ assert self.chain[1].tmaster.state.flow_count() == 2
class AddUpstreamCertsToClientChainMixin:
@@ -1066,12 +1086,13 @@ class AddUpstreamCertsToClientChainMixin:
d = f.read()
upstreamCert = SSLCert.from_pem(d)
p = self.pathoc()
- upstream_cert_found_in_client_chain = False
- for receivedCert in p.server_certs:
- if receivedCert.digest('sha256') == upstreamCert.digest('sha256'):
- upstream_cert_found_in_client_chain = True
- break
- assert(upstream_cert_found_in_client_chain == self.master.options.add_upstream_certs_to_client_chain)
+ with p.connect():
+ upstream_cert_found_in_client_chain = False
+ for receivedCert in p.server_certs:
+ if receivedCert.digest('sha256') == upstreamCert.digest('sha256'):
+ upstream_cert_found_in_client_chain = True
+ break
+ assert(upstream_cert_found_in_client_chain == self.master.options.add_upstream_certs_to_client_chain)
class TestHTTPSAddUpstreamCertsToClientChainTrue(
diff --git a/test/mitmproxy/tservers.py b/test/mitmproxy/tservers.py
index 1597f59c..4291f743 100644
--- a/test/mitmproxy/tservers.py
+++ b/test/mitmproxy/tservers.py
@@ -3,6 +3,7 @@ import threading
import tempfile
import flask
import mock
+import sys
from mitmproxy.proxy.config import ProxyConfig
from mitmproxy.proxy.server import ProxyServer
@@ -10,6 +11,7 @@ import pathod.test
import pathod.pathoc
from mitmproxy import flow, controller, options
from mitmproxy import builtins
+import netlib.exceptions
testapp = flask.Flask(__name__)
@@ -104,6 +106,14 @@ class ProxyTestBase(object):
cls.server.shutdown()
cls.server2.shutdown()
+ def teardown(self):
+ try:
+ self.server.wait_for_silence()
+ except netlib.exceptions.Timeout:
+ # FIXME: Track down the Windows sync issues
+ if sys.platform != "win32":
+ raise
+
def setup(self):
self.master.clear_log()
self.master.state.clear()
@@ -125,6 +135,15 @@ class ProxyTestBase(object):
)
+class LazyPathoc(pathod.pathoc.Pathoc):
+ def __init__(self, lazy_connect, *args, **kwargs):
+ self.lazy_connect = lazy_connect
+ pathod.pathoc.Pathoc.__init__(self, *args, **kwargs)
+
+ def connect(self):
+ return pathod.pathoc.Pathoc.connect(self, self.lazy_connect)
+
+
class HTTPProxyTest(ProxyTestBase):
def pathoc_raw(self):
@@ -134,14 +153,14 @@ class HTTPProxyTest(ProxyTestBase):
"""
Returns a connected Pathoc instance.
"""
- p = pathod.pathoc.Pathoc(
- ("localhost", self.proxy.port), ssl=self.ssl, sni=sni, fp=None
- )
if self.ssl:
- p.connect(("127.0.0.1", self.server.port))
+ conn = ("127.0.0.1", self.server.port)
else:
- p.connect()
- return p
+ conn = None
+ return LazyPathoc(
+ conn,
+ ("localhost", self.proxy.port), ssl=self.ssl, sni=sni, fp=None
+ )
def pathod(self, spec, sni=None):
"""
@@ -152,18 +171,20 @@ class HTTPProxyTest(ProxyTestBase):
q = "get:'/p/%s'" % spec
else:
q = "get:'%s/p/%s'" % (self.server.urlbase, spec)
- return p.request(q)
+ with p.connect():
+ return p.request(q)
def app(self, page):
if self.ssl:
p = pathod.pathoc.Pathoc(
("127.0.0.1", self.proxy.port), True, fp=None
)
- p.connect((options.APP_HOST, options.APP_PORT))
- return p.request("get:'%s'" % page)
+ with p.connect((options.APP_HOST, options.APP_PORT)):
+ return p.request("get:'%s'" % page)
else:
p = self.pathoc()
- return p.request("get:'http://%s%s'" % (options.APP_HOST, page))
+ with p.connect():
+ return p.request("get:'http://%s%s'" % (options.APP_HOST, page))
class TResolver:
@@ -210,7 +231,8 @@ class TransparentProxyTest(ProxyTestBase):
else:
p = self.pathoc()
q = "get:'/p/%s'" % spec
- return p.request(q)
+ with p.connect():
+ return p.request(q)
def pathoc(self, sni=None):
"""
@@ -219,7 +241,6 @@ class TransparentProxyTest(ProxyTestBase):
p = pathod.pathoc.Pathoc(
("localhost", self.proxy.port), ssl=self.ssl, sni=sni, fp=None
)
- p.connect()
return p
@@ -247,7 +268,6 @@ class ReverseProxyTest(ProxyTestBase):
p = pathod.pathoc.Pathoc(
("localhost", self.proxy.port), ssl=self.ssl, sni=sni, fp=None
)
- p.connect()
return p
def pathod(self, spec, sni=None):
@@ -260,7 +280,8 @@ class ReverseProxyTest(ProxyTestBase):
else:
p = self.pathoc()
q = "get:'/p/%s'" % spec
- return p.request(q)
+ with p.connect():
+ return p.request(q)
class SocksModeTest(HTTPProxyTest):
diff --git a/test/netlib/http/test_headers.py b/test/netlib/http/test_headers.py
index ad2bc548..e8752c52 100644
--- a/test/netlib/http/test_headers.py
+++ b/test/netlib/http/test_headers.py
@@ -43,6 +43,15 @@ class TestHeaders(object):
with raises(TypeError):
Headers([[b"Host", u"not-bytes"]])
+ def test_set(self):
+ headers = Headers()
+ headers[u"foo"] = u"1"
+ headers[b"bar"] = b"2"
+ headers["baz"] = b"3"
+ with raises(TypeError):
+ headers["foobar"] = 42
+ assert len(headers) == 3
+
def test_bytes(self):
headers = Headers(Host="example.com")
assert bytes(headers) == b"Host: example.com\r\n"
diff --git a/test/netlib/http/test_response.py b/test/netlib/http/test_response.py
index c7b1b646..e97cc419 100644
--- a/test/netlib/http/test_response.py
+++ b/test/netlib/http/test_response.py
@@ -34,6 +34,11 @@ class TestResponseCore(object):
assert r.status_code == 200
assert r.content == b""
+ r = Response.make(418, "teatime")
+ assert r.status_code == 418
+ assert r.content == b"teatime"
+ assert r.headers["content-length"] == "7"
+
Response.make(content=b"foo")
Response.make(content="foo")
with raises(TypeError):
diff --git a/test/netlib/test_encoding.py b/test/netlib/test_encoding.py
index 08e69ec5..e1175ef0 100644
--- a/test/netlib/test_encoding.py
+++ b/test/netlib/test_encoding.py
@@ -1,55 +1,46 @@
import mock
+import pytest
+
from netlib import encoding, tutils
-def test_identity():
- assert b"string" == encoding.decode(b"string", "identity")
- assert b"string" == encoding.encode(b"string", "identity")
+@pytest.mark.parametrize("encoder", [
+ 'identity',
+ 'none',
+])
+def test_identity(encoder):
+ assert b"string" == encoding.decode(b"string", encoder)
+ assert b"string" == encoding.encode(b"string", encoder)
with tutils.raises(ValueError):
encoding.encode(b"string", "nonexistent encoding")
-def test_gzip():
- assert b"string" == encoding.decode(
- encoding.encode(
- b"string",
- "gzip"
- ),
- "gzip"
- )
- with tutils.raises(ValueError):
- encoding.decode(b"bogus", "gzip")
+@pytest.mark.parametrize("encoder", [
+ 'gzip',
+ 'br',
+ 'deflate',
+])
+def test_encoders(encoder):
+ assert "" == encoding.decode("", encoder)
+ assert b"" == encoding.decode(b"", encoder)
-
-def test_brotli():
- assert b"string" == encoding.decode(
+ assert "string" == encoding.decode(
encoding.encode(
- b"string",
- "br"
+ "string",
+ encoder
),
- "br"
+ encoder
)
- with tutils.raises(ValueError):
- encoding.decode(b"bogus", "br")
-
-
-def test_deflate():
assert b"string" == encoding.decode(
encoding.encode(
b"string",
- "deflate"
+ encoder
),
- "deflate"
- )
- assert b"string" == encoding.decode(
- encoding.encode(
- b"string",
- "deflate"
- )[2:-4],
- "deflate"
+ encoder
)
+
with tutils.raises(ValueError):
- encoding.decode(b"bogus", "deflate")
+ encoding.decode(b"foobar", encoder)
def test_cache():
diff --git a/test/netlib/test_strutils.py b/test/netlib/test_strutils.py
index 5be254a3..0f58cac5 100644
--- a/test/netlib/test_strutils.py
+++ b/test/netlib/test_strutils.py
@@ -8,6 +8,8 @@ def test_always_bytes():
assert strutils.always_bytes("foo") == b"foo"
with tutils.raises(ValueError):
strutils.always_bytes(u"\u2605", "ascii")
+ with tutils.raises(TypeError):
+ strutils.always_bytes(42, "ascii")
def test_native():
diff --git a/test/pathod/test_protocols_http2.py b/test/pathod/test_protocols_http2.py
index 8d7efc82..7f65c0eb 100644
--- a/test/pathod/test_protocols_http2.py
+++ b/test/pathod/test_protocols_http2.py
@@ -5,7 +5,7 @@ import hyperframe
from netlib import tcp, http
from netlib.tutils import raises
from netlib.exceptions import TcpDisconnect
-from netlib.http.http2 import framereader
+from netlib.http import http2
from ..netlib import tservers as netlib_tservers
@@ -112,11 +112,11 @@ class TestPerformServerConnectionPreface(netlib_tservers.ServerTestBase):
self.wfile.flush()
# check empty settings frame
- raw = framereader.http2_read_raw_frame(self.rfile)
+ raw = http2.read_raw_frame(self.rfile)
assert raw == codecs.decode('00000c040000000000000200000000000300000001', 'hex_codec')
# check settings acknowledgement
- raw = framereader.http2_read_raw_frame(self.rfile)
+ raw = http2.read_raw_frame(self.rfile)
assert raw == codecs.decode('000000040100000000', 'hex_codec')
# send settings acknowledgement