diff options
author | Thomas Kriechbaumer <thomas@kriechbaumer.name> | 2015-07-16 22:56:34 +0200 |
---|---|---|
committer | Thomas Kriechbaumer <thomas@kriechbaumer.name> | 2015-07-22 15:30:51 +0200 |
commit | 808b294865257fc3f52b33ed2a796009658b126f (patch) | |
tree | ebf522088ab56eda052bba7c78f298faa4306557 /test/http/http1/test_protocol.py | |
parent | 230c16122b06f5c6af60e6ddc2d8e2e83cd75273 (diff) | |
download | mitmproxy-808b294865257fc3f52b33ed2a796009658b126f.tar.gz mitmproxy-808b294865257fc3f52b33ed2a796009658b126f.tar.bz2 mitmproxy-808b294865257fc3f52b33ed2a796009658b126f.zip |
refactor HTTP/1 as protocol
Diffstat (limited to 'test/http/http1/test_protocol.py')
-rw-r--r-- | test/http/http1/test_protocol.py | 214 |
1 files changed, 111 insertions, 103 deletions
diff --git a/test/http/http1/test_protocol.py b/test/http/http1/test_protocol.py index d0a2ee02..6b8a884c 100644 --- a/test/http/http1/test_protocol.py +++ b/test/http/http1/test_protocol.py @@ -3,70 +3,79 @@ import textwrap import binascii from netlib import http, odict, tcp -from netlib.http.http1 import protocol +from netlib.http.http1 import HTTP1Protocol from ... import tutils, tservers +def mock_protocol(data='', chunked=False): + class TCPHandlerMock(object): + pass + tcp_handler = TCPHandlerMock() + tcp_handler.rfile = cStringIO.StringIO(data) + tcp_handler.wfile = cStringIO.StringIO() + return HTTP1Protocol(tcp_handler) + + + def test_has_chunked_encoding(): h = odict.ODictCaseless() - assert not protocol.has_chunked_encoding(h) + assert not HTTP1Protocol.has_chunked_encoding(h) h["transfer-encoding"] = ["chunked"] - assert protocol.has_chunked_encoding(h) + assert HTTP1Protocol.has_chunked_encoding(h) def test_read_chunked(): - h = odict.ODictCaseless() h["transfer-encoding"] = ["chunked"] - s = cStringIO.StringIO("1\r\na\r\n0\r\n") + data = "1\r\na\r\n0\r\n" tutils.raises( "malformed chunked body", - protocol.read_http_body, - s, h, None, "GET", None, True + mock_protocol(data).read_http_body, + h, None, "GET", None, True ) - s = cStringIO.StringIO("1\r\na\r\n0\r\n\r\n") - assert protocol.read_http_body(s, h, None, "GET", None, True) == "a" + data = "1\r\na\r\n0\r\n\r\n" + assert mock_protocol(data).read_http_body(h, None, "GET", None, True) == "a" - s = cStringIO.StringIO("\r\n\r\n1\r\na\r\n0\r\n\r\n") - assert protocol.read_http_body(s, h, None, "GET", None, True) == "a" + data = "\r\n\r\n1\r\na\r\n0\r\n\r\n" + assert mock_protocol(data).read_http_body(h, None, "GET", None, True) == "a" - s = cStringIO.StringIO("\r\n") + data = "\r\n" tutils.raises( "closed prematurely", - protocol.read_http_body, - s, h, None, "GET", None, True + mock_protocol(data).read_http_body, + h, None, "GET", None, True ) - s = cStringIO.StringIO("1\r\nfoo") + data = "1\r\nfoo" tutils.raises( "malformed chunked body", - protocol.read_http_body, - s, h, None, "GET", None, True + mock_protocol(data).read_http_body, + h, None, "GET", None, True ) - s = cStringIO.StringIO("foo\r\nfoo") + data = "foo\r\nfoo" tutils.raises( - protocol.HttpError, - protocol.read_http_body, - s, h, None, "GET", None, True + http.HttpError, + mock_protocol(data).read_http_body, + h, None, "GET", None, True ) - s = cStringIO.StringIO("5\r\naaaaa\r\n0\r\n\r\n") - tutils.raises("too large", protocol.read_http_body, s, h, 2, "GET", None, True) + data = "5\r\naaaaa\r\n0\r\n\r\n" + tutils.raises("too large", mock_protocol(data).read_http_body, h, 2, "GET", None, True) def test_connection_close(): h = odict.ODictCaseless() - assert protocol.connection_close((1, 0), h) - assert not protocol.connection_close((1, 1), h) + assert HTTP1Protocol.connection_close((1, 0), h) + assert not HTTP1Protocol.connection_close((1, 1), h) h["connection"] = ["keep-alive"] - assert not protocol.connection_close((1, 1), h) + assert not HTTP1Protocol.connection_close((1, 1), h) h["connection"] = ["close"] - assert protocol.connection_close((1, 1), h) + assert HTTP1Protocol.connection_close((1, 1), h) def test_get_header_tokens(): @@ -82,119 +91,119 @@ def test_get_header_tokens(): def test_read_http_body_request(): h = odict.ODictCaseless() - r = cStringIO.StringIO("testing") - assert protocol.read_http_body(r, h, None, "GET", None, True) == "" + data = "testing" + assert mock_protocol(data).read_http_body(h, None, "GET", None, True) == "" def test_read_http_body_response(): h = odict.ODictCaseless() - s = tcp.Reader(cStringIO.StringIO("testing")) - assert protocol.read_http_body(s, h, None, "GET", 200, False) == "testing" + data = "testing" + assert mock_protocol(data, chunked=True).read_http_body(h, None, "GET", 200, False) == "testing" def test_read_http_body(): # test default case h = odict.ODictCaseless() h["content-length"] = [7] - s = cStringIO.StringIO("testing") - assert protocol.read_http_body(s, h, None, "GET", 200, False) == "testing" + data = "testing" + assert mock_protocol(data).read_http_body(h, None, "GET", 200, False) == "testing" # test content length: invalid header h["content-length"] = ["foo"] - s = cStringIO.StringIO("testing") + data = "testing" tutils.raises( - protocol.HttpError, - protocol.read_http_body, - s, h, None, "GET", 200, False + http.HttpError, + mock_protocol(data).read_http_body, + h, None, "GET", 200, False ) # test content length: invalid header #2 h["content-length"] = [-1] - s = cStringIO.StringIO("testing") + data = "testing" tutils.raises( - protocol.HttpError, - protocol.read_http_body, - s, h, None, "GET", 200, False + http.HttpError, + mock_protocol(data).read_http_body, + h, None, "GET", 200, False ) # test content length: content length > actual content h["content-length"] = [5] - s = cStringIO.StringIO("testing") + data = "testing" tutils.raises( - protocol.HttpError, - protocol.read_http_body, - s, h, 4, "GET", 200, False + http.HttpError, + mock_protocol(data).read_http_body, + h, 4, "GET", 200, False ) # test content length: content length < actual content - s = cStringIO.StringIO("testing") - assert len(protocol.read_http_body(s, h, None, "GET", 200, False)) == 5 + data = "testing" + assert len(mock_protocol(data).read_http_body(h, None, "GET", 200, False)) == 5 # test no content length: limit > actual content h = odict.ODictCaseless() - s = tcp.Reader(cStringIO.StringIO("testing")) - assert len(protocol.read_http_body(s, h, 100, "GET", 200, False)) == 7 + data = "testing" + assert len(mock_protocol(data, chunked=True).read_http_body(h, 100, "GET", 200, False)) == 7 # test no content length: limit < actual content - s = tcp.Reader(cStringIO.StringIO("testing")) + data = "testing" tutils.raises( - protocol.HttpError, - protocol.read_http_body, - s, h, 4, "GET", 200, False + http.HttpError, + mock_protocol(data, chunked=True).read_http_body, + h, 4, "GET", 200, False ) # test chunked h = odict.ODictCaseless() h["transfer-encoding"] = ["chunked"] - s = tcp.Reader(cStringIO.StringIO("5\r\naaaaa\r\n0\r\n\r\n")) - assert protocol.read_http_body(s, h, 100, "GET", 200, False) == "aaaaa" + data = "5\r\naaaaa\r\n0\r\n\r\n" + assert mock_protocol(data, chunked=True).read_http_body(h, 100, "GET", 200, False) == "aaaaa" def test_expected_http_body_size(): # gibber in the content-length field h = odict.ODictCaseless() h["content-length"] = ["foo"] - assert protocol.expected_http_body_size(h, False, "GET", 200) is None + assert HTTP1Protocol.expected_http_body_size(h, False, "GET", 200) is None # negative number in the content-length field h = odict.ODictCaseless() h["content-length"] = ["-7"] - assert protocol.expected_http_body_size(h, False, "GET", 200) is None + assert HTTP1Protocol.expected_http_body_size(h, False, "GET", 200) is None # explicit length h = odict.ODictCaseless() h["content-length"] = ["5"] - assert protocol.expected_http_body_size(h, False, "GET", 200) == 5 + assert HTTP1Protocol.expected_http_body_size(h, False, "GET", 200) == 5 # no length h = odict.ODictCaseless() - assert protocol.expected_http_body_size(h, False, "GET", 200) == -1 + assert HTTP1Protocol.expected_http_body_size(h, False, "GET", 200) == -1 # no length request h = odict.ODictCaseless() - assert protocol.expected_http_body_size(h, True, "GET", None) == 0 + assert HTTP1Protocol.expected_http_body_size(h, True, "GET", None) == 0 def test_parse_http_protocol(): - assert protocol.parse_http_protocol("HTTP/1.1") == (1, 1) - assert protocol.parse_http_protocol("HTTP/0.0") == (0, 0) - assert not protocol.parse_http_protocol("HTTP/a.1") - assert not protocol.parse_http_protocol("HTTP/1.a") - assert not protocol.parse_http_protocol("foo/0.0") - assert not protocol.parse_http_protocol("HTTP/x") + assert HTTP1Protocol.parse_http_protocol("HTTP/1.1") == (1, 1) + assert HTTP1Protocol.parse_http_protocol("HTTP/0.0") == (0, 0) + assert not HTTP1Protocol.parse_http_protocol("HTTP/a.1") + assert not HTTP1Protocol.parse_http_protocol("HTTP/1.a") + assert not HTTP1Protocol.parse_http_protocol("foo/0.0") + assert not HTTP1Protocol.parse_http_protocol("HTTP/x") def test_parse_init_connect(): - assert protocol.parse_init_connect("CONNECT host.com:443 HTTP/1.0") - assert not protocol.parse_init_connect("C\xfeONNECT host.com:443 HTTP/1.0") - assert not protocol.parse_init_connect("CONNECT \0host.com:443 HTTP/1.0") - assert not protocol.parse_init_connect("CONNECT host.com:444444 HTTP/1.0") - assert not protocol.parse_init_connect("bogus") - assert not protocol.parse_init_connect("GET host.com:443 HTTP/1.0") - assert not protocol.parse_init_connect("CONNECT host.com443 HTTP/1.0") - assert not protocol.parse_init_connect("CONNECT host.com:443 foo/1.0") - assert not protocol.parse_init_connect("CONNECT host.com:foo HTTP/1.0") + assert HTTP1Protocol.parse_init_connect("CONNECT host.com:443 HTTP/1.0") + assert not HTTP1Protocol.parse_init_connect("C\xfeONNECT host.com:443 HTTP/1.0") + assert not HTTP1Protocol.parse_init_connect("CONNECT \0host.com:443 HTTP/1.0") + assert not HTTP1Protocol.parse_init_connect("CONNECT host.com:444444 HTTP/1.0") + assert not HTTP1Protocol.parse_init_connect("bogus") + assert not HTTP1Protocol.parse_init_connect("GET host.com:443 HTTP/1.0") + assert not HTTP1Protocol.parse_init_connect("CONNECT host.com443 HTTP/1.0") + assert not HTTP1Protocol.parse_init_connect("CONNECT host.com:443 foo/1.0") + assert not HTTP1Protocol.parse_init_connect("CONNECT host.com:foo HTTP/1.0") def test_parse_init_proxy(): u = "GET http://foo.com:8888/test HTTP/1.1" - m, s, h, po, pa, httpversion = protocol.parse_init_proxy(u) + m, s, h, po, pa, httpversion = HTTP1Protocol.parse_init_proxy(u) assert m == "GET" assert s == "http" assert h == "foo.com" @@ -203,27 +212,27 @@ def test_parse_init_proxy(): assert httpversion == (1, 1) u = "G\xfeET http://foo.com:8888/test HTTP/1.1" - assert not protocol.parse_init_proxy(u) + assert not HTTP1Protocol.parse_init_proxy(u) - assert not protocol.parse_init_proxy("invalid") - assert not protocol.parse_init_proxy("GET invalid HTTP/1.1") - assert not protocol.parse_init_proxy("GET http://foo.com:8888/test foo/1.1") + assert not HTTP1Protocol.parse_init_proxy("invalid") + assert not HTTP1Protocol.parse_init_proxy("GET invalid HTTP/1.1") + assert not HTTP1Protocol.parse_init_proxy("GET http://foo.com:8888/test foo/1.1") def test_parse_init_http(): u = "GET /test HTTP/1.1" - m, u, httpversion = protocol.parse_init_http(u) + m, u, httpversion = HTTP1Protocol.parse_init_http(u) assert m == "GET" assert u == "/test" assert httpversion == (1, 1) u = "G\xfeET /test HTTP/1.1" - assert not protocol.parse_init_http(u) + assert not HTTP1Protocol.parse_init_http(u) - assert not protocol.parse_init_http("invalid") - assert not protocol.parse_init_http("GET invalid HTTP/1.1") - assert not protocol.parse_init_http("GET /test foo/1.1") - assert not protocol.parse_init_http("GET /test\xc0 HTTP/1.1") + assert not HTTP1Protocol.parse_init_http("invalid") + assert not HTTP1Protocol.parse_init_http("GET invalid HTTP/1.1") + assert not HTTP1Protocol.parse_init_http("GET /test foo/1.1") + assert not HTTP1Protocol.parse_init_http("GET /test\xc0 HTTP/1.1") class TestReadHeaders: @@ -232,8 +241,7 @@ class TestReadHeaders: if not verbatim: data = textwrap.dedent(data) data = data.strip() - s = cStringIO.StringIO(data) - return protocol.read_headers(s) + return mock_protocol(data).read_headers() def test_read_simple(self): data = """ @@ -287,16 +295,15 @@ class TestReadResponseNoContentLength(tservers.ServerTestBase): def test_no_content_length(self): c = tcp.TCPClient(("127.0.0.1", self.port)) c.connect() - resp = protocol.read_response(c.rfile, "GET", None) + resp = HTTP1Protocol(c).read_response("GET", None) assert resp.content == "bar\r\n\r\n" def test_read_response(): def tst(data, method, limit, include_body=True): data = textwrap.dedent(data) - r = cStringIO.StringIO(data) - return protocol.read_response( - r, method, limit, include_body=include_body + return mock_protocol(data).read_response( + method, limit, include_body=include_body ) tutils.raises("server disconnect", tst, "", "GET", None) @@ -358,16 +365,16 @@ def test_read_response(): def test_get_request_line(): - r = cStringIO.StringIO("\nfoo") - assert protocol.get_request_line(r) == "foo" - assert not protocol.get_request_line(r) + data = "\nfoo" + p = mock_protocol(data) + assert p.get_request_line() == "foo" + assert not p.get_request_line() class TestReadRequest(): def tst(self, data, **kwargs): - r = cStringIO.StringIO(data) - return protocol.read_request(r, **kwargs) + return mock_protocol(data).read_request(**kwargs) def test_invalid(self): tutils.raises( @@ -421,14 +428,15 @@ class TestReadRequest(): assert v.host == "foo.com" def test_expect(self): - w = cStringIO.StringIO() - r = cStringIO.StringIO( + data = "".join( "GET / HTTP/1.1\r\n" "Content-Length: 3\r\n" "Expect: 100-continue\r\n\r\n" - "foobar", + "foobar" ) - v = protocol.read_request(r, wfile=w) - assert w.getvalue() == "HTTP/1.1 100 Continue\r\n\r\n" + + p = mock_protocol(data) + v = p.read_request() + assert p.tcp_handler.wfile.getvalue() == "HTTP/1.1 100 Continue\r\n\r\n" assert v.content == "foo" - assert r.read(3) == "bar" + assert p.tcp_handler.rfile.read(3) == "bar" |