diff options
Diffstat (limited to 'libmproxy')
-rw-r--r-- | libmproxy/flow.py | 19 | ||||
-rw-r--r-- | libmproxy/protocol/__init__.py | 16 | ||||
-rw-r--r-- | libmproxy/protocol/http.py | 137 | ||||
-rw-r--r-- | libmproxy/protocol/primitives.py | 11 | ||||
-rw-r--r-- | libmproxy/proxy.py | 18 |
5 files changed, 134 insertions, 67 deletions
diff --git a/libmproxy/flow.py b/libmproxy/flow.py index bf9171a7..55ff109e 100644 --- a/libmproxy/flow.py +++ b/libmproxy/flow.py @@ -249,10 +249,9 @@ class StickyCookieState: """ Returns a (domain, port, path) tuple. """ - raise NotImplementedError return ( - m["domain"] or f.request.host, - f.server_conn.address.port, + m["domain"] or f.request.get_host(), + f.request.get_port(), m["path"] or "/" ) @@ -270,7 +269,7 @@ class StickyCookieState: c = Cookie.SimpleCookie(str(i)) m = c.values()[0] k = self.ckey(m, f) - if self.domain_match(f.request.host, k[0]): + if self.domain_match(f.request.get_host(), k[0]): self.jar[self.ckey(m, f)] = m def handle_request(self, f): @@ -278,8 +277,8 @@ class StickyCookieState: if f.match(self.flt): for i in self.jar.keys(): match = [ - self.domain_match(f.request.host, i[0]), - f.request.port == i[1], + self.domain_match(f.request.get_host(), i[0]), + f.request.get_port() == i[1], f.request.path.startswith(i[2]) ] if all(match): @@ -298,12 +297,12 @@ class StickyAuthState: self.hosts = {} def handle_request(self, f): - raise NotImplementedError + host = f.request.get_host() if "authorization" in f.request.headers: - self.hosts[f.request.host] = f.request.headers["authorization"] + self.hosts[host] = f.request.headers["authorization"] elif f.match(self.flt): - if f.request.host in self.hosts: - f.request.headers["authorization"] = self.hosts[f.request.host] + if host in self.hosts: + f.request.headers["authorization"] = self.hosts[host] class State(object): diff --git a/libmproxy/protocol/__init__.py b/libmproxy/protocol/__init__.py index da85500b..f23159b2 100644 --- a/libmproxy/protocol/__init__.py +++ b/libmproxy/protocol/__init__.py @@ -28,6 +28,22 @@ class ProtocolHandler(object): """ raise error + +class TemporaryServerChangeMixin(object): + """ + This mixin allows safe modification of the target server, + without any need to expose the ConnectionHandler to the Flow. + """ + + def change_server(self): + self._backup_server = True + raise NotImplementedError + + def restore_server(self): + if not hasattr(self,"_backup_server"): + return + raise NotImplementedError + from . import http, tcp protocols = { diff --git a/libmproxy/protocol/http.py b/libmproxy/protocol/http.py index 636e1b07..069030ef 100644 --- a/libmproxy/protocol/http.py +++ b/libmproxy/protocol/http.py @@ -3,9 +3,9 @@ from email.utils import parsedate_tz, formatdate, mktime_tz import netlib.utils from netlib import http, tcp, http_status, odict from netlib.odict import ODict, ODictCaseless -from . import ProtocolHandler, ConnectionTypeChange, KILL +from . import ProtocolHandler, ConnectionTypeChange, KILL, TemporaryServerChangeMixin from .. import encoding, utils, version, filt, controller, stateobject -from ..proxy import ProxyError, AddressPriority +from ..proxy import ProxyError, AddressPriority, ServerConnection from .primitives import Flow, Error @@ -55,10 +55,24 @@ class decoded(object): class HTTPMessage(stateobject.SimpleStateObject): - def __init__(self): + def __init__(self, httpversion, headers, content, timestamp_start=None, timestamp_end=None): + self.httpversion = httpversion + self.headers = headers + self.content = content + self.timestamp_start = timestamp_start + self.timestamp_end = timestamp_end + self.flow = None # will usually be set by the flow backref mixin """@type: HTTPFlow""" + _stateobject_attributes = dict( + httpversion=tuple, + headers=ODictCaseless, + content=str, + timestamp_start=float, + timestamp_end=float + ) + def get_decoded_content(self): """ Returns the decoded content based on the current Content-Encoding header. @@ -199,7 +213,7 @@ class HTTPRequest(HTTPMessage): def __init__(self, form_in, method, scheme, host, port, path, httpversion, headers, content, timestamp_start=None, timestamp_end=None, form_out=None): assert isinstance(headers, ODictCaseless) or not headers - HTTPMessage.__init__(self) + HTTPMessage.__init__(self, httpversion, headers, content, timestamp_start, timestamp_end) self.form_in = form_in self.method = method @@ -208,10 +222,6 @@ class HTTPRequest(HTTPMessage): self.port = port self.path = path self.httpversion = httpversion - self.headers = headers - self.content = content - self.timestamp_start = timestamp_start - self.timestamp_end = timestamp_end self.form_out = form_out or form_in # Have this request's cookies been modified by sticky cookies or auth? @@ -220,18 +230,14 @@ class HTTPRequest(HTTPMessage): # Is this request replayed? self.is_replay = False - _stateobject_attributes = dict( + _stateobject_attributes = HTTPMessage._stateobject_attributes.copy() + _stateobject_attributes.update( form_in=str, method=str, scheme=str, host=str, port=int, path=str, - httpversion=tuple, - headers=ODictCaseless, - content=str, - timestamp_start=float, - timestamp_end=float, form_out=str ) @@ -437,15 +443,13 @@ class HTTPRequest(HTTPMessage): query = utils.urlencode(odict.lst) self.set_url(urlparse.urlunparse([scheme, netloc, path, params, query, fragment])) - def get_url(self, hostheader=False): + def get_host(self, hostheader=False): """ - Returns a URL string, constructed from the Request's URL components. - - If hostheader is True, we use the value specified in the request - Host header to construct the URL. + Heuristic to get the host of the request. + The host is not necessarily equal to the TCP destination of the request, + for example on a transparently proxified absolute-form request to an upstream HTTP proxy. + If hostheader is set to True, the Host: header will be used as additional (and preferred) data source. """ - raise NotImplementedError - # FIXME: Take server_conn into account. host = None if hostheader: host = self.headers.get_first("host") @@ -455,7 +459,35 @@ class HTTPRequest(HTTPMessage): else: host = self.flow.server_conn.address.host host = host.encode("idna") - return utils.unparse_url(self.scheme, host, self.port, self.path).encode('ascii') + return host + + def get_scheme(self): + """ + Returns the request port, either from the request itself or from the flow's server connection + """ + if self.scheme: + return self.scheme + return "https" if self.flow.server_conn.ssl_established else "http" + + def get_port(self): + """ + Returns the request port, either from the request itself or from the flow's server connection + """ + if self.port: + return self.port + return self.flow.server_conn.address.port + + def get_url(self, hostheader=False): + """ + Returns a URL string, constructed from the Request's URL components. + + If hostheader is True, we use the value specified in the request + Host header to construct the URL. + """ + return utils.unparse_url(self.get_scheme(), + self.get_host(hostheader), + self.get_port(), + self.path).encode('ascii') def set_url(self, url): """ @@ -464,12 +496,30 @@ class HTTPRequest(HTTPMessage): Returns False if the URL was invalid, True if the request succeeded. """ - raise NotImplementedError - # FIXME: Needs to update server_conn as well. parts = http.parse_url(url) if not parts: return False - self.scheme, self.host, self.port, self.path = parts + scheme, host, port, path = parts + is_ssl = (True if scheme == "https" else False) + + self.path = path + + if host != self.get_host() or port != self.get_port(): + if self.flow.change_server: + self.flow.change_server((host, port), ssl=is_ssl) + else: + # There's not live server connection, we're just changing the attributes here. + self.flow.server_conn = ServerConnection((host, port), AddressPriority.MANUALLY_CHANGED) + self.flow.server_conn.ssl_established = is_ssl + + # If this is an absolute request, replace the attributes on the request object as well. + if self.host: + self.host = host + if self.port: + self.port = port + if self.scheme: + self.scheme = scheme + return True def get_cookies(self): @@ -521,34 +571,25 @@ class HTTPResponse(HTTPMessage): timestamp_end: Timestamp indicating when request transmission ended """ - def __init__(self, httpversion, code, msg, headers, content, timestamp_start, timestamp_end): + def __init__(self, httpversion, code, msg, headers, content, timestamp_start=None, timestamp_end=None): assert isinstance(headers, ODictCaseless) or headers is None - HTTPMessage.__init__(self) + HTTPMessage.__init__(self, httpversion, headers, content, timestamp_start, timestamp_end) - self.httpversion = httpversion self.code = code self.msg = msg - self.headers = headers - self.content = content - self.timestamp_start = timestamp_start - self.timestamp_end = timestamp_end # Is this request replayed? self.is_replay = False - _stateobject_attributes = dict( - httpversion=tuple, + _stateobject_attributes = HTTPMessage._stateobject_attributes.copy() + _stateobject_attributes.update( code=int, - msg=str, - headers=ODictCaseless, - content=str, - timestamp_start=float, - timestamp_end=float + msg=str ) @classmethod def _from_state(cls, state): - f = cls(None, None, None, None, None, None, None) + f = cls(None, None, None, None, None) f._load_state(state) return f @@ -688,13 +729,15 @@ class HTTPFlow(Flow): intercepting: Is this flow currently being intercepted? """ - def __init__(self, client_conn, server_conn): + def __init__(self, client_conn, server_conn, change_server=None): Flow.__init__(self, "http", client_conn, server_conn) self.request = None + """@type: HTTPRequest""" self.response = None + """@type: HTTPResponse""" + self.change_server = None # Used by flow.request.set_url to change the server address self.intercepting = False # FIXME: Should that rather be an attribute of Flow? - self._backup = None _backrefattr = Flow._backrefattr + ("request", "response") @@ -787,13 +830,15 @@ class HttpAuthenticationError(Exception): return "HttpAuthenticationError" -class HTTPHandler(ProtocolHandler): +class HTTPHandler(ProtocolHandler, TemporaryServerChangeMixin): + def handle_messages(self): while self.handle_flow(): pass self.c.close = True def get_response_from_server(self, request): + self.c.establish_server_connection() request_raw = request._assemble() for i in range(2): @@ -818,7 +863,7 @@ class HTTPHandler(ProtocolHandler): raise v def handle_flow(self): - flow = HTTPFlow(self.c.client_conn, self.c.server_conn) + flow = HTTPFlow(self.c.client_conn, self.c.server_conn, self.change_server) try: flow.request = HTTPRequest.from_stream(self.c.client_conn.rfile, body_size_limit=self.c.config.body_size_limit) @@ -833,7 +878,6 @@ class HTTPHandler(ProtocolHandler): if isinstance(request_reply, HTTPResponse): flow.response = request_reply else: - self.c.establish_server_connection() flow.response = self.get_response_from_server(flow.request) self.c.log("response", [flow.response._assemble_first_line()]) @@ -855,7 +899,8 @@ class HTTPHandler(ProtocolHandler): self.ssl_upgrade(flow.request) flow.server_conn = self.c.server_conn - + self.restore_server() # If the user has changed the target server on this connection, + # restore the original target server return True except (HttpAuthenticationError, http.HttpError, ProxyError, tcp.NetLibError), e: self.handle_error(e, flow) diff --git a/libmproxy/protocol/primitives.py b/libmproxy/protocol/primitives.py index f3fdd245..d1546ddd 100644 --- a/libmproxy/protocol/primitives.py +++ b/libmproxy/protocol/primitives.py @@ -3,7 +3,7 @@ from ..proxy import ServerConnection, ClientConnection import copy -class _BackreferenceMixin(object): +class BackreferenceMixin(object): """ If an attribute from the _backrefattr tuple is set, this mixin sets a reference back on the attribute object. @@ -16,7 +16,7 @@ class _BackreferenceMixin(object): _backrefattr = tuple() def __setattr__(self, key, value): - super(_BackreferenceMixin, self).__setattr__(key, value) + super(BackreferenceMixin, self).__setattr__(key, value) if key in self._backrefattr and value is not None: setattr(value, self._backrefname, self) @@ -61,12 +61,17 @@ class Error(stateobject.SimpleStateObject): return c -class Flow(stateobject.SimpleStateObject, _BackreferenceMixin): +class Flow(stateobject.SimpleStateObject, BackreferenceMixin): def __init__(self, conntype, client_conn, server_conn): self.conntype = conntype self.client_conn = client_conn + """@type: ClientConnection""" self.server_conn = server_conn + """@type: ServerConnection""" + self.error = None + """@type: Error""" + self._backup = None _backrefattr = ("error",) _backrefname = "flow" diff --git a/libmproxy/proxy.py b/libmproxy/proxy.py index 53e3f575..6ff02a36 100644 --- a/libmproxy/proxy.py +++ b/libmproxy/proxy.py @@ -30,6 +30,7 @@ class ProxyError(Exception): def __str__(self): return "ProxyError(%s, %s)" % (self.code, self.msg) + class Log: def __init__(self, msg): self.msg = msg @@ -104,8 +105,9 @@ class ClientConnection(tcp.BaseHandler, stateobject.SimpleStateObject): class ServerConnection(tcp.TCPClient, stateobject.SimpleStateObject): - def __init__(self, address): + def __init__(self, address, priority): tcp.TCPClient.__init__(self, address) + self.priority = priority self.peername = None self.timestamp_start = None @@ -145,7 +147,7 @@ class ServerConnection(tcp.TCPClient, stateobject.SimpleStateObject): @classmethod def _from_state(cls, state): - f = cls(tuple()) + f = cls(tuple(), None) f._load_state(state) return f @@ -190,7 +192,7 @@ class RequestReplayThread(threading.Thread): def run(self): try: r = self.flow.request - server = ServerConnection(self.flow.server_conn.address()) + server = ServerConnection(self.flow.server_conn.address(), None) server.connect() if self.flow.server_conn.ssl_established: server.establish_ssl(self.config.clientcerts, @@ -202,6 +204,7 @@ class RequestReplayThread(threading.Thread): self.flow.error = protocol.primitives.Error(str(v)) self.channel.ask("error", self.flow.error) + class ConnectionHandler: def __init__(self, config, client_connection, client_address, server, channel, server_version): self.config = config @@ -310,18 +313,17 @@ class ConnectionHandler: @type priority: AddressPriority """ address = tcp.Address.wrap(address) - self.log("Set server address: %s:%s" % (address.host, address.port)) - if self.server_conn and (self.server_address_priority > priority): + self.log("Try to set server address: %s:%s" % (address.host, address.port)) + if self.server_conn and (self.server_conn.priority > priority): self.log("Server address priority too low (is: %s, got: %s)" % (self.server_address_priority, priority)) return - self.address_priority = priority - if self.server_conn and (self.server_conn.address == address): + self.server_conn.priority = priority # Possibly increase priority self.log("Addresses match, skip.") return - server_conn = ServerConnection(address) + server_conn = ServerConnection(address, priority) if self.server_conn and self.server_conn.connection: self.del_server_connection() self.server_conn = server_conn |