aboutsummaryrefslogtreecommitdiffstats
path: root/test/http/http1/test_protocol.py
diff options
context:
space:
mode:
authorThomas Kriechbaumer <thomas@kriechbaumer.name>2015-07-16 22:56:34 +0200
committerThomas Kriechbaumer <thomas@kriechbaumer.name>2015-07-22 15:30:51 +0200
commit808b294865257fc3f52b33ed2a796009658b126f (patch)
treeebf522088ab56eda052bba7c78f298faa4306557 /test/http/http1/test_protocol.py
parent230c16122b06f5c6af60e6ddc2d8e2e83cd75273 (diff)
downloadmitmproxy-808b294865257fc3f52b33ed2a796009658b126f.tar.gz
mitmproxy-808b294865257fc3f52b33ed2a796009658b126f.tar.bz2
mitmproxy-808b294865257fc3f52b33ed2a796009658b126f.zip
refactor HTTP/1 as protocol
Diffstat (limited to 'test/http/http1/test_protocol.py')
-rw-r--r--test/http/http1/test_protocol.py214
1 files changed, 111 insertions, 103 deletions
diff --git a/test/http/http1/test_protocol.py b/test/http/http1/test_protocol.py
index d0a2ee02..6b8a884c 100644
--- a/test/http/http1/test_protocol.py
+++ b/test/http/http1/test_protocol.py
@@ -3,70 +3,79 @@ import textwrap
import binascii
from netlib import http, odict, tcp
-from netlib.http.http1 import protocol
+from netlib.http.http1 import HTTP1Protocol
from ... import tutils, tservers
+def mock_protocol(data='', chunked=False):
+ class TCPHandlerMock(object):
+ pass
+ tcp_handler = TCPHandlerMock()
+ tcp_handler.rfile = cStringIO.StringIO(data)
+ tcp_handler.wfile = cStringIO.StringIO()
+ return HTTP1Protocol(tcp_handler)
+
+
+
def test_has_chunked_encoding():
h = odict.ODictCaseless()
- assert not protocol.has_chunked_encoding(h)
+ assert not HTTP1Protocol.has_chunked_encoding(h)
h["transfer-encoding"] = ["chunked"]
- assert protocol.has_chunked_encoding(h)
+ assert HTTP1Protocol.has_chunked_encoding(h)
def test_read_chunked():
-
h = odict.ODictCaseless()
h["transfer-encoding"] = ["chunked"]
- s = cStringIO.StringIO("1\r\na\r\n0\r\n")
+ data = "1\r\na\r\n0\r\n"
tutils.raises(
"malformed chunked body",
- protocol.read_http_body,
- s, h, None, "GET", None, True
+ mock_protocol(data).read_http_body,
+ h, None, "GET", None, True
)
- s = cStringIO.StringIO("1\r\na\r\n0\r\n\r\n")
- assert protocol.read_http_body(s, h, None, "GET", None, True) == "a"
+ data = "1\r\na\r\n0\r\n\r\n"
+ assert mock_protocol(data).read_http_body(h, None, "GET", None, True) == "a"
- s = cStringIO.StringIO("\r\n\r\n1\r\na\r\n0\r\n\r\n")
- assert protocol.read_http_body(s, h, None, "GET", None, True) == "a"
+ data = "\r\n\r\n1\r\na\r\n0\r\n\r\n"
+ assert mock_protocol(data).read_http_body(h, None, "GET", None, True) == "a"
- s = cStringIO.StringIO("\r\n")
+ data = "\r\n"
tutils.raises(
"closed prematurely",
- protocol.read_http_body,
- s, h, None, "GET", None, True
+ mock_protocol(data).read_http_body,
+ h, None, "GET", None, True
)
- s = cStringIO.StringIO("1\r\nfoo")
+ data = "1\r\nfoo"
tutils.raises(
"malformed chunked body",
- protocol.read_http_body,
- s, h, None, "GET", None, True
+ mock_protocol(data).read_http_body,
+ h, None, "GET", None, True
)
- s = cStringIO.StringIO("foo\r\nfoo")
+ data = "foo\r\nfoo"
tutils.raises(
- protocol.HttpError,
- protocol.read_http_body,
- s, h, None, "GET", None, True
+ http.HttpError,
+ mock_protocol(data).read_http_body,
+ h, None, "GET", None, True
)
- s = cStringIO.StringIO("5\r\naaaaa\r\n0\r\n\r\n")
- tutils.raises("too large", protocol.read_http_body, s, h, 2, "GET", None, True)
+ data = "5\r\naaaaa\r\n0\r\n\r\n"
+ tutils.raises("too large", mock_protocol(data).read_http_body, h, 2, "GET", None, True)
def test_connection_close():
h = odict.ODictCaseless()
- assert protocol.connection_close((1, 0), h)
- assert not protocol.connection_close((1, 1), h)
+ assert HTTP1Protocol.connection_close((1, 0), h)
+ assert not HTTP1Protocol.connection_close((1, 1), h)
h["connection"] = ["keep-alive"]
- assert not protocol.connection_close((1, 1), h)
+ assert not HTTP1Protocol.connection_close((1, 1), h)
h["connection"] = ["close"]
- assert protocol.connection_close((1, 1), h)
+ assert HTTP1Protocol.connection_close((1, 1), h)
def test_get_header_tokens():
@@ -82,119 +91,119 @@ def test_get_header_tokens():
def test_read_http_body_request():
h = odict.ODictCaseless()
- r = cStringIO.StringIO("testing")
- assert protocol.read_http_body(r, h, None, "GET", None, True) == ""
+ data = "testing"
+ assert mock_protocol(data).read_http_body(h, None, "GET", None, True) == ""
def test_read_http_body_response():
h = odict.ODictCaseless()
- s = tcp.Reader(cStringIO.StringIO("testing"))
- assert protocol.read_http_body(s, h, None, "GET", 200, False) == "testing"
+ data = "testing"
+ assert mock_protocol(data, chunked=True).read_http_body(h, None, "GET", 200, False) == "testing"
def test_read_http_body():
# test default case
h = odict.ODictCaseless()
h["content-length"] = [7]
- s = cStringIO.StringIO("testing")
- assert protocol.read_http_body(s, h, None, "GET", 200, False) == "testing"
+ data = "testing"
+ assert mock_protocol(data).read_http_body(h, None, "GET", 200, False) == "testing"
# test content length: invalid header
h["content-length"] = ["foo"]
- s = cStringIO.StringIO("testing")
+ data = "testing"
tutils.raises(
- protocol.HttpError,
- protocol.read_http_body,
- s, h, None, "GET", 200, False
+ http.HttpError,
+ mock_protocol(data).read_http_body,
+ h, None, "GET", 200, False
)
# test content length: invalid header #2
h["content-length"] = [-1]
- s = cStringIO.StringIO("testing")
+ data = "testing"
tutils.raises(
- protocol.HttpError,
- protocol.read_http_body,
- s, h, None, "GET", 200, False
+ http.HttpError,
+ mock_protocol(data).read_http_body,
+ h, None, "GET", 200, False
)
# test content length: content length > actual content
h["content-length"] = [5]
- s = cStringIO.StringIO("testing")
+ data = "testing"
tutils.raises(
- protocol.HttpError,
- protocol.read_http_body,
- s, h, 4, "GET", 200, False
+ http.HttpError,
+ mock_protocol(data).read_http_body,
+ h, 4, "GET", 200, False
)
# test content length: content length < actual content
- s = cStringIO.StringIO("testing")
- assert len(protocol.read_http_body(s, h, None, "GET", 200, False)) == 5
+ data = "testing"
+ assert len(mock_protocol(data).read_http_body(h, None, "GET", 200, False)) == 5
# test no content length: limit > actual content
h = odict.ODictCaseless()
- s = tcp.Reader(cStringIO.StringIO("testing"))
- assert len(protocol.read_http_body(s, h, 100, "GET", 200, False)) == 7
+ data = "testing"
+ assert len(mock_protocol(data, chunked=True).read_http_body(h, 100, "GET", 200, False)) == 7
# test no content length: limit < actual content
- s = tcp.Reader(cStringIO.StringIO("testing"))
+ data = "testing"
tutils.raises(
- protocol.HttpError,
- protocol.read_http_body,
- s, h, 4, "GET", 200, False
+ http.HttpError,
+ mock_protocol(data, chunked=True).read_http_body,
+ h, 4, "GET", 200, False
)
# test chunked
h = odict.ODictCaseless()
h["transfer-encoding"] = ["chunked"]
- s = tcp.Reader(cStringIO.StringIO("5\r\naaaaa\r\n0\r\n\r\n"))
- assert protocol.read_http_body(s, h, 100, "GET", 200, False) == "aaaaa"
+ data = "5\r\naaaaa\r\n0\r\n\r\n"
+ assert mock_protocol(data, chunked=True).read_http_body(h, 100, "GET", 200, False) == "aaaaa"
def test_expected_http_body_size():
# gibber in the content-length field
h = odict.ODictCaseless()
h["content-length"] = ["foo"]
- assert protocol.expected_http_body_size(h, False, "GET", 200) is None
+ assert HTTP1Protocol.expected_http_body_size(h, False, "GET", 200) is None
# negative number in the content-length field
h = odict.ODictCaseless()
h["content-length"] = ["-7"]
- assert protocol.expected_http_body_size(h, False, "GET", 200) is None
+ assert HTTP1Protocol.expected_http_body_size(h, False, "GET", 200) is None
# explicit length
h = odict.ODictCaseless()
h["content-length"] = ["5"]
- assert protocol.expected_http_body_size(h, False, "GET", 200) == 5
+ assert HTTP1Protocol.expected_http_body_size(h, False, "GET", 200) == 5
# no length
h = odict.ODictCaseless()
- assert protocol.expected_http_body_size(h, False, "GET", 200) == -1
+ assert HTTP1Protocol.expected_http_body_size(h, False, "GET", 200) == -1
# no length request
h = odict.ODictCaseless()
- assert protocol.expected_http_body_size(h, True, "GET", None) == 0
+ assert HTTP1Protocol.expected_http_body_size(h, True, "GET", None) == 0
def test_parse_http_protocol():
- assert protocol.parse_http_protocol("HTTP/1.1") == (1, 1)
- assert protocol.parse_http_protocol("HTTP/0.0") == (0, 0)
- assert not protocol.parse_http_protocol("HTTP/a.1")
- assert not protocol.parse_http_protocol("HTTP/1.a")
- assert not protocol.parse_http_protocol("foo/0.0")
- assert not protocol.parse_http_protocol("HTTP/x")
+ assert HTTP1Protocol.parse_http_protocol("HTTP/1.1") == (1, 1)
+ assert HTTP1Protocol.parse_http_protocol("HTTP/0.0") == (0, 0)
+ assert not HTTP1Protocol.parse_http_protocol("HTTP/a.1")
+ assert not HTTP1Protocol.parse_http_protocol("HTTP/1.a")
+ assert not HTTP1Protocol.parse_http_protocol("foo/0.0")
+ assert not HTTP1Protocol.parse_http_protocol("HTTP/x")
def test_parse_init_connect():
- assert protocol.parse_init_connect("CONNECT host.com:443 HTTP/1.0")
- assert not protocol.parse_init_connect("C\xfeONNECT host.com:443 HTTP/1.0")
- assert not protocol.parse_init_connect("CONNECT \0host.com:443 HTTP/1.0")
- assert not protocol.parse_init_connect("CONNECT host.com:444444 HTTP/1.0")
- assert not protocol.parse_init_connect("bogus")
- assert not protocol.parse_init_connect("GET host.com:443 HTTP/1.0")
- assert not protocol.parse_init_connect("CONNECT host.com443 HTTP/1.0")
- assert not protocol.parse_init_connect("CONNECT host.com:443 foo/1.0")
- assert not protocol.parse_init_connect("CONNECT host.com:foo HTTP/1.0")
+ assert HTTP1Protocol.parse_init_connect("CONNECT host.com:443 HTTP/1.0")
+ assert not HTTP1Protocol.parse_init_connect("C\xfeONNECT host.com:443 HTTP/1.0")
+ assert not HTTP1Protocol.parse_init_connect("CONNECT \0host.com:443 HTTP/1.0")
+ assert not HTTP1Protocol.parse_init_connect("CONNECT host.com:444444 HTTP/1.0")
+ assert not HTTP1Protocol.parse_init_connect("bogus")
+ assert not HTTP1Protocol.parse_init_connect("GET host.com:443 HTTP/1.0")
+ assert not HTTP1Protocol.parse_init_connect("CONNECT host.com443 HTTP/1.0")
+ assert not HTTP1Protocol.parse_init_connect("CONNECT host.com:443 foo/1.0")
+ assert not HTTP1Protocol.parse_init_connect("CONNECT host.com:foo HTTP/1.0")
def test_parse_init_proxy():
u = "GET http://foo.com:8888/test HTTP/1.1"
- m, s, h, po, pa, httpversion = protocol.parse_init_proxy(u)
+ m, s, h, po, pa, httpversion = HTTP1Protocol.parse_init_proxy(u)
assert m == "GET"
assert s == "http"
assert h == "foo.com"
@@ -203,27 +212,27 @@ def test_parse_init_proxy():
assert httpversion == (1, 1)
u = "G\xfeET http://foo.com:8888/test HTTP/1.1"
- assert not protocol.parse_init_proxy(u)
+ assert not HTTP1Protocol.parse_init_proxy(u)
- assert not protocol.parse_init_proxy("invalid")
- assert not protocol.parse_init_proxy("GET invalid HTTP/1.1")
- assert not protocol.parse_init_proxy("GET http://foo.com:8888/test foo/1.1")
+ assert not HTTP1Protocol.parse_init_proxy("invalid")
+ assert not HTTP1Protocol.parse_init_proxy("GET invalid HTTP/1.1")
+ assert not HTTP1Protocol.parse_init_proxy("GET http://foo.com:8888/test foo/1.1")
def test_parse_init_http():
u = "GET /test HTTP/1.1"
- m, u, httpversion = protocol.parse_init_http(u)
+ m, u, httpversion = HTTP1Protocol.parse_init_http(u)
assert m == "GET"
assert u == "/test"
assert httpversion == (1, 1)
u = "G\xfeET /test HTTP/1.1"
- assert not protocol.parse_init_http(u)
+ assert not HTTP1Protocol.parse_init_http(u)
- assert not protocol.parse_init_http("invalid")
- assert not protocol.parse_init_http("GET invalid HTTP/1.1")
- assert not protocol.parse_init_http("GET /test foo/1.1")
- assert not protocol.parse_init_http("GET /test\xc0 HTTP/1.1")
+ assert not HTTP1Protocol.parse_init_http("invalid")
+ assert not HTTP1Protocol.parse_init_http("GET invalid HTTP/1.1")
+ assert not HTTP1Protocol.parse_init_http("GET /test foo/1.1")
+ assert not HTTP1Protocol.parse_init_http("GET /test\xc0 HTTP/1.1")
class TestReadHeaders:
@@ -232,8 +241,7 @@ class TestReadHeaders:
if not verbatim:
data = textwrap.dedent(data)
data = data.strip()
- s = cStringIO.StringIO(data)
- return protocol.read_headers(s)
+ return mock_protocol(data).read_headers()
def test_read_simple(self):
data = """
@@ -287,16 +295,15 @@ class TestReadResponseNoContentLength(tservers.ServerTestBase):
def test_no_content_length(self):
c = tcp.TCPClient(("127.0.0.1", self.port))
c.connect()
- resp = protocol.read_response(c.rfile, "GET", None)
+ resp = HTTP1Protocol(c).read_response("GET", None)
assert resp.content == "bar\r\n\r\n"
def test_read_response():
def tst(data, method, limit, include_body=True):
data = textwrap.dedent(data)
- r = cStringIO.StringIO(data)
- return protocol.read_response(
- r, method, limit, include_body=include_body
+ return mock_protocol(data).read_response(
+ method, limit, include_body=include_body
)
tutils.raises("server disconnect", tst, "", "GET", None)
@@ -358,16 +365,16 @@ def test_read_response():
def test_get_request_line():
- r = cStringIO.StringIO("\nfoo")
- assert protocol.get_request_line(r) == "foo"
- assert not protocol.get_request_line(r)
+ data = "\nfoo"
+ p = mock_protocol(data)
+ assert p.get_request_line() == "foo"
+ assert not p.get_request_line()
class TestReadRequest():
def tst(self, data, **kwargs):
- r = cStringIO.StringIO(data)
- return protocol.read_request(r, **kwargs)
+ return mock_protocol(data).read_request(**kwargs)
def test_invalid(self):
tutils.raises(
@@ -421,14 +428,15 @@ class TestReadRequest():
assert v.host == "foo.com"
def test_expect(self):
- w = cStringIO.StringIO()
- r = cStringIO.StringIO(
+ data = "".join(
"GET / HTTP/1.1\r\n"
"Content-Length: 3\r\n"
"Expect: 100-continue\r\n\r\n"
- "foobar",
+ "foobar"
)
- v = protocol.read_request(r, wfile=w)
- assert w.getvalue() == "HTTP/1.1 100 Continue\r\n\r\n"
+
+ p = mock_protocol(data)
+ v = p.read_request()
+ assert p.tcp_handler.wfile.getvalue() == "HTTP/1.1 100 Continue\r\n\r\n"
assert v.content == "foo"
- assert r.read(3) == "bar"
+ assert p.tcp_handler.rfile.read(3) == "bar"