aboutsummaryrefslogtreecommitdiffstats
path: root/netlib/websockets
diff options
context:
space:
mode:
authorMaximilian Hils <git@maximilianhils.com>2016-02-18 13:03:40 +0100
committerMaximilian Hils <git@maximilianhils.com>2016-02-18 13:03:40 +0100
commitd33d3663ecb166461d9cb5a78a29b44ee7a8fbb7 (patch)
treefe8856f65d1dafa946150c5acbaf6e942ba3c026 /netlib/websockets
parent294774d6f0dee95b02a93307ec493b111b7f171e (diff)
downloadmitmproxy-d33d3663ecb166461d9cb5a78a29b44ee7a8fbb7.tar.gz
mitmproxy-d33d3663ecb166461d9cb5a78a29b44ee7a8fbb7.tar.bz2
mitmproxy-d33d3663ecb166461d9cb5a78a29b44ee7a8fbb7.zip
combine projects
Diffstat (limited to 'netlib/websockets')
-rw-r--r--netlib/websockets/__init__.py2
-rw-r--r--netlib/websockets/frame.py316
-rw-r--r--netlib/websockets/protocol.py115
3 files changed, 433 insertions, 0 deletions
diff --git a/netlib/websockets/__init__.py b/netlib/websockets/__init__.py
new file mode 100644
index 00000000..1c143919
--- /dev/null
+++ b/netlib/websockets/__init__.py
@@ -0,0 +1,2 @@
+from .frame import *
+from .protocol import *
diff --git a/netlib/websockets/frame.py b/netlib/websockets/frame.py
new file mode 100644
index 00000000..fce2c9d3
--- /dev/null
+++ b/netlib/websockets/frame.py
@@ -0,0 +1,316 @@
+from __future__ import absolute_import
+import os
+import struct
+import io
+import warnings
+
+import six
+
+from .protocol import Masker
+from netlib import tcp
+from netlib import utils
+
+
+MAX_16_BIT_INT = (1 << 16)
+MAX_64_BIT_INT = (1 << 64)
+
+DEFAULT=object()
+
+OPCODE = utils.BiDi(
+ CONTINUE=0x00,
+ TEXT=0x01,
+ BINARY=0x02,
+ CLOSE=0x08,
+ PING=0x09,
+ PONG=0x0a
+)
+
+
+class FrameHeader(object):
+
+ def __init__(
+ self,
+ opcode=OPCODE.TEXT,
+ payload_length=0,
+ fin=False,
+ rsv1=False,
+ rsv2=False,
+ rsv3=False,
+ masking_key=DEFAULT,
+ mask=DEFAULT,
+ length_code=DEFAULT
+ ):
+ if not 0 <= opcode < 2 ** 4:
+ raise ValueError("opcode must be 0-16")
+ self.opcode = opcode
+ self.payload_length = payload_length
+ self.fin = fin
+ self.rsv1 = rsv1
+ self.rsv2 = rsv2
+ self.rsv3 = rsv3
+
+ if length_code is DEFAULT:
+ self.length_code = self._make_length_code(self.payload_length)
+ else:
+ self.length_code = length_code
+
+ if mask is DEFAULT and masking_key is DEFAULT:
+ self.mask = False
+ self.masking_key = b""
+ elif mask is DEFAULT:
+ self.mask = 1
+ self.masking_key = masking_key
+ elif masking_key is DEFAULT:
+ self.mask = mask
+ self.masking_key = os.urandom(4)
+ else:
+ self.mask = mask
+ self.masking_key = masking_key
+
+ if self.masking_key and len(self.masking_key) != 4:
+ raise ValueError("Masking key must be 4 bytes.")
+
+ @classmethod
+ def _make_length_code(self, length):
+ """
+ A websockets frame contains an initial length_code, and an optional
+ extended length code to represent the actual length if length code is
+ larger than 125
+ """
+ if length <= 125:
+ return length
+ elif length >= 126 and length <= 65535:
+ return 126
+ else:
+ return 127
+
+ def __repr__(self):
+ vals = [
+ "ws frame:",
+ OPCODE.get_name(self.opcode, hex(self.opcode)).lower()
+ ]
+ flags = []
+ for i in ["fin", "rsv1", "rsv2", "rsv3", "mask"]:
+ if getattr(self, i):
+ flags.append(i)
+ if flags:
+ vals.extend([":", "|".join(flags)])
+ if self.masking_key:
+ vals.append(":key=%s" % repr(self.masking_key))
+ if self.payload_length:
+ vals.append(" %s" % utils.pretty_size(self.payload_length))
+ return "".join(vals)
+
+ def human_readable(self):
+ warnings.warn("FrameHeader.to_bytes is deprecated, use bytes(frame_header) instead.", DeprecationWarning)
+ return repr(self)
+
+ def __bytes__(self):
+ first_byte = utils.setbit(0, 7, self.fin)
+ first_byte = utils.setbit(first_byte, 6, self.rsv1)
+ first_byte = utils.setbit(first_byte, 5, self.rsv2)
+ first_byte = utils.setbit(first_byte, 4, self.rsv3)
+ first_byte = first_byte | self.opcode
+
+ second_byte = utils.setbit(self.length_code, 7, self.mask)
+
+ b = six.int2byte(first_byte) + six.int2byte(second_byte)
+
+ if self.payload_length < 126:
+ pass
+ elif self.payload_length < MAX_16_BIT_INT:
+ # '!H' pack as 16 bit unsigned short
+ # add 2 byte extended payload length
+ b += struct.pack('!H', self.payload_length)
+ elif self.payload_length < MAX_64_BIT_INT:
+ # '!Q' = pack as 64 bit unsigned long long
+ # add 8 bytes extended payload length
+ b += struct.pack('!Q', self.payload_length)
+ if self.masking_key:
+ b += self.masking_key
+ return b
+
+ if six.PY2:
+ __str__ = __bytes__
+
+ def to_bytes(self):
+ warnings.warn("FrameHeader.to_bytes is deprecated, use bytes(frame_header) instead.", DeprecationWarning)
+ return bytes(self)
+
+ @classmethod
+ def from_file(cls, fp):
+ """
+ read a websockets frame header
+ """
+ first_byte = six.byte2int(fp.safe_read(1))
+ second_byte = six.byte2int(fp.safe_read(1))
+
+ fin = utils.getbit(first_byte, 7)
+ rsv1 = utils.getbit(first_byte, 6)
+ rsv2 = utils.getbit(first_byte, 5)
+ rsv3 = utils.getbit(first_byte, 4)
+ # grab right-most 4 bits
+ opcode = first_byte & 15
+ mask_bit = utils.getbit(second_byte, 7)
+ # grab the next 7 bits
+ length_code = second_byte & 127
+
+ # payload_lengthy > 125 indicates you need to read more bytes
+ # to get the actual payload length
+ if length_code <= 125:
+ payload_length = length_code
+ elif length_code == 126:
+ payload_length, = struct.unpack("!H", fp.safe_read(2))
+ elif length_code == 127:
+ payload_length, = struct.unpack("!Q", fp.safe_read(8))
+
+ # masking key only present if mask bit set
+ if mask_bit == 1:
+ masking_key = fp.safe_read(4)
+ else:
+ masking_key = None
+
+ return cls(
+ fin=fin,
+ rsv1=rsv1,
+ rsv2=rsv2,
+ rsv3=rsv3,
+ opcode=opcode,
+ mask=mask_bit,
+ length_code=length_code,
+ payload_length=payload_length,
+ masking_key=masking_key,
+ )
+
+ def __eq__(self, other):
+ if isinstance(other, FrameHeader):
+ return bytes(self) == bytes(other)
+ return False
+
+
+class Frame(object):
+
+ """
+ Represents one websockets frame.
+ Constructor takes human readable forms of the frame components
+ from_bytes() is also avaliable.
+
+ WebSockets Frame as defined in RFC6455
+
+ 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
+ +-+-+-+-+-------+-+-------------+-------------------------------+
+ |F|R|R|R| opcode|M| Payload len | Extended payload length |
+ |I|S|S|S| (4) |A| (7) | (16/64) |
+ |N|V|V|V| |S| | (if payload len==126/127) |
+ | |1|2|3| |K| | |
+ +-+-+-+-+-------+-+-------------+ - - - - - - - - - - - - - - - +
+ | Extended payload length continued, if payload len == 127 |
+ + - - - - - - - - - - - - - - - +-------------------------------+
+ | |Masking-key, if MASK set to 1 |
+ +-------------------------------+-------------------------------+
+ | Masking-key (continued) | Payload Data |
+ +-------------------------------- - - - - - - - - - - - - - - - +
+ : Payload Data continued ... :
+ + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - +
+ | Payload Data continued ... |
+ +---------------------------------------------------------------+
+ """
+
+ def __init__(self, payload=b"", **kwargs):
+ self.payload = payload
+ kwargs["payload_length"] = kwargs.get("payload_length", len(payload))
+ self.header = FrameHeader(**kwargs)
+
+ @classmethod
+ def default(cls, message, from_client=False):
+ """
+ Construct a basic websocket frame from some default values.
+ Creates a non-fragmented text frame.
+ """
+ if from_client:
+ mask_bit = 1
+ masking_key = os.urandom(4)
+ else:
+ mask_bit = 0
+ masking_key = None
+
+ return cls(
+ message,
+ fin=1, # final frame
+ opcode=OPCODE.TEXT, # text
+ mask=mask_bit,
+ masking_key=masking_key,
+ )
+
+ @classmethod
+ def from_bytes(cls, bytestring):
+ """
+ Construct a websocket frame from an in-memory bytestring
+ to construct a frame from a stream of bytes, use from_file() directly
+ """
+ return cls.from_file(tcp.Reader(io.BytesIO(bytestring)))
+
+ def __repr__(self):
+ ret = repr(self.header)
+ if self.payload:
+ ret = ret + "\nPayload:\n" + utils.clean_bin(self.payload).decode("ascii")
+ return ret
+
+ def human_readable(self):
+ warnings.warn("Frame.to_bytes is deprecated, use bytes(frame) instead.", DeprecationWarning)
+ return repr(self)
+
+ def __bytes__(self):
+ """
+ Serialize the frame to wire format. Returns a string.
+ """
+ b = bytes(self.header)
+ if self.header.masking_key:
+ b += Masker(self.header.masking_key)(self.payload)
+ else:
+ b += self.payload
+ return b
+
+ if six.PY2:
+ __str__ = __bytes__
+
+ def to_bytes(self):
+ warnings.warn("FrameHeader.to_bytes is deprecated, use bytes(frame_header) instead.", DeprecationWarning)
+ return bytes(self)
+
+ def to_file(self, writer):
+ warnings.warn("Frame.to_file is deprecated, use wfile.write(bytes(frame)) instead.", DeprecationWarning)
+ writer.write(bytes(self))
+ writer.flush()
+
+ @classmethod
+ def from_file(cls, fp):
+ """
+ read a websockets frame sent by a server or client
+
+ fp is a "file like" object that could be backed by a network
+ stream or a disk or an in memory stream reader
+ """
+ header = FrameHeader.from_file(fp)
+ payload = fp.safe_read(header.payload_length)
+
+ if header.mask == 1 and header.masking_key:
+ payload = Masker(header.masking_key)(payload)
+
+ return cls(
+ payload,
+ fin=header.fin,
+ opcode=header.opcode,
+ mask=header.mask,
+ payload_length=header.payload_length,
+ masking_key=header.masking_key,
+ rsv1=header.rsv1,
+ rsv2=header.rsv2,
+ rsv3=header.rsv3,
+ length_code=header.length_code
+ )
+
+ def __eq__(self, other):
+ if isinstance(other, Frame):
+ return bytes(self) == bytes(other)
+ return False
diff --git a/netlib/websockets/protocol.py b/netlib/websockets/protocol.py
new file mode 100644
index 00000000..1e95fa1c
--- /dev/null
+++ b/netlib/websockets/protocol.py
@@ -0,0 +1,115 @@
+
+
+
+# Colleciton of utility functions that implement small portions of the RFC6455
+# WebSockets Protocol Useful for building WebSocket clients and servers.
+#
+# Emphassis is on readabilty, simplicity and modularity, not performance or
+# completeness
+#
+# This is a work in progress and does not yet contain all the utilites need to
+# create fully complient client/servers #
+# Spec: https://tools.ietf.org/html/rfc6455
+
+# The magic sha that websocket servers must know to prove they understand
+# RFC6455
+from __future__ import absolute_import
+import base64
+import hashlib
+import os
+
+import binascii
+import six
+from ..http import Headers
+
+websockets_magic = b'258EAFA5-E914-47DA-95CA-C5AB0DC85B11'
+VERSION = "13"
+
+
+class Masker(object):
+
+ """
+ Data sent from the server must be masked to prevent malicious clients
+ from sending data over the wire in predictable patterns
+
+ Servers do not have to mask data they send to the client.
+ https://tools.ietf.org/html/rfc6455#section-5.3
+ """
+
+ def __init__(self, key):
+ self.key = key
+ self.offset = 0
+
+ def mask(self, offset, data):
+ result = bytearray(data)
+ if six.PY2:
+ for i in range(len(data)):
+ result[i] ^= ord(self.key[offset % 4])
+ offset += 1
+ result = str(result)
+ else:
+
+ for i in range(len(data)):
+ result[i] ^= self.key[offset % 4]
+ offset += 1
+ result = bytes(result)
+ return result
+
+ def __call__(self, data):
+ ret = self.mask(self.offset, data)
+ self.offset += len(ret)
+ return ret
+
+
+class WebsocketsProtocol(object):
+
+ def __init__(self):
+ pass
+
+ @classmethod
+ def client_handshake_headers(self, key=None, version=VERSION):
+ """
+ Create the headers for a valid HTTP upgrade request. If Key is not
+ specified, it is generated, and can be found in sec-websocket-key in
+ the returned header set.
+
+ Returns an instance of Headers
+ """
+ if not key:
+ key = base64.b64encode(os.urandom(16)).decode('ascii')
+ return Headers(
+ sec_websocket_key=key,
+ sec_websocket_version=version,
+ connection="Upgrade",
+ upgrade="websocket",
+ )
+
+ @classmethod
+ def server_handshake_headers(self, key):
+ """
+ The server response is a valid HTTP 101 response.
+ """
+ return Headers(
+ sec_websocket_accept=self.create_server_nonce(key),
+ connection="Upgrade",
+ upgrade="websocket"
+ )
+
+
+ @classmethod
+ def check_client_handshake(self, headers):
+ if headers.get("upgrade") != "websocket":
+ return
+ return headers.get("sec-websocket-key")
+
+
+ @classmethod
+ def check_server_handshake(self, headers):
+ if headers.get("upgrade") != "websocket":
+ return
+ return headers.get("sec-websocket-accept")
+
+
+ @classmethod
+ def create_server_nonce(self, client_nonce):
+ return base64.b64encode(hashlib.sha1(client_nonce + websockets_magic).digest())