aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--docs/features/passthrough.rst4
-rw-r--r--docs/index.rst10
-rw-r--r--docs/protocols/http1.rst15
-rw-r--r--docs/protocols/http2.rst16
-rw-r--r--docs/protocols/tcpproxy.rst (renamed from docs/features/tcpproxy.rst)6
-rw-r--r--docs/protocols/websocket.rst22
-rw-r--r--docs/scripting/events.rst2
-rw-r--r--mitmproxy/addons/dumper.py2
-rw-r--r--mitmproxy/contrib/wsproto/compat.py20
-rw-r--r--mitmproxy/contrib/wsproto/connection.py477
-rw-r--r--mitmproxy/contrib/wsproto/events.py81
-rw-r--r--mitmproxy/contrib/wsproto/extensions.py257
-rw-r--r--mitmproxy/contrib/wsproto/frame_protocol.py579
-rw-r--r--mitmproxy/proxy/protocol/websocket.py182
-rw-r--r--mitmproxy/tools/console/consoleaddons.py2
-rw-r--r--setup.cfg9
-rw-r--r--setup.py1
-rw-r--r--test/mitmproxy/proxy/protocol/test_websocket.py142
18 files changed, 1708 insertions, 119 deletions
diff --git a/docs/features/passthrough.rst b/docs/features/passthrough.rst
index 00462e9d..dbaf3506 100644
--- a/docs/features/passthrough.rst
+++ b/docs/features/passthrough.rst
@@ -13,7 +13,7 @@ mechanism:
away. Note that mitmproxy's "Limit" option is often the better alternative here, as it is
not affected by the limitations listed below.
-If you want to peek into (SSL-protected) non-HTTP connections, check out the :ref:`tcpproxy`
+If you want to peek into (SSL-protected) non-HTTP connections, check out the :ref:`tcp_proxy`
feature.
If you want to ignore traffic from mitmproxy's processing because of large response bodies,
take a look at the :ref:`streaming` feature.
@@ -88,7 +88,7 @@ Here are some other examples for ignore patterns:
.. seealso::
- - :ref:`tcpproxy`
+ - :ref:`tcp_proxy`
- :ref:`streaming`
- mitmproxy's "Limit" feature
diff --git a/docs/index.rst b/docs/index.rst
index 7cf593ff..8dba4d04 100644
--- a/docs/index.rst
+++ b/docs/index.rst
@@ -22,6 +22,15 @@
.. toctree::
:hidden:
+ :caption: Protocols
+
+ protocols/http1
+ protocols/http2
+ protocols/websocket
+ protocols/tcpproxy
+
+.. toctree::
+ :hidden:
:caption: Features
features/anticache
@@ -36,7 +45,6 @@
features/streaming
features/socksproxy
features/sticky
- features/tcpproxy
features/upstreamproxy
features/upstreamcerts
diff --git a/docs/protocols/http1.rst b/docs/protocols/http1.rst
new file mode 100644
index 00000000..21e68785
--- /dev/null
+++ b/docs/protocols/http1.rst
@@ -0,0 +1,15 @@
+.. _http1_protocol:
+
+HTTP/1.0 and HTTP/1.1
+===========================
+
+.. seealso::
+
+ - `RFC7230: HTTP/1.1: Message Syntax and Routing <http://tools.ietf.org/html/rfc7230>`_
+ - `RFC7231: HTTP/1.1: Semantics and Content <http://tools.ietf.org/html/rfc7231>`_
+
+HTTP/1.0 and HTTP/1.1 support in mitmproxy is based on our custom HTTP stack,
+which takes care of all semantics and on-the-wire parsing/serialization tasks.
+
+mitmproxy currently does not support HTTP trailers - but if you want to send
+us a PR, we promise to take look!
diff --git a/docs/protocols/http2.rst b/docs/protocols/http2.rst
new file mode 100644
index 00000000..b3268ae5
--- /dev/null
+++ b/docs/protocols/http2.rst
@@ -0,0 +1,16 @@
+.. _http2_protocol:
+
+HTTP/2
+======
+
+.. seealso::
+
+ - `RFC7540: Hypertext Transfer Protocol Version 2 (HTTP/2) <http://tools.ietf.org/html/rfc7540>`_
+
+HTTP/2 support in mitmproxy is based on the amazing work by the python-hyper
+community with the `hyper-h2 <https://github.com/python-hyper/hyper-h2>`_
+project. It fully encapsulates the internal state of HTTP/2 connections and
+provides an easy-to-use event-based API.
+
+mitmproxy currently does not support HTTP/2 trailers - but if you want to send
+us a PR, we promise to take look!
diff --git a/docs/features/tcpproxy.rst b/docs/protocols/tcpproxy.rst
index cba374e3..77248573 100644
--- a/docs/features/tcpproxy.rst
+++ b/docs/protocols/tcpproxy.rst
@@ -1,7 +1,7 @@
-.. _tcpproxy:
+.. _tcp_proxy:
-TCP Proxy
-=========
+TCP Proxy / Fallback
+====================
In case mitmproxy does not handle a specific protocol, you can exempt
hostnames from processing, so that mitmproxy acts as a generic TCP forwarder.
diff --git a/docs/protocols/websocket.rst b/docs/protocols/websocket.rst
new file mode 100644
index 00000000..8a7e807f
--- /dev/null
+++ b/docs/protocols/websocket.rst
@@ -0,0 +1,22 @@
+.. _websocket_protocol:
+
+WebSocket
+=========
+
+.. seealso::
+
+ - `RFC6455: The WebSocket Protocol <http://tools.ietf.org/html/rfc6455>`_
+ - `RFC7692: Compression Extensions for WebSocket <http://tools.ietf.org/html/rfc7692>`_
+
+WebSocket support in mitmproxy is based on the amazing work by the python-hyper
+community with the `wsproto <https://github.com/python-hyper/wsproto>`_
+project. It fully encapsulates WebSocket frames/messages/connections and
+provides an easy-to-use event-based API.
+
+mitmproxy fully supports the compression extension for WebSocket messages,
+provided by wsproto.
+
+If an endpoint sends a PING to mitmproxy, a PONG will be sent back immediately
+(with the same payload if present). To keep the other connection alive, a new
+PING (without a payload) is sent to the other endpoint. Unsolicited PONG's are
+not forwarded. All PING's and PONG's are logged (with payload if present).
diff --git a/docs/scripting/events.rst b/docs/scripting/events.rst
index 8f9463ff..9e84dacf 100644
--- a/docs/scripting/events.rst
+++ b/docs/scripting/events.rst
@@ -211,7 +211,7 @@ TCP Events
----------
These events are called only if the connection is in :ref:`TCP mode
-<tcpproxy>`. So, for instance, TCP events are not called for ordinary HTTP/S
+<tcp_proxy>`. So, for instance, TCP events are not called for ordinary HTTP/S
connections.
.. list-table::
diff --git a/mitmproxy/addons/dumper.py b/mitmproxy/addons/dumper.py
index 54526d5b..48bc8118 100644
--- a/mitmproxy/addons/dumper.py
+++ b/mitmproxy/addons/dumper.py
@@ -234,6 +234,8 @@ class Dumper:
message = f.messages[-1]
self.echo(f.message_info(message))
if ctx.options.flow_detail >= 3:
+ message = message.from_state(message.get_state())
+ message.content = message.content.encode() if isinstance(message.content, str) else message.content
self._echo_message(message)
def websocket_end(self, f):
diff --git a/mitmproxy/contrib/wsproto/compat.py b/mitmproxy/contrib/wsproto/compat.py
new file mode 100644
index 00000000..1911f83c
--- /dev/null
+++ b/mitmproxy/contrib/wsproto/compat.py
@@ -0,0 +1,20 @@
+# flake8: noqa
+
+import sys
+
+
+PY2 = sys.version_info.major == 2
+PY3 = sys.version_info.major == 3
+
+
+if PY3:
+ unicode = str
+
+ def Utf8Validator():
+ return None
+else:
+ unicode = unicode
+ try:
+ from wsaccel.utf8validator import Utf8Validator
+ except ImportError:
+ from .utf8validator import Utf8Validator
diff --git a/mitmproxy/contrib/wsproto/connection.py b/mitmproxy/contrib/wsproto/connection.py
new file mode 100644
index 00000000..f994cd3a
--- /dev/null
+++ b/mitmproxy/contrib/wsproto/connection.py
@@ -0,0 +1,477 @@
+# -*- coding: utf-8 -*-
+"""
+wsproto/connection
+~~~~~~~~~~~~~~
+
+An implementation of a WebSocket connection.
+"""
+
+import os
+import base64
+import hashlib
+from collections import deque
+
+from enum import Enum
+
+import h11
+
+from .events import (
+ ConnectionRequested, ConnectionEstablished, ConnectionClosed,
+ ConnectionFailed, TextReceived, BytesReceived, PingReceived, PongReceived
+)
+from .frame_protocol import FrameProtocol, ParseFailed, CloseReason, Opcode
+
+
+# RFC6455, Section 1.3 - Opening Handshake
+ACCEPT_GUID = b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11"
+
+
+class ConnectionState(Enum):
+ """
+ RFC 6455, Section 4 - Opening Handshake
+ """
+ CONNECTING = 0
+ OPEN = 1
+ CLOSING = 2
+ CLOSED = 3
+
+
+class ConnectionType(Enum):
+ CLIENT = 1
+ SERVER = 2
+
+
+CLIENT = ConnectionType.CLIENT
+SERVER = ConnectionType.SERVER
+
+
+# Some convenience utilities for working with HTTP headers
+def _normed_header_dict(h11_headers):
+ # This mangles Set-Cookie headers. But it happens that we don't care about
+ # any of those, so it's OK. For every other HTTP header, if there are
+ # multiple instances then you're allowed to join them together with
+ # commas.
+ name_to_values = {}
+ for name, value in h11_headers:
+ name_to_values.setdefault(name, []).append(value)
+ name_to_normed_value = {}
+ for name, values in name_to_values.items():
+ name_to_normed_value[name] = b", ".join(values)
+ return name_to_normed_value
+
+
+# We use this for parsing the proposed protocol list, and for parsing the
+# proposed and accepted extension lists. For the proposed protocol list it's
+# fine, because the ABNF is just 1#token. But for the extension lists, it's
+# wrong, because those can contain quoted strings, which can in turn contain
+# commas. XX FIXME
+def _split_comma_header(value):
+ return [piece.decode('ascii').strip() for piece in value.split(b',')]
+
+
+class WSConnection(object):
+ """
+ A low-level WebSocket connection object.
+
+ This wraps two other protocol objects, an HTTP/1.1 protocol object used
+ to do the initial HTTP upgrade handshake and a WebSocket frame protocol
+ object used to exchange messages and other control frames.
+
+ :param conn_type: Whether this object is on the client- or server-side of
+ a connection. To initialise as a client pass ``CLIENT`` otherwise
+ pass ``SERVER``.
+ :type conn_type: ``ConnectionType``
+
+ :param host: The hostname to pass to the server when acting as a client.
+ :type host: ``str``
+
+ :param resource: The resource (aka path) to pass to the server when acting
+ as a client.
+ :type resource: ``str``
+
+ :param extensions: A list of extensions to use on this connection.
+ Extensions should be instances of a subclass of
+ :class:`Extension <wsproto.extensions.Extension>`.
+
+ :param subprotocols: A list of subprotocols to request when acting as a
+ client, ordered by preference. This has no impact on the connection
+ itself.
+ :type subprotocol: ``list`` of ``str``
+ """
+
+ def __init__(self, conn_type, host=None, resource=None, extensions=None,
+ subprotocols=None):
+ self.client = conn_type is ConnectionType.CLIENT
+
+ self.host = host
+ self.resource = resource
+
+ self.subprotocols = subprotocols or []
+ self.extensions = extensions or []
+
+ self.version = b'13'
+
+ self._state = ConnectionState.CONNECTING
+ self._close_reason = None
+
+ self._nonce = None
+ self._outgoing = b''
+ self._events = deque()
+ self._proto = None
+
+ if self.client:
+ self._upgrade_connection = h11.Connection(h11.CLIENT)
+ else:
+ self._upgrade_connection = h11.Connection(h11.SERVER)
+
+ if self.client:
+ if self.host is None:
+ raise ValueError(
+ "Host must not be None for a client-side connection.")
+ if self.resource is None:
+ raise ValueError(
+ "Resource must not be None for a client-side connection.")
+ self.initiate_connection()
+
+ def initiate_connection(self):
+ self._generate_nonce()
+
+ headers = {
+ b"Host": self.host.encode('ascii'),
+ b"Upgrade": b'WebSocket',
+ b"Connection": b'Upgrade',
+ b"Sec-WebSocket-Key": self._nonce,
+ b"Sec-WebSocket-Version": self.version,
+ }
+
+ if self.subprotocols:
+ headers[b"Sec-WebSocket-Protocol"] = ", ".join(self.subprotocols)
+
+ if self.extensions:
+ offers = {e.name: e.offer(self) for e in self.extensions}
+ extensions = []
+ for name, params in offers.items():
+ if params is True:
+ extensions.append(name.encode('ascii'))
+ elif params:
+ # py34 annoyance: doesn't support bytestring formatting
+ extensions.append(('%s; %s' % (name, params))
+ .encode("ascii"))
+ if extensions:
+ headers[b'Sec-WebSocket-Extensions'] = b', '.join(extensions)
+
+ upgrade = h11.Request(method=b'GET', target=self.resource,
+ headers=headers.items())
+ self._outgoing += self._upgrade_connection.send(upgrade)
+
+ def send_data(self, payload, final=True):
+ """
+ Send a message or part of a message to the remote peer.
+
+ If ``final`` is ``False`` it indicates that this is part of a longer
+ message. If ``final`` is ``True`` it indicates that this is either a
+ self-contained message or the last part of a longer message.
+
+ If ``payload`` is of type ``bytes`` then the message is flagged as
+ being binary If it is of type ``str`` encoded as UTF-8 and sent as
+ text.
+
+ :param payload: The message body to send.
+ :type payload: ``bytes`` or ``str``
+
+ :param final: Whether there are more parts to this message to be sent.
+ :type final: ``bool``
+ """
+
+ self._outgoing += self._proto.send_data(payload, final)
+
+ def close(self, code=CloseReason.NORMAL_CLOSURE, reason=None):
+ self._outgoing += self._proto.close(code, reason)
+ self._state = ConnectionState.CLOSING
+
+ @property
+ def closed(self):
+ return self._state is ConnectionState.CLOSED
+
+ def bytes_to_send(self, amount=None):
+ """
+ Return any data that is to be sent to the remote peer.
+
+ :param amount: (optional) The maximum number of bytes to be provided.
+ If ``None`` or not provided it will return all available bytes.
+ :type amount: ``int``
+ """
+
+ if amount is None:
+ data = self._outgoing
+ self._outgoing = b''
+ else:
+ data = self._outgoing[:amount]
+ self._outgoing = self._outgoing[amount:]
+
+ return data
+
+ def receive_bytes(self, data):
+ """
+ Pass some received bytes to the connection for processing.
+
+ :param data: The data received from the remote peer.
+ :type data: ``bytes``
+ """
+
+ if data is None and self._state is ConnectionState.OPEN:
+ # "If _The WebSocket Connection is Closed_ and no Close control
+ # frame was received by the endpoint (such as could occur if the
+ # underlying transport connection is lost), _The WebSocket
+ # Connection Close Code_ is considered to be 1006."
+ self._events.append(ConnectionClosed(CloseReason.ABNORMAL_CLOSURE))
+ self._state = ConnectionState.CLOSED
+ return
+ elif data is None:
+ self._state = ConnectionState.CLOSED
+ return
+
+ if self._state is ConnectionState.CONNECTING:
+ event, data = self._process_upgrade(data)
+ if event is not None:
+ self._events.append(event)
+
+ if self._state is ConnectionState.OPEN:
+ self._proto.receive_bytes(data)
+
+ def _process_upgrade(self, data):
+ self._upgrade_connection.receive_data(data)
+ while True:
+ try:
+ event = self._upgrade_connection.next_event()
+ except h11.RemoteProtocolError:
+ return ConnectionFailed(CloseReason.PROTOCOL_ERROR,
+ "Bad HTTP message"), b''
+ if event is h11.NEED_DATA:
+ break
+ elif self.client and isinstance(event, (h11.InformationalResponse,
+ h11.Response)):
+ data = self._upgrade_connection.trailing_data[0]
+ return self._establish_client_connection(event), data
+ elif not self.client and isinstance(event, h11.Request):
+ return self._process_connection_request(event), None
+ else:
+ return ConnectionFailed(CloseReason.PROTOCOL_ERROR,
+ "Bad HTTP message"), b''
+
+ self._incoming = b''
+ return None, None
+
+ def events(self):
+ """
+ Return a generator that provides any events that have been generated
+ by protocol activity.
+
+ :returns: generator
+ """
+
+ while self._events:
+ yield self._events.popleft()
+
+ if self._proto is None:
+ return
+
+ try:
+ for frame in self._proto.received_frames():
+ if frame.opcode is Opcode.PING:
+ assert frame.frame_finished and frame.message_finished
+ self._outgoing += self._proto.pong(frame.payload)
+ yield PingReceived(frame.payload)
+
+ elif frame.opcode is Opcode.PONG:
+ assert frame.frame_finished and frame.message_finished
+ yield PongReceived(frame.payload)
+
+ elif frame.opcode is Opcode.CLOSE:
+ code, reason = frame.payload
+ self.close(code, reason)
+ yield ConnectionClosed(code, reason)
+
+ elif frame.opcode is Opcode.TEXT:
+ yield TextReceived(frame.payload,
+ frame.frame_finished,
+ frame.message_finished)
+
+ elif frame.opcode is Opcode.BINARY:
+ yield BytesReceived(frame.payload,
+ frame.frame_finished,
+ frame.message_finished)
+ except ParseFailed as exc:
+ # XX FIXME: apparently autobahn intentionally deviates from the
+ # spec in that on protocol errors it just closes the connection
+ # rather than trying to send a CLOSE frame. Investigate whether we
+ # should do the same.
+ self.close(code=exc.code, reason=str(exc))
+ yield ConnectionClosed(exc.code, reason=str(exc))
+
+ def _generate_nonce(self):
+ # os.urandom may be overkill for this use case, but I don't think this
+ # is a bottleneck, and better safe than sorry...
+ self._nonce = base64.b64encode(os.urandom(16))
+
+ def _generate_accept_token(self, token):
+ accept_token = token + ACCEPT_GUID
+ accept_token = hashlib.sha1(accept_token).digest()
+ return base64.b64encode(accept_token)
+
+ def _establish_client_connection(self, event):
+ if event.status_code != 101:
+ return ConnectionFailed(CloseReason.PROTOCOL_ERROR,
+ "Bad status code from server")
+ headers = _normed_header_dict(event.headers)
+ if headers[b'connection'].lower() != b'upgrade':
+ return ConnectionFailed(CloseReason.PROTOCOL_ERROR,
+ "Missing Connection: Upgrade header")
+ if headers[b'upgrade'].lower() != b'websocket':
+ return ConnectionFailed(CloseReason.PROTOCOL_ERROR,
+ "Missing Upgrade: WebSocket header")
+
+ accept_token = self._generate_accept_token(self._nonce)
+ if headers[b'sec-websocket-accept'] != accept_token:
+ return ConnectionFailed(CloseReason.PROTOCOL_ERROR,
+ "Bad accept token")
+
+ subprotocol = headers.get(b'sec-websocket-protocol', None)
+ if subprotocol is not None:
+ subprotocol = subprotocol.decode('ascii')
+ if subprotocol not in self.subprotocols:
+ return ConnectionFailed(CloseReason.PROTOCOL_ERROR,
+ "unrecognized subprotocol {!r}"
+ .format(subprotocol))
+
+ extensions = headers.get(b'sec-websocket-extensions', None)
+ if extensions:
+ accepts = _split_comma_header(extensions)
+
+ for accept in accepts:
+ name = accept.split(';', 1)[0].strip()
+ for extension in self.extensions:
+ if extension.name == name:
+ extension.finalize(self, accept)
+ break
+ else:
+ return ConnectionFailed(CloseReason.PROTOCOL_ERROR,
+ "unrecognized extension {!r}"
+ .format(name))
+
+ self._proto = FrameProtocol(self.client, self.extensions)
+ self._state = ConnectionState.OPEN
+ return ConnectionEstablished(subprotocol, extensions)
+
+ def _process_connection_request(self, event):
+ if event.method != b'GET':
+ return ConnectionFailed(CloseReason.PROTOCOL_ERROR,
+ "Request method must be GET")
+ headers = _normed_header_dict(event.headers)
+ if headers[b'connection'].lower() != b'upgrade':
+ return ConnectionFailed(CloseReason.PROTOCOL_ERROR,
+ "Missing Connection: Upgrade header")
+ if headers[b'upgrade'].lower() != b'websocket':
+ return ConnectionFailed(CloseReason.PROTOCOL_ERROR,
+ "Missing Upgrade: WebSocket header")
+
+ if b'sec-websocket-version' not in headers:
+ return ConnectionFailed(CloseReason.PROTOCOL_ERROR,
+ "Missing Sec-WebSocket-Version header")
+ # XX FIXME: need to check Sec-Websocket-Version, and respond with a
+ # 400 if it's not what we expect
+
+ if b'sec-websocket-protocol' in headers:
+ proposed_subprotocols = _split_comma_header(
+ headers[b'sec-websocket-protocol'])
+ else:
+ proposed_subprotocols = []
+
+ if b'sec-websocket-key' not in headers:
+ return ConnectionFailed(CloseReason.PROTOCOL_ERROR,
+ "Missing Sec-WebSocket-Key header")
+
+ return ConnectionRequested(proposed_subprotocols, event)
+
+ def _extension_accept(self, extensions_header):
+ accepts = {}
+ offers = _split_comma_header(extensions_header)
+
+ for offer in offers:
+ name = offer.split(';', 1)[0].strip()
+ for extension in self.extensions:
+ if extension.name == name:
+ accept = extension.accept(self, offer)
+ if accept is True:
+ accepts[extension.name] = True
+ elif accept is not False and accept is not None:
+ accepts[extension.name] = accept.encode('ascii')
+
+ if accepts:
+ extensions = []
+ for name, params in accepts.items():
+ if params is True:
+ extensions.append(name.encode('ascii'))
+ else:
+ # py34 annoyance: doesn't support bytestring formatting
+ params = params.decode("ascii")
+ extensions.append(('%s; %s' % (name, params))
+ .encode("ascii"))
+ return b', '.join(extensions)
+
+ return None
+
+ def accept(self, event, subprotocol=None):
+ request = event.h11request
+ request_headers = _normed_header_dict(request.headers)
+
+ nonce = request_headers[b'sec-websocket-key']
+ accept_token = self._generate_accept_token(nonce)
+
+ headers = {
+ b"Upgrade": b'WebSocket',
+ b"Connection": b'Upgrade',
+ b"Sec-WebSocket-Accept": accept_token,
+ }
+
+ if subprotocol is not None:
+ if subprotocol not in event.proposed_subprotocols:
+ raise ValueError(
+ "unexpected subprotocol {!r}".format(subprotocol))
+ headers[b'Sec-WebSocket-Protocol'] = subprotocol
+
+ extensions = request_headers.get(b'sec-websocket-extensions', None)
+ if extensions:
+ accepts = self._extension_accept(extensions)
+ if accepts:
+ headers[b"Sec-WebSocket-Extensions"] = accepts
+
+ response = h11.InformationalResponse(status_code=101,
+ headers=headers.items())
+ self._outgoing += self._upgrade_connection.send(response)
+ self._proto = FrameProtocol(self.client, self.extensions)
+ self._state = ConnectionState.OPEN
+
+ def ping(self, payload=None):
+ """
+ Send a PING message to the peer.
+
+ :param payload: an optional payload to send with the message
+ """
+
+ payload = bytes(payload or b'')
+ self._outgoing += self._proto.ping(payload)
+
+ def pong(self, payload=None):
+ """
+ Send a PONG message to the peer.
+
+ This method can be used to send an unsolicted PONG to the peer.
+ It is not needed otherwise since every received PING causes a
+ corresponding PONG to be sent automatically.
+
+ :param payload: an optional payload to send with the message
+ """
+
+ payload = bytes(payload or b'')
+ self._outgoing += self._proto.pong(payload)
diff --git a/mitmproxy/contrib/wsproto/events.py b/mitmproxy/contrib/wsproto/events.py
new file mode 100644
index 00000000..73ce27aa
--- /dev/null
+++ b/mitmproxy/contrib/wsproto/events.py
@@ -0,0 +1,81 @@
+# -*- coding: utf-8 -*-
+"""
+wsproto/events
+~~~~~~~~~~
+
+Events that result from processing data on a WebSocket connection.
+"""
+
+
+class ConnectionRequested(object):
+ def __init__(self, proposed_subprotocols, h11request):
+ self.proposed_subprotocols = proposed_subprotocols
+ self.h11request = h11request
+
+ def __repr__(self):
+ path = self.h11request.target
+
+ headers = dict(self.h11request.headers)
+ host = headers[b'host']
+ version = headers[b'sec-websocket-version']
+ subprotocol = headers.get(b'sec-websocket-protocol', None)
+ extensions = []
+
+ fmt = '<%s host=%s path=%s version=%s subprotocol=%r extensions=%r>'
+ return fmt % (self.__class__.__name__, host, path, version,
+ subprotocol, extensions)
+
+
+class ConnectionEstablished(object):
+ def __init__(self, subprotocol=None, extensions=None):
+ self.subprotocol = subprotocol
+ self.extensions = extensions
+ if self.extensions is None:
+ self.extensions = []
+
+ def __repr__(self):
+ return '<ConnectionEstablished subprotocol=%r extensions=%r>' % \
+ (self.subprotocol, self.extensions)
+
+
+class ConnectionClosed(object):
+ def __init__(self, code, reason=None):
+ self.code = code
+ self.reason = reason
+
+ def __repr__(self):
+ return '<%s code=%r reason="%s">' % (self.__class__.__name__,
+ self.code, self.reason)
+
+
+class ConnectionFailed(ConnectionClosed):
+ pass
+
+
+class DataReceived(object):
+ def __init__(self, data, frame_finished, message_finished):
+ self.data = data
+ # This has no semantic content, but is provided just in case some
+ # weird edge case user wants to be able to reconstruct the
+ # fragmentation pattern of the original stream. You don't want it:
+ self.frame_finished = frame_finished
+ # This is the field that you almost certainly want:
+ self.message_finished = message_finished
+
+
+class TextReceived(DataReceived):
+ pass
+
+
+class BytesReceived(DataReceived):
+ pass
+
+
+class PingReceived(object):
+ def __init__(self, payload):
+ self.payload = payload
+
+
+class PongReceived(object):
+ def __init__(self, payload):
+ self.payload = payload
diff --git a/mitmproxy/contrib/wsproto/extensions.py b/mitmproxy/contrib/wsproto/extensions.py
new file mode 100644
index 00000000..f7cf4fb6
--- /dev/null
+++ b/mitmproxy/contrib/wsproto/extensions.py
@@ -0,0 +1,257 @@
+# -*- coding: utf-8 -*-
+"""
+wsproto/extensions
+~~~~~~~~~~~~~~
+
+WebSocket extensions.
+"""
+
+import zlib
+
+from .frame_protocol import CloseReason, Opcode, RsvBits
+
+
+class Extension(object):
+ name = None
+
+ def enabled(self):
+ return False
+
+ def offer(self, connection):
+ pass
+
+ def accept(self, connection, offer):
+ pass
+
+ def finalize(self, connection, offer):
+ pass
+
+ def frame_inbound_header(self, proto, opcode, rsv, payload_length):
+ return RsvBits(False, False, False)
+
+ def frame_inbound_payload_data(self, proto, data):
+ return data
+
+ def frame_inbound_complete(self, proto, fin):
+ pass
+
+ def frame_outbound(self, proto, opcode, rsv, data, fin):
+ return (rsv, data)
+
+
+class PerMessageDeflate(Extension):
+ name = 'permessage-deflate'
+
+ DEFAULT_CLIENT_MAX_WINDOW_BITS = 15
+ DEFAULT_SERVER_MAX_WINDOW_BITS = 15
+
+ def __init__(self, client_no_context_takeover=False,
+ client_max_window_bits=None, server_no_context_takeover=False,
+ server_max_window_bits=None):
+ self.client_no_context_takeover = client_no_context_takeover
+ if client_max_window_bits is None:
+ client_max_window_bits = self.DEFAULT_CLIENT_MAX_WINDOW_BITS
+ self.client_max_window_bits = client_max_window_bits
+ self.server_no_context_takeover = server_no_context_takeover
+ if server_max_window_bits is None:
+ server_max_window_bits = self.DEFAULT_SERVER_MAX_WINDOW_BITS
+ self.server_max_window_bits = server_max_window_bits
+
+ self._compressor = None
+ self._decompressor = None
+ # This refers to the current frame
+ self._inbound_is_compressible = None
+ # This refers to the ongoing message (which might span multiple
+ # frames). Only the first frame in a fragmented message is flagged for
+ # compression, so this carries that bit forward.
+ self._inbound_compressed = None
+
+ self._enabled = False
+
+ def _compressible_opcode(self, opcode):
+ return opcode in (Opcode.TEXT, Opcode.BINARY, Opcode.CONTINUATION)
+
+ def enabled(self):
+ return self._enabled
+
+ def offer(self, connection):
+ parameters = [
+ 'client_max_window_bits=%d' % self.client_max_window_bits,
+ 'server_max_window_bits=%d' % self.server_max_window_bits,
+ ]
+
+ if self.client_no_context_takeover:
+ parameters.append('client_no_context_takeover')
+ if self.server_no_context_takeover:
+ parameters.append('server_no_context_takeover')
+
+ return '; '.join(parameters)
+
+ def finalize(self, connection, offer):
+ bits = [b.strip() for b in offer.split(';')]
+ for bit in bits[1:]:
+ if bit.startswith('client_no_context_takeover'):
+ self.client_no_context_takeover = True
+ elif bit.startswith('server_no_context_takeover'):
+ self.server_no_context_takeover = True
+ elif bit.startswith('client_max_window_bits'):
+ self.client_max_window_bits = int(bit.split('=', 1)[1].strip())
+ elif bit.startswith('server_max_window_bits'):
+ self.server_max_window_bits = int(bit.split('=', 1)[1].strip())
+
+ self._enabled = True
+
+ def _parse_params(self, params):
+ client_max_window_bits = None
+ server_max_window_bits = None
+
+ bits = [b.strip() for b in params.split(';')]
+ for bit in bits[1:]:
+ if bit.startswith('client_no_context_takeover'):
+ self.client_no_context_takeover = True
+ elif bit.startswith('server_no_context_takeover'):
+ self.server_no_context_takeover = True
+ elif bit.startswith('client_max_window_bits'):
+ if '=' in bit:
+ client_max_window_bits = int(bit.split('=', 1)[1].strip())
+ else:
+ client_max_window_bits = self.client_max_window_bits
+ elif bit.startswith('server_max_window_bits'):
+ if '=' in bit:
+ server_max_window_bits = int(bit.split('=', 1)[1].strip())
+ else:
+ server_max_window_bits = self.server_max_window_bits
+
+ return client_max_window_bits, server_max_window_bits
+
+ def accept(self, connection, offer):
+ client_max_window_bits, server_max_window_bits = \
+ self._parse_params(offer)
+
+ self._enabled = True
+
+ parameters = []
+
+ if self.client_no_context_takeover:
+ parameters.append('client_no_context_takeover')
+ if client_max_window_bits is not None:
+ parameters.append('client_max_window_bits=%d' %
+ client_max_window_bits)
+ self.client_max_window_bits = client_max_window_bits
+ if self.server_no_context_takeover:
+ parameters.append('server_no_context_takeover')
+ if server_max_window_bits is not None:
+ parameters.append('server_max_window_bits=%d' %
+ server_max_window_bits)
+ self.server_max_window_bits = server_max_window_bits
+
+ return '; '.join(parameters)
+
+ def frame_inbound_header(self, proto, opcode, rsv, payload_length):
+ if rsv.rsv1 and opcode.iscontrol():
+ return CloseReason.PROTOCOL_ERROR
+ elif rsv.rsv1 and opcode is Opcode.CONTINUATION:
+ return CloseReason.PROTOCOL_ERROR
+
+ self._inbound_is_compressible = self._compressible_opcode(opcode)
+
+ if self._inbound_compressed is None:
+ self._inbound_compressed = rsv.rsv1
+ if self._inbound_compressed:
+ assert self._inbound_is_compressible
+ if proto.client:
+ bits = self.server_max_window_bits
+ else:
+ bits = self.client_max_window_bits
+ if self._decompressor is None:
+ self._decompressor = zlib.decompressobj(-int(bits))
+
+ return RsvBits(True, False, False)
+
+ def frame_inbound_payload_data(self, proto, data):
+ if not self._inbound_compressed or not self._inbound_is_compressible:
+ return data
+
+ try:
+ return self._decompressor.decompress(bytes(data))
+ except zlib.error:
+ return CloseReason.INVALID_FRAME_PAYLOAD_DATA
+
+ def frame_inbound_complete(self, proto, fin):
+ if not fin:
+ return
+ elif not self._inbound_is_compressible:
+ return
+ elif not self._inbound_compressed:
+ return
+
+ try:
+ data = self._decompressor.decompress(b'\x00\x00\xff\xff')
+ data += self._decompressor.flush()
+ except zlib.error:
+ return CloseReason.INVALID_FRAME_PAYLOAD_DATA
+
+ if proto.client:
+ no_context_takeover = self.server_no_context_takeover
+ else:
+ no_context_takeover = self.client_no_context_takeover
+
+ if no_context_takeover:
+ self._decompressor = None
+
+ self._inbound_compressed = None
+
+ return data
+
+ def frame_outbound(self, proto, opcode, rsv, data, fin):
+ if not self._compressible_opcode(opcode):
+ return (rsv, data)
+
+ if opcode is not Opcode.CONTINUATION:
+ rsv = RsvBits(True, *rsv[1:])
+
+ if self._compressor is None:
+ assert opcode is not Opcode.CONTINUATION
+ if proto.client:
+ bits = self.client_max_window_bits
+ else:
+ bits = self.server_max_window_bits
+ self._compressor = zlib.compressobj(zlib.Z_DEFAULT_COMPRESSION,
+ zlib.DEFLATED, -int(bits))
+
+ data = self._compressor.compress(bytes(data))
+
+ if fin:
+ data += self._compressor.flush(zlib.Z_SYNC_FLUSH)
+ data = data[:-4]
+
+ if proto.client:
+ no_context_takeover = self.client_no_context_takeover
+ else:
+ no_context_takeover = self.server_no_context_takeover
+
+ if no_context_takeover:
+ self._compressor = None
+
+ return (rsv, data)
+
+ def __repr__(self):
+ descr = ['client_max_window_bits=%d' % self.client_max_window_bits]
+ if self.client_no_context_takeover:
+ descr.append('client_no_context_takeover')
+ descr.append('server_max_window_bits=%d' % self.server_max_window_bits)
+ if self.server_no_context_takeover:
+ descr.append('server_no_context_takeover')
+
+ descr = '; '.join(descr)
+
+ return '<%s %s>' % (self.__class__.__name__, descr)
+
+
+#: SUPPORTED_EXTENSIONS maps all supported extension names to their class.
+#: This can be used to iterate all supported extensions of wsproto, instantiate
+#: new extensions based on their name, or check if a given extension is
+#: supported or not.
+SUPPORTED_EXTENSIONS = {
+ PerMessageDeflate.name: PerMessageDeflate
+}
diff --git a/mitmproxy/contrib/wsproto/frame_protocol.py b/mitmproxy/contrib/wsproto/frame_protocol.py
new file mode 100644
index 00000000..b95dceec
--- /dev/null
+++ b/mitmproxy/contrib/wsproto/frame_protocol.py
@@ -0,0 +1,579 @@
+# -*- coding: utf-8 -*-
+"""
+wsproto/frame_protocol
+~~~~~~~~~~~~~~
+
+WebSocket frame protocol implementation.
+"""
+
+import os
+import itertools
+import struct
+from codecs import getincrementaldecoder
+from collections import namedtuple
+
+from enum import Enum, IntEnum
+
+from .compat import unicode, Utf8Validator
+
+try:
+ from wsaccel.xormask import XorMaskerSimple
+except ImportError:
+ class XorMaskerSimple:
+ def __init__(self, masking_key):
+ self._maskbytes = itertools.cycle(bytearray(masking_key))
+
+ def process(self, data):
+ maskbytes = self._maskbytes
+ return bytearray(b ^ next(maskbytes) for b in bytearray(data))
+
+
+class XorMaskerNull:
+ def process(self, data):
+ return data
+
+
+# RFC6455, Section 5.2 - Base Framing Protocol
+
+# Payload length constants
+PAYLOAD_LENGTH_TWO_BYTE = 126
+PAYLOAD_LENGTH_EIGHT_BYTE = 127
+MAX_PAYLOAD_NORMAL = 125
+MAX_PAYLOAD_TWO_BYTE = 2 ** 16 - 1
+MAX_PAYLOAD_EIGHT_BYTE = 2 ** 64 - 1
+MAX_FRAME_PAYLOAD = MAX_PAYLOAD_EIGHT_BYTE
+
+# MASK and PAYLOAD LEN are packed into a byte
+MASK_MASK = 0x80
+PAYLOAD_LEN_MASK = 0x7f
+
+# FIN, RSV[123] and OPCODE are packed into a single byte
+FIN_MASK = 0x80
+RSV1_MASK = 0x40
+RSV2_MASK = 0x20
+RSV3_MASK = 0x10
+OPCODE_MASK = 0x0f
+
+
+class Opcode(IntEnum):
+ """
+ RFC 6455, Section 5.2 - Base Framing Protocol
+ """
+ CONTINUATION = 0x0
+ TEXT = 0x1
+ BINARY = 0x2
+ CLOSE = 0x8
+ PING = 0x9
+ PONG = 0xA
+
+ def iscontrol(self):
+ return bool(self & 0x08)
+
+
+class CloseReason(IntEnum):
+ """
+ RFC 6455, Section 7.4.1 - Defined Status Codes
+ """
+ NORMAL_CLOSURE = 1000
+ GOING_AWAY = 1001
+ PROTOCOL_ERROR = 1002
+ UNSUPPORTED_DATA = 1003
+ NO_STATUS_RCVD = 1005
+ ABNORMAL_CLOSURE = 1006
+ INVALID_FRAME_PAYLOAD_DATA = 1007
+ POLICY_VIOLATION = 1008
+ MESSAGE_TOO_BIG = 1009
+ MANDATORY_EXT = 1010
+ INTERNAL_ERROR = 1011
+ SERVICE_RESTART = 1012
+ TRY_AGAIN_LATER = 1013
+ TLS_HANDSHAKE_FAILED = 1015
+
+
+# RFC 6455, Section 7.4.1 - Defined Status Codes
+LOCAL_ONLY_CLOSE_REASONS = (
+ CloseReason.NO_STATUS_RCVD,
+ CloseReason.ABNORMAL_CLOSURE,
+ CloseReason.TLS_HANDSHAKE_FAILED,
+)
+
+
+# RFC 6455, Section 7.4.2 - Status Code Ranges
+MIN_CLOSE_REASON = 1000
+MIN_PROTOCOL_CLOSE_REASON = 1000
+MAX_PROTOCOL_CLOSE_REASON = 2999
+MIN_LIBRARY_CLOSE_REASON = 3000
+MAX_LIBRARY_CLOSE_REASON = 3999
+MIN_PRIVATE_CLOSE_REASON = 4000
+MAX_PRIVATE_CLOSE_REASON = 4999
+MAX_CLOSE_REASON = 4999
+
+
+NULL_MASK = struct.pack("!I", 0)
+
+
+class ParseFailed(Exception):
+ def __init__(self, msg, code=CloseReason.PROTOCOL_ERROR):
+ super(ParseFailed, self).__init__(msg)
+ self.code = code
+
+
+Header = namedtuple("Header", "fin rsv opcode payload_len masking_key".split())
+
+
+Frame = namedtuple("Frame",
+ "opcode payload frame_finished message_finished".split())
+
+
+RsvBits = namedtuple("RsvBits", "rsv1 rsv2 rsv3".split())
+
+
+def _truncate_utf8(data, nbytes):
+ if len(data) <= nbytes:
+ return data
+
+ # Truncate
+ data = data[:nbytes]
+ # But we might have cut a codepoint in half, in which case we want to
+ # discard the partial character so the data is at least
+ # well-formed. This is a little inefficient since it processes the
+ # whole message twice when in theory we could just peek at the last
+ # few characters, but since this is only used for close messages (max
+ # length = 125 bytes) it really doesn't matter.
+ data = data.decode("utf-8", errors="ignore").encode("utf-8")
+ return data
+
+
+class Buffer(object):
+ def __init__(self, initial_bytes=None):
+ self.buffer = bytearray()
+ self.bytes_used = 0
+ if initial_bytes:
+ self.feed(initial_bytes)
+
+ def feed(self, new_bytes):
+ self.buffer += new_bytes
+
+ def consume_at_most(self, nbytes):
+ if not nbytes:
+ return bytearray()
+
+ data = self.buffer[self.bytes_used:self.bytes_used + nbytes]
+ self.bytes_used += len(data)
+ return data
+
+ def consume_exactly(self, nbytes):
+ if len(self.buffer) - self.bytes_used < nbytes:
+ return None
+
+ return self.consume_at_most(nbytes)
+
+ def commit(self):
+ # In CPython 3.4+, del[:n] is amortized O(n), *not* quadratic
+ del self.buffer[:self.bytes_used]
+ self.bytes_used = 0
+
+ def rollback(self):
+ self.bytes_used = 0
+
+ def __len__(self):
+ return len(self.buffer)
+
+
+class MessageDecoder(object):
+ def __init__(self):
+ self.opcode = None
+ self.validator = None
+ self.decoder = None
+
+ def process_frame(self, frame):
+ assert not frame.opcode.iscontrol()
+
+ if self.opcode is None:
+ if frame.opcode is Opcode.CONTINUATION:
+ raise ParseFailed("unexpected CONTINUATION")
+ self.opcode = frame.opcode
+ elif frame.opcode is not Opcode.CONTINUATION:
+ raise ParseFailed("expected CONTINUATION, got %r" % frame.opcode)
+
+ if frame.opcode is Opcode.TEXT:
+ self.validator = Utf8Validator()
+ self.decoder = getincrementaldecoder("utf-8")()
+
+ finished = frame.frame_finished and frame.message_finished
+
+ if self.decoder is not None:
+ data = self.decode_payload(frame.payload, finished)
+ else:
+ data = frame.payload
+
+ frame = Frame(self.opcode, data, frame.frame_finished, finished)
+
+ if finished:
+ self.opcode = None
+ self.decoder = None
+
+ return frame
+
+ def decode_payload(self, data, finished):
+ if self.validator is not None:
+ results = self.validator.validate(bytes(data))
+ if not results[0] or (finished and not results[1]):
+ raise ParseFailed(u'encountered invalid UTF-8 while processing'
+ ' text message at payload octet index %d' %
+ results[3],
+ CloseReason.INVALID_FRAME_PAYLOAD_DATA)
+
+ try:
+ return self.decoder.decode(data, finished)
+ except UnicodeDecodeError as exc:
+ raise ParseFailed(str(exc), CloseReason.INVALID_FRAME_PAYLOAD_DATA)
+
+
+class FrameDecoder(object):
+ def __init__(self, client, extensions=None):
+ self.client = client
+ self.extensions = extensions or []
+
+ self.buffer = Buffer()
+
+ self.header = None
+ self.effective_opcode = None
+ self.masker = None
+ self.payload_required = 0
+ self.payload_consumed = 0
+
+ def receive_bytes(self, data):
+ self.buffer.feed(data)
+
+ def process_buffer(self):
+ if not self.header:
+ if not self.parse_header():
+ return None
+
+ if len(self.buffer) < self.payload_required:
+ return None
+
+ payload_remaining = self.header.payload_len - self.payload_consumed
+ payload = self.buffer.consume_at_most(payload_remaining)
+ if not payload and self.header.payload_len > 0:
+ return None
+ self.buffer.commit()
+
+ self.payload_consumed += len(payload)
+ finished = self.payload_consumed == self.header.payload_len
+
+ payload = self.masker.process(payload)
+
+ for extension in self.extensions:
+ payload = extension.frame_inbound_payload_data(self, payload)
+ if isinstance(payload, CloseReason):
+ raise ParseFailed("error in extension", payload)
+
+ if finished:
+ final = bytearray()
+ for extension in self.extensions:
+ result = extension.frame_inbound_complete(self,
+ self.header.fin)
+ if isinstance(result, CloseReason):
+ raise ParseFailed("error in extension", result)
+ if result is not None:
+ final += result
+ payload += final
+
+ frame = Frame(self.effective_opcode, payload, finished,
+ self.header.fin)
+
+ if finished:
+ self.header = None
+ self.effective_opcode = None
+ self.masker = None
+ else:
+ self.effective_opcode = Opcode.CONTINUATION
+
+ return frame
+
+ def parse_header(self):
+ data = self.buffer.consume_exactly(2)
+ if data is None:
+ self.buffer.rollback()
+ return False
+
+ fin = bool(data[0] & FIN_MASK)
+ rsv = RsvBits(bool(data[0] & RSV1_MASK),
+ bool(data[0] & RSV2_MASK),
+ bool(data[0] & RSV3_MASK))
+ opcode = data[0] & OPCODE_MASK
+ try:
+ opcode = Opcode(opcode)
+ except ValueError:
+ raise ParseFailed("Invalid opcode {:#x}".format(opcode))
+
+ if opcode.iscontrol() and not fin:
+ raise ParseFailed("Invalid attempt to fragment control frame")
+
+ has_mask = bool(data[1] & MASK_MASK)
+ payload_len = data[1] & PAYLOAD_LEN_MASK
+ payload_len = self.parse_extended_payload_length(opcode, payload_len)
+ if payload_len is None:
+ self.buffer.rollback()
+ return False
+
+ self.extension_processing(opcode, rsv, payload_len)
+
+ if has_mask and self.client:
+ raise ParseFailed("client received unexpected masked frame")
+ if not has_mask and not self.client:
+ raise ParseFailed("server received unexpected unmasked frame")
+ if has_mask:
+ masking_key = self.buffer.consume_exactly(4)
+ if masking_key is None:
+ self.buffer.rollback()
+ return False
+ self.masker = XorMaskerSimple(masking_key)
+ else:
+ self.masker = XorMaskerNull()
+
+ self.buffer.commit()
+ self.header = Header(fin, rsv, opcode, payload_len, None)
+ self.effective_opcode = self.header.opcode
+ if self.header.opcode.iscontrol():
+ self.payload_required = payload_len
+ else:
+ self.payload_required = 0
+ self.payload_consumed = 0
+ return True
+
+ def parse_extended_payload_length(self, opcode, payload_len):
+ if opcode.iscontrol() and payload_len > MAX_PAYLOAD_NORMAL:
+ raise ParseFailed("Control frame with payload len > 125")
+ if payload_len == PAYLOAD_LENGTH_TWO_BYTE:
+ data = self.buffer.consume_exactly(2)
+ if data is None:
+ return None
+ (payload_len,) = struct.unpack("!H", data)
+ if payload_len <= MAX_PAYLOAD_NORMAL:
+ raise ParseFailed(
+ "Payload length used 2 bytes when 1 would have sufficed")
+ elif payload_len == PAYLOAD_LENGTH_EIGHT_BYTE:
+ data = self.buffer.consume_exactly(8)
+ if data is None:
+ return None
+ (payload_len,) = struct.unpack("!Q", data)
+ if payload_len <= MAX_PAYLOAD_TWO_BYTE:
+ raise ParseFailed(
+ "Payload length used 8 bytes when 2 would have sufficed")
+ if payload_len >> 63:
+ # I'm not sure why this is illegal, but that's what the RFC
+ # says, so...
+ raise ParseFailed("8-byte payload length with non-zero MSB")
+
+ return payload_len
+
+ def extension_processing(self, opcode, rsv, payload_len):
+ rsv_used = [False, False, False]
+ for extension in self.extensions:
+ result = extension.frame_inbound_header(self, opcode, rsv,
+ payload_len)
+ if isinstance(result, CloseReason):
+ raise ParseFailed("error in extension", result)
+ for bit, used in enumerate(result):
+ if used:
+ rsv_used[bit] = True
+ for expected, found in zip(rsv_used, rsv):
+ if found and not expected:
+ raise ParseFailed("Reserved bit set unexpectedly")
+
+
+class FrameProtocol(object):
+ class State(Enum):
+ HEADER = 1
+ PAYLOAD = 2
+ FRAME_COMPLETE = 3
+ FAILED = 4
+
+ def __init__(self, client, extensions):
+ self.client = client
+ self.extensions = [ext for ext in extensions if ext.enabled()]
+
+ # Global state
+ self._frame_decoder = FrameDecoder(self.client, self.extensions)
+ self._message_decoder = MessageDecoder()
+ self._parse_more = self.parse_more_gen()
+
+ self._outbound_opcode = None
+
+ def _process_close(self, frame):
+ data = frame.payload
+
+ if not data:
+ # "If this Close control frame contains no status code, _The
+ # WebSocket Connection Close Code_ is considered to be 1005"
+ data = (CloseReason.NO_STATUS_RCVD, "")
+ elif len(data) == 1:
+ raise ParseFailed("CLOSE with 1 byte payload")
+ else:
+ (code,) = struct.unpack("!H", data[:2])
+ if code < MIN_CLOSE_REASON or code > MAX_CLOSE_REASON:
+ raise ParseFailed("CLOSE with invalid code")
+ try:
+ code = CloseReason(code)
+ except ValueError:
+ pass
+ if code in LOCAL_ONLY_CLOSE_REASONS:
+ raise ParseFailed(
+ "remote CLOSE with local-only reason")
+ if not isinstance(code, CloseReason) and \
+ code <= MAX_PROTOCOL_CLOSE_REASON:
+ raise ParseFailed(
+ "CLOSE with unknown reserved code")
+ validator = Utf8Validator()
+ if validator is not None:
+ results = validator.validate(bytes(data[2:]))
+ if not (results[0] and results[1]):
+ raise ParseFailed(u'encountered invalid UTF-8 while'
+ ' processing close message at payload'
+ ' octet index %d' %
+ results[3],
+ CloseReason.INVALID_FRAME_PAYLOAD_DATA)
+ try:
+ reason = data[2:].decode("utf-8")
+ except UnicodeDecodeError as exc:
+ raise ParseFailed(
+ "Error decoding CLOSE reason: " + str(exc),
+ CloseReason.INVALID_FRAME_PAYLOAD_DATA)
+ data = (code, reason)
+
+ return Frame(frame.opcode, data, frame.frame_finished,
+ frame.message_finished)
+
+ def parse_more_gen(self):
+ # Consume as much as we can from self._buffer, yielding events, and
+ # then yield None when we need more data. Or raise ParseFailed.
+
+ # XX FIXME this should probably be refactored so that we never see
+ # disabled extensions in the first place...
+ self.extensions = [ext for ext in self.extensions if ext.enabled()]
+ closed = False
+
+ while not closed:
+ frame = self._frame_decoder.process_buffer()
+
+ if frame is not None:
+ if not frame.opcode.iscontrol():
+ frame = self._message_decoder.process_frame(frame)
+ elif frame.opcode == Opcode.CLOSE:
+ frame = self._process_close(frame)
+ closed = True
+
+ yield frame
+
+ def receive_bytes(self, data):
+ self._frame_decoder.receive_bytes(data)
+
+ def received_frames(self):
+ for event in self._parse_more:
+ if event is None:
+ break
+ else:
+ yield event
+
+ def close(self, code=None, reason=None):
+ payload = bytearray()
+ if code is None and reason is not None:
+ raise TypeError("cannot specify a reason without a code")
+ if code in LOCAL_ONLY_CLOSE_REASONS:
+ code = CloseReason.NORMAL_CLOSURE
+ if code is not None:
+ payload += bytearray(struct.pack('!H', code))
+ if reason is not None:
+ payload += _truncate_utf8(reason.encode('utf-8'),
+ MAX_PAYLOAD_NORMAL - 2)
+
+ return self._serialize_frame(Opcode.CLOSE, payload)
+
+ def ping(self, payload=b''):
+ return self._serialize_frame(Opcode.PING, payload)
+
+ def pong(self, payload=b''):
+ return self._serialize_frame(Opcode.PONG, payload)
+
+ def send_data(self, payload=b'', fin=True):
+ if isinstance(payload, (bytes, bytearray, memoryview)):
+ opcode = Opcode.BINARY
+ elif isinstance(payload, unicode):
+ opcode = Opcode.TEXT
+ payload = payload.encode('utf-8')
+ else:
+ raise ValueError('Must provide bytes or text')
+
+ if self._outbound_opcode is None:
+ self._outbound_opcode = opcode
+ elif self._outbound_opcode is not opcode:
+ raise TypeError('Data type mismatch inside message')
+ else:
+ opcode = Opcode.CONTINUATION
+
+ if fin:
+ self._outbound_opcode = None
+
+ return self._serialize_frame(opcode, payload, fin)
+
+ def _make_fin_rsv_opcode(self, fin, rsv, opcode):
+ fin = int(fin) << 7
+ rsv = (int(rsv.rsv1) << 6) + (int(rsv.rsv2) << 5) + \
+ (int(rsv.rsv3) << 4)
+ opcode = int(opcode)
+
+ return fin | rsv | opcode
+
+ def _serialize_frame(self, opcode, payload=b'', fin=True):
+ rsv = RsvBits(False, False, False)
+ for extension in reversed(self.extensions):
+ rsv, payload = extension.frame_outbound(self, opcode, rsv, payload,
+ fin)
+
+ fin_rsv_opcode = self._make_fin_rsv_opcode(fin, rsv, opcode)
+
+ payload_length = len(payload)
+ quad_payload = False
+ if payload_length <= MAX_PAYLOAD_NORMAL:
+ first_payload = payload_length
+ second_payload = None
+ elif payload_length <= MAX_PAYLOAD_TWO_BYTE:
+ first_payload = PAYLOAD_LENGTH_TWO_BYTE
+ second_payload = payload_length
+ else:
+ first_payload = PAYLOAD_LENGTH_EIGHT_BYTE
+ second_payload = payload_length
+ quad_payload = True
+
+ if self.client:
+ first_payload |= 1 << 7
+
+ header = bytearray([fin_rsv_opcode, first_payload])
+ if second_payload is not None:
+ if opcode.iscontrol():
+ raise ValueError("payload too long for control frame")
+ if quad_payload:
+ header += bytearray(struct.pack('!Q', second_payload))
+ else:
+ header += bytearray(struct.pack('!H', second_payload))
+
+ if self.client:
+ # "The masking key is a 32-bit value chosen at random by the
+ # client. When preparing a masked frame, the client MUST pick a
+ # fresh masking key from the set of allowed 32-bit values. The
+ # masking key needs to be unpredictable; thus, the masking key
+ # MUST be derived from a strong source of entropy, and the masking
+ # key for a given frame MUST NOT make it simple for a server/proxy
+ # to predict the masking key for a subsequent frame. The
+ # unpredictability of the masking key is essential to prevent
+ # authors of malicious applications from selecting the bytes that
+ # appear on the wire."
+ # -- https://tools.ietf.org/html/rfc6455#section-5.3
+ masking_key = os.urandom(4)
+ masker = XorMaskerSimple(masking_key)
+ return header + masking_key + masker.process(payload)
+
+ return header + payload
diff --git a/mitmproxy/proxy/protocol/websocket.py b/mitmproxy/proxy/protocol/websocket.py
index 19546eb2..34dcba06 100644
--- a/mitmproxy/proxy/protocol/websocket.py
+++ b/mitmproxy/proxy/protocol/websocket.py
@@ -1,14 +1,18 @@
-import os
import socket
-import struct
from OpenSSL import SSL
+from mitmproxy.contrib.wsproto import events
+from mitmproxy.contrib.wsproto.connection import ConnectionType, WSConnection
+from mitmproxy.contrib.wsproto.extensions import PerMessageDeflate
+from mitmproxy.contrib.wsproto.frame_protocol import Opcode
+
from mitmproxy import exceptions
from mitmproxy import flow
from mitmproxy.proxy.protocol import base
from mitmproxy.net import tcp
from mitmproxy.net import websockets
from mitmproxy.websocket import WebSocketFlow, WebSocketMessage
+from mitmproxy.utils import strutils
class WebSocketLayer(base.Layer):
@@ -44,26 +48,56 @@ class WebSocketLayer(base.Layer):
self.client_frame_buffer = []
self.server_frame_buffer = []
- def _handle_frame(self, frame, source_conn, other_conn, is_server):
- if frame.header.opcode & 0x8 == 0:
- return self._handle_data_frame(frame, source_conn, other_conn, is_server)
- elif frame.header.opcode in (websockets.OPCODE.PING, websockets.OPCODE.PONG):
- return self._handle_ping_pong(frame, source_conn, other_conn, is_server)
- elif frame.header.opcode == websockets.OPCODE.CLOSE:
- return self._handle_close(frame, source_conn, other_conn, is_server)
- else:
- return self._handle_unknown_frame(frame, source_conn, other_conn, is_server)
-
- def _handle_data_frame(self, frame, source_conn, other_conn, is_server):
-
+ self.connections = {} # type: Dict[object, WSConnection]
+
+ extensions = []
+ if 'Sec-WebSocket-Extensions' in handshake_flow.response.headers:
+ if PerMessageDeflate.name in handshake_flow.response.headers['Sec-WebSocket-Extensions']:
+ extensions = [PerMessageDeflate()]
+ self.connections[self.client_conn] = WSConnection(ConnectionType.SERVER,
+ extensions=extensions)
+ self.connections[self.server_conn] = WSConnection(ConnectionType.CLIENT,
+ host=handshake_flow.request.host,
+ resource=handshake_flow.request.path,
+ extensions=extensions)
+ if extensions:
+ for conn in self.connections.values():
+ conn.extensions[0].finalize(conn, handshake_flow.response.headers['Sec-WebSocket-Extensions'])
+
+ data = self.connections[self.server_conn].bytes_to_send()
+ self.connections[self.client_conn].receive_bytes(data)
+
+ event = next(self.connections[self.client_conn].events())
+ assert isinstance(event, events.ConnectionRequested)
+
+ self.connections[self.client_conn].accept(event)
+ self.connections[self.server_conn].receive_bytes(self.connections[self.client_conn].bytes_to_send())
+ assert isinstance(next(self.connections[self.server_conn].events()), events.ConnectionEstablished)
+
+ def _handle_event(self, event, source_conn, other_conn, is_server):
+ if isinstance(event, events.DataReceived):
+ return self._handle_data_received(event, source_conn, other_conn, is_server)
+ elif isinstance(event, events.PingReceived):
+ return self._handle_ping_received(event, source_conn, other_conn, is_server)
+ elif isinstance(event, events.PongReceived):
+ return self._handle_pong_received(event, source_conn, other_conn, is_server)
+ elif isinstance(event, events.ConnectionClosed):
+ return self._handle_connection_closed(event, source_conn, other_conn, is_server)
+
+ # fail-safe for unhandled events
+ return True # pragma: no cover
+
+ def _handle_data_received(self, event, source_conn, other_conn, is_server):
fb = self.server_frame_buffer if is_server else self.client_frame_buffer
- fb.append(frame)
+ fb.append(event.data)
- if frame.header.fin:
- payload = b''.join(f.payload for f in fb)
- original_chunk_sizes = [len(f.payload) for f in fb]
- message_type = fb[0].header.opcode
- compressed_message = fb[0].header.rsv1
+ if event.message_finished:
+ original_chunk_sizes = [len(f) for f in fb]
+ message_type = Opcode.TEXT if isinstance(event, events.TextReceived) else Opcode.BINARY
+ if message_type == Opcode.TEXT:
+ payload = ''.join(fb)
+ else:
+ payload = b''.join(fb)
fb.clear()
websocket_message = WebSocketMessage(message_type, not is_server, payload)
@@ -77,7 +111,7 @@ class WebSocketLayer(base.Layer):
# message has the same length, we can reuse the same sizes
pos = 0
for s in original_chunk_sizes:
- yield payload[pos:pos + s]
+ yield (payload[pos:pos + s], True if pos + s == length else False)
pos += s
else:
# just re-chunk everything into 4kB frames
@@ -85,95 +119,81 @@ class WebSocketLayer(base.Layer):
chunk_size = 4092 if is_server else 4088
chunks = range(0, len(payload), chunk_size)
for i in chunks:
- yield payload[i:i + chunk_size]
-
- frms = [
- websockets.Frame(
- payload=chunk,
- opcode=frame.header.opcode,
- mask=(False if is_server else 1),
- masking_key=(b'' if is_server else os.urandom(4)))
- for chunk in get_chunk(websocket_message.content)
- ]
-
- if len(frms) > 0:
- frms[-1].header.fin = True
- else:
- frms.append(websockets.Frame(
- fin=True,
- opcode=websockets.OPCODE.CONTINUE,
- mask=(False if is_server else 1),
- masking_key=(b'' if is_server else os.urandom(4))))
-
- frms[0].header.opcode = message_type
- frms[0].header.rsv1 = compressed_message
-
- for frm in frms:
- other_conn.send(bytes(frm))
+ yield (payload[i:i + chunk_size], True if i + chunk_size >= len(payload) else False)
+
+ for chunk, final in get_chunk(websocket_message.content):
+ self.connections[other_conn].send_data(chunk, final)
+ other_conn.send(self.connections[other_conn].bytes_to_send())
else:
- other_conn.send(bytes(frame))
+ self.connections[other_conn].send_data(event.data, event.message_finished)
+ other_conn.send(self.connections[other_conn].bytes_to_send())
elif self.flow.stream:
- other_conn.send(bytes(frame))
+ self.connections[other_conn].send_data(event.data, event.message_finished)
+ other_conn.send(self.connections[other_conn].bytes_to_send())
+
+ return True
+ def _handle_ping_received(self, event, source_conn, other_conn, is_server):
+ # PING is automatically answered with a PONG by wsproto
+ self.connections[other_conn].ping()
+ other_conn.send(self.connections[other_conn].bytes_to_send())
+ source_conn.send(self.connections[source_conn].bytes_to_send())
+ self.log(
+ "Ping Received from {}".format("server" if is_server else "client"),
+ "info",
+ [strutils.bytes_to_escaped_str(bytes(event.payload))]
+ )
return True
- def _handle_ping_pong(self, frame, source_conn, other_conn, is_server):
- # just forward the ping/pong to the other side
- other_conn.send(bytes(frame))
+ def _handle_pong_received(self, event, source_conn, other_conn, is_server):
+ self.log(
+ "Pong Received from {}".format("server" if is_server else "client"),
+ "info",
+ [strutils.bytes_to_escaped_str(bytes(event.payload))]
+ )
return True
- def _handle_close(self, frame, source_conn, other_conn, is_server):
+ def _handle_connection_closed(self, event, source_conn, other_conn, is_server):
self.flow.close_sender = "server" if is_server else "client"
- if len(frame.payload) >= 2:
- code, = struct.unpack('!H', frame.payload[:2])
- self.flow.close_code = code
- self.flow.close_message = websockets.CLOSE_REASON.get_name(code, default='unknown status code')
- if len(frame.payload) > 2:
- self.flow.close_reason = frame.payload[2:]
+ self.flow.close_code = event.code
+ self.flow.close_reason = event.reason
- other_conn.send(bytes(frame))
+ self.connections[other_conn].close(event.code, event.reason)
+ other_conn.send(self.connections[other_conn].bytes_to_send())
+ source_conn.send(self.connections[source_conn].bytes_to_send())
- # initiate close handshake
return False
- def _handle_unknown_frame(self, frame, source_conn, other_conn, is_server):
- # unknown frame - just forward it
- other_conn.send(bytes(frame))
-
- sender = "server" if is_server else "client"
- self.log("Unknown WebSocket frame received from {}".format(sender), "info", [repr(frame)])
-
- return True
-
def __call__(self):
self.flow = WebSocketFlow(self.client_conn, self.server_conn, self.handshake_flow, self)
self.flow.metadata['websocket_handshake'] = self.handshake_flow.id
self.handshake_flow.metadata['websocket_flow'] = self.flow.id
self.channel.ask("websocket_start", self.flow)
- client = self.client_conn.connection
- server = self.server_conn.connection
- conns = [client, server]
+ conns = [c.connection for c in self.connections.keys()]
close_received = False
try:
while not self.channel.should_exit.is_set():
r = tcp.ssl_read_select(conns, 0.1)
for conn in r:
- source_conn = self.client_conn if conn == client else self.server_conn
- other_conn = self.server_conn if conn == client else self.client_conn
- is_server = (conn == self.server_conn.connection)
+ source_conn = self.client_conn if conn == self.client_conn.connection else self.server_conn
+ other_conn = self.server_conn if conn == self.client_conn.connection else self.client_conn
+ is_server = (source_conn == self.server_conn)
frame = websockets.Frame.from_file(source_conn.rfile)
+ self.connections[source_conn].receive_bytes(bytes(frame))
+ source_conn.send(self.connections[source_conn].bytes_to_send())
+
+ if close_received:
+ return
- cont = self._handle_frame(frame, source_conn, other_conn, is_server)
- if not cont:
- if close_received:
- return
- else:
- close_received = True
+ for event in self.connections[source_conn].events():
+ if not self._handle_event(event, source_conn, other_conn, is_server):
+ if not close_received:
+ close_received = True
except (socket.error, exceptions.TcpException, SSL.Error) as e:
s = 'server' if is_server else 'client'
self.flow.error = flow.Error("WebSocket connection closed unexpectedly by {}: {}".format(s, repr(e)))
diff --git a/mitmproxy/tools/console/consoleaddons.py b/mitmproxy/tools/console/consoleaddons.py
index 471e3a53..06ee3341 100644
--- a/mitmproxy/tools/console/consoleaddons.py
+++ b/mitmproxy/tools/console/consoleaddons.py
@@ -49,7 +49,7 @@ class UnsupportedLog:
def websocket_message(self, f):
message = f.messages[-1]
signals.add_log(f.message_info(message), "info")
- signals.add_log(strutils.bytes_to_escaped_str(message.content), "debug")
+ signals.add_log(message.content if isinstance(message.content, str) else strutils.bytes_to_escaped_str(message.content), "debug")
def websocket_end(self, f):
signals.add_log("WebSocket connection closed by {}: {} {}, {}".format(
diff --git a/setup.cfg b/setup.cfg
index eaabfa12..fd31d15b 100644
--- a/setup.cfg
+++ b/setup.cfg
@@ -21,7 +21,13 @@ exclude_lines =
[tool:full_coverage]
exclude =
- mitmproxy/proxy/protocol/
+ mitmproxy/proxy/protocol/base.py
+ mitmproxy/proxy/protocol/http.py
+ mitmproxy/proxy/protocol/http1.py
+ mitmproxy/proxy/protocol/http2.py
+ mitmproxy/proxy/protocol/http_replay.py
+ mitmproxy/proxy/protocol/rawtcp.py
+ mitmproxy/proxy/protocol/tls.py
mitmproxy/proxy/root_context.py
mitmproxy/proxy/server.py
mitmproxy/tools/
@@ -64,7 +70,6 @@ exclude =
mitmproxy/proxy/protocol/http_replay.py
mitmproxy/proxy/protocol/rawtcp.py
mitmproxy/proxy/protocol/tls.py
- mitmproxy/proxy/protocol/websocket.py
mitmproxy/proxy/root_context.py
mitmproxy/proxy/server.py
mitmproxy/stateobject.py
diff --git a/setup.py b/setup.py
index 54c2811d..ad792881 100644
--- a/setup.py
+++ b/setup.py
@@ -65,6 +65,7 @@ setup(
"certifi>=2015.11.20.1", # no semver here - this should always be on the last release!
"click>=6.2, <7",
"cryptography>=2.0,<2.2",
+ 'h11>=0.7.0,<0.8',
"h2>=3.0, <4",
"hyperframe>=5.0, <6",
"kaitaistruct>=0.7, <0.8",
diff --git a/test/mitmproxy/proxy/protocol/test_websocket.py b/test/mitmproxy/proxy/protocol/test_websocket.py
index 460d85f8..a7acdc4d 100644
--- a/test/mitmproxy/proxy/protocol/test_websocket.py
+++ b/test/mitmproxy/proxy/protocol/test_websocket.py
@@ -1,5 +1,6 @@
import pytest
import os
+import struct
import tempfile
import traceback
@@ -33,6 +34,7 @@ class _WebSocketServerBase(net_tservers.ServerTestBase):
connection='upgrade',
upgrade='websocket',
sec_websocket_accept=b'',
+ sec_websocket_extensions='permessage-deflate' if "permessage-deflate" in request.headers.values() else ''
),
content=b'',
)
@@ -80,7 +82,7 @@ class _WebSocketTestBase:
if self.client:
self.client.close()
- def setup_connection(self):
+ def setup_connection(self, extension=False):
self.client = tcp.TCPClient(("127.0.0.1", self.proxy.port))
self.client.connect()
@@ -115,6 +117,7 @@ class _WebSocketTestBase:
upgrade="websocket",
sec_websocket_version="13",
sec_websocket_key="1234",
+ sec_websocket_extensions="permessage-deflate" if extension else ""
),
content=b'')
self.client.wfile.write(http.http1.assemble_request(request))
@@ -145,11 +148,11 @@ class TestSimple(_WebSocketTest):
wfile.flush()
frame = websockets.Frame.from_file(rfile)
- wfile.write(bytes(frame))
+ wfile.write(bytes(websockets.Frame(fin=1, opcode=frame.header.opcode, payload=frame.payload)))
wfile.flush()
frame = websockets.Frame.from_file(rfile)
- wfile.write(bytes(frame))
+ wfile.write(bytes(websockets.Frame(fin=1, opcode=frame.header.opcode, payload=frame.payload)))
wfile.flush()
@pytest.mark.parametrize('streaming', [True, False])
@@ -164,36 +167,59 @@ class TestSimple(_WebSocketTest):
frame = websockets.Frame.from_file(self.client.rfile)
assert frame.payload == b'server-foobar'
- self.client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.TEXT, payload=b'self.client-foobar')))
+ self.client.wfile.write(bytes(websockets.Frame(fin=1, mask=1, opcode=websockets.OPCODE.TEXT, payload=b'self.client-foobar')))
self.client.wfile.flush()
frame = websockets.Frame.from_file(self.client.rfile)
assert frame.payload == b'self.client-foobar'
- self.client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.BINARY, payload=b'\xde\xad\xbe\xef')))
+ self.client.wfile.write(bytes(websockets.Frame(fin=1, mask=1, opcode=websockets.OPCODE.BINARY, payload=b'\xde\xad\xbe\xef')))
self.client.wfile.flush()
frame = websockets.Frame.from_file(self.client.rfile)
assert frame.payload == b'\xde\xad\xbe\xef'
- self.client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.CLOSE)))
+ self.client.wfile.write(bytes(websockets.Frame(fin=1, mask=1, opcode=websockets.OPCODE.CLOSE)))
self.client.wfile.flush()
assert len(self.master.state.flows) == 2
assert isinstance(self.master.state.flows[0], HTTPFlow)
assert isinstance(self.master.state.flows[1], WebSocketFlow)
assert len(self.master.state.flows[1].messages) == 5
- assert self.master.state.flows[1].messages[0].content == b'server-foobar'
+ assert self.master.state.flows[1].messages[0].content == 'server-foobar'
assert self.master.state.flows[1].messages[0].type == websockets.OPCODE.TEXT
- assert self.master.state.flows[1].messages[1].content == b'self.client-foobar'
+ assert self.master.state.flows[1].messages[1].content == 'self.client-foobar'
assert self.master.state.flows[1].messages[1].type == websockets.OPCODE.TEXT
- assert self.master.state.flows[1].messages[2].content == b'self.client-foobar'
+ assert self.master.state.flows[1].messages[2].content == 'self.client-foobar'
assert self.master.state.flows[1].messages[2].type == websockets.OPCODE.TEXT
assert self.master.state.flows[1].messages[3].content == b'\xde\xad\xbe\xef'
assert self.master.state.flows[1].messages[3].type == websockets.OPCODE.BINARY
assert self.master.state.flows[1].messages[4].content == b'\xde\xad\xbe\xef'
assert self.master.state.flows[1].messages[4].type == websockets.OPCODE.BINARY
+ def test_change_payload(self):
+ class Addon:
+ def websocket_message(self, f):
+ f.messages[-1].content = "foo"
+
+ self.master.addons.add(Addon())
+ self.setup_connection()
+
+ frame = websockets.Frame.from_file(self.client.rfile)
+ assert frame.payload == b'foo'
+
+ self.client.wfile.write(bytes(websockets.Frame(fin=1, mask=1, opcode=websockets.OPCODE.TEXT, payload=b'self.client-foobar')))
+ self.client.wfile.flush()
+
+ frame = websockets.Frame.from_file(self.client.rfile)
+ assert frame.payload == b'foo'
+
+ self.client.wfile.write(bytes(websockets.Frame(fin=1, mask=1, opcode=websockets.OPCODE.BINARY, payload=b'\xde\xad\xbe\xef')))
+ self.client.wfile.flush()
+
+ frame = websockets.Frame.from_file(self.client.rfile)
+ assert frame.payload == b'foo'
+
class TestSimpleTLS(_WebSocketTest):
ssl = True
@@ -204,7 +230,7 @@ class TestSimpleTLS(_WebSocketTest):
wfile.flush()
frame = websockets.Frame.from_file(rfile)
- wfile.write(bytes(frame))
+ wfile.write(bytes(websockets.Frame(fin=1, opcode=frame.header.opcode, payload=frame.payload)))
wfile.flush()
def test_simple_tls(self):
@@ -213,13 +239,13 @@ class TestSimpleTLS(_WebSocketTest):
frame = websockets.Frame.from_file(self.client.rfile)
assert frame.payload == b'server-foobar'
- self.client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.TEXT, payload=b'self.client-foobar')))
+ self.client.wfile.write(bytes(websockets.Frame(fin=1, mask=1, opcode=websockets.OPCODE.TEXT, payload=b'self.client-foobar')))
self.client.wfile.flush()
frame = websockets.Frame.from_file(self.client.rfile)
assert frame.payload == b'self.client-foobar'
- self.client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.CLOSE)))
+ self.client.wfile.write(bytes(websockets.Frame(fin=1, mask=1, opcode=websockets.OPCODE.CLOSE)))
self.client.wfile.flush()
@@ -234,22 +260,24 @@ class TestPing(_WebSocketTest):
assert frame.header.opcode == websockets.OPCODE.PONG
assert frame.payload == b'foobar'
- wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.TEXT, payload=b'pong-received')))
+ wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.PONG, payload=b'done')))
+ wfile.flush()
+
+ wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.CLOSE)))
wfile.flush()
+ websockets.Frame.from_file(rfile)
def test_ping(self):
self.setup_connection()
frame = websockets.Frame.from_file(self.client.rfile)
- assert frame.header.opcode == websockets.OPCODE.PING
- assert frame.payload == b'foobar'
-
- self.client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.PONG, payload=frame.payload)))
+ websockets.Frame.from_file(self.client.rfile)
+ self.client.wfile.write(bytes(websockets.Frame(fin=1, mask=1, opcode=websockets.OPCODE.CLOSE)))
self.client.wfile.flush()
+ assert frame.header.opcode == websockets.OPCODE.PING
+ assert frame.payload == b'' # We don't send payload to other end
- frame = websockets.Frame.from_file(self.client.rfile)
- assert frame.header.opcode == websockets.OPCODE.TEXT
- assert frame.payload == b'pong-received'
+ assert self.master.has_log("Pong Received from server", "info")
class TestPong(_WebSocketTest):
@@ -258,20 +286,29 @@ class TestPong(_WebSocketTest):
def handle_websockets(cls, rfile, wfile):
frame = websockets.Frame.from_file(rfile)
assert frame.header.opcode == websockets.OPCODE.PING
- assert frame.payload == b'foobar'
+ assert frame.payload == b''
wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.PONG, payload=frame.payload)))
wfile.flush()
+ wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.CLOSE)))
+ wfile.flush()
+ websockets.Frame.from_file(rfile)
+
def test_pong(self):
self.setup_connection()
- self.client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.PING, payload=b'foobar')))
+ self.client.wfile.write(bytes(websockets.Frame(fin=1, mask=1, opcode=websockets.OPCODE.PING, payload=b'foobar')))
self.client.wfile.flush()
frame = websockets.Frame.from_file(self.client.rfile)
+ websockets.Frame.from_file(self.client.rfile)
+ self.client.wfile.write(bytes(websockets.Frame(fin=1, mask=1, opcode=websockets.OPCODE.CLOSE)))
+ self.client.wfile.flush()
+
assert frame.header.opcode == websockets.OPCODE.PONG
assert frame.payload == b'foobar'
+ assert self.master.has_log("Pong Received from server", "info")
class TestClose(_WebSocketTest):
@@ -279,7 +316,7 @@ class TestClose(_WebSocketTest):
@classmethod
def handle_websockets(cls, rfile, wfile):
frame = websockets.Frame.from_file(rfile)
- wfile.write(bytes(frame))
+ wfile.write(bytes(websockets.Frame(fin=1, opcode=frame.header.opcode, payload=frame.payload)))
wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.CLOSE)))
wfile.flush()
@@ -289,7 +326,7 @@ class TestClose(_WebSocketTest):
def test_close(self):
self.setup_connection()
- self.client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.CLOSE)))
+ self.client.wfile.write(bytes(websockets.Frame(fin=1, mask=1, opcode=websockets.OPCODE.CLOSE)))
self.client.wfile.flush()
websockets.Frame.from_file(self.client.rfile)
@@ -299,7 +336,7 @@ class TestClose(_WebSocketTest):
def test_close_payload_1(self):
self.setup_connection()
- self.client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.CLOSE, payload=b'\00\42')))
+ self.client.wfile.write(bytes(websockets.Frame(fin=1, mask=1, opcode=websockets.OPCODE.CLOSE, payload=b'\00\42')))
self.client.wfile.flush()
websockets.Frame.from_file(self.client.rfile)
@@ -309,7 +346,7 @@ class TestClose(_WebSocketTest):
def test_close_payload_2(self):
self.setup_connection()
- self.client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.CLOSE, payload=b'\00\42foobar')))
+ self.client.wfile.write(bytes(websockets.Frame(fin=1, mask=1, opcode=websockets.OPCODE.CLOSE, payload=b'\00\42foobar')))
self.client.wfile.flush()
websockets.Frame.from_file(self.client.rfile)
@@ -329,8 +366,9 @@ class TestInvalidFrame(_WebSocketTest):
# with pytest.raises(exceptions.TcpDisconnect):
frame = websockets.Frame.from_file(self.client.rfile)
- assert frame.header.opcode == 15
- assert frame.payload == b'foobar'
+ code, = struct.unpack('!H', frame.payload[:2])
+ assert code == 1002
+ assert frame.payload[2:].startswith(b'Invalid opcode')
class TestStreaming(_WebSocketTest):
@@ -360,3 +398,51 @@ class TestStreaming(_WebSocketTest):
assert frame
assert self.master.state.flows[1].messages == [] # Message not appended as the final frame isn't received
+
+
+class TestExtension(_WebSocketTest):
+
+ @classmethod
+ def handle_websockets(cls, rfile, wfile):
+ wfile.write(b'\xc1\x0f*N-*K-\xd2M\xcb\xcfOJ,\x02\x00')
+ wfile.flush()
+
+ frame = websockets.Frame.from_file(rfile)
+ assert frame.header.rsv1
+ wfile.write(b'\xc1\nJ\xce\xc9L\xcd+\x81r\x00\x00')
+ wfile.flush()
+
+ frame = websockets.Frame.from_file(rfile)
+ assert frame.header.rsv1
+ wfile.write(b'\xc2\x07\xba\xb7v\xdf{\x00\x00')
+ wfile.flush()
+
+ def test_extension(self):
+ self.setup_connection(True)
+
+ frame = websockets.Frame.from_file(self.client.rfile)
+ assert frame.header.rsv1
+
+ self.client.wfile.write(b'\xc1\x8fQ\xb7vX\x1by\xbf\x14\x9c\x9c\xa7\x15\x9ax9\x12}\xb5v')
+ self.client.wfile.flush()
+
+ frame = websockets.Frame.from_file(self.client.rfile)
+ assert frame.header.rsv1
+
+ self.client.wfile.write(b'\xc2\x87\xeb\xbb\x0csQ\x0cz\xac\x90\xbb\x0c')
+ self.client.wfile.flush()
+
+ frame = websockets.Frame.from_file(self.client.rfile)
+ assert frame.header.rsv1
+
+ assert len(self.master.state.flows[1].messages) == 5
+ assert self.master.state.flows[1].messages[0].content == 'server-foobar'
+ assert self.master.state.flows[1].messages[0].type == websockets.OPCODE.TEXT
+ assert self.master.state.flows[1].messages[1].content == 'client-foobar'
+ assert self.master.state.flows[1].messages[1].type == websockets.OPCODE.TEXT
+ assert self.master.state.flows[1].messages[2].content == 'client-foobar'
+ assert self.master.state.flows[1].messages[2].type == websockets.OPCODE.TEXT
+ assert self.master.state.flows[1].messages[3].content == b'\xde\xad\xbe\xef'
+ assert self.master.state.flows[1].messages[3].type == websockets.OPCODE.BINARY
+ assert self.master.state.flows[1].messages[4].content == b'\xde\xad\xbe\xef'
+ assert self.master.state.flows[1].messages[4].type == websockets.OPCODE.BINARY