diff options
33 files changed, 1260 insertions, 556 deletions
diff --git a/.appveyor.yml b/.appveyor.yml index 914e75eb..689ad5e5 100644 --- a/.appveyor.yml +++ b/.appveyor.yml @@ -1,16 +1,19 @@ version: '{build}' shallow_clone: true +build: off # Not a C# project environment: matrix: - PYTHON: "C:\\Python27" - PATH: "C:\\Python27;C:\\Python27\\Scripts;%PATH%" + PATH: "%APPDATA%\\Python\\Scripts;C:\\Python27;C:\\Python27\\Scripts;%PATH%" PYINSTALLER_VERSION: "git+https://github.com/pyinstaller/pyinstaller.git" install: - - "pip install --src .. -r requirements.txt" + - "pip install --user -U pip setuptools" + - "pip install --user --src .. -r requirements.txt" - "python -c \"from OpenSSL import SSL; print(SSL.SSLeay_version(SSL.SSLEAY_VERSION))\"" -build: off # Not a C# project test_script: - - "py.test -n 4" + - "py.test -s --cov libmproxy --timeout 30" +cache: + - C:\Users\appveyor\AppData\Local\pip\cache after_test: - | git clone https://github.com/mitmproxy/release.git ..\release diff --git a/.travis.yml b/.travis.yml index ed5a7ad5..8ea3ed32 100644 --- a/.travis.yml +++ b/.travis.yml @@ -5,9 +5,6 @@ matrix: fast_finish: true include: - python: 2.7 - - language: generic - os: osx - osx_image: xcode7.1 - python: 2.7 env: OPENSSL=1.0.2 addons: @@ -18,9 +15,6 @@ matrix: - debian-sid packages: - libssl-dev - - python: 2.7 - env: DOCS=1 - script: 'cd docs && make html' - python: pypy - python: pypy env: OPENSSL=1.0.2 @@ -32,9 +26,13 @@ matrix: - debian-sid packages: - libssl-dev + - language: generic + os: osx + osx_image: xcode7.1 + - python: 2.7 + env: DOCS=1 + script: 'cd docs && make html' allow_failures: - # We allow pypy to fail until Travis fixes their infrastructure to a pypy - # with a recent enought CFFI library to run cryptography 1.0+. - python: pypy install: @@ -45,13 +43,27 @@ install: brew outdated openssl || brew upgrade openssl brew install python fi + - | + if [ "$TRAVIS_PYTHON_VERSION" = "pypy" ]; then + export PYENV_ROOT="$HOME/.pyenv" + if [ -f "$PYENV_ROOT/bin/pyenv" ]; then + pushd "$PYENV_ROOT" && git pull && popd + else + rm -rf "$PYENV_ROOT" && git clone --depth 1 https://github.com/yyuu/pyenv.git "$PYENV_ROOT" + fi + export PYPY_VERSION="4.0.1" + "$PYENV_ROOT/bin/pyenv" install --skip-existing "pypy-$PYPY_VERSION" + virtualenv --python="$PYENV_ROOT/versions/pypy-$PYPY_VERSION/bin/python" "$HOME/virtualenvs/pypy-$PYPY_VERSION" + source "$HOME/virtualenvs/pypy-$PYPY_VERSION/bin/activate" + fi + - "pip install -U pip setuptools" - "pip install --src .. -r requirements.txt" before_script: - "openssl version -a" script: - - "py.test -n 4 --cov libmproxy" + - "py.test -s --cov libmproxy --timeout 30" after_success: - coveralls @@ -80,16 +92,8 @@ notifications: on_success: always on_failure: always -# exclude cryptography from cache -# it depends on libssl-dev version -# which needs to be compiled specifically to each version -before_cache: - - pip uninstall -y cryptography - cache: directories: - $HOME/.cache/pip - - /home/travis/virtualenv/python2.7.9/lib/python2.7/site-packages - - /home/travis/virtualenv/python2.7.9/bin - - /home/travis/virtualenv/pypy-2.5.0/site-packages - - /home/travis/virtualenv/pypy-2.5.0/bin + - $HOME/.pyenv + - $HOME/Library/Caches/pip
\ No newline at end of file diff --git a/docs/install.rst b/docs/install.rst index a4a26a42..3300807b 100644 --- a/docs/install.rst +++ b/docs/install.rst @@ -16,6 +16,9 @@ This was tested on a fully patched installation of Ubuntu 14.04. Once installation is complete you can run :ref:`mitmproxy` or :ref:`mitmdump` from a terminal. +On **Ubuntu 12.04** (and other systems with an outdated version of pip), +you may need to update pip using ``pip install -U pip`` before installing mitmproxy. + Installation From Source (Ubuntu) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ diff --git a/libmproxy/cmdline.py b/libmproxy/cmdline.py index 111ab145..d8b6000c 100644 --- a/libmproxy/cmdline.py +++ b/libmproxy/cmdline.py @@ -363,6 +363,11 @@ def proxy_options(parser): help="Proxy service port." ) http2 = group.add_mutually_exclusive_group() + # !!! + # Watch out: We raise a RuntimeError in libmproxy.proxy.config if http2 is enabled, + # but the OpenSSL version does not have ALPN support (which is the default on Ubuntu 14.04). + # Do not simply set --http2 as enabled by default. + # !!! http2.add_argument("--http2", action="store_true", dest="http2") http2.add_argument("--no-http2", action="store_false", dest="http2", help="Explicitly enable/disable experimental HTTP2 support. " diff --git a/libmproxy/console/common.py b/libmproxy/console/common.py index 0ada3e34..c29ffddc 100644 --- a/libmproxy/console/common.py +++ b/libmproxy/console/common.py @@ -130,7 +130,7 @@ else: SYMBOL_MARK = "[m]" -def raw_format_flow(f, focus, extended, padding): +def raw_format_flow(f, focus, extended): f = dict(f) pile = [] req = [] @@ -160,8 +160,11 @@ def raw_format_flow(f, focus, extended, padding): else: uc = "title" + url = f["req_url"] + if f["req_http_version"] not in ("HTTP/1.0", "HTTP/1.1"): + url += " " + f["req_http_version"] req.append( - urwid.Text([(uc, f["req_url"])]) + urwid.Text([(uc, url)]) ) pile.append(urwid.Columns(req, dividechars=1)) @@ -396,8 +399,7 @@ def ask_save_body(part, master, state, flow): flowcache = utils.LRUCache(800) -def format_flow(f, focus, extended=False, hostheader=False, padding=2, - marked=False): +def format_flow(f, focus, extended=False, hostheader=False, marked=False): d = dict( intercepted = f.intercepted, acked = f.reply.acked, @@ -406,6 +408,7 @@ def format_flow(f, focus, extended=False, hostheader=False, padding=2, req_is_replay = f.request.is_replay, req_method = f.request.method, req_url = f.request.pretty_url if hostheader else f.request.url, + req_http_version = f.request.http_version, err_msg = f.error.msg if f.error else None, resp_code = f.response.status_code if f.response else None, @@ -437,5 +440,5 @@ def format_flow(f, focus, extended=False, hostheader=False, padding=2, d["resp_ctype"] = "" return flowcache.get( raw_format_flow, - tuple(sorted(d.items())), focus, extended, padding + tuple(sorted(d.items())), focus, extended ) diff --git a/libmproxy/console/flowview.py b/libmproxy/console/flowview.py index 8102de55..d2b98b68 100644 --- a/libmproxy/console/flowview.py +++ b/libmproxy/console/flowview.py @@ -103,7 +103,6 @@ class FlowViewHeader(urwid.WidgetWrap): f, False, extended=True, - padding=0, hostheader=self.master.showhost ) signals.flow_change.connect(self.sig_flow_change) @@ -114,7 +113,6 @@ class FlowViewHeader(urwid.WidgetWrap): flow, False, extended=True, - padding=0, hostheader=self.master.showhost ) diff --git a/libmproxy/contentviews.py b/libmproxy/contentviews.py index 80955b0f..c0652c18 100644 --- a/libmproxy/contentviews.py +++ b/libmproxy/contentviews.py @@ -35,12 +35,12 @@ from .contrib.wbxml.ASCommandResponse import ASCommandResponse try: import pyamf from pyamf import remoting, flex -except ImportError: # pragma nocover +except ImportError: # pragma no cover pyamf = None try: import cssutils -except ImportError: # pragma nocover +except ImportError: # pragma no cover cssutils = None else: cssutils.log.setLevel(logging.CRITICAL) diff --git a/libmproxy/controller.py b/libmproxy/controller.py index 712ab1d2..9a059856 100644 --- a/libmproxy/controller.py +++ b/libmproxy/controller.py @@ -56,7 +56,7 @@ class Channel: try: # The timeout is here so we can handle a should_exit event. g = m.reply.q.get(timeout=0.5) - except Queue.Empty: # pragma: nocover + except Queue.Empty: # pragma: no cover continue return g diff --git a/libmproxy/dump.py b/libmproxy/dump.py index 95be2d27..6dab2ddc 100644 --- a/libmproxy/dump.py +++ b/libmproxy/dump.py @@ -247,11 +247,16 @@ class DumpMaster(flow.FlowMaster): url = flow.request.url url = click.style(url, bold=True) - line = "{stickycookie}{client} {method} {url}".format( + httpversion = "" + if flow.request.http_version not in ("HTTP/1.1", "HTTP/1.0"): + httpversion = " " + flow.request.http_version # We hide "normal" HTTP 1. + + line = "{stickycookie}{client} {method} {url}{httpversion}".format( stickycookie=stickycookie, client=client, method=method, - url=url + url=url, + httpversion=httpversion ) self.echo(line) diff --git a/libmproxy/flow_format_compat.py b/libmproxy/flow_format_compat.py index 2b99b805..5af9b762 100644 --- a/libmproxy/flow_format_compat.py +++ b/libmproxy/flow_format_compat.py @@ -21,9 +21,23 @@ def convert_014_015(data): return data +def convert_015_016(data): + for m in ("request", "response"): + if "body" in data[m]: + data[m]["content"] = data[m].pop("body") + if "httpversion" in data[m]: + data[m]["http_version"] = data[m].pop("httpversion") + if "msg" in data["response"]: + data["response"]["reason"] = data["response"].pop("msg") + data["request"].pop("form_out", None) + data["version"] = (0, 16) + return data + + converters = { (0, 13): convert_013_014, (0, 14): convert_014_015, + (0, 15): convert_015_016, } diff --git a/libmproxy/main.py b/libmproxy/main.py index 655d573d..f6664924 100644 --- a/libmproxy/main.py +++ b/libmproxy/main.py @@ -37,7 +37,11 @@ def get_server(dummy_server, options): sys.exit(1) -def mitmproxy(args=None): # pragma: nocover +def mitmproxy(args=None): # pragma: no cover + if os.name == "nt": + print("Error: mitmproxy's console interface is not supported on Windows. " + "You can run mitmdump or mitmweb instead.", file=sys.stderr) + sys.exit(1) from . import console check_pyopenssl_version() @@ -68,7 +72,7 @@ def mitmproxy(args=None): # pragma: nocover pass -def mitmdump(args=None): # pragma: nocover +def mitmdump(args=None): # pragma: no cover from . import dump check_pyopenssl_version() @@ -103,7 +107,7 @@ def mitmdump(args=None): # pragma: nocover pass -def mitmweb(args=None): # pragma: nocover +def mitmweb(args=None): # pragma: no cover from . import web check_pyopenssl_version() diff --git a/libmproxy/models/connections.py b/libmproxy/models/connections.py index a45e1629..d5920256 100644 --- a/libmproxy/models/connections.py +++ b/libmproxy/models/connections.py @@ -42,28 +42,14 @@ class ClientConnection(tcp.BaseHandler, stateobject.StateObject): return self.ssl_established _stateobject_attributes = dict( + address=tcp.Address, + clientcert=certutils.SSLCert, ssl_established=bool, timestamp_start=float, timestamp_end=float, timestamp_ssl_setup=float ) - def get_state(self, short=False): - d = super(ClientConnection, self).get_state(short) - d.update( - address=({ - "address": self.address(), - "use_ipv6": self.address.use_ipv6} if self.address else {}), - clientcert=self.cert.to_pem() if self.clientcert else None) - return d - - def load_state(self, state): - super(ClientConnection, self).load_state(state) - self.address = tcp.Address( - **state["address"]) if state["address"] else None - self.clientcert = certutils.SSLCert.from_pem( - state["clientcert"]) if state["clientcert"] else None - def copy(self): return copy.copy(self) @@ -76,7 +62,7 @@ class ClientConnection(tcp.BaseHandler, stateobject.StateObject): @classmethod def from_state(cls, state): f = cls(None, tuple(), None) - f.load_state(state) + f.set_state(state) return f def convert_to_ssl(self, *args, **kwargs): @@ -130,33 +116,11 @@ class ServerConnection(tcp.TCPClient, stateobject.StateObject): ssl_established=bool, sni=str ) - _stateobject_long_attributes = {"cert"} - - def get_state(self, short=False): - d = super(ServerConnection, self).get_state(short) - d.update( - address=({"address": self.address(), - "use_ipv6": self.address.use_ipv6} if self.address else {}), - source_address=({"address": self.source_address(), - "use_ipv6": self.source_address.use_ipv6} if self.source_address else None), - cert=self.cert.to_pem() if self.cert else None - ) - return d - - def load_state(self, state): - super(ServerConnection, self).load_state(state) - - self.address = tcp.Address( - **state["address"]) if state["address"] else None - self.source_address = tcp.Address( - **state["source_address"]) if state["source_address"] else None - self.cert = certutils.SSLCert.from_pem( - state["cert"]) if state["cert"] else None @classmethod def from_state(cls, state): f = cls(tuple()) - f.load_state(state) + f.set_state(state) return f def copy(self): diff --git a/libmproxy/models/flow.py b/libmproxy/models/flow.py index b4e8cb88..10255dad 100644 --- a/libmproxy/models/flow.py +++ b/libmproxy/models/flow.py @@ -45,7 +45,7 @@ class Error(stateobject.StateObject): # the default implementation assumes an empty constructor. Override # accordingly. f = cls(None) - f.load_state(state) + f.set_state(state) return f def copy(self): @@ -86,16 +86,19 @@ class Flow(stateobject.StateObject): intercepted=bool ) - def get_state(self, short=False): - d = super(Flow, self).get_state(short) + def get_state(self): + d = super(Flow, self).get_state() d.update(version=version.IVERSION) if self._backup and self._backup != d: - if short: - d.update(modified=True) - else: - d.update(backup=self._backup) + d.update(backup=self._backup) return d + def set_state(self, state): + state.pop("version") + if "backup" in state: + self._backup = state.pop("backup") + super(Flow, self).set_state(state) + def __eq__(self, other): return self is other @@ -133,7 +136,7 @@ class Flow(stateobject.StateObject): Revert to the last backed up state. """ if self._backup: - self.load_state(self._backup) + self.set_state(self._backup) self._backup = None def kill(self, master): diff --git a/libmproxy/models/http.py b/libmproxy/models/http.py index e07dff69..3c024e76 100644 --- a/libmproxy/models/http.py +++ b/libmproxy/models/http.py @@ -1,36 +1,20 @@ from __future__ import (absolute_import, print_function, division) import Cookie import copy +import warnings from email.utils import parsedate_tz, formatdate, mktime_tz import time from libmproxy import utils from netlib import encoding -from netlib.http import status_codes, Headers, Request, Response, CONTENT_MISSING, decoded +from netlib.http import status_codes, Headers, Request, Response, decoded from netlib.tcp import Address -from .. import version, stateobject +from .. import version from .flow import Flow -class MessageMixin(stateobject.StateObject): - _stateobject_attributes = dict( - http_version=bytes, - headers=Headers, - timestamp_start=float, - timestamp_end=float - ) - _stateobject_long_attributes = {"body"} - - def get_state(self, short=False): - ret = super(MessageMixin, self).get_state(short) - if short: - if self.content: - ret["contentLength"] = len(self.content) - elif self.content == CONTENT_MISSING: - ret["contentLength"] = None - else: - ret["contentLength"] = 0 - return ret + +class MessageMixin(object): def get_decoded_content(self): """ @@ -43,33 +27,6 @@ class MessageMixin(stateobject.StateObject): return self.content return encoding.decode(ce, self.content) - def decode(self): - """ - Decodes body based on the current Content-Encoding header, then - removes the header. If there is no Content-Encoding header, no - action is taken. - - Returns True if decoding succeeded, False otherwise. - """ - ce = self.headers.get("content-encoding") - if not self.content or ce not in encoding.ENCODINGS: - return False - data = encoding.decode(ce, self.content) - if data is None: - return False - self.content = data - self.headers.pop("content-encoding", None) - return True - - def encode(self, e): - """ - Encodes body with the encoding e, where e is "gzip", "deflate" - or "identity". - """ - # FIXME: Error if there's an existing encoding header? - self.content = encoding.encode(e, self.content) - self.headers["content-encoding"] = e - def copy(self): c = copy.copy(self) if hasattr(self, "data"): # FIXME remove condition @@ -86,10 +43,12 @@ class MessageMixin(stateobject.StateObject): Returns the number of replacements made. """ - with decoded(self): - self.content, count = utils.safe_subn( - pattern, repl, self.content, *args, **kwargs - ) + count = 0 + if self.content: + with decoded(self): + self.content, count = utils.safe_subn( + pattern, repl, self.content, *args, **kwargs + ) fields = [] for name, value in self.headers.fields: name, c = utils.safe_subn(pattern, repl, name, *args, **kwargs) @@ -161,6 +120,9 @@ class HTTPRequest(MessageMixin, Request): timestamp_start=None, timestamp_end=None, form_out=None, + is_replay=False, + stickycookie=False, + stickyauth=False, ): Request.__init__( self, @@ -179,51 +141,26 @@ class HTTPRequest(MessageMixin, Request): self.form_out = form_out or first_line_format # FIXME remove # Have this request's cookies been modified by sticky cookies or auth? - self.stickycookie = False - self.stickyauth = False + self.stickycookie = stickycookie + self.stickyauth = stickyauth # Is this request replayed? - self.is_replay = False - - _stateobject_attributes = MessageMixin._stateobject_attributes.copy() - _stateobject_attributes.update( - content=bytes, - first_line_format=str, - method=bytes, - scheme=bytes, - host=bytes, - port=int, - path=bytes, - form_out=str, - is_replay=bool - ) - - @classmethod - def from_state(cls, state): - f = cls( - None, - b"", - None, - None, - None, - None, - None, - None, - None, - None, - None) - f.load_state(state) - return f + self.is_replay = is_replay + + def get_state(self): + state = super(HTTPRequest, self).get_state() + state.update( + stickycookie = self.stickycookie, + stickyauth = self.stickyauth, + is_replay = self.is_replay, + ) + return state - @classmethod - def from_protocol( - self, - protocol, - *args, - **kwargs - ): - req = protocol.read_request(*args, **kwargs) - return self.wrap(req) + def set_state(self, state): + self.stickycookie = state.pop("stickycookie") + self.stickyauth = state.pop("stickyauth") + self.is_replay = state.pop("is_replay") + super(HTTPRequest, self).set_state(state) @classmethod def wrap(self, request): @@ -241,10 +178,17 @@ class HTTPRequest(MessageMixin, Request): timestamp_end=request.timestamp_end, form_out=(request.form_out if hasattr(request, 'form_out') else None), ) - if hasattr(request, 'stream_id'): - req.stream_id = request.stream_id return req + @property + def form_out(self): + warnings.warn(".form_out is deprecated, use .first_line_format instead.", DeprecationWarning) + return self.first_line_format + + @form_out.setter + def form_out(self, value): + warnings.warn(".form_out is deprecated, use .first_line_format instead.", DeprecationWarning) + def __hash__(self): return id(self) @@ -297,6 +241,7 @@ class HTTPResponse(MessageMixin, Response): content, timestamp_start=None, timestamp_end=None, + is_replay = False ): Response.__init__( self, @@ -310,32 +255,9 @@ class HTTPResponse(MessageMixin, Response): ) # Is this request replayed? - self.is_replay = False + self.is_replay = is_replay self.stream = False - _stateobject_attributes = MessageMixin._stateobject_attributes.copy() - _stateobject_attributes.update( - body=bytes, - status_code=int, - msg=bytes - ) - - @classmethod - def from_state(cls, state): - f = cls(None, None, None, None, None) - f.load_state(state) - return f - - @classmethod - def from_protocol( - self, - protocol, - *args, - **kwargs - ): - resp = protocol.read_response(*args, **kwargs) - return self.wrap(resp) - @classmethod def wrap(self, response): resp = HTTPResponse( @@ -347,8 +269,6 @@ class HTTPResponse(MessageMixin, Response): timestamp_start=response.timestamp_start, timestamp_end=response.timestamp_end, ) - if hasattr(response, 'stream_id'): - resp.stream_id = response.stream_id return resp def _refresh_cookie(self, c, delta): @@ -448,7 +368,7 @@ class HTTPFlow(Flow): @classmethod def from_state(cls, state): f = cls(None, None) - f.load_state(state) + f.set_state(state) return f def __repr__(self): diff --git a/libmproxy/protocol/__init__.py b/libmproxy/protocol/__init__.py index d46f16f5..ea958d06 100644 --- a/libmproxy/protocol/__init__.py +++ b/libmproxy/protocol/__init__.py @@ -27,15 +27,19 @@ as late as possible; this makes server replay without any outgoing connections p from __future__ import (absolute_import, print_function, division) from .base import Layer, ServerConnectionMixin, Kill -from .http import Http1Layer, UpstreamConnectLayer, Http2Layer from .tls import TlsLayer from .tls import is_tls_record_magic from .tls import TlsClientHello +from .http import UpstreamConnectLayer +from .http1 import Http1Layer +from .http2 import Http2Layer from .rawtcp import RawTCPLayer __all__ = [ "Layer", "ServerConnectionMixin", "Kill", - "Http1Layer", "UpstreamConnectLayer", "Http2Layer", "TlsLayer", "is_tls_record_magic", "TlsClientHello", + "UpstreamConnectLayer", + "Http1Layer", + "Http2Layer", "RawTCPLayer", ] diff --git a/libmproxy/protocol/base.py b/libmproxy/protocol/base.py index 4eb034c0..40fcaf65 100644 --- a/libmproxy/protocol/base.py +++ b/libmproxy/protocol/base.py @@ -14,7 +14,7 @@ class _LayerCodeCompletion(object): Dummy class that provides type hinting in PyCharm, which simplifies development a lot. """ - def __init__(self, **mixin_args): # pragma: nocover + def __init__(self, **mixin_args): # pragma: no cover super(_LayerCodeCompletion, self).__init__(**mixin_args) if True: return diff --git a/libmproxy/protocol/http.py b/libmproxy/protocol/http.py index 12d09e71..f3240b85 100644 --- a/libmproxy/protocol/http.py +++ b/libmproxy/protocol/http.py @@ -1,26 +1,30 @@ from __future__ import (absolute_import, print_function, division) + import sys import traceback - import six from netlib import tcp from netlib.exceptions import HttpException, HttpReadDisconnect, NetlibException -from netlib.http import http1, Headers -from netlib.http import CONTENT_MISSING -from netlib.tcp import Address -from netlib.http.http2.connections import HTTP2Protocol -from netlib.http.http2.frame import GoAwayFrame, PriorityFrame, WindowUpdateFrame +from netlib.http import Headers, CONTENT_MISSING + +from h2.exceptions import H2Error + from .. import utils from ..exceptions import HttpProtocolException, ProtocolException from ..models import ( - HTTPFlow, HTTPRequest, HTTPResponse, make_error_response, make_connect_response, Error, expect_continue_response + HTTPFlow, + HTTPResponse, + make_error_response, + make_connect_response, + Error, + expect_continue_response ) + from .base import Layer, Kill -class _HttpLayer(Layer): - supports_streaming = False +class _HttpTransmissionLayer(Layer): def read_request(self): raise NotImplementedError() @@ -32,37 +36,18 @@ class _HttpLayer(Layer): raise NotImplementedError() def read_response(self, request): - raise NotImplementedError() - - def send_response(self, response): - raise NotImplementedError() - - def check_close_connection(self, flow): - raise NotImplementedError() - - -class _StreamingHttpLayer(_HttpLayer): - supports_streaming = True - - def read_response_headers(self): - raise NotImplementedError - - def read_response_body(self, request, response): - raise NotImplementedError() - yield "this is a generator" # pragma: no cover - - def read_response(self, request): response = self.read_response_headers() response.data.content = b"".join( self.read_response_body(request, response) ) return response - def send_response_headers(self, response): - raise NotImplementedError + def read_response_headers(self): + raise NotImplementedError() - def send_response_body(self, response, chunks): + def read_response_body(self, request, response): raise NotImplementedError() + yield "this is a generator" # pragma: no cover def send_response(self, response): if response.content == CONTENT_MISSING: @@ -70,164 +55,14 @@ class _StreamingHttpLayer(_HttpLayer): self.send_response_headers(response) self.send_response_body(response, [response.content]) - -class Http1Layer(_StreamingHttpLayer): - - def __init__(self, ctx, mode): - super(Http1Layer, self).__init__(ctx) - self.mode = mode - - def read_request(self): - req = http1.read_request(self.client_conn.rfile, body_size_limit=self.config.body_size_limit) - return HTTPRequest.wrap(req) - - def read_request_body(self, request): - expected_size = http1.expected_http_body_size(request) - return http1.read_body(self.client_conn.rfile, expected_size, self.config.body_size_limit) - - def send_request(self, request): - self.server_conn.wfile.write(http1.assemble_request(request)) - self.server_conn.wfile.flush() - - def read_response_headers(self): - resp = http1.read_response_head(self.server_conn.rfile) - return HTTPResponse.wrap(resp) - - def read_response_body(self, request, response): - expected_size = http1.expected_http_body_size(request, response) - return http1.read_body(self.server_conn.rfile, expected_size, self.config.body_size_limit) - def send_response_headers(self, response): - raw = http1.assemble_response_head(response) - self.client_conn.wfile.write(raw) - self.client_conn.wfile.flush() + raise NotImplementedError() def send_response_body(self, response, chunks): - for chunk in http1.assemble_body(response.headers, chunks): - self.client_conn.wfile.write(chunk) - self.client_conn.wfile.flush() - - def check_close_connection(self, flow): - request_close = http1.connection_close( - flow.request.http_version, - flow.request.headers - ) - response_close = http1.connection_close( - flow.response.http_version, - flow.response.headers - ) - read_until_eof = http1.expected_http_body_size(flow.request, flow.response) == -1 - close_connection = request_close or response_close or read_until_eof - if flow.request.form_in == "authority" and flow.response.status_code == 200: - # Workaround for https://github.com/mitmproxy/mitmproxy/issues/313: - # Charles Proxy sends a CONNECT response with HTTP/1.0 - # and no Content-Length header - - return False - return close_connection - - def __call__(self): - layer = HttpLayer(self, self.mode) - layer() - - -# TODO: The HTTP2 layer is missing multiplexing, which requires a major rewrite. -class Http2Layer(_HttpLayer): - - def __init__(self, ctx, mode): - super(Http2Layer, self).__init__(ctx) - self.mode = mode - self.client_protocol = HTTP2Protocol(self.client_conn, is_server=True, - unhandled_frame_cb=self.handle_unexpected_frame_from_client) - self.server_protocol = HTTP2Protocol(self.server_conn, is_server=False, - unhandled_frame_cb=self.handle_unexpected_frame_from_server) - - def read_request(self): - request = HTTPRequest.from_protocol( - self.client_protocol, - body_size_limit=self.config.body_size_limit - ) - self._stream_id = request.stream_id - return request - - def send_request(self, message): - # TODO: implement flow control and WINDOW_UPDATE frames - self.server_conn.send(self.server_protocol.assemble(message)) - - def read_response(self, request): - return HTTPResponse.from_protocol( - self.server_protocol, - request_method=request.method, - body_size_limit=self.config.body_size_limit, - include_body=True, - stream_id=self._stream_id - ) - - def send_response(self, message): - # TODO: implement flow control to prevent client buffer filling up - # maintain a send buffer size, and read WindowUpdateFrames from client to increase the send buffer - self.client_conn.send(self.client_protocol.assemble(message)) + raise NotImplementedError() def check_close_connection(self, flow): - # TODO: add a timer to disconnect after a 10 second timeout - return False - - def connect(self): - self.ctx.connect() - self.server_protocol = HTTP2Protocol(self.server_conn, is_server=False, - unhandled_frame_cb=self.handle_unexpected_frame_from_server) - self.server_protocol.perform_connection_preface() - - def set_server(self, *args, **kwargs): - self.ctx.set_server(*args, **kwargs) - self.server_protocol = HTTP2Protocol(self.server_conn, is_server=False, - unhandled_frame_cb=self.handle_unexpected_frame_from_server) - self.server_protocol.perform_connection_preface() - - def __call__(self): - self.server_protocol.perform_connection_preface() - layer = HttpLayer(self, self.mode) - layer() - - # terminate the connection - self.client_conn.send(GoAwayFrame().to_bytes()) - - def handle_unexpected_frame_from_client(self, frame): - if isinstance(frame, WindowUpdateFrame): - # Clients are sending WindowUpdate frames depending on their flow control algorithm. - # Since we cannot predict these frames, and we do not need to respond to them, - # simply accept them, and hide them from the log. - # Ideally we should keep track of our own flow control window and - # stall transmission if the outgoing flow control buffer is full. - return - if isinstance(frame, PriorityFrame): - # Clients are sending Priority frames depending on their implementation. - # The RFC does not clearly state when or which priority preferences should be set. - # Since we cannot predict these frames, and we do not need to respond to them, - # simply accept them, and hide them from the log. - # Ideally we should forward them to the server. - return - if isinstance(frame, GoAwayFrame): - # Client wants to terminate the connection, - # relay it to the server. - self.server_conn.send(frame.to_bytes()) - return - self.log("Unexpected HTTP2 frame from client: %s" % frame.human_readable(), "info") - - def handle_unexpected_frame_from_server(self, frame): - if isinstance(frame, WindowUpdateFrame): - # Servers are sending WindowUpdate frames depending on their flow control algorithm. - # Since we cannot predict these frames, and we do not need to respond to them, - # simply accept them, and hide them from the log. - # Ideally we should keep track of our own flow control window and - # stall transmission if the outgoing flow control buffer is full. - return - if isinstance(frame, GoAwayFrame): - # Server wants to terminate the connection, - # relay it to the client. - self.client_conn.send(frame.to_bytes()) - return - self.log("Unexpected HTTP2 frame from server: %s" % frame.human_readable(), "info") + raise NotImplementedError() class ConnectServerConnection(object): @@ -285,7 +120,7 @@ class UpstreamConnectLayer(Layer): def set_server(self, address, server_tls=None, sni=None): if self.ctx.server_conn: self.ctx.disconnect() - address = Address.wrap(address) + address = tcp.Address.wrap(address) self.connect_request.host = address.host self.connect_request.port = address.port self.server_conn.address = address @@ -400,7 +235,8 @@ class HttpLayer(Layer): try: response = make_error_response(code, message) self.send_response(response) - except NetlibException: + except (NetlibException, H2Error): + self.log(traceback.format_exc(), "debug") pass def change_upstream_proxy_server(self, address): @@ -420,7 +256,7 @@ class HttpLayer(Layer): layer() def send_response_to_client(self, flow): - if not (self.supports_streaming and flow.response.stream): + if not flow.response.stream: # no streaming: # we already received the full response from the server and can # send it to the client straight away. @@ -441,10 +277,7 @@ class HttpLayer(Layer): def get_response_from_server(self, flow): def get_response(): self.send_request(flow.request) - if self.supports_streaming: - flow.response = self.read_response_headers() - else: - flow.response = self.read_response(flow.request) + flow.response = self.read_response_headers() try: get_response() @@ -474,15 +307,14 @@ class HttpLayer(Layer): if flow == Kill: raise Kill() - if self.supports_streaming: - if flow.response.stream: - flow.response.data.content = CONTENT_MISSING - else: - flow.response.data.content = b"".join(self.read_response_body( - flow.request, - flow.response - )) - flow.response.timestamp_end = utils.timestamp() + if flow.response.stream: + flow.response.data.content = CONTENT_MISSING + else: + flow.response.data.content = b"".join(self.read_response_body( + flow.request, + flow.response + )) + flow.response.timestamp_end = utils.timestamp() # no further manipulation of self.server_conn beyond this point # we can safely set it as the final attribute value here. diff --git a/libmproxy/protocol/http1.py b/libmproxy/protocol/http1.py new file mode 100644 index 00000000..fc2cf07a --- /dev/null +++ b/libmproxy/protocol/http1.py @@ -0,0 +1,70 @@ +from __future__ import (absolute_import, print_function, division) + +import six + +from netlib import tcp +from netlib.http import http1 + +from .http import _HttpTransmissionLayer, HttpLayer +from .. import utils +from ..models import HTTPRequest, HTTPResponse + + +class Http1Layer(_HttpTransmissionLayer): + + def __init__(self, ctx, mode): + super(Http1Layer, self).__init__(ctx) + self.mode = mode + + def read_request(self): + req = http1.read_request(self.client_conn.rfile, body_size_limit=self.config.body_size_limit) + return HTTPRequest.wrap(req) + + def read_request_body(self, request): + expected_size = http1.expected_http_body_size(request) + return http1.read_body(self.client_conn.rfile, expected_size, self.config.body_size_limit) + + def send_request(self, request): + self.server_conn.wfile.write(http1.assemble_request(request)) + self.server_conn.wfile.flush() + + def read_response_headers(self): + resp = http1.read_response_head(self.server_conn.rfile) + return HTTPResponse.wrap(resp) + + def read_response_body(self, request, response): + expected_size = http1.expected_http_body_size(request, response) + return http1.read_body(self.server_conn.rfile, expected_size, self.config.body_size_limit) + + def send_response_headers(self, response): + raw = http1.assemble_response_head(response) + self.client_conn.wfile.write(raw) + self.client_conn.wfile.flush() + + def send_response_body(self, response, chunks): + for chunk in http1.assemble_body(response.headers, chunks): + self.client_conn.wfile.write(chunk) + self.client_conn.wfile.flush() + + def check_close_connection(self, flow): + request_close = http1.connection_close( + flow.request.http_version, + flow.request.headers + ) + response_close = http1.connection_close( + flow.response.http_version, + flow.response.headers + ) + read_until_eof = http1.expected_http_body_size(flow.request, flow.response) == -1 + close_connection = request_close or response_close or read_until_eof + if flow.request.form_in == "authority" and flow.response.status_code == 200: + # Workaround for https://github.com/mitmproxy/mitmproxy/issues/313: + # Charles Proxy sends a CONNECT response with HTTP/1.0 + # and no Content-Length header + + return False + return close_connection + + def __call__(self): + layer = HttpLayer(self, self.mode) + layer() diff --git a/libmproxy/protocol/http2.py b/libmproxy/protocol/http2.py new file mode 100644 index 00000000..04ff8bf6 --- /dev/null +++ b/libmproxy/protocol/http2.py @@ -0,0 +1,434 @@ +from __future__ import (absolute_import, print_function, division) + +import threading +import time +import Queue + +from netlib.tcp import ssl_read_select +from netlib.exceptions import HttpException +from netlib.http import Headers +from netlib.utils import http2_read_raw_frame + +import hyperframe +import h2 +from h2.connection import H2Connection +from h2.events import * + +from .base import Layer +from .http import _HttpTransmissionLayer, HttpLayer +from .. import utils +from ..models import HTTPRequest, HTTPResponse +from ..exceptions import HttpProtocolException +from ..exceptions import ProtocolException + + +class SafeH2Connection(H2Connection): + + def __init__(self, conn, *args, **kwargs): + super(SafeH2Connection, self).__init__(*args, **kwargs) + self.conn = conn + self.lock = threading.RLock() + + def safe_close_connection(self, error_code): + with self.lock: + self.close_connection(error_code) + self.conn.send(self.data_to_send()) + + def safe_increment_flow_control(self, stream_id, length): + if length == 0: + return + + with self.lock: + self.increment_flow_control_window(length) + self.conn.send(self.data_to_send()) + with self.lock: + if stream_id in self.streams and not self.streams[stream_id].closed: + self.increment_flow_control_window(length, stream_id=stream_id) + self.conn.send(self.data_to_send()) + + def safe_reset_stream(self, stream_id, error_code): + with self.lock: + try: + self.reset_stream(stream_id, error_code) + except h2.exceptions.StreamClosedError: + # stream is already closed - good + pass + self.conn.send(self.data_to_send()) + + def safe_update_settings(self, new_settings): + with self.lock: + self.update_settings(new_settings) + self.conn.send(self.data_to_send()) + + def safe_send_headers(self, is_zombie, stream_id, headers): + with self.lock: + if is_zombie(): + return + self.send_headers(stream_id, headers) + self.conn.send(self.data_to_send()) + + def safe_send_body(self, is_zombie, stream_id, chunks): + for chunk in chunks: + position = 0 + while position < len(chunk): + self.lock.acquire() + if is_zombie(): + self.lock.release() + return + max_outbound_frame_size = self.max_outbound_frame_size + frame_chunk = chunk[position:position + max_outbound_frame_size] + if self.local_flow_control_window(stream_id) < len(frame_chunk): + self.lock.release() + time.sleep(0) + continue + self.send_data(stream_id, frame_chunk) + self.conn.send(self.data_to_send()) + self.lock.release() + position += max_outbound_frame_size + with self.lock: + if is_zombie(): + return + self.end_stream(stream_id) + self.conn.send(self.data_to_send()) + + +class Http2Layer(Layer): + + def __init__(self, ctx, mode): + super(Http2Layer, self).__init__(ctx) + self.mode = mode + self.streams = dict() + self.client_reset_streams = [] + self.server_reset_streams = [] + self.server_to_client_stream_ids = dict([(0, 0)]) + self.client_conn.h2 = SafeH2Connection(self.client_conn, client_side=False) + + # make sure that we only pass actual SSL.Connection objects in here, + # because otherwise ssl_read_select fails! + self.active_conns = [self.client_conn.connection] + + def _initiate_server_conn(self): + self.server_conn.h2 = SafeH2Connection(self.server_conn, client_side=True) + self.server_conn.h2.initiate_connection() + self.server_conn.send(self.server_conn.h2.data_to_send()) + self.active_conns.append(self.server_conn.connection) + + def connect(self): # pragma: no cover + raise ValueError("CONNECT inside an HTTP2 stream is not supported.") + # self.ctx.connect() + # self.server_conn.connect() + # self._initiate_server_conn() + + def set_server(self): # pragma: no cover + raise NotImplementedError("Cannot change server for HTTP2 connections.") + + def disconnect(self): # pragma: no cover + raise NotImplementedError("Cannot dis- or reconnect in HTTP2 connections.") + + def next_layer(self): # pragma: no cover + # WebSockets over HTTP/2? + # CONNECT for proxying? + raise NotImplementedError() + + def _handle_event(self, event, source_conn, other_conn, is_server): + if hasattr(event, 'stream_id'): + if is_server and event.stream_id % 2 == 1: + eid = self.server_to_client_stream_ids[event.stream_id] + else: + eid = event.stream_id + + if isinstance(event, RequestReceived): + headers = Headers([[str(k), str(v)] for k, v in event.headers]) + self.streams[eid] = Http2SingleStreamLayer(self, eid, headers) + self.streams[eid].timestamp_start = time.time() + self.streams[eid].start() + elif isinstance(event, ResponseReceived): + headers = Headers([[str(k), str(v)] for k, v in event.headers]) + self.streams[eid].queued_data_length = 0 + self.streams[eid].timestamp_start = time.time() + self.streams[eid].response_headers = headers + self.streams[eid].response_arrived.set() + elif isinstance(event, DataReceived): + if self.config.body_size_limit and self.streams[eid].queued_data_length > self.config.body_size_limit: + raise HttpException("HTTP body too large. Limit is {}.".format(self.config.body_size_limit)) + self.streams[eid].data_queue.put(event.data) + self.streams[eid].queued_data_length += len(event.data) + source_conn.h2.safe_increment_flow_control(event.stream_id, event.flow_controlled_length) + elif isinstance(event, StreamEnded): + self.streams[eid].timestamp_end = time.time() + self.streams[eid].data_finished.set() + elif isinstance(event, StreamReset): + self.streams[eid].zombie = time.time() + self.client_reset_streams.append(self.streams[eid].client_stream_id) + if self.streams[eid].server_stream_id: + self.server_reset_streams.append(self.streams[eid].server_stream_id) + if eid in self.streams and event.error_code == 0x8: + if is_server: + other_stream_id = self.streams[eid].client_stream_id + else: + other_stream_id = self.streams[eid].server_stream_id + if other_stream_id is not None: + other_conn.h2.safe_reset_stream(other_stream_id, event.error_code) + elif isinstance(event, RemoteSettingsChanged): + new_settings = dict([(id, cs.new_value) for (id, cs) in event.changed_settings.iteritems()]) + other_conn.h2.safe_update_settings(new_settings) + elif isinstance(event, ConnectionTerminated): + # Do not immediately terminate the other connection. + # Some streams might be still sending data to the client. + return False + elif isinstance(event, PushedStreamReceived): + # pushed stream ids should be uniq and not dependent on race conditions + # only the parent stream id must be looked up first + parent_eid = self.server_to_client_stream_ids[event.parent_stream_id] + with self.client_conn.h2.lock: + self.client_conn.h2.push_stream(parent_eid, event.pushed_stream_id, event.headers) + + headers = Headers([[str(k), str(v)] for k, v in event.headers]) + headers['x-mitmproxy-pushed'] = 'true' + self.streams[event.pushed_stream_id] = Http2SingleStreamLayer(self, event.pushed_stream_id, headers) + self.streams[event.pushed_stream_id].timestamp_start = time.time() + self.streams[event.pushed_stream_id].pushed = True + self.streams[event.pushed_stream_id].parent_stream_id = parent_eid + self.streams[event.pushed_stream_id].timestamp_end = time.time() + self.streams[event.pushed_stream_id].request_data_finished.set() + self.streams[event.pushed_stream_id].start() + elif isinstance(event, TrailersReceived): + raise NotImplementedError() + + return True + + def _cleanup_streams(self): + death_time = time.time() - 10 + for stream_id in self.streams.keys(): + zombie = self.streams[stream_id].zombie + if zombie and zombie <= death_time: + self.streams.pop(stream_id, None) + + def __call__(self): + if self.server_conn: + self._initiate_server_conn() + + preamble = self.client_conn.rfile.read(24) + self.client_conn.h2.initiate_connection() + self.client_conn.h2.receive_data(preamble) + self.client_conn.send(self.client_conn.h2.data_to_send()) + + while True: + r = ssl_read_select(self.active_conns, 1) + for conn in r: + source_conn = self.client_conn if conn == self.client_conn.connection else self.server_conn + other_conn = self.server_conn if conn == self.client_conn.connection else self.client_conn + is_server = (conn == self.server_conn.connection) + + with source_conn.h2.lock: + try: + raw_frame = b''.join(http2_read_raw_frame(source_conn.rfile)) + except: + for stream in self.streams.values(): + stream.zombie = time.time() + return + + + frame, _ = hyperframe.frame.Frame.parse_frame_header(raw_frame[:9]) + + if is_server: + list = self.server_reset_streams + else: + list = self.client_reset_streams + if frame.stream_id in list: + # this frame belongs to a reset stream - just ignore it + if isinstance(frame, hyperframe.frame.HeadersFrame) or isinstance(frame, hyperframe.frame.ContinuationFrame): + # we need to keep the hpack-decoder happy too + source_conn.h2.decoder.decode(raw_frame[9:]) + continue + + events = source_conn.h2.receive_data(raw_frame) + source_conn.send(source_conn.h2.data_to_send()) + + for event in events: + if not self._handle_event(event, source_conn, other_conn, is_server): + return + + self._cleanup_streams() + + +class Http2SingleStreamLayer(_HttpTransmissionLayer, threading.Thread): + + def __init__(self, ctx, stream_id, request_headers): + super(Http2SingleStreamLayer, self).__init__(ctx) + self.zombie = None + self.client_stream_id = stream_id + self.server_stream_id = None + self.request_headers = request_headers + self.response_headers = None + self.pushed = False + + self.request_data_queue = Queue.Queue() + self.request_queued_data_length = 0 + self.request_data_finished = threading.Event() + + self.response_arrived = threading.Event() + self.response_data_queue = Queue.Queue() + self.response_queued_data_length = 0 + self.response_data_finished = threading.Event() + + @property + def data_queue(self): + if self.response_arrived.is_set(): + return self.response_data_queue + else: + return self.request_data_queue + + @property + def queued_data_length(self): + if self.response_arrived.is_set(): + return self.response_queued_data_length + else: + return self.request_queued_data_length + + @property + def data_finished(self): + if self.response_arrived.is_set(): + return self.response_data_finished + else: + return self.request_data_finished + + @queued_data_length.setter + def queued_data_length(self, v): + if self.response_arrived.is_set(): + return self.response_queued_data_length + else: + return self.request_queued_data_length + + def is_zombie(self): + return self.zombie is not None + + def read_request(self): + self.request_data_finished.wait() + + authority = self.request_headers.get(':authority', '') + method = self.request_headers.get(':method', 'GET') + scheme = self.request_headers.get(':scheme', 'https') + path = self.request_headers.get(':path', '/') + host = None + port = None + + if path == '*' or path.startswith("/"): + form_in = "relative" + elif method == 'CONNECT': # pragma: no cover + raise NotImplementedError("CONNECT over HTTP/2 is not implemented.") + else: # pragma: no cover + form_in = "absolute" + # FIXME: verify if path or :host contains what we need + scheme, host, port, _ = utils.parse_url(path) + + if authority: + host, _, port = authority.partition(':') + + if not host: + host = 'localhost' + if not port: + port = 443 if scheme == 'https' else 80 + port = int(port) + + data = [] + while self.request_data_queue.qsize() > 0: + data.append(self.request_data_queue.get()) + data = b"".join(data) + + return HTTPRequest( + form_in, + method, + scheme, + host, + port, + path, + b"HTTP/2.0", + self.request_headers, + data, + timestamp_start=self.timestamp_start, + timestamp_end=self.timestamp_end, + ) + + def send_request(self, message): + if self.pushed: + # nothing to do here + return + + with self.server_conn.h2.lock: + # We must not assign a stream id if we are already a zombie. + if self.zombie: + return + + self.server_stream_id = self.server_conn.h2.get_next_available_stream_id() + self.server_to_client_stream_ids[self.server_stream_id] = self.client_stream_id + + self.server_conn.h2.safe_send_headers( + self.is_zombie, + self.server_stream_id, + message.headers + ) + self.server_conn.h2.safe_send_body( + self.is_zombie, + self.server_stream_id, + message.body + ) + + def read_response_headers(self): + self.response_arrived.wait() + + status_code = int(self.response_headers.get(':status', 502)) + + return HTTPResponse( + http_version=b"HTTP/2.0", + status_code=status_code, + reason='', + headers=self.response_headers, + content=None, + timestamp_start=self.timestamp_start, + timestamp_end=self.timestamp_end, + ) + + def read_response_body(self, request, response): + while True: + try: + yield self.response_data_queue.get(timeout=1) + except Queue.Empty: + pass + if self.response_data_finished.is_set(): + while self.response_data_queue.qsize() > 0: + yield self.response_data_queue.get() + return + if self.zombie: + return + + def send_response_headers(self, response): + self.client_conn.h2.safe_send_headers( + self.is_zombie, + self.client_stream_id, + response.headers + ) + + def send_response_body(self, _response, chunks): + self.client_conn.h2.safe_send_body( + self.is_zombie, + self.client_stream_id, + chunks + ) + + def check_close_connection(self, flow): + # This layer only handles a single stream. + # RFC 7540 8.1: An HTTP request/response exchange fully consumes a single stream. + return True + + def connect(self): # pragma: no cover + raise ValueError("CONNECT inside an HTTP2 stream is not supported.") + + def set_server(self, *args, **kwargs): # pragma: no cover + # do not mess with the server connection - all streams share it. + pass + + def run(self): + layer = HttpLayer(self, self.mode) + layer() + self.zombie = time.time() diff --git a/libmproxy/protocol/tls.py b/libmproxy/protocol/tls.py index af1a6055..986eb964 100644 --- a/libmproxy/protocol/tls.py +++ b/libmproxy/protocol/tls.py @@ -349,7 +349,7 @@ class TlsLayer(Layer): layer = self.ctx.next_layer(self) layer() - def __repr__(self): + def __repr__(self): # pragma: no cover if self._client_tls and self._server_tls: return "TlsLayer(client and server)" elif self._client_tls: @@ -560,5 +560,7 @@ class TlsLayer(Layer): if self._sni_from_server_change: sans.add(self._sni_from_server_change) - sans.discard(host) + # Some applications don't consider the CN and expect the hostname to be in the SANs. + # For example, Thunderbird 38 will display a warning if the remote host is only the CN. + sans.add(host) return self.config.certstore.get_cert(host, list(sans)) diff --git a/libmproxy/proxy/config.py b/libmproxy/proxy/config.py index bf765d81..a635ab19 100644 --- a/libmproxy/proxy/config.py +++ b/libmproxy/proxy/config.py @@ -180,6 +180,9 @@ def process_proxy_options(parser, options): parser.error("Certificate file does not exist: %s" % parts[1]) certs.append(parts) + if options.http2 and not tcp.HAS_ALPN: + raise RuntimeError("HTTP2 support requires OpenSSL 1.0.2 or above.") + return ProxyConfig( host=options.addr, port=options.port, diff --git a/libmproxy/proxy/server.py b/libmproxy/proxy/server.py index 750cb1a4..d208cff5 100644 --- a/libmproxy/proxy/server.py +++ b/libmproxy/proxy/server.py @@ -103,9 +103,9 @@ class ConnectionHandler(object): return Socks5Proxy(root_context) elif mode == "regular": return HttpProxy(root_context) - elif callable(mode): # pragma: nocover + elif callable(mode): # pragma: no cover return mode(root_context) - else: # pragma: nocover + else: # pragma: no cover raise ValueError("Unknown proxy mode: %s" % mode) def handle(self): diff --git a/libmproxy/stateobject.py b/libmproxy/stateobject.py index 52a8347f..9600ab09 100644 --- a/libmproxy/stateobject.py +++ b/libmproxy/stateobject.py @@ -1,52 +1,51 @@ from __future__ import absolute_import +from netlib.utils import Serializable -class StateObject(object): - +class StateObject(Serializable): """ - An object with serializable state. + An object with serializable state. - State attributes can either be serializable types(str, tuple, bool, ...) - or StateObject instances themselves. + State attributes can either be serializable types(str, tuple, bool, ...) + or StateObject instances themselves. """ - # An attribute-name -> class-or-type dict containing all attributes that - # should be serialized. If the attribute is a class, it must implement the - # StateObject protocol. - _stateobject_attributes = None - # A set() of attributes that should be ignored for short state - _stateobject_long_attributes = frozenset([]) - def from_state(self, state): - raise NotImplementedError() + _stateobject_attributes = None + """ + An attribute-name -> class-or-type dict containing all attributes that + should be serialized. If the attribute is a class, it must implement the + Serializable protocol. + """ - def get_state(self, short=False): + def get_state(self): """ - Retrieve object state. If short is true, return an abbreviated - format with long data elided. + Retrieve object state. """ state = {} for attr, cls in self._stateobject_attributes.iteritems(): - if short and attr in self._stateobject_long_attributes: - continue val = getattr(self, attr) if hasattr(val, "get_state"): - state[attr] = val.get_state(short) + state[attr] = val.get_state() else: state[attr] = val return state - def load_state(self, state): + def set_state(self, state): """ - Load object state from data returned by a get_state call. + Load object state from data returned by a get_state call. """ + state = state.copy() for attr, cls in self._stateobject_attributes.iteritems(): - if state.get(attr, None) is None: - setattr(self, attr, None) + if state.get(attr) is None: + setattr(self, attr, state.pop(attr)) else: curr = getattr(self, attr) - if hasattr(curr, "load_state"): - curr.load_state(state[attr]) + if hasattr(curr, "set_state"): + curr.set_state(state.pop(attr)) elif hasattr(cls, "from_state"): - setattr(self, attr, cls.from_state(state[attr])) - else: - setattr(self, attr, cls(state[attr])) + obj = cls.from_state(state.pop(attr)) + setattr(self, attr, obj) + else: # primitive types such as int, str, ... + setattr(self, attr, cls(state.pop(attr))) + if state: + raise RuntimeWarning("Unexpected State in __setstate__: {}".format(state)) diff --git a/libmproxy/version.py b/libmproxy/version.py index 03c2f256..25c56706 100644 --- a/libmproxy/version.py +++ b/libmproxy/version.py @@ -1,6 +1,6 @@ from __future__ import (absolute_import, print_function, division) -IVERSION = (0, 15) +IVERSION = (0, 16) VERSION = ".".join(str(i) for i in IVERSION) MINORVERSION = ".".join(str(i) for i in IVERSION[:2]) NAME = "mitmproxy" diff --git a/libmproxy/web/__init__.py b/libmproxy/web/__init__.py index 43fc993d..c48b3d09 100644 --- a/libmproxy/web/__init__.py +++ b/libmproxy/web/__init__.py @@ -2,6 +2,7 @@ from __future__ import absolute_import, print_function import collections import tornado.ioloop import tornado.httpserver + from .. import controller, flow from . import app @@ -20,7 +21,7 @@ class WebFlowView(flow.FlowView): app.ClientConnection.broadcast( type="flows", cmd="add", - data=f.get_state(short=True) + data=app._strip_content(f.get_state()) ) def _update(self, f): @@ -28,7 +29,7 @@ class WebFlowView(flow.FlowView): app.ClientConnection.broadcast( type="flows", cmd="update", - data=f.get_state(short=True) + data=app._strip_content(f.get_state()) ) def _remove(self, f): diff --git a/libmproxy/web/app.py b/libmproxy/web/app.py index 79f76013..958b8669 100644 --- a/libmproxy/web/app.py +++ b/libmproxy/web/app.py @@ -4,9 +4,38 @@ import tornado.web import tornado.websocket import logging import json + +from netlib.http import CONTENT_MISSING from .. import version, filt +def _strip_content(flow_state): + """ + Remove flow message content and cert to save transmission space. + + Args: + flow_state: The original flow state. Will be left unmodified + """ + for attr in ("request", "response"): + if attr in flow_state: + message = flow_state[attr] + if message["content"]: + message["contentLength"] = len(message["content"]) + elif message["content"] == CONTENT_MISSING: + message["contentLength"] = None + else: + message["contentLength"] = 0 + del message["content"] + + if "backup" in flow_state: + del flow_state["backup"] + flow_state["modified"] = True + + flow_state.get("server_conn", {}).pop("cert", None) + + return flow_state + + class APIError(tornado.web.HTTPError): pass @@ -100,7 +129,7 @@ class Flows(RequestHandler): def get(self): self.write(dict( - data=[f.get_state(short=True) for f in self.state.flows] + data=[_strip_content(f.get_state()) for f in self.state.flows] )) @@ -141,7 +170,7 @@ class FlowHandler(RequestHandler): elif k == "port": request.port = int(v) elif k == "headers": - request.headers.load_state(v) + request.headers.set_state(v) else: print "Warning: Unknown update {}.{}: {}".format(a, k, v) @@ -155,7 +184,7 @@ class FlowHandler(RequestHandler): elif k == "http_version": response.http_version = str(v) elif k == "headers": - response.headers.load_state(v) + response.headers.set_state(v) else: print "Warning: Unknown update {}.{}: {}".format(a, k, v) else: diff --git a/requirements.txt b/requirements.txt index 3832c953..49f86b9a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,3 @@ -e git+https://github.com/mitmproxy/netlib.git#egg=netlib -e git+https://github.com/mitmproxy/pathod.git#egg=pathod --e .[dev,examples,contentviews]
\ No newline at end of file +-e .[dev,examples,contentviews] @@ -1,7 +1,6 @@ from setuptools import setup, find_packages from codecs import open import os -import sys from libmproxy import version # Based on https://github.com/pypa/sampleproject/blob/master/setup.py @@ -12,68 +11,6 @@ here = os.path.abspath(os.path.dirname(__file__)) with open(os.path.join(here, 'README.rst'), encoding='utf-8') as f: long_description = f.read() -# Core dependencies -# Do not use the "~=" compatible release specifier. -# This will break `pip install` on systems with old setuptools versions. -deps = { - "netlib>=%s, <%s" % (version.MINORVERSION, version.NEXT_MINORVERSION), - "tornado>=4.3.0, <4.4", - "configargparse>=0.10.0, <0.11", - "pyperclip>=1.5.22, <1.6", - "blinker>=1.4, <1.5", - "pyparsing>=2.0.5, <2.1", - "html2text==2015.11.4", - "construct>=2.5.2, <2.6", - "six>=1.10.0, <1.11", - "lxml==3.4.4", # there are no Windows wheels for newer versions, so we pin this. - "Pillow>=3.0.0, <3.2", - "watchdog>=0.8.3, <0.9", -} -# A script -> additional dependencies dict. -scripts = { - "mitmproxy": { - "urwid>=1.3.1, <1.4", - }, - "mitmdump": { - "click>=6.2, <6.3", - }, - "mitmweb": set() -} -# Developer dependencies -dev_deps = { - "mock>=1.0.1", - "pytest>=2.8.0", - "pytest-xdist>=1.13.1", - "pytest-cov>=2.1.0", - "coveralls>=0.4.1", - "pathod>=%s, <%s" % (version.MINORVERSION, version.NEXT_MINORVERSION), - "sphinx>=1.3.1", - "sphinx-autobuild>=0.5.2", - "sphinxcontrib-documentedlist>=0.2", -} -example_deps = { - "pytz==2015.7", - "harparser>=0.2, <0.3", - "beautifulsoup4>=4.4.1, <4.5", -} -# Add *all* script dependencies to developer dependencies. -for script_deps in scripts.values(): - dev_deps.update(script_deps) - -# Remove mitmproxy for Windows support. -if os.name == "nt": - del scripts["mitmproxy"] - deps.add("pydivert>=0.0.7") # Transparent proxying on Windows - -# Add dependencies for available scripts as core dependencies. -for script_deps in scripts.values(): - deps.update(script_deps) - -if sys.version_info < (3, 4): - example_deps.add("enum34>=1.0.4, <1.1") - -console_scripts = ["%s = libmproxy.main:%s" % (s, s) for s in scripts.keys()] - setup( name="mitmproxy", version=version.VERSION, @@ -101,18 +38,67 @@ setup( "Topic :: Internet :: Proxy Servers", "Topic :: Software Development :: Testing" ], - packages=find_packages(), + packages=find_packages(exclude=["test", "test.*"]), include_package_data=True, entry_points={ - 'console_scripts': console_scripts}, - install_requires=list(deps), + 'console_scripts': [ + 'mitmproxy = libmproxy.main:mitmproxy', + 'mitmdump = libmproxy.main:mitmdump', + 'mitmweb = libmproxy.main:mitmweb' + ] + }, + # https://packaging.python.org/en/latest/requirements/#install-requires + # It is not considered best practice to use install_requires to pin dependencies to specific versions. + install_requires=[ + "netlib>={}, <{}".format(version.MINORVERSION, version.NEXT_MINORVERSION), + "h2>=2.1.0, <2.2", + "tornado>=4.3, <4.4", + "configargparse>=0.10, <0.11", + "pyperclip>=1.5.22, <1.6", + "blinker>=1.4, <1.5", + "pyparsing>=2.0.5, <2.1", + "html2text==2016.1.8", + "construct>=2.5.2, <2.6", + "six>=1.10, <1.11", + "Pillow>=3.1, <3.2", + "watchdog>=0.8.3, <0.9", + "click>=6.2, <7.0", + "urwid>=1.3.1, <1.4", + ], extras_require={ - 'dev': list(dev_deps), + ':sys_platform == "win32"': [ + "pydivert>=0.0.7, <0.1", + "lxml==3.4.4", # there are no Windows wheels for newer versions, so we pin this. + ], + ':sys_platform != "win32"': [ + "lxml>=3.5.0, <3.6", + ], + # Do not use a range operator here: https://bitbucket.org/pypa/setuptools/issues/380 + # Ubuntu Trusty and other still ship with setuptools < 17.1 + ':python_version == "2.7"': [ + "enum34>=1.1.2, <1.2", + ], + 'dev': [ + "mock>=1.3.0, <1.4", + "pytest>=2.8.7, <2.9", + "pytest-xdist>=1.14, <1.15", + "pytest-cov>=2.2.1, <2.3", + "pytest-timeout>=1.0.0, <1.1", + "coveralls>=1.1, <1.2", + "pathod>={}, <{}".format(version.MINORVERSION, version.NEXT_MINORVERSION), + "sphinx>=1.3.5, <1.4", + "sphinx-autobuild>=0.5.2, <0.6", + "sphinxcontrib-documentedlist>=0.3.0, <0.4" + ], 'contentviews': [ - "pyamf>=0.7.2, <0.8", + "pyamf>=0.8.0, <0.9", "protobuf>=2.6.1, <2.7", "cssutils>=1.0.1, <1.1" ], - 'examples': list(example_deps) + 'examples': [ + "pytz==2015.7.0", + "harparser>=0.2, <0.3", + "beautifulsoup4>=4.4.1, <4.5", + ] } ) diff --git a/test/test_filt.py b/test/test_filt.py index b1fd2ad9..b1f3a21f 100644 --- a/test/test_filt.py +++ b/test/test_filt.py @@ -1,7 +1,7 @@ import cStringIO from libmproxy import filt -from libmproxy.protocol import http from libmproxy.models import Error +from libmproxy.models import http from netlib.http import Headers from . import tutils diff --git a/test/test_flow.py b/test/test_flow.py index b8d1fad3..51b88fff 100644 --- a/test/test_flow.py +++ b/test/test_flow.py @@ -422,7 +422,7 @@ class TestFlow(object): assert not f == f2 f2.error = Error("e2") assert not f == f2 - f.load_state(f2.get_state()) + f.set_state(f2.get_state()) assert f.get_state() == f2.get_state() def test_kill(self): @@ -463,6 +463,11 @@ class TestFlow(object): f.response.content = "\xc2foo" f.replace("foo", u"bar") + def test_replace_no_content(self): + f = tutils.tflow() + f.request.content = CONTENT_MISSING + assert f.replace("foo", "bar") == 0 + def test_replace(self): f = tutils.tflow(resp=True) f.request.headers["foo"] = "foo" @@ -1199,7 +1204,7 @@ class TestError: e2 = Error("bar") assert not e == e2 - e.load_state(e2.get_state()) + e.set_state(e2.get_state()) assert e.get_state() == e2.get_state() e3 = e.copy() @@ -1219,7 +1224,7 @@ class TestClientConnection: assert not c == c2 c2.timestamp_start = 42 - c.load_state(c2.get_state()) + c.set_state(c2.get_state()) assert c.timestamp_start == 42 c3 = c.copy() diff --git a/test/test_protocol_http.py b/test/test_protocol_http1.py index 489be3f9..a1485f1b 100644 --- a/test/test_protocol_http.py +++ b/test/test_protocol_http1.py @@ -1,4 +1,3 @@ -from io import BytesIO from netlib.exceptions import HttpSyntaxException from netlib.http import http1 from netlib.tcp import TCPClient @@ -6,33 +5,6 @@ from netlib.tutils import treq, raises from . import tutils, tservers -class TestHTTPResponse: - - def test_read_from_stringio(self): - s = ( - b"HTTP/1.1 200 OK\r\n" - b"Content-Length: 7\r\n" - b"\r\n" - b"content\r\n" - b"HTTP/1.1 204 OK\r\n" - b"\r\n" - ) - rfile = BytesIO(s) - r = http1.read_response(rfile, treq()) - assert r.status_code == 200 - assert r.content == b"content" - assert http1.read_response(rfile, treq()).status_code == 204 - - rfile = BytesIO(s) - # HEAD must not have content by spec. We should leave it on the pipe. - r = http1.read_response(rfile, treq(method=b"HEAD")) - assert r.status_code == 200 - assert r.content == b"" - - with raises(HttpSyntaxException): - http1.read_response(rfile, treq()) - - class TestHTTPFlow(object): def test_repr(self): diff --git a/test/test_protocol_http2.py b/test/test_protocol_http2.py new file mode 100644 index 00000000..38cfdfc3 --- /dev/null +++ b/test/test_protocol_http2.py @@ -0,0 +1,431 @@ +from __future__ import (absolute_import, print_function, division) + +import OpenSSL +import pytest +import traceback +import os +import tempfile +import sys + +from libmproxy.proxy.config import ProxyConfig +from libmproxy.proxy.server import ProxyServer +from libmproxy.cmdline import APP_HOST, APP_PORT + +import logging +logging.getLogger("hyper.packages.hpack.hpack").setLevel(logging.WARNING) +logging.getLogger("requests.packages.urllib3.connectionpool").setLevel(logging.WARNING) +logging.getLogger("passlib.utils.compat").setLevel(logging.WARNING) +logging.getLogger("passlib.registry").setLevel(logging.WARNING) +logging.getLogger("PIL.Image").setLevel(logging.WARNING) +logging.getLogger("PIL.PngImagePlugin").setLevel(logging.WARNING) + +import netlib +from netlib import tservers as netlib_tservers +from netlib.utils import http2_read_raw_frame + +import h2 +from hyperframe.frame import Frame + +from libmproxy import utils +from . import tservers + +requires_alpn = pytest.mark.skipif( + not OpenSSL._util.lib.Cryptography_HAS_ALPN, + reason="requires OpenSSL with ALPN support") + + +class _Http2ServerBase(netlib_tservers.ServerTestBase): + ssl = dict(alpn_select=b'h2') + + class handler(netlib.tcp.BaseHandler): + + def handle(self): + h2_conn = h2.connection.H2Connection(client_side=False) + + preamble = self.rfile.read(24) + h2_conn.initiate_connection() + h2_conn.receive_data(preamble) + self.wfile.write(h2_conn.data_to_send()) + self.wfile.flush() + + done = False + while not done: + try: + raw = b''.join(http2_read_raw_frame(self.rfile)) + events = h2_conn.receive_data(raw) + except: + break + self.wfile.write(h2_conn.data_to_send()) + self.wfile.flush() + + for event in events: + try: + if not self.server.handle_server_event(event, h2_conn, self.rfile, self.wfile): + done = True + break + except Exception as e: + print(repr(e)) + print(traceback.format_exc()) + done = True + break + + def handle_server_event(self, h2_conn, rfile, wfile): + raise NotImplementedError() + + +class _Http2TestBase(object): + + @classmethod + def setup_class(self): + self.config = ProxyConfig(**self.get_proxy_config()) + + tmaster = tservers.TestMaster(self.config) + tmaster.start_app(APP_HOST, APP_PORT) + self.proxy = tservers.ProxyThread(tmaster) + self.proxy.start() + + @classmethod + def teardown_class(cls): + cls.proxy.shutdown() + + @property + def master(self): + return self.proxy.tmaster + + @classmethod + def get_proxy_config(cls): + cls.cadir = os.path.join(tempfile.gettempdir(), "mitmproxy") + return dict( + no_upstream_cert = False, + cadir = cls.cadir, + authenticator = None, + ) + + def setup(self): + self.master.clear_log() + self.master.state.clear() + self.server.server.handle_server_event = self.handle_server_event + + def _setup_connection(self): + self.config.http2 = True + + client = netlib.tcp.TCPClient(("127.0.0.1", self.proxy.port)) + client.connect() + + # send CONNECT request + client.wfile.write( + b"CONNECT localhost:%d HTTP/1.1\r\n" + b"Host: localhost:%d\r\n" + b"\r\n" % (self.server.server.address.port, self.server.server.address.port) + ) + client.wfile.flush() + + # read CONNECT response + while client.rfile.readline() != "\r\n": + pass + + client.convert_to_ssl(alpn_protos=[b'h2']) + + h2_conn = h2.connection.H2Connection(client_side=True) + h2_conn.initiate_connection() + client.wfile.write(h2_conn.data_to_send()) + client.wfile.flush() + + return client, h2_conn + + def _send_request(self, wfile, h2_conn, stream_id=1, headers=[], body=b''): + h2_conn.send_headers( + stream_id=stream_id, + headers=headers, + end_stream=(len(body) == 0), + ) + if body: + h2_conn.send_data(stream_id, body) + h2_conn.end_stream(stream_id) + wfile.write(h2_conn.data_to_send()) + wfile.flush() + + +@requires_alpn +class TestSimple(_Http2TestBase, _Http2ServerBase): + + @classmethod + def setup_class(self): + _Http2TestBase.setup_class() + _Http2ServerBase.setup_class() + + @classmethod + def teardown_class(self): + _Http2TestBase.teardown_class() + _Http2ServerBase.teardown_class() + + @classmethod + def handle_server_event(self, event, h2_conn, rfile, wfile): + if isinstance(event, h2.events.ConnectionTerminated): + return False + elif isinstance(event, h2.events.RequestReceived): + h2_conn.send_headers(1, [ + (':status', '200'), + ('foo', 'bar'), + ]) + h2_conn.send_data(1, b'foobar') + h2_conn.end_stream(1) + wfile.write(h2_conn.data_to_send()) + wfile.flush() + + return True + + def test_simple(self): + client, h2_conn = self._setup_connection() + + self._send_request(client.wfile, h2_conn, headers=[ + (':authority', "127.0.0.1:%s" % self.server.server.address.port), + (':method', 'GET'), + (':scheme', 'https'), + (':path', '/'), + ], body='my request body echoed back to me') + + done = False + while not done: + try: + events = h2_conn.receive_data(b''.join(http2_read_raw_frame(client.rfile))) + except: + break + client.wfile.write(h2_conn.data_to_send()) + client.wfile.flush() + + for event in events: + if isinstance(event, h2.events.StreamEnded): + done = True + + h2_conn.close_connection() + client.wfile.write(h2_conn.data_to_send()) + client.wfile.flush() + + assert len(self.master.state.flows) == 1 + assert self.master.state.flows[0].response.status_code == 200 + assert self.master.state.flows[0].response.headers['foo'] == 'bar' + assert self.master.state.flows[0].response.body == b'foobar' + + +@requires_alpn +class TestWithBodies(_Http2TestBase, _Http2ServerBase): + tmp_data_buffer_foobar = b'' + + @classmethod + def setup_class(self): + _Http2TestBase.setup_class() + _Http2ServerBase.setup_class() + + @classmethod + def teardown_class(self): + _Http2TestBase.teardown_class() + _Http2ServerBase.teardown_class() + + @classmethod + def handle_server_event(self, event, h2_conn, rfile, wfile): + if isinstance(event, h2.events.ConnectionTerminated): + return False + if isinstance(event, h2.events.DataReceived): + self.tmp_data_buffer_foobar += event.data + elif isinstance(event, h2.events.StreamEnded): + h2_conn.send_headers(1, [ + (':status', '200'), + ]) + h2_conn.send_data(1, self.tmp_data_buffer_foobar) + h2_conn.end_stream(1) + wfile.write(h2_conn.data_to_send()) + wfile.flush() + + return True + + def test_with_bodies(self): + client, h2_conn = self._setup_connection() + + self._send_request( + client.wfile, + h2_conn, + headers=[ + (':authority', "127.0.0.1:%s" % self.server.server.address.port), + (':method', 'GET'), + (':scheme', 'https'), + (':path', '/'), + ], + body='foobar with request body', + ) + + done = False + while not done: + try: + events = h2_conn.receive_data(b''.join(http2_read_raw_frame(client.rfile))) + except: + break + client.wfile.write(h2_conn.data_to_send()) + client.wfile.flush() + + for event in events: + if isinstance(event, h2.events.StreamEnded): + done = True + + h2_conn.close_connection() + client.wfile.write(h2_conn.data_to_send()) + client.wfile.flush() + + assert self.master.state.flows[0].response.body == b'foobar with request body' + + +@requires_alpn +class TestPushPromise(_Http2TestBase, _Http2ServerBase): + + @classmethod + def setup_class(self): + _Http2TestBase.setup_class() + _Http2ServerBase.setup_class() + + @classmethod + def teardown_class(self): + _Http2TestBase.teardown_class() + _Http2ServerBase.teardown_class() + + @classmethod + def handle_server_event(self, event, h2_conn, rfile, wfile): + if isinstance(event, h2.events.ConnectionTerminated): + return False + elif isinstance(event, h2.events.RequestReceived): + if event.stream_id != 1: + # ignore requests initiated by push promises + return True + + h2_conn.send_headers(1, [(':status', '200')]) + h2_conn.push_stream(1, 2, [ + (':authority', "127.0.0.1:%s" % self.port), + (':method', 'GET'), + (':scheme', 'https'), + (':path', '/pushed_stream_foo'), + ('foo', 'bar') + ]) + h2_conn.push_stream(1, 4, [ + (':authority', "127.0.0.1:%s" % self.port), + (':method', 'GET'), + (':scheme', 'https'), + (':path', '/pushed_stream_bar'), + ('foo', 'bar') + ]) + wfile.write(h2_conn.data_to_send()) + wfile.flush() + + h2_conn.send_headers(2, [(':status', '200')]) + h2_conn.send_headers(4, [(':status', '200')]) + wfile.write(h2_conn.data_to_send()) + wfile.flush() + + h2_conn.send_data(1, b'regular_stream') + h2_conn.send_data(2, b'pushed_stream_foo') + h2_conn.send_data(4, b'pushed_stream_bar') + wfile.write(h2_conn.data_to_send()) + wfile.flush() + h2_conn.end_stream(1) + h2_conn.end_stream(2) + h2_conn.end_stream(4) + wfile.write(h2_conn.data_to_send()) + wfile.flush() + + return True + + def test_push_promise(self): + client, h2_conn = self._setup_connection() + + self._send_request(client.wfile, h2_conn, stream_id=1, headers=[ + (':authority', "127.0.0.1:%s" % self.server.server.address.port), + (':method', 'GET'), + (':scheme', 'https'), + (':path', '/'), + ('foo', 'bar') + ]) + + done = False + ended_streams = 0 + pushed_streams = 0 + responses = 0 + while not done: + try: + raw = b''.join(http2_read_raw_frame(client.rfile)) + events = h2_conn.receive_data(raw) + except: + break + client.wfile.write(h2_conn.data_to_send()) + client.wfile.flush() + + for event in events: + if isinstance(event, h2.events.StreamEnded): + ended_streams += 1 + elif isinstance(event, h2.events.PushedStreamReceived): + pushed_streams += 1 + elif isinstance(event, h2.events.ResponseReceived): + responses += 1 + if isinstance(event, h2.events.ConnectionTerminated): + done = True + + if responses == 3 and ended_streams == 3 and pushed_streams == 2: + done = True + + h2_conn.close_connection() + client.wfile.write(h2_conn.data_to_send()) + client.wfile.flush() + + assert ended_streams == 3 + assert pushed_streams == 2 + + bodies = [flow.response.body for flow in self.master.state.flows] + assert len(bodies) == 3 + assert b'regular_stream' in bodies + assert b'pushed_stream_foo' in bodies + assert b'pushed_stream_bar' in bodies + + def test_push_promise_reset(self): + client, h2_conn = self._setup_connection() + + self._send_request(client.wfile, h2_conn, stream_id=1, headers=[ + (':authority', "127.0.0.1:%s" % self.server.server.address.port), + (':method', 'GET'), + (':scheme', 'https'), + (':path', '/'), + ('foo', 'bar') + ]) + + done = False + ended_streams = 0 + pushed_streams = 0 + responses = 0 + while not done: + try: + events = h2_conn.receive_data(b''.join(http2_read_raw_frame(client.rfile))) + except: + break + client.wfile.write(h2_conn.data_to_send()) + client.wfile.flush() + + for event in events: + if isinstance(event, h2.events.StreamEnded) and event.stream_id == 1: + ended_streams += 1 + elif isinstance(event, h2.events.PushedStreamReceived): + pushed_streams += 1 + h2_conn.reset_stream(event.pushed_stream_id, error_code=0x8) + client.wfile.write(h2_conn.data_to_send()) + client.wfile.flush() + elif isinstance(event, h2.events.ResponseReceived): + responses += 1 + if isinstance(event, h2.events.ConnectionTerminated): + done = True + + if responses >= 1 and ended_streams >= 1 and pushed_streams == 2: + done = True + + h2_conn.close_connection() + client.wfile.write(h2_conn.data_to_send()) + client.wfile.flush() + + bodies = [flow.response.body for flow in self.master.state.flows if flow.response] + assert len(bodies) >= 1 + assert b'regular_stream' in bodies + # the other two bodies might not be transmitted before the reset diff --git a/test/tutils.py b/test/tutils.py index 5bd91307..2ce0884d 100644 --- a/test/tutils.py +++ b/test/tutils.py @@ -76,7 +76,11 @@ def tclient_conn(): """ c = ClientConnection.from_state(dict( address=dict(address=("address", 22), use_ipv6=True), - clientcert=None + clientcert=None, + ssl_established=False, + timestamp_start=1, + timestamp_ssl_setup=2, + timestamp_end=3, )) c.reply = controller.DummyReply() return c @@ -88,9 +92,15 @@ def tserver_conn(): """ c = ServerConnection.from_state(dict( address=dict(address=("address", 22), use_ipv6=True), - state=[], source_address=dict(address=("address", 22), use_ipv6=True), - cert=None + cert=None, + timestamp_start=1, + timestamp_tcp_setup=2, + timestamp_ssl_setup=3, + timestamp_end=4, + ssl_established=False, + sni="address", + via=None )) c.reply = controller.DummyReply() return c |