aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorMaximilian Hils <git@maximilianhils.com>2014-02-05 20:26:47 +0100
committerMaximilian Hils <git@maximilianhils.com>2014-02-05 20:26:47 +0100
commitf26d91cb814436fa5c1290459f5313e6831bd53c (patch)
treee54d13d395f7746b5b92637dd642645ceb34cd58
parent9a55cd733268ff66c19ff6fead18291ec8342d8c (diff)
downloadmitmproxy-f26d91cb814436fa5c1290459f5313e6831bd53c.tar.gz
mitmproxy-f26d91cb814436fa5c1290459f5313e6831bd53c.tar.bz2
mitmproxy-f26d91cb814436fa5c1290459f5313e6831bd53c.zip
add skeleton to change destinatin server during intercept, fix all testcases on windows
-rw-r--r--libmproxy/flow.py19
-rw-r--r--libmproxy/protocol/__init__.py16
-rw-r--r--libmproxy/protocol/http.py137
-rw-r--r--libmproxy/protocol/primitives.py11
-rw-r--r--libmproxy/proxy.py18
-rw-r--r--test/test_dump.py4
-rw-r--r--test/test_flow.py27
-rw-r--r--test/test_proxy.py4
-rw-r--r--test/test_script.py2
9 files changed, 159 insertions, 79 deletions
diff --git a/libmproxy/flow.py b/libmproxy/flow.py
index bf9171a7..55ff109e 100644
--- a/libmproxy/flow.py
+++ b/libmproxy/flow.py
@@ -249,10 +249,9 @@ class StickyCookieState:
"""
Returns a (domain, port, path) tuple.
"""
- raise NotImplementedError
return (
- m["domain"] or f.request.host,
- f.server_conn.address.port,
+ m["domain"] or f.request.get_host(),
+ f.request.get_port(),
m["path"] or "/"
)
@@ -270,7 +269,7 @@ class StickyCookieState:
c = Cookie.SimpleCookie(str(i))
m = c.values()[0]
k = self.ckey(m, f)
- if self.domain_match(f.request.host, k[0]):
+ if self.domain_match(f.request.get_host(), k[0]):
self.jar[self.ckey(m, f)] = m
def handle_request(self, f):
@@ -278,8 +277,8 @@ class StickyCookieState:
if f.match(self.flt):
for i in self.jar.keys():
match = [
- self.domain_match(f.request.host, i[0]),
- f.request.port == i[1],
+ self.domain_match(f.request.get_host(), i[0]),
+ f.request.get_port() == i[1],
f.request.path.startswith(i[2])
]
if all(match):
@@ -298,12 +297,12 @@ class StickyAuthState:
self.hosts = {}
def handle_request(self, f):
- raise NotImplementedError
+ host = f.request.get_host()
if "authorization" in f.request.headers:
- self.hosts[f.request.host] = f.request.headers["authorization"]
+ self.hosts[host] = f.request.headers["authorization"]
elif f.match(self.flt):
- if f.request.host in self.hosts:
- f.request.headers["authorization"] = self.hosts[f.request.host]
+ if host in self.hosts:
+ f.request.headers["authorization"] = self.hosts[host]
class State(object):
diff --git a/libmproxy/protocol/__init__.py b/libmproxy/protocol/__init__.py
index da85500b..f23159b2 100644
--- a/libmproxy/protocol/__init__.py
+++ b/libmproxy/protocol/__init__.py
@@ -28,6 +28,22 @@ class ProtocolHandler(object):
"""
raise error
+
+class TemporaryServerChangeMixin(object):
+ """
+ This mixin allows safe modification of the target server,
+ without any need to expose the ConnectionHandler to the Flow.
+ """
+
+ def change_server(self):
+ self._backup_server = True
+ raise NotImplementedError
+
+ def restore_server(self):
+ if not hasattr(self,"_backup_server"):
+ return
+ raise NotImplementedError
+
from . import http, tcp
protocols = {
diff --git a/libmproxy/protocol/http.py b/libmproxy/protocol/http.py
index 636e1b07..069030ef 100644
--- a/libmproxy/protocol/http.py
+++ b/libmproxy/protocol/http.py
@@ -3,9 +3,9 @@ from email.utils import parsedate_tz, formatdate, mktime_tz
import netlib.utils
from netlib import http, tcp, http_status, odict
from netlib.odict import ODict, ODictCaseless
-from . import ProtocolHandler, ConnectionTypeChange, KILL
+from . import ProtocolHandler, ConnectionTypeChange, KILL, TemporaryServerChangeMixin
from .. import encoding, utils, version, filt, controller, stateobject
-from ..proxy import ProxyError, AddressPriority
+from ..proxy import ProxyError, AddressPriority, ServerConnection
from .primitives import Flow, Error
@@ -55,10 +55,24 @@ class decoded(object):
class HTTPMessage(stateobject.SimpleStateObject):
- def __init__(self):
+ def __init__(self, httpversion, headers, content, timestamp_start=None, timestamp_end=None):
+ self.httpversion = httpversion
+ self.headers = headers
+ self.content = content
+ self.timestamp_start = timestamp_start
+ self.timestamp_end = timestamp_end
+
self.flow = None # will usually be set by the flow backref mixin
"""@type: HTTPFlow"""
+ _stateobject_attributes = dict(
+ httpversion=tuple,
+ headers=ODictCaseless,
+ content=str,
+ timestamp_start=float,
+ timestamp_end=float
+ )
+
def get_decoded_content(self):
"""
Returns the decoded content based on the current Content-Encoding header.
@@ -199,7 +213,7 @@ class HTTPRequest(HTTPMessage):
def __init__(self, form_in, method, scheme, host, port, path, httpversion, headers, content,
timestamp_start=None, timestamp_end=None, form_out=None):
assert isinstance(headers, ODictCaseless) or not headers
- HTTPMessage.__init__(self)
+ HTTPMessage.__init__(self, httpversion, headers, content, timestamp_start, timestamp_end)
self.form_in = form_in
self.method = method
@@ -208,10 +222,6 @@ class HTTPRequest(HTTPMessage):
self.port = port
self.path = path
self.httpversion = httpversion
- self.headers = headers
- self.content = content
- self.timestamp_start = timestamp_start
- self.timestamp_end = timestamp_end
self.form_out = form_out or form_in
# Have this request's cookies been modified by sticky cookies or auth?
@@ -220,18 +230,14 @@ class HTTPRequest(HTTPMessage):
# Is this request replayed?
self.is_replay = False
- _stateobject_attributes = dict(
+ _stateobject_attributes = HTTPMessage._stateobject_attributes.copy()
+ _stateobject_attributes.update(
form_in=str,
method=str,
scheme=str,
host=str,
port=int,
path=str,
- httpversion=tuple,
- headers=ODictCaseless,
- content=str,
- timestamp_start=float,
- timestamp_end=float,
form_out=str
)
@@ -437,15 +443,13 @@ class HTTPRequest(HTTPMessage):
query = utils.urlencode(odict.lst)
self.set_url(urlparse.urlunparse([scheme, netloc, path, params, query, fragment]))
- def get_url(self, hostheader=False):
+ def get_host(self, hostheader=False):
"""
- Returns a URL string, constructed from the Request's URL components.
-
- If hostheader is True, we use the value specified in the request
- Host header to construct the URL.
+ Heuristic to get the host of the request.
+ The host is not necessarily equal to the TCP destination of the request,
+ for example on a transparently proxified absolute-form request to an upstream HTTP proxy.
+ If hostheader is set to True, the Host: header will be used as additional (and preferred) data source.
"""
- raise NotImplementedError
- # FIXME: Take server_conn into account.
host = None
if hostheader:
host = self.headers.get_first("host")
@@ -455,7 +459,35 @@ class HTTPRequest(HTTPMessage):
else:
host = self.flow.server_conn.address.host
host = host.encode("idna")
- return utils.unparse_url(self.scheme, host, self.port, self.path).encode('ascii')
+ return host
+
+ def get_scheme(self):
+ """
+ Returns the request port, either from the request itself or from the flow's server connection
+ """
+ if self.scheme:
+ return self.scheme
+ return "https" if self.flow.server_conn.ssl_established else "http"
+
+ def get_port(self):
+ """
+ Returns the request port, either from the request itself or from the flow's server connection
+ """
+ if self.port:
+ return self.port
+ return self.flow.server_conn.address.port
+
+ def get_url(self, hostheader=False):
+ """
+ Returns a URL string, constructed from the Request's URL components.
+
+ If hostheader is True, we use the value specified in the request
+ Host header to construct the URL.
+ """
+ return utils.unparse_url(self.get_scheme(),
+ self.get_host(hostheader),
+ self.get_port(),
+ self.path).encode('ascii')
def set_url(self, url):
"""
@@ -464,12 +496,30 @@ class HTTPRequest(HTTPMessage):
Returns False if the URL was invalid, True if the request succeeded.
"""
- raise NotImplementedError
- # FIXME: Needs to update server_conn as well.
parts = http.parse_url(url)
if not parts:
return False
- self.scheme, self.host, self.port, self.path = parts
+ scheme, host, port, path = parts
+ is_ssl = (True if scheme == "https" else False)
+
+ self.path = path
+
+ if host != self.get_host() or port != self.get_port():
+ if self.flow.change_server:
+ self.flow.change_server((host, port), ssl=is_ssl)
+ else:
+ # There's not live server connection, we're just changing the attributes here.
+ self.flow.server_conn = ServerConnection((host, port), AddressPriority.MANUALLY_CHANGED)
+ self.flow.server_conn.ssl_established = is_ssl
+
+ # If this is an absolute request, replace the attributes on the request object as well.
+ if self.host:
+ self.host = host
+ if self.port:
+ self.port = port
+ if self.scheme:
+ self.scheme = scheme
+
return True
def get_cookies(self):
@@ -521,34 +571,25 @@ class HTTPResponse(HTTPMessage):
timestamp_end: Timestamp indicating when request transmission ended
"""
- def __init__(self, httpversion, code, msg, headers, content, timestamp_start, timestamp_end):
+ def __init__(self, httpversion, code, msg, headers, content, timestamp_start=None, timestamp_end=None):
assert isinstance(headers, ODictCaseless) or headers is None
- HTTPMessage.__init__(self)
+ HTTPMessage.__init__(self, httpversion, headers, content, timestamp_start, timestamp_end)
- self.httpversion = httpversion
self.code = code
self.msg = msg
- self.headers = headers
- self.content = content
- self.timestamp_start = timestamp_start
- self.timestamp_end = timestamp_end
# Is this request replayed?
self.is_replay = False
- _stateobject_attributes = dict(
- httpversion=tuple,
+ _stateobject_attributes = HTTPMessage._stateobject_attributes.copy()
+ _stateobject_attributes.update(
code=int,
- msg=str,
- headers=ODictCaseless,
- content=str,
- timestamp_start=float,
- timestamp_end=float
+ msg=str
)
@classmethod
def _from_state(cls, state):
- f = cls(None, None, None, None, None, None, None)
+ f = cls(None, None, None, None, None)
f._load_state(state)
return f
@@ -688,13 +729,15 @@ class HTTPFlow(Flow):
intercepting: Is this flow currently being intercepted?
"""
- def __init__(self, client_conn, server_conn):
+ def __init__(self, client_conn, server_conn, change_server=None):
Flow.__init__(self, "http", client_conn, server_conn)
self.request = None
+ """@type: HTTPRequest"""
self.response = None
+ """@type: HTTPResponse"""
+ self.change_server = None # Used by flow.request.set_url to change the server address
self.intercepting = False # FIXME: Should that rather be an attribute of Flow?
- self._backup = None
_backrefattr = Flow._backrefattr + ("request", "response")
@@ -787,13 +830,15 @@ class HttpAuthenticationError(Exception):
return "HttpAuthenticationError"
-class HTTPHandler(ProtocolHandler):
+class HTTPHandler(ProtocolHandler, TemporaryServerChangeMixin):
+
def handle_messages(self):
while self.handle_flow():
pass
self.c.close = True
def get_response_from_server(self, request):
+ self.c.establish_server_connection()
request_raw = request._assemble()
for i in range(2):
@@ -818,7 +863,7 @@ class HTTPHandler(ProtocolHandler):
raise v
def handle_flow(self):
- flow = HTTPFlow(self.c.client_conn, self.c.server_conn)
+ flow = HTTPFlow(self.c.client_conn, self.c.server_conn, self.change_server)
try:
flow.request = HTTPRequest.from_stream(self.c.client_conn.rfile,
body_size_limit=self.c.config.body_size_limit)
@@ -833,7 +878,6 @@ class HTTPHandler(ProtocolHandler):
if isinstance(request_reply, HTTPResponse):
flow.response = request_reply
else:
- self.c.establish_server_connection()
flow.response = self.get_response_from_server(flow.request)
self.c.log("response", [flow.response._assemble_first_line()])
@@ -855,7 +899,8 @@ class HTTPHandler(ProtocolHandler):
self.ssl_upgrade(flow.request)
flow.server_conn = self.c.server_conn
-
+ self.restore_server() # If the user has changed the target server on this connection,
+ # restore the original target server
return True
except (HttpAuthenticationError, http.HttpError, ProxyError, tcp.NetLibError), e:
self.handle_error(e, flow)
diff --git a/libmproxy/protocol/primitives.py b/libmproxy/protocol/primitives.py
index f3fdd245..d1546ddd 100644
--- a/libmproxy/protocol/primitives.py
+++ b/libmproxy/protocol/primitives.py
@@ -3,7 +3,7 @@ from ..proxy import ServerConnection, ClientConnection
import copy
-class _BackreferenceMixin(object):
+class BackreferenceMixin(object):
"""
If an attribute from the _backrefattr tuple is set,
this mixin sets a reference back on the attribute object.
@@ -16,7 +16,7 @@ class _BackreferenceMixin(object):
_backrefattr = tuple()
def __setattr__(self, key, value):
- super(_BackreferenceMixin, self).__setattr__(key, value)
+ super(BackreferenceMixin, self).__setattr__(key, value)
if key in self._backrefattr and value is not None:
setattr(value, self._backrefname, self)
@@ -61,12 +61,17 @@ class Error(stateobject.SimpleStateObject):
return c
-class Flow(stateobject.SimpleStateObject, _BackreferenceMixin):
+class Flow(stateobject.SimpleStateObject, BackreferenceMixin):
def __init__(self, conntype, client_conn, server_conn):
self.conntype = conntype
self.client_conn = client_conn
+ """@type: ClientConnection"""
self.server_conn = server_conn
+ """@type: ServerConnection"""
+
self.error = None
+ """@type: Error"""
+ self._backup = None
_backrefattr = ("error",)
_backrefname = "flow"
diff --git a/libmproxy/proxy.py b/libmproxy/proxy.py
index 53e3f575..6ff02a36 100644
--- a/libmproxy/proxy.py
+++ b/libmproxy/proxy.py
@@ -30,6 +30,7 @@ class ProxyError(Exception):
def __str__(self):
return "ProxyError(%s, %s)" % (self.code, self.msg)
+
class Log:
def __init__(self, msg):
self.msg = msg
@@ -104,8 +105,9 @@ class ClientConnection(tcp.BaseHandler, stateobject.SimpleStateObject):
class ServerConnection(tcp.TCPClient, stateobject.SimpleStateObject):
- def __init__(self, address):
+ def __init__(self, address, priority):
tcp.TCPClient.__init__(self, address)
+ self.priority = priority
self.peername = None
self.timestamp_start = None
@@ -145,7 +147,7 @@ class ServerConnection(tcp.TCPClient, stateobject.SimpleStateObject):
@classmethod
def _from_state(cls, state):
- f = cls(tuple())
+ f = cls(tuple(), None)
f._load_state(state)
return f
@@ -190,7 +192,7 @@ class RequestReplayThread(threading.Thread):
def run(self):
try:
r = self.flow.request
- server = ServerConnection(self.flow.server_conn.address())
+ server = ServerConnection(self.flow.server_conn.address(), None)
server.connect()
if self.flow.server_conn.ssl_established:
server.establish_ssl(self.config.clientcerts,
@@ -202,6 +204,7 @@ class RequestReplayThread(threading.Thread):
self.flow.error = protocol.primitives.Error(str(v))
self.channel.ask("error", self.flow.error)
+
class ConnectionHandler:
def __init__(self, config, client_connection, client_address, server, channel, server_version):
self.config = config
@@ -310,18 +313,17 @@ class ConnectionHandler:
@type priority: AddressPriority
"""
address = tcp.Address.wrap(address)
- self.log("Set server address: %s:%s" % (address.host, address.port))
- if self.server_conn and (self.server_address_priority > priority):
+ self.log("Try to set server address: %s:%s" % (address.host, address.port))
+ if self.server_conn and (self.server_conn.priority > priority):
self.log("Server address priority too low (is: %s, got: %s)" % (self.server_address_priority, priority))
return
- self.address_priority = priority
-
if self.server_conn and (self.server_conn.address == address):
+ self.server_conn.priority = priority # Possibly increase priority
self.log("Addresses match, skip.")
return
- server_conn = ServerConnection(address)
+ server_conn = ServerConnection(address, priority)
if self.server_conn and self.server_conn.connection:
self.del_server_connection()
self.server_conn = server_conn
diff --git a/test/test_dump.py b/test/test_dump.py
index f6688b1a..314356fc 100644
--- a/test/test_dump.py
+++ b/test/test_dump.py
@@ -26,12 +26,12 @@ class TestDumpMaster:
m.handle_log(l)
cc = req.flow.client_conn
cc.reply = mock.MagicMock()
- resp = tutils.tresp(req, content=content)
m.handle_clientconnect(cc)
- sc = proxy.ServerConnection((req.host, req.port))
+ sc = proxy.ServerConnection((req.host, req.port), None)
sc.reply = mock.MagicMock()
m.handle_serverconnection(sc)
m.handle_request(req)
+ resp = tutils.tresp(req, content=content)
f = m.handle_response(resp)
m.handle_clientdisconnect(cc)
return f
diff --git a/test/test_flow.py b/test/test_flow.py
index 5ae6c8d6..f28697c1 100644
--- a/test/test_flow.py
+++ b/test/test_flow.py
@@ -4,6 +4,7 @@ import email.utils
from libmproxy import filt, protocol, controller, utils, tnetstring, proxy, flow
from libmproxy.protocol.primitives import Error, Flow
from libmproxy.protocol.http import decoded
+from netlib import tcp
import tutils
@@ -32,7 +33,7 @@ class TestStickyCookieState:
def _response(self, cookie, host):
s = flow.StickyCookieState(filt.parse(".*"))
f = tutils.tflow_full()
- f.request.host = host
+ f.server_conn.address = tcp.Address((host, 80))
f.response.headers["Set-Cookie"] = [cookie]
s.handle_response(f)
return s, f
@@ -68,7 +69,7 @@ class TestStickyAuthState:
f = tutils.tflow_full()
f.request.headers["authorization"] = ["foo"]
s.handle_request(f)
- assert "host" in s.hosts
+ assert "address" in s.hosts
f = tutils.tflow_full()
s.handle_request(f)
@@ -586,7 +587,7 @@ class TestFlowMaster:
req = tutils.treq()
fm.handle_clientconnect(req.flow.client_conn)
assert fm.scripts[0].ns["log"][-1] == "clientconnect"
- sc = proxy.ServerConnection((req.host, req.port))
+ sc = proxy.ServerConnection((req.host, req.port), None)
sc.reply = controller.DummyReply()
fm.handle_serverconnection(sc)
assert fm.scripts[0].ns["log"][-1] == "serverconnect"
@@ -800,11 +801,23 @@ class TestRequest:
def test_get_url(self):
r = tutils.tflow().request
- assert r.get_url() == "https://host:22/"
- assert r.get_url(hostheader=True) == "https://host:22/"
+
+ assert r.get_url() == "http://address:22/path"
+
+ r.flow.server_conn.ssl_established = True
+ assert r.get_url() == "https://address:22/path"
+
+ r.flow.server_conn.address = tcp.Address(("host", 42))
+ assert r.get_url() == "https://host:42/path"
+
+ r.host = "address"
+ r.port = 22
+ assert r.get_url() == "https://address:22/path"
+
+ assert r.get_url(hostheader=True) == "https://address:22/path"
r.headers["Host"] = ["foo.com"]
- assert r.get_url() == "https://host:22/"
- assert r.get_url(hostheader=True) == "https://foo.com:22/"
+ assert r.get_url() == "https://address:22/path"
+ assert r.get_url(hostheader=True) == "https://foo.com:22/path"
def test_path_components(self):
r = tutils.treq()
diff --git a/test/test_proxy.py b/test/test_proxy.py
index 41d41d0c..c42d66e7 100644
--- a/test/test_proxy.py
+++ b/test/test_proxy.py
@@ -19,7 +19,7 @@ class TestServerConnection:
self.d.shutdown()
def test_simple(self):
- sc = proxy.ServerConnection((self.d.IFACE, self.d.port))
+ sc = proxy.ServerConnection((self.d.IFACE, self.d.port), None)
sc.connect()
r = tutils.treq()
r.flow.server_conn = sc
@@ -31,7 +31,7 @@ class TestServerConnection:
sc.finish()
def test_terminate_error(self):
- sc = proxy.ServerConnection((self.d.IFACE, self.d.port))
+ sc = proxy.ServerConnection((self.d.IFACE, self.d.port), None)
sc.connect()
sc.connection = mock.Mock()
sc.connection.recv = mock.Mock(return_value=False)
diff --git a/test/test_script.py b/test/test_script.py
index 7ee85f2c..2e48081b 100644
--- a/test/test_script.py
+++ b/test/test_script.py
@@ -75,7 +75,7 @@ class TestScript:
# Two instantiations
assert m.call_count == 2
assert (time.time() - t_start) < 0.09
- time.sleep(0.2)
+ time.sleep(0.3 - (time.time() - t_start))
# Plus two invocations
assert m.call_count == 4