From e5b0dae7e9ef8d2ce62fc263c377c76546190825 Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Tue, 16 Aug 2016 18:31:50 +0200 Subject: add websockets support to mitmproxy --- mitmproxy/protocol/http.py | 20 +- mitmproxy/protocol/websockets.py | 140 ++++++++++++++ pathod/language/http.py | 4 +- pathod/pathoc.py | 2 +- pathod/pathod.py | 9 +- pathod/protocols/websockets.py | 2 +- test/mitmproxy/protocol/test_websockets.py | 297 +++++++++++++++++++++++++++++ 7 files changed, 465 insertions(+), 9 deletions(-) create mode 100644 mitmproxy/protocol/websockets.py create mode 100644 test/mitmproxy/protocol/test_websockets.py diff --git a/mitmproxy/protocol/http.py b/mitmproxy/protocol/http.py index d81fc8ca..fbb52c92 100644 --- a/mitmproxy/protocol/http.py +++ b/mitmproxy/protocol/http.py @@ -7,12 +7,15 @@ import traceback import h2.exceptions import six -import netlib.exceptions from mitmproxy import exceptions from mitmproxy import models from mitmproxy.protocol import base +from .websockets import WebSocketsLayer + +import netlib.exceptions from netlib import http from netlib import tcp +from netlib import websockets class _HttpTransmissionLayer(base.Layer): @@ -189,6 +192,21 @@ class HttpLayer(base.Layer): self.process_request_hook(flow) try: + # WebSockets + if websockets.check_handshake(request.headers): + if websockets.check_client_version(request.headers): + layer = WebSocketsLayer(self, request) + layer() + return + else: + # we only support RFC6455 with WebSockets version 13 + self.send_response(models.make_error_response( + 400, + http.status_codes.RESPONSES.get(400), + http.Headers(sec_websocket_version="13") + )) + return + if not flow.response: self.establish_server_connection( flow.request.host, diff --git a/mitmproxy/protocol/websockets.py b/mitmproxy/protocol/websockets.py new file mode 100644 index 00000000..05eaa537 --- /dev/null +++ b/mitmproxy/protocol/websockets.py @@ -0,0 +1,140 @@ +from __future__ import absolute_import, print_function, division + +import socket +import struct + +from OpenSSL import SSL + +from mitmproxy import exceptions +from mitmproxy import models +from mitmproxy.protocol import base + +import netlib.exceptions +from netlib import tcp +from netlib import http +from netlib import websockets + + +class WebSocketsLayer(base.Layer): + """ + WebSockets layer to intercept, modify, and forward WebSockets connections + + Only version 13 is supported (as specified in RFC6455) + Only HTTP/1.1-initiated connections are supported. + + The client starts by sending an Upgrade-request. + In order to determine the handshake and negotiate the correct protocol + and extensions, the Upgrade-request is forwarded to the server. + The response from the server is then parsed and negotiated settings are extracted. + Finally the handshake is completed by forwarding the server-response to the client. + After that, only WebSockets frames are exchanged. + + PING/PONG frames pass through and must be answered by the other endpoint. + + CLOSE frames are forwarded before this WebSocketsLayer terminates. + + This layer is transparent to any negotiated extensions. + This layer is transparent to any negotiated subprotocols. + Only raw frames are forwarded to the other endpoint. + """ + + def __init__(self, ctx, request): + super(WebSocketsLayer, self).__init__(ctx) + self._request = request + + self.client_key = websockets.get_client_key(self._request.headers) + self.client_protocol = websockets.get_protocol(self._request.headers) + self.client_extensions = websockets.get_extensions(self._request.headers) + + self.server_accept = None + self.server_protocol = None + self.server_extensions = None + + def _initiate_server_conn(self): + self.establish_server_connection( + self._request.host, + self._request.port, + self._request.scheme, + ) + + self.server_conn.send(netlib.http.http1.assemble_request(self._request)) + response = netlib.http.http1.read_response(self.server_conn.rfile, self._request, body_size_limit=None) + + if not websockets.check_handshake(response.headers): + raise exceptions.ProtocolException("Establishing WebSockets connection with server failed: {}".format(response.headers)) + + self.server_accept = websockets.get_server_accept(response.headers) + self.server_protocol = websockets.get_protocol(response.headers) + self.server_extensions = websockets.get_extensions(response.headers) + + def _complete_handshake(self): + headers = websockets.server_handshake_headers(self.client_key, self.server_protocol, self.server_extensions) + self.send_response(models.HTTPResponse( + self._request.http_version, + 101, + http.status_codes.RESPONSES.get(101), + headers, + b"", + )) + + def _handle_frame(self, frame, source_conn, other_conn, is_server): + self.log( + "WebSockets Frame received from {}".format("server" if is_server else "client"), + "debug", + [repr(frame)] + ) + + if frame.header.opcode & 0x8 == 0: + # forward the data frame to the other side + other_conn.send(bytes(frame)) + self.log("WebSockets frame received by {}: {}".format(is_server, frame), "debug") + elif frame.header.opcode in (websockets.OPCODE.PING, websockets.OPCODE.PONG): + # just forward the ping/pong to the other side + other_conn.send(bytes(frame)) + elif frame.header.opcode == websockets.OPCODE.CLOSE: + other_conn.send(bytes(frame)) + + code = '(status code missing)' + msg = None + reason = '(message missing)' + if len(frame.payload) >= 2: + code, = struct.unpack('!H', frame.payload[:2]) + msg = websockets.CLOSE_REASON.get_name(code, default='unknown status code') + if len(frame.payload) > 2: + reason = frame.payload[2:] + self.log("WebSockets connection closed: {} {}, {}".format(code, msg, reason), "info") + + # close the connection + return False + else: + # unknown frame - just forward it + other_conn.send(bytes(frame)) + + # continue the connection + return True + + def __call__(self): + self._initiate_server_conn() + self._complete_handshake() + + client = self.client_conn.connection + server = self.server_conn.connection + conns = [client, server] + + try: + while not self.channel.should_exit.is_set(): + r = tcp.ssl_read_select(conns, 1) + for conn in r: + source_conn = self.client_conn if conn == client else self.server_conn + other_conn = self.server_conn if conn == client else self.client_conn + is_server = (conn == self.server_conn.connection) + + frame = websockets.Frame.from_file(source_conn.rfile) + + if not self._handle_frame(frame, source_conn, other_conn, is_server): + return + except (socket.error, netlib.exceptions.TcpException, SSL.Error) as e: + self.log("WebSockets connection closed unexpectedly by {}: {}".format( + "server" if is_server else "client", repr(e)), "info") + except Exception as e: # pragma: no cover + raise exceptions.ProtocolException("Error in WebSockets connection: {}".format(repr(e))) diff --git a/pathod/language/http.py b/pathod/language/http.py index fdc5bba6..46027ca3 100644 --- a/pathod/language/http.py +++ b/pathod/language/http.py @@ -198,7 +198,7 @@ class Response(_HTTPMessage): 1, StatusCode(101) ) - headers = netlib.websockets.WebsocketsProtocol.server_handshake_headers( + headers = netlib.websockets.server_handshake_headers( settings.websocket_key ) for i in headers.fields: @@ -310,7 +310,7 @@ class Request(_HTTPMessage): 1, Method("get") ) - for i in netlib.websockets.WebsocketsProtocol.client_handshake_headers().fields: + for i in netlib.websockets.client_handshake_headers().fields: if not get_header(i[0], self.headers): tokens.append( Header( diff --git a/pathod/pathoc.py b/pathod/pathoc.py index 5831ba3e..a8923013 100644 --- a/pathod/pathoc.py +++ b/pathod/pathoc.py @@ -139,7 +139,7 @@ class WebsocketFrameReader(basethread.BaseThread): except exceptions.TcpDisconnect: return self.frames_queue.put(frm) - log("<< %s" % frm.header.human_readable()) + log("<< %s" % repr(frm.header)) if self.ws_read_limit is not None: self.ws_read_limit -= 1 starttime = time.time() diff --git a/pathod/pathod.py b/pathod/pathod.py index 7087cba6..bd0feb73 100644 --- a/pathod/pathod.py +++ b/pathod/pathod.py @@ -173,12 +173,13 @@ class PathodHandler(tcp.BaseHandler): retlog["cipher"] = self.get_current_cipher() m = utils.MemBool() - websocket_key = websockets.WebsocketsProtocol.check_client_handshake(headers) - self.settings.websocket_key = websocket_key + + valid_websockets_handshake = websockets.check_handshake(headers) + self.settings.websocket_key = websockets.get_client_key(headers) # If this is a websocket initiation, we respond with a proper # server response, unless over-ridden. - if websocket_key: + if valid_websockets_handshake: anchor_gen = language.parse_pathod("ws") else: anchor_gen = None @@ -225,7 +226,7 @@ class PathodHandler(tcp.BaseHandler): spec, lg ) - if nexthandler and websocket_key: + if nexthandler and valid_websockets_handshake: self.protocol = protocols.websockets.WebsocketsProtocol(self) return self.protocol.handle_websocket, retlog else: diff --git a/pathod/protocols/websockets.py b/pathod/protocols/websockets.py index a34e75e8..df83461a 100644 --- a/pathod/protocols/websockets.py +++ b/pathod/protocols/websockets.py @@ -20,7 +20,7 @@ class WebsocketsProtocol: lg("Error reading websocket frame: %s" % e) return None, None ended = time.time() - lg(frm.human_readable()) + lg(repr(frm)) retlog = dict( type="inbound", protocol="websockets", diff --git a/test/mitmproxy/protocol/test_websockets.py b/test/mitmproxy/protocol/test_websockets.py new file mode 100644 index 00000000..cc478c0b --- /dev/null +++ b/test/mitmproxy/protocol/test_websockets.py @@ -0,0 +1,297 @@ +import pytest +import os +import tempfile +import traceback + +from mitmproxy import options +from mitmproxy.proxy.config import ProxyConfig + +import netlib +from netlib import http +from ...netlib import tservers as netlib_tservers +from .. import tservers + +from netlib import websockets + + +class _WebSocketsServerBase(netlib_tservers.ServerTestBase): + + class handler(netlib.tcp.BaseHandler): + + def handle(self): + try: + request = http.http1.read_request(self.rfile) + assert websockets.check_handshake(request.headers) + + response = http.Response( + "HTTP/1.1", + 101, + reason=http.status_codes.RESPONSES.get(101), + headers=http.Headers( + connection='upgrade', + upgrade='websocket', + sec_websocket_accept=b'', + ), + content=b'', + ) + self.wfile.write(http.http1.assemble_response(response)) + self.wfile.flush() + + self.server.handle_websockets(self.rfile, self.wfile) + except: + traceback.print_exc() + + +class _WebSocketsTestBase(object): + + @classmethod + def setup_class(cls): + opts = cls.get_options() + cls.config = ProxyConfig(opts) + + tmaster = tservers.TestMaster(opts, cls.config) + tmaster.start_app(options.APP_HOST, options.APP_PORT) + cls.proxy = tservers.ProxyThread(tmaster) + cls.proxy.start() + + @classmethod + def teardown_class(cls): + cls.proxy.shutdown() + + @classmethod + def get_options(cls): + opts = options.Options( + listen_port=0, + no_upstream_cert=False, + ssl_insecure=True + ) + opts.cadir = os.path.join(tempfile.gettempdir(), "mitmproxy") + return opts + + @property + def master(self): + return self.proxy.tmaster + + def setup(self): + self.master.clear_log() + self.master.state.clear() + self.server.server.handle_websockets = self.handle_websockets + + def _setup_connection(self): + client = netlib.tcp.TCPClient(("127.0.0.1", self.proxy.port)) + client.connect() + + request = http.Request( + "authority", + "CONNECT", + "", + "localhost", + self.server.server.address.port, + "", + "HTTP/1.1", + content=b'') + client.wfile.write(http.http1.assemble_request(request)) + client.wfile.flush() + + response = http.http1.read_response(client.rfile, request) + + if self.ssl: + client.convert_to_ssl() + assert client.ssl_established + + request = http.Request( + "relative", + "GET", + "http", + "localhost", + self.server.server.address.port, + "/ws", + "HTTP/1.1", + headers=http.Headers( + connection="upgrade", + upgrade="websocket", + sec_websocket_version="13", + sec_websocket_key="1234", + ), + content=b'') + client.wfile.write(http.http1.assemble_request(request)) + client.wfile.flush() + + response = http.http1.read_response(client.rfile, request) + assert websockets.check_handshake(response.headers) + + return client + + +class _WebSocketsTest(_WebSocketsTestBase, _WebSocketsServerBase): + + @classmethod + def setup_class(cls): + _WebSocketsTestBase.setup_class() + _WebSocketsServerBase.setup_class(ssl=cls.ssl) + + @classmethod + def teardown_class(cls): + _WebSocketsTestBase.teardown_class() + _WebSocketsServerBase.teardown_class() + + +class TestSimple(_WebSocketsTest): + + @classmethod + def handle_websockets(cls, rfile, wfile): + wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.TEXT, payload=b'server-foobar'))) + wfile.flush() + + frame = websockets.Frame.from_file(rfile) + wfile.write(bytes(frame)) + wfile.flush() + + def test_simple(self): + client = self._setup_connection() + + frame = websockets.Frame.from_file(client.rfile) + assert frame.payload == b'server-foobar' + + client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.TEXT, payload=b'client-foobar'))) + client.wfile.flush() + + frame = websockets.Frame.from_file(client.rfile) + assert frame.payload == b'client-foobar' + + client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.CLOSE))) + client.wfile.flush() + + +class TestSimpleTLS(_WebSocketsTest): + ssl = True + + @classmethod + def handle_websockets(cls, rfile, wfile): + wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.TEXT, payload=b'server-foobar'))) + wfile.flush() + + frame = websockets.Frame.from_file(rfile) + wfile.write(bytes(frame)) + wfile.flush() + + def test_simple_tls(self): + client = self._setup_connection() + + frame = websockets.Frame.from_file(client.rfile) + assert frame.payload == b'server-foobar' + + client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.TEXT, payload=b'client-foobar'))) + client.wfile.flush() + + frame = websockets.Frame.from_file(client.rfile) + assert frame.payload == b'client-foobar' + + client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.CLOSE))) + client.wfile.flush() + + +class TestPing(_WebSocketsTest): + + @classmethod + def handle_websockets(cls, rfile, wfile): + wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.PING, payload=b'foobar'))) + wfile.flush() + + frame = websockets.Frame.from_file(rfile) + assert frame.header.opcode == websockets.OPCODE.PONG + assert frame.payload == b'foobar' + + wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.TEXT, payload=b'pong-received'))) + wfile.flush() + + def test_ping(self): + client = self._setup_connection() + + frame = websockets.Frame.from_file(client.rfile) + assert frame.header.opcode == websockets.OPCODE.PING + assert frame.payload == b'foobar' + + client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.PONG, payload=frame.payload))) + client.wfile.flush() + + frame = websockets.Frame.from_file(client.rfile) + assert frame.header.opcode == websockets.OPCODE.TEXT + assert frame.payload == b'pong-received' + + +class TestPong(_WebSocketsTest): + + @classmethod + def handle_websockets(cls, rfile, wfile): + frame = websockets.Frame.from_file(rfile) + assert frame.header.opcode == websockets.OPCODE.PING + assert frame.payload == b'foobar' + + wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.PONG, payload=frame.payload))) + wfile.flush() + + def test_pong(self): + client = self._setup_connection() + + client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.PING, payload=b'foobar'))) + client.wfile.flush() + + frame = websockets.Frame.from_file(client.rfile) + assert frame.header.opcode == websockets.OPCODE.PONG + assert frame.payload == b'foobar' + + +class TestClose(_WebSocketsTest): + + @classmethod + def handle_websockets(cls, rfile, wfile): + frame = websockets.Frame.from_file(rfile) + wfile.write(bytes(frame)) + wfile.flush() + + with pytest.raises(netlib.exceptions.TcpDisconnect): + websockets.Frame.from_file(rfile) + + def test_close(self): + client = self._setup_connection() + + client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.CLOSE))) + client.wfile.flush() + + with pytest.raises(netlib.exceptions.TcpDisconnect): + websockets.Frame.from_file(client.rfile) + + def test_close_payload_1(self): + client = self._setup_connection() + + client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.CLOSE, payload=b'\00\42'))) + client.wfile.flush() + + with pytest.raises(netlib.exceptions.TcpDisconnect): + websockets.Frame.from_file(client.rfile) + + def test_close_payload_2(self): + client = self._setup_connection() + + client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.CLOSE, payload=b'\00\42foobar'))) + client.wfile.flush() + + with pytest.raises(netlib.exceptions.TcpDisconnect): + websockets.Frame.from_file(client.rfile) + + +class TestInvalidFrame(_WebSocketsTest): + + @classmethod + def handle_websockets(cls, rfile, wfile): + wfile.write(bytes(websockets.Frame(fin=1, opcode=15, payload=b'foobar'))) + wfile.flush() + + def test_invalid_frame(self): + client = self._setup_connection() + + # with pytest.raises(netlib.exceptions.TcpDisconnect): + frame = websockets.Frame.from_file(client.rfile) + assert frame.header.opcode == 15 + assert frame.payload == b'foobar' -- cgit v1.2.3