diff options
Diffstat (limited to 'test')
-rw-r--r-- | test/test_filt.py | 2 | ||||
-rw-r--r-- | test/test_flow.py | 11 | ||||
-rw-r--r-- | test/test_protocol_http1.py (renamed from test/test_protocol_http.py) | 28 | ||||
-rw-r--r-- | test/test_protocol_http2.py | 431 | ||||
-rw-r--r-- | test/tutils.py | 16 |
5 files changed, 453 insertions, 35 deletions
diff --git a/test/test_filt.py b/test/test_filt.py index b1fd2ad9..b1f3a21f 100644 --- a/test/test_filt.py +++ b/test/test_filt.py @@ -1,7 +1,7 @@ import cStringIO from libmproxy import filt -from libmproxy.protocol import http from libmproxy.models import Error +from libmproxy.models import http from netlib.http import Headers from . import tutils diff --git a/test/test_flow.py b/test/test_flow.py index b8d1fad3..51b88fff 100644 --- a/test/test_flow.py +++ b/test/test_flow.py @@ -422,7 +422,7 @@ class TestFlow(object): assert not f == f2 f2.error = Error("e2") assert not f == f2 - f.load_state(f2.get_state()) + f.set_state(f2.get_state()) assert f.get_state() == f2.get_state() def test_kill(self): @@ -463,6 +463,11 @@ class TestFlow(object): f.response.content = "\xc2foo" f.replace("foo", u"bar") + def test_replace_no_content(self): + f = tutils.tflow() + f.request.content = CONTENT_MISSING + assert f.replace("foo", "bar") == 0 + def test_replace(self): f = tutils.tflow(resp=True) f.request.headers["foo"] = "foo" @@ -1199,7 +1204,7 @@ class TestError: e2 = Error("bar") assert not e == e2 - e.load_state(e2.get_state()) + e.set_state(e2.get_state()) assert e.get_state() == e2.get_state() e3 = e.copy() @@ -1219,7 +1224,7 @@ class TestClientConnection: assert not c == c2 c2.timestamp_start = 42 - c.load_state(c2.get_state()) + c.set_state(c2.get_state()) assert c.timestamp_start == 42 c3 = c.copy() diff --git a/test/test_protocol_http.py b/test/test_protocol_http1.py index 489be3f9..a1485f1b 100644 --- a/test/test_protocol_http.py +++ b/test/test_protocol_http1.py @@ -1,4 +1,3 @@ -from io import BytesIO from netlib.exceptions import HttpSyntaxException from netlib.http import http1 from netlib.tcp import TCPClient @@ -6,33 +5,6 @@ from netlib.tutils import treq, raises from . import tutils, tservers -class TestHTTPResponse: - - def test_read_from_stringio(self): - s = ( - b"HTTP/1.1 200 OK\r\n" - b"Content-Length: 7\r\n" - b"\r\n" - b"content\r\n" - b"HTTP/1.1 204 OK\r\n" - b"\r\n" - ) - rfile = BytesIO(s) - r = http1.read_response(rfile, treq()) - assert r.status_code == 200 - assert r.content == b"content" - assert http1.read_response(rfile, treq()).status_code == 204 - - rfile = BytesIO(s) - # HEAD must not have content by spec. We should leave it on the pipe. - r = http1.read_response(rfile, treq(method=b"HEAD")) - assert r.status_code == 200 - assert r.content == b"" - - with raises(HttpSyntaxException): - http1.read_response(rfile, treq()) - - class TestHTTPFlow(object): def test_repr(self): diff --git a/test/test_protocol_http2.py b/test/test_protocol_http2.py new file mode 100644 index 00000000..38cfdfc3 --- /dev/null +++ b/test/test_protocol_http2.py @@ -0,0 +1,431 @@ +from __future__ import (absolute_import, print_function, division) + +import OpenSSL +import pytest +import traceback +import os +import tempfile +import sys + +from libmproxy.proxy.config import ProxyConfig +from libmproxy.proxy.server import ProxyServer +from libmproxy.cmdline import APP_HOST, APP_PORT + +import logging +logging.getLogger("hyper.packages.hpack.hpack").setLevel(logging.WARNING) +logging.getLogger("requests.packages.urllib3.connectionpool").setLevel(logging.WARNING) +logging.getLogger("passlib.utils.compat").setLevel(logging.WARNING) +logging.getLogger("passlib.registry").setLevel(logging.WARNING) +logging.getLogger("PIL.Image").setLevel(logging.WARNING) +logging.getLogger("PIL.PngImagePlugin").setLevel(logging.WARNING) + +import netlib +from netlib import tservers as netlib_tservers +from netlib.utils import http2_read_raw_frame + +import h2 +from hyperframe.frame import Frame + +from libmproxy import utils +from . import tservers + +requires_alpn = pytest.mark.skipif( + not OpenSSL._util.lib.Cryptography_HAS_ALPN, + reason="requires OpenSSL with ALPN support") + + +class _Http2ServerBase(netlib_tservers.ServerTestBase): + ssl = dict(alpn_select=b'h2') + + class handler(netlib.tcp.BaseHandler): + + def handle(self): + h2_conn = h2.connection.H2Connection(client_side=False) + + preamble = self.rfile.read(24) + h2_conn.initiate_connection() + h2_conn.receive_data(preamble) + self.wfile.write(h2_conn.data_to_send()) + self.wfile.flush() + + done = False + while not done: + try: + raw = b''.join(http2_read_raw_frame(self.rfile)) + events = h2_conn.receive_data(raw) + except: + break + self.wfile.write(h2_conn.data_to_send()) + self.wfile.flush() + + for event in events: + try: + if not self.server.handle_server_event(event, h2_conn, self.rfile, self.wfile): + done = True + break + except Exception as e: + print(repr(e)) + print(traceback.format_exc()) + done = True + break + + def handle_server_event(self, h2_conn, rfile, wfile): + raise NotImplementedError() + + +class _Http2TestBase(object): + + @classmethod + def setup_class(self): + self.config = ProxyConfig(**self.get_proxy_config()) + + tmaster = tservers.TestMaster(self.config) + tmaster.start_app(APP_HOST, APP_PORT) + self.proxy = tservers.ProxyThread(tmaster) + self.proxy.start() + + @classmethod + def teardown_class(cls): + cls.proxy.shutdown() + + @property + def master(self): + return self.proxy.tmaster + + @classmethod + def get_proxy_config(cls): + cls.cadir = os.path.join(tempfile.gettempdir(), "mitmproxy") + return dict( + no_upstream_cert = False, + cadir = cls.cadir, + authenticator = None, + ) + + def setup(self): + self.master.clear_log() + self.master.state.clear() + self.server.server.handle_server_event = self.handle_server_event + + def _setup_connection(self): + self.config.http2 = True + + client = netlib.tcp.TCPClient(("127.0.0.1", self.proxy.port)) + client.connect() + + # send CONNECT request + client.wfile.write( + b"CONNECT localhost:%d HTTP/1.1\r\n" + b"Host: localhost:%d\r\n" + b"\r\n" % (self.server.server.address.port, self.server.server.address.port) + ) + client.wfile.flush() + + # read CONNECT response + while client.rfile.readline() != "\r\n": + pass + + client.convert_to_ssl(alpn_protos=[b'h2']) + + h2_conn = h2.connection.H2Connection(client_side=True) + h2_conn.initiate_connection() + client.wfile.write(h2_conn.data_to_send()) + client.wfile.flush() + + return client, h2_conn + + def _send_request(self, wfile, h2_conn, stream_id=1, headers=[], body=b''): + h2_conn.send_headers( + stream_id=stream_id, + headers=headers, + end_stream=(len(body) == 0), + ) + if body: + h2_conn.send_data(stream_id, body) + h2_conn.end_stream(stream_id) + wfile.write(h2_conn.data_to_send()) + wfile.flush() + + +@requires_alpn +class TestSimple(_Http2TestBase, _Http2ServerBase): + + @classmethod + def setup_class(self): + _Http2TestBase.setup_class() + _Http2ServerBase.setup_class() + + @classmethod + def teardown_class(self): + _Http2TestBase.teardown_class() + _Http2ServerBase.teardown_class() + + @classmethod + def handle_server_event(self, event, h2_conn, rfile, wfile): + if isinstance(event, h2.events.ConnectionTerminated): + return False + elif isinstance(event, h2.events.RequestReceived): + h2_conn.send_headers(1, [ + (':status', '200'), + ('foo', 'bar'), + ]) + h2_conn.send_data(1, b'foobar') + h2_conn.end_stream(1) + wfile.write(h2_conn.data_to_send()) + wfile.flush() + + return True + + def test_simple(self): + client, h2_conn = self._setup_connection() + + self._send_request(client.wfile, h2_conn, headers=[ + (':authority', "127.0.0.1:%s" % self.server.server.address.port), + (':method', 'GET'), + (':scheme', 'https'), + (':path', '/'), + ], body='my request body echoed back to me') + + done = False + while not done: + try: + events = h2_conn.receive_data(b''.join(http2_read_raw_frame(client.rfile))) + except: + break + client.wfile.write(h2_conn.data_to_send()) + client.wfile.flush() + + for event in events: + if isinstance(event, h2.events.StreamEnded): + done = True + + h2_conn.close_connection() + client.wfile.write(h2_conn.data_to_send()) + client.wfile.flush() + + assert len(self.master.state.flows) == 1 + assert self.master.state.flows[0].response.status_code == 200 + assert self.master.state.flows[0].response.headers['foo'] == 'bar' + assert self.master.state.flows[0].response.body == b'foobar' + + +@requires_alpn +class TestWithBodies(_Http2TestBase, _Http2ServerBase): + tmp_data_buffer_foobar = b'' + + @classmethod + def setup_class(self): + _Http2TestBase.setup_class() + _Http2ServerBase.setup_class() + + @classmethod + def teardown_class(self): + _Http2TestBase.teardown_class() + _Http2ServerBase.teardown_class() + + @classmethod + def handle_server_event(self, event, h2_conn, rfile, wfile): + if isinstance(event, h2.events.ConnectionTerminated): + return False + if isinstance(event, h2.events.DataReceived): + self.tmp_data_buffer_foobar += event.data + elif isinstance(event, h2.events.StreamEnded): + h2_conn.send_headers(1, [ + (':status', '200'), + ]) + h2_conn.send_data(1, self.tmp_data_buffer_foobar) + h2_conn.end_stream(1) + wfile.write(h2_conn.data_to_send()) + wfile.flush() + + return True + + def test_with_bodies(self): + client, h2_conn = self._setup_connection() + + self._send_request( + client.wfile, + h2_conn, + headers=[ + (':authority', "127.0.0.1:%s" % self.server.server.address.port), + (':method', 'GET'), + (':scheme', 'https'), + (':path', '/'), + ], + body='foobar with request body', + ) + + done = False + while not done: + try: + events = h2_conn.receive_data(b''.join(http2_read_raw_frame(client.rfile))) + except: + break + client.wfile.write(h2_conn.data_to_send()) + client.wfile.flush() + + for event in events: + if isinstance(event, h2.events.StreamEnded): + done = True + + h2_conn.close_connection() + client.wfile.write(h2_conn.data_to_send()) + client.wfile.flush() + + assert self.master.state.flows[0].response.body == b'foobar with request body' + + +@requires_alpn +class TestPushPromise(_Http2TestBase, _Http2ServerBase): + + @classmethod + def setup_class(self): + _Http2TestBase.setup_class() + _Http2ServerBase.setup_class() + + @classmethod + def teardown_class(self): + _Http2TestBase.teardown_class() + _Http2ServerBase.teardown_class() + + @classmethod + def handle_server_event(self, event, h2_conn, rfile, wfile): + if isinstance(event, h2.events.ConnectionTerminated): + return False + elif isinstance(event, h2.events.RequestReceived): + if event.stream_id != 1: + # ignore requests initiated by push promises + return True + + h2_conn.send_headers(1, [(':status', '200')]) + h2_conn.push_stream(1, 2, [ + (':authority', "127.0.0.1:%s" % self.port), + (':method', 'GET'), + (':scheme', 'https'), + (':path', '/pushed_stream_foo'), + ('foo', 'bar') + ]) + h2_conn.push_stream(1, 4, [ + (':authority', "127.0.0.1:%s" % self.port), + (':method', 'GET'), + (':scheme', 'https'), + (':path', '/pushed_stream_bar'), + ('foo', 'bar') + ]) + wfile.write(h2_conn.data_to_send()) + wfile.flush() + + h2_conn.send_headers(2, [(':status', '200')]) + h2_conn.send_headers(4, [(':status', '200')]) + wfile.write(h2_conn.data_to_send()) + wfile.flush() + + h2_conn.send_data(1, b'regular_stream') + h2_conn.send_data(2, b'pushed_stream_foo') + h2_conn.send_data(4, b'pushed_stream_bar') + wfile.write(h2_conn.data_to_send()) + wfile.flush() + h2_conn.end_stream(1) + h2_conn.end_stream(2) + h2_conn.end_stream(4) + wfile.write(h2_conn.data_to_send()) + wfile.flush() + + return True + + def test_push_promise(self): + client, h2_conn = self._setup_connection() + + self._send_request(client.wfile, h2_conn, stream_id=1, headers=[ + (':authority', "127.0.0.1:%s" % self.server.server.address.port), + (':method', 'GET'), + (':scheme', 'https'), + (':path', '/'), + ('foo', 'bar') + ]) + + done = False + ended_streams = 0 + pushed_streams = 0 + responses = 0 + while not done: + try: + raw = b''.join(http2_read_raw_frame(client.rfile)) + events = h2_conn.receive_data(raw) + except: + break + client.wfile.write(h2_conn.data_to_send()) + client.wfile.flush() + + for event in events: + if isinstance(event, h2.events.StreamEnded): + ended_streams += 1 + elif isinstance(event, h2.events.PushedStreamReceived): + pushed_streams += 1 + elif isinstance(event, h2.events.ResponseReceived): + responses += 1 + if isinstance(event, h2.events.ConnectionTerminated): + done = True + + if responses == 3 and ended_streams == 3 and pushed_streams == 2: + done = True + + h2_conn.close_connection() + client.wfile.write(h2_conn.data_to_send()) + client.wfile.flush() + + assert ended_streams == 3 + assert pushed_streams == 2 + + bodies = [flow.response.body for flow in self.master.state.flows] + assert len(bodies) == 3 + assert b'regular_stream' in bodies + assert b'pushed_stream_foo' in bodies + assert b'pushed_stream_bar' in bodies + + def test_push_promise_reset(self): + client, h2_conn = self._setup_connection() + + self._send_request(client.wfile, h2_conn, stream_id=1, headers=[ + (':authority', "127.0.0.1:%s" % self.server.server.address.port), + (':method', 'GET'), + (':scheme', 'https'), + (':path', '/'), + ('foo', 'bar') + ]) + + done = False + ended_streams = 0 + pushed_streams = 0 + responses = 0 + while not done: + try: + events = h2_conn.receive_data(b''.join(http2_read_raw_frame(client.rfile))) + except: + break + client.wfile.write(h2_conn.data_to_send()) + client.wfile.flush() + + for event in events: + if isinstance(event, h2.events.StreamEnded) and event.stream_id == 1: + ended_streams += 1 + elif isinstance(event, h2.events.PushedStreamReceived): + pushed_streams += 1 + h2_conn.reset_stream(event.pushed_stream_id, error_code=0x8) + client.wfile.write(h2_conn.data_to_send()) + client.wfile.flush() + elif isinstance(event, h2.events.ResponseReceived): + responses += 1 + if isinstance(event, h2.events.ConnectionTerminated): + done = True + + if responses >= 1 and ended_streams >= 1 and pushed_streams == 2: + done = True + + h2_conn.close_connection() + client.wfile.write(h2_conn.data_to_send()) + client.wfile.flush() + + bodies = [flow.response.body for flow in self.master.state.flows if flow.response] + assert len(bodies) >= 1 + assert b'regular_stream' in bodies + # the other two bodies might not be transmitted before the reset diff --git a/test/tutils.py b/test/tutils.py index 5bd91307..2ce0884d 100644 --- a/test/tutils.py +++ b/test/tutils.py @@ -76,7 +76,11 @@ def tclient_conn(): """ c = ClientConnection.from_state(dict( address=dict(address=("address", 22), use_ipv6=True), - clientcert=None + clientcert=None, + ssl_established=False, + timestamp_start=1, + timestamp_ssl_setup=2, + timestamp_end=3, )) c.reply = controller.DummyReply() return c @@ -88,9 +92,15 @@ def tserver_conn(): """ c = ServerConnection.from_state(dict( address=dict(address=("address", 22), use_ipv6=True), - state=[], source_address=dict(address=("address", 22), use_ipv6=True), - cert=None + cert=None, + timestamp_start=1, + timestamp_tcp_setup=2, + timestamp_ssl_setup=3, + timestamp_end=4, + ssl_established=False, + sni="address", + via=None )) c.reply = controller.DummyReply() return c |