aboutsummaryrefslogtreecommitdiffstats
path: root/test/test_websockets.py
diff options
context:
space:
mode:
Diffstat (limited to 'test/test_websockets.py')
-rw-r--r--test/test_websockets.py112
1 files changed, 48 insertions, 64 deletions
diff --git a/test/test_websockets.py b/test/test_websockets.py
index 1f2025bf..9b27e810 100644
--- a/test/test_websockets.py
+++ b/test/test_websockets.py
@@ -1,6 +1,4 @@
-from netlib import tcp
-from netlib import test
-from netlib import websockets
+from netlib import tcp, test, websockets, http, odict
import io
import os
from nose.tools import raises
@@ -21,18 +19,20 @@ class WebSocketsEchoHandler(tcp.BaseHandler):
self.read_next_message()
def read_next_message(self):
- decoded = websockets.Frame.from_file(self.rfile).decoded_payload
- self.on_message(decoded)
+ frame = websockets.Frame.from_file(self.rfile)
+ self.on_message(frame.decoded_payload)
def send_message(self, message):
frame = websockets.Frame.default(message, from_client = False)
frame.to_file(self.wfile)
def handshake(self):
- client_hs = websockets.read_handshake(self.rfile, 1)
- key = websockets.process_handshake_from_client(client_hs)
- response = websockets.create_server_handshake(key)
- self.wfile.write(response)
+ req = http.read_request(self.rfile)
+ key = websockets.check_client_handshake(req)
+
+ self.wfile.write(http.response_preamble(101) + "\r\n")
+ headers = websockets.server_handshake_headers(key)
+ self.wfile.write(headers.format() + "\r\n")
self.wfile.flush()
self.handshake_done = True
@@ -44,28 +44,20 @@ class WebSocketsEchoHandler(tcp.BaseHandler):
class WebSocketsClient(tcp.TCPClient):
def __init__(self, address, source_address=None):
super(WebSocketsClient, self).__init__(address, source_address)
- self.version = "13"
- self.client_nonce = websockets.create_client_nonce()
- self.resource = "/"
+ self.client_nonce = None
def connect(self):
super(WebSocketsClient, self).connect()
- handshake = websockets.create_client_handshake(
- self.address.host,
- self.address.port,
- self.client_nonce,
- self.version,
- self.resource
- )
-
- self.wfile.write(handshake)
+ preamble = http.request_preamble("GET", "/")
+ self.wfile.write(preamble + "\r\n")
+ headers = websockets.client_handshake_headers()
+ self.client_nonce = headers.get_first("sec-websocket-key")
+ self.wfile.write(headers.format() + "\r\n")
self.wfile.flush()
- server_handshake = websockets.read_handshake(self.rfile, 1)
- server_nonce = websockets.process_handshake_from_server(
- server_handshake
- )
+ resp = http.read_response(self.rfile, "get", None)
+ server_nonce = websockets.check_server_handshake(resp)
if not server_nonce == websockets.create_server_nonce(self.client_nonce):
self.close()
@@ -140,51 +132,43 @@ class TestWebSockets(test.ServerTestBase):
frame.actual_payload_length = 1 # corrupt the frame
frame.safe_to_bytes()
- def test_handshake(self):
- bad_upgrade = "not_websockets"
- bad_header_handshake = websockets.build_handshake([
- ('Host', '%s:%s' % ("a", "b")),
- ('Connection', "c"),
- ('Upgrade', bad_upgrade),
- ('Sec-WebSocket-Key', "d"),
- ('Sec-WebSocket-Version', "e")
- ], "f")
-
- # check behavior when required header values are missing
- assert None is websockets.process_handshake_from_server(
- bad_header_handshake
- )
- assert None is websockets.process_handshake_from_client(
- bad_header_handshake
- )
-
- key = "test_key"
-
- client_handshake = websockets.create_client_handshake(
- "a", "b", key, "d", "e"
+ def test_check_server_handshake(self):
+ resp = http.Response(
+ (1, 1),
+ 101,
+ "Switching Protocols",
+ websockets.server_handshake_headers("key"),
+ ""
)
- assert key == websockets.process_handshake_from_client(
- client_handshake
+ assert websockets.check_server_handshake(resp)
+ resp.headers["Upgrade"] = ["not_websocket"]
+ assert not websockets.check_server_handshake(resp)
+
+ def test_check_client_handshake(self):
+ resp = http.Request(
+ "relative",
+ "get",
+ "http",
+ "host",
+ 22,
+ "/",
+ (1, 1),
+ websockets.client_handshake_headers("key"),
+ ""
)
-
- server_handshake = websockets.create_server_handshake(key)
- assert websockets.create_server_nonce(key) == websockets.process_handshake_from_server(server_handshake)
-
- handshake = websockets.create_client_handshake("a", "b", "c", "d", "e")
- stream = io.BytesIO(handshake)
- assert handshake == websockets.read_handshake(stream, 1)
-
- # ensure readhandshake doesn't loop forever on empty stream
- empty_stream = io.BytesIO("")
- assert "" == websockets.read_handshake(empty_stream, 1)
+ assert websockets.check_client_handshake(resp) == "key"
+ resp.headers["Upgrade"] = ["not_websocket"]
+ assert not websockets.check_client_handshake(resp)
class BadHandshakeHandler(WebSocketsEchoHandler):
def handshake(self):
- client_hs = websockets.read_handshake(self.rfile, 1)
- websockets.process_handshake_from_client(client_hs)
- response = websockets.create_server_handshake("malformed_key")
- self.wfile.write(response)
+ client_hs = http.read_request(self.rfile)
+ websockets.check_client_handshake(client_hs)
+
+ self.wfile.write(http.response_preamble(101) + "\r\n")
+ headers = websockets.server_handshake_headers("malformed key")
+ self.wfile.write(headers.format() + "\r\n")
self.wfile.flush()
self.handshake_done = True