aboutsummaryrefslogtreecommitdiffstats
path: root/test/test_websockets.py
diff options
context:
space:
mode:
authorAldo Cortesi <aldo@nullcube.com>2015-04-20 09:38:09 +1200
committerAldo Cortesi <aldo@nullcube.com>2015-04-20 09:38:09 +1200
commit74389ef04a3fdda4d388acb6d655adde78fccd7d (patch)
treeab28a997d24cad0fe5665915203dfc96011c8d0c /test/test_websockets.py
parent08ba987a84a66d1f8f6464f2800b49cdd461dd72 (diff)
downloadmitmproxy-74389ef04a3fdda4d388acb6d655adde78fccd7d.tar.gz
mitmproxy-74389ef04a3fdda4d388acb6d655adde78fccd7d.tar.bz2
mitmproxy-74389ef04a3fdda4d388acb6d655adde78fccd7d.zip
Websockets: reorganise
- websockets.py to top-level - implementations into test suite
Diffstat (limited to 'test/test_websockets.py')
-rw-r--r--test/test_websockets.py105
1 files changed, 89 insertions, 16 deletions
diff --git a/test/test_websockets.py b/test/test_websockets.py
index d1753638..62268423 100644
--- a/test/test_websockets.py
+++ b/test/test_websockets.py
@@ -1,19 +1,92 @@
from netlib import tcp
from netlib import test
-from netlib.websockets import implementations as impl
-from netlib.websockets import websockets as ws
+from netlib import websockets
import os
from nose.tools import raises
+class WebSocketsEchoHandler(tcp.BaseHandler):
+ def __init__(self, connection, address, server):
+ super(WebSocketsEchoHandler, self).__init__(
+ connection, address, server
+ )
+ self.handshake_done = False
+
+ def handle(self):
+ while True:
+ if not self.handshake_done:
+ self.handshake()
+ else:
+ self.read_next_message()
+
+ def read_next_message(self):
+ decoded = websockets.Frame.from_byte_stream(self.rfile.read).decoded_payload
+ self.on_message(decoded)
+
+ def send_message(self, message):
+ frame = websockets.Frame.default(message, from_client = False)
+ self.wfile.write(frame.safe_to_bytes())
+ self.wfile.flush()
+
+ def handshake(self):
+ client_hs = websockets.read_handshake(self.rfile.read, 1)
+ key = websockets.process_handshake_from_client(client_hs)
+ response = websockets.create_server_handshake(key)
+ self.wfile.write(response)
+ self.wfile.flush()
+ self.handshake_done = True
+
+ def on_message(self, message):
+ if message is not None:
+ self.send_message(message)
+
+
+class WebSocketsClient(tcp.TCPClient):
+ def __init__(self, address, source_address=None):
+ super(WebSocketsClient, self).__init__(address, source_address)
+ self.version = "13"
+ self.client_nounce = websockets.create_client_nounce()
+ self.resource = "/"
+
+ def connect(self):
+ super(WebSocketsClient, self).connect()
+
+ handshake = websockets.create_client_handshake(
+ self.address.host,
+ self.address.port,
+ self.client_nounce,
+ self.version,
+ self.resource
+ )
+
+ self.wfile.write(handshake)
+ self.wfile.flush()
+
+ server_handshake = websockets.read_handshake(self.rfile.read, 1)
+ server_nounce = websockets.process_handshake_from_server(
+ server_handshake, self.client_nounce
+ )
+
+ if not server_nounce == websockets.create_server_nounce(self.client_nounce):
+ self.close()
+
+ def read_next_message(self):
+ return websockets.Frame.from_byte_stream(self.rfile.read).payload
+
+ def send_message(self, message):
+ frame = websockets.Frame.default(message, from_client = True)
+ self.wfile.write(frame.safe_to_bytes())
+ self.wfile.flush()
+
+
class TestWebSockets(test.ServerTestBase):
- handler = impl.WebSocketsEchoHandler
+ handler = WebSocketsEchoHandler
def random_bytes(self, n = 100):
return os.urandom(n)
def echo(self, msg):
- client = impl.WebSocketsClient(("127.0.0.1", self.port))
+ client = WebSocketsClient(("127.0.0.1", self.port))
client.connect()
client.send_message(msg)
response = client.read_next_message()
@@ -39,10 +112,10 @@ class TestWebSockets(test.ServerTestBase):
default builder should always generate valid frames
"""
msg = self.random_bytes()
- client_frame = ws.Frame.default(msg, from_client = True)
+ client_frame = websockets.Frame.default(msg, from_client = True)
assert client_frame.is_valid()
- server_frame = ws.Frame.default(msg, from_client = False)
+ server_frame = websockets.Frame.default(msg, from_client = False)
assert server_frame.is_valid()
def test_serialization_bijection(self):
@@ -52,26 +125,26 @@ class TestWebSockets(test.ServerTestBase):
"""
for is_client in [True, False]:
for num_bytes in [100, 50000, 150000]:
- frame = ws.Frame.default(
+ frame = websockets.Frame.default(
self.random_bytes(num_bytes), is_client
)
- assert frame == ws.Frame.from_bytes(frame.to_bytes())
+ assert frame == websockets.Frame.from_bytes(frame.to_bytes())
bytes = b'\x81\x11cba'
- assert ws.Frame.from_bytes(bytes).to_bytes() == bytes
+ assert websockets.Frame.from_bytes(bytes).to_bytes() == bytes
- @raises(ws.WebSocketFrameValidationException)
+ @raises(websockets.WebSocketFrameValidationException)
def test_safe_to_bytes(self):
- frame = ws.Frame.default(self.random_bytes(8))
+ frame = websockets.Frame.default(self.random_bytes(8))
frame.actual_payload_length = 1 # corrupt the frame
frame.safe_to_bytes()
-class BadHandshakeHandler(impl.WebSocketsEchoHandler):
+class BadHandshakeHandler(WebSocketsEchoHandler):
def handshake(self):
- client_hs = ws.read_handshake(self.rfile.read, 1)
- ws.process_handshake_from_client(client_hs)
- response = ws.create_server_handshake("malformed_key")
+ client_hs = websockets.read_handshake(self.rfile.read, 1)
+ websockets.process_handshake_from_client(client_hs)
+ response = websockets.create_server_handshake("malformed_key")
self.wfile.write(response)
self.wfile.flush()
self.handshake_done = True
@@ -85,6 +158,6 @@ class TestBadHandshake(test.ServerTestBase):
@raises(tcp.NetLibDisconnect)
def test(self):
- client = impl.WebSocketsClient(("127.0.0.1", self.port))
+ client = WebSocketsClient(("127.0.0.1", self.port))
client.connect()
client.send_message("hello")