aboutsummaryrefslogtreecommitdiffstats
path: root/test
diff options
context:
space:
mode:
Diffstat (limited to 'test')
-rw-r--r--test/mitmproxy/test_contrib_tnetstring.py141
-rw-r--r--test/mitmproxy/test_flow.py3
-rw-r--r--test/mitmproxy/test_protocol_http1.py8
-rw-r--r--test/mitmproxy/test_protocol_http2.py41
4 files changed, 176 insertions, 17 deletions
diff --git a/test/mitmproxy/test_contrib_tnetstring.py b/test/mitmproxy/test_contrib_tnetstring.py
new file mode 100644
index 00000000..17654ad9
--- /dev/null
+++ b/test/mitmproxy/test_contrib_tnetstring.py
@@ -0,0 +1,141 @@
+import unittest
+import random
+import math
+import io
+import struct
+
+from mitmproxy.contrib import tnetstring
+
+MAXINT = 2 ** (struct.Struct('i').size * 8 - 1) - 1
+
+FORMAT_EXAMPLES = {
+ b'0:}': {},
+ b'0:]': [],
+ b'51:5:hello,39:11:12345678901#4:this,4:true!0:~4:\x00\x00\x00\x00,]}':
+ {b'hello': [12345678901, b'this', True, None, b'\x00\x00\x00\x00']},
+ b'5:12345#': 12345,
+ b'12:this is cool,': b'this is cool',
+ b'0:,': b'',
+ b'0:~': None,
+ b'4:true!': True,
+ b'5:false!': False,
+ b'10:\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00,': b'\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00',
+ b'24:5:12345#5:67890#5:xxxxx,]': [12345, 67890, b'xxxxx'],
+ b'18:3:0.1^3:0.2^3:0.3^]': [0.1, 0.2, 0.3],
+ b'243:238:233:228:223:218:213:208:203:198:193:188:183:178:173:168:163:158:153:148:143:138:133:128:123:118:113:108:103:99:95:91:87:83:79:75:71:67:63:59:55:51:47:43:39:35:31:27:23:19:15:11:hello-there,]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]': [[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[b'hello-there']]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]] # noqa
+}
+
+
+def get_random_object(random=random, depth=0):
+ """Generate a random serializable object."""
+ # The probability of generating a scalar value increases as the depth increase.
+ # This ensures that we bottom out eventually.
+ if random.randint(depth, 10) <= 4:
+ what = random.randint(0, 1)
+ if what == 0:
+ n = random.randint(0, 10)
+ l = []
+ for _ in range(n):
+ l.append(get_random_object(random, depth + 1))
+ return l
+ if what == 1:
+ n = random.randint(0, 10)
+ d = {}
+ for _ in range(n):
+ n = random.randint(0, 100)
+ k = bytes([random.randint(32, 126) for _ in range(n)])
+ d[k] = get_random_object(random, depth + 1)
+ return d
+ else:
+ what = random.randint(0, 4)
+ if what == 0:
+ return None
+ if what == 1:
+ return True
+ if what == 2:
+ return False
+ if what == 3:
+ if random.randint(0, 1) == 0:
+ return random.randint(0, MAXINT)
+ else:
+ return -1 * random.randint(0, MAXINT)
+ n = random.randint(0, 100)
+ return bytes([random.randint(32, 126) for _ in range(n)])
+
+
+class Test_Format(unittest.TestCase):
+
+ def test_roundtrip_format_examples(self):
+ for data, expect in FORMAT_EXAMPLES.items():
+ self.assertEqual(expect, tnetstring.loads(data))
+ self.assertEqual(
+ expect, tnetstring.loads(tnetstring.dumps(expect)))
+ self.assertEqual((expect, b''), tnetstring.pop(data))
+
+ def test_roundtrip_format_random(self):
+ for _ in range(500):
+ v = get_random_object()
+ self.assertEqual(v, tnetstring.loads(tnetstring.dumps(v)))
+ self.assertEqual((v, b""), tnetstring.pop(tnetstring.dumps(v)))
+
+ def test_unicode_handling(self):
+ with self.assertRaises(ValueError):
+ tnetstring.dumps(u"hello")
+ self.assertEqual(tnetstring.dumps(u"hello".encode()), b"5:hello,")
+ self.assertEqual(type(tnetstring.loads(b"5:hello,")), bytes)
+
+ def test_roundtrip_format_unicode(self):
+ for _ in range(500):
+ v = get_random_object()
+ self.assertEqual(v, tnetstring.loads(tnetstring.dumps(v)))
+ self.assertEqual((v, b''), tnetstring.pop(tnetstring.dumps(v)))
+
+ def test_roundtrip_big_integer(self):
+ i1 = math.factorial(30000)
+ s = tnetstring.dumps(i1)
+ i2 = tnetstring.loads(s)
+ self.assertEqual(i1, i2)
+
+
+class Test_FileLoading(unittest.TestCase):
+
+ def test_roundtrip_file_examples(self):
+ for data, expect in FORMAT_EXAMPLES.items():
+ s = io.BytesIO()
+ s.write(data)
+ s.write(b'OK')
+ s.seek(0)
+ self.assertEqual(expect, tnetstring.load(s))
+ self.assertEqual(b'OK', s.read())
+ s = io.BytesIO()
+ tnetstring.dump(expect, s)
+ s.write(b'OK')
+ s.seek(0)
+ self.assertEqual(expect, tnetstring.load(s))
+ self.assertEqual(b'OK', s.read())
+
+ def test_roundtrip_file_random(self):
+ for _ in range(500):
+ v = get_random_object()
+ s = io.BytesIO()
+ tnetstring.dump(v, s)
+ s.write(b'OK')
+ s.seek(0)
+ self.assertEqual(v, tnetstring.load(s))
+ self.assertEqual(b'OK', s.read())
+
+ def test_error_on_absurd_lengths(self):
+ s = io.BytesIO()
+ s.write(b'1000000000:pwned!,')
+ s.seek(0)
+ with self.assertRaises(ValueError):
+ tnetstring.load(s)
+ self.assertEqual(s.read(1), b':')
+
+
+def suite():
+ loader = unittest.TestLoader()
+ suite = unittest.TestSuite()
+ suite.addTest(loader.loadTestsFromTestCase(Test_Format))
+ suite.addTest(loader.loadTestsFromTestCase(Test_FileLoading))
+ return suite
diff --git a/test/mitmproxy/test_flow.py b/test/mitmproxy/test_flow.py
index af8256c4..9eaab9aa 100644
--- a/test/mitmproxy/test_flow.py
+++ b/test/mitmproxy/test_flow.py
@@ -5,7 +5,8 @@ import mock
import netlib.utils
from netlib.http import Headers
-from mitmproxy import filt, controller, tnetstring, flow
+from mitmproxy import filt, controller, flow
+from mitmproxy.contrib import tnetstring
from mitmproxy.exceptions import FlowReadException, ScriptException
from mitmproxy.models import Error
from mitmproxy.models import Flow
diff --git a/test/mitmproxy/test_protocol_http1.py b/test/mitmproxy/test_protocol_http1.py
index e0a57b4e..cf7bd598 100644
--- a/test/mitmproxy/test_protocol_http1.py
+++ b/test/mitmproxy/test_protocol_http1.py
@@ -18,14 +18,14 @@ class TestInvalidRequests(tservers.HTTPProxyTest):
p = self.pathoc()
r = p.request("connect:'%s:%s'" % ("127.0.0.1", self.server2.port))
assert r.status_code == 400
- assert "Invalid HTTP request form" in r.content
+ assert b"Invalid HTTP request form" in r.content
def test_relative_request(self):
p = self.pathoc_raw()
p.connect()
r = p.request("get:/p/200")
assert r.status_code == 400
- assert "Invalid HTTP request form" in r.content
+ assert b"Invalid HTTP request form" in r.content
class TestExpectHeader(tservers.HTTPProxyTest):
@@ -43,8 +43,8 @@ class TestExpectHeader(tservers.HTTPProxyTest):
)
client.wfile.flush()
- assert client.rfile.readline() == "HTTP/1.1 100 Continue\r\n"
- assert client.rfile.readline() == "\r\n"
+ assert client.rfile.readline() == b"HTTP/1.1 100 Continue\r\n"
+ assert client.rfile.readline() == b"\r\n"
client.wfile.write(b"0123456789abcdef\r\n")
client.wfile.flush()
diff --git a/test/mitmproxy/test_protocol_http2.py b/test/mitmproxy/test_protocol_http2.py
index 23072260..932c8df2 100644
--- a/test/mitmproxy/test_protocol_http2.py
+++ b/test/mitmproxy/test_protocol_http2.py
@@ -13,6 +13,7 @@ from mitmproxy.cmdline import APP_HOST, APP_PORT
import netlib
from ..netlib import tservers as netlib_tservers
+from netlib.exceptions import HttpException
from netlib.http.http2 import framereader
from . import tservers
@@ -50,6 +51,9 @@ class _Http2ServerBase(netlib_tservers.ServerTestBase):
try:
raw = b''.join(framereader.http2_read_raw_frame(self.rfile))
events = h2_conn.receive_data(raw)
+ except HttpException:
+ print(traceback.format_exc())
+ assert False
except:
break
self.wfile.write(h2_conn.data_to_send())
@@ -60,9 +64,7 @@ class _Http2ServerBase(netlib_tservers.ServerTestBase):
if not self.server.handle_server_event(event, h2_conn, self.rfile, self.wfile):
done = True
break
- except Exception as e:
- print(repr(e))
- print(traceback.format_exc())
+ except:
done = True
break
@@ -200,9 +202,12 @@ class TestSimple(_Http2TestBase, _Http2ServerBase):
done = False
while not done:
try:
- events = h2_conn.receive_data(b''.join(framereader.http2_read_raw_frame(client.rfile)))
- except:
- break
+ raw = b''.join(framereader.http2_read_raw_frame(client.rfile))
+ events = h2_conn.receive_data(raw)
+ except HttpException:
+ print(traceback.format_exc())
+ assert False
+
client.wfile.write(h2_conn.data_to_send())
client.wfile.flush()
@@ -270,9 +275,12 @@ class TestWithBodies(_Http2TestBase, _Http2ServerBase):
done = False
while not done:
try:
- events = h2_conn.receive_data(b''.join(framereader.http2_read_raw_frame(client.rfile)))
- except:
- break
+ raw = b''.join(framereader.http2_read_raw_frame(client.rfile))
+ events = h2_conn.receive_data(raw)
+ except HttpException:
+ print(traceback.format_exc())
+ assert False
+
client.wfile.write(h2_conn.data_to_send())
client.wfile.flush()
@@ -364,6 +372,9 @@ class TestPushPromise(_Http2TestBase, _Http2ServerBase):
try:
raw = b''.join(framereader.http2_read_raw_frame(client.rfile))
events = h2_conn.receive_data(raw)
+ except HttpException:
+ print(traceback.format_exc())
+ assert False
except:
break
client.wfile.write(h2_conn.data_to_send())
@@ -412,9 +423,12 @@ class TestPushPromise(_Http2TestBase, _Http2ServerBase):
responses = 0
while not done:
try:
- events = h2_conn.receive_data(b''.join(framereader.http2_read_raw_frame(client.rfile)))
- except:
- break
+ raw = b''.join(framereader.http2_read_raw_frame(client.rfile))
+ events = h2_conn.receive_data(raw)
+ except HttpException:
+ print(traceback.format_exc())
+ assert False
+
client.wfile.write(h2_conn.data_to_send())
client.wfile.flush()
@@ -481,6 +495,9 @@ class TestConnectionLost(_Http2TestBase, _Http2ServerBase):
try:
raw = b''.join(framereader.http2_read_raw_frame(client.rfile))
h2_conn.receive_data(raw)
+ except HttpException:
+ print(traceback.format_exc())
+ assert False
except:
break
try: