aboutsummaryrefslogtreecommitdiffstats
path: root/test/netlib/websockets/test_websockets.py
diff options
context:
space:
mode:
Diffstat (limited to 'test/netlib/websockets/test_websockets.py')
-rw-r--r--test/netlib/websockets/test_websockets.py269
1 files changed, 0 insertions, 269 deletions
diff --git a/test/netlib/websockets/test_websockets.py b/test/netlib/websockets/test_websockets.py
deleted file mode 100644
index 50fa26e6..00000000
--- a/test/netlib/websockets/test_websockets.py
+++ /dev/null
@@ -1,269 +0,0 @@
-import os
-
-from netlib.http.http1 import read_response, read_request
-
-from netlib import tcp
-from netlib import tutils
-from netlib import websockets
-from netlib.http import status_codes
-from netlib.tutils import treq
-from netlib import exceptions
-
-from .. import tservers
-
-
-class WebSocketsEchoHandler(tcp.BaseHandler):
-
- def __init__(self, connection, address, server):
- super(WebSocketsEchoHandler, self).__init__(
- connection, address, server
- )
- self.protocol = websockets.WebsocketsProtocol()
- self.handshake_done = False
-
- def handle(self):
- while True:
- if not self.handshake_done:
- self.handshake()
- else:
- self.read_next_message()
-
- def read_next_message(self):
- frame = websockets.Frame.from_file(self.rfile)
- self.on_message(frame.payload)
-
- def send_message(self, message):
- frame = websockets.Frame.default(message, from_client=False)
- frame.to_file(self.wfile)
-
- def handshake(self):
-
- req = read_request(self.rfile)
- key = self.protocol.check_client_handshake(req.headers)
-
- preamble = 'HTTP/1.1 101 %s' % status_codes.RESPONSES.get(101)
- self.wfile.write(preamble.encode() + b"\r\n")
- headers = self.protocol.server_handshake_headers(key)
- self.wfile.write(str(headers) + "\r\n")
- self.wfile.flush()
- self.handshake_done = True
-
- def on_message(self, message):
- if message is not None:
- self.send_message(message)
-
-
-class WebSocketsClient(tcp.TCPClient):
-
- def __init__(self, address, source_address=None):
- super(WebSocketsClient, self).__init__(address, source_address)
- self.protocol = websockets.WebsocketsProtocol()
- self.client_nonce = None
-
- def connect(self):
- super(WebSocketsClient, self).connect()
-
- preamble = b'GET / HTTP/1.1'
- self.wfile.write(preamble + b"\r\n")
- headers = self.protocol.client_handshake_headers()
- self.client_nonce = headers["sec-websocket-key"].encode("ascii")
- self.wfile.write(bytes(headers) + b"\r\n")
- self.wfile.flush()
-
- resp = read_response(self.rfile, treq(method=b"GET"))
- server_nonce = self.protocol.check_server_handshake(resp.headers)
-
- if not server_nonce == self.protocol.create_server_nonce(self.client_nonce):
- self.close()
-
- def read_next_message(self):
- return websockets.Frame.from_file(self.rfile).payload
-
- def send_message(self, message):
- frame = websockets.Frame.default(message, from_client=True)
- frame.to_file(self.wfile)
-
-
-class TestWebSockets(tservers.ServerTestBase):
- handler = WebSocketsEchoHandler
-
- def __init__(self):
- self.protocol = websockets.WebsocketsProtocol()
-
- def random_bytes(self, n=100):
- return os.urandom(n)
-
- def echo(self, msg):
- client = WebSocketsClient(("127.0.0.1", self.port))
- client.connect()
- client.send_message(msg)
- response = client.read_next_message()
- assert response == msg
-
- def test_simple_echo(self):
- self.echo(b"hello I'm the client")
-
- def test_frame_sizes(self):
- # length can fit in the the 7 bit payload length
- small_msg = self.random_bytes(100)
- # 50kb, sligthly larger than can fit in a 7 bit int
- medium_msg = self.random_bytes(50000)
- # 150kb, slightly larger than can fit in a 16 bit int
- large_msg = self.random_bytes(150000)
-
- self.echo(small_msg)
- self.echo(medium_msg)
- self.echo(large_msg)
-
- def test_default_builder(self):
- """
- default builder should always generate valid frames
- """
- msg = self.random_bytes()
- assert websockets.Frame.default(msg, from_client=True)
- assert websockets.Frame.default(msg, from_client=False)
-
- def test_serialization_bijection(self):
- """
- Ensure that various frame types can be serialized/deserialized back
- and forth between to_bytes() and from_bytes()
- """
- for is_client in [True, False]:
- for num_bytes in [100, 50000, 150000]:
- frame = websockets.Frame.default(
- self.random_bytes(num_bytes), is_client
- )
- frame2 = websockets.Frame.from_bytes(
- frame.to_bytes()
- )
- assert frame == frame2
-
- bytes = b'\x81\x03cba'
- assert websockets.Frame.from_bytes(bytes).to_bytes() == bytes
-
- def test_check_server_handshake(self):
- headers = self.protocol.server_handshake_headers("key")
- assert self.protocol.check_server_handshake(headers)
- headers["Upgrade"] = "not_websocket"
- assert not self.protocol.check_server_handshake(headers)
-
- def test_check_client_handshake(self):
- headers = self.protocol.client_handshake_headers("key")
- assert self.protocol.check_client_handshake(headers) == "key"
- headers["Upgrade"] = "not_websocket"
- assert not self.protocol.check_client_handshake(headers)
-
-
-class BadHandshakeHandler(WebSocketsEchoHandler):
-
- def handshake(self):
-
- client_hs = read_request(self.rfile)
- self.protocol.check_client_handshake(client_hs.headers)
-
- preamble = 'HTTP/1.1 101 %s\r\n' % status_codes.RESPONSES.get(101)
- self.wfile.write(preamble.encode())
- headers = self.protocol.server_handshake_headers(b"malformed key")
- self.wfile.write(bytes(headers) + b"\r\n")
- self.wfile.flush()
- self.handshake_done = True
-
-
-class TestBadHandshake(tservers.ServerTestBase):
-
- """
- Ensure that the client disconnects if the server handshake is malformed
- """
- handler = BadHandshakeHandler
-
- def test(self):
- with tutils.raises(exceptions.TcpDisconnect):
- client = WebSocketsClient(("127.0.0.1", self.port))
- client.connect()
- client.send_message(b"hello")
-
-
-class TestFrameHeader:
-
- def test_roundtrip(self):
- def round(*args, **kwargs):
- f = websockets.FrameHeader(*args, **kwargs)
- f2 = websockets.FrameHeader.from_file(tutils.treader(bytes(f)))
- assert f == f2
- round()
- round(fin=1)
- round(rsv1=1)
- round(rsv2=1)
- round(rsv3=1)
- round(payload_length=1)
- round(payload_length=100)
- round(payload_length=1000)
- round(payload_length=10000)
- round(opcode=websockets.OPCODE.PING)
- round(masking_key=b"test")
-
- def test_human_readable(self):
- f = websockets.FrameHeader(
- masking_key=b"test",
- fin=True,
- payload_length=10
- )
- assert repr(f)
- f = websockets.FrameHeader()
- assert repr(f)
-
- def test_funky(self):
- f = websockets.FrameHeader(masking_key=b"test", mask=False)
- raw = bytes(f)
- f2 = websockets.FrameHeader.from_file(tutils.treader(raw))
- assert not f2.mask
-
- def test_violations(self):
- tutils.raises("opcode", websockets.FrameHeader, opcode=17)
- tutils.raises("masking key", websockets.FrameHeader, masking_key=b"x")
-
- def test_automask(self):
- f = websockets.FrameHeader(mask=True)
- assert f.masking_key
-
- f = websockets.FrameHeader(masking_key=b"foob")
- assert f.mask
-
- f = websockets.FrameHeader(masking_key=b"foob", mask=0)
- assert not f.mask
- assert f.masking_key
-
-
-class TestFrame:
-
- def test_roundtrip(self):
- def round(*args, **kwargs):
- f = websockets.Frame(*args, **kwargs)
- raw = bytes(f)
- f2 = websockets.Frame.from_file(tutils.treader(raw))
- assert f == f2
- round(b"test")
- round(b"test", fin=1)
- round(b"test", rsv1=1)
- round(b"test", opcode=websockets.OPCODE.PING)
- round(b"test", masking_key=b"test")
-
- def test_human_readable(self):
- f = websockets.Frame()
- assert repr(f)
-
-
-def test_masker():
- tests = [
- [b"a"],
- [b"four"],
- [b"fourf"],
- [b"fourfive"],
- [b"a", b"aasdfasdfa", b"asdf"],
- [b"a" * 50, b"aasdfasdfa", b"asdf"],
- ]
- for i in tests:
- m = websockets.Masker(b"abcd")
- data = b"".join([m(t) for t in i])
- data2 = websockets.Masker(b"abcd")(data)
- assert data2 == b"".join(i)