aboutsummaryrefslogtreecommitdiffstats
path: root/netlib
diff options
context:
space:
mode:
Diffstat (limited to 'netlib')
-rw-r--r--netlib/http/authentication.py16
-rw-r--r--netlib/http/cookies.py78
-rw-r--r--netlib/http/headers.py10
-rw-r--r--netlib/http/message.py6
-rw-r--r--netlib/http/request.py22
-rw-r--r--netlib/http/response.py66
-rw-r--r--netlib/strutils.py3
-rw-r--r--netlib/websockets/__init__.py34
-rw-r--r--netlib/websockets/frame.py142
-rw-r--r--netlib/websockets/masker.py33
-rw-r--r--netlib/websockets/protocol.py112
-rw-r--r--netlib/websockets/utils.py90
12 files changed, 372 insertions, 240 deletions
diff --git a/netlib/http/authentication.py b/netlib/http/authentication.py
index 38ea46d6..58fc9bdc 100644
--- a/netlib/http/authentication.py
+++ b/netlib/http/authentication.py
@@ -50,9 +50,9 @@ class NullProxyAuth(object):
return {}
-class BasicProxyAuth(NullProxyAuth):
- CHALLENGE_HEADER = 'Proxy-Authenticate'
- AUTH_HEADER = 'Proxy-Authorization'
+class BasicAuth(NullProxyAuth):
+ CHALLENGE_HEADER = None
+ AUTH_HEADER = None
def __init__(self, password_manager, realm):
NullProxyAuth.__init__(self, password_manager)
@@ -80,6 +80,16 @@ class BasicProxyAuth(NullProxyAuth):
return {self.CHALLENGE_HEADER: 'Basic realm="%s"' % self.realm}
+class BasicWebsiteAuth(BasicAuth):
+ CHALLENGE_HEADER = 'WWW-Authenticate'
+ AUTH_HEADER = 'Authorization'
+
+
+class BasicProxyAuth(BasicAuth):
+ CHALLENGE_HEADER = 'Proxy-Authenticate'
+ AUTH_HEADER = 'Proxy-Authorization'
+
+
class PassMan(object):
def test(self, username_, password_token_):
diff --git a/netlib/http/cookies.py b/netlib/http/cookies.py
index dd0af99c..1421d8eb 100644
--- a/netlib/http/cookies.py
+++ b/netlib/http/cookies.py
@@ -26,6 +26,12 @@ variants. Serialization follows RFC6265.
http://tools.ietf.org/html/rfc2965
"""
+_cookie_params = set((
+ 'expires', 'path', 'comment', 'max-age',
+ 'secure', 'httponly', 'version',
+))
+
+
# TODO: Disallow LHS-only Cookie values
@@ -263,27 +269,69 @@ def refresh_set_cookie_header(c, delta):
return ret
-def is_expired(cookie_attrs):
+def get_expiration_ts(cookie_attrs):
"""
- Determines whether a cookie has expired.
+ Determines the time when the cookie will be expired.
- Returns: boolean
- """
+ Considering both 'expires' and 'max-age' parameters.
- # See if 'expires' time is in the past
- expires = False
+ Returns: timestamp of when the cookie will expire.
+ None, if no expiration time is set.
+ """
if 'expires' in cookie_attrs:
e = email.utils.parsedate_tz(cookie_attrs["expires"])
if e:
- exp_ts = email.utils.mktime_tz(e)
+ return email.utils.mktime_tz(e)
+
+ elif 'max-age' in cookie_attrs:
+ try:
+ max_age = int(cookie_attrs['Max-Age'])
+ except ValueError:
+ pass
+ else:
now_ts = time.time()
- expires = exp_ts < now_ts
+ return now_ts + max_age
+
+ return None
- # or if Max-Age is 0
- max_age = False
- try:
- max_age = int(cookie_attrs.get('Max-Age', 1)) == 0
- except ValueError:
- pass
- return expires or max_age
+def is_expired(cookie_attrs):
+ """
+ Determines whether a cookie has expired.
+
+ Returns: boolean
+ """
+
+ exp_ts = get_expiration_ts(cookie_attrs)
+ now_ts = time.time()
+
+ # If no expiration information was provided with the cookie
+ if exp_ts is None:
+ return False
+ else:
+ return exp_ts <= now_ts
+
+
+def group_cookies(pairs):
+ """
+ Converts a list of pairs to a (name, value, attrs) for each cookie.
+ """
+
+ if not pairs:
+ return []
+
+ cookie_list = []
+
+ # First pair is always a new cookie
+ name, value = pairs[0]
+ attrs = []
+
+ for k, v in pairs[1:]:
+ if k.lower() in _cookie_params:
+ attrs.append((k, v))
+ else:
+ cookie_list.append((name, value, CookieAttrs(attrs)))
+ name, value, attrs = k, v, []
+
+ cookie_list.append((name, value, CookieAttrs(attrs)))
+ return cookie_list
diff --git a/netlib/http/headers.py b/netlib/http/headers.py
index 36e5060c..131e8ce5 100644
--- a/netlib/http/headers.py
+++ b/netlib/http/headers.py
@@ -158,7 +158,7 @@ class Headers(multidict.MultiDict):
else:
return super(Headers, self).items()
- def replace(self, pattern, repl, flags=0):
+ def replace(self, pattern, repl, flags=0, count=0):
"""
Replaces a regular expression pattern with repl in each "name: value"
header line.
@@ -172,10 +172,10 @@ class Headers(multidict.MultiDict):
repl = strutils.escaped_str_to_bytes(repl)
pattern = re.compile(pattern, flags)
replacements = 0
-
+ flag_count = count > 0
fields = []
for name, value in self.fields:
- line, n = pattern.subn(repl, name + b": " + value)
+ line, n = pattern.subn(repl, name + b": " + value, count=count)
try:
name, value = line.split(b": ", 1)
except ValueError:
@@ -184,6 +184,10 @@ class Headers(multidict.MultiDict):
pass
else:
replacements += n
+ if flag_count:
+ count -= n
+ if count == 0:
+ break
fields.append((name, value))
self.fields = tuple(fields)
return replacements
diff --git a/netlib/http/message.py b/netlib/http/message.py
index ce92bab1..0b64d4a6 100644
--- a/netlib/http/message.py
+++ b/netlib/http/message.py
@@ -260,7 +260,7 @@ class Message(basetypes.Serializable):
if "content-encoding" not in self.headers:
raise ValueError("Invalid content encoding {}".format(repr(e)))
- def replace(self, pattern, repl, flags=0):
+ def replace(self, pattern, repl, flags=0, count=0):
"""
Replaces a regular expression pattern with repl in both the headers
and the body of the message. Encoded body will be decoded
@@ -276,9 +276,9 @@ class Message(basetypes.Serializable):
replacements = 0
if self.content:
self.content, replacements = re.subn(
- pattern, repl, self.content, flags=flags
+ pattern, repl, self.content, flags=flags, count=count
)
- replacements += self.headers.replace(pattern, repl, flags)
+ replacements += self.headers.replace(pattern, repl, flags=flags, count=count)
return replacements
# Legacy
diff --git a/netlib/http/request.py b/netlib/http/request.py
index d59fead4..e0aaa8a9 100644
--- a/netlib/http/request.py
+++ b/netlib/http/request.py
@@ -20,8 +20,20 @@ host_header_re = re.compile(r"^(?P<host>[^:]+|\[.+\])(?::(?P<port>\d+))?$")
class RequestData(message.MessageData):
- def __init__(self, first_line_format, method, scheme, host, port, path, http_version, headers=(), content=None,
- timestamp_start=None, timestamp_end=None):
+ def __init__(
+ self,
+ first_line_format,
+ method,
+ scheme,
+ host,
+ port,
+ path,
+ http_version,
+ headers=(),
+ content=None,
+ timestamp_start=None,
+ timestamp_end=None
+ ):
if isinstance(method, six.text_type):
method = method.encode("ascii", "strict")
if isinstance(scheme, six.text_type):
@@ -68,7 +80,7 @@ class Request(message.Message):
self.method, hostport, path
)
- def replace(self, pattern, repl, flags=0):
+ def replace(self, pattern, repl, flags=0, count=0):
"""
Replaces a regular expression pattern with repl in the headers, the
request path and the body of the request. Encoded content will be
@@ -82,9 +94,9 @@ class Request(message.Message):
if isinstance(repl, six.text_type):
repl = strutils.escaped_str_to_bytes(repl)
- c = super(Request, self).replace(pattern, repl, flags)
+ c = super(Request, self).replace(pattern, repl, flags, count)
self.path, pc = re.subn(
- pattern, repl, self.data.path, flags=flags
+ pattern, repl, self.data.path, flags=flags, count=count
)
c += pc
return c
diff --git a/netlib/http/response.py b/netlib/http/response.py
index 85f54940..ae29298f 100644
--- a/netlib/http/response.py
+++ b/netlib/http/response.py
@@ -1,19 +1,32 @@
from __future__ import absolute_import, print_function, division
-from email.utils import parsedate_tz, formatdate, mktime_tz
-import time
import six
-
+import time
+from email.utils import parsedate_tz, formatdate, mktime_tz
+from netlib import human
+from netlib import multidict
from netlib.http import cookies
from netlib.http import headers as nheaders
from netlib.http import message
-from netlib import multidict
-from netlib import human
+from netlib.http import status_codes
+from typing import AnyStr # noqa
+from typing import Dict # noqa
+from typing import Iterable # noqa
+from typing import Tuple # noqa
+from typing import Union # noqa
class ResponseData(message.MessageData):
- def __init__(self, http_version, status_code, reason=None, headers=(), content=None,
- timestamp_start=None, timestamp_end=None):
+ def __init__(
+ self,
+ http_version,
+ status_code,
+ reason=None,
+ headers=(),
+ content=None,
+ timestamp_start=None,
+ timestamp_end=None
+ ):
if isinstance(http_version, six.text_type):
http_version = http_version.encode("ascii", "strict")
if isinstance(reason, six.text_type):
@@ -54,6 +67,45 @@ class Response(message.Message):
details=details
)
+ @classmethod
+ def make(
+ cls,
+ status_code=200, # type: int
+ content=b"", # type: AnyStr
+ headers=() # type: Union[Dict[AnyStr, AnyStr], Iterable[Tuple[bytes, bytes]]]
+ ):
+ """
+ Simplified API for creating response objects.
+ """
+ resp = cls(
+ b"HTTP/1.1",
+ status_code,
+ status_codes.RESPONSES.get(status_code, "").encode(),
+ (),
+ None
+ )
+ # Assign this manually to update the content-length header.
+ if isinstance(content, bytes):
+ resp.content = content
+ elif isinstance(content, str):
+ resp.text = content
+ else:
+ raise TypeError("Expected content to be str or bytes, but is {}.".format(
+ type(content).__name__
+ ))
+
+ # Headers can be list or dict, we differentiate here.
+ if isinstance(headers, dict):
+ resp.headers = nheaders.Headers(**headers)
+ elif isinstance(headers, Iterable):
+ resp.headers = nheaders.Headers(headers)
+ else:
+ raise TypeError("Expected headers to be an iterable or dict, but is {}.".format(
+ type(headers).__name__
+ ))
+
+ return resp
+
@property
def status_code(self):
"""
diff --git a/netlib/strutils.py b/netlib/strutils.py
index 4a46b6b1..4cb3b805 100644
--- a/netlib/strutils.py
+++ b/netlib/strutils.py
@@ -121,6 +121,9 @@ def escaped_str_to_bytes(data):
def is_mostly_bin(s):
# type: (bytes) -> bool
+ if not s or len(s) == 0:
+ return False
+
return sum(
i < 9 or 13 < i < 32 or 126 < i
for i in six.iterbytes(s[:100])
diff --git a/netlib/websockets/__init__.py b/netlib/websockets/__init__.py
index fea696d9..e14e8a7d 100644
--- a/netlib/websockets/__init__.py
+++ b/netlib/websockets/__init__.py
@@ -1,11 +1,37 @@
from __future__ import absolute_import, print_function, division
-from .frame import FrameHeader, Frame, OPCODE
-from .protocol import Masker, WebsocketsProtocol
+
+from .frame import FrameHeader
+from .frame import Frame
+from .frame import OPCODE
+from .frame import CLOSE_REASON
+from .masker import Masker
+from .utils import MAGIC
+from .utils import VERSION
+from .utils import client_handshake_headers
+from .utils import server_handshake_headers
+from .utils import check_handshake
+from .utils import check_client_version
+from .utils import create_server_nonce
+from .utils import get_extensions
+from .utils import get_protocol
+from .utils import get_client_key
+from .utils import get_server_accept
__all__ = [
"FrameHeader",
"Frame",
- "Masker",
- "WebsocketsProtocol",
"OPCODE",
+ "CLOSE_REASON",
+ "Masker",
+ "MAGIC",
+ "VERSION",
+ "client_handshake_headers",
+ "server_handshake_headers",
+ "check_handshake",
+ "check_client_version",
+ "create_server_nonce",
+ "get_extensions",
+ "get_protocol",
+ "get_client_key",
+ "get_server_accept",
]
diff --git a/netlib/websockets/frame.py b/netlib/websockets/frame.py
index 7d355699..e62d0e87 100644
--- a/netlib/websockets/frame.py
+++ b/netlib/websockets/frame.py
@@ -2,7 +2,6 @@ from __future__ import absolute_import
import os
import struct
import io
-import warnings
import six
@@ -10,7 +9,7 @@ from netlib import tcp
from netlib import strutils
from netlib import utils
from netlib import human
-from netlib.websockets import protocol
+from .masker import Masker
MAX_16_BIT_INT = (1 << 16)
@@ -18,6 +17,7 @@ MAX_64_BIT_INT = (1 << 64)
DEFAULT = object()
+# RFC 6455, Section 5.2 - Base Framing Protocol
OPCODE = utils.BiDi(
CONTINUE=0x00,
TEXT=0x01,
@@ -27,6 +27,23 @@ OPCODE = utils.BiDi(
PONG=0x0a
)
+# RFC 6455, Section 7.4.1 - Defined Status Codes
+CLOSE_REASON = utils.BiDi(
+ NORMAL_CLOSURE=1000,
+ GOING_AWAY=1001,
+ PROTOCOL_ERROR=1002,
+ UNSUPPORTED_DATA=1003,
+ RESERVED=1004,
+ RESERVED_NO_STATUS=1005,
+ RESERVED_ABNORMAL_CLOSURE=1006,
+ INVALID_PAYLOAD_DATA=1007,
+ POLICY_VIOLATION=1008,
+ MESSAGE_TOO_BIG=1009,
+ MANDATORY_EXTENSION=1010,
+ INTERNAL_ERROR=1011,
+ RESERVED_TLS_HANDHSAKE_FAILED=1015,
+)
+
class FrameHeader(object):
@@ -103,10 +120,6 @@ class FrameHeader(object):
vals.append(" %s" % human.pretty_size(self.payload_length))
return "".join(vals)
- def human_readable(self):
- warnings.warn("FrameHeader.to_bytes is deprecated, use bytes(frame_header) instead.", DeprecationWarning)
- return repr(self)
-
def __bytes__(self):
first_byte = utils.setbit(0, 7, self.fin)
first_byte = utils.setbit(first_byte, 6, self.rsv1)
@@ -128,6 +141,9 @@ class FrameHeader(object):
# '!Q' = pack as 64 bit unsigned long long
# add 8 bytes extended payload length
b += struct.pack('!Q', self.payload_length)
+ else:
+ raise ValueError("Payload length exceeds 64bit integer")
+
if self.masking_key:
b += self.masking_key
return b
@@ -135,10 +151,6 @@ class FrameHeader(object):
if six.PY2:
__str__ = __bytes__
- def to_bytes(self):
- warnings.warn("FrameHeader.to_bytes is deprecated, use bytes(frame_header) instead.", DeprecationWarning)
- return bytes(self)
-
@classmethod
def from_file(cls, fp):
"""
@@ -151,19 +163,17 @@ class FrameHeader(object):
rsv1 = utils.getbit(first_byte, 6)
rsv2 = utils.getbit(first_byte, 5)
rsv3 = utils.getbit(first_byte, 4)
- # grab right-most 4 bits
- opcode = first_byte & 15
+ opcode = first_byte & 0xF
mask_bit = utils.getbit(second_byte, 7)
- # grab the next 7 bits
- length_code = second_byte & 127
+ length_code = second_byte & 0x7F
- # payload_lengthy > 125 indicates you need to read more bytes
+ # payload_length > 125 indicates you need to read more bytes
# to get the actual payload length
if length_code <= 125:
payload_length = length_code
elif length_code == 126:
payload_length, = struct.unpack("!H", fp.safe_read(2))
- elif length_code == 127:
+ else: # length_code == 127:
payload_length, = struct.unpack("!Q", fp.safe_read(8))
# masking key only present if mask bit set
@@ -191,31 +201,30 @@ class FrameHeader(object):
class Frame(object):
-
"""
- Represents one websockets frame.
- Constructor takes human readable forms of the frame components
- from_bytes() is also avaliable.
-
- WebSockets Frame as defined in RFC6455
-
- 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
- +-+-+-+-+-------+-+-------------+-------------------------------+
- |F|R|R|R| opcode|M| Payload len | Extended payload length |
- |I|S|S|S| (4) |A| (7) | (16/64) |
- |N|V|V|V| |S| | (if payload len==126/127) |
- | |1|2|3| |K| | |
- +-+-+-+-+-------+-+-------------+ - - - - - - - - - - - - - - - +
- | Extended payload length continued, if payload len == 127 |
- + - - - - - - - - - - - - - - - +-------------------------------+
- | |Masking-key, if MASK set to 1 |
- +-------------------------------+-------------------------------+
- | Masking-key (continued) | Payload Data |
- +-------------------------------- - - - - - - - - - - - - - - - +
- : Payload Data continued ... :
- + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - +
- | Payload Data continued ... |
- +---------------------------------------------------------------+
+ Represents a single WebSockets frame.
+ Constructor takes human readable forms of the frame components.
+ from_bytes() reads from a file-like object to create a new Frame.
+
+ WebSockets Frame as defined in RFC6455
+
+ 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
+ +-+-+-+-+-------+-+-------------+-------------------------------+
+ |F|R|R|R| opcode|M| Payload len | Extended payload length |
+ |I|S|S|S| (4) |A| (7) | (16/64) |
+ |N|V|V|V| |S| | (if payload len==126/127) |
+ | |1|2|3| |K| | |
+ +-+-+-+-+-------+-+-------------+ - - - - - - - - - - - - - - - +
+ | Extended payload length continued, if payload len == 127 |
+ + - - - - - - - - - - - - - - - +-------------------------------+
+ | |Masking-key, if MASK set to 1 |
+ +-------------------------------+-------------------------------+
+ | Masking-key (continued) | Payload Data |
+ +-------------------------------- - - - - - - - - - - - - - - - +
+ : Payload Data continued ... :
+ + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - +
+ | Payload Data continued ... |
+ +---------------------------------------------------------------+
"""
def __init__(self, payload=b"", **kwargs):
@@ -224,27 +233,6 @@ class Frame(object):
self.header = FrameHeader(**kwargs)
@classmethod
- def default(cls, message, from_client=False):
- """
- Construct a basic websocket frame from some default values.
- Creates a non-fragmented text frame.
- """
- if from_client:
- mask_bit = 1
- masking_key = os.urandom(4)
- else:
- mask_bit = 0
- masking_key = None
-
- return cls(
- message,
- fin=1, # final frame
- opcode=OPCODE.TEXT, # text
- mask=mask_bit,
- masking_key=masking_key,
- )
-
- @classmethod
def from_bytes(cls, bytestring):
"""
Construct a websocket frame from an in-memory bytestring
@@ -258,17 +246,13 @@ class Frame(object):
ret = ret + "\nPayload:\n" + strutils.bytes_to_escaped_str(self.payload)
return ret
- def human_readable(self):
- warnings.warn("Frame.to_bytes is deprecated, use bytes(frame) instead.", DeprecationWarning)
- return repr(self)
-
def __bytes__(self):
"""
Serialize the frame to wire format. Returns a string.
"""
b = bytes(self.header)
if self.header.masking_key:
- b += protocol.Masker(self.header.masking_key)(self.payload)
+ b += Masker(self.header.masking_key)(self.payload)
else:
b += self.payload
return b
@@ -276,15 +260,6 @@ class Frame(object):
if six.PY2:
__str__ = __bytes__
- def to_bytes(self):
- warnings.warn("FrameHeader.to_bytes is deprecated, use bytes(frame_header) instead.", DeprecationWarning)
- return bytes(self)
-
- def to_file(self, writer):
- warnings.warn("Frame.to_file is deprecated, use wfile.write(bytes(frame)) instead.", DeprecationWarning)
- writer.write(bytes(self))
- writer.flush()
-
@classmethod
def from_file(cls, fp):
"""
@@ -297,20 +272,11 @@ class Frame(object):
payload = fp.safe_read(header.payload_length)
if header.mask == 1 and header.masking_key:
- payload = protocol.Masker(header.masking_key)(payload)
+ payload = Masker(header.masking_key)(payload)
- return cls(
- payload,
- fin=header.fin,
- opcode=header.opcode,
- mask=header.mask,
- payload_length=header.payload_length,
- masking_key=header.masking_key,
- rsv1=header.rsv1,
- rsv2=header.rsv2,
- rsv3=header.rsv3,
- length_code=header.length_code
- )
+ frame = cls(payload)
+ frame.header = header
+ return frame
def __eq__(self, other):
if isinstance(other, Frame):
diff --git a/netlib/websockets/masker.py b/netlib/websockets/masker.py
new file mode 100644
index 00000000..bd39ed6a
--- /dev/null
+++ b/netlib/websockets/masker.py
@@ -0,0 +1,33 @@
+from __future__ import absolute_import
+
+import six
+
+
+class Masker(object):
+ """
+ Data sent from the server must be masked to prevent malicious clients
+ from sending data over the wire in predictable patterns.
+
+ Servers do not have to mask data they send to the client.
+ https://tools.ietf.org/html/rfc6455#section-5.3
+ """
+
+ def __init__(self, key):
+ self.key = key
+ self.offset = 0
+
+ def mask(self, offset, data):
+ result = bytearray(data)
+ for i in range(len(data)):
+ if six.PY2:
+ result[i] ^= ord(self.key[offset % 4])
+ else:
+ result[i] ^= self.key[offset % 4]
+ offset += 1
+ result = bytes(result)
+ return result
+
+ def __call__(self, data):
+ ret = self.mask(self.offset, data)
+ self.offset += len(ret)
+ return ret
diff --git a/netlib/websockets/protocol.py b/netlib/websockets/protocol.py
deleted file mode 100644
index af0eef7d..00000000
--- a/netlib/websockets/protocol.py
+++ /dev/null
@@ -1,112 +0,0 @@
-"""
-Colleciton of utility functions that implement small portions of the RFC6455
-WebSockets Protocol Useful for building WebSocket clients and servers.
-
-Emphassis is on readabilty, simplicity and modularity, not performance or
-completeness
-
-This is a work in progress and does not yet contain all the utilites need to
-create fully complient client/servers #
-Spec: https://tools.ietf.org/html/rfc6455
-
-The magic sha that websocket servers must know to prove they understand
-RFC6455
-"""
-
-from __future__ import absolute_import
-import base64
-import hashlib
-import os
-
-import six
-
-from netlib import http, strutils
-
-websockets_magic = b'258EAFA5-E914-47DA-95CA-C5AB0DC85B11'
-VERSION = "13"
-
-
-class Masker(object):
-
- """
- Data sent from the server must be masked to prevent malicious clients
- from sending data over the wire in predictable patterns
-
- Servers do not have to mask data they send to the client.
- https://tools.ietf.org/html/rfc6455#section-5.3
- """
-
- def __init__(self, key):
- self.key = key
- self.offset = 0
-
- def mask(self, offset, data):
- result = bytearray(data)
- if six.PY2:
- for i in range(len(data)):
- result[i] ^= ord(self.key[offset % 4])
- offset += 1
- result = str(result)
- else:
-
- for i in range(len(data)):
- result[i] ^= self.key[offset % 4]
- offset += 1
- result = bytes(result)
- return result
-
- def __call__(self, data):
- ret = self.mask(self.offset, data)
- self.offset += len(ret)
- return ret
-
-
-class WebsocketsProtocol(object):
-
- def __init__(self):
- pass
-
- @classmethod
- def client_handshake_headers(self, key=None, version=VERSION):
- """
- Create the headers for a valid HTTP upgrade request. If Key is not
- specified, it is generated, and can be found in sec-websocket-key in
- the returned header set.
-
- Returns an instance of http.Headers
- """
- if not key:
- key = base64.b64encode(os.urandom(16)).decode('ascii')
- return http.Headers(
- sec_websocket_key=key,
- sec_websocket_version=version,
- connection="Upgrade",
- upgrade="websocket",
- )
-
- @classmethod
- def server_handshake_headers(self, key):
- """
- The server response is a valid HTTP 101 response.
- """
- return http.Headers(
- sec_websocket_accept=self.create_server_nonce(key),
- connection="Upgrade",
- upgrade="websocket"
- )
-
- @classmethod
- def check_client_handshake(self, headers):
- if headers.get("upgrade") != "websocket":
- return
- return headers.get("sec-websocket-key")
-
- @classmethod
- def check_server_handshake(self, headers):
- if headers.get("upgrade") != "websocket":
- return
- return headers.get("sec-websocket-accept")
-
- @classmethod
- def create_server_nonce(self, client_nonce):
- return base64.b64encode(hashlib.sha1(strutils.always_bytes(client_nonce) + websockets_magic).digest())
diff --git a/netlib/websockets/utils.py b/netlib/websockets/utils.py
new file mode 100644
index 00000000..aa0d39a1
--- /dev/null
+++ b/netlib/websockets/utils.py
@@ -0,0 +1,90 @@
+"""
+Collection of WebSockets Protocol utility functions (RFC6455)
+Spec: https://tools.ietf.org/html/rfc6455
+"""
+
+from __future__ import absolute_import
+
+import base64
+import hashlib
+import os
+
+from netlib import http, strutils
+
+MAGIC = b'258EAFA5-E914-47DA-95CA-C5AB0DC85B11'
+VERSION = "13"
+
+
+def client_handshake_headers(version=None, key=None, protocol=None, extensions=None):
+ """
+ Create the headers for a valid HTTP upgrade request. If Key is not
+ specified, it is generated, and can be found in sec-websocket-key in
+ the returned header set.
+
+ Returns an instance of http.Headers
+ """
+ if version is None:
+ version = VERSION
+ if key is None:
+ key = base64.b64encode(os.urandom(16)).decode('ascii')
+ h = http.Headers(
+ connection="upgrade",
+ upgrade="websocket",
+ sec_websocket_version=version,
+ sec_websocket_key=key,
+ )
+ if protocol is not None:
+ h['sec-websocket-protocol'] = protocol
+ if extensions is not None:
+ h['sec-websocket-extensions'] = extensions
+ return h
+
+
+def server_handshake_headers(client_key, protocol=None, extensions=None):
+ """
+ The server response is a valid HTTP 101 response.
+
+ Returns an instance of http.Headers
+ """
+ h = http.Headers(
+ connection="upgrade",
+ upgrade="websocket",
+ sec_websocket_accept=create_server_nonce(client_key),
+ )
+ if protocol is not None:
+ h['sec-websocket-protocol'] = protocol
+ if extensions is not None:
+ h['sec-websocket-extensions'] = extensions
+ return h
+
+
+def check_handshake(headers):
+ return (
+ "upgrade" in headers.get("connection", "").lower() and
+ headers.get("upgrade", "").lower() == "websocket" and
+ (headers.get("sec-websocket-key") is not None or headers.get("sec-websocket-accept") is not None)
+ )
+
+
+def create_server_nonce(client_nonce):
+ return base64.b64encode(hashlib.sha1(strutils.always_bytes(client_nonce) + MAGIC).digest())
+
+
+def check_client_version(headers):
+ return headers.get("sec-websocket-version", "") == VERSION
+
+
+def get_extensions(headers):
+ return headers.get("sec-websocket-extensions", None)
+
+
+def get_protocol(headers):
+ return headers.get("sec-websocket-protocol", None)
+
+
+def get_client_key(headers):
+ return headers.get("sec-websocket-key", None)
+
+
+def get_server_accept(headers):
+ return headers.get("sec-websocket-accept", None)