aboutsummaryrefslogtreecommitdiffstats
path: root/test/test_websockets.py
diff options
context:
space:
mode:
authorChandler Abraham <cabraham@twitter.com>2015-04-19 22:18:30 -0700
committerChandler Abraham <cabraham@twitter.com>2015-04-19 22:18:30 -0700
commit4ea1ccb638366fbdac2d294c23ce8052dcf250c2 (patch)
treebe91238bb8a725a4bb9fd382015be55957b73eed /test/test_websockets.py
parent74389ef04a3fdda4d388acb6d655adde78fccd7d (diff)
downloadmitmproxy-4ea1ccb638366fbdac2d294c23ce8052dcf250c2.tar.gz
mitmproxy-4ea1ccb638366fbdac2d294c23ce8052dcf250c2.tar.bz2
mitmproxy-4ea1ccb638366fbdac2d294c23ce8052dcf250c2.zip
fixing test coverage, adding to_file/from_file reader writes to match socks.py
Diffstat (limited to 'test/test_websockets.py')
-rw-r--r--test/test_websockets.py61
1 files changed, 44 insertions, 17 deletions
diff --git a/test/test_websockets.py b/test/test_websockets.py
index 62268423..34692183 100644
--- a/test/test_websockets.py
+++ b/test/test_websockets.py
@@ -1,6 +1,7 @@
from netlib import tcp
from netlib import test
from netlib import websockets
+import io
import os
from nose.tools import raises
@@ -20,16 +21,15 @@ class WebSocketsEchoHandler(tcp.BaseHandler):
self.read_next_message()
def read_next_message(self):
- decoded = websockets.Frame.from_byte_stream(self.rfile.read).decoded_payload
+ decoded = websockets.Frame.from_file(self.rfile).decoded_payload
self.on_message(decoded)
def send_message(self, message):
frame = websockets.Frame.default(message, from_client = False)
- self.wfile.write(frame.safe_to_bytes())
- self.wfile.flush()
-
+ frame.to_file(self.wfile)
+
def handshake(self):
- client_hs = websockets.read_handshake(self.rfile.read, 1)
+ client_hs = websockets.read_handshake(self.rfile, 1)
key = websockets.process_handshake_from_client(client_hs)
response = websockets.create_server_handshake(key)
self.wfile.write(response)
@@ -62,22 +62,18 @@ class WebSocketsClient(tcp.TCPClient):
self.wfile.write(handshake)
self.wfile.flush()
- server_handshake = websockets.read_handshake(self.rfile.read, 1)
- server_nounce = websockets.process_handshake_from_server(
- server_handshake, self.client_nounce
- )
+ server_handshake = websockets.read_handshake(self.rfile, 1)
+ server_nounce = websockets.process_handshake_from_server(server_handshake)
if not server_nounce == websockets.create_server_nounce(self.client_nounce):
self.close()
def read_next_message(self):
- return websockets.Frame.from_byte_stream(self.rfile.read).payload
+ return websockets.Frame.from_file(self.rfile).payload
def send_message(self, message):
frame = websockets.Frame.default(message, from_client = True)
- self.wfile.write(frame.safe_to_bytes())
- self.wfile.flush()
-
+ frame.to_file(self.wfile)
class TestWebSockets(test.ServerTestBase):
handler = WebSocketsEchoHandler
@@ -128,10 +124,10 @@ class TestWebSockets(test.ServerTestBase):
frame = websockets.Frame.default(
self.random_bytes(num_bytes), is_client
)
- assert frame == websockets.Frame.from_bytes(frame.to_bytes())
+ assert frame == websockets.Frame.from_bytes(frame.safe_to_bytes())
- bytes = b'\x81\x11cba'
- assert websockets.Frame.from_bytes(bytes).to_bytes() == bytes
+ bytes = b'\x81\x03cba'
+ assert websockets.Frame.from_bytes(bytes).safe_to_bytes() == bytes
@raises(websockets.WebSocketFrameValidationException)
def test_safe_to_bytes(self):
@@ -139,10 +135,41 @@ class TestWebSockets(test.ServerTestBase):
frame.actual_payload_length = 1 # corrupt the frame
frame.safe_to_bytes()
+ def test_handshake(self):
+ bad_upgrade = "not_websockets"
+ bad_header_handshake = websockets.build_handshake([
+ ('Host', '%s:%s' % ("a", "b")),
+ ('Connection', "c"),
+ ('Upgrade', bad_upgrade),
+ ('Sec-WebSocket-Key', "d"),
+ ('Sec-WebSocket-Version', "e")
+ ], "f")
+
+ # check behavior when required header values are missing
+ assert None == websockets.process_handshake_from_server(bad_header_handshake)
+ assert None == websockets.process_handshake_from_client(bad_header_handshake)
+
+ key = "test_key"
+
+ client_handshake = websockets.create_client_handshake("a","b",key,"d","e")
+ assert key == websockets.process_handshake_from_client(client_handshake)
+
+ server_handshake = websockets.create_server_handshake(key)
+ assert websockets.create_server_nounce(key) == websockets.process_handshake_from_server(server_handshake)
+
+ handshake = websockets.create_client_handshake("a","b","c","d","e")
+ stream = io.BytesIO(handshake)
+ assert handshake == websockets.read_handshake(stream, 1)
+
+ # ensure readhandshake doesn't loop forever on empty stream
+ empty_stream = io.BytesIO("")
+ assert "" == websockets.read_handshake(empty_stream, 1)
+
+
class BadHandshakeHandler(WebSocketsEchoHandler):
def handshake(self):
- client_hs = websockets.read_handshake(self.rfile.read, 1)
+ client_hs = websockets.read_handshake(self.rfile, 1)
websockets.process_handshake_from_client(client_hs)
response = websockets.create_server_handshake("malformed_key")
self.wfile.write(response)