diff options
author | Thomas Kriechbaumer <thomas@kriechbaumer.name> | 2015-08-21 10:03:57 +0200 |
---|---|---|
committer | Thomas Kriechbaumer <thomas@kriechbaumer.name> | 2015-08-21 10:04:57 +0200 |
commit | cd9701050f58f90c757a34f7e4e6b5711700d649 (patch) | |
tree | a34ce39dc2a260d57873e14e35bfedd1d939c9cb | |
parent | 6fc2ff94694d70426663209e2ded977d9e0ecd3c (diff) | |
download | mitmproxy-cd9701050f58f90c757a34f7e4e6b5711700d649.tar.gz mitmproxy-cd9701050f58f90c757a34f7e4e6b5711700d649.tar.bz2 mitmproxy-cd9701050f58f90c757a34f7e4e6b5711700d649.zip |
read_response depends on request for stream_id
-rw-r--r-- | netlib/http/http1/protocol.py | 4 | ||||
-rw-r--r-- | netlib/http/http2/protocol.py | 18 | ||||
-rw-r--r-- | netlib/http/semantics.py | 34 | ||||
-rw-r--r-- | test/http/http1/test_protocol.py | 5 | ||||
-rw-r--r-- | test/http/http2/test_protocol.py | 39 | ||||
-rw-r--r-- | test/websockets/test_websockets.py | 6 |
6 files changed, 57 insertions, 49 deletions
diff --git a/netlib/http/http1/protocol.py b/netlib/http/http1/protocol.py index dc33a8af..107a48d1 100644 --- a/netlib/http/http1/protocol.py +++ b/netlib/http/http1/protocol.py @@ -136,7 +136,7 @@ class HTTP1Protocol(semantics.ProtocolMixin): def read_response( self, - request_method, + request, body_size_limit, include_body=True, ): @@ -175,7 +175,7 @@ class HTTP1Protocol(semantics.ProtocolMixin): body = self.read_http_body( headers, body_size_limit, - request_method, + request.method, code, False ) diff --git a/netlib/http/http2/protocol.py b/netlib/http/http2/protocol.py index 66ce19c8..e032c2a0 100644 --- a/netlib/http/http2/protocol.py +++ b/netlib/http/http2/protocol.py @@ -74,7 +74,9 @@ class HTTP2Protocol(semantics.ProtocolMixin): if hasattr(self.tcp_handler.rfile, "reset_timestamps"): self.tcp_handler.rfile.reset_timestamps() - stream_id, headers, body = self._receive_transmission(include_body) + stream_id, headers, body = self._receive_transmission( + include_body=include_body, + ) if hasattr(self.tcp_handler.rfile, "first_byte_timestamp"): # more accurate timestamp_start @@ -127,7 +129,7 @@ class HTTP2Protocol(semantics.ProtocolMixin): def read_response( self, - request_method='', + request='', body_size_limit=None, include_body=True, ): @@ -137,7 +139,10 @@ class HTTP2Protocol(semantics.ProtocolMixin): if hasattr(self.tcp_handler.rfile, "reset_timestamps"): self.tcp_handler.rfile.reset_timestamps() - stream_id, headers, body = self._receive_transmission(include_body) + stream_id, headers, body = self._receive_transmission( + stream_id=request.stream_id, + include_body=include_body, + ) if hasattr(self.tcp_handler.rfile, "first_byte_timestamp"): # more accurate timestamp_start @@ -145,7 +150,7 @@ class HTTP2Protocol(semantics.ProtocolMixin): if include_body: timestamp_end = time.time() - else: + else: # pragma: no cover timestamp_end = None response = http.Response( @@ -358,11 +363,10 @@ class HTTP2Protocol(semantics.ProtocolMixin): return [frm.to_bytes() for frm in frms] - def _receive_transmission(self, include_body=True): + def _receive_transmission(self, stream_id=None, include_body=True): # TODO: include_body is not respected body_expected = True - stream_id = 0 header_block_fragment = b'' body = b'' @@ -370,7 +374,7 @@ class HTTP2Protocol(semantics.ProtocolMixin): frm = self.read_frame() if ( (isinstance(frm, frame.HeadersFrame) or isinstance(frm, frame.ContinuationFrame)) and - (stream_id == 0 or frm.stream_id == stream_id) + (stream_id is None or frm.stream_id == stream_id) ): stream_id = frm.stream_id header_block_fragment += frm.header_block_fragment diff --git a/netlib/http/semantics.py b/netlib/http/semantics.py index 836af550..e388a344 100644 --- a/netlib/http/semantics.py +++ b/netlib/http/semantics.py @@ -337,18 +337,32 @@ class Request(object): class EmptyRequest(Request): - def __init__(self): + def __init__( + self, + form_in="", + method="", + scheme="", + host="", + port="", + path="", + httpversion=None, + headers=None, + body="", + stream_id=None + ): super(EmptyRequest, self).__init__( - form_in="", - method="", - scheme="", - host="", - port="", - path="", - httpversion=(0, 0), - headers=odict.ODictCaseless(), - body="", + form_in=form_in, + method=method, + scheme=scheme, + host=host, + port=port, + path=path, + httpversion=(httpversion or (0, 0)), + headers=(headers or odict.ODictCaseless()), + body=body, ) + if stream_id: + self.stream_id = stream_id class Response(object): diff --git a/test/http/http1/test_protocol.py b/test/http/http1/test_protocol.py index 6704647f..31bf7dab 100644 --- a/test/http/http1/test_protocol.py +++ b/test/http/http1/test_protocol.py @@ -376,8 +376,9 @@ class TestReadRequest(object): class TestReadResponse(object): def tst(self, data, method, body_size_limit, include_body=True): data = textwrap.dedent(data) + request = http.EmptyRequest(method=method) return mock_protocol(data).read_response( - method, body_size_limit, include_body=include_body + request, body_size_limit, include_body=include_body ) def test_errors(self): @@ -457,7 +458,7 @@ class TestReadResponseNoContentLength(tservers.ServerTestBase): def test_no_content_length(self): c = tcp.TCPClient(("127.0.0.1", self.port)) c.connect() - resp = HTTP1Protocol(c).read_response("GET", None) + resp = HTTP1Protocol(c).read_response(http.EmptyRequest(method="GET"), None) assert resp.body == "bar\r\n\r\n" diff --git a/test/http/http2/test_protocol.py b/test/http/http2/test_protocol.py index 726d8e2e..92fa109c 100644 --- a/test/http/http2/test_protocol.py +++ b/test/http/http2/test_protocol.py @@ -365,6 +365,8 @@ class TestReadRequestConnect(tservers.ServerTestBase): def handle(self): self.wfile.write( b'00001b0105000000014287bdab4e9c17b7ff44871c92585422e08541871c92585422e085'.decode('hex')) + self.wfile.write( + b'00001d0105000000014287bdab4e9c17b7ff44882f91d35d055c87a741882f91d35d055c87a7'.decode('hex')) self.wfile.flush() ssl = True @@ -377,20 +379,25 @@ class TestReadRequestConnect(tservers.ServerTestBase): protocol.connection_preface_performed = True req = protocol.read_request() - assert req.form_in == "authority" assert req.method == "CONNECT" assert req.host == "address" assert req.port == 22 + req = protocol.read_request() + assert req.form_in == "authority" + assert req.method == "CONNECT" + assert req.host == "example.com" + assert req.port == 443 + class TestReadResponse(tservers.ServerTestBase): class handler(tcp.BaseHandler): def handle(self): self.wfile.write( - b'00000801040000000188628594e78c767f'.decode('hex')) + b'00000801040000002a88628594e78c767f'.decode('hex')) self.wfile.write( - b'000006000100000001666f6f626172'.decode('hex')) + b'00000600010000002a666f6f626172'.decode('hex')) self.wfile.flush() self.rfile.safe_read(9) # just to keep the connection alive a bit longer @@ -403,8 +410,9 @@ class TestReadResponse(tservers.ServerTestBase): protocol = HTTP2Protocol(c) protocol.connection_preface_performed = True - resp = protocol.read_response() + resp = protocol.read_response(http.EmptyRequest(stream_id=42)) + assert resp.stream_id == 42 assert resp.httpversion == (2, 0) assert resp.status_code == 200 assert resp.msg == "" @@ -412,29 +420,12 @@ class TestReadResponse(tservers.ServerTestBase): assert resp.body == b'foobar' assert resp.timestamp_end - def test_read_response_no_body(self): - c = tcp.TCPClient(("127.0.0.1", self.port)) - c.connect() - c.convert_to_ssl() - protocol = HTTP2Protocol(c) - protocol.connection_preface_performed = True - - resp = protocol.read_response(include_body=False) - - assert resp.httpversion == (2, 0) - assert resp.status_code == 200 - assert resp.msg == "" - assert resp.headers.lst == [[':status', '200'], ['etag', 'foobar']] - assert resp.body == b'foobar' # TODO: this should be true: assert resp.body == http.CONTENT_MISSING - assert not resp.timestamp_end - class TestReadEmptyResponse(tservers.ServerTestBase): class handler(tcp.BaseHandler): - def handle(self): self.wfile.write( - b'00000801050000000188628594e78c767f'.decode('hex')) + b'00000801050000002a88628594e78c767f'.decode('hex')) self.wfile.flush() ssl = True @@ -446,9 +437,9 @@ class TestReadEmptyResponse(tservers.ServerTestBase): protocol = HTTP2Protocol(c) protocol.connection_preface_performed = True - resp = protocol.read_response() + resp = protocol.read_response(http.EmptyRequest(stream_id=42)) - assert resp.stream_id + assert resp.stream_id == 42 assert resp.httpversion == (2, 0) assert resp.status_code == 200 assert resp.msg == "" diff --git a/test/websockets/test_websockets.py b/test/websockets/test_websockets.py index 752f2c3e..5f27c128 100644 --- a/test/websockets/test_websockets.py +++ b/test/websockets/test_websockets.py @@ -2,9 +2,7 @@ import os from nose.tools import raises -from netlib import tcp -from netlib import tutils -from netlib import websockets +from netlib import tcp, tutils, websockets, http from netlib.http import status_codes from netlib.http.exceptions import * from netlib.http.http1 import HTTP1Protocol @@ -72,7 +70,7 @@ class WebSocketsClient(tcp.TCPClient): self.wfile.write(headers.format() + "\r\n") self.wfile.flush() - resp = http1_protocol.read_response("get", None) + resp = http1_protocol.read_response(http.EmptyRequest(method="GET"), None) server_nonce = self.protocol.check_server_handshake(resp.headers) if not server_nonce == self.protocol.create_server_nonce( |