aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--netlib/websockets.py24
-rw-r--r--test/test_websockets.py26
2 files changed, 12 insertions, 38 deletions
diff --git a/netlib/websockets.py b/netlib/websockets.py
index 016e75c2..b1afa620 100644
--- a/netlib/websockets.py
+++ b/netlib/websockets.py
@@ -115,6 +115,8 @@ def create_server_nonce(client_nonce):
DEFAULT = object()
+
+
class FrameHeader:
def __init__(
self,
@@ -128,6 +130,8 @@ class FrameHeader:
mask = DEFAULT,
length_code = DEFAULT
):
+ if not 0 <= opcode < 2 ** 4:
+ raise ValueError("opcode must be 0-16")
self.opcode = opcode
self.payload_length = payload_length
self.fin = fin
@@ -275,26 +279,6 @@ class Frame(object):
masking_key = masking_key,
)
- def is_valid(self):
- """
- Validate websocket frame invariants, call at anytime to ensure the
- Frame has not been corrupted.
- """
- constraints = [
- 0 <= self.header.fin <= 1,
- 0 <= self.header.rsv1 <= 1,
- 0 <= self.header.rsv2 <= 1,
- 0 <= self.header.rsv3 <= 1,
- 1 <= self.header.opcode <= 4,
- 0 <= self.header.mask <= 1,
- #1 <= self.payload_length_code <= 127,
- 1 <= len(self.header.masking_key) <= 4 if self.header.mask else True,
- self.header.masking_key is not None if self.header.mask else True
- ]
- if not all(constraints):
- return False
- return True
-
def human_readable(self):
return "\n".join([
("fin - " + str(self.header.fin)),
diff --git a/test/test_websockets.py b/test/test_websockets.py
index 06876e0b..215b3958 100644
--- a/test/test_websockets.py
+++ b/test/test_websockets.py
@@ -4,6 +4,7 @@ import os
from nose.tools import raises
from netlib import tcp, test, websockets, http
+import tutils
class WebSocketsEchoHandler(tcp.BaseHandler):
@@ -106,25 +107,7 @@ class TestWebSockets(test.ServerTestBase):
"""
msg = self.random_bytes()
client_frame = websockets.Frame.default(msg, from_client = True)
-
server_frame = websockets.Frame.default(msg, from_client = False)
- assert server_frame.is_valid()
-
- def test_is_valid(self):
- def f():
- return websockets.Frame.default(self.random_bytes(10), True)
-
- frame = f()
- assert frame.is_valid()
-
- frame = f()
- frame.header.fin = 2
- assert not frame.is_valid()
-
- frame = f()
- frame.header.mask_bit = 1
- frame.header.masking_key = "foobbarboo"
- assert not frame.is_valid()
def test_serialization_bijection(self):
"""
@@ -207,6 +190,9 @@ class TestFrameHeader:
f2 = websockets.FrameHeader.from_file(cStringIO.StringIO(bytes))
assert not f2.mask
+ def test_violations(self):
+ tutils.raises("opcode", websockets.FrameHeader, opcode=17)
+
class TestFrame:
def test_roundtrip(self):
@@ -216,3 +202,7 @@ class TestFrame:
f2 = websockets.Frame.from_file(cStringIO.StringIO(bytes))
assert f == f2
round("test")
+
+ def test_human_readable(self):
+ f = websockets.Frame()
+ assert f.human_readable()