diff options
author | Chandler Abraham <cabraham@twitter.com> | 2015-04-19 22:18:30 -0700 |
---|---|---|
committer | Chandler Abraham <cabraham@twitter.com> | 2015-04-19 22:18:30 -0700 |
commit | 4ea1ccb638366fbdac2d294c23ce8052dcf250c2 (patch) | |
tree | be91238bb8a725a4bb9fd382015be55957b73eed /test/test_websockets.py | |
parent | 74389ef04a3fdda4d388acb6d655adde78fccd7d (diff) | |
download | mitmproxy-4ea1ccb638366fbdac2d294c23ce8052dcf250c2.tar.gz mitmproxy-4ea1ccb638366fbdac2d294c23ce8052dcf250c2.tar.bz2 mitmproxy-4ea1ccb638366fbdac2d294c23ce8052dcf250c2.zip |
fixing test coverage, adding to_file/from_file reader writes to match socks.py
Diffstat (limited to 'test/test_websockets.py')
-rw-r--r-- | test/test_websockets.py | 61 |
1 files changed, 44 insertions, 17 deletions
diff --git a/test/test_websockets.py b/test/test_websockets.py index 62268423..34692183 100644 --- a/test/test_websockets.py +++ b/test/test_websockets.py @@ -1,6 +1,7 @@ from netlib import tcp from netlib import test from netlib import websockets +import io import os from nose.tools import raises @@ -20,16 +21,15 @@ class WebSocketsEchoHandler(tcp.BaseHandler): self.read_next_message() def read_next_message(self): - decoded = websockets.Frame.from_byte_stream(self.rfile.read).decoded_payload + decoded = websockets.Frame.from_file(self.rfile).decoded_payload self.on_message(decoded) def send_message(self, message): frame = websockets.Frame.default(message, from_client = False) - self.wfile.write(frame.safe_to_bytes()) - self.wfile.flush() - + frame.to_file(self.wfile) + def handshake(self): - client_hs = websockets.read_handshake(self.rfile.read, 1) + client_hs = websockets.read_handshake(self.rfile, 1) key = websockets.process_handshake_from_client(client_hs) response = websockets.create_server_handshake(key) self.wfile.write(response) @@ -62,22 +62,18 @@ class WebSocketsClient(tcp.TCPClient): self.wfile.write(handshake) self.wfile.flush() - server_handshake = websockets.read_handshake(self.rfile.read, 1) - server_nounce = websockets.process_handshake_from_server( - server_handshake, self.client_nounce - ) + server_handshake = websockets.read_handshake(self.rfile, 1) + server_nounce = websockets.process_handshake_from_server(server_handshake) if not server_nounce == websockets.create_server_nounce(self.client_nounce): self.close() def read_next_message(self): - return websockets.Frame.from_byte_stream(self.rfile.read).payload + return websockets.Frame.from_file(self.rfile).payload def send_message(self, message): frame = websockets.Frame.default(message, from_client = True) - self.wfile.write(frame.safe_to_bytes()) - self.wfile.flush() - + frame.to_file(self.wfile) class TestWebSockets(test.ServerTestBase): handler = WebSocketsEchoHandler @@ -128,10 +124,10 @@ class TestWebSockets(test.ServerTestBase): frame = websockets.Frame.default( self.random_bytes(num_bytes), is_client ) - assert frame == websockets.Frame.from_bytes(frame.to_bytes()) + assert frame == websockets.Frame.from_bytes(frame.safe_to_bytes()) - bytes = b'\x81\x11cba' - assert websockets.Frame.from_bytes(bytes).to_bytes() == bytes + bytes = b'\x81\x03cba' + assert websockets.Frame.from_bytes(bytes).safe_to_bytes() == bytes @raises(websockets.WebSocketFrameValidationException) def test_safe_to_bytes(self): @@ -139,10 +135,41 @@ class TestWebSockets(test.ServerTestBase): frame.actual_payload_length = 1 # corrupt the frame frame.safe_to_bytes() + def test_handshake(self): + bad_upgrade = "not_websockets" + bad_header_handshake = websockets.build_handshake([ + ('Host', '%s:%s' % ("a", "b")), + ('Connection', "c"), + ('Upgrade', bad_upgrade), + ('Sec-WebSocket-Key', "d"), + ('Sec-WebSocket-Version', "e") + ], "f") + + # check behavior when required header values are missing + assert None == websockets.process_handshake_from_server(bad_header_handshake) + assert None == websockets.process_handshake_from_client(bad_header_handshake) + + key = "test_key" + + client_handshake = websockets.create_client_handshake("a","b",key,"d","e") + assert key == websockets.process_handshake_from_client(client_handshake) + + server_handshake = websockets.create_server_handshake(key) + assert websockets.create_server_nounce(key) == websockets.process_handshake_from_server(server_handshake) + + handshake = websockets.create_client_handshake("a","b","c","d","e") + stream = io.BytesIO(handshake) + assert handshake == websockets.read_handshake(stream, 1) + + # ensure readhandshake doesn't loop forever on empty stream + empty_stream = io.BytesIO("") + assert "" == websockets.read_handshake(empty_stream, 1) + + class BadHandshakeHandler(WebSocketsEchoHandler): def handshake(self): - client_hs = websockets.read_handshake(self.rfile.read, 1) + client_hs = websockets.read_handshake(self.rfile, 1) websockets.process_handshake_from_client(client_hs) response = websockets.create_server_handshake("malformed_key") self.wfile.write(response) |