aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--netlib/socks.py32
-rw-r--r--netlib/tcp.py48
-rw-r--r--netlib/websockets.py16
-rw-r--r--test/test_http.py8
-rw-r--r--test/test_socks.py16
-rw-r--r--test/test_websockets.py7
-rw-r--r--test/tutils.py15
7 files changed, 80 insertions, 62 deletions
diff --git a/netlib/socks.py b/netlib/socks.py
index 497b8eef..6f9f57bd 100644
--- a/netlib/socks.py
+++ b/netlib/socks.py
@@ -52,20 +52,6 @@ METHOD = utils.BiDi(
)
-def _read(f, n):
- try:
- d = f.read(n)
- if len(d) == n:
- return d
- else:
- raise SocksError(
- REP.GENERAL_SOCKS_SERVER_FAILURE,
- "Incomplete Read"
- )
- except socket.error as e:
- raise SocksError(REP.GENERAL_SOCKS_SERVER_FAILURE, str(e))
-
-
class ClientGreeting(object):
__slots__ = ("ver", "methods")
@@ -75,9 +61,9 @@ class ClientGreeting(object):
@classmethod
def from_file(cls, f):
- ver, nmethods = struct.unpack("!BB", _read(f, 2))
+ ver, nmethods = struct.unpack("!BB", f.safe_read(2))
methods = array.array("B")
- methods.fromstring(_read(f, nmethods))
+ methods.fromstring(f.safe_read(nmethods))
return cls(ver, methods)
def to_file(self, f):
@@ -94,7 +80,7 @@ class ServerGreeting(object):
@classmethod
def from_file(cls, f):
- ver, method = struct.unpack("!BB", _read(f, 2))
+ ver, method = struct.unpack("!BB", f.safe_read(2))
return cls(ver, method)
def to_file(self, f):
@@ -112,27 +98,27 @@ class Message(object):
@classmethod
def from_file(cls, f):
- ver, msg, rsv, atyp = struct.unpack("!BBBB", _read(f, 4))
+ ver, msg, rsv, atyp = struct.unpack("!BBBB", f.safe_read(4))
if rsv != 0x00:
raise SocksError(REP.GENERAL_SOCKS_SERVER_FAILURE,
"Socks Request: Invalid reserved byte: %s" % rsv)
if atyp == ATYP.IPV4_ADDRESS:
# We use tnoa here as ntop is not commonly available on Windows.
- host = socket.inet_ntoa(_read(f, 4))
+ host = socket.inet_ntoa(f.safe_read(4))
use_ipv6 = False
elif atyp == ATYP.IPV6_ADDRESS:
- host = socket.inet_ntop(socket.AF_INET6, _read(f, 16))
+ host = socket.inet_ntop(socket.AF_INET6, f.safe_read(16))
use_ipv6 = True
elif atyp == ATYP.DOMAINNAME:
- length, = struct.unpack("!B", _read(f, 1))
- host = _read(f, length)
+ length, = struct.unpack("!B", f.safe_read(1))
+ host = f.safe_read(length)
use_ipv6 = False
else:
raise SocksError(REP.ADDRESS_TYPE_NOT_SUPPORTED,
"Socks Request: Unknown ATYP: %s" % atyp)
- port, = struct.unpack("!H", _read(f, 2))
+ port, = struct.unpack("!H", f.safe_read(2))
addr = tcp.Address((host, port), use_ipv6=use_ipv6)
return cls(ver, msg, atyp, addr)
diff --git a/netlib/tcp.py b/netlib/tcp.py
index 84008e2c..dbe114a1 100644
--- a/netlib/tcp.py
+++ b/netlib/tcp.py
@@ -24,6 +24,7 @@ OP_NO_SSLv3 = SSL.OP_NO_SSLv3
class NetLibError(Exception): pass
class NetLibDisconnect(NetLibError): pass
+class NetLibIncomplete(NetLibError): pass
class NetLibTimeout(NetLibError): pass
class NetLibSSLError(NetLibError): pass
@@ -195,10 +196,23 @@ class Reader(_FileLike):
break
return result
+ def safe_read(self, length):
+ """
+ Like .read, but is guaranteed to either return length bytes, or
+ raise an exception.
+ """
+ result = self.read(length)
+ if length != -1 and len(result) != length:
+ raise NetLibIncomplete(
+ "Expected %s bytes, got %s"%(length, len(result))
+ )
+ return result
+
class Address(object):
"""
- This class wraps an IPv4/IPv6 tuple to provide named attributes and ipv6 information.
+ This class wraps an IPv4/IPv6 tuple to provide named attributes and
+ ipv6 information.
"""
def __init__(self, address, use_ipv6=False):
self.address = tuple(address)
@@ -247,22 +261,28 @@ def close_socket(sock):
"""
try:
# We already indicate that we close our end.
- sock.shutdown(socket.SHUT_WR) # may raise "Transport endpoint is not connected" on Linux
+ # may raise "Transport endpoint is not connected" on Linux
+ sock.shutdown(socket.SHUT_WR)
- # Section 4.2.2.13 of RFC 1122 tells us that a close() with any
- # pending readable data could lead to an immediate RST being sent (which is the case on Windows).
+ # Section 4.2.2.13 of RFC 1122 tells us that a close() with any pending
+ # readable data could lead to an immediate RST being sent (which is the
+ # case on Windows).
# http://ia600609.us.archive.org/22/items/TheUltimateSo_lingerPageOrWhyIsMyTcpNotReliable/the-ultimate-so_linger-page-or-why-is-my-tcp-not-reliable.html
#
- # This in turn results in the following issue: If we send an error page to the client and then close the socket,
- # the RST may be received by the client before the error page and the users sees a connection error rather than
- # the error page. Thus, we try to empty the read buffer on Windows first.
- # (see https://github.com/mitmproxy/mitmproxy/issues/527#issuecomment-93782988)
+ # This in turn results in the following issue: If we send an error page
+ # to the client and then close the socket, the RST may be received by
+ # the client before the error page and the users sees a connection
+ # error rather than the error page. Thus, we try to empty the read
+ # buffer on Windows first. (see
+ # https://github.com/mitmproxy/mitmproxy/issues/527#issuecomment-93782988)
#
+
if os.name == "nt": # pragma: no cover
- # We cannot rely on the shutdown()-followed-by-read()-eof technique proposed by the page above:
- # Some remote machines just don't send a TCP FIN, which would leave us in the unfortunate situation that
- # recv() would block infinitely.
- # As a workaround, we set a timeout here even if we are in blocking mode.
+ # We cannot rely on the shutdown()-followed-by-read()-eof technique
+ # proposed by the page above: Some remote machines just don't send
+ # a TCP FIN, which would leave us in the unfortunate situation that
+ # recv() would block infinitely. As a workaround, we set a timeout
+ # here even if we are in blocking mode.
sock.settimeout(sock.gettimeout() or 20)
# limit at a megabyte so that we don't read infinitely
@@ -292,10 +312,10 @@ class _Connection(object):
def finish(self):
self.finished = True
-
# If we have an SSL connection, wfile.close == connection.close
# (We call _FileLike.set_descriptor(conn))
- # Closing the socket is not our task, therefore we don't call close then.
+ # Closing the socket is not our task, therefore we don't call close
+ # then.
if type(self.connection) != SSL.Connection:
if not getattr(self.wfile, "closed", False):
try:
diff --git a/netlib/websockets.py b/netlib/websockets.py
index 0ad0e294..6d08e101 100644
--- a/netlib/websockets.py
+++ b/netlib/websockets.py
@@ -5,7 +5,7 @@ import os
import struct
import io
-from . import utils, odict
+from . import utils, odict, tcp
# Colleciton of utility functions that implement small portions of the RFC6455
# WebSockets Protocol Useful for building WebSocket clients and servers.
@@ -217,8 +217,8 @@ class FrameHeader:
"""
read a websockets frame header
"""
- first_byte = utils.bytes_to_int(fp.read(1))
- second_byte = utils.bytes_to_int(fp.read(1))
+ first_byte = utils.bytes_to_int(fp.safe_read(1))
+ second_byte = utils.bytes_to_int(fp.safe_read(1))
fin = utils.getbit(first_byte, 7)
rsv1 = utils.getbit(first_byte, 6)
@@ -235,13 +235,13 @@ class FrameHeader:
if length_code <= 125:
payload_length = length_code
elif length_code == 126:
- payload_length = utils.bytes_to_int(fp.read(2))
+ payload_length = utils.bytes_to_int(fp.safe_read(2))
elif length_code == 127:
- payload_length = utils.bytes_to_int(fp.read(8))
+ payload_length = utils.bytes_to_int(fp.safe_read(8))
# masking key only present if mask bit set
if mask_bit == 1:
- masking_key = fp.read(4)
+ masking_key = fp.safe_read(4)
else:
masking_key = None
@@ -319,7 +319,7 @@ class Frame(object):
Construct a websocket frame from an in-memory bytestring
to construct a frame from a stream of bytes, use from_file() directly
"""
- return cls.from_file(io.BytesIO(bytestring))
+ return cls.from_file(tcp.Reader(io.BytesIO(bytestring)))
def human_readable(self):
hdr = self.header.human_readable()
@@ -351,7 +351,7 @@ class Frame(object):
stream or a disk or an in memory stream reader
"""
header = FrameHeader.from_file(fp)
- payload = fp.read(header.payload_length)
+ payload = fp.safe_read(header.payload_length)
if header.mask == 1 and header.masking_key:
payload = Masker(header.masking_key)(payload)
diff --git a/test/test_http.py b/test/test_http.py
index f1a31b93..63b39f08 100644
--- a/test/test_http.py
+++ b/test/test_http.py
@@ -91,7 +91,7 @@ def test_read_http_body_request():
def test_read_http_body_response():
h = odict.ODictCaseless()
- s = cStringIO.StringIO("testing")
+ s = tcp.Reader(cStringIO.StringIO("testing"))
assert http.read_http_body(s, h, None, "GET", 200, False) == "testing"
@@ -135,11 +135,11 @@ def test_read_http_body():
# test no content length: limit > actual content
h = odict.ODictCaseless()
- s = cStringIO.StringIO("testing")
+ s = tcp.Reader(cStringIO.StringIO("testing"))
assert len(http.read_http_body(s, h, 100, "GET", 200, False)) == 7
# test no content length: limit < actual content
- s = cStringIO.StringIO("testing")
+ s = tcp.Reader(cStringIO.StringIO("testing"))
tutils.raises(
http.HttpError,
http.read_http_body,
@@ -149,7 +149,7 @@ def test_read_http_body():
# test chunked
h = odict.ODictCaseless()
h["transfer-encoding"] = ["chunked"]
- s = cStringIO.StringIO("5\r\naaaaa\r\n0\r\n\r\n")
+ s = tcp.Reader(cStringIO.StringIO("5\r\naaaaa\r\n0\r\n\r\n"))
assert http.read_http_body(s, h, 100, "GET", 200, False) == "aaaaa"
diff --git a/test/test_socks.py b/test/test_socks.py
index aa4f9c11..6e522826 100644
--- a/test/test_socks.py
+++ b/test/test_socks.py
@@ -7,7 +7,7 @@ import tutils
def test_client_greeting():
- raw = StringIO("\x05\x02\x00\xBE\xEF")
+ raw = tutils.treader("\x05\x02\x00\xBE\xEF")
out = StringIO()
msg = socks.ClientGreeting.from_file(raw)
msg.to_file(out)
@@ -20,7 +20,7 @@ def test_client_greeting():
def test_server_greeting():
- raw = StringIO("\x05\x02")
+ raw = tutils.treader("\x05\x02")
out = StringIO()
msg = socks.ServerGreeting.from_file(raw)
msg.to_file(out)
@@ -31,7 +31,7 @@ def test_server_greeting():
def test_message():
- raw = StringIO("\x05\x01\x00\x03\x0bexample.com\xDE\xAD\xBE\xEF")
+ raw = tutils.treader("\x05\x01\x00\x03\x0bexample.com\xDE\xAD\xBE\xEF")
out = StringIO()
msg = socks.Message.from_file(raw)
assert raw.read(2) == "\xBE\xEF"
@@ -46,7 +46,7 @@ def test_message():
def test_message_ipv4():
# Test ATYP=0x01 (IPV4)
- raw = StringIO("\x05\x01\x00\x01\x7f\x00\x00\x01\xDE\xAD\xBE\xEF")
+ raw = tutils.treader("\x05\x01\x00\x01\x7f\x00\x00\x01\xDE\xAD\xBE\xEF")
out = StringIO()
msg = socks.Message.from_file(raw)
assert raw.read(2) == "\xBE\xEF"
@@ -62,7 +62,7 @@ def test_message_ipv6():
# Test ATYP=0x04 (IPV6)
ipv6_addr = "2001:db8:85a3:8d3:1319:8a2e:370:7344"
- raw = StringIO("\x05\x01\x00\x04" + socket.inet_pton(socket.AF_INET6, ipv6_addr) + "\xDE\xAD\xBE\xEF")
+ raw = tutils.treader("\x05\x01\x00\x04" + socket.inet_pton(socket.AF_INET6, ipv6_addr) + "\xDE\xAD\xBE\xEF")
out = StringIO()
msg = socks.Message.from_file(raw)
assert raw.read(2) == "\xBE\xEF"
@@ -73,12 +73,12 @@ def test_message_ipv6():
def test_message_invalid_rsv():
- raw = StringIO("\x05\x01\xFF\x01\x7f\x00\x00\x01\xDE\xAD\xBE\xEF")
+ raw = tutils.treader("\x05\x01\xFF\x01\x7f\x00\x00\x01\xDE\xAD\xBE\xEF")
tutils.raises(socks.SocksError, socks.Message.from_file, raw)
def test_message_unknown_atyp():
- raw = StringIO("\x05\x02\x00\x02\x7f\x00\x00\x01\xDE\xAD\xBE\xEF")
+ raw = tutils.treader("\x05\x02\x00\x02\x7f\x00\x00\x01\xDE\xAD\xBE\xEF")
tutils.raises(socks.SocksError, socks.Message.from_file, raw)
m = socks.Message(5, 1, 0x02, tcp.Address(("example.com", 5050)))
@@ -93,4 +93,4 @@ def test_read():
cs = mock.Mock()
cs.read = mock.Mock(side_effect=socket.error)
- tutils.raises(socks.SocksError, socks._read, cs, 4) \ No newline at end of file
+ tutils.raises(socks.SocksError, socks._read, cs, 4)
diff --git a/test/test_websockets.py b/test/test_websockets.py
index 428f7c61..7bd5d74e 100644
--- a/test/test_websockets.py
+++ b/test/test_websockets.py
@@ -1,4 +1,3 @@
-import cStringIO
import os
from nose.tools import raises
@@ -170,7 +169,7 @@ class TestFrameHeader:
def round(*args, **kwargs):
f = websockets.FrameHeader(*args, **kwargs)
bytes = f.to_bytes()
- f2 = websockets.FrameHeader.from_file(cStringIO.StringIO(bytes))
+ f2 = websockets.FrameHeader.from_file(tutils.treader(bytes))
assert f == f2
round()
round(fin=1)
@@ -197,7 +196,7 @@ class TestFrameHeader:
def test_funky(self):
f = websockets.FrameHeader(masking_key="test", mask=False)
bytes = f.to_bytes()
- f2 = websockets.FrameHeader.from_file(cStringIO.StringIO(bytes))
+ f2 = websockets.FrameHeader.from_file(tutils.treader(bytes))
assert not f2.mask
def test_violations(self):
@@ -221,7 +220,7 @@ class TestFrame:
def round(*args, **kwargs):
f = websockets.Frame(*args, **kwargs)
bytes = f.to_bytes()
- f2 = websockets.Frame.from_file(cStringIO.StringIO(bytes))
+ f2 = websockets.Frame.from_file(tutils.treader(bytes))
assert f == f2
round("test")
round("test", fin=1)
diff --git a/test/tutils.py b/test/tutils.py
index ea30f59c..141979f8 100644
--- a/test/tutils.py
+++ b/test/tutils.py
@@ -1,7 +1,20 @@
-import tempfile, os, shutil
+import cStringIO
+import tempfile
+import os
+import shutil
from contextlib import contextmanager
from libpathod import utils
+from netlib import tcp
+
+
+def treader(bytes):
+ """
+ Construct a tcp.Read object from bytes.
+ """
+ fp = cStringIO.StringIO(bytes)
+ return tcp.Reader(fp)
+
@contextmanager
def tmpdir(*args, **kwargs):