diff options
-rw-r--r-- | mitmproxy/addons/session.py | 324 | ||||
-rw-r--r-- | mitmproxy/io/protobuf.py | 2 | ||||
-rw-r--r-- | mitmproxy/io/sql/session_create.sql | 2 | ||||
-rw-r--r-- | test/mitmproxy/addons/test_session.py | 187 |
4 files changed, 507 insertions, 8 deletions
diff --git a/mitmproxy/addons/session.py b/mitmproxy/addons/session.py index c49b95c4..63e382ec 100644 --- a/mitmproxy/addons/session.py +++ b/mitmproxy/addons/session.py @@ -1,12 +1,34 @@ +import collections import tempfile +import asyncio +import typing +import bisect import shutil import sqlite3 +import copy import os -from mitmproxy.exceptions import SessionLoadException +from mitmproxy import flowfilter +from mitmproxy import types +from mitmproxy import http +from mitmproxy import ctx +from mitmproxy.io import protobuf +from mitmproxy.exceptions import SessionLoadException, CommandError from mitmproxy.utils.data import pkg_data +class KeyifyList(object): + def __init__(self, inner, key): + self.inner = inner + self.key = key + + def __len__(self): + return len(self.inner) + + def __getitem__(self, k): + return self.key(self.inner[k]) + + # Could be implemented using async libraries class SessionDB: """ @@ -14,6 +36,13 @@ class SessionDB: for Sessions and handles creation, retrieving and insertion in tables. """ + content_threshold = 1000 + type_mappings = { + "body": { + 1: "request", + 2: "response" + } + } def __init__(self, db_path=None): """ @@ -21,8 +50,13 @@ class SessionDB: or create a new one with optional path. :param db_path: """ - self.tempdir = None - self.con = None + self.live_components: typing.Dict[str, tuple] = {} + self.tempdir: tempfile.TemporaryDirectory = None + self.con: sqlite3.Connection = None + # This is used for fast look-ups over bodies already dumped to database. + # This permits to enforce one-to-one relationship between flow and body table. + self.body_ledger: typing.Set[str] = set() + self.id_ledger: typing.Set[str] = set() if db_path is not None and os.path.isfile(db_path): self._load_session(db_path) else: @@ -40,6 +74,12 @@ class SessionDB: if self.tempdir: shutil.rmtree(self.tempdir) + def __contains__(self, fid): + return fid in self.id_ledger + + def __len__(self): + return len(self.id_ledger) + def _load_session(self, path): if not self.is_session_db(path): raise SessionLoadException('Given path does not point to a valid Session') @@ -48,8 +88,8 @@ class SessionDB: def _create_session(self): script_path = pkg_data.path("io/sql/session_create.sql") qry = open(script_path, 'r').read() - with self.con: - self.con.executescript(qry) + self.con.executescript(qry) + self.con.commit() @staticmethod def is_session_db(path): @@ -58,6 +98,7 @@ class SessionDB: is a valid Session SQLite DB. :return: True if valid, False if invalid. """ + c = None try: c = sqlite3.connect(f'file:{path}?mode=rw', uri=True) cursor = c.cursor() @@ -67,7 +108,278 @@ class SessionDB: if all(elem in rows for elem in tables): c.close() return True - except: + except sqlite3.Error: if c: c.close() return False + + def _disassemble(self, flow): + # Some live components of flows cannot be serialized, but they are needed to ensure correct functionality. + # We solve this by keeping a list of tuples which "save" those components for each flow id, eventually + # adding them back when needed. + self.live_components[flow.id] = ( + flow.client_conn.wfile, + flow.client_conn.rfile, + flow.client_conn.reply, + flow.server_conn.wfile, + flow.server_conn.rfile, + flow.server_conn.reply, + (flow.server_conn.via.wfile, flow.server_conn.via.rfile, + flow.server_conn.via.reply) if flow.server_conn.via else None, + flow.reply + ) + + def _reassemble(self, flow): + if flow.id in self.live_components: + cwf, crf, crp, swf, srf, srp, via, rep = self.live_components[flow.id] + flow.client_conn.wfile = cwf + flow.client_conn.rfile = crf + flow.client_conn.reply = crp + flow.server_conn.wfile = swf + flow.server_conn.rfile = srf + flow.server_conn.reply = srp + flow.reply = rep + if via: + flow.server_conn.via.rfile, flow.server_conn.via.wfile, flow.server_conn.via.reply = via + return flow + + def store_flows(self, flows): + body_buf = [] + flow_buf = [] + for flow in flows: + self.id_ledger.add(flow.id) + self._disassemble(flow) + f = copy.copy(flow) + f.request = copy.deepcopy(flow.request) + if flow.response: + f.response = copy.deepcopy(flow.response) + f.id = flow.id + if len(f.request.content) > self.content_threshold and f.id not in self.body_ledger: + body_buf.append((f.id, 1, f.request.content)) + f.request.content = b"" + self.body_ledger.add(f.id) + if f.response and f.id not in self.body_ledger: + if len(f.response.content) > self.content_threshold: + body_buf.append((f.id, 2, f.response.content)) + f.response.content = b"" + flow_buf.append((f.id, protobuf.dumps(f))) + self.con.executemany("INSERT OR REPLACE INTO flow VALUES(?, ?);", flow_buf) + if body_buf: + self.con.executemany("INSERT INTO body (flow_id, type_id, content) VALUES(?, ?, ?);", body_buf) + self.con.commit() + + def retrieve_flows(self, ids=None): + flows = [] + with self.con as con: + if not ids: + sql = "SELECT f.content, b.type_id, b.content " \ + "FROM flow f " \ + "LEFT OUTER JOIN body b ON f.id = b.flow_id;" + rows = con.execute(sql).fetchall() + else: + sql = "SELECT f.content, b.type_id, b.content " \ + "FROM flow f " \ + "LEFT OUTER JOIN body b ON f.id = b.flow_id " \ + f"AND f.id IN ({','.join(['?' for _ in range(len(ids))])});" + rows = con.execute(sql, ids).fetchall() + for row in rows: + flow = protobuf.loads(row[0]) + if row[1]: + typ = self.type_mappings["body"][row[1]] + if typ and row[2]: + setattr(getattr(flow, typ), "content", row[2]) + flow = self._reassemble(flow) + flows.append(flow) + return flows + + def clear(self): + self.con.executescript("DELETE FROM body; DELETE FROM annotation; DELETE FROM flow;") + + +matchall = flowfilter.parse(".") + +orders = [ + "time", + "method", + "url", + "size" +] + + +class Session: + + _FP_RATE = 150 + _FP_DECREMENT = 0.9 + _FP_DEFAULT = 3.0 + + def __init__(self): + self.db_store: SessionDB = None + self._hot_store: collections.OrderedDict = collections.OrderedDict() + self._order_store: typing.Dict[str, typing.Dict[str, typing.Union[int, float, str]]] = {} + self._view: typing.List[typing.Tuple[typing.Union[int, float, str], str]] = [] + self.order: str = orders[0] + self.filter = matchall + self._flush_period: float = self._FP_DEFAULT + self._flush_rate: int = self._FP_RATE + self.started: bool = False + + def load(self, loader): + loader.add_option( + "session_path", typing.Optional[types.Path], None, + "Path of session to load or to create." + ) + loader.add_option( + "view_order", str, "time", + "Flow sort order.", + choices=list(map(lambda c: c[1], orders)) + ) + loader.add_option( + "view_filter", typing.Optional[str], None, + "Limit the view to matching flows." + ) + + def running(self): + if not self.started: + self.started = True + self.db_store = SessionDB(ctx.options.session_path) + loop = asyncio.get_event_loop() + loop.create_task(self._writer()) + + def configure(self, updated): + if "view_order" in updated: + self.set_order(ctx.options.view_order) + if "view_filter" in updated: + self.set_filter(ctx.options.view_filter) + + async def _writer(self): + while True: + await asyncio.sleep(self._flush_period) + batches = -(-len(self._hot_store) // self._flush_rate) + self._flush_period = self._flush_period * self._FP_DECREMENT if batches > 1 else self._FP_DEFAULT + while batches: + tof = [] + to_dump = min(len(self._hot_store), self._flush_rate) + for _ in range(to_dump): + tof.append(self._hot_store.popitem(last=False)[1]) + self.db_store.store_flows(tof) + batches -= 1 + await asyncio.sleep(0.01) + + def load_view(self) -> typing.Sequence[http.HTTPFlow]: + ids = [fid for _, fid in self._view] + flows = self.load_storage(ids) + return sorted(flows, key=lambda f: self._generate_order(self.order, f)) + + def load_storage(self, ids=None) -> typing.Sequence[http.HTTPFlow]: + flows = [] + ids_from_store = [] + if ids is not None: + for fid in ids: + # A same flow could be at the same time in hot and db storage. We want the most updated version. + if fid in self._hot_store: + flows.append(self._hot_store[fid]) + elif fid in self.db_store: + ids_from_store.append(fid) + flows += self.db_store.retrieve_flows(ids_from_store) + else: + for flow in self._hot_store.values(): + flows.append(flow) + for flow in self.db_store.retrieve_flows(): + if flow.id not in self._hot_store: + flows.append(flow) + return flows + + def clear_storage(self): + self.db_store.clear() + self._hot_store.clear() + self._view = [] + + def store_count(self) -> int: + ln = 0 + for fid in self._hot_store.keys(): + if fid not in self.db_store: + ln += 1 + return ln + len(self.db_store) + + @staticmethod + def _generate_order(o: str, f: http.HTTPFlow) -> typing.Optional[typing.Union[str, int, float]]: + if o == "time": + return f.request.timestamp_start or 0 + if o == "method": + return f.request.method + if o == "url": + return f.request.url + if o == "size": + s = 0 + if f.request.raw_content: + s += len(f.request.raw_content) + if f.response and f.response.raw_content: + s += len(f.response.raw_content) + return s + return None + + def _store_order(self, f: http.HTTPFlow): + self._order_store[f.id] = {} + for order in orders: + self._order_store[f.id][order] = self._generate_order(order, f) + + def set_order(self, order: str) -> None: + if order not in orders: + raise CommandError( + "Unknown flow order: %s" % order + ) + if order != self.order: + self.order = order + newview = [ + (self._order_store[t[1]][order], t[1]) for t in self._view + ] + self._view = sorted(newview) + + def _refilter(self): + self._view = [] + flows = self.load_storage() + for f in flows: + if self.filter(f): + self.update_view(f) + + def set_filter(self, input_filter: typing.Optional[str]) -> None: + filt = matchall if not input_filter else flowfilter.parse(input_filter) + if not filt: + raise CommandError( + "Invalid interception filter: %s" % filt + ) + self.filter = filt + self._refilter() + + def update_view(self, f): + if any([f.id == t[1] for t in self._view]): + self._view = [(order, fid) for order, fid in self._view if fid != f.id] + o = self._order_store[f.id][self.order] + self._view.insert(bisect.bisect_left(KeyifyList(self._view, lambda x: x[0]), o), (o, f.id)) + + def update(self, flows: typing.Sequence[http.HTTPFlow]) -> None: + for f in flows: + self._store_order(f) + if f.id in self._hot_store: + self._hot_store.pop(f.id) + self._hot_store[f.id] = f + if self.filter(f): + self.update_view(f) + + def request(self, f): + self.update([f]) + + def error(self, f): + self.update([f]) + + def response(self, f): + self.update([f]) + + def intercept(self, f): + self.update([f]) + + def resume(self, f): + self.update([f]) + + def kill(self, f): + self.update([f]) diff --git a/mitmproxy/io/protobuf.py b/mitmproxy/io/protobuf.py index 9a00eacf..c8ca3acc 100644 --- a/mitmproxy/io/protobuf.py +++ b/mitmproxy/io/protobuf.py @@ -189,7 +189,7 @@ def load_http(hf: http_pb2.HTTPFlow) -> HTTPFlow: return f -def loads(b: bytes, typ="http") -> flow.Flow: +def loads(b: bytes, typ="http") -> typing.Union[HTTPFlow]: if typ != 'http': raise exceptions.TypeError("Flow types different than HTTP not supported yet!") else: diff --git a/mitmproxy/io/sql/session_create.sql b/mitmproxy/io/sql/session_create.sql index bfc98b94..b9c28c03 100644 --- a/mitmproxy/io/sql/session_create.sql +++ b/mitmproxy/io/sql/session_create.sql @@ -1,3 +1,5 @@ +PRAGMA foreign_keys = ON; + CREATE TABLE flow ( id VARCHAR(36) PRIMARY KEY, content BLOB diff --git a/test/mitmproxy/addons/test_session.py b/test/mitmproxy/addons/test_session.py index d4b1109b..20feb69d 100644 --- a/test/mitmproxy/addons/test_session.py +++ b/test/mitmproxy/addons/test_session.py @@ -1,13 +1,40 @@ import sqlite3 +import asyncio import pytest import os +from mitmproxy import ctx +from mitmproxy import http +from mitmproxy.test import tflow, tutils +from mitmproxy.test import taddons from mitmproxy.addons import session -from mitmproxy.exceptions import SessionLoadException +from mitmproxy.exceptions import SessionLoadException, CommandError from mitmproxy.utils.data import pkg_data class TestSession: + + @staticmethod + def tft(*, method="GET", start=0): + f = tflow.tflow() + f.request.method = method + f.request.timestamp_start = start + return f + + @staticmethod + def start_session(fp=None): + s = session.Session() + with taddons.context() as tctx: + tctx.master.addons.add(s) + tctx.options.session_path = None + tctx.options.view_filter = None + # To make tests quicker + if fp: + s._flush_period = fp + s._FP_DEFAULT = fp + s.running() + return s + def test_session_temporary(self): s = session.SessionDB() td = s.tempdir @@ -56,3 +83,161 @@ class TestSession: assert len(rows) == 1 con.close() os.remove(path) + + def test_session_order_generators(self): + s = session.Session() + tf = tflow.tflow(resp=True) + assert s._generate_order('time', tf) == 946681200 + assert s._generate_order('method', tf) == tf.request.method + assert s._generate_order('url', tf) == tf.request.url + assert s._generate_order('size', tf) == len(tf.request.raw_content) + len(tf.response.raw_content) + assert not s._generate_order('invalid', tf) + + def test_storage_simple(self): + s = session.Session() + ctx.options = taddons.context() + ctx.options.session_path = None + s.running() + f = self.tft(start=1) + assert s.store_count() == 0 + s.request(f) + assert s._view == [(1, f.id)] + assert s._order_store[f.id]['time'] == 1 + assert s._order_store[f.id]['method'] == f.request.method + assert s._order_store[f.id]['url'] == f.request.url + assert s._order_store[f.id]['size'] == len(f.request.raw_content) + assert s.load_view() == [f] + assert s.load_storage(['nonexistent']) == [] + + s.error(f) + s.response(f) + s.intercept(f) + s.resume(f) + s.kill(f) + + # Verify that flow has been updated, not duplicated + assert s._view == [(1, f.id)] + assert s._order_store[f.id]['time'] == 1 + assert s._order_store[f.id]['method'] == f.request.method + assert s._order_store[f.id]['url'] == f.request.url + assert s._order_store[f.id]['size'] == len(f.request.raw_content) + assert s.store_count() == 1 + + f2 = self.tft(start=3) + s.request(f2) + assert s._view == [(1, f.id), (3, f2.id)] + s.request(f2) + assert s._view == [(1, f.id), (3, f2.id)] + + f3 = self.tft(start=2) + s.request(f3) + assert s._view == [(1, f.id), (2, f3.id), (3, f2.id)] + s.request(f3) + assert s._view == [(1, f.id), (2, f3.id), (3, f2.id)] + assert s.store_count() == 3 + + s.clear_storage() + assert len(s._view) == 0 + assert s.store_count() == 0 + + def test_storage_filter(self): + s = self.start_session() + s.request(self.tft(method="get")) + s.request(self.tft(method="put")) + s.request(self.tft(method="get")) + s.request(self.tft(method="put")) + assert len(s._view) == 4 + with taddons.context() as tctx: + tctx.master.addons.add(s) + tctx.options.view_filter = '~m get' + s.configure({"view_filter"}) + assert [f.request.method for f in s.load_view()] == ["GET", "GET"] + assert s.store_count() == 4 + with pytest.raises(CommandError): + s.set_filter("~notafilter") + s.set_filter(None) + assert len(s._view) == 4 + + @pytest.mark.asyncio + async def test_storage_flush_with_specials(self): + s = self.start_session(fp=0.5) + f = self.tft() + s.request(f) + await asyncio.sleep(1) + assert len(s._hot_store) == 0 + f.response = http.HTTPResponse.wrap(tutils.tresp()) + s.response(f) + assert len(s._hot_store) == 1 + assert s.load_storage() == [f] + await asyncio.sleep(1) + assert all([lflow.__dict__ == flow.__dict__ for lflow, flow in list(zip(s.load_storage(), [f]))]) + + f.server_conn.via = tflow.tserver_conn() + s.request(f) + await asyncio.sleep(0.6) + assert len(s._hot_store) == 0 + assert all([lflow.__dict__ == flow.__dict__ for lflow, flow in list(zip(s.load_storage(), [f]))]) + + flows = [self.tft() for _ in range(500)] + s.update(flows) + await asyncio.sleep(0.6) + assert s._flush_period == s._FP_DEFAULT * s._FP_DECREMENT + await asyncio.sleep(3) + assert s._flush_period == s._FP_DEFAULT + + @pytest.mark.asyncio + async def test_storage_bodies(self): + # Need to test for configure + # Need to test for set_order + s = self.start_session(fp=0.5) + f = self.tft() + f2 = self.tft(start=1) + f.request.content = b"A" * 1001 + s.request(f) + s.request(f2) + await asyncio.sleep(1.0) + content = s.db_store.con.execute( + "SELECT type_id, content FROM body WHERE body.flow_id == (?);", [f.id] + ).fetchall()[0] + assert content == (1, b"A" * 1001) + assert s.db_store.body_ledger == {f.id} + f.response = http.HTTPResponse.wrap(tutils.tresp(content=b"A" * 1001)) + f2.response = http.HTTPResponse.wrap(tutils.tresp(content=b"A" * 1001)) + # Content length is wrong for some reason -- quick fix + f.response.headers['content-length'] = b"1001" + f2.response.headers['content-length'] = b"1001" + s.response(f) + s.response(f2) + await asyncio.sleep(1.0) + rows = s.db_store.con.execute( + "SELECT type_id, content FROM body WHERE body.flow_id == (?);", [f.id] + ).fetchall() + assert len(rows) == 1 + rows = s.db_store.con.execute( + "SELECT type_id, content FROM body WHERE body.flow_id == (?);", [f2.id] + ).fetchall() + assert len(rows) == 1 + assert s.db_store.body_ledger == {f.id} + assert all([lf.__dict__ == rf.__dict__ for lf, rf in list(zip(s.load_view(), [f, f2]))]) + + @pytest.mark.asyncio + async def test_storage_order(self): + s = self.start_session(fp=0.5) + s.request(self.tft(method="GET", start=4)) + s.request(self.tft(method="PUT", start=2)) + s.request(self.tft(method="GET", start=3)) + s.request(self.tft(method="PUT", start=1)) + assert [i.request.timestamp_start for i in s.load_view()] == [1, 2, 3, 4] + await asyncio.sleep(1.0) + assert [i.request.timestamp_start for i in s.load_view()] == [1, 2, 3, 4] + with taddons.context() as tctx: + tctx.master.addons.add(s) + tctx.options.view_order = "method" + s.configure({"view_order"}) + assert [i.request.method for i in s.load_view()] == ["GET", "GET", "PUT", "PUT"] + + s.set_order("time") + assert [i.request.timestamp_start for i in s.load_view()] == [1, 2, 3, 4] + + with pytest.raises(CommandError): + s.set_order("not_an_order") |