diff options
Diffstat (limited to 'netlib')
-rw-r--r-- | netlib/encoding.py | 1 | ||||
-rw-r--r-- | netlib/http/__init__.py | 8 | ||||
-rw-r--r-- | netlib/http/cookies.py | 60 | ||||
-rw-r--r-- | netlib/http/headers.py | 140 | ||||
-rw-r--r-- | netlib/http/http1/read.py | 4 | ||||
-rw-r--r-- | netlib/http/http2/connections.py | 12 | ||||
-rw-r--r-- | netlib/http/message.py | 70 | ||||
-rw-r--r-- | netlib/http/request.py | 106 | ||||
-rw-r--r-- | netlib/http/response.py | 41 | ||||
-rw-r--r-- | netlib/multidict.py | 248 | ||||
-rw-r--r-- | netlib/utils.py | 11 |
11 files changed, 487 insertions, 214 deletions
diff --git a/netlib/encoding.py b/netlib/encoding.py index 14479e00..98502451 100644 --- a/netlib/encoding.py +++ b/netlib/encoding.py @@ -5,7 +5,6 @@ from __future__ import absolute_import from io import BytesIO import gzip import zlib -from .utils import always_byte_args ENCODINGS = {"identity", "gzip", "deflate"} diff --git a/netlib/http/__init__.py b/netlib/http/__init__.py index 917080f7..9fafa28f 100644 --- a/netlib/http/__init__.py +++ b/netlib/http/__init__.py @@ -2,13 +2,13 @@ from __future__ import absolute_import, print_function, division from .request import Request from .response import Response from .headers import Headers -from .message import decoded -from . import http1, http2 +from .message import MultiDictView, decoded +from . import http1, http2, status_codes __all__ = [ "Request", "Response", "Headers", - "decoded", - "http1", "http2", + "MultiDictView", "decoded", + "http1", "http2", "status_codes", ] diff --git a/netlib/http/cookies.py b/netlib/http/cookies.py index 4451f1da..88c76870 100644 --- a/netlib/http/cookies.py +++ b/netlib/http/cookies.py @@ -1,8 +1,8 @@ -from six.moves import http_cookies as Cookie +import collections import re -import string from email.utils import parsedate_tz, formatdate, mktime_tz +from netlib.multidict import ImmutableMultiDict from .. import odict """ @@ -157,42 +157,76 @@ def _parse_set_cookie_pairs(s): return pairs +def parse_set_cookie_headers(headers): + ret = [] + for header in headers: + v = parse_set_cookie_header(header) + if v: + name, value, attrs = v + ret.append((name, SetCookie(value, attrs))) + return ret + + +class CookieAttrs(ImmutableMultiDict): + @staticmethod + def _kconv(key): + return key.lower() + + @staticmethod + def _reduce_values(values): + # See the StickyCookieTest for a weird cookie that only makes sense + # if we take the last part. + return values[-1] + + +SetCookie = collections.namedtuple("SetCookie", ["value", "attrs"]) + + def parse_set_cookie_header(line): """ Parse a Set-Cookie header value Returns a (name, value, attrs) tuple, or None, where attrs is an - ODictCaseless set of attributes. No attempt is made to parse attribute + CookieAttrs dict of attributes. No attempt is made to parse attribute values - they are treated purely as strings. """ pairs = _parse_set_cookie_pairs(line) if pairs: - return pairs[0][0], pairs[0][1], odict.ODictCaseless(pairs[1:]) + return pairs[0][0], pairs[0][1], CookieAttrs(tuple(x) for x in pairs[1:]) def format_set_cookie_header(name, value, attrs): """ Formats a Set-Cookie header value. """ - pairs = [[name, value]] - pairs.extend(attrs.lst) + pairs = [(name, value)] + pairs.extend( + attrs.fields if hasattr(attrs, "fields") else attrs + ) return _format_set_cookie_pairs(pairs) +def parse_cookie_headers(cookie_headers): + cookie_list = [] + for header in cookie_headers: + cookie_list.extend(parse_cookie_header(header)) + return cookie_list + + def parse_cookie_header(line): """ Parse a Cookie header value. - Returns a (possibly empty) ODict object. + Returns a list of (lhs, rhs) tuples. """ pairs, off_ = _read_pairs(line) - return odict.ODict(pairs) + return pairs -def format_cookie_header(od): +def format_cookie_header(lst): """ Formats a Cookie header value. """ - return _format_pairs(od.lst) + return _format_pairs(lst) def refresh_set_cookie_header(c, delta): @@ -209,10 +243,10 @@ def refresh_set_cookie_header(c, delta): raise ValueError("Invalid Cookie") if "expires" in attrs: - e = parsedate_tz(attrs["expires"][-1]) + e = parsedate_tz(attrs["expires"]) if e: f = mktime_tz(e) + delta - attrs["expires"] = [formatdate(f)] + attrs = attrs.with_set_all("expires", [formatdate(f)]) else: # This can happen when the expires tag is invalid. # reddit.com sends a an expires tag like this: "Thu, 31 Dec @@ -220,7 +254,7 @@ def refresh_set_cookie_header(c, delta): # strictly correct according to the cookie spec. Browsers # appear to parse this tolerantly - maybe we should too. # For now, we just ignore this. - del attrs["expires"] + attrs = attrs.with_delitem("expires") ret = format_set_cookie_header(name, value, attrs) if not ret: diff --git a/netlib/http/headers.py b/netlib/http/headers.py index 72739f90..60d3f429 100644 --- a/netlib/http/headers.py +++ b/netlib/http/headers.py @@ -1,9 +1,3 @@ -""" - -Unicode Handling ----------------- -See also: http://lucumr.pocoo.org/2013/7/2/the-updated-guide-to-unicode/ -""" from __future__ import absolute_import, print_function, division import re @@ -13,23 +7,22 @@ try: except ImportError: # pragma: no cover from collections import MutableMapping # Workaround for Python < 3.3 - import six +from ..multidict import MultiDict +from ..utils import always_bytes -from netlib.utils import always_byte_args, always_bytes, Serializable +# See also: http://lucumr.pocoo.org/2013/7/2/the-updated-guide-to-unicode/ if six.PY2: # pragma: no cover _native = lambda x: x _always_bytes = lambda x: x - _always_byte_args = lambda x: x else: # While headers _should_ be ASCII, it's not uncommon for certain headers to be utf-8 encoded. _native = lambda x: x.decode("utf-8", "surrogateescape") _always_bytes = lambda x: always_bytes(x, "utf-8", "surrogateescape") - _always_byte_args = always_byte_args("utf-8", "surrogateescape") -class Headers(MutableMapping, Serializable): +class Headers(MultiDict): """ Header class which allows both convenient access to individual headers as well as direct access to the underlying raw data. Provides a full dictionary interface. @@ -49,11 +42,11 @@ class Headers(MutableMapping, Serializable): >>> h["host"] "example.com" - # Headers can also be creatd from a list of raw (header_name, header_value) byte tuples + # Headers can also be created from a list of raw (header_name, header_value) byte tuples >>> h = Headers([ - [b"Host",b"example.com"], - [b"Accept",b"text/html"], - [b"accept",b"application/xml"] + (b"Host",b"example.com"), + (b"Accept",b"text/html"), + (b"accept",b"application/xml") ]) # Multiple headers are folded into a single header as per RFC7230 @@ -77,7 +70,6 @@ class Headers(MutableMapping, Serializable): For use with the "Set-Cookie" header, see :py:meth:`get_all`. """ - @_always_byte_args def __init__(self, fields=None, **headers): """ Args: @@ -89,19 +81,29 @@ class Headers(MutableMapping, Serializable): If ``**headers`` contains multiple keys that have equal ``.lower()`` s, the behavior is undefined. """ - self.fields = fields or [] + super(Headers, self).__init__(fields) - for name, value in self.fields: - if not isinstance(name, bytes) or not isinstance(value, bytes): - raise ValueError("Headers passed as fields must be bytes.") + for key, value in self.fields: + if not isinstance(key, bytes) or not isinstance(value, bytes): + raise TypeError("Header fields must be bytes.") # content_type -> content-type headers = { - _always_bytes(name).replace(b"_", b"-"): value + _always_bytes(name).replace(b"_", b"-"): _always_bytes(value) for name, value in six.iteritems(headers) } self.update(headers) + @staticmethod + def _reduce_values(values): + # Headers can be folded + return ", ".join(values) + + @staticmethod + def _kconv(key): + # Headers are case-insensitive + return key.lower() + def __bytes__(self): if self.fields: return b"\r\n".join(b": ".join(field) for field in self.fields) + b"\r\n" @@ -111,98 +113,40 @@ class Headers(MutableMapping, Serializable): if six.PY2: # pragma: no cover __str__ = __bytes__ - @_always_byte_args - def __getitem__(self, name): - values = self.get_all(name) - if not values: - raise KeyError(name) - return ", ".join(values) - - @_always_byte_args - def __setitem__(self, name, value): - idx = self._index(name) - - # To please the human eye, we insert at the same position the first existing header occured. - if idx is not None: - del self[name] - self.fields.insert(idx, [name, value]) - else: - self.fields.append([name, value]) - - @_always_byte_args - def __delitem__(self, name): - if name not in self: - raise KeyError(name) - name = name.lower() - self.fields = [ - field for field in self.fields - if name != field[0].lower() - ] + def __delitem__(self, key): + key = _always_bytes(key) + super(Headers, self).__delitem__(key) def __iter__(self): - seen = set() - for name, _ in self.fields: - name_lower = name.lower() - if name_lower not in seen: - seen.add(name_lower) - yield _native(name) - - def __len__(self): - return len(set(name.lower() for name, _ in self.fields)) - - # __hash__ = object.__hash__ - - def _index(self, name): - name = name.lower() - for i, field in enumerate(self.fields): - if field[0].lower() == name: - return i - return None - - def __eq__(self, other): - if isinstance(other, Headers): - return self.fields == other.fields - return False - - def __ne__(self, other): - return not self.__eq__(other) - - @_always_byte_args + for x in super(Headers, self).__iter__(): + yield _native(x) + def get_all(self, name): """ Like :py:meth:`get`, but does not fold multiple headers into a single one. This is useful for Set-Cookie headers, which do not support folding. - See also: https://tools.ietf.org/html/rfc7230#section-3.2.2 """ - name_lower = name.lower() - values = [_native(value) for n, value in self.fields if n.lower() == name_lower] - return values + name = _always_bytes(name) + return [ + _native(x) for x in + super(Headers, self).get_all(name) + ] - @_always_byte_args def set_all(self, name, values): """ Explicitly set multiple headers for the given key. See: :py:meth:`get_all` """ - values = map(_always_bytes, values) # _always_byte_args does not fix lists - if name in self: - del self[name] - self.fields.extend( - [name, value] for value in values - ) - - def get_state(self): - return tuple(tuple(field) for field in self.fields) - - def set_state(self, state): - self.fields = [list(field) for field in state] + name = _always_bytes(name) + values = [_always_bytes(x) for x in values] + return super(Headers, self).set_all(name, values) - @classmethod - def from_state(cls, state): - return cls([list(field) for field in state]) + def insert(self, index, key, value): + key = _always_bytes(key) + value = _always_bytes(value) + super(Headers, self).insert(index, key, value) - @_always_byte_args def replace(self, pattern, repl, flags=0): """ Replaces a regular expression pattern with repl in each "name: value" @@ -211,6 +155,8 @@ class Headers(MutableMapping, Serializable): Returns: The number of replacements made. """ + pattern = _always_bytes(pattern) + repl = _always_bytes(repl) pattern = re.compile(pattern, flags) replacements = 0 diff --git a/netlib/http/http1/read.py b/netlib/http/http1/read.py index 6e3a1b93..d30976bd 100644 --- a/netlib/http/http1/read.py +++ b/netlib/http/http1/read.py @@ -316,14 +316,14 @@ def _read_headers(rfile): if not ret: raise HttpSyntaxException("Invalid headers") # continued header - ret[-1][1] = ret[-1][1] + b'\r\n ' + line.strip() + ret[-1] = (ret[-1][0], ret[-1][1] + b'\r\n ' + line.strip()) else: try: name, value = line.split(b":", 1) value = value.strip() if not name: raise ValueError() - ret.append([name, value]) + ret.append((name, value)) except ValueError: raise HttpSyntaxException("Invalid headers") return Headers(ret) diff --git a/netlib/http/http2/connections.py b/netlib/http/http2/connections.py index f900b67c..6643b6b9 100644 --- a/netlib/http/http2/connections.py +++ b/netlib/http/http2/connections.py @@ -201,13 +201,13 @@ class HTTP2Protocol(object): headers = request.headers.copy() if ':authority' not in headers: - headers.fields.insert(0, (b':authority', authority.encode('ascii'))) + headers.insert(0, b':authority', authority.encode('ascii')) if ':scheme' not in headers: - headers.fields.insert(0, (b':scheme', request.scheme.encode('ascii'))) + headers.insert(0, b':scheme', request.scheme.encode('ascii')) if ':path' not in headers: - headers.fields.insert(0, (b':path', request.path.encode('ascii'))) + headers.insert(0, b':path', request.path.encode('ascii')) if ':method' not in headers: - headers.fields.insert(0, (b':method', request.method.encode('ascii'))) + headers.insert(0, b':method', request.method.encode('ascii')) if hasattr(request, 'stream_id'): stream_id = request.stream_id @@ -224,7 +224,7 @@ class HTTP2Protocol(object): headers = response.headers.copy() if ':status' not in headers: - headers.fields.insert(0, (b':status', str(response.status_code).encode('ascii'))) + headers.insert(0, b':status', str(response.status_code).encode('ascii')) if hasattr(response, 'stream_id'): stream_id = response.stream_id @@ -420,7 +420,7 @@ class HTTP2Protocol(object): self._handle_unexpected_frame(frm) headers = Headers( - [[k.encode('ascii'), v.encode('ascii')] for k, v in self.decoder.decode(header_blocks)] + (k.encode('ascii'), v.encode('ascii')) for k, v in self.decoder.decode(header_blocks) ) return stream_id, headers, body diff --git a/netlib/http/message.py b/netlib/http/message.py index da9681a0..db4054b1 100644 --- a/netlib/http/message.py +++ b/netlib/http/message.py @@ -4,6 +4,7 @@ import warnings import six +from ..multidict import MultiDict from .headers import Headers from .. import encoding, utils @@ -235,3 +236,72 @@ class decoded(object): def __exit__(self, type, value, tb): if self.ce: self.message.encode(self.ce) + + +class MultiDictView(MultiDict): + """ + Some parts in HTTP (Cookies, URL query strings, ...) require a specific data structure: A MultiDict. + It behaves mostly like an ordered dict but it can have several values for the same key. + + The MultiDictView provides a MultiDict *view* on an :py:class:`Request` or :py:class:`Response`. + That is, it represents a part of the request as a MultiDict, but doesn't contain state/data themselves. + + For example, ``request.cookies`` provides a view on the ``Cookie: ...`` header. + Any change to ``request.cookies`` will also modify the ``Cookie`` header. + Any change to the ``Cookie`` header will also modify ``request.cookies``. + + Example: + + .. code-block:: python + + # Cookies are represented as a MultiDict. + >>> request.cookies + MultiDictView[("name", "value"), ("a", "false"), ("a", "42")] + + # MultiDicts mostly behave like a normal dict. + >>> request.cookies["name"] + "value" + + # If there is more than one value, only the first value is returned. + >>> request.cookies["a"] + "false" + + # `.get_all(key)` returns a list of all values. + >>> request.cookies.get_all("a") + ["false", "42"] + + # Changes to the headers are immediately reflected in the cookies. + >>> request.cookies + MultiDictView[("name", "value"), ...] + >>> del request.headers["Cookie"] + >>> request.cookies + MultiDictView[] # empty now + """ + + def __init__(self, attr, message): + if False: # pragma: no cover + # We do not want to call the parent constructor here as that + # would cause an unnecessary parse/unparse pass. + # This is here to silence linters. Message + super(MultiDictView, self).__init__(None) + self._attr = attr + self._message = message # type: Message + + @staticmethod + def _kconv(key): + # All request-attributes are case-sensitive. + return key + + @staticmethod + def _reduce_values(values): + # We just return the first element if + # multiple elements exist with the same key. + return values[0] + + @property + def fields(self): + return getattr(self._message, "_" + self._attr) + + @fields.setter + def fields(self, value): + setattr(self._message, self._attr, value) diff --git a/netlib/http/request.py b/netlib/http/request.py index a42150ff..ae28084b 100644 --- a/netlib/http/request.py +++ b/netlib/http/request.py @@ -11,7 +11,7 @@ from netlib.http import cookies from netlib.odict import ODict from .. import encoding from .headers import Headers -from .message import Message, _native, _always_bytes, MessageData +from .message import Message, _native, _always_bytes, MessageData, MultiDictView # This regex extracts & splits the host header into host and port. # Handles the edge case of IPv6 addresses containing colons. @@ -224,45 +224,54 @@ class Request(Message): @property def query(self): + # type: () -> MultiDictView """ - The request query string as an :py:class:`ODict` object. - None, if there is no query. + The request query string as an :py:class:`MultiDictView` object. """ + return MultiDictView("query", self) + + @property + def _query(self): _, _, _, _, query, _ = urllib.parse.urlparse(self.url) - if query: - return ODict(utils.urldecode(query)) - return None + return tuple(utils.urldecode(query)) @query.setter - def query(self, odict): - query = utils.urlencode(odict.lst) + def query(self, value): + query = utils.urlencode(value) scheme, netloc, path, params, _, fragment = urllib.parse.urlparse(self.url) _, _, _, self.path = utils.parse_url( urllib.parse.urlunparse([scheme, netloc, path, params, query, fragment])) @property def cookies(self): + # type: () -> MultiDictView """ The request cookies. - An empty :py:class:`ODict` object if the cookie monster ate them all. + + An empty :py:class:`MultiDictView` object if the cookie monster ate them all. """ - ret = ODict() - for i in self.headers.get_all("Cookie"): - ret.extend(cookies.parse_cookie_header(i)) - return ret + return MultiDictView("cookies", self) + + @property + def _cookies(self): + h = self.headers.get_all("Cookie") + return tuple(cookies.parse_cookie_headers(h)) @cookies.setter - def cookies(self, odict): - self.headers["cookie"] = cookies.format_cookie_header(odict) + def cookies(self, value): + self.headers["cookie"] = cookies.format_cookie_header(value) @property def path_components(self): """ - The URL's path components as a list of strings. + The URL's path components as a tuple of strings. Components are unquoted. """ _, _, path, _, _, _ = urllib.parse.urlparse(self.url) - return [urllib.parse.unquote(i) for i in path.split("/") if i] + # This needs to be a tuple so that it's immutable. + # Otherwise, this would fail silently: + # request.path_components.append("foo") + return tuple(urllib.parse.unquote(i) for i in path.split("/") if i) @path_components.setter def path_components(self, components): @@ -309,64 +318,43 @@ class Request(Message): @property def urlencoded_form(self): """ - The URL-encoded form data as an :py:class:`ODict` object. - None if there is no data or the content-type indicates non-form data. + The URL-encoded form data as an :py:class:`MultiDictView` object. + An empty MultiDictView if the content-type indicates non-form data + or the content could not be parsed. """ + return MultiDictView("urlencoded_form", self) + + @property + def _urlencoded_form(self): is_valid_content_type = "application/x-www-form-urlencoded" in self.headers.get("content-type", "").lower() - if self.content and is_valid_content_type: - return ODict(utils.urldecode(self.content)) - return None + if is_valid_content_type: + return tuple(utils.urldecode(self.content)) + return () @urlencoded_form.setter - def urlencoded_form(self, odict): + def urlencoded_form(self, value): """ Sets the body to the URL-encoded form data, and adds the appropriate content-type header. This will overwrite the existing content if there is one. """ self.headers["content-type"] = "application/x-www-form-urlencoded" - self.content = utils.urlencode(odict.lst) + self.content = utils.urlencode(value) @property def multipart_form(self): """ - The multipart form data as an :py:class:`ODict` object. - None if there is no data or the content-type indicates non-form data. + The multipart form data as an :py:class:`MultipartFormDict` object. + None if the content-type indicates non-form data. """ + return MultiDictView("multipart_form", self) + + @property + def _multipart_form(self): is_valid_content_type = "multipart/form-data" in self.headers.get("content-type", "").lower() - if self.content and is_valid_content_type: - return ODict(utils.multipartdecode(self.headers,self.content)) - return None + if is_valid_content_type: + return utils.multipartdecode(self.headers, self.content) + return () @multipart_form.setter def multipart_form(self, value): raise NotImplementedError() - - # Legacy - - def get_query(self): # pragma: no cover - warnings.warn(".get_query is deprecated, use .query instead.", DeprecationWarning) - return self.query or ODict([]) - - def set_query(self, odict): # pragma: no cover - warnings.warn(".set_query is deprecated, use .query instead.", DeprecationWarning) - self.query = odict - - def get_path_components(self): # pragma: no cover - warnings.warn(".get_path_components is deprecated, use .path_components instead.", DeprecationWarning) - return self.path_components - - def set_path_components(self, lst): # pragma: no cover - warnings.warn(".set_path_components is deprecated, use .path_components instead.", DeprecationWarning) - self.path_components = lst - - def get_form_urlencoded(self): # pragma: no cover - warnings.warn(".get_form_urlencoded is deprecated, use .urlencoded_form instead.", DeprecationWarning) - return self.urlencoded_form or ODict([]) - - def set_form_urlencoded(self, odict): # pragma: no cover - warnings.warn(".set_form_urlencoded is deprecated, use .urlencoded_form instead.", DeprecationWarning) - self.urlencoded_form = odict - - def get_form_multipart(self): # pragma: no cover - warnings.warn(".get_form_multipart is deprecated, use .multipart_form instead.", DeprecationWarning) - return self.multipart_form or ODict([]) diff --git a/netlib/http/response.py b/netlib/http/response.py index 2f06149e..6d56fc1f 100644 --- a/netlib/http/response.py +++ b/netlib/http/response.py @@ -1,14 +1,12 @@ from __future__ import absolute_import, print_function, division -import warnings from email.utils import parsedate_tz, formatdate, mktime_tz import time from . import cookies from .headers import Headers -from .message import Message, _native, _always_bytes, MessageData +from .message import Message, _native, _always_bytes, MessageData, MultiDictView from .. import utils -from ..odict import ODict class ResponseData(MessageData): @@ -72,29 +70,30 @@ class Response(Message): @property def cookies(self): + # type: () -> MultiDictView """ - Get the contents of all Set-Cookie headers. + The response cookies. A possibly empty :py:class:`MultiDictView`, where the keys are + cookie name strings, and values are (value, attr) tuples. Value is a string, and attr is + an ODictCaseless containing cookie attributes. Within attrs, unary attributes (e.g. HTTPOnly) + are indicated by a Null value. - A possibly empty :py:class:`ODict`, where keys are cookie name strings, - and values are [value, attr] lists. Value is a string, and attr is - an ODictCaseless containing cookie attributes. Within attrs, unary - attributes (e.g. HTTPOnly) are indicated by a Null value. + Caveats: + Updating the attr """ - ret = [] - for header in self.headers.get_all("set-cookie"): - v = cookies.parse_set_cookie_header(header) - if v: - name, value, attrs = v - ret.append([name, [value, attrs]]) - return ODict(ret) + return MultiDictView("cookies", self) + + @property + def _cookies(self): + h = self.headers.get_all("set-cookie") + return tuple(cookies.parse_set_cookie_headers(h)) @cookies.setter - def cookies(self, odict): - values = [] - for i in odict.lst: - header = cookies.format_set_cookie_header(i[0], i[1][0], i[1][1]) - values.append(header) - self.headers.set_all("set-cookie", values) + def cookies(self, all_cookies): + cookie_headers = [] + for k, v in all_cookies: + header = cookies.format_set_cookie_header(k, v[0], v[1]) + cookie_headers.append(header) + self.headers.set_all("set-cookie", cookie_headers) def refresh(self, now=None): """ diff --git a/netlib/multidict.py b/netlib/multidict.py new file mode 100644 index 00000000..a359d46b --- /dev/null +++ b/netlib/multidict.py @@ -0,0 +1,248 @@ +from __future__ import absolute_import, print_function, division + +from abc import ABCMeta, abstractmethod + +from typing import Tuple, TypeVar + +try: + from collections.abc import MutableMapping +except ImportError: # pragma: no cover + from collections import MutableMapping # Workaround for Python < 3.3 + +import six + +from .utils import Serializable + + +@six.add_metaclass(ABCMeta) +class MultiDict(MutableMapping, Serializable): + def __init__(self, fields=None): + + # it is important for us that .fields is immutable, so that we can easily + # detect changes to it. + self.fields = tuple(fields) if fields else tuple() # type: Tuple[Tuple[bytes, bytes], ...] + + def __repr__(self): + fields = tuple( + repr(field) + for field in self.fields + ) + return "{cls}[{fields}]".format( + cls=type(self).__name__, + fields=", ".join(fields) + ) + + @staticmethod + @abstractmethod + def _reduce_values(values): + """ + If a user accesses multidict["foo"], this method + reduces all values for "foo" to a single value that is returned. + For example, HTTP headers are folded, whereas we will just take + the first cookie we found with that name. + """ + + @staticmethod + @abstractmethod + def _kconv(key): + """ + This method converts a key to its canonical representation. + For example, HTTP headers are case-insensitive, so this method returns key.lower(). + """ + + def __getitem__(self, key): + values = self.get_all(key) + if not values: + raise KeyError(key) + return self._reduce_values(values) + + def __setitem__(self, key, value): + self.set_all(key, [value]) + + def __delitem__(self, key): + if key not in self: + raise KeyError(key) + key = self._kconv(key) + self.fields = tuple( + field for field in self.fields + if key != self._kconv(field[0]) + ) + + def __iter__(self): + seen = set() + for key, _ in self.fields: + key_kconv = self._kconv(key) + if key_kconv not in seen: + seen.add(key_kconv) + yield key + + def __len__(self): + return len(set(self._kconv(key) for key, _ in self.fields)) + + def __eq__(self, other): + if isinstance(other, MultiDict): + return self.fields == other.fields + return False + + def __ne__(self, other): + return not self.__eq__(other) + + def get_all(self, key): + """ + Return the list of all values for a given key. + If that key is not in the MultiDict, the return value will be an empty list. + """ + key = self._kconv(key) + return [ + value + for k, value in self.fields + if self._kconv(k) == key + ] + + def set_all(self, key, values): + """ + Remove the old values for a key and add new ones. + """ + key_kconv = self._kconv(key) + + new_fields = [] + for field in self.fields: + if self._kconv(field[0]) == key_kconv: + if values: + new_fields.append( + (key, values.pop(0)) + ) + else: + new_fields.append(field) + while values: + new_fields.append( + (key, values.pop(0)) + ) + self.fields = tuple(new_fields) + + def add(self, key, value): + """ + Add an additional value for the given key at the bottom. + """ + self.insert(len(self.fields), key, value) + + def insert(self, index, key, value): + """ + Insert an additional value for the given key at the specified position. + """ + item = (key, value) + self.fields = self.fields[:index] + (item,) + self.fields[index:] + + def keys(self, multi=False): + """ + Get all keys. + + Args: + multi(bool): + If True, one key per value will be returned. + If False, duplicate keys will only be returned once. + """ + return ( + k + for k, _ in self.items(multi) + ) + + def values(self, multi=False): + """ + Get all values. + + Args: + multi(bool): + If True, all values will be returned. + If False, only the first value per key will be returned. + """ + return ( + v + for _, v in self.items(multi) + ) + + def items(self, multi=False): + """ + Get all (key, value) tuples. + + Args: + multi(bool): + If True, all (key, value) pairs will be returned + If False, only the first (key, value) pair per unique key will be returned. + """ + if multi: + return self.fields + else: + return super(MultiDict, self).items() + + def to_dict(self): + """ + Get the MultiDict as a plain Python dict. + Keys with multiple values are returned as lists. + + Example: + + .. code-block:: python + + # Simple dict with duplicate values. + >>> d + MultiDictView[("name", "value"), ("a", "false"), ("a", "42")] + >>> d.to_dict() + { + "name": "value", + "a": ["false", "42"] + } + """ + d = {} + for key in self: + values = self.get_all(key) + if len(values) == 1: + d[key] = values[0] + else: + d[key] = values + return d + + def get_state(self): + return self.fields + + def set_state(self, state): + self.fields = tuple(tuple(x) for x in state) + + @classmethod + def from_state(cls, state): + return cls(tuple(x) for x in state) + + +@six.add_metaclass(ABCMeta) +class ImmutableMultiDict(MultiDict): + def _immutable(self, *_): + raise TypeError('{} objects are immutable'.format(self.__class__.__name__)) + + __delitem__ = set_all = insert = _immutable + + def with_delitem(self, key): + """ + Returns: + An updated ImmutableMultiDict. The original object will not be modified. + """ + ret = self.copy() + super(ImmutableMultiDict, ret).__delitem__(key) + return ret + + def with_set_all(self, key, values): + """ + Returns: + An updated ImmutableMultiDict. The original object will not be modified. + """ + ret = self.copy() + super(ImmutableMultiDict, ret).set_all(key, values) + return ret + + def with_insert(self, index, key, value): + """ + Returns: + An updated ImmutableMultiDict. The original object will not be modified. + """ + ret = self.copy() + super(ImmutableMultiDict, ret).insert(index, key, value) + return ret diff --git a/netlib/utils.py b/netlib/utils.py index be2701a0..7499f71f 100644 --- a/netlib/utils.py +++ b/netlib/utils.py @@ -51,17 +51,6 @@ def always_bytes(unicode_or_bytes, *encode_args): return unicode_or_bytes -def always_byte_args(*encode_args): - """Decorator that transparently encodes all arguments passed as unicode""" - def decorator(fun): - def _fun(*args, **kwargs): - args = [always_bytes(arg, *encode_args) for arg in args] - kwargs = {k: always_bytes(v, *encode_args) for k, v in six.iteritems(kwargs)} - return fun(*args, **kwargs) - return _fun - return decorator - - def native(s, *encoding_opts): """ Convert :py:class:`bytes` or :py:class:`unicode` to the native |