aboutsummaryrefslogtreecommitdiffstats
path: root/netlib/websockets.py
diff options
context:
space:
mode:
authorAldo Cortesi <aldo@nullcube.com>2015-04-24 15:09:21 +1200
committerAldo Cortesi <aldo@nullcube.com>2015-04-24 15:09:21 +1200
commitf22bc0b4c74776bcc312fed1f4ceede83f869a6e (patch)
tree7d8b947d9940bb0faa68fe21d924642f6c3d1667 /netlib/websockets.py
parent3519871f340cb0466fc6935d6e8e3b7822d36c52 (diff)
downloadmitmproxy-f22bc0b4c74776bcc312fed1f4ceede83f869a6e.tar.gz
mitmproxy-f22bc0b4c74776bcc312fed1f4ceede83f869a6e.tar.bz2
mitmproxy-f22bc0b4c74776bcc312fed1f4ceede83f869a6e.zip
websocket: interface refactoring
- Separate out FrameHeader. We need to deal with this separately in many circumstances. - Simpler equality scheme. - Bits are now specified by truthiness - we don't care about the integer value. This means lots of validation is not needed any more.
Diffstat (limited to 'netlib/websockets.py')
-rw-r--r--netlib/websockets.py303
1 files changed, 143 insertions, 160 deletions
diff --git a/netlib/websockets.py b/netlib/websockets.py
index 7c127563..016e75c2 100644
--- a/netlib/websockets.py
+++ b/netlib/websockets.py
@@ -1,5 +1,4 @@
from __future__ import absolute_import
-
import base64
import hashlib
import os
@@ -83,23 +82,6 @@ def server_handshake_headers(key):
)
-def get_payload_length_pair(payload_bytestring):
- """
- A websockets frame contains an initial length_code, and an optional
- extended length code to represent the actual length if length code is
- larger than 125
- """
- actual_length = len(payload_bytestring)
-
- if actual_length <= 125:
- length_code = actual_length
- elif actual_length >= 126 and actual_length <= 65535:
- length_code = 126
- else:
- length_code = 127
- return (length_code, actual_length)
-
-
def make_length_code(len):
"""
A websockets frame contains an initial length_code, and an optional
@@ -132,40 +114,113 @@ def create_server_nonce(client_nonce):
)
-def frame_header_bytes(
- opcode = 0,
- payload_length = 0,
- fin = 0,
- rsv1 = 0,
- rsv2 = 0,
- rsv3 = 0,
- mask = 0,
- masking_key = None,
- length_code = None
-):
- first_byte = (fin << 7) | (rsv1 << 6) |\
- (rsv2 << 4) | (rsv3 << 4) | opcode
-
- if length_code is None:
- length_code = make_length_code(payload_length)
-
- second_byte = (mask << 7) | length_code
-
- b = chr(first_byte) + chr(second_byte)
-
- if payload_length < 126:
- pass
- elif payload_length < MAX_16_BIT_INT:
- # '!H' pack as 16 bit unsigned short
- # add 2 byte extended payload length
- b += struct.pack('!H', payload_length)
- elif payload_length < MAX_64_BIT_INT:
- # '!Q' = pack as 64 bit unsigned long long
- # add 8 bytes extended payload length
- b += struct.pack('!Q', payload_length)
- if masking_key is not None:
- b += masking_key
- return b
+DEFAULT = object()
+class FrameHeader:
+ def __init__(
+ self,
+ opcode = OPCODE.TEXT,
+ payload_length = 0,
+ fin = False,
+ rsv1 = False,
+ rsv2 = False,
+ rsv3 = False,
+ masking_key = None,
+ mask = DEFAULT,
+ length_code = DEFAULT
+ ):
+ self.opcode = opcode
+ self.payload_length = payload_length
+ self.fin = fin
+ self.rsv1 = rsv1
+ self.rsv2 = rsv2
+ self.rsv3 = rsv3
+ self.mask = mask
+ self.masking_key = masking_key
+ self.length_code = length_code
+
+ def to_bytes(self):
+ first_byte = utils.setbit(0, 7, self.fin)
+ first_byte = utils.setbit(first_byte, 6, self.rsv1)
+ first_byte = utils.setbit(first_byte, 5, self.rsv2)
+ first_byte = utils.setbit(first_byte, 4, self.rsv3)
+ first_byte = first_byte | self.opcode
+
+ if self.length_code is DEFAULT:
+ length_code = make_length_code(self.payload_length)
+ else:
+ length_code = self.length_code
+
+ if self.mask is DEFAULT:
+ mask = bool(self.masking_key)
+ else:
+ mask = self.mask
+
+ second_byte = (mask << 7) | length_code
+
+ b = chr(first_byte) + chr(second_byte)
+
+ if self.payload_length < 126:
+ pass
+ elif self.payload_length < MAX_16_BIT_INT:
+ # '!H' pack as 16 bit unsigned short
+ # add 2 byte extended payload length
+ b += struct.pack('!H', self.payload_length)
+ elif self.payload_length < MAX_64_BIT_INT:
+ # '!Q' = pack as 64 bit unsigned long long
+ # add 8 bytes extended payload length
+ b += struct.pack('!Q', self.payload_length)
+ if self.masking_key is not None:
+ b += self.masking_key
+ return b
+
+ @classmethod
+ def from_file(klass, fp):
+ """
+ read a websockets frame header
+ """
+ first_byte = utils.bytes_to_int(fp.read(1))
+ second_byte = utils.bytes_to_int(fp.read(1))
+
+ fin = utils.getbit(first_byte, 7)
+ rsv1 = utils.getbit(first_byte, 6)
+ rsv2 = utils.getbit(first_byte, 5)
+ rsv3 = utils.getbit(first_byte, 4)
+ # grab right most 4 bits by and-ing with 00001111
+ opcode = first_byte & 15
+ # grab left most bit
+ mask_bit = second_byte >> 7
+ # grab the next 7 bits
+ length_code = second_byte & 127
+
+ # payload_lengthy > 125 indicates you need to read more bytes
+ # to get the actual payload length
+ if length_code <= 125:
+ payload_length = length_code
+ elif length_code == 126:
+ payload_length = utils.bytes_to_int(fp.read(2))
+ elif length_code == 127:
+ payload_length = utils.bytes_to_int(fp.read(8))
+
+ # masking key only present if mask bit set
+ if mask_bit == 1:
+ masking_key = fp.read(4)
+ else:
+ masking_key = None
+
+ return klass(
+ fin = fin,
+ rsv1 = rsv1,
+ rsv2 = rsv2,
+ rsv3 = rsv3,
+ opcode = opcode,
+ mask = mask_bit,
+ length_code = length_code,
+ payload_length = payload_length,
+ masking_key = masking_key,
+ )
+
+ def __eq__(self, other):
+ return self.to_bytes() == other.to_bytes()
class Frame(object):
@@ -194,27 +249,10 @@ class Frame(object):
| Payload Data continued ... |
+---------------------------------------------------------------+
"""
- def __init__(
- self,
- fin, # decmial integer 1 or 0
- opcode, # decmial integer 1 - 4
- payload = "", # bytestring
- masking_key = None, # 32 bit byte string
- mask_bit = 0, # decimal integer 1 or 0
- payload_length_code = None, # decimal integer 1 - 127
- rsv1 = 0, # decimal integer 1 or 0
- rsv2 = 0, # decimal integer 1 or 0
- rsv3 = 0, # decimal integer 1 or 0
- ):
- self.fin = fin
- self.rsv1 = rsv1
- self.rsv2 = rsv2
- self.rsv3 = rsv3
- self.opcode = opcode
- self.mask_bit = mask_bit
- self.payload_length_code = payload_length_code
- self.masking_key = masking_key
+ def __init__(self, payload = "", **kwargs):
self.payload = payload
+ kwargs["payload_length"] = kwargs.get("payload_length", len(payload))
+ self.header = FrameHeader(**kwargs)
@classmethod
def default(cls, message, from_client = False):
@@ -230,10 +268,10 @@ class Frame(object):
masking_key = None
return cls(
+ message,
fin = 1, # final frame
opcode = OPCODE.TEXT, # text
- mask_bit = mask_bit,
- payload = message,
+ mask = mask_bit,
masking_key = masking_key,
)
@@ -243,30 +281,30 @@ class Frame(object):
Frame has not been corrupted.
"""
constraints = [
- 0 <= self.fin <= 1,
- 0 <= self.rsv1 <= 1,
- 0 <= self.rsv2 <= 1,
- 0 <= self.rsv3 <= 1,
- 1 <= self.opcode <= 4,
- 0 <= self.mask_bit <= 1,
+ 0 <= self.header.fin <= 1,
+ 0 <= self.header.rsv1 <= 1,
+ 0 <= self.header.rsv2 <= 1,
+ 0 <= self.header.rsv3 <= 1,
+ 1 <= self.header.opcode <= 4,
+ 0 <= self.header.mask <= 1,
#1 <= self.payload_length_code <= 127,
- 1 <= len(self.masking_key) <= 4 if self.mask_bit else True,
- self.masking_key is not None if self.mask_bit else True
+ 1 <= len(self.header.masking_key) <= 4 if self.header.mask else True,
+ self.header.masking_key is not None if self.header.mask else True
]
if not all(constraints):
return False
return True
- def human_readable(self): # pragma: nocover
+ def human_readable(self):
return "\n".join([
- ("fin - " + str(self.fin)),
- ("rsv1 - " + str(self.rsv1)),
- ("rsv2 - " + str(self.rsv2)),
- ("rsv3 - " + str(self.rsv3)),
- ("opcode - " + str(self.opcode)),
- ("mask_bit - " + str(self.mask_bit)),
- ("payload_length_code - " + str(self.payload_length_code)),
- ("masking_key - " + repr(str(self.masking_key))),
+ ("fin - " + str(self.header.fin)),
+ ("rsv1 - " + str(self.header.rsv1)),
+ ("rsv2 - " + str(self.header.rsv2)),
+ ("rsv3 - " + str(self.header.rsv3)),
+ ("opcode - " + str(self.header.opcode)),
+ ("mask - " + str(self.header.mask)),
+ ("length_code - " + str(self.header.length_code)),
+ ("masking_key - " + repr(str(self.header.masking_key))),
("payload - " + repr(str(self.payload))),
])
@@ -284,18 +322,9 @@ class Frame(object):
If you haven't checked is_valid_frame() then there's no guarentees
that the serialized bytes will be correct. see safe_to_bytes()
"""
- b = frame_header_bytes(
- opcode = self.opcode,
- fin = self.fin,
- rsv1 = self.rsv1,
- rsv2 = self.rsv2,
- rsv3 = self.rsv3,
- mask = self.mask_bit,
- masking_key = self.masking_key,
- payload_length = len(self.payload) if self.payload else 0
- )
- if self.masking_key:
- b += apply_mask(self.payload, self.masking_key)
+ b = self.header.to_bytes()
+ if self.header.masking_key:
+ b += apply_mask(self.payload, self.header.masking_key)
else:
b += self.payload
return b
@@ -312,66 +341,20 @@ class Frame(object):
fp is a "file like" object that could be backed by a network
stream or a disk or an in memory stream reader
"""
- first_byte = utils.bytes_to_int(fp.read(1))
- second_byte = utils.bytes_to_int(fp.read(1))
-
- # grab the left most bit
- fin = first_byte >> 7
- # grab right most 4 bits by and-ing with 00001111
- opcode = first_byte & 15
- # grab left most bit
- mask_bit = second_byte >> 7
- # grab the next 7 bits
- payload_length = second_byte & 127
-
- # payload_lengthy > 125 indicates you need to read more bytes
- # to get the actual payload length
- if payload_length <= 125:
- actual_payload_length = payload_length
-
- elif payload_length == 126:
- actual_payload_length = utils.bytes_to_int(fp.read(2))
-
- elif payload_length == 127:
- actual_payload_length = utils.bytes_to_int(fp.read(8))
-
- # masking key only present if mask bit set
- if mask_bit == 1:
- masking_key = fp.read(4)
- else:
- masking_key = None
-
- payload = fp.read(actual_payload_length)
+ header = FrameHeader.from_file(fp)
+ payload = fp.read(header.payload_length)
- if mask_bit == 1 and masking_key:
- payload = apply_mask(payload, masking_key)
+ if header.mask == 1 and header.masking_key:
+ payload = apply_mask(payload, header.masking_key)
return cls(
- fin = fin,
- opcode = opcode,
- mask_bit = mask_bit,
- payload_length_code = payload_length,
- payload = payload,
- masking_key = masking_key,
+ payload,
+ fin = header.fin,
+ opcode = header.opcode,
+ mask = header.mask,
+ payload_length = header.payload_length,
+ masking_key = header.masking_key,
)
def __eq__(self, other):
- if self.payload_length_code is None:
- myplc = make_length_code(len(self.payload))
- else:
- myplc = self.payload_length_code
- if other.payload_length_code is None:
- otherplc = make_length_code(len(other.payload))
- else:
- otherplc = other.payload_length_code
- return (
- self.fin == other.fin and
- self.rsv1 == other.rsv1 and
- self.rsv2 == other.rsv2 and
- self.rsv3 == other.rsv3 and
- self.opcode == other.opcode and
- self.mask_bit == other.mask_bit and
- self.masking_key == other.masking_key and
- self.payload == other.payload,
- myplc == otherplc
- )
+ return self.to_bytes() == other.to_bytes()