diff options
author | Aldo Cortesi <aldo@nullcube.com> | 2015-05-05 10:47:02 +1200 |
---|---|---|
committer | Aldo Cortesi <aldo@nullcube.com> | 2015-05-05 10:47:02 +1200 |
commit | f2bc58cdd2f2b9b0025a88c0faccf55e10b29353 (patch) | |
tree | 9353f4b78f4bf8ec1e6b4169155da084a2c05ea9 /netlib | |
parent | 08b2e2a6a98fd175e1b49d62dffde34e91c77b1c (diff) | |
download | mitmproxy-f2bc58cdd2f2b9b0025a88c0faccf55e10b29353.tar.gz mitmproxy-f2bc58cdd2f2b9b0025a88c0faccf55e10b29353.tar.bz2 mitmproxy-f2bc58cdd2f2b9b0025a88c0faccf55e10b29353.zip |
Add tcp.Reader.safe_read, use it in socks and websockets
safe_read is guaranteed to raise or return a byte string of the
requested length. It's particularly useful for implementing binary
protocols.
Diffstat (limited to 'netlib')
-rw-r--r-- | netlib/socks.py | 32 | ||||
-rw-r--r-- | netlib/tcp.py | 48 | ||||
-rw-r--r-- | netlib/websockets.py | 16 |
3 files changed, 51 insertions, 45 deletions
diff --git a/netlib/socks.py b/netlib/socks.py index 497b8eef..6f9f57bd 100644 --- a/netlib/socks.py +++ b/netlib/socks.py @@ -52,20 +52,6 @@ METHOD = utils.BiDi( ) -def _read(f, n): - try: - d = f.read(n) - if len(d) == n: - return d - else: - raise SocksError( - REP.GENERAL_SOCKS_SERVER_FAILURE, - "Incomplete Read" - ) - except socket.error as e: - raise SocksError(REP.GENERAL_SOCKS_SERVER_FAILURE, str(e)) - - class ClientGreeting(object): __slots__ = ("ver", "methods") @@ -75,9 +61,9 @@ class ClientGreeting(object): @classmethod def from_file(cls, f): - ver, nmethods = struct.unpack("!BB", _read(f, 2)) + ver, nmethods = struct.unpack("!BB", f.safe_read(2)) methods = array.array("B") - methods.fromstring(_read(f, nmethods)) + methods.fromstring(f.safe_read(nmethods)) return cls(ver, methods) def to_file(self, f): @@ -94,7 +80,7 @@ class ServerGreeting(object): @classmethod def from_file(cls, f): - ver, method = struct.unpack("!BB", _read(f, 2)) + ver, method = struct.unpack("!BB", f.safe_read(2)) return cls(ver, method) def to_file(self, f): @@ -112,27 +98,27 @@ class Message(object): @classmethod def from_file(cls, f): - ver, msg, rsv, atyp = struct.unpack("!BBBB", _read(f, 4)) + ver, msg, rsv, atyp = struct.unpack("!BBBB", f.safe_read(4)) if rsv != 0x00: raise SocksError(REP.GENERAL_SOCKS_SERVER_FAILURE, "Socks Request: Invalid reserved byte: %s" % rsv) if atyp == ATYP.IPV4_ADDRESS: # We use tnoa here as ntop is not commonly available on Windows. - host = socket.inet_ntoa(_read(f, 4)) + host = socket.inet_ntoa(f.safe_read(4)) use_ipv6 = False elif atyp == ATYP.IPV6_ADDRESS: - host = socket.inet_ntop(socket.AF_INET6, _read(f, 16)) + host = socket.inet_ntop(socket.AF_INET6, f.safe_read(16)) use_ipv6 = True elif atyp == ATYP.DOMAINNAME: - length, = struct.unpack("!B", _read(f, 1)) - host = _read(f, length) + length, = struct.unpack("!B", f.safe_read(1)) + host = f.safe_read(length) use_ipv6 = False else: raise SocksError(REP.ADDRESS_TYPE_NOT_SUPPORTED, "Socks Request: Unknown ATYP: %s" % atyp) - port, = struct.unpack("!H", _read(f, 2)) + port, = struct.unpack("!H", f.safe_read(2)) addr = tcp.Address((host, port), use_ipv6=use_ipv6) return cls(ver, msg, atyp, addr) diff --git a/netlib/tcp.py b/netlib/tcp.py index 84008e2c..dbe114a1 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -24,6 +24,7 @@ OP_NO_SSLv3 = SSL.OP_NO_SSLv3 class NetLibError(Exception): pass class NetLibDisconnect(NetLibError): pass +class NetLibIncomplete(NetLibError): pass class NetLibTimeout(NetLibError): pass class NetLibSSLError(NetLibError): pass @@ -195,10 +196,23 @@ class Reader(_FileLike): break return result + def safe_read(self, length): + """ + Like .read, but is guaranteed to either return length bytes, or + raise an exception. + """ + result = self.read(length) + if length != -1 and len(result) != length: + raise NetLibIncomplete( + "Expected %s bytes, got %s"%(length, len(result)) + ) + return result + class Address(object): """ - This class wraps an IPv4/IPv6 tuple to provide named attributes and ipv6 information. + This class wraps an IPv4/IPv6 tuple to provide named attributes and + ipv6 information. """ def __init__(self, address, use_ipv6=False): self.address = tuple(address) @@ -247,22 +261,28 @@ def close_socket(sock): """ try: # We already indicate that we close our end. - sock.shutdown(socket.SHUT_WR) # may raise "Transport endpoint is not connected" on Linux + # may raise "Transport endpoint is not connected" on Linux + sock.shutdown(socket.SHUT_WR) - # Section 4.2.2.13 of RFC 1122 tells us that a close() with any - # pending readable data could lead to an immediate RST being sent (which is the case on Windows). + # Section 4.2.2.13 of RFC 1122 tells us that a close() with any pending + # readable data could lead to an immediate RST being sent (which is the + # case on Windows). # http://ia600609.us.archive.org/22/items/TheUltimateSo_lingerPageOrWhyIsMyTcpNotReliable/the-ultimate-so_linger-page-or-why-is-my-tcp-not-reliable.html # - # This in turn results in the following issue: If we send an error page to the client and then close the socket, - # the RST may be received by the client before the error page and the users sees a connection error rather than - # the error page. Thus, we try to empty the read buffer on Windows first. - # (see https://github.com/mitmproxy/mitmproxy/issues/527#issuecomment-93782988) + # This in turn results in the following issue: If we send an error page + # to the client and then close the socket, the RST may be received by + # the client before the error page and the users sees a connection + # error rather than the error page. Thus, we try to empty the read + # buffer on Windows first. (see + # https://github.com/mitmproxy/mitmproxy/issues/527#issuecomment-93782988) # + if os.name == "nt": # pragma: no cover - # We cannot rely on the shutdown()-followed-by-read()-eof technique proposed by the page above: - # Some remote machines just don't send a TCP FIN, which would leave us in the unfortunate situation that - # recv() would block infinitely. - # As a workaround, we set a timeout here even if we are in blocking mode. + # We cannot rely on the shutdown()-followed-by-read()-eof technique + # proposed by the page above: Some remote machines just don't send + # a TCP FIN, which would leave us in the unfortunate situation that + # recv() would block infinitely. As a workaround, we set a timeout + # here even if we are in blocking mode. sock.settimeout(sock.gettimeout() or 20) # limit at a megabyte so that we don't read infinitely @@ -292,10 +312,10 @@ class _Connection(object): def finish(self): self.finished = True - # If we have an SSL connection, wfile.close == connection.close # (We call _FileLike.set_descriptor(conn)) - # Closing the socket is not our task, therefore we don't call close then. + # Closing the socket is not our task, therefore we don't call close + # then. if type(self.connection) != SSL.Connection: if not getattr(self.wfile, "closed", False): try: diff --git a/netlib/websockets.py b/netlib/websockets.py index 0ad0e294..6d08e101 100644 --- a/netlib/websockets.py +++ b/netlib/websockets.py @@ -5,7 +5,7 @@ import os import struct import io -from . import utils, odict +from . import utils, odict, tcp # Colleciton of utility functions that implement small portions of the RFC6455 # WebSockets Protocol Useful for building WebSocket clients and servers. @@ -217,8 +217,8 @@ class FrameHeader: """ read a websockets frame header """ - first_byte = utils.bytes_to_int(fp.read(1)) - second_byte = utils.bytes_to_int(fp.read(1)) + first_byte = utils.bytes_to_int(fp.safe_read(1)) + second_byte = utils.bytes_to_int(fp.safe_read(1)) fin = utils.getbit(first_byte, 7) rsv1 = utils.getbit(first_byte, 6) @@ -235,13 +235,13 @@ class FrameHeader: if length_code <= 125: payload_length = length_code elif length_code == 126: - payload_length = utils.bytes_to_int(fp.read(2)) + payload_length = utils.bytes_to_int(fp.safe_read(2)) elif length_code == 127: - payload_length = utils.bytes_to_int(fp.read(8)) + payload_length = utils.bytes_to_int(fp.safe_read(8)) # masking key only present if mask bit set if mask_bit == 1: - masking_key = fp.read(4) + masking_key = fp.safe_read(4) else: masking_key = None @@ -319,7 +319,7 @@ class Frame(object): Construct a websocket frame from an in-memory bytestring to construct a frame from a stream of bytes, use from_file() directly """ - return cls.from_file(io.BytesIO(bytestring)) + return cls.from_file(tcp.Reader(io.BytesIO(bytestring))) def human_readable(self): hdr = self.header.human_readable() @@ -351,7 +351,7 @@ class Frame(object): stream or a disk or an in memory stream reader """ header = FrameHeader.from_file(fp) - payload = fp.read(header.payload_length) + payload = fp.safe_read(header.payload_length) if header.mask == 1 and header.masking_key: payload = Masker(header.masking_key)(payload) |