diff options
Diffstat (limited to 'libmproxy/models')
-rw-r--r-- | libmproxy/models/__init__.py | 16 | ||||
-rw-r--r-- | libmproxy/models/connections.py | 194 | ||||
-rw-r--r-- | libmproxy/models/flow.py | 166 | ||||
-rw-r--r-- | libmproxy/models/http.py | 554 |
4 files changed, 930 insertions, 0 deletions
diff --git a/libmproxy/models/__init__.py b/libmproxy/models/__init__.py new file mode 100644 index 00000000..3947847c --- /dev/null +++ b/libmproxy/models/__init__.py @@ -0,0 +1,16 @@ +from __future__ import (absolute_import, print_function, division) + +from .http import ( + HTTPFlow, HTTPRequest, HTTPResponse, decoded, + make_error_response, make_connect_request, make_connect_response +) +from .connections import ClientConnection, ServerConnection +from .flow import Flow, Error + +__all__ = [ + "HTTPFlow", "HTTPRequest", "HTTPResponse", "decoded" + "make_error_response", "make_connect_request", + "make_connect_response", + "ClientConnection", "ServerConnection", + "Flow", "Error", +] diff --git a/libmproxy/models/connections.py b/libmproxy/models/connections.py new file mode 100644 index 00000000..98bae3cc --- /dev/null +++ b/libmproxy/models/connections.py @@ -0,0 +1,194 @@ +from __future__ import absolute_import + +import copy +import os + +from netlib import tcp, certutils +from .. import stateobject, utils + + +class ClientConnection(tcp.BaseHandler, stateobject.StateObject): + def __init__(self, client_connection, address, server): + # Eventually, this object is restored from state. We don't have a + # connection then. + if client_connection: + super(ClientConnection, self).__init__(client_connection, address, server) + else: + self.connection = None + self.server = None + self.wfile = None + self.rfile = None + self.address = None + self.clientcert = None + self.ssl_established = None + + self.timestamp_start = utils.timestamp() + self.timestamp_end = None + self.timestamp_ssl_setup = None + self.protocol = None + + def __nonzero__(self): + return bool(self.connection) and not self.finished + + def __repr__(self): + return "<ClientConnection: {ssl}{host}:{port}>".format( + ssl="[ssl] " if self.ssl_established else "", + host=self.address.host, + port=self.address.port + ) + + @property + def tls_established(self): + return self.ssl_established + + _stateobject_attributes = dict( + ssl_established=bool, + timestamp_start=float, + timestamp_end=float, + timestamp_ssl_setup=float + ) + + def get_state(self, short=False): + d = super(ClientConnection, self).get_state(short) + d.update( + address={ + "address": self.address(), + "use_ipv6": self.address.use_ipv6}, + clientcert=self.cert.to_pem() if self.clientcert else None) + return d + + def load_state(self, state): + super(ClientConnection, self).load_state(state) + self.address = tcp.Address( + **state["address"]) if state["address"] else None + self.clientcert = certutils.SSLCert.from_pem( + state["clientcert"]) if state["clientcert"] else None + + def copy(self): + return copy.copy(self) + + def send(self, message): + if isinstance(message, list): + message = b''.join(message) + self.wfile.write(message) + self.wfile.flush() + + @classmethod + def from_state(cls, state): + f = cls(None, tuple(), None) + f.load_state(state) + return f + + def convert_to_ssl(self, *args, **kwargs): + super(ClientConnection, self).convert_to_ssl(*args, **kwargs) + self.timestamp_ssl_setup = utils.timestamp() + + def finish(self): + super(ClientConnection, self).finish() + self.timestamp_end = utils.timestamp() + + +class ServerConnection(tcp.TCPClient, stateobject.StateObject): + def __init__(self, address): + tcp.TCPClient.__init__(self, address) + + self.via = None + self.timestamp_start = None + self.timestamp_end = None + self.timestamp_tcp_setup = None + self.timestamp_ssl_setup = None + self.protocol = None + + def __nonzero__(self): + return bool(self.connection) and not self.finished + + def __repr__(self): + if self.ssl_established and self.sni: + ssl = "[ssl: {0}] ".format(self.sni) + elif self.ssl_established: + ssl = "[ssl] " + else: + ssl = "" + return "<ServerConnection: {ssl}{host}:{port}>".format( + ssl=ssl, + host=self.address.host, + port=self.address.port + ) + + @property + def tls_established(self): + return self.ssl_established + + _stateobject_attributes = dict( + timestamp_start=float, + timestamp_end=float, + timestamp_tcp_setup=float, + timestamp_ssl_setup=float, + address=tcp.Address, + source_address=tcp.Address, + cert=certutils.SSLCert, + ssl_established=bool, + sni=str + ) + _stateobject_long_attributes = {"cert"} + + def get_state(self, short=False): + d = super(ServerConnection, self).get_state(short) + d.update( + address={"address": self.address(), + "use_ipv6": self.address.use_ipv6}, + source_address=({"address": self.source_address(), + "use_ipv6": self.source_address.use_ipv6} if self.source_address else None), + cert=self.cert.to_pem() if self.cert else None + ) + return d + + def load_state(self, state): + super(ServerConnection, self).load_state(state) + + self.address = tcp.Address( + **state["address"]) if state["address"] else None + self.source_address = tcp.Address( + **state["source_address"]) if state["source_address"] else None + self.cert = certutils.SSLCert.from_pem( + state["cert"]) if state["cert"] else None + + @classmethod + def from_state(cls, state): + f = cls(tuple()) + f.load_state(state) + return f + + def copy(self): + return copy.copy(self) + + def connect(self): + self.timestamp_start = utils.timestamp() + tcp.TCPClient.connect(self) + self.timestamp_tcp_setup = utils.timestamp() + + def send(self, message): + if isinstance(message, list): + message = b''.join(message) + self.wfile.write(message) + self.wfile.flush() + + def establish_ssl(self, clientcerts, sni, **kwargs): + clientcert = None + if clientcerts: + path = os.path.join( + clientcerts, + self.address.host.encode("idna")) + ".pem" + if os.path.exists(path): + clientcert = path + + self.convert_to_ssl(cert=clientcert, sni=sni, **kwargs) + self.sni = sni + self.timestamp_ssl_setup = utils.timestamp() + + def finish(self): + tcp.TCPClient.finish(self) + self.timestamp_end = utils.timestamp() + + +ServerConnection._stateobject_attributes["via"] = ServerConnection diff --git a/libmproxy/models/flow.py b/libmproxy/models/flow.py new file mode 100644 index 00000000..58287e5b --- /dev/null +++ b/libmproxy/models/flow.py @@ -0,0 +1,166 @@ +from __future__ import absolute_import +import copy +import uuid + +from .. import stateobject, utils, version +from .connections import ClientConnection, ServerConnection + + +class Error(stateobject.StateObject): + """ + An Error. + + This is distinct from an protocol error response (say, a HTTP code 500), + which is represented by a normal HTTPResponse object. This class is + responsible for indicating errors that fall outside of normal protocol + communications, like interrupted connections, timeouts, protocol errors. + + Exposes the following attributes: + + flow: Flow object + msg: Message describing the error + timestamp: Seconds since the epoch + """ + + def __init__(self, msg, timestamp=None): + """ + @type msg: str + @type timestamp: float + """ + self.flow = None # will usually be set by the flow backref mixin + self.msg = msg + self.timestamp = timestamp or utils.timestamp() + + _stateobject_attributes = dict( + msg=str, + timestamp=float + ) + + def __str__(self): + return self.msg + + @classmethod + def from_state(cls, state): + # the default implementation assumes an empty constructor. Override + # accordingly. + f = cls(None) + f.load_state(state) + return f + + def copy(self): + c = copy.copy(self) + return c + + +class Flow(stateobject.StateObject): + """ + A Flow is a collection of objects representing a single transaction. + This class is usually subclassed for each protocol, e.g. HTTPFlow. + """ + + def __init__(self, type, client_conn, server_conn, live=None): + self.type = type + self.id = str(uuid.uuid4()) + self.client_conn = client_conn + """@type: ClientConnection""" + self.server_conn = server_conn + """@type: ServerConnection""" + self.live = live + """@type: LiveConnection""" + + self.error = None + """@type: Error""" + self.intercepted = False + """@type: bool""" + self._backup = None + self.reply = None + + _stateobject_attributes = dict( + id=str, + error=Error, + client_conn=ClientConnection, + server_conn=ServerConnection, + type=str, + intercepted=bool + ) + + def get_state(self, short=False): + d = super(Flow, self).get_state(short) + d.update(version=version.IVERSION) + if self._backup and self._backup != d: + if short: + d.update(modified=True) + else: + d.update(backup=self._backup) + return d + + def __eq__(self, other): + return self is other + + def copy(self): + f = copy.copy(self) + + f.id = str(uuid.uuid4()) + f.live = False + f.client_conn = self.client_conn.copy() + f.server_conn = self.server_conn.copy() + + if self.error: + f.error = self.error.copy() + return f + + def modified(self): + """ + Has this Flow been modified? + """ + if self._backup: + return self._backup != self.get_state() + else: + return False + + def backup(self, force=False): + """ + Save a backup of this Flow, which can be reverted to using a + call to .revert(). + """ + if not self._backup: + self._backup = self.get_state() + + def revert(self): + """ + Revert to the last backed up state. + """ + if self._backup: + self.load_state(self._backup) + self._backup = None + + def kill(self, master): + """ + Kill this request. + """ + from ..protocol import Kill + + self.error = Error("Connection killed") + self.intercepted = False + self.reply(Kill) + master.handle_error(self) + + def intercept(self, master): + """ + Intercept this Flow. Processing will stop until accept_intercept is + called. + """ + if self.intercepted: + return + self.intercepted = True + master.handle_intercept(self) + + def accept_intercept(self, master): + """ + Continue with the flow - called after an intercept(). + """ + if not self.intercepted: + return + self.intercepted = False + self.reply() + master.handle_accept_intercept(self) diff --git a/libmproxy/models/http.py b/libmproxy/models/http.py new file mode 100644 index 00000000..fb2f305b --- /dev/null +++ b/libmproxy/models/http.py @@ -0,0 +1,554 @@ +from __future__ import (absolute_import, print_function, division) +import Cookie +import copy +from email.utils import parsedate_tz, formatdate, mktime_tz +import time + +from libmproxy import utils +from netlib import odict, encoding +from netlib.http import status_codes +from netlib.tcp import Address +from netlib.http.semantics import Request, Response, CONTENT_MISSING +from .. import version, stateobject +from .flow import Flow + + +class MessageMixin(stateobject.StateObject): + _stateobject_attributes = dict( + httpversion=tuple, + headers=odict.ODictCaseless, + body=str, + timestamp_start=float, + timestamp_end=float + ) + _stateobject_long_attributes = {"body"} + + def get_state(self, short=False): + ret = super(MessageMixin, self).get_state(short) + if short: + if self.body: + ret["contentLength"] = len(self.body) + elif self.body == CONTENT_MISSING: + ret["contentLength"] = None + else: + ret["contentLength"] = 0 + return ret + + def get_decoded_content(self): + """ + Returns the decoded content based on the current Content-Encoding + header. + Doesn't change the message iteself or its headers. + """ + ce = self.headers.get_first("content-encoding") + if not self.body or ce not in encoding.ENCODINGS: + return self.body + return encoding.decode(ce, self.body) + + def decode(self): + """ + Decodes body based on the current Content-Encoding header, then + removes the header. If there is no Content-Encoding header, no + action is taken. + + Returns True if decoding succeeded, False otherwise. + """ + ce = self.headers.get_first("content-encoding") + if not self.body or ce not in encoding.ENCODINGS: + return False + data = encoding.decode(ce, self.body) + if data is None: + return False + self.body = data + del self.headers["content-encoding"] + return True + + def encode(self, e): + """ + Encodes body with the encoding e, where e is "gzip", "deflate" + or "identity". + """ + # FIXME: Error if there's an existing encoding header? + self.body = encoding.encode(e, self.body) + self.headers["content-encoding"] = [e] + + def copy(self): + c = copy.copy(self) + c.headers = self.headers.copy() + return c + + def replace(self, pattern, repl, *args, **kwargs): + """ + Replaces a regular expression pattern with repl in both the headers + and the body of the message. Encoded body will be decoded + before replacement, and re-encoded afterwards. + + Returns the number of replacements made. + """ + with decoded(self): + self.body, c = utils.safe_subn( + pattern, repl, self.body, *args, **kwargs + ) + c += self.headers.replace(pattern, repl, *args, **kwargs) + return c + + +class HTTPRequest(MessageMixin, Request): + """ + An HTTP request. + + Exposes the following attributes: + + method: HTTP method + + scheme: URL scheme (http/https) + + host: Target hostname of the request. This is not neccessarily the + directy upstream server (which could be another proxy), but it's always + the target server we want to reach at the end. This attribute is either + inferred from the request itself (absolute-form, authority-form) or from + the connection metadata (e.g. the host in reverse proxy mode). + + port: Destination port + + path: Path portion of the URL (not present in authority-form) + + httpversion: HTTP version tuple, e.g. (1,1) + + headers: odict.ODictCaseless object + + content: Content of the request, None, or CONTENT_MISSING if there + is content associated, but not present. CONTENT_MISSING evaluates + to False to make checking for the presence of content natural. + + form_in: The request form which mitmproxy has received. The following + values are possible: + + - relative (GET /index.html, OPTIONS *) (covers origin form and + asterisk form) + - absolute (GET http://example.com:80/index.html) + - authority-form (CONNECT example.com:443) + Details: http://tools.ietf.org/html/draft-ietf-httpbis-p1-messaging-25#section-5.3 + + form_out: The request form which mitmproxy will send out to the + destination + + timestamp_start: Timestamp indicating when request transmission started + + timestamp_end: Timestamp indicating when request transmission ended + """ + + def __init__( + self, + form_in, + method, + scheme, + host, + port, + path, + httpversion, + headers, + body, + timestamp_start=None, + timestamp_end=None, + form_out=None, + ): + Request.__init__( + self, + form_in, + method, + scheme, + host, + port, + path, + httpversion, + headers, + body, + timestamp_start, + timestamp_end, + ) + self.form_out = form_out or form_in + + # Have this request's cookies been modified by sticky cookies or auth? + self.stickycookie = False + self.stickyauth = False + + # Is this request replayed? + self.is_replay = False + + _stateobject_attributes = MessageMixin._stateobject_attributes.copy() + _stateobject_attributes.update( + form_in=str, + method=str, + scheme=str, + host=str, + port=int, + path=str, + form_out=str, + is_replay=bool + ) + + @classmethod + def from_state(cls, state): + f = cls( + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None) + f.load_state(state) + return f + + @classmethod + def from_protocol( + self, + protocol, + *args, + **kwargs + ): + req = protocol.read_request(*args, **kwargs) + return self.wrap(req) + + @classmethod + def wrap(self, request): + req = HTTPRequest( + form_in=request.form_in, + method=request.method, + scheme=request.scheme, + host=request.host, + port=request.port, + path=request.path, + httpversion=request.httpversion, + headers=request.headers, + body=request.body, + timestamp_start=request.timestamp_start, + timestamp_end=request.timestamp_end, + form_out=(request.form_out if hasattr(request, 'form_out') else None), + ) + if hasattr(request, 'stream_id'): + req.stream_id = request.stream_id + return req + + def __hash__(self): + return id(self) + + def replace(self, pattern, repl, *args, **kwargs): + """ + Replaces a regular expression pattern with repl in the headers, the + request path and the body of the request. Encoded content will be + decoded before replacement, and re-encoded afterwards. + + Returns the number of replacements made. + """ + c = MessageMixin.replace(self, pattern, repl, *args, **kwargs) + self.path, pc = utils.safe_subn( + pattern, repl, self.path, *args, **kwargs + ) + c += pc + return c + + +class HTTPResponse(MessageMixin, Response): + """ + An HTTP response. + + Exposes the following attributes: + + httpversion: HTTP version tuple, e.g. (1, 0), (1, 1), or (2, 0) + + status_code: HTTP response status code + + msg: HTTP response message + + headers: ODict Caseless object + + content: Content of the request, None, or CONTENT_MISSING if there + is content associated, but not present. CONTENT_MISSING evaluates + to False to make checking for the presence of content natural. + + timestamp_start: Timestamp indicating when request transmission started + + timestamp_end: Timestamp indicating when request transmission ended + """ + + def __init__( + self, + httpversion, + status_code, + msg, + headers, + body, + timestamp_start=None, + timestamp_end=None, + ): + Response.__init__( + self, + httpversion, + status_code, + msg, + headers, + body, + timestamp_start=timestamp_start, + timestamp_end=timestamp_end, + ) + + # Is this request replayed? + self.is_replay = False + self.stream = False + + _stateobject_attributes = MessageMixin._stateobject_attributes.copy() + _stateobject_attributes.update( + status_code=int, + msg=str + ) + + @classmethod + def from_state(cls, state): + f = cls(None, None, None, None, None) + f.load_state(state) + return f + + @classmethod + def from_protocol( + self, + protocol, + *args, + **kwargs + ): + resp = protocol.read_response(*args, **kwargs) + return self.wrap(resp) + + @classmethod + def wrap(self, response): + resp = HTTPResponse( + httpversion=response.httpversion, + status_code=response.status_code, + msg=response.msg, + headers=response.headers, + body=response.body, + timestamp_start=response.timestamp_start, + timestamp_end=response.timestamp_end, + ) + if hasattr(response, 'stream_id'): + resp.stream_id = response.stream_id + return resp + + def _refresh_cookie(self, c, delta): + """ + Takes a cookie string c and a time delta in seconds, and returns + a refreshed cookie string. + """ + c = Cookie.SimpleCookie(str(c)) + for i in c.values(): + if "expires" in i: + d = parsedate_tz(i["expires"]) + if d: + d = mktime_tz(d) + delta + i["expires"] = formatdate(d) + else: + # This can happen when the expires tag is invalid. + # reddit.com sends a an expires tag like this: "Thu, 31 Dec + # 2037 23:59:59 GMT", which is valid RFC 1123, but not + # strictly correct according to the cookie spec. Browsers + # appear to parse this tolerantly - maybe we should too. + # For now, we just ignore this. + del i["expires"] + return c.output(header="").strip() + + def refresh(self, now=None): + """ + This fairly complex and heuristic function refreshes a server + response for replay. + + - It adjusts date, expires and last-modified headers. + - It adjusts cookie expiration. + """ + if not now: + now = time.time() + delta = now - self.timestamp_start + refresh_headers = [ + "date", + "expires", + "last-modified", + ] + for i in refresh_headers: + if i in self.headers: + d = parsedate_tz(self.headers[i][0]) + if d: + new = mktime_tz(d) + delta + self.headers[i] = [formatdate(new)] + c = [] + for i in self.headers["set-cookie"]: + c.append(self._refresh_cookie(i, delta)) + if c: + self.headers["set-cookie"] = c + + +class HTTPFlow(Flow): + """ + A HTTPFlow is a collection of objects representing a single HTTP + transaction. The main attributes are: + + request: HTTPRequest object + response: HTTPResponse object + error: Error object + server_conn: ServerConnection object + client_conn: ClientConnection object + + Note that it's possible for a Flow to have both a response and an error + object. This might happen, for instance, when a response was received + from the server, but there was an error sending it back to the client. + + The following additional attributes are exposed: + + intercepted: Is this flow currently being intercepted? + live: Does this flow have a live client connection? + """ + + def __init__(self, client_conn, server_conn, live=None): + super(HTTPFlow, self).__init__("http", client_conn, server_conn, live) + self.request = None + """@type: HTTPRequest""" + self.response = None + """@type: HTTPResponse""" + + _stateobject_attributes = Flow._stateobject_attributes.copy() + _stateobject_attributes.update( + request=HTTPRequest, + response=HTTPResponse + ) + + @classmethod + def from_state(cls, state): + f = cls(None, None) + f.load_state(state) + return f + + def __repr__(self): + s = "<HTTPFlow" + for a in ("request", "response", "error", "client_conn", "server_conn"): + if getattr(self, a, False): + s += "\r\n %s = {flow.%s}" % (a, a) + s += ">" + return s.format(flow=self) + + def copy(self): + f = super(HTTPFlow, self).copy() + if self.request: + f.request = self.request.copy() + if self.response: + f.response = self.response.copy() + return f + + def match(self, f): + """ + Match this flow against a compiled filter expression. Returns True + if matched, False if not. + + If f is a string, it will be compiled as a filter expression. If + the expression is invalid, ValueError is raised. + """ + if isinstance(f, basestring): + from .. import filt + + f = filt.parse(f) + if not f: + raise ValueError("Invalid filter expression.") + if f: + return f(self) + return True + + def replace(self, pattern, repl, *args, **kwargs): + """ + Replaces a regular expression pattern with repl in both request and + response of the flow. Encoded content will be decoded before + replacement, and re-encoded afterwards. + + Returns the number of replacements made. + """ + c = self.request.replace(pattern, repl, *args, **kwargs) + if self.response: + c += self.response.replace(pattern, repl, *args, **kwargs) + return c + + +class decoded(object): + """ + A context manager that decodes a request or response, and then + re-encodes it with the same encoding after execution of the block. + + Example: + with decoded(request): + request.content = request.content.replace("foo", "bar") + """ + + def __init__(self, o): + self.o = o + ce = o.headers.get_first("content-encoding") + if ce in encoding.ENCODINGS: + self.ce = ce + else: + self.ce = None + + def __enter__(self): + if self.ce: + self.o.decode() + + def __exit__(self, type, value, tb): + if self.ce: + self.o.encode(self.ce) + + +def make_error_response(status_code, message, headers=None): + response = status_codes.RESPONSES.get(status_code, "Unknown") + body = """ + <html> + <head> + <title>%d %s</title> + </head> + <body>%s</body> + </html> + """.strip() % (status_code, response, message) + + if not headers: + headers = odict.ODictCaseless() + headers["Server"] = [version.NAMEVERSION] + headers["Connection"] = ["close"] + headers["Content-Length"] = [len(body)] + headers["Content-Type"] = ["text/html"] + + return HTTPResponse( + (1, 1), # FIXME: Should be a string. + status_code, + response, + headers, + body, + ) + + +def make_connect_request(address): + address = Address.wrap(address) + return HTTPRequest( + "authority", "CONNECT", None, address.host, address.port, None, (1, 1), + odict.ODictCaseless(), "" + ) + + +def make_connect_response(httpversion): + headers = odict.ODictCaseless([ + ["Content-Length", "0"], + ["Proxy-Agent", version.NAMEVERSION] + ]) + return HTTPResponse( + httpversion, + 200, + "Connection established", + headers, + "", + ) |