diff options
Diffstat (limited to 'netlib')
-rw-r--r-- | netlib/http/cookies.py | 39 | ||||
-rw-r--r-- | netlib/http/request.py | 10 | ||||
-rw-r--r-- | netlib/http/response.py | 10 | ||||
-rw-r--r-- | netlib/odict.py | 81 | ||||
-rw-r--r-- | netlib/utils.py | 30 |
5 files changed, 85 insertions, 85 deletions
diff --git a/netlib/http/cookies.py b/netlib/http/cookies.py index caa84ff7..4451f1da 100644 --- a/netlib/http/cookies.py +++ b/netlib/http/cookies.py @@ -1,5 +1,6 @@ from six.moves import http_cookies as Cookie import re +import string from email.utils import parsedate_tz, formatdate, mktime_tz from .. import odict @@ -27,7 +28,6 @@ variants. Serialization follows RFC6265. # TODO: Disallow LHS-only Cookie values - def _read_until(s, start, term): """ Read until one of the characters in term is reached. @@ -203,25 +203,26 @@ def refresh_set_cookie_header(c, delta): Returns: A refreshed Set-Cookie string """ - try: - c = Cookie.SimpleCookie(str(c)) - except Cookie.CookieError: + + name, value, attrs = parse_set_cookie_header(c) + if not name or not value: raise ValueError("Invalid Cookie") - for i in c.values(): - if "expires" in i: - d = parsedate_tz(i["expires"]) - if d: - d = mktime_tz(d) + delta - i["expires"] = formatdate(d) - else: - # This can happen when the expires tag is invalid. - # reddit.com sends a an expires tag like this: "Thu, 31 Dec - # 2037 23:59:59 GMT", which is valid RFC 1123, but not - # strictly correct according to the cookie spec. Browsers - # appear to parse this tolerantly - maybe we should too. - # For now, we just ignore this. - del i["expires"] - ret = c.output(header="").strip() + + if "expires" in attrs: + e = parsedate_tz(attrs["expires"][-1]) + if e: + f = mktime_tz(e) + delta + attrs["expires"] = [formatdate(f)] + else: + # This can happen when the expires tag is invalid. + # reddit.com sends a an expires tag like this: "Thu, 31 Dec + # 2037 23:59:59 GMT", which is valid RFC 1123, but not + # strictly correct according to the cookie spec. Browsers + # appear to parse this tolerantly - maybe we should too. + # For now, we just ignore this. + del attrs["expires"] + + ret = format_set_cookie_header(name, value, attrs) if not ret: raise ValueError("Invalid Cookie") return ret diff --git a/netlib/http/request.py b/netlib/http/request.py index 6406a980..a42150ff 100644 --- a/netlib/http/request.py +++ b/netlib/http/request.py @@ -162,7 +162,7 @@ class Request(Message): def path(self): """ HTTP request path, e.g. "/index.html". - Guaranteed to start with a slash. + Guaranteed to start with a slash, except for OPTIONS requests, which may just be "*". """ if self.data.path is None: return None @@ -343,14 +343,6 @@ class Request(Message): # Legacy - def get_cookies(self): # pragma: no cover - warnings.warn(".get_cookies is deprecated, use .cookies instead.", DeprecationWarning) - return self.cookies - - def set_cookies(self, odict): # pragma: no cover - warnings.warn(".set_cookies is deprecated, use .cookies instead.", DeprecationWarning) - self.cookies = odict - def get_query(self): # pragma: no cover warnings.warn(".get_query is deprecated, use .query instead.", DeprecationWarning) return self.query or ODict([]) diff --git a/netlib/http/response.py b/netlib/http/response.py index efd7f60a..2f06149e 100644 --- a/netlib/http/response.py +++ b/netlib/http/response.py @@ -127,13 +127,3 @@ class Response(Message): c.append(refreshed) if c: self.headers.set_all("set-cookie", c) - - # Legacy - - def get_cookies(self): # pragma: no cover - warnings.warn(".get_cookies is deprecated, use .cookies instead.", DeprecationWarning) - return self.cookies - - def set_cookies(self, odict): # pragma: no cover - warnings.warn(".set_cookies is deprecated, use .cookies instead.", DeprecationWarning) - self.cookies = odict diff --git a/netlib/odict.py b/netlib/odict.py index 461192f7..8a638dab 100644 --- a/netlib/odict.py +++ b/netlib/odict.py @@ -1,5 +1,6 @@ from __future__ import (absolute_import, print_function, division) import copy + import six from .utils import Serializable, safe_subn @@ -27,27 +28,24 @@ class ODict(Serializable): def __iter__(self): return self.lst.__iter__() - def __getitem__(self, k): + def __getitem__(self, key): """ Returns a list of values matching key. """ - ret = [] - k = self._kconv(k) - for i in self.lst: - if self._kconv(i[0]) == k: - ret.append(i[1]) - return ret - def keys(self): - return list(set([self._kconv(i[0]) for i in self.lst])) + key = self._kconv(key) + return [ + v + for k, v in self.lst + if self._kconv(k) == key + ] - def _filter_lst(self, k, lst): - k = self._kconv(k) - new = [] - for i in lst: - if self._kconv(i[0]) != k: - new.append(i) - return new + def keys(self): + return list( + set( + self._kconv(k) for k, _ in self.lst + ) + ) def __len__(self): """ @@ -81,14 +79,19 @@ class ODict(Serializable): """ Delete all items matching k. """ - self.lst = self._filter_lst(k, self.lst) - - def __contains__(self, k): k = self._kconv(k) - for i in self.lst: - if self._kconv(i[0]) == k: - return True - return False + self.lst = [ + i + for i in self.lst + if self._kconv(i[0]) != k + ] + + def __contains__(self, key): + key = self._kconv(key) + return any( + self._kconv(k) == key + for k, _ in self.lst + ) def add(self, key, value, prepend=False): if prepend: @@ -127,40 +130,24 @@ class ODict(Serializable): def __repr__(self): return repr(self.lst) - def in_any(self, key, value, caseless=False): - """ - Do any of the values matching key contain value? - - If caseless is true, value comparison is case-insensitive. - """ - if caseless: - value = value.lower() - for i in self[key]: - if caseless: - i = i.lower() - if value in i: - return True - return False - def replace(self, pattern, repl, *args, **kwargs): """ Replaces a regular expression pattern with repl in both keys and - values. Encoded content will be decoded before replacement, and - re-encoded afterwards. + values. Returns the number of replacements made. """ - nlst, count = [], 0 - for i in self.lst: - k, c = safe_subn(pattern, repl, i[0], *args, **kwargs) + new, count = [], 0 + for k, v in self.lst: + k, c = safe_subn(pattern, repl, k, *args, **kwargs) count += c - v, c = safe_subn(pattern, repl, i[1], *args, **kwargs) + v, c = safe_subn(pattern, repl, v, *args, **kwargs) count += c - nlst.append([k, v]) - self.lst = nlst + new.append([k, v]) + self.lst = new return count - # Implement the StateObject protocol from mitmproxy + # Implement Serializable def get_state(self): return [tuple(i) for i in self.lst] diff --git a/netlib/utils.py b/netlib/utils.py index dda76808..be2701a0 100644 --- a/netlib/utils.py +++ b/netlib/utils.py @@ -330,6 +330,8 @@ def unparse_url(scheme, host, port, path=""): Args: All args must be str. """ + if path == "*": + path = "" return "%s://%s%s" % (scheme, hostport(scheme, host, port), path) @@ -429,3 +431,31 @@ def safe_subn(pattern, repl, target, *args, **kwargs): need a better solution that is aware of the actual content ecoding. """ return re.subn(str(pattern), str(repl), target, *args, **kwargs) + + +def bytes_to_escaped_str(data): + """ + Take bytes and return a safe string that can be displayed to the user. + """ + # TODO: We may want to support multi-byte characters without escaping them. + # One way to do would be calling .decode("utf8", "backslashreplace") first + # and then escaping UTF8 control chars (see clean_bin). + + if not isinstance(data, bytes): + raise ValueError("data must be bytes") + return repr(data).lstrip("b")[1:-1] + + +def escaped_str_to_bytes(data): + """ + Take an escaped string and return the unescaped bytes equivalent. + """ + if not isinstance(data, str): + raise ValueError("data must be str") + + if six.PY2: + return data.decode("string-escape") + + # This one is difficult - we use an undocumented Python API here + # as per http://stackoverflow.com/a/23151714/934719 + return codecs.escape_decode(data)[0] |