diff options
-rw-r--r-- | netlib/http2/protocol.py | 16 | ||||
-rw-r--r-- | netlib/tcp.py | 9 | ||||
-rw-r--r-- | test/http2/test_http2_protocol.py | 13 | ||||
-rw-r--r-- | test/test_tcp.py | 41 |
4 files changed, 53 insertions, 26 deletions
diff --git a/netlib/http2/protocol.py b/netlib/http2/protocol.py index a77edd9b..8191090c 100644 --- a/netlib/http2/protocol.py +++ b/netlib/http2/protocol.py @@ -55,7 +55,7 @@ class HTTP2Protocol(object): if isinstance(frm, frame.SettingsFrame): break - def _read_settings_ack(self, hide=False): + def _read_settings_ack(self, hide=False): # pragma no cover while True: frm = self.read_frame(hide) if isinstance(frm, frame.SettingsFrame): @@ -99,12 +99,12 @@ class HTTP2Protocol(object): raw_bytes = frm.to_bytes() self.tcp_handler.wfile.write(raw_bytes) self.tcp_handler.wfile.flush() - if not hide and self.dump_frames: + if not hide and self.dump_frames: # pragma no cover print(frm.human_readable(">>")) def read_frame(self, hide=False): frm = frame.Frame.from_file(self.tcp_handler.rfile, self) - if not hide and self.dump_frames: + if not hide and self.dump_frames: # pragma no cover print(frm.human_readable("<<")) if isinstance(frm, frame.SettingsFrame) and not frm.flags & frame.Frame.FLAG_ACK: self._apply_settings(frm.settings, hide) @@ -123,7 +123,9 @@ class HTTP2Protocol(object): state=self, flags=frame.Frame.FLAG_ACK), hide) - # self._read_settings_ack(hide) + + # be liberal in what we expect from the other end + # to be more strict use: self._read_settings_ack(hide) def _create_headers(self, headers, stream_id, end_stream=True): # TODO: implement max frame size checks and sending in chunks @@ -140,7 +142,7 @@ class HTTP2Protocol(object): stream_id=stream_id, header_block_fragment=header_block_fragment) - if self.dump_frames: + if self.dump_frames: # pragma no cover print(frm.human_readable(">>")) return [frm.to_bytes()] @@ -158,7 +160,7 @@ class HTTP2Protocol(object): stream_id=stream_id, payload=body) - if self.dump_frames: + if self.dump_frames: # pragma no cover print(frm.human_readable(">>")) return [frm.to_bytes()] @@ -225,8 +227,6 @@ class HTTP2Protocol(object): if headers is None: headers = [] - body='foobar' - headers = [(b':status', bytes(str(code)))] + headers if not stream_id: diff --git a/netlib/tcp.py b/netlib/tcp.py index 2e847d83..cafc3ed9 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -414,6 +414,9 @@ class _Connection(object): if cipher_list: try: context.set_cipher_list(cipher_list) + + # TODO: maybe change this to with newer pyOpenSSL APIs + context.set_tmp_ecdh(OpenSSL.crypto.get_elliptic_curve('prime256v1')) except SSL.Error as v: raise NetLibError("SSL cipher specification error: %s" % str(v)) @@ -421,8 +424,6 @@ class _Connection(object): if log_ssl_key: context.set_info_callback(log_ssl_key) - context.set_tmp_ecdh(OpenSSL.crypto.get_elliptic_curve('prime256v1')) - if OpenSSL._util.lib.Cryptography_HAS_ALPN: if alpn_protos is not None: # advertise application layer protocols @@ -526,7 +527,7 @@ class TCPClient(_Connection): if OpenSSL._util.lib.Cryptography_HAS_ALPN and self.ssl_established: return self.connection.get_alpn_proto_negotiated() else: - return None + return "" class BaseHandler(_Connection): @@ -636,7 +637,7 @@ class BaseHandler(_Connection): if OpenSSL._util.lib.Cryptography_HAS_ALPN and self.ssl_established: return self.connection.get_alpn_proto_negotiated() else: - return None + return "" class TCPServer(object): diff --git a/test/http2/test_http2_protocol.py b/test/http2/test_http2_protocol.py index 34c69fa9..231b35e0 100644 --- a/test/http2/test_http2_protocol.py +++ b/test/http2/test_http2_protocol.py @@ -300,8 +300,9 @@ class TestReadRequest(test.ServerTestBase): c.convert_to_ssl() protocol = http2.HTTP2Protocol(c, is_server=True) - headers, body = protocol.read_request() + stream_id, headers, body = protocol.read_request() + assert stream_id assert headers == {':method': 'GET', ':path': '/', ':scheme': 'https'} assert body == b'foobar' @@ -309,17 +310,17 @@ class TestReadRequest(test.ServerTestBase): class TestCreateResponse(): c = tcp.TCPClient(("127.0.0.1", 0)) - def test_create_request_simple(self): + def test_create_response_simple(self): bytes = http2.HTTP2Protocol(self.c, is_server=True).create_response(200) assert len(bytes) == 1 assert bytes[0] ==\ '00000101050000000288'.decode('hex') - def test_create_request_with_body(self): + def test_create_response_with_body(self): bytes = http2.HTTP2Protocol(self.c, is_server=True).create_response( - 200, [(b'foo', b'bar')], 'foobar') + 200, 1, [(b'foo', b'bar')], 'foobar') assert len(bytes) == 2 assert bytes[0] ==\ - '00000901040000000288408294e7838c767f'.decode('hex') + '00000901040000000188408294e7838c767f'.decode('hex') assert bytes[1] ==\ - '000006000100000002666f6f626172'.decode('hex') + '000006000100000001666f6f626172'.decode('hex') diff --git a/test/test_tcp.py b/test/test_tcp.py index 0cecaaa2..122c1f0f 100644 --- a/test/test_tcp.py +++ b/test/test_tcp.py @@ -41,6 +41,18 @@ class HangHandler(tcp.BaseHandler): time.sleep(1) +class ALPNHandler(tcp.BaseHandler): + sni = None + + def handle(self): + alp = self.get_alpn_proto_negotiated() + if alp: + self.wfile.write("%s" % alp) + else: + self.wfile.write("NONE") + self.wfile.flush() + + class TestServer(test.ServerTestBase): handler = EchoHandler @@ -416,30 +428,43 @@ class TestTimeOut(test.ServerTestBase): tutils.raises(tcp.NetLibTimeout, c.rfile.read, 10) -class TestALPN(test.ServerTestBase): - handler = EchoHandler +class TestALPNClient(test.ServerTestBase): + handler = ALPNHandler ssl = dict( - alpn_select="foobar" + alpn_select="bar" ) if OpenSSL._util.lib.Cryptography_HAS_ALPN: def test_alpn(self): c = tcp.TCPClient(("127.0.0.1", self.port)) c.connect() - c.convert_to_ssl(alpn_protos=["foobar"]) - assert c.get_alpn_proto_negotiated() == "foobar" + c.convert_to_ssl(alpn_protos=["foo", "bar", "fasel"]) + assert c.get_alpn_proto_negotiated() == "bar" + assert c.rfile.readline().strip() == "bar" def test_no_alpn(self): c = tcp.TCPClient(("127.0.0.1", self.port)) c.connect() - assert c.get_alpn_proto_negotiated() == None + c.convert_to_ssl() + assert c.get_alpn_proto_negotiated() == "" + assert c.rfile.readline().strip() == "NONE" else: def test_none_alpn(self): c = tcp.TCPClient(("127.0.0.1", self.port)) c.connect() - c.convert_to_ssl(alpn_protos=["foobar"]) - assert c.get_alpn_proto_negotiated() == None + c.convert_to_ssl(alpn_protos=["foo", "bar", "fasel"]) + assert c.get_alpn_proto_negotiated() == "" + assert c.rfile.readline() == "NONE" + +class TestNoSSLNoALPNClient(test.ServerTestBase): + handler = ALPNHandler + + def test_no_ssl_no_alpn(self): + c = tcp.TCPClient(("127.0.0.1", self.port)) + c.connect() + assert c.get_alpn_proto_negotiated() == "" + assert c.rfile.readline().strip() == "NONE" class TestSSLTimeOut(test.ServerTestBase): |