diff options
Diffstat (limited to 'test/websockets/test_websockets.py')
-rw-r--r-- | test/websockets/test_websockets.py | 71 |
1 files changed, 35 insertions, 36 deletions
diff --git a/test/websockets/test_websockets.py b/test/websockets/test_websockets.py index 3af5dc9c..6f67b84d 100644 --- a/test/websockets/test_websockets.py +++ b/test/websockets/test_websockets.py @@ -41,7 +41,7 @@ class WebSocketsEchoHandler(tcp.BaseHandler): key = self.protocol.check_client_handshake(req.headers) preamble = 'HTTP/1.1 101 %s' % status_codes.RESPONSES.get(101) - self.wfile.write(preamble + "\r\n") + self.wfile.write(preamble.encode() + b"\r\n") headers = self.protocol.server_handshake_headers(key) self.wfile.write(str(headers) + "\r\n") self.wfile.flush() @@ -62,11 +62,11 @@ class WebSocketsClient(tcp.TCPClient): def connect(self): super(WebSocketsClient, self).connect() - preamble = 'GET / HTTP/1.1' - self.wfile.write(preamble + "\r\n") + preamble = b'GET / HTTP/1.1' + self.wfile.write(preamble + b"\r\n") headers = self.protocol.client_handshake_headers() self.client_nonce = headers["sec-websocket-key"] - self.wfile.write(str(headers) + "\r\n") + self.wfile.write(bytes(headers) + b"\r\n") self.wfile.flush() resp = read_response(self.rfile, treq(method="GET")) @@ -101,7 +101,7 @@ class TestWebSockets(tservers.ServerTestBase): assert response == msg def test_simple_echo(self): - self.echo("hello I'm the client") + self.echo(b"hello I'm the client") def test_frame_sizes(self): # length can fit in the the 7 bit payload length @@ -161,10 +161,10 @@ class BadHandshakeHandler(WebSocketsEchoHandler): client_hs = read_request(self.rfile) self.protocol.check_client_handshake(client_hs.headers) - preamble = 'HTTP/1.1 101 %s' % status_codes.RESPONSES.get(101) - self.wfile.write(preamble + "\r\n") - headers = self.protocol.server_handshake_headers("malformed key") - self.wfile.write(str(headers) + "\r\n") + preamble = 'HTTP/1.1 101 %s\r\n' % status_codes.RESPONSES.get(101) + self.wfile.write(preamble.encode()) + headers = self.protocol.server_handshake_headers(b"malformed key") + self.wfile.write(bytes(headers) + b"\r\n") self.wfile.flush() self.handshake_done = True @@ -180,7 +180,7 @@ class TestBadHandshake(tservers.ServerTestBase): def test(self): client = WebSocketsClient(("127.0.0.1", self.port)) client.connect() - client.send_message("hello") + client.send_message(b"hello") class TestFrameHeader: @@ -188,8 +188,7 @@ class TestFrameHeader: def test_roundtrip(self): def round(*args, **kwargs): f = websockets.FrameHeader(*args, **kwargs) - bytes = f.to_bytes() - f2 = websockets.FrameHeader.from_file(tutils.treader(bytes)) + f2 = websockets.FrameHeader.from_file(tutils.treader(bytes(f))) assert f == f2 round() round(fin=1) @@ -201,11 +200,11 @@ class TestFrameHeader: round(payload_length=1000) round(payload_length=10000) round(opcode=websockets.OPCODE.PING) - round(masking_key="test") + round(masking_key=b"test") def test_human_readable(self): f = websockets.FrameHeader( - masking_key="test", + masking_key=b"test", fin=True, payload_length=10 ) @@ -214,23 +213,23 @@ class TestFrameHeader: assert f.human_readable() def test_funky(self): - f = websockets.FrameHeader(masking_key="test", mask=False) + f = websockets.FrameHeader(masking_key=b"test", mask=False) bytes = f.to_bytes() f2 = websockets.FrameHeader.from_file(tutils.treader(bytes)) assert not f2.mask def test_violations(self): tutils.raises("opcode", websockets.FrameHeader, opcode=17) - tutils.raises("masking key", websockets.FrameHeader, masking_key="x") + tutils.raises("masking key", websockets.FrameHeader, masking_key=b"x") def test_automask(self): f = websockets.FrameHeader(mask=True) assert f.masking_key - f = websockets.FrameHeader(masking_key="foob") + f = websockets.FrameHeader(masking_key=b"foob") assert f.mask - f = websockets.FrameHeader(masking_key="foob", mask=0) + f = websockets.FrameHeader(masking_key=b"foob", mask=0) assert not f.mask assert f.masking_key @@ -240,31 +239,31 @@ class TestFrame: def test_roundtrip(self): def round(*args, **kwargs): f = websockets.Frame(*args, **kwargs) - bytes = f.to_bytes() - f2 = websockets.Frame.from_file(tutils.treader(bytes)) + raw = bytes(f) + f2 = websockets.Frame.from_file(tutils.treader(raw)) assert f == f2 - round("test") - round("test", fin=1) - round("test", rsv1=1) - round("test", opcode=websockets.OPCODE.PING) - round("test", masking_key="test") + round(b"test") + round(b"test", fin=1) + round(b"test", rsv1=1) + round(b"test", opcode=websockets.OPCODE.PING) + round(b"test", masking_key=b"test") def test_human_readable(self): f = websockets.Frame() - assert f.human_readable() + assert repr(f) def test_masker(): tests = [ - ["a"], - ["four"], - ["fourf"], - ["fourfive"], - ["a", "aasdfasdfa", "asdf"], - ["a" * 50, "aasdfasdfa", "asdf"], + [b"a"], + [b"four"], + [b"fourf"], + [b"fourfive"], + [b"a", b"aasdfasdfa", b"asdf"], + [b"a" * 50, b"aasdfasdfa", b"asdf"], ] for i in tests: - m = websockets.Masker("abcd") - data = "".join([m(t) for t in i]) - data2 = websockets.Masker("abcd")(data) - assert data2 == "".join(i) + m = websockets.Masker(b"abcd") + data = b"".join([m(t) for t in i]) + data2 = websockets.Masker(b"abcd")(data) + assert data2 == b"".join(i) |