aboutsummaryrefslogtreecommitdiffstats
path: root/netlib
diff options
context:
space:
mode:
authorMaximilian Hils <git@maximilianhils.com>2016-02-01 19:38:14 +0100
committerMaximilian Hils <git@maximilianhils.com>2016-02-01 19:38:14 +0100
commitbda49dd178fee1361f3585bd7efad67883298e5a (patch)
tree46153699b7e8cd96fc534844911a2da42c7af9bb /netlib
parent7c83a709ea06f3b538f446860f3c7ed463a29b1f (diff)
downloadmitmproxy-bda49dd178fee1361f3585bd7efad67883298e5a.tar.gz
mitmproxy-bda49dd178fee1361f3585bd7efad67883298e5a.tar.bz2
mitmproxy-bda49dd178fee1361f3585bd7efad67883298e5a.zip
fix #113, make Reader.peek() work on Python 3
Diffstat (limited to 'netlib')
-rw-r--r--netlib/tcp.py30
1 files changed, 25 insertions, 5 deletions
diff --git a/netlib/tcp.py b/netlib/tcp.py
index 8902b9dc..57a9b737 100644
--- a/netlib/tcp.py
+++ b/netlib/tcp.py
@@ -25,6 +25,10 @@ from netlib.exceptions import InvalidCertificateException, TcpReadIncomplete, Tl
version_check.check_pyopenssl_version()
+if six.PY2:
+ socket_fileobject = socket._fileobject
+else:
+ socket_fileobject = socket.SocketIO
EINTR = 4
@@ -270,7 +274,7 @@ class Reader(_FileLike):
TlsException if there was an error with pyOpenSSL.
NotImplementedError if the underlying file object is not a (pyOpenSSL) socket
"""
- if isinstance(self.o, socket._fileobject):
+ if isinstance(self.o, socket_fileobject):
try:
return self.o._sock.recv(length, socket.MSG_PEEK)
except socket.error as e:
@@ -423,8 +427,17 @@ class _Connection(object):
def __init__(self, connection):
if connection:
self.connection = connection
- self.rfile = Reader(self.connection.makefile('rb', self.rbufsize))
- self.wfile = Writer(self.connection.makefile('wb', self.wbufsize))
+ # Ideally, we would use the Buffered IO in Python 3 by default.
+ # Unfortunately, the implementation of .peek() is broken for n>1 bytes,
+ # as it may just return what's left in the buffer and not all the bytes we want.
+ # As a workaround, we just use unbuffered sockets directly.
+ # https://mail.python.org/pipermail/python-dev/2009-June/089986.html
+ if six.PY2:
+ self.rfile = Reader(self.connection.makefile('rb', self.rbufsize))
+ self.wfile = Writer(self.connection.makefile('wb', self.wbufsize))
+ else:
+ self.rfile = Reader(socket.SocketIO(self.connection, "rb"))
+ self.wfile = Writer(socket.SocketIO(self.connection, "wb"))
else:
self.connection = None
self.rfile = None
@@ -663,8 +676,15 @@ class TCPClient(_Connection):
connection.connect(self.address())
if not self.source_address:
self.source_address = Address(connection.getsockname())
- self.rfile = Reader(connection.makefile('rb', self.rbufsize))
- self.wfile = Writer(connection.makefile('wb', self.wbufsize))
+
+ # See _Connection.__init__ why we do this dance.
+ if six.PY2:
+ self.rfile = Reader(connection.makefile('rb', self.rbufsize))
+ self.wfile = Writer(connection.makefile('wb', self.wbufsize))
+ else:
+ self.rfile = Reader(socket.SocketIO(connection, "rb"))
+ self.wfile = Writer(socket.SocketIO(connection, "wb"))
+
except (socket.error, IOError) as err:
raise TcpException(
'Error connecting to "%s": %s' %