aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--mitmproxy/addons/session.py324
-rw-r--r--mitmproxy/io/protobuf.py2
-rw-r--r--mitmproxy/io/sql/session_create.sql2
-rw-r--r--test/mitmproxy/addons/test_session.py187
4 files changed, 507 insertions, 8 deletions
diff --git a/mitmproxy/addons/session.py b/mitmproxy/addons/session.py
index c49b95c4..63e382ec 100644
--- a/mitmproxy/addons/session.py
+++ b/mitmproxy/addons/session.py
@@ -1,12 +1,34 @@
+import collections
import tempfile
+import asyncio
+import typing
+import bisect
import shutil
import sqlite3
+import copy
import os
-from mitmproxy.exceptions import SessionLoadException
+from mitmproxy import flowfilter
+from mitmproxy import types
+from mitmproxy import http
+from mitmproxy import ctx
+from mitmproxy.io import protobuf
+from mitmproxy.exceptions import SessionLoadException, CommandError
from mitmproxy.utils.data import pkg_data
+class KeyifyList(object):
+ def __init__(self, inner, key):
+ self.inner = inner
+ self.key = key
+
+ def __len__(self):
+ return len(self.inner)
+
+ def __getitem__(self, k):
+ return self.key(self.inner[k])
+
+
# Could be implemented using async libraries
class SessionDB:
"""
@@ -14,6 +36,13 @@ class SessionDB:
for Sessions and handles creation,
retrieving and insertion in tables.
"""
+ content_threshold = 1000
+ type_mappings = {
+ "body": {
+ 1: "request",
+ 2: "response"
+ }
+ }
def __init__(self, db_path=None):
"""
@@ -21,8 +50,13 @@ class SessionDB:
or create a new one with optional path.
:param db_path:
"""
- self.tempdir = None
- self.con = None
+ self.live_components: typing.Dict[str, tuple] = {}
+ self.tempdir: tempfile.TemporaryDirectory = None
+ self.con: sqlite3.Connection = None
+ # This is used for fast look-ups over bodies already dumped to database.
+ # This permits to enforce one-to-one relationship between flow and body table.
+ self.body_ledger: typing.Set[str] = set()
+ self.id_ledger: typing.Set[str] = set()
if db_path is not None and os.path.isfile(db_path):
self._load_session(db_path)
else:
@@ -40,6 +74,12 @@ class SessionDB:
if self.tempdir:
shutil.rmtree(self.tempdir)
+ def __contains__(self, fid):
+ return fid in self.id_ledger
+
+ def __len__(self):
+ return len(self.id_ledger)
+
def _load_session(self, path):
if not self.is_session_db(path):
raise SessionLoadException('Given path does not point to a valid Session')
@@ -48,8 +88,8 @@ class SessionDB:
def _create_session(self):
script_path = pkg_data.path("io/sql/session_create.sql")
qry = open(script_path, 'r').read()
- with self.con:
- self.con.executescript(qry)
+ self.con.executescript(qry)
+ self.con.commit()
@staticmethod
def is_session_db(path):
@@ -58,6 +98,7 @@ class SessionDB:
is a valid Session SQLite DB.
:return: True if valid, False if invalid.
"""
+ c = None
try:
c = sqlite3.connect(f'file:{path}?mode=rw', uri=True)
cursor = c.cursor()
@@ -67,7 +108,278 @@ class SessionDB:
if all(elem in rows for elem in tables):
c.close()
return True
- except:
+ except sqlite3.Error:
if c:
c.close()
return False
+
+ def _disassemble(self, flow):
+ # Some live components of flows cannot be serialized, but they are needed to ensure correct functionality.
+ # We solve this by keeping a list of tuples which "save" those components for each flow id, eventually
+ # adding them back when needed.
+ self.live_components[flow.id] = (
+ flow.client_conn.wfile,
+ flow.client_conn.rfile,
+ flow.client_conn.reply,
+ flow.server_conn.wfile,
+ flow.server_conn.rfile,
+ flow.server_conn.reply,
+ (flow.server_conn.via.wfile, flow.server_conn.via.rfile,
+ flow.server_conn.via.reply) if flow.server_conn.via else None,
+ flow.reply
+ )
+
+ def _reassemble(self, flow):
+ if flow.id in self.live_components:
+ cwf, crf, crp, swf, srf, srp, via, rep = self.live_components[flow.id]
+ flow.client_conn.wfile = cwf
+ flow.client_conn.rfile = crf
+ flow.client_conn.reply = crp
+ flow.server_conn.wfile = swf
+ flow.server_conn.rfile = srf
+ flow.server_conn.reply = srp
+ flow.reply = rep
+ if via:
+ flow.server_conn.via.rfile, flow.server_conn.via.wfile, flow.server_conn.via.reply = via
+ return flow
+
+ def store_flows(self, flows):
+ body_buf = []
+ flow_buf = []
+ for flow in flows:
+ self.id_ledger.add(flow.id)
+ self._disassemble(flow)
+ f = copy.copy(flow)
+ f.request = copy.deepcopy(flow.request)
+ if flow.response:
+ f.response = copy.deepcopy(flow.response)
+ f.id = flow.id
+ if len(f.request.content) > self.content_threshold and f.id not in self.body_ledger:
+ body_buf.append((f.id, 1, f.request.content))
+ f.request.content = b""
+ self.body_ledger.add(f.id)
+ if f.response and f.id not in self.body_ledger:
+ if len(f.response.content) > self.content_threshold:
+ body_buf.append((f.id, 2, f.response.content))
+ f.response.content = b""
+ flow_buf.append((f.id, protobuf.dumps(f)))
+ self.con.executemany("INSERT OR REPLACE INTO flow VALUES(?, ?);", flow_buf)
+ if body_buf:
+ self.con.executemany("INSERT INTO body (flow_id, type_id, content) VALUES(?, ?, ?);", body_buf)
+ self.con.commit()
+
+ def retrieve_flows(self, ids=None):
+ flows = []
+ with self.con as con:
+ if not ids:
+ sql = "SELECT f.content, b.type_id, b.content " \
+ "FROM flow f " \
+ "LEFT OUTER JOIN body b ON f.id = b.flow_id;"
+ rows = con.execute(sql).fetchall()
+ else:
+ sql = "SELECT f.content, b.type_id, b.content " \
+ "FROM flow f " \
+ "LEFT OUTER JOIN body b ON f.id = b.flow_id " \
+ f"AND f.id IN ({','.join(['?' for _ in range(len(ids))])});"
+ rows = con.execute(sql, ids).fetchall()
+ for row in rows:
+ flow = protobuf.loads(row[0])
+ if row[1]:
+ typ = self.type_mappings["body"][row[1]]
+ if typ and row[2]:
+ setattr(getattr(flow, typ), "content", row[2])
+ flow = self._reassemble(flow)
+ flows.append(flow)
+ return flows
+
+ def clear(self):
+ self.con.executescript("DELETE FROM body; DELETE FROM annotation; DELETE FROM flow;")
+
+
+matchall = flowfilter.parse(".")
+
+orders = [
+ "time",
+ "method",
+ "url",
+ "size"
+]
+
+
+class Session:
+
+ _FP_RATE = 150
+ _FP_DECREMENT = 0.9
+ _FP_DEFAULT = 3.0
+
+ def __init__(self):
+ self.db_store: SessionDB = None
+ self._hot_store: collections.OrderedDict = collections.OrderedDict()
+ self._order_store: typing.Dict[str, typing.Dict[str, typing.Union[int, float, str]]] = {}
+ self._view: typing.List[typing.Tuple[typing.Union[int, float, str], str]] = []
+ self.order: str = orders[0]
+ self.filter = matchall
+ self._flush_period: float = self._FP_DEFAULT
+ self._flush_rate: int = self._FP_RATE
+ self.started: bool = False
+
+ def load(self, loader):
+ loader.add_option(
+ "session_path", typing.Optional[types.Path], None,
+ "Path of session to load or to create."
+ )
+ loader.add_option(
+ "view_order", str, "time",
+ "Flow sort order.",
+ choices=list(map(lambda c: c[1], orders))
+ )
+ loader.add_option(
+ "view_filter", typing.Optional[str], None,
+ "Limit the view to matching flows."
+ )
+
+ def running(self):
+ if not self.started:
+ self.started = True
+ self.db_store = SessionDB(ctx.options.session_path)
+ loop = asyncio.get_event_loop()
+ loop.create_task(self._writer())
+
+ def configure(self, updated):
+ if "view_order" in updated:
+ self.set_order(ctx.options.view_order)
+ if "view_filter" in updated:
+ self.set_filter(ctx.options.view_filter)
+
+ async def _writer(self):
+ while True:
+ await asyncio.sleep(self._flush_period)
+ batches = -(-len(self._hot_store) // self._flush_rate)
+ self._flush_period = self._flush_period * self._FP_DECREMENT if batches > 1 else self._FP_DEFAULT
+ while batches:
+ tof = []
+ to_dump = min(len(self._hot_store), self._flush_rate)
+ for _ in range(to_dump):
+ tof.append(self._hot_store.popitem(last=False)[1])
+ self.db_store.store_flows(tof)
+ batches -= 1
+ await asyncio.sleep(0.01)
+
+ def load_view(self) -> typing.Sequence[http.HTTPFlow]:
+ ids = [fid for _, fid in self._view]
+ flows = self.load_storage(ids)
+ return sorted(flows, key=lambda f: self._generate_order(self.order, f))
+
+ def load_storage(self, ids=None) -> typing.Sequence[http.HTTPFlow]:
+ flows = []
+ ids_from_store = []
+ if ids is not None:
+ for fid in ids:
+ # A same flow could be at the same time in hot and db storage. We want the most updated version.
+ if fid in self._hot_store:
+ flows.append(self._hot_store[fid])
+ elif fid in self.db_store:
+ ids_from_store.append(fid)
+ flows += self.db_store.retrieve_flows(ids_from_store)
+ else:
+ for flow in self._hot_store.values():
+ flows.append(flow)
+ for flow in self.db_store.retrieve_flows():
+ if flow.id not in self._hot_store:
+ flows.append(flow)
+ return flows
+
+ def clear_storage(self):
+ self.db_store.clear()
+ self._hot_store.clear()
+ self._view = []
+
+ def store_count(self) -> int:
+ ln = 0
+ for fid in self._hot_store.keys():
+ if fid not in self.db_store:
+ ln += 1
+ return ln + len(self.db_store)
+
+ @staticmethod
+ def _generate_order(o: str, f: http.HTTPFlow) -> typing.Optional[typing.Union[str, int, float]]:
+ if o == "time":
+ return f.request.timestamp_start or 0
+ if o == "method":
+ return f.request.method
+ if o == "url":
+ return f.request.url
+ if o == "size":
+ s = 0
+ if f.request.raw_content:
+ s += len(f.request.raw_content)
+ if f.response and f.response.raw_content:
+ s += len(f.response.raw_content)
+ return s
+ return None
+
+ def _store_order(self, f: http.HTTPFlow):
+ self._order_store[f.id] = {}
+ for order in orders:
+ self._order_store[f.id][order] = self._generate_order(order, f)
+
+ def set_order(self, order: str) -> None:
+ if order not in orders:
+ raise CommandError(
+ "Unknown flow order: %s" % order
+ )
+ if order != self.order:
+ self.order = order
+ newview = [
+ (self._order_store[t[1]][order], t[1]) for t in self._view
+ ]
+ self._view = sorted(newview)
+
+ def _refilter(self):
+ self._view = []
+ flows = self.load_storage()
+ for f in flows:
+ if self.filter(f):
+ self.update_view(f)
+
+ def set_filter(self, input_filter: typing.Optional[str]) -> None:
+ filt = matchall if not input_filter else flowfilter.parse(input_filter)
+ if not filt:
+ raise CommandError(
+ "Invalid interception filter: %s" % filt
+ )
+ self.filter = filt
+ self._refilter()
+
+ def update_view(self, f):
+ if any([f.id == t[1] for t in self._view]):
+ self._view = [(order, fid) for order, fid in self._view if fid != f.id]
+ o = self._order_store[f.id][self.order]
+ self._view.insert(bisect.bisect_left(KeyifyList(self._view, lambda x: x[0]), o), (o, f.id))
+
+ def update(self, flows: typing.Sequence[http.HTTPFlow]) -> None:
+ for f in flows:
+ self._store_order(f)
+ if f.id in self._hot_store:
+ self._hot_store.pop(f.id)
+ self._hot_store[f.id] = f
+ if self.filter(f):
+ self.update_view(f)
+
+ def request(self, f):
+ self.update([f])
+
+ def error(self, f):
+ self.update([f])
+
+ def response(self, f):
+ self.update([f])
+
+ def intercept(self, f):
+ self.update([f])
+
+ def resume(self, f):
+ self.update([f])
+
+ def kill(self, f):
+ self.update([f])
diff --git a/mitmproxy/io/protobuf.py b/mitmproxy/io/protobuf.py
index 9a00eacf..c8ca3acc 100644
--- a/mitmproxy/io/protobuf.py
+++ b/mitmproxy/io/protobuf.py
@@ -189,7 +189,7 @@ def load_http(hf: http_pb2.HTTPFlow) -> HTTPFlow:
return f
-def loads(b: bytes, typ="http") -> flow.Flow:
+def loads(b: bytes, typ="http") -> typing.Union[HTTPFlow]:
if typ != 'http':
raise exceptions.TypeError("Flow types different than HTTP not supported yet!")
else:
diff --git a/mitmproxy/io/sql/session_create.sql b/mitmproxy/io/sql/session_create.sql
index bfc98b94..b9c28c03 100644
--- a/mitmproxy/io/sql/session_create.sql
+++ b/mitmproxy/io/sql/session_create.sql
@@ -1,3 +1,5 @@
+PRAGMA foreign_keys = ON;
+
CREATE TABLE flow (
id VARCHAR(36) PRIMARY KEY,
content BLOB
diff --git a/test/mitmproxy/addons/test_session.py b/test/mitmproxy/addons/test_session.py
index d4b1109b..20feb69d 100644
--- a/test/mitmproxy/addons/test_session.py
+++ b/test/mitmproxy/addons/test_session.py
@@ -1,13 +1,40 @@
import sqlite3
+import asyncio
import pytest
import os
+from mitmproxy import ctx
+from mitmproxy import http
+from mitmproxy.test import tflow, tutils
+from mitmproxy.test import taddons
from mitmproxy.addons import session
-from mitmproxy.exceptions import SessionLoadException
+from mitmproxy.exceptions import SessionLoadException, CommandError
from mitmproxy.utils.data import pkg_data
class TestSession:
+
+ @staticmethod
+ def tft(*, method="GET", start=0):
+ f = tflow.tflow()
+ f.request.method = method
+ f.request.timestamp_start = start
+ return f
+
+ @staticmethod
+ def start_session(fp=None):
+ s = session.Session()
+ with taddons.context() as tctx:
+ tctx.master.addons.add(s)
+ tctx.options.session_path = None
+ tctx.options.view_filter = None
+ # To make tests quicker
+ if fp:
+ s._flush_period = fp
+ s._FP_DEFAULT = fp
+ s.running()
+ return s
+
def test_session_temporary(self):
s = session.SessionDB()
td = s.tempdir
@@ -56,3 +83,161 @@ class TestSession:
assert len(rows) == 1
con.close()
os.remove(path)
+
+ def test_session_order_generators(self):
+ s = session.Session()
+ tf = tflow.tflow(resp=True)
+ assert s._generate_order('time', tf) == 946681200
+ assert s._generate_order('method', tf) == tf.request.method
+ assert s._generate_order('url', tf) == tf.request.url
+ assert s._generate_order('size', tf) == len(tf.request.raw_content) + len(tf.response.raw_content)
+ assert not s._generate_order('invalid', tf)
+
+ def test_storage_simple(self):
+ s = session.Session()
+ ctx.options = taddons.context()
+ ctx.options.session_path = None
+ s.running()
+ f = self.tft(start=1)
+ assert s.store_count() == 0
+ s.request(f)
+ assert s._view == [(1, f.id)]
+ assert s._order_store[f.id]['time'] == 1
+ assert s._order_store[f.id]['method'] == f.request.method
+ assert s._order_store[f.id]['url'] == f.request.url
+ assert s._order_store[f.id]['size'] == len(f.request.raw_content)
+ assert s.load_view() == [f]
+ assert s.load_storage(['nonexistent']) == []
+
+ s.error(f)
+ s.response(f)
+ s.intercept(f)
+ s.resume(f)
+ s.kill(f)
+
+ # Verify that flow has been updated, not duplicated
+ assert s._view == [(1, f.id)]
+ assert s._order_store[f.id]['time'] == 1
+ assert s._order_store[f.id]['method'] == f.request.method
+ assert s._order_store[f.id]['url'] == f.request.url
+ assert s._order_store[f.id]['size'] == len(f.request.raw_content)
+ assert s.store_count() == 1
+
+ f2 = self.tft(start=3)
+ s.request(f2)
+ assert s._view == [(1, f.id), (3, f2.id)]
+ s.request(f2)
+ assert s._view == [(1, f.id), (3, f2.id)]
+
+ f3 = self.tft(start=2)
+ s.request(f3)
+ assert s._view == [(1, f.id), (2, f3.id), (3, f2.id)]
+ s.request(f3)
+ assert s._view == [(1, f.id), (2, f3.id), (3, f2.id)]
+ assert s.store_count() == 3
+
+ s.clear_storage()
+ assert len(s._view) == 0
+ assert s.store_count() == 0
+
+ def test_storage_filter(self):
+ s = self.start_session()
+ s.request(self.tft(method="get"))
+ s.request(self.tft(method="put"))
+ s.request(self.tft(method="get"))
+ s.request(self.tft(method="put"))
+ assert len(s._view) == 4
+ with taddons.context() as tctx:
+ tctx.master.addons.add(s)
+ tctx.options.view_filter = '~m get'
+ s.configure({"view_filter"})
+ assert [f.request.method for f in s.load_view()] == ["GET", "GET"]
+ assert s.store_count() == 4
+ with pytest.raises(CommandError):
+ s.set_filter("~notafilter")
+ s.set_filter(None)
+ assert len(s._view) == 4
+
+ @pytest.mark.asyncio
+ async def test_storage_flush_with_specials(self):
+ s = self.start_session(fp=0.5)
+ f = self.tft()
+ s.request(f)
+ await asyncio.sleep(1)
+ assert len(s._hot_store) == 0
+ f.response = http.HTTPResponse.wrap(tutils.tresp())
+ s.response(f)
+ assert len(s._hot_store) == 1
+ assert s.load_storage() == [f]
+ await asyncio.sleep(1)
+ assert all([lflow.__dict__ == flow.__dict__ for lflow, flow in list(zip(s.load_storage(), [f]))])
+
+ f.server_conn.via = tflow.tserver_conn()
+ s.request(f)
+ await asyncio.sleep(0.6)
+ assert len(s._hot_store) == 0
+ assert all([lflow.__dict__ == flow.__dict__ for lflow, flow in list(zip(s.load_storage(), [f]))])
+
+ flows = [self.tft() for _ in range(500)]
+ s.update(flows)
+ await asyncio.sleep(0.6)
+ assert s._flush_period == s._FP_DEFAULT * s._FP_DECREMENT
+ await asyncio.sleep(3)
+ assert s._flush_period == s._FP_DEFAULT
+
+ @pytest.mark.asyncio
+ async def test_storage_bodies(self):
+ # Need to test for configure
+ # Need to test for set_order
+ s = self.start_session(fp=0.5)
+ f = self.tft()
+ f2 = self.tft(start=1)
+ f.request.content = b"A" * 1001
+ s.request(f)
+ s.request(f2)
+ await asyncio.sleep(1.0)
+ content = s.db_store.con.execute(
+ "SELECT type_id, content FROM body WHERE body.flow_id == (?);", [f.id]
+ ).fetchall()[0]
+ assert content == (1, b"A" * 1001)
+ assert s.db_store.body_ledger == {f.id}
+ f.response = http.HTTPResponse.wrap(tutils.tresp(content=b"A" * 1001))
+ f2.response = http.HTTPResponse.wrap(tutils.tresp(content=b"A" * 1001))
+ # Content length is wrong for some reason -- quick fix
+ f.response.headers['content-length'] = b"1001"
+ f2.response.headers['content-length'] = b"1001"
+ s.response(f)
+ s.response(f2)
+ await asyncio.sleep(1.0)
+ rows = s.db_store.con.execute(
+ "SELECT type_id, content FROM body WHERE body.flow_id == (?);", [f.id]
+ ).fetchall()
+ assert len(rows) == 1
+ rows = s.db_store.con.execute(
+ "SELECT type_id, content FROM body WHERE body.flow_id == (?);", [f2.id]
+ ).fetchall()
+ assert len(rows) == 1
+ assert s.db_store.body_ledger == {f.id}
+ assert all([lf.__dict__ == rf.__dict__ for lf, rf in list(zip(s.load_view(), [f, f2]))])
+
+ @pytest.mark.asyncio
+ async def test_storage_order(self):
+ s = self.start_session(fp=0.5)
+ s.request(self.tft(method="GET", start=4))
+ s.request(self.tft(method="PUT", start=2))
+ s.request(self.tft(method="GET", start=3))
+ s.request(self.tft(method="PUT", start=1))
+ assert [i.request.timestamp_start for i in s.load_view()] == [1, 2, 3, 4]
+ await asyncio.sleep(1.0)
+ assert [i.request.timestamp_start for i in s.load_view()] == [1, 2, 3, 4]
+ with taddons.context() as tctx:
+ tctx.master.addons.add(s)
+ tctx.options.view_order = "method"
+ s.configure({"view_order"})
+ assert [i.request.method for i in s.load_view()] == ["GET", "GET", "PUT", "PUT"]
+
+ s.set_order("time")
+ assert [i.request.timestamp_start for i in s.load_view()] == [1, 2, 3, 4]
+
+ with pytest.raises(CommandError):
+ s.set_order("not_an_order")