aboutsummaryrefslogtreecommitdiffstats
path: root/netlib/utils.py
diff options
context:
space:
mode:
authorMaximilian Hils <git@maximilianhils.com>2015-09-16 20:19:52 +0200
committerMaximilian Hils <git@maximilianhils.com>2015-09-16 20:19:52 +0200
commite1659f3fcf83b5993b776a4ef3d2de70fbe27aa2 (patch)
treec0eba50b522d1d0183b057e9cae7bf7cc38c4fc3 /netlib/utils.py
parent2f9c566e480c377566a0ae044d698a75b45cd54c (diff)
parent265f31e8782ee9da511ce4b63aa2da00221cbf66 (diff)
downloadmitmproxy-e1659f3fcf83b5993b776a4ef3d2de70fbe27aa2.tar.gz
mitmproxy-e1659f3fcf83b5993b776a4ef3d2de70fbe27aa2.tar.bz2
mitmproxy-e1659f3fcf83b5993b776a4ef3d2de70fbe27aa2.zip
Merge pull request #92 from mitmproxy/python3
Python3 & HTTP1 Refactor
Diffstat (limited to 'netlib/utils.py')
-rw-r--r--netlib/utils.py166
1 files changed, 103 insertions, 63 deletions
diff --git a/netlib/utils.py b/netlib/utils.py
index d6774419..a86b8019 100644
--- a/netlib/utils.py
+++ b/netlib/utils.py
@@ -1,17 +1,17 @@
-from __future__ import (absolute_import, print_function, division)
+from __future__ import absolute_import, print_function, division
import os.path
-import cgi
-import urllib
-import urlparse
-import string
import re
-import six
+import string
import unicodedata
+import six
+
+from six.moves import urllib
+
-def isascii(s):
+def isascii(bytes):
try:
- s.decode("ascii")
+ bytes.decode("ascii")
except ValueError:
return False
return True
@@ -40,12 +40,12 @@ def clean_bin(s, keep_spacing=True):
)
else:
if keep_spacing:
- keep = b"\n\r\t"
+ keep = (9, 10, 13) # \t, \n, \r,
else:
- keep = b""
+ keep = ()
return b"".join(
- ch if (31 < ord(ch) < 127 or ch in keep) else b"."
- for ch in s
+ six.int2byte(ch) if (31 < ch < 127 or ch in keep) else b"."
+ for ch in six.iterbytes(s)
)
@@ -149,10 +149,7 @@ class Data(object):
return fullpath
-def is_valid_port(port):
- if not 0 <= port <= 65535:
- return False
- return True
+_label_valid = re.compile(b"(?!-)[A-Z\d-]{1,63}(?<!-)$", re.IGNORECASE)
def is_valid_host(host):
@@ -160,53 +157,79 @@ def is_valid_host(host):
host.decode("idna")
except ValueError:
return False
- if "\0" in host:
- return None
- return True
+ if len(host) > 255:
+ return False
+ if host[-1] == ".":
+ host = host[:-1]
+ return all(_label_valid.match(x) for x in host.split(b"."))
+
+
+def is_valid_port(port):
+ return 0 <= port <= 65535
+
+
+# PY2 workaround
+def decode_parse_result(result, enc):
+ if hasattr(result, "decode"):
+ return result.decode(enc)
+ else:
+ return urllib.parse.ParseResult(*[x.decode(enc) for x in result])
+
+
+# PY2 workaround
+def encode_parse_result(result, enc):
+ if hasattr(result, "encode"):
+ return result.encode(enc)
+ else:
+ return urllib.parse.ParseResult(*[x.encode(enc) for x in result])
def parse_url(url):
"""
- Returns a (scheme, host, port, path) tuple, or None on error.
+ URL-parsing function that checks that
+ - port is an integer 0-65535
+ - host is a valid IDNA-encoded hostname with no null-bytes
+ - path is valid ASCII
- Checks that:
- port is an integer 0-65535
- host is a valid IDNA-encoded hostname with no null-bytes
- path is valid ASCII
+ Args:
+ A URL (as bytes or as unicode)
+
+ Returns:
+ A (scheme, host, port, path) tuple
+
+ Raises:
+ ValueError, if the URL is not properly formatted.
"""
- try:
- scheme, netloc, path, params, query, fragment = urlparse.urlparse(url)
- except ValueError:
- return None
- if not scheme:
- return None
- if '@' in netloc:
- # FIXME: Consider what to do with the discarded credentials here Most
- # probably we should extend the signature to return these as a separate
- # value.
- _, netloc = string.rsplit(netloc, '@', maxsplit=1)
- if ':' in netloc:
- host, port = string.rsplit(netloc, ':', maxsplit=1)
- try:
- port = int(port)
- except ValueError:
- return None
+ parsed = urllib.parse.urlparse(url)
+
+ if not parsed.hostname:
+ raise ValueError("No hostname given")
+
+ if isinstance(url, six.binary_type):
+ host = parsed.hostname
+
+ # this should not raise a ValueError
+ decode_parse_result(parsed, "ascii")
else:
- host = netloc
- if scheme.endswith("https"):
- port = 443
- else:
- port = 80
- path = urlparse.urlunparse(('', '', path, params, query, fragment))
- if not path.startswith("/"):
- path = "/" + path
+ host = parsed.hostname.encode("idna")
+ parsed = encode_parse_result(parsed, "ascii")
+
+ port = parsed.port
+ if not port:
+ port = 443 if parsed.scheme == b"https" else 80
+
+ full_path = urllib.parse.urlunparse(
+ (b"", b"", parsed.path, parsed.params, parsed.query, parsed.fragment)
+ )
+ if not full_path.startswith(b"/"):
+ full_path = b"/" + full_path
+
if not is_valid_host(host):
- return None
- if not isascii(path):
- return None
+ raise ValueError("Invalid Host")
if not is_valid_port(port):
- return None
- return scheme, host, port, path
+ raise ValueError("Invalid Port")
+
+ return parsed.scheme, host, port, full_path
def get_header_tokens(headers, key):
@@ -217,7 +240,7 @@ def get_header_tokens(headers, key):
"""
if key not in headers:
return []
- tokens = headers[key].split(",")
+ tokens = headers[key].split(b",")
return [token.strip() for token in tokens]
@@ -228,7 +251,7 @@ def hostport(scheme, host, port):
if (port, scheme) in [(80, "http"), (443, "https")]:
return host
else:
- return "%s:%s" % (host, port)
+ return b"%s:%d" % (host, port)
def unparse_url(scheme, host, port, path=""):
@@ -243,14 +266,14 @@ def urlencode(s):
Takes a list of (key, value) tuples and returns a urlencoded string.
"""
s = [tuple(i) for i in s]
- return urllib.urlencode(s, False)
+ return urllib.parse.urlencode(s, False)
def urldecode(s):
"""
Takes a urlencoded string and returns a list of (key, value) tuples.
"""
- return cgi.parse_qsl(s, keep_blank_values=True)
+ return urllib.parse.parse_qsl(s, keep_blank_values=True)
def parse_content_type(c):
@@ -267,14 +290,14 @@ def parse_content_type(c):
("text", "html", {"charset": "UTF-8"})
"""
- parts = c.split(";", 1)
- ts = parts[0].split("/", 1)
+ parts = c.split(b";", 1)
+ ts = parts[0].split(b"/", 1)
if len(ts) != 2:
return None
d = {}
if len(parts) == 2:
- for i in parts[1].split(";"):
- clause = i.split("=", 1)
+ for i in parts[1].split(b";"):
+ clause = i.split(b"=", 1)
if len(clause) == 2:
d[clause[0].strip()] = clause[1].strip()
return ts[0].lower(), ts[1].lower(), d
@@ -289,7 +312,7 @@ def multipartdecode(headers, content):
v = parse_content_type(v)
if not v:
return []
- boundary = v[2].get("boundary")
+ boundary = v[2].get(b"boundary")
if not boundary:
return []
@@ -306,3 +329,20 @@ def multipartdecode(headers, content):
r.append((key, value))
return r
return []
+
+
+def always_bytes(unicode_or_bytes, encoding):
+ if isinstance(unicode_or_bytes, six.text_type):
+ return unicode_or_bytes.encode(encoding)
+ return unicode_or_bytes
+
+
+def always_byte_args(encoding):
+ """Decorator that transparently encodes all arguments passed as unicode"""
+ def decorator(fun):
+ def _fun(*args, **kwargs):
+ args = [always_bytes(arg, encoding) for arg in args]
+ kwargs = {k: always_bytes(v, encoding) for k, v in six.iteritems(kwargs)}
+ return fun(*args, **kwargs)
+ return _fun
+ return decorator