aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--netlib/http/authentication.py4
-rw-r--r--netlib/http/exceptions.py18
-rw-r--r--netlib/http/http1/protocol.py41
-rw-r--r--netlib/http/http2/protocol.py44
-rw-r--r--netlib/http/semantics.py297
-rw-r--r--netlib/tutils.py10
-rw-r--r--netlib/utils.py13
-rw-r--r--netlib/websockets/protocol.py28
-rw-r--r--netlib/wsgi.py22
-rw-r--r--test/http/http1/test_protocol.py127
-rw-r--r--test/http/http2/test_protocol.py22
-rw-r--r--test/http/test_authentication.py44
-rw-r--r--test/http/test_exceptions.py20
-rw-r--r--test/http/test_semantics.py238
-rw-r--r--test/test_utils.py25
-rw-r--r--test/test_wsgi.py8
-rw-r--r--test/websockets/test_websockets.py12
17 files changed, 607 insertions, 366 deletions
diff --git a/netlib/http/authentication.py b/netlib/http/authentication.py
index 29b9eb3c..fe1f0d14 100644
--- a/netlib/http/authentication.py
+++ b/netlib/http/authentication.py
@@ -62,10 +62,10 @@ class BasicProxyAuth(NullProxyAuth):
del headers[self.AUTH_HEADER]
def authenticate(self, headers):
- auth_value = headers.get(self.AUTH_HEADER, [])
+ auth_value = headers.get(self.AUTH_HEADER)
if not auth_value:
return False
- parts = parse_http_basic_auth(auth_value[0])
+ parts = parse_http_basic_auth(auth_value)
if not parts:
return False
scheme, username, password = parts
diff --git a/netlib/http/exceptions.py b/netlib/http/exceptions.py
index 987a7908..8a2bbebc 100644
--- a/netlib/http/exceptions.py
+++ b/netlib/http/exceptions.py
@@ -1,6 +1,3 @@
-from netlib import odict
-
-
class HttpError(Exception):
def __init__(self, code, message):
@@ -10,18 +7,3 @@ class HttpError(Exception):
class HttpErrorConnClosed(HttpError):
pass
-
-
-class HttpAuthenticationError(Exception):
-
- def __init__(self, auth_headers=None):
- super(HttpAuthenticationError, self).__init__(
- "Proxy Authentication Required"
- )
- if isinstance(auth_headers, dict):
- auth_headers = odict.ODictCaseless(auth_headers.items())
- self.headers = auth_headers
- self.code = 407
-
- def __repr__(self):
- return "Proxy Authentication Required"
diff --git a/netlib/http/http1/protocol.py b/netlib/http/http1/protocol.py
index 50975818..bf33a18e 100644
--- a/netlib/http/http1/protocol.py
+++ b/netlib/http/http1/protocol.py
@@ -3,8 +3,8 @@ import string
import sys
import time
-from netlib import odict, utils, tcp, http
-from netlib.http import semantics
+from ... import utils, tcp, http
+from .. import semantics, Headers
from ..exceptions import *
@@ -96,7 +96,7 @@ class HTTP1Protocol(semantics.ProtocolMixin):
if headers is None:
raise HttpError(400, "Invalid headers")
- expect_header = headers.get_first("expect", "").lower()
+ expect_header = headers.get("expect", "").lower()
if expect_header == "100-continue" and httpversion == (1, 1):
self.tcp_handler.wfile.write(
'HTTP/1.1 100 Continue\r\n'
@@ -232,10 +232,9 @@ class HTTP1Protocol(semantics.ProtocolMixin):
Read a set of headers.
Stop once a blank line is reached.
- Return a ODictCaseless object, or None if headers are invalid.
+ Return a Header object, or None if headers are invalid.
"""
ret = []
- name = ''
while True:
line = self.tcp_handler.rfile.readline()
if not line or line == '\r\n' or line == '\n':
@@ -254,7 +253,7 @@ class HTTP1Protocol(semantics.ProtocolMixin):
ret.append([name, value])
else:
return None
- return odict.ODictCaseless(ret)
+ return Headers(ret)
def read_http_body(self, *args, **kwargs):
@@ -272,7 +271,7 @@ class HTTP1Protocol(semantics.ProtocolMixin):
):
"""
Read an HTTP message body:
- headers: An ODictCaseless object
+ headers: A Header object
limit: Size limit.
is_request: True if the body to read belongs to a request, False
otherwise
@@ -356,7 +355,7 @@ class HTTP1Protocol(semantics.ProtocolMixin):
return None
if "content-length" in headers:
try:
- size = int(headers["content-length"][0])
+ size = int(headers["content-length"])
if size < 0:
raise ValueError()
return size
@@ -369,9 +368,7 @@ class HTTP1Protocol(semantics.ProtocolMixin):
@classmethod
def has_chunked_encoding(self, headers):
- return "chunked" in [
- i.lower() for i in utils.get_header_tokens(headers, "transfer-encoding")
- ]
+ return "chunked" in headers.get("transfer-encoding", "").lower()
def _get_request_line(self):
@@ -547,18 +544,20 @@ class HTTP1Protocol(semantics.ProtocolMixin):
def _assemble_request_headers(self, request):
headers = request.headers.copy()
for k in request._headers_to_strip_off:
- del headers[k]
+ headers.pop(k, None)
if 'host' not in headers and request.scheme and request.host and request.port:
- headers["Host"] = [utils.hostport(request.scheme,
- request.host,
- request.port)]
+ headers["Host"] = utils.hostport(
+ request.scheme,
+ request.host,
+ request.port
+ )
# If content is defined (i.e. not None or CONTENT_MISSING), we always
# add a content-length header.
if request.body or request.body == "":
- headers["Content-Length"] = [str(len(request.body))]
+ headers["Content-Length"] = str(len(request.body))
- return headers.format()
+ return str(headers)
def _assemble_response_first_line(self, response):
return 'HTTP/%s.%s %s %s' % (
@@ -575,13 +574,13 @@ class HTTP1Protocol(semantics.ProtocolMixin):
):
headers = response.headers.copy()
for k in response._headers_to_strip_off:
- del headers[k]
+ headers.pop(k, None)
if not preserve_transfer_encoding:
- del headers['Transfer-Encoding']
+ headers.pop('Transfer-Encoding', None)
# If body is defined (i.e. not None or CONTENT_MISSING), we always
# add a content-length header.
if response.body or response.body == "":
- headers["Content-Length"] = [str(len(response.body))]
+ headers["Content-Length"] = str(len(response.body))
- return headers.format()
+ return str(headers)
diff --git a/netlib/http/http2/protocol.py b/netlib/http/http2/protocol.py
index 4328ebdd..b6d376d3 100644
--- a/netlib/http/http2/protocol.py
+++ b/netlib/http/http2/protocol.py
@@ -3,7 +3,7 @@ import itertools
import time
from hpack.hpack import Encoder, Decoder
-from netlib import http, utils, odict
+from netlib import http, utils
from netlib.http import semantics
from . import frame
@@ -85,10 +85,10 @@ class HTTP2Protocol(semantics.ProtocolMixin):
timestamp_end = time.time()
- authority = headers.get_first(':authority', '')
- method = headers.get_first(':method', 'GET')
- scheme = headers.get_first(':scheme', 'https')
- path = headers.get_first(':path', '/')
+ authority = headers.get(':authority', '')
+ method = headers.get(':method', 'GET')
+ scheme = headers.get(':scheme', 'https')
+ path = headers.get(':path', '/')
host = None
port = None
@@ -161,7 +161,7 @@ class HTTP2Protocol(semantics.ProtocolMixin):
response = http.Response(
(2, 0),
- int(headers.get_first(':status')),
+ int(headers.get(':status', 502)),
"",
headers,
body,
@@ -181,16 +181,14 @@ class HTTP2Protocol(semantics.ProtocolMixin):
headers = request.headers.copy()
- if ':authority' not in headers.keys():
- headers.add(':authority', bytes(authority), prepend=True)
- if ':scheme' not in headers.keys():
- headers.add(':scheme', bytes(request.scheme), prepend=True)
- if ':path' not in headers.keys():
- headers.add(':path', bytes(request.path), prepend=True)
- if ':method' not in headers.keys():
- headers.add(':method', bytes(request.method), prepend=True)
-
- headers = headers.items()
+ if ':authority' not in headers:
+ headers.fields.insert(0, (':authority', bytes(authority)))
+ if ':scheme' not in headers:
+ headers.fields.insert(0, (':scheme', bytes(request.scheme)))
+ if ':path' not in headers:
+ headers.fields.insert(0, (':path', bytes(request.path)))
+ if ':method' not in headers:
+ headers.fields.insert(0, (':method', bytes(request.method)))
if hasattr(request, 'stream_id'):
stream_id = request.stream_id
@@ -206,10 +204,8 @@ class HTTP2Protocol(semantics.ProtocolMixin):
headers = response.headers.copy()
- if ':status' not in headers.keys():
- headers.add(':status', bytes(str(response.status_code)), prepend=True)
-
- headers = headers.items()
+ if ':status' not in headers:
+ headers.fields.insert(0, (':status', bytes(str(response.status_code))))
if hasattr(response, 'stream_id'):
stream_id = response.stream_id
@@ -336,7 +332,7 @@ class HTTP2Protocol(semantics.ProtocolMixin):
else:
yield frame.ContinuationFrame, i
- header_block_fragment = self.encoder.encode(headers)
+ header_block_fragment = self.encoder.encode(headers.fields)
chunk_size = self.http2_settings[frame.SettingsFrame.SETTINGS.SETTINGS_MAX_FRAME_SIZE]
chunks = range(0, len(header_block_fragment), chunk_size)
@@ -409,8 +405,8 @@ class HTTP2Protocol(semantics.ProtocolMixin):
else:
self._handle_unexpected_frame(frm)
- headers = odict.ODictCaseless()
- for header, value in self.decoder.decode(header_block_fragment):
- headers.add(header, value)
+ headers = http.Headers(
+ [[str(k), str(v)] for k, v in self.decoder.decode(header_block_fragment)]
+ )
return stream_id, headers, body
diff --git a/netlib/http/semantics.py b/netlib/http/semantics.py
index 2b960483..edf5fc07 100644
--- a/netlib/http/semantics.py
+++ b/netlib/http/semantics.py
@@ -1,8 +1,10 @@
from __future__ import (absolute_import, print_function, division)
+import UserDict
+import copy
import urllib
import urlparse
-from .. import utils, odict
+from .. import odict
from . import cookies, exceptions
from netlib import utils, encoding
@@ -12,8 +14,165 @@ HDR_FORM_MULTIPART = "multipart/form-data"
CONTENT_MISSING = 0
-class ProtocolMixin(object):
+class Headers(UserDict.DictMixin):
+ """
+ Header class which allows both convenient access to individual headers as well as
+ direct access to the underlying raw data. Provides a full dictionary interface.
+
+ Example:
+
+ .. code-block:: python
+
+ # Create header from a list of (header_name, header_value) tuples
+ >>> h = Headers([
+ ["Host","example.com"],
+ ["Accept","text/html"],
+ ["accept","application/xml"]
+ ])
+
+ # Headers mostly behave like a normal dict.
+ >>> h["Host"]
+ "example.com"
+
+ # HTTP Headers are case insensitive
+ >>> h["host"]
+ "example.com"
+
+ # Multiple headers are folded into a single header as per RFC7230
+ >>> h["Accept"]
+ "text/html, application/xml"
+
+ # Setting a header removes all existing headers with the same name.
+ >>> h["Accept"] = "application/text"
+ >>> h["Accept"]
+ "application/text"
+
+ # str(h) returns a HTTP1 header block.
+ >>> print(h)
+ Host: example.com
+ Accept: application/text
+
+ # For full control, the raw header fields can be accessed
+ >>> h.fields
+
+ # Headers can also be crated from keyword arguments
+ >>> h = Headers(host="example.com", content_type="application/xml")
+
+ Caveats:
+ For use with the "Set-Cookie" header, see :py:meth:`get_all`.
+ """
+
+ def __init__(self, fields=None, **headers):
+ """
+ Args:
+ fields: (optional) list of ``(name, value)`` header tuples, e.g. ``[("Host","example.com")]``
+ **headers: Additional headers to set. Will overwrite existing values from `fields`.
+ For convenience, underscores in header names will be transformed to dashes -
+ this behaviour does not extend to other methods.
+ If ``**headers`` contains multiple keys that have equal ``.lower()`` s,
+ the behavior is undefined.
+ """
+ self.fields = fields or []
+
+ # content_type -> content-type
+ headers = {
+ name.replace("_", "-"): value
+ for name, value in headers.iteritems()
+ }
+ self.update(headers)
+
+ def __str__(self):
+ return "\r\n".join(": ".join(field) for field in self.fields) + "\r\n"
+
+ def __getitem__(self, name):
+ values = self.get_all(name)
+ if not values:
+ raise KeyError(name)
+ else:
+ return ", ".join(values)
+
+ def __setitem__(self, name, value):
+ idx = self._index(name)
+
+ # To please the human eye, we insert at the same position the first existing header occured.
+ if idx is not None:
+ del self[name]
+ self.fields.insert(idx, [name, value])
+ else:
+ self.fields.append([name, value])
+
+ def __delitem__(self, name):
+ if name not in self:
+ raise KeyError(name)
+ name = name.lower()
+ self.fields = [
+ field for field in self.fields
+ if name != field[0].lower()
+ ]
+
+ def _index(self, name):
+ name = name.lower()
+ for i, field in enumerate(self.fields):
+ if field[0].lower() == name:
+ return i
+ return None
+
+ def keys(self):
+ seen = set()
+ names = []
+ for name, _ in self.fields:
+ name_lower = name.lower()
+ if name_lower not in seen:
+ seen.add(name_lower)
+ names.append(name)
+ return names
+
+ def __eq__(self, other):
+ if isinstance(other, Headers):
+ return self.fields == other.fields
+ return False
+
+ def __ne__(self, other):
+ return not self.__eq__(other)
+
+ def get_all(self, name, default=[]):
+ """
+ Like :py:meth:`get`, but does not fold multiple headers into a single one.
+ This is useful for Set-Cookie headers, which do not support folding.
+
+ See also: https://tools.ietf.org/html/rfc7230#section-3.2.2
+ """
+ name = name.lower()
+ values = [value for n, value in self.fields if n.lower() == name]
+ return values or default
+
+ def set_all(self, name, values):
+ """
+ Explicitly set multiple headers for the given key.
+ See: :py:meth:`get_all`
+ """
+ if name in self:
+ del self[name]
+ self.fields.extend(
+ [name, value] for value in values
+ )
+
+ def copy(self):
+ return Headers(copy.copy(self.fields))
+
+ # Implement the StateObject protocol from mitmproxy
+ def get_state(self, short=False):
+ return tuple(tuple(field) for field in self.fields)
+
+ def load_state(self, state):
+ self.fields = [list(field) for field in state]
+
+ @classmethod
+ def from_state(cls, state):
+ return cls([list(field) for field in state])
+
+class ProtocolMixin(object):
def read_request(self, *args, **kwargs): # pragma: no cover
raise NotImplementedError
@@ -47,23 +206,23 @@ class Request(object):
]
def __init__(
- self,
- form_in,
- method,
- scheme,
- host,
- port,
- path,
- httpversion,
- headers=None,
- body=None,
- timestamp_start=None,
- timestamp_end=None,
- form_out=None
+ self,
+ form_in,
+ method,
+ scheme,
+ host,
+ port,
+ path,
+ httpversion,
+ headers=None,
+ body=None,
+ timestamp_start=None,
+ timestamp_end=None,
+ form_out=None
):
if not headers:
- headers = odict.ODictCaseless()
- assert isinstance(headers, odict.ODictCaseless)
+ headers = Headers()
+ assert isinstance(headers, Headers)
self.form_in = form_in
self.method = method
@@ -80,8 +239,10 @@ class Request(object):
def __eq__(self, other):
try:
- self_d = [self.__dict__[k] for k in self.__dict__ if k not in ('timestamp_start', 'timestamp_end')]
- other_d = [other.__dict__[k] for k in other.__dict__ if k not in ('timestamp_start', 'timestamp_end')]
+ self_d = [self.__dict__[k] for k in self.__dict__ if
+ k not in ('timestamp_start', 'timestamp_end')]
+ other_d = [other.__dict__[k] for k in other.__dict__ if
+ k not in ('timestamp_start', 'timestamp_end')]
return self_d == other_d
except:
return False
@@ -134,30 +295,35 @@ class Request(object):
"if-none-match",
]
for i in delheaders:
- del self.headers[i]
+ self.headers.pop(i, None)
def anticomp(self):
"""
Modifies this request to remove headers that will compress the
resource's data.
"""
- self.headers["accept-encoding"] = ["identity"]
+ self.headers["accept-encoding"] = "identity"
def constrain_encoding(self):
"""
Limits the permissible Accept-Encoding values, based on what we can
decode appropriately.
"""
- if self.headers["accept-encoding"]:
- self.headers["accept-encoding"] = [
+ accept_encoding = self.headers.get("accept-encoding")
+ if accept_encoding:
+ self.headers["accept-encoding"] = (
', '.join(
- e for e in encoding.ENCODINGS if e in self.headers.get_first("accept-encoding"))]
+ e
+ for e in encoding.ENCODINGS
+ if e in accept_encoding
+ )
+ )
def update_host_header(self):
"""
Update the host header to reflect the current target.
"""
- self.headers["Host"] = [self.host]
+ self.headers["Host"] = self.host
def get_form(self):
"""
@@ -166,9 +332,9 @@ class Request(object):
indicates non-form data.
"""
if self.body:
- if self.headers.in_any("content-type", HDR_FORM_URLENCODED, True):
+ if HDR_FORM_URLENCODED in self.headers.get("content-type","").lower():
return self.get_form_urlencoded()
- elif self.headers.in_any("content-type", HDR_FORM_MULTIPART, True):
+ elif HDR_FORM_MULTIPART in self.headers.get("content-type","").lower():
return self.get_form_multipart()
return odict.ODict([])
@@ -178,18 +344,12 @@ class Request(object):
Returns an empty ODict if there is no data or the content-type
indicates non-form data.
"""
- if self.body and self.headers.in_any(
- "content-type",
- HDR_FORM_URLENCODED,
- True):
+ if self.body and HDR_FORM_URLENCODED in self.headers.get("content-type","").lower():
return odict.ODict(utils.urldecode(self.body))
return odict.ODict([])
def get_form_multipart(self):
- if self.body and self.headers.in_any(
- "content-type",
- HDR_FORM_MULTIPART,
- True):
+ if self.body and HDR_FORM_MULTIPART in self.headers.get("content-type","").lower():
return odict.ODict(
utils.multipartdecode(
self.headers,
@@ -204,7 +364,7 @@ class Request(object):
"""
# FIXME: If there's an existing content-type header indicating a
# url-encoded form, leave it alone.
- self.headers["Content-Type"] = [HDR_FORM_URLENCODED]
+ self.headers["Content-Type"] = HDR_FORM_URLENCODED
self.body = utils.urlencode(odict.lst)
def get_path_components(self):
@@ -263,7 +423,7 @@ class Request(object):
"""
host = None
if hostheader:
- host = self.headers.get_first("host")
+ host = self.headers.get("Host")
if not host:
host = self.host
if host:
@@ -287,7 +447,7 @@ class Request(object):
Returns a possibly empty netlib.odict.ODict object.
"""
ret = odict.ODict()
- for i in self.headers["cookie"]:
+ for i in self.headers.get_all("cookie"):
ret.extend(cookies.parse_cookie_header(i))
return ret
@@ -297,7 +457,7 @@ class Request(object):
headers.
"""
v = cookies.format_cookie_header(odict)
- self.headers["Cookie"] = [v]
+ self.headers["Cookie"] = v
@property
def url(self):
@@ -336,18 +496,17 @@ class Request(object):
class EmptyRequest(Request):
-
def __init__(
- self,
- form_in="",
- method="",
- scheme="",
- host="",
- port="",
- path="",
- httpversion=(0, 0),
- headers=None,
- body=""
+ self,
+ form_in="",
+ method="",
+ scheme="",
+ host="",
+ port="",
+ path="",
+ httpversion=(0, 0),
+ headers=None,
+ body=""
):
super(EmptyRequest, self).__init__(
form_in=form_in,
@@ -357,7 +516,7 @@ class EmptyRequest(Request):
port=port,
path=path,
httpversion=httpversion,
- headers=(headers or odict.ODictCaseless()),
+ headers=headers,
body=body,
)
@@ -370,19 +529,19 @@ class Response(object):
]
def __init__(
- self,
- httpversion,
- status_code,
- msg=None,
- headers=None,
- body=None,
- sslinfo=None,
- timestamp_start=None,
- timestamp_end=None,
+ self,
+ httpversion,
+ status_code,
+ msg=None,
+ headers=None,
+ body=None,
+ sslinfo=None,
+ timestamp_start=None,
+ timestamp_end=None,
):
if not headers:
- headers = odict.ODictCaseless()
- assert isinstance(headers, odict.ODictCaseless)
+ headers = Headers()
+ assert isinstance(headers, Headers)
self.httpversion = httpversion
self.status_code = status_code
@@ -395,8 +554,10 @@ class Response(object):
def __eq__(self, other):
try:
- self_d = [self.__dict__[k] for k in self.__dict__ if k not in ('timestamp_start', 'timestamp_end')]
- other_d = [other.__dict__[k] for k in other.__dict__ if k not in ('timestamp_start', 'timestamp_end')]
+ self_d = [self.__dict__[k] for k in self.__dict__ if
+ k not in ('timestamp_start', 'timestamp_end')]
+ other_d = [other.__dict__[k] for k in other.__dict__ if
+ k not in ('timestamp_start', 'timestamp_end')]
return self_d == other_d
except:
return False
@@ -412,9 +573,7 @@ class Response(object):
return "<Response: {status_code} {msg} ({contenttype}, {size})>".format(
status_code=self.status_code,
msg=self.msg,
- contenttype=self.headers.get_first(
- "content-type",
- "unknown content type"),
+ contenttype=self.headers.get("content-type", "unknown content type"),
size=size)
def get_cookies(self):
@@ -427,7 +586,7 @@ class Response(object):
attributes (e.g. HTTPOnly) are indicated by a Null value.
"""
ret = []
- for header in self.headers["set-cookie"]:
+ for header in self.headers.get_all("set-cookie"):
v = cookies.parse_set_cookie_header(header)
if v:
name, value, attrs = v
@@ -450,7 +609,7 @@ class Response(object):
i[1][1]
)
)
- self.headers["Set-Cookie"] = values
+ self.headers.set_all("Set-Cookie", values)
@property
def content(self): # pragma: no cover
diff --git a/netlib/tutils.py b/netlib/tutils.py
index 7434c108..951ef3d9 100644
--- a/netlib/tutils.py
+++ b/netlib/tutils.py
@@ -5,7 +5,7 @@ import time
import shutil
from contextlib import contextmanager
-from netlib import tcp, utils, odict, http
+from netlib import tcp, utils, http
def treader(bytes):
@@ -73,8 +73,8 @@ def treq(content="content", scheme="http", host="address", port=22):
"""
@return: libmproxy.protocol.http.HTTPRequest
"""
- headers = odict.ODictCaseless()
- headers["header"] = ["qvalue"]
+ headers = http.Headers()
+ headers["header"] = "qvalue"
req = http.Request(
"relative",
"GET",
@@ -108,8 +108,8 @@ def tresp(content="message"):
@return: libmproxy.protocol.http.HTTPResponse
"""
- headers = odict.ODictCaseless()
- headers["header_response"] = ["svalue"]
+ headers = http.Headers()
+ headers["header_response"] = "svalue"
resp = http.semantics.Response(
(1, 1),
diff --git a/netlib/utils.py b/netlib/utils.py
index d6190673..aae187da 100644
--- a/netlib/utils.py
+++ b/netlib/utils.py
@@ -204,11 +204,10 @@ def get_header_tokens(headers, key):
follow a pattern where each header line can containe comma-separated
tokens, and headers can be set multiple times.
"""
- toks = []
- for i in headers[key]:
- for j in i.split(","):
- toks.append(j.strip())
- return toks
+ if key not in headers:
+ return []
+ tokens = headers[key].split(",")
+ return [token.strip() for token in tokens]
def hostport(scheme, host, port):
@@ -270,11 +269,11 @@ def parse_content_type(c):
return ts[0].lower(), ts[1].lower(), d
-def multipartdecode(hdrs, content):
+def multipartdecode(headers, content):
"""
Takes a multipart boundary encoded string and returns list of (key, value) tuples.
"""
- v = hdrs.get_first("content-type")
+ v = headers.get("content-type")
if v:
v = parse_content_type(v)
if not v:
diff --git a/netlib/websockets/protocol.py b/netlib/websockets/protocol.py
index 6ce32eac..46c02875 100644
--- a/netlib/websockets/protocol.py
+++ b/netlib/websockets/protocol.py
@@ -1,10 +1,5 @@
-from __future__ import absolute_import
-import base64
-import hashlib
-import os
-from netlib import odict
-from netlib import utils
+
# Colleciton of utility functions that implement small portions of the RFC6455
# WebSockets Protocol Useful for building WebSocket clients and servers.
@@ -18,6 +13,13 @@ from netlib import utils
# The magic sha that websocket servers must know to prove they understand
# RFC6455
+from __future__ import absolute_import
+import base64
+import hashlib
+import os
+from ..http import Headers
+from .. import utils
+
websockets_magic = '258EAFA5-E914-47DA-95CA-C5AB0DC85B11'
VERSION = "13"
@@ -66,11 +68,11 @@ class WebsocketsProtocol(object):
specified, it is generated, and can be found in sec-websocket-key in
the returned header set.
- Returns an instance of ODictCaseless
+ Returns an instance of Headers
"""
if not key:
key = base64.b64encode(os.urandom(16)).decode('utf-8')
- return odict.ODictCaseless([
+ return Headers([
('Connection', 'Upgrade'),
('Upgrade', 'websocket'),
(HEADER_WEBSOCKET_KEY, key),
@@ -82,7 +84,7 @@ class WebsocketsProtocol(object):
"""
The server response is a valid HTTP 101 response.
"""
- return odict.ODictCaseless(
+ return Headers(
[
('Connection', 'Upgrade'),
('Upgrade', 'websocket'),
@@ -93,16 +95,16 @@ class WebsocketsProtocol(object):
@classmethod
def check_client_handshake(self, headers):
- if headers.get_first("upgrade", None) != "websocket":
+ if headers.get("upgrade") != "websocket":
return
- return headers.get_first(HEADER_WEBSOCKET_KEY)
+ return headers.get(HEADER_WEBSOCKET_KEY)
@classmethod
def check_server_handshake(self, headers):
- if headers.get_first("upgrade", None) != "websocket":
+ if headers.get("upgrade") != "websocket":
return
- return headers.get_first(HEADER_WEBSOCKET_ACCEPT)
+ return headers.get(HEADER_WEBSOCKET_ACCEPT)
@classmethod
diff --git a/netlib/wsgi.py b/netlib/wsgi.py
index 99afe00e..8a98884a 100644
--- a/netlib/wsgi.py
+++ b/netlib/wsgi.py
@@ -3,7 +3,7 @@ import cStringIO
import urllib
import time
import traceback
-from . import odict, tcp
+from . import http, tcp
class ClientConn(object):
@@ -68,8 +68,8 @@ class WSGIAdaptor(object):
'SCRIPT_NAME': '',
'PATH_INFO': urllib.unquote(path_info),
'QUERY_STRING': query,
- 'CONTENT_TYPE': flow.request.headers.get('Content-Type', [''])[0],
- 'CONTENT_LENGTH': flow.request.headers.get('Content-Length', [''])[0],
+ 'CONTENT_TYPE': flow.request.headers.get('Content-Type', ''),
+ 'CONTENT_LENGTH': flow.request.headers.get('Content-Length', ''),
'SERVER_NAME': self.domain,
'SERVER_PORT': str(self.port),
# FIXME: We need to pick up the protocol read from the request.
@@ -115,12 +115,12 @@ class WSGIAdaptor(object):
def write(data):
if not state["headers_sent"]:
soc.write("HTTP/1.1 %s\r\n" % state["status"])
- h = state["headers"]
- if 'server' not in h:
- h["Server"] = [self.sversion]
- if 'date' not in h:
- h["Date"] = [date_time_string()]
- soc.write(h.format())
+ headers = state["headers"]
+ if 'server' not in headers:
+ headers["Server"] = self.sversion
+ if 'date' not in headers:
+ headers["Date"] = date_time_string()
+ soc.write(str(headers))
soc.write("\r\n")
state["headers_sent"] = True
if data:
@@ -137,7 +137,7 @@ class WSGIAdaptor(object):
elif state["status"]:
raise AssertionError('Response already started')
state["status"] = status
- state["headers"] = odict.ODictCaseless(headers)
+ state["headers"] = http.Headers(headers)
return write
errs = cStringIO.StringIO()
@@ -149,7 +149,7 @@ class WSGIAdaptor(object):
write(i)
if not state["headers_sent"]:
write("")
- except Exception:
+ except Exception as e:
try:
s = traceback.format_exc()
errs.write(s)
diff --git a/test/http/http1/test_protocol.py b/test/http/http1/test_protocol.py
index 6704647f..f7c615bd 100644
--- a/test/http/http1/test_protocol.py
+++ b/test/http/http1/test_protocol.py
@@ -2,7 +2,7 @@ import cStringIO
import textwrap
from netlib import http, odict, tcp, tutils
-from netlib.http import semantics
+from netlib.http import semantics, Headers
from netlib.http.http1 import HTTP1Protocol
from ... import tservers
@@ -29,164 +29,161 @@ def test_stripped_chunked_encoding_no_content():
"""
r = tutils.treq(content="")
- r.headers["Transfer-Encoding"] = ["chunked"]
+ r.headers["Transfer-Encoding"] = "chunked"
assert "Content-Length" in mock_protocol()._assemble_request_headers(r)
r = tutils.tresp(content="")
- r.headers["Transfer-Encoding"] = ["chunked"]
+ r.headers["Transfer-Encoding"] = "chunked"
assert "Content-Length" in mock_protocol()._assemble_response_headers(r)
def test_has_chunked_encoding():
- h = odict.ODictCaseless()
- assert not HTTP1Protocol.has_chunked_encoding(h)
- h["transfer-encoding"] = ["chunked"]
- assert HTTP1Protocol.has_chunked_encoding(h)
+ headers = http.Headers()
+ assert not HTTP1Protocol.has_chunked_encoding(headers)
+ headers["transfer-encoding"] = "chunked"
+ assert HTTP1Protocol.has_chunked_encoding(headers)
def test_read_chunked():
- h = odict.ODictCaseless()
- h["transfer-encoding"] = ["chunked"]
+ headers = http.Headers()
+ headers["transfer-encoding"] = "chunked"
data = "1\r\na\r\n0\r\n"
tutils.raises(
"malformed chunked body",
mock_protocol(data).read_http_body,
- h, None, "GET", None, True
+ headers, None, "GET", None, True
)
data = "1\r\na\r\n0\r\n\r\n"
- assert mock_protocol(data).read_http_body(h, None, "GET", None, True) == "a"
+ assert mock_protocol(data).read_http_body(headers, None, "GET", None, True) == "a"
data = "\r\n\r\n1\r\na\r\n0\r\n\r\n"
- assert mock_protocol(data).read_http_body(h, None, "GET", None, True) == "a"
+ assert mock_protocol(data).read_http_body(headers, None, "GET", None, True) == "a"
data = "\r\n"
tutils.raises(
"closed prematurely",
mock_protocol(data).read_http_body,
- h, None, "GET", None, True
+ headers, None, "GET", None, True
)
data = "1\r\nfoo"
tutils.raises(
"malformed chunked body",
mock_protocol(data).read_http_body,
- h, None, "GET", None, True
+ headers, None, "GET", None, True
)
data = "foo\r\nfoo"
tutils.raises(
http.HttpError,
mock_protocol(data).read_http_body,
- h, None, "GET", None, True
+ headers, None, "GET", None, True
)
data = "5\r\naaaaa\r\n0\r\n\r\n"
- tutils.raises("too large", mock_protocol(data).read_http_body, h, 2, "GET", None, True)
+ tutils.raises("too large", mock_protocol(data).read_http_body, headers, 2, "GET", None, True)
def test_connection_close():
- h = odict.ODictCaseless()
- assert HTTP1Protocol.connection_close((1, 0), h)
- assert not HTTP1Protocol.connection_close((1, 1), h)
+ headers = Headers()
+ assert HTTP1Protocol.connection_close((1, 0), headers)
+ assert not HTTP1Protocol.connection_close((1, 1), headers)
- h["connection"] = ["keep-alive"]
- assert not HTTP1Protocol.connection_close((1, 1), h)
+ headers["connection"] = "keep-alive"
+ assert not HTTP1Protocol.connection_close((1, 1), headers)
- h["connection"] = ["close"]
- assert HTTP1Protocol.connection_close((1, 1), h)
+ headers["connection"] = "close"
+ assert HTTP1Protocol.connection_close((1, 1), headers)
def test_read_http_body_request():
- h = odict.ODictCaseless()
+ headers = Headers()
data = "testing"
- assert mock_protocol(data).read_http_body(h, None, "GET", None, True) == ""
+ assert mock_protocol(data).read_http_body(headers, None, "GET", None, True) == ""
def test_read_http_body_response():
- h = odict.ODictCaseless()
+ headers = Headers()
data = "testing"
- assert mock_protocol(data).read_http_body(h, None, "GET", 200, False) == "testing"
+ assert mock_protocol(data).read_http_body(headers, None, "GET", 200, False) == "testing"
def test_read_http_body():
# test default case
- h = odict.ODictCaseless()
- h["content-length"] = [7]
+ headers = Headers()
+ headers["content-length"] = "7"
data = "testing"
- assert mock_protocol(data).read_http_body(h, None, "GET", 200, False) == "testing"
+ assert mock_protocol(data).read_http_body(headers, None, "GET", 200, False) == "testing"
# test content length: invalid header
- h["content-length"] = ["foo"]
+ headers["content-length"] = "foo"
data = "testing"
tutils.raises(
http.HttpError,
mock_protocol(data).read_http_body,
- h, None, "GET", 200, False
+ headers, None, "GET", 200, False
)
# test content length: invalid header #2
- h["content-length"] = [-1]
+ headers["content-length"] = "-1"
data = "testing"
tutils.raises(
http.HttpError,
mock_protocol(data).read_http_body,
- h, None, "GET", 200, False
+ headers, None, "GET", 200, False
)
# test content length: content length > actual content
- h["content-length"] = [5]
+ headers["content-length"] = "5"
data = "testing"
tutils.raises(
http.HttpError,
mock_protocol(data).read_http_body,
- h, 4, "GET", 200, False
+ headers, 4, "GET", 200, False
)
# test content length: content length < actual content
data = "testing"
- assert len(mock_protocol(data).read_http_body(h, None, "GET", 200, False)) == 5
+ assert len(mock_protocol(data).read_http_body(headers, None, "GET", 200, False)) == 5
# test no content length: limit > actual content
- h = odict.ODictCaseless()
+ headers = Headers()
data = "testing"
- assert len(mock_protocol(data).read_http_body(h, 100, "GET", 200, False)) == 7
+ assert len(mock_protocol(data).read_http_body(headers, 100, "GET", 200, False)) == 7
# test no content length: limit < actual content
data = "testing"
tutils.raises(
http.HttpError,
mock_protocol(data).read_http_body,
- h, 4, "GET", 200, False
+ headers, 4, "GET", 200, False
)
# test chunked
- h = odict.ODictCaseless()
- h["transfer-encoding"] = ["chunked"]
+ headers = Headers()
+ headers["transfer-encoding"] = "chunked"
data = "5\r\naaaaa\r\n0\r\n\r\n"
- assert mock_protocol(data).read_http_body(h, 100, "GET", 200, False) == "aaaaa"
+ assert mock_protocol(data).read_http_body(headers, 100, "GET", 200, False) == "aaaaa"
def test_expected_http_body_size():
# gibber in the content-length field
- h = odict.ODictCaseless()
- h["content-length"] = ["foo"]
- assert HTTP1Protocol.expected_http_body_size(h, False, "GET", 200) is None
+ headers = Headers(content_length="foo")
+ assert HTTP1Protocol.expected_http_body_size(headers, False, "GET", 200) is None
# negative number in the content-length field
- h = odict.ODictCaseless()
- h["content-length"] = ["-7"]
- assert HTTP1Protocol.expected_http_body_size(h, False, "GET", 200) is None
+ headers = Headers(content_length="-7")
+ assert HTTP1Protocol.expected_http_body_size(headers, False, "GET", 200) is None
# explicit length
- h = odict.ODictCaseless()
- h["content-length"] = ["5"]
- assert HTTP1Protocol.expected_http_body_size(h, False, "GET", 200) == 5
+ headers = Headers(content_length="5")
+ assert HTTP1Protocol.expected_http_body_size(headers, False, "GET", 200) == 5
# no length
- h = odict.ODictCaseless()
- assert HTTP1Protocol.expected_http_body_size(h, False, "GET", 200) == -1
+ headers = Headers()
+ assert HTTP1Protocol.expected_http_body_size(headers, False, "GET", 200) == -1
# no length request
- h = odict.ODictCaseless()
- assert HTTP1Protocol.expected_http_body_size(h, True, "GET", None) == 0
+ headers = Headers()
+ assert HTTP1Protocol.expected_http_body_size(headers, True, "GET", None) == 0
def test_get_request_line():
@@ -265,8 +262,8 @@ class TestReadHeaders:
Header2: two
\r\n
"""
- h = self._read(data)
- assert h.lst == [["Header", "one"], ["Header2", "two"]]
+ headers = self._read(data)
+ assert headers.fields == [["Header", "one"], ["Header2", "two"]]
def test_read_multi(self):
data = """
@@ -274,8 +271,8 @@ class TestReadHeaders:
Header: two
\r\n
"""
- h = self._read(data)
- assert h.lst == [["Header", "one"], ["Header", "two"]]
+ headers = self._read(data)
+ assert headers.fields == [["Header", "one"], ["Header", "two"]]
def test_read_continued(self):
data = """
@@ -284,8 +281,8 @@ class TestReadHeaders:
Header2: three
\r\n
"""
- h = self._read(data)
- assert h.lst == [["Header", "one\r\n two"], ["Header2", "three"]]
+ headers = self._read(data)
+ assert headers.fields == [["Header", "one\r\n two"], ["Header2", "three"]]
def test_read_continued_err(self):
data = "\tfoo: bar\r\n"
@@ -389,7 +386,7 @@ class TestReadResponse(object):
HTTP/1.1 200
"""
assert self.tst(data, "GET", None) == http.Response(
- (1, 1), 200, '', odict.ODictCaseless(), ''
+ (1, 1), 200, '', Headers(), ''
)
def test_simple_message(self):
@@ -397,7 +394,7 @@ class TestReadResponse(object):
HTTP/1.1 200 OK
"""
assert self.tst(data, "GET", None) == http.Response(
- (1, 1), 200, 'OK', odict.ODictCaseless(), ''
+ (1, 1), 200, 'OK', Headers(), ''
)
def test_invalid_http_version(self):
@@ -419,7 +416,7 @@ class TestReadResponse(object):
HTTP/1.1 200 OK
"""
assert self.tst(data, "GET", None) == http.Response(
- (1, 1), 100, 'CONTINUE', odict.ODictCaseless(), ''
+ (1, 1), 100, 'CONTINUE', Headers(), ''
)
def test_simple_body(self):
diff --git a/test/http/http2/test_protocol.py b/test/http/http2/test_protocol.py
index 8810894f..2b7d7958 100644
--- a/test/http/http2/test_protocol.py
+++ b/test/http/http2/test_protocol.py
@@ -1,8 +1,8 @@
import OpenSSL
import mock
-from netlib import tcp, odict, http, tutils
-from netlib.http import http2
+from netlib import tcp, http, tutils
+from netlib.http import http2, Headers
from netlib.http.http2 import HTTP2Protocol
from netlib.http.http2.frame import *
from ... import tservers
@@ -229,11 +229,11 @@ class TestCreateHeaders():
c = tcp.TCPClient(("127.0.0.1", 0))
def test_create_headers(self):
- headers = [
+ headers = http.Headers([
(b':method', b'GET'),
(b':path', b'index.html'),
(b':scheme', b'https'),
- (b'foo', b'bar')]
+ (b'foo', b'bar')])
bytes = HTTP2Protocol(self.c)._create_headers(
headers, 1, end_stream=True)
@@ -248,12 +248,12 @@ class TestCreateHeaders():
.decode('hex')
def test_create_headers_multiple_frames(self):
- headers = [
+ headers = http.Headers([
(b':method', b'GET'),
(b':path', b'/'),
(b':scheme', b'https'),
(b'foo', b'bar'),
- (b'server', b'version')]
+ (b'server', b'version')])
protocol = HTTP2Protocol(self.c)
protocol.http2_settings[SettingsFrame.SETTINGS.SETTINGS_MAX_FRAME_SIZE] = 8
@@ -309,7 +309,7 @@ class TestReadRequest(tservers.ServerTestBase):
req = protocol.read_request()
assert req.stream_id
- assert req.headers.lst == [[u':method', u'GET'], [u':path', u'/'], [u':scheme', u'https']]
+ assert req.headers.fields == [[':method', 'GET'], [':path', '/'], [':scheme', 'https']]
assert req.body == b'foobar'
@@ -415,7 +415,7 @@ class TestReadResponse(tservers.ServerTestBase):
assert resp.httpversion == (2, 0)
assert resp.status_code == 200
assert resp.msg == ""
- assert resp.headers.lst == [[':status', '200'], ['etag', 'foobar']]
+ assert resp.headers.fields == [[':status', '200'], ['etag', 'foobar']]
assert resp.body == b'foobar'
assert resp.timestamp_end
@@ -442,7 +442,7 @@ class TestReadEmptyResponse(tservers.ServerTestBase):
assert resp.httpversion == (2, 0)
assert resp.status_code == 200
assert resp.msg == ""
- assert resp.headers.lst == [[':status', '200'], ['etag', 'foobar']]
+ assert resp.headers.fields == [[':status', '200'], ['etag', 'foobar']]
assert resp.body == b''
@@ -490,7 +490,7 @@ class TestAssembleRequest(object):
'',
'/',
(2, 0),
- odict.ODictCaseless([('foo', 'bar')]),
+ http.Headers([('foo', 'bar')]),
'foobar',
))
assert len(bytes) == 2
@@ -528,7 +528,7 @@ class TestAssembleResponse(object):
(2, 0),
200,
'',
- odict.ODictCaseless([('foo', 'bar')]),
+ Headers(foo="bar"),
'foobar'
))
assert len(bytes) == 2
diff --git a/test/http/test_authentication.py b/test/http/test_authentication.py
index 5261e029..17c91fe5 100644
--- a/test/http/test_authentication.py
+++ b/test/http/test_authentication.py
@@ -1,18 +1,18 @@
import binascii
-from netlib import odict, http, tutils
-from netlib.http import authentication
+from netlib import tutils
+from netlib.http import authentication, Headers
def test_parse_http_basic_auth():
vals = ("basic", "foo", "bar")
- assert http.authentication.parse_http_basic_auth(
- http.authentication.assemble_http_basic_auth(*vals)
+ assert authentication.parse_http_basic_auth(
+ authentication.assemble_http_basic_auth(*vals)
) == vals
- assert not http.authentication.parse_http_basic_auth("")
- assert not http.authentication.parse_http_basic_auth("foo bar")
+ assert not authentication.parse_http_basic_auth("")
+ assert not authentication.parse_http_basic_auth("foo bar")
v = "basic " + binascii.b2a_base64("foo")
- assert not http.authentication.parse_http_basic_auth(v)
+ assert not authentication.parse_http_basic_auth(v)
class TestPassManNonAnon:
@@ -65,35 +65,35 @@ class TestBasicProxyAuth:
def test_simple(self):
ba = authentication.BasicProxyAuth(authentication.PassManNonAnon(), "test")
- h = odict.ODictCaseless()
+ headers = Headers()
assert ba.auth_challenge_headers()
- assert not ba.authenticate(h)
+ assert not ba.authenticate(headers)
def test_authenticate_clean(self):
ba = authentication.BasicProxyAuth(authentication.PassManNonAnon(), "test")
- hdrs = odict.ODictCaseless()
+ headers = Headers()
vals = ("basic", "foo", "bar")
- hdrs[ba.AUTH_HEADER] = [authentication.assemble_http_basic_auth(*vals)]
- assert ba.authenticate(hdrs)
+ headers[ba.AUTH_HEADER] = authentication.assemble_http_basic_auth(*vals)
+ assert ba.authenticate(headers)
- ba.clean(hdrs)
- assert not ba.AUTH_HEADER in hdrs
+ ba.clean(headers)
+ assert not ba.AUTH_HEADER in headers
- hdrs[ba.AUTH_HEADER] = [""]
- assert not ba.authenticate(hdrs)
+ headers[ba.AUTH_HEADER] = ""
+ assert not ba.authenticate(headers)
- hdrs[ba.AUTH_HEADER] = ["foo"]
- assert not ba.authenticate(hdrs)
+ headers[ba.AUTH_HEADER] = "foo"
+ assert not ba.authenticate(headers)
vals = ("foo", "foo", "bar")
- hdrs[ba.AUTH_HEADER] = [authentication.assemble_http_basic_auth(*vals)]
- assert not ba.authenticate(hdrs)
+ headers[ba.AUTH_HEADER] = authentication.assemble_http_basic_auth(*vals)
+ assert not ba.authenticate(headers)
ba = authentication.BasicProxyAuth(authentication.PassMan(), "test")
vals = ("basic", "foo", "bar")
- hdrs[ba.AUTH_HEADER] = [authentication.assemble_http_basic_auth(*vals)]
- assert not ba.authenticate(hdrs)
+ headers[ba.AUTH_HEADER] = authentication.assemble_http_basic_auth(*vals)
+ assert not ba.authenticate(headers)
class Bunch:
diff --git a/test/http/test_exceptions.py b/test/http/test_exceptions.py
index d7c438f7..49588d0a 100644
--- a/test/http/test_exceptions.py
+++ b/test/http/test_exceptions.py
@@ -1,26 +1,6 @@
from netlib.http.exceptions import *
-from netlib import odict
class TestHttpError:
def test_simple(self):
e = HttpError(404, "Not found")
assert str(e)
-
-class TestHttpAuthenticationError:
- def test_init(self):
- headers = odict.ODictCaseless([("foo", "bar")])
- x = HttpAuthenticationError(headers)
- assert str(x)
- assert isinstance(x.headers, odict.ODictCaseless)
- assert x.code == 407
- assert x.headers == headers
- assert "foo" in x.headers.keys()
-
- def test_header_conversion(self):
- headers = {"foo": "bar"}
- x = HttpAuthenticationError(headers)
- assert isinstance(x.headers, odict.ODictCaseless)
- assert x.headers.lst == headers.items()
-
- def test_repr(self):
- assert repr(HttpAuthenticationError()) == "Proxy Authentication Required"
diff --git a/test/http/test_semantics.py b/test/http/test_semantics.py
index 2a799044..22fe992c 100644
--- a/test/http/test_semantics.py
+++ b/test/http/test_semantics.py
@@ -33,7 +33,7 @@ class TestRequest(object):
r = tutils.treq()
assert repr(r)
- def test_headers_odict(self):
+ def test_headers(self):
tutils.raises(AssertionError, semantics.Request,
'form_in',
'method',
@@ -54,7 +54,7 @@ class TestRequest(object):
'path',
(1, 1),
)
- assert isinstance(req.headers, odict.ODictCaseless)
+ assert isinstance(req.headers, http.Headers)
def test_equal(self):
a = tutils.treq()
@@ -76,30 +76,30 @@ class TestRequest(object):
def test_anticache(self):
req = tutils.treq()
- req.headers.add("If-Modified-Since", "foo")
- req.headers.add("If-None-Match", "bar")
+ req.headers["If-Modified-Since"] = "foo"
+ req.headers["If-None-Match"] = "bar"
req.anticache()
assert "If-Modified-Since" not in req.headers
assert "If-None-Match" not in req.headers
def test_anticomp(self):
req = tutils.treq()
- req.headers.add("Accept-Encoding", "foobar")
+ req.headers["Accept-Encoding"] = "foobar"
req.anticomp()
- assert req.headers["Accept-Encoding"] == ["identity"]
+ assert req.headers["Accept-Encoding"] == "identity"
def test_constrain_encoding(self):
req = tutils.treq()
- req.headers.add("Accept-Encoding", "identity, gzip, foo")
+ req.headers["Accept-Encoding"] = "identity, gzip, foo"
req.constrain_encoding()
- assert "foo" not in req.headers.get_first("Accept-Encoding")
+ assert "foo" not in req.headers["Accept-Encoding"]
def test_update_host(self):
req = tutils.treq()
- req.headers.add("Host", "")
+ req.headers["Host"] = ""
req.host = "foobar"
req.update_host_header()
- assert req.headers.get_first("Host") == "foobar"
+ assert req.headers["Host"] == "foobar"
def test_get_form(self):
req = tutils.treq()
@@ -113,7 +113,7 @@ class TestRequest(object):
req = tutils.treq()
req.body = "foobar"
- req.headers["Content-Type"] = [semantics.HDR_FORM_URLENCODED]
+ req.headers["Content-Type"] = semantics.HDR_FORM_URLENCODED
req.get_form()
assert req.get_form_urlencoded.called
assert not req.get_form_multipart.called
@@ -123,7 +123,7 @@ class TestRequest(object):
def test_get_form_with_multipart(self, mock_method_urlencoded, mock_method_multipart):
req = tutils.treq()
req.body = "foobar"
- req.headers["Content-Type"] = [semantics.HDR_FORM_MULTIPART]
+ req.headers["Content-Type"] = semantics.HDR_FORM_MULTIPART
req.get_form()
assert not req.get_form_urlencoded.called
assert req.get_form_multipart.called
@@ -132,23 +132,25 @@ class TestRequest(object):
req = tutils.treq("foobar")
assert req.get_form_urlencoded() == odict.ODict()
- req.headers["Content-Type"] = [semantics.HDR_FORM_URLENCODED]
+ req.headers["Content-Type"] = semantics.HDR_FORM_URLENCODED
assert req.get_form_urlencoded() == odict.ODict(utils.urldecode(req.body))
def test_get_form_multipart(self):
req = tutils.treq("foobar")
assert req.get_form_multipart() == odict.ODict()
- req.headers["Content-Type"] = [semantics.HDR_FORM_MULTIPART]
+ req.headers["Content-Type"] = semantics.HDR_FORM_MULTIPART
assert req.get_form_multipart() == odict.ODict(
utils.multipartdecode(
req.headers,
- req.body))
+ req.body
+ )
+ )
def test_set_form_urlencoded(self):
req = tutils.treq()
req.set_form_urlencoded(odict.ODict([('foo', 'bar'), ('rab', 'oof')]))
- assert req.headers.get_first("Content-Type") == semantics.HDR_FORM_URLENCODED
+ assert req.headers["Content-Type"] == semantics.HDR_FORM_URLENCODED
assert req.body
def test_get_path_components(self):
@@ -176,7 +178,7 @@ class TestRequest(object):
r = tutils.treq()
assert r.pretty_host(True) == "address"
assert r.pretty_host(False) == "address"
- r.headers["host"] = ["other"]
+ r.headers["host"] = "other"
assert r.pretty_host(True) == "other"
assert r.pretty_host(False) == "address"
r.host = None
@@ -187,7 +189,7 @@ class TestRequest(object):
assert r.pretty_host(False) is None
# Invalid IDNA
- r.headers["host"] = [".disqus.com"]
+ r.headers["host"] = ".disqus.com"
assert r.pretty_host(True) == ".disqus.com"
def test_pretty_url(self):
@@ -201,49 +203,37 @@ class TestRequest(object):
assert req.pretty_url(False) == "http://address:22/path"
def test_get_cookies_none(self):
- h = odict.ODictCaseless()
+ headers = http.Headers()
r = tutils.treq()
- r.headers = h
+ r.headers = headers
assert len(r.get_cookies()) == 0
def test_get_cookies_single(self):
- h = odict.ODictCaseless()
- h["Cookie"] = ["cookiename=cookievalue"]
r = tutils.treq()
- r.headers = h
+ r.headers = http.Headers(cookie="cookiename=cookievalue")
result = r.get_cookies()
assert len(result) == 1
assert result['cookiename'] == ['cookievalue']
def test_get_cookies_double(self):
- h = odict.ODictCaseless()
- h["Cookie"] = [
- "cookiename=cookievalue;othercookiename=othercookievalue"
- ]
r = tutils.treq()
- r.headers = h
+ r.headers = http.Headers(cookie="cookiename=cookievalue;othercookiename=othercookievalue")
result = r.get_cookies()
assert len(result) == 2
assert result['cookiename'] == ['cookievalue']
assert result['othercookiename'] == ['othercookievalue']
def test_get_cookies_withequalsign(self):
- h = odict.ODictCaseless()
- h["Cookie"] = [
- "cookiename=coo=kievalue;othercookiename=othercookievalue"
- ]
r = tutils.treq()
- r.headers = h
+ r.headers = http.Headers(cookie="cookiename=coo=kievalue;othercookiename=othercookievalue")
result = r.get_cookies()
assert len(result) == 2
assert result['cookiename'] == ['coo=kievalue']
assert result['othercookiename'] == ['othercookievalue']
def test_set_cookies(self):
- h = odict.ODictCaseless()
- h["Cookie"] = ["cookiename=cookievalue"]
r = tutils.treq()
- r.headers = h
+ r.headers = http.Headers(cookie="cookiename=cookievalue")
result = r.get_cookies()
result["cookiename"] = ["foo"]
r.set_cookies(result)
@@ -348,7 +338,7 @@ class TestEmptyRequest(object):
assert req
class TestResponse(object):
- def test_headers_odict(self):
+ def test_headers(self):
tutils.raises(AssertionError, semantics.Response,
(1, 1),
200,
@@ -359,7 +349,7 @@ class TestResponse(object):
(1, 1),
200,
)
- assert isinstance(resp.headers, odict.ODictCaseless)
+ assert isinstance(resp.headers, http.Headers)
def test_equal(self):
a = tutils.tresp()
@@ -374,32 +364,26 @@ class TestResponse(object):
def test_repr(self):
r = tutils.tresp()
assert "unknown content type" in repr(r)
- r.headers["content-type"] = ["foo"]
+ r.headers["content-type"] = "foo"
assert "foo" in repr(r)
assert repr(tutils.tresp(content=CONTENT_MISSING))
def test_get_cookies_none(self):
- h = odict.ODictCaseless()
resp = tutils.tresp()
- resp.headers = h
+ resp.headers = http.Headers()
assert not resp.get_cookies()
def test_get_cookies_simple(self):
- h = odict.ODictCaseless()
- h["Set-Cookie"] = ["cookiename=cookievalue"]
resp = tutils.tresp()
- resp.headers = h
+ resp.headers = http.Headers(set_cookie="cookiename=cookievalue")
result = resp.get_cookies()
assert len(result) == 1
assert "cookiename" in result
assert result["cookiename"][0] == ["cookievalue", odict.ODict()]
def test_get_cookies_with_parameters(self):
- h = odict.ODictCaseless()
- h["Set-Cookie"] = [
- "cookiename=cookievalue;domain=example.com;expires=Wed Oct 21 16:29:41 2015;path=/; HttpOnly"]
resp = tutils.tresp()
- resp.headers = h
+ resp.headers = http.Headers(set_cookie="cookiename=cookievalue;domain=example.com;expires=Wed Oct 21 16:29:41 2015;path=/; HttpOnly")
result = resp.get_cookies()
assert len(result) == 1
assert "cookiename" in result
@@ -412,12 +396,8 @@ class TestResponse(object):
assert attrs["httponly"] == [None]
def test_get_cookies_no_value(self):
- h = odict.ODictCaseless()
- h["Set-Cookie"] = [
- "cookiename=; Expires=Thu, 01-Jan-1970 00:00:01 GMT; path=/"
- ]
resp = tutils.tresp()
- resp.headers = h
+ resp.headers = http.Headers(set_cookie="cookiename=; Expires=Thu, 01-Jan-1970 00:00:01 GMT; path=/")
result = resp.get_cookies()
assert len(result) == 1
assert "cookiename" in result
@@ -425,10 +405,11 @@ class TestResponse(object):
assert len(result["cookiename"][0][1]) == 2
def test_get_cookies_twocookies(self):
- h = odict.ODictCaseless()
- h["Set-Cookie"] = ["cookiename=cookievalue", "othercookie=othervalue"]
resp = tutils.tresp()
- resp.headers = h
+ resp.headers = http.Headers([
+ ["Set-Cookie", "cookiename=cookievalue"],
+ ["Set-Cookie", "othercookie=othervalue"]
+ ])
result = resp.get_cookies()
assert len(result) == 2
assert "cookiename" in result
@@ -445,3 +426,148 @@ class TestResponse(object):
v = resp.get_cookies()
assert len(v) == 1
assert v["foo"] == [["bar", odict.ODictCaseless()]]
+
+
+class TestHeaders(object):
+ def _2host(self):
+ return semantics.Headers(
+ [
+ ["Host", "example.com"],
+ ["host", "example.org"]
+ ]
+ )
+
+ def test_init(self):
+ headers = semantics.Headers()
+ assert len(headers) == 0
+
+ headers = semantics.Headers([["Host", "example.com"]])
+ assert len(headers) == 1
+ assert headers["Host"] == "example.com"
+
+ headers = semantics.Headers(Host="example.com")
+ assert len(headers) == 1
+ assert headers["Host"] == "example.com"
+
+ headers = semantics.Headers(
+ [["Host", "invalid"]],
+ Host="example.com"
+ )
+ assert len(headers) == 1
+ assert headers["Host"] == "example.com"
+
+ headers = semantics.Headers(
+ [["Host", "invalid"], ["Accept", "text/plain"]],
+ Host="example.com"
+ )
+ assert len(headers) == 2
+ assert headers["Host"] == "example.com"
+ assert headers["Accept"] == "text/plain"
+
+ def test_getitem(self):
+ headers = semantics.Headers(Host="example.com")
+ assert headers["Host"] == "example.com"
+ assert headers["host"] == "example.com"
+ tutils.raises(KeyError, headers.__getitem__, "Accept")
+
+ headers = self._2host()
+ assert headers["Host"] == "example.com, example.org"
+
+ def test_str(self):
+ headers = semantics.Headers(Host="example.com")
+ assert str(headers) == "Host: example.com\r\n"
+
+ headers = semantics.Headers([
+ ["Host", "example.com"],
+ ["Accept", "text/plain"]
+ ])
+ assert str(headers) == "Host: example.com\r\nAccept: text/plain\r\n"
+
+ def test_setitem(self):
+ headers = semantics.Headers()
+ headers["Host"] = "example.com"
+ assert "Host" in headers
+ assert "host" in headers
+ assert headers["Host"] == "example.com"
+
+ headers["host"] = "example.org"
+ assert "Host" in headers
+ assert "host" in headers
+ assert headers["Host"] == "example.org"
+
+ headers["accept"] = "text/plain"
+ assert len(headers) == 2
+ assert "Accept" in headers
+ assert "Host" in headers
+
+ headers = self._2host()
+ assert len(headers.fields) == 2
+ headers["Host"] = "example.com"
+ assert len(headers.fields) == 1
+ assert "Host" in headers
+
+ def test_delitem(self):
+ headers = semantics.Headers(Host="example.com")
+ assert len(headers) == 1
+ del headers["host"]
+ assert len(headers) == 0
+ try:
+ del headers["host"]
+ except KeyError:
+ assert True
+ else:
+ assert False
+
+ headers = self._2host()
+ del headers["Host"]
+ assert len(headers) == 0
+
+ def test_keys(self):
+ headers = semantics.Headers(Host="example.com")
+ assert len(headers.keys()) == 1
+ assert headers.keys()[0] == "Host"
+
+ headers = self._2host()
+ assert len(headers.keys()) == 1
+ assert headers.keys()[0] == "Host"
+
+ def test_eq_ne(self):
+ headers1 = semantics.Headers(Host="example.com")
+ headers2 = semantics.Headers(host="example.com")
+ assert not (headers1 == headers2)
+ assert headers1 != headers2
+
+ headers1 = semantics.Headers(Host="example.com")
+ headers2 = semantics.Headers(Host="example.com")
+ assert headers1 == headers2
+ assert not (headers1 != headers2)
+
+ assert headers1 != 42
+
+ def test_get_all(self):
+ headers = self._2host()
+ assert headers.get_all("host") == ["example.com", "example.org"]
+ assert headers.get_all("accept", 42) is 42
+
+ def test_set_all(self):
+ headers = semantics.Headers(Host="example.com")
+ headers.set_all("Accept", ["text/plain"])
+ assert len(headers) == 2
+ assert "accept" in headers
+
+ headers = self._2host()
+ headers.set_all("Host", ["example.org"])
+ assert headers["host"] == "example.org"
+
+ headers.set_all("Host", ["example.org", "example.net"])
+ assert headers["host"] == "example.org, example.net"
+
+ def test_state(self):
+ headers = self._2host()
+ assert len(headers.get_state()) == 2
+ assert headers == semantics.Headers.from_state(headers.get_state())
+
+ headers2 = semantics.Headers()
+ assert headers != headers2
+ headers2.load_state(headers.get_state())
+ assert headers == headers2
diff --git a/test/test_utils.py b/test/test_utils.py
index fc7174d6..374d09ba 100644
--- a/test/test_utils.py
+++ b/test/test_utils.py
@@ -1,5 +1,5 @@
-from netlib import utils, odict, tutils
-
+from netlib import utils, tutils
+from netlib.http import Headers
def test_bidi():
b = utils.BiDi(a=1, b=2)
@@ -88,20 +88,21 @@ def test_urldecode():
def test_get_header_tokens():
- h = odict.ODictCaseless()
- assert utils.get_header_tokens(h, "foo") == []
- h["foo"] = ["bar"]
- assert utils.get_header_tokens(h, "foo") == ["bar"]
- h["foo"] = ["bar, voing"]
- assert utils.get_header_tokens(h, "foo") == ["bar", "voing"]
- h["foo"] = ["bar, voing", "oink"]
- assert utils.get_header_tokens(h, "foo") == ["bar", "voing", "oink"]
+ headers = Headers()
+ assert utils.get_header_tokens(headers, "foo") == []
+ headers["foo"] = "bar"
+ assert utils.get_header_tokens(headers, "foo") == ["bar"]
+ headers["foo"] = "bar, voing"
+ assert utils.get_header_tokens(headers, "foo") == ["bar", "voing"]
+ headers.set_all("foo", ["bar, voing", "oink"])
+ assert utils.get_header_tokens(headers, "foo") == ["bar", "voing", "oink"]
def test_multipartdecode():
boundary = 'somefancyboundary'
- headers = odict.ODict(
- [('content-type', ('multipart/form-data; boundary=%s' % boundary))])
+ headers = Headers(
+ content_type='multipart/form-data; boundary=%s' % boundary
+ )
content = "--{0}\n" \
"Content-Disposition: form-data; name=\"field1\"\n\n" \
"value1\n" \
diff --git a/test/test_wsgi.py b/test/test_wsgi.py
index 41572d49..e26e1413 100644
--- a/test/test_wsgi.py
+++ b/test/test_wsgi.py
@@ -1,12 +1,12 @@
import cStringIO
import sys
-from netlib import wsgi, odict
+from netlib import wsgi
+from netlib.http import Headers
def tflow():
- h = odict.ODictCaseless()
- h["test"] = ["value"]
- req = wsgi.Request("http", "GET", "/", h, "")
+ headers = Headers(test="value")
+ req = wsgi.Request("http", "GET", "/", headers, "")
return wsgi.Flow(("127.0.0.1", 8888), req)
diff --git a/test/websockets/test_websockets.py b/test/websockets/test_websockets.py
index be87b20a..57cfd166 100644
--- a/test/websockets/test_websockets.py
+++ b/test/websockets/test_websockets.py
@@ -42,7 +42,7 @@ class WebSocketsEchoHandler(tcp.BaseHandler):
preamble = 'HTTP/1.1 101 %s' % status_codes.RESPONSES.get(101)
self.wfile.write(preamble + "\r\n")
headers = self.protocol.server_handshake_headers(key)
- self.wfile.write(headers.format() + "\r\n")
+ self.wfile.write(str(headers) + "\r\n")
self.wfile.flush()
self.handshake_done = True
@@ -66,8 +66,8 @@ class WebSocketsClient(tcp.TCPClient):
preamble = 'GET / HTTP/1.1'
self.wfile.write(preamble + "\r\n")
headers = self.protocol.client_handshake_headers()
- self.client_nonce = headers.get_first("sec-websocket-key")
- self.wfile.write(headers.format() + "\r\n")
+ self.client_nonce = headers["sec-websocket-key"]
+ self.wfile.write(str(headers) + "\r\n")
self.wfile.flush()
resp = http1_protocol.read_response("GET", None)
@@ -145,13 +145,13 @@ class TestWebSockets(tservers.ServerTestBase):
def test_check_server_handshake(self):
headers = self.protocol.server_handshake_headers("key")
assert self.protocol.check_server_handshake(headers)
- headers["Upgrade"] = ["not_websocket"]
+ headers["Upgrade"] = "not_websocket"
assert not self.protocol.check_server_handshake(headers)
def test_check_client_handshake(self):
headers = self.protocol.client_handshake_headers("key")
assert self.protocol.check_client_handshake(headers) == "key"
- headers["Upgrade"] = ["not_websocket"]
+ headers["Upgrade"] = "not_websocket"
assert not self.protocol.check_client_handshake(headers)
@@ -166,7 +166,7 @@ class BadHandshakeHandler(WebSocketsEchoHandler):
preamble = 'HTTP/1.1 101 %s' % status_codes.RESPONSES.get(101)
self.wfile.write(preamble + "\r\n")
headers = self.protocol.server_handshake_headers("malformed key")
- self.wfile.write(headers.format() + "\r\n")
+ self.wfile.write(str(headers) + "\r\n")
self.wfile.flush()
self.handshake_done = True