aboutsummaryrefslogtreecommitdiffstats
path: root/netlib/websockets.py
diff options
context:
space:
mode:
Diffstat (limited to 'netlib/websockets.py')
-rw-r--r--netlib/websockets.py60
1 files changed, 23 insertions, 37 deletions
diff --git a/netlib/websockets.py b/netlib/websockets.py
index abf86262..7c127563 100644
--- a/netlib/websockets.py
+++ b/netlib/websockets.py
@@ -198,15 +198,13 @@ class Frame(object):
self,
fin, # decmial integer 1 or 0
opcode, # decmial integer 1 - 4
- mask_bit, # decimal integer 1 or 0
- payload_length_code, # decimal integer 1 - 127
- decoded_payload, # bytestring
+ payload = "", # bytestring
+ masking_key = None, # 32 bit byte string
+ mask_bit = 0, # decimal integer 1 or 0
+ payload_length_code = None, # decimal integer 1 - 127
rsv1 = 0, # decimal integer 1 or 0
rsv2 = 0, # decimal integer 1 or 0
rsv3 = 0, # decimal integer 1 or 0
- payload = None, # bytestring
- masking_key = None, # 32 bit byte string
- actual_payload_length = None, # any decimal integer
):
self.fin = fin
self.rsv1 = rsv1
@@ -217,8 +215,6 @@ class Frame(object):
self.payload_length_code = payload_length_code
self.masking_key = masking_key
self.payload = payload
- self.decoded_payload = decoded_payload
- self.actual_payload_length = actual_payload_length
@classmethod
def default(cls, message, from_client = False):
@@ -226,27 +222,19 @@ class Frame(object):
Construct a basic websocket frame from some default values.
Creates a non-fragmented text frame.
"""
- length_code, actual_length = get_payload_length_pair(message)
-
if from_client:
mask_bit = 1
- # Random masking key
masking_key = os.urandom(4)
- payload = apply_mask(message, masking_key)
else:
mask_bit = 0
masking_key = None
- payload = message
return cls(
fin = 1, # final frame
opcode = OPCODE.TEXT, # text
mask_bit = mask_bit,
- payload_length_code = length_code,
- payload = payload,
+ payload = message,
masking_key = masking_key,
- decoded_payload = message,
- actual_payload_length = actual_length
)
def is_valid(self):
@@ -261,17 +249,12 @@ class Frame(object):
0 <= self.rsv3 <= 1,
1 <= self.opcode <= 4,
0 <= self.mask_bit <= 1,
- 1 <= self.payload_length_code <= 127,
- self.actual_payload_length == len(self.payload),
+ #1 <= self.payload_length_code <= 127,
1 <= len(self.masking_key) <= 4 if self.mask_bit else True,
self.masking_key is not None if self.mask_bit else True
]
if not all(constraints):
return False
- elif self.payload and self.masking_key:
- decoded = apply_mask(self.payload, self.masking_key)
- if decoded != self.decoded_payload:
- return False
return True
def human_readable(self): # pragma: nocover
@@ -285,8 +268,6 @@ class Frame(object):
("payload_length_code - " + str(self.payload_length_code)),
("masking_key - " + repr(str(self.masking_key))),
("payload - " + repr(str(self.payload))),
- ("decoded_payload - " + repr(str(self.decoded_payload))),
- ("actual_payload_length - " + str(self.actual_payload_length))
])
@classmethod
@@ -311,9 +292,12 @@ class Frame(object):
rsv3 = self.rsv3,
mask = self.mask_bit,
masking_key = self.masking_key,
- payload_length = self.actual_payload_length
+ payload_length = len(self.payload) if self.payload else 0
)
- b += self.payload # already will be encoded if neccessary
+ if self.masking_key:
+ b += apply_mask(self.payload, self.masking_key)
+ else:
+ b += self.payload
return b
def to_file(self, writer):
@@ -359,10 +343,8 @@ class Frame(object):
payload = fp.read(actual_payload_length)
- if mask_bit == 1:
- decoded_payload = apply_mask(payload, masking_key)
- else:
- decoded_payload = payload
+ if mask_bit == 1 and masking_key:
+ payload = apply_mask(payload, masking_key)
return cls(
fin = fin,
@@ -371,11 +353,17 @@ class Frame(object):
payload_length_code = payload_length,
payload = payload,
masking_key = masking_key,
- decoded_payload = decoded_payload,
- actual_payload_length = actual_payload_length
)
def __eq__(self, other):
+ if self.payload_length_code is None:
+ myplc = make_length_code(len(self.payload))
+ else:
+ myplc = self.payload_length_code
+ if other.payload_length_code is None:
+ otherplc = make_length_code(len(other.payload))
+ else:
+ otherplc = other.payload_length_code
return (
self.fin == other.fin and
self.rsv1 == other.rsv1 and
@@ -383,9 +371,7 @@ class Frame(object):
self.rsv3 == other.rsv3 and
self.opcode == other.opcode and
self.mask_bit == other.mask_bit and
- self.payload_length_code == other.payload_length_code and
self.masking_key == other.masking_key and
- self.payload == other.payload and
- self.decoded_payload == other.decoded_payload and
- self.actual_payload_length == other.actual_payload_length
+ self.payload == other.payload,
+ myplc == otherplc
)