aboutsummaryrefslogtreecommitdiffstats
path: root/netlib
diff options
context:
space:
mode:
authorAldo Cortesi <aldo@nullcube.com>2015-05-05 10:47:02 +1200
committerAldo Cortesi <aldo@nullcube.com>2015-05-05 10:47:02 +1200
commitf2bc58cdd2f2b9b0025a88c0faccf55e10b29353 (patch)
tree9353f4b78f4bf8ec1e6b4169155da084a2c05ea9 /netlib
parent08b2e2a6a98fd175e1b49d62dffde34e91c77b1c (diff)
downloadmitmproxy-f2bc58cdd2f2b9b0025a88c0faccf55e10b29353.tar.gz
mitmproxy-f2bc58cdd2f2b9b0025a88c0faccf55e10b29353.tar.bz2
mitmproxy-f2bc58cdd2f2b9b0025a88c0faccf55e10b29353.zip
Add tcp.Reader.safe_read, use it in socks and websockets
safe_read is guaranteed to raise or return a byte string of the requested length. It's particularly useful for implementing binary protocols.
Diffstat (limited to 'netlib')
-rw-r--r--netlib/socks.py32
-rw-r--r--netlib/tcp.py48
-rw-r--r--netlib/websockets.py16
3 files changed, 51 insertions, 45 deletions
diff --git a/netlib/socks.py b/netlib/socks.py
index 497b8eef..6f9f57bd 100644
--- a/netlib/socks.py
+++ b/netlib/socks.py
@@ -52,20 +52,6 @@ METHOD = utils.BiDi(
)
-def _read(f, n):
- try:
- d = f.read(n)
- if len(d) == n:
- return d
- else:
- raise SocksError(
- REP.GENERAL_SOCKS_SERVER_FAILURE,
- "Incomplete Read"
- )
- except socket.error as e:
- raise SocksError(REP.GENERAL_SOCKS_SERVER_FAILURE, str(e))
-
-
class ClientGreeting(object):
__slots__ = ("ver", "methods")
@@ -75,9 +61,9 @@ class ClientGreeting(object):
@classmethod
def from_file(cls, f):
- ver, nmethods = struct.unpack("!BB", _read(f, 2))
+ ver, nmethods = struct.unpack("!BB", f.safe_read(2))
methods = array.array("B")
- methods.fromstring(_read(f, nmethods))
+ methods.fromstring(f.safe_read(nmethods))
return cls(ver, methods)
def to_file(self, f):
@@ -94,7 +80,7 @@ class ServerGreeting(object):
@classmethod
def from_file(cls, f):
- ver, method = struct.unpack("!BB", _read(f, 2))
+ ver, method = struct.unpack("!BB", f.safe_read(2))
return cls(ver, method)
def to_file(self, f):
@@ -112,27 +98,27 @@ class Message(object):
@classmethod
def from_file(cls, f):
- ver, msg, rsv, atyp = struct.unpack("!BBBB", _read(f, 4))
+ ver, msg, rsv, atyp = struct.unpack("!BBBB", f.safe_read(4))
if rsv != 0x00:
raise SocksError(REP.GENERAL_SOCKS_SERVER_FAILURE,
"Socks Request: Invalid reserved byte: %s" % rsv)
if atyp == ATYP.IPV4_ADDRESS:
# We use tnoa here as ntop is not commonly available on Windows.
- host = socket.inet_ntoa(_read(f, 4))
+ host = socket.inet_ntoa(f.safe_read(4))
use_ipv6 = False
elif atyp == ATYP.IPV6_ADDRESS:
- host = socket.inet_ntop(socket.AF_INET6, _read(f, 16))
+ host = socket.inet_ntop(socket.AF_INET6, f.safe_read(16))
use_ipv6 = True
elif atyp == ATYP.DOMAINNAME:
- length, = struct.unpack("!B", _read(f, 1))
- host = _read(f, length)
+ length, = struct.unpack("!B", f.safe_read(1))
+ host = f.safe_read(length)
use_ipv6 = False
else:
raise SocksError(REP.ADDRESS_TYPE_NOT_SUPPORTED,
"Socks Request: Unknown ATYP: %s" % atyp)
- port, = struct.unpack("!H", _read(f, 2))
+ port, = struct.unpack("!H", f.safe_read(2))
addr = tcp.Address((host, port), use_ipv6=use_ipv6)
return cls(ver, msg, atyp, addr)
diff --git a/netlib/tcp.py b/netlib/tcp.py
index 84008e2c..dbe114a1 100644
--- a/netlib/tcp.py
+++ b/netlib/tcp.py
@@ -24,6 +24,7 @@ OP_NO_SSLv3 = SSL.OP_NO_SSLv3
class NetLibError(Exception): pass
class NetLibDisconnect(NetLibError): pass
+class NetLibIncomplete(NetLibError): pass
class NetLibTimeout(NetLibError): pass
class NetLibSSLError(NetLibError): pass
@@ -195,10 +196,23 @@ class Reader(_FileLike):
break
return result
+ def safe_read(self, length):
+ """
+ Like .read, but is guaranteed to either return length bytes, or
+ raise an exception.
+ """
+ result = self.read(length)
+ if length != -1 and len(result) != length:
+ raise NetLibIncomplete(
+ "Expected %s bytes, got %s"%(length, len(result))
+ )
+ return result
+
class Address(object):
"""
- This class wraps an IPv4/IPv6 tuple to provide named attributes and ipv6 information.
+ This class wraps an IPv4/IPv6 tuple to provide named attributes and
+ ipv6 information.
"""
def __init__(self, address, use_ipv6=False):
self.address = tuple(address)
@@ -247,22 +261,28 @@ def close_socket(sock):
"""
try:
# We already indicate that we close our end.
- sock.shutdown(socket.SHUT_WR) # may raise "Transport endpoint is not connected" on Linux
+ # may raise "Transport endpoint is not connected" on Linux
+ sock.shutdown(socket.SHUT_WR)
- # Section 4.2.2.13 of RFC 1122 tells us that a close() with any
- # pending readable data could lead to an immediate RST being sent (which is the case on Windows).
+ # Section 4.2.2.13 of RFC 1122 tells us that a close() with any pending
+ # readable data could lead to an immediate RST being sent (which is the
+ # case on Windows).
# http://ia600609.us.archive.org/22/items/TheUltimateSo_lingerPageOrWhyIsMyTcpNotReliable/the-ultimate-so_linger-page-or-why-is-my-tcp-not-reliable.html
#
- # This in turn results in the following issue: If we send an error page to the client and then close the socket,
- # the RST may be received by the client before the error page and the users sees a connection error rather than
- # the error page. Thus, we try to empty the read buffer on Windows first.
- # (see https://github.com/mitmproxy/mitmproxy/issues/527#issuecomment-93782988)
+ # This in turn results in the following issue: If we send an error page
+ # to the client and then close the socket, the RST may be received by
+ # the client before the error page and the users sees a connection
+ # error rather than the error page. Thus, we try to empty the read
+ # buffer on Windows first. (see
+ # https://github.com/mitmproxy/mitmproxy/issues/527#issuecomment-93782988)
#
+
if os.name == "nt": # pragma: no cover
- # We cannot rely on the shutdown()-followed-by-read()-eof technique proposed by the page above:
- # Some remote machines just don't send a TCP FIN, which would leave us in the unfortunate situation that
- # recv() would block infinitely.
- # As a workaround, we set a timeout here even if we are in blocking mode.
+ # We cannot rely on the shutdown()-followed-by-read()-eof technique
+ # proposed by the page above: Some remote machines just don't send
+ # a TCP FIN, which would leave us in the unfortunate situation that
+ # recv() would block infinitely. As a workaround, we set a timeout
+ # here even if we are in blocking mode.
sock.settimeout(sock.gettimeout() or 20)
# limit at a megabyte so that we don't read infinitely
@@ -292,10 +312,10 @@ class _Connection(object):
def finish(self):
self.finished = True
-
# If we have an SSL connection, wfile.close == connection.close
# (We call _FileLike.set_descriptor(conn))
- # Closing the socket is not our task, therefore we don't call close then.
+ # Closing the socket is not our task, therefore we don't call close
+ # then.
if type(self.connection) != SSL.Connection:
if not getattr(self.wfile, "closed", False):
try:
diff --git a/netlib/websockets.py b/netlib/websockets.py
index 0ad0e294..6d08e101 100644
--- a/netlib/websockets.py
+++ b/netlib/websockets.py
@@ -5,7 +5,7 @@ import os
import struct
import io
-from . import utils, odict
+from . import utils, odict, tcp
# Colleciton of utility functions that implement small portions of the RFC6455
# WebSockets Protocol Useful for building WebSocket clients and servers.
@@ -217,8 +217,8 @@ class FrameHeader:
"""
read a websockets frame header
"""
- first_byte = utils.bytes_to_int(fp.read(1))
- second_byte = utils.bytes_to_int(fp.read(1))
+ first_byte = utils.bytes_to_int(fp.safe_read(1))
+ second_byte = utils.bytes_to_int(fp.safe_read(1))
fin = utils.getbit(first_byte, 7)
rsv1 = utils.getbit(first_byte, 6)
@@ -235,13 +235,13 @@ class FrameHeader:
if length_code <= 125:
payload_length = length_code
elif length_code == 126:
- payload_length = utils.bytes_to_int(fp.read(2))
+ payload_length = utils.bytes_to_int(fp.safe_read(2))
elif length_code == 127:
- payload_length = utils.bytes_to_int(fp.read(8))
+ payload_length = utils.bytes_to_int(fp.safe_read(8))
# masking key only present if mask bit set
if mask_bit == 1:
- masking_key = fp.read(4)
+ masking_key = fp.safe_read(4)
else:
masking_key = None
@@ -319,7 +319,7 @@ class Frame(object):
Construct a websocket frame from an in-memory bytestring
to construct a frame from a stream of bytes, use from_file() directly
"""
- return cls.from_file(io.BytesIO(bytestring))
+ return cls.from_file(tcp.Reader(io.BytesIO(bytestring)))
def human_readable(self):
hdr = self.header.human_readable()
@@ -351,7 +351,7 @@ class Frame(object):
stream or a disk or an in memory stream reader
"""
header = FrameHeader.from_file(fp)
- payload = fp.read(header.payload_length)
+ payload = fp.safe_read(header.payload_length)
if header.mask == 1 and header.masking_key:
payload = Masker(header.masking_key)(payload)