diff options
-rw-r--r-- | .travis.yml | 6 | ||||
-rw-r--r-- | netlib/encoding.py | 20 | ||||
-rw-r--r-- | netlib/http/models.py | 48 | ||||
-rw-r--r-- | netlib/odict.py | 25 | ||||
-rw-r--r-- | netlib/tutils.py | 4 | ||||
-rw-r--r-- | netlib/utils.py | 22 | ||||
-rw-r--r-- | test/http/test_models.py | 8 | ||||
-rw-r--r-- | test/test_encoding.py | 10 | ||||
-rw-r--r-- | test/test_odict.py | 40 | ||||
-rw-r--r-- | test/test_socks.py | 55 | ||||
-rw-r--r-- | test/test_utils.py | 10 |
11 files changed, 105 insertions, 143 deletions
diff --git a/.travis.yml b/.travis.yml index fa997542..7e18176c 100644 --- a/.travis.yml +++ b/.travis.yml @@ -16,7 +16,11 @@ matrix: packages: - libssl-dev - python: 3.5 - script: "nosetests --with-cov --cov-report term-missing test/http/http1" + script: + - nosetests --with-cov --cov-report term-missing test/http/http1 + - nosetests --with-cov --cov-report term-missing test/test_utils.py + - nosetests --with-cov --cov-report term-missing test/test_encoding.py + - nosetests --with-cov --cov-report term-missing test/test_odict.py - python: pypy - python: pypy env: OPENSSL=1.0.2 diff --git a/netlib/encoding.py b/netlib/encoding.py index 06830f2c..8ac59905 100644 --- a/netlib/encoding.py +++ b/netlib/encoding.py @@ -5,28 +5,30 @@ from __future__ import absolute_import from io import BytesIO import gzip import zlib +from .utils import always_byte_args -__ALL__ = ["ENCODINGS"] -ENCODINGS = {"identity", "gzip", "deflate"} +ENCODINGS = {b"identity", b"gzip", b"deflate"} +@always_byte_args("ascii", "ignore") def decode(e, content): encoding_map = { - "identity": identity, - "gzip": decode_gzip, - "deflate": decode_deflate, + b"identity": identity, + b"gzip": decode_gzip, + b"deflate": decode_deflate, } if e not in encoding_map: return None return encoding_map[e](content) +@always_byte_args("ascii", "ignore") def encode(e, content): encoding_map = { - "identity": identity, - "gzip": encode_gzip, - "deflate": encode_deflate, + b"identity": identity, + b"gzip": encode_gzip, + b"deflate": encode_deflate, } if e not in encoding_map: return None @@ -80,3 +82,5 @@ def encode_deflate(content): Returns compressed content, always including zlib header and checksum. """ return zlib.compress(content) + +__all__ = ["ENCODINGS", "encode", "decode"] diff --git a/netlib/http/models.py b/netlib/http/models.py index 54b8b112..bc681de3 100644 --- a/netlib/http/models.py +++ b/netlib/http/models.py @@ -136,7 +136,7 @@ class Headers(MutableMapping, object): def __len__(self): return len(set(name.lower() for name, _ in self.fields)) - #__hash__ = object.__hash__ + # __hash__ = object.__hash__ def _index(self, name): name = name.lower() @@ -227,11 +227,11 @@ class Request(Message): # This list is adopted legacy code. # We probably don't need to strip off keep-alive. _headers_to_strip_off = [ - b'Proxy-Connection', - b'Keep-Alive', - b'Connection', - b'Transfer-Encoding', - b'Upgrade', + 'Proxy-Connection', + 'Keep-Alive', + 'Connection', + 'Transfer-Encoding', + 'Upgrade', ] def __init__( @@ -275,8 +275,8 @@ class Request(Message): response. That is, we remove ETags and If-Modified-Since headers. """ delheaders = [ - b"if-modified-since", - b"if-none-match", + b"If-Modified-Since", + b"If-None-Match", ] for i in delheaders: self.headers.pop(i, None) @@ -286,16 +286,16 @@ class Request(Message): Modifies this request to remove headers that will compress the resource's data. """ - self.headers[b"accept-encoding"] = b"identity" + self.headers["Accept-Encoding"] = b"identity" def constrain_encoding(self): """ Limits the permissible Accept-Encoding values, based on what we can decode appropriately. """ - accept_encoding = self.headers.get(b"accept-encoding") + accept_encoding = self.headers.get(b"Accept-Encoding") if accept_encoding: - self.headers[b"accept-encoding"] = ( + self.headers["Accept-Encoding"] = ( ', '.join( e for e in encoding.ENCODINGS @@ -316,9 +316,9 @@ class Request(Message): indicates non-form data. """ if self.body: - if HDR_FORM_URLENCODED in self.headers.get("content-type","").lower(): + if HDR_FORM_URLENCODED in self.headers.get("Content-Type", "").lower(): return self.get_form_urlencoded() - elif HDR_FORM_MULTIPART in self.headers.get("content-type","").lower(): + elif HDR_FORM_MULTIPART in self.headers.get("Content-Type", "").lower(): return self.get_form_multipart() return ODict([]) @@ -328,12 +328,12 @@ class Request(Message): Returns an empty ODict if there is no data or the content-type indicates non-form data. """ - if self.body and HDR_FORM_URLENCODED in self.headers.get("content-type","").lower(): + if self.body and HDR_FORM_URLENCODED in self.headers.get("Content-Type", "").lower(): return ODict(utils.urldecode(self.body)) return ODict([]) def get_form_multipart(self): - if self.body and HDR_FORM_MULTIPART in self.headers.get("content-type","").lower(): + if self.body and HDR_FORM_MULTIPART in self.headers.get("Content-Type", "").lower(): return ODict( utils.multipartdecode( self.headers, @@ -405,9 +405,9 @@ class Request(Message): but not the resolved name. This is disabled by default, as an attacker may spoof the host header to confuse an analyst. """ - if hostheader and b"Host" in self.headers: + if hostheader and "Host" in self.headers: try: - return self.headers[b"Host"].decode("idna") + return self.headers["Host"].decode("idna") except ValueError: pass if self.host: @@ -426,7 +426,7 @@ class Request(Message): Returns a possibly empty netlib.odict.ODict object. """ ret = ODict() - for i in self.headers.get_all("cookie"): + for i in self.headers.get_all("Cookie"): ret.extend(cookies.parse_cookie_header(i)) return ret @@ -468,9 +468,9 @@ class Request(Message): class Response(Message): _headers_to_strip_off = [ - b'Proxy-Connection', - b'Alternate-Protocol', - b'Alt-Svc', + 'Proxy-Connection', + 'Alternate-Protocol', + 'Alt-Svc', ] def __init__( @@ -498,7 +498,7 @@ class Response(Message): return "<Response: {status_code} {msg} ({contenttype}, {size})>".format( status_code=self.status_code, msg=self.msg, - contenttype=self.headers.get("content-type", "unknown content type"), + contenttype=self.headers.get("Content-Type", "unknown content type"), size=size) def get_cookies(self): @@ -511,7 +511,7 @@ class Response(Message): attributes (e.g. HTTPOnly) are indicated by a Null value. """ ret = [] - for header in self.headers.get_all(b"set-cookie"): + for header in self.headers.get_all("Set-Cookie"): v = cookies.parse_set_cookie_header(header) if v: name, value, attrs = v @@ -534,4 +534,4 @@ class Response(Message): i[1][1] ) ) - self.headers.set_all(b"Set-Cookie", values) + self.headers.set_all("Set-Cookie", values) diff --git a/netlib/odict.py b/netlib/odict.py index 11d5d52a..1124b23a 100644 --- a/netlib/odict.py +++ b/netlib/odict.py @@ -1,6 +1,7 @@ from __future__ import (absolute_import, print_function, division) import re import copy +import six def safe_subn(pattern, repl, target, *args, **kwargs): @@ -67,10 +68,10 @@ class ODict(object): Sets the values for key k. If there are existing values for this key, they are cleared. """ - if isinstance(valuelist, basestring): + if isinstance(valuelist, six.text_type) or isinstance(valuelist, six.binary_type): raise ValueError( "Expected list of values instead of string. " - "Example: odict['Host'] = ['www.example.com']" + "Example: odict[b'Host'] = [b'www.example.com']" ) kc = self._kconv(k) new = [] @@ -134,13 +135,6 @@ class ODict(object): def __repr__(self): return repr(self.lst) - def format(self): - elements = [] - for itm in self.lst: - elements.append(itm[0] + ": " + str(itm[1])) - elements.append("") - return "\r\n".join(elements) - def in_any(self, key, value, caseless=False): """ Do any of the values matching key contain value? @@ -156,19 +150,6 @@ class ODict(object): return True return False - def match_re(self, expr): - """ - Match the regular expression against each (key, value) pair. For - each pair a string of the following format is matched against: - - "key: value" - """ - for k, v in self.lst: - s = "%s: %s" % (k, v) - if re.search(expr, s): - return True - return False - def replace(self, pattern, repl, *args, **kwargs): """ Replaces a regular expression pattern with repl in both keys and diff --git a/netlib/tutils.py b/netlib/tutils.py index b69495a3..746e1488 100644 --- a/netlib/tutils.py +++ b/netlib/tutils.py @@ -123,9 +123,7 @@ def tresp(**kwargs): status_code=200, msg=b"OK", headers=Headers(header_response=b"svalue"), - body=b"message", - timestamp_start=time.time(), - timestamp_end=time.time() + body=b"message" ) default.update(kwargs) return Response(**default) diff --git a/netlib/utils.py b/netlib/utils.py index 14b428d7..6fed44b6 100644 --- a/netlib/utils.py +++ b/netlib/utils.py @@ -246,7 +246,7 @@ def unparse_url(scheme, host, port, path=""): """ Returns a URL string, constructed from the specified compnents. """ - return "%s://%s%s" % (scheme, hostport(scheme, host, port), path) + return b"%s://%s%s" % (scheme, hostport(scheme, host, port), path) def urlencode(s): @@ -295,7 +295,7 @@ def multipartdecode(headers, content): """ Takes a multipart boundary encoded string and returns list of (key, value) tuples. """ - v = headers.get("content-type") + v = headers.get(b"Content-Type") if v: v = parse_content_type(v) if not v: @@ -304,33 +304,33 @@ def multipartdecode(headers, content): if not boundary: return [] - rx = re.compile(r'\bname="([^"]+)"') + rx = re.compile(br'\bname="([^"]+)"') r = [] - for i in content.split("--" + boundary): + for i in content.split(b"--" + boundary): parts = i.splitlines() - if len(parts) > 1 and parts[0][0:2] != "--": + if len(parts) > 1 and parts[0][0:2] != b"--": match = rx.search(parts[1]) if match: key = match.group(1) - value = "".join(parts[3 + parts[2:].index(""):]) + value = b"".join(parts[3 + parts[2:].index(b""):]) r.append((key, value)) return r return [] -def always_bytes(unicode_or_bytes, encoding): +def always_bytes(unicode_or_bytes, *encode_args): if isinstance(unicode_or_bytes, six.text_type): - return unicode_or_bytes.encode(encoding) + return unicode_or_bytes.encode(*encode_args) return unicode_or_bytes -def always_byte_args(encoding): +def always_byte_args(*encode_args): """Decorator that transparently encodes all arguments passed as unicode""" def decorator(fun): def _fun(*args, **kwargs): - args = [always_bytes(arg, encoding) for arg in args] - kwargs = {k: always_bytes(v, encoding) for k, v in six.iteritems(kwargs)} + args = [always_bytes(arg, *encode_args) for arg in args] + kwargs = {k: always_bytes(v, *encode_args) for k, v in six.iteritems(kwargs)} return fun(*args, **kwargs) return _fun return decorator diff --git a/test/http/test_models.py b/test/http/test_models.py index 8fce2e9d..c3ab4d0f 100644 --- a/test/http/test_models.py +++ b/test/http/test_models.py @@ -36,8 +36,8 @@ class TestRequest(object): assert isinstance(req.headers, Headers) def test_equal(self): - a = tutils.treq() - b = tutils.treq() + a = tutils.treq(timestamp_start=42, timestamp_end=43) + b = tutils.treq(timestamp_start=42, timestamp_end=43) assert a == b assert not a == 'foo' @@ -319,8 +319,8 @@ class TestResponse(object): assert isinstance(resp.headers, Headers) def test_equal(self): - a = tutils.tresp() - b = tutils.tresp() + a = tutils.tresp(timestamp_start=42, timestamp_end=43) + b = tutils.tresp(timestamp_start=42, timestamp_end=43) assert a == b assert not a == 'foo' diff --git a/test/test_encoding.py b/test/test_encoding.py index 9da3a38d..90f99338 100644 --- a/test/test_encoding.py +++ b/test/test_encoding.py @@ -2,10 +2,12 @@ from netlib import encoding def test_identity(): - assert "string" == encoding.decode("identity", "string") - assert "string" == encoding.encode("identity", "string") - assert not encoding.encode("nonexistent", "string") - assert None == encoding.decode("nonexistent encoding", "string") + assert b"string" == encoding.decode("identity", b"string") + assert b"string" == encoding.encode("identity", b"string") + assert b"string" == encoding.encode(b"identity", b"string") + assert b"string" == encoding.decode(b"identity", b"string") + assert not encoding.encode("nonexistent", b"string") + assert not encoding.decode("nonexistent encoding", b"string") def test_gzip(): diff --git a/test/test_odict.py b/test/test_odict.py index be3d862d..962c0daa 100644 --- a/test/test_odict.py +++ b/test/test_odict.py @@ -1,7 +1,7 @@ from netlib import odict, tutils -class TestODict: +class TestODict(object): def setUp(self): self.od = odict.ODict() @@ -13,21 +13,10 @@ class TestODict: def test_str_err(self): h = odict.ODict() - tutils.raises(ValueError, h.__setitem__, "key", "foo") - - def test_dictToHeader1(self): - self.od.add("one", "uno") - self.od.add("two", "due") - self.od.add("two", "tre") - expected = [ - "one: uno\r\n", - "two: due\r\n", - "two: tre\r\n", - "\r\n" - ] - out = self.od.format() - for i in expected: - assert out.find(i) >= 0 + with tutils.raises(ValueError): + h["key"] = u"foo" + with tutils.raises(ValueError): + h["key"] = b"foo" def test_getset_state(self): self.od.add("foo", 1) @@ -40,23 +29,6 @@ class TestODict: b.load_state(state) assert b == self.od - def test_dictToHeader2(self): - self.od["one"] = ["uno"] - expected1 = "one: uno\r\n" - expected2 = "\r\n" - out = self.od.format() - assert out.find(expected1) >= 0 - assert out.find(expected2) >= 0 - - def test_match_re(self): - h = odict.ODict() - h.add("one", "uno") - h.add("two", "due") - h.add("two", "tre") - assert h.match_re("uno") - assert h.match_re("two: due") - assert not h.match_re("nonono") - def test_in_any(self): self.od["one"] = ["atwoa", "athreea"] assert self.od.in_any("one", "two") @@ -122,7 +94,7 @@ class TestODict: assert a["a"] == ["b", "b"] -class TestODictCaseless: +class TestODictCaseless(object): def setUp(self): self.od = odict.ODictCaseless() diff --git a/test/test_socks.py b/test/test_socks.py index 3d109f42..65a0f0eb 100644 --- a/test/test_socks.py +++ b/test/test_socks.py @@ -1,12 +1,12 @@ -from cStringIO import StringIO +from io import BytesIO import socket from nose.plugins.skip import SkipTest from netlib import socks, tcp, tutils def test_client_greeting(): - raw = tutils.treader("\x05\x02\x00\xBE\xEF") - out = StringIO() + raw = tutils.treader(b"\x05\x02\x00\xBE\xEF") + out = BytesIO() msg = socks.ClientGreeting.from_file(raw) msg.assert_socks5() msg.to_file(out) @@ -19,11 +19,11 @@ def test_client_greeting(): def test_client_greeting_assert_socks5(): - raw = tutils.treader("\x00\x00") + raw = tutils.treader(b"\x00\x00") msg = socks.ClientGreeting.from_file(raw) tutils.raises(socks.SocksError, msg.assert_socks5) - raw = tutils.treader("HTTP/1.1 200 OK" + " " * 100) + raw = tutils.treader(b"HTTP/1.1 200 OK" + " " * 100) msg = socks.ClientGreeting.from_file(raw) try: msg.assert_socks5() @@ -33,7 +33,7 @@ def test_client_greeting_assert_socks5(): else: assert False - raw = tutils.treader("GET / HTTP/1.1" + " " * 100) + raw = tutils.treader(b"GET / HTTP/1.1" + " " * 100) msg = socks.ClientGreeting.from_file(raw) try: msg.assert_socks5() @@ -43,7 +43,7 @@ def test_client_greeting_assert_socks5(): else: assert False - raw = tutils.treader("XX") + raw = tutils.treader(b"XX") tutils.raises( socks.SocksError, socks.ClientGreeting.from_file, @@ -52,8 +52,8 @@ def test_client_greeting_assert_socks5(): def test_server_greeting(): - raw = tutils.treader("\x05\x02") - out = StringIO() + raw = tutils.treader(b"\x05\x02") + out = BytesIO() msg = socks.ServerGreeting.from_file(raw) msg.assert_socks5() msg.to_file(out) @@ -64,7 +64,7 @@ def test_server_greeting(): def test_server_greeting_assert_socks5(): - raw = tutils.treader("HTTP/1.1 200 OK" + " " * 100) + raw = tutils.treader(b"HTTP/1.1 200 OK" + " " * 100) msg = socks.ServerGreeting.from_file(raw) try: msg.assert_socks5() @@ -74,7 +74,7 @@ def test_server_greeting_assert_socks5(): else: assert False - raw = tutils.treader("GET / HTTP/1.1" + " " * 100) + raw = tutils.treader(b"GET / HTTP/1.1" + " " * 100) msg = socks.ServerGreeting.from_file(raw) try: msg.assert_socks5() @@ -86,36 +86,37 @@ def test_server_greeting_assert_socks5(): def test_message(): - raw = tutils.treader("\x05\x01\x00\x03\x0bexample.com\xDE\xAD\xBE\xEF") - out = StringIO() + raw = tutils.treader(b"\x05\x01\x00\x03\x0bexample.com\xDE\xAD\xBE\xEF") + out = BytesIO() msg = socks.Message.from_file(raw) msg.assert_socks5() - assert raw.read(2) == "\xBE\xEF" + assert raw.read(2) == b"\xBE\xEF" msg.to_file(out) assert out.getvalue() == raw.getvalue()[:-2] assert msg.ver == 5 assert msg.msg == 0x01 assert msg.atyp == 0x03 - assert msg.addr == ("example.com", 0xDEAD) + assert msg.addr == (b"example.com", 0xDEAD) def test_message_assert_socks5(): - raw = tutils.treader("\xEE\x01\x00\x03\x0bexample.com\xDE\xAD\xBE\xEF") + raw = tutils.treader(b"\xEE\x01\x00\x03\x0bexample.com\xDE\xAD\xBE\xEF") msg = socks.Message.from_file(raw) tutils.raises(socks.SocksError, msg.assert_socks5) def test_message_ipv4(): # Test ATYP=0x01 (IPV4) - raw = tutils.treader("\x05\x01\x00\x01\x7f\x00\x00\x01\xDE\xAD\xBE\xEF") - out = StringIO() + raw = tutils.treader(b"\x05\x01\x00\x01\x7f\x00\x00\x01\xDE\xAD\xBE\xEF") + out = BytesIO() msg = socks.Message.from_file(raw) - assert raw.read(2) == "\xBE\xEF" + left = raw.read(2) + assert left == b"\xBE\xEF" msg.to_file(out) assert out.getvalue() == raw.getvalue()[:-2] - assert msg.addr == ("127.0.0.1", 0xDEAD) + assert msg.addr == (b"127.0.0.1", 0xDEAD) def test_message_ipv6(): @@ -125,14 +126,14 @@ def test_message_ipv6(): ipv6_addr = "2001:db8:85a3:8d3:1319:8a2e:370:7344" raw = tutils.treader( - "\x05\x01\x00\x04" + + b"\x05\x01\x00\x04" + socket.inet_pton( socket.AF_INET6, ipv6_addr) + - "\xDE\xAD\xBE\xEF") - out = StringIO() + b"\xDE\xAD\xBE\xEF") + out = BytesIO() msg = socks.Message.from_file(raw) - assert raw.read(2) == "\xBE\xEF" + assert raw.read(2) == b"\xBE\xEF" msg.to_file(out) assert out.getvalue() == raw.getvalue()[:-2] @@ -140,13 +141,13 @@ def test_message_ipv6(): def test_message_invalid_rsv(): - raw = tutils.treader("\x05\x01\xFF\x01\x7f\x00\x00\x01\xDE\xAD\xBE\xEF") + raw = tutils.treader(b"\x05\x01\xFF\x01\x7f\x00\x00\x01\xDE\xAD\xBE\xEF") tutils.raises(socks.SocksError, socks.Message.from_file, raw) def test_message_unknown_atyp(): - raw = tutils.treader("\x05\x02\x00\x02\x7f\x00\x00\x01\xDE\xAD\xBE\xEF") + raw = tutils.treader(b"\x05\x02\x00\x02\x7f\x00\x00\x01\xDE\xAD\xBE\xEF") tutils.raises(socks.SocksError, socks.Message.from_file, raw) m = socks.Message(5, 1, 0x02, tcp.Address(("example.com", 5050))) - tutils.raises(socks.SocksError, m.to_file, StringIO()) + tutils.raises(socks.SocksError, m.to_file, BytesIO()) diff --git a/test/test_utils.py b/test/test_utils.py index 0db75578..ff27486c 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -84,10 +84,10 @@ def test_parse_url(): def test_unparse_url(): - assert utils.unparse_url("http", "foo.com", 99, "") == "http://foo.com:99" - assert utils.unparse_url("http", "foo.com", 80, "") == "http://foo.com" - assert utils.unparse_url("https", "foo.com", 80, "") == "https://foo.com:80" - assert utils.unparse_url("https", "foo.com", 443, "") == "https://foo.com" + assert utils.unparse_url(b"http", b"foo.com", 99, b"") == b"http://foo.com:99" + assert utils.unparse_url(b"http", b"foo.com", 80, b"/bar") == b"http://foo.com/bar" + assert utils.unparse_url(b"https", b"foo.com", 80, b"") == b"https://foo.com:80" + assert utils.unparse_url(b"https", b"foo.com", 443, b"") == b"https://foo.com" def test_urlencode(): @@ -122,7 +122,7 @@ def test_multipartdecode(): "--{0}\n" "Content-Disposition: form-data; name=\"field2\"\n\n" "value2\n" - "--{0}--".format(boundary).encode("ascii") + "--{0}--".format(boundary.decode()).encode() ) form = utils.multipartdecode(headers, content) |