aboutsummaryrefslogtreecommitdiffstats
path: root/netlib
diff options
context:
space:
mode:
Diffstat (limited to 'netlib')
-rw-r--r--netlib/basetypes.py34
-rw-r--r--netlib/certutils.py4
-rw-r--r--netlib/http/__init__.py14
-rw-r--r--netlib/http/authentication.py6
-rw-r--r--netlib/http/cookies.py14
-rw-r--r--netlib/http/headers.py56
-rw-r--r--netlib/http/http1/assemble.py10
-rw-r--r--netlib/http/http1/read.py91
-rw-r--r--netlib/http/http2/__init__.py4
-rw-r--r--netlib/http/http2/connections.py100
-rw-r--r--netlib/http/http2/framereader.py21
-rw-r--r--netlib/http/message.py33
-rw-r--r--netlib/http/multipart.py32
-rw-r--r--netlib/http/request.py80
-rw-r--r--netlib/http/response.py32
-rw-r--r--netlib/http/url.py96
-rw-r--r--netlib/human.py50
-rw-r--r--netlib/multidict.py48
-rw-r--r--netlib/odict.py8
-rw-r--r--netlib/socks.py5
-rw-r--r--netlib/strutils.py154
-rw-r--r--netlib/tcp.py103
-rw-r--r--netlib/tutils.py14
-rw-r--r--netlib/utils.py341
-rw-r--r--netlib/version_check.py1
-rw-r--r--netlib/websockets/__init__.py13
-rw-r--r--netlib/websockets/frame.py14
-rw-r--r--netlib/websockets/protocol.py35
-rw-r--r--netlib/wsgi.py31
29 files changed, 799 insertions, 645 deletions
diff --git a/netlib/basetypes.py b/netlib/basetypes.py
new file mode 100644
index 00000000..9d6c60ba
--- /dev/null
+++ b/netlib/basetypes.py
@@ -0,0 +1,34 @@
+import six
+import abc
+
+
+@six.add_metaclass(abc.ABCMeta)
+class Serializable(object):
+ """
+ Abstract Base Class that defines an API to save an object's state and restore it later on.
+ """
+
+ @classmethod
+ @abc.abstractmethod
+ def from_state(cls, state):
+ """
+ Create a new object from the given state.
+ """
+ raise NotImplementedError()
+
+ @abc.abstractmethod
+ def get_state(self):
+ """
+ Retrieve object state.
+ """
+ raise NotImplementedError()
+
+ @abc.abstractmethod
+ def set_state(self, state):
+ """
+ Set object state to the given state.
+ """
+ raise NotImplementedError()
+
+ def copy(self):
+ return self.from_state(self.get_state())
diff --git a/netlib/certutils.py b/netlib/certutils.py
index 34e01ed3..9eb41d03 100644
--- a/netlib/certutils.py
+++ b/netlib/certutils.py
@@ -12,7 +12,7 @@ from pyasn1.codec.der.decoder import decode
from pyasn1.error import PyAsn1Error
import OpenSSL
-from .utils import Serializable
+from netlib import basetypes
# Default expiry must not be too long: https://github.com/mitmproxy/mitmproxy/issues/815
@@ -364,7 +364,7 @@ class _GeneralNames(univ.SequenceOf):
constraint.ValueSizeConstraint(1, 1024)
-class SSLCert(Serializable):
+class SSLCert(basetypes.Serializable):
def __init__(self, cert):
"""
diff --git a/netlib/http/__init__.py b/netlib/http/__init__.py
index c4eb1d58..af95f4d0 100644
--- a/netlib/http/__init__.py
+++ b/netlib/http/__init__.py
@@ -1,14 +1,14 @@
from __future__ import absolute_import, print_function, division
-from .request import Request
-from .response import Response
-from .headers import Headers
-from .message import decoded
-from . import http1, http2, status_codes
+from netlib.http.request import Request
+from netlib.http.response import Response
+from netlib.http.headers import Headers, parse_content_type
+from netlib.http.message import decoded
+from netlib.http import http1, http2, status_codes, multipart
__all__ = [
"Request",
"Response",
- "Headers",
+ "Headers", "parse_content_type",
"decoded",
- "http1", "http2", "status_codes",
+ "http1", "http2", "status_codes", "multipart",
]
diff --git a/netlib/http/authentication.py b/netlib/http/authentication.py
index 6db70fdd..38ea46d6 100644
--- a/netlib/http/authentication.py
+++ b/netlib/http/authentication.py
@@ -1,5 +1,5 @@
from __future__ import (absolute_import, print_function, division)
-from argparse import Action, ArgumentTypeError
+import argparse
import binascii
@@ -124,7 +124,7 @@ class PassManSingleUser(PassMan):
return self.username == username and self.password == password_token
-class AuthAction(Action):
+class AuthAction(argparse.Action):
"""
Helper class to allow seamless integration int argparse. Example usage:
@@ -148,7 +148,7 @@ class SingleuserAuthAction(AuthAction):
def getPasswordManager(self, s):
if len(s.split(':')) != 2:
- raise ArgumentTypeError(
+ raise argparse.ArgumentTypeError(
"Invalid single-user specification. Please use the format username:password"
)
username, password = s.split(':')
diff --git a/netlib/http/cookies.py b/netlib/http/cookies.py
index 88c76870..768a85df 100644
--- a/netlib/http/cookies.py
+++ b/netlib/http/cookies.py
@@ -1,9 +1,8 @@
import collections
import re
-from email.utils import parsedate_tz, formatdate, mktime_tz
-from netlib.multidict import ImmutableMultiDict
-from .. import odict
+import email.utils
+from netlib import multidict
"""
A flexible module for cookie parsing and manipulation.
@@ -28,6 +27,7 @@ variants. Serialization follows RFC6265.
# TODO: Disallow LHS-only Cookie values
+
def _read_until(s, start, term):
"""
Read until one of the characters in term is reached.
@@ -167,7 +167,7 @@ def parse_set_cookie_headers(headers):
return ret
-class CookieAttrs(ImmutableMultiDict):
+class CookieAttrs(multidict.ImmutableMultiDict):
@staticmethod
def _kconv(key):
return key.lower()
@@ -243,10 +243,10 @@ def refresh_set_cookie_header(c, delta):
raise ValueError("Invalid Cookie")
if "expires" in attrs:
- e = parsedate_tz(attrs["expires"])
+ e = email.utils.parsedate_tz(attrs["expires"])
if e:
- f = mktime_tz(e) + delta
- attrs = attrs.with_set_all("expires", [formatdate(f)])
+ f = email.utils.mktime_tz(e) + delta
+ attrs = attrs.with_set_all("expires", [email.utils.formatdate(f)])
else:
# This can happen when the expires tag is invalid.
# reddit.com sends a an expires tag like this: "Thu, 31 Dec
diff --git a/netlib/http/headers.py b/netlib/http/headers.py
index 60d3f429..14888ea9 100644
--- a/netlib/http/headers.py
+++ b/netlib/http/headers.py
@@ -2,27 +2,28 @@ from __future__ import absolute_import, print_function, division
import re
-try:
- from collections.abc import MutableMapping
-except ImportError: # pragma: no cover
- from collections import MutableMapping # Workaround for Python < 3.3
-
import six
-from ..multidict import MultiDict
-from ..utils import always_bytes
+from netlib import multidict
+from netlib import strutils
# See also: http://lucumr.pocoo.org/2013/7/2/the-updated-guide-to-unicode/
if six.PY2: # pragma: no cover
- _native = lambda x: x
- _always_bytes = lambda x: x
+ def _native(x):
+ return x
+
+ def _always_bytes(x):
+ return x
else:
# While headers _should_ be ASCII, it's not uncommon for certain headers to be utf-8 encoded.
- _native = lambda x: x.decode("utf-8", "surrogateescape")
- _always_bytes = lambda x: always_bytes(x, "utf-8", "surrogateescape")
+ def _native(x):
+ return x.decode("utf-8", "surrogateescape")
+
+ def _always_bytes(x):
+ return strutils.always_bytes(x, "utf-8", "surrogateescape")
-class Headers(MultiDict):
+class Headers(multidict.MultiDict):
"""
Header class which allows both convenient access to individual headers as well as
direct access to the underlying raw data. Provides a full dictionary interface.
@@ -70,7 +71,7 @@ class Headers(MultiDict):
For use with the "Set-Cookie" header, see :py:meth:`get_all`.
"""
- def __init__(self, fields=None, **headers):
+ def __init__(self, fields=(), **headers):
"""
Args:
fields: (optional) list of ``(name, value)`` header byte tuples,
@@ -91,7 +92,7 @@ class Headers(MultiDict):
headers = {
_always_bytes(name).replace(b"_", b"-"): _always_bytes(value)
for name, value in six.iteritems(headers)
- }
+ }
self.update(headers)
@staticmethod
@@ -174,3 +175,30 @@ class Headers(MultiDict):
fields.append([name, value])
self.fields = fields
return replacements
+
+
+def parse_content_type(c):
+ """
+ A simple parser for content-type values. Returns a (type, subtype,
+ parameters) tuple, where type and subtype are strings, and parameters
+ is a dict. If the string could not be parsed, return None.
+
+ E.g. the following string:
+
+ text/html; charset=UTF-8
+
+ Returns:
+
+ ("text", "html", {"charset": "UTF-8"})
+ """
+ parts = c.split(";", 1)
+ ts = parts[0].split("/", 1)
+ if len(ts) != 2:
+ return None
+ d = {}
+ if len(parts) == 2:
+ for i in parts[1].split(";"):
+ clause = i.split("=", 1)
+ if len(clause) == 2:
+ d[clause[0].strip()] = clause[1].strip()
+ return ts[0].lower(), ts[1].lower(), d
diff --git a/netlib/http/http1/assemble.py b/netlib/http/http1/assemble.py
index f06ad5a1..00d1563b 100644
--- a/netlib/http/http1/assemble.py
+++ b/netlib/http/http1/assemble.py
@@ -1,12 +1,12 @@
from __future__ import absolute_import, print_function, division
-from ... import utils
-import itertools
-from ...exceptions import HttpException
+from netlib import utils
+from netlib import exceptions
+
def assemble_request(request):
if request.content is None:
- raise HttpException("Cannot assemble flow with missing content")
+ raise exceptions.HttpException("Cannot assemble flow with missing content")
head = assemble_request_head(request)
body = b"".join(assemble_body(request.data.headers, [request.data.content]))
return head + body
@@ -20,7 +20,7 @@ def assemble_request_head(request):
def assemble_response(response):
if response.content is None:
- raise HttpException("Cannot assemble flow with missing content")
+ raise exceptions.HttpException("Cannot assemble flow with missing content")
head = assemble_response_head(response)
body = b"".join(assemble_body(response.data.headers, [response.data.content]))
return head + body
diff --git a/netlib/http/http1/read.py b/netlib/http/http1/read.py
index d30976bd..bf4c2f0c 100644
--- a/netlib/http/http1/read.py
+++ b/netlib/http/http1/read.py
@@ -3,9 +3,24 @@ import time
import sys
import re
-from ... import utils
-from ...exceptions import HttpReadDisconnect, HttpSyntaxException, HttpException, TcpDisconnect
-from .. import Request, Response, Headers
+from netlib.http import request
+from netlib.http import response
+from netlib.http import headers
+from netlib.http import url
+from netlib import utils
+from netlib import exceptions
+
+
+def get_header_tokens(headers, key):
+ """
+ Retrieve all tokens for a header key. A number of different headers
+ follow a pattern where each header line can containe comma-separated
+ tokens, and headers can be set multiple times.
+ """
+ if key not in headers:
+ return []
+ tokens = headers[key].split(",")
+ return [token.strip() for token in tokens]
def read_request(rfile, body_size_limit=None):
@@ -27,9 +42,9 @@ def read_request_head(rfile):
The HTTP request object (without body)
Raises:
- HttpReadDisconnect: No bytes can be read from rfile.
- HttpSyntaxException: The input is malformed HTTP.
- HttpException: Any other error occured.
+ exceptions.HttpReadDisconnect: No bytes can be read from rfile.
+ exceptions.HttpSyntaxException: The input is malformed HTTP.
+ exceptions.HttpException: Any other error occured.
"""
timestamp_start = time.time()
if hasattr(rfile, "reset_timestamps"):
@@ -42,7 +57,7 @@ def read_request_head(rfile):
# more accurate timestamp_start
timestamp_start = rfile.first_byte_timestamp
- return Request(
+ return request.Request(
form, method, scheme, host, port, path, http_version, headers, None, timestamp_start
)
@@ -66,9 +81,9 @@ def read_response_head(rfile):
The HTTP request object (without body)
Raises:
- HttpReadDisconnect: No bytes can be read from rfile.
- HttpSyntaxException: The input is malformed HTTP.
- HttpException: Any other error occured.
+ exceptions.HttpReadDisconnect: No bytes can be read from rfile.
+ exceptions.HttpSyntaxException: The input is malformed HTTP.
+ exceptions.HttpException: Any other error occured.
"""
timestamp_start = time.time()
@@ -82,7 +97,7 @@ def read_response_head(rfile):
# more accurate timestamp_start
timestamp_start = rfile.first_byte_timestamp
- return Response(http_version, status_code, message, headers, None, timestamp_start)
+ return response.Response(http_version, status_code, message, headers, None, timestamp_start)
def read_body(rfile, expected_size, limit=None, max_chunk_size=4096):
@@ -99,7 +114,7 @@ def read_body(rfile, expected_size, limit=None, max_chunk_size=4096):
A generator that yields byte chunks of the content.
Raises:
- HttpException, if an error occurs
+ exceptions.HttpException, if an error occurs
Caveats:
max_chunk_size is not considered if the transfer encoding is chunked.
@@ -114,7 +129,7 @@ def read_body(rfile, expected_size, limit=None, max_chunk_size=4096):
yield x
elif expected_size >= 0:
if limit is not None and expected_size > limit:
- raise HttpException(
+ raise exceptions.HttpException(
"HTTP Body too large. "
"Limit is {}, content length was advertised as {}".format(limit, expected_size)
)
@@ -123,7 +138,7 @@ def read_body(rfile, expected_size, limit=None, max_chunk_size=4096):
chunk_size = min(bytes_left, max_chunk_size)
content = rfile.read(chunk_size)
if len(content) < chunk_size:
- raise HttpException("Unexpected EOF")
+ raise exceptions.HttpException("Unexpected EOF")
yield content
bytes_left -= chunk_size
else:
@@ -137,7 +152,7 @@ def read_body(rfile, expected_size, limit=None, max_chunk_size=4096):
bytes_left -= chunk_size
not_done = rfile.read(1)
if not_done:
- raise HttpException("HTTP body too large. Limit is {}.".format(limit))
+ raise exceptions.HttpException("HTTP body too large. Limit is {}.".format(limit))
def connection_close(http_version, headers):
@@ -147,7 +162,7 @@ def connection_close(http_version, headers):
"""
# At first, check if we have an explicit Connection header.
if "connection" in headers:
- tokens = utils.get_header_tokens(headers, "connection")
+ tokens = get_header_tokens(headers, "connection")
if "close" in tokens:
return True
elif "keep-alive" in tokens:
@@ -167,7 +182,7 @@ def expected_http_body_size(request, response=None):
- -1, if all data should be read until end of stream.
Raises:
- HttpSyntaxException, if the content length header is invalid
+ exceptions.HttpSyntaxException, if the content length header is invalid
"""
# Determine response size according to
# http://tools.ietf.org/html/rfc7230#section-3.3
@@ -202,7 +217,7 @@ def expected_http_body_size(request, response=None):
raise ValueError()
return size
except ValueError:
- raise HttpSyntaxException("Unparseable Content Length")
+ raise exceptions.HttpSyntaxException("Unparseable Content Length")
if is_request:
return 0
return -1
@@ -214,19 +229,19 @@ def _get_first_line(rfile):
if line == b"\r\n" or line == b"\n":
# Possible leftover from previous message
line = rfile.readline()
- except TcpDisconnect:
- raise HttpReadDisconnect("Remote disconnected")
+ except exceptions.TcpDisconnect:
+ raise exceptions.HttpReadDisconnect("Remote disconnected")
if not line:
- raise HttpReadDisconnect("Remote disconnected")
+ raise exceptions.HttpReadDisconnect("Remote disconnected")
return line.strip()
def _read_request_line(rfile):
try:
line = _get_first_line(rfile)
- except HttpReadDisconnect:
+ except exceptions.HttpReadDisconnect:
# We want to provide a better error message.
- raise HttpReadDisconnect("Client disconnected")
+ raise exceptions.HttpReadDisconnect("Client disconnected")
try:
method, path, http_version = line.split(b" ")
@@ -240,11 +255,11 @@ def _read_request_line(rfile):
scheme, path = None, None
else:
form = "absolute"
- scheme, host, port, path = utils.parse_url(path)
+ scheme, host, port, path = url.parse(path)
_check_http_version(http_version)
except ValueError:
- raise HttpSyntaxException("Bad HTTP request line: {}".format(line))
+ raise exceptions.HttpSyntaxException("Bad HTTP request line: {}".format(line))
return form, method, scheme, host, port, path, http_version
@@ -263,7 +278,7 @@ def _parse_authority_form(hostport):
if not utils.is_valid_host(host) or not utils.is_valid_port(port):
raise ValueError()
except ValueError:
- raise HttpSyntaxException("Invalid host specification: {}".format(hostport))
+ raise exceptions.HttpSyntaxException("Invalid host specification: {}".format(hostport))
return host, port
@@ -271,9 +286,9 @@ def _parse_authority_form(hostport):
def _read_response_line(rfile):
try:
line = _get_first_line(rfile)
- except HttpReadDisconnect:
+ except exceptions.HttpReadDisconnect:
# We want to provide a better error message.
- raise HttpReadDisconnect("Server disconnected")
+ raise exceptions.HttpReadDisconnect("Server disconnected")
try:
@@ -286,14 +301,14 @@ def _read_response_line(rfile):
_check_http_version(http_version)
except ValueError:
- raise HttpSyntaxException("Bad HTTP response line: {}".format(line))
+ raise exceptions.HttpSyntaxException("Bad HTTP response line: {}".format(line))
return http_version, status_code, message
def _check_http_version(http_version):
if not re.match(br"^HTTP/\d\.\d$", http_version):
- raise HttpSyntaxException("Unknown HTTP version: {}".format(http_version))
+ raise exceptions.HttpSyntaxException("Unknown HTTP version: {}".format(http_version))
def _read_headers(rfile):
@@ -305,7 +320,7 @@ def _read_headers(rfile):
A headers object
Raises:
- HttpSyntaxException
+ exceptions.HttpSyntaxException
"""
ret = []
while True:
@@ -314,7 +329,7 @@ def _read_headers(rfile):
break
if line[0] in b" \t":
if not ret:
- raise HttpSyntaxException("Invalid headers")
+ raise exceptions.HttpSyntaxException("Invalid headers")
# continued header
ret[-1] = (ret[-1][0], ret[-1][1] + b'\r\n ' + line.strip())
else:
@@ -325,8 +340,8 @@ def _read_headers(rfile):
raise ValueError()
ret.append((name, value))
except ValueError:
- raise HttpSyntaxException("Invalid headers")
- return Headers(ret)
+ raise exceptions.HttpSyntaxException("Invalid headers")
+ return headers.Headers(ret)
def _read_chunked(rfile, limit=sys.maxsize):
@@ -341,22 +356,22 @@ def _read_chunked(rfile, limit=sys.maxsize):
while True:
line = rfile.readline(128)
if line == b"":
- raise HttpException("Connection closed prematurely")
+ raise exceptions.HttpException("Connection closed prematurely")
if line != b"\r\n" and line != b"\n":
try:
length = int(line, 16)
except ValueError:
- raise HttpSyntaxException("Invalid chunked encoding length: {}".format(line))
+ raise exceptions.HttpSyntaxException("Invalid chunked encoding length: {}".format(line))
total += length
if total > limit:
- raise HttpException(
+ raise exceptions.HttpException(
"HTTP Body too large. Limit is {}, "
"chunked content longer than {}".format(limit, total)
)
chunk = rfile.read(length)
suffix = rfile.readline(5)
if suffix != b"\r\n":
- raise HttpSyntaxException("Malformed chunked body")
+ raise exceptions.HttpSyntaxException("Malformed chunked body")
if length == 0:
return
yield chunk
diff --git a/netlib/http/http2/__init__.py b/netlib/http/http2/__init__.py
index 7043d36f..633e6a20 100644
--- a/netlib/http/http2/__init__.py
+++ b/netlib/http/http2/__init__.py
@@ -1,6 +1,8 @@
from __future__ import absolute_import, print_function, division
from .connections import HTTP2Protocol
+from netlib.http.http2 import framereader
__all__ = [
- "HTTP2Protocol"
+ "HTTP2Protocol",
+ "framereader",
]
diff --git a/netlib/http/http2/connections.py b/netlib/http/http2/connections.py
index 6643b6b9..8667d370 100644
--- a/netlib/http/http2/connections.py
+++ b/netlib/http/http2/connections.py
@@ -2,11 +2,15 @@ from __future__ import (absolute_import, print_function, division)
import itertools
import time
-from hpack.hpack import Encoder, Decoder
-from ... import utils
-from .. import Headers, Response, Request
+import hyperframe.frame
-from hyperframe import frame
+from hpack.hpack import Encoder, Decoder
+from netlib import utils
+from netlib.http import url
+import netlib.http.headers
+import netlib.http.response
+import netlib.http.request
+from netlib.http.http2 import framereader
class TCPHandler(object):
@@ -38,12 +42,12 @@ class HTTP2Protocol(object):
CLIENT_CONNECTION_PREFACE = b'PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n'
HTTP2_DEFAULT_SETTINGS = {
- frame.SettingsFrame.HEADER_TABLE_SIZE: 4096,
- frame.SettingsFrame.ENABLE_PUSH: 1,
- frame.SettingsFrame.MAX_CONCURRENT_STREAMS: None,
- frame.SettingsFrame.INITIAL_WINDOW_SIZE: 2 ** 16 - 1,
- frame.SettingsFrame.MAX_FRAME_SIZE: 2 ** 14,
- frame.SettingsFrame.MAX_HEADER_LIST_SIZE: None,
+ hyperframe.frame.SettingsFrame.HEADER_TABLE_SIZE: 4096,
+ hyperframe.frame.SettingsFrame.ENABLE_PUSH: 1,
+ hyperframe.frame.SettingsFrame.MAX_CONCURRENT_STREAMS: None,
+ hyperframe.frame.SettingsFrame.INITIAL_WINDOW_SIZE: 2 ** 16 - 1,
+ hyperframe.frame.SettingsFrame.MAX_FRAME_SIZE: 2 ** 14,
+ hyperframe.frame.SettingsFrame.MAX_HEADER_LIST_SIZE: None,
}
def __init__(
@@ -98,6 +102,11 @@ class HTTP2Protocol(object):
method = headers.get(':method', 'GET')
scheme = headers.get(':scheme', 'https')
path = headers.get(':path', '/')
+
+ headers.clear(":method")
+ headers.clear(":scheme")
+ headers.clear(":path")
+
host = None
port = None
@@ -112,7 +121,7 @@ class HTTP2Protocol(object):
else:
first_line_format = "absolute"
# FIXME: verify if path or :host contains what we need
- scheme, host, port, _ = utils.parse_url(path)
+ scheme, host, port, _ = url.parse(path)
scheme = scheme.decode('ascii')
host = host.decode('ascii')
@@ -122,7 +131,7 @@ class HTTP2Protocol(object):
port = 80 if scheme == 'http' else 443
port = int(port)
- request = Request(
+ request = netlib.http.request.Request(
first_line_format,
method.encode('ascii'),
scheme.encode('ascii'),
@@ -170,7 +179,7 @@ class HTTP2Protocol(object):
else:
timestamp_end = None
- response = Response(
+ response = netlib.http.response.Response(
b"HTTP/2.0",
int(headers.get(':status', 502)),
b'',
@@ -184,15 +193,15 @@ class HTTP2Protocol(object):
return response
def assemble(self, message):
- if isinstance(message, Request):
+ if isinstance(message, netlib.http.request.Request):
return self.assemble_request(message)
- elif isinstance(message, Response):
+ elif isinstance(message, netlib.http.response.Response):
return self.assemble_response(message)
else:
raise ValueError("HTTP message not supported.")
def assemble_request(self, request):
- assert isinstance(request, Request)
+ assert isinstance(request, netlib.http.request.Request)
authority = self.tcp_handler.sni if self.tcp_handler.sni else self.tcp_handler.address.host
if self.tcp_handler.address.port != 443:
@@ -202,12 +211,9 @@ class HTTP2Protocol(object):
if ':authority' not in headers:
headers.insert(0, b':authority', authority.encode('ascii'))
- if ':scheme' not in headers:
- headers.insert(0, b':scheme', request.scheme.encode('ascii'))
- if ':path' not in headers:
- headers.insert(0, b':path', request.path.encode('ascii'))
- if ':method' not in headers:
- headers.insert(0, b':method', request.method.encode('ascii'))
+ headers.insert(0, b':scheme', request.scheme.encode('ascii'))
+ headers.insert(0, b':path', request.path.encode('ascii'))
+ headers.insert(0, b':method', request.method.encode('ascii'))
if hasattr(request, 'stream_id'):
stream_id = request.stream_id
@@ -219,7 +225,7 @@ class HTTP2Protocol(object):
self._create_body(request.body, stream_id)))
def assemble_response(self, response):
- assert isinstance(response, Response)
+ assert isinstance(response, netlib.http.response.Response)
headers = response.headers.copy()
@@ -251,9 +257,9 @@ class HTTP2Protocol(object):
magic = self.tcp_handler.rfile.safe_read(magic_length)
assert magic == self.CLIENT_CONNECTION_PREFACE
- frm = frame.SettingsFrame(settings={
- frame.SettingsFrame.ENABLE_PUSH: 0,
- frame.SettingsFrame.MAX_CONCURRENT_STREAMS: 1,
+ frm = hyperframe.frame.SettingsFrame(settings={
+ hyperframe.frame.SettingsFrame.ENABLE_PUSH: 0,
+ hyperframe.frame.SettingsFrame.MAX_CONCURRENT_STREAMS: 1,
})
self.send_frame(frm, hide=True)
self._receive_settings(hide=True)
@@ -264,7 +270,7 @@ class HTTP2Protocol(object):
self.tcp_handler.wfile.write(self.CLIENT_CONNECTION_PREFACE)
- self.send_frame(frame.SettingsFrame(), hide=True)
+ self.send_frame(hyperframe.frame.SettingsFrame(), hide=True)
self._receive_settings(hide=True) # server announces own settings
self._receive_settings(hide=True) # server acks my settings
@@ -277,18 +283,18 @@ class HTTP2Protocol(object):
def read_frame(self, hide=False):
while True:
- frm = utils.http2_read_frame(self.tcp_handler.rfile)
+ frm = framereader.http2_read_frame(self.tcp_handler.rfile)
if not hide and self.dump_frames: # pragma no cover
print(frm.human_readable("<<"))
- if isinstance(frm, frame.PingFrame):
- raw_bytes = frame.PingFrame(flags=['ACK'], payload=frm.payload).serialize()
+ if isinstance(frm, hyperframe.frame.PingFrame):
+ raw_bytes = hyperframe.frame.PingFrame(flags=['ACK'], payload=frm.payload).serialize()
self.tcp_handler.wfile.write(raw_bytes)
self.tcp_handler.wfile.flush()
continue
- if isinstance(frm, frame.SettingsFrame) and 'ACK' not in frm.flags:
+ if isinstance(frm, hyperframe.frame.SettingsFrame) and 'ACK' not in frm.flags:
self._apply_settings(frm.settings, hide)
- if isinstance(frm, frame.DataFrame) and frm.flow_controlled_length > 0:
+ if isinstance(frm, hyperframe.frame.DataFrame) and frm.flow_controlled_length > 0:
self._update_flow_control_window(frm.stream_id, frm.flow_controlled_length)
return frm
@@ -300,7 +306,7 @@ class HTTP2Protocol(object):
return True
def _handle_unexpected_frame(self, frm):
- if isinstance(frm, frame.SettingsFrame):
+ if isinstance(frm, hyperframe.frame.SettingsFrame):
return
if self.unhandled_frame_cb:
self.unhandled_frame_cb(frm)
@@ -308,7 +314,7 @@ class HTTP2Protocol(object):
def _receive_settings(self, hide=False):
while True:
frm = self.read_frame(hide)
- if isinstance(frm, frame.SettingsFrame):
+ if isinstance(frm, hyperframe.frame.SettingsFrame):
break
else:
self._handle_unexpected_frame(frm)
@@ -332,31 +338,31 @@ class HTTP2Protocol(object):
old_value = '-'
self.http2_settings[setting] = value
- frm = frame.SettingsFrame(flags=['ACK'])
+ frm = hyperframe.frame.SettingsFrame(flags=['ACK'])
self.send_frame(frm, hide)
def _update_flow_control_window(self, stream_id, increment):
- frm = frame.WindowUpdateFrame(stream_id=0, window_increment=increment)
+ frm = hyperframe.frame.WindowUpdateFrame(stream_id=0, window_increment=increment)
self.send_frame(frm)
- frm = frame.WindowUpdateFrame(stream_id=stream_id, window_increment=increment)
+ frm = hyperframe.frame.WindowUpdateFrame(stream_id=stream_id, window_increment=increment)
self.send_frame(frm)
def _create_headers(self, headers, stream_id, end_stream=True):
def frame_cls(chunks):
for i in chunks:
if i == 0:
- yield frame.HeadersFrame, i
+ yield hyperframe.frame.HeadersFrame, i
else:
- yield frame.ContinuationFrame, i
+ yield hyperframe.frame.ContinuationFrame, i
header_block_fragment = self.encoder.encode(headers.fields)
- chunk_size = self.http2_settings[frame.SettingsFrame.MAX_FRAME_SIZE]
+ chunk_size = self.http2_settings[hyperframe.frame.SettingsFrame.MAX_FRAME_SIZE]
chunks = range(0, len(header_block_fragment), chunk_size)
frms = [frm_cls(
flags=[],
stream_id=stream_id,
- data=header_block_fragment[i:i+chunk_size]) for frm_cls, i in frame_cls(chunks)]
+ data=header_block_fragment[i:i + chunk_size]) for frm_cls, i in frame_cls(chunks)]
frms[-1].flags.add('END_HEADERS')
if end_stream:
@@ -372,12 +378,12 @@ class HTTP2Protocol(object):
if body is None or len(body) == 0:
return b''
- chunk_size = self.http2_settings[frame.SettingsFrame.MAX_FRAME_SIZE]
+ chunk_size = self.http2_settings[hyperframe.frame.SettingsFrame.MAX_FRAME_SIZE]
chunks = range(0, len(body), chunk_size)
- frms = [frame.DataFrame(
+ frms = [hyperframe.frame.DataFrame(
flags=[],
stream_id=stream_id,
- data=body[i:i+chunk_size]) for i in chunks]
+ data=body[i:i + chunk_size]) for i in chunks]
frms[-1].flags.add('END_STREAM')
if self.dump_frames: # pragma no cover
@@ -398,7 +404,7 @@ class HTTP2Protocol(object):
while True:
frm = self.read_frame()
if (
- (isinstance(frm, frame.HeadersFrame) or isinstance(frm, frame.ContinuationFrame)) and
+ (isinstance(frm, hyperframe.frame.HeadersFrame) or isinstance(frm, hyperframe.frame.ContinuationFrame)) and
(stream_id is None or frm.stream_id == stream_id)
):
stream_id = frm.stream_id
@@ -412,14 +418,14 @@ class HTTP2Protocol(object):
while body_expected:
frm = self.read_frame()
- if isinstance(frm, frame.DataFrame) and frm.stream_id == stream_id:
+ if isinstance(frm, hyperframe.frame.DataFrame) and frm.stream_id == stream_id:
body += frm.data
if 'END_STREAM' in frm.flags:
break
else:
self._handle_unexpected_frame(frm)
- headers = Headers(
+ headers = netlib.http.headers.Headers(
(k.encode('ascii'), v.encode('ascii')) for k, v in self.decoder.decode(header_blocks)
)
diff --git a/netlib/http/http2/framereader.py b/netlib/http/http2/framereader.py
new file mode 100644
index 00000000..d45be646
--- /dev/null
+++ b/netlib/http/http2/framereader.py
@@ -0,0 +1,21 @@
+import codecs
+
+import hyperframe
+
+
+def http2_read_raw_frame(rfile):
+ header = rfile.safe_read(9)
+ length = int(codecs.encode(header[:3], 'hex_codec'), 16)
+
+ if length == 4740180:
+ raise ValueError("Length field looks more like HTTP/1.1: %s" % rfile.peek(20))
+
+ body = rfile.safe_read(length)
+ return [header, body]
+
+
+def http2_read_frame(rfile):
+ header, body = http2_read_raw_frame(rfile)
+ frame, length = hyperframe.frame.Frame.parse_frame_header(header)
+ frame.parse_body(memoryview(body))
+ return frame
diff --git a/netlib/http/message.py b/netlib/http/message.py
index 028f43a1..b633b671 100644
--- a/netlib/http/message.py
+++ b/netlib/http/message.py
@@ -4,20 +4,25 @@ import warnings
import six
-from ..multidict import MultiDict
-from .headers import Headers
-from .. import encoding, utils
+from netlib import encoding, strutils, basetypes
+from netlib.http import headers
if six.PY2: # pragma: no cover
- _native = lambda x: x
- _always_bytes = lambda x: x
+ def _native(x):
+ return x
+
+ def _always_bytes(x):
+ return x
else:
- # While the HTTP head _should_ be ASCII, it's not uncommon for certain headers to be utf-8 encoded.
- _native = lambda x: x.decode("utf-8", "surrogateescape")
- _always_bytes = lambda x: utils.always_bytes(x, "utf-8", "surrogateescape")
+ # While headers _should_ be ASCII, it's not uncommon for certain headers to be utf-8 encoded.
+ def _native(x):
+ return x.decode("utf-8", "surrogateescape")
+
+ def _always_bytes(x):
+ return strutils.always_bytes(x, "utf-8", "surrogateescape")
-class MessageData(utils.Serializable):
+class MessageData(basetypes.Serializable):
def __eq__(self, other):
if isinstance(other, MessageData):
return self.__dict__ == other.__dict__
@@ -32,7 +37,7 @@ class MessageData(utils.Serializable):
def set_state(self, state):
for k, v in state.items():
if k == "headers":
- v = Headers.from_state(v)
+ v = headers.Headers.from_state(v)
setattr(self, k, v)
def get_state(self):
@@ -42,11 +47,11 @@ class MessageData(utils.Serializable):
@classmethod
def from_state(cls, state):
- state["headers"] = Headers.from_state(state["headers"])
+ state["headers"] = headers.Headers.from_state(state["headers"])
return cls(**state)
-class Message(utils.Serializable):
+class Message(basetypes.Serializable):
def __eq__(self, other):
if isinstance(other, Message):
return self.data == other.data
@@ -66,7 +71,7 @@ class Message(utils.Serializable):
@classmethod
def from_state(cls, state):
- state["headers"] = Headers.from_state(state["headers"])
+ state["headers"] = headers.Headers.from_state(state["headers"])
return cls(**state)
@property
@@ -195,7 +200,7 @@ class Message(utils.Serializable):
replacements = 0
if self.content:
with decoded(self):
- self.content, replacements = utils.safe_subn(
+ self.content, replacements = strutils.safe_subn(
pattern, repl, self.content, flags=flags
)
replacements += self.headers.replace(pattern, repl, flags)
diff --git a/netlib/http/multipart.py b/netlib/http/multipart.py
new file mode 100644
index 00000000..536b2809
--- /dev/null
+++ b/netlib/http/multipart.py
@@ -0,0 +1,32 @@
+import re
+
+from netlib.http import headers
+
+
+def decode(hdrs, content):
+ """
+ Takes a multipart boundary encoded string and returns list of (key, value) tuples.
+ """
+ v = hdrs.get("content-type")
+ if v:
+ v = headers.parse_content_type(v)
+ if not v:
+ return []
+ try:
+ boundary = v[2]["boundary"].encode("ascii")
+ except (KeyError, UnicodeError):
+ return []
+
+ rx = re.compile(br'\bname="([^"]+)"')
+ r = []
+
+ for i in content.split(b"--" + boundary):
+ parts = i.splitlines()
+ if len(parts) > 1 and parts[0][0:2] != b"--":
+ match = rx.search(parts[1])
+ if match:
+ key = match.group(1)
+ value = b"".join(parts[3 + parts[2:].index(b""):])
+ r.append((key, value))
+ return r
+ return []
diff --git a/netlib/http/request.py b/netlib/http/request.py
index 056a2d93..91d5f020 100644
--- a/netlib/http/request.py
+++ b/netlib/http/request.py
@@ -1,18 +1,18 @@
from __future__ import absolute_import, print_function, division
import re
-import warnings
import six
from six.moves import urllib
-from netlib import utils
+from netlib import encoding
+from netlib import multidict
+from netlib import strutils
+from netlib.http import multipart
from netlib.http import cookies
-from netlib.odict import ODict
-from .. import encoding
-from ..multidict import MultiDictView
-from .headers import Headers
-from .message import Message, _native, _always_bytes, MessageData
+from netlib.http import headers as nheaders
+from netlib.http import message
+import netlib.http.url
# This regex extracts & splits the host header into host and port.
# Handles the edge case of IPv6 addresses containing colons.
@@ -20,11 +20,11 @@ from .message import Message, _native, _always_bytes, MessageData
host_header_re = re.compile(r"^(?P<host>[^:]+|\[.+\])(?::(?P<port>\d+))?$")
-class RequestData(MessageData):
- def __init__(self, first_line_format, method, scheme, host, port, path, http_version, headers=None, content=None,
+class RequestData(message.MessageData):
+ def __init__(self, first_line_format, method, scheme, host, port, path, http_version, headers=(), content=None,
timestamp_start=None, timestamp_end=None):
- if not isinstance(headers, Headers):
- headers = Headers(headers)
+ if not isinstance(headers, nheaders.Headers):
+ headers = nheaders.Headers(headers)
self.first_line_format = first_line_format
self.method = method
@@ -39,7 +39,7 @@ class RequestData(MessageData):
self.timestamp_end = timestamp_end
-class Request(Message):
+class Request(message.Message):
"""
An HTTP request.
"""
@@ -67,7 +67,7 @@ class Request(Message):
"""
# TODO: Proper distinction between text and bytes.
c = super(Request, self).replace(pattern, repl, flags)
- self.path, pc = utils.safe_subn(
+ self.path, pc = strutils.safe_subn(
pattern, repl, self.path, flags=flags
)
c += pc
@@ -91,22 +91,22 @@ class Request(Message):
"""
HTTP request method, e.g. "GET".
"""
- return _native(self.data.method).upper()
+ return message._native(self.data.method).upper()
@method.setter
def method(self, method):
- self.data.method = _always_bytes(method)
+ self.data.method = message._always_bytes(method)
@property
def scheme(self):
"""
HTTP request scheme, which should be "http" or "https".
"""
- return _native(self.data.scheme)
+ return message._native(self.data.scheme)
@scheme.setter
def scheme(self, scheme):
- self.data.scheme = _always_bytes(scheme)
+ self.data.scheme = message._always_bytes(scheme)
@property
def host(self):
@@ -168,11 +168,11 @@ class Request(Message):
if self.data.path is None:
return None
else:
- return _native(self.data.path)
+ return message._native(self.data.path)
@path.setter
def path(self, path):
- self.data.path = _always_bytes(path)
+ self.data.path = message._always_bytes(path)
@property
def url(self):
@@ -181,11 +181,11 @@ class Request(Message):
"""
if self.first_line_format == "authority":
return "%s:%d" % (self.host, self.port)
- return utils.unparse_url(self.scheme, self.host, self.port, self.path)
+ return netlib.http.url.unparse(self.scheme, self.host, self.port, self.path)
@url.setter
def url(self, url):
- self.scheme, self.host, self.port, self.path = utils.parse_url(url)
+ self.scheme, self.host, self.port, self.path = netlib.http.url.parse(url)
def _parse_host_header(self):
"""Extract the host and port from Host header"""
@@ -221,28 +221,28 @@ class Request(Message):
"""
if self.first_line_format == "authority":
return "%s:%d" % (self.pretty_host, self.port)
- return utils.unparse_url(self.scheme, self.pretty_host, self.port, self.path)
+ return netlib.http.url.unparse(self.scheme, self.pretty_host, self.port, self.path)
@property
def query(self):
- # type: () -> MultiDictView
+ # type: () -> multidict.MultiDictView
"""
The request query string as an :py:class:`MultiDictView` object.
"""
- return MultiDictView(
+ return multidict.MultiDictView(
self._get_query,
self._set_query
)
def _get_query(self):
_, _, _, _, query, _ = urllib.parse.urlparse(self.url)
- return tuple(utils.urldecode(query))
+ return tuple(netlib.http.url.decode(query))
def _set_query(self, value):
- query = utils.urlencode(value)
+ query = netlib.http.url.encode(value)
scheme, netloc, path, params, _, fragment = urllib.parse.urlparse(self.url)
- _, _, _, self.path = utils.parse_url(
- urllib.parse.urlunparse([scheme, netloc, path, params, query, fragment]))
+ _, _, _, self.path = netlib.http.url.parse(
+ urllib.parse.urlunparse([scheme, netloc, path, params, query, fragment]))
@query.setter
def query(self, value):
@@ -250,13 +250,13 @@ class Request(Message):
@property
def cookies(self):
- # type: () -> MultiDictView
+ # type: () -> multidict.MultiDictView
"""
The request cookies.
- An empty :py:class:`MultiDictView` object if the cookie monster ate them all.
+ An empty :py:class:`multidict.MultiDictView` object if the cookie monster ate them all.
"""
- return MultiDictView(
+ return multidict.MultiDictView(
self._get_cookies,
self._set_cookies
)
@@ -289,8 +289,8 @@ class Request(Message):
components = map(lambda x: urllib.parse.quote(x, safe=""), components)
path = "/" + "/".join(components)
scheme, netloc, _, params, query, fragment = urllib.parse.urlparse(self.url)
- _, _, _, self.path = utils.parse_url(
- urllib.parse.urlunparse([scheme, netloc, path, params, query, fragment]))
+ _, _, _, self.path = netlib.http.url.parse(
+ urllib.parse.urlunparse([scheme, netloc, path, params, query, fragment]))
def anticache(self):
"""
@@ -329,11 +329,11 @@ class Request(Message):
@property
def urlencoded_form(self):
"""
- The URL-encoded form data as an :py:class:`MultiDictView` object.
- An empty MultiDictView if the content-type indicates non-form data
+ The URL-encoded form data as an :py:class:`multidict.MultiDictView` object.
+ An empty multidict.MultiDictView if the content-type indicates non-form data
or the content could not be parsed.
"""
- return MultiDictView(
+ return multidict.MultiDictView(
self._get_urlencoded_form,
self._set_urlencoded_form
)
@@ -341,7 +341,7 @@ class Request(Message):
def _get_urlencoded_form(self):
is_valid_content_type = "application/x-www-form-urlencoded" in self.headers.get("content-type", "").lower()
if is_valid_content_type:
- return tuple(utils.urldecode(self.content))
+ return tuple(netlib.http.url.decode(self.content))
return ()
def _set_urlencoded_form(self, value):
@@ -350,7 +350,7 @@ class Request(Message):
This will overwrite the existing content if there is one.
"""
self.headers["content-type"] = "application/x-www-form-urlencoded"
- self.content = utils.urlencode(value)
+ self.content = netlib.http.url.encode(value)
@urlencoded_form.setter
def urlencoded_form(self, value):
@@ -362,7 +362,7 @@ class Request(Message):
The multipart form data as an :py:class:`MultipartFormDict` object.
None if the content-type indicates non-form data.
"""
- return MultiDictView(
+ return multidict.MultiDictView(
self._get_multipart_form,
self._set_multipart_form
)
@@ -370,7 +370,7 @@ class Request(Message):
def _get_multipart_form(self):
is_valid_content_type = "multipart/form-data" in self.headers.get("content-type", "").lower()
if is_valid_content_type:
- return utils.multipartdecode(self.headers, self.content)
+ return multipart.decode(self.headers, self.content)
return ()
def _set_multipart_form(self, value):
diff --git a/netlib/http/response.py b/netlib/http/response.py
index 7d272e10..44b58be6 100644
--- a/netlib/http/response.py
+++ b/netlib/http/response.py
@@ -3,18 +3,18 @@ from __future__ import absolute_import, print_function, division
from email.utils import parsedate_tz, formatdate, mktime_tz
import time
-from . import cookies
-from .headers import Headers
-from .message import Message, _native, _always_bytes, MessageData
-from ..multidict import MultiDictView
-from .. import utils
+from netlib.http import cookies
+from netlib.http import headers as nheaders
+from netlib.http import message
+from netlib import multidict
+from netlib import human
-class ResponseData(MessageData):
- def __init__(self, http_version, status_code, reason=None, headers=None, content=None,
+class ResponseData(message.MessageData):
+ def __init__(self, http_version, status_code, reason=None, headers=(), content=None,
timestamp_start=None, timestamp_end=None):
- if not isinstance(headers, Headers):
- headers = Headers(headers)
+ if not isinstance(headers, nheaders.Headers):
+ headers = nheaders.Headers(headers)
self.http_version = http_version
self.status_code = status_code
@@ -25,7 +25,7 @@ class ResponseData(MessageData):
self.timestamp_end = timestamp_end
-class Response(Message):
+class Response(message.Message):
"""
An HTTP response.
"""
@@ -36,7 +36,7 @@ class Response(Message):
if self.content:
details = "{}, {}".format(
self.headers.get("content-type", "unknown content type"),
- utils.pretty_size(len(self.content))
+ human.pretty_size(len(self.content))
)
else:
details = "no content"
@@ -63,17 +63,17 @@ class Response(Message):
HTTP Reason Phrase, e.g. "Not Found".
This is always :py:obj:`None` for HTTP2 requests, because HTTP2 responses do not contain a reason phrase.
"""
- return _native(self.data.reason)
+ return message._native(self.data.reason)
@reason.setter
def reason(self, reason):
- self.data.reason = _always_bytes(reason)
+ self.data.reason = message._always_bytes(reason)
@property
def cookies(self):
- # type: () -> MultiDictView
+ # type: () -> multidict.MultiDictView
"""
- The response cookies. A possibly empty :py:class:`MultiDictView`, where the keys are
+ The response cookies. A possibly empty :py:class:`multidict.MultiDictView`, where the keys are
cookie name strings, and values are (value, attr) tuples. Value is a string, and attr is
an ODictCaseless containing cookie attributes. Within attrs, unary attributes (e.g. HTTPOnly)
are indicated by a Null value.
@@ -81,7 +81,7 @@ class Response(Message):
Caveats:
Updating the attr
"""
- return MultiDictView(
+ return multidict.MultiDictView(
self._get_cookies,
self._set_cookies
)
diff --git a/netlib/http/url.py b/netlib/http/url.py
new file mode 100644
index 00000000..5d461387
--- /dev/null
+++ b/netlib/http/url.py
@@ -0,0 +1,96 @@
+import six
+from six.moves import urllib
+
+from netlib import utils
+
+
+# PY2 workaround
+def decode_parse_result(result, enc):
+ if hasattr(result, "decode"):
+ return result.decode(enc)
+ else:
+ return urllib.parse.ParseResult(*[x.decode(enc) for x in result])
+
+
+# PY2 workaround
+def encode_parse_result(result, enc):
+ if hasattr(result, "encode"):
+ return result.encode(enc)
+ else:
+ return urllib.parse.ParseResult(*[x.encode(enc) for x in result])
+
+
+def parse(url):
+ """
+ URL-parsing function that checks that
+ - port is an integer 0-65535
+ - host is a valid IDNA-encoded hostname with no null-bytes
+ - path is valid ASCII
+
+ Args:
+ A URL (as bytes or as unicode)
+
+ Returns:
+ A (scheme, host, port, path) tuple
+
+ Raises:
+ ValueError, if the URL is not properly formatted.
+ """
+ parsed = urllib.parse.urlparse(url)
+
+ if not parsed.hostname:
+ raise ValueError("No hostname given")
+
+ if isinstance(url, six.binary_type):
+ host = parsed.hostname
+
+ # this should not raise a ValueError,
+ # but we try to be very forgiving here and accept just everything.
+ # decode_parse_result(parsed, "ascii")
+ else:
+ host = parsed.hostname.encode("idna")
+ parsed = encode_parse_result(parsed, "ascii")
+
+ port = parsed.port
+ if not port:
+ port = 443 if parsed.scheme == b"https" else 80
+
+ full_path = urllib.parse.urlunparse(
+ (b"", b"", parsed.path, parsed.params, parsed.query, parsed.fragment)
+ )
+ if not full_path.startswith(b"/"):
+ full_path = b"/" + full_path
+
+ if not utils.is_valid_host(host):
+ raise ValueError("Invalid Host")
+ if not utils.is_valid_port(port):
+ raise ValueError("Invalid Port")
+
+ return parsed.scheme, host, port, full_path
+
+
+def unparse(scheme, host, port, path=""):
+ """
+ Returns a URL string, constructed from the specified components.
+
+ Args:
+ All args must be str.
+ """
+ if path == "*":
+ path = ""
+ return "%s://%s%s" % (scheme, utils.hostport(scheme, host, port), path)
+
+
+def encode(s):
+ """
+ Takes a list of (key, value) tuples and returns a urlencoded string.
+ """
+ s = [tuple(i) for i in s]
+ return urllib.parse.urlencode(s, False)
+
+
+def decode(s):
+ """
+ Takes a urlencoded string and returns a list of (key, value) tuples.
+ """
+ return urllib.parse.parse_qsl(s, keep_blank_values=True)
diff --git a/netlib/human.py b/netlib/human.py
new file mode 100644
index 00000000..a007adc7
--- /dev/null
+++ b/netlib/human.py
@@ -0,0 +1,50 @@
+
+SIZE_TABLE = [
+ ("b", 1024 ** 0),
+ ("k", 1024 ** 1),
+ ("m", 1024 ** 2),
+ ("g", 1024 ** 3),
+ ("t", 1024 ** 4),
+]
+
+SIZE_UNITS = dict(SIZE_TABLE)
+
+
+def pretty_size(size):
+ for bottom, top in zip(SIZE_TABLE, SIZE_TABLE[1:]):
+ if bottom[1] <= size < top[1]:
+ suf = bottom[0]
+ lim = bottom[1]
+ x = round(size / lim, 2)
+ if x == int(x):
+ x = int(x)
+ return str(x) + suf
+ return "%s%s" % (size, SIZE_TABLE[0][0])
+
+
+def parse_size(s):
+ try:
+ return int(s)
+ except ValueError:
+ pass
+ for i in SIZE_UNITS.keys():
+ if s.endswith(i):
+ try:
+ return int(s[:-1]) * SIZE_UNITS[i]
+ except ValueError:
+ break
+ raise ValueError("Invalid size specification.")
+
+
+def pretty_duration(secs):
+ formatters = [
+ (100, "{:.0f}s"),
+ (10, "{:2.1f}s"),
+ (1, "{:1.2f}s"),
+ ]
+
+ for limit, formatter in formatters:
+ if secs >= limit:
+ return formatter.format(secs)
+ # less than 1 sec
+ return "{:.0f}ms".format(secs * 1000)
diff --git a/netlib/multidict.py b/netlib/multidict.py
index 248acdec..dc0f3466 100644
--- a/netlib/multidict.py
+++ b/netlib/multidict.py
@@ -2,7 +2,6 @@ from __future__ import absolute_import, print_function, division
from abc import ABCMeta, abstractmethod
-from typing import Tuple, TypeVar
try:
from collections.abc import MutableMapping
@@ -10,14 +9,13 @@ except ImportError: # pragma: no cover
from collections import MutableMapping # Workaround for Python < 3.3
import six
-
-from .utils import Serializable
+from netlib import basetypes
@six.add_metaclass(ABCMeta)
-class _MultiDict(MutableMapping, Serializable):
+class _MultiDict(MutableMapping, basetypes.Serializable):
def __repr__(self):
- fields = tuple(
+ fields = (
repr(field)
for field in self.fields
)
@@ -172,6 +170,30 @@ class _MultiDict(MutableMapping, Serializable):
else:
return super(_MultiDict, self).items()
+ def clear(self, key):
+ """
+ Removes all items with the specified key, and does not raise an
+ exception if the key does not exist.
+ """
+ if key in self:
+ del self[key]
+
+ def collect(self):
+ """
+ Returns a list of (key, value) tuples, where values are either
+ singular if threre is only one matching item for a key, or a list
+ if there are more than one. The order of the keys matches the order
+ in the underlying fields list.
+ """
+ coll = []
+ for key in self:
+ values = self.get_all(key)
+ if len(values) == 1:
+ coll.append([key, values[0]])
+ else:
+ coll.append([key, values])
+ return coll
+
def to_dict(self):
"""
Get the MultiDict as a plain Python dict.
@@ -191,12 +213,8 @@ class _MultiDict(MutableMapping, Serializable):
}
"""
d = {}
- for key in self:
- values = self.get_all(key)
- if len(values) == 1:
- d[key] = values[0]
- else:
- d[key] = values
+ for k, v in self.collect():
+ d[k] = v
return d
def get_state(self):
@@ -207,13 +225,15 @@ class _MultiDict(MutableMapping, Serializable):
@classmethod
def from_state(cls, state):
- return cls(tuple(x) for x in state)
+ return cls(state)
class MultiDict(_MultiDict):
- def __init__(self, fields=None):
+ def __init__(self, fields=()):
super(MultiDict, self).__init__()
- self.fields = tuple(fields) if fields else tuple() # type: Tuple[Tuple[bytes, bytes], ...]
+ self.fields = tuple(
+ tuple(i) for i in fields
+ )
@six.add_metaclass(ABCMeta)
diff --git a/netlib/odict.py b/netlib/odict.py
index 8a638dab..f9f55991 100644
--- a/netlib/odict.py
+++ b/netlib/odict.py
@@ -3,10 +3,10 @@ import copy
import six
-from .utils import Serializable, safe_subn
+from netlib import basetypes, strutils
-class ODict(Serializable):
+class ODict(basetypes.Serializable):
"""
A dictionary-like object for managing ordered (key, value) data. Think
@@ -139,9 +139,9 @@ class ODict(Serializable):
"""
new, count = [], 0
for k, v in self.lst:
- k, c = safe_subn(pattern, repl, k, *args, **kwargs)
+ k, c = strutils.safe_subn(pattern, repl, k, *args, **kwargs)
count += c
- v, c = safe_subn(pattern, repl, v, *args, **kwargs)
+ v, c = strutils.safe_subn(pattern, repl, v, *args, **kwargs)
count += c
new.append([k, v])
self.lst = new
diff --git a/netlib/socks.py b/netlib/socks.py
index 57ccd1be..8d7c5279 100644
--- a/netlib/socks.py
+++ b/netlib/socks.py
@@ -2,7 +2,8 @@ from __future__ import (absolute_import, print_function, division)
import struct
import array
import ipaddress
-from . import tcp, utils
+
+from netlib import tcp, utils
class SocksError(Exception):
@@ -147,7 +148,7 @@ class UsernamePasswordAuth(object):
class UsernamePasswordAuthResponse(object):
- __slots__ = ("ver", "status")
+ __slots__ = ("ver", "status")
def __init__(self, ver, status):
self.ver = ver
diff --git a/netlib/strutils.py b/netlib/strutils.py
new file mode 100644
index 00000000..03b371f5
--- /dev/null
+++ b/netlib/strutils.py
@@ -0,0 +1,154 @@
+import re
+import unicodedata
+import codecs
+
+import six
+
+
+def always_bytes(unicode_or_bytes, *encode_args):
+ if isinstance(unicode_or_bytes, six.text_type):
+ return unicode_or_bytes.encode(*encode_args)
+ return unicode_or_bytes
+
+
+def native(s, *encoding_opts):
+ """
+ Convert :py:class:`bytes` or :py:class:`unicode` to the native
+ :py:class:`str` type, using latin1 encoding if conversion is necessary.
+
+ https://www.python.org/dev/peps/pep-3333/#a-note-on-string-types
+ """
+ if not isinstance(s, (six.binary_type, six.text_type)):
+ raise TypeError("%r is neither bytes nor unicode" % s)
+ if six.PY3:
+ if isinstance(s, six.binary_type):
+ return s.decode(*encoding_opts)
+ else:
+ if isinstance(s, six.text_type):
+ return s.encode(*encoding_opts)
+ return s
+
+
+def clean_bin(s, keep_spacing=True):
+ """
+ Cleans binary data to make it safe to display.
+
+ Args:
+ keep_spacing: If False, tabs and newlines will also be replaced.
+ """
+ if isinstance(s, six.text_type):
+ if keep_spacing:
+ keep = u" \n\r\t"
+ else:
+ keep = u" "
+ return u"".join(
+ ch if (unicodedata.category(ch)[0] not in "CZ" or ch in keep) else u"."
+ for ch in s
+ )
+ else:
+ if keep_spacing:
+ keep = (9, 10, 13) # \t, \n, \r,
+ else:
+ keep = ()
+ return b"".join(
+ six.int2byte(ch) if (31 < ch < 127 or ch in keep) else b"."
+ for ch in six.iterbytes(s)
+ )
+
+
+def safe_subn(pattern, repl, target, *args, **kwargs):
+ """
+ There are Unicode conversion problems with re.subn. We try to smooth
+ that over by casting the pattern and replacement to strings. We really
+ need a better solution that is aware of the actual content ecoding.
+ """
+ return re.subn(str(pattern), str(repl), target, *args, **kwargs)
+
+
+def bytes_to_escaped_str(data):
+ """
+ Take bytes and return a safe string that can be displayed to the user.
+
+ Single quotes are always escaped, double quotes are never escaped:
+ "'" + bytes_to_escaped_str(...) + "'"
+ gives a valid Python string.
+ """
+ # TODO: We may want to support multi-byte characters without escaping them.
+ # One way to do would be calling .decode("utf8", "backslashreplace") first
+ # and then escaping UTF8 control chars (see clean_bin).
+
+ if not isinstance(data, bytes):
+ raise ValueError("data must be bytes, but is {}".format(data.__class__.__name__))
+ # We always insert a double-quote here so that we get a single-quoted string back
+ # https://stackoverflow.com/questions/29019340/why-does-python-use-different-quotes-for-representing-strings-depending-on-their
+ return repr(b'"' + data).lstrip("b")[2:-1]
+
+
+def escaped_str_to_bytes(data):
+ """
+ Take an escaped string and return the unescaped bytes equivalent.
+ """
+ if not isinstance(data, six.string_types):
+ if six.PY2:
+ raise ValueError("data must be str or unicode, but is {}".format(data.__class__.__name__))
+ raise ValueError("data must be str, but is {}".format(data.__class__.__name__))
+
+ if six.PY2:
+ if isinstance(data, unicode):
+ data = data.encode("utf8")
+ return data.decode("string-escape")
+
+ # This one is difficult - we use an undocumented Python API here
+ # as per http://stackoverflow.com/a/23151714/934719
+ return codecs.escape_decode(data)[0]
+
+
+def isBin(s):
+ """
+ Does this string have any non-ASCII characters?
+ """
+ for i in s:
+ i = ord(i)
+ if i < 9 or 13 < i < 32 or 126 < i:
+ return True
+ return False
+
+
+def isMostlyBin(s):
+ s = s[:100]
+ return sum(isBin(ch) for ch in s) / len(s) > 0.3
+
+
+def isXML(s):
+ for i in s:
+ if i in "\n \t":
+ continue
+ elif i == "<":
+ return True
+ else:
+ return False
+
+
+def clean_hanging_newline(t):
+ """
+ Many editors will silently add a newline to the final line of a
+ document (I'm looking at you, Vim). This function fixes this common
+ problem at the risk of removing a hanging newline in the rare cases
+ where the user actually intends it.
+ """
+ if t and t[-1] == "\n":
+ return t[:-1]
+ return t
+
+
+def hexdump(s):
+ """
+ Returns:
+ A generator of (offset, hex, str) tuples
+ """
+ for i in range(0, len(s), 16):
+ offset = "{:0=10x}".format(i).encode()
+ part = s[i:i + 16]
+ x = b" ".join("{:0=2x}".format(i).encode() for i in six.iterbytes(part))
+ x = x.ljust(47) # 16*2 + 15
+ yield (offset, x, clean_bin(part, False))
diff --git a/netlib/tcp.py b/netlib/tcp.py
index d26bb5f7..de12102e 100644
--- a/netlib/tcp.py
+++ b/netlib/tcp.py
@@ -6,6 +6,7 @@ import sys
import threading
import time
import traceback
+import contextlib
import binascii
from six.moves import range
@@ -16,13 +17,10 @@ import six
import OpenSSL
from OpenSSL import SSL
-from . import certutils, version_check, utils
+from netlib import certutils, version_check, basetypes, exceptions
# This is a rather hackish way to make sure that
# the latest version of pyOpenSSL is actually installed.
-from netlib.exceptions import InvalidCertificateException, TcpReadIncomplete, TlsException, \
- TcpTimeout, TcpDisconnect, TcpException
-
version_check.check_pyopenssl_version()
if six.PY2:
@@ -71,6 +69,7 @@ sslversion_choices = {
"TLSv1_2": (SSL.TLSv1_2_METHOD, SSL_BASIC_OPTIONS),
}
+
class SSLKeyLogger(object):
def __init__(self, filename):
@@ -161,17 +160,17 @@ class Writer(_FileLike):
def flush(self):
"""
- May raise TcpDisconnect
+ May raise exceptions.TcpDisconnect
"""
if hasattr(self.o, "flush"):
try:
self.o.flush()
except (socket.error, IOError) as v:
- raise TcpDisconnect(str(v))
+ raise exceptions.TcpDisconnect(str(v))
def write(self, v):
"""
- May raise TcpDisconnect
+ May raise exceptions.TcpDisconnect
"""
if v:
self.first_byte_timestamp = self.first_byte_timestamp or time.time()
@@ -184,7 +183,7 @@ class Writer(_FileLike):
self.add_log(v[:r])
return r
except (SSL.Error, socket.error) as e:
- raise TcpDisconnect(str(e))
+ raise exceptions.TcpDisconnect(str(e))
class Reader(_FileLike):
@@ -215,17 +214,17 @@ class Reader(_FileLike):
time.sleep(0.1)
continue
else:
- raise TcpTimeout()
+ raise exceptions.TcpTimeout()
except socket.timeout:
- raise TcpTimeout()
+ raise exceptions.TcpTimeout()
except socket.error as e:
- raise TcpDisconnect(str(e))
+ raise exceptions.TcpDisconnect(str(e))
except SSL.SysCallError as e:
if e.args == (-1, 'Unexpected EOF'):
break
- raise TlsException(str(e))
+ raise exceptions.TlsException(str(e))
except SSL.Error as e:
- raise TlsException(str(e))
+ raise exceptions.TlsException(str(e))
self.first_byte_timestamp = self.first_byte_timestamp or time.time()
if not data:
break
@@ -259,9 +258,9 @@ class Reader(_FileLike):
result = self.read(length)
if length != -1 and len(result) != length:
if not result:
- raise TcpDisconnect()
+ raise exceptions.TcpDisconnect()
else:
- raise TcpReadIncomplete(
+ raise exceptions.TcpReadIncomplete(
"Expected %s bytes, got %s" % (length, len(result))
)
return result
@@ -274,7 +273,7 @@ class Reader(_FileLike):
Up to the next N bytes if peeking is successful.
Raises:
- TcpException if there was an error with the socket
+ exceptions.TcpException if there was an error with the socket
TlsException if there was an error with pyOpenSSL.
NotImplementedError if the underlying file object is not a [pyOpenSSL] socket
"""
@@ -282,7 +281,7 @@ class Reader(_FileLike):
try:
return self.o._sock.recv(length, socket.MSG_PEEK)
except socket.error as e:
- raise TcpException(repr(e))
+ raise exceptions.TcpException(repr(e))
elif isinstance(self.o, SSL.Connection):
try:
if tuple(int(x) for x in OpenSSL.__version__.split(".")[:2]) > (0, 15):
@@ -296,12 +295,12 @@ class Reader(_FileLike):
self.o._raise_ssl_error(self.o._ssl, result)
return SSL._ffi.buffer(buf, result)[:]
except SSL.Error as e:
- six.reraise(TlsException, TlsException(str(e)), sys.exc_info()[2])
+ six.reraise(exceptions.TlsException, exceptions.TlsException(str(e)), sys.exc_info()[2])
else:
raise NotImplementedError("Can only peek into (pyOpenSSL) sockets")
-class Address(utils.Serializable):
+class Address(basetypes.Serializable):
"""
This class wraps an IPv4/IPv6 tuple to provide named attributes and
@@ -489,7 +488,7 @@ class _Connection(object):
try:
self.wfile.flush()
self.wfile.close()
- except TcpDisconnect:
+ except exceptions.TcpDisconnect:
pass
self.rfile.close()
@@ -553,7 +552,7 @@ class _Connection(object):
# TODO: maybe change this to with newer pyOpenSSL APIs
context.set_tmp_ecdh(OpenSSL.crypto.get_elliptic_curve('prime256v1'))
except SSL.Error as v:
- raise TlsException("SSL cipher specification error: %s" % str(v))
+ raise exceptions.TlsException("SSL cipher specification error: %s" % str(v))
# SSLKEYLOGFILE
if log_ssl_key:
@@ -574,11 +573,17 @@ class _Connection(object):
elif alpn_select_callback is not None and alpn_select is None:
context.set_alpn_select_callback(alpn_select_callback)
elif alpn_select_callback is not None and alpn_select is not None:
- raise TlsException("ALPN error: only define alpn_select (string) OR alpn_select_callback (method).")
+ raise exceptions.TlsException("ALPN error: only define alpn_select (string) OR alpn_select_callback (method).")
return context
+@contextlib.contextmanager
+def _closer(client):
+ yield
+ client.close()
+
+
class TCPClient(_Connection):
def __init__(self, address, source_address=None):
@@ -631,7 +636,7 @@ class TCPClient(_Connection):
context.use_privatekey_file(cert)
context.use_certificate_file(cert)
except SSL.Error as v:
- raise TlsException("SSL client certificate error: %s" % str(v))
+ raise exceptions.TlsException("SSL client certificate error: %s" % str(v))
return context
def convert_to_ssl(self, sni=None, alpn_protos=None, **sslctx_kwargs):
@@ -645,7 +650,7 @@ class TCPClient(_Connection):
"""
verification_mode = sslctx_kwargs.get('verify_options', None)
if verification_mode == SSL.VERIFY_PEER and not sni:
- raise TlsException("Cannot validate certificate hostname without SNI")
+ raise exceptions.TlsException("Cannot validate certificate hostname without SNI")
context = self.create_ssl_context(
alpn_protos=alpn_protos,
@@ -660,14 +665,14 @@ class TCPClient(_Connection):
self.connection.do_handshake()
except SSL.Error as v:
if self.ssl_verification_error:
- raise InvalidCertificateException("SSL handshake error: %s" % repr(v))
+ raise exceptions.InvalidCertificateException("SSL handshake error: %s" % repr(v))
else:
- raise TlsException("SSL handshake error: %s" % repr(v))
+ raise exceptions.TlsException("SSL handshake error: %s" % repr(v))
else:
# Fix for pre v1.0 OpenSSL, which doesn't throw an exception on
# certificate validation failure
if verification_mode == SSL.VERIFY_PEER and self.ssl_verification_error is not None:
- raise InvalidCertificateException("SSL handshake error: certificate verify failed")
+ raise exceptions.InvalidCertificateException("SSL handshake error: certificate verify failed")
self.cert = certutils.SSLCert(self.connection.get_peer_certificate())
@@ -690,7 +695,7 @@ class TCPClient(_Connection):
except (ValueError, ssl_match_hostname.CertificateError) as e:
self.ssl_verification_error = dict(depth=0, errno="Invalid Hostname")
if verification_mode == SSL.VERIFY_PEER:
- raise InvalidCertificateException("Presented certificate for {} is not valid: {}".format(sni, str(e)))
+ raise exceptions.InvalidCertificateException("Presented certificate for {} is not valid: {}".format(sni, str(e)))
self.ssl_established = True
self.rfile.set_descriptor(self.connection)
@@ -704,12 +709,13 @@ class TCPClient(_Connection):
connection.connect(self.address())
self.source_address = Address(connection.getsockname())
except (socket.error, IOError) as err:
- raise TcpException(
+ raise exceptions.TcpException(
'Error connecting to "%s": %s' %
(self.address.host, err))
self.connection = connection
self.ip_address = Address(connection.getpeername())
self._makefile()
+ return _closer(self)
def settimeout(self, n):
self.connection.settimeout(n)
@@ -817,7 +823,7 @@ class BaseHandler(_Connection):
try:
self.connection.do_handshake()
except SSL.Error as v:
- raise TlsException("SSL handshake error: %s" % repr(v))
+ raise exceptions.TlsException("SSL handshake error: %s" % repr(v))
self.ssl_established = True
self.rfile.set_descriptor(self.connection)
self.wfile.set_descriptor(self.connection)
@@ -835,6 +841,25 @@ class BaseHandler(_Connection):
return b""
+class Counter:
+ def __init__(self):
+ self._count = 0
+ self._lock = threading.Lock()
+
+ @property
+ def count(self):
+ with self._lock:
+ return self._count
+
+ def __enter__(self):
+ with self._lock:
+ self._count += 1
+
+ def __exit__(self, *args):
+ with self._lock:
+ self._count -= 1
+
+
class TCPServer(object):
request_queue_size = 20
@@ -847,15 +872,17 @@ class TCPServer(object):
self.socket.bind(self.address())
self.address = Address.wrap(self.socket.getsockname())
self.socket.listen(self.request_queue_size)
+ self.handler_counter = Counter()
def connection_thread(self, connection, client_address):
- client_address = Address(client_address)
- try:
- self.handle_client_connection(connection, client_address)
- except:
- self.handle_error(connection, client_address)
- finally:
- close_socket(connection)
+ with self.handler_counter:
+ client_address = Address(client_address)
+ try:
+ self.handle_client_connection(connection, client_address)
+ except:
+ self.handle_error(connection, client_address)
+ finally:
+ close_socket(connection)
def serve_forever(self, poll_interval=0.1):
self.__is_shut_down.clear()
@@ -900,7 +927,7 @@ class TCPServer(object):
"""
# If a thread has persisted after interpreter exit, the module might be
# none.
- if traceback:
+ if traceback and six:
exc = six.text_type(traceback.format_exc())
print(u'-' * 40, file=fp)
print(
diff --git a/netlib/tutils.py b/netlib/tutils.py
index 18d632f0..452766d6 100644
--- a/netlib/tutils.py
+++ b/netlib/tutils.py
@@ -7,8 +7,7 @@ from contextlib import contextmanager
import six
import sys
-from . import utils, tcp
-from .http import Request, Response, Headers
+from netlib import utils, tcp, http
def treader(bytes):
@@ -91,8 +90,7 @@ class RaisesContext(object):
test_data = utils.Data(__name__)
# FIXME: Temporary workaround during repo merge.
-import os
-test_data.dirname = os.path.join(test_data.dirname,"..","test","netlib")
+test_data.dirname = os.path.join(test_data.dirname, "..", "test", "netlib")
def treq(**kwargs):
@@ -108,11 +106,11 @@ def treq(**kwargs):
port=22,
path=b"/path",
http_version=b"HTTP/1.1",
- headers=Headers(header="qvalue", content_length="7"),
+ headers=http.Headers(((b"header", b"qvalue"), (b"content-length", b"7"))),
content=b"content"
)
default.update(kwargs)
- return Request(**default)
+ return http.Request(**default)
def tresp(**kwargs):
@@ -124,10 +122,10 @@ def tresp(**kwargs):
http_version=b"HTTP/1.1",
status_code=200,
reason=b"OK",
- headers=Headers(header_response="svalue", content_length="7"),
+ headers=http.Headers(((b"header-response", b"svalue"), (b"content-length", b"7"))),
content=b"message",
timestamp_start=time.time(),
timestamp_end=time.time(),
)
default.update(kwargs)
- return Response(**default)
+ return http.Response(**default)
diff --git a/netlib/utils.py b/netlib/utils.py
index 7499f71f..b4b99679 100644
--- a/netlib/utils.py
+++ b/netlib/utils.py
@@ -1,121 +1,11 @@
from __future__ import absolute_import, print_function, division
import os.path
import re
-import codecs
-import unicodedata
-from abc import ABCMeta, abstractmethod
import importlib
import inspect
import six
-from six.moves import urllib
-import hyperframe
-
-
-@six.add_metaclass(ABCMeta)
-class Serializable(object):
- """
- Abstract Base Class that defines an API to save an object's state and restore it later on.
- """
-
- @classmethod
- @abstractmethod
- def from_state(cls, state):
- """
- Create a new object from the given state.
- """
- raise NotImplementedError()
-
- @abstractmethod
- def get_state(self):
- """
- Retrieve object state.
- """
- raise NotImplementedError()
-
- @abstractmethod
- def set_state(self, state):
- """
- Set object state to the given state.
- """
- raise NotImplementedError()
-
- def copy(self):
- return self.from_state(self.get_state())
-
-
-def always_bytes(unicode_or_bytes, *encode_args):
- if isinstance(unicode_or_bytes, six.text_type):
- return unicode_or_bytes.encode(*encode_args)
- return unicode_or_bytes
-
-
-def native(s, *encoding_opts):
- """
- Convert :py:class:`bytes` or :py:class:`unicode` to the native
- :py:class:`str` type, using latin1 encoding if conversion is necessary.
-
- https://www.python.org/dev/peps/pep-3333/#a-note-on-string-types
- """
- if not isinstance(s, (six.binary_type, six.text_type)):
- raise TypeError("%r is neither bytes nor unicode" % s)
- if six.PY3:
- if isinstance(s, six.binary_type):
- return s.decode(*encoding_opts)
- else:
- if isinstance(s, six.text_type):
- return s.encode(*encoding_opts)
- return s
-
-
-def isascii(bytes):
- try:
- bytes.decode("ascii")
- except ValueError:
- return False
- return True
-
-
-def clean_bin(s, keep_spacing=True):
- """
- Cleans binary data to make it safe to display.
-
- Args:
- keep_spacing: If False, tabs and newlines will also be replaced.
- """
- if isinstance(s, six.text_type):
- if keep_spacing:
- keep = u" \n\r\t"
- else:
- keep = u" "
- return u"".join(
- ch if (unicodedata.category(ch)[0] not in "CZ" or ch in keep) else u"."
- for ch in s
- )
- else:
- if keep_spacing:
- keep = (9, 10, 13) # \t, \n, \r,
- else:
- keep = ()
- return b"".join(
- six.int2byte(ch) if (31 < ch < 127 or ch in keep) else b"."
- for ch in six.iterbytes(s)
- )
-
-
-def hexdump(s):
- """
- Returns:
- A generator of (offset, hex, str) tuples
- """
- for i in range(0, len(s), 16):
- offset = "{:0=10x}".format(i).encode()
- part = s[i:i + 16]
- x = b" ".join("{:0=2x}".format(i).encode() for i in six.iterbytes(part))
- x = x.ljust(47) # 16*2 + 15
- yield (offset, x, clean_bin(part, False))
-
def setbit(byte, offset, value):
"""
@@ -161,22 +51,6 @@ class BiDi(object):
return self.values.get(n, default)
-def pretty_size(size):
- suffixes = [
- ("B", 2 ** 10),
- ("kB", 2 ** 20),
- ("MB", 2 ** 30),
- ]
- for suf, lim in suffixes:
- if size >= lim:
- continue
- else:
- x = round(size / float(lim / 2 ** 10), 2)
- if x == int(x):
- x = int(x)
- return str(x) + suf
-
-
class Data(object):
def __init__(self, name):
@@ -222,83 +96,6 @@ def is_valid_port(port):
return 0 <= port <= 65535
-# PY2 workaround
-def decode_parse_result(result, enc):
- if hasattr(result, "decode"):
- return result.decode(enc)
- else:
- return urllib.parse.ParseResult(*[x.decode(enc) for x in result])
-
-
-# PY2 workaround
-def encode_parse_result(result, enc):
- if hasattr(result, "encode"):
- return result.encode(enc)
- else:
- return urllib.parse.ParseResult(*[x.encode(enc) for x in result])
-
-
-def parse_url(url):
- """
- URL-parsing function that checks that
- - port is an integer 0-65535
- - host is a valid IDNA-encoded hostname with no null-bytes
- - path is valid ASCII
-
- Args:
- A URL (as bytes or as unicode)
-
- Returns:
- A (scheme, host, port, path) tuple
-
- Raises:
- ValueError, if the URL is not properly formatted.
- """
- parsed = urllib.parse.urlparse(url)
-
- if not parsed.hostname:
- raise ValueError("No hostname given")
-
- if isinstance(url, six.binary_type):
- host = parsed.hostname
-
- # this should not raise a ValueError,
- # but we try to be very forgiving here and accept just everything.
- # decode_parse_result(parsed, "ascii")
- else:
- host = parsed.hostname.encode("idna")
- parsed = encode_parse_result(parsed, "ascii")
-
- port = parsed.port
- if not port:
- port = 443 if parsed.scheme == b"https" else 80
-
- full_path = urllib.parse.urlunparse(
- (b"", b"", parsed.path, parsed.params, parsed.query, parsed.fragment)
- )
- if not full_path.startswith(b"/"):
- full_path = b"/" + full_path
-
- if not is_valid_host(host):
- raise ValueError("Invalid Host")
- if not is_valid_port(port):
- raise ValueError("Invalid Port")
-
- return parsed.scheme, host, port, full_path
-
-
-def get_header_tokens(headers, key):
- """
- Retrieve all tokens for a header key. A number of different headers
- follow a pattern where each header line can containe comma-separated
- tokens, and headers can be set multiple times.
- """
- if key not in headers:
- return []
- tokens = headers[key].split(",")
- return [token.strip() for token in tokens]
-
-
def hostport(scheme, host, port):
"""
Returns the host component, with a port specifcation if needed.
@@ -310,141 +107,3 @@ def hostport(scheme, host, port):
return b"%s:%d" % (host, port)
else:
return "%s:%d" % (host, port)
-
-
-def unparse_url(scheme, host, port, path=""):
- """
- Returns a URL string, constructed from the specified components.
-
- Args:
- All args must be str.
- """
- if path == "*":
- path = ""
- return "%s://%s%s" % (scheme, hostport(scheme, host, port), path)
-
-
-def urlencode(s):
- """
- Takes a list of (key, value) tuples and returns a urlencoded string.
- """
- s = [tuple(i) for i in s]
- return urllib.parse.urlencode(s, False)
-
-
-def urldecode(s):
- """
- Takes a urlencoded string and returns a list of (key, value) tuples.
- """
- return urllib.parse.parse_qsl(s, keep_blank_values=True)
-
-
-def parse_content_type(c):
- """
- A simple parser for content-type values. Returns a (type, subtype,
- parameters) tuple, where type and subtype are strings, and parameters
- is a dict. If the string could not be parsed, return None.
-
- E.g. the following string:
-
- text/html; charset=UTF-8
-
- Returns:
-
- ("text", "html", {"charset": "UTF-8"})
- """
- parts = c.split(";", 1)
- ts = parts[0].split("/", 1)
- if len(ts) != 2:
- return None
- d = {}
- if len(parts) == 2:
- for i in parts[1].split(";"):
- clause = i.split("=", 1)
- if len(clause) == 2:
- d[clause[0].strip()] = clause[1].strip()
- return ts[0].lower(), ts[1].lower(), d
-
-
-def multipartdecode(headers, content):
- """
- Takes a multipart boundary encoded string and returns list of (key, value) tuples.
- """
- v = headers.get("content-type")
- if v:
- v = parse_content_type(v)
- if not v:
- return []
- try:
- boundary = v[2]["boundary"].encode("ascii")
- except (KeyError, UnicodeError):
- return []
-
- rx = re.compile(br'\bname="([^"]+)"')
- r = []
-
- for i in content.split(b"--" + boundary):
- parts = i.splitlines()
- if len(parts) > 1 and parts[0][0:2] != b"--":
- match = rx.search(parts[1])
- if match:
- key = match.group(1)
- value = b"".join(parts[3 + parts[2:].index(b""):])
- r.append((key, value))
- return r
- return []
-
-
-def http2_read_raw_frame(rfile):
- header = rfile.safe_read(9)
- length = int(codecs.encode(header[:3], 'hex_codec'), 16)
-
- if length == 4740180:
- raise ValueError("Length field looks more like HTTP/1.1: %s" % rfile.peek(20))
-
- body = rfile.safe_read(length)
- return [header, body]
-
-
-def http2_read_frame(rfile):
- header, body = http2_read_raw_frame(rfile)
- frame, length = hyperframe.frame.Frame.parse_frame_header(header)
- frame.parse_body(memoryview(body))
- return frame
-
-
-def safe_subn(pattern, repl, target, *args, **kwargs):
- """
- There are Unicode conversion problems with re.subn. We try to smooth
- that over by casting the pattern and replacement to strings. We really
- need a better solution that is aware of the actual content ecoding.
- """
- return re.subn(str(pattern), str(repl), target, *args, **kwargs)
-
-
-def bytes_to_escaped_str(data):
- """
- Take bytes and return a safe string that can be displayed to the user.
- """
- # TODO: We may want to support multi-byte characters without escaping them.
- # One way to do would be calling .decode("utf8", "backslashreplace") first
- # and then escaping UTF8 control chars (see clean_bin).
-
- if not isinstance(data, bytes):
- raise ValueError("data must be bytes")
- return repr(data).lstrip("b")[1:-1]
-
-
-def escaped_str_to_bytes(data):
- """
- Take an escaped string and return the unescaped bytes equivalent.
- """
- if not isinstance(data, str):
- raise ValueError("data must be str")
-
- if six.PY2:
- return data.decode("string-escape")
-
- # This one is difficult - we use an undocumented Python API here
- # as per http://stackoverflow.com/a/23151714/934719
- return codecs.escape_decode(data)[0]
diff --git a/netlib/version_check.py b/netlib/version_check.py
index 8e05b458..63f3e876 100644
--- a/netlib/version_check.py
+++ b/netlib/version_check.py
@@ -10,7 +10,6 @@ import os.path
import six
import OpenSSL
-from . import version
PYOPENSSL_MIN_VERSION = (0, 15)
diff --git a/netlib/websockets/__init__.py b/netlib/websockets/__init__.py
index 1c143919..fea696d9 100644
--- a/netlib/websockets/__init__.py
+++ b/netlib/websockets/__init__.py
@@ -1,2 +1,11 @@
-from .frame import *
-from .protocol import *
+from __future__ import absolute_import, print_function, division
+from .frame import FrameHeader, Frame, OPCODE
+from .protocol import Masker, WebsocketsProtocol
+
+__all__ = [
+ "FrameHeader",
+ "Frame",
+ "Masker",
+ "WebsocketsProtocol",
+ "OPCODE",
+]
diff --git a/netlib/websockets/frame.py b/netlib/websockets/frame.py
index fce2c9d3..42196ffb 100644
--- a/netlib/websockets/frame.py
+++ b/netlib/websockets/frame.py
@@ -6,15 +6,17 @@ import warnings
import six
-from .protocol import Masker
from netlib import tcp
+from netlib import strutils
from netlib import utils
+from netlib import human
+from netlib.websockets import protocol
MAX_16_BIT_INT = (1 << 16)
MAX_64_BIT_INT = (1 << 64)
-DEFAULT=object()
+DEFAULT = object()
OPCODE = utils.BiDi(
CONTINUE=0x00,
@@ -98,7 +100,7 @@ class FrameHeader(object):
if self.masking_key:
vals.append(":key=%s" % repr(self.masking_key))
if self.payload_length:
- vals.append(" %s" % utils.pretty_size(self.payload_length))
+ vals.append(" %s" % human.pretty_size(self.payload_length))
return "".join(vals)
def human_readable(self):
@@ -253,7 +255,7 @@ class Frame(object):
def __repr__(self):
ret = repr(self.header)
if self.payload:
- ret = ret + "\nPayload:\n" + utils.clean_bin(self.payload).decode("ascii")
+ ret = ret + "\nPayload:\n" + strutils.clean_bin(self.payload).decode("ascii")
return ret
def human_readable(self):
@@ -266,7 +268,7 @@ class Frame(object):
"""
b = bytes(self.header)
if self.header.masking_key:
- b += Masker(self.header.masking_key)(self.payload)
+ b += protocol.Masker(self.header.masking_key)(self.payload)
else:
b += self.payload
return b
@@ -295,7 +297,7 @@ class Frame(object):
payload = fp.safe_read(header.payload_length)
if header.mask == 1 and header.masking_key:
- payload = Masker(header.masking_key)(payload)
+ payload = protocol.Masker(header.masking_key)(payload)
return cls(
payload,
diff --git a/netlib/websockets/protocol.py b/netlib/websockets/protocol.py
index 1e95fa1c..c1b7be2c 100644
--- a/netlib/websockets/protocol.py
+++ b/netlib/websockets/protocol.py
@@ -1,26 +1,26 @@
+"""
+Colleciton of utility functions that implement small portions of the RFC6455
+WebSockets Protocol Useful for building WebSocket clients and servers.
+Emphassis is on readabilty, simplicity and modularity, not performance or
+completeness
+This is a work in progress and does not yet contain all the utilites need to
+create fully complient client/servers #
+Spec: https://tools.ietf.org/html/rfc6455
-# Colleciton of utility functions that implement small portions of the RFC6455
-# WebSockets Protocol Useful for building WebSocket clients and servers.
-#
-# Emphassis is on readabilty, simplicity and modularity, not performance or
-# completeness
-#
-# This is a work in progress and does not yet contain all the utilites need to
-# create fully complient client/servers #
-# Spec: https://tools.ietf.org/html/rfc6455
+The magic sha that websocket servers must know to prove they understand
+RFC6455
+"""
-# The magic sha that websocket servers must know to prove they understand
-# RFC6455
from __future__ import absolute_import
import base64
import hashlib
import os
-import binascii
import six
-from ..http import Headers
+
+from netlib import http
websockets_magic = b'258EAFA5-E914-47DA-95CA-C5AB0DC85B11'
VERSION = "13"
@@ -73,11 +73,11 @@ class WebsocketsProtocol(object):
specified, it is generated, and can be found in sec-websocket-key in
the returned header set.
- Returns an instance of Headers
+ Returns an instance of http.Headers
"""
if not key:
key = base64.b64encode(os.urandom(16)).decode('ascii')
- return Headers(
+ return http.Headers(
sec_websocket_key=key,
sec_websocket_version=version,
connection="Upgrade",
@@ -89,27 +89,24 @@ class WebsocketsProtocol(object):
"""
The server response is a valid HTTP 101 response.
"""
- return Headers(
+ return http.Headers(
sec_websocket_accept=self.create_server_nonce(key),
connection="Upgrade",
upgrade="websocket"
)
-
@classmethod
def check_client_handshake(self, headers):
if headers.get("upgrade") != "websocket":
return
return headers.get("sec-websocket-key")
-
@classmethod
def check_server_handshake(self, headers):
if headers.get("upgrade") != "websocket":
return
return headers.get("sec-websocket-accept")
-
@classmethod
def create_server_nonce(self, client_nonce):
return base64.b64encode(hashlib.sha1(client_nonce + websockets_magic).digest())
diff --git a/netlib/wsgi.py b/netlib/wsgi.py
index d6dfae5d..c66fddc2 100644
--- a/netlib/wsgi.py
+++ b/netlib/wsgi.py
@@ -1,14 +1,13 @@
from __future__ import (absolute_import, print_function, division)
-from io import BytesIO, StringIO
-import urllib
+
import time
import traceback
-
import six
+from io import BytesIO
from six.moves import urllib
-from netlib.utils import always_bytes, native
-from . import http, tcp
+from netlib import http, tcp, strutils
+
class ClientConn(object):
@@ -55,38 +54,38 @@ class WSGIAdaptor(object):
self.app, self.domain, self.port, self.sversion = app, domain, port, sversion
def make_environ(self, flow, errsoc, **extra):
- path = native(flow.request.path, "latin-1")
+ path = strutils.native(flow.request.path, "latin-1")
if '?' in path:
- path_info, query = native(path, "latin-1").split('?', 1)
+ path_info, query = strutils.native(path, "latin-1").split('?', 1)
else:
path_info = path
query = ''
environ = {
'wsgi.version': (1, 0),
- 'wsgi.url_scheme': native(flow.request.scheme, "latin-1"),
+ 'wsgi.url_scheme': strutils.native(flow.request.scheme, "latin-1"),
'wsgi.input': BytesIO(flow.request.content or b""),
'wsgi.errors': errsoc,
'wsgi.multithread': True,
'wsgi.multiprocess': False,
'wsgi.run_once': False,
'SERVER_SOFTWARE': self.sversion,
- 'REQUEST_METHOD': native(flow.request.method, "latin-1"),
+ 'REQUEST_METHOD': strutils.native(flow.request.method, "latin-1"),
'SCRIPT_NAME': '',
'PATH_INFO': urllib.parse.unquote(path_info),
'QUERY_STRING': query,
- 'CONTENT_TYPE': native(flow.request.headers.get('Content-Type', ''), "latin-1"),
- 'CONTENT_LENGTH': native(flow.request.headers.get('Content-Length', ''), "latin-1"),
+ 'CONTENT_TYPE': strutils.native(flow.request.headers.get('Content-Type', ''), "latin-1"),
+ 'CONTENT_LENGTH': strutils.native(flow.request.headers.get('Content-Length', ''), "latin-1"),
'SERVER_NAME': self.domain,
'SERVER_PORT': str(self.port),
- 'SERVER_PROTOCOL': native(flow.request.http_version, "latin-1"),
+ 'SERVER_PROTOCOL': strutils.native(flow.request.http_version, "latin-1"),
}
environ.update(extra)
if flow.client_conn.address:
- environ["REMOTE_ADDR"] = native(flow.client_conn.address.host, "latin-1")
+ environ["REMOTE_ADDR"] = strutils.native(flow.client_conn.address.host, "latin-1")
environ["REMOTE_PORT"] = flow.client_conn.address.port
for key, value in flow.request.headers.items():
- key = 'HTTP_' + native(key, "latin-1").upper().replace('-', '_')
+ key = 'HTTP_' + strutils.native(key, "latin-1").upper().replace('-', '_')
if key not in ('HTTP_CONTENT_TYPE', 'HTTP_CONTENT_LENGTH'):
environ[key] = value
return environ
@@ -140,7 +139,7 @@ class WSGIAdaptor(object):
elif state["status"]:
raise AssertionError('Response already started')
state["status"] = status
- state["headers"] = http.Headers([[always_bytes(k), always_bytes(v)] for k,v in headers])
+ state["headers"] = http.Headers([[strutils.always_bytes(k), strutils.always_bytes(v)] for k, v in headers])
if exc_info:
self.error_page(soc, state["headers_sent"], traceback.format_tb(exc_info[2]))
state["headers_sent"] = True
@@ -154,7 +153,7 @@ class WSGIAdaptor(object):
write(i)
if not state["headers_sent"]:
write(b"")
- except Exception as e:
+ except Exception:
try:
s = traceback.format_exc()
errs.write(s.encode("utf-8", "replace"))