diff options
Diffstat (limited to 'netlib/test/websockets/test_websockets.py')
-rw-r--r-- | netlib/test/websockets/test_websockets.py | 266 |
1 files changed, 0 insertions, 266 deletions
diff --git a/netlib/test/websockets/test_websockets.py b/netlib/test/websockets/test_websockets.py deleted file mode 100644 index d53f0d83..00000000 --- a/netlib/test/websockets/test_websockets.py +++ /dev/null @@ -1,266 +0,0 @@ -import os - -from netlib.http.http1 import read_response, read_request - -from netlib import tcp, websockets, http, tutils, tservers -from netlib.http import status_codes -from netlib.tutils import treq - -from netlib.exceptions import * - - -class WebSocketsEchoHandler(tcp.BaseHandler): - - def __init__(self, connection, address, server): - super(WebSocketsEchoHandler, self).__init__( - connection, address, server - ) - self.protocol = websockets.WebsocketsProtocol() - self.handshake_done = False - - def handle(self): - while True: - if not self.handshake_done: - self.handshake() - else: - self.read_next_message() - - def read_next_message(self): - frame = websockets.Frame.from_file(self.rfile) - self.on_message(frame.payload) - - def send_message(self, message): - frame = websockets.Frame.default(message, from_client=False) - frame.to_file(self.wfile) - - def handshake(self): - - req = read_request(self.rfile) - key = self.protocol.check_client_handshake(req.headers) - - preamble = 'HTTP/1.1 101 %s' % status_codes.RESPONSES.get(101) - self.wfile.write(preamble.encode() + b"\r\n") - headers = self.protocol.server_handshake_headers(key) - self.wfile.write(str(headers) + "\r\n") - self.wfile.flush() - self.handshake_done = True - - def on_message(self, message): - if message is not None: - self.send_message(message) - - -class WebSocketsClient(tcp.TCPClient): - - def __init__(self, address, source_address=None): - super(WebSocketsClient, self).__init__(address, source_address) - self.protocol = websockets.WebsocketsProtocol() - self.client_nonce = None - - def connect(self): - super(WebSocketsClient, self).connect() - - preamble = b'GET / HTTP/1.1' - self.wfile.write(preamble + b"\r\n") - headers = self.protocol.client_handshake_headers() - self.client_nonce = headers["sec-websocket-key"].encode("ascii") - self.wfile.write(bytes(headers) + b"\r\n") - self.wfile.flush() - - resp = read_response(self.rfile, treq(method=b"GET")) - server_nonce = self.protocol.check_server_handshake(resp.headers) - - if not server_nonce == self.protocol.create_server_nonce(self.client_nonce): - self.close() - - def read_next_message(self): - return websockets.Frame.from_file(self.rfile).payload - - def send_message(self, message): - frame = websockets.Frame.default(message, from_client=True) - frame.to_file(self.wfile) - - -class TestWebSockets(tservers.ServerTestBase): - handler = WebSocketsEchoHandler - - def __init__(self): - self.protocol = websockets.WebsocketsProtocol() - - def random_bytes(self, n=100): - return os.urandom(n) - - def echo(self, msg): - client = WebSocketsClient(("127.0.0.1", self.port)) - client.connect() - client.send_message(msg) - response = client.read_next_message() - assert response == msg - - def test_simple_echo(self): - self.echo(b"hello I'm the client") - - def test_frame_sizes(self): - # length can fit in the the 7 bit payload length - small_msg = self.random_bytes(100) - # 50kb, sligthly larger than can fit in a 7 bit int - medium_msg = self.random_bytes(50000) - # 150kb, slightly larger than can fit in a 16 bit int - large_msg = self.random_bytes(150000) - - self.echo(small_msg) - self.echo(medium_msg) - self.echo(large_msg) - - def test_default_builder(self): - """ - default builder should always generate valid frames - """ - msg = self.random_bytes() - client_frame = websockets.Frame.default(msg, from_client=True) - server_frame = websockets.Frame.default(msg, from_client=False) - - def test_serialization_bijection(self): - """ - Ensure that various frame types can be serialized/deserialized back - and forth between to_bytes() and from_bytes() - """ - for is_client in [True, False]: - for num_bytes in [100, 50000, 150000]: - frame = websockets.Frame.default( - self.random_bytes(num_bytes), is_client - ) - frame2 = websockets.Frame.from_bytes( - frame.to_bytes() - ) - assert frame == frame2 - - bytes = b'\x81\x03cba' - assert websockets.Frame.from_bytes(bytes).to_bytes() == bytes - - def test_check_server_handshake(self): - headers = self.protocol.server_handshake_headers("key") - assert self.protocol.check_server_handshake(headers) - headers["Upgrade"] = "not_websocket" - assert not self.protocol.check_server_handshake(headers) - - def test_check_client_handshake(self): - headers = self.protocol.client_handshake_headers("key") - assert self.protocol.check_client_handshake(headers) == "key" - headers["Upgrade"] = "not_websocket" - assert not self.protocol.check_client_handshake(headers) - - -class BadHandshakeHandler(WebSocketsEchoHandler): - - def handshake(self): - - client_hs = read_request(self.rfile) - self.protocol.check_client_handshake(client_hs.headers) - - preamble = 'HTTP/1.1 101 %s\r\n' % status_codes.RESPONSES.get(101) - self.wfile.write(preamble.encode()) - headers = self.protocol.server_handshake_headers(b"malformed key") - self.wfile.write(bytes(headers) + b"\r\n") - self.wfile.flush() - self.handshake_done = True - - -class TestBadHandshake(tservers.ServerTestBase): - - """ - Ensure that the client disconnects if the server handshake is malformed - """ - handler = BadHandshakeHandler - - def test(self): - with tutils.raises(TcpDisconnect): - client = WebSocketsClient(("127.0.0.1", self.port)) - client.connect() - client.send_message(b"hello") - - -class TestFrameHeader: - - def test_roundtrip(self): - def round(*args, **kwargs): - f = websockets.FrameHeader(*args, **kwargs) - f2 = websockets.FrameHeader.from_file(tutils.treader(bytes(f))) - assert f == f2 - round() - round(fin=1) - round(rsv1=1) - round(rsv2=1) - round(rsv3=1) - round(payload_length=1) - round(payload_length=100) - round(payload_length=1000) - round(payload_length=10000) - round(opcode=websockets.OPCODE.PING) - round(masking_key=b"test") - - def test_human_readable(self): - f = websockets.FrameHeader( - masking_key=b"test", - fin=True, - payload_length=10 - ) - assert repr(f) - f = websockets.FrameHeader() - assert repr(f) - - def test_funky(self): - f = websockets.FrameHeader(masking_key=b"test", mask=False) - raw = bytes(f) - f2 = websockets.FrameHeader.from_file(tutils.treader(raw)) - assert not f2.mask - - def test_violations(self): - tutils.raises("opcode", websockets.FrameHeader, opcode=17) - tutils.raises("masking key", websockets.FrameHeader, masking_key=b"x") - - def test_automask(self): - f = websockets.FrameHeader(mask=True) - assert f.masking_key - - f = websockets.FrameHeader(masking_key=b"foob") - assert f.mask - - f = websockets.FrameHeader(masking_key=b"foob", mask=0) - assert not f.mask - assert f.masking_key - - -class TestFrame: - - def test_roundtrip(self): - def round(*args, **kwargs): - f = websockets.Frame(*args, **kwargs) - raw = bytes(f) - f2 = websockets.Frame.from_file(tutils.treader(raw)) - assert f == f2 - round(b"test") - round(b"test", fin=1) - round(b"test", rsv1=1) - round(b"test", opcode=websockets.OPCODE.PING) - round(b"test", masking_key=b"test") - - def test_human_readable(self): - f = websockets.Frame() - assert repr(f) - - -def test_masker(): - tests = [ - [b"a"], - [b"four"], - [b"fourf"], - [b"fourfive"], - [b"a", b"aasdfasdfa", b"asdf"], - [b"a" * 50, b"aasdfasdfa", b"asdf"], - ] - for i in tests: - m = websockets.Masker(b"abcd") - data = b"".join([m(t) for t in i]) - data2 = websockets.Masker(b"abcd")(data) - assert data2 == b"".join(i) |