From c4f028d7017ea77293c05d38700abd14461d7119 Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Wed, 24 May 2017 11:35:20 +0200 Subject: websocket tests: fix leaking sockets --- test/mitmproxy/proxy/protocol/test_websocket.py | 122 ++++++++++++------------ 1 file changed, 62 insertions(+), 60 deletions(-) (limited to 'test') diff --git a/test/mitmproxy/proxy/protocol/test_websocket.py b/test/mitmproxy/proxy/protocol/test_websocket.py index 8dfc4f2b..f78e173f 100644 --- a/test/mitmproxy/proxy/protocol/test_websocket.py +++ b/test/mitmproxy/proxy/protocol/test_websocket.py @@ -79,9 +79,13 @@ class _WebSocketTestBase: self.master.reset([]) self.server.server.handle_websockets = self.handle_websockets - def _setup_connection(self): - client = tcp.TCPClient(("127.0.0.1", self.proxy.port)) - client.connect() + def teardown(self): + if self.client: + self.client.close() + + def setup_connection(self): + self.client = tcp.TCPClient(("127.0.0.1", self.proxy.port)) + self.client.connect() request = http.Request( "authority", @@ -92,14 +96,14 @@ class _WebSocketTestBase: "", "HTTP/1.1", content=b'') - client.wfile.write(http.http1.assemble_request(request)) - client.wfile.flush() + self.client.wfile.write(http.http1.assemble_request(request)) + self.client.wfile.flush() - response = http.http1.read_response(client.rfile, request) + response = http.http1.read_response(self.client.rfile, request) if self.ssl: - client.convert_to_ssl() - assert client.ssl_established + self.client.convert_to_ssl() + assert self.client.ssl_established request = http.Request( "relative", @@ -116,14 +120,12 @@ class _WebSocketTestBase: sec_websocket_key="1234", ), content=b'') - client.wfile.write(http.http1.assemble_request(request)) - client.wfile.flush() + self.client.wfile.write(http.http1.assemble_request(request)) + self.client.wfile.flush() - response = http.http1.read_response(client.rfile, request) + response = http.http1.read_response(self.client.rfile, request) assert websockets.check_handshake(response.headers) - return client - class _WebSocketTest(_WebSocketTestBase, _WebSocketServerBase): @@ -154,25 +156,25 @@ class TestSimple(_WebSocketTest): wfile.flush() def test_simple(self): - client = self._setup_connection() + self.setup_connection() - frame = websockets.Frame.from_file(client.rfile) + frame = websockets.Frame.from_file(self.client.rfile) assert frame.payload == b'server-foobar' - client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.TEXT, payload=b'client-foobar'))) - client.wfile.flush() + self.client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.TEXT, payload=b'self.client-foobar'))) + self.client.wfile.flush() - frame = websockets.Frame.from_file(client.rfile) - assert frame.payload == b'client-foobar' + frame = websockets.Frame.from_file(self.client.rfile) + assert frame.payload == b'self.client-foobar' - client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.BINARY, payload=b'\xde\xad\xbe\xef'))) - client.wfile.flush() + self.client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.BINARY, payload=b'\xde\xad\xbe\xef'))) + self.client.wfile.flush() - frame = websockets.Frame.from_file(client.rfile) + frame = websockets.Frame.from_file(self.client.rfile) assert frame.payload == b'\xde\xad\xbe\xef' - client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.CLOSE))) - client.wfile.flush() + self.client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.CLOSE))) + self.client.wfile.flush() assert len(self.master.state.flows) == 2 assert isinstance(self.master.state.flows[0], HTTPFlow) @@ -180,9 +182,9 @@ class TestSimple(_WebSocketTest): assert len(self.master.state.flows[1].messages) == 5 assert self.master.state.flows[1].messages[0].content == b'server-foobar' assert self.master.state.flows[1].messages[0].type == websockets.OPCODE.TEXT - assert self.master.state.flows[1].messages[1].content == b'client-foobar' + assert self.master.state.flows[1].messages[1].content == b'self.client-foobar' assert self.master.state.flows[1].messages[1].type == websockets.OPCODE.TEXT - assert self.master.state.flows[1].messages[2].content == b'client-foobar' + assert self.master.state.flows[1].messages[2].content == b'self.client-foobar' assert self.master.state.flows[1].messages[2].type == websockets.OPCODE.TEXT assert self.master.state.flows[1].messages[3].content == b'\xde\xad\xbe\xef' assert self.master.state.flows[1].messages[3].type == websockets.OPCODE.BINARY @@ -203,19 +205,19 @@ class TestSimpleTLS(_WebSocketTest): wfile.flush() def test_simple_tls(self): - client = self._setup_connection() + self.setup_connection() - frame = websockets.Frame.from_file(client.rfile) + frame = websockets.Frame.from_file(self.client.rfile) assert frame.payload == b'server-foobar' - client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.TEXT, payload=b'client-foobar'))) - client.wfile.flush() + self.client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.TEXT, payload=b'self.client-foobar'))) + self.client.wfile.flush() - frame = websockets.Frame.from_file(client.rfile) - assert frame.payload == b'client-foobar' + frame = websockets.Frame.from_file(self.client.rfile) + assert frame.payload == b'self.client-foobar' - client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.CLOSE))) - client.wfile.flush() + self.client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.CLOSE))) + self.client.wfile.flush() class TestPing(_WebSocketTest): @@ -233,16 +235,16 @@ class TestPing(_WebSocketTest): wfile.flush() def test_ping(self): - client = self._setup_connection() + self.setup_connection() - frame = websockets.Frame.from_file(client.rfile) + frame = websockets.Frame.from_file(self.client.rfile) assert frame.header.opcode == websockets.OPCODE.PING assert frame.payload == b'foobar' - client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.PONG, payload=frame.payload))) - client.wfile.flush() + self.client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.PONG, payload=frame.payload))) + self.client.wfile.flush() - frame = websockets.Frame.from_file(client.rfile) + frame = websockets.Frame.from_file(self.client.rfile) assert frame.header.opcode == websockets.OPCODE.TEXT assert frame.payload == b'pong-received' @@ -259,12 +261,12 @@ class TestPong(_WebSocketTest): wfile.flush() def test_pong(self): - client = self._setup_connection() + self.setup_connection() - client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.PING, payload=b'foobar'))) - client.wfile.flush() + self.client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.PING, payload=b'foobar'))) + self.client.wfile.flush() - frame = websockets.Frame.from_file(client.rfile) + frame = websockets.Frame.from_file(self.client.rfile) assert frame.header.opcode == websockets.OPCODE.PONG assert frame.payload == b'foobar' @@ -282,34 +284,34 @@ class TestClose(_WebSocketTest): websockets.Frame.from_file(rfile) def test_close(self): - client = self._setup_connection() + self.setup_connection() - client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.CLOSE))) - client.wfile.flush() + self.client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.CLOSE))) + self.client.wfile.flush() - websockets.Frame.from_file(client.rfile) + websockets.Frame.from_file(self.client.rfile) with pytest.raises(exceptions.TcpDisconnect): - websockets.Frame.from_file(client.rfile) + websockets.Frame.from_file(self.client.rfile) def test_close_payload_1(self): - client = self._setup_connection() + self.setup_connection() - client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.CLOSE, payload=b'\00\42'))) - client.wfile.flush() + self.client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.CLOSE, payload=b'\00\42'))) + self.client.wfile.flush() - websockets.Frame.from_file(client.rfile) + websockets.Frame.from_file(self.client.rfile) with pytest.raises(exceptions.TcpDisconnect): - websockets.Frame.from_file(client.rfile) + websockets.Frame.from_file(self.client.rfile) def test_close_payload_2(self): - client = self._setup_connection() + self.setup_connection() - client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.CLOSE, payload=b'\00\42foobar'))) - client.wfile.flush() + self.client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.CLOSE, payload=b'\00\42foobar'))) + self.client.wfile.flush() - websockets.Frame.from_file(client.rfile) + websockets.Frame.from_file(self.client.rfile) with pytest.raises(exceptions.TcpDisconnect): - websockets.Frame.from_file(client.rfile) + websockets.Frame.from_file(self.client.rfile) class TestInvalidFrame(_WebSocketTest): @@ -320,9 +322,9 @@ class TestInvalidFrame(_WebSocketTest): wfile.flush() def test_invalid_frame(self): - client = self._setup_connection() + self.setup_connection() # with pytest.raises(exceptions.TcpDisconnect): - frame = websockets.Frame.from_file(client.rfile) + frame = websockets.Frame.from_file(self.client.rfile) assert frame.header.opcode == 15 assert frame.payload == b'foobar' -- cgit v1.2.3