diff options
Diffstat (limited to 'netlib/http')
-rw-r--r-- | netlib/http/message.py | 42 | ||||
-rw-r--r-- | netlib/http/request.py | 12 |
2 files changed, 38 insertions, 16 deletions
diff --git a/netlib/http/message.py b/netlib/http/message.py index ca3a4145..86ff64d1 100644 --- a/netlib/http/message.py +++ b/netlib/http/message.py @@ -124,6 +124,9 @@ class Message(basetypes.Serializable): """ The HTTP message body decoded with the content-encoding header (e.g. gzip) + Raises: + ValueError, when getting the content and the content-encoding is invalid. + See also: :py:class:`raw_content`, :py:attr:`text` """ ce = self.headers.get("content-encoding") @@ -132,17 +135,21 @@ class Message(basetypes.Serializable): self._content_cache.encoding == ce ) if not cached: - try: - if not ce: - raise ValueError() + if ce: decoded = encoding.decode(self.raw_content, ce) - except ValueError: + else: decoded = self.raw_content self._content_cache = CachedDecode(self.raw_content, ce, decoded) return self._content_cache.decoded @content.setter def content(self, value): + if value is not None and not isinstance(value, bytes): + raise TypeError( + "Message content must be bytes, not {}. " + "Please use .text if you want to assign a str." + .format(type(value).__name__) + ) ce = self.headers.get("content-encoding") cached = ( self._content_cache.decoded == value and @@ -150,15 +157,15 @@ class Message(basetypes.Serializable): ) if not cached: try: - if not ce: - raise ValueError() - encoded = encoding.encode(value, ce) + if ce and value is not None: + encoded = encoding.encode(value, ce) + else: + encoded = value except ValueError: - # Do we have an unknown content-encoding? - # If so, we want to remove it. - if value and ce: - self.headers.pop("content-encoding", None) - ce = None + # So we have an invalid content-encoding? + # Let's remove it! + del self.headers["content-encoding"] + ce = None encoded = value self._content_cache = CachedDecode(encoded, ce, value) self.raw_content = self._content_cache.encoded @@ -262,6 +269,9 @@ class Message(basetypes.Serializable): Decodes body based on the current Content-Encoding header, then removes the header. If there is no Content-Encoding header, no action is taken. + + Raises: + ValueError, when the content-encoding is invalid. """ self.raw_content = self.content self.headers.pop("content-encoding", None) @@ -269,10 +279,16 @@ class Message(basetypes.Serializable): def encode(self, e): """ Encodes body with the encoding e, where e is "gzip", "deflate" or "identity". + Any existing content-encodings are overwritten, + the content is not decoded beforehand. + + Raises: + ValueError, when the specified content-encoding is invalid. """ - self.decode() # remove the current encoding self.headers["content-encoding"] = e self.content = self.raw_content + if "content-encoding" not in self.headers: + raise ValueError("Invalid content encoding {}".format(repr(e))) def replace(self, pattern, repl, flags=0): """ diff --git a/netlib/http/request.py b/netlib/http/request.py index 4ce94549..a8ec6238 100644 --- a/netlib/http/request.py +++ b/netlib/http/request.py @@ -347,7 +347,10 @@ class Request(message.Message): def _get_urlencoded_form(self): is_valid_content_type = "application/x-www-form-urlencoded" in self.headers.get("content-type", "").lower() if is_valid_content_type: - return tuple(netlib.http.url.decode(self.content)) + try: + return tuple(netlib.http.url.decode(self.content)) + except ValueError: + pass return () def _set_urlencoded_form(self, value): @@ -356,7 +359,7 @@ class Request(message.Message): This will overwrite the existing content if there is one. """ self.headers["content-type"] = "application/x-www-form-urlencoded" - self.content = netlib.http.url.encode(value) + self.content = netlib.http.url.encode(value).encode() @urlencoded_form.setter def urlencoded_form(self, value): @@ -376,7 +379,10 @@ class Request(message.Message): def _get_multipart_form(self): is_valid_content_type = "multipart/form-data" in self.headers.get("content-type", "").lower() if is_valid_content_type: - return multipart.decode(self.headers, self.content) + try: + return multipart.decode(self.headers, self.content) + except ValueError: + pass return () def _set_multipart_form(self, value): |