diff options
author | Maximilian Hils <git@maximilianhils.com> | 2016-09-21 21:00:07 -0700 |
---|---|---|
committer | GitHub <noreply@github.com> | 2016-09-21 21:00:07 -0700 |
commit | d5427c7298b0c6aee009a86dec041011427689e9 (patch) | |
tree | 00b84f3a616c1c2c6e3469c2fb2a0be4a9d5e265 | |
parent | 1e5a5b03f8d56df62a04a368bd5eb2d59cb7582a (diff) | |
parent | f1d519d7c4231513c868179abf0fbfbb9387e633 (diff) | |
download | mitmproxy-d5427c7298b0c6aee009a86dec041011427689e9.tar.gz mitmproxy-d5427c7298b0c6aee009a86dec041011427689e9.tar.bz2 mitmproxy-d5427c7298b0c6aee009a86dec041011427689e9.zip |
Merge pull request #1563 from mhils/fix-1562
Raise TypeError on invalid header assignment, fix #1562
-rw-r--r-- | netlib/http/headers.py | 1 | ||||
-rw-r--r-- | netlib/strutils.py | 5 | ||||
-rw-r--r-- | pathod/language/http2.py | 2 | ||||
-rw-r--r-- | pathod/protocols/http2.py | 4 | ||||
-rw-r--r-- | test/netlib/http/test_headers.py | 9 | ||||
-rw-r--r-- | test/netlib/test_strutils.py | 2 |
6 files changed, 19 insertions, 4 deletions
diff --git a/netlib/http/headers.py b/netlib/http/headers.py index 131e8ce5..b55874ca 100644 --- a/netlib/http/headers.py +++ b/netlib/http/headers.py @@ -14,6 +14,7 @@ if six.PY2: # pragma: no cover return x def _always_bytes(x): + strutils.always_bytes(x, "utf-8", "replace") # raises a TypeError if x != str/bytes/None. return x else: # While headers _should_ be ASCII, it's not uncommon for certain headers to be utf-8 encoded. diff --git a/netlib/strutils.py b/netlib/strutils.py index 4cb3b805..d43c2aab 100644 --- a/netlib/strutils.py +++ b/netlib/strutils.py @@ -8,7 +8,10 @@ import six def always_bytes(unicode_or_bytes, *encode_args): if isinstance(unicode_or_bytes, six.text_type): return unicode_or_bytes.encode(*encode_args) - return unicode_or_bytes + elif isinstance(unicode_or_bytes, bytes) or unicode_or_bytes is None: + return unicode_or_bytes + else: + raise TypeError("Expected str or bytes, but got {}.".format(type(unicode_or_bytes).__name__)) def native(s, *encoding_opts): diff --git a/pathod/language/http2.py b/pathod/language/http2.py index c0313baa..519ee699 100644 --- a/pathod/language/http2.py +++ b/pathod/language/http2.py @@ -189,7 +189,7 @@ class Response(_HTTP2Message): resp = http.Response( b'HTTP/2.0', - self.status_code.string(), + int(self.status_code.string()), b'', headers, body, diff --git a/pathod/protocols/http2.py b/pathod/protocols/http2.py index a2aa91b4..7b162664 100644 --- a/pathod/protocols/http2.py +++ b/pathod/protocols/http2.py @@ -6,7 +6,7 @@ import time import hyperframe.frame from hpack.hpack import Encoder, Decoder -from netlib import utils, strutils +from netlib import utils from netlib.http import http2 import netlib.http.headers import netlib.http.response @@ -201,7 +201,7 @@ class HTTP2StateProtocol(object): headers = response.headers.copy() if ':status' not in headers: - headers.insert(0, b':status', strutils.always_bytes(response.status_code)) + headers.insert(0, b':status', str(response.status_code).encode()) if hasattr(response, 'stream_id'): stream_id = response.stream_id diff --git a/test/netlib/http/test_headers.py b/test/netlib/http/test_headers.py index ad2bc548..e8752c52 100644 --- a/test/netlib/http/test_headers.py +++ b/test/netlib/http/test_headers.py @@ -43,6 +43,15 @@ class TestHeaders(object): with raises(TypeError): Headers([[b"Host", u"not-bytes"]]) + def test_set(self): + headers = Headers() + headers[u"foo"] = u"1" + headers[b"bar"] = b"2" + headers["baz"] = b"3" + with raises(TypeError): + headers["foobar"] = 42 + assert len(headers) == 3 + def test_bytes(self): headers = Headers(Host="example.com") assert bytes(headers) == b"Host: example.com\r\n" diff --git a/test/netlib/test_strutils.py b/test/netlib/test_strutils.py index 5be254a3..0f58cac5 100644 --- a/test/netlib/test_strutils.py +++ b/test/netlib/test_strutils.py @@ -8,6 +8,8 @@ def test_always_bytes(): assert strutils.always_bytes("foo") == b"foo" with tutils.raises(ValueError): strutils.always_bytes(u"\u2605", "ascii") + with tutils.raises(TypeError): + strutils.always_bytes(42, "ascii") def test_native(): |