aboutsummaryrefslogtreecommitdiffstats
path: root/netlib/tutils.py
diff options
context:
space:
mode:
Diffstat (limited to 'netlib/tutils.py')
-rw-r--r--netlib/tutils.py144
1 files changed, 76 insertions, 68 deletions
diff --git a/netlib/tutils.py b/netlib/tutils.py
index 951ef3d9..05791c49 100644
--- a/netlib/tutils.py
+++ b/netlib/tutils.py
@@ -1,18 +1,22 @@
-import cStringIO
+from io import BytesIO
import tempfile
import os
import time
import shutil
from contextlib import contextmanager
+import six
+import sys
-from netlib import tcp, utils, http
+from . import utils
+from .http import Request, Response, Headers
def treader(bytes):
"""
Construct a tcp.Read object from bytes.
"""
- fp = cStringIO.StringIO(bytes)
+ from . import tcp # TODO: move to top once cryptography is on Python 3.5
+ fp = BytesIO(bytes)
return tcp.Reader(fp)
@@ -28,7 +32,24 @@ def tmpdir(*args, **kwargs):
shutil.rmtree(temp_workdir)
-def raises(exc, obj, *args, **kwargs):
+def _check_exception(expected, actual, exc_tb):
+ if isinstance(expected, six.string_types):
+ if expected.lower() not in str(actual).lower():
+ six.reraise(AssertionError, AssertionError(
+ "Expected %s, but caught %s" % (
+ repr(expected), repr(actual)
+ )
+ ), exc_tb)
+ else:
+ if not isinstance(actual, expected):
+ six.reraise(AssertionError, AssertionError(
+ "Expected %s, but caught %s %s" % (
+ expected.__name__, actual.__class__.__name__, repr(actual)
+ )
+ ), exc_tb)
+
+
+def raises(expected_exception, obj=None, *args, **kwargs):
"""
Assert that a callable raises a specified exception.
@@ -43,81 +64,68 @@ def raises(exc, obj, *args, **kwargs):
:kwargs Arguments to be passed to the callable.
"""
- try:
- ret = obj(*args, **kwargs)
- except Exception as v:
- if isinstance(exc, basestring):
- if exc.lower() in str(v).lower():
- return
- else:
- raise AssertionError(
- "Expected %s, but caught %s" % (
- repr(str(exc)), v
- )
- )
+ if obj is None:
+ return RaisesContext(expected_exception)
+ else:
+ try:
+ ret = obj(*args, **kwargs)
+ except Exception as actual:
+ _check_exception(expected_exception, actual, sys.exc_info()[2])
else:
- if isinstance(v, exc):
- return
- else:
- raise AssertionError(
- "Expected %s, but caught %s %s" % (
- exc.__name__, v.__class__.__name__, str(v)
- )
- )
- raise AssertionError("No exception raised. Return value: {}".format(ret))
+ raise AssertionError("No exception raised. Return value: {}".format(ret))
-test_data = utils.Data(__name__)
+class RaisesContext(object):
+ def __init__(self, expected_exception):
+ self.expected_exception = expected_exception
-def treq(content="content", scheme="http", host="address", port=22):
- """
- @return: libmproxy.protocol.http.HTTPRequest
- """
- headers = http.Headers()
- headers["header"] = "qvalue"
- req = http.Request(
- "relative",
- "GET",
- scheme,
- host,
- port,
- "/path",
- (1, 1),
- headers,
- content,
- None,
- None,
- )
- return req
+ def __enter__(self):
+ return
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ if not exc_type:
+ raise AssertionError("No exception raised.")
+ else:
+ _check_exception(self.expected_exception, exc_val, exc_tb)
+ return True
-def treq_absolute(content="content"):
- """
- @return: libmproxy.protocol.http.HTTPRequest
- """
- r = treq(content)
- r.form_in = r.form_out = "absolute"
- r.host = "address"
- r.port = 22
- r.scheme = "http"
- return r
+test_data = utils.Data(__name__)
-def tresp(content="message"):
+
+def treq(**kwargs):
"""
- @return: libmproxy.protocol.http.HTTPResponse
+ Returns:
+ netlib.http.Request
"""
+ default = dict(
+ form_in="relative",
+ method=b"GET",
+ scheme=b"http",
+ host=b"address",
+ port=22,
+ path=b"/path",
+ httpversion=b"HTTP/1.1",
+ headers=Headers(header=b"qvalue"),
+ body=b"content"
+ )
+ default.update(kwargs)
+ return Request(**default)
- headers = http.Headers()
- headers["header_response"] = "svalue"
- resp = http.semantics.Response(
- (1, 1),
- 200,
- "OK",
- headers,
- content,
+def tresp(**kwargs):
+ """
+ Returns:
+ netlib.http.Response
+ """
+ default = dict(
+ httpversion=b"HTTP/1.1",
+ status_code=200,
+ msg=b"OK",
+ headers=Headers(header_response=b"svalue"),
+ body=b"message",
timestamp_start=time.time(),
- timestamp_end=time.time(),
+ timestamp_end=time.time()
)
- return resp
+ default.update(kwargs)
+ return Response(**default)