diff options
Diffstat (limited to 'test')
-rw-r--r-- | test/test_certutils.py | 80 | ||||
-rw-r--r-- | test/test_http.py | 69 | ||||
-rw-r--r-- | test/test_http_auth.py | 8 | ||||
-rw-r--r-- | test/test_socks.py | 84 | ||||
-rw-r--r-- | test/test_tcp.py | 7 | ||||
-rw-r--r-- | test/test_utils.py | 3 | ||||
-rw-r--r-- | test/test_wsgi.py | 30 |
7 files changed, 201 insertions, 80 deletions
diff --git a/test/test_certutils.py b/test/test_certutils.py index 176575ea..55fcc1dc 100644 --- a/test/test_certutils.py +++ b/test/test_certutils.py @@ -3,34 +3,34 @@ from netlib import certutils, certffi import OpenSSL import tutils -class TestDNTree: - def test_simple(self): - d = certutils.DNTree() - d.add("foo.com", "foo") - d.add("bar.com", "bar") - assert d.get("foo.com") == "foo" - assert d.get("bar.com") == "bar" - assert not d.get("oink.com") - assert not d.get("oink") - assert not d.get("") - assert not d.get("oink.oink") - - d.add("*.match.org", "match") - assert not d.get("match.org") - assert d.get("foo.match.org") == "match" - assert d.get("foo.foo.match.org") == "match" - - def test_wildcard(self): - d = certutils.DNTree() - d.add("foo.com", "foo") - assert not d.get("*.foo.com") - d.add("*.foo.com", "wild") - - d = certutils.DNTree() - d.add("*", "foo") - assert d.get("foo.com") == "foo" - assert d.get("*.foo.com") == "foo" - assert d.get("com") == "foo" +# class TestDNTree: +# def test_simple(self): +# d = certutils.DNTree() +# d.add("foo.com", "foo") +# d.add("bar.com", "bar") +# assert d.get("foo.com") == "foo" +# assert d.get("bar.com") == "bar" +# assert not d.get("oink.com") +# assert not d.get("oink") +# assert not d.get("") +# assert not d.get("oink.oink") +# +# d.add("*.match.org", "match") +# assert not d.get("match.org") +# assert d.get("foo.match.org") == "match" +# assert d.get("foo.foo.match.org") == "match" +# +# def test_wildcard(self): +# d = certutils.DNTree() +# d.add("foo.com", "foo") +# assert not d.get("*.foo.com") +# d.add("*.foo.com", "wild") +# +# d = certutils.DNTree() +# d.add("*", "foo") +# assert d.get("foo.com") == "foo" +# assert d.get("*.foo.com") == "foo" +# assert d.get("com") == "foo" class TestCertStore: @@ -63,10 +63,17 @@ class TestCertStore: ca = certutils.CertStore.from_store(d, "test") c1 = ca.get_cert("foo.com", ["*.bar.com"]) c2 = ca.get_cert("foo.bar.com", []) - assert c1 == c2 + # assert c1 == c2 c3 = ca.get_cert("bar.com", []) assert not c1 == c3 + def test_sans_change(self): + with tutils.tmpdir() as d: + ca = certutils.CertStore.from_store(d, "test") + _ = ca.get_cert("foo.com", ["*.bar.com"]) + cert, key = ca.get_cert("foo.bar.com", ["*.baz.com"]) + assert "*.baz.com" in cert.altnames + def test_overrides(self): with tutils.tmpdir() as d: ca1 = certutils.CertStore.from_store(os.path.join(d, "ca1"), "test") @@ -109,11 +116,15 @@ class TestDummyCert: class TestSSLCert: def test_simple(self): - c = certutils.SSLCert.from_pem(file(tutils.test_data.path("data/text_cert"), "rb").read()) + with open(tutils.test_data.path("data/text_cert"), "rb") as f: + d = f.read() + c = certutils.SSLCert.from_pem(d) assert c.cn == "google.com" assert len(c.altnames) == 436 - c = certutils.SSLCert.from_pem(file(tutils.test_data.path("data/text_cert_2"), "rb").read()) + with open(tutils.test_data.path("data/text_cert_2"), "rb") as f: + d = f.read() + c = certutils.SSLCert.from_pem(d) assert c.cn == "www.inode.co.nz" assert len(c.altnames) == 2 assert c.digest("sha1") @@ -127,12 +138,15 @@ class TestSSLCert: c.has_expired def test_err_broken_sans(self): - c = certutils.SSLCert.from_pem(file(tutils.test_data.path("data/text_cert_weird1"), "rb").read()) + with open(tutils.test_data.path("data/text_cert_weird1"), "rb") as f: + d = f.read() + c = certutils.SSLCert.from_pem(d) # This breaks unless we ignore a decoding error. c.altnames def test_der(self): - d = file(tutils.test_data.path("data/dercert"),"rb").read() + with open(tutils.test_data.path("data/dercert"), "rb") as f: + d = f.read() s = certutils.SSLCert.from_der(d) assert s.cn diff --git a/test/test_http.py b/test/test_http.py index e80e4b8f..497e80e2 100644 --- a/test/test_http.py +++ b/test/test_http.py @@ -16,26 +16,30 @@ def test_has_chunked_encoding(): def test_read_chunked(): + + h = odict.ODictCaseless() + h["transfer-encoding"] = ["chunked"] s = cStringIO.StringIO("1\r\na\r\n0\r\n") - tutils.raises("closed prematurely", http.read_chunked, s, None, None, True) + + tutils.raises("malformed chunked body", http.read_http_body, s, h, None, "GET", None, True) s = cStringIO.StringIO("1\r\na\r\n0\r\n\r\n") - assert http.read_chunked(s, None, None, True) == "a" + assert http.read_http_body(s, h, None, "GET", None, True) == "a" s = cStringIO.StringIO("\r\n\r\n1\r\na\r\n0\r\n\r\n") - assert http.read_chunked(s, None, None, True) == "a" + assert http.read_http_body(s, h, None, "GET", None, True) == "a" s = cStringIO.StringIO("\r\n") - tutils.raises("closed prematurely", http.read_chunked, s, None, None, True) + tutils.raises("closed prematurely", http.read_http_body, s, h, None, "GET", None, True) s = cStringIO.StringIO("1\r\nfoo") - tutils.raises("malformed chunked body", http.read_chunked, s, None, None, True) + tutils.raises("malformed chunked body", http.read_http_body, s, h, None, "GET", None, True) s = cStringIO.StringIO("foo\r\nfoo") - tutils.raises(http.HttpError, http.read_chunked, s, None, None, True) + tutils.raises(http.HttpError, http.read_http_body, s, h, None, "GET", None, True) s = cStringIO.StringIO("5\r\naaaaa\r\n0\r\n\r\n") - tutils.raises("too large", http.read_chunked, s, None, 2, True) + tutils.raises("too large", http.read_http_body, s, h, 2, "GET", None, True) def test_connection_close(): @@ -63,54 +67,73 @@ def test_get_header_tokens(): def test_read_http_body_request(): h = odict.ODictCaseless() r = cStringIO.StringIO("testing") - assert http.read_http_body(r, h, None, True) == "" + assert http.read_http_body(r, h, None, "GET", None, True) == "" def test_read_http_body_response(): h = odict.ODictCaseless() s = cStringIO.StringIO("testing") - assert http.read_http_body(s, h, None, False) == "testing" + assert http.read_http_body(s, h, None, "GET", 200, False) == "testing" def test_read_http_body(): # test default case h = odict.ODictCaseless() h["content-length"] = [7] s = cStringIO.StringIO("testing") - assert http.read_http_body(s, h, None, False) == "testing" + assert http.read_http_body(s, h, None, "GET", 200, False) == "testing" # test content length: invalid header h["content-length"] = ["foo"] s = cStringIO.StringIO("testing") - tutils.raises(http.HttpError, http.read_http_body, s, h, None, False) + tutils.raises(http.HttpError, http.read_http_body, s, h, None, "GET", 200, False) # test content length: invalid header #2 h["content-length"] = [-1] s = cStringIO.StringIO("testing") - tutils.raises(http.HttpError, http.read_http_body, s, h, None, False) + tutils.raises(http.HttpError, http.read_http_body, s, h, None, "GET", 200, False) # test content length: content length > actual content h["content-length"] = [5] s = cStringIO.StringIO("testing") - tutils.raises(http.HttpError, http.read_http_body, s, h, 4, False) + tutils.raises(http.HttpError, http.read_http_body, s, h, 4, "GET", 200, False) # test content length: content length < actual content s = cStringIO.StringIO("testing") - assert len(http.read_http_body(s, h, None, False)) == 5 + assert len(http.read_http_body(s, h, None, "GET", 200, False)) == 5 # test no content length: limit > actual content h = odict.ODictCaseless() s = cStringIO.StringIO("testing") - assert len(http.read_http_body(s, h, 100, False)) == 7 + assert len(http.read_http_body(s, h, 100, "GET", 200, False)) == 7 # test no content length: limit < actual content s = cStringIO.StringIO("testing") - tutils.raises(http.HttpError, http.read_http_body, s, h, 4, False) + tutils.raises(http.HttpError, http.read_http_body, s, h, 4, "GET", 200, False) # test chunked h = odict.ODictCaseless() h["transfer-encoding"] = ["chunked"] s = cStringIO.StringIO("5\r\naaaaa\r\n0\r\n\r\n") - assert http.read_http_body(s, h, 100, False) == "aaaaa" + assert http.read_http_body(s, h, 100, "GET", 200, False) == "aaaaa" +def test_expected_http_body_size(): + # gibber in the content-length field + h = odict.ODictCaseless() + h["content-length"] = ["foo"] + tutils.raises(http.HttpError, http.expected_http_body_size, h, False, "GET", 200) + # negative number in the content-length field + h = odict.ODictCaseless() + h["content-length"] = ["-7"] + tutils.raises(http.HttpError, http.expected_http_body_size, h, False, "GET", 200) + # explicit length + h = odict.ODictCaseless() + h["content-length"] = ["5"] + assert http.expected_http_body_size(h, False, "GET", 200) == 5 + # no length + h = odict.ODictCaseless() + assert http.expected_http_body_size(h, False, "GET", 200) == -1 + # no length request + h = odict.ODictCaseless() + assert http.expected_http_body_size(h, True, "GET", None) == 0 def test_parse_http_protocol(): assert http.parse_http_protocol("HTTP/1.1") == (1, 1) @@ -229,10 +252,10 @@ class TestReadResponseNoContentLength(test.ServerTestBase): assert content == "bar\r\n\r\n" def test_read_response(): - def tst(data, method, limit): + def tst(data, method, limit, include_body=True): data = textwrap.dedent(data) r = cStringIO.StringIO(data) - return http.read_response(r, method, limit) + return http.read_response(r, method, limit, include_body=include_body) tutils.raises("server disconnect", tst, "", "GET", None) tutils.raises("invalid server response", tst, "foo", "GET", None) @@ -277,6 +300,14 @@ def test_read_response(): """ tutils.raises("invalid headers", tst, data, "GET", None) + data = """ + HTTP/1.1 200 OK + Content-Length: 3 + + foo + """ + assert tst(data, "GET", None, include_body=False)[4] == None + def test_parse_url(): assert not http.parse_url("") diff --git a/test/test_http_auth.py b/test/test_http_auth.py index dd0273fe..176aa3ff 100644 --- a/test/test_http_auth.py +++ b/test/test_http_auth.py @@ -12,14 +12,10 @@ class TestPassManNonAnon: class TestPassManHtpasswd: def test_file_errors(self): - s = cStringIO.StringIO("foo") - tutils.raises("invalid htpasswd", http_auth.PassManHtpasswd, s) - s = cStringIO.StringIO("foo:bar$foo") - tutils.raises("invalid htpasswd", http_auth.PassManHtpasswd, s) + tutils.raises("malformed htpasswd file", http_auth.PassManHtpasswd, tutils.test_data.path("data/server.crt")) def test_simple(self): - f = open(tutils.test_data.path("data/htpasswd"),"rb") - pm = http_auth.PassManHtpasswd(f) + pm = http_auth.PassManHtpasswd(tutils.test_data.path("data/htpasswd")) vals = ("basic", "test", "test") p = http.assemble_http_basic_auth(*vals) diff --git a/test/test_socks.py b/test/test_socks.py new file mode 100644 index 00000000..740fdb9c --- /dev/null +++ b/test/test_socks.py @@ -0,0 +1,84 @@ +from cStringIO import StringIO +import socket +from nose.plugins.skip import SkipTest +from netlib import socks, tcp +import tutils + + +def test_client_greeting(): + raw = StringIO("\x05\x02\x00\xBE\xEF") + out = StringIO() + msg = socks.ClientGreeting.from_file(raw) + msg.to_file(out) + + assert out.getvalue() == raw.getvalue()[:-1] + assert msg.ver == 5 + assert len(msg.methods) == 2 + assert 0xBE in msg.methods + assert 0xEF not in msg.methods + + +def test_server_greeting(): + raw = StringIO("\x05\x02") + out = StringIO() + msg = socks.ServerGreeting.from_file(raw) + msg.to_file(out) + + assert out.getvalue() == raw.getvalue() + assert msg.ver == 5 + assert msg.method == 0x02 + + +def test_message(): + raw = StringIO("\x05\x01\x00\x03\x0bexample.com\xDE\xAD\xBE\xEF") + out = StringIO() + msg = socks.Message.from_file(raw) + assert raw.read(2) == "\xBE\xEF" + msg.to_file(out) + + assert out.getvalue() == raw.getvalue()[:-2] + assert msg.ver == 5 + assert msg.msg == 0x01 + assert msg.atyp == 0x03 + assert msg.addr == ("example.com", 0xDEAD) + + +def test_message_ipv4(): + # Test ATYP=0x01 (IPV4) + raw = StringIO("\x05\x01\x00\x01\x7f\x00\x00\x01\xDE\xAD\xBE\xEF") + out = StringIO() + msg = socks.Message.from_file(raw) + assert raw.read(2) == "\xBE\xEF" + msg.to_file(out) + + assert out.getvalue() == raw.getvalue()[:-2] + assert msg.addr == ("127.0.0.1", 0xDEAD) + + +def test_message_ipv6(): + if not hasattr(socket, "inet_ntop"): + raise SkipTest("Skipped because inet_ntop is not available") + # Test ATYP=0x04 (IPV6) + ipv6_addr = "2001:db8:85a3:8d3:1319:8a2e:370:7344" + + raw = StringIO("\x05\x01\x00\x04" + socket.inet_pton(socket.AF_INET6, ipv6_addr) + "\xDE\xAD\xBE\xEF") + out = StringIO() + msg = socks.Message.from_file(raw) + assert raw.read(2) == "\xBE\xEF" + msg.to_file(out) + + assert out.getvalue() == raw.getvalue()[:-2] + assert msg.addr.host == ipv6_addr + + +def test_message_invalid_rsv(): + raw = StringIO("\x05\x01\xFF\x01\x7f\x00\x00\x01\xDE\xAD\xBE\xEF") + tutils.raises(socks.SocksError, socks.Message.from_file, raw) + + +def test_message_unknown_atyp(): + raw = StringIO("\x05\x02\x00\x02\x7f\x00\x00\x01\xDE\xAD\xBE\xEF") + tutils.raises(socks.SocksError, socks.Message.from_file, raw) + + m = socks.Message(5, 1, 0x02, tcp.Address(("example.com", 5050))) + tutils.raises(socks.SocksError, m.to_file, StringIO())
\ No newline at end of file diff --git a/test/test_tcp.py b/test/test_tcp.py index 77146829..bf681811 100644 --- a/test/test_tcp.py +++ b/test/test_tcp.py @@ -1,5 +1,5 @@ import cStringIO, Queue, time, socket, random -from netlib import tcp, certutils, test +from netlib import tcp, certutils, test, certffi import mock import tutils from OpenSSL import SSL @@ -129,9 +129,6 @@ class TestServerSSL(test.ServerTestBase): c.wfile.flush() assert c.rfile.readline() == testval - def test_get_remote_cert(self): - assert certutils.get_remote_cert("127.0.0.1", self.port, None).digest("sha1") - def test_get_current_cipher(self): c = tcp.TCPClient(("127.0.0.1", self.port)) c.connect() @@ -419,7 +416,7 @@ class TestPrivkeyGenNoFlags(test.ServerTestBase): def test_privkey(self): c = tcp.TCPClient(("127.0.0.1", self.port)) c.connect() - tutils.raises("unexpected eof", c.convert_to_ssl) + tutils.raises("sslv3 alert handshake failure", c.convert_to_ssl) diff --git a/test/test_utils.py b/test/test_utils.py index 61820a81..971e5076 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -1,5 +1,5 @@ from netlib import utils - +import socket def test_hexdump(): assert utils.hexdump("one\0"*10) @@ -10,4 +10,3 @@ def test_cleanBin(): assert utils.cleanBin("\00ne") == ".ne" assert utils.cleanBin("\nne") == "\nne" assert utils.cleanBin("\nne", True) == ".ne" - diff --git a/test/test_wsgi.py b/test/test_wsgi.py index 91a8ff7a..6e1fb146 100644 --- a/test/test_wsgi.py +++ b/test/test_wsgi.py @@ -2,11 +2,11 @@ import cStringIO, sys from netlib import wsgi, odict -def treq(): - cc = wsgi.ClientConn(("127.0.0.1", 8888)) +def tflow(): h = odict.ODictCaseless() h["test"] = ["value"] - return wsgi.Request(cc, "http", "GET", "/", h, "") + req = wsgi.Request("http", "GET", "/", h, "") + return wsgi.Flow(("127.0.0.1", 8888), req) class TestApp: @@ -24,22 +24,22 @@ class TestApp: class TestWSGI: def test_make_environ(self): w = wsgi.WSGIAdaptor(None, "foo", 80, "version") - tr = treq() - assert w.make_environ(tr, None) + tf = tflow() + assert w.make_environ(tf, None) - tr.path = "/foo?bar=voing" - r = w.make_environ(tr, None) + tf.request.path = "/foo?bar=voing" + r = w.make_environ(tf, None) assert r["QUERY_STRING"] == "bar=voing" def test_serve(self): ta = TestApp() w = wsgi.WSGIAdaptor(ta, "foo", 80, "version") - r = treq() - r.host = "foo" - r.port = 80 + f = tflow() + f.request.host = "foo" + f.request.port = 80 wfile = cStringIO.StringIO() - err = w.serve(r, wfile) + err = w.serve(f, wfile) assert ta.called assert not err @@ -49,11 +49,11 @@ class TestWSGI: def _serve(self, app): w = wsgi.WSGIAdaptor(app, "foo", 80, "version") - r = treq() - r.host = "foo" - r.port = 80 + f = tflow() + f.request.host = "foo" + f.request.port = 80 wfile = cStringIO.StringIO() - err = w.serve(r, wfile) + err = w.serve(f, wfile) return wfile.getvalue() def test_serve_empty_body(self): |