diff options
-rw-r--r-- | netlib/http/headers.py | 1 | ||||
-rw-r--r-- | netlib/strutils.py | 5 | ||||
-rw-r--r-- | test/netlib/http/test_headers.py | 9 | ||||
-rw-r--r-- | test/netlib/test_strutils.py | 2 |
4 files changed, 16 insertions, 1 deletions
diff --git a/netlib/http/headers.py b/netlib/http/headers.py index 131e8ce5..b55874ca 100644 --- a/netlib/http/headers.py +++ b/netlib/http/headers.py @@ -14,6 +14,7 @@ if six.PY2: # pragma: no cover return x def _always_bytes(x): + strutils.always_bytes(x, "utf-8", "replace") # raises a TypeError if x != str/bytes/None. return x else: # While headers _should_ be ASCII, it's not uncommon for certain headers to be utf-8 encoded. diff --git a/netlib/strutils.py b/netlib/strutils.py index 4cb3b805..d43c2aab 100644 --- a/netlib/strutils.py +++ b/netlib/strutils.py @@ -8,7 +8,10 @@ import six def always_bytes(unicode_or_bytes, *encode_args): if isinstance(unicode_or_bytes, six.text_type): return unicode_or_bytes.encode(*encode_args) - return unicode_or_bytes + elif isinstance(unicode_or_bytes, bytes) or unicode_or_bytes is None: + return unicode_or_bytes + else: + raise TypeError("Expected str or bytes, but got {}.".format(type(unicode_or_bytes).__name__)) def native(s, *encoding_opts): diff --git a/test/netlib/http/test_headers.py b/test/netlib/http/test_headers.py index ad2bc548..e8752c52 100644 --- a/test/netlib/http/test_headers.py +++ b/test/netlib/http/test_headers.py @@ -43,6 +43,15 @@ class TestHeaders(object): with raises(TypeError): Headers([[b"Host", u"not-bytes"]]) + def test_set(self): + headers = Headers() + headers[u"foo"] = u"1" + headers[b"bar"] = b"2" + headers["baz"] = b"3" + with raises(TypeError): + headers["foobar"] = 42 + assert len(headers) == 3 + def test_bytes(self): headers = Headers(Host="example.com") assert bytes(headers) == b"Host: example.com\r\n" diff --git a/test/netlib/test_strutils.py b/test/netlib/test_strutils.py index 5be254a3..0f58cac5 100644 --- a/test/netlib/test_strutils.py +++ b/test/netlib/test_strutils.py @@ -8,6 +8,8 @@ def test_always_bytes(): assert strutils.always_bytes("foo") == b"foo" with tutils.raises(ValueError): strutils.always_bytes(u"\u2605", "ascii") + with tutils.raises(TypeError): + strutils.always_bytes(42, "ascii") def test_native(): |