diff options
-rw-r--r-- | netlib/websockets/implementations.py | 12 | ||||
-rw-r--r-- | netlib/websockets/websockets.py | 100 | ||||
-rw-r--r-- | test/test_websockets.py | 45 |
3 files changed, 81 insertions, 76 deletions
diff --git a/netlib/websockets/implementations.py b/netlib/websockets/implementations.py index 1ded3b85..337c5496 100644 --- a/netlib/websockets/implementations.py +++ b/netlib/websockets/implementations.py @@ -9,7 +9,7 @@ import os # Simple websocket client and servers that are used to exercise the functionality in websockets.py # These are *not* fully RFC6455 compliant -class WebSocketsEchoHandler(tcp.BaseHandler): +class WebSocketsEchoHandler(tcp.BaseHandler): def __init__(self, connection, address, server): super(WebSocketsEchoHandler, self).__init__(connection, address, server) self.handshake_done = False @@ -22,14 +22,14 @@ class WebSocketsEchoHandler(tcp.BaseHandler): self.read_next_message() def read_next_message(self): - decoded = ws.WebSocketsFrame.from_byte_stream(self.rfile.read).decoded_payload + decoded = ws.Frame.from_byte_stream(self.rfile.read).decoded_payload self.on_message(decoded) def send_message(self, message): - frame = ws.WebSocketsFrame.default(message, from_client = False) + frame = ws.Frame.default(message, from_client = False) self.wfile.write(frame.safe_to_bytes()) self.wfile.flush() - + def handshake(self): client_hs = ws.read_handshake(self.rfile.read, 1) key = ws.process_handshake_from_client(client_hs) @@ -72,9 +72,9 @@ class WebSocketsClient(tcp.TCPClient): self.close() def read_next_message(self): - return ws.WebSocketsFrame.from_byte_stream(self.rfile.read).payload + return ws.Frame.from_byte_stream(self.rfile.read).payload def send_message(self, message): - frame = ws.WebSocketsFrame.default(message, from_client = True) + frame = ws.Frame.default(message, from_client = True) self.wfile.write(frame.safe_to_bytes()) self.wfile.flush() diff --git a/netlib/websockets/websockets.py b/netlib/websockets/websockets.py index 8782ea49..86d98caf 100644 --- a/netlib/websockets/websockets.py +++ b/netlib/websockets/websockets.py @@ -29,7 +29,7 @@ class WebSocketFrameValidationException(Exception): pass -class WebSocketsFrame(object): +class Frame(object): """ Represents one websockets frame. Constructor takes human readable forms of the frame components @@ -98,29 +98,29 @@ class WebSocketsFrame(object): length_code, actual_length = get_payload_length_pair(message) if from_client: - mask_bit = 1 + mask_bit = 1 masking_key = random_masking_key() - payload = apply_mask(message, masking_key) + payload = apply_mask(message, masking_key) else: - mask_bit = 0 + mask_bit = 0 masking_key = None - payload = message + payload = message return cls( - fin = 1, # final frame - opcode = 1, # text - mask_bit = mask_bit, - payload_length_code = length_code, - payload = payload, - masking_key = masking_key, - decoded_payload = message, + fin = 1, # final frame + opcode = 1, # text + mask_bit = mask_bit, + payload_length_code = length_code, + payload = payload, + masking_key = masking_key, + decoded_payload = message, actual_payload_length = actual_length ) def is_valid(self): """ - Validate websocket frame invariants, call at anytime to ensure the - WebSocketsFrame has not been corrupted. + Validate websocket frame invariants, call at anytime to ensure the + Frame has not been corrupted. """ try: assert 0 <= self.fin <= 1 @@ -147,17 +147,18 @@ class WebSocketsFrame(object): def human_readable(self): return "\n".join([ - ("fin - " + str(self.fin)), - ("rsv1 - " + str(self.rsv1)), - ("rsv2 - " + str(self.rsv2)), - ("rsv3 - " + str(self.rsv3)), - ("opcode - " + str(self.opcode)), - ("mask_bit - " + str(self.mask_bit)), - ("payload_length_code - " + str(self.payload_length_code)), - ("masking_key - " + str(self.masking_key)), - ("payload - " + str(self.payload)), - ("decoded_payload - " + str(self.decoded_payload)), - ("actual_payload_length - " + str(self.actual_payload_length))]) + ("fin - " + str(self.fin)), + ("rsv1 - " + str(self.rsv1)), + ("rsv2 - " + str(self.rsv2)), + ("rsv3 - " + str(self.rsv3)), + ("opcode - " + str(self.opcode)), + ("mask_bit - " + str(self.mask_bit)), + ("payload_length_code - " + str(self.payload_length_code)), + ("masking_key - " + str(self.masking_key)), + ("payload - " + str(self.payload)), + ("decoded_payload - " + str(self.decoded_payload)), + ("actual_payload_length - " + str(self.actual_payload_length)) + ]) def safe_to_bytes(self): if self.is_valid(): @@ -167,11 +168,10 @@ class WebSocketsFrame(object): def to_bytes(self): """ - Serialize the frame back into the wire format, returns a bytestring If - you haven't checked is_valid_frame() then there's no guarentees that - the serialized bytes will be correct. see safe_to_bytes() + Serialize the frame back into the wire format, returns a bytestring + If you haven't checked is_valid_frame() then there's no guarentees + that the serialized bytes will be correct. see safe_to_bytes() """ - max_16_bit_int = (1 << 16) max_64_bit_int = (1 << 63) @@ -199,13 +199,10 @@ class WebSocketsFrame(object): if self.actual_payload_length < 126: pass - elif self.actual_payload_length < max_16_bit_int: - # '!H' pack as 16 bit unsigned short # add 2 byte extended payload length bytes += struct.pack('!H', self.actual_payload_length) - elif self.actual_payload_length < max_64_bit_int: # '!Q' = pack as 64 bit unsigned long long # add 8 bytes extended payload length @@ -215,7 +212,6 @@ class WebSocketsFrame(object): bytes += self.masking_key bytes += self.payload # already will be encoded if neccessary - return bytes @classmethod @@ -264,29 +260,31 @@ class WebSocketsFrame(object): decoded_payload = payload return cls( - fin = fin, - opcode = opcode, - mask_bit = mask_bit, - payload_length_code = payload_length, - payload = payload, - masking_key = masking_key, - decoded_payload = decoded_payload, + fin = fin, + opcode = opcode, + mask_bit = mask_bit, + payload_length_code = payload_length, + payload = payload, + masking_key = masking_key, + decoded_payload = decoded_payload, actual_payload_length = actual_payload_length ) def __eq__(self, other): return ( - self.fin == other.fin and - self.rsv1 == other.rsv1 and - self.rsv2 == other.rsv2 and - self.rsv3 == other.rsv3 and - self.opcode == other.opcode and - self.mask_bit == other.mask_bit and - self.payload_length_code == other.payload_length_code and - self.masking_key == other.masking_key and - self.payload == other.payload and - self.decoded_payload == other.decoded_payload and - self.actual_payload_length == other.actual_payload_length) + self.fin == other.fin and + self.rsv1 == other.rsv1 and + self.rsv2 == other.rsv2 and + self.rsv3 == other.rsv3 and + self.opcode == other.opcode and + self.mask_bit == other.mask_bit and + self.payload_length_code == other.payload_length_code and + self.masking_key == other.masking_key and + self.payload == other.payload and + self.decoded_payload == other.decoded_payload and + self.actual_payload_length == other.actual_payload_length + ) + def apply_mask(message, masking_key): """ diff --git a/test/test_websockets.py b/test/test_websockets.py index 951aa41f..d1753638 100644 --- a/test/test_websockets.py +++ b/test/test_websockets.py @@ -5,6 +5,7 @@ from netlib.websockets import websockets as ws import os from nose.tools import raises + class TestWebSockets(test.ServerTestBase): handler = impl.WebSocketsEchoHandler @@ -22,9 +23,12 @@ class TestWebSockets(test.ServerTestBase): self.echo("hello I'm the client") def test_frame_sizes(self): - small_msg = self.random_bytes(100) # length can fit in the the 7 bit payload length - medium_msg = self.random_bytes(50000) # 50kb, sligthly larger than can fit in a 7 bit int - large_msg = self.random_bytes(150000) # 150kb, slightly larger than can fit in a 16 bit int + # length can fit in the the 7 bit payload length + small_msg = self.random_bytes(100) + # 50kb, sligthly larger than can fit in a 7 bit int + medium_msg = self.random_bytes(50000) + # 150kb, slightly larger than can fit in a 16 bit int + large_msg = self.random_bytes(150000) self.echo(small_msg) self.echo(medium_msg) @@ -33,51 +37,54 @@ class TestWebSockets(test.ServerTestBase): def test_default_builder(self): """ default builder should always generate valid frames - """ + """ msg = self.random_bytes() - client_frame = ws.WebSocketsFrame.default(msg, from_client = True) + client_frame = ws.Frame.default(msg, from_client = True) assert client_frame.is_valid() - server_frame = ws.WebSocketsFrame.default(msg, from_client = False) + server_frame = ws.Frame.default(msg, from_client = False) assert server_frame.is_valid() def test_serialization_bijection(self): """ - Ensure that various frame types can be serialized/deserialized back and forth - between to_bytes() and from_bytes() - """ + Ensure that various frame types can be serialized/deserialized back + and forth between to_bytes() and from_bytes() + """ for is_client in [True, False]: - for num_bytes in [100, 50000, 150000]: - frame = ws.WebSocketsFrame.default(self.random_bytes(num_bytes), is_client) - assert frame == ws.WebSocketsFrame.from_bytes(frame.to_bytes()) + for num_bytes in [100, 50000, 150000]: + frame = ws.Frame.default( + self.random_bytes(num_bytes), is_client + ) + assert frame == ws.Frame.from_bytes(frame.to_bytes()) bytes = b'\x81\x11cba' - assert ws.WebSocketsFrame.from_bytes(bytes).to_bytes() == bytes + assert ws.Frame.from_bytes(bytes).to_bytes() == bytes @raises(ws.WebSocketFrameValidationException) def test_safe_to_bytes(self): - frame = ws.WebSocketsFrame.default(self.random_bytes(8)) - frame.actual_payload_length = 1 #corrupt the frame + frame = ws.Frame.default(self.random_bytes(8)) + frame.actual_payload_length = 1 # corrupt the frame frame.safe_to_bytes() class BadHandshakeHandler(impl.WebSocketsEchoHandler): def handshake(self): client_hs = ws.read_handshake(self.rfile.read, 1) - key = ws.process_handshake_from_client(client_hs) - response = ws.create_server_handshake("malformed_key") + ws.process_handshake_from_client(client_hs) + response = ws.create_server_handshake("malformed_key") self.wfile.write(response) self.wfile.flush() self.handshake_done = True + class TestBadHandshake(test.ServerTestBase): """ Ensure that the client disconnects if the server handshake is malformed - """ + """ handler = BadHandshakeHandler @raises(tcp.NetLibDisconnect) def test(self): client = impl.WebSocketsClient(("127.0.0.1", self.port)) client.connect() - client.send_message("hello")
\ No newline at end of file + client.send_message("hello") |