diff options
author | Rouli <rouli.net@gmail.com> | 2013-02-28 13:28:57 +0200 |
---|---|---|
committer | Rouli <rouli.net@gmail.com> | 2013-02-28 13:28:57 +0200 |
commit | b6cae7cd2d0105d6a6fe9d35864d0f9b7c5f8924 (patch) | |
tree | a939022f9bbafea95d1d2e88e141b6cceefebdd2 | |
parent | 35f36481b9f9a8050e0316600be168316b60d05e (diff) | |
parent | b077189dd5230b6c440a200d867c70c6ce031b66 (diff) | |
download | mitmproxy-b6cae7cd2d0105d6a6fe9d35864d0f9b7c5f8924.tar.gz mitmproxy-b6cae7cd2d0105d6a6fe9d35864d0f9b7c5f8924.tar.bz2 mitmproxy-b6cae7cd2d0105d6a6fe9d35864d0f9b7c5f8924.zip |
Merge remote-tracking branch 'upstream/master'
-rw-r--r-- | .coveragerc | 3 | ||||
-rw-r--r-- | README.mkd | 26 | ||||
-rw-r--r-- | doc-src/howmitmproxy.html | 2 | ||||
-rw-r--r-- | libmproxy/console/__init__.py | 6 | ||||
-rw-r--r-- | libmproxy/console/common.py | 4 | ||||
-rw-r--r-- | libmproxy/controller.py | 87 | ||||
-rw-r--r-- | libmproxy/dump.py | 22 | ||||
-rw-r--r-- | libmproxy/flow.py | 85 | ||||
-rw-r--r-- | libmproxy/proxy.py | 241 | ||||
-rw-r--r-- | test/.gitignore | 1 | ||||
-rw-r--r-- | test/.pry | 6 | ||||
-rw-r--r-- | test/test_dump.py | 6 | ||||
-rw-r--r-- | test/test_flow.py | 32 | ||||
-rw-r--r-- | test/test_proxy.py | 11 | ||||
-rw-r--r-- | test/test_server.py | 132 | ||||
-rw-r--r-- | test/tservers.py | 198 | ||||
-rw-r--r-- | test/tutils.py | 176 |
17 files changed, 628 insertions, 410 deletions
diff --git a/.coveragerc b/.coveragerc index 696e0eb8..dd57a6e7 100644 --- a/.coveragerc +++ b/.coveragerc @@ -1,3 +1,6 @@ +[rum] +branch = True + [report] omit = *contrib*, *tnetstring*, *platform* include = *libmproxy* @@ -50,28 +50,24 @@ Requirements ------------ * [Python](http://www.python.org) 2.7.x. +* [netlib](http://pypi.python.org/pypi/netlib) 0.2.2 or newer. * [PyOpenSSL](http://pypi.python.org/pypi/pyOpenSSL) 0.13 or newer. * [pyasn1](http://pypi.python.org/pypi/pyasn1) 0.1.2 or newer. * [urwid](http://excess.org/urwid/) version 1.1 or newer. * [PIL](http://www.pythonware.com/products/pil/) version 1.1 or newer. * [lxml](http://lxml.de/) version 2.3 or newer. -* [netlib](http://pypi.python.org/pypi/netlib) 0.2.2 or newer. -The following auxiliary components may be needed if you plan to hack on -mitmproxy: +__mitmproxy__ is tested and developed on OSX, Linux and OpenBSD. Windows is not +officially supported at the moment. -* The test suite uses the [nose](http://readthedocs.org/docs/nose/en/latest/) unit testing - framework and requires [human_curl](https://github.com/Lispython/human_curl) and - [pathod](http://pathod.org). -* Rendering the documentation requires [countershape](http://github.com/cortesi/countershape). -__mitmproxy__ is tested and developed on OSX, Linux and OpenBSD. Windows is not -supported at the moment. +Hacking +------- + +The following components are needed if you plan to hack on mitmproxy: -You should also make sure that your console environment is set up with the -following: +* The test suite uses the [nose](http://readthedocs.org/docs/nose/en/latest/) unit testing + framework and requires [human_curl](https://github.com/Lispython/human_curl), + [pathod](http://pathod.org) and [flask](http://flask.pocoo.org/). +* Rendering the documentation requires [countershape](http://github.com/cortesi/countershape). -* EDITOR environment variable to determine the external editor. -* PAGER environment variable to determine the external pager. -* Appropriate entries in your mailcap files to determine external - viewers for request and response contents. diff --git a/doc-src/howmitmproxy.html b/doc-src/howmitmproxy.html index 6ea723cd..94c895d7 100644 --- a/doc-src/howmitmproxy.html +++ b/doc-src/howmitmproxy.html @@ -71,7 +71,7 @@ flow of requests and responses are completely opaque to the proxy. ## The MITM in mitmproxy -This is where mitmproxy's fundamental trick comes in to play. The MITM in its +This is where mitmproxy's fundamental trick comes into play. The MITM in its name stands for Man-In-The-Middle - a reference to the process we use to intercept and interfere with these theoretially opaque data streams. The basic idea is to pretend to be the server to the client, and pretend to be the client diff --git a/libmproxy/console/__init__.py b/libmproxy/console/__init__.py index d6c7f5a2..a16cc4dc 100644 --- a/libmproxy/console/__init__.py +++ b/libmproxy/console/__init__.py @@ -580,7 +580,7 @@ class ConsoleMaster(flow.FlowMaster): self.view_flowlist() - self.server.start_slave(controller.Slave, self.masterq) + self.server.start_slave(controller.Slave, controller.Channel(self.masterq)) if self.options.rfile: ret = self.load_flows(self.options.rfile) @@ -1002,7 +1002,7 @@ class ConsoleMaster(flow.FlowMaster): if self.state.intercept and f.match(self.state.intercept) and not f.request.is_replay(): f.intercept() else: - r._ack() + r.reply() self.sync_list_view() self.refresh_flow(f) @@ -1023,7 +1023,7 @@ class ConsoleMaster(flow.FlowMaster): # Handlers def handle_log(self, l): self.add_event(l.msg) - l._ack() + l.reply() def handle_error(self, r): f = flow.FlowMaster.handle_error(self, r) diff --git a/libmproxy/console/common.py b/libmproxy/console/common.py index 2da7f802..1cc0b5b9 100644 --- a/libmproxy/console/common.py +++ b/libmproxy/console/common.py @@ -184,7 +184,7 @@ def format_flow(f, focus, extended=False, padding=2): req_timestamp = f.request.timestamp_start, req_is_replay = f.request.is_replay(), req_method = f.request.method, - req_acked = f.request.acked, + req_acked = f.request.reply.acked, req_url = f.request.get_url(), err_msg = f.error.msg if f.error else None, @@ -200,7 +200,7 @@ def format_flow(f, focus, extended=False, padding=2): d.update(dict( resp_code = f.response.code, resp_is_replay = f.response.is_replay(), - resp_acked = f.response.acked, + resp_acked = f.response.reply.acked, resp_clen = contentdesc )) t = f.response.headers["content-type"] diff --git a/libmproxy/controller.py b/libmproxy/controller.py index f38d1edb..bb22597d 100644 --- a/libmproxy/controller.py +++ b/libmproxy/controller.py @@ -17,37 +17,75 @@ import Queue, threading should_exit = False -class Msg: + +class DummyReply: + """ + A reply object that does nothing. Useful when we need an object to seem + like it has a channel, and during testing. + """ def __init__(self): + self.acked = False + + def __call__(self, msg=False): + self.acked = True + + +class Reply: + """ + Messages sent through a channel are decorated with a "reply" attribute. + This object is used to respond to the message through the return + channel. + """ + def __init__(self, obj): + self.obj = obj self.q = Queue.Queue() self.acked = False - def _ack(self, data=False): + def __call__(self, msg=None): if not self.acked: self.acked = True - if data is None: - self.q.put(data) + if msg is None: + self.q.put(self.obj) else: - self.q.put(data or self) + self.q.put(msg) - def _send(self, masterq): - self.acked = False - try: - masterq.put(self, timeout=3) - while not should_exit: # pragma: no cover - try: - g = self.q.get(timeout=0.5) - except Queue.Empty: - continue - return g - except (Queue.Empty, Queue.Full): # pragma: no cover - return None + +class Channel: + def __init__(self, q): + self.q = q + + def ask(self, m): + """ + Decorate a message with a reply attribute, and send it to the + master. then wait for a response. + """ + m.reply = Reply(m) + self.q.put(m) + while not should_exit: + try: + # The timeout is here so we can handle a should_exit event. + g = m.reply.q.get(timeout=0.5) + except Queue.Empty: # pragma: nocover + continue + return g + + def tell(self, m): + """ + Decorate a message with a dummy reply attribute, send it to the + master, then return immediately. + """ + m.reply = DummyReply() + self.q.put(m) class Slave(threading.Thread): - def __init__(self, masterq, server): - self.masterq, self.server = masterq, server - self.server.set_mqueue(masterq) + """ + Slaves get a channel end-point through which they can send messages to + the master. + """ + def __init__(self, channel, server): + self.channel, self.server = channel, server + self.server.set_channel(channel) threading.Thread.__init__(self) def run(self): @@ -55,6 +93,9 @@ class Slave(threading.Thread): class Master: + """ + Masters get and respond to messages from slaves. + """ def __init__(self, server): """ server may be None if no server is needed. @@ -81,18 +122,18 @@ class Master: def run(self): global should_exit should_exit = False - self.server.start_slave(Slave, self.masterq) + self.server.start_slave(Slave, Channel(self.masterq)) while not should_exit: self.tick(self.masterq) self.shutdown() - def handle(self, msg): # pragma: no cover + def handle(self, msg): c = "handle_" + msg.__class__.__name__.lower() m = getattr(self, c, None) if m: m(msg) else: - msg._ack() + msg.reply() def shutdown(self): global should_exit diff --git a/libmproxy/dump.py b/libmproxy/dump.py index 170c701d..3c7eee71 100644 --- a/libmproxy/dump.py +++ b/libmproxy/dump.py @@ -150,16 +150,6 @@ class DumpMaster(flow.FlowMaster): print >> self.outfile, e self.outfile.flush() - def handle_log(self, l): - self.add_event(l.msg) - l._ack() - - def handle_request(self, r): - f = flow.FlowMaster.handle_request(self, r) - if f: - r._ack() - return f - def indent(self, n, t): l = str(t).strip().split("\n") return "\n".join(" "*n + i for i in l) @@ -210,10 +200,20 @@ class DumpMaster(flow.FlowMaster): self.outfile.flush() self.state.delete_flow(f) + def handle_log(self, l): + self.add_event(l.msg) + l.reply() + + def handle_request(self, r): + f = flow.FlowMaster.handle_request(self, r) + if f: + r.reply() + return f + def handle_response(self, msg): f = flow.FlowMaster.handle_response(self, msg) if f: - msg._ack() + msg.reply() self._process_flow(f) return f diff --git a/libmproxy/flow.py b/libmproxy/flow.py index af97698c..1f5d01ee 100644 --- a/libmproxy/flow.py +++ b/libmproxy/flow.py @@ -196,7 +196,15 @@ class decoded(object): self.o.encode(self.ce) -class HTTPMsg(controller.Msg): +class StateObject: + def __eq__(self, other): + try: + return self._get_state() == other._get_state() + except AttributeError: + return False + + +class HTTPMsg(StateObject): def get_decoded_content(self): """ Returns the decoded content based on the current Content-Encoding header. @@ -252,6 +260,7 @@ class HTTPMsg(controller.Msg): return 0 return len(self.content) + class Request(HTTPMsg): """ An HTTP request. @@ -289,7 +298,6 @@ class Request(HTTPMsg): self.timestamp_start = timestamp_start or utils.timestamp() self.timestamp_end = max(timestamp_end or utils.timestamp(), timestamp_start) self.close = False - controller.Msg.__init__(self) # Have this request's cookies been modified by sticky cookies or auth? self.stickycookie = False @@ -388,15 +396,8 @@ class Request(HTTPMsg): def __hash__(self): return id(self) - def __eq__(self, other): - return self._get_state() == other._get_state() - def copy(self): - """ - Returns a copy of this object. - """ c = copy.copy(self) - c.acked = True c.headers = self.headers.copy() return c @@ -603,7 +604,6 @@ class Response(HTTPMsg): self.cert = cert self.timestamp_start = timestamp_start or utils.timestamp() self.timestamp_end = max(timestamp_end or utils.timestamp(), timestamp_start) - controller.Msg.__init__(self) self.replay = False def _refresh_cookie(self, c, delta): @@ -700,15 +700,8 @@ class Response(HTTPMsg): state["timestamp_end"], ) - def __eq__(self, other): - return self._get_state() == other._get_state() - def copy(self): - """ - Returns a copy of this object. - """ c = copy.copy(self) - c.acked = True c.headers = self.headers.copy() return c @@ -773,7 +766,7 @@ class Response(HTTPMsg): cookies.append((cookie_name, (cookie_value, cookie_parameters))) return dict(cookies) -class ClientDisconnect(controller.Msg): +class ClientDisconnect: """ A client disconnection event. @@ -782,11 +775,10 @@ class ClientDisconnect(controller.Msg): client_conn: ClientConnect object. """ def __init__(self, client_conn): - controller.Msg.__init__(self) self.client_conn = client_conn -class ClientConnect(controller.Msg): +class ClientConnect(StateObject): """ A single client connection. Each connection can result in multiple HTTP Requests. @@ -807,10 +799,6 @@ class ClientConnect(controller.Msg): self.close = False self.requestcount = 0 self.error = None - controller.Msg.__init__(self) - - def __eq__(self, other): - return self._get_state() == other._get_state() def __str__(self): if self.address: @@ -839,15 +827,10 @@ class ClientConnect(controller.Msg): return None def copy(self): - """ - Returns a copy of this object. - """ - c = copy.copy(self) - c.acked = True - return c + return copy.copy(self) -class Error(controller.Msg): +class Error(StateObject): """ An Error. @@ -865,18 +848,13 @@ class Error(controller.Msg): def __init__(self, request, msg, timestamp=None): self.request, self.msg = request, msg self.timestamp = timestamp or utils.timestamp() - controller.Msg.__init__(self) def _load_state(self, state): self.msg = state["msg"] self.timestamp = state["timestamp"] def copy(self): - """ - Returns a copy of this object. - """ c = copy.copy(self) - c.acked = True return c def _get_state(self): @@ -893,9 +871,6 @@ class Error(controller.Msg): state["timestamp"], ) - def __eq__(self, other): - return self._get_state() == other._get_state() - def replace(self, pattern, repl, *args, **kwargs): """ Replaces a regular expression pattern with repl in both the headers @@ -1185,10 +1160,11 @@ class Flow: Kill this request. """ self.error = Error(self.request, "Connection killed") - if self.request and not self.request.acked: - self.request._ack(None) - elif self.response and not self.response.acked: - self.response._ack(None) + self.error.reply = controller.DummyReply() + if self.request and not self.request.reply.acked: + self.request.reply(proxy.KILL) + elif self.response and not self.response.reply.acked: + self.response.reply(proxy.KILL) master.handle_error(self.error) self.intercepting = False @@ -1204,10 +1180,10 @@ class Flow: Continue with the flow - called after an intercept(). """ if self.request: - if not self.request.acked: - self.request._ack() - elif self.response and not self.response.acked: - self.response._ack() + if not self.request.reply.acked: + self.request.reply() + elif self.response and not self.response.reply.acked: + self.response.reply() self.intercepting = False def replace(self, pattern, repl, *args, **kwargs): @@ -1469,7 +1445,7 @@ class FlowMaster(controller.Master): flow.response = response if self.refresh_server_playback: response.refresh() - flow.request._ack(response) + flow.request.reply(response) if self.server_playback.count() == 0: self.stop_server_playback() return True @@ -1496,10 +1472,13 @@ class FlowMaster(controller.Master): Loads a flow, and returns a new flow object. """ if f.request: + f.request.reply = controller.DummyReply() fr = self.handle_request(f.request) if f.response: + f.response.reply = controller.DummyReply() self.handle_response(f.response) if f.error: + f.error.reply = controller.DummyReply() self.handle_error(f.error) return fr @@ -1527,7 +1506,7 @@ class FlowMaster(controller.Master): if self.kill_nonreplay: f.kill(self) else: - f.request._ack() + f.request.reply() def process_new_response(self, f): if self.stickycookie_state: @@ -1566,11 +1545,11 @@ class FlowMaster(controller.Master): def handle_clientconnect(self, cc): self.run_script_hook("clientconnect", cc) - cc._ack() + cc.reply() def handle_clientdisconnect(self, r): self.run_script_hook("clientdisconnect", r) - r._ack() + r.reply() def handle_error(self, r): f = self.state.add_error(r) @@ -1578,7 +1557,7 @@ class FlowMaster(controller.Master): self.run_script_hook("error", f) if self.client_playback: self.client_playback.clear(f) - r._ack() + r.reply() return f def handle_request(self, r): @@ -1601,7 +1580,7 @@ class FlowMaster(controller.Master): if self.stream: self.stream.add(f) else: - r._ack() + r.reply() return f def shutdown(self): diff --git a/libmproxy/proxy.py b/libmproxy/proxy.py index f14e4e3e..7c229064 100644 --- a/libmproxy/proxy.py +++ b/libmproxy/proxy.py @@ -20,6 +20,8 @@ from netlib import odict, tcp, http, wsgi, certutils, http_status import utils, flow, version, platform, controller import authentication +KILL = 0 + class ProxyError(Exception): def __init__(self, code, msg, headers=None): @@ -29,9 +31,8 @@ class ProxyError(Exception): return "ProxyError(%s, %s)"%(self.code, self.msg) -class Log(controller.Msg): +class Log: def __init__(self, msg): - controller.Msg.__init__(self) self.msg = msg @@ -49,45 +50,23 @@ class ProxyConfig: self.certstore = certutils.CertStore(certdir) -class RequestReplayThread(threading.Thread): - def __init__(self, config, flow, masterq): - self.config, self.flow, self.masterq = config, flow, masterq - threading.Thread.__init__(self) - - def run(self): - try: - r = self.flow.request - server = ServerConnection(self.config, r.host, r.port) - server.connect(r.scheme) - server.send(r) - httpversion, code, msg, headers, content = http.read_response( - server.rfile, r.method, self.config.body_size_limit - ) - response = flow.Response( - self.flow.request, httpversion, code, msg, headers, content, server.cert - ) - response._send(self.masterq) - except (ProxyError, http.HttpError, tcp.NetLibError), v: - err = flow.Error(self.flow.request, str(v)) - err._send(self.masterq) - - class ServerConnection(tcp.TCPClient): - def __init__(self, config, host, port): + def __init__(self, config, scheme, host, port, sni): tcp.TCPClient.__init__(self, host, port) self.config = config + self.scheme, self.sni = scheme, sni self.requestcount = 0 - def connect(self, scheme): + def connect(self): tcp.TCPClient.connect(self) - if scheme == "https": + if self.scheme == "https": clientcert = None if self.config.clientcerts: path = os.path.join(self.config.clientcerts, self.host.encode("idna")) + ".pem" if os.path.exists(path): clientcert = path try: - self.convert_to_ssl(clientcert=clientcert, sni=self.host) + self.convert_to_ssl(cert=clientcert, sni=self.sni) except tcp.NetLibError, v: raise ProxyError(400, str(v)) @@ -108,42 +87,78 @@ class ServerConnection(tcp.TCPClient): pass +class RequestReplayThread(threading.Thread): + def __init__(self, config, flow, masterq): + self.config, self.flow, self.channel = config, flow, controller.Channel(masterq) + threading.Thread.__init__(self) + + def run(self): + try: + r = self.flow.request + server = ServerConnection(self.config, r.scheme, r.host, r.port, r.host) + server.connect() + server.send(r) + httpversion, code, msg, headers, content = http.read_response( + server.rfile, r.method, self.config.body_size_limit + ) + response = flow.Response( + self.flow.request, httpversion, code, msg, headers, content, server.cert + ) + self.channel.ask(response) + except (ProxyError, http.HttpError, tcp.NetLibError), v: + err = flow.Error(self.flow.request, str(v)) + self.channel.ask(err) + + class ProxyHandler(tcp.BaseHandler): - def __init__(self, config, connection, client_address, server, mqueue, server_version): - self.mqueue, self.server_version = mqueue, server_version + def __init__(self, config, connection, client_address, server, channel, server_version): + self.channel, self.server_version = channel, server_version self.config = config - self.server_conn = None self.proxy_connect_state = None self.sni = None + self.server_conn = None tcp.BaseHandler.__init__(self, connection, client_address, server) + def get_server_connection(self, cc, scheme, host, port, sni): + sc = self.server_conn + if sc and (scheme, host, port, sni) != (sc.scheme, sc.host, sc.port, sc.sni): + sc.terminate() + self.server_conn = None + self.log( + cc, + "switching connection", [ + "%s://%s:%s (sni=%s) -> %s://%s:%s (sni=%s)"%( + scheme, host, port, sni, + sc.scheme, sc.host, sc.port, sc.sni + ) + ] + ) + if not self.server_conn: + try: + self.server_conn = ServerConnection(self.config, scheme, host, port, sni) + self.server_conn.connect() + except tcp.NetLibError, v: + raise ProxyError(502, v) + return self.server_conn + + def del_server_connection(self): + self.server_conn = None + def handle(self): cc = flow.ClientConnect(self.client_address) self.log(cc, "connect") - cc._send(self.mqueue) + self.channel.ask(cc) while self.handle_request(cc) and not cc.close: pass cc.close = True - cd = flow.ClientDisconnect(cc) + cd = flow.ClientDisconnect(cc) self.log( cc, "disconnect", [ "handled %s requests"%cc.requestcount] ) - cd._send(self.mqueue) - - def server_connect(self, scheme, host, port): - sc = self.server_conn - if sc and (host, port) != (sc.host, sc.port): - sc.terminate() - self.server_conn = None - if not self.server_conn: - try: - self.server_conn = ServerConnection(self.config, host, port) - self.server_conn.connect(scheme) - except tcp.NetLibError, v: - raise ProxyError(502, v) + self.channel.tell(cd) def handle_request(self, cc): try: @@ -160,45 +175,66 @@ class ProxyHandler(tcp.BaseHandler): self.log(cc, "Error in wsgi app.", err.split("\n")) return else: - request = request._send(self.mqueue) - if request is None: + request_reply = self.channel.ask(request) + if request_reply == KILL: return - - if isinstance(request, flow.Response): - response = request + elif isinstance(request_reply, flow.Response): request = False - response = response._send(self.mqueue) + response = request_reply + response_reply = self.channel.ask(response) else: + request = request_reply if self.config.reverse_proxy: scheme, host, port = self.config.reverse_proxy else: scheme, host, port = request.scheme, request.host, request.port - self.server_connect(scheme, host, port) - self.server_conn.send(request) - self.server_conn.rfile.reset_timestamps() - httpversion, code, msg, headers, content = http.read_response( - self.server_conn.rfile, - request.method, - self.config.body_size_limit - ) + + # If we've already pumped a request over this connection, + # it's possible that the server has timed out. If this is + # the case, we want to reconnect without sending an error + # to the client. + while 1: + sc = self.get_server_connection(cc, scheme, host, port, host) + sc.send(request) + sc.rfile.reset_timestamps() + try: + httpversion, code, msg, headers, content = http.read_response( + sc.rfile, + request.method, + self.config.body_size_limit + ) + except http.HttpErrorConnClosed, v: + self.del_server_connection() + if sc.requestcount > 1: + continue + else: + raise + else: + break + response = flow.Response( - request, httpversion, code, msg, headers, content, self.server_conn.cert, self.server_conn.rfile.first_byte_timestamp, utils.timestamp() + request, httpversion, code, msg, headers, content, sc.cert, + sc.rfile.first_byte_timestamp, utils.timestamp() ) + response_reply = self.channel.ask(response) + # Not replying to the server invalidates the server + # connection, so we terminate. + if response_reply == KILL: + sc.terminate() - response = response._send(self.mqueue) - if response is None: - self.server_conn.terminate() - if response is None: - return - self.send_response(response) - if request and http.request_connection_close(request.httpversion, request.headers): - return - # We could keep the client connection when the server - # connection needs to go away. However, we want to mimic - # behaviour as closely as possible to the client, so we - # disconnect. - if http.response_connection_close(response.httpversion, response.headers): + if response_reply == KILL: return + else: + response = response_reply + self.send_response(response) + if request and http.request_connection_close(request.httpversion, request.headers): + return + # We could keep the client connection when the server + # connection needs to go away. However, we want to mimic + # behaviour as closely as possible to the client, so we + # disconnect. + if http.response_connection_close(response.httpversion, response.headers): + return except (IOError, ProxyError, http.HttpError, tcp.NetLibDisconnect), e: if hasattr(e, "code"): cc.error = "%s: %s"%(e.code, e.msg) @@ -207,7 +243,7 @@ class ProxyHandler(tcp.BaseHandler): if request: err = flow.Error(request, cc.error) - err._send(self.mqueue) + self.channel.ask(err) self.log( cc, cc.error, ["url: %s"%request.get_url()] @@ -228,7 +264,7 @@ class ProxyHandler(tcp.BaseHandler): msg.append(" -> "+i) msg = "\n".join(msg) l = Log(msg) - l._send(self.mqueue) + self.channel.tell(l) def find_cert(self, host, port, sni): if self.config.certfile: @@ -292,25 +328,6 @@ class ProxyHandler(tcp.BaseHandler): self.rfile.first_byte_timestamp, utils.timestamp() ) - def read_request_reverse(self, client_conn): - line = self.get_line(self.rfile) - if line == "": - return None - scheme, host, port = self.config.reverse_proxy - r = http.parse_init_http(line) - if not r: - raise ProxyError(400, "Bad HTTP request line: %s"%repr(line)) - method, path, httpversion = r - headers = self.read_headers(authenticate=False) - content = http.read_http_body_request( - self.rfile, self.wfile, headers, httpversion, self.config.body_size_limit - ) - return flow.Request( - client_conn, httpversion, host, port, "http", method, path, headers, content, - self.rfile.first_byte_timestamp, utils.timestamp() - ) - - def read_request_proxy(self, client_conn): line = self.get_line(self.rfile) if line == "": @@ -366,6 +383,24 @@ class ProxyHandler(tcp.BaseHandler): self.rfile.first_byte_timestamp, utils.timestamp() ) + def read_request_reverse(self, client_conn): + line = self.get_line(self.rfile) + if line == "": + return None + scheme, host, port = self.config.reverse_proxy + r = http.parse_init_http(line) + if not r: + raise ProxyError(400, "Bad HTTP request line: %s"%repr(line)) + method, path, httpversion = r + headers = self.read_headers(authenticate=False) + content = http.read_http_body_request( + self.rfile, self.wfile, headers, httpversion, self.config.body_size_limit + ) + return flow.Request( + client_conn, httpversion, host, port, "http", method, path, headers, content, + self.rfile.first_byte_timestamp, utils.timestamp() + ) + def read_request(self, client_conn): self.rfile.reset_timestamps() if self.config.transparent_proxy: @@ -431,18 +466,18 @@ class ProxyServer(tcp.TCPServer): tcp.TCPServer.__init__(self, (address, port)) except socket.error, v: raise ProxyServerError('Error starting proxy server: ' + v.strerror) - self.masterq = None + self.channel = None self.apps = AppRegistry() - def start_slave(self, klass, masterq): - slave = klass(masterq, self) + def start_slave(self, klass, channel): + slave = klass(channel, self) slave.start() - def set_mqueue(self, q): - self.masterq = q + def set_channel(self, channel): + self.channel = channel def handle_connection(self, request, client_address): - h = ProxyHandler(self.config, request, client_address, self, self.masterq, self.server_version) + h = ProxyHandler(self.config, request, client_address, self, self.channel, self.server_version) h.handle() try: h.finish() @@ -480,7 +515,7 @@ class DummyServer: def __init__(self, config): self.config = config - def start_slave(self, klass, masterq): + def start_slave(self, klass, channel): pass def shutdown(self): diff --git a/test/.gitignore b/test/.gitignore deleted file mode 100644 index 6350e986..00000000 --- a/test/.gitignore +++ /dev/null @@ -1 +0,0 @@ -.coverage diff --git a/test/.pry b/test/.pry deleted file mode 100644 index f6f18e7b..00000000 --- a/test/.pry +++ /dev/null @@ -1,6 +0,0 @@ -base = .. -coverage = ../libmproxy -exclude = . - ../libmproxy/contrib - ../libmproxy/tnetstring.py - diff --git a/test/test_dump.py b/test/test_dump.py index e1241e29..5d3f9133 100644 --- a/test/test_dump.py +++ b/test/test_dump.py @@ -3,6 +3,7 @@ from cStringIO import StringIO import libpry from libmproxy import dump, flow, proxy import tutils +import mock def test_strfuncs(): t = tutils.tresp() @@ -21,6 +22,7 @@ class TestDumpMaster: req = tutils.treq() req.content = content l = proxy.Log("connect") + l.reply = mock.MagicMock() m.handle_log(l) cc = req.client_conn cc.connection_error = "error" @@ -29,7 +31,9 @@ class TestDumpMaster: m.handle_clientconnect(cc) m.handle_request(req) f = m.handle_response(resp) - m.handle_clientdisconnect(flow.ClientDisconnect(cc)) + cd = flow.ClientDisconnect(cc) + cd.reply = mock.MagicMock() + m.handle_clientdisconnect(cd) return f def _dummy_cycle(self, n, filt, content, **options): diff --git a/test/test_flow.py b/test/test_flow.py index da5b095e..6aa898ad 100644 --- a/test/test_flow.py +++ b/test/test_flow.py @@ -223,16 +223,16 @@ class TestFlow: f = tutils.tflow() f.request = tutils.treq() f.intercept() - assert not f.request.acked + assert not f.request.reply.acked f.kill(fm) - assert f.request.acked + assert f.request.reply.acked f.intercept() f.response = tutils.tresp() f.request = f.response.request - f.request._ack() - assert not f.response.acked + f.request.reply() + assert not f.response.reply.acked f.kill(fm) - assert f.response.acked + assert f.response.reply.acked def test_killall(self): s = flow.State() @@ -245,25 +245,25 @@ class TestFlow: fm.handle_request(r) for i in s.view: - assert not i.request.acked + assert not i.request.reply.acked s.killall(fm) for i in s.view: - assert i.request.acked + assert i.request.reply.acked def test_accept_intercept(self): f = tutils.tflow() f.request = tutils.treq() f.intercept() - assert not f.request.acked + assert not f.request.reply.acked f.accept_intercept() - assert f.request.acked + assert f.request.reply.acked f.response = tutils.tresp() f.request = f.response.request f.intercept() - f.request._ack() - assert not f.response.acked + f.request.reply() + assert not f.response.reply.acked f.accept_intercept() - assert f.response.acked + assert f.response.reply.acked def test_serialization(self): f = flow.Flow(None) @@ -562,9 +562,11 @@ class TestFlowMaster: fm.handle_response(resp) assert fm.script.ns["log"][-1] == "response" dc = flow.ClientDisconnect(req.client_conn) + dc.reply = controller.DummyReply() fm.handle_clientdisconnect(dc) assert fm.script.ns["log"][-1] == "clientdisconnect" err = flow.Error(f.request, "msg") + err.reply = controller.DummyReply() fm.handle_error(err) assert fm.script.ns["log"][-1] == "error" @@ -598,10 +600,12 @@ class TestFlowMaster: assert not fm.handle_response(rx) dc = flow.ClientDisconnect(req.client_conn) + dc.reply = controller.DummyReply() req.client_conn.requestcount = 1 fm.handle_clientdisconnect(dc) err = flow.Error(f.request, "msg") + err.reply = controller.DummyReply() fm.handle_error(err) fm.load_script(tutils.test_data.path("scripts/a.py")) @@ -621,7 +625,9 @@ class TestFlowMaster: fm.tick(q) assert fm.state.flow_count() - fm.handle_error(flow.Error(f.request, "error")) + err = flow.Error(f.request, "error") + err.reply = controller.DummyReply() + fm.handle_error(err) def test_server_playback(self): controller.should_exit = False diff --git a/test/test_proxy.py b/test/test_proxy.py index c73f61d8..3995b393 100644 --- a/test/test_proxy.py +++ b/test/test_proxy.py @@ -1,7 +1,7 @@ from libmproxy import proxy, flow import tutils from libpathod import test -from netlib import http +from netlib import http, tcp import mock @@ -39,8 +39,8 @@ class TestServerConnection: self.d.shutdown() def test_simple(self): - sc = proxy.ServerConnection(proxy.ProxyConfig(), self.d.IFACE, self.d.port) - sc.connect("http") + sc = proxy.ServerConnection(proxy.ProxyConfig(), "http", self.d.IFACE, self.d.port, "host.com") + sc.connect() r = tutils.treq() r.path = "/p/200:da" sc.send(r) @@ -53,8 +53,9 @@ class TestServerConnection: sc.terminate() def test_terminate_error(self): - sc = proxy.ServerConnection(proxy.ProxyConfig(), self.d.IFACE, self.d.port) - sc.connect("http") + sc = proxy.ServerConnection(proxy.ProxyConfig(), "http", self.d.IFACE, self.d.port, "host.com") + sc.connect() sc.connection = mock.Mock() sc.connection.close = mock.Mock(side_effect=IOError) sc.terminate() + diff --git a/test/test_server.py b/test/test_server.py index 0a2f142e..034fab41 100644 --- a/test/test_server.py +++ b/test/test_server.py @@ -1,7 +1,9 @@ import socket, time +import mock from netlib import tcp from libpathod import pathoc -import tutils +import tutils, tservers +from libmproxy import flow, proxy """ Note that the choice of response code in these tests matters more than you @@ -39,7 +41,19 @@ class SanityMixin: assert l.error -class TestHTTP(tutils.HTTPProxTest, SanityMixin): +class TestHTTP(tservers.HTTPProxTest, SanityMixin): + def test_app(self): + p = self.pathoc() + ret = p.request("get:'http://testapp/'") + assert ret[1] == 200 + assert ret[4] == "testapp" + + def test_app_err(self): + p = self.pathoc() + ret = p.request("get:'http://errapp/'") + assert ret[1] == 500 + assert "ValueError" in ret[4] + def test_invalid_http(self): t = tcp.TCPClient("127.0.0.1", self.proxy.port) t.connect() @@ -68,24 +82,83 @@ class TestHTTP(tutils.HTTPProxTest, SanityMixin): assert "host" in l.request.headers assert l.response.code == 304 + def test_connection_close(self): + # Add a body, so we have a content-length header, which combined with + # HTTP1.1 means the connection is kept alive. + response = '%s/p/200:b@1'%self.server.urlbase + + # Lets sanity check that the connection does indeed stay open by + # issuing two requests over the same connection + p = self.pathoc() + assert p.request("get:'%s'"%response) + assert p.request("get:'%s'"%response) + + # Now check that the connection is closed as the client specifies + p = self.pathoc() + assert p.request("get:'%s':h'Connection'='close'"%response) + tutils.raises("disconnect", p.request, "get:'%s'"%response) + + def test_reconnect(self): + req = "get:'%s/p/200:b@1:da'"%self.server.urlbase + p = self.pathoc() + assert p.request(req) + # Server has disconnected. Mitmproxy should detect this, and reconnect. + assert p.request(req) + assert p.request(req) -class TestHTTPS(tutils.HTTPProxTest, SanityMixin): + # However, if the server disconnects on our first try, it's an error. + req = "get:'%s/p/200:b@1:d0'"%self.server.urlbase + p = self.pathoc() + tutils.raises("server disconnect", p.request, req) + + def test_proxy_ioerror(self): + # Tests a difficult-to-trigger condition, where an IOError is raised + # within our read loop. + with mock.patch("libmproxy.proxy.ProxyHandler.read_request") as m: + m.side_effect = IOError("error!") + tutils.raises("empty reply", self.pathod, "304") + + def test_get_connection_switching(self): + def switched(l): + for i in l: + if "switching" in i: + return True + req = "get:'%s/p/200:b@1'" + p = self.pathoc() + assert p.request(req%self.server.urlbase) + assert p.request(req%self.server2.urlbase) + assert switched(self.proxy.log) + + def test_get_connection_err(self): + p = self.pathoc() + ret = p.request("get:'http://localhost:0'") + assert ret[1] == 502 + + +class TestHTTPS(tservers.HTTPProxTest, SanityMixin): ssl = True clientcerts = True def test_clientcert(self): f = self.pathod("304") - assert self.last_log()["request"]["clientcert"]["keyinfo"] + assert self.server.last_log()["request"]["clientcert"]["keyinfo"] + + +class TestHTTPSCertfile(tservers.HTTPProxTest, SanityMixin): + ssl = True + certfile = True + def test_certfile(self): + assert self.pathod("304") -class TestReverse(tutils.ReverseProxTest, SanityMixin): +class TestReverse(tservers.ReverseProxTest, SanityMixin): reverse = True -class TestTransparent(tutils.TransparentProxTest, SanityMixin): +class TestTransparent(tservers.TransparentProxTest, SanityMixin): transparent = True -class TestProxy(tutils.HTTPProxTest): +class TestProxy(tservers.HTTPProxTest): def test_http(self): f = self.pathod("304") assert f.status_code == 304 @@ -132,3 +205,48 @@ class TestProxy(tutils.HTTPProxTest): request = self.master.state.view[1].request assert request.timestamp_end - request.timestamp_start <= 0.1 + + + +class MasterFakeResponse(tservers.TestMaster): + def handle_request(self, m): + resp = tutils.tresp() + m.reply(resp) + + +class TestFakeResponse(tservers.HTTPProxTest): + masterclass = MasterFakeResponse + def test_kill(self): + p = self.pathoc() + f = self.pathod("200") + assert "header_response" in f.headers.keys() + + + +class MasterKillRequest(tservers.TestMaster): + def handle_request(self, m): + m.reply(proxy.KILL) + + +class TestKillRequest(tservers.HTTPProxTest): + masterclass = MasterKillRequest + def test_kill(self): + p = self.pathoc() + tutils.raises("empty reply", self.pathod, "200") + # Nothing should have hit the server + assert not self.server.last_log() + + +class MasterKillResponse(tservers.TestMaster): + def handle_response(self, m): + m.reply(proxy.KILL) + + +class TestKillResponse(tservers.HTTPProxTest): + masterclass = MasterKillResponse + def test_kill(self): + p = self.pathoc() + tutils.raises("empty reply", self.pathod, "200") + # The server should have seen a request + assert self.server.last_log() + diff --git a/test/tservers.py b/test/tservers.py new file mode 100644 index 00000000..998ad6c6 --- /dev/null +++ b/test/tservers.py @@ -0,0 +1,198 @@ +import threading, Queue +import flask +import human_curl as hurl +import libpathod.test, libpathod.pathoc +from libmproxy import proxy, flow, controller +import tutils + +testapp = flask.Flask(__name__) + +@testapp.route("/") +def hello(): + return "testapp" + +@testapp.route("/error") +def error(): + raise ValueError("An exception...") + + +def errapp(environ, start_response): + raise ValueError("errapp") + + +class TestMaster(flow.FlowMaster): + def __init__(self, testq, config): + s = proxy.ProxyServer(config, 0) + s.apps.add(testapp, "testapp", 80) + s.apps.add(errapp, "errapp", 80) + state = flow.State() + flow.FlowMaster.__init__(self, s, state) + self.testq = testq + self.log = [] + + def handle_request(self, m): + flow.FlowMaster.handle_request(self, m) + m.reply() + + def handle_response(self, m): + flow.FlowMaster.handle_response(self, m) + m.reply() + + def handle_log(self, l): + self.log.append(l.msg) + l.reply() + + +class ProxyThread(threading.Thread): + def __init__(self, tmaster): + threading.Thread.__init__(self) + self.tmaster = tmaster + controller.should_exit = False + + @property + def port(self): + return self.tmaster.server.port + + @property + def log(self): + return self.tmaster.log + + def run(self): + self.tmaster.run() + + def shutdown(self): + self.tmaster.shutdown() + + +class ProxTestBase: + # Test Configuration + ssl = None + clientcerts = False + certfile = None + + masterclass = TestMaster + @classmethod + def setupAll(cls): + cls.tqueue = Queue.Queue() + cls.server = libpathod.test.Daemon(ssl=cls.ssl) + cls.server2 = libpathod.test.Daemon(ssl=cls.ssl) + pconf = cls.get_proxy_config() + config = proxy.ProxyConfig( + cacert = tutils.test_data.path("data/serverkey.pem"), + **pconf + ) + tmaster = cls.masterclass(cls.tqueue, config) + cls.proxy = ProxyThread(tmaster) + cls.proxy.start() + + @property + def master(cls): + return cls.proxy.tmaster + + @classmethod + def teardownAll(cls): + cls.proxy.shutdown() + cls.server.shutdown() + cls.server2.shutdown() + + def setUp(self): + self.master.state.clear() + + @property + def scheme(self): + return "https" if self.ssl else "http" + + @property + def proxies(self): + """ + The URL base for the server instance. + """ + return ( + (self.scheme, ("127.0.0.1", self.proxy.port)) + ) + + @classmethod + def get_proxy_config(cls): + d = dict() + if cls.clientcerts: + d["clientcerts"] = tutils.test_data.path("data/clientcert") + if cls.certfile: + d["certfile"] =tutils.test_data.path("data/testkey.pem") + return d + + +class HTTPProxTest(ProxTestBase): + def pathoc(self, connect_to = None): + """ + Returns a connected Pathoc instance. + """ + p = libpathod.pathoc.Pathoc("localhost", self.proxy.port) + p.connect(connect_to) + return p + + def pathod(self, spec): + """ + Constructs a pathod request, with the appropriate base and proxy. + """ + return hurl.get( + self.server.urlbase + "/p/" + spec, + proxy=self.proxies, + validate_cert=False, + #debug=hurl.utils.stdout_debug + ) + + +class TResolver: + def __init__(self, port): + self.port = port + + def original_addr(self, sock): + return ("127.0.0.1", self.port) + + +class TransparentProxTest(ProxTestBase): + ssl = None + @classmethod + def get_proxy_config(cls): + d = ProxTestBase.get_proxy_config() + d["transparent_proxy"] = dict( + resolver = TResolver(cls.server.port), + sslports = [] + ) + return d + + def pathod(self, spec): + """ + Constructs a pathod request, with the appropriate base and proxy. + """ + r = hurl.get( + "http://127.0.0.1:%s"%self.proxy.port + "/p/" + spec, + validate_cert=False, + #debug=hurl.utils.stdout_debug + ) + return r + + +class ReverseProxTest(ProxTestBase): + ssl = None + @classmethod + def get_proxy_config(cls): + d = ProxTestBase.get_proxy_config() + d["reverse_proxy"] = ( + "https" if cls.ssl else "http", + "127.0.0.1", + cls.server.port + ) + return d + + def pathod(self, spec): + """ + Constructs a pathod request, with the appropriate base and proxy. + """ + r = hurl.get( + "http://127.0.0.1:%s"%self.proxy.port + "/p/" + spec, + validate_cert=False, + #debug=hurl.utils.stdout_debug + ) + return r + diff --git a/test/tutils.py b/test/tutils.py index 9868c778..1a1c8724 100644 --- a/test/tutils.py +++ b/test/tutils.py @@ -1,17 +1,18 @@ -import threading, Queue import os, shutil, tempfile from contextlib import contextmanager -from libmproxy import proxy, flow, controller, utils +from libmproxy import flow, utils, controller from netlib import certutils -import human_curl as hurl -import libpathod.test, libpathod.pathoc +import mock def treq(conn=None): if not conn: conn = flow.ClientConnect(("address", 22)) + conn.reply = controller.DummyReply() headers = flow.ODictCaseless() headers["header"] = ["qvalue"] - return flow.Request(conn, (1, 1), "host", 80, "http", "GET", "/path", headers, "content") + r = flow.Request(conn, (1, 1), "host", 80, "http", "GET", "/path", headers, "content") + r.reply = controller.DummyReply() + return r def tresp(req=None): @@ -20,7 +21,9 @@ def tresp(req=None): headers = flow.ODictCaseless() headers["header_response"] = ["svalue"] cert = certutils.SSLCert.from_der(file(test_data.path("data/dercert")).read()) - return flow.Response(req, (1, 1), 200, "message", headers, "content_response", cert) + resp = flow.Response(req, (1, 1), 200, "message", headers, "content_response", cert) + resp.reply = controller.DummyReply() + return resp def tflow(): @@ -39,168 +42,10 @@ def tflow_err(): r = treq() f = flow.Flow(r) f.error = flow.Error(r, "error") + f.error.reply = controller.DummyReply() return f -class TestMaster(flow.FlowMaster): - def __init__(self, testq, config): - s = proxy.ProxyServer(config, 0) - state = flow.State() - flow.FlowMaster.__init__(self, s, state) - self.testq = testq - - def handle(self, m): - flow.FlowMaster.handle(self, m) - m._ack() - - -class ProxyThread(threading.Thread): - def __init__(self, testq, config): - self.tmaster = TestMaster(testq, config) - controller.should_exit = False - threading.Thread.__init__(self) - - @property - def port(self): - return self.tmaster.server.port - - def run(self): - self.tmaster.run() - - def shutdown(self): - self.tmaster.shutdown() - - -class ProxTestBase: - @classmethod - def setupAll(cls): - cls.tqueue = Queue.Queue() - cls.server = libpathod.test.Daemon(ssl=cls.ssl) - pconf = cls.get_proxy_config() - config = proxy.ProxyConfig( - certfile=test_data.path("data/testkey.pem"), - **pconf - ) - cls.proxy = ProxyThread(cls.tqueue, config) - cls.proxy.start() - - @property - def master(cls): - return cls.proxy.tmaster - - @classmethod - def teardownAll(cls): - cls.proxy.shutdown() - cls.server.shutdown() - - def setUp(self): - self.master.state.clear() - - @property - def scheme(self): - return "https" if self.ssl else "http" - - @property - def proxies(self): - """ - The URL base for the server instance. - """ - return ( - (self.scheme, ("127.0.0.1", self.proxy.port)) - ) - - @property - def urlbase(self): - """ - The URL base for the server instance. - """ - return self.server.urlbase - - def last_log(self): - return self.server.last_log() - - -class HTTPProxTest(ProxTestBase): - ssl = None - clientcerts = False - @classmethod - def get_proxy_config(cls): - d = dict() - if cls.clientcerts: - d["clientcerts"] = test_data.path("data/clientcert") - return d - - def pathoc(self, connect_to = None): - p = libpathod.pathoc.Pathoc("localhost", self.proxy.port) - p.connect(connect_to) - return p - - def pathod(self, spec): - """ - Constructs a pathod request, with the appropriate base and proxy. - """ - return hurl.get( - self.urlbase + "/p/" + spec, - proxy=self.proxies, - validate_cert=False, - #debug=hurl.utils.stdout_debug - ) - - -class TResolver: - def __init__(self, port): - self.port = port - - def original_addr(self, sock): - return ("127.0.0.1", self.port) - - -class TransparentProxTest(ProxTestBase): - ssl = None - @classmethod - def get_proxy_config(cls): - return dict( - transparent_proxy = dict( - resolver = TResolver(cls.server.port), - sslports = [] - ) - ) - - def pathod(self, spec): - """ - Constructs a pathod request, with the appropriate base and proxy. - """ - r = hurl.get( - "http://127.0.0.1:%s"%self.proxy.port + "/p/" + spec, - validate_cert=False, - #debug=hurl.utils.stdout_debug - ) - return r - - -class ReverseProxTest(ProxTestBase): - ssl = None - @classmethod - def get_proxy_config(cls): - return dict( - reverse_proxy = ( - "https" if cls.ssl else "http", - "127.0.0.1", - cls.server.port - ) - ) - - def pathod(self, spec): - """ - Constructs a pathod request, with the appropriate base and proxy. - """ - r = hurl.get( - "http://127.0.0.1:%s"%self.proxy.port + "/p/" + spec, - validate_cert=False, - #debug=hurl.utils.stdout_debug - ) - return r - @contextmanager def tmpdir(*args, **kwargs): @@ -252,5 +97,4 @@ def raises(exc, obj, *args, **kwargs): ) raise AssertionError("No exception raised.") - test_data = utils.Data(__name__) |