diff options
author | Maximilian Hils <git@maximilianhils.com> | 2015-09-21 00:44:17 +0200 |
---|---|---|
committer | Maximilian Hils <git@maximilianhils.com> | 2015-09-21 00:44:17 +0200 |
commit | 73586b1be95d97f0be76e85223b53d1f4ed697d6 (patch) | |
tree | 47754fc26d196355babbff026df035271953c5d0 /netlib/websockets | |
parent | daebd1bd275a398d42cc4dbfe5c6399c7fe3b3a0 (diff) | |
download | mitmproxy-73586b1be95d97f0be76e85223b53d1f4ed697d6.tar.gz mitmproxy-73586b1be95d97f0be76e85223b53d1f4ed697d6.tar.bz2 mitmproxy-73586b1be95d97f0be76e85223b53d1f4ed697d6.zip |
python 3++
Diffstat (limited to 'netlib/websockets')
-rw-r--r-- | netlib/websockets/frame.py | 77 | ||||
-rw-r--r-- | netlib/websockets/protocol.py | 52 |
2 files changed, 80 insertions, 49 deletions
diff --git a/netlib/websockets/frame.py b/netlib/websockets/frame.py index ceddd273..55eeaf41 100644 --- a/netlib/websockets/frame.py +++ b/netlib/websockets/frame.py @@ -2,13 +2,14 @@ from __future__ import absolute_import import os import struct import io +import warnings + import six from .protocol import Masker from netlib import tcp from netlib import utils -DEFAULT = object() MAX_16_BIT_INT = (1 << 16) MAX_64_BIT_INT = (1 << 64) @@ -33,9 +34,9 @@ class FrameHeader(object): rsv1=False, rsv2=False, rsv3=False, - masking_key=DEFAULT, - mask=DEFAULT, - length_code=DEFAULT + masking_key=None, + mask=None, + length_code=None ): if not 0 <= opcode < 2 ** 4: raise ValueError("opcode must be 0-16") @@ -46,18 +47,18 @@ class FrameHeader(object): self.rsv2 = rsv2 self.rsv3 = rsv3 - if length_code is DEFAULT: + if length_code is None: self.length_code = self._make_length_code(self.payload_length) else: self.length_code = length_code - if mask is DEFAULT and masking_key is DEFAULT: + if mask is None and masking_key is None: self.mask = False - self.masking_key = "" - elif mask is DEFAULT: + self.masking_key = b"" + elif mask is None: self.mask = 1 self.masking_key = masking_key - elif masking_key is DEFAULT: + elif masking_key is None: self.mask = mask self.masking_key = os.urandom(4) else: @@ -81,7 +82,7 @@ class FrameHeader(object): else: return 127 - def human_readable(self): + def __repr__(self): vals = [ "ws frame:", OPCODE.get_name(self.opcode, hex(self.opcode)).lower() @@ -98,7 +99,11 @@ class FrameHeader(object): vals.append(" %s" % utils.pretty_size(self.payload_length)) return "".join(vals) - def to_bytes(self): + def human_readable(self): + warnings.warn("FrameHeader.to_bytes is deprecated, use bytes(frame_header) instead.", DeprecationWarning) + return repr(self) + + def __bytes__(self): first_byte = utils.setbit(0, 7, self.fin) first_byte = utils.setbit(first_byte, 6, self.rsv1) first_byte = utils.setbit(first_byte, 5, self.rsv2) @@ -107,7 +112,7 @@ class FrameHeader(object): second_byte = utils.setbit(self.length_code, 7, self.mask) - b = chr(first_byte) + chr(second_byte) + b = six.int2byte(first_byte) + six.int2byte(second_byte) if self.payload_length < 126: pass @@ -119,10 +124,17 @@ class FrameHeader(object): # '!Q' = pack as 64 bit unsigned long long # add 8 bytes extended payload length b += struct.pack('!Q', self.payload_length) - if self.masking_key is not None: + if self.masking_key: b += self.masking_key return b + if six.PY2: + __str__ = __bytes__ + + def to_bytes(self): + warnings.warn("FrameHeader.to_bytes is deprecated, use bytes(frame_header) instead.", DeprecationWarning) + return bytes(self) + @classmethod def from_file(cls, fp): """ @@ -154,7 +166,7 @@ class FrameHeader(object): if mask_bit == 1: masking_key = fp.safe_read(4) else: - masking_key = None + masking_key = False return cls( fin=fin, @@ -169,7 +181,9 @@ class FrameHeader(object): ) def __eq__(self, other): - return self.to_bytes() == other.to_bytes() + if isinstance(other, FrameHeader): + return bytes(self) == bytes(other) + return False class Frame(object): @@ -200,7 +214,7 @@ class Frame(object): +---------------------------------------------------------------+ """ - def __init__(self, payload="", **kwargs): + def __init__(self, payload=b"", **kwargs): self.payload = payload kwargs["payload_length"] = kwargs.get("payload_length", len(payload)) self.header = FrameHeader(**kwargs) @@ -216,7 +230,7 @@ class Frame(object): masking_key = os.urandom(4) else: mask_bit = 0 - masking_key = None + masking_key = False return cls( message, @@ -234,28 +248,37 @@ class Frame(object): """ return cls.from_file(tcp.Reader(io.BytesIO(bytestring))) - def human_readable(self): - ret = self.header.human_readable() + def __repr__(self): + ret = repr(self.header) if self.payload: - ret = ret + "\nPayload:\n" + utils.clean_bin(self.payload) + ret = ret + "\nPayload:\n" + utils.clean_bin(self.payload).decode("ascii") return ret - def __repr__(self): - return self.header.human_readable() + def human_readable(self): + warnings.warn("Frame.to_bytes is deprecated, use bytes(frame) instead.", DeprecationWarning) + return repr(self) - def to_bytes(self): + def __bytes__(self): """ Serialize the frame to wire format. Returns a string. """ - b = self.header.to_bytes() + b = bytes(self.header) if self.header.masking_key: b += Masker(self.header.masking_key)(self.payload) else: b += self.payload return b + if six.PY2: + __str__ = __bytes__ + + def to_bytes(self): + warnings.warn("FrameHeader.to_bytes is deprecated, use bytes(frame_header) instead.", DeprecationWarning) + return bytes(self) + def to_file(self, writer): - writer.write(self.to_bytes()) + warnings.warn("Frame.to_file is deprecated, use wfile.write(bytes(frame)) instead.", DeprecationWarning) + writer.write(bytes(self)) writer.flush() @classmethod @@ -286,4 +309,6 @@ class Frame(object): ) def __eq__(self, other): - return self.to_bytes() == other.to_bytes() + if isinstance(other, Frame): + return bytes(self) == bytes(other) + return False
\ No newline at end of file diff --git a/netlib/websockets/protocol.py b/netlib/websockets/protocol.py index 68d827a5..778fe7e7 100644 --- a/netlib/websockets/protocol.py +++ b/netlib/websockets/protocol.py @@ -17,11 +17,12 @@ from __future__ import absolute_import import base64 import hashlib import os + +import binascii import six from ..http import Headers -from .. import utils -websockets_magic = '258EAFA5-E914-47DA-95CA-C5AB0DC85B11' +websockets_magic = b'258EAFA5-E914-47DA-95CA-C5AB0DC85B11' VERSION = "13" HEADER_WEBSOCKET_KEY = 'sec-websocket-key' @@ -41,14 +42,21 @@ class Masker(object): def __init__(self, key): self.key = key - self.masks = [six.byte2int(byte) for byte in key] self.offset = 0 def mask(self, offset, data): - result = "" - for c in data: - result += chr(ord(c) ^ self.masks[offset % 4]) - offset += 1 + result = bytearray(data) + if six.PY2: + for i in range(len(data)): + result[i] ^= ord(self.key[offset % 4]) + offset += 1 + result = str(result) + else: + + for i in range(len(data)): + result[i] ^= self.key[offset % 4] + offset += 1 + result = bytes(result) return result def __call__(self, data): @@ -73,37 +81,35 @@ class WebsocketsProtocol(object): """ if not key: key = base64.b64encode(os.urandom(16)).decode('utf-8') - return Headers([ - ('Connection', 'Upgrade'), - ('Upgrade', 'websocket'), - (HEADER_WEBSOCKET_KEY, key), - (HEADER_WEBSOCKET_VERSION, version) - ]) + return Headers(**{ + HEADER_WEBSOCKET_KEY: key, + HEADER_WEBSOCKET_VERSION: version, + "Connection": "Upgrade", + "Upgrade": "websocket", + }) @classmethod def server_handshake_headers(self, key): """ The server response is a valid HTTP 101 response. """ - return Headers( - [ - ('Connection', 'Upgrade'), - ('Upgrade', 'websocket'), - (HEADER_WEBSOCKET_ACCEPT, self.create_server_nonce(key)) - ] - ) + return Headers(**{ + HEADER_WEBSOCKET_ACCEPT: self.create_server_nonce(key), + "Connection": "Upgrade", + "Upgrade": "websocket", + }) @classmethod def check_client_handshake(self, headers): - if headers.get("upgrade") != "websocket": + if headers.get("upgrade") != b"websocket": return return headers.get(HEADER_WEBSOCKET_KEY) @classmethod def check_server_handshake(self, headers): - if headers.get("upgrade") != "websocket": + if headers.get("upgrade") != b"websocket": return return headers.get(HEADER_WEBSOCKET_ACCEPT) @@ -111,5 +117,5 @@ class WebsocketsProtocol(object): @classmethod def create_server_nonce(self, client_nonce): return base64.b64encode( - hashlib.sha1(client_nonce + websockets_magic).hexdigest().decode('hex') + binascii.unhexlify(hashlib.sha1(client_nonce + websockets_magic).hexdigest()) ) |