diff options
author | Maximilian Hils <git@maximilianhils.com> | 2014-09-03 16:57:56 +0200 |
---|---|---|
committer | Maximilian Hils <git@maximilianhils.com> | 2014-09-03 16:57:56 +0200 |
commit | b0cfeff06d9dd99a16dfae19c5df3c73c5864fb9 (patch) | |
tree | 9204cb67becedded525d57581efbed947225b3e2 /libmproxy | |
parent | 951a6fcc36780a0bd5a1f1ff718327d1c6d4fc5e (diff) | |
download | mitmproxy-b0cfeff06d9dd99a16dfae19c5df3c73c5864fb9.tar.gz mitmproxy-b0cfeff06d9dd99a16dfae19c5df3c73c5864fb9.tar.bz2 mitmproxy-b0cfeff06d9dd99a16dfae19c5df3c73c5864fb9.zip |
fix #341 - work on flows instead of request/response internally.
Diffstat (limited to 'libmproxy')
-rw-r--r-- | libmproxy/app.py | 8 | ||||
-rw-r--r-- | libmproxy/console/__init__.py | 26 | ||||
-rw-r--r-- | libmproxy/dump.py | 32 | ||||
-rw-r--r-- | libmproxy/filt.py | 4 | ||||
-rw-r--r-- | libmproxy/flow.py | 106 | ||||
-rw-r--r-- | libmproxy/protocol/http.py | 102 | ||||
-rw-r--r-- | libmproxy/protocol/primitives.py | 23 | ||||
-rw-r--r-- | libmproxy/script.py | 11 |
8 files changed, 132 insertions, 180 deletions
diff --git a/libmproxy/app.py b/libmproxy/app.py index 9941d6ea..ed7ec72a 100644 --- a/libmproxy/app.py +++ b/libmproxy/app.py @@ -1,7 +1,7 @@ from __future__ import absolute_import import flask -import os.path, os -from . import proxy +import os +from .proxy import config mapp = flask.Flask(__name__) mapp.debug = True @@ -18,12 +18,12 @@ def index(): @mapp.route("/cert/pem") def certs_pem(): - p = os.path.join(master().server.config.confdir, proxy.config.CONF_BASENAME + "-ca-cert.pem") + p = os.path.join(master().server.config.confdir, config.CONF_BASENAME + "-ca-cert.pem") return flask.Response(open(p, "rb").read(), mimetype='application/x-x509-ca-cert') @mapp.route("/cert/p12") def certs_p12(): - p = os.path.join(master().server.config.confdir, proxy.config.CONF_BASENAME + "-ca-cert.p12") + p = os.path.join(master().server.config.confdir, config.CONF_BASENAME + "-ca-cert.p12") return flask.Response(open(p, "rb").read(), mimetype='application/x-pkcs12') diff --git a/libmproxy/console/__init__.py b/libmproxy/console/__init__.py index 1325aae5..a5920915 100644 --- a/libmproxy/console/__init__.py +++ b/libmproxy/console/__init__.py @@ -268,8 +268,8 @@ class ConsoleState(flow.State): d = self.flowsettings.get(flow, {}) return d.get(key, default) - def add_request(self, req): - f = flow.State.add_request(self, req) + def add_request(self, f): + flow.State.add_request(self, f) if self.focus is None: self.set_focus(0) elif self.follow_focus: @@ -996,11 +996,11 @@ class ConsoleMaster(flow.FlowMaster): if hasattr(self.statusbar, "refresh_flow"): self.statusbar.refresh_flow(c) - def process_flow(self, f, r): + def process_flow(self, f): if self.state.intercept and f.match(self.state.intercept) and not f.request.is_replay: f.intercept() else: - r.reply() + f.reply() self.sync_list_view() self.refresh_flow(f) @@ -1022,20 +1022,20 @@ class ConsoleMaster(flow.FlowMaster): self.eventlist.set_focus(len(self.eventlist)-1) # Handlers - def handle_error(self, r): - f = flow.FlowMaster.handle_error(self, r) + def handle_error(self, f): + f = flow.FlowMaster.handle_error(self, f) if f: - self.process_flow(f, r) + self.process_flow(f) return f - def handle_request(self, r): - f = flow.FlowMaster.handle_request(self, r) + def handle_request(self, f): + f = flow.FlowMaster.handle_request(self, f) if f: - self.process_flow(f, r) + self.process_flow(f) return f - def handle_response(self, r): - f = flow.FlowMaster.handle_response(self, r) + def handle_response(self, f): + f = flow.FlowMaster.handle_response(self, f) if f: - self.process_flow(f, r) + self.process_flow(f) return f diff --git a/libmproxy/dump.py b/libmproxy/dump.py index aeb34cc3..8ecd56e7 100644 --- a/libmproxy/dump.py +++ b/libmproxy/dump.py @@ -50,13 +50,13 @@ def str_response(resp): return r -def str_request(req, showhost): - if req.flow.client_conn: - c = req.flow.client_conn.address.host +def str_request(f, showhost): + if f.client_conn: + c = f.client_conn.address.host else: c = "[replay]" - r = "%s %s %s"%(c, req.method, req.get_url(showhost)) - if req.stickycookie: + r = "%s %s %s"%(c, f.request.method, f.request.get_url(showhost, f)) + if f.request.stickycookie: r = "[stickycookie] " + r return r @@ -185,16 +185,16 @@ class DumpMaster(flow.FlowMaster): result = " << %s"%f.error.msg if self.o.flow_detail == 1: - print >> self.outfile, str_request(f.request, self.showhost) + print >> self.outfile, str_request(f, self.showhost) print >> self.outfile, result elif self.o.flow_detail == 2: - print >> self.outfile, str_request(f.request, self.showhost) + print >> self.outfile, str_request(f, self.showhost) print >> self.outfile, self.indent(4, f.request.headers) print >> self.outfile print >> self.outfile, result print >> self.outfile, "\n" elif self.o.flow_detail >= 3: - print >> self.outfile, str_request(f.request, self.showhost) + print >> self.outfile, str_request(f, self.showhost) print >> self.outfile, self.indent(4, f.request.headers) if utils.isBin(f.request.content): print >> self.outfile, self.indent(4, netlib.utils.hexdump(f.request.content)) @@ -206,21 +206,21 @@ class DumpMaster(flow.FlowMaster): if self.o.flow_detail: self.outfile.flush() - def handle_request(self, r): - f = flow.FlowMaster.handle_request(self, r) + def handle_request(self, f): + flow.FlowMaster.handle_request(self, f) if f: - r.reply() + f.reply() return f - def handle_response(self, msg): - f = flow.FlowMaster.handle_response(self, msg) + def handle_response(self, f): + flow.FlowMaster.handle_response(self, f) if f: - msg.reply() + f.reply() self._process_flow(f) return f - def handle_error(self, msg): - f = flow.FlowMaster.handle_error(self, msg) + def handle_error(self, f): + flow.FlowMaster.handle_error(self, f) if f: self._process_flow(f) return f diff --git a/libmproxy/filt.py b/libmproxy/filt.py index e17ed735..925dbfbb 100644 --- a/libmproxy/filt.py +++ b/libmproxy/filt.py @@ -208,7 +208,7 @@ class FDomain(_Rex): code = "d" help = "Domain" def __call__(self, f): - return bool(re.search(self.expr, f.request.get_host(), re.IGNORECASE)) + return bool(re.search(self.expr, f.request.get_host(False, f), re.IGNORECASE)) class FUrl(_Rex): @@ -222,7 +222,7 @@ class FUrl(_Rex): return klass(*toks) def __call__(self, f): - return re.search(self.expr, f.request.get_url()) + return re.search(self.expr, f.request.get_url(False, f)) class _Int(_Action): diff --git a/libmproxy/flow.py b/libmproxy/flow.py index 2540435e..eb183d9f 100644 --- a/libmproxy/flow.py +++ b/libmproxy/flow.py @@ -34,11 +34,11 @@ class AppRegistry: """ Returns an WSGIAdaptor instance if request matches an app, or None. """ - if (request.get_host(), request.get_port()) in self.apps: - return self.apps[(request.get_host(), request.get_port())] + if (request.host, request.port) in self.apps: + return self.apps[(request.host, request.port)] if "host" in request.headers: host = request.headers["host"][0] - return self.apps.get((host, request.get_port()), None) + return self.apps.get((host, request.port), None) class ReplaceHooks: @@ -185,11 +185,11 @@ class ClientPlaybackState: n = self.flows.pop(0) n.request.reply = controller.DummyReply() n.client_conn = None - self.current = master.handle_request(n.request) + self.current = master.handle_request(n) if not testing and not self.current.response: - master.replay_request(self.current) # pragma: no cover + master.replay_request(self.current) # pragma: no cover elif self.current.response: - master.handle_response(self.current.response) + master.handle_response(self.current) class ServerPlaybackState: @@ -260,8 +260,8 @@ class StickyCookieState: Returns a (domain, port, path) tuple. """ return ( - m["domain"] or f.request.get_host(), - f.request.get_port(), + m["domain"] or f.request.get_host(False, f), + f.request.get_port(f), m["path"] or "/" ) @@ -279,7 +279,7 @@ class StickyCookieState: c = Cookie.SimpleCookie(str(i)) m = c.values()[0] k = self.ckey(m, f) - if self.domain_match(f.request.get_host(), k[0]): + if self.domain_match(f.request.get_host(False, f), k[0]): self.jar[self.ckey(m, f)] = m def handle_request(self, f): @@ -287,8 +287,8 @@ class StickyCookieState: if f.match(self.flt): for i in self.jar.keys(): match = [ - self.domain_match(f.request.get_host(), i[0]), - f.request.get_port() == i[1], + self.domain_match(f.request.get_host(False, f), i[0]), + f.request.get_port(f) == i[1], f.request.path.startswith(i[2]) ] if all(match): @@ -307,7 +307,7 @@ class StickyAuthState: self.hosts = {} def handle_request(self, f): - host = f.request.get_host() + host = f.request.get_host(False, f) if "authorization" in f.request.headers: self.hosts[host] = f.request.headers["authorization"] elif f.match(self.flt): @@ -342,33 +342,30 @@ class State(object): c += 1 return c - def add_request(self, req): + def add_request(self, flow): """ Add a request to the state. Returns the matching flow. """ - f = req.flow - self._flow_list.append(f) - if f.match(self._limit): - self.view.append(f) - return f + self._flow_list.append(flow) + if flow.match(self._limit): + self.view.append(flow) + return flow - def add_response(self, resp): + def add_response(self, f): """ Add a response to the state. Returns the matching flow. """ - f = resp.flow if not f: return False if f.match(self._limit) and not f in self.view: self.view.append(f) return f - def add_error(self, err): + def add_error(self, f): """ Add an error response to the state. Returns the matching flow, or None if there isn't one. """ - f = err.flow if not f: return None if f.match(self._limit) and not f in self.view: @@ -586,7 +583,7 @@ class FlowMaster(controller.Master): response.is_replay = True if self.refresh_server_playback: response.refresh() - flow.request.reply(response) + flow.reply(response) if self.server_playback.count() == 0: self.stop_server_playback() return True @@ -612,16 +609,14 @@ class FlowMaster(controller.Master): """ Loads a flow, and returns a new flow object. """ + f.reply = controller.DummyReply() if f.request: - f.request.reply = controller.DummyReply() - fr = self.handle_request(f.request) + self.handle_request(f) if f.response: - f.response.reply = controller.DummyReply() - self.handle_response(f.response) + self.handle_response(f) if f.error: - f.error.reply = controller.DummyReply() - self.handle_error(f.error) - return fr + self.handle_error(f) + return f def load_flows(self, fr): """ @@ -647,7 +642,7 @@ class FlowMaster(controller.Master): if self.kill_nonreplay: f.kill(self) else: - f.request.reply() + f.reply() def process_new_response(self, f): if self.stickycookie_state: @@ -694,54 +689,49 @@ class FlowMaster(controller.Master): self.run_script_hook("serverconnect", sc) sc.reply() - def handle_error(self, r): - f = self.state.add_error(r) - if f: - self.run_script_hook("error", f) + def handle_error(self, f): + self.state.add_error(f) + self.run_script_hook("error", f) if self.client_playback: self.client_playback.clear(f) - r.reply() + f.reply() return f - def handle_request(self, r): - if r.flow.live: - app = self.apps.get(r) + def handle_request(self, f): + if f.live: + app = self.apps.get(f.request) if app: - err = app.serve(r, r.flow.client_conn.wfile, **{"mitmproxy.master": self}) + err = app.serve(f, f.client_conn.wfile, **{"mitmproxy.master": self}) if err: self.add_event("Error in wsgi app. %s"%err, "error") - r.reply(protocol.KILL) + f.reply(protocol.KILL) return - f = self.state.add_request(r) + self.state.add_request(f) self.replacehooks.run(f) self.setheaders.run(f) self.run_script_hook("request", f) self.process_new_request(f) return f - def handle_responseheaders(self, resp): - f = resp.flow + def handle_responseheaders(self, f): self.run_script_hook("responseheaders", f) if self.stream_large_bodies: self.stream_large_bodies.run(f, False) - resp.reply() + f.reply() return f - def handle_response(self, r): - f = self.state.add_response(r) - if f: - self.replacehooks.run(f) - self.setheaders.run(f) - self.run_script_hook("response", f) - if self.client_playback: - self.client_playback.clear(f) - self.process_new_response(f) - if self.stream: - self.stream.add(f) - else: - r.reply() + def handle_response(self, f): + self.state.add_response(f) + self.replacehooks.run(f) + self.setheaders.run(f) + self.run_script_hook("response", f) + if self.client_playback: + self.client_playback.clear(f) + self.process_new_response(f) + if self.stream: + self.stream.add(f) return f def shutdown(self): diff --git a/libmproxy/protocol/http.py b/libmproxy/protocol/http.py index 658c08ed..3f9eecb3 100644 --- a/libmproxy/protocol/http.py +++ b/libmproxy/protocol/http.py @@ -77,9 +77,6 @@ class HTTPMessage(stateobject.SimpleStateObject): self.timestamp_start = timestamp_start if timestamp_start is not None else utils.timestamp() self.timestamp_end = timestamp_end if timestamp_end is not None else utils.timestamp() - self.flow = None # will usually be set by the flow backref mixin - """@type: HTTPFlow""" - _stateobject_attributes = dict( httpversion=tuple, headers=ODictCaseless, @@ -346,10 +343,10 @@ class HTTPRequest(HTTPMessage): del headers[k] if headers["Upgrade"] == ["h2c"]: # Suppress HTTP2 https://http2.github.io/http2-spec/index.html#discover-http del headers["Upgrade"] - if not 'host' in headers: + if not 'host' in headers and self.scheme and self.host and self.port: headers["Host"] = [utils.hostport(self.scheme, - self.host or self.flow.server_conn.address.host, - self.port or self.flow.server_conn.address.port)] + self.host, + self.port)] if self.content: headers["Content-Length"] = [str(len(self.content))] @@ -429,16 +426,16 @@ class HTTPRequest(HTTPMessage): self.headers["Content-Type"] = [HDR_FORM_URLENCODED] self.content = utils.urlencode(odict.lst) - def get_path_components(self): + def get_path_components(self, f): """ Returns the path components of the URL as a list of strings. Components are unquoted. """ - _, _, path, _, _, _ = urlparse.urlparse(self.get_url()) + _, _, path, _, _, _ = urlparse.urlparse(self.get_url(False, f)) return [urllib.unquote(i) for i in path.split("/") if i] - def set_path_components(self, lst): + def set_path_components(self, lst, f): """ Takes a list of strings, and sets the path component of the URL. @@ -446,27 +443,27 @@ class HTTPRequest(HTTPMessage): """ lst = [urllib.quote(i, safe="") for i in lst] path = "/" + "/".join(lst) - scheme, netloc, _, params, query, fragment = urlparse.urlparse(self.get_url()) - self.set_url(urlparse.urlunparse([scheme, netloc, path, params, query, fragment])) + scheme, netloc, _, params, query, fragment = urlparse.urlparse(self.get_url(False, f)) + self.set_url(urlparse.urlunparse([scheme, netloc, path, params, query, fragment]), f) - def get_query(self): + def get_query(self, f): """ Gets the request query string. Returns an ODict object. """ - _, _, _, _, query, _ = urlparse.urlparse(self.get_url()) + _, _, _, _, query, _ = urlparse.urlparse(self.get_url(False, f)) if query: return ODict(utils.urldecode(query)) return ODict([]) - def set_query(self, odict): + def set_query(self, odict, f): """ Takes an ODict object, and sets the request query string. """ - scheme, netloc, path, params, _, fragment = urlparse.urlparse(self.get_url()) + scheme, netloc, path, params, _, fragment = urlparse.urlparse(self.get_url(False, f)) query = utils.urlencode(odict.lst) - self.set_url(urlparse.urlunparse([scheme, netloc, path, params, query, fragment])) + self.set_url(urlparse.urlunparse([scheme, netloc, path, params, query, fragment]), f) - def get_host(self, hostheader=False): + def get_host(self, hostheader, flow): """ Heuristic to get the host of the request. @@ -484,16 +481,16 @@ class HTTPRequest(HTTPMessage): if self.host: host = self.host else: - for s in self.flow.server_conn.state: + for s in flow.server_conn.state: if s[0] == "http" and s[1]["state"] == "connect": host = s[1]["host"] break if not host: - host = self.flow.server_conn.address.host + host = flow.server_conn.address.host host = host.encode("idna") return host - def get_scheme(self): + def get_scheme(self, flow): """ Returns the request port, either from the request itself or from the flow's server connection """ @@ -501,20 +498,20 @@ class HTTPRequest(HTTPMessage): return self.scheme if self.form_out == "authority": # On SSLed connections, the original CONNECT request is still unencrypted. return "http" - return "https" if self.flow.server_conn.ssl_established else "http" + return "https" if flow.server_conn.ssl_established else "http" - def get_port(self): + def get_port(self, flow): """ Returns the request port, either from the request itself or from the flow's server connection """ if self.port: return self.port - for s in self.flow.server_conn.state: + for s in flow.server_conn.state: if s[0] == "http" and s[1].get("state") == "connect": return s[1]["port"] - return self.flow.server_conn.address.port + return flow.server_conn.address.port - def get_url(self, hostheader=False): + def get_url(self, hostheader, flow): """ Returns a URL string, constructed from the Request's URL components. @@ -522,13 +519,13 @@ class HTTPRequest(HTTPMessage): Host header to construct the URL. """ if self.form_out == "authority": # upstream proxy mode - return "%s:%s" % (self.get_host(hostheader), self.get_port()) - return utils.unparse_url(self.get_scheme(), - self.get_host(hostheader), - self.get_port(), + return "%s:%s" % (self.get_host(hostheader, flow), self.get_port(flow)) + return utils.unparse_url(self.get_scheme(flow), + self.get_host(hostheader, flow), + self.get_port(flow), self.path).encode('ascii') - def set_url(self, url): + def set_url(self, url, flow): """ Parses a URL specification, and updates the Request's information accordingly. @@ -543,14 +540,14 @@ class HTTPRequest(HTTPMessage): self.path = path - if host != self.get_host() or port != self.get_port(): - if self.flow.live: - self.flow.live.change_server((host, port), ssl=is_ssl) + if host != self.get_host(False, flow) or port != self.get_port(flow): + if flow.live: + flow.live.change_server((host, port), ssl=is_ssl) else: # There's not live server connection, we're just changing the attributes here. - self.flow.server_conn = ServerConnection((host, port), + flow.server_conn = ServerConnection((host, port), proxy.AddressPriority.MANUALLY_CHANGED) - self.flow.server_conn.ssl_established = is_ssl + flow.server_conn.ssl_established = is_ssl # If this is an absolute request, replace the attributes on the request object as well. if self.host: @@ -802,8 +799,6 @@ class HTTPFlow(Flow): self.intercepting = False # FIXME: Should that rather be an attribute of Flow? - _backrefattr = Flow._backrefattr + ("request", "response") - _stateobject_attributes = Flow._stateobject_attributes.copy() _stateobject_attributes.update( request=HTTPRequest, @@ -855,13 +850,10 @@ class HTTPFlow(Flow): Kill this request. """ self.error = Error("Connection killed") - self.error.reply = controller.DummyReply() - if self.request and not self.request.reply.acked: - self.request.reply(KILL) - elif self.response and not self.response.reply.acked: - self.response.reply(KILL) - master.handle_error(self.error) self.intercepting = False + self.reply(KILL) + self.reply = controller.DummyReply() + master.handle_error(self) def intercept(self): """ @@ -874,12 +866,8 @@ class HTTPFlow(Flow): """ Continue with the flow - called after an intercept(). """ - if self.request: - if not self.request.reply.acked: - self.request.reply() - elif self.response and not self.response.reply.acked: - self.response.reply() - self.intercepting = False + self.intercepting = False + self.reply() def replace(self, pattern, repl, *args, **kwargs): """ @@ -961,7 +949,7 @@ class HTTPHandler(ProtocolHandler): # in an Error object that has an attached request that has not been # sent through to the Master. flow.request = req - request_reply = self.c.channel.ask("request", flow.request) + request_reply = self.c.channel.ask("request", flow) self.determine_server_address(flow, flow.request) # The inline script may have changed request.host flow.server_conn = self.c.server_conn # Update server_conn attribute on the flow @@ -976,7 +964,7 @@ class HTTPHandler(ProtocolHandler): flow.response = self.get_response_from_server(flow.request, include_body=False) # call the appropriate script hook - this is an opportunity for an inline script to set flow.stream = True - self.c.channel.ask("responseheaders", flow.response) + self.c.channel.ask("responseheaders", flow) # now get the rest of the request body, if body still needs to be read but not streaming this response if flow.response.stream: @@ -991,7 +979,7 @@ class HTTPHandler(ProtocolHandler): flow.server_conn = self.c.server_conn self.c.log("response", "debug", [flow.response._assemble_first_line()]) - response_reply = self.c.channel.ask("response", flow.response) + response_reply = self.c.channel.ask("response", flow) if response_reply is None or response_reply == KILL: return False @@ -1079,7 +1067,7 @@ class HTTPHandler(ProtocolHandler): # TODO: no flows without request or with both request and response at the moment. if flow.request and not flow.response: flow.error = Error(message) - self.c.channel.ask("error", flow.error) + self.c.channel.ask("error", flow) try: code = getattr(error, "code", 502) @@ -1204,12 +1192,12 @@ class RequestReplayThread(threading.Thread): except proxy.ProxyError: pass if not server_address: - server_address = (r.get_host(), r.get_port()) + server_address = (r.get_host(False, self.flow), r.get_port(self.flow)) server = ServerConnection(server_address, None) server.connect() - if server_ssl or r.get_scheme() == "https": + if server_ssl or r.get_scheme(self.flow) == "https": if self.config.http_form_out == "absolute": # form_out == absolute -> forward mode -> send CONNECT send_connect_request(server, r.get_host(), r.get_port()) r.form_out = "relative" @@ -1218,9 +1206,9 @@ class RequestReplayThread(threading.Thread): server.send(r._assemble()) self.flow.response = HTTPResponse.from_stream(server.rfile, r.method, body_size_limit=self.config.body_size_limit) - self.channel.ask("response", self.flow.response) + self.channel.ask("response", self.flow) except (proxy.ProxyError, http.HttpError, tcp.NetLibError), v: self.flow.error = Error(repr(v)) - self.channel.ask("error", self.flow.error) + self.channel.ask("error", self.flow) finally: r.form_out = form_out_backup
\ No newline at end of file diff --git a/libmproxy/protocol/primitives.py b/libmproxy/protocol/primitives.py index a227d904..a84b4061 100644 --- a/libmproxy/protocol/primitives.py +++ b/libmproxy/protocol/primitives.py @@ -9,24 +9,6 @@ from ..proxy.connection import ClientConnection, ServerConnection KILL = 0 # const for killed requests -class BackreferenceMixin(object): - """ - If an attribute from the _backrefattr tuple is set, - this mixin sets a reference back on the attribute object. - Example: - e = Error() - f = Flow() - f.error = e - assert f is e.flow - """ - _backrefattr = tuple() - - def __setattr__(self, key, value): - super(BackreferenceMixin, self).__setattr__(key, value) - if key in self._backrefattr and value is not None: - setattr(value, self._backrefname, self) - - class Error(stateobject.SimpleStateObject): """ An Error. @@ -70,7 +52,7 @@ class Error(stateobject.SimpleStateObject): return c -class Flow(stateobject.SimpleStateObject, BackreferenceMixin): +class Flow(stateobject.SimpleStateObject): def __init__(self, conntype, client_conn, server_conn, live=None): self.conntype = conntype self.client_conn = client_conn @@ -84,9 +66,6 @@ class Flow(stateobject.SimpleStateObject, BackreferenceMixin): """@type: Error""" self._backup = None - _backrefattr = ("error",) - _backrefname = "flow" - _stateobject_attributes = dict( error=Error, client_conn=ClientConnection, diff --git a/libmproxy/script.py b/libmproxy/script.py index e582c4e8..706d84d5 100644 --- a/libmproxy/script.py +++ b/libmproxy/script.py @@ -125,13 +125,8 @@ def _handle_concurrent_reply(fn, o, *args, **kwargs): def concurrent(fn): - if fn.func_name in ["request", "response", "error"]: - def _concurrent(ctx, flow): - r = getattr(flow, fn.func_name) - _handle_concurrent_reply(fn, r, ctx, flow) - return _concurrent - elif fn.func_name in ["clientconnect", "serverconnect", "clientdisconnect"]: - def _concurrent(ctx, conn): - _handle_concurrent_reply(fn, conn, ctx, conn) + if fn.func_name in ("request", "response", "error", "clientconnect", "serverconnect", "clientdisconnect"): + def _concurrent(ctx, obj): + _handle_concurrent_reply(fn, obj, ctx, obj) return _concurrent raise NotImplementedError("Concurrent decorator not supported for this method.") |