aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorThomas Kriechbaumer <Kriechi@users.noreply.github.com>2016-08-31 13:49:03 +0200
committerGitHub <noreply@github.com>2016-08-31 13:49:03 +0200
commitb4b2e5fd3431ea3c8b8f00b8f92e8c7fc6f309ae (patch)
tree6ce6930f0f2db77c7d90bb940b359f4ea5eded73
parent6afbfc85266e783b1bad432ae1f6715bfb16a39f (diff)
parenta8deed1f4ef5992b1c9c74ac69491bdb5f0e8490 (diff)
downloadmitmproxy-b4b2e5fd3431ea3c8b8f00b8f92e8c7fc6f309ae.tar.gz
mitmproxy-b4b2e5fd3431ea3c8b8f00b8f92e8c7fc6f309ae.tar.bz2
mitmproxy-b4b2e5fd3431ea3c8b8f00b8f92e8c7fc6f309ae.zip
Merge pull request #1511 from arjun23496/count_in_replace
Fixes #1495 - Added count argument for replacing contents in body
-rw-r--r--netlib/http/headers.py10
-rw-r--r--netlib/http/message.py6
-rw-r--r--netlib/http/request.py6
-rw-r--r--test/netlib/http/test_headers.py5
-rw-r--r--test/netlib/http/test_message.py10
-rw-r--r--test/netlib/http/test_request.py10
6 files changed, 38 insertions, 9 deletions
diff --git a/netlib/http/headers.py b/netlib/http/headers.py
index 36e5060c..131e8ce5 100644
--- a/netlib/http/headers.py
+++ b/netlib/http/headers.py
@@ -158,7 +158,7 @@ class Headers(multidict.MultiDict):
else:
return super(Headers, self).items()
- def replace(self, pattern, repl, flags=0):
+ def replace(self, pattern, repl, flags=0, count=0):
"""
Replaces a regular expression pattern with repl in each "name: value"
header line.
@@ -172,10 +172,10 @@ class Headers(multidict.MultiDict):
repl = strutils.escaped_str_to_bytes(repl)
pattern = re.compile(pattern, flags)
replacements = 0
-
+ flag_count = count > 0
fields = []
for name, value in self.fields:
- line, n = pattern.subn(repl, name + b": " + value)
+ line, n = pattern.subn(repl, name + b": " + value, count=count)
try:
name, value = line.split(b": ", 1)
except ValueError:
@@ -184,6 +184,10 @@ class Headers(multidict.MultiDict):
pass
else:
replacements += n
+ if flag_count:
+ count -= n
+ if count == 0:
+ break
fields.append((name, value))
self.fields = tuple(fields)
return replacements
diff --git a/netlib/http/message.py b/netlib/http/message.py
index ce92bab1..0b64d4a6 100644
--- a/netlib/http/message.py
+++ b/netlib/http/message.py
@@ -260,7 +260,7 @@ class Message(basetypes.Serializable):
if "content-encoding" not in self.headers:
raise ValueError("Invalid content encoding {}".format(repr(e)))
- def replace(self, pattern, repl, flags=0):
+ def replace(self, pattern, repl, flags=0, count=0):
"""
Replaces a regular expression pattern with repl in both the headers
and the body of the message. Encoded body will be decoded
@@ -276,9 +276,9 @@ class Message(basetypes.Serializable):
replacements = 0
if self.content:
self.content, replacements = re.subn(
- pattern, repl, self.content, flags=flags
+ pattern, repl, self.content, flags=flags, count=count
)
- replacements += self.headers.replace(pattern, repl, flags)
+ replacements += self.headers.replace(pattern, repl, flags=flags, count=count)
return replacements
# Legacy
diff --git a/netlib/http/request.py b/netlib/http/request.py
index 666a5869..e0aaa8a9 100644
--- a/netlib/http/request.py
+++ b/netlib/http/request.py
@@ -80,7 +80,7 @@ class Request(message.Message):
self.method, hostport, path
)
- def replace(self, pattern, repl, flags=0):
+ def replace(self, pattern, repl, flags=0, count=0):
"""
Replaces a regular expression pattern with repl in the headers, the
request path and the body of the request. Encoded content will be
@@ -94,9 +94,9 @@ class Request(message.Message):
if isinstance(repl, six.text_type):
repl = strutils.escaped_str_to_bytes(repl)
- c = super(Request, self).replace(pattern, repl, flags)
+ c = super(Request, self).replace(pattern, repl, flags, count)
self.path, pc = re.subn(
- pattern, repl, self.data.path, flags=flags
+ pattern, repl, self.data.path, flags=flags, count=count
)
c += pc
return c
diff --git a/test/netlib/http/test_headers.py b/test/netlib/http/test_headers.py
index 51537310..ad2bc548 100644
--- a/test/netlib/http/test_headers.py
+++ b/test/netlib/http/test_headers.py
@@ -75,6 +75,11 @@ class TestHeaders(object):
assert replacements == 0
assert headers["Host"] == "example.com"
+ def test_replace_with_count(self):
+ headers = Headers(Host="foobarfoo.com", Accept="foo/bar")
+ replacements = headers.replace("foo", "bar", count=1)
+ assert replacements == 1
+
def test_parse_content_type():
p = parse_content_type
diff --git a/test/netlib/http/test_message.py b/test/netlib/http/test_message.py
index 12e4706c..74272309 100644
--- a/test/netlib/http/test_message.py
+++ b/test/netlib/http/test_message.py
@@ -99,6 +99,16 @@ class TestMessage(object):
def test_http_version(self):
_test_decoded_attr(tresp(), "http_version")
+ def test_replace(self):
+ r = tresp()
+ r.content = b"foofootoo"
+ r.replace(b"foo", "gg")
+ assert r.content == b"ggggtoo"
+
+ r.content = b"foofootoo"
+ r.replace(b"foo", "gg", count=1)
+ assert r.content == b"ggfootoo"
+
class TestMessageContentEncoding(object):
def test_simple(self):
diff --git a/test/netlib/http/test_request.py b/test/netlib/http/test_request.py
index f3cd8b71..9baabaa6 100644
--- a/test/netlib/http/test_request.py
+++ b/test/netlib/http/test_request.py
@@ -26,6 +26,16 @@ class TestRequestCore(object):
request.host = None
assert repr(request) == "Request(GET /path)"
+ def replace(self):
+ r = treq()
+ r.path = b"foobarfoo"
+ r.replace(b"foo", "bar")
+ assert r.path == b"barbarbar"
+
+ r.path = b"foobarfoo"
+ r.replace(b"foo", "bar", count=1)
+ assert r.path == b"barbarfoo"
+
def test_first_line_format(self):
_test_passthrough_attr(treq(), "first_line_format")