diff options
Diffstat (limited to 'netlib/http')
-rw-r--r-- | netlib/http/authentication.py | 16 | ||||
-rw-r--r-- | netlib/http/cookies.py | 78 | ||||
-rw-r--r-- | netlib/http/headers.py | 10 | ||||
-rw-r--r-- | netlib/http/message.py | 6 | ||||
-rw-r--r-- | netlib/http/request.py | 22 | ||||
-rw-r--r-- | netlib/http/response.py | 66 |
6 files changed, 162 insertions, 36 deletions
diff --git a/netlib/http/authentication.py b/netlib/http/authentication.py index 38ea46d6..58fc9bdc 100644 --- a/netlib/http/authentication.py +++ b/netlib/http/authentication.py @@ -50,9 +50,9 @@ class NullProxyAuth(object): return {} -class BasicProxyAuth(NullProxyAuth): - CHALLENGE_HEADER = 'Proxy-Authenticate' - AUTH_HEADER = 'Proxy-Authorization' +class BasicAuth(NullProxyAuth): + CHALLENGE_HEADER = None + AUTH_HEADER = None def __init__(self, password_manager, realm): NullProxyAuth.__init__(self, password_manager) @@ -80,6 +80,16 @@ class BasicProxyAuth(NullProxyAuth): return {self.CHALLENGE_HEADER: 'Basic realm="%s"' % self.realm} +class BasicWebsiteAuth(BasicAuth): + CHALLENGE_HEADER = 'WWW-Authenticate' + AUTH_HEADER = 'Authorization' + + +class BasicProxyAuth(BasicAuth): + CHALLENGE_HEADER = 'Proxy-Authenticate' + AUTH_HEADER = 'Proxy-Authorization' + + class PassMan(object): def test(self, username_, password_token_): diff --git a/netlib/http/cookies.py b/netlib/http/cookies.py index dd0af99c..1421d8eb 100644 --- a/netlib/http/cookies.py +++ b/netlib/http/cookies.py @@ -26,6 +26,12 @@ variants. Serialization follows RFC6265. http://tools.ietf.org/html/rfc2965 """ +_cookie_params = set(( + 'expires', 'path', 'comment', 'max-age', + 'secure', 'httponly', 'version', +)) + + # TODO: Disallow LHS-only Cookie values @@ -263,27 +269,69 @@ def refresh_set_cookie_header(c, delta): return ret -def is_expired(cookie_attrs): +def get_expiration_ts(cookie_attrs): """ - Determines whether a cookie has expired. + Determines the time when the cookie will be expired. - Returns: boolean - """ + Considering both 'expires' and 'max-age' parameters. - # See if 'expires' time is in the past - expires = False + Returns: timestamp of when the cookie will expire. + None, if no expiration time is set. + """ if 'expires' in cookie_attrs: e = email.utils.parsedate_tz(cookie_attrs["expires"]) if e: - exp_ts = email.utils.mktime_tz(e) + return email.utils.mktime_tz(e) + + elif 'max-age' in cookie_attrs: + try: + max_age = int(cookie_attrs['Max-Age']) + except ValueError: + pass + else: now_ts = time.time() - expires = exp_ts < now_ts + return now_ts + max_age + + return None - # or if Max-Age is 0 - max_age = False - try: - max_age = int(cookie_attrs.get('Max-Age', 1)) == 0 - except ValueError: - pass - return expires or max_age +def is_expired(cookie_attrs): + """ + Determines whether a cookie has expired. + + Returns: boolean + """ + + exp_ts = get_expiration_ts(cookie_attrs) + now_ts = time.time() + + # If no expiration information was provided with the cookie + if exp_ts is None: + return False + else: + return exp_ts <= now_ts + + +def group_cookies(pairs): + """ + Converts a list of pairs to a (name, value, attrs) for each cookie. + """ + + if not pairs: + return [] + + cookie_list = [] + + # First pair is always a new cookie + name, value = pairs[0] + attrs = [] + + for k, v in pairs[1:]: + if k.lower() in _cookie_params: + attrs.append((k, v)) + else: + cookie_list.append((name, value, CookieAttrs(attrs))) + name, value, attrs = k, v, [] + + cookie_list.append((name, value, CookieAttrs(attrs))) + return cookie_list diff --git a/netlib/http/headers.py b/netlib/http/headers.py index 36e5060c..131e8ce5 100644 --- a/netlib/http/headers.py +++ b/netlib/http/headers.py @@ -158,7 +158,7 @@ class Headers(multidict.MultiDict): else: return super(Headers, self).items() - def replace(self, pattern, repl, flags=0): + def replace(self, pattern, repl, flags=0, count=0): """ Replaces a regular expression pattern with repl in each "name: value" header line. @@ -172,10 +172,10 @@ class Headers(multidict.MultiDict): repl = strutils.escaped_str_to_bytes(repl) pattern = re.compile(pattern, flags) replacements = 0 - + flag_count = count > 0 fields = [] for name, value in self.fields: - line, n = pattern.subn(repl, name + b": " + value) + line, n = pattern.subn(repl, name + b": " + value, count=count) try: name, value = line.split(b": ", 1) except ValueError: @@ -184,6 +184,10 @@ class Headers(multidict.MultiDict): pass else: replacements += n + if flag_count: + count -= n + if count == 0: + break fields.append((name, value)) self.fields = tuple(fields) return replacements diff --git a/netlib/http/message.py b/netlib/http/message.py index ce92bab1..0b64d4a6 100644 --- a/netlib/http/message.py +++ b/netlib/http/message.py @@ -260,7 +260,7 @@ class Message(basetypes.Serializable): if "content-encoding" not in self.headers: raise ValueError("Invalid content encoding {}".format(repr(e))) - def replace(self, pattern, repl, flags=0): + def replace(self, pattern, repl, flags=0, count=0): """ Replaces a regular expression pattern with repl in both the headers and the body of the message. Encoded body will be decoded @@ -276,9 +276,9 @@ class Message(basetypes.Serializable): replacements = 0 if self.content: self.content, replacements = re.subn( - pattern, repl, self.content, flags=flags + pattern, repl, self.content, flags=flags, count=count ) - replacements += self.headers.replace(pattern, repl, flags) + replacements += self.headers.replace(pattern, repl, flags=flags, count=count) return replacements # Legacy diff --git a/netlib/http/request.py b/netlib/http/request.py index d59fead4..e0aaa8a9 100644 --- a/netlib/http/request.py +++ b/netlib/http/request.py @@ -20,8 +20,20 @@ host_header_re = re.compile(r"^(?P<host>[^:]+|\[.+\])(?::(?P<port>\d+))?$") class RequestData(message.MessageData): - def __init__(self, first_line_format, method, scheme, host, port, path, http_version, headers=(), content=None, - timestamp_start=None, timestamp_end=None): + def __init__( + self, + first_line_format, + method, + scheme, + host, + port, + path, + http_version, + headers=(), + content=None, + timestamp_start=None, + timestamp_end=None + ): if isinstance(method, six.text_type): method = method.encode("ascii", "strict") if isinstance(scheme, six.text_type): @@ -68,7 +80,7 @@ class Request(message.Message): self.method, hostport, path ) - def replace(self, pattern, repl, flags=0): + def replace(self, pattern, repl, flags=0, count=0): """ Replaces a regular expression pattern with repl in the headers, the request path and the body of the request. Encoded content will be @@ -82,9 +94,9 @@ class Request(message.Message): if isinstance(repl, six.text_type): repl = strutils.escaped_str_to_bytes(repl) - c = super(Request, self).replace(pattern, repl, flags) + c = super(Request, self).replace(pattern, repl, flags, count) self.path, pc = re.subn( - pattern, repl, self.data.path, flags=flags + pattern, repl, self.data.path, flags=flags, count=count ) c += pc return c diff --git a/netlib/http/response.py b/netlib/http/response.py index 85f54940..ae29298f 100644 --- a/netlib/http/response.py +++ b/netlib/http/response.py @@ -1,19 +1,32 @@ from __future__ import absolute_import, print_function, division -from email.utils import parsedate_tz, formatdate, mktime_tz -import time import six - +import time +from email.utils import parsedate_tz, formatdate, mktime_tz +from netlib import human +from netlib import multidict from netlib.http import cookies from netlib.http import headers as nheaders from netlib.http import message -from netlib import multidict -from netlib import human +from netlib.http import status_codes +from typing import AnyStr # noqa +from typing import Dict # noqa +from typing import Iterable # noqa +from typing import Tuple # noqa +from typing import Union # noqa class ResponseData(message.MessageData): - def __init__(self, http_version, status_code, reason=None, headers=(), content=None, - timestamp_start=None, timestamp_end=None): + def __init__( + self, + http_version, + status_code, + reason=None, + headers=(), + content=None, + timestamp_start=None, + timestamp_end=None + ): if isinstance(http_version, six.text_type): http_version = http_version.encode("ascii", "strict") if isinstance(reason, six.text_type): @@ -54,6 +67,45 @@ class Response(message.Message): details=details ) + @classmethod + def make( + cls, + status_code=200, # type: int + content=b"", # type: AnyStr + headers=() # type: Union[Dict[AnyStr, AnyStr], Iterable[Tuple[bytes, bytes]]] + ): + """ + Simplified API for creating response objects. + """ + resp = cls( + b"HTTP/1.1", + status_code, + status_codes.RESPONSES.get(status_code, "").encode(), + (), + None + ) + # Assign this manually to update the content-length header. + if isinstance(content, bytes): + resp.content = content + elif isinstance(content, str): + resp.text = content + else: + raise TypeError("Expected content to be str or bytes, but is {}.".format( + type(content).__name__ + )) + + # Headers can be list or dict, we differentiate here. + if isinstance(headers, dict): + resp.headers = nheaders.Headers(**headers) + elif isinstance(headers, Iterable): + resp.headers = nheaders.Headers(headers) + else: + raise TypeError("Expected headers to be an iterable or dict, but is {}.".format( + type(headers).__name__ + )) + + return resp + @property def status_code(self): """ |