diff options
Diffstat (limited to 'netlib')
-rw-r--r-- | netlib/http/authentication.py | 16 | ||||
-rw-r--r-- | netlib/http/cookies.py | 78 | ||||
-rw-r--r-- | netlib/http/headers.py | 10 | ||||
-rw-r--r-- | netlib/http/message.py | 6 | ||||
-rw-r--r-- | netlib/http/request.py | 22 | ||||
-rw-r--r-- | netlib/http/response.py | 66 | ||||
-rw-r--r-- | netlib/strutils.py | 3 | ||||
-rw-r--r-- | netlib/websockets/__init__.py | 34 | ||||
-rw-r--r-- | netlib/websockets/frame.py | 142 | ||||
-rw-r--r-- | netlib/websockets/masker.py | 33 | ||||
-rw-r--r-- | netlib/websockets/protocol.py | 112 | ||||
-rw-r--r-- | netlib/websockets/utils.py | 90 |
12 files changed, 372 insertions, 240 deletions
diff --git a/netlib/http/authentication.py b/netlib/http/authentication.py index 38ea46d6..58fc9bdc 100644 --- a/netlib/http/authentication.py +++ b/netlib/http/authentication.py @@ -50,9 +50,9 @@ class NullProxyAuth(object): return {} -class BasicProxyAuth(NullProxyAuth): - CHALLENGE_HEADER = 'Proxy-Authenticate' - AUTH_HEADER = 'Proxy-Authorization' +class BasicAuth(NullProxyAuth): + CHALLENGE_HEADER = None + AUTH_HEADER = None def __init__(self, password_manager, realm): NullProxyAuth.__init__(self, password_manager) @@ -80,6 +80,16 @@ class BasicProxyAuth(NullProxyAuth): return {self.CHALLENGE_HEADER: 'Basic realm="%s"' % self.realm} +class BasicWebsiteAuth(BasicAuth): + CHALLENGE_HEADER = 'WWW-Authenticate' + AUTH_HEADER = 'Authorization' + + +class BasicProxyAuth(BasicAuth): + CHALLENGE_HEADER = 'Proxy-Authenticate' + AUTH_HEADER = 'Proxy-Authorization' + + class PassMan(object): def test(self, username_, password_token_): diff --git a/netlib/http/cookies.py b/netlib/http/cookies.py index dd0af99c..1421d8eb 100644 --- a/netlib/http/cookies.py +++ b/netlib/http/cookies.py @@ -26,6 +26,12 @@ variants. Serialization follows RFC6265. http://tools.ietf.org/html/rfc2965 """ +_cookie_params = set(( + 'expires', 'path', 'comment', 'max-age', + 'secure', 'httponly', 'version', +)) + + # TODO: Disallow LHS-only Cookie values @@ -263,27 +269,69 @@ def refresh_set_cookie_header(c, delta): return ret -def is_expired(cookie_attrs): +def get_expiration_ts(cookie_attrs): """ - Determines whether a cookie has expired. + Determines the time when the cookie will be expired. - Returns: boolean - """ + Considering both 'expires' and 'max-age' parameters. - # See if 'expires' time is in the past - expires = False + Returns: timestamp of when the cookie will expire. + None, if no expiration time is set. + """ if 'expires' in cookie_attrs: e = email.utils.parsedate_tz(cookie_attrs["expires"]) if e: - exp_ts = email.utils.mktime_tz(e) + return email.utils.mktime_tz(e) + + elif 'max-age' in cookie_attrs: + try: + max_age = int(cookie_attrs['Max-Age']) + except ValueError: + pass + else: now_ts = time.time() - expires = exp_ts < now_ts + return now_ts + max_age + + return None - # or if Max-Age is 0 - max_age = False - try: - max_age = int(cookie_attrs.get('Max-Age', 1)) == 0 - except ValueError: - pass - return expires or max_age +def is_expired(cookie_attrs): + """ + Determines whether a cookie has expired. + + Returns: boolean + """ + + exp_ts = get_expiration_ts(cookie_attrs) + now_ts = time.time() + + # If no expiration information was provided with the cookie + if exp_ts is None: + return False + else: + return exp_ts <= now_ts + + +def group_cookies(pairs): + """ + Converts a list of pairs to a (name, value, attrs) for each cookie. + """ + + if not pairs: + return [] + + cookie_list = [] + + # First pair is always a new cookie + name, value = pairs[0] + attrs = [] + + for k, v in pairs[1:]: + if k.lower() in _cookie_params: + attrs.append((k, v)) + else: + cookie_list.append((name, value, CookieAttrs(attrs))) + name, value, attrs = k, v, [] + + cookie_list.append((name, value, CookieAttrs(attrs))) + return cookie_list diff --git a/netlib/http/headers.py b/netlib/http/headers.py index 36e5060c..131e8ce5 100644 --- a/netlib/http/headers.py +++ b/netlib/http/headers.py @@ -158,7 +158,7 @@ class Headers(multidict.MultiDict): else: return super(Headers, self).items() - def replace(self, pattern, repl, flags=0): + def replace(self, pattern, repl, flags=0, count=0): """ Replaces a regular expression pattern with repl in each "name: value" header line. @@ -172,10 +172,10 @@ class Headers(multidict.MultiDict): repl = strutils.escaped_str_to_bytes(repl) pattern = re.compile(pattern, flags) replacements = 0 - + flag_count = count > 0 fields = [] for name, value in self.fields: - line, n = pattern.subn(repl, name + b": " + value) + line, n = pattern.subn(repl, name + b": " + value, count=count) try: name, value = line.split(b": ", 1) except ValueError: @@ -184,6 +184,10 @@ class Headers(multidict.MultiDict): pass else: replacements += n + if flag_count: + count -= n + if count == 0: + break fields.append((name, value)) self.fields = tuple(fields) return replacements diff --git a/netlib/http/message.py b/netlib/http/message.py index ce92bab1..0b64d4a6 100644 --- a/netlib/http/message.py +++ b/netlib/http/message.py @@ -260,7 +260,7 @@ class Message(basetypes.Serializable): if "content-encoding" not in self.headers: raise ValueError("Invalid content encoding {}".format(repr(e))) - def replace(self, pattern, repl, flags=0): + def replace(self, pattern, repl, flags=0, count=0): """ Replaces a regular expression pattern with repl in both the headers and the body of the message. Encoded body will be decoded @@ -276,9 +276,9 @@ class Message(basetypes.Serializable): replacements = 0 if self.content: self.content, replacements = re.subn( - pattern, repl, self.content, flags=flags + pattern, repl, self.content, flags=flags, count=count ) - replacements += self.headers.replace(pattern, repl, flags) + replacements += self.headers.replace(pattern, repl, flags=flags, count=count) return replacements # Legacy diff --git a/netlib/http/request.py b/netlib/http/request.py index d59fead4..e0aaa8a9 100644 --- a/netlib/http/request.py +++ b/netlib/http/request.py @@ -20,8 +20,20 @@ host_header_re = re.compile(r"^(?P<host>[^:]+|\[.+\])(?::(?P<port>\d+))?$") class RequestData(message.MessageData): - def __init__(self, first_line_format, method, scheme, host, port, path, http_version, headers=(), content=None, - timestamp_start=None, timestamp_end=None): + def __init__( + self, + first_line_format, + method, + scheme, + host, + port, + path, + http_version, + headers=(), + content=None, + timestamp_start=None, + timestamp_end=None + ): if isinstance(method, six.text_type): method = method.encode("ascii", "strict") if isinstance(scheme, six.text_type): @@ -68,7 +80,7 @@ class Request(message.Message): self.method, hostport, path ) - def replace(self, pattern, repl, flags=0): + def replace(self, pattern, repl, flags=0, count=0): """ Replaces a regular expression pattern with repl in the headers, the request path and the body of the request. Encoded content will be @@ -82,9 +94,9 @@ class Request(message.Message): if isinstance(repl, six.text_type): repl = strutils.escaped_str_to_bytes(repl) - c = super(Request, self).replace(pattern, repl, flags) + c = super(Request, self).replace(pattern, repl, flags, count) self.path, pc = re.subn( - pattern, repl, self.data.path, flags=flags + pattern, repl, self.data.path, flags=flags, count=count ) c += pc return c diff --git a/netlib/http/response.py b/netlib/http/response.py index 85f54940..ae29298f 100644 --- a/netlib/http/response.py +++ b/netlib/http/response.py @@ -1,19 +1,32 @@ from __future__ import absolute_import, print_function, division -from email.utils import parsedate_tz, formatdate, mktime_tz -import time import six - +import time +from email.utils import parsedate_tz, formatdate, mktime_tz +from netlib import human +from netlib import multidict from netlib.http import cookies from netlib.http import headers as nheaders from netlib.http import message -from netlib import multidict -from netlib import human +from netlib.http import status_codes +from typing import AnyStr # noqa +from typing import Dict # noqa +from typing import Iterable # noqa +from typing import Tuple # noqa +from typing import Union # noqa class ResponseData(message.MessageData): - def __init__(self, http_version, status_code, reason=None, headers=(), content=None, - timestamp_start=None, timestamp_end=None): + def __init__( + self, + http_version, + status_code, + reason=None, + headers=(), + content=None, + timestamp_start=None, + timestamp_end=None + ): if isinstance(http_version, six.text_type): http_version = http_version.encode("ascii", "strict") if isinstance(reason, six.text_type): @@ -54,6 +67,45 @@ class Response(message.Message): details=details ) + @classmethod + def make( + cls, + status_code=200, # type: int + content=b"", # type: AnyStr + headers=() # type: Union[Dict[AnyStr, AnyStr], Iterable[Tuple[bytes, bytes]]] + ): + """ + Simplified API for creating response objects. + """ + resp = cls( + b"HTTP/1.1", + status_code, + status_codes.RESPONSES.get(status_code, "").encode(), + (), + None + ) + # Assign this manually to update the content-length header. + if isinstance(content, bytes): + resp.content = content + elif isinstance(content, str): + resp.text = content + else: + raise TypeError("Expected content to be str or bytes, but is {}.".format( + type(content).__name__ + )) + + # Headers can be list or dict, we differentiate here. + if isinstance(headers, dict): + resp.headers = nheaders.Headers(**headers) + elif isinstance(headers, Iterable): + resp.headers = nheaders.Headers(headers) + else: + raise TypeError("Expected headers to be an iterable or dict, but is {}.".format( + type(headers).__name__ + )) + + return resp + @property def status_code(self): """ diff --git a/netlib/strutils.py b/netlib/strutils.py index 4a46b6b1..4cb3b805 100644 --- a/netlib/strutils.py +++ b/netlib/strutils.py @@ -121,6 +121,9 @@ def escaped_str_to_bytes(data): def is_mostly_bin(s): # type: (bytes) -> bool + if not s or len(s) == 0: + return False + return sum( i < 9 or 13 < i < 32 or 126 < i for i in six.iterbytes(s[:100]) diff --git a/netlib/websockets/__init__.py b/netlib/websockets/__init__.py index fea696d9..e14e8a7d 100644 --- a/netlib/websockets/__init__.py +++ b/netlib/websockets/__init__.py @@ -1,11 +1,37 @@ from __future__ import absolute_import, print_function, division -from .frame import FrameHeader, Frame, OPCODE -from .protocol import Masker, WebsocketsProtocol + +from .frame import FrameHeader +from .frame import Frame +from .frame import OPCODE +from .frame import CLOSE_REASON +from .masker import Masker +from .utils import MAGIC +from .utils import VERSION +from .utils import client_handshake_headers +from .utils import server_handshake_headers +from .utils import check_handshake +from .utils import check_client_version +from .utils import create_server_nonce +from .utils import get_extensions +from .utils import get_protocol +from .utils import get_client_key +from .utils import get_server_accept __all__ = [ "FrameHeader", "Frame", - "Masker", - "WebsocketsProtocol", "OPCODE", + "CLOSE_REASON", + "Masker", + "MAGIC", + "VERSION", + "client_handshake_headers", + "server_handshake_headers", + "check_handshake", + "check_client_version", + "create_server_nonce", + "get_extensions", + "get_protocol", + "get_client_key", + "get_server_accept", ] diff --git a/netlib/websockets/frame.py b/netlib/websockets/frame.py index 7d355699..e62d0e87 100644 --- a/netlib/websockets/frame.py +++ b/netlib/websockets/frame.py @@ -2,7 +2,6 @@ from __future__ import absolute_import import os import struct import io -import warnings import six @@ -10,7 +9,7 @@ from netlib import tcp from netlib import strutils from netlib import utils from netlib import human -from netlib.websockets import protocol +from .masker import Masker MAX_16_BIT_INT = (1 << 16) @@ -18,6 +17,7 @@ MAX_64_BIT_INT = (1 << 64) DEFAULT = object() +# RFC 6455, Section 5.2 - Base Framing Protocol OPCODE = utils.BiDi( CONTINUE=0x00, TEXT=0x01, @@ -27,6 +27,23 @@ OPCODE = utils.BiDi( PONG=0x0a ) +# RFC 6455, Section 7.4.1 - Defined Status Codes +CLOSE_REASON = utils.BiDi( + NORMAL_CLOSURE=1000, + GOING_AWAY=1001, + PROTOCOL_ERROR=1002, + UNSUPPORTED_DATA=1003, + RESERVED=1004, + RESERVED_NO_STATUS=1005, + RESERVED_ABNORMAL_CLOSURE=1006, + INVALID_PAYLOAD_DATA=1007, + POLICY_VIOLATION=1008, + MESSAGE_TOO_BIG=1009, + MANDATORY_EXTENSION=1010, + INTERNAL_ERROR=1011, + RESERVED_TLS_HANDHSAKE_FAILED=1015, +) + class FrameHeader(object): @@ -103,10 +120,6 @@ class FrameHeader(object): vals.append(" %s" % human.pretty_size(self.payload_length)) return "".join(vals) - def human_readable(self): - warnings.warn("FrameHeader.to_bytes is deprecated, use bytes(frame_header) instead.", DeprecationWarning) - return repr(self) - def __bytes__(self): first_byte = utils.setbit(0, 7, self.fin) first_byte = utils.setbit(first_byte, 6, self.rsv1) @@ -128,6 +141,9 @@ class FrameHeader(object): # '!Q' = pack as 64 bit unsigned long long # add 8 bytes extended payload length b += struct.pack('!Q', self.payload_length) + else: + raise ValueError("Payload length exceeds 64bit integer") + if self.masking_key: b += self.masking_key return b @@ -135,10 +151,6 @@ class FrameHeader(object): if six.PY2: __str__ = __bytes__ - def to_bytes(self): - warnings.warn("FrameHeader.to_bytes is deprecated, use bytes(frame_header) instead.", DeprecationWarning) - return bytes(self) - @classmethod def from_file(cls, fp): """ @@ -151,19 +163,17 @@ class FrameHeader(object): rsv1 = utils.getbit(first_byte, 6) rsv2 = utils.getbit(first_byte, 5) rsv3 = utils.getbit(first_byte, 4) - # grab right-most 4 bits - opcode = first_byte & 15 + opcode = first_byte & 0xF mask_bit = utils.getbit(second_byte, 7) - # grab the next 7 bits - length_code = second_byte & 127 + length_code = second_byte & 0x7F - # payload_lengthy > 125 indicates you need to read more bytes + # payload_length > 125 indicates you need to read more bytes # to get the actual payload length if length_code <= 125: payload_length = length_code elif length_code == 126: payload_length, = struct.unpack("!H", fp.safe_read(2)) - elif length_code == 127: + else: # length_code == 127: payload_length, = struct.unpack("!Q", fp.safe_read(8)) # masking key only present if mask bit set @@ -191,31 +201,30 @@ class FrameHeader(object): class Frame(object): - """ - Represents one websockets frame. - Constructor takes human readable forms of the frame components - from_bytes() is also avaliable. - - WebSockets Frame as defined in RFC6455 - - 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 - +-+-+-+-+-------+-+-------------+-------------------------------+ - |F|R|R|R| opcode|M| Payload len | Extended payload length | - |I|S|S|S| (4) |A| (7) | (16/64) | - |N|V|V|V| |S| | (if payload len==126/127) | - | |1|2|3| |K| | | - +-+-+-+-+-------+-+-------------+ - - - - - - - - - - - - - - - + - | Extended payload length continued, if payload len == 127 | - + - - - - - - - - - - - - - - - +-------------------------------+ - | |Masking-key, if MASK set to 1 | - +-------------------------------+-------------------------------+ - | Masking-key (continued) | Payload Data | - +-------------------------------- - - - - - - - - - - - - - - - + - : Payload Data continued ... : - + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + - | Payload Data continued ... | - +---------------------------------------------------------------+ + Represents a single WebSockets frame. + Constructor takes human readable forms of the frame components. + from_bytes() reads from a file-like object to create a new Frame. + + WebSockets Frame as defined in RFC6455 + + 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 + +-+-+-+-+-------+-+-------------+-------------------------------+ + |F|R|R|R| opcode|M| Payload len | Extended payload length | + |I|S|S|S| (4) |A| (7) | (16/64) | + |N|V|V|V| |S| | (if payload len==126/127) | + | |1|2|3| |K| | | + +-+-+-+-+-------+-+-------------+ - - - - - - - - - - - - - - - + + | Extended payload length continued, if payload len == 127 | + + - - - - - - - - - - - - - - - +-------------------------------+ + | |Masking-key, if MASK set to 1 | + +-------------------------------+-------------------------------+ + | Masking-key (continued) | Payload Data | + +-------------------------------- - - - - - - - - - - - - - - - + + : Payload Data continued ... : + + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + | Payload Data continued ... | + +---------------------------------------------------------------+ """ def __init__(self, payload=b"", **kwargs): @@ -224,27 +233,6 @@ class Frame(object): self.header = FrameHeader(**kwargs) @classmethod - def default(cls, message, from_client=False): - """ - Construct a basic websocket frame from some default values. - Creates a non-fragmented text frame. - """ - if from_client: - mask_bit = 1 - masking_key = os.urandom(4) - else: - mask_bit = 0 - masking_key = None - - return cls( - message, - fin=1, # final frame - opcode=OPCODE.TEXT, # text - mask=mask_bit, - masking_key=masking_key, - ) - - @classmethod def from_bytes(cls, bytestring): """ Construct a websocket frame from an in-memory bytestring @@ -258,17 +246,13 @@ class Frame(object): ret = ret + "\nPayload:\n" + strutils.bytes_to_escaped_str(self.payload) return ret - def human_readable(self): - warnings.warn("Frame.to_bytes is deprecated, use bytes(frame) instead.", DeprecationWarning) - return repr(self) - def __bytes__(self): """ Serialize the frame to wire format. Returns a string. """ b = bytes(self.header) if self.header.masking_key: - b += protocol.Masker(self.header.masking_key)(self.payload) + b += Masker(self.header.masking_key)(self.payload) else: b += self.payload return b @@ -276,15 +260,6 @@ class Frame(object): if six.PY2: __str__ = __bytes__ - def to_bytes(self): - warnings.warn("FrameHeader.to_bytes is deprecated, use bytes(frame_header) instead.", DeprecationWarning) - return bytes(self) - - def to_file(self, writer): - warnings.warn("Frame.to_file is deprecated, use wfile.write(bytes(frame)) instead.", DeprecationWarning) - writer.write(bytes(self)) - writer.flush() - @classmethod def from_file(cls, fp): """ @@ -297,20 +272,11 @@ class Frame(object): payload = fp.safe_read(header.payload_length) if header.mask == 1 and header.masking_key: - payload = protocol.Masker(header.masking_key)(payload) + payload = Masker(header.masking_key)(payload) - return cls( - payload, - fin=header.fin, - opcode=header.opcode, - mask=header.mask, - payload_length=header.payload_length, - masking_key=header.masking_key, - rsv1=header.rsv1, - rsv2=header.rsv2, - rsv3=header.rsv3, - length_code=header.length_code - ) + frame = cls(payload) + frame.header = header + return frame def __eq__(self, other): if isinstance(other, Frame): diff --git a/netlib/websockets/masker.py b/netlib/websockets/masker.py new file mode 100644 index 00000000..bd39ed6a --- /dev/null +++ b/netlib/websockets/masker.py @@ -0,0 +1,33 @@ +from __future__ import absolute_import + +import six + + +class Masker(object): + """ + Data sent from the server must be masked to prevent malicious clients + from sending data over the wire in predictable patterns. + + Servers do not have to mask data they send to the client. + https://tools.ietf.org/html/rfc6455#section-5.3 + """ + + def __init__(self, key): + self.key = key + self.offset = 0 + + def mask(self, offset, data): + result = bytearray(data) + for i in range(len(data)): + if six.PY2: + result[i] ^= ord(self.key[offset % 4]) + else: + result[i] ^= self.key[offset % 4] + offset += 1 + result = bytes(result) + return result + + def __call__(self, data): + ret = self.mask(self.offset, data) + self.offset += len(ret) + return ret diff --git a/netlib/websockets/protocol.py b/netlib/websockets/protocol.py deleted file mode 100644 index af0eef7d..00000000 --- a/netlib/websockets/protocol.py +++ /dev/null @@ -1,112 +0,0 @@ -""" -Colleciton of utility functions that implement small portions of the RFC6455 -WebSockets Protocol Useful for building WebSocket clients and servers. - -Emphassis is on readabilty, simplicity and modularity, not performance or -completeness - -This is a work in progress and does not yet contain all the utilites need to -create fully complient client/servers # -Spec: https://tools.ietf.org/html/rfc6455 - -The magic sha that websocket servers must know to prove they understand -RFC6455 -""" - -from __future__ import absolute_import -import base64 -import hashlib -import os - -import six - -from netlib import http, strutils - -websockets_magic = b'258EAFA5-E914-47DA-95CA-C5AB0DC85B11' -VERSION = "13" - - -class Masker(object): - - """ - Data sent from the server must be masked to prevent malicious clients - from sending data over the wire in predictable patterns - - Servers do not have to mask data they send to the client. - https://tools.ietf.org/html/rfc6455#section-5.3 - """ - - def __init__(self, key): - self.key = key - self.offset = 0 - - def mask(self, offset, data): - result = bytearray(data) - if six.PY2: - for i in range(len(data)): - result[i] ^= ord(self.key[offset % 4]) - offset += 1 - result = str(result) - else: - - for i in range(len(data)): - result[i] ^= self.key[offset % 4] - offset += 1 - result = bytes(result) - return result - - def __call__(self, data): - ret = self.mask(self.offset, data) - self.offset += len(ret) - return ret - - -class WebsocketsProtocol(object): - - def __init__(self): - pass - - @classmethod - def client_handshake_headers(self, key=None, version=VERSION): - """ - Create the headers for a valid HTTP upgrade request. If Key is not - specified, it is generated, and can be found in sec-websocket-key in - the returned header set. - - Returns an instance of http.Headers - """ - if not key: - key = base64.b64encode(os.urandom(16)).decode('ascii') - return http.Headers( - sec_websocket_key=key, - sec_websocket_version=version, - connection="Upgrade", - upgrade="websocket", - ) - - @classmethod - def server_handshake_headers(self, key): - """ - The server response is a valid HTTP 101 response. - """ - return http.Headers( - sec_websocket_accept=self.create_server_nonce(key), - connection="Upgrade", - upgrade="websocket" - ) - - @classmethod - def check_client_handshake(self, headers): - if headers.get("upgrade") != "websocket": - return - return headers.get("sec-websocket-key") - - @classmethod - def check_server_handshake(self, headers): - if headers.get("upgrade") != "websocket": - return - return headers.get("sec-websocket-accept") - - @classmethod - def create_server_nonce(self, client_nonce): - return base64.b64encode(hashlib.sha1(strutils.always_bytes(client_nonce) + websockets_magic).digest()) diff --git a/netlib/websockets/utils.py b/netlib/websockets/utils.py new file mode 100644 index 00000000..aa0d39a1 --- /dev/null +++ b/netlib/websockets/utils.py @@ -0,0 +1,90 @@ +""" +Collection of WebSockets Protocol utility functions (RFC6455) +Spec: https://tools.ietf.org/html/rfc6455 +""" + +from __future__ import absolute_import + +import base64 +import hashlib +import os + +from netlib import http, strutils + +MAGIC = b'258EAFA5-E914-47DA-95CA-C5AB0DC85B11' +VERSION = "13" + + +def client_handshake_headers(version=None, key=None, protocol=None, extensions=None): + """ + Create the headers for a valid HTTP upgrade request. If Key is not + specified, it is generated, and can be found in sec-websocket-key in + the returned header set. + + Returns an instance of http.Headers + """ + if version is None: + version = VERSION + if key is None: + key = base64.b64encode(os.urandom(16)).decode('ascii') + h = http.Headers( + connection="upgrade", + upgrade="websocket", + sec_websocket_version=version, + sec_websocket_key=key, + ) + if protocol is not None: + h['sec-websocket-protocol'] = protocol + if extensions is not None: + h['sec-websocket-extensions'] = extensions + return h + + +def server_handshake_headers(client_key, protocol=None, extensions=None): + """ + The server response is a valid HTTP 101 response. + + Returns an instance of http.Headers + """ + h = http.Headers( + connection="upgrade", + upgrade="websocket", + sec_websocket_accept=create_server_nonce(client_key), + ) + if protocol is not None: + h['sec-websocket-protocol'] = protocol + if extensions is not None: + h['sec-websocket-extensions'] = extensions + return h + + +def check_handshake(headers): + return ( + "upgrade" in headers.get("connection", "").lower() and + headers.get("upgrade", "").lower() == "websocket" and + (headers.get("sec-websocket-key") is not None or headers.get("sec-websocket-accept") is not None) + ) + + +def create_server_nonce(client_nonce): + return base64.b64encode(hashlib.sha1(strutils.always_bytes(client_nonce) + MAGIC).digest()) + + +def check_client_version(headers): + return headers.get("sec-websocket-version", "") == VERSION + + +def get_extensions(headers): + return headers.get("sec-websocket-extensions", None) + + +def get_protocol(headers): + return headers.get("sec-websocket-protocol", None) + + +def get_client_key(headers): + return headers.get("sec-websocket-key", None) + + +def get_server_accept(headers): + return headers.get("sec-websocket-accept", None) |