aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--docs/scripting/inlinescripts.rst2
-rw-r--r--examples/custom_contentviews.py2
-rw-r--r--examples/filt.py8
-rw-r--r--examples/flowwriter.py8
-rw-r--r--examples/har_extractor.py11
-rw-r--r--examples/iframe_injector.py7
-rw-r--r--examples/modify_response_body.py8
-rw-r--r--examples/proxapp.py2
-rw-r--r--examples/sslstrip.py2
-rw-r--r--examples/stub.py2
-rw-r--r--examples/tcp_message.py4
-rw-r--r--examples/tls_passthrough.py7
-rw-r--r--mitmproxy/script/script.py26
-rw-r--r--netlib/debug.py8
-rw-r--r--netlib/exceptions.py4
-rw-r--r--netlib/http/http2/connections.py4
-rw-r--r--netlib/tcp.py11
-rw-r--r--netlib/websockets/protocol.py4
-rw-r--r--pathod/language/http.py2
-rw-r--r--pathod/language/http2.py2
-rw-r--r--pathod/language/websockets.py5
-rw-r--r--pathod/log.py2
-rw-r--r--pathod/pathoc.py29
-rw-r--r--pathod/pathod.py18
-rw-r--r--pathod/protocols/websockets.py2
-rw-r--r--pathod/test.py14
-rw-r--r--test/mitmproxy/data/har_extractor.har7
-rw-r--r--test/mitmproxy/data/scripts/a.py6
-rw-r--r--test/mitmproxy/data/scripts/concurrent_decorator_err.py2
-rw-r--r--test/mitmproxy/data/scripts/starterr.py4
-rw-r--r--test/netlib/http/http2/test_connections.py200
-rw-r--r--test/netlib/test_debug.py4
-rw-r--r--test/netlib/test_tcp.py429
-rw-r--r--test/netlib/tservers.py3
-rw-r--r--test/pathod/test_pathoc.py33
-rw-r--r--test/pathod/tutils.py8
-rw-r--r--tox.ini2
37 files changed, 474 insertions, 418 deletions
diff --git a/docs/scripting/inlinescripts.rst b/docs/scripting/inlinescripts.rst
index d282dfa6..2065923d 100644
--- a/docs/scripting/inlinescripts.rst
+++ b/docs/scripting/inlinescripts.rst
@@ -44,7 +44,7 @@ to store any form of state you require.
Script Lifecycle Events
^^^^^^^^^^^^^^^^^^^^^^^
-.. py:function:: start(context, argv)
+.. py:function:: start(context)
Called once on startup, before any other events.
diff --git a/examples/custom_contentviews.py b/examples/custom_contentviews.py
index 034f356c..05ebeb69 100644
--- a/examples/custom_contentviews.py
+++ b/examples/custom_contentviews.py
@@ -62,7 +62,7 @@ class ViewPigLatin(contentviews.View):
pig_view = ViewPigLatin()
-def start(context, argv):
+def start(context):
context.add_contentview(pig_view)
diff --git a/examples/filt.py b/examples/filt.py
index f99b675c..1a423845 100644
--- a/examples/filt.py
+++ b/examples/filt.py
@@ -1,13 +1,13 @@
# This scripts demonstrates how to use mitmproxy's filter pattern in inline scripts.
# Usage: mitmdump -s "filt.py FILTER"
-
+import sys
from mitmproxy import filt
-def start(context, argv):
- if len(argv) != 2:
+def start(context):
+ if len(sys.argv) != 2:
raise ValueError("Usage: -s 'filt.py FILTER'")
- context.filter = filt.parse(argv[1])
+ context.filter = filt.parse(sys.argv[1])
def response(context, flow):
diff --git a/examples/flowwriter.py b/examples/flowwriter.py
index 8fb8cc60..cb5ccb0d 100644
--- a/examples/flowwriter.py
+++ b/examples/flowwriter.py
@@ -4,14 +4,14 @@ import sys
from mitmproxy.flow import FlowWriter
-def start(context, argv):
- if len(argv) != 2:
+def start(context):
+ if len(sys.argv) != 2:
raise ValueError('Usage: -s "flowriter.py filename"')
- if argv[1] == "-":
+ if sys.argv[1] == "-":
f = sys.stdout
else:
- f = open(argv[1], "wb")
+ f = open(sys.argv[1], "wb")
context.flow_writer = FlowWriter(f)
diff --git a/examples/har_extractor.py b/examples/har_extractor.py
index 6806989d..d6b50c21 100644
--- a/examples/har_extractor.py
+++ b/examples/har_extractor.py
@@ -3,6 +3,8 @@
https://github.com/JustusW/harparser to generate a HAR log object.
"""
import six
+import sys
+import pytz
from harparser import HAR
from datetime import datetime
@@ -52,15 +54,15 @@ class _HARLog(HAR.log):
return self.__page_list__
-def start(context, argv):
+def start(context):
"""
On start we create a HARLog instance. You will have to adapt this to
suit your actual needs of HAR generation. As it will probably be
necessary to cluster logs by IPs or reset them from time to time.
"""
context.dump_file = None
- if len(argv) > 1:
- context.dump_file = argv[1]
+ if len(sys.argv) > 1:
+ context.dump_file = sys.argv[1]
else:
raise ValueError(
'Usage: -s "har_extractor.py filename" '
@@ -119,7 +121,7 @@ def response(context, flow):
full_time = sum(v for v in timings.values() if v > -1)
started_date_time = datetime.utcfromtimestamp(
- flow.request.timestamp_start).isoformat()
+ flow.request.timestamp_start).replace(tzinfo=pytz.timezone("UTC")).isoformat()
request_query_string = [{"name": k, "value": v}
for k, v in flow.request.query or {}]
@@ -173,6 +175,7 @@ def response(context, flow):
"startedDateTime": entry['startedDateTime'],
"id": page_id,
"title": flow.request.url,
+ "pageTimings": {}
})
)
context.HARLog.set_page_ref(flow.request.url, page_id)
diff --git a/examples/iframe_injector.py b/examples/iframe_injector.py
index ad844f19..9495da93 100644
--- a/examples/iframe_injector.py
+++ b/examples/iframe_injector.py
@@ -1,13 +1,14 @@
# Usage: mitmdump -s "iframe_injector.py url"
# (this script works best with --anticache)
+import sys
from bs4 import BeautifulSoup
from mitmproxy.models import decoded
-def start(context, argv):
- if len(argv) != 2:
+def start(context):
+ if len(sys.argv) != 2:
raise ValueError('Usage: -s "iframe_injector.py url"')
- context.iframe_url = argv[1]
+ context.iframe_url = sys.argv[1]
def response(context, flow):
diff --git a/examples/modify_response_body.py b/examples/modify_response_body.py
index d68bcf63..3034892e 100644
--- a/examples/modify_response_body.py
+++ b/examples/modify_response_body.py
@@ -1,14 +1,16 @@
# Usage: mitmdump -s "modify_response_body.py mitmproxy bananas"
# (this script works best with --anticache)
+import sys
+
from mitmproxy.models import decoded
-def start(context, argv):
- if len(argv) != 3:
+def start(context):
+ if len(sys.argv) != 3:
raise ValueError('Usage: -s "modify_response_body.py old new"')
# You may want to use Python's argparse for more sophisticated argument
# parsing.
- context.old, context.new = argv[1], argv[2]
+ context.old, context.new = sys.argv[1], sys.argv[2]
def response(context, flow):
diff --git a/examples/proxapp.py b/examples/proxapp.py
index 4d8e7b58..613d3f8b 100644
--- a/examples/proxapp.py
+++ b/examples/proxapp.py
@@ -15,7 +15,7 @@ def hello_world():
# Register the app using the magic domain "proxapp" on port 80. Requests to
# this domain and port combination will now be routed to the WSGI app instance.
-def start(context, argv):
+def start(context):
context.app_registry.add(app, "proxapp", 80)
# SSL works too, but the magic domain needs to be resolvable from the mitmproxy machine due to mitmproxy's design.
diff --git a/examples/sslstrip.py b/examples/sslstrip.py
index 1bc89946..8dde8e3e 100644
--- a/examples/sslstrip.py
+++ b/examples/sslstrip.py
@@ -3,7 +3,7 @@ import re
from six.moves import urllib
-def start(context, argv):
+def start(context):
# set of SSL/TLS capable hosts
context.secure_hosts = set()
diff --git a/examples/stub.py b/examples/stub.py
index 516b71a5..a0f73538 100644
--- a/examples/stub.py
+++ b/examples/stub.py
@@ -3,7 +3,7 @@
"""
-def start(context, argv):
+def start(context):
"""
Called once on script startup, before any other events.
"""
diff --git a/examples/tcp_message.py b/examples/tcp_message.py
index 2c210618..78500c19 100644
--- a/examples/tcp_message.py
+++ b/examples/tcp_message.py
@@ -1,4 +1,4 @@
-'''
+"""
tcp_message Inline Script Hook API Demonstration
------------------------------------------------
@@ -7,7 +7,7 @@ tcp_message Inline Script Hook API Demonstration
example cmdline invocation:
mitmdump -T --host --tcp ".*" -q -s examples/tcp_message.py
-'''
+"""
from netlib import strutils
diff --git a/examples/tls_passthrough.py b/examples/tls_passthrough.py
index 0c6d450d..50aab65b 100644
--- a/examples/tls_passthrough.py
+++ b/examples/tls_passthrough.py
@@ -24,6 +24,7 @@ from __future__ import (absolute_import, print_function, division)
import collections
import random
+import sys
from enum import Enum
from mitmproxy.exceptions import TlsProtocolException
@@ -110,9 +111,9 @@ class TlsFeedback(TlsLayer):
# inline script hooks below.
-def start(context, argv):
- if len(argv) == 2:
- context.tls_strategy = ProbabilisticStrategy(float(argv[1]))
+def start(context):
+ if len(sys.argv) == 2:
+ context.tls_strategy = ProbabilisticStrategy(float(sys.argv[1]))
else:
context.tls_strategy = ConservativeStrategy()
diff --git a/mitmproxy/script/script.py b/mitmproxy/script/script.py
index 70f74817..9ff79f52 100644
--- a/mitmproxy/script/script.py
+++ b/mitmproxy/script/script.py
@@ -6,15 +6,28 @@ by the mitmproxy-specific ScriptContext.
# Do not import __future__ here, this would apply transitively to the inline scripts.
from __future__ import absolute_import, print_function, division
+import inspect
import os
import shlex
import sys
+import contextlib
+import warnings
import six
from mitmproxy import exceptions
+@contextlib.contextmanager
+def setargs(args):
+ oldargs = sys.argv
+ sys.argv = args
+ try:
+ yield
+ finally:
+ sys.argv = oldargs
+
+
class Script(object):
"""
@@ -89,7 +102,15 @@ class Script(object):
finally:
sys.path.pop()
sys.path.pop()
- return self.run("start", self.args)
+
+ start_fn = self.ns.get("start")
+ if start_fn and len(inspect.getargspec(start_fn).args) == 2:
+ warnings.warn(
+ "The 'args' argument of the start() script hook is deprecated. "
+ "Please use sys.argv instead."
+ )
+ return self.run("start", self.args)
+ return self.run("start")
def unload(self):
try:
@@ -113,7 +134,8 @@ class Script(object):
f = self.ns.get(name)
if f:
try:
- return f(self.ctx, *args, **kwargs)
+ with setargs(self.args):
+ return f(self.ctx, *args, **kwargs)
except Exception:
six.reraise(
exceptions.ScriptException,
diff --git a/netlib/debug.py b/netlib/debug.py
index 303a2f6f..a395afcb 100644
--- a/netlib/debug.py
+++ b/netlib/debug.py
@@ -1,5 +1,6 @@
from __future__ import (absolute_import, print_function, division)
+import os
import sys
import threading
import signal
@@ -93,6 +94,7 @@ def dump_stacks(signal, frame, file=sys.stdout):
print("\n".join(code), file=file)
-def register_info_dumpers(): # pragma: no cover
- signal.signal(signal.SIGUSR1, dump_info)
- signal.signal(signal.SIGUSR2, dump_stacks)
+def register_info_dumpers():
+ if os.name != "nt":
+ signal.signal(signal.SIGUSR1, dump_info)
+ signal.signal(signal.SIGUSR2, dump_stacks)
diff --git a/netlib/exceptions.py b/netlib/exceptions.py
index 05f1054b..dec79c22 100644
--- a/netlib/exceptions.py
+++ b/netlib/exceptions.py
@@ -54,3 +54,7 @@ class TlsException(NetlibException):
class InvalidCertificateException(TlsException):
pass
+
+
+class Timeout(TcpException):
+ pass
diff --git a/netlib/http/http2/connections.py b/netlib/http/http2/connections.py
index 8667d370..8f246feb 100644
--- a/netlib/http/http2/connections.py
+++ b/netlib/http/http2/connections.py
@@ -5,7 +5,7 @@ import time
import hyperframe.frame
from hpack.hpack import Encoder, Decoder
-from netlib import utils
+from netlib import utils, strutils
from netlib.http import url
import netlib.http.headers
import netlib.http.response
@@ -230,7 +230,7 @@ class HTTP2Protocol(object):
headers = response.headers.copy()
if ':status' not in headers:
- headers.insert(0, b':status', str(response.status_code).encode('ascii'))
+ headers.insert(0, b':status', strutils.always_bytes(response.status_code))
if hasattr(response, 'stream_id'):
stream_id = response.stream_id
diff --git a/netlib/tcp.py b/netlib/tcp.py
index a8a68139..69dafc1f 100644
--- a/netlib/tcp.py
+++ b/netlib/tcp.py
@@ -967,3 +967,14 @@ class TCPServer(object):
"""
Called after server shutdown.
"""
+
+ def wait_for_silence(self, timeout=5):
+ start = time.time()
+ while 1:
+ if time.time() - start >= timeout:
+ raise exceptions.Timeout(
+ "%s service threads still alive" %
+ self.handler_counter.count
+ )
+ if self.handler_counter.count == 0:
+ return
diff --git a/netlib/websockets/protocol.py b/netlib/websockets/protocol.py
index c1b7be2c..af0eef7d 100644
--- a/netlib/websockets/protocol.py
+++ b/netlib/websockets/protocol.py
@@ -20,7 +20,7 @@ import os
import six
-from netlib import http
+from netlib import http, strutils
websockets_magic = b'258EAFA5-E914-47DA-95CA-C5AB0DC85B11'
VERSION = "13"
@@ -109,4 +109,4 @@ class WebsocketsProtocol(object):
@classmethod
def create_server_nonce(self, client_nonce):
- return base64.b64encode(hashlib.sha1(client_nonce + websockets_magic).digest())
+ return base64.b64encode(hashlib.sha1(strutils.always_bytes(client_nonce) + websockets_magic).digest())
diff --git a/pathod/language/http.py b/pathod/language/http.py
index 4cc7db5f..5bd6e385 100644
--- a/pathod/language/http.py
+++ b/pathod/language/http.py
@@ -209,7 +209,7 @@ class Response(_HTTPMessage):
base.TokValueLiteral(i[1].decode()))
)
if not self.raw:
- if not get_header("Content-Length", self.headers):
+ if not get_header(b"Content-Length", self.headers):
if not self.body:
length = 0
else:
diff --git a/pathod/language/http2.py b/pathod/language/http2.py
index ea4fcd27..2693446e 100644
--- a/pathod/language/http2.py
+++ b/pathod/language/http2.py
@@ -188,7 +188,7 @@ class Response(_HTTP2Message):
body = body.string()
resp = http.Response(
- (2, 0),
+ b'HTTP/2.0',
self.status_code.string(),
b'',
headers,
diff --git a/pathod/language/websockets.py b/pathod/language/websockets.py
index 9b752b7e..417944af 100644
--- a/pathod/language/websockets.py
+++ b/pathod/language/websockets.py
@@ -1,10 +1,11 @@
import random
import string
import netlib.websockets
+from netlib import strutils
import pyparsing as pp
from . import base, generators, actions, message
-NESTED_LEADER = "pathod!"
+NESTED_LEADER = b"pathod!"
class WF(base.CaselessLiteral):
@@ -193,7 +194,7 @@ class WebsocketFrame(message.Message):
bodygen = self.rawbody.value.get_generator(settings)
length = len(self.rawbody.value.get_generator(settings))
elif self.nested_frame:
- bodygen = NESTED_LEADER + self.nested_frame.parsed.spec()
+ bodygen = NESTED_LEADER + strutils.always_bytes(self.nested_frame.parsed.spec())
length = len(bodygen)
else:
bodygen = None
diff --git a/pathod/log.py b/pathod/log.py
index 1d3ec356..23e9a2ce 100644
--- a/pathod/log.py
+++ b/pathod/log.py
@@ -62,7 +62,7 @@ class LogCtx(object):
for line in strutils.hexdump(data):
self("\t%s %s %s" % line)
else:
- for i in strutils.clean_bin(data).split("\n"):
+ for i in strutils.clean_bin(data).split(b"\n"):
self("\t%s" % i)
def __call__(self, line):
diff --git a/pathod/pathoc.py b/pathod/pathoc.py
index 21fc9845..478ce2a2 100644
--- a/pathod/pathoc.py
+++ b/pathod/pathoc.py
@@ -42,7 +42,7 @@ class SSLInfo(object):
def __str__(self):
parts = [
- "Application Layer Protocol: %s" % self.alp,
+ "Application Layer Protocol: %s" % strutils.native(self.alp, "utf8"),
"Cipher: %s, %s bit, %s" % self.cipher,
"SSL certificate chain:"
]
@@ -50,18 +50,25 @@ class SSLInfo(object):
parts.append(" Certificate [%s]" % n)
parts.append("\tSubject: ")
for cn in i.get_subject().get_components():
- parts.append("\t\t%s=%s" % cn)
+ parts.append("\t\t%s=%s" % (
+ strutils.native(cn[0], "utf8"),
+ strutils.native(cn[1], "utf8"))
+ )
parts.append("\tIssuer: ")
for cn in i.get_issuer().get_components():
- parts.append("\t\t%s=%s" % cn)
+ parts.append("\t\t%s=%s" % (
+ strutils.native(cn[0], "utf8"),
+ strutils.native(cn[1], "utf8"))
+ )
parts.extend(
[
"\tVersion: %s" % i.get_version(),
"\tValidity: %s - %s" % (
- i.get_notBefore(), i.get_notAfter()
+ strutils.native(i.get_notBefore(), "utf8"),
+ strutils.native(i.get_notAfter(), "utf8")
),
"\tSerial: %s" % i.get_serial_number(),
- "\tAlgorithm: %s" % i.get_signature_algorithm()
+ "\tAlgorithm: %s" % strutils.native(i.get_signature_algorithm(), "utf8")
]
)
pk = i.get_pubkey()
@@ -73,7 +80,7 @@ class SSLInfo(object):
parts.append("\tPubkey: %s bit %s" % (pk.bits(), t))
s = certutils.SSLCert(i)
if s.altnames:
- parts.append("\tSANs: %s" % " ".join(s.altnames))
+ parts.append("\tSANs: %s" % " ".join(strutils.native(n, "utf8") for n in s.altnames))
return "\n".join(parts)
@@ -218,7 +225,7 @@ class Pathoc(tcp.TCPClient):
"HTTP/2 requires ALPN support. "
"Please use OpenSSL >= 1.0.2. "
"Pathoc might not be working as expected without ALPN.",
- timestamp = False
+ timestamp=False
)
self.protocol = http2.HTTP2Protocol(self, dump_frames=self.http2_framedump)
else:
@@ -239,7 +246,7 @@ class Pathoc(tcp.TCPClient):
)
self.wfile.flush()
try:
- resp = self.protocol.read_response(self.rfile, treq(method="CONNECT"))
+ resp = self.protocol.read_response(self.rfile, treq(method=b"CONNECT"))
if resp.status_code != 200:
raise exceptions.HttpException("Unexpected status code: %s" % resp.status_code)
except exceptions.HttpException as e:
@@ -437,7 +444,7 @@ class Pathoc(tcp.TCPClient):
finally:
if resp:
lg("<< %s %s: %s bytes" % (
- resp.status_code, strutils.bytes_to_escaped_str(resp.reason), len(resp.content)
+ resp.status_code, strutils.bytes_to_escaped_str(resp.reason.encode()), len(resp.content)
))
if resp.status_code in self.ignorecodes:
lg.suppress()
@@ -454,8 +461,8 @@ class Pathoc(tcp.TCPClient):
May raise a exceptions.NetlibException
"""
- if isinstance(r, basestring):
- r = language.parse_pathoc(r, self.use_http2).next()
+ if isinstance(r, six.string_types):
+ r = next(language.parse_pathoc(r, self.use_http2))
if isinstance(r, language.http.Request):
if r.ws:
diff --git a/pathod/pathod.py b/pathod/pathod.py
index 315a04e0..3df86aae 100644
--- a/pathod/pathod.py
+++ b/pathod/pathod.py
@@ -4,19 +4,20 @@ import logging
import os
import sys
import threading
-import urllib
from netlib import tcp
from netlib import certutils
from netlib import websockets
from netlib import version
+
+from six.moves import urllib
from netlib.exceptions import HttpException, HttpReadDisconnect, TcpTimeout, TcpDisconnect, \
TlsException
from . import language, utils, log, protocols
-DEFAULT_CERT_DOMAIN = "pathod.net"
+DEFAULT_CERT_DOMAIN = b"pathod.net"
CONFDIR = "~/.mitmproxy"
CERTSTORE_BASENAME = "mitmproxy"
CA_CERT_NAME = "mitmproxy-ca.pem"
@@ -185,7 +186,7 @@ class PathodHandler(tcp.BaseHandler):
break
else:
if m(path.startswith(self.server.craftanchor)):
- spec = urllib.unquote(path)[len(self.server.craftanchor):]
+ spec = urllib.parse.unquote(path)[len(self.server.craftanchor):]
if spec:
try:
anchor_gen = language.parse_pathod(spec, self.use_http2)
@@ -211,7 +212,7 @@ class PathodHandler(tcp.BaseHandler):
"No valid craft request found"
)])
- spec = anchor_gen.next()
+ spec = next(anchor_gen)
if self.use_http2 and isinstance(spec, language.http2.Response):
spec.stream_id = req.stream_id
@@ -283,15 +284,10 @@ class PathodHandler(tcp.BaseHandler):
return
def addlog(self, log):
- # FIXME: The bytes in the log should not be escaped. We do this at the
- # moment because JSON encoding can't handle binary data, and I don't
- # want to base64 everything.
if self.server.logreq:
- encoded_bytes = self.rfile.get_log().encode("string_escape")
- log["request_bytes"] = encoded_bytes
+ log["request_bytes"] = self.rfile.get_log()
if self.server.logresp:
- encoded_bytes = self.wfile.get_log().encode("string_escape")
- log["response_bytes"] = encoded_bytes
+ log["response_bytes"] = self.wfile.get_log()
self.server.add_log(log)
diff --git a/pathod/protocols/websockets.py b/pathod/protocols/websockets.py
index 2b60e618..a34e75e8 100644
--- a/pathod/protocols/websockets.py
+++ b/pathod/protocols/websockets.py
@@ -37,7 +37,7 @@ class WebsocketsProtocol:
if frm.payload.startswith(ld):
nest = frm.payload[len(ld):]
try:
- wf_gen = language.parse_websocket_frame(nest)
+ wf_gen = language.parse_websocket_frame(nest.decode())
except language.exceptions.ParseException as v:
logger.write(
"Parse error in reflected frame specifcation:"
diff --git a/pathod/test.py b/pathod/test.py
index 3ba541b1..4992945d 100644
--- a/pathod/test.py
+++ b/pathod/test.py
@@ -7,10 +7,6 @@ from . import pathod
from netlib import basethread
-class TimeoutError(Exception):
- pass
-
-
class Daemon:
IFACE = "127.0.0.1"
@@ -45,15 +41,7 @@ class Daemon:
return self.logfp.getvalue()
def wait_for_silence(self, timeout=5):
- start = time.time()
- while 1:
- if time.time() - start >= timeout:
- raise TimeoutError(
- "%s service threads still alive" %
- self.thread.server.handler_counter.count
- )
- if self.thread.server.handler_counter.count == 0:
- return
+ self.thread.server.wait_for_silence(timeout=timeout)
def expect_log(self, n, timeout=5):
l = []
diff --git a/test/mitmproxy/data/har_extractor.har b/test/mitmproxy/data/har_extractor.har
index 6b5e2994..d80dc55f 100644
--- a/test/mitmproxy/data/har_extractor.har
+++ b/test/mitmproxy/data/har_extractor.har
@@ -10,15 +10,16 @@
},
"pages": [
{
- "startedDateTime": "1993-08-24T14:41:12",
+ "startedDateTime": "1993-08-24T14:41:12+00:00",
"id": "autopage_1",
- "title": "http://address:22/path"
+ "title": "http://address:22/path",
+ "pageTimings": {}
}
],
"entries": [
{
"pageref": "autopage_1",
- "startedDateTime": "1993-08-24T14:41:12",
+ "startedDateTime": "1993-08-24T14:41:12+00:00",
"cache": {},
"request": {
"cookies": [],
diff --git a/test/mitmproxy/data/scripts/a.py b/test/mitmproxy/data/scripts/a.py
index d4272ac8..33dbaa64 100644
--- a/test/mitmproxy/data/scripts/a.py
+++ b/test/mitmproxy/data/scripts/a.py
@@ -1,11 +1,13 @@
+import sys
+
from a_helper import parser
var = 0
-def start(ctx, argv):
+def start(ctx):
global var
- var = parser.parse_args(argv[1:]).var
+ var = parser.parse_args(sys.argv[1:]).var
def here(ctx):
diff --git a/test/mitmproxy/data/scripts/concurrent_decorator_err.py b/test/mitmproxy/data/scripts/concurrent_decorator_err.py
index 071b8889..349e5dd6 100644
--- a/test/mitmproxy/data/scripts/concurrent_decorator_err.py
+++ b/test/mitmproxy/data/scripts/concurrent_decorator_err.py
@@ -2,5 +2,5 @@ from mitmproxy.script import concurrent
@concurrent
-def start(context, argv):
+def start(context):
pass
diff --git a/test/mitmproxy/data/scripts/starterr.py b/test/mitmproxy/data/scripts/starterr.py
index b217bdfe..82d773bd 100644
--- a/test/mitmproxy/data/scripts/starterr.py
+++ b/test/mitmproxy/data/scripts/starterr.py
@@ -1,3 +1,3 @@
-def start(ctx, argv):
- raise ValueError
+def start(ctx):
+ raise ValueError()
diff --git a/test/netlib/http/http2/test_connections.py b/test/netlib/http/http2/test_connections.py
index 27cc30ba..2a43627a 100644
--- a/test/netlib/http/http2/test_connections.py
+++ b/test/netlib/http/http2/test_connections.py
@@ -75,10 +75,10 @@ class TestCheckALPNMatch(tservers.ServerTestBase):
def test_check_alpn(self):
c = tcp.TCPClient(("127.0.0.1", self.port))
- c.connect()
- c.convert_to_ssl(alpn_protos=[b'h2'])
- protocol = HTTP2Protocol(c)
- assert protocol.check_alpn()
+ with c.connect():
+ c.convert_to_ssl(alpn_protos=[b'h2'])
+ protocol = HTTP2Protocol(c)
+ assert protocol.check_alpn()
class TestCheckALPNMismatch(tservers.ServerTestBase):
@@ -91,11 +91,11 @@ class TestCheckALPNMismatch(tservers.ServerTestBase):
def test_check_alpn(self):
c = tcp.TCPClient(("127.0.0.1", self.port))
- c.connect()
- c.convert_to_ssl(alpn_protos=[b'h2'])
- protocol = HTTP2Protocol(c)
- with raises(NotImplementedError):
- protocol.check_alpn()
+ with c.connect():
+ c.convert_to_ssl(alpn_protos=[b'h2'])
+ protocol = HTTP2Protocol(c)
+ with raises(NotImplementedError):
+ protocol.check_alpn()
class TestPerformServerConnectionPreface(tservers.ServerTestBase):
@@ -124,15 +124,15 @@ class TestPerformServerConnectionPreface(tservers.ServerTestBase):
def test_perform_server_connection_preface(self):
c = tcp.TCPClient(("127.0.0.1", self.port))
- c.connect()
- protocol = HTTP2Protocol(c)
+ with c.connect():
+ protocol = HTTP2Protocol(c)
- assert not protocol.connection_preface_performed
- protocol.perform_server_connection_preface()
- assert protocol.connection_preface_performed
+ assert not protocol.connection_preface_performed
+ protocol.perform_server_connection_preface()
+ assert protocol.connection_preface_performed
- with raises(TcpDisconnect):
- protocol.perform_server_connection_preface(force=True)
+ with raises(TcpDisconnect):
+ protocol.perform_server_connection_preface(force=True)
class TestPerformClientConnectionPreface(tservers.ServerTestBase):
@@ -160,12 +160,12 @@ class TestPerformClientConnectionPreface(tservers.ServerTestBase):
def test_perform_client_connection_preface(self):
c = tcp.TCPClient(("127.0.0.1", self.port))
- c.connect()
- protocol = HTTP2Protocol(c)
+ with c.connect():
+ protocol = HTTP2Protocol(c)
- assert not protocol.connection_preface_performed
- protocol.perform_client_connection_preface()
- assert protocol.connection_preface_performed
+ assert not protocol.connection_preface_performed
+ protocol.perform_client_connection_preface()
+ assert protocol.connection_preface_performed
class TestClientStreamIds(object):
@@ -209,24 +209,24 @@ class TestApplySettings(tservers.ServerTestBase):
def test_apply_settings(self):
c = tcp.TCPClient(("127.0.0.1", self.port))
- c.connect()
- c.convert_to_ssl()
- protocol = HTTP2Protocol(c)
+ with c.connect():
+ c.convert_to_ssl()
+ protocol = HTTP2Protocol(c)
- protocol._apply_settings({
- hyperframe.frame.SettingsFrame.ENABLE_PUSH: 'foo',
- hyperframe.frame.SettingsFrame.MAX_CONCURRENT_STREAMS: 'bar',
- hyperframe.frame.SettingsFrame.INITIAL_WINDOW_SIZE: 'deadbeef',
- })
+ protocol._apply_settings({
+ hyperframe.frame.SettingsFrame.ENABLE_PUSH: 'foo',
+ hyperframe.frame.SettingsFrame.MAX_CONCURRENT_STREAMS: 'bar',
+ hyperframe.frame.SettingsFrame.INITIAL_WINDOW_SIZE: 'deadbeef',
+ })
- assert c.rfile.safe_read(2) == b"OK"
+ assert c.rfile.safe_read(2) == b"OK"
- assert protocol.http2_settings[
- hyperframe.frame.SettingsFrame.ENABLE_PUSH] == 'foo'
- assert protocol.http2_settings[
- hyperframe.frame.SettingsFrame.MAX_CONCURRENT_STREAMS] == 'bar'
- assert protocol.http2_settings[
- hyperframe.frame.SettingsFrame.INITIAL_WINDOW_SIZE] == 'deadbeef'
+ assert protocol.http2_settings[
+ hyperframe.frame.SettingsFrame.ENABLE_PUSH] == 'foo'
+ assert protocol.http2_settings[
+ hyperframe.frame.SettingsFrame.MAX_CONCURRENT_STREAMS] == 'bar'
+ assert protocol.http2_settings[
+ hyperframe.frame.SettingsFrame.INITIAL_WINDOW_SIZE] == 'deadbeef'
class TestCreateHeaders(object):
@@ -304,19 +304,19 @@ class TestReadRequest(tservers.ServerTestBase):
def test_read_request(self):
c = tcp.TCPClient(("127.0.0.1", self.port))
- c.connect()
- c.convert_to_ssl()
- protocol = HTTP2Protocol(c, is_server=True)
- protocol.connection_preface_performed = True
+ with c.connect():
+ c.convert_to_ssl()
+ protocol = HTTP2Protocol(c, is_server=True)
+ protocol.connection_preface_performed = True
- req = protocol.read_request(NotImplemented)
+ req = protocol.read_request(NotImplemented)
- assert req.stream_id
- assert req.headers.fields == ()
- assert req.method == "GET"
- assert req.path == "/"
- assert req.scheme == "https"
- assert req.content == b'foobar'
+ assert req.stream_id
+ assert req.headers.fields == ()
+ assert req.method == "GET"
+ assert req.path == "/"
+ assert req.scheme == "https"
+ assert req.content == b'foobar'
class TestReadRequestRelative(tservers.ServerTestBase):
@@ -330,16 +330,16 @@ class TestReadRequestRelative(tservers.ServerTestBase):
def test_asterisk_form(self):
c = tcp.TCPClient(("127.0.0.1", self.port))
- c.connect()
- c.convert_to_ssl()
- protocol = HTTP2Protocol(c, is_server=True)
- protocol.connection_preface_performed = True
+ with c.connect():
+ c.convert_to_ssl()
+ protocol = HTTP2Protocol(c, is_server=True)
+ protocol.connection_preface_performed = True
- req = protocol.read_request(NotImplemented)
+ req = protocol.read_request(NotImplemented)
- assert req.first_line_format == "relative"
- assert req.method == "OPTIONS"
- assert req.path == "*"
+ assert req.first_line_format == "relative"
+ assert req.method == "OPTIONS"
+ assert req.path == "*"
class TestReadRequestAbsolute(tservers.ServerTestBase):
@@ -353,17 +353,17 @@ class TestReadRequestAbsolute(tservers.ServerTestBase):
def test_absolute_form(self):
c = tcp.TCPClient(("127.0.0.1", self.port))
- c.connect()
- c.convert_to_ssl()
- protocol = HTTP2Protocol(c, is_server=True)
- protocol.connection_preface_performed = True
+ with c.connect():
+ c.convert_to_ssl()
+ protocol = HTTP2Protocol(c, is_server=True)
+ protocol.connection_preface_performed = True
- req = protocol.read_request(NotImplemented)
+ req = protocol.read_request(NotImplemented)
- assert req.first_line_format == "absolute"
- assert req.scheme == "http"
- assert req.host == "address"
- assert req.port == 22
+ assert req.first_line_format == "absolute"
+ assert req.scheme == "http"
+ assert req.host == "address"
+ assert req.port == 22
class TestReadRequestConnect(tservers.ServerTestBase):
@@ -379,22 +379,22 @@ class TestReadRequestConnect(tservers.ServerTestBase):
def test_connect(self):
c = tcp.TCPClient(("127.0.0.1", self.port))
- c.connect()
- c.convert_to_ssl()
- protocol = HTTP2Protocol(c, is_server=True)
- protocol.connection_preface_performed = True
+ with c.connect():
+ c.convert_to_ssl()
+ protocol = HTTP2Protocol(c, is_server=True)
+ protocol.connection_preface_performed = True
- req = protocol.read_request(NotImplemented)
- assert req.first_line_format == "authority"
- assert req.method == "CONNECT"
- assert req.host == "address"
- assert req.port == 22
+ req = protocol.read_request(NotImplemented)
+ assert req.first_line_format == "authority"
+ assert req.method == "CONNECT"
+ assert req.host == "address"
+ assert req.port == 22
- req = protocol.read_request(NotImplemented)
- assert req.first_line_format == "authority"
- assert req.method == "CONNECT"
- assert req.host == "example.com"
- assert req.port == 443
+ req = protocol.read_request(NotImplemented)
+ assert req.first_line_format == "authority"
+ assert req.method == "CONNECT"
+ assert req.host == "example.com"
+ assert req.port == 443
class TestReadResponse(tservers.ServerTestBase):
@@ -411,19 +411,19 @@ class TestReadResponse(tservers.ServerTestBase):
def test_read_response(self):
c = tcp.TCPClient(("127.0.0.1", self.port))
- c.connect()
- c.convert_to_ssl()
- protocol = HTTP2Protocol(c)
- protocol.connection_preface_performed = True
+ with c.connect():
+ c.convert_to_ssl()
+ protocol = HTTP2Protocol(c)
+ protocol.connection_preface_performed = True
- resp = protocol.read_response(NotImplemented, stream_id=42)
+ resp = protocol.read_response(NotImplemented, stream_id=42)
- assert resp.http_version == "HTTP/2.0"
- assert resp.status_code == 200
- assert resp.reason == ''
- assert resp.headers.fields == ((b':status', b'200'), (b'etag', b'foobar'))
- assert resp.content == b'foobar'
- assert resp.timestamp_end
+ assert resp.http_version == "HTTP/2.0"
+ assert resp.status_code == 200
+ assert resp.reason == ''
+ assert resp.headers.fields == ((b':status', b'200'), (b'etag', b'foobar'))
+ assert resp.content == b'foobar'
+ assert resp.timestamp_end
class TestReadEmptyResponse(tservers.ServerTestBase):
@@ -437,19 +437,19 @@ class TestReadEmptyResponse(tservers.ServerTestBase):
def test_read_empty_response(self):
c = tcp.TCPClient(("127.0.0.1", self.port))
- c.connect()
- c.convert_to_ssl()
- protocol = HTTP2Protocol(c)
- protocol.connection_preface_performed = True
+ with c.connect():
+ c.convert_to_ssl()
+ protocol = HTTP2Protocol(c)
+ protocol.connection_preface_performed = True
- resp = protocol.read_response(NotImplemented, stream_id=42)
+ resp = protocol.read_response(NotImplemented, stream_id=42)
- assert resp.stream_id == 42
- assert resp.http_version == "HTTP/2.0"
- assert resp.status_code == 200
- assert resp.reason == ''
- assert resp.headers.fields == ((b':status', b'200'), (b'etag', b'foobar'))
- assert resp.content == b''
+ assert resp.stream_id == 42
+ assert resp.http_version == "HTTP/2.0"
+ assert resp.status_code == 200
+ assert resp.reason == ''
+ assert resp.headers.fields == ((b':status', b'200'), (b'etag', b'foobar'))
+ assert resp.content == b''
class TestAssembleRequest(object):
diff --git a/test/netlib/test_debug.py b/test/netlib/test_debug.py
index b9315c7f..51710da0 100644
--- a/test/netlib/test_debug.py
+++ b/test/netlib/test_debug.py
@@ -18,3 +18,7 @@ def test_dump_stacks():
def test_sysinfo():
assert debug.sysinfo()
+
+
+def test_register_info_dumpers():
+ debug.register_info_dumpers()
diff --git a/test/netlib/test_tcp.py b/test/netlib/test_tcp.py
index 083360b4..590bcc01 100644
--- a/test/netlib/test_tcp.py
+++ b/test/netlib/test_tcp.py
@@ -39,8 +39,21 @@ class ClientCipherListHandler(tcp.BaseHandler):
class HangHandler(tcp.BaseHandler):
def handle(self):
+ # Hang as long as the client connection is alive
while True:
- time.sleep(1)
+ try:
+ self.connection.setblocking(0)
+ ret = self.connection.recv(1)
+ # Client connection is dead...
+ if ret == "" or ret == b"":
+ return
+ except socket.error:
+ pass
+ except SSL.WantReadError:
+ pass
+ except Exception:
+ return
+ time.sleep(0.1)
class ALPNHandler(tcp.BaseHandler):
@@ -61,18 +74,18 @@ class TestServer(tservers.ServerTestBase):
def test_echo(self):
testval = b"echo!\n"
c = tcp.TCPClient(("127.0.0.1", self.port))
- c.connect()
- c.wfile.write(testval)
- c.wfile.flush()
- assert c.rfile.readline() == testval
+ with c.connect():
+ c.wfile.write(testval)
+ c.wfile.flush()
+ assert c.rfile.readline() == testval
def test_thread_start_error(self):
with mock.patch.object(threading.Thread, "start", side_effect=threading.ThreadError("nonewthread")) as m:
c = tcp.TCPClient(("127.0.0.1", self.port))
- c.connect()
- assert not c.rfile.read(1)
- assert m.called
- assert "nonewthread" in self.q.get_nowait()
+ with c.connect():
+ assert not c.rfile.read(1)
+ assert m.called
+ assert "nonewthread" in self.q.get_nowait()
self.test_echo()
@@ -92,9 +105,9 @@ class TestServerBind(tservers.ServerTestBase):
c = tcp.TCPClient(
("127.0.0.1", self.port), source_address=(
"127.0.0.1", random_port))
- c.connect()
- assert c.rfile.readline() == str(("127.0.0.1", random_port)).encode()
- return
+ with c.connect():
+ assert c.rfile.readline() == str(("127.0.0.1", random_port)).encode()
+ return
except TcpException: # port probably already in use
pass
@@ -106,10 +119,10 @@ class TestServerIPv6(tservers.ServerTestBase):
def test_echo(self):
testval = b"echo!\n"
c = tcp.TCPClient(tcp.Address(("::1", self.port), use_ipv6=True))
- c.connect()
- c.wfile.write(testval)
- c.wfile.flush()
- assert c.rfile.readline() == testval
+ with c.connect():
+ c.wfile.write(testval)
+ c.wfile.flush()
+ assert c.rfile.readline() == testval
class TestEcho(tservers.ServerTestBase):
@@ -118,10 +131,10 @@ class TestEcho(tservers.ServerTestBase):
def test_echo(self):
testval = b"echo!\n"
c = tcp.TCPClient(("127.0.0.1", self.port))
- c.connect()
- c.wfile.write(testval)
- c.wfile.flush()
- assert c.rfile.readline() == testval
+ with c.connect():
+ c.wfile.write(testval)
+ c.wfile.flush()
+ assert c.rfile.readline() == testval
class HardDisconnectHandler(tcp.BaseHandler):
@@ -140,10 +153,10 @@ class TestFinishFail(tservers.ServerTestBase):
def test_disconnect_in_finish(self):
c = tcp.TCPClient(("127.0.0.1", self.port))
- c.connect()
- c.wfile.write(b"foo\n")
- c.wfile.flush = mock.Mock(side_effect=TcpDisconnect)
- c.finish()
+ with c.connect():
+ c.wfile.write(b"foo\n")
+ c.wfile.flush = mock.Mock(side_effect=TcpDisconnect)
+ c.finish()
class TestServerSSL(tservers.ServerTestBase):
@@ -155,21 +168,21 @@ class TestServerSSL(tservers.ServerTestBase):
def test_echo(self):
c = tcp.TCPClient(("127.0.0.1", self.port))
- c.connect()
- c.convert_to_ssl(sni=b"foo.com", options=SSL.OP_ALL)
- testval = b"echo!\n"
- c.wfile.write(testval)
- c.wfile.flush()
- assert c.rfile.readline() == testval
+ with c.connect():
+ c.convert_to_ssl(sni=b"foo.com", options=SSL.OP_ALL)
+ testval = b"echo!\n"
+ c.wfile.write(testval)
+ c.wfile.flush()
+ assert c.rfile.readline() == testval
def test_get_current_cipher(self):
c = tcp.TCPClient(("127.0.0.1", self.port))
- c.connect()
- assert not c.get_current_cipher()
- c.convert_to_ssl(sni=b"foo.com")
- ret = c.get_current_cipher()
- assert ret
- assert "AES" in ret[0]
+ with c.connect():
+ assert not c.get_current_cipher()
+ c.convert_to_ssl(sni=b"foo.com")
+ ret = c.get_current_cipher()
+ assert ret
+ assert "AES" in ret[0]
class TestSSLv3Only(tservers.ServerTestBase):
@@ -181,8 +194,8 @@ class TestSSLv3Only(tservers.ServerTestBase):
def test_failure(self):
c = tcp.TCPClient(("127.0.0.1", self.port))
- c.connect()
- tutils.raises(TlsException, c.convert_to_ssl, sni=b"foo.com")
+ with c.connect():
+ tutils.raises(TlsException, c.convert_to_ssl, sni=b"foo.com")
class TestSSLUpstreamCertVerificationWBadServerCert(tservers.ServerTestBase):
@@ -195,49 +208,46 @@ class TestSSLUpstreamCertVerificationWBadServerCert(tservers.ServerTestBase):
def test_mode_default_should_pass(self):
c = tcp.TCPClient(("127.0.0.1", self.port))
- c.connect()
-
- c.convert_to_ssl()
+ with c.connect():
+ c.convert_to_ssl()
- # Verification errors should be saved even if connection isn't aborted
- # aborted
- assert c.ssl_verification_error is not None
+ # Verification errors should be saved even if connection isn't aborted
+ # aborted
+ assert c.ssl_verification_error is not None
- testval = b"echo!\n"
- c.wfile.write(testval)
- c.wfile.flush()
- assert c.rfile.readline() == testval
+ testval = b"echo!\n"
+ c.wfile.write(testval)
+ c.wfile.flush()
+ assert c.rfile.readline() == testval
def test_mode_none_should_pass(self):
c = tcp.TCPClient(("127.0.0.1", self.port))
- c.connect()
-
- c.convert_to_ssl(verify_options=SSL.VERIFY_NONE)
+ with c.connect():
+ c.convert_to_ssl(verify_options=SSL.VERIFY_NONE)
- # Verification errors should be saved even if connection isn't aborted
- assert c.ssl_verification_error is not None
+ # Verification errors should be saved even if connection isn't aborted
+ assert c.ssl_verification_error is not None
- testval = b"echo!\n"
- c.wfile.write(testval)
- c.wfile.flush()
- assert c.rfile.readline() == testval
+ testval = b"echo!\n"
+ c.wfile.write(testval)
+ c.wfile.flush()
+ assert c.rfile.readline() == testval
def test_mode_strict_should_fail(self):
c = tcp.TCPClient(("127.0.0.1", self.port))
- c.connect()
-
- with tutils.raises(InvalidCertificateException):
- c.convert_to_ssl(
- sni=b"example.mitmproxy.org",
- verify_options=SSL.VERIFY_PEER,
- ca_pemfile=tutils.test_data.path("data/verificationcerts/trusted-root.crt")
- )
+ with c.connect():
+ with tutils.raises(InvalidCertificateException):
+ c.convert_to_ssl(
+ sni=b"example.mitmproxy.org",
+ verify_options=SSL.VERIFY_PEER,
+ ca_pemfile=tutils.test_data.path("data/verificationcerts/trusted-root.crt")
+ )
- assert c.ssl_verification_error is not None
+ assert c.ssl_verification_error is not None
- # Unknown issuing certificate authority for first certificate
- assert c.ssl_verification_error['errno'] == 18
- assert c.ssl_verification_error['depth'] == 0
+ # Unknown issuing certificate authority for first certificate
+ assert c.ssl_verification_error['errno'] == 18
+ assert c.ssl_verification_error['depth'] == 0
class TestSSLUpstreamCertVerificationWBadHostname(tservers.ServerTestBase):
@@ -250,26 +260,23 @@ class TestSSLUpstreamCertVerificationWBadHostname(tservers.ServerTestBase):
def test_should_fail_without_sni(self):
c = tcp.TCPClient(("127.0.0.1", self.port))
- c.connect()
-
- with tutils.raises(TlsException):
- c.convert_to_ssl(
- verify_options=SSL.VERIFY_PEER,
- ca_pemfile=tutils.test_data.path("data/verificationcerts/trusted-root.crt")
- )
+ with c.connect():
+ with tutils.raises(TlsException):
+ c.convert_to_ssl(
+ verify_options=SSL.VERIFY_PEER,
+ ca_pemfile=tutils.test_data.path("data/verificationcerts/trusted-root.crt")
+ )
def test_should_fail(self):
c = tcp.TCPClient(("127.0.0.1", self.port))
- c.connect()
-
- with tutils.raises(InvalidCertificateException):
- c.convert_to_ssl(
- sni=b"mitmproxy.org",
- verify_options=SSL.VERIFY_PEER,
- ca_pemfile=tutils.test_data.path("data/verificationcerts/trusted-root.crt")
- )
-
- assert c.ssl_verification_error is not None
+ with c.connect():
+ with tutils.raises(InvalidCertificateException):
+ c.convert_to_ssl(
+ sni=b"mitmproxy.org",
+ verify_options=SSL.VERIFY_PEER,
+ ca_pemfile=tutils.test_data.path("data/verificationcerts/trusted-root.crt")
+ )
+ assert c.ssl_verification_error is not None
class TestSSLUpstreamCertVerificationWValidCertChain(tservers.ServerTestBase):
@@ -282,37 +289,35 @@ class TestSSLUpstreamCertVerificationWValidCertChain(tservers.ServerTestBase):
def test_mode_strict_w_pemfile_should_pass(self):
c = tcp.TCPClient(("127.0.0.1", self.port))
- c.connect()
-
- c.convert_to_ssl(
- sni=b"example.mitmproxy.org",
- verify_options=SSL.VERIFY_PEER,
- ca_pemfile=tutils.test_data.path("data/verificationcerts/trusted-root.crt")
- )
+ with c.connect():
+ c.convert_to_ssl(
+ sni=b"example.mitmproxy.org",
+ verify_options=SSL.VERIFY_PEER,
+ ca_pemfile=tutils.test_data.path("data/verificationcerts/trusted-root.crt")
+ )
- assert c.ssl_verification_error is None
+ assert c.ssl_verification_error is None
- testval = b"echo!\n"
- c.wfile.write(testval)
- c.wfile.flush()
- assert c.rfile.readline() == testval
+ testval = b"echo!\n"
+ c.wfile.write(testval)
+ c.wfile.flush()
+ assert c.rfile.readline() == testval
def test_mode_strict_w_cadir_should_pass(self):
c = tcp.TCPClient(("127.0.0.1", self.port))
- c.connect()
-
- c.convert_to_ssl(
- sni=b"example.mitmproxy.org",
- verify_options=SSL.VERIFY_PEER,
- ca_path=tutils.test_data.path("data/verificationcerts/")
- )
+ with c.connect():
+ c.convert_to_ssl(
+ sni=b"example.mitmproxy.org",
+ verify_options=SSL.VERIFY_PEER,
+ ca_path=tutils.test_data.path("data/verificationcerts/")
+ )
- assert c.ssl_verification_error is None
+ assert c.ssl_verification_error is None
- testval = b"echo!\n"
- c.wfile.write(testval)
- c.wfile.flush()
- assert c.rfile.readline() == testval
+ testval = b"echo!\n"
+ c.wfile.write(testval)
+ c.wfile.flush()
+ assert c.rfile.readline() == testval
class TestSSLClientCert(tservers.ServerTestBase):
@@ -334,19 +339,19 @@ class TestSSLClientCert(tservers.ServerTestBase):
def test_clientcert(self):
c = tcp.TCPClient(("127.0.0.1", self.port))
- c.connect()
- c.convert_to_ssl(
- cert=tutils.test_data.path("data/clientcert/client.pem"))
- assert c.rfile.readline().strip() == b"1"
+ with c.connect():
+ c.convert_to_ssl(
+ cert=tutils.test_data.path("data/clientcert/client.pem"))
+ assert c.rfile.readline().strip() == b"1"
def test_clientcert_err(self):
c = tcp.TCPClient(("127.0.0.1", self.port))
- c.connect()
- tutils.raises(
- TlsException,
- c.convert_to_ssl,
- cert=tutils.test_data.path("data/clientcert/make")
- )
+ with c.connect():
+ tutils.raises(
+ TlsException,
+ c.convert_to_ssl,
+ cert=tutils.test_data.path("data/clientcert/make")
+ )
class TestSNI(tservers.ServerTestBase):
@@ -365,10 +370,10 @@ class TestSNI(tservers.ServerTestBase):
def test_echo(self):
c = tcp.TCPClient(("127.0.0.1", self.port))
- c.connect()
- c.convert_to_ssl(sni=b"foo.com")
- assert c.sni == b"foo.com"
- assert c.rfile.readline() == b"foo.com"
+ with c.connect():
+ c.convert_to_ssl(sni=b"foo.com")
+ assert c.sni == b"foo.com"
+ assert c.rfile.readline() == b"foo.com"
class TestServerCipherList(tservers.ServerTestBase):
@@ -379,9 +384,9 @@ class TestServerCipherList(tservers.ServerTestBase):
def test_echo(self):
c = tcp.TCPClient(("127.0.0.1", self.port))
- c.connect()
- c.convert_to_ssl(sni=b"foo.com")
- assert c.rfile.readline() == b"['RC4-SHA']"
+ with c.connect():
+ c.convert_to_ssl(sni=b"foo.com")
+ assert c.rfile.readline() == b"['RC4-SHA']"
class TestServerCurrentCipher(tservers.ServerTestBase):
@@ -399,9 +404,9 @@ class TestServerCurrentCipher(tservers.ServerTestBase):
def test_echo(self):
c = tcp.TCPClient(("127.0.0.1", self.port))
- c.connect()
- c.convert_to_ssl(sni=b"foo.com")
- assert b"RC4-SHA" in c.rfile.readline()
+ with c.connect():
+ c.convert_to_ssl(sni=b"foo.com")
+ assert b"RC4-SHA" in c.rfile.readline()
class TestServerCipherListError(tservers.ServerTestBase):
@@ -412,8 +417,8 @@ class TestServerCipherListError(tservers.ServerTestBase):
def test_echo(self):
c = tcp.TCPClient(("127.0.0.1", self.port))
- c.connect()
- tutils.raises("handshake error", c.convert_to_ssl, sni=b"foo.com")
+ with c.connect():
+ tutils.raises("handshake error", c.convert_to_ssl, sni=b"foo.com")
class TestClientCipherListError(tservers.ServerTestBase):
@@ -424,12 +429,13 @@ class TestClientCipherListError(tservers.ServerTestBase):
def test_echo(self):
c = tcp.TCPClient(("127.0.0.1", self.port))
- c.connect()
- tutils.raises(
- "cipher specification",
- c.convert_to_ssl,
- sni=b"foo.com",
- cipher_list="bogus")
+ with c.connect():
+ tutils.raises(
+ "cipher specification",
+ c.convert_to_ssl,
+ sni=b"foo.com",
+ cipher_list="bogus"
+ )
class TestSSLDisconnect(tservers.ServerTestBase):
@@ -443,13 +449,13 @@ class TestSSLDisconnect(tservers.ServerTestBase):
def test_echo(self):
c = tcp.TCPClient(("127.0.0.1", self.port))
- c.connect()
- c.convert_to_ssl()
- # Excercise SSL.ZeroReturnError
- c.rfile.read(10)
- c.close()
- tutils.raises(TcpDisconnect, c.wfile.write, b"foo")
- tutils.raises(queue.Empty, self.q.get_nowait)
+ with c.connect():
+ c.convert_to_ssl()
+ # Excercise SSL.ZeroReturnError
+ c.rfile.read(10)
+ c.close()
+ tutils.raises(TcpDisconnect, c.wfile.write, b"foo")
+ tutils.raises(queue.Empty, self.q.get_nowait)
class TestSSLHardDisconnect(tservers.ServerTestBase):
@@ -458,23 +464,23 @@ class TestSSLHardDisconnect(tservers.ServerTestBase):
def test_echo(self):
c = tcp.TCPClient(("127.0.0.1", self.port))
- c.connect()
- c.convert_to_ssl()
- # Exercise SSL.SysCallError
- c.rfile.read(10)
- c.close()
- tutils.raises(TcpDisconnect, c.wfile.write, b"foo")
+ with c.connect():
+ c.convert_to_ssl()
+ # Exercise SSL.SysCallError
+ c.rfile.read(10)
+ c.close()
+ tutils.raises(TcpDisconnect, c.wfile.write, b"foo")
class TestDisconnect(tservers.ServerTestBase):
def test_echo(self):
c = tcp.TCPClient(("127.0.0.1", self.port))
- c.connect()
- c.rfile.read(10)
- c.wfile.write(b"foo")
- c.close()
- c.close()
+ with c.connect():
+ c.rfile.read(10)
+ c.wfile.write(b"foo")
+ c.close()
+ c.close()
class TestServerTimeOut(tservers.ServerTestBase):
@@ -491,9 +497,9 @@ class TestServerTimeOut(tservers.ServerTestBase):
def test_timeout(self):
c = tcp.TCPClient(("127.0.0.1", self.port))
- c.connect()
- time.sleep(0.3)
- assert self.last_handler.timeout
+ with c.connect():
+ time.sleep(0.3)
+ assert self.last_handler.timeout
class TestTimeOut(tservers.ServerTestBase):
@@ -501,10 +507,10 @@ class TestTimeOut(tservers.ServerTestBase):
def test_timeout(self):
c = tcp.TCPClient(("127.0.0.1", self.port))
- c.connect()
- c.settimeout(0.1)
- assert c.gettimeout() == 0.1
- tutils.raises(TcpTimeout, c.rfile.read, 10)
+ with c.connect():
+ c.settimeout(0.1)
+ assert c.gettimeout() == 0.1
+ tutils.raises(TcpTimeout, c.rfile.read, 10)
class TestALPNClient(tservers.ServerTestBase):
@@ -516,25 +522,25 @@ class TestALPNClient(tservers.ServerTestBase):
if tcp.HAS_ALPN:
def test_alpn(self):
c = tcp.TCPClient(("127.0.0.1", self.port))
- c.connect()
- c.convert_to_ssl(alpn_protos=[b"foo", b"bar", b"fasel"])
- assert c.get_alpn_proto_negotiated() == b"bar"
- assert c.rfile.readline().strip() == b"bar"
+ with c.connect():
+ c.convert_to_ssl(alpn_protos=[b"foo", b"bar", b"fasel"])
+ assert c.get_alpn_proto_negotiated() == b"bar"
+ assert c.rfile.readline().strip() == b"bar"
def test_no_alpn(self):
c = tcp.TCPClient(("127.0.0.1", self.port))
- c.connect()
- c.convert_to_ssl()
- assert c.get_alpn_proto_negotiated() == b""
- assert c.rfile.readline().strip() == b"NONE"
+ with c.connect():
+ c.convert_to_ssl()
+ assert c.get_alpn_proto_negotiated() == b""
+ assert c.rfile.readline().strip() == b"NONE"
else:
def test_none_alpn(self):
c = tcp.TCPClient(("127.0.0.1", self.port))
- c.connect()
- c.convert_to_ssl(alpn_protos=[b"foo", b"bar", b"fasel"])
- assert c.get_alpn_proto_negotiated() == b""
- assert c.rfile.readline() == b"NONE"
+ with c.connect():
+ c.convert_to_ssl(alpn_protos=[b"foo", b"bar", b"fasel"])
+ assert c.get_alpn_proto_negotiated() == b""
+ assert c.rfile.readline() == b"NONE"
class TestNoSSLNoALPNClient(tservers.ServerTestBase):
@@ -542,9 +548,9 @@ class TestNoSSLNoALPNClient(tservers.ServerTestBase):
def test_no_ssl_no_alpn(self):
c = tcp.TCPClient(("127.0.0.1", self.port))
- c.connect()
- assert c.get_alpn_proto_negotiated() == b""
- assert c.rfile.readline().strip() == b"NONE"
+ with c.connect():
+ assert c.get_alpn_proto_negotiated() == b""
+ assert c.rfile.readline().strip() == b"NONE"
class TestSSLTimeOut(tservers.ServerTestBase):
@@ -553,10 +559,10 @@ class TestSSLTimeOut(tservers.ServerTestBase):
def test_timeout_client(self):
c = tcp.TCPClient(("127.0.0.1", self.port))
- c.connect()
- c.convert_to_ssl()
- c.settimeout(0.1)
- tutils.raises(TcpTimeout, c.rfile.read, 10)
+ with c.connect():
+ c.convert_to_ssl()
+ c.settimeout(0.1)
+ tutils.raises(TcpTimeout, c.rfile.read, 10)
class TestDHParams(tservers.ServerTestBase):
@@ -570,10 +576,10 @@ class TestDHParams(tservers.ServerTestBase):
def test_dhparams(self):
c = tcp.TCPClient(("127.0.0.1", self.port))
- c.connect()
- c.convert_to_ssl()
- ret = c.get_current_cipher()
- assert ret[0] == "DHE-RSA-AES256-SHA"
+ with c.connect():
+ c.convert_to_ssl()
+ ret = c.get_current_cipher()
+ assert ret[0] == "DHE-RSA-AES256-SHA"
def test_create_dhparams(self):
with tutils.tmpdir() as d:
@@ -718,33 +724,34 @@ class TestPeek(tservers.ServerTestBase):
handler = EchoHandler
def _connect(self, c):
- c.connect()
+ return c.connect()
def test_peek(self):
testval = b"peek!\n"
c = tcp.TCPClient(("127.0.0.1", self.port))
- self._connect(c)
- c.wfile.write(testval)
- c.wfile.flush()
+ with self._connect(c):
+ c.wfile.write(testval)
+ c.wfile.flush()
- assert c.rfile.peek(4) == b"peek"
- assert c.rfile.peek(6) == b"peek!\n"
- assert c.rfile.readline() == testval
+ assert c.rfile.peek(4) == b"peek"
+ assert c.rfile.peek(6) == b"peek!\n"
+ assert c.rfile.readline() == testval
- c.close()
- with tutils.raises(NetlibException):
- if c.rfile.peek(1) == b"":
- # Workaround for Python 2 on Unix:
- # Peeking a closed connection does not raise an exception here.
- raise NetlibException()
+ c.close()
+ with tutils.raises(NetlibException):
+ if c.rfile.peek(1) == b"":
+ # Workaround for Python 2 on Unix:
+ # Peeking a closed connection does not raise an exception here.
+ raise NetlibException()
class TestPeekSSL(TestPeek):
ssl = True
def _connect(self, c):
- c.connect()
- c.convert_to_ssl()
+ with c.connect() as conn:
+ c.convert_to_ssl()
+ return conn.pop()
class TestAddress:
@@ -774,16 +781,16 @@ class TestSSLKeyLogger(tservers.ServerTestBase):
tcp.log_ssl_key = tcp.SSLKeyLogger(logfile)
c = tcp.TCPClient(("127.0.0.1", self.port))
- c.connect()
- c.convert_to_ssl()
- c.wfile.write(testval)
- c.wfile.flush()
- assert c.rfile.readline() == testval
- c.finish()
-
- tcp.log_ssl_key.close()
- with open(logfile, "rb") as f:
- assert f.read().count(b"CLIENT_RANDOM") == 2
+ with c.connect():
+ c.convert_to_ssl()
+ c.wfile.write(testval)
+ c.wfile.flush()
+ assert c.rfile.readline() == testval
+ c.finish()
+
+ tcp.log_ssl_key.close()
+ with open(logfile, "rb") as f:
+ assert f.read().count(b"CLIENT_RANDOM") == 2
tcp.log_ssl_key = _logfun
diff --git a/test/netlib/tservers.py b/test/netlib/tservers.py
index 569745e6..803aaa72 100644
--- a/test/netlib/tservers.py
+++ b/test/netlib/tservers.py
@@ -104,6 +104,9 @@ class ServerTestBase(object):
def teardown_class(cls):
cls.server.shutdown()
+ def teardown(self):
+ self.server.server.wait_for_silence()
+
@property
def last_handler(self):
return self.server.server.last_handler
diff --git a/test/pathod/test_pathoc.py b/test/pathod/test_pathoc.py
index 77d4721c..05cf518d 100644
--- a/test/pathod/test_pathoc.py
+++ b/test/pathod/test_pathoc.py
@@ -1,4 +1,5 @@
from six.moves import cStringIO as StringIO
+from six import BytesIO
from mock import Mock
from netlib import http
@@ -12,7 +13,7 @@ import tutils
def test_response():
- r = http.Response("HTTP/1.1", 200, "Message", {}, None, None)
+ r = http.Response(b"HTTP/1.1", 200, b"Message", {}, None, None)
assert repr(r)
@@ -29,7 +30,7 @@ class PathocTestDaemon(tutils.DaemonTests):
if timeout:
c.settimeout(timeout)
for i in requests:
- r = language.parse_pathoc(i).next()
+ r = next(language.parse_pathoc(i))
if kwargs.get("explain"):
r = r.freeze(language.Settings())
try:
@@ -44,17 +45,17 @@ class TestDaemonSSL(PathocTestDaemon):
ssl = True
ssloptions = dict(
request_client_cert=True,
- sans=["test1.com", "test2.com"],
+ sans=[b"test1.com", b"test2.com"],
alpn_select=b'h2',
)
def test_sni(self):
self.tval(
["get:/p/200"],
- sni="foobar.com"
+ sni=b"foobar.com"
)
log = self.d.log()
- assert log[0]["request"]["sni"] == "foobar.com"
+ assert log[0]["request"]["sni"] == b"foobar.com"
def test_showssl(self):
assert "certificate chain" in self.tval(["get:/p/200"], showssl=True)
@@ -73,7 +74,7 @@ class TestDaemonSSL(PathocTestDaemon):
("127.0.0.1", self.d.port),
use_http2=True,
ssl=False,
- fp = fp
+ fp=fp
)
tutils.raises(NotImplementedError, c.connect)
@@ -171,36 +172,36 @@ class TestDaemon(PathocTestDaemon):
c.rfile, c.wfile = StringIO(), StringIO()
with raises("connect failed"):
c.http_connect(to)
- c.rfile = StringIO(
- "HTTP/1.1 500 OK\r\n"
+ c.rfile = BytesIO(
+ b"HTTP/1.1 500 OK\r\n"
)
with raises("connect failed"):
c.http_connect(to)
- c.rfile = StringIO(
- "HTTP/1.1 200 OK\r\n"
+ c.rfile = BytesIO(
+ b"HTTP/1.1 200 OK\r\n"
)
c.http_connect(to)
def test_socks_connect(self):
to = ("foobar", 80)
c = pathoc.Pathoc(("127.0.0.1", self.d.port), fp=None)
- c.rfile, c.wfile = tutils.treader(""), StringIO()
+ c.rfile, c.wfile = tutils.treader(b""), BytesIO()
tutils.raises(pathoc.PathocError, c.socks_connect, to)
c.rfile = tutils.treader(
- "\x05\xEE"
+ b"\x05\xEE"
)
tutils.raises("SOCKS without authentication", c.socks_connect, ("example.com", 0xDEAD))
c.rfile = tutils.treader(
- "\x05\x00" +
- "\x05\xEE\x00\x03\x0bexample.com\xDE\xAD"
+ b"\x05\x00" +
+ b"\x05\xEE\x00\x03\x0bexample.com\xDE\xAD"
)
tutils.raises("SOCKS server error", c.socks_connect, ("example.com", 0xDEAD))
c.rfile = tutils.treader(
- "\x05\x00" +
- "\x05\x00\x00\x03\x0bexample.com\xDE\xAD"
+ b"\x05\x00" +
+ b"\x05\x00\x00\x03\x0bexample.com\xDE\xAD"
)
c.socks_connect(("example.com", 0xDEAD))
diff --git a/test/pathod/tutils.py b/test/pathod/tutils.py
index bf5e3165..daaa8628 100644
--- a/test/pathod/tutils.py
+++ b/test/pathod/tutils.py
@@ -3,8 +3,8 @@ import re
import shutil
import requests
from six.moves import cStringIO as StringIO
+from six.moves import urllib
from six import BytesIO
-import urllib
from netlib import tcp
from netlib import utils
@@ -20,7 +20,7 @@ def treader(bytes):
"""
Construct a tcp.Read object from bytes.
"""
- fp = StringIO(bytes)
+ fp = BytesIO(bytes)
return tcp.Reader(fp)
@@ -87,7 +87,7 @@ class DaemonTests(object):
)
with c.connect():
if params:
- path = path + "?" + urllib.urlencode(params)
+ path = path + "?" + urllib.parse.urlencode(params)
resp = c.request("get:%s" % path)
return resp
@@ -100,7 +100,7 @@ class DaemonTests(object):
)
with c.connect():
resp = c.request(
- "get:/p/%s" % urllib.quote(spec).encode("string_escape")
+ "get:/p/%s" % urllib.parse.quote(spec).encode("string_escape")
)
return resp
diff --git a/tox.ini b/tox.ini
index ffd78359..c150e14e 100644
--- a/tox.ini
+++ b/tox.ini
@@ -7,7 +7,7 @@ deps =
codecov>=2.0.5
passenv = CI TRAVIS_BUILD_ID TRAVIS TRAVIS_BRANCH TRAVIS_JOB_NUMBER TRAVIS_PULL_REQUEST TRAVIS_JOB_ID TRAVIS_REPO_SLUG TRAVIS_COMMIT
setenv =
- PY3TESTS = test/netlib test/mitmproxy/script test/pathod/test_utils.py test/pathod/test_log.py test/pathod/test_language_generators.py test/pathod/test_language_writer.py test/pathod/test_language_base.py test/pathod/test_language_http.py test/pathod/test_language_websocket.py test/pathod/test_language_http2.py
+ PY3TESTS = test/netlib test/mitmproxy/script test/pathod/test_utils.py test/pathod/test_log.py test/pathod/test_language_generators.py test/pathod/test_language_writer.py test/pathod/test_language_base.py test/pathod/test_language_http.py test/pathod/test_language_websocket.py test/pathod/test_language_http2.py test/pathod/test_pathoc.py
[testenv:py27]
commands =