diff options
author | Aldo Cortesi <aldo@nullcube.com> | 2015-04-21 22:39:45 +1200 |
---|---|---|
committer | Aldo Cortesi <aldo@nullcube.com> | 2015-04-21 22:39:45 +1200 |
commit | 3e0a71ea345131a5f2dcc9581a7d93b8ebe09b13 (patch) | |
tree | 183a689254da4d824ffe6bd2eba8b07623f840bd /test/test_websockets.py | |
parent | e5f12648380cb4401f77e3cae51189ef97b603dc (diff) | |
download | mitmproxy-3e0a71ea345131a5f2dcc9581a7d93b8ebe09b13.tar.gz mitmproxy-3e0a71ea345131a5f2dcc9581a7d93b8ebe09b13.tar.bz2 mitmproxy-3e0a71ea345131a5f2dcc9581a7d93b8ebe09b13.zip |
websockets: refactor to use http and header functions in http.py
Diffstat (limited to 'test/test_websockets.py')
-rw-r--r-- | test/test_websockets.py | 112 |
1 files changed, 48 insertions, 64 deletions
diff --git a/test/test_websockets.py b/test/test_websockets.py index 1f2025bf..9b27e810 100644 --- a/test/test_websockets.py +++ b/test/test_websockets.py @@ -1,6 +1,4 @@ -from netlib import tcp -from netlib import test -from netlib import websockets +from netlib import tcp, test, websockets, http, odict import io import os from nose.tools import raises @@ -21,18 +19,20 @@ class WebSocketsEchoHandler(tcp.BaseHandler): self.read_next_message() def read_next_message(self): - decoded = websockets.Frame.from_file(self.rfile).decoded_payload - self.on_message(decoded) + frame = websockets.Frame.from_file(self.rfile) + self.on_message(frame.decoded_payload) def send_message(self, message): frame = websockets.Frame.default(message, from_client = False) frame.to_file(self.wfile) def handshake(self): - client_hs = websockets.read_handshake(self.rfile, 1) - key = websockets.process_handshake_from_client(client_hs) - response = websockets.create_server_handshake(key) - self.wfile.write(response) + req = http.read_request(self.rfile) + key = websockets.check_client_handshake(req) + + self.wfile.write(http.response_preamble(101) + "\r\n") + headers = websockets.server_handshake_headers(key) + self.wfile.write(headers.format() + "\r\n") self.wfile.flush() self.handshake_done = True @@ -44,28 +44,20 @@ class WebSocketsEchoHandler(tcp.BaseHandler): class WebSocketsClient(tcp.TCPClient): def __init__(self, address, source_address=None): super(WebSocketsClient, self).__init__(address, source_address) - self.version = "13" - self.client_nonce = websockets.create_client_nonce() - self.resource = "/" + self.client_nonce = None def connect(self): super(WebSocketsClient, self).connect() - handshake = websockets.create_client_handshake( - self.address.host, - self.address.port, - self.client_nonce, - self.version, - self.resource - ) - - self.wfile.write(handshake) + preamble = http.request_preamble("GET", "/") + self.wfile.write(preamble + "\r\n") + headers = websockets.client_handshake_headers() + self.client_nonce = headers.get_first("sec-websocket-key") + self.wfile.write(headers.format() + "\r\n") self.wfile.flush() - server_handshake = websockets.read_handshake(self.rfile, 1) - server_nonce = websockets.process_handshake_from_server( - server_handshake - ) + resp = http.read_response(self.rfile, "get", None) + server_nonce = websockets.check_server_handshake(resp) if not server_nonce == websockets.create_server_nonce(self.client_nonce): self.close() @@ -140,51 +132,43 @@ class TestWebSockets(test.ServerTestBase): frame.actual_payload_length = 1 # corrupt the frame frame.safe_to_bytes() - def test_handshake(self): - bad_upgrade = "not_websockets" - bad_header_handshake = websockets.build_handshake([ - ('Host', '%s:%s' % ("a", "b")), - ('Connection', "c"), - ('Upgrade', bad_upgrade), - ('Sec-WebSocket-Key', "d"), - ('Sec-WebSocket-Version', "e") - ], "f") - - # check behavior when required header values are missing - assert None is websockets.process_handshake_from_server( - bad_header_handshake - ) - assert None is websockets.process_handshake_from_client( - bad_header_handshake - ) - - key = "test_key" - - client_handshake = websockets.create_client_handshake( - "a", "b", key, "d", "e" + def test_check_server_handshake(self): + resp = http.Response( + (1, 1), + 101, + "Switching Protocols", + websockets.server_handshake_headers("key"), + "" ) - assert key == websockets.process_handshake_from_client( - client_handshake + assert websockets.check_server_handshake(resp) + resp.headers["Upgrade"] = ["not_websocket"] + assert not websockets.check_server_handshake(resp) + + def test_check_client_handshake(self): + resp = http.Request( + "relative", + "get", + "http", + "host", + 22, + "/", + (1, 1), + websockets.client_handshake_headers("key"), + "" ) - - server_handshake = websockets.create_server_handshake(key) - assert websockets.create_server_nonce(key) == websockets.process_handshake_from_server(server_handshake) - - handshake = websockets.create_client_handshake("a", "b", "c", "d", "e") - stream = io.BytesIO(handshake) - assert handshake == websockets.read_handshake(stream, 1) - - # ensure readhandshake doesn't loop forever on empty stream - empty_stream = io.BytesIO("") - assert "" == websockets.read_handshake(empty_stream, 1) + assert websockets.check_client_handshake(resp) == "key" + resp.headers["Upgrade"] = ["not_websocket"] + assert not websockets.check_client_handshake(resp) class BadHandshakeHandler(WebSocketsEchoHandler): def handshake(self): - client_hs = websockets.read_handshake(self.rfile, 1) - websockets.process_handshake_from_client(client_hs) - response = websockets.create_server_handshake("malformed_key") - self.wfile.write(response) + client_hs = http.read_request(self.rfile) + websockets.check_client_handshake(client_hs) + + self.wfile.write(http.response_preamble(101) + "\r\n") + headers = websockets.server_handshake_headers("malformed key") + self.wfile.write(headers.format() + "\r\n") self.wfile.flush() self.handshake_done = True |