aboutsummaryrefslogtreecommitdiffstats
path: root/test/websockets/test_websockets.py
diff options
context:
space:
mode:
Diffstat (limited to 'test/websockets/test_websockets.py')
-rw-r--r--test/websockets/test_websockets.py71
1 files changed, 35 insertions, 36 deletions
diff --git a/test/websockets/test_websockets.py b/test/websockets/test_websockets.py
index 3af5dc9c..6f67b84d 100644
--- a/test/websockets/test_websockets.py
+++ b/test/websockets/test_websockets.py
@@ -41,7 +41,7 @@ class WebSocketsEchoHandler(tcp.BaseHandler):
key = self.protocol.check_client_handshake(req.headers)
preamble = 'HTTP/1.1 101 %s' % status_codes.RESPONSES.get(101)
- self.wfile.write(preamble + "\r\n")
+ self.wfile.write(preamble.encode() + b"\r\n")
headers = self.protocol.server_handshake_headers(key)
self.wfile.write(str(headers) + "\r\n")
self.wfile.flush()
@@ -62,11 +62,11 @@ class WebSocketsClient(tcp.TCPClient):
def connect(self):
super(WebSocketsClient, self).connect()
- preamble = 'GET / HTTP/1.1'
- self.wfile.write(preamble + "\r\n")
+ preamble = b'GET / HTTP/1.1'
+ self.wfile.write(preamble + b"\r\n")
headers = self.protocol.client_handshake_headers()
self.client_nonce = headers["sec-websocket-key"]
- self.wfile.write(str(headers) + "\r\n")
+ self.wfile.write(bytes(headers) + b"\r\n")
self.wfile.flush()
resp = read_response(self.rfile, treq(method="GET"))
@@ -101,7 +101,7 @@ class TestWebSockets(tservers.ServerTestBase):
assert response == msg
def test_simple_echo(self):
- self.echo("hello I'm the client")
+ self.echo(b"hello I'm the client")
def test_frame_sizes(self):
# length can fit in the the 7 bit payload length
@@ -161,10 +161,10 @@ class BadHandshakeHandler(WebSocketsEchoHandler):
client_hs = read_request(self.rfile)
self.protocol.check_client_handshake(client_hs.headers)
- preamble = 'HTTP/1.1 101 %s' % status_codes.RESPONSES.get(101)
- self.wfile.write(preamble + "\r\n")
- headers = self.protocol.server_handshake_headers("malformed key")
- self.wfile.write(str(headers) + "\r\n")
+ preamble = 'HTTP/1.1 101 %s\r\n' % status_codes.RESPONSES.get(101)
+ self.wfile.write(preamble.encode())
+ headers = self.protocol.server_handshake_headers(b"malformed key")
+ self.wfile.write(bytes(headers) + b"\r\n")
self.wfile.flush()
self.handshake_done = True
@@ -180,7 +180,7 @@ class TestBadHandshake(tservers.ServerTestBase):
def test(self):
client = WebSocketsClient(("127.0.0.1", self.port))
client.connect()
- client.send_message("hello")
+ client.send_message(b"hello")
class TestFrameHeader:
@@ -188,8 +188,7 @@ class TestFrameHeader:
def test_roundtrip(self):
def round(*args, **kwargs):
f = websockets.FrameHeader(*args, **kwargs)
- bytes = f.to_bytes()
- f2 = websockets.FrameHeader.from_file(tutils.treader(bytes))
+ f2 = websockets.FrameHeader.from_file(tutils.treader(bytes(f)))
assert f == f2
round()
round(fin=1)
@@ -201,11 +200,11 @@ class TestFrameHeader:
round(payload_length=1000)
round(payload_length=10000)
round(opcode=websockets.OPCODE.PING)
- round(masking_key="test")
+ round(masking_key=b"test")
def test_human_readable(self):
f = websockets.FrameHeader(
- masking_key="test",
+ masking_key=b"test",
fin=True,
payload_length=10
)
@@ -214,23 +213,23 @@ class TestFrameHeader:
assert f.human_readable()
def test_funky(self):
- f = websockets.FrameHeader(masking_key="test", mask=False)
+ f = websockets.FrameHeader(masking_key=b"test", mask=False)
bytes = f.to_bytes()
f2 = websockets.FrameHeader.from_file(tutils.treader(bytes))
assert not f2.mask
def test_violations(self):
tutils.raises("opcode", websockets.FrameHeader, opcode=17)
- tutils.raises("masking key", websockets.FrameHeader, masking_key="x")
+ tutils.raises("masking key", websockets.FrameHeader, masking_key=b"x")
def test_automask(self):
f = websockets.FrameHeader(mask=True)
assert f.masking_key
- f = websockets.FrameHeader(masking_key="foob")
+ f = websockets.FrameHeader(masking_key=b"foob")
assert f.mask
- f = websockets.FrameHeader(masking_key="foob", mask=0)
+ f = websockets.FrameHeader(masking_key=b"foob", mask=0)
assert not f.mask
assert f.masking_key
@@ -240,31 +239,31 @@ class TestFrame:
def test_roundtrip(self):
def round(*args, **kwargs):
f = websockets.Frame(*args, **kwargs)
- bytes = f.to_bytes()
- f2 = websockets.Frame.from_file(tutils.treader(bytes))
+ raw = bytes(f)
+ f2 = websockets.Frame.from_file(tutils.treader(raw))
assert f == f2
- round("test")
- round("test", fin=1)
- round("test", rsv1=1)
- round("test", opcode=websockets.OPCODE.PING)
- round("test", masking_key="test")
+ round(b"test")
+ round(b"test", fin=1)
+ round(b"test", rsv1=1)
+ round(b"test", opcode=websockets.OPCODE.PING)
+ round(b"test", masking_key=b"test")
def test_human_readable(self):
f = websockets.Frame()
- assert f.human_readable()
+ assert repr(f)
def test_masker():
tests = [
- ["a"],
- ["four"],
- ["fourf"],
- ["fourfive"],
- ["a", "aasdfasdfa", "asdf"],
- ["a" * 50, "aasdfasdfa", "asdf"],
+ [b"a"],
+ [b"four"],
+ [b"fourf"],
+ [b"fourfive"],
+ [b"a", b"aasdfasdfa", b"asdf"],
+ [b"a" * 50, b"aasdfasdfa", b"asdf"],
]
for i in tests:
- m = websockets.Masker("abcd")
- data = "".join([m(t) for t in i])
- data2 = websockets.Masker("abcd")(data)
- assert data2 == "".join(i)
+ m = websockets.Masker(b"abcd")
+ data = b"".join([m(t) for t in i])
+ data2 = websockets.Masker(b"abcd")(data)
+ assert data2 == b"".join(i)