aboutsummaryrefslogtreecommitdiffstats
path: root/netlib/tutils.py
diff options
context:
space:
mode:
Diffstat (limited to 'netlib/tutils.py')
-rw-r--r--netlib/tutils.py70
1 files changed, 46 insertions, 24 deletions
diff --git a/netlib/tutils.py b/netlib/tutils.py
index 951ef3d9..65c4a313 100644
--- a/netlib/tutils.py
+++ b/netlib/tutils.py
@@ -1,9 +1,11 @@
-import cStringIO
+from io import BytesIO
import tempfile
import os
import time
import shutil
from contextlib import contextmanager
+import six
+import sys
from netlib import tcp, utils, http
@@ -12,7 +14,7 @@ def treader(bytes):
"""
Construct a tcp.Read object from bytes.
"""
- fp = cStringIO.StringIO(bytes)
+ fp = BytesIO(bytes)
return tcp.Reader(fp)
@@ -28,7 +30,24 @@ def tmpdir(*args, **kwargs):
shutil.rmtree(temp_workdir)
-def raises(exc, obj, *args, **kwargs):
+def _check_exception(expected, actual, exc_tb):
+ if isinstance(expected, six.string_types):
+ if expected.lower() not in str(actual).lower():
+ six.reraise(AssertionError, AssertionError(
+ "Expected %s, but caught %s" % (
+ repr(str(expected)), actual
+ )
+ ), exc_tb)
+ else:
+ if not isinstance(actual, expected):
+ six.reraise(AssertionError, AssertionError(
+ "Expected %s, but caught %s %s" % (
+ expected.__name__, actual.__class__.__name__, str(actual)
+ )
+ ), exc_tb)
+
+
+def raises(expected_exception, obj=None, *args, **kwargs):
"""
Assert that a callable raises a specified exception.
@@ -43,28 +62,31 @@ def raises(exc, obj, *args, **kwargs):
:kwargs Arguments to be passed to the callable.
"""
- try:
- ret = obj(*args, **kwargs)
- except Exception as v:
- if isinstance(exc, basestring):
- if exc.lower() in str(v).lower():
- return
- else:
- raise AssertionError(
- "Expected %s, but caught %s" % (
- repr(str(exc)), v
- )
- )
+ if obj is None:
+ return RaisesContext(expected_exception)
+ else:
+ try:
+ ret = obj(*args, **kwargs)
+ except Exception as actual:
+ _check_exception(expected_exception, actual, sys.exc_info()[2])
else:
- if isinstance(v, exc):
- return
- else:
- raise AssertionError(
- "Expected %s, but caught %s %s" % (
- exc.__name__, v.__class__.__name__, str(v)
- )
- )
- raise AssertionError("No exception raised. Return value: {}".format(ret))
+ raise AssertionError("No exception raised. Return value: {}".format(ret))
+
+
+class RaisesContext(object):
+ def __init__(self, expected_exception):
+ self.expected_exception = expected_exception
+
+ def __enter__(self):
+ return
+
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ if not exc_type:
+ raise AssertionError("No exception raised.")
+ else:
+ _check_exception(self.expected_exception, exc_val, exc_tb)
+ return True
+
test_data = utils.Data(__name__)