diff options
author | Maximilian Hils <git@maximilianhils.com> | 2016-02-01 19:38:14 +0100 |
---|---|---|
committer | Maximilian Hils <git@maximilianhils.com> | 2016-02-01 19:38:14 +0100 |
commit | bda49dd178fee1361f3585bd7efad67883298e5a (patch) | |
tree | 46153699b7e8cd96fc534844911a2da42c7af9bb /netlib | |
parent | 7c83a709ea06f3b538f446860f3c7ed463a29b1f (diff) | |
download | mitmproxy-bda49dd178fee1361f3585bd7efad67883298e5a.tar.gz mitmproxy-bda49dd178fee1361f3585bd7efad67883298e5a.tar.bz2 mitmproxy-bda49dd178fee1361f3585bd7efad67883298e5a.zip |
fix #113, make Reader.peek() work on Python 3
Diffstat (limited to 'netlib')
-rw-r--r-- | netlib/tcp.py | 30 |
1 files changed, 25 insertions, 5 deletions
diff --git a/netlib/tcp.py b/netlib/tcp.py index 8902b9dc..57a9b737 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -25,6 +25,10 @@ from netlib.exceptions import InvalidCertificateException, TcpReadIncomplete, Tl version_check.check_pyopenssl_version() +if six.PY2: + socket_fileobject = socket._fileobject +else: + socket_fileobject = socket.SocketIO EINTR = 4 @@ -270,7 +274,7 @@ class Reader(_FileLike): TlsException if there was an error with pyOpenSSL. NotImplementedError if the underlying file object is not a (pyOpenSSL) socket """ - if isinstance(self.o, socket._fileobject): + if isinstance(self.o, socket_fileobject): try: return self.o._sock.recv(length, socket.MSG_PEEK) except socket.error as e: @@ -423,8 +427,17 @@ class _Connection(object): def __init__(self, connection): if connection: self.connection = connection - self.rfile = Reader(self.connection.makefile('rb', self.rbufsize)) - self.wfile = Writer(self.connection.makefile('wb', self.wbufsize)) + # Ideally, we would use the Buffered IO in Python 3 by default. + # Unfortunately, the implementation of .peek() is broken for n>1 bytes, + # as it may just return what's left in the buffer and not all the bytes we want. + # As a workaround, we just use unbuffered sockets directly. + # https://mail.python.org/pipermail/python-dev/2009-June/089986.html + if six.PY2: + self.rfile = Reader(self.connection.makefile('rb', self.rbufsize)) + self.wfile = Writer(self.connection.makefile('wb', self.wbufsize)) + else: + self.rfile = Reader(socket.SocketIO(self.connection, "rb")) + self.wfile = Writer(socket.SocketIO(self.connection, "wb")) else: self.connection = None self.rfile = None @@ -663,8 +676,15 @@ class TCPClient(_Connection): connection.connect(self.address()) if not self.source_address: self.source_address = Address(connection.getsockname()) - self.rfile = Reader(connection.makefile('rb', self.rbufsize)) - self.wfile = Writer(connection.makefile('wb', self.wbufsize)) + + # See _Connection.__init__ why we do this dance. + if six.PY2: + self.rfile = Reader(connection.makefile('rb', self.rbufsize)) + self.wfile = Writer(connection.makefile('wb', self.wbufsize)) + else: + self.rfile = Reader(socket.SocketIO(connection, "rb")) + self.wfile = Writer(socket.SocketIO(connection, "wb")) + except (socket.error, IOError) as err: raise TcpException( 'Error connecting to "%s": %s' % |