diff options
-rw-r--r-- | libmproxy/flow.py | 3 | ||||
-rw-r--r-- | libmproxy/protocol.py | 69 | ||||
-rw-r--r-- | libmproxy/proxy.py | 28 | ||||
-rw-r--r-- | test/test_dump.py | 2 | ||||
-rw-r--r-- | test/test_flow.py | 2 | ||||
-rw-r--r-- | test/test_proxy.py | 13 | ||||
-rw-r--r-- | test/test_server.py | 31 | ||||
-rw-r--r-- | test/tservers.py | 12 |
8 files changed, 100 insertions, 60 deletions
diff --git a/libmproxy/flow.py b/libmproxy/flow.py index ee9031ba..acebb71d 100644 --- a/libmproxy/flow.py +++ b/libmproxy/flow.py @@ -835,7 +835,7 @@ class ClientConnect(StateObject): def _get_state(self): return dict( address = list(self.address), - requestcount = self.requestcount, + requestcount = -1, # FIXME self.requestcount, error = self.error, ) @@ -1599,6 +1599,7 @@ class FlowMaster(controller.Master): if r.is_live(): app = self.apps.get(r) if app: + # FIXME: for the tcp proxy, use flow.client_conn.wfile err = app.serve(r, r.wfile, **{"mitmproxy.master": self}) if err: self.add_event("Error in wsgi app. %s"%err, "error") diff --git a/libmproxy/protocol.py b/libmproxy/protocol.py index 279ff015..9d3f805c 100644 --- a/libmproxy/protocol.py +++ b/libmproxy/protocol.py @@ -22,6 +22,10 @@ def handle_messages(conntype, connection_handler): _handle("messages", conntype, connection_handler) +def handle_error(conntype, connection_handler, error): + _handle("error", conntype, connection_handler, error) + + class ConnectionTypeChange(Exception): pass @@ -29,6 +33,17 @@ class ConnectionTypeChange(Exception): class ProtocolHandler(object): def __init__(self, c): self.c = c + def handle_messages(self): + """ + This method gets called if the connection has been established. + """ + raise NotImplementedError + def handle_error(self, error): + """ + This method gets called should there be an uncaught exception during the connection. + This might happen outside of handle_messages, e.g. if the initial SSL handshake fails in transparent mode. + """ + raise NotImplementedError """ @@ -131,9 +146,13 @@ class HTTPRequest(HTTPMessage): self.form_out = form_out or self.form_in assert isinstance(headers, ODictCaseless) - #FIXME: Compatibility Fix + #FIXME: Compatibility Fixes def is_live(self): return True + @property + def wfile(self): + import mock + return mock.Mock(side_effect=tcp.NetLibDisconnect) def _assemble_request_line(self, form=None): form = form or self.form_out @@ -239,19 +258,19 @@ class HTTPHandler(ProtocolHandler): flow.request = HTTPRequest.from_stream(self.c.client_conn.rfile, body_size_limit=self.c.config.body_size_limit) self.c.log("request", [flow.request._assemble_request_line(flow.request.form_in)]) - self.process_request(flow.request) request_reply = self.c.channel.ask("request" if LEGACY else "httprequest", flow.request if LEGACY else flow) if request_reply is None or request_reply == KILL: return False - if isinstance(request_reply, HTTPResponse): + if isinstance(request_reply, HTTPResponse) or (LEGACY and isinstance(request_reply, libmproxy.flow.Response)): flow.response = request_reply else: + self.process_request(flow.request) flow.response = self.get_response_from_server(flow.request) - self.c.log("response", [flow.response._assemble_response_line()]) + self.c.log("response", [flow.response._assemble_response_line() if not LEGACY else flow.response._assemble().splitlines()[0]]) response_reply = self.c.channel.ask("response" if LEGACY else "httpresponse", flow.response if LEGACY else flow) if response_reply is None or response_reply == KILL: @@ -269,25 +288,39 @@ class HTTPHandler(ProtocolHandler): if flow.request.form_in == "authority": self.ssl_upgrade(flow.request) return True - except HttpAuthenticationError, e: - self.process_error(flow, code=407, message="Proxy Authentication Required", headers=e.auth_headers) - except (http.HttpError, ProxyError), e: - self.process_error(flow, code=e.code, message=e.msg) - except tcp.NetLibError, e: - self.process_error(flow, message=e.message or e.__class__) + except (HttpAuthenticationError, http.HttpError, ProxyError, tcp.NetLibError), e: + self.handle_error(e, flow) return False - def process_error(self, flow, code=None, message=None, headers=None): - try: - err = ("%s: %s" % (code, message)) if code else message + def handle_error(self, error, flow=None): + code, message, headers = None, None, None + if isinstance(error, HttpAuthenticationError): + code, message, headers = 407, "Proxy Authentication Required", error.auth_headers + elif isinstance(error, (http.HttpError, ProxyError)): + code, message = error.code, error.msg + elif isinstance(error, tcp.NetLibError): + code = 502 + message = error.message or error.__class__ + + if code: + err = "%s: %s" % (code, message) + else: + err = message + + self.c.log("error: %s" %err) + + if flow: flow.error = libmproxy.flow.Error(False, err) - self.c.log("error: %s" % err) self.c.channel.ask("error" if LEGACY else "httperror", flow.error if LEGACY else flow) - if code: + else: + pass # FIXME: Is there any use case for persisting errors that occur outside of flows? + + if code: + try: self.send_error(code, message, headers) - except: - pass + except: + pass def send_error(self, code, message, headers): response = http_status.RESPONSES.get(code, "Unknown") @@ -364,6 +397,8 @@ class HTTPHandler(ProtocolHandler): if request.form_in == "authority": pass elif request.form_in == "absolute": + if request.scheme != "http": + raise http.HttpError(400, "Invalid Request") if not self.c.config.forward_proxy: request.form_out = "origin" if ((not self.c.server_conn) or diff --git a/libmproxy/proxy.py b/libmproxy/proxy.py index 9e3e317b..a7ee9a7b 100644 --- a/libmproxy/proxy.py +++ b/libmproxy/proxy.py @@ -40,8 +40,8 @@ class ProxyConfig: class ClientConnection(tcp.BaseHandler): - def __init__(self, client_connection, address): - tcp.BaseHandler.__init__(self, client_connection, address) + def __init__(self, client_connection, address, server): + tcp.BaseHandler.__init__(self, client_connection, address, server) self.timestamp_start = utils.timestamp() self.timestamp_end = None @@ -72,10 +72,14 @@ class ServerConnection(tcp.TCPClient): self.peername = self.connection.getpeername() self.timestamp_tcp_setup = utils.timestamp() + def send(self, message): + self.wfile.write(message) + self.wfile.flush() + def establish_ssl(self, clientcerts, sni): clientcert = None if clientcerts: - path = os.path.join(clientcerts, self.host.encode("idna")) + ".pem" + path = os.path.join(clientcerts, self.address.host.encode("idna")) + ".pem" if os.path.exists(path): clientcert = path try: @@ -118,7 +122,7 @@ class RequestReplayThread(threading.Thread): class ConnectionHandler: def __init__(self, config, client_connection, client_address, server, channel, server_version): self.config = config - self.client_conn = ClientConnection(client_connection, client_address) + self.client_conn = ClientConnection(client_connection, client_address, server) self.server_conn = None self.channel, self.server_version = channel, server_version @@ -144,6 +148,8 @@ class ConnectionHandler: self.log("clientconnect") self.channel.ask("clientconnect", self) + self.determine_conntype() + try: # Can we already identify the target server and connect to it? server_address = None @@ -159,8 +165,6 @@ class ConnectionHandler: raise ProxyError(502, "Transparent mode failure: could not resolve original destination.") self.log("transparent to %s:%s" % server_address) - self.determine_conntype() - if server_address: self.establish_server_connection(server_address) self._handle_ssl() @@ -171,11 +175,14 @@ class ConnectionHandler: except protocol.ConnectionTypeChange: continue - self.del_server_connection() - except ProxyError, e: + # FIXME: Do we want to persist errors? + except (ProxyError, tcp.NetLibError), e: + protocol.handle_error(self.conntype, self, e) + except Exception, e: + self.log(e.__class__) self.log(str(e)) - # FIXME: We need to persist errors + self.del_server_connection() self.log("clientdisconnect") self.channel.tell("clientdisconnect", self) @@ -279,8 +286,7 @@ class ConnectionHandler: sn = connection.get_servername() if sn and sn != self.sni: self.sni = sn.decode("utf8").encode("idna") - self.establish_server_connection() # reconnect to upstream server with SNI - self.establish_ssl(server=True) # establish SSL with upstream + self.server_reconnect() # reconnect to upstream server with SNI # Now, change client context to reflect changed certificate: new_context = SSL.Context(SSL.TLSv1_METHOD) new_context.use_privatekey_file(self.config.certfile or self.config.cacert) diff --git a/test/test_dump.py b/test/test_dump.py index a958a2ec..031a3f6a 100644 --- a/test/test_dump.py +++ b/test/test_dump.py @@ -30,7 +30,7 @@ class TestDumpMaster: resp = tutils.tresp(req) resp.content = content m.handle_clientconnect(cc) - sc = proxy.ServerConnection(m.o, req.scheme, req.host, req.port, None) + sc = proxy.ServerConnection((req.host, req.port)) sc.reply = mock.MagicMock() m.handle_serverconnection(sc) m.handle_request(req) diff --git a/test/test_flow.py b/test/test_flow.py index f9198f0c..aec04152 100644 --- a/test/test_flow.py +++ b/test/test_flow.py @@ -595,7 +595,7 @@ class TestFlowMaster: req = tutils.treq() fm.handle_clientconnect(req.client_conn) assert fm.scripts[0].ns["log"][-1] == "clientconnect" - sc = proxy.ServerConnection(None, req.scheme, req.host, req.port, None) + sc = proxy.ServerConnection((req.host, req.port)) sc.reply = controller.DummyReply() fm.handle_serverconnection(sc) assert fm.scripts[0].ns["log"][-1] == "serverconnect" diff --git a/test/test_proxy.py b/test/test_proxy.py index 371e5ef7..737e4a92 100644 --- a/test/test_proxy.py +++ b/test/test_proxy.py @@ -19,25 +19,26 @@ class TestServerConnection: self.d.shutdown() def test_simple(self): - sc = proxy.ServerConnection(proxy.ProxyConfig(), "http", self.d.IFACE, self.d.port, "host.com") + sc = proxy.ServerConnection((self.d.IFACE, self.d.port)) sc.connect() r = tutils.treq() r.path = "/p/200:da" - sc.send(r) + sc.send(r._assemble()) assert http.read_response(sc.rfile, r.method, 1000) assert self.d.last_log() r.content = flow.CONTENT_MISSING - tutils.raises("incomplete request", sc.send, r) + tutils.raises("incomplete request", sc.send, r._assemble()) - sc.terminate() + sc.finish() def test_terminate_error(self): - sc = proxy.ServerConnection(proxy.ProxyConfig(), "http", self.d.IFACE, self.d.port, "host.com") + sc = proxy.ServerConnection((self.d.IFACE, self.d.port)) sc.connect() sc.connection = mock.Mock() + sc.connection.recv = mock.Mock(return_value=False) sc.connection.flush = mock.Mock(side_effect=tcp.NetLibDisconnect) - sc.terminate() + sc.finish() class MockParser: diff --git a/test/test_server.py b/test/test_server.py index ba152dc2..d3bf4676 100644 --- a/test/test_server.py +++ b/test/test_server.py @@ -19,8 +19,8 @@ class CommonMixin: def test_replay(self): assert self.pathod("304").status_code == 304 - assert len(self.master.state.view) == 1 - l = self.master.state.view[0] + assert len(self.master.state.view) == (2 if self.ssl else 1) + l = self.master.state.view[1 if self.ssl else 0] assert l.response.code == 304 l.request.path = "/p/305" rt = self.master.replay_request(l, block=True) @@ -41,16 +41,17 @@ class CommonMixin: assert f.status_code == 304 l = self.master.state.view[0] - assert l.request.client_conn.address + assert l.client_conn.address assert "host" in l.request.headers assert l.response.code == 304 def test_invalid_http(self): - t = tcp.TCPClient(("127.0.0.1", self.proxy.address.port)) + t = tcp.TCPClient(("127.0.0.1", self.proxy.port)) t.connect() t.wfile.write("invalid\r\n\r\n") t.wfile.flush() - assert "Bad Request" in t.rfile.readline() + line = t.rfile.readline() + assert ("Bad Request" in line) or ("Bad Gateway" in line) @@ -70,7 +71,7 @@ class TestHTTP(tservers.HTTPProxTest, CommonMixin, AppMixin): assert "ValueError" in ret.content def test_invalid_connect(self): - t = tcp.TCPClient(("127.0.0.1", self.proxy.address.port)) + t = tcp.TCPClient(("127.0.0.1", self.proxy.port)) t.connect() t.wfile.write("CONNECT invalid\n\n") t.wfile.flush() @@ -105,22 +106,17 @@ class TestHTTP(tservers.HTTPProxTest, CommonMixin, AppMixin): assert p.request(req) assert p.request(req) - # However, if the server disconnects on our first try, it's an error. - req = "get:'%s/p/200:b@1:d0'"%self.server.urlbase - p = self.pathoc() - tutils.raises("server disconnect", p.request, req) - def test_proxy_ioerror(self): # Tests a difficult-to-trigger condition, where an IOError is raised # within our read loop. - with mock.patch("libmproxy.proxy.ProxyHandler.read_request") as m: + with mock.patch("libmproxy.protocol.HTTPRequest.from_stream") as m: m.side_effect = IOError("error!") tutils.raises("server disconnect", self.pathod, "304") def test_get_connection_switching(self): def switched(l): for i in l: - if "switching" in i: + if "serverdisconnect" in i: return True req = "get:'%s/p/200:b@1'" p = self.pathoc() @@ -230,12 +226,13 @@ class TestTransparentSSL(tservers.TransparentProxTest, CommonMixin): f = self.pathod("304", sni="testserver.com") assert f.status_code == 304 l = self.server.last_log() - assert self.server.last_log()["request"]["sni"] == "testserver.com" + assert l["request"]["sni"] == "testserver.com" def test_sslerr(self): - p = pathoc.Pathoc("localhost", self.proxy.port) + p = pathoc.Pathoc(("localhost", self.proxy.port)) p.connect() - assert p.request("get:/").status_code == 400 + r = p.request("get:/") + assert r.status_code == 502 class TestProxy(tservers.HTTPProxTest): @@ -335,7 +332,6 @@ class TestFakeResponse(tservers.HTTPProxTest): assert "header_response" in f.headers.keys() - class MasterKillRequest(tservers.TestMaster): def handle_request(self, m): m.reply(proxy.KILL) @@ -376,6 +372,7 @@ class TestTransparentResolveError(tservers.TransparentProxTest): class MasterIncomplete(tservers.TestMaster): def handle_request(self, m): + # FIXME: fails because of a ._assemble().splitlines() log statement. resp = tutils.tresp() resp.content = flow.CONTENT_MISSING m.reply(resp) diff --git a/test/tservers.py b/test/tservers.py index ac95b168..f9008cd6 100644 --- a/test/tservers.py +++ b/test/tservers.py @@ -55,7 +55,7 @@ class ProxyThread(threading.Thread): @property def port(self): - return self.tmaster.server.port + return self.tmaster.server.address.port @property def log(self): @@ -134,13 +134,13 @@ class ProxTestBase: class HTTPProxTest(ProxTestBase): def pathoc_raw(self): - return libpathod.pathoc.Pathoc("127.0.0.1", self.proxy.port) + return libpathod.pathoc.Pathoc(("127.0.0.1", self.proxy.port)) def pathoc(self, sni=None): """ Returns a connected Pathoc instance. """ - p = libpathod.pathoc.Pathoc("localhost", self.proxy.port, ssl=self.ssl, sni=sni) + p = libpathod.pathoc.Pathoc(("localhost", self.proxy.port), ssl=self.ssl, sni=sni) if self.ssl: p.connect(("127.0.0.1", self.server.port)) else: @@ -161,7 +161,7 @@ class HTTPProxTest(ProxTestBase): def app(self, page): if self.ssl: - p = libpathod.pathoc.Pathoc("127.0.0.1", self.proxy.port, True) + p = libpathod.pathoc.Pathoc(("127.0.0.1", self.proxy.port), True) print "PRE" p.connect((APP_HOST, APP_PORT)) print "POST" @@ -211,7 +211,7 @@ class TransparentProxTest(ProxTestBase): """ Returns a connected Pathoc instance. """ - p = libpathod.pathoc.Pathoc("localhost", self.proxy.port, ssl=self.ssl, sni=sni) + p = libpathod.pathoc.Pathoc(("localhost", self.proxy.port), ssl=self.ssl, sni=sni) p.connect() return p @@ -232,7 +232,7 @@ class ReverseProxTest(ProxTestBase): """ Returns a connected Pathoc instance. """ - p = libpathod.pathoc.Pathoc("localhost", self.proxy.port, ssl=self.ssl, sni=sni) + p = libpathod.pathoc.Pathoc(("localhost", self.proxy.port), ssl=self.ssl, sni=sni) p.connect() return p |