aboutsummaryrefslogtreecommitdiffstats
path: root/libmproxy/flow.py
diff options
context:
space:
mode:
Diffstat (limited to 'libmproxy/flow.py')
-rw-r--r--libmproxy/flow.py1028
1 files changed, 59 insertions, 969 deletions
diff --git a/libmproxy/flow.py b/libmproxy/flow.py
index 76ca4f47..40786631 100644
--- a/libmproxy/flow.py
+++ b/libmproxy/flow.py
@@ -2,16 +2,19 @@
This module provides more sophisticated flow tracking. These match requests
with their responses, and provide filtering and interception facilities.
"""
-import hashlib, Cookie, cookielib, copy, re, urlparse, threading
-import time, urllib
-import tnetstring, filt, script, utils, encoding, proxy
-from email.utils import parsedate_tz, formatdate, mktime_tz
-from netlib import odict, http, certutils, wsgi
-import controller, version
+import base64
+import hashlib, Cookie, cookielib, re, threading
+import os
+from flask import request
+import requests
+import tnetstring, filt, script
+from netlib import odict, wsgi
+from .proxy import ClientConnection, ServerConnection # FIXME: remove circular dependency
+import controller, version, protocol
import app
-
-HDR_FORM_URLENCODED = "application/x-www-form-urlencoded"
-CONTENT_MISSING = 0
+from .protocol import KILL
+from .protocol.http import HTTPResponse, CONTENT_MISSING
+from .proxy import RequestReplayThread
ODict = odict.ODict
ODictCaseless = odict.ODictCaseless
@@ -32,11 +35,11 @@ class AppRegistry:
"""
Returns an WSGIAdaptor instance if request matches an app, or None.
"""
- if (request.host, request.port) in self.apps:
- return self.apps[(request.host, request.port)]
+ if (request.get_host(), request.get_port()) in self.apps:
+ return self.apps[(request.get_host(), request.get_port())]
if "host" in request.headers:
host = request.headers["host"][0]
- return self.apps.get((host, request.port), None)
+ return self.apps.get((host, request.get_port()), None)
class ReplaceHooks:
@@ -143,769 +146,6 @@ class SetHeaders:
f.request.headers.add(header, value)
-class decoded(object):
- """
-
- A context manager that decodes a request, response or error, and then
- re-encodes it with the same encoding after execution of the block.
-
- Example:
-
- with decoded(request):
- request.content = request.content.replace("foo", "bar")
- """
- def __init__(self, o):
- self.o = o
- ce = o.headers.get_first("content-encoding")
- if ce in encoding.ENCODINGS:
- self.ce = ce
- else:
- self.ce = None
-
- def __enter__(self):
- if self.ce:
- self.o.decode()
-
- def __exit__(self, type, value, tb):
- if self.ce:
- self.o.encode(self.ce)
-
-
-class StateObject:
- def __eq__(self, other):
- try:
- return self._get_state() == other._get_state()
- except AttributeError:
- return False
-
-
-class HTTPMsg(StateObject):
- def get_decoded_content(self):
- """
- Returns the decoded content based on the current Content-Encoding header.
- Doesn't change the message iteself or its headers.
- """
- ce = self.headers.get_first("content-encoding")
- if not self.content or ce not in encoding.ENCODINGS:
- return self.content
- return encoding.decode(ce, self.content)
-
- def decode(self):
- """
- Decodes content based on the current Content-Encoding header, then
- removes the header. If there is no Content-Encoding header, no
- action is taken.
-
- Returns True if decoding succeeded, False otherwise.
- """
- ce = self.headers.get_first("content-encoding")
- if not self.content or ce not in encoding.ENCODINGS:
- return False
- data = encoding.decode(
- ce,
- self.content
- )
- if data is None:
- return False
- self.content = data
- del self.headers["content-encoding"]
- return True
-
- def encode(self, e):
- """
- Encodes content with the encoding e, where e is "gzip", "deflate"
- or "identity".
- """
- # FIXME: Error if there's an existing encoding header?
- self.content = encoding.encode(e, self.content)
- self.headers["content-encoding"] = [e]
-
- def size(self, **kwargs):
- """
- Size in bytes of a fully rendered message, including headers and
- HTTP lead-in.
- """
- hl = len(self._assemble_head(**kwargs))
- if self.content:
- return hl + len(self.content)
- else:
- return hl
-
- def get_content_type(self):
- return self.headers.get_first("content-type")
-
- def get_transmitted_size(self):
- # FIXME: this is inprecise in case chunking is used
- # (we should count the chunking headers)
- if not self.content:
- return 0
- return len(self.content)
-
-
-class Request(HTTPMsg):
- """
- An HTTP request.
-
- Exposes the following attributes:
-
- client_conn: ClientConnect object, or None if this is a replay.
-
- headers: ODictCaseless object
-
- content: Content of the request, None, or CONTENT_MISSING if there
- is content associated, but not present. CONTENT_MISSING evaluates
- to False to make checking for the presence of content natural.
-
- scheme: URL scheme (http/https)
-
- host: Host portion of the URL
-
- port: Destination port
-
- path: Path portion of the URL
-
- timestamp_start: Seconds since the epoch signifying request transmission started
-
- method: HTTP method
-
- timestamp_end: Seconds since the epoch signifying request transmission ended
-
- tcp_setup_timestamp: Seconds since the epoch signifying remote TCP connection setup completion time
- (or None, if request didn't results TCP setup)
-
- ssl_setup_timestamp: Seconds since the epoch signifying remote SSL encryption setup completion time
- (or None, if request didn't results SSL setup)
-
- """
- def __init__(
- self, client_conn, httpversion, host, port,
- scheme, method, path, headers, content, timestamp_start=None,
- timestamp_end=None, tcp_setup_timestamp=None,
- ssl_setup_timestamp=None, ip=None):
- assert isinstance(headers, ODictCaseless)
- self.client_conn = client_conn
- self.httpversion = httpversion
- self.host, self.port, self.scheme = host, port, scheme
- self.method, self.path, self.headers, self.content = method, path, headers, content
- self.timestamp_start = timestamp_start or utils.timestamp()
- self.timestamp_end = max(timestamp_end or utils.timestamp(), timestamp_start)
- self.close = False
- self.tcp_setup_timestamp = tcp_setup_timestamp
- self.ssl_setup_timestamp = ssl_setup_timestamp
- self.ip = ip
-
- # Have this request's cookies been modified by sticky cookies or auth?
- self.stickycookie = False
- self.stickyauth = False
-
- # Live attributes - not serialized
- self.wfile, self.rfile = None, None
-
- def set_live(self, rfile, wfile):
- self.wfile, self.rfile = wfile, rfile
-
- def is_live(self):
- return bool(self.wfile)
-
- def anticache(self):
- """
- Modifies this request to remove headers that might produce a cached
- response. That is, we remove ETags and If-Modified-Since headers.
- """
- delheaders = [
- "if-modified-since",
- "if-none-match",
- ]
- for i in delheaders:
- del self.headers[i]
-
- def anticomp(self):
- """
- Modifies this request to remove headers that will compress the
- resource's data.
- """
- self.headers["accept-encoding"] = ["identity"]
-
- def constrain_encoding(self):
- """
- Limits the permissible Accept-Encoding values, based on what we can
- decode appropriately.
- """
- if self.headers["accept-encoding"]:
- self.headers["accept-encoding"] = [', '.join(
- e for e in encoding.ENCODINGS if e in self.headers["accept-encoding"][0]
- )]
-
- def _set_replay(self):
- self.client_conn = None
-
- def is_replay(self):
- """
- Is this request a replay?
- """
- if self.client_conn:
- return False
- else:
- return True
-
- def _load_state(self, state):
- if state["client_conn"]:
- if self.client_conn:
- self.client_conn._load_state(state["client_conn"])
- else:
- self.client_conn = ClientConnect._from_state(state["client_conn"])
- else:
- self.client_conn = None
- self.host = state["host"]
- self.port = state["port"]
- self.scheme = state["scheme"]
- self.method = state["method"]
- self.path = state["path"]
- self.headers = ODictCaseless._from_state(state["headers"])
- self.content = state["content"]
- self.timestamp_start = state["timestamp_start"]
- self.timestamp_end = state["timestamp_end"]
- self.tcp_setup_timestamp = state["tcp_setup_timestamp"]
- self.ssl_setup_timestamp = state["ssl_setup_timestamp"]
- self.ip = state["ip"]
-
- def _get_state(self):
- return dict(
- client_conn = self.client_conn._get_state() if self.client_conn else None,
- httpversion = self.httpversion,
- host = self.host,
- port = self.port,
- scheme = self.scheme,
- method = self.method,
- path = self.path,
- headers = self.headers._get_state(),
- content = self.content,
- timestamp_start = self.timestamp_start,
- timestamp_end = self.timestamp_end,
- tcp_setup_timestamp = self.tcp_setup_timestamp,
- ssl_setup_timestamp = self.ssl_setup_timestamp,
- ip = self.ip
- )
-
- @classmethod
- def _from_state(klass, state):
- return klass(
- ClientConnect._from_state(state["client_conn"]),
- tuple(state["httpversion"]),
- str(state["host"]),
- state["port"],
- str(state["scheme"]),
- str(state["method"]),
- str(state["path"]),
- ODictCaseless._from_state(state["headers"]),
- state["content"],
- state["timestamp_start"],
- state["timestamp_end"],
- state["tcp_setup_timestamp"],
- state["ssl_setup_timestamp"],
- state["ip"]
- )
-
- def __hash__(self):
- return id(self)
-
- def copy(self):
- c = copy.copy(self)
- c.headers = self.headers.copy()
- return c
-
- def get_form_urlencoded(self):
- """
- Retrieves the URL-encoded form data, returning an ODict object.
- Returns an empty ODict if there is no data or the content-type
- indicates non-form data.
- """
- if self.content and self.headers.in_any("content-type", HDR_FORM_URLENCODED, True):
- return ODict(utils.urldecode(self.content))
- return ODict([])
-
- def set_form_urlencoded(self, odict):
- """
- Sets the body to the URL-encoded form data, and adds the
- appropriate content-type header. Note that this will destory the
- existing body if there is one.
- """
- # FIXME: If there's an existing content-type header indicating a
- # url-encoded form, leave it alone.
- self.headers["Content-Type"] = [HDR_FORM_URLENCODED]
- self.content = utils.urlencode(odict.lst)
-
- def get_path_components(self):
- """
- Returns the path components of the URL as a list of strings.
-
- Components are unquoted.
- """
- _, _, path, _, _, _ = urlparse.urlparse(self.get_url())
- return [urllib.unquote(i) for i in path.split("/") if i]
-
- def set_path_components(self, lst):
- """
- Takes a list of strings, and sets the path component of the URL.
-
- Components are quoted.
- """
- lst = [urllib.quote(i, safe="") for i in lst]
- path = "/" + "/".join(lst)
- scheme, netloc, _, params, query, fragment = urlparse.urlparse(self.get_url())
- self.set_url(urlparse.urlunparse([scheme, netloc, path, params, query, fragment]))
-
- def get_query(self):
- """
- Gets the request query string. Returns an ODict object.
- """
- _, _, _, _, query, _ = urlparse.urlparse(self.get_url())
- if query:
- return ODict(utils.urldecode(query))
- return ODict([])
-
- def set_query(self, odict):
- """
- Takes an ODict object, and sets the request query string.
- """
- scheme, netloc, path, params, _, fragment = urlparse.urlparse(self.get_url())
- query = utils.urlencode(odict.lst)
- self.set_url(urlparse.urlunparse([scheme, netloc, path, params, query, fragment]))
-
- def get_url(self, hostheader=False):
- """
- Returns a URL string, constructed from the Request's URL compnents.
-
- If hostheader is True, we use the value specified in the request
- Host header to construct the URL.
- """
- if hostheader:
- host = self.headers.get_first("host") or self.host
- else:
- host = self.host
- host = host.encode("idna")
- return utils.unparse_url(self.scheme, host, self.port, self.path).encode('ascii')
-
- def set_url(self, url):
- """
- Parses a URL specification, and updates the Request's information
- accordingly.
-
- Returns False if the URL was invalid, True if the request succeeded.
- """
- parts = http.parse_url(url)
- if not parts:
- return False
- self.scheme, self.host, self.port, self.path = parts
- return True
-
- def get_cookies(self):
- cookie_headers = self.headers.get("cookie")
- if not cookie_headers:
- return None
-
- cookies = []
- for header in cookie_headers:
- pairs = [pair.partition("=") for pair in header.split(';')]
- cookies.extend((pair[0],(pair[2],{})) for pair in pairs)
- return dict(cookies)
-
- def get_header_size(self):
- FMT = '%s %s HTTP/%s.%s\r\n%s\r\n'
- assembled_header = FMT % (
- self.method,
- self.path,
- self.httpversion[0],
- self.httpversion[1],
- str(self.headers)
- )
- return len(assembled_header)
-
- def _assemble_head(self, proxy=False):
- FMT = '%s %s HTTP/%s.%s\r\n%s\r\n'
- FMT_PROXY = '%s %s://%s:%s%s HTTP/%s.%s\r\n%s\r\n'
-
- headers = self.headers.copy()
- utils.del_all(
- headers,
- [
- 'proxy-connection',
- 'keep-alive',
- 'connection',
- 'transfer-encoding'
- ]
- )
- if not 'host' in headers:
- headers["host"] = [utils.hostport(self.scheme, self.host, self.port)]
- content = self.content
- if content:
- headers["Content-Length"] = [str(len(content))]
- else:
- content = ""
- if self.close:
- headers["connection"] = ["close"]
- if not proxy:
- return FMT % (
- self.method,
- self.path,
- self.httpversion[0],
- self.httpversion[1],
- str(headers)
- )
- else:
- return FMT_PROXY % (
- self.method,
- self.scheme,
- self.host,
- self.port,
- self.path,
- self.httpversion[0],
- self.httpversion[1],
- str(headers)
- )
-
- def _assemble(self, _proxy = False):
- """
- Assembles the request for transmission to the server. We make some
- modifications to make sure interception works properly.
-
- Returns None if the request cannot be assembled.
- """
- if self.content == CONTENT_MISSING:
- return None
- head = self._assemble_head(_proxy)
- if self.content:
- return head + self.content
- else:
- return head
-
- def replace(self, pattern, repl, *args, **kwargs):
- """
- Replaces a regular expression pattern with repl in both the headers
- and the body of the request. Encoded content will be decoded before
- replacement, and re-encoded afterwards.
-
- Returns the number of replacements made.
- """
- with decoded(self):
- self.content, c = utils.safe_subn(pattern, repl, self.content, *args, **kwargs)
- self.path, pc = utils.safe_subn(pattern, repl, self.path, *args, **kwargs)
- c += pc
- c += self.headers.replace(pattern, repl, *args, **kwargs)
- return c
-
-
-class Response(HTTPMsg):
- """
- An HTTP response.
-
- Exposes the following attributes:
-
- request: Request object.
-
- code: HTTP response code
-
- msg: HTTP response message
-
- headers: ODict object
-
- content: Content of the request, None, or CONTENT_MISSING if there
- is content associated, but not present. CONTENT_MISSING evaluates
- to False to make checking for the presence of content natural.
-
- timestamp_start: Seconds since the epoch signifying response transmission started
-
- timestamp_end: Seconds since the epoch signifying response transmission ended
- """
- def __init__(self, request, httpversion, code, msg, headers, content, cert, timestamp_start=None, timestamp_end=None):
- assert isinstance(headers, ODictCaseless)
- self.request = request
- self.httpversion, self.code, self.msg = httpversion, code, msg
- self.headers, self.content = headers, content
- self.cert = cert
- self.timestamp_start = timestamp_start or utils.timestamp()
- self.timestamp_end = timestamp_end or utils.timestamp()
- self.replay = False
-
- def _refresh_cookie(self, c, delta):
- """
- Takes a cookie string c and a time delta in seconds, and returns
- a refreshed cookie string.
- """
- c = Cookie.SimpleCookie(str(c))
- for i in c.values():
- if "expires" in i:
- d = parsedate_tz(i["expires"])
- if d:
- d = mktime_tz(d) + delta
- i["expires"] = formatdate(d)
- else:
- # This can happen when the expires tag is invalid.
- # reddit.com sends a an expires tag like this: "Thu, 31 Dec
- # 2037 23:59:59 GMT", which is valid RFC 1123, but not
- # strictly correct according tot he cookie spec. Browsers
- # appear to parse this tolerantly - maybe we should too.
- # For now, we just ignore this.
- del i["expires"]
- return c.output(header="").strip()
-
- def refresh(self, now=None):
- """
- This fairly complex and heuristic function refreshes a server
- response for replay.
-
- - It adjusts date, expires and last-modified headers.
- - It adjusts cookie expiration.
- """
- if not now:
- now = time.time()
- delta = now - self.timestamp_start
- refresh_headers = [
- "date",
- "expires",
- "last-modified",
- ]
- for i in refresh_headers:
- if i in self.headers:
- d = parsedate_tz(self.headers[i][0])
- if d:
- new = mktime_tz(d) + delta
- self.headers[i] = [formatdate(new)]
- c = []
- for i in self.headers["set-cookie"]:
- c.append(self._refresh_cookie(i, delta))
- if c:
- self.headers["set-cookie"] = c
-
- def _set_replay(self):
- self.replay = True
-
- def is_replay(self):
- """
- Is this response a replay?
- """
- return self.replay
-
- def _load_state(self, state):
- self.code = state["code"]
- self.msg = state["msg"]
- self.headers = ODictCaseless._from_state(state["headers"])
- self.content = state["content"]
- self.timestamp_start = state["timestamp_start"]
- self.timestamp_end = state["timestamp_end"]
- self.cert = certutils.SSLCert.from_pem(state["cert"]) if state["cert"] else None
-
- def _get_state(self):
- return dict(
- httpversion = self.httpversion,
- code = self.code,
- msg = self.msg,
- headers = self.headers._get_state(),
- timestamp_start = self.timestamp_start,
- timestamp_end = self.timestamp_end,
- cert = self.cert.to_pem() if self.cert else None,
- content = self.content,
- )
-
- @classmethod
- def _from_state(klass, request, state):
- return klass(
- request,
- state["httpversion"],
- state["code"],
- str(state["msg"]),
- ODictCaseless._from_state(state["headers"]),
- state["content"],
- certutils.SSLCert.from_pem(state["cert"]) if state["cert"] else None,
- state["timestamp_start"],
- state["timestamp_end"],
- )
-
- def copy(self):
- c = copy.copy(self)
- c.headers = self.headers.copy()
- return c
-
- def _assemble_head(self):
- FMT = '%s\r\n%s\r\n'
- headers = self.headers.copy()
- utils.del_all(
- headers,
- ['proxy-connection', 'transfer-encoding']
- )
- if self.content:
- headers["Content-Length"] = [str(len(self.content))]
- elif 'Transfer-Encoding' in self.headers:
- headers["Content-Length"] = ["0"]
- proto = "HTTP/%s.%s %s %s"%(self.httpversion[0], self.httpversion[1], self.code, str(self.msg))
- data = (proto, str(headers))
- return FMT%data
-
- def _assemble(self):
- """
- Assembles the response for transmission to the client. We make some
- modifications to make sure interception works properly.
-
- Returns None if the request cannot be assembled.
- """
- if self.content == CONTENT_MISSING:
- return None
- head = self._assemble_head()
- if self.content:
- return head + self.content
- else:
- return head
-
- def replace(self, pattern, repl, *args, **kwargs):
- """
- Replaces a regular expression pattern with repl in both the headers
- and the body of the response. Encoded content will be decoded
- before replacement, and re-encoded afterwards.
-
- Returns the number of replacements made.
- """
- with decoded(self):
- self.content, c = utils.safe_subn(pattern, repl, self.content, *args, **kwargs)
- c += self.headers.replace(pattern, repl, *args, **kwargs)
- return c
-
- def get_header_size(self):
- FMT = '%s\r\n%s\r\n'
- proto = "HTTP/%s.%s %s %s"%(self.httpversion[0], self.httpversion[1], self.code, str(self.msg))
- assembled_header = FMT % (proto, str(self.headers))
- return len(assembled_header)
-
- def get_cookies(self):
- cookie_headers = self.headers.get("set-cookie")
- if not cookie_headers:
- return None
-
- cookies = []
- for header in cookie_headers:
- pairs = [pair.partition("=") for pair in header.split(';')]
- cookie_name = pairs[0][0] # the key of the first key/value pairs
- cookie_value = pairs[0][2] # the value of the first key/value pairs
- cookie_parameters = {key.strip().lower():value.strip() for key,sep,value in pairs[1:]}
- cookies.append((cookie_name, (cookie_value, cookie_parameters)))
- return dict(cookies)
-
-
-class ClientDisconnect:
- """
- A client disconnection event.
-
- Exposes the following attributes:
-
- client_conn: ClientConnect object.
- """
- def __init__(self, client_conn):
- self.client_conn = client_conn
-
-
-class ClientConnect(StateObject):
- """
- A single client connection. Each connection can result in multiple HTTP
- Requests.
-
- Exposes the following attributes:
-
- address: (address, port) tuple, or None if the connection is replayed.
- requestcount: Number of requests created by this client connection.
- close: Is the client connection closed?
- error: Error string or None.
- """
- def __init__(self, address):
- """
- address is an (address, port) tuple, or None if this connection has
- been replayed from within mitmproxy.
- """
- self.address = address
- self.close = False
- self.requestcount = 0
- self.error = None
-
- def __str__(self):
- if self.address:
- return "%s:%d"%(self.address[0],self.address[1])
-
- def _load_state(self, state):
- self.close = True
- self.error = state["error"]
- self.requestcount = state["requestcount"]
-
- def _get_state(self):
- return dict(
- address = list(self.address),
- requestcount = self.requestcount,
- error = self.error,
- )
-
- @classmethod
- def _from_state(klass, state):
- if state:
- k = klass(state["address"])
- k._load_state(state)
- return k
- else:
- return None
-
- def copy(self):
- return copy.copy(self)
-
-
-class Error(StateObject):
- """
- An Error.
-
- This is distinct from an HTTP error response (say, a code 500), which
- is represented by a normal Response object. This class is responsible
- for indicating errors that fall outside of normal HTTP communications,
- like interrupted connections, timeouts, protocol errors.
-
- Exposes the following attributes:
-
- request: Request object
- msg: Message describing the error
- timestamp: Seconds since the epoch
- """
- def __init__(self, request, msg, timestamp=None):
- self.request, self.msg = request, msg
- self.timestamp = timestamp or utils.timestamp()
-
- def _load_state(self, state):
- self.msg = state["msg"]
- self.timestamp = state["timestamp"]
-
- def copy(self):
- c = copy.copy(self)
- return c
-
- def _get_state(self):
- return dict(
- msg = self.msg,
- timestamp = self.timestamp,
- )
-
- @classmethod
- def _from_state(klass, request, state):
- return klass(
- request,
- state["msg"],
- state["timestamp"],
- )
-
- def replace(self, pattern, repl, *args, **kwargs):
- """
- Replaces a regular expression pattern with repl in both the headers
- and the body of the request. Returns the number of replacements
- made.
-
- FIXME: Is replace useful on an Error object??
- """
- self.msg, c = utils.safe_subn(pattern, repl, self.msg, *args, **kwargs)
- return c
-
-
class ClientPlaybackState:
def __init__(self, flows, exit):
self.flows, self.exit = flows, exit
@@ -934,7 +174,7 @@ class ClientPlaybackState:
if self.flows and not self.current:
n = self.flows.pop(0)
n.request.reply = controller.DummyReply()
- n.request.client_conn = None
+ n.client_conn = None
self.current = master.handle_request(n.request)
if not testing and not self.current.response:
master.replay_request(self.current) # pragma: no cover
@@ -997,7 +237,6 @@ class ServerPlaybackState:
return l.pop(0)
-
class StickyCookieState:
def __init__(self, flt):
"""
@@ -1011,8 +250,8 @@ class StickyCookieState:
Returns a (domain, port, path) tuple.
"""
return (
- m["domain"] or f.request.host,
- f.request.port,
+ m["domain"] or f.request.get_host(),
+ f.request.get_port(),
m["path"] or "/"
)
@@ -1030,7 +269,7 @@ class StickyCookieState:
c = Cookie.SimpleCookie(str(i))
m = c.values()[0]
k = self.ckey(m, f)
- if self.domain_match(f.request.host, k[0]):
+ if self.domain_match(f.request.get_host(), k[0]):
self.jar[self.ckey(m, f)] = m
def handle_request(self, f):
@@ -1038,8 +277,8 @@ class StickyCookieState:
if f.match(self.flt):
for i in self.jar.keys():
match = [
- self.domain_match(f.request.host, i[0]),
- f.request.port == i[1],
+ self.domain_match(f.request.get_host(), i[0]),
+ f.request.get_port() == i[1],
f.request.path.startswith(i[2])
]
if all(match):
@@ -1058,177 +297,16 @@ class StickyAuthState:
self.hosts = {}
def handle_request(self, f):
+ host = f.request.get_host()
if "authorization" in f.request.headers:
- self.hosts[f.request.host] = f.request.headers["authorization"]
+ self.hosts[host] = f.request.headers["authorization"]
elif f.match(self.flt):
- if f.request.host in self.hosts:
- f.request.headers["authorization"] = self.hosts[f.request.host]
-
-
-class Flow:
- """
- A Flow is a collection of objects representing a single HTTP
- transaction. The main attributes are:
-
- request: Request object
- response: Response object
- error: Error object
-
- Note that it's possible for a Flow to have both a response and an error
- object. This might happen, for instance, when a response was received
- from the server, but there was an error sending it back to the client.
-
- The following additional attributes are exposed:
-
- intercepting: Is this flow currently being intercepted?
- """
- def __init__(self, request):
- self.request = request
- self.response, self.error = None, None
- self.intercepting = False
- self._backup = None
-
- def copy(self):
- rc = self.request.copy()
- f = Flow(rc)
- if self.response:
- f.response = self.response.copy()
- f.response.request = rc
- if self.error:
- f.error = self.error.copy()
- f.error.request = rc
- return f
-
- @classmethod
- def _from_state(klass, state):
- f = klass(None)
- f._load_state(state)
- return f
-
- def _get_state(self):
- d = dict(
- request = self.request._get_state() if self.request else None,
- response = self.response._get_state() if self.response else None,
- error = self.error._get_state() if self.error else None,
- version = version.IVERSION
- )
- return d
-
- def _load_state(self, state):
- if self.request:
- self.request._load_state(state["request"])
- else:
- self.request = Request._from_state(state["request"])
-
- if state["response"]:
- if self.response:
- self.response._load_state(state["response"])
- else:
- self.response = Response._from_state(self.request, state["response"])
- else:
- self.response = None
-
- if state["error"]:
- if self.error:
- self.error._load_state(state["error"])
- else:
- self.error = Error._from_state(self.request, state["error"])
- else:
- self.error = None
-
- def modified(self):
- """
- Has this Flow been modified?
- """
- # FIXME: Save a serialization in backup, compare current with
- # backup to detect if flow has _really_ been modified.
- if self._backup:
- return True
- else:
- return False
-
- def backup(self, force=False):
- """
- Save a backup of this Flow, which can be reverted to using a
- call to .revert().
- """
- if not self._backup:
- self._backup = self._get_state()
-
- def revert(self):
- """
- Revert to the last backed up state.
- """
- if self._backup:
- self._load_state(self._backup)
- self._backup = None
-
- def match(self, f):
- """
- Match this flow against a compiled filter expression. Returns True
- if matched, False if not.
-
- If f is a string, it will be compiled as a filter expression. If
- the expression is invalid, ValueError is raised.
- """
- if isinstance(f, basestring):
- f = filt.parse(f)
- if not f:
- raise ValueError("Invalid filter expression.")
- if f:
- return f(self)
- return True
-
- def kill(self, master):
- """
- Kill this request.
- """
- self.error = Error(self.request, "Connection killed")
- self.error.reply = controller.DummyReply()
- if self.request and not self.request.reply.acked:
- self.request.reply(proxy.KILL)
- elif self.response and not self.response.reply.acked:
- self.response.reply(proxy.KILL)
- master.handle_error(self.error)
- self.intercepting = False
-
- def intercept(self):
- """
- Intercept this Flow. Processing will stop until accept_intercept is
- called.
- """
- self.intercepting = True
-
- def accept_intercept(self):
- """
- Continue with the flow - called after an intercept().
- """
- if self.request:
- if not self.request.reply.acked:
- self.request.reply()
- elif self.response and not self.response.reply.acked:
- self.response.reply()
- self.intercepting = False
-
- def replace(self, pattern, repl, *args, **kwargs):
- """
- Replaces a regular expression pattern with repl in all parts of the
- flow. Encoded content will be decoded before replacement, and
- re-encoded afterwards.
-
- Returns the number of replacements made.
- """
- c = self.request.replace(pattern, repl, *args, **kwargs)
- if self.response:
- c += self.response.replace(pattern, repl, *args, **kwargs)
- if self.error:
- c += self.error.replace(pattern, repl, *args, **kwargs)
- return c
+ if host in self.hosts:
+ f.request.headers["authorization"] = self.hosts[host]
class State(object):
def __init__(self):
- self._flow_map = {}
self._flow_list = []
self.view = []
@@ -1242,7 +320,7 @@ class State(object):
return self._limit_txt
def flow_count(self):
- return len(self._flow_map)
+ return len(self._flow_list)
def index(self, f):
return self._flow_list.index(f)
@@ -1258,10 +336,8 @@ class State(object):
"""
Add a request to the state. Returns the matching flow.
"""
- f = Flow(req)
+ f = req.flow
self._flow_list.append(f)
- self._flow_map[req] = f
- assert len(self._flow_list) == len(self._flow_map)
if f.match(self._limit):
self.view.append(f)
return f
@@ -1270,10 +346,9 @@ class State(object):
"""
Add a response to the state. Returns the matching flow.
"""
- f = self._flow_map.get(resp.request)
+ f = resp.flow
if not f:
return False
- f.response = resp
if f.match(self._limit) and not f in self.view:
self.view.append(f)
return f
@@ -1283,18 +358,15 @@ class State(object):
Add an error response to the state. Returns the matching flow, or
None if there isn't one.
"""
- f = self._flow_map.get(err.request)
+ f = err.flow
if not f:
return None
- f.error = err
if f.match(self._limit) and not f in self.view:
self.view.append(f)
return f
def load_flows(self, flows):
self._flow_list.extend(flows)
- for i in flows:
- self._flow_map[i.request] = i
self.recalculate_view()
def set_limit(self, txt):
@@ -1327,8 +399,6 @@ class State(object):
self.view = self._flow_list[:]
def delete_flow(self, f):
- if f.request in self._flow_map:
- del self._flow_map[f.request]
self._flow_list.remove(f)
if f in self.view:
self.view.remove(f)
@@ -1383,7 +453,28 @@ class FlowMaster(controller.Master):
port
)
else:
- threading.Thread(target=app.mapp.run,kwargs={
+ @app.mapp.before_request
+ def patch_environ(*args, **kwargs):
+ request.environ["mitmproxy.master"] = self
+
+ # the only absurd way to shut down a flask/werkzeug server.
+ # http://flask.pocoo.org/snippets/67/
+ shutdown_secret = base64.b32encode(os.urandom(30))
+
+ @app.mapp.route('/shutdown/<secret>')
+ def shutdown(secret):
+ if secret == shutdown_secret:
+ request.environ.get('werkzeug.server.shutdown')()
+
+ # Workaround: Monkey-patch shutdown function to stop the app.
+ # Improve this when we switch flask werkzeug for something useful.
+ _shutdown = self.shutdown
+ def _shutdownwrap():
+ _shutdown()
+ requests.get("http://%s:%s/shutdown/%s" % (host, port, shutdown_secret))
+ self.shutdown = _shutdownwrap
+
+ threading.Thread(target=app.mapp.run, kwargs={
"use_reloader": False,
"host": host,
"port": port}).start()
@@ -1474,9 +565,8 @@ class FlowMaster(controller.Master):
rflow = self.server_playback.next_flow(flow)
if not rflow:
return None
- response = Response._from_state(flow.request, rflow.response._get_state())
- response._set_replay()
- flow.response = response
+ response = HTTPResponse._from_state(rflow.response._get_state())
+ response.is_replay = True
if self.refresh_server_playback:
response.refresh()
flow.request.reply(response)
@@ -1555,13 +645,13 @@ class FlowMaster(controller.Master):
if f.request.content == CONTENT_MISSING:
return "Can't replay request with missing content..."
if f.request:
- f.request._set_replay()
+ f.request.is_replay = True
if f.request.content:
f.request.headers["Content-Length"] = [str(len(f.request.content))]
f.response = None
f.error = None
self.process_new_request(f)
- rt = proxy.RequestReplayThread(
+ rt = RequestReplayThread(
self.server.config,
f,
self.masterq,
@@ -1597,13 +687,13 @@ class FlowMaster(controller.Master):
return f
def handle_request(self, r):
- if r.is_live():
+ if r.flow.client_conn and r.flow.client_conn.wfile:
app = self.apps.get(r)
if app:
- err = app.serve(r, r.wfile, **{"mitmproxy.master": self})
+ err = app.serve(r, r.flow.client_conn.wfile, **{"mitmproxy.master": self})
if err:
self.add_event("Error in wsgi app. %s"%err, "error")
- r.reply(proxy.KILL)
+ r.reply(KILL)
return
f = self.state.add_request(r)
self.replacehooks.run(f)
@@ -1676,7 +766,7 @@ class FlowReader:
v = ".".join(str(i) for i in data["version"])
raise FlowReadError("Incompatible serialized data version: %s"%v)
off = self.fo.tell()
- yield Flow._from_state(data)
+ yield protocol.protocols[data["conntype"]]["flow"]._from_state(data)
except ValueError, v:
# Error is due to EOF
if self.fo.tell() == off and self.fo.read() == '':