diff options
-rw-r--r-- | libmproxy/dump.py | 6 | ||||
-rw-r--r-- | libmproxy/filt.py | 2 | ||||
-rw-r--r-- | libmproxy/flow.py | 121 | ||||
-rw-r--r-- | libmproxy/protocol/__init__.py | 8 | ||||
-rw-r--r-- | libmproxy/protocol/http.py | 159 | ||||
-rw-r--r-- | libmproxy/proxy.py | 9 | ||||
-rw-r--r-- | libmproxy/stateobject.py | 4 | ||||
-rw-r--r-- | test/test_dump.py | 18 | ||||
-rw-r--r-- | test/test_filt.py | 43 | ||||
-rw-r--r-- | test/test_flow.py | 122 | ||||
-rw-r--r-- | test/tutils.py | 64 |
11 files changed, 296 insertions, 260 deletions
diff --git a/libmproxy/dump.py b/libmproxy/dump.py index e76ea1ce..7b54f7c1 100644 --- a/libmproxy/dump.py +++ b/libmproxy/dump.py @@ -42,14 +42,14 @@ class Options(object): def str_response(resp): r = "%s %s"%(resp.code, resp.msg) - if resp.is_replay(): + if resp.is_replay: r = "[replay] " + r return r def str_request(req, showhost): - if req.client_conn: - c = req.client_conn.address[0] + if req.flow.client_conn: + c = req.flow.client_conn.address.host else: c = "[replay]" r = "%s %s %s"%(c, req.method, req.get_url(showhost)) diff --git a/libmproxy/filt.py b/libmproxy/filt.py index 6a0c3075..09be41b8 100644 --- a/libmproxy/filt.py +++ b/libmproxy/filt.py @@ -198,7 +198,7 @@ class FDomain(_Rex): code = "d" help = "Domain" def __call__(self, f): - return bool(re.search(self.expr, f.request.host, re.IGNORECASE)) + return bool(re.search(self.expr, f.request.host or f.server_conn.address.host, re.IGNORECASE)) class FUrl(_Rex): diff --git a/libmproxy/flow.py b/libmproxy/flow.py index 4032461d..b4b939c7 100644 --- a/libmproxy/flow.py +++ b/libmproxy/flow.py @@ -8,7 +8,8 @@ import types import tnetstring, filt, script, utils, encoding, proxy from email.utils import parsedate_tz, formatdate, mktime_tz from netlib import odict, http, certutils, wsgi -import controller, version, protocol +from .proxy import ClientConnection, ServerConnection +import controller, version, protocol, stateobject import app @@ -19,6 +20,123 @@ ODict = odict.ODict ODictCaseless = odict.ODictCaseless +class BackreferenceMixin(object): + """ + If an attribute from the _backrefattr tuple is set, + this mixin sets a reference back on the attribute object. + Example: + e = Error() + f = Flow() + f.error = e + assert f is e.flow + """ + _backrefattr = tuple() + + def __setattr__(self, key, value): + super(BackreferenceMixin, self).__setattr__(key, value) + if key in self._backrefattr and value is not None: + setattr(value, self._backrefname, self) + + +class Error(stateobject.SimpleStateObject): + """ + An Error. + + This is distinct from an HTTP error response (say, a code 500), which + is represented by a normal Response object. This class is responsible + for indicating errors that fall outside of normal HTTP communications, + like interrupted connections, timeouts, protocol errors. + + Exposes the following attributes: + + flow: Flow object + msg: Message describing the error + timestamp: Seconds since the epoch + """ + def __init__(self, msg, timestamp=None): + """ + @type msg: str + @type timestamp: float + """ + self.msg = msg + self.timestamp = timestamp or utils.timestamp() + + _stateobject_attributes = dict( + msg=str, + timestamp=float + ) + + @classmethod + def _from_state(cls, state): + f = cls(None) # the default implementation assumes an empty constructor. Override accordingly. + f._load_state(state) + return f + + def copy(self): + c = copy.copy(self) + return c + + +class Flow(stateobject.SimpleStateObject, BackreferenceMixin): + def __init__(self, conntype, client_conn, server_conn): + self.conntype = conntype + self.client_conn = client_conn + self.server_conn = server_conn + self.error = None + + _backrefattr = ("error",) + _backrefname = "flow" + + _stateobject_attributes = dict( + error=Error, + client_conn=ClientConnection, + server_conn=ServerConnection, + conntype=str + ) + + def _get_state(self): + d = super(Flow, self)._get_state() + d.update(version=version.IVERSION) + return d + + @classmethod + def _from_state(cls, state): + f = cls(None, None, None) + f._load_state(state) + return f + + def copy(self): + f = copy.copy(self) + if self.error: + f.error = self.error.copy() + return f + + def modified(self): + """ + Has this Flow been modified? + """ + if self._backup: + return self._backup != self._get_state() + else: + return False + + def backup(self, force=False): + """ + Save a backup of this Flow, which can be reverted to using a + call to .revert(). + """ + if not self._backup: + self._backup = self._get_state() + + def revert(self): + """ + Revert to the last backed up state. + """ + if self._backup: + self._load_state(self._backup) + self._backup = None + + class AppRegistry: def __init__(self): self.apps = {} @@ -542,6 +660,7 @@ class FlowMaster(controller.Master): rflow = self.server_playback.next_flow(flow) if not rflow: return None + # FIXME response = Response._from_state(flow.request, rflow.response._get_state()) response._set_replay() flow.response = response diff --git a/libmproxy/protocol/__init__.py b/libmproxy/protocol/__init__.py index ae0d99a6..da85500b 100644 --- a/libmproxy/protocol/__init__.py +++ b/libmproxy/protocol/__init__.py @@ -30,10 +30,10 @@ class ProtocolHandler(object): from . import http, tcp -protocols = dict( - http = dict(handler=http.HTTPHandler, flow=http.HTTPFlow), - tcp = dict(handler=tcp.TCPHandler), -) +protocols = { + 'http': dict(handler=http.HTTPHandler, flow=http.HTTPFlow), + 'tcp': dict(handler=tcp.TCPHandler) +} # PyCharm type hinting behaves bad if this is a dict constructor... def _handler(conntype, connection_handler): diff --git a/libmproxy/protocol/http.py b/libmproxy/protocol/http.py index 29cdf446..5faf78e0 100644 --- a/libmproxy/protocol/http.py +++ b/libmproxy/protocol/http.py @@ -5,7 +5,9 @@ from netlib import http, tcp, http_status, odict from netlib.odict import ODict, ODictCaseless from . import ProtocolHandler, ConnectionTypeChange, KILL from .. import encoding, utils, version, filt, controller, stateobject -from ..proxy import ProxyError, ClientConnection, ServerConnection +from ..proxy import ProxyError +from ..flow import Flow, Error + HDR_FORM_URLENCODED = "application/x-www-form-urlencoded" CONTENT_MISSING = 0 @@ -51,117 +53,11 @@ class decoded(object): if self.ce: self.o.encode(self.ce) -# FIXME: Move out of http -class BackreferenceMixin(object): - """ - If an attribute from the _backrefattr tuple is set, - this mixin sets a reference back on the attribute object. - Example: - e = Error() - f = Flow() - f.error = e - assert f is e.flow - """ - _backrefattr = tuple() - - def __setattr__(self, key, value): - super(BackreferenceMixin, self).__setattr__(key, value) - if key in self._backrefattr and value is not None: - setattr(value, self._backrefname, self) - -# FIXME: Move out of http -class Error(stateobject.SimpleStateObject): - """ - An Error. - - This is distinct from an HTTP error response (say, a code 500), which - is represented by a normal Response object. This class is responsible - for indicating errors that fall outside of normal HTTP communications, - like interrupted connections, timeouts, protocol errors. - - Exposes the following attributes: - - flow: Flow object - msg: Message describing the error - timestamp: Seconds since the epoch - """ - def __init__(self, msg, timestamp=None): - self.msg = msg - self.timestamp = timestamp or utils.timestamp() - - _stateobject_attributes = dict( - msg=str, - timestamp=float - ) - - def copy(self): - c = copy.copy(self) - return c - -# FIXME: Move out of http -class Flow(stateobject.SimpleStateObject, BackreferenceMixin): - def __init__(self, conntype, client_conn, server_conn, error): - self.conntype = conntype - self.client_conn = client_conn - self.server_conn = server_conn - self.error = error - - _backrefattr = ("error",) - _backrefname = "flow" - - _stateobject_attributes = dict( - error=Error, - client_conn=ClientConnection, - server_conn=ServerConnection, - conntype=str - ) - - def _get_state(self): - d = super(Flow, self)._get_state() - d.update(version=version.IVERSION) - return d - - @classmethod - def _from_state(cls, state): - f = cls(None, None, None, None) - f._load_state(state) - return f - - def copy(self): - f = copy.copy(self) - if self.error: - f.error = self.error.copy() - return f - - def modified(self): - """ - Has this Flow been modified? - """ - if self._backup: - return self._backup != self._get_state() - else: - return False - - def backup(self, force=False): - """ - Save a backup of this Flow, which can be reverted to using a - call to .revert(). - """ - if not self._backup: - self._backup = self._get_state() - - def revert(self): - """ - Revert to the last backed up state. - """ - if self._backup: - self._load_state(self._backup) - self._backup = None - class HTTPMessage(stateobject.SimpleStateObject): def __init__(self): self.flow = None # Will usually set by backref mixin + """@type: HTTPFlow""" def get_decoded_content(self): """ @@ -397,7 +293,7 @@ class HTTPRequest(HTTPMessage): form = form or self.form_out if form == "asterisk" or \ - form == "origin": + form == "origin": request_line = '%s %s HTTP/%s.%s' % (self.method, self.path, self.httpversion[0], self.httpversion[1]) elif form == "authority": request_line = '%s %s:%s HTTP/%s.%s' % (self.method, self.host, self.port, @@ -422,7 +318,9 @@ class HTTPRequest(HTTPMessage): ] ) if not 'host' in headers: - headers["Host"] = [utils.hostport(self.scheme, self.host, self.port)] + headers["Host"] = [utils.hostport(self.scheme, + self.host or self.flow.server_conn.address.host, + self.port or self.flow.server_conn.address.port)] if self.content: headers["Content-Length"] = [str(len(self.content))] @@ -442,7 +340,7 @@ class HTTPRequest(HTTPMessage): Raises an Exception if the request cannot be assembled. """ if self.content == CONTENT_MISSING: - raise Exception("CONTENT_MISSING") # FIXME correct exception class + raise RuntimeError("Cannot assemble flow with CONTENT_MISSING") head = self._assemble_head(form) if self.content: return head + self.content @@ -481,7 +379,6 @@ class HTTPRequest(HTTPMessage): e for e in encoding.ENCODINGS if e in self.headers["accept-encoding"][0] )] - def get_form_urlencoded(self): """ Retrieves the URL-encoded form data, returning an ODict object. @@ -542,15 +439,19 @@ class HTTPRequest(HTTPMessage): def get_url(self, hostheader=False): """ - Returns a URL string, constructed from the Request's URL compnents. + Returns a URL string, constructed from the Request's URL components. If hostheader is True, we use the value specified in the request Host header to construct the URL. """ + host = None if hostheader: - host = self.headers.get_first("host") or self.host - else: - host = self.host + host = self.headers.get_first("host") + if not host: + if self.host: + host = self.host + else: + host = self.flow.server_conn.address.host host = host.encode("idna") return utils.unparse_url(self.scheme, host, self.port, self.path).encode('ascii') @@ -678,7 +579,7 @@ class HTTPResponse(HTTPMessage): ) if self.content: headers["Content-Length"] = [str(len(self.content))] - elif 'Transfer-Encoding' in self.headers: # add content-length for chuncked transfer-encoding with no content + elif 'Transfer-Encoding' in self.headers: # add content-length for chuncked transfer-encoding with no content headers["Content-Length"] = ["0"] return str(headers) @@ -694,7 +595,7 @@ class HTTPResponse(HTTPMessage): Raises an Exception if the request cannot be assembled. """ if self.content == CONTENT_MISSING: - raise Exception("CONTENT_MISSING") # FIXME correct exception class + raise RuntimeError("Cannot assemble flow with CONTENT_MISSING") head = self._assemble_head() if self.content: return head + self.content @@ -759,8 +660,8 @@ class HTTPResponse(HTTPMessage): cookies = [] for header in cookie_headers: pairs = [pair.partition("=") for pair in header.split(';')] - cookie_name = pairs[0][0] # the key of the first key/value pairs - cookie_value = pairs[0][2] # the value of the first key/value pairs + cookie_name = pairs[0][0] # the key of the first key/value pairs + cookie_value = pairs[0][2] # the value of the first key/value pairs cookie_parameters = {key.strip().lower(): value.strip() for key, sep, value in pairs[1:]} cookies.append((cookie_name, (cookie_value, cookie_parameters))) return dict(cookies) @@ -783,12 +684,12 @@ class HTTPFlow(Flow): intercepting: Is this flow currently being intercepted? """ - def __init__(self, client_conn, server_conn, error, request, response): - Flow.__init__(self, "http", client_conn, server_conn, error) - self.request = request - self.response = response + def __init__(self, client_conn, server_conn): + Flow.__init__(self, "http", client_conn, server_conn) + self.request = None + self.response = None - self.intercepting = False # FIXME: Should that rather be an attribute of Flow? + self.intercepting = False # FIXME: Should that rather be an attribute of Flow? self._backup = None _backrefattr = Flow._backrefattr + ("request", "response") @@ -801,7 +702,7 @@ class HTTPFlow(Flow): @classmethod def _from_state(cls, state): - f = cls(None, None, None, None, None) + f = cls(None, None) f._load_state(state) return f @@ -839,7 +740,7 @@ class HTTPFlow(Flow): self.request.reply(KILL) elif self.response and not self.response.reply.acked: self.response.reply(KILL) - master.handle_error(self) + master.handle_error(self.error) self.intercepting = False def intercept(self): @@ -932,7 +833,7 @@ class HTTPHandler(ProtocolHandler): self.process_request(flow.request) flow.response = self.get_response_from_server(flow.request) - self.c.log("response", [flow.response._assemble_response_line() if not LEGACY else flow.response._assemble().splitlines()[0]]) + self.c.log("response", [flow.response._assemble_response_line()]) response_reply = self.c.channel.ask("response" if LEGACY else "httpresponse", flow.response if LEGACY else flow) if response_reply is None or response_reply == KILL: @@ -982,7 +883,7 @@ class HTTPHandler(ProtocolHandler): else: err = message - self.c.log("error: %s" %err) + self.c.log("error: %s" % err) if flow: flow.error = Error(err) diff --git a/libmproxy/proxy.py b/libmproxy/proxy.py index e69bb6da..e3e40c7b 100644 --- a/libmproxy/proxy.py +++ b/libmproxy/proxy.py @@ -172,7 +172,7 @@ class RequestReplayThread(threading.Thread): ) self.channel.ask("response", response) except (ProxyError, http.HttpError, tcp.NetLibError), v: - err = flow.Error(self.flow.request, str(v)) + err = flow.Error(str(v)) self.channel.ask("error", err) """ @@ -291,9 +291,10 @@ class ConnectionHandler: A protocol handler must raise a ConnTypeChanged exception if it detects that this is happening """ # TODO: Implement SSL pass-through handling and change conntype - passthrough = ["echo.websocket.org", - "174.129.224.73" # echo.websocket.org, transparent mode - ] + passthrough = [ + "echo.websocket.org", + "174.129.224.73" # echo.websocket.org, transparent mode + ] if self.server_conn.address.host in passthrough or self.sni in passthrough: self.conntype = "tcp" return diff --git a/libmproxy/stateobject.py b/libmproxy/stateobject.py index ef8879b8..2cbec068 100644 --- a/libmproxy/stateobject.py +++ b/libmproxy/stateobject.py @@ -1,4 +1,4 @@ -class StateObject: +class StateObject(object): def _get_state(self): raise NotImplementedError @@ -56,7 +56,7 @@ class SimpleStateObject(StateObject): helper for _load_state. loads the given attribute from the state. """ - if state[attr] is None: + if state.get(attr, None) is None: setattr(self, attr, None) return diff --git a/test/test_dump.py b/test/test_dump.py index 031a3f6a..f6688b1a 100644 --- a/test/test_dump.py +++ b/test/test_dump.py @@ -6,11 +6,11 @@ import mock def test_strfuncs(): t = tutils.tresp() - t._set_replay() + t.is_replay = True dump.str_response(t) t = tutils.treq() - t.client_conn = None + t.flow.client_conn = None t.stickycookie = True assert "stickycookie" in dump.str_request(t, False) assert "stickycookie" in dump.str_request(t, True) @@ -20,24 +20,20 @@ def test_strfuncs(): class TestDumpMaster: def _cycle(self, m, content): - req = tutils.treq() - req.content = content + req = tutils.treq(content=content) l = proxy.Log("connect") l.reply = mock.MagicMock() m.handle_log(l) - cc = req.client_conn - cc.connection_error = "error" - resp = tutils.tresp(req) - resp.content = content + cc = req.flow.client_conn + cc.reply = mock.MagicMock() + resp = tutils.tresp(req, content=content) m.handle_clientconnect(cc) sc = proxy.ServerConnection((req.host, req.port)) sc.reply = mock.MagicMock() m.handle_serverconnection(sc) m.handle_request(req) f = m.handle_response(resp) - cd = flow.ClientDisconnect(cc) - cd.reply = mock.MagicMock() - m.handle_clientdisconnect(cd) + m.handle_clientdisconnect(cc) return f def _dummy_cycle(self, n, filt, content, **options): diff --git a/test/test_filt.py b/test/test_filt.py index 4e059196..96fc58f9 100644 --- a/test/test_filt.py +++ b/test/test_filt.py @@ -1,6 +1,7 @@ import cStringIO from libmproxy import filt, flow - +from libmproxy.protocol import http +import tutils class TestParsing: def _dump(self, x): @@ -72,41 +73,37 @@ class TestParsing: class TestMatching: def req(self): - conn = flow.ClientConnect(("one", 2222)) headers = flow.ODictCaseless() headers["header"] = ["qvalue"] - req = flow.Request( - conn, - (1, 1), - "host", - 80, - "http", - "GET", - "/path", - headers, - "content_request" + req = http.HTTPRequest( + "absolute", + "GET", + "http", + "host", + 80, + "/path", + (1, 1), + headers, + "content_request", + None, + None ) - return flow.Flow(req) + f = http.HTTPFlow(tutils.tclient_conn(), None) + f.request = req + return f def resp(self): f = self.req() headers = flow.ODictCaseless() headers["header_response"] = ["svalue"] - f.response = flow.Response( - f.request, - (1, 1), - 200, - "message", - headers, - "content_response", - None - ) + f.response = http.HTTPResponse((1, 1), 200, "OK", headers, "content_response", None, None) + return f def err(self): f = self.req() - f.error = flow.Error(f.request, "msg") + f.error = flow.Error("msg") return f def q(self, q, o): diff --git a/test/test_flow.py b/test/test_flow.py index 42118e36..3bcb47d8 100644 --- a/test/test_flow.py +++ b/test/test_flow.py @@ -172,7 +172,9 @@ class TestFlow: def test_copy(self): f = tutils.tflow_full() f2 = f.copy() + assert f == f2 assert not f is f2 + assert f.request == f2.request assert not f.request is f2.request assert f.request.headers == f2.request.headers assert not f.request.headers is f2.request.headers @@ -189,9 +191,7 @@ class TestFlow: assert not f.error is f2.error def test_match(self): - f = tutils.tflow() - f.response = tutils.tresp() - f.request = f.response.request + f = tutils.tflow_full() assert not f.match("~b test") assert f.match(None) assert not f.match("~b test") @@ -201,7 +201,6 @@ class TestFlow: tutils.raises(ValueError, f.match, "~") - def test_backup(self): f = tutils.tflow() f.response = tutils.tresp() @@ -228,12 +227,12 @@ class TestFlow: assert f._get_state() == flow.Flow._from_state(state)._get_state() f.response = None - f.error = flow.Error(f.request, "error") + f.error = flow.Error("error") state = f._get_state() assert f._get_state() == flow.Flow._from_state(state)._get_state() f2 = tutils.tflow() - f2.error = flow.Error(f.request, "e2") + f2.error = flow.Error("e2") assert not f == f2 f._load_state(f2._get_state()) assert f._get_state() == f2._get_state() @@ -286,10 +285,6 @@ class TestFlow: f.accept_intercept() assert f.response.reply.acked - def test_serialization(self): - f = flow.Flow(None) - f.request = tutils.treq() - def test_replace_unicode(self): f = tutils.tflow_full() f.response.content = "\xc2foo" @@ -310,10 +305,6 @@ class TestFlow: assert f.response.headers["bar"] == ["bar"] assert f.response.content == "abarb" - f = tutils.tflow_err() - f.replace("error", "bar") - assert f.error.msg == "bar" - def test_replace_encoded(self): f = tutils.tflow_full() f.request.content = "afoob" @@ -378,16 +369,16 @@ class TestState: c = flow.State() req = tutils.treq() f = c.add_request(req) - e = flow.Error(f.request, "message") + e = flow.Error("message") assert c.add_error(e) - e = flow.Error(tutils.tflow().request, "message") + e = flow.Error("message") assert not c.add_error(e) c = flow.State() req = tutils.treq() f = c.add_request(req) - e = flow.Error(f.request, "message") + e = flow.Error("message") c.set_limit("~e") assert not c.view assert not c.view @@ -444,7 +435,7 @@ class TestState: def _add_error(self, state): req = tutils.treq() f = state.add_request(req) - f.error = flow.Error(f.request, "msg") + f.error = flow.Error("msg") def test_clear(self): c = flow.State() @@ -615,7 +606,7 @@ class TestFlowMaster: assert len(fm.scripts) == 0 assert not fm.load_script(tutils.test_data.path("scripts/all.py")) - err = flow.Error(f.request, "msg") + err = flow.Error("msg") err.reply = controller.DummyReply() fm.handle_error(err) assert fm.scripts[0].ns["log"][-1] == "error" @@ -637,7 +628,7 @@ class TestFlowMaster: fm.anticache = True fm.anticomp = True req = tutils.treq() - fm.handle_clientconnect(req.client_conn) + fm.handle_clientconnect(req.flow.client_conn) f = fm.handle_request(req) assert s.flow_count() == 1 @@ -649,12 +640,12 @@ class TestFlowMaster: rx = tutils.tresp() assert not fm.handle_response(rx) - dc = flow.ClientDisconnect(req.client_conn) + dc = flow.ClientDisconnect(req.flow.client_conn) dc.reply = controller.DummyReply() req.client_conn.requestcount = 1 fm.handle_clientdisconnect(dc) - err = flow.Error(f.request, "msg") + err = flow.Error("msg") err.reply = controller.DummyReply() fm.handle_error(err) @@ -675,7 +666,7 @@ class TestFlowMaster: fm.tick(q) assert fm.state.flow_count() - err = flow.Error(f.request, "error") + err = flow.Error("error") err.reply = controller.DummyReply() fm.handle_error(err) @@ -886,7 +877,8 @@ class TestRequest: def test_anticache(self): h = flow.ODictCaseless() - r = flow.Request(None, (1, 1), "host", 22, "https", "GET", "/", h, "content") + r = tutils.treq() + r.headers = h h["if-modified-since"] = ["test"] h["if-none-match"] = ["test"] r.anticache() @@ -896,8 +888,8 @@ class TestRequest: def test_getset_state(self): h = flow.ODictCaseless() h["test"] = ["test"] - c = flow.ClientConnect(("addr", 2222)) - r = flow.Request(c, (1, 1), "host", 22, "https", "GET", "/", h, "content") + r = tutils.treq() + r.headers = h state = r._get_state() assert flow.Request._from_state(state) == r @@ -905,7 +897,8 @@ class TestRequest: state = r._get_state() assert flow.Request._from_state(state) == r - r2 = flow.Request(c, (1, 1), "testing", 20, "http", "PUT", "/foo", h, "test") + r2 = tutils.treq() + r2.headers = h assert not r == r2 r._load_state(r2._get_state()) assert r == r2 @@ -971,15 +964,15 @@ class TestRequest: def test_get_cookies_none(self): h = flow.ODictCaseless() - c = flow.ClientConnect(("addr", 2222)) - r = flow.Request(c, (1, 1), "host", 22, "https", "GET", "/", h, "content") - assert r.get_cookies() == None + r = tutils.treq() + r.headers = h + assert r.get_cookies() is None def test_get_cookies_single(self): h = flow.ODictCaseless() h["Cookie"] = ["cookiename=cookievalue"] - c = flow.ClientConnect(("addr", 2222)) - r = flow.Request(c, (1, 1), "host", 22, "https", "GET", "/", h, "content") + r = tutils.treq() + r.headers = h result = r.get_cookies() assert len(result)==1 assert result['cookiename']==('cookievalue',{}) @@ -987,8 +980,8 @@ class TestRequest: def test_get_cookies_double(self): h = flow.ODictCaseless() h["Cookie"] = ["cookiename=cookievalue;othercookiename=othercookievalue"] - c = flow.ClientConnect(("addr", 2222)) - r = flow.Request(c, (1, 1), "host", 22, "https", "GET", "/", h, "content") + r = tutils.treq() + r.headers = h result = r.get_cookies() assert len(result)==2 assert result['cookiename']==('cookievalue',{}) @@ -997,26 +990,28 @@ class TestRequest: def test_get_cookies_withequalsign(self): h = flow.ODictCaseless() h["Cookie"] = ["cookiename=coo=kievalue;othercookiename=othercookievalue"] - c = flow.ClientConnect(("addr", 2222)) - r = flow.Request(c, (1, 1), "host", 22, "https", "GET", "/", h, "content") + r = tutils.treq() + r.headers = h result = r.get_cookies() assert len(result)==2 assert result['cookiename']==('coo=kievalue',{}) assert result['othercookiename']==('othercookievalue',{}) - def test_get_header_size(self): + def test_header_size(self): h = flow.ODictCaseless() h["headername"] = ["headervalue"] - c = flow.ClientConnect(("addr", 2222)) - r = flow.Request(c, (1, 1), "host", 22, "https", "GET", "/", h, "content") - result = r.get_header_size() - assert result==43 + r = tutils.treq() + r.headers = h + result = len(r._assemble_headers()) + print result + print r._assemble_headers() + assert result == 62 def test_get_transmitted_size(self): h = flow.ODictCaseless() h["headername"] = ["headervalue"] - c = flow.ClientConnect(("addr", 2222)) - r = flow.Request(c, (1, 1), "host", 22, "https", "GET", "/", h, "content") + r = tutils.treq() + r.headers = h result = r.get_transmitted_size() assert result==len("content") r.content = None @@ -1025,9 +1020,9 @@ class TestRequest: def test_get_content_type(self): h = flow.ODictCaseless() h["Content-Type"] = ["text/plain"] - c = flow.ClientConnect(("addr", 2222)) - r = flow.Request(c, (1, 1), "host", 22, "https", "GET", "/", h, "content") - assert r.get_content_type()=="text/plain" + resp = tutils.tresp() + resp.headers = h + assert resp.get_content_type()=="text/plain" class TestResponse: def test_simple(self): @@ -1125,20 +1120,22 @@ class TestResponse: assert not r.headers["content-encoding"] assert r.content == "falafel" - def test_get_header_size(self): + def test_header_size(self): r = tutils.tresp() - result = r.get_header_size() - assert result==49 + result = len(r._assemble_headers()) + assert result==44 def test_get_cookies_none(self): h = flow.ODictCaseless() - resp = flow.Response(None, (1, 1), 200, "OK", h, "content", None) + resp = tutils.tresp() + resp.headers = h assert not resp.get_cookies() def test_get_cookies_simple(self): h = flow.ODictCaseless() h["Set-Cookie"] = ["cookiename=cookievalue"] - resp = flow.Response(None, (1, 1), 200, "OK", h, "content", None) + resp = tutils.tresp() + resp.headers = h result = resp.get_cookies() assert len(result)==1 assert "cookiename" in result @@ -1147,7 +1144,8 @@ class TestResponse: def test_get_cookies_with_parameters(self): h = flow.ODictCaseless() h["Set-Cookie"] = ["cookiename=cookievalue;domain=example.com;expires=Wed Oct 21 16:29:41 2015;path=/; HttpOnly"] - resp = flow.Response(None, (1, 1), 200, "OK", h, "content", None) + resp = tutils.tresp() + resp.headers = h result = resp.get_cookies() assert len(result)==1 assert "cookiename" in result @@ -1161,7 +1159,8 @@ class TestResponse: def test_get_cookies_no_value(self): h = flow.ODictCaseless() h["Set-Cookie"] = ["cookiename=; Expires=Thu, 01-Jan-1970 00:00:01 GMT; path=/"] - resp = flow.Response(None, (1, 1), 200, "OK", h, "content", None) + resp = tutils.tresp() + resp.headers = h result = resp.get_cookies() assert len(result)==1 assert "cookiename" in result @@ -1171,7 +1170,8 @@ class TestResponse: def test_get_cookies_twocookies(self): h = flow.ODictCaseless() h["Set-Cookie"] = ["cookiename=cookievalue","othercookie=othervalue"] - resp = flow.Response(None, (1, 1), 200, "OK", h, "content", None) + resp = tutils.tresp() + resp.headers = h result = resp.get_cookies() assert len(result)==2 assert "cookiename" in result @@ -1182,19 +1182,20 @@ class TestResponse: def test_get_content_type(self): h = flow.ODictCaseless() h["Content-Type"] = ["text/plain"] - resp = flow.Response(None, (1, 1), 200, "OK", h, "content", None) + resp = tutils.tresp() + resp.headers = h assert resp.get_content_type()=="text/plain" class TestError: def test_getset_state(self): - e = flow.Error(None, "Error") + e = flow.Error("Error") state = e._get_state() - assert flow.Error._from_state(None, state) == e + assert flow.Error._from_state(state) == e assert e.copy() - e2 = flow.Error(None, "bar") + e2 = flow.Error("bar") assert not e == e2 e._load_state(e2._get_state()) assert e == e2 @@ -1203,11 +1204,6 @@ class TestError: e3 = e.copy() assert e3 == e - def test_replace(self): - e = flow.Error(None, "amoop") - e.replace("moo", "bar") - assert e.msg == "abarp" - class TestClientConnect: def test_state(self): diff --git a/test/tutils.py b/test/tutils.py index fb41d77a..78bf5909 100644 --- a/test/tutils.py +++ b/test/tutils.py @@ -1,6 +1,7 @@ import os, shutil, tempfile from contextlib import contextmanager -from libmproxy import flow, utils, controller +from libmproxy import flow, utils, controller, proxy +from libmproxy.protocol import http if os.name != "nt": from libmproxy.console.flowview import FlowView from libmproxy.console import ConsoleState @@ -16,40 +17,65 @@ def SkipWindows(fn): else: return fn + +def tclient_conn(): + return proxy.ClientConnection._from_state(dict( + address=dict(address=("address", 22), use_ipv6=True), + clientcert=None + )) + +def tserver_conn(): + return proxy.ServerConnection._from_state(dict( + address=dict(address=("address", 22), use_ipv6=True), + source_address=dict(address=("address", 22), use_ipv6=True), + cert=None + )) + + def treq(conn=None, content="content"): if not conn: - conn = flow.ClientConnect(("address", 22)) - conn.reply = controller.DummyReply() + conn = tclient_conn() + server_conn = tserver_conn() headers = flow.ODictCaseless() headers["header"] = ["qvalue"] - r = flow.Request(conn, (1, 1), "host", 80, "http", "GET", "/path", headers, - content) - r.reply = controller.DummyReply() - return r + + f = http.HTTPFlow(conn, server_conn) + f.request = http.HTTPRequest("origin", "GET", None, None, None, "/path", (1, 1), headers, content, + None, None, None) + f.request.reply = controller.DummyReply() + return f.request -def tresp(req=None): +def tresp(req=None, content="message"): if not req: req = treq() + f = req.flow + headers = flow.ODictCaseless() headers["header_response"] = ["svalue"] - cert = certutils.SSLCert.from_der(file(test_data.path("data/dercert"),"rb").read()) - resp = flow.Response(req, (1, 1), 200, "message", headers, "content_response", cert) - resp.reply = controller.DummyReply() - return resp + cert = certutils.SSLCert.from_der(file(test_data.path("data/dercert"), "rb").read()) + f.server_conn = proxy.ServerConnection._from_state(dict( + address=dict(address=("address", 22), use_ipv6=True), + source_address=None, + cert=cert.to_pem())) + f.response = http.HTTPResponse((1, 1), 200, "OK", headers, content, None, None) + f.response.reply = controller.DummyReply() + return f.response + def terr(req=None): if not req: req = treq() - err = flow.Error(req, "error") - err.reply = controller.DummyReply() - return err + f = req.flow + f.error = flow.Error("error") + f.error.reply = controller.DummyReply() + return f.error -def tflow(r=None): - if r == None: - r = treq() - return flow.Flow(r) +def tflow(req=None): + if not req: + req = treq() + return req.flow def tflow_full(): |