diff options
Diffstat (limited to 'libmproxy')
-rw-r--r-- | libmproxy/app.py | 6 | ||||
-rw-r--r-- | libmproxy/console/__init__.py | 4 | ||||
-rw-r--r-- | libmproxy/console/common.py | 6 | ||||
-rw-r--r-- | libmproxy/console/flowdetailview.py | 35 | ||||
-rw-r--r-- | libmproxy/controller.py | 1 | ||||
-rw-r--r-- | libmproxy/dump.py | 6 | ||||
-rw-r--r-- | libmproxy/filt.py | 2 | ||||
-rw-r--r-- | libmproxy/flow.py | 1028 | ||||
-rw-r--r-- | libmproxy/protocol/__init__.py | 101 | ||||
-rw-r--r-- | libmproxy/protocol/http.py | 1046 | ||||
-rw-r--r-- | libmproxy/protocol/primitives.py | 130 | ||||
-rw-r--r-- | libmproxy/protocol/tcp.py | 59 | ||||
-rw-r--r-- | libmproxy/proxy.py | 848 | ||||
-rw-r--r-- | libmproxy/script.py | 2 | ||||
-rw-r--r-- | libmproxy/stateobject.py | 73 |
15 files changed, 1896 insertions, 1451 deletions
diff --git a/libmproxy/app.py b/libmproxy/app.py index b0692cf2..b046f712 100644 --- a/libmproxy/app.py +++ b/libmproxy/app.py @@ -4,9 +4,11 @@ import os.path mapp = flask.Flask(__name__) mapp.debug = True + def master(): return flask.request.environ["mitmproxy.master"] + @mapp.route("/") def index(): return flask.render_template("index.html", section="home") @@ -16,12 +18,12 @@ def index(): def certs_pem(): capath = master().server.config.cacert p = os.path.splitext(capath)[0] + "-cert.pem" - return flask.Response(open(p).read(), mimetype='application/x-x509-ca-cert') + return flask.Response(open(p, "rb").read(), mimetype='application/x-x509-ca-cert') @mapp.route("/cert/p12") def certs_p12(): capath = master().server.config.cacert p = os.path.splitext(capath)[0] + "-cert.p12" - return flask.Response(open(p).read(), mimetype='application/x-pkcs12') + return flask.Response(open(p, "rb").read(), mimetype='application/x-pkcs12') diff --git a/libmproxy/console/__init__.py b/libmproxy/console/__init__.py index a316602c..d92561f2 100644 --- a/libmproxy/console/__init__.py +++ b/libmproxy/console/__init__.py @@ -197,7 +197,7 @@ class StatusBar(common.WWrap): ] if self.master.server.bound: - boundaddr = "[%s:%s]"%(self.master.server.address or "*", self.master.server.port) + boundaddr = "[%s:%s]"%(self.master.server.address.host or "*", self.master.server.address.port) else: boundaddr = "" t.extend(self.get_status()) @@ -1008,7 +1008,7 @@ class ConsoleMaster(flow.FlowMaster): self.statusbar.refresh_flow(c) def process_flow(self, f, r): - if self.state.intercept and f.match(self.state.intercept) and not f.request.is_replay(): + if self.state.intercept and f.match(self.state.intercept) and not f.request.is_replay: f.intercept() else: r.reply() diff --git a/libmproxy/console/common.py b/libmproxy/console/common.py index 951d2c2a..715bed80 100644 --- a/libmproxy/console/common.py +++ b/libmproxy/console/common.py @@ -172,7 +172,7 @@ def format_flow(f, focus, extended=False, hostheader=False, padding=2): intercepting = f.intercepting, req_timestamp = f.request.timestamp_start, - req_is_replay = f.request.is_replay(), + req_is_replay = f.request.is_replay, req_method = f.request.method, req_acked = f.request.reply.acked, req_url = f.request.get_url(hostheader=hostheader), @@ -189,12 +189,12 @@ def format_flow(f, focus, extended=False, hostheader=False, padding=2): contentdesc = "[no content]" delta = f.response.timestamp_end - f.response.timestamp_start - size = len(f.response.content) + f.response.get_header_size() + size = f.response.size() rate = utils.pretty_size(size / ( delta if delta > 0 else 1 ) ) d.update(dict( resp_code = f.response.code, - resp_is_replay = f.response.is_replay(), + resp_is_replay = f.response.is_replay, resp_acked = f.response.reply.acked, resp_clen = contentdesc, resp_rate = "{0}/s".format(rate), diff --git a/libmproxy/console/flowdetailview.py b/libmproxy/console/flowdetailview.py index a26e5308..436d8f07 100644 --- a/libmproxy/console/flowdetailview.py +++ b/libmproxy/console/flowdetailview.py @@ -1,5 +1,6 @@ import urwid import common +from .. import utils footer = [ ('heading_key', "q"), ":back ", @@ -33,8 +34,17 @@ class FlowDetailsView(urwid.ListBox): title = urwid.AttrWrap(title, "heading") text.append(title) - if self.flow.response: - c = self.flow.response.cert + if self.flow.server_conn: + text.append(urwid.Text([("head", "Server Connection:")])) + sc = self.flow.server_conn + parts = [ + ["Address", "%s:%s" % sc.peername], + ["Start time", utils.format_timestamp(sc.timestamp_start)], + ["End time", utils.format_timestamp(sc.timestamp_end) if sc.timestamp_end else "active"], + ] + text.extend(common.format_keyvals(parts, key="key", val="text", indent=4)) + + c = self.flow.server_conn.cert if c: text.append(urwid.Text([("head", "Server Certificate:")])) parts = [ @@ -43,19 +53,13 @@ class FlowDetailsView(urwid.ListBox): ["Valid to", str(c.notafter)], ["Valid from", str(c.notbefore)], ["Serial", str(c.serial)], - ] - - parts.append( [ "Subject", urwid.BoxAdapter( urwid.ListBox(common.format_keyvals(c.subject, key="highlight", val="text")), len(c.subject) ) - ] - ) - - parts.append( + ], [ "Issuer", urwid.BoxAdapter( @@ -63,7 +67,7 @@ class FlowDetailsView(urwid.ListBox): len(c.issuer) ) ] - ) + ] if c.altnames: parts.append( @@ -74,13 +78,14 @@ class FlowDetailsView(urwid.ListBox): ) text.extend(common.format_keyvals(parts, key="key", val="text", indent=4)) - if self.flow.request.client_conn: + if self.flow.client_conn: text.append(urwid.Text([("head", "Client Connection:")])) - cc = self.flow.request.client_conn + cc = self.flow.client_conn parts = [ - ["Address", "%s:%s"%tuple(cc.address)], - ["Requests", "%s"%cc.requestcount], - ["Closed", "%s"%cc.close], + ["Address", "%s:%s" % cc.address()], + ["Start time", utils.format_timestamp(cc.timestamp_start)], + # ["Requests", "%s"%cc.requestcount], + ["End time", utils.format_timestamp(cc.timestamp_end) if cc.timestamp_end else "active"], ] text.extend(common.format_keyvals(parts, key="key", val="text", indent=4)) diff --git a/libmproxy/controller.py b/libmproxy/controller.py index b662b6d5..470d88fc 100644 --- a/libmproxy/controller.py +++ b/libmproxy/controller.py @@ -72,6 +72,7 @@ class Slave(threading.Thread): self.channel, self.server = channel, server self.server.set_channel(channel) threading.Thread.__init__(self) + self.name = "SlaveThread (%s:%s)" % (self.server.address.host, self.server.address.port) def run(self): self.server.serve_forever() diff --git a/libmproxy/dump.py b/libmproxy/dump.py index 8bd29ae5..6cf5e688 100644 --- a/libmproxy/dump.py +++ b/libmproxy/dump.py @@ -42,14 +42,14 @@ class Options(object): def str_response(resp): r = "%s %s"%(resp.code, resp.msg) - if resp.is_replay(): + if resp.is_replay: r = "[replay] " + r return r def str_request(req, showhost): - if req.client_conn: - c = req.client_conn.address[0] + if req.flow.client_conn: + c = req.flow.client_conn.address.host else: c = "[replay]" r = "%s %s %s"%(c, req.method, req.get_url(showhost)) diff --git a/libmproxy/filt.py b/libmproxy/filt.py index 6a0c3075..95076eed 100644 --- a/libmproxy/filt.py +++ b/libmproxy/filt.py @@ -198,7 +198,7 @@ class FDomain(_Rex): code = "d" help = "Domain" def __call__(self, f): - return bool(re.search(self.expr, f.request.host, re.IGNORECASE)) + return bool(re.search(self.expr, f.request.get_host(), re.IGNORECASE)) class FUrl(_Rex): diff --git a/libmproxy/flow.py b/libmproxy/flow.py index 76ca4f47..40786631 100644 --- a/libmproxy/flow.py +++ b/libmproxy/flow.py @@ -2,16 +2,19 @@ This module provides more sophisticated flow tracking. These match requests with their responses, and provide filtering and interception facilities. """ -import hashlib, Cookie, cookielib, copy, re, urlparse, threading -import time, urllib -import tnetstring, filt, script, utils, encoding, proxy -from email.utils import parsedate_tz, formatdate, mktime_tz -from netlib import odict, http, certutils, wsgi -import controller, version +import base64 +import hashlib, Cookie, cookielib, re, threading +import os +from flask import request +import requests +import tnetstring, filt, script +from netlib import odict, wsgi +from .proxy import ClientConnection, ServerConnection # FIXME: remove circular dependency +import controller, version, protocol import app - -HDR_FORM_URLENCODED = "application/x-www-form-urlencoded" -CONTENT_MISSING = 0 +from .protocol import KILL +from .protocol.http import HTTPResponse, CONTENT_MISSING +from .proxy import RequestReplayThread ODict = odict.ODict ODictCaseless = odict.ODictCaseless @@ -32,11 +35,11 @@ class AppRegistry: """ Returns an WSGIAdaptor instance if request matches an app, or None. """ - if (request.host, request.port) in self.apps: - return self.apps[(request.host, request.port)] + if (request.get_host(), request.get_port()) in self.apps: + return self.apps[(request.get_host(), request.get_port())] if "host" in request.headers: host = request.headers["host"][0] - return self.apps.get((host, request.port), None) + return self.apps.get((host, request.get_port()), None) class ReplaceHooks: @@ -143,769 +146,6 @@ class SetHeaders: f.request.headers.add(header, value) -class decoded(object): - """ - - A context manager that decodes a request, response or error, and then - re-encodes it with the same encoding after execution of the block. - - Example: - - with decoded(request): - request.content = request.content.replace("foo", "bar") - """ - def __init__(self, o): - self.o = o - ce = o.headers.get_first("content-encoding") - if ce in encoding.ENCODINGS: - self.ce = ce - else: - self.ce = None - - def __enter__(self): - if self.ce: - self.o.decode() - - def __exit__(self, type, value, tb): - if self.ce: - self.o.encode(self.ce) - - -class StateObject: - def __eq__(self, other): - try: - return self._get_state() == other._get_state() - except AttributeError: - return False - - -class HTTPMsg(StateObject): - def get_decoded_content(self): - """ - Returns the decoded content based on the current Content-Encoding header. - Doesn't change the message iteself or its headers. - """ - ce = self.headers.get_first("content-encoding") - if not self.content or ce not in encoding.ENCODINGS: - return self.content - return encoding.decode(ce, self.content) - - def decode(self): - """ - Decodes content based on the current Content-Encoding header, then - removes the header. If there is no Content-Encoding header, no - action is taken. - - Returns True if decoding succeeded, False otherwise. - """ - ce = self.headers.get_first("content-encoding") - if not self.content or ce not in encoding.ENCODINGS: - return False - data = encoding.decode( - ce, - self.content - ) - if data is None: - return False - self.content = data - del self.headers["content-encoding"] - return True - - def encode(self, e): - """ - Encodes content with the encoding e, where e is "gzip", "deflate" - or "identity". - """ - # FIXME: Error if there's an existing encoding header? - self.content = encoding.encode(e, self.content) - self.headers["content-encoding"] = [e] - - def size(self, **kwargs): - """ - Size in bytes of a fully rendered message, including headers and - HTTP lead-in. - """ - hl = len(self._assemble_head(**kwargs)) - if self.content: - return hl + len(self.content) - else: - return hl - - def get_content_type(self): - return self.headers.get_first("content-type") - - def get_transmitted_size(self): - # FIXME: this is inprecise in case chunking is used - # (we should count the chunking headers) - if not self.content: - return 0 - return len(self.content) - - -class Request(HTTPMsg): - """ - An HTTP request. - - Exposes the following attributes: - - client_conn: ClientConnect object, or None if this is a replay. - - headers: ODictCaseless object - - content: Content of the request, None, or CONTENT_MISSING if there - is content associated, but not present. CONTENT_MISSING evaluates - to False to make checking for the presence of content natural. - - scheme: URL scheme (http/https) - - host: Host portion of the URL - - port: Destination port - - path: Path portion of the URL - - timestamp_start: Seconds since the epoch signifying request transmission started - - method: HTTP method - - timestamp_end: Seconds since the epoch signifying request transmission ended - - tcp_setup_timestamp: Seconds since the epoch signifying remote TCP connection setup completion time - (or None, if request didn't results TCP setup) - - ssl_setup_timestamp: Seconds since the epoch signifying remote SSL encryption setup completion time - (or None, if request didn't results SSL setup) - - """ - def __init__( - self, client_conn, httpversion, host, port, - scheme, method, path, headers, content, timestamp_start=None, - timestamp_end=None, tcp_setup_timestamp=None, - ssl_setup_timestamp=None, ip=None): - assert isinstance(headers, ODictCaseless) - self.client_conn = client_conn - self.httpversion = httpversion - self.host, self.port, self.scheme = host, port, scheme - self.method, self.path, self.headers, self.content = method, path, headers, content - self.timestamp_start = timestamp_start or utils.timestamp() - self.timestamp_end = max(timestamp_end or utils.timestamp(), timestamp_start) - self.close = False - self.tcp_setup_timestamp = tcp_setup_timestamp - self.ssl_setup_timestamp = ssl_setup_timestamp - self.ip = ip - - # Have this request's cookies been modified by sticky cookies or auth? - self.stickycookie = False - self.stickyauth = False - - # Live attributes - not serialized - self.wfile, self.rfile = None, None - - def set_live(self, rfile, wfile): - self.wfile, self.rfile = wfile, rfile - - def is_live(self): - return bool(self.wfile) - - def anticache(self): - """ - Modifies this request to remove headers that might produce a cached - response. That is, we remove ETags and If-Modified-Since headers. - """ - delheaders = [ - "if-modified-since", - "if-none-match", - ] - for i in delheaders: - del self.headers[i] - - def anticomp(self): - """ - Modifies this request to remove headers that will compress the - resource's data. - """ - self.headers["accept-encoding"] = ["identity"] - - def constrain_encoding(self): - """ - Limits the permissible Accept-Encoding values, based on what we can - decode appropriately. - """ - if self.headers["accept-encoding"]: - self.headers["accept-encoding"] = [', '.join( - e for e in encoding.ENCODINGS if e in self.headers["accept-encoding"][0] - )] - - def _set_replay(self): - self.client_conn = None - - def is_replay(self): - """ - Is this request a replay? - """ - if self.client_conn: - return False - else: - return True - - def _load_state(self, state): - if state["client_conn"]: - if self.client_conn: - self.client_conn._load_state(state["client_conn"]) - else: - self.client_conn = ClientConnect._from_state(state["client_conn"]) - else: - self.client_conn = None - self.host = state["host"] - self.port = state["port"] - self.scheme = state["scheme"] - self.method = state["method"] - self.path = state["path"] - self.headers = ODictCaseless._from_state(state["headers"]) - self.content = state["content"] - self.timestamp_start = state["timestamp_start"] - self.timestamp_end = state["timestamp_end"] - self.tcp_setup_timestamp = state["tcp_setup_timestamp"] - self.ssl_setup_timestamp = state["ssl_setup_timestamp"] - self.ip = state["ip"] - - def _get_state(self): - return dict( - client_conn = self.client_conn._get_state() if self.client_conn else None, - httpversion = self.httpversion, - host = self.host, - port = self.port, - scheme = self.scheme, - method = self.method, - path = self.path, - headers = self.headers._get_state(), - content = self.content, - timestamp_start = self.timestamp_start, - timestamp_end = self.timestamp_end, - tcp_setup_timestamp = self.tcp_setup_timestamp, - ssl_setup_timestamp = self.ssl_setup_timestamp, - ip = self.ip - ) - - @classmethod - def _from_state(klass, state): - return klass( - ClientConnect._from_state(state["client_conn"]), - tuple(state["httpversion"]), - str(state["host"]), - state["port"], - str(state["scheme"]), - str(state["method"]), - str(state["path"]), - ODictCaseless._from_state(state["headers"]), - state["content"], - state["timestamp_start"], - state["timestamp_end"], - state["tcp_setup_timestamp"], - state["ssl_setup_timestamp"], - state["ip"] - ) - - def __hash__(self): - return id(self) - - def copy(self): - c = copy.copy(self) - c.headers = self.headers.copy() - return c - - def get_form_urlencoded(self): - """ - Retrieves the URL-encoded form data, returning an ODict object. - Returns an empty ODict if there is no data or the content-type - indicates non-form data. - """ - if self.content and self.headers.in_any("content-type", HDR_FORM_URLENCODED, True): - return ODict(utils.urldecode(self.content)) - return ODict([]) - - def set_form_urlencoded(self, odict): - """ - Sets the body to the URL-encoded form data, and adds the - appropriate content-type header. Note that this will destory the - existing body if there is one. - """ - # FIXME: If there's an existing content-type header indicating a - # url-encoded form, leave it alone. - self.headers["Content-Type"] = [HDR_FORM_URLENCODED] - self.content = utils.urlencode(odict.lst) - - def get_path_components(self): - """ - Returns the path components of the URL as a list of strings. - - Components are unquoted. - """ - _, _, path, _, _, _ = urlparse.urlparse(self.get_url()) - return [urllib.unquote(i) for i in path.split("/") if i] - - def set_path_components(self, lst): - """ - Takes a list of strings, and sets the path component of the URL. - - Components are quoted. - """ - lst = [urllib.quote(i, safe="") for i in lst] - path = "/" + "/".join(lst) - scheme, netloc, _, params, query, fragment = urlparse.urlparse(self.get_url()) - self.set_url(urlparse.urlunparse([scheme, netloc, path, params, query, fragment])) - - def get_query(self): - """ - Gets the request query string. Returns an ODict object. - """ - _, _, _, _, query, _ = urlparse.urlparse(self.get_url()) - if query: - return ODict(utils.urldecode(query)) - return ODict([]) - - def set_query(self, odict): - """ - Takes an ODict object, and sets the request query string. - """ - scheme, netloc, path, params, _, fragment = urlparse.urlparse(self.get_url()) - query = utils.urlencode(odict.lst) - self.set_url(urlparse.urlunparse([scheme, netloc, path, params, query, fragment])) - - def get_url(self, hostheader=False): - """ - Returns a URL string, constructed from the Request's URL compnents. - - If hostheader is True, we use the value specified in the request - Host header to construct the URL. - """ - if hostheader: - host = self.headers.get_first("host") or self.host - else: - host = self.host - host = host.encode("idna") - return utils.unparse_url(self.scheme, host, self.port, self.path).encode('ascii') - - def set_url(self, url): - """ - Parses a URL specification, and updates the Request's information - accordingly. - - Returns False if the URL was invalid, True if the request succeeded. - """ - parts = http.parse_url(url) - if not parts: - return False - self.scheme, self.host, self.port, self.path = parts - return True - - def get_cookies(self): - cookie_headers = self.headers.get("cookie") - if not cookie_headers: - return None - - cookies = [] - for header in cookie_headers: - pairs = [pair.partition("=") for pair in header.split(';')] - cookies.extend((pair[0],(pair[2],{})) for pair in pairs) - return dict(cookies) - - def get_header_size(self): - FMT = '%s %s HTTP/%s.%s\r\n%s\r\n' - assembled_header = FMT % ( - self.method, - self.path, - self.httpversion[0], - self.httpversion[1], - str(self.headers) - ) - return len(assembled_header) - - def _assemble_head(self, proxy=False): - FMT = '%s %s HTTP/%s.%s\r\n%s\r\n' - FMT_PROXY = '%s %s://%s:%s%s HTTP/%s.%s\r\n%s\r\n' - - headers = self.headers.copy() - utils.del_all( - headers, - [ - 'proxy-connection', - 'keep-alive', - 'connection', - 'transfer-encoding' - ] - ) - if not 'host' in headers: - headers["host"] = [utils.hostport(self.scheme, self.host, self.port)] - content = self.content - if content: - headers["Content-Length"] = [str(len(content))] - else: - content = "" - if self.close: - headers["connection"] = ["close"] - if not proxy: - return FMT % ( - self.method, - self.path, - self.httpversion[0], - self.httpversion[1], - str(headers) - ) - else: - return FMT_PROXY % ( - self.method, - self.scheme, - self.host, - self.port, - self.path, - self.httpversion[0], - self.httpversion[1], - str(headers) - ) - - def _assemble(self, _proxy = False): - """ - Assembles the request for transmission to the server. We make some - modifications to make sure interception works properly. - - Returns None if the request cannot be assembled. - """ - if self.content == CONTENT_MISSING: - return None - head = self._assemble_head(_proxy) - if self.content: - return head + self.content - else: - return head - - def replace(self, pattern, repl, *args, **kwargs): - """ - Replaces a regular expression pattern with repl in both the headers - and the body of the request. Encoded content will be decoded before - replacement, and re-encoded afterwards. - - Returns the number of replacements made. - """ - with decoded(self): - self.content, c = utils.safe_subn(pattern, repl, self.content, *args, **kwargs) - self.path, pc = utils.safe_subn(pattern, repl, self.path, *args, **kwargs) - c += pc - c += self.headers.replace(pattern, repl, *args, **kwargs) - return c - - -class Response(HTTPMsg): - """ - An HTTP response. - - Exposes the following attributes: - - request: Request object. - - code: HTTP response code - - msg: HTTP response message - - headers: ODict object - - content: Content of the request, None, or CONTENT_MISSING if there - is content associated, but not present. CONTENT_MISSING evaluates - to False to make checking for the presence of content natural. - - timestamp_start: Seconds since the epoch signifying response transmission started - - timestamp_end: Seconds since the epoch signifying response transmission ended - """ - def __init__(self, request, httpversion, code, msg, headers, content, cert, timestamp_start=None, timestamp_end=None): - assert isinstance(headers, ODictCaseless) - self.request = request - self.httpversion, self.code, self.msg = httpversion, code, msg - self.headers, self.content = headers, content - self.cert = cert - self.timestamp_start = timestamp_start or utils.timestamp() - self.timestamp_end = timestamp_end or utils.timestamp() - self.replay = False - - def _refresh_cookie(self, c, delta): - """ - Takes a cookie string c and a time delta in seconds, and returns - a refreshed cookie string. - """ - c = Cookie.SimpleCookie(str(c)) - for i in c.values(): - if "expires" in i: - d = parsedate_tz(i["expires"]) - if d: - d = mktime_tz(d) + delta - i["expires"] = formatdate(d) - else: - # This can happen when the expires tag is invalid. - # reddit.com sends a an expires tag like this: "Thu, 31 Dec - # 2037 23:59:59 GMT", which is valid RFC 1123, but not - # strictly correct according tot he cookie spec. Browsers - # appear to parse this tolerantly - maybe we should too. - # For now, we just ignore this. - del i["expires"] - return c.output(header="").strip() - - def refresh(self, now=None): - """ - This fairly complex and heuristic function refreshes a server - response for replay. - - - It adjusts date, expires and last-modified headers. - - It adjusts cookie expiration. - """ - if not now: - now = time.time() - delta = now - self.timestamp_start - refresh_headers = [ - "date", - "expires", - "last-modified", - ] - for i in refresh_headers: - if i in self.headers: - d = parsedate_tz(self.headers[i][0]) - if d: - new = mktime_tz(d) + delta - self.headers[i] = [formatdate(new)] - c = [] - for i in self.headers["set-cookie"]: - c.append(self._refresh_cookie(i, delta)) - if c: - self.headers["set-cookie"] = c - - def _set_replay(self): - self.replay = True - - def is_replay(self): - """ - Is this response a replay? - """ - return self.replay - - def _load_state(self, state): - self.code = state["code"] - self.msg = state["msg"] - self.headers = ODictCaseless._from_state(state["headers"]) - self.content = state["content"] - self.timestamp_start = state["timestamp_start"] - self.timestamp_end = state["timestamp_end"] - self.cert = certutils.SSLCert.from_pem(state["cert"]) if state["cert"] else None - - def _get_state(self): - return dict( - httpversion = self.httpversion, - code = self.code, - msg = self.msg, - headers = self.headers._get_state(), - timestamp_start = self.timestamp_start, - timestamp_end = self.timestamp_end, - cert = self.cert.to_pem() if self.cert else None, - content = self.content, - ) - - @classmethod - def _from_state(klass, request, state): - return klass( - request, - state["httpversion"], - state["code"], - str(state["msg"]), - ODictCaseless._from_state(state["headers"]), - state["content"], - certutils.SSLCert.from_pem(state["cert"]) if state["cert"] else None, - state["timestamp_start"], - state["timestamp_end"], - ) - - def copy(self): - c = copy.copy(self) - c.headers = self.headers.copy() - return c - - def _assemble_head(self): - FMT = '%s\r\n%s\r\n' - headers = self.headers.copy() - utils.del_all( - headers, - ['proxy-connection', 'transfer-encoding'] - ) - if self.content: - headers["Content-Length"] = [str(len(self.content))] - elif 'Transfer-Encoding' in self.headers: - headers["Content-Length"] = ["0"] - proto = "HTTP/%s.%s %s %s"%(self.httpversion[0], self.httpversion[1], self.code, str(self.msg)) - data = (proto, str(headers)) - return FMT%data - - def _assemble(self): - """ - Assembles the response for transmission to the client. We make some - modifications to make sure interception works properly. - - Returns None if the request cannot be assembled. - """ - if self.content == CONTENT_MISSING: - return None - head = self._assemble_head() - if self.content: - return head + self.content - else: - return head - - def replace(self, pattern, repl, *args, **kwargs): - """ - Replaces a regular expression pattern with repl in both the headers - and the body of the response. Encoded content will be decoded - before replacement, and re-encoded afterwards. - - Returns the number of replacements made. - """ - with decoded(self): - self.content, c = utils.safe_subn(pattern, repl, self.content, *args, **kwargs) - c += self.headers.replace(pattern, repl, *args, **kwargs) - return c - - def get_header_size(self): - FMT = '%s\r\n%s\r\n' - proto = "HTTP/%s.%s %s %s"%(self.httpversion[0], self.httpversion[1], self.code, str(self.msg)) - assembled_header = FMT % (proto, str(self.headers)) - return len(assembled_header) - - def get_cookies(self): - cookie_headers = self.headers.get("set-cookie") - if not cookie_headers: - return None - - cookies = [] - for header in cookie_headers: - pairs = [pair.partition("=") for pair in header.split(';')] - cookie_name = pairs[0][0] # the key of the first key/value pairs - cookie_value = pairs[0][2] # the value of the first key/value pairs - cookie_parameters = {key.strip().lower():value.strip() for key,sep,value in pairs[1:]} - cookies.append((cookie_name, (cookie_value, cookie_parameters))) - return dict(cookies) - - -class ClientDisconnect: - """ - A client disconnection event. - - Exposes the following attributes: - - client_conn: ClientConnect object. - """ - def __init__(self, client_conn): - self.client_conn = client_conn - - -class ClientConnect(StateObject): - """ - A single client connection. Each connection can result in multiple HTTP - Requests. - - Exposes the following attributes: - - address: (address, port) tuple, or None if the connection is replayed. - requestcount: Number of requests created by this client connection. - close: Is the client connection closed? - error: Error string or None. - """ - def __init__(self, address): - """ - address is an (address, port) tuple, or None if this connection has - been replayed from within mitmproxy. - """ - self.address = address - self.close = False - self.requestcount = 0 - self.error = None - - def __str__(self): - if self.address: - return "%s:%d"%(self.address[0],self.address[1]) - - def _load_state(self, state): - self.close = True - self.error = state["error"] - self.requestcount = state["requestcount"] - - def _get_state(self): - return dict( - address = list(self.address), - requestcount = self.requestcount, - error = self.error, - ) - - @classmethod - def _from_state(klass, state): - if state: - k = klass(state["address"]) - k._load_state(state) - return k - else: - return None - - def copy(self): - return copy.copy(self) - - -class Error(StateObject): - """ - An Error. - - This is distinct from an HTTP error response (say, a code 500), which - is represented by a normal Response object. This class is responsible - for indicating errors that fall outside of normal HTTP communications, - like interrupted connections, timeouts, protocol errors. - - Exposes the following attributes: - - request: Request object - msg: Message describing the error - timestamp: Seconds since the epoch - """ - def __init__(self, request, msg, timestamp=None): - self.request, self.msg = request, msg - self.timestamp = timestamp or utils.timestamp() - - def _load_state(self, state): - self.msg = state["msg"] - self.timestamp = state["timestamp"] - - def copy(self): - c = copy.copy(self) - return c - - def _get_state(self): - return dict( - msg = self.msg, - timestamp = self.timestamp, - ) - - @classmethod - def _from_state(klass, request, state): - return klass( - request, - state["msg"], - state["timestamp"], - ) - - def replace(self, pattern, repl, *args, **kwargs): - """ - Replaces a regular expression pattern with repl in both the headers - and the body of the request. Returns the number of replacements - made. - - FIXME: Is replace useful on an Error object?? - """ - self.msg, c = utils.safe_subn(pattern, repl, self.msg, *args, **kwargs) - return c - - class ClientPlaybackState: def __init__(self, flows, exit): self.flows, self.exit = flows, exit @@ -934,7 +174,7 @@ class ClientPlaybackState: if self.flows and not self.current: n = self.flows.pop(0) n.request.reply = controller.DummyReply() - n.request.client_conn = None + n.client_conn = None self.current = master.handle_request(n.request) if not testing and not self.current.response: master.replay_request(self.current) # pragma: no cover @@ -997,7 +237,6 @@ class ServerPlaybackState: return l.pop(0) - class StickyCookieState: def __init__(self, flt): """ @@ -1011,8 +250,8 @@ class StickyCookieState: Returns a (domain, port, path) tuple. """ return ( - m["domain"] or f.request.host, - f.request.port, + m["domain"] or f.request.get_host(), + f.request.get_port(), m["path"] or "/" ) @@ -1030,7 +269,7 @@ class StickyCookieState: c = Cookie.SimpleCookie(str(i)) m = c.values()[0] k = self.ckey(m, f) - if self.domain_match(f.request.host, k[0]): + if self.domain_match(f.request.get_host(), k[0]): self.jar[self.ckey(m, f)] = m def handle_request(self, f): @@ -1038,8 +277,8 @@ class StickyCookieState: if f.match(self.flt): for i in self.jar.keys(): match = [ - self.domain_match(f.request.host, i[0]), - f.request.port == i[1], + self.domain_match(f.request.get_host(), i[0]), + f.request.get_port() == i[1], f.request.path.startswith(i[2]) ] if all(match): @@ -1058,177 +297,16 @@ class StickyAuthState: self.hosts = {} def handle_request(self, f): + host = f.request.get_host() if "authorization" in f.request.headers: - self.hosts[f.request.host] = f.request.headers["authorization"] + self.hosts[host] = f.request.headers["authorization"] elif f.match(self.flt): - if f.request.host in self.hosts: - f.request.headers["authorization"] = self.hosts[f.request.host] - - -class Flow: - """ - A Flow is a collection of objects representing a single HTTP - transaction. The main attributes are: - - request: Request object - response: Response object - error: Error object - - Note that it's possible for a Flow to have both a response and an error - object. This might happen, for instance, when a response was received - from the server, but there was an error sending it back to the client. - - The following additional attributes are exposed: - - intercepting: Is this flow currently being intercepted? - """ - def __init__(self, request): - self.request = request - self.response, self.error = None, None - self.intercepting = False - self._backup = None - - def copy(self): - rc = self.request.copy() - f = Flow(rc) - if self.response: - f.response = self.response.copy() - f.response.request = rc - if self.error: - f.error = self.error.copy() - f.error.request = rc - return f - - @classmethod - def _from_state(klass, state): - f = klass(None) - f._load_state(state) - return f - - def _get_state(self): - d = dict( - request = self.request._get_state() if self.request else None, - response = self.response._get_state() if self.response else None, - error = self.error._get_state() if self.error else None, - version = version.IVERSION - ) - return d - - def _load_state(self, state): - if self.request: - self.request._load_state(state["request"]) - else: - self.request = Request._from_state(state["request"]) - - if state["response"]: - if self.response: - self.response._load_state(state["response"]) - else: - self.response = Response._from_state(self.request, state["response"]) - else: - self.response = None - - if state["error"]: - if self.error: - self.error._load_state(state["error"]) - else: - self.error = Error._from_state(self.request, state["error"]) - else: - self.error = None - - def modified(self): - """ - Has this Flow been modified? - """ - # FIXME: Save a serialization in backup, compare current with - # backup to detect if flow has _really_ been modified. - if self._backup: - return True - else: - return False - - def backup(self, force=False): - """ - Save a backup of this Flow, which can be reverted to using a - call to .revert(). - """ - if not self._backup: - self._backup = self._get_state() - - def revert(self): - """ - Revert to the last backed up state. - """ - if self._backup: - self._load_state(self._backup) - self._backup = None - - def match(self, f): - """ - Match this flow against a compiled filter expression. Returns True - if matched, False if not. - - If f is a string, it will be compiled as a filter expression. If - the expression is invalid, ValueError is raised. - """ - if isinstance(f, basestring): - f = filt.parse(f) - if not f: - raise ValueError("Invalid filter expression.") - if f: - return f(self) - return True - - def kill(self, master): - """ - Kill this request. - """ - self.error = Error(self.request, "Connection killed") - self.error.reply = controller.DummyReply() - if self.request and not self.request.reply.acked: - self.request.reply(proxy.KILL) - elif self.response and not self.response.reply.acked: - self.response.reply(proxy.KILL) - master.handle_error(self.error) - self.intercepting = False - - def intercept(self): - """ - Intercept this Flow. Processing will stop until accept_intercept is - called. - """ - self.intercepting = True - - def accept_intercept(self): - """ - Continue with the flow - called after an intercept(). - """ - if self.request: - if not self.request.reply.acked: - self.request.reply() - elif self.response and not self.response.reply.acked: - self.response.reply() - self.intercepting = False - - def replace(self, pattern, repl, *args, **kwargs): - """ - Replaces a regular expression pattern with repl in all parts of the - flow. Encoded content will be decoded before replacement, and - re-encoded afterwards. - - Returns the number of replacements made. - """ - c = self.request.replace(pattern, repl, *args, **kwargs) - if self.response: - c += self.response.replace(pattern, repl, *args, **kwargs) - if self.error: - c += self.error.replace(pattern, repl, *args, **kwargs) - return c + if host in self.hosts: + f.request.headers["authorization"] = self.hosts[host] class State(object): def __init__(self): - self._flow_map = {} self._flow_list = [] self.view = [] @@ -1242,7 +320,7 @@ class State(object): return self._limit_txt def flow_count(self): - return len(self._flow_map) + return len(self._flow_list) def index(self, f): return self._flow_list.index(f) @@ -1258,10 +336,8 @@ class State(object): """ Add a request to the state. Returns the matching flow. """ - f = Flow(req) + f = req.flow self._flow_list.append(f) - self._flow_map[req] = f - assert len(self._flow_list) == len(self._flow_map) if f.match(self._limit): self.view.append(f) return f @@ -1270,10 +346,9 @@ class State(object): """ Add a response to the state. Returns the matching flow. """ - f = self._flow_map.get(resp.request) + f = resp.flow if not f: return False - f.response = resp if f.match(self._limit) and not f in self.view: self.view.append(f) return f @@ -1283,18 +358,15 @@ class State(object): Add an error response to the state. Returns the matching flow, or None if there isn't one. """ - f = self._flow_map.get(err.request) + f = err.flow if not f: return None - f.error = err if f.match(self._limit) and not f in self.view: self.view.append(f) return f def load_flows(self, flows): self._flow_list.extend(flows) - for i in flows: - self._flow_map[i.request] = i self.recalculate_view() def set_limit(self, txt): @@ -1327,8 +399,6 @@ class State(object): self.view = self._flow_list[:] def delete_flow(self, f): - if f.request in self._flow_map: - del self._flow_map[f.request] self._flow_list.remove(f) if f in self.view: self.view.remove(f) @@ -1383,7 +453,28 @@ class FlowMaster(controller.Master): port ) else: - threading.Thread(target=app.mapp.run,kwargs={ + @app.mapp.before_request + def patch_environ(*args, **kwargs): + request.environ["mitmproxy.master"] = self + + # the only absurd way to shut down a flask/werkzeug server. + # http://flask.pocoo.org/snippets/67/ + shutdown_secret = base64.b32encode(os.urandom(30)) + + @app.mapp.route('/shutdown/<secret>') + def shutdown(secret): + if secret == shutdown_secret: + request.environ.get('werkzeug.server.shutdown')() + + # Workaround: Monkey-patch shutdown function to stop the app. + # Improve this when we switch flask werkzeug for something useful. + _shutdown = self.shutdown + def _shutdownwrap(): + _shutdown() + requests.get("http://%s:%s/shutdown/%s" % (host, port, shutdown_secret)) + self.shutdown = _shutdownwrap + + threading.Thread(target=app.mapp.run, kwargs={ "use_reloader": False, "host": host, "port": port}).start() @@ -1474,9 +565,8 @@ class FlowMaster(controller.Master): rflow = self.server_playback.next_flow(flow) if not rflow: return None - response = Response._from_state(flow.request, rflow.response._get_state()) - response._set_replay() - flow.response = response + response = HTTPResponse._from_state(rflow.response._get_state()) + response.is_replay = True if self.refresh_server_playback: response.refresh() flow.request.reply(response) @@ -1555,13 +645,13 @@ class FlowMaster(controller.Master): if f.request.content == CONTENT_MISSING: return "Can't replay request with missing content..." if f.request: - f.request._set_replay() + f.request.is_replay = True if f.request.content: f.request.headers["Content-Length"] = [str(len(f.request.content))] f.response = None f.error = None self.process_new_request(f) - rt = proxy.RequestReplayThread( + rt = RequestReplayThread( self.server.config, f, self.masterq, @@ -1597,13 +687,13 @@ class FlowMaster(controller.Master): return f def handle_request(self, r): - if r.is_live(): + if r.flow.client_conn and r.flow.client_conn.wfile: app = self.apps.get(r) if app: - err = app.serve(r, r.wfile, **{"mitmproxy.master": self}) + err = app.serve(r, r.flow.client_conn.wfile, **{"mitmproxy.master": self}) if err: self.add_event("Error in wsgi app. %s"%err, "error") - r.reply(proxy.KILL) + r.reply(KILL) return f = self.state.add_request(r) self.replacehooks.run(f) @@ -1676,7 +766,7 @@ class FlowReader: v = ".".join(str(i) for i in data["version"]) raise FlowReadError("Incompatible serialized data version: %s"%v) off = self.fo.tell() - yield Flow._from_state(data) + yield protocol.protocols[data["conntype"]]["flow"]._from_state(data) except ValueError, v: # Error is due to EOF if self.fo.tell() == off and self.fo.read() == '': diff --git a/libmproxy/protocol/__init__.py b/libmproxy/protocol/__init__.py new file mode 100644 index 00000000..580d693c --- /dev/null +++ b/libmproxy/protocol/__init__.py @@ -0,0 +1,101 @@ +from ..proxy import ServerConnection, AddressPriority + +KILL = 0 # const for killed requests + +class ConnectionTypeChange(Exception): + """ + Gets raised if the connetion type has been changed (e.g. after HTTP/1.1 101 Switching Protocols). + It's up to the raising ProtocolHandler to specify the new conntype before raising the exception. + """ + pass + + +class ProtocolHandler(object): + def __init__(self, c): + self.c = c + """@type: libmproxy.proxy.ConnectionHandler""" + + def handle_messages(self): + """ + This method gets called if a client connection has been made. Depending on the proxy settings, + a server connection might already exist as well. + """ + raise NotImplementedError # pragma: nocover + + def handle_error(self, error): + """ + This method gets called should there be an uncaught exception during the connection. + This might happen outside of handle_messages, e.g. if the initial SSL handshake fails in transparent mode. + """ + raise error # pragma: nocover + + +class TemporaryServerChangeMixin(object): + """ + This mixin allows safe modification of the target server, + without any need to expose the ConnectionHandler to the Flow. + """ + + def change_server(self, address, ssl): + if address == self.c.server_conn.address(): + return + priority = AddressPriority.MANUALLY_CHANGED + + if self.c.server_conn.priority > priority: + self.log("Attempt to change server address, " + "but priority is too low (is: %s, got: %s)" % (self.server_conn.priority, priority)) + return + + self.log("Temporarily change server connection: %s:%s -> %s:%s" % ( + self.c.server_conn.address.host, + self.c.server_conn.address.port, + address.host, + address.port + )) + + if not hasattr(self, "_backup_server_conn"): + self._backup_server_conn = self.c.server_conn + self.c.server_conn = None + else: # This is at least the second temporary change. We can kill the current connection. + self.c.del_server_connection() + + self.c.set_server_address(address, priority) + if ssl: + self.establish_ssl(server=True) + + def restore_server(self): + if not hasattr(self, "_backup_server_conn"): + return + + self.log("Restore original server connection: %s:%s -> %s:%s" % ( + self.c.server_conn.address.host, + self.c.server_conn.address.port, + self._backup_server_conn.host, + self._backup_server_conn.port + )) + + self.c.del_server_connection() + self.c.server_conn = self._backup_server_conn + del self._backup_server_conn + +from . import http, tcp + +protocols = { + 'http': dict(handler=http.HTTPHandler, flow=http.HTTPFlow), + 'tcp': dict(handler=tcp.TCPHandler) +} # PyCharm type hinting behaves bad if this is a dict constructor... + + +def _handler(conntype, connection_handler): + if conntype in protocols: + return protocols[conntype]["handler"](connection_handler) + + raise NotImplementedError # pragma: nocover + + +def handle_messages(conntype, connection_handler): + return _handler(conntype, connection_handler).handle_messages() + + +def handle_error(conntype, connection_handler, error): + return _handler(conntype, connection_handler).handle_error(error)
\ No newline at end of file diff --git a/libmproxy/protocol/http.py b/libmproxy/protocol/http.py new file mode 100644 index 00000000..95de6606 --- /dev/null +++ b/libmproxy/protocol/http.py @@ -0,0 +1,1046 @@ +import Cookie, urllib, urlparse, time, copy +from email.utils import parsedate_tz, formatdate, mktime_tz +import netlib.utils +from netlib import http, tcp, http_status, odict +from netlib.odict import ODict, ODictCaseless +from . import ProtocolHandler, ConnectionTypeChange, KILL, TemporaryServerChangeMixin +from .. import encoding, utils, version, filt, controller, stateobject +from ..proxy import ProxyError, AddressPriority, ServerConnection +from .primitives import Flow, Error + + +HDR_FORM_URLENCODED = "application/x-www-form-urlencoded" +CONTENT_MISSING = 0 + + +def get_line(fp): + """ + Get a line, possibly preceded by a blank. + """ + line = fp.readline() + if line == "\r\n" or line == "\n": # Possible leftover from previous message + line = fp.readline() + if line == "": + raise tcp.NetLibDisconnect + return line + + +class decoded(object): + """ + A context manager that decodes a request or response, and then + re-encodes it with the same encoding after execution of the block. + + Example: + with decoded(request): + request.content = request.content.replace("foo", "bar") + """ + + def __init__(self, o): + self.o = o + ce = o.headers.get_first("content-encoding") + if ce in encoding.ENCODINGS: + self.ce = ce + else: + self.ce = None + + def __enter__(self): + if self.ce: + self.o.decode() + + def __exit__(self, type, value, tb): + if self.ce: + self.o.encode(self.ce) + + +class HTTPMessage(stateobject.SimpleStateObject): + def __init__(self, httpversion, headers, content, timestamp_start=None, timestamp_end=None): + self.httpversion = httpversion + self.headers = headers + """@type: ODictCaseless""" + self.content = content + self.timestamp_start = timestamp_start + self.timestamp_end = timestamp_end + + self.flow = None # will usually be set by the flow backref mixin + """@type: HTTPFlow""" + + _stateobject_attributes = dict( + httpversion=tuple, + headers=ODictCaseless, + content=str, + timestamp_start=float, + timestamp_end=float + ) + + def get_decoded_content(self): + """ + Returns the decoded content based on the current Content-Encoding header. + Doesn't change the message iteself or its headers. + """ + ce = self.headers.get_first("content-encoding") + if not self.content or ce not in encoding.ENCODINGS: + return self.content + return encoding.decode(ce, self.content) + + def decode(self): + """ + Decodes content based on the current Content-Encoding header, then + removes the header. If there is no Content-Encoding header, no + action is taken. + + Returns True if decoding succeeded, False otherwise. + """ + ce = self.headers.get_first("content-encoding") + if not self.content or ce not in encoding.ENCODINGS: + return False + data = encoding.decode(ce, self.content) + if data is None: + return False + self.content = data + del self.headers["content-encoding"] + return True + + def encode(self, e): + """ + Encodes content with the encoding e, where e is "gzip", "deflate" + or "identity". + """ + # FIXME: Error if there's an existing encoding header? + self.content = encoding.encode(e, self.content) + self.headers["content-encoding"] = [e] + + def size(self, **kwargs): + """ + Size in bytes of a fully rendered message, including headers and + HTTP lead-in. + """ + hl = len(self._assemble_head(**kwargs)) + if self.content: + return hl + len(self.content) + else: + return hl + + def copy(self): + c = copy.copy(self) + c.headers = self.headers.copy() + return c + + def replace(self, pattern, repl, *args, **kwargs): + """ + Replaces a regular expression pattern with repl in both the headers + and the body of the message. Encoded content will be decoded + before replacement, and re-encoded afterwards. + + Returns the number of replacements made. + """ + with decoded(self): + self.content, c = utils.safe_subn(pattern, repl, self.content, *args, **kwargs) + c += self.headers.replace(pattern, repl, *args, **kwargs) + return c + + @classmethod + def from_stream(cls, rfile, include_content=True, body_size_limit=None): + """ + Parse an HTTP message from a file stream + """ + raise NotImplementedError # pragma: nocover + + def _assemble_first_line(self): + """ + Returns the assembled request/response line + """ + raise NotImplementedError # pragma: nocover + + def _assemble_headers(self): + """ + Returns the assembled headers + """ + raise NotImplementedError # pragma: nocover + + def _assemble_head(self): + """ + Returns the assembled request/response line plus headers + """ + raise NotImplementedError # pragma: nocover + + def _assemble(self): + """ + Returns the assembled request/response + """ + raise NotImplementedError # pragma: nocover + + +class HTTPRequest(HTTPMessage): + """ + An HTTP request. + + Exposes the following attributes: + + flow: Flow object the request belongs to + + headers: ODictCaseless object + + content: Content of the request, None, or CONTENT_MISSING if there + is content associated, but not present. CONTENT_MISSING evaluates + to False to make checking for the presence of content natural. + + form_in: The request form which mitmproxy has received. The following values are possible: + - origin (GET /index.html) + - absolute (GET http://example.com:80/index.html) + - authority-form (CONNECT example.com:443) + - asterisk-form (OPTIONS *) + Details: http://tools.ietf.org/html/draft-ietf-httpbis-p1-messaging-25#section-5.3 + + form_out: The request form which mitmproxy has send out to the destination + + method: HTTP method + + scheme: URL scheme (http/https) (absolute-form only) + + host: Host portion of the URL (absolute-form and authority-form only) + + port: Destination port (absolute-form and authority-form only) + + path: Path portion of the URL (not present in authority-form) + + httpversion: HTTP version tuple + + timestamp_start: Timestamp indicating when request transmission started + + timestamp_end: Timestamp indicating when request transmission ended + """ + def __init__(self, form_in, method, scheme, host, port, path, httpversion, headers, content, + timestamp_start=None, timestamp_end=None, form_out=None): + assert isinstance(headers, ODictCaseless) or not headers + HTTPMessage.__init__(self, httpversion, headers, content, timestamp_start, timestamp_end) + + self.form_in = form_in + self.method = method + self.scheme = scheme + self.host = host + self.port = port + self.path = path + self.httpversion = httpversion + self.form_out = form_out or form_in + + # Have this request's cookies been modified by sticky cookies or auth? + self.stickycookie = False + self.stickyauth = False + # Is this request replayed? + self.is_replay = False + + _stateobject_attributes = HTTPMessage._stateobject_attributes.copy() + _stateobject_attributes.update( + form_in=str, + method=str, + scheme=str, + host=str, + port=int, + path=str, + form_out=str + ) + + @classmethod + def _from_state(cls, state): + f = cls(None, None, None, None, None, None, None, None, None, None, None) + f._load_state(state) + return f + + @classmethod + def from_stream(cls, rfile, include_content=True, body_size_limit=None): + """ + Parse an HTTP request from a file stream + """ + httpversion, host, port, scheme, method, path, headers, content, timestamp_start, timestamp_end \ + = None, None, None, None, None, None, None, None, None, None + + if hasattr(rfile, "reset_timestamps"): + rfile.reset_timestamps() + + request_line = get_line(rfile) + + if hasattr(rfile, "first_byte_timestamp"): + timestamp_start = rfile.first_byte_timestamp + else: + timestamp_start = utils.timestamp() + + request_line_parts = http.parse_init(request_line) + if not request_line_parts: + raise http.HttpError(400, "Bad HTTP request line: %s" % repr(request_line)) + method, path, httpversion = request_line_parts + + if path == '*': + form_in = "asterisk" + elif path.startswith("/"): + form_in = "origin" + if not netlib.utils.isascii(path): + raise http.HttpError(400, "Bad HTTP request line: %s" % repr(request_line)) + elif method.upper() == 'CONNECT': + form_in = "authority" + r = http.parse_init_connect(request_line) + if not r: + raise http.HttpError(400, "Bad HTTP request line: %s" % repr(request_line)) + host, port, _ = r + path = None + else: + form_in = "absolute" + r = http.parse_init_proxy(request_line) + if not r: + raise http.HttpError(400, "Bad HTTP request line: %s" % repr(request_line)) + _, scheme, host, port, path, _ = r + + headers = http.read_headers(rfile) + if headers is None: + raise http.HttpError(400, "Invalid headers") + + if include_content: + content = http.read_http_body(rfile, headers, body_size_limit, True) + timestamp_end = utils.timestamp() + + return HTTPRequest(form_in, method, scheme, host, port, path, httpversion, headers, content, + timestamp_start, timestamp_end) + + def _assemble_first_line(self, form=None): + form = form or self.form_out + + if form == "asterisk" or \ + form == "origin": + request_line = '%s %s HTTP/%s.%s' % (self.method, self.path, self.httpversion[0], self.httpversion[1]) + elif form == "authority": + request_line = '%s %s:%s HTTP/%s.%s' % (self.method, self.host, self.port, + self.httpversion[0], self.httpversion[1]) + elif form == "absolute": + request_line = '%s %s://%s:%s%s HTTP/%s.%s' % \ + (self.method, self.scheme, self.host, self.port, self.path, + self.httpversion[0], self.httpversion[1]) + else: + raise http.HttpError(400, "Invalid request form") + return request_line + + def _assemble_headers(self): + headers = self.headers.copy() + utils.del_all( + headers, + [ + 'Proxy-Connection', + 'Keep-Alive', + 'Connection', + 'Transfer-Encoding' + ] + ) + if not 'host' in headers: + headers["Host"] = [utils.hostport(self.scheme, + self.host or self.flow.server_conn.address.host, + self.port or self.flow.server_conn.address.port)] + + if self.content: + headers["Content-Length"] = [str(len(self.content))] + elif 'Transfer-Encoding' in self.headers: # content-length for e.g. chuncked transfer-encoding with no content + headers["Content-Length"] = ["0"] + + return str(headers) + + def _assemble_head(self, form=None): + return "%s\r\n%s\r\n" % (self._assemble_first_line(form), self._assemble_headers()) + + def _assemble(self, form=None): + """ + Assembles the request for transmission to the server. We make some + modifications to make sure interception works properly. + + Raises an Exception if the request cannot be assembled. + """ + if self.content == CONTENT_MISSING: + raise ProxyError(502, "Cannot assemble flow with CONTENT_MISSING") + head = self._assemble_head(form) + if self.content: + return head + self.content + else: + return head + + def __hash__(self): + return id(self) + + def anticache(self): + """ + Modifies this request to remove headers that might produce a cached + response. That is, we remove ETags and If-Modified-Since headers. + """ + delheaders = [ + "if-modified-since", + "if-none-match", + ] + for i in delheaders: + del self.headers[i] + + def anticomp(self): + """ + Modifies this request to remove headers that will compress the + resource's data. + """ + self.headers["accept-encoding"] = ["identity"] + + def constrain_encoding(self): + """ + Limits the permissible Accept-Encoding values, based on what we can + decode appropriately. + """ + if self.headers["accept-encoding"]: + self.headers["accept-encoding"] = [', '.join( + e for e in encoding.ENCODINGS if e in self.headers["accept-encoding"][0] + )] + + def get_form_urlencoded(self): + """ + Retrieves the URL-encoded form data, returning an ODict object. + Returns an empty ODict if there is no data or the content-type + indicates non-form data. + """ + if self.content and self.headers.in_any("content-type", HDR_FORM_URLENCODED, True): + return ODict(utils.urldecode(self.content)) + return ODict([]) + + def set_form_urlencoded(self, odict): + """ + Sets the body to the URL-encoded form data, and adds the + appropriate content-type header. Note that this will destory the + existing body if there is one. + """ + # FIXME: If there's an existing content-type header indicating a + # url-encoded form, leave it alone. + self.headers["Content-Type"] = [HDR_FORM_URLENCODED] + self.content = utils.urlencode(odict.lst) + + def get_path_components(self): + """ + Returns the path components of the URL as a list of strings. + + Components are unquoted. + """ + _, _, path, _, _, _ = urlparse.urlparse(self.get_url()) + return [urllib.unquote(i) for i in path.split("/") if i] + + def set_path_components(self, lst): + """ + Takes a list of strings, and sets the path component of the URL. + + Components are quoted. + """ + lst = [urllib.quote(i, safe="") for i in lst] + path = "/" + "/".join(lst) + scheme, netloc, _, params, query, fragment = urlparse.urlparse(self.get_url()) + self.set_url(urlparse.urlunparse([scheme, netloc, path, params, query, fragment])) + + def get_query(self): + """ + Gets the request query string. Returns an ODict object. + """ + _, _, _, _, query, _ = urlparse.urlparse(self.get_url()) + if query: + return ODict(utils.urldecode(query)) + return ODict([]) + + def set_query(self, odict): + """ + Takes an ODict object, and sets the request query string. + """ + scheme, netloc, path, params, _, fragment = urlparse.urlparse(self.get_url()) + query = utils.urlencode(odict.lst) + self.set_url(urlparse.urlunparse([scheme, netloc, path, params, query, fragment])) + + def get_host(self, hostheader=False): + """ + Heuristic to get the host of the request. + The host is not necessarily equal to the TCP destination of the request, + for example on a transparently proxified absolute-form request to an upstream HTTP proxy. + If hostheader is set to True, the Host: header will be used as additional (and preferred) data source. + """ + host = None + if hostheader: + host = self.headers.get_first("host") + if not host: + if self.host: + host = self.host + else: + host = self.flow.server_conn.address.host + host = host.encode("idna") + return host + + def get_scheme(self): + """ + Returns the request port, either from the request itself or from the flow's server connection + """ + if self.scheme: + return self.scheme + return "https" if self.flow.server_conn.ssl_established else "http" + + def get_port(self): + """ + Returns the request port, either from the request itself or from the flow's server connection + """ + if self.port: + return self.port + return self.flow.server_conn.address.port + + def get_url(self, hostheader=False): + """ + Returns a URL string, constructed from the Request's URL components. + + If hostheader is True, we use the value specified in the request + Host header to construct the URL. + """ + return utils.unparse_url(self.get_scheme(), + self.get_host(hostheader), + self.get_port(), + self.path).encode('ascii') + + def set_url(self, url): + """ + Parses a URL specification, and updates the Request's information + accordingly. + + Returns False if the URL was invalid, True if the request succeeded. + """ + parts = http.parse_url(url) + if not parts: + return False + scheme, host, port, path = parts + is_ssl = (True if scheme == "https" else False) + + self.path = path + + if host != self.get_host() or port != self.get_port(): + if self.flow.change_server: + self.flow.change_server((host, port), ssl=is_ssl) + else: + # There's not live server connection, we're just changing the attributes here. + self.flow.server_conn = ServerConnection((host, port), AddressPriority.MANUALLY_CHANGED) + self.flow.server_conn.ssl_established = is_ssl + + # If this is an absolute request, replace the attributes on the request object as well. + if self.host: + self.host = host + if self.port: + self.port = port + if self.scheme: + self.scheme = scheme + + return True + + def get_cookies(self): + cookie_headers = self.headers.get("cookie") + if not cookie_headers: + return None + + cookies = [] + for header in cookie_headers: + pairs = [pair.partition("=") for pair in header.split(';')] + cookies.extend((pair[0], (pair[2], {})) for pair in pairs) + return dict(cookies) + + def replace(self, pattern, repl, *args, **kwargs): + """ + Replaces a regular expression pattern with repl in the headers, the request path + and the body of the request. Encoded content will be decoded before + replacement, and re-encoded afterwards. + + Returns the number of replacements made. + """ + c = HTTPMessage.replace(self, pattern, repl, *args, **kwargs) + self.path, pc = utils.safe_subn(pattern, repl, self.path, *args, **kwargs) + c += pc + return c + + +class HTTPResponse(HTTPMessage): + """ + An HTTP response. + + Exposes the following attributes: + + flow: Flow object the request belongs to + + code: HTTP response code + + msg: HTTP response message + + headers: ODict object + + content: Content of the request, None, or CONTENT_MISSING if there + is content associated, but not present. CONTENT_MISSING evaluates + to False to make checking for the presence of content natural. + + httpversion: HTTP version tuple + + timestamp_start: Timestamp indicating when request transmission started + + timestamp_end: Timestamp indicating when request transmission ended + """ + def __init__(self, httpversion, code, msg, headers, content, timestamp_start=None, timestamp_end=None): + assert isinstance(headers, ODictCaseless) or headers is None + HTTPMessage.__init__(self, httpversion, headers, content, timestamp_start, timestamp_end) + + self.code = code + self.msg = msg + + # Is this request replayed? + self.is_replay = False + + _stateobject_attributes = HTTPMessage._stateobject_attributes.copy() + _stateobject_attributes.update( + code=int, + msg=str + ) + + @classmethod + def _from_state(cls, state): + f = cls(None, None, None, None, None) + f._load_state(state) + return f + + @classmethod + def from_stream(cls, rfile, request_method, include_content=True, body_size_limit=None): + """ + Parse an HTTP response from a file stream + """ + if not include_content: + raise NotImplementedError # pragma: nocover + + if hasattr(rfile, "reset_timestamps"): + rfile.reset_timestamps() + + httpversion, code, msg, headers, content = http.read_response( + rfile, + request_method, + body_size_limit) + + if hasattr(rfile, "first_byte_timestamp"): + timestamp_start = rfile.first_byte_timestamp + else: + timestamp_start = utils.timestamp() + + timestamp_end = utils.timestamp() + return HTTPResponse(httpversion, code, msg, headers, content, timestamp_start, timestamp_end) + + def _assemble_first_line(self): + return 'HTTP/%s.%s %s %s' % (self.httpversion[0], self.httpversion[1], self.code, self.msg) + + def _assemble_headers(self): + headers = self.headers.copy() + utils.del_all( + headers, + [ + 'Proxy-Connection', + 'Transfer-Encoding' + ] + ) + if self.content: + headers["Content-Length"] = [str(len(self.content))] + elif 'Transfer-Encoding' in self.headers: # add content-length for chuncked transfer-encoding with no content + headers["Content-Length"] = ["0"] + + return str(headers) + + def _assemble_head(self): + return '%s\r\n%s\r\n' % (self._assemble_first_line(), self._assemble_headers()) + + def _assemble(self): + """ + Assembles the response for transmission to the client. We make some + modifications to make sure interception works properly. + + Raises an Exception if the request cannot be assembled. + """ + if self.content == CONTENT_MISSING: + raise ProxyError(502, "Cannot assemble flow with CONTENT_MISSING") + head = self._assemble_head() + if self.content: + return head + self.content + else: + return head + + def _refresh_cookie(self, c, delta): + """ + Takes a cookie string c and a time delta in seconds, and returns + a refreshed cookie string. + """ + c = Cookie.SimpleCookie(str(c)) + for i in c.values(): + if "expires" in i: + d = parsedate_tz(i["expires"]) + if d: + d = mktime_tz(d) + delta + i["expires"] = formatdate(d) + else: + # This can happen when the expires tag is invalid. + # reddit.com sends a an expires tag like this: "Thu, 31 Dec + # 2037 23:59:59 GMT", which is valid RFC 1123, but not + # strictly correct according to the cookie spec. Browsers + # appear to parse this tolerantly - maybe we should too. + # For now, we just ignore this. + del i["expires"] + return c.output(header="").strip() + + def refresh(self, now=None): + """ + This fairly complex and heuristic function refreshes a server + response for replay. + + - It adjusts date, expires and last-modified headers. + - It adjusts cookie expiration. + """ + if not now: + now = time.time() + delta = now - self.timestamp_start + refresh_headers = [ + "date", + "expires", + "last-modified", + ] + for i in refresh_headers: + if i in self.headers: + d = parsedate_tz(self.headers[i][0]) + if d: + new = mktime_tz(d) + delta + self.headers[i] = [formatdate(new)] + c = [] + for i in self.headers["set-cookie"]: + c.append(self._refresh_cookie(i, delta)) + if c: + self.headers["set-cookie"] = c + + def get_cookies(self): + cookie_headers = self.headers.get("set-cookie") + if not cookie_headers: + return None + + cookies = [] + for header in cookie_headers: + pairs = [pair.partition("=") for pair in header.split(';')] + cookie_name = pairs[0][0] # the key of the first key/value pairs + cookie_value = pairs[0][2] # the value of the first key/value pairs + cookie_parameters = {key.strip().lower(): value.strip() for key, sep, value in pairs[1:]} + cookies.append((cookie_name, (cookie_value, cookie_parameters))) + return dict(cookies) + + +class HTTPFlow(Flow): + """ + A Flow is a collection of objects representing a single HTTP + transaction. The main attributes are: + + request: HTTPRequest object + response: HTTPResponse object + error: Error object + + Note that it's possible for a Flow to have both a response and an error + object. This might happen, for instance, when a response was received + from the server, but there was an error sending it back to the client. + + The following additional attributes are exposed: + + intercepting: Is this flow currently being intercepted? + """ + def __init__(self, client_conn, server_conn, change_server=None): + Flow.__init__(self, "http", client_conn, server_conn) + self.request = None + """@type: HTTPRequest""" + self.response = None + """@type: HTTPResponse""" + self.change_server = change_server # Used by flow.request.set_url to change the server address + + self.intercepting = False # FIXME: Should that rather be an attribute of Flow? + + _backrefattr = Flow._backrefattr + ("request", "response") + + _stateobject_attributes = Flow._stateobject_attributes.copy() + _stateobject_attributes.update( + request=HTTPRequest, + response=HTTPResponse + ) + + @classmethod + def _from_state(cls, state): + f = cls(None, None) + f._load_state(state) + return f + + def copy(self): + f = super(HTTPFlow, self).copy() + if self.request: + f.request = self.request.copy() + if self.response: + f.response = self.response.copy() + return f + + def match(self, f): + """ + Match this flow against a compiled filter expression. Returns True + if matched, False if not. + + If f is a string, it will be compiled as a filter expression. If + the expression is invalid, ValueError is raised. + """ + if isinstance(f, basestring): + f = filt.parse(f) + if not f: + raise ValueError("Invalid filter expression.") + if f: + return f(self) + return True + + def kill(self, master): + """ + Kill this request. + """ + self.error = Error("Connection killed") + self.error.reply = controller.DummyReply() + if self.request and not self.request.reply.acked: + self.request.reply(KILL) + elif self.response and not self.response.reply.acked: + self.response.reply(KILL) + master.handle_error(self.error) + self.intercepting = False + + def intercept(self): + """ + Intercept this Flow. Processing will stop until accept_intercept is + called. + """ + self.intercepting = True + + def accept_intercept(self): + """ + Continue with the flow - called after an intercept(). + """ + if self.request: + if not self.request.reply.acked: + self.request.reply() + elif self.response and not self.response.reply.acked: + self.response.reply() + self.intercepting = False + + def replace(self, pattern, repl, *args, **kwargs): + """ + Replaces a regular expression pattern with repl in both request and response of the + flow. Encoded content will be decoded before replacement, and + re-encoded afterwards. + + Returns the number of replacements made. + """ + c = self.request.replace(pattern, repl, *args, **kwargs) + if self.response: + c += self.response.replace(pattern, repl, *args, **kwargs) + return c + + +class HttpAuthenticationError(Exception): + def __init__(self, auth_headers=None): + self.auth_headers = auth_headers + + def __str__(self): + return "HttpAuthenticationError" + + +class HTTPHandler(ProtocolHandler, TemporaryServerChangeMixin): + + def handle_messages(self): + while self.handle_flow(): + pass + self.c.close = True + + def get_response_from_server(self, request): + self.c.establish_server_connection() + request_raw = request._assemble() + + for i in range(2): + try: + self.c.server_conn.send(request_raw) + return HTTPResponse.from_stream(self.c.server_conn.rfile, request.method, + body_size_limit=self.c.config.body_size_limit) + except (tcp.NetLibDisconnect, http.HttpErrorConnClosed), v: + self.c.log("error in server communication: %s" % str(v)) + if i < 1: + # In any case, we try to reconnect at least once. + # This is necessary because it might be possible that we already initiated an upstream connection + # after clientconnect that has already been expired, e.g consider the following event log: + # > clientconnect (transparent mode destination known) + # > serverconnect + # > read n% of large request + # > server detects timeout, disconnects + # > read (100-n)% of large request + # > send large request upstream + self.c.server_reconnect() + else: + raise v + + def handle_flow(self): + flow = HTTPFlow(self.c.client_conn, self.c.server_conn, self.change_server) + try: + flow.request = HTTPRequest.from_stream(self.c.client_conn.rfile, + body_size_limit=self.c.config.body_size_limit) + self.c.log("request", [flow.request._assemble_first_line(flow.request.form_in)]) + self.process_request(flow.request) + + request_reply = self.c.channel.ask("request", flow.request) + flow.server_conn = self.c.server_conn + + if request_reply is None or request_reply == KILL: + return False + + if isinstance(request_reply, HTTPResponse): + flow.response = request_reply + else: + flow.response = self.get_response_from_server(flow.request) + + flow.server_conn = self.c.server_conn # no further manipulation of self.c.server_conn beyond this point + # we can safely set it as the final attribute value here. + + self.c.log("response", [flow.response._assemble_first_line()]) + response_reply = self.c.channel.ask("response", flow.response) + if response_reply is None or response_reply == KILL: + return False + + self.c.client_conn.send(flow.response._assemble()) + flow.timestamp_end = utils.timestamp() + + if (http.connection_close(flow.request.httpversion, flow.request.headers) or + http.connection_close(flow.response.httpversion, flow.response.headers)): + return False + + if flow.request.form_in == "authority": + self.ssl_upgrade() + + self.restore_server() # If the user has changed the target server on this connection, + # restore the original target server + return True + except (HttpAuthenticationError, http.HttpError, ProxyError, tcp.NetLibError), e: + self.handle_error(e, flow) + return False + + def handle_error(self, error, flow=None): + code, message, headers = None, None, None + if isinstance(error, HttpAuthenticationError): + code = 407 + message = "Proxy Authentication Required" + headers = error.auth_headers + elif isinstance(error, (http.HttpError, ProxyError)): + code = error.code + message = error.msg + elif isinstance(error, tcp.NetLibError): + code = 502 + message = error.message or error.__class__ + + if code: + err = "%s: %s" % (code, message) + else: + err = error.__class__ + + self.c.log("error: %s" % err) + + if flow: + flow.error = Error(err) + if flow.request and not flow.response: + # FIXME: no flows without request or with both request and response at the moement. + self.c.channel.ask("error", flow.error) + else: + pass # FIXME: Is there any use case for persisting errors that occur outside of flows? + + if code: + try: + self.send_error(code, message, headers) + except: + pass + + def send_error(self, code, message, headers): + response = http_status.RESPONSES.get(code, "Unknown") + html_content = '<html><head>\n<title>%d %s</title>\n</head>\n<body>\n%s\n</body>\n</html>' % \ + (code, response, message) + self.c.client_conn.wfile.write("HTTP/1.1 %s %s\r\n" % (code, response)) + self.c.client_conn.wfile.write("Server: %s\r\n" % self.c.server_version) + self.c.client_conn.wfile.write("Content-type: text/html\r\n") + self.c.client_conn.wfile.write("Content-Length: %d\r\n" % len(html_content)) + if headers: + for key, value in headers.items(): + self.c.client_conn.wfile.write("%s: %s\r\n" % (key, value)) + self.c.client_conn.wfile.write("Connection: close\r\n") + self.c.client_conn.wfile.write("\r\n") + self.c.client_conn.wfile.write(html_content) + self.c.client_conn.wfile.flush() + + def hook_reconnect(self, upstream_request): + self.c.log("Hook reconnect function") + original_reconnect_func = self.c.server_reconnect + + def reconnect_http_proxy(): + self.c.log("Hooked reconnect function") + self.c.log("Hook: Run original reconnect") + original_reconnect_func(no_ssl=True) + self.c.log("Hook: Write CONNECT request to upstream proxy", [upstream_request._assemble_first_line()]) + self.c.server_conn.send(upstream_request._assemble()) + self.c.log("Hook: Read answer to CONNECT request from proxy") + resp = HTTPResponse.from_stream(self.c.server_conn.rfile, upstream_request.method) + if resp.code != 200: + raise ProxyError(resp.code, + "Cannot reestablish SSL connection with upstream proxy: \r\n" + str(resp.headers)) + self.c.log("Hook: Establish SSL with upstream proxy") + self.c.establish_ssl(server=True) + + self.c.server_reconnect = reconnect_http_proxy + + def ssl_upgrade(self): + """ + Upgrade the connection to SSL after an authority (CONNECT) request has been made. + If the authority request has been forwarded upstream (because we have another proxy server there), + money-patch the ConnectionHandler.server_reconnect function to resend the request on reconnect. + + This isn't particular beautiful code, but it isolates this rare edge-case from the + protocol-agnostic ConnectionHandler + """ + self.c.log("Received CONNECT request. Upgrading to SSL...") + self.c.mode = "transparent" + self.c.determine_conntype() + self.c.establish_ssl(server=True, client=True) + self.c.log("Upgrade to SSL completed.") + raise ConnectionTypeChange + + def process_request(self, request): + if self.c.mode == "regular": + self.authenticate(request) + if request.form_in == "authority" and self.c.client_conn.ssl_established: + raise http.HttpError(502, "Must not CONNECT on already encrypted connection") + + # If we have a CONNECT request, we might need to intercept + if request.form_in == "authority": + directly_addressed_at_mitmproxy = (self.c.mode == "regular" and not self.c.config.forward_proxy) + if directly_addressed_at_mitmproxy: + self.c.set_server_address((request.host, request.port), AddressPriority.FROM_PROTOCOL) + request.flow.server_conn = self.c.server_conn # Update server_conn attribute on the flow + self.c.client_conn.wfile.write( + 'HTTP/1.1 200 Connection established\r\n' + + ('Proxy-agent: %s\r\n' % self.c.server_version) + + '\r\n' + ) + self.c.client_conn.wfile.flush() + self.ssl_upgrade() # raises ConnectionTypeChange exception + + if self.c.mode == "regular": + if request.form_in == "authority": # forward mode + self.hook_reconnect(request) + elif request.form_in == "absolute": + if request.scheme != "http": + raise http.HttpError(400, "Invalid Request") + if not self.c.config.forward_proxy: + request.form_out = "origin" + self.c.set_server_address((request.host, request.port), AddressPriority.FROM_PROTOCOL) + request.flow.server_conn = self.c.server_conn # Update server_conn attribute on the flow + else: + raise http.HttpError(400, "Invalid request form (absolute-form or authority-form required)") + + def authenticate(self, request): + if self.c.config.authenticator: + if self.c.config.authenticator.authenticate(request.headers): + self.c.config.authenticator.clean(request.headers) + else: + raise HttpAuthenticationError(self.c.config.authenticator.auth_challenge_headers()) + return request.headers
\ No newline at end of file diff --git a/libmproxy/protocol/primitives.py b/libmproxy/protocol/primitives.py new file mode 100644 index 00000000..90191eeb --- /dev/null +++ b/libmproxy/protocol/primitives.py @@ -0,0 +1,130 @@ +from .. import stateobject, utils, version +from ..proxy import ServerConnection, ClientConnection +import copy + + +class BackreferenceMixin(object): + """ + If an attribute from the _backrefattr tuple is set, + this mixin sets a reference back on the attribute object. + Example: + e = Error() + f = Flow() + f.error = e + assert f is e.flow + """ + _backrefattr = tuple() + + def __setattr__(self, key, value): + super(BackreferenceMixin, self).__setattr__(key, value) + if key in self._backrefattr and value is not None: + setattr(value, self._backrefname, self) + + +class Error(stateobject.SimpleStateObject): + """ + An Error. + + This is distinct from an HTTP error response (say, a code 500), which + is represented by a normal Response object. This class is responsible + for indicating errors that fall outside of normal HTTP communications, + like interrupted connections, timeouts, protocol errors. + + Exposes the following attributes: + + flow: Flow object + msg: Message describing the error + timestamp: Seconds since the epoch + """ + def __init__(self, msg, timestamp=None): + """ + @type msg: str + @type timestamp: float + """ + self.flow = None # will usually be set by the flow backref mixin + self.msg = msg + self.timestamp = timestamp or utils.timestamp() + + _stateobject_attributes = dict( + msg=str, + timestamp=float + ) + + def __str__(self): + return self.msg + + @classmethod + def _from_state(cls, state): + f = cls(None) # the default implementation assumes an empty constructor. Override accordingly. + f._load_state(state) + return f + + def copy(self): + c = copy.copy(self) + return c + + +class Flow(stateobject.SimpleStateObject, BackreferenceMixin): + def __init__(self, conntype, client_conn, server_conn): + self.conntype = conntype + self.client_conn = client_conn + """@type: ClientConnection""" + self.server_conn = server_conn + """@type: ServerConnection""" + + self.error = None + """@type: Error""" + self._backup = None + + _backrefattr = ("error",) + _backrefname = "flow" + + _stateobject_attributes = dict( + error=Error, + client_conn=ClientConnection, + server_conn=ServerConnection, + conntype=str + ) + + def _get_state(self): + d = super(Flow, self)._get_state() + d.update(version=version.IVERSION) + return d + + def __eq__(self, other): + return self is other + + def copy(self): + f = copy.copy(self) + + f.client_conn = self.client_conn.copy() + f.server_conn = self.server_conn.copy() + + if self.error: + f.error = self.error.copy() + return f + + def modified(self): + """ + Has this Flow been modified? + """ + if self._backup: + return self._backup != self._get_state() + else: + return False + + def backup(self, force=False): + """ + Save a backup of this Flow, which can be reverted to using a + call to .revert(). + """ + if not self._backup: + self._backup = self._get_state() + + def revert(self): + """ + Revert to the last backed up state. + """ + if self._backup: + self._load_state(self._backup) + self._backup = None
\ No newline at end of file diff --git a/libmproxy/protocol/tcp.py b/libmproxy/protocol/tcp.py new file mode 100644 index 00000000..406a6f7b --- /dev/null +++ b/libmproxy/protocol/tcp.py @@ -0,0 +1,59 @@ +from . import ProtocolHandler +import select, socket +from cStringIO import StringIO + + +class TCPHandler(ProtocolHandler): + """ + TCPHandler acts as a generic TCP forwarder. + Data will be .log()ed, but not stored any further. + """ + def handle_messages(self): + conns = [self.c.client_conn.rfile, self.c.server_conn.rfile] + while not self.c.close: + r, _, _ = select.select(conns, [], [], 10) + for rfile in r: + if self.c.client_conn.rfile == rfile: + src, dst = self.c.client_conn, self.c.server_conn + direction = "-> tcp ->" + dst_str = "%s:%s" % self.c.server_conn.address()[:2] + else: + dst, src = self.c.client_conn, self.c.server_conn + direction = "<- tcp <-" + dst_str = "client" + + data = StringIO() + while range(4096): + # Do non-blocking select() to see if there is further data on in the buffer. + r, _, _ = select.select([rfile], [], [], 0) + if len(r): + d = rfile.read(1) + if d == "": # connection closed + break + data.write(d) + + """ + OpenSSL Connections have an internal buffer that might contain data altough everything is read + from the socket. Thankfully, connection.pending() returns the amount of bytes in this buffer, + so we can read it completely at once. + """ + if src.ssl_established: + data.write(rfile.read(src.connection.pending())) + else: # no data left, but not closed yet + break + data = data.getvalue() + + if data == "": # no data received, rfile is closed + self.c.log("Close writing connection to %s" % dst_str) + conns.remove(rfile) + if dst.ssl_established: + dst.connection.shutdown() + else: + dst.connection.shutdown(socket.SHUT_WR) + if len(conns) == 0: + self.c.close = True + break + + self.c.log("%s %s\r\n%s" % (direction, dst_str,data)) + dst.wfile.write(data) + dst.wfile.flush()
\ No newline at end of file diff --git a/libmproxy/proxy.py b/libmproxy/proxy.py index 0d53aef8..b6480822 100644 --- a/libmproxy/proxy.py +++ b/libmproxy/proxy.py @@ -1,13 +1,26 @@ -import os, socket, time -import threading +import os, socket, time, threading, copy from OpenSSL import SSL -from netlib import tcp, http, certutils, http_status, http_auth -import utils, flow, version, platform, controller - +from netlib import tcp, http, certutils, http_auth +import utils, version, platform, controller, stateobject TRANSPARENT_SSL_PORTS = [443, 8443] -KILL = 0 + +class AddressPriority(object): + """ + Enum that signifies the priority of the given address when choosing the destination host. + Higher is better (None < i) + """ + FORCE = 5 + """forward mode""" + MANUALLY_CHANGED = 4 + """user changed the target address in the ui""" + FROM_SETTINGS = 3 + """reverse proxy mode""" + FROM_CONNECTION = 2 + """derived from transparent resolver""" + FROM_PROTOCOL = 1 + """derived from protocol (e.g. absolute-form http requests)""" class ProxyError(Exception): @@ -15,7 +28,7 @@ class ProxyError(Exception): self.code, self.msg, self.headers = code, msg, headers def __str__(self): - return "ProxyError(%s, %s)"%(self.code, self.msg) + return "ProxyError(%s, %s)" % (self.code, self.msg) class Log: @@ -24,7 +37,8 @@ class Log: class ProxyConfig: - def __init__(self, certfile = None, cacert = None, clientcerts = None, no_upstream_cert=False, body_size_limit = None, reverse_proxy=None, forward_proxy=None, transparent_proxy=None, authenticator=None): + def __init__(self, certfile=None, cacert=None, clientcerts=None, no_upstream_cert=False, body_size_limit=None, + reverse_proxy=None, forward_proxy=None, transparent_proxy=None, authenticator=None): self.certfile = certfile self.cacert = cacert self.clientcerts = clientcerts @@ -37,49 +51,146 @@ class ProxyConfig: self.certstore = certutils.CertStore() -class ServerConnection(tcp.TCPClient): - def __init__(self, config, scheme, host, port, sni): - tcp.TCPClient.__init__(self, host, port) - self.config = config - self.scheme, self.sni = scheme, sni - self.requestcount = 0 - self.tcp_setup_timestamp = None - self.ssl_setup_timestamp = None +class ClientConnection(tcp.BaseHandler, stateobject.SimpleStateObject): + def __init__(self, client_connection, address, server): + if client_connection: # Eventually, this object is restored from state. We don't have a connection then. + tcp.BaseHandler.__init__(self, client_connection, address, server) + else: + self.connection = None + self.server = None + self.wfile = None + self.rfile = None + self.address = None + self.clientcert = None + + self.timestamp_start = utils.timestamp() + self.timestamp_end = None + self.timestamp_ssl_setup = None + + _stateobject_attributes = dict( + timestamp_start=float, + timestamp_end=float, + timestamp_ssl_setup=float + ) + + def _get_state(self): + d = super(ClientConnection, self)._get_state() + d.update( + address={"address": self.address(), "use_ipv6": self.address.use_ipv6}, + clientcert=self.cert.to_pem() if self.clientcert else None + ) + return d + + def _load_state(self, state): + super(ClientConnection, self)._load_state(state) + self.address = tcp.Address(**state["address"]) if state["address"] else None + self.clientcert = certutils.SSLCert.from_pem(state["clientcert"]) if state["clientcert"] else None + + def copy(self): + return copy.copy(self) + + def send(self, message): + self.wfile.write(message) + self.wfile.flush() + + @classmethod + def _from_state(cls, state): + f = cls(None, tuple(), None) + f._load_state(state) + return f + + def convert_to_ssl(self, *args, **kwargs): + tcp.BaseHandler.convert_to_ssl(self, *args, **kwargs) + self.timestamp_ssl_setup = utils.timestamp() + + def finish(self): + tcp.BaseHandler.finish(self) + self.timestamp_end = utils.timestamp() + + +class ServerConnection(tcp.TCPClient, stateobject.SimpleStateObject): + def __init__(self, address, priority): + tcp.TCPClient.__init__(self, address) + self.priority = priority + + self.peername = None + self.timestamp_start = None + self.timestamp_end = None + self.timestamp_tcp_setup = None + self.timestamp_ssl_setup = None + + _stateobject_attributes = dict( + peername=tuple, + timestamp_start=float, + timestamp_end=float, + timestamp_tcp_setup=float, + timestamp_ssl_setup=float, + address=tcp.Address, + source_address=tcp.Address, + cert=certutils.SSLCert, + ssl_established=bool, + sni=str + ) + + def _get_state(self): + d = super(ServerConnection, self)._get_state() + d.update( + address={"address": self.address(), "use_ipv6": self.address.use_ipv6}, + source_address= {"address": self.source_address(), + "use_ipv6": self.source_address.use_ipv6} if self.source_address else None, + cert=self.cert.to_pem() if self.cert else None + ) + return d + + def _load_state(self, state): + super(ServerConnection, self)._load_state(state) + + self.address = tcp.Address(**state["address"]) if state["address"] else None + self.source_address = tcp.Address(**state["source_address"]) if state["source_address"] else None + self.cert = certutils.SSLCert.from_pem(state["cert"]) if state["cert"] else None + + @classmethod + def _from_state(cls, state): + f = cls(tuple(), None) + f._load_state(state) + return f + + def copy(self): + return copy.copy(self) def connect(self): + self.timestamp_start = utils.timestamp() tcp.TCPClient.connect(self) - self.tcp_setup_timestamp = time.time() - if self.scheme == "https": - clientcert = None - if self.config.clientcerts: - path = os.path.join(self.config.clientcerts, self.host.encode("idna")) + ".pem" - if os.path.exists(path): - clientcert = path - try: - self.convert_to_ssl(cert=clientcert, sni=self.sni) - self.ssl_setup_timestamp = time.time() - except tcp.NetLibError, v: - raise ProxyError(400, str(v)) - - def send(self, request): - self.requestcount += 1 - d = request._assemble() - if not d: - raise ProxyError(502, "Cannot transmit an incomplete request.") - self.wfile.write(d) + self.peername = self.connection.getpeername() + self.timestamp_tcp_setup = utils.timestamp() + + def send(self, message): + self.wfile.write(message) self.wfile.flush() - def terminate(self): - if self.connection: - try: - self.wfile.flush() - except tcp.NetLibDisconnect: # pragma: no cover - pass - self.connection.close() + def establish_ssl(self, clientcerts, sni): + clientcert = None + if clientcerts: + path = os.path.join(clientcerts, self.address.host.encode("idna")) + ".pem" + if os.path.exists(path): + clientcert = path + try: + self.convert_to_ssl(cert=clientcert, sni=sni) + self.timestamp_ssl_setup = utils.timestamp() + except tcp.NetLibError, v: + raise ProxyError(400, str(v)) + + def finish(self): + tcp.TCPClient.finish(self) + self.timestamp_end = utils.timestamp() +from . import protocol +from .protocol.http import HTTPResponse class RequestReplayThread(threading.Thread): + name="RequestReplayThread" + def __init__(self, config, flow, masterq): self.config, self.flow, self.channel = config, flow, controller.Channel(masterq) threading.Thread.__init__(self) @@ -87,448 +198,275 @@ class RequestReplayThread(threading.Thread): def run(self): try: r = self.flow.request - server = ServerConnection(self.config, r.scheme, r.host, r.port, r.host) + server = ServerConnection(self.flow.server_conn.address(), None) server.connect() - server.send(r) - httpversion, code, msg, headers, content = http.read_response( - server.rfile, r.method, self.config.body_size_limit - ) - response = flow.Response( - self.flow.request, httpversion, code, msg, headers, content, server.cert, - server.rfile.first_byte_timestamp - ) - self.channel.ask("response", response) + if self.flow.server_conn.ssl_established: + server.establish_ssl(self.config.clientcerts, + self.flow.server_conn.sni) + server.send(r._assemble()) + self.flow.response = HTTPResponse.from_stream(server.rfile, r.method, body_size_limit=self.config.body_size_limit) + self.channel.ask("response", self.flow.response) except (ProxyError, http.HttpError, tcp.NetLibError), v: - err = flow.Error(self.flow.request, str(v)) - self.channel.ask("error", err) + self.flow.error = protocol.primitives.Error(str(v)) + self.channel.ask("error", self.flow.error) + + +class ConnectionHandler: + def __init__(self, config, client_connection, client_address, server, channel, server_version): + self.config = config + self.client_conn = ClientConnection(client_connection, client_address, server) + self.server_conn = None + self.channel, self.server_version = channel, server_version + + self.close = False + self.conntype = None + self.sni = None + + self.mode = "regular" + if self.config.reverse_proxy: + self.mode = "reverse" + if self.config.transparent_proxy: + self.mode = "transparent" + def handle(self): + self.log("clientconnect") + self.channel.ask("clientconnect", self) -class HandleSNI: - def __init__(self, handler, client_conn, host, port, key): - self.handler, self.client_conn, self.host, self.port = handler, client_conn, host, port - self.key = key + self.determine_conntype() - def __call__(self, client_connection): try: - sn = client_connection.get_servername() - if sn: - self.handler.get_server_connection(self.client_conn, "https", self.host, self.port, sn) - dummycert = self.handler.find_cert(self.client_conn, self.host, self.port, sn) - new_context = SSL.Context(SSL.TLSv1_METHOD) - new_context.use_privatekey_file(self.key) - new_context.use_certificate(dummycert.x509) - client_connection.set_context(new_context) - self.handler.sni = sn.decode("utf8").encode("idna") - # An unhandled exception in this method will core dump PyOpenSSL, so - # make dang sure it doesn't happen. - except Exception: # pragma: no cover - pass + try: + # Can we already identify the target server and connect to it? + server_address = None + address_priority = None + if self.config.forward_proxy: + server_address = self.config.forward_proxy[1:] + address_priority = AddressPriority.FORCE + elif self.config.reverse_proxy: + server_address = self.config.reverse_proxy[1:] + address_priority = AddressPriority.FROM_SETTINGS + elif self.config.transparent_proxy: + server_address = self.config.transparent_proxy["resolver"].original_addr( + self.client_conn.connection) + if not server_address: + raise ProxyError(502, "Transparent mode failure: could not resolve original destination.") + address_priority = AddressPriority.FROM_CONNECTION + self.log("transparent to %s:%s" % server_address) + + if server_address: + self.set_server_address(server_address, address_priority) + self._handle_ssl() + + while not self.close: + try: + protocol.handle_messages(self.conntype, self) + except protocol.ConnectionTypeChange: + self.log("Connection Type Changed: %s" % self.conntype) + continue + + # FIXME: Do we want to persist errors? + except (ProxyError, tcp.NetLibError), e: + protocol.handle_error(self.conntype, self, e) + except Exception, e: + self.log(e.__class__) + import traceback + self.log(traceback.format_exc()) + self.log(str(e)) + self.del_server_connection() + self.log("clientdisconnect") + self.channel.tell("clientdisconnect", self) -class ProxyHandler(tcp.BaseHandler): - def __init__(self, config, connection, client_address, server, channel, server_version): - self.channel, self.server_version = channel, server_version - self.config = config - self.proxy_connect_state = None - self.sni = None - self.server_conn = None - tcp.BaseHandler.__init__(self, connection, client_address, server) + def _handle_ssl(self): + """ + Helper function of .handle() + Check if we can already identify SSL connections. + If so, connect to the server and establish an SSL connection + """ + client_ssl = False + server_ssl = False - def get_server_connection(self, cc, scheme, host, port, sni, request=None): + if self.config.transparent_proxy: + client_ssl = server_ssl = (self.server_conn.address.port in self.config.transparent_proxy["sslports"]) + elif self.config.reverse_proxy: + client_ssl = server_ssl = (self.config.reverse_proxy[0] == "https") + # TODO: Make protocol generic (as with transparent proxies) + # TODO: Add SSL-terminating capatbility (SSL -> mitmproxy -> plain and vice versa) + if client_ssl or server_ssl: + self.establish_server_connection() + self.establish_ssl(client=client_ssl, server=server_ssl) + + def del_server_connection(self): """ - When SNI is in play, this means we have an SSL-encrypted - connection, which means that the entire handler is dedicated to a - single server connection - no multiplexing. If this assumption ever - breaks, we'll have to do something different with the SNI host - variable on the handler object. - - `conn_info` holds the initial connection's parameters, as the - hook might change them. Also, the hook might require an initial - request to figure out connection settings; in this case it can - set require_request, which will cause the connection to be - re-opened after the client's request arrives. + Deletes an existing server connection. """ - sc = self.server_conn - if not sni: - sni = host - conn_info = (scheme, host, port, sni) - if sc and (conn_info != sc.conn_info or (request and sc.require_request)): - sc.terminate() - self.server_conn = None - self.log( - cc, - "switching connection", [ - "%s://%s:%s (sni=%s) -> %s://%s:%s (sni=%s)"%( - scheme, host, port, sni, - sc.scheme, sc.host, sc.port, sc.sni - ) - ] - ) - if not self.server_conn: - try: - self.server_conn = ServerConnection(self.config, scheme, host, port, sni) + if self.server_conn and self.server_conn.connection: + self.server_conn.finish() + self.log("serverdisconnect", ["%s:%s" % (self.server_conn.address.host, self.server_conn.address.port)]) + self.channel.tell("serverdisconnect", self) + self.server_conn = None + self.sni = None - # Additional attributes, used if the server_connect hook - # needs to change parameters - self.server_conn.request = request - self.server_conn.require_request = False + def determine_conntype(self): + #TODO: Add ruleset to select correct protocol depending on mode/target port etc. + self.conntype = "http" - self.server_conn.conn_info = conn_info - self.channel.ask("serverconnect", self.server_conn) - self.server_conn.connect() - except tcp.NetLibError, v: - raise ProxyError(502, v) - return self.server_conn + def set_server_address(self, address, priority): + """ + Sets a new server address with the given priority. + Does not re-establish either connection or SSL handshake. + @type priority: AddressPriority + """ + address = tcp.Address.wrap(address) - def del_server_connection(self): if self.server_conn: - self.server_conn.terminate() - self.server_conn = None + if self.server_conn.priority > priority: + self.log("Attempt to change server address, " + "but priority is too low (is: %s, got: %s)" % (self.server_conn.priority, priority)) + return + if self.server_conn.address == address: + self.server_conn.priority = priority # Possibly increase priority + return - def handle(self): - cc = flow.ClientConnect(self.client_address) - self.log(cc, "connect") - self.channel.ask("clientconnect", cc) - while self.handle_request(cc) and not cc.close: - pass - cc.close = True - self.del_server_connection() + self.del_server_connection() - cd = flow.ClientDisconnect(cc) - self.log( - cc, "disconnect", - [ - "handled %s requests"%cc.requestcount] - ) - self.channel.tell("clientdisconnect", cd) + self.log("Set new server address: %s:%s" % (address.host, address.port)) + self.server_conn = ServerConnection(address, priority) - def handle_request(self, cc): + def establish_server_connection(self): + """ + Establishes a new server connection. + If there is already an existing server connection, the function returns immediately. + """ + if self.server_conn.connection: + return + self.log("serverconnect", ["%s:%s" % self.server_conn.address()[:2]]) + self.channel.tell("serverconnect", self) try: - request, err = None, None - request = self.read_request(cc) - if request is None: - return - cc.requestcount += 1 + self.server_conn.connect() + except tcp.NetLibError, v: + raise ProxyError(502, v) - request_reply = self.channel.ask("request", request) - if request_reply is None or request_reply == KILL: - return - elif isinstance(request_reply, flow.Response): - request = False - response = request_reply - response_reply = self.channel.ask("response", response) - else: - request = request_reply - if self.config.reverse_proxy: - scheme, host, port = self.config.reverse_proxy - elif self.config.forward_proxy: - scheme, host, port = self.config.forward_proxy - else: - scheme, host, port = request.scheme, request.host, request.port - - # If we've already pumped a request over this connection, - # it's possible that the server has timed out. If this is - # the case, we want to reconnect without sending an error - # to the client. - while 1: - sc = self.get_server_connection(cc, scheme, host, port, self.sni, request=request) - sc.send(request) - if sc.requestcount == 1: # add timestamps only for first request (others are not directly affected) - request.tcp_setup_timestamp = sc.tcp_setup_timestamp - request.ssl_setup_timestamp = sc.ssl_setup_timestamp - sc.rfile.reset_timestamps() - try: - peername = sc.connection.getpeername() - if peername: - request.ip = peername[0] - httpversion, code, msg, headers, content = http.read_response( - sc.rfile, - request.method, - self.config.body_size_limit - ) - except http.HttpErrorConnClosed: - self.del_server_connection() - if sc.requestcount > 1: - continue - else: - raise - except http.HttpError: - raise ProxyError(502, "Invalid server response.") - else: - break - - response = flow.Response( - request, httpversion, code, msg, headers, content, sc.cert, - sc.rfile.first_byte_timestamp - ) - response_reply = self.channel.ask("response", response) - # Not replying to the server invalidates the server - # connection, so we terminate. - if response_reply == KILL: - sc.terminate() - - if response_reply == KILL: - return - else: - response = response_reply - self.send_response(response) - if request and http.connection_close(request.httpversion, request.headers): - return - # We could keep the client connection when the server - # connection needs to go away. However, we want to mimic - # behaviour as closely as possible to the client, so we - # disconnect. - if http.connection_close(response.httpversion, response.headers): - return - except (IOError, ProxyError, http.HttpError, tcp.NetLibError), e: - if hasattr(e, "code"): - cc.error = "%s: %s"%(e.code, e.msg) - else: - cc.error = str(e) - - if request: - err = flow.Error(request, cc.error) - self.channel.ask("error", err) - self.log( - cc, cc.error, - ["url: %s"%request.get_url()] - ) - else: - self.log(cc, cc.error) - if isinstance(e, ProxyError): - self.send_error(e.code, e.msg, e.headers) - else: - return True + def establish_ssl(self, client=False, server=False): + """ + Establishes SSL on the existing connection(s) to the server or the client, + as specified by the parameters. If the target server is on the pass-through list, + the conntype attribute will be changed and the SSL connection won't be wrapped. + A protocol handler must raise a ConnTypeChanged exception if it detects that this is happening + """ + # TODO: Implement SSL pass-through handling and change conntype + passthrough = [ + # "echo.websocket.org", + # "174.129.224.73" # echo.websocket.org, transparent mode + ] + if self.server_conn.address.host in passthrough or self.sni in passthrough: + self.conntype = "tcp" + return + + # Logging + if client or server: + subs = [] + if client: + subs.append("with client") + if server: + subs.append("with server (sni: %s)" % self.sni) + self.log("Establish SSL", subs) + + if server: + if self.server_conn.ssl_established: + raise ProxyError(502, "SSL to Server already established.") + self.establish_server_connection() # make sure there is a server connection. + self.server_conn.establish_ssl(self.config.clientcerts, self.sni) + if client: + if self.client_conn.ssl_established: + raise ProxyError(502, "SSL to Client already established.") + dummycert = self.find_cert() + self.client_conn.convert_to_ssl(dummycert, self.config.certfile or self.config.cacert, + handle_sni=self.handle_sni) + + def server_reconnect(self, no_ssl=False): + address = self.server_conn.address + had_ssl = self.server_conn.ssl_established + priority = self.server_conn.priority + sni = self.sni + self.log("(server reconnect follows)") + self.del_server_connection() + self.set_server_address(address, priority) + self.establish_server_connection() + if had_ssl and not no_ssl: + self.sni = sni + self.establish_ssl(server=True) + + def finish(self): + self.client_conn.finish() - def log(self, cc, msg, subs=()): + def log(self, msg, subs=()): msg = [ - "%s:%s: "%cc.address + msg + "%s:%s: %s" % (self.client_conn.address.host, self.client_conn.address.port, msg) ] for i in subs: - msg.append(" -> "+i) + msg.append(" -> " + i) msg = "\n".join(msg) - l = Log(msg) - self.channel.tell("log", l) + self.channel.tell("log", Log(msg)) - def find_cert(self, cc, host, port, sni): + def find_cert(self): if self.config.certfile: with open(self.config.certfile, "rb") as f: return certutils.SSLCert.from_pem(f.read()) else: + host = self.server_conn.address.host sans = [] - if not self.config.no_upstream_cert: - conn = self.get_server_connection(cc, "https", host, port, sni) - sans = conn.cert.altnames - if conn.cert.cn: - host = conn.cert.cn.decode("utf8").encode("idna") + if not self.config.no_upstream_cert or not self.server_conn.ssl_established: + upstream_cert = self.server_conn.cert + if upstream_cert.cn: + host = upstream_cert.cn.decode("utf8").encode("idna") + sans = upstream_cert.altnames + ret = self.config.certstore.get_cert(host, sans, self.config.cacert) if not ret: raise ProxyError(502, "Unable to generate dummy cert.") return ret - def establish_ssl(self, client_conn, host, port): - dummycert = self.find_cert(client_conn, host, port, host) - sni = HandleSNI( - self, client_conn, host, port, self.config.certfile or self.config.cacert - ) - try: - self.convert_to_ssl(dummycert, self.config.certfile or self.config.cacert, handle_sni=sni) - except tcp.NetLibError, v: - raise ProxyError(400, str(v)) - - def get_line(self, fp): + def handle_sni(self, connection): """ - Get a line, possibly preceded by a blank. + This callback gets called during the SSL handshake with the client. + The client has just sent the Sever Name Indication (SNI). We now connect upstream to + figure out which certificate needs to be served. """ - line = fp.readline() - if line == "\r\n" or line == "\n": # Possible leftover from previous message - line = fp.readline() - return line - - def read_request(self, client_conn): - self.rfile.reset_timestamps() - if self.config.transparent_proxy: - return self.read_request_transparent(client_conn) - elif self.config.reverse_proxy: - return self.read_request_reverse(client_conn) - else: - return self.read_request_proxy(client_conn) - - def read_request_transparent(self, client_conn): - orig = self.config.transparent_proxy["resolver"].original_addr(self.connection) - if not orig: - raise ProxyError(502, "Transparent mode failure: could not resolve original destination.") - self.log(client_conn, "transparent to %s:%s"%orig) - - host, port = orig - if port in self.config.transparent_proxy["sslports"]: - scheme = "https" - else: - scheme = "http" - - return self._read_request_origin_form(client_conn, scheme, host, port) - - def read_request_reverse(self, client_conn): - scheme, host, port = self.config.reverse_proxy - return self._read_request_origin_form(client_conn, scheme, host, port) - - def read_request_proxy(self, client_conn): - # Check for a CONNECT command. - if not self.proxy_connect_state: - line = self.get_line(self.rfile) - if line == "": - return None - self.proxy_connect_state = self._read_request_authority_form(line) - - # Check for an actual request - if self.proxy_connect_state: - host, port, _ = self.proxy_connect_state - return self._read_request_origin_form(client_conn, "https", host, port) - else: - # noinspection PyUnboundLocalVariable - return self._read_request_absolute_form(client_conn, line) - - def _read_request_authority_form(self, line): - """ - The authority-form of request-target is only used for CONNECT requests. - The CONNECT method is used to request a tunnel to the destination server. - This function sends a "200 Connection established" response to the client - and returns the host information that can be used to process further requests in origin-form. - An example authority-form request line would be: - CONNECT www.example.com:80 HTTP/1.1 - """ - connparts = http.parse_init_connect(line) - if connparts: - self.read_headers(authenticate=True) - # respond according to http://tools.ietf.org/html/draft-luotonen-web-proxy-tunneling-01 section 3.2 - self.wfile.write( - 'HTTP/1.1 200 Connection established\r\n' + - ('Proxy-agent: %s\r\n'%self.server_version) + - '\r\n' - ) - self.wfile.flush() - return connparts - - def _read_request_absolute_form(self, client_conn, line): - """ - When making a request to a proxy (other than CONNECT or OPTIONS), - a client must send the target uri in absolute-form. - An example absolute-form request line would be: - GET http://www.example.com/foo.html HTTP/1.1 - """ - r = http.parse_init_proxy(line) - if not r: - raise ProxyError(400, "Bad HTTP request line: %s"%repr(line)) - method, scheme, host, port, path, httpversion = r - headers = self.read_headers(authenticate=True) - self.handle_expect_header(headers, httpversion) - content = http.read_http_body( - self.rfile, headers, self.config.body_size_limit, True - ) - r = flow.Request( - client_conn, httpversion, host, port, scheme, method, path, headers, content, - self.rfile.first_byte_timestamp, utils.timestamp() - ) - r.set_live(self.rfile, self.wfile) - return r - - def _read_request_origin_form(self, client_conn, scheme, host, port): - """ - Read a HTTP request with regular (origin-form) request line. - An example origin-form request line would be: - GET /foo.html HTTP/1.1 - - The request destination is already known from one of the following sources: - 1) transparent proxy: destination provided by platform resolver - 2) reverse proxy: fixed destination - 3) regular proxy: known from CONNECT command. - """ - if scheme.lower() == "https" and not self.ssl_established: - self.establish_ssl(client_conn, host, port) - - line = self.get_line(self.rfile) - if line == "": - return None - - r = http.parse_init_http(line) - if not r: - raise ProxyError(400, "Bad HTTP request line: %s"%repr(line)) - method, path, httpversion = r - headers = self.read_headers(authenticate=False) - self.handle_expect_header(headers, httpversion) - content = http.read_http_body( - self.rfile, headers, self.config.body_size_limit, True - ) - r = flow.Request( - client_conn, httpversion, host, port, scheme, method, path, headers, content, - self.rfile.first_byte_timestamp, utils.timestamp() - ) - r.set_live(self.rfile, self.wfile) - return r - - def handle_expect_header(self, headers, httpversion): - if "expect" in headers: - if "100-continue" in headers['expect'] and httpversion >= (1, 1): - #FIXME: Check if content-length is over limit - self.wfile.write('HTTP/1.1 100 Continue\r\n' - '\r\n') - del headers['expect'] - - def read_headers(self, authenticate=False): - headers = http.read_headers(self.rfile) - if headers is None: - raise ProxyError(400, "Invalid headers") - if authenticate and self.config.authenticator: - if self.config.authenticator.authenticate(headers): - self.config.authenticator.clean(headers) - else: - raise ProxyError( - 407, - "Proxy Authentication Required", - self.config.authenticator.auth_challenge_headers() - ) - return headers - - def send_response(self, response): - d = response._assemble() - if not d: - raise ProxyError(502, "Cannot transmit an incomplete response.") - self.wfile.write(d) - self.wfile.flush() - - def send_error(self, code, body, headers): try: - response = http_status.RESPONSES.get(code, "Unknown") - html_content = '<html><head>\n<title>%d %s</title>\n</head>\n<body>\n%s\n</body>\n</html>'%(code, response, body) - self.wfile.write("HTTP/1.1 %s %s\r\n" % (code, response)) - self.wfile.write("Server: %s\r\n"%self.server_version) - self.wfile.write("Content-type: text/html\r\n") - self.wfile.write("Content-Length: %d\r\n"%len(html_content)) - if headers: - for key, value in headers.items(): - self.wfile.write("%s: %s\r\n"%(key, value)) - self.wfile.write("Connection: close\r\n") - self.wfile.write("\r\n") - self.wfile.write(html_content) - self.wfile.flush() - except: + sn = connection.get_servername() + if sn and sn != self.sni: + self.sni = sn.decode("utf8").encode("idna") + self.log("SNI received: %s" % self.sni) + self.server_reconnect() # reconnect to upstream server with SNI + # Now, change client context to reflect changed certificate: + new_context = SSL.Context(SSL.TLSv1_METHOD) + new_context.use_privatekey_file(self.config.certfile or self.config.cacert) + dummycert = self.find_cert() + new_context.use_certificate(dummycert.x509) + connection.set_context(new_context) + # An unhandled exception in this method will core dump PyOpenSSL, so + # make dang sure it doesn't happen. + except Exception, e: # pragma: no cover pass -class ProxyServerError(Exception): pass +class ProxyServerError(Exception): + pass class ProxyServer(tcp.TCPServer): allow_reuse_address = True bound = True - def __init__(self, config, port, address='', server_version=version.NAMEVERSION): + + def __init__(self, config, port, host='', server_version=version.NAMEVERSION): """ Raises ProxyServerError if there's a startup problem. """ - self.config, self.port, self.address = config, port, address + self.config = config self.server_version = server_version try: - tcp.TCPServer.__init__(self, (address, port)) + tcp.TCPServer.__init__(self, (host, port)) except socket.error, v: raise ProxyServerError('Error starting proxy server: ' + v.strerror) self.channel = None @@ -540,14 +478,15 @@ class ProxyServer(tcp.TCPServer): def set_channel(self, channel): self.channel = channel - def handle_connection(self, request, client_address): - h = ProxyHandler(self.config, request, client_address, self, self.channel, self.server_version) + def handle_client_connection(self, conn, client_address): + h = ConnectionHandler(self.config, conn, client_address, self, self.channel, self.server_version) h.handle() h.finish() class DummyServer: bound = False + def __init__(self, config): self.config = config @@ -563,22 +502,21 @@ def certificate_option_group(parser): group = parser.add_argument_group("SSL") group.add_argument( "--cert", action="store", - type = str, dest="cert", default=None, - help = "User-created SSL certificate file." + type=str, dest="cert", default=None, + help="User-created SSL certificate file." ) group.add_argument( "--client-certs", action="store", - type = str, dest = "clientcerts", default=None, - help = "Client certificate directory." + type=str, dest="clientcerts", default=None, + help="Client certificate directory." ) - def process_proxy_options(parser, options): if options.cert: options.cert = os.path.expanduser(options.cert) if not os.path.exists(options.cert): - return parser.error("Manually created certificate does not exist: %s"%options.cert) + return parser.error("Manually created certificate does not exist: %s" % options.cert) cacert = os.path.join(options.confdir, "mitmproxy-ca.pem") cacert = os.path.expanduser(cacert) @@ -592,8 +530,8 @@ def process_proxy_options(parser, options): if not platform.resolver: return parser.error("Transparent mode not supported on this platform.") trans = dict( - resolver = platform.resolver(), - sslports = TRANSPARENT_SSL_PORTS + resolver=platform.resolver(), + sslports=TRANSPARENT_SSL_PORTS ) else: trans = None @@ -601,14 +539,14 @@ def process_proxy_options(parser, options): if options.reverse_proxy: rp = utils.parse_proxy_spec(options.reverse_proxy) if not rp: - return parser.error("Invalid reverse proxy specification: %s"%options.reverse_proxy) + return parser.error("Invalid reverse proxy specification: %s" % options.reverse_proxy) else: rp = None if options.forward_proxy: fp = utils.parse_proxy_spec(options.forward_proxy) if not fp: - return parser.error("Invalid forward proxy specification: %s"%options.forward_proxy) + return parser.error("Invalid forward proxy specification: %s" % options.forward_proxy) else: fp = None @@ -616,8 +554,8 @@ def process_proxy_options(parser, options): options.clientcerts = os.path.expanduser(options.clientcerts) if not os.path.exists(options.clientcerts) or not os.path.isdir(options.clientcerts): return parser.error( - "Client certificate directory does not exist or is not a directory: %s"%options.clientcerts - ) + "Client certificate directory does not exist or is not a directory: %s" % options.clientcerts + ) if (options.auth_nonanonymous or options.auth_singleuser or options.auth_htpasswd): if options.auth_singleuser: @@ -637,13 +575,13 @@ def process_proxy_options(parser, options): authenticator = http_auth.NullProxyAuth(None) return ProxyConfig( - certfile = options.cert, - cacert = cacert, - clientcerts = options.clientcerts, - body_size_limit = body_size_limit, - no_upstream_cert = options.no_upstream_cert, - reverse_proxy = rp, - forward_proxy = fp, - transparent_proxy = trans, - authenticator = authenticator + certfile=options.cert, + cacert=cacert, + clientcerts=options.clientcerts, + body_size_limit=body_size_limit, + no_upstream_cert=options.no_upstream_cert, + reverse_proxy=rp, + forward_proxy=fp, + transparent_proxy=trans, + authenticator=authenticator ) diff --git a/libmproxy/script.py b/libmproxy/script.py index 0912c9ae..d34d3383 100644 --- a/libmproxy/script.py +++ b/libmproxy/script.py @@ -108,7 +108,7 @@ def _handle_concurrent_reply(fn, o, args=[], kwargs={}): def run(): fn(*args, **kwargs) reply(o) - threading.Thread(target=run).start() + threading.Thread(target=run, name="ScriptThread").start() def concurrent(fn): diff --git a/libmproxy/stateobject.py b/libmproxy/stateobject.py new file mode 100644 index 00000000..a752999d --- /dev/null +++ b/libmproxy/stateobject.py @@ -0,0 +1,73 @@ +class StateObject(object): + def _get_state(self): + raise NotImplementedError # pragma: nocover + + def _load_state(self, state): + raise NotImplementedError # pragma: nocover + + @classmethod + def _from_state(cls, state): + raise NotImplementedError # pragma: nocover + # Usually, this function roughly equals to the following code: + # f = cls() + # f._load_state(state) + # return f + + def __eq__(self, other): + try: + return self._get_state() == other._get_state() + except AttributeError: # we may compare with something that's not a StateObject + return False + + +class SimpleStateObject(StateObject): + """ + A StateObject with opionated conventions that tries to keep everything DRY. + + Simply put, you agree on a list of attributes and their type. + Attributes can either be primitive types(str, tuple, bool, ...) or StateObject instances themselves. + SimpleStateObject uses this information for the default _get_state(), _from_state(s) and _load_state(s) methods. + Overriding _get_state or _load_state to add custom adjustments is always possible. + """ + + _stateobject_attributes = None # none by default to raise an exception if definition was forgotten + """ + An attribute-name -> class-or-type dict containing all attributes that should be serialized + If the attribute is a class, this class must be a subclass of StateObject. + """ + + def _get_state(self): + return {attr: self._get_state_attr(attr, cls) + for attr, cls in self._stateobject_attributes.iteritems()} + + def _get_state_attr(self, attr, cls): + """ + helper for _get_state. + returns the value of the given attribute + """ + val = getattr(self, attr) + if hasattr(val, "_get_state"): + return val._get_state() + else: + return val + + def _load_state(self, state): + for attr, cls in self._stateobject_attributes.iteritems(): + self._load_state_attr(attr, cls, state) + + def _load_state_attr(self, attr, cls, state): + """ + helper for _load_state. + loads the given attribute from the state. + """ + if state.get(attr, None) is None: + setattr(self, attr, None) + return + + curr = getattr(self, attr) + if hasattr(curr, "_load_state"): + curr._load_state(state[attr]) + elif hasattr(cls, "_from_state"): + setattr(self, attr, cls._from_state(state[attr])) + else: + setattr(self, attr, cls(state[attr]))
\ No newline at end of file |