aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--netlib/http2/protocol.py16
-rw-r--r--netlib/tcp.py9
-rw-r--r--test/http2/test_http2_protocol.py13
-rw-r--r--test/test_tcp.py41
4 files changed, 53 insertions, 26 deletions
diff --git a/netlib/http2/protocol.py b/netlib/http2/protocol.py
index a77edd9b..8191090c 100644
--- a/netlib/http2/protocol.py
+++ b/netlib/http2/protocol.py
@@ -55,7 +55,7 @@ class HTTP2Protocol(object):
if isinstance(frm, frame.SettingsFrame):
break
- def _read_settings_ack(self, hide=False):
+ def _read_settings_ack(self, hide=False): # pragma no cover
while True:
frm = self.read_frame(hide)
if isinstance(frm, frame.SettingsFrame):
@@ -99,12 +99,12 @@ class HTTP2Protocol(object):
raw_bytes = frm.to_bytes()
self.tcp_handler.wfile.write(raw_bytes)
self.tcp_handler.wfile.flush()
- if not hide and self.dump_frames:
+ if not hide and self.dump_frames: # pragma no cover
print(frm.human_readable(">>"))
def read_frame(self, hide=False):
frm = frame.Frame.from_file(self.tcp_handler.rfile, self)
- if not hide and self.dump_frames:
+ if not hide and self.dump_frames: # pragma no cover
print(frm.human_readable("<<"))
if isinstance(frm, frame.SettingsFrame) and not frm.flags & frame.Frame.FLAG_ACK:
self._apply_settings(frm.settings, hide)
@@ -123,7 +123,9 @@ class HTTP2Protocol(object):
state=self,
flags=frame.Frame.FLAG_ACK),
hide)
- # self._read_settings_ack(hide)
+
+ # be liberal in what we expect from the other end
+ # to be more strict use: self._read_settings_ack(hide)
def _create_headers(self, headers, stream_id, end_stream=True):
# TODO: implement max frame size checks and sending in chunks
@@ -140,7 +142,7 @@ class HTTP2Protocol(object):
stream_id=stream_id,
header_block_fragment=header_block_fragment)
- if self.dump_frames:
+ if self.dump_frames: # pragma no cover
print(frm.human_readable(">>"))
return [frm.to_bytes()]
@@ -158,7 +160,7 @@ class HTTP2Protocol(object):
stream_id=stream_id,
payload=body)
- if self.dump_frames:
+ if self.dump_frames: # pragma no cover
print(frm.human_readable(">>"))
return [frm.to_bytes()]
@@ -225,8 +227,6 @@ class HTTP2Protocol(object):
if headers is None:
headers = []
- body='foobar'
-
headers = [(b':status', bytes(str(code)))] + headers
if not stream_id:
diff --git a/netlib/tcp.py b/netlib/tcp.py
index 2e847d83..cafc3ed9 100644
--- a/netlib/tcp.py
+++ b/netlib/tcp.py
@@ -414,6 +414,9 @@ class _Connection(object):
if cipher_list:
try:
context.set_cipher_list(cipher_list)
+
+ # TODO: maybe change this to with newer pyOpenSSL APIs
+ context.set_tmp_ecdh(OpenSSL.crypto.get_elliptic_curve('prime256v1'))
except SSL.Error as v:
raise NetLibError("SSL cipher specification error: %s" % str(v))
@@ -421,8 +424,6 @@ class _Connection(object):
if log_ssl_key:
context.set_info_callback(log_ssl_key)
- context.set_tmp_ecdh(OpenSSL.crypto.get_elliptic_curve('prime256v1'))
-
if OpenSSL._util.lib.Cryptography_HAS_ALPN:
if alpn_protos is not None:
# advertise application layer protocols
@@ -526,7 +527,7 @@ class TCPClient(_Connection):
if OpenSSL._util.lib.Cryptography_HAS_ALPN and self.ssl_established:
return self.connection.get_alpn_proto_negotiated()
else:
- return None
+ return ""
class BaseHandler(_Connection):
@@ -636,7 +637,7 @@ class BaseHandler(_Connection):
if OpenSSL._util.lib.Cryptography_HAS_ALPN and self.ssl_established:
return self.connection.get_alpn_proto_negotiated()
else:
- return None
+ return ""
class TCPServer(object):
diff --git a/test/http2/test_http2_protocol.py b/test/http2/test_http2_protocol.py
index 34c69fa9..231b35e0 100644
--- a/test/http2/test_http2_protocol.py
+++ b/test/http2/test_http2_protocol.py
@@ -300,8 +300,9 @@ class TestReadRequest(test.ServerTestBase):
c.convert_to_ssl()
protocol = http2.HTTP2Protocol(c, is_server=True)
- headers, body = protocol.read_request()
+ stream_id, headers, body = protocol.read_request()
+ assert stream_id
assert headers == {':method': 'GET', ':path': '/', ':scheme': 'https'}
assert body == b'foobar'
@@ -309,17 +310,17 @@ class TestReadRequest(test.ServerTestBase):
class TestCreateResponse():
c = tcp.TCPClient(("127.0.0.1", 0))
- def test_create_request_simple(self):
+ def test_create_response_simple(self):
bytes = http2.HTTP2Protocol(self.c, is_server=True).create_response(200)
assert len(bytes) == 1
assert bytes[0] ==\
'00000101050000000288'.decode('hex')
- def test_create_request_with_body(self):
+ def test_create_response_with_body(self):
bytes = http2.HTTP2Protocol(self.c, is_server=True).create_response(
- 200, [(b'foo', b'bar')], 'foobar')
+ 200, 1, [(b'foo', b'bar')], 'foobar')
assert len(bytes) == 2
assert bytes[0] ==\
- '00000901040000000288408294e7838c767f'.decode('hex')
+ '00000901040000000188408294e7838c767f'.decode('hex')
assert bytes[1] ==\
- '000006000100000002666f6f626172'.decode('hex')
+ '000006000100000001666f6f626172'.decode('hex')
diff --git a/test/test_tcp.py b/test/test_tcp.py
index 0cecaaa2..122c1f0f 100644
--- a/test/test_tcp.py
+++ b/test/test_tcp.py
@@ -41,6 +41,18 @@ class HangHandler(tcp.BaseHandler):
time.sleep(1)
+class ALPNHandler(tcp.BaseHandler):
+ sni = None
+
+ def handle(self):
+ alp = self.get_alpn_proto_negotiated()
+ if alp:
+ self.wfile.write("%s" % alp)
+ else:
+ self.wfile.write("NONE")
+ self.wfile.flush()
+
+
class TestServer(test.ServerTestBase):
handler = EchoHandler
@@ -416,30 +428,43 @@ class TestTimeOut(test.ServerTestBase):
tutils.raises(tcp.NetLibTimeout, c.rfile.read, 10)
-class TestALPN(test.ServerTestBase):
- handler = EchoHandler
+class TestALPNClient(test.ServerTestBase):
+ handler = ALPNHandler
ssl = dict(
- alpn_select="foobar"
+ alpn_select="bar"
)
if OpenSSL._util.lib.Cryptography_HAS_ALPN:
def test_alpn(self):
c = tcp.TCPClient(("127.0.0.1", self.port))
c.connect()
- c.convert_to_ssl(alpn_protos=["foobar"])
- assert c.get_alpn_proto_negotiated() == "foobar"
+ c.convert_to_ssl(alpn_protos=["foo", "bar", "fasel"])
+ assert c.get_alpn_proto_negotiated() == "bar"
+ assert c.rfile.readline().strip() == "bar"
def test_no_alpn(self):
c = tcp.TCPClient(("127.0.0.1", self.port))
c.connect()
- assert c.get_alpn_proto_negotiated() == None
+ c.convert_to_ssl()
+ assert c.get_alpn_proto_negotiated() == ""
+ assert c.rfile.readline().strip() == "NONE"
else:
def test_none_alpn(self):
c = tcp.TCPClient(("127.0.0.1", self.port))
c.connect()
- c.convert_to_ssl(alpn_protos=["foobar"])
- assert c.get_alpn_proto_negotiated() == None
+ c.convert_to_ssl(alpn_protos=["foo", "bar", "fasel"])
+ assert c.get_alpn_proto_negotiated() == ""
+ assert c.rfile.readline() == "NONE"
+
+class TestNoSSLNoALPNClient(test.ServerTestBase):
+ handler = ALPNHandler
+
+ def test_no_ssl_no_alpn(self):
+ c = tcp.TCPClient(("127.0.0.1", self.port))
+ c.connect()
+ assert c.get_alpn_proto_negotiated() == ""
+ assert c.rfile.readline().strip() == "NONE"
class TestSSLTimeOut(test.ServerTestBase):