diff options
Diffstat (limited to 'test/test_tcp.py')
-rw-r--r-- | test/test_tcp.py | 127 |
1 files changed, 72 insertions, 55 deletions
diff --git a/test/test_tcp.py b/test/test_tcp.py index 814754cd..ec995702 100644 --- a/test/test_tcp.py +++ b/test/test_tcp.py @@ -4,16 +4,6 @@ import mock import tutils from OpenSSL import SSL -class SNIHandler(tcp.BaseHandler): - sni = None - def handle_sni(self, connection): - self.sni = connection.get_servername() - - def handle(self): - self.wfile.write(self.sni) - self.wfile.flush() - - class EchoHandler(tcp.BaseHandler): sni = None def handle_sni(self, connection): @@ -25,58 +15,19 @@ class EchoHandler(tcp.BaseHandler): self.wfile.flush() -class ClientPeernameHandler(tcp.BaseHandler): - def handle(self): - self.wfile.write(str(self.connection.getpeername())) - self.wfile.flush() - - -class CertHandler(tcp.BaseHandler): - sni = None - def handle_sni(self, connection): - self.sni = connection.get_servername() - - def handle(self): - self.wfile.write("%s\n"%self.clientcert.serial) - self.wfile.flush() - - class ClientCipherListHandler(tcp.BaseHandler): sni = None - def handle(self): self.wfile.write("%s"%self.connection.get_cipher_list()) self.wfile.flush() -class CurrentCipherHandler(tcp.BaseHandler): - sni = None - def handle(self): - self.wfile.write("%s"%str(self.get_current_cipher())) - self.wfile.flush() - - -class DisconnectHandler(tcp.BaseHandler): - def handle(self): - self.close() - - class HangHandler(tcp.BaseHandler): def handle(self): while 1: time.sleep(1) -class TimeoutHandler(tcp.BaseHandler): - def handle(self): - self.timeout = False - self.settimeout(0.01) - try: - self.rfile.read(10) - except tcp.NetLibTimeout: - self.timeout = True - - class TestServer(test.ServerTestBase): handler = EchoHandler def test_echo(self): @@ -89,7 +40,10 @@ class TestServer(test.ServerTestBase): class TestServerBind(test.ServerTestBase): - handler = ClientPeernameHandler + class handler(tcp.BaseHandler): + def handle(self): + self.wfile.write(str(self.connection.getpeername())) + self.wfile.flush() def test_bind(self): """ Test to bind to a given random port. Try again if the random port turned out to be blocked. """ @@ -198,7 +152,14 @@ class TestSSLv3Only(test.ServerTestBase): class TestSSLClientCert(test.ServerTestBase): - handler = CertHandler + class handler(tcp.BaseHandler): + sni = None + def handle_sni(self, connection): + self.sni = connection.get_servername() + + def handle(self): + self.wfile.write("%s\n"%self.clientcert.serial) + self.wfile.flush() ssl = dict( cert = tutils.test_data.path("data/server.crt"), key = tutils.test_data.path("data/server.key"), @@ -222,7 +183,15 @@ class TestSSLClientCert(test.ServerTestBase): class TestSNI(test.ServerTestBase): - handler = SNIHandler + class handler(tcp.BaseHandler): + sni = None + def handle_sni(self, connection): + self.sni = connection.get_servername() + + def handle(self): + self.wfile.write(self.sni) + self.wfile.flush() + ssl = dict( cert = tutils.test_data.path("data/server.crt"), key = tutils.test_data.path("data/server.key"), @@ -254,7 +223,11 @@ class TestServerCipherList(test.ServerTestBase): class TestServerCurrentCipher(test.ServerTestBase): - handler = CurrentCipherHandler + class handler(tcp.BaseHandler): + sni = None + def handle(self): + self.wfile.write("%s"%str(self.get_current_cipher())) + self.wfile.flush() ssl = dict( cert = tutils.test_data.path("data/server.crt"), key = tutils.test_data.path("data/server.key"), @@ -300,7 +273,9 @@ class TestClientCipherListError(test.ServerTestBase): class TestSSLDisconnect(test.ServerTestBase): - handler = DisconnectHandler + class handler(tcp.BaseHandler): + def handle(self): + self.close() ssl = dict( cert = tutils.test_data.path("data/server.crt"), key = tutils.test_data.path("data/server.key"), @@ -329,7 +304,15 @@ class TestDisconnect(test.ServerTestBase): class TestServerTimeOut(test.ServerTestBase): - handler = TimeoutHandler + class handler(tcp.BaseHandler): + def handle(self): + self.timeout = False + self.settimeout(0.01) + try: + self.rfile.read(10) + except tcp.NetLibTimeout: + self.timeout = True + def test_timeout(self): c = tcp.TCPClient(("127.0.0.1", self.port)) c.connect() @@ -383,6 +366,40 @@ class TestDHParams(test.ServerTestBase): assert ret[0] == "DHE-RSA-AES256-SHA" + +class TestPrivkeyGen(test.ServerTestBase): + class handler(tcp.BaseHandler): + def handle(self): + with tutils.tmpdir() as d: + ca1 = certutils.CertStore.from_store(d, "test2") + ca2 = certutils.CertStore.from_store(d, "test3") + cert, _ = ca1.get_cert("foo.com", []) + key = ca2.gen_pkey(cert) + self.convert_to_ssl(cert, key) + + def test_privkey(self): + c = tcp.TCPClient(("127.0.0.1", self.port)) + c.connect() + tutils.raises("bad record mac", c.convert_to_ssl) + + +class TestPrivkeyGenNoFlags(test.ServerTestBase): + class handler(tcp.BaseHandler): + def handle(self): + with tutils.tmpdir() as d: + ca1 = certutils.CertStore.from_store(d, "test2") + ca2 = certutils.CertStore.from_store(d, "test3") + cert, _ = ca1.get_cert("foo.com", []) + certffi.set_flags(ca2.privkey, 0) + self.convert_to_ssl(cert, ca2.privkey) + + def test_privkey(self): + c = tcp.TCPClient(("127.0.0.1", self.port)) + c.connect() + tutils.raises("unexpected eof", c.convert_to_ssl) + + + class TestTCPClient: def test_conerr(self): c = tcp.TCPClient(("127.0.0.1", 0)) |