diff options
author | Thomas Kriechbaumer <Kriechi@users.noreply.github.com> | 2016-08-31 13:49:03 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2016-08-31 13:49:03 +0200 |
commit | b4b2e5fd3431ea3c8b8f00b8f92e8c7fc6f309ae (patch) | |
tree | 6ce6930f0f2db77c7d90bb940b359f4ea5eded73 | |
parent | 6afbfc85266e783b1bad432ae1f6715bfb16a39f (diff) | |
parent | a8deed1f4ef5992b1c9c74ac69491bdb5f0e8490 (diff) | |
download | mitmproxy-b4b2e5fd3431ea3c8b8f00b8f92e8c7fc6f309ae.tar.gz mitmproxy-b4b2e5fd3431ea3c8b8f00b8f92e8c7fc6f309ae.tar.bz2 mitmproxy-b4b2e5fd3431ea3c8b8f00b8f92e8c7fc6f309ae.zip |
Merge pull request #1511 from arjun23496/count_in_replace
Fixes #1495 - Added count argument for replacing contents in body
-rw-r--r-- | netlib/http/headers.py | 10 | ||||
-rw-r--r-- | netlib/http/message.py | 6 | ||||
-rw-r--r-- | netlib/http/request.py | 6 | ||||
-rw-r--r-- | test/netlib/http/test_headers.py | 5 | ||||
-rw-r--r-- | test/netlib/http/test_message.py | 10 | ||||
-rw-r--r-- | test/netlib/http/test_request.py | 10 |
6 files changed, 38 insertions, 9 deletions
diff --git a/netlib/http/headers.py b/netlib/http/headers.py index 36e5060c..131e8ce5 100644 --- a/netlib/http/headers.py +++ b/netlib/http/headers.py @@ -158,7 +158,7 @@ class Headers(multidict.MultiDict): else: return super(Headers, self).items() - def replace(self, pattern, repl, flags=0): + def replace(self, pattern, repl, flags=0, count=0): """ Replaces a regular expression pattern with repl in each "name: value" header line. @@ -172,10 +172,10 @@ class Headers(multidict.MultiDict): repl = strutils.escaped_str_to_bytes(repl) pattern = re.compile(pattern, flags) replacements = 0 - + flag_count = count > 0 fields = [] for name, value in self.fields: - line, n = pattern.subn(repl, name + b": " + value) + line, n = pattern.subn(repl, name + b": " + value, count=count) try: name, value = line.split(b": ", 1) except ValueError: @@ -184,6 +184,10 @@ class Headers(multidict.MultiDict): pass else: replacements += n + if flag_count: + count -= n + if count == 0: + break fields.append((name, value)) self.fields = tuple(fields) return replacements diff --git a/netlib/http/message.py b/netlib/http/message.py index ce92bab1..0b64d4a6 100644 --- a/netlib/http/message.py +++ b/netlib/http/message.py @@ -260,7 +260,7 @@ class Message(basetypes.Serializable): if "content-encoding" not in self.headers: raise ValueError("Invalid content encoding {}".format(repr(e))) - def replace(self, pattern, repl, flags=0): + def replace(self, pattern, repl, flags=0, count=0): """ Replaces a regular expression pattern with repl in both the headers and the body of the message. Encoded body will be decoded @@ -276,9 +276,9 @@ class Message(basetypes.Serializable): replacements = 0 if self.content: self.content, replacements = re.subn( - pattern, repl, self.content, flags=flags + pattern, repl, self.content, flags=flags, count=count ) - replacements += self.headers.replace(pattern, repl, flags) + replacements += self.headers.replace(pattern, repl, flags=flags, count=count) return replacements # Legacy diff --git a/netlib/http/request.py b/netlib/http/request.py index 666a5869..e0aaa8a9 100644 --- a/netlib/http/request.py +++ b/netlib/http/request.py @@ -80,7 +80,7 @@ class Request(message.Message): self.method, hostport, path ) - def replace(self, pattern, repl, flags=0): + def replace(self, pattern, repl, flags=0, count=0): """ Replaces a regular expression pattern with repl in the headers, the request path and the body of the request. Encoded content will be @@ -94,9 +94,9 @@ class Request(message.Message): if isinstance(repl, six.text_type): repl = strutils.escaped_str_to_bytes(repl) - c = super(Request, self).replace(pattern, repl, flags) + c = super(Request, self).replace(pattern, repl, flags, count) self.path, pc = re.subn( - pattern, repl, self.data.path, flags=flags + pattern, repl, self.data.path, flags=flags, count=count ) c += pc return c diff --git a/test/netlib/http/test_headers.py b/test/netlib/http/test_headers.py index 51537310..ad2bc548 100644 --- a/test/netlib/http/test_headers.py +++ b/test/netlib/http/test_headers.py @@ -75,6 +75,11 @@ class TestHeaders(object): assert replacements == 0 assert headers["Host"] == "example.com" + def test_replace_with_count(self): + headers = Headers(Host="foobarfoo.com", Accept="foo/bar") + replacements = headers.replace("foo", "bar", count=1) + assert replacements == 1 + def test_parse_content_type(): p = parse_content_type diff --git a/test/netlib/http/test_message.py b/test/netlib/http/test_message.py index 12e4706c..74272309 100644 --- a/test/netlib/http/test_message.py +++ b/test/netlib/http/test_message.py @@ -99,6 +99,16 @@ class TestMessage(object): def test_http_version(self): _test_decoded_attr(tresp(), "http_version") + def test_replace(self): + r = tresp() + r.content = b"foofootoo" + r.replace(b"foo", "gg") + assert r.content == b"ggggtoo" + + r.content = b"foofootoo" + r.replace(b"foo", "gg", count=1) + assert r.content == b"ggfootoo" + class TestMessageContentEncoding(object): def test_simple(self): diff --git a/test/netlib/http/test_request.py b/test/netlib/http/test_request.py index f3cd8b71..9baabaa6 100644 --- a/test/netlib/http/test_request.py +++ b/test/netlib/http/test_request.py @@ -26,6 +26,16 @@ class TestRequestCore(object): request.host = None assert repr(request) == "Request(GET /path)" + def replace(self): + r = treq() + r.path = b"foobarfoo" + r.replace(b"foo", "bar") + assert r.path == b"barbarbar" + + r.path = b"foobarfoo" + r.replace(b"foo", "bar", count=1) + assert r.path == b"barbarfoo" + def test_first_line_format(self): _test_passthrough_attr(treq(), "first_line_format") |