diff options
-rw-r--r-- | mitmproxy/addons/save.py | 9 | ||||
-rw-r--r-- | mitmproxy/contrib/wsproto/__init__.py | 13 | ||||
-rw-r--r-- | mitmproxy/contrib/wsproto/extensions.py | 2 | ||||
-rw-r--r-- | mitmproxy/contrib/wsproto/frame_protocol.py | 2 | ||||
-rw-r--r-- | mitmproxy/flowfilter.py | 10 | ||||
-rw-r--r-- | mitmproxy/master.py | 30 | ||||
-rw-r--r-- | mitmproxy/proxy/protocol/http.py | 1 | ||||
-rw-r--r-- | mitmproxy/proxy/protocol/websocket.py | 10 | ||||
-rw-r--r-- | mitmproxy/test/tflow.py | 6 | ||||
-rw-r--r-- | mitmproxy/websocket.py | 18 | ||||
-rw-r--r-- | setup.cfg | 3 | ||||
-rw-r--r-- | test/mitmproxy/addons/test_save.py | 13 | ||||
-rw-r--r-- | test/mitmproxy/test_flow.py | 16 | ||||
-rw-r--r-- | test/mitmproxy/test_flowfilter.py | 14 |
14 files changed, 128 insertions, 19 deletions
diff --git a/mitmproxy/addons/save.py b/mitmproxy/addons/save.py index 1778855d..44afef68 100644 --- a/mitmproxy/addons/save.py +++ b/mitmproxy/addons/save.py @@ -75,6 +75,15 @@ class Save: self.stream.add(flow) self.active_flows.discard(flow) + def websocket_start(self, flow): + if self.stream: + self.active_flows.add(flow) + + def websocket_end(self, flow): + if self.stream: + self.stream.add(flow) + self.active_flows.discard(flow) + def response(self, flow): if self.stream: self.stream.add(flow) diff --git a/mitmproxy/contrib/wsproto/__init__.py b/mitmproxy/contrib/wsproto/__init__.py new file mode 100644 index 00000000..d0592bc5 --- /dev/null +++ b/mitmproxy/contrib/wsproto/__init__.py @@ -0,0 +1,13 @@ +from . import compat +from . import connection +from . import events +from . import extensions +from . import frame_protocol + +__all__ = [ + 'compat', + 'connection', + 'events', + 'extensions', + 'frame_protocol', +] diff --git a/mitmproxy/contrib/wsproto/extensions.py b/mitmproxy/contrib/wsproto/extensions.py index f7cf4fb6..0e0d2018 100644 --- a/mitmproxy/contrib/wsproto/extensions.py +++ b/mitmproxy/contrib/wsproto/extensions.py @@ -1,3 +1,5 @@ +# type: ignore + # -*- coding: utf-8 -*- """ wsproto/extensions diff --git a/mitmproxy/contrib/wsproto/frame_protocol.py b/mitmproxy/contrib/wsproto/frame_protocol.py index b95dceec..30f146c6 100644 --- a/mitmproxy/contrib/wsproto/frame_protocol.py +++ b/mitmproxy/contrib/wsproto/frame_protocol.py @@ -1,3 +1,5 @@ +# type: ignore + # -*- coding: utf-8 -*- """ wsproto/frame_protocol diff --git a/mitmproxy/flowfilter.py b/mitmproxy/flowfilter.py index 23e47e2b..d1fd8299 100644 --- a/mitmproxy/flowfilter.py +++ b/mitmproxy/flowfilter.py @@ -322,8 +322,10 @@ class FDomain(_Rex): flags = re.IGNORECASE is_binary = False - @only(http.HTTPFlow) + @only(http.HTTPFlow, websocket.WebSocketFlow) def __call__(self, f): + if isinstance(f, websocket.WebSocketFlow): + f = f.handshake_flow return bool( self.re.search(f.request.host) or self.re.search(f.request.pretty_host) @@ -342,9 +344,11 @@ class FUrl(_Rex): toks = toks[1:] return klass(*toks) - @only(http.HTTPFlow) + @only(http.HTTPFlow, websocket.WebSocketFlow) def __call__(self, f): - if not f.request: + if isinstance(f, websocket.WebSocketFlow): + f = f.handshake_flow + if not f or not f.request: return False return self.re.search(f.request.pretty_url) diff --git a/mitmproxy/master.py b/mitmproxy/master.py index 5997ff6d..de3b24e1 100644 --- a/mitmproxy/master.py +++ b/mitmproxy/master.py @@ -9,6 +9,7 @@ from mitmproxy import eventsequence from mitmproxy import exceptions from mitmproxy import command from mitmproxy import http +from mitmproxy import websocket from mitmproxy import log from mitmproxy.net import server_spec from mitmproxy.proxy.protocol import http_replay @@ -41,6 +42,7 @@ class Master: self.should_exit = threading.Event() self._server = None self.first_tick = True + self.waiting_flows = [] @property def server(self): @@ -117,15 +119,33 @@ class Master: self.should_exit.set() self.addons.trigger("done") + def _change_reverse_host(self, f): + """ + When we load flows in reverse proxy mode, we adjust the target host to + the reverse proxy destination for all flows we load. This makes it very + easy to replay saved flows against a different host. + """ + if self.options.mode.startswith("reverse:"): + _, upstream_spec = server_spec.parse_with_mode(self.options.mode) + f.request.host, f.request.port = upstream_spec.address + f.request.scheme = upstream_spec.scheme + def load_flow(self, f): """ - Loads a flow + Loads a flow and links websocket & handshake flows """ + if isinstance(f, http.HTTPFlow): - if self.options.mode.startswith("reverse:"): - _, upstream_spec = server_spec.parse_with_mode(self.options.mode) - f.request.host, f.request.port = upstream_spec.address - f.request.scheme = upstream_spec.scheme + self._change_reverse_host(f) + if 'websocket' in f.metadata: + self.waiting_flows.append(f) + + if isinstance(f, websocket.WebSocketFlow): + hf = [hf for hf in self.waiting_flows if hf.id == f.metadata['websocket_handshake']][0] + f.handshake_flow = hf + self.waiting_flows.remove(hf) + self._change_reverse_host(f.handshake_flow) + f.reply = controller.DummyReply() for e, o in eventsequence.iterate(f): self.addons.handle_lifecycle(e, o) diff --git a/mitmproxy/proxy/protocol/http.py b/mitmproxy/proxy/protocol/http.py index 57ac0f16..076ffa62 100644 --- a/mitmproxy/proxy/protocol/http.py +++ b/mitmproxy/proxy/protocol/http.py @@ -321,6 +321,7 @@ class HttpLayer(base.Layer): try: if websockets.check_handshake(request.headers) and websockets.check_client_version(request.headers): + f.metadata['websocket'] = True # We only support RFC6455 with WebSocket version 13 # allow inline scripts to manipulate the client handshake self.channel.ask("websocket_handshake", f) diff --git a/mitmproxy/proxy/protocol/websocket.py b/mitmproxy/proxy/protocol/websocket.py index 34dcba06..1bd5284d 100644 --- a/mitmproxy/proxy/protocol/websocket.py +++ b/mitmproxy/proxy/protocol/websocket.py @@ -1,10 +1,11 @@ import socket from OpenSSL import SSL + +from mitmproxy.contrib import wsproto from mitmproxy.contrib.wsproto import events from mitmproxy.contrib.wsproto.connection import ConnectionType, WSConnection from mitmproxy.contrib.wsproto.extensions import PerMessageDeflate -from mitmproxy.contrib.wsproto.frame_protocol import Opcode from mitmproxy import exceptions from mitmproxy import flow @@ -93,11 +94,14 @@ class WebSocketLayer(base.Layer): if event.message_finished: original_chunk_sizes = [len(f) for f in fb] - message_type = Opcode.TEXT if isinstance(event, events.TextReceived) else Opcode.BINARY - if message_type == Opcode.TEXT: + + if isinstance(event, events.TextReceived): + message_type = wsproto.frame_protocol.Opcode.TEXT payload = ''.join(fb) else: + message_type = wsproto.frame_protocol.Opcode.BINARY payload = b''.join(fb) + fb.clear() websocket_message = WebSocketMessage(message_type, not is_server, payload) diff --git a/mitmproxy/test/tflow.py b/mitmproxy/test/tflow.py index c3dab30c..91747866 100644 --- a/mitmproxy/test/tflow.py +++ b/mitmproxy/test/tflow.py @@ -44,7 +44,7 @@ def twebsocketflow(client_conn=True, server_conn=True, messages=True, err=None, "GET", "http", "example.com", - "80", + 80, "/ws", "HTTP/1.1", headers=net_http.Headers( @@ -75,7 +75,9 @@ def twebsocketflow(client_conn=True, server_conn=True, messages=True, err=None, handshake_flow.response = resp f = websocket.WebSocketFlow(client_conn, server_conn, handshake_flow) - handshake_flow.metadata['websocket_flow'] = f + f.metadata['websocket_handshake'] = handshake_flow.id + handshake_flow.metadata['websocket_flow'] = f.id + handshake_flow.metadata['websocket'] = True if messages is True: messages = [ diff --git a/mitmproxy/websocket.py b/mitmproxy/websocket.py index 6c1e7000..8efd4117 100644 --- a/mitmproxy/websocket.py +++ b/mitmproxy/websocket.py @@ -1,6 +1,8 @@ import time from typing import List, Optional +from mitmproxy.contrib import wsproto + from mitmproxy import flow from mitmproxy.net import websockets from mitmproxy.coretypes import serializable @@ -11,7 +13,7 @@ class WebSocketMessage(serializable.Serializable): def __init__( self, type: int, from_client: bool, content: bytes, timestamp: Optional[int]=None ) -> None: - self.type = type + self.type = wsproto.frame_protocol.Opcode(type) # type: ignore self.from_client = from_client self.content = content self.timestamp = timestamp or int(time.time()) # type: int @@ -21,13 +23,14 @@ class WebSocketMessage(serializable.Serializable): return cls(*state) def get_state(self): - return self.type, self.from_client, self.content, self.timestamp + return int(self.type), self.from_client, self.content, self.timestamp def set_state(self, state): self.type, self.from_client, self.content, self.timestamp = state + self.type = wsproto.frame_protocol.Opcode(self.type) # replace enum with bare int def __repr__(self): - if self.type == websockets.OPCODE.TEXT: + if self.type == wsproto.frame_protocol.Opcode.TEXT: return "text message: {}".format(repr(self.content)) else: return "binary message: {}".format(strutils.bytes_to_escaped_str(self.content)) @@ -42,7 +45,7 @@ class WebSocketFlow(flow.Flow): super().__init__("websocket", client_conn, server_conn, live) self.messages = [] # type: List[WebSocketMessage] self.close_sender = 'client' - self.close_code = '(status code missing)' + self.close_code = wsproto.frame_protocol.CloseReason.NORMAL_CLOSURE self.close_message = '(message missing)' self.close_reason = 'unknown status code' self.stream = False @@ -69,7 +72,7 @@ class WebSocketFlow(flow.Flow): _stateobject_attributes.update(dict( messages=List[WebSocketMessage], close_sender=str, - close_code=str, + close_code=int, close_message=str, close_reason=str, client_key=str, @@ -83,6 +86,11 @@ class WebSocketFlow(flow.Flow): # dumping the handshake_flow will include the WebSocketFlow too. )) + def get_state(self): + d = super().get_state() + d['close_code'] = int(d['close_code']) # replace enum with bare int + return d + @classmethod def from_state(cls, state): f = cls(None, None, None) @@ -19,6 +19,9 @@ exclude_lines = pragma: no cover raise NotImplementedError() +[mypy-mitmproxy.contrib.*] +ignore_errors = True + [tool:full_coverage] exclude = mitmproxy/proxy/protocol/base.py diff --git a/test/mitmproxy/addons/test_save.py b/test/mitmproxy/addons/test_save.py index a4e425cd..2dee708f 100644 --- a/test/mitmproxy/addons/test_save.py +++ b/test/mitmproxy/addons/test_save.py @@ -44,6 +44,19 @@ def test_tcp(tmpdir): assert rd(p) +def test_websocket(tmpdir): + sa = save.Save() + with taddons.context() as tctx: + p = str(tmpdir.join("foo")) + tctx.configure(sa, save_stream_file=p) + + f = tflow.twebsocketflow() + sa.websocket_start(f) + sa.websocket_end(f) + tctx.configure(sa, save_stream_file=None) + assert rd(p) + + def test_save_command(tmpdir): sa = save.Save() with taddons.context() as tctx: diff --git a/test/mitmproxy/test_flow.py b/test/mitmproxy/test_flow.py index fcc766b5..8cc11a16 100644 --- a/test/mitmproxy/test_flow.py +++ b/test/mitmproxy/test_flow.py @@ -97,7 +97,7 @@ class TestSerialize: class TestFlowMaster: - def test_load_flow_reverse(self): + def test_load_http_flow_reverse(self): s = tservers.TestState() opts = options.Options( mode="reverse:https://use-this-domain" @@ -108,6 +108,20 @@ class TestFlowMaster: fm.load_flow(f) assert s.flows[0].request.host == "use-this-domain" + def test_load_websocket_flow(self): + s = tservers.TestState() + opts = options.Options( + mode="reverse:https://use-this-domain" + ) + fm = master.Master(opts) + fm.addons.add(s) + f = tflow.twebsocketflow() + fm.load_flow(f.handshake_flow) + fm.load_flow(f) + assert s.flows[0].request.host == "use-this-domain" + assert s.flows[1].handshake_flow == f.handshake_flow + assert len(s.flows[1].messages) == len(f.messages) + def test_replay(self): opts = options.Options() fm = master.Master(opts) diff --git a/test/mitmproxy/test_flowfilter.py b/test/mitmproxy/test_flowfilter.py index c411258a..4eb37d81 100644 --- a/test/mitmproxy/test_flowfilter.py +++ b/test/mitmproxy/test_flowfilter.py @@ -420,6 +420,20 @@ class TestMatchingWebSocketFlow: e = self.err() assert self.q("~e", e) + def test_domain(self): + q = self.flow() + assert self.q("~d example.com", q) + assert not self.q("~d none", q) + + def test_url(self): + q = self.flow() + assert self.q("~u example.com", q) + assert self.q("~u example.com/ws", q) + assert not self.q("~u moo/path", q) + + q.handshake_flow = None + assert not self.q("~u example.com", q) + def test_body(self): f = self.flow() |