diff options
-rw-r--r-- | netlib/socks.py | 32 | ||||
-rw-r--r-- | netlib/tcp.py | 48 | ||||
-rw-r--r-- | netlib/websockets.py | 16 | ||||
-rw-r--r-- | test/test_http.py | 8 | ||||
-rw-r--r-- | test/test_socks.py | 16 | ||||
-rw-r--r-- | test/test_websockets.py | 7 | ||||
-rw-r--r-- | test/tutils.py | 15 |
7 files changed, 80 insertions, 62 deletions
diff --git a/netlib/socks.py b/netlib/socks.py index 497b8eef..6f9f57bd 100644 --- a/netlib/socks.py +++ b/netlib/socks.py @@ -52,20 +52,6 @@ METHOD = utils.BiDi( ) -def _read(f, n): - try: - d = f.read(n) - if len(d) == n: - return d - else: - raise SocksError( - REP.GENERAL_SOCKS_SERVER_FAILURE, - "Incomplete Read" - ) - except socket.error as e: - raise SocksError(REP.GENERAL_SOCKS_SERVER_FAILURE, str(e)) - - class ClientGreeting(object): __slots__ = ("ver", "methods") @@ -75,9 +61,9 @@ class ClientGreeting(object): @classmethod def from_file(cls, f): - ver, nmethods = struct.unpack("!BB", _read(f, 2)) + ver, nmethods = struct.unpack("!BB", f.safe_read(2)) methods = array.array("B") - methods.fromstring(_read(f, nmethods)) + methods.fromstring(f.safe_read(nmethods)) return cls(ver, methods) def to_file(self, f): @@ -94,7 +80,7 @@ class ServerGreeting(object): @classmethod def from_file(cls, f): - ver, method = struct.unpack("!BB", _read(f, 2)) + ver, method = struct.unpack("!BB", f.safe_read(2)) return cls(ver, method) def to_file(self, f): @@ -112,27 +98,27 @@ class Message(object): @classmethod def from_file(cls, f): - ver, msg, rsv, atyp = struct.unpack("!BBBB", _read(f, 4)) + ver, msg, rsv, atyp = struct.unpack("!BBBB", f.safe_read(4)) if rsv != 0x00: raise SocksError(REP.GENERAL_SOCKS_SERVER_FAILURE, "Socks Request: Invalid reserved byte: %s" % rsv) if atyp == ATYP.IPV4_ADDRESS: # We use tnoa here as ntop is not commonly available on Windows. - host = socket.inet_ntoa(_read(f, 4)) + host = socket.inet_ntoa(f.safe_read(4)) use_ipv6 = False elif atyp == ATYP.IPV6_ADDRESS: - host = socket.inet_ntop(socket.AF_INET6, _read(f, 16)) + host = socket.inet_ntop(socket.AF_INET6, f.safe_read(16)) use_ipv6 = True elif atyp == ATYP.DOMAINNAME: - length, = struct.unpack("!B", _read(f, 1)) - host = _read(f, length) + length, = struct.unpack("!B", f.safe_read(1)) + host = f.safe_read(length) use_ipv6 = False else: raise SocksError(REP.ADDRESS_TYPE_NOT_SUPPORTED, "Socks Request: Unknown ATYP: %s" % atyp) - port, = struct.unpack("!H", _read(f, 2)) + port, = struct.unpack("!H", f.safe_read(2)) addr = tcp.Address((host, port), use_ipv6=use_ipv6) return cls(ver, msg, atyp, addr) diff --git a/netlib/tcp.py b/netlib/tcp.py index 84008e2c..dbe114a1 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -24,6 +24,7 @@ OP_NO_SSLv3 = SSL.OP_NO_SSLv3 class NetLibError(Exception): pass class NetLibDisconnect(NetLibError): pass +class NetLibIncomplete(NetLibError): pass class NetLibTimeout(NetLibError): pass class NetLibSSLError(NetLibError): pass @@ -195,10 +196,23 @@ class Reader(_FileLike): break return result + def safe_read(self, length): + """ + Like .read, but is guaranteed to either return length bytes, or + raise an exception. + """ + result = self.read(length) + if length != -1 and len(result) != length: + raise NetLibIncomplete( + "Expected %s bytes, got %s"%(length, len(result)) + ) + return result + class Address(object): """ - This class wraps an IPv4/IPv6 tuple to provide named attributes and ipv6 information. + This class wraps an IPv4/IPv6 tuple to provide named attributes and + ipv6 information. """ def __init__(self, address, use_ipv6=False): self.address = tuple(address) @@ -247,22 +261,28 @@ def close_socket(sock): """ try: # We already indicate that we close our end. - sock.shutdown(socket.SHUT_WR) # may raise "Transport endpoint is not connected" on Linux + # may raise "Transport endpoint is not connected" on Linux + sock.shutdown(socket.SHUT_WR) - # Section 4.2.2.13 of RFC 1122 tells us that a close() with any - # pending readable data could lead to an immediate RST being sent (which is the case on Windows). + # Section 4.2.2.13 of RFC 1122 tells us that a close() with any pending + # readable data could lead to an immediate RST being sent (which is the + # case on Windows). # http://ia600609.us.archive.org/22/items/TheUltimateSo_lingerPageOrWhyIsMyTcpNotReliable/the-ultimate-so_linger-page-or-why-is-my-tcp-not-reliable.html # - # This in turn results in the following issue: If we send an error page to the client and then close the socket, - # the RST may be received by the client before the error page and the users sees a connection error rather than - # the error page. Thus, we try to empty the read buffer on Windows first. - # (see https://github.com/mitmproxy/mitmproxy/issues/527#issuecomment-93782988) + # This in turn results in the following issue: If we send an error page + # to the client and then close the socket, the RST may be received by + # the client before the error page and the users sees a connection + # error rather than the error page. Thus, we try to empty the read + # buffer on Windows first. (see + # https://github.com/mitmproxy/mitmproxy/issues/527#issuecomment-93782988) # + if os.name == "nt": # pragma: no cover - # We cannot rely on the shutdown()-followed-by-read()-eof technique proposed by the page above: - # Some remote machines just don't send a TCP FIN, which would leave us in the unfortunate situation that - # recv() would block infinitely. - # As a workaround, we set a timeout here even if we are in blocking mode. + # We cannot rely on the shutdown()-followed-by-read()-eof technique + # proposed by the page above: Some remote machines just don't send + # a TCP FIN, which would leave us in the unfortunate situation that + # recv() would block infinitely. As a workaround, we set a timeout + # here even if we are in blocking mode. sock.settimeout(sock.gettimeout() or 20) # limit at a megabyte so that we don't read infinitely @@ -292,10 +312,10 @@ class _Connection(object): def finish(self): self.finished = True - # If we have an SSL connection, wfile.close == connection.close # (We call _FileLike.set_descriptor(conn)) - # Closing the socket is not our task, therefore we don't call close then. + # Closing the socket is not our task, therefore we don't call close + # then. if type(self.connection) != SSL.Connection: if not getattr(self.wfile, "closed", False): try: diff --git a/netlib/websockets.py b/netlib/websockets.py index 0ad0e294..6d08e101 100644 --- a/netlib/websockets.py +++ b/netlib/websockets.py @@ -5,7 +5,7 @@ import os import struct import io -from . import utils, odict +from . import utils, odict, tcp # Colleciton of utility functions that implement small portions of the RFC6455 # WebSockets Protocol Useful for building WebSocket clients and servers. @@ -217,8 +217,8 @@ class FrameHeader: """ read a websockets frame header """ - first_byte = utils.bytes_to_int(fp.read(1)) - second_byte = utils.bytes_to_int(fp.read(1)) + first_byte = utils.bytes_to_int(fp.safe_read(1)) + second_byte = utils.bytes_to_int(fp.safe_read(1)) fin = utils.getbit(first_byte, 7) rsv1 = utils.getbit(first_byte, 6) @@ -235,13 +235,13 @@ class FrameHeader: if length_code <= 125: payload_length = length_code elif length_code == 126: - payload_length = utils.bytes_to_int(fp.read(2)) + payload_length = utils.bytes_to_int(fp.safe_read(2)) elif length_code == 127: - payload_length = utils.bytes_to_int(fp.read(8)) + payload_length = utils.bytes_to_int(fp.safe_read(8)) # masking key only present if mask bit set if mask_bit == 1: - masking_key = fp.read(4) + masking_key = fp.safe_read(4) else: masking_key = None @@ -319,7 +319,7 @@ class Frame(object): Construct a websocket frame from an in-memory bytestring to construct a frame from a stream of bytes, use from_file() directly """ - return cls.from_file(io.BytesIO(bytestring)) + return cls.from_file(tcp.Reader(io.BytesIO(bytestring))) def human_readable(self): hdr = self.header.human_readable() @@ -351,7 +351,7 @@ class Frame(object): stream or a disk or an in memory stream reader """ header = FrameHeader.from_file(fp) - payload = fp.read(header.payload_length) + payload = fp.safe_read(header.payload_length) if header.mask == 1 and header.masking_key: payload = Masker(header.masking_key)(payload) diff --git a/test/test_http.py b/test/test_http.py index f1a31b93..63b39f08 100644 --- a/test/test_http.py +++ b/test/test_http.py @@ -91,7 +91,7 @@ def test_read_http_body_request(): def test_read_http_body_response(): h = odict.ODictCaseless() - s = cStringIO.StringIO("testing") + s = tcp.Reader(cStringIO.StringIO("testing")) assert http.read_http_body(s, h, None, "GET", 200, False) == "testing" @@ -135,11 +135,11 @@ def test_read_http_body(): # test no content length: limit > actual content h = odict.ODictCaseless() - s = cStringIO.StringIO("testing") + s = tcp.Reader(cStringIO.StringIO("testing")) assert len(http.read_http_body(s, h, 100, "GET", 200, False)) == 7 # test no content length: limit < actual content - s = cStringIO.StringIO("testing") + s = tcp.Reader(cStringIO.StringIO("testing")) tutils.raises( http.HttpError, http.read_http_body, @@ -149,7 +149,7 @@ def test_read_http_body(): # test chunked h = odict.ODictCaseless() h["transfer-encoding"] = ["chunked"] - s = cStringIO.StringIO("5\r\naaaaa\r\n0\r\n\r\n") + s = tcp.Reader(cStringIO.StringIO("5\r\naaaaa\r\n0\r\n\r\n")) assert http.read_http_body(s, h, 100, "GET", 200, False) == "aaaaa" diff --git a/test/test_socks.py b/test/test_socks.py index aa4f9c11..6e522826 100644 --- a/test/test_socks.py +++ b/test/test_socks.py @@ -7,7 +7,7 @@ import tutils def test_client_greeting(): - raw = StringIO("\x05\x02\x00\xBE\xEF") + raw = tutils.treader("\x05\x02\x00\xBE\xEF") out = StringIO() msg = socks.ClientGreeting.from_file(raw) msg.to_file(out) @@ -20,7 +20,7 @@ def test_client_greeting(): def test_server_greeting(): - raw = StringIO("\x05\x02") + raw = tutils.treader("\x05\x02") out = StringIO() msg = socks.ServerGreeting.from_file(raw) msg.to_file(out) @@ -31,7 +31,7 @@ def test_server_greeting(): def test_message(): - raw = StringIO("\x05\x01\x00\x03\x0bexample.com\xDE\xAD\xBE\xEF") + raw = tutils.treader("\x05\x01\x00\x03\x0bexample.com\xDE\xAD\xBE\xEF") out = StringIO() msg = socks.Message.from_file(raw) assert raw.read(2) == "\xBE\xEF" @@ -46,7 +46,7 @@ def test_message(): def test_message_ipv4(): # Test ATYP=0x01 (IPV4) - raw = StringIO("\x05\x01\x00\x01\x7f\x00\x00\x01\xDE\xAD\xBE\xEF") + raw = tutils.treader("\x05\x01\x00\x01\x7f\x00\x00\x01\xDE\xAD\xBE\xEF") out = StringIO() msg = socks.Message.from_file(raw) assert raw.read(2) == "\xBE\xEF" @@ -62,7 +62,7 @@ def test_message_ipv6(): # Test ATYP=0x04 (IPV6) ipv6_addr = "2001:db8:85a3:8d3:1319:8a2e:370:7344" - raw = StringIO("\x05\x01\x00\x04" + socket.inet_pton(socket.AF_INET6, ipv6_addr) + "\xDE\xAD\xBE\xEF") + raw = tutils.treader("\x05\x01\x00\x04" + socket.inet_pton(socket.AF_INET6, ipv6_addr) + "\xDE\xAD\xBE\xEF") out = StringIO() msg = socks.Message.from_file(raw) assert raw.read(2) == "\xBE\xEF" @@ -73,12 +73,12 @@ def test_message_ipv6(): def test_message_invalid_rsv(): - raw = StringIO("\x05\x01\xFF\x01\x7f\x00\x00\x01\xDE\xAD\xBE\xEF") + raw = tutils.treader("\x05\x01\xFF\x01\x7f\x00\x00\x01\xDE\xAD\xBE\xEF") tutils.raises(socks.SocksError, socks.Message.from_file, raw) def test_message_unknown_atyp(): - raw = StringIO("\x05\x02\x00\x02\x7f\x00\x00\x01\xDE\xAD\xBE\xEF") + raw = tutils.treader("\x05\x02\x00\x02\x7f\x00\x00\x01\xDE\xAD\xBE\xEF") tutils.raises(socks.SocksError, socks.Message.from_file, raw) m = socks.Message(5, 1, 0x02, tcp.Address(("example.com", 5050))) @@ -93,4 +93,4 @@ def test_read(): cs = mock.Mock() cs.read = mock.Mock(side_effect=socket.error) - tutils.raises(socks.SocksError, socks._read, cs, 4)
\ No newline at end of file + tutils.raises(socks.SocksError, socks._read, cs, 4) diff --git a/test/test_websockets.py b/test/test_websockets.py index 428f7c61..7bd5d74e 100644 --- a/test/test_websockets.py +++ b/test/test_websockets.py @@ -1,4 +1,3 @@ -import cStringIO import os from nose.tools import raises @@ -170,7 +169,7 @@ class TestFrameHeader: def round(*args, **kwargs): f = websockets.FrameHeader(*args, **kwargs) bytes = f.to_bytes() - f2 = websockets.FrameHeader.from_file(cStringIO.StringIO(bytes)) + f2 = websockets.FrameHeader.from_file(tutils.treader(bytes)) assert f == f2 round() round(fin=1) @@ -197,7 +196,7 @@ class TestFrameHeader: def test_funky(self): f = websockets.FrameHeader(masking_key="test", mask=False) bytes = f.to_bytes() - f2 = websockets.FrameHeader.from_file(cStringIO.StringIO(bytes)) + f2 = websockets.FrameHeader.from_file(tutils.treader(bytes)) assert not f2.mask def test_violations(self): @@ -221,7 +220,7 @@ class TestFrame: def round(*args, **kwargs): f = websockets.Frame(*args, **kwargs) bytes = f.to_bytes() - f2 = websockets.Frame.from_file(cStringIO.StringIO(bytes)) + f2 = websockets.Frame.from_file(tutils.treader(bytes)) assert f == f2 round("test") round("test", fin=1) diff --git a/test/tutils.py b/test/tutils.py index ea30f59c..141979f8 100644 --- a/test/tutils.py +++ b/test/tutils.py @@ -1,7 +1,20 @@ -import tempfile, os, shutil +import cStringIO +import tempfile +import os +import shutil from contextlib import contextmanager from libpathod import utils +from netlib import tcp + + +def treader(bytes): + """ + Construct a tcp.Read object from bytes. + """ + fp = cStringIO.StringIO(bytes) + return tcp.Reader(fp) + @contextmanager def tmpdir(*args, **kwargs): |