aboutsummaryrefslogtreecommitdiffstats
path: root/test/test_tcp.py
blob: a81632e79356502b0a3731e6a2b39bf084c462fb (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
import cStringIO, threading, Queue
from netlib import tcp
import tutils

class ServerThread(threading.Thread):
    def __init__(self, server):
        self.server = server
        threading.Thread.__init__(self)

    def run(self):
        self.server.serve_forever()

    def shutdown(self):
        self.server.shutdown()


class ServerTestBase:
    @classmethod
    def setupAll(cls):
        cls.server = ServerThread(cls.makeserver())
        cls.server.start()

    @classmethod
    def teardownAll(cls):
        cls.server.shutdown()


class EchoHandler(tcp.BaseHandler):
    def handle(self):
        v = self.rfile.readline()
        if v.startswith("echo"):
            self.wfile.write(v)
        elif v.startswith("error"):
            raise ValueError("Testing an error.")
        self.wfile.flush()


class DisconnectHandler(tcp.BaseHandler):
    def handle(self):
        self.finish()


class TServer(tcp.TCPServer):
    def __init__(self, addr, ssl, q, handler):
        tcp.TCPServer.__init__(self, addr)
        self.ssl, self.q = ssl, q
        self.handler = handler

    def handle_connection(self, request, client_address):
        h = self.handler(request, client_address, self)
        if self.ssl:
            h.convert_to_ssl(
                tutils.test_data.path("data/server.crt"),
                tutils.test_data.path("data/server.key"),
            )
        h.handle()
        h.finish()

    def handle_error(self, request, client_address):
        s = cStringIO.StringIO()
        tcp.TCPServer.handle_error(self, request, client_address, s)
        self.q.put(s.getvalue())


class TestServer(ServerTestBase):
    @classmethod
    def makeserver(cls):
        cls.q = Queue.Queue()
        s = TServer(("127.0.0.1", 0), False, cls.q, EchoHandler)
        cls.port = s.port
        return s

    def test_echo(self):
        testval = "echo!\n"
        c = tcp.TCPClient("127.0.0.1", self.port)
        c.connect()
        c.wfile.write(testval)
        c.wfile.flush()
        assert c.rfile.readline() == testval


class TestServerSSL(ServerTestBase):
    @classmethod
    def makeserver(cls):
        cls.q = Queue.Queue()
        s = TServer(("127.0.0.1", 0), True, cls.q, EchoHandler)
        cls.port = s.port
        return s

    def test_echo(self):
        c = tcp.TCPClient("127.0.0.1", self.port)
        c.connect()
        c.convert_to_ssl()
        testval = "echo!\n"
        c.wfile.write(testval)
        c.wfile.flush()
        assert c.rfile.readline() == testval


class TestSSLDisconnect(ServerTestBase):
    @classmethod
    def makeserver(cls):
        cls.q = Queue.Queue()
        s = TServer(("127.0.0.1", 0), True, cls.q, DisconnectHandler)
        cls.port = s.port
        return s

    def test_echo(self):
        c = tcp.TCPClient("127.0.0.1", self.port)
        c.connect()
        c.convert_to_ssl()
        # Excercise SSL.ZeroReturnError
        c.rfile.read(10)


class TestTCPClient:
    def test_conerr(self):
        c = tcp.TCPClient("127.0.0.1", 0)
        tutils.raises(tcp.NetLibError, c.connect)


class TestFileLike:
    def test_wrap(self):
        s = cStringIO.StringIO("foobar\nfoobar")
        s = tcp.FileLike(s)
        s.flush()
        assert s.readline() == "foobar\n"
        assert s.readline() == "foobar"
        # Test __getattr__
        assert s.isatty

    def test_limit(self):
        s = cStringIO.StringIO("foobar\nfoobar")
        s = tcp.FileLike(s)
        assert s.readline(3) == "foo"