aboutsummaryrefslogtreecommitdiffstats
path: root/netlib/websockets.py
diff options
context:
space:
mode:
authorAldo Cortesi <aldo@nullcube.com>2015-04-21 22:39:45 +1200
committerAldo Cortesi <aldo@nullcube.com>2015-04-21 22:39:45 +1200
commit3e0a71ea345131a5f2dcc9581a7d93b8ebe09b13 (patch)
tree183a689254da4d824ffe6bd2eba8b07623f840bd /netlib/websockets.py
parente5f12648380cb4401f77e3cae51189ef97b603dc (diff)
downloadmitmproxy-3e0a71ea345131a5f2dcc9581a7d93b8ebe09b13.tar.gz
mitmproxy-3e0a71ea345131a5f2dcc9581a7d93b8ebe09b13.tar.bz2
mitmproxy-3e0a71ea345131a5f2dcc9581a7d93b8ebe09b13.zip
websockets: refactor to use http and header functions in http.py
Diffstat (limited to 'netlib/websockets.py')
-rw-r--r--netlib/websockets.py108
1 files changed, 35 insertions, 73 deletions
diff --git a/netlib/websockets.py b/netlib/websockets.py
index f2d467a5..a03185fa 100644
--- a/netlib/websockets.py
+++ b/netlib/websockets.py
@@ -2,13 +2,11 @@ from __future__ import absolute_import
import base64
import hashlib
-import mimetools
-import StringIO
import os
import struct
import io
-from . import utils
+from . import utils, odict
# Colleciton of utility functions that implement small portions of the RFC6455
# WebSockets Protocol Useful for building WebSocket clients and servers.
@@ -23,6 +21,7 @@ from . import utils
# The magic sha that websocket servers must know to prove they understand
# RFC6455
websockets_magic = '258EAFA5-E914-47DA-95CA-C5AB0DC85B11'
+VERSION = "13"
class CONST(object):
@@ -151,9 +150,9 @@ class Frame(object):
("opcode - " + str(self.opcode)),
("mask_bit - " + str(self.mask_bit)),
("payload_length_code - " + str(self.payload_length_code)),
- ("masking_key - " + str(self.masking_key)),
- ("payload - " + str(self.payload)),
- ("decoded_payload - " + str(self.decoded_payload)),
+ ("masking_key - " + repr(str(self.masking_key))),
+ ("payload - " + repr(str(self.payload))),
+ ("decoded_payload - " + repr(str(self.decoded_payload))),
("actual_payload_length - " + str(self.actual_payload_length))
])
@@ -198,24 +197,24 @@ class Frame(object):
second_byte = (self.mask_bit << 7) | self.payload_length_code
- bytes = chr(first_byte) + chr(second_byte)
+ b = chr(first_byte) + chr(second_byte)
if self.actual_payload_length < 126:
pass
elif self.actual_payload_length < CONST.MAX_16_BIT_INT:
# '!H' pack as 16 bit unsigned short
# add 2 byte extended payload length
- bytes += struct.pack('!H', self.actual_payload_length)
+ b += struct.pack('!H', self.actual_payload_length)
elif self.actual_payload_length < CONST.MAX_64_BIT_INT:
# '!Q' = pack as 64 bit unsigned long long
# add 8 bytes extended payload length
- bytes += struct.pack('!Q', self.actual_payload_length)
+ b += struct.pack('!Q', self.actual_payload_length)
if self.masking_key is not None:
- bytes += self.masking_key
+ b += self.masking_key
- bytes += self.payload # already will be encoded if neccessary
- return bytes
+ b += self.payload # already will be encoded if neccessary
+ return b
def to_file(self, writer):
writer.write(self.to_bytes())
@@ -313,58 +312,35 @@ def random_masking_key():
return os.urandom(4)
-def create_client_handshake(host, port, key, version, resource):
+def client_handshake_headers(key=None, version=VERSION):
"""
- WebSockets connections are intiated by the client with a valid HTTP
- upgrade request
+ Create the headers for a valid HTTP upgrade request. If Key is not
+ specified, it is generated, and can be found in sec-websocket-key in
+ the returned header set.
+
+ Returns an instance of ODictCaseless
"""
- headers = [
- ('Host', '%s:%s' % (host, port)),
+ if not key:
+ key = base64.b64encode(os.urandom(16)).decode('utf-8')
+ return odict.ODictCaseless([
('Connection', 'Upgrade'),
('Upgrade', 'websocket'),
('Sec-WebSocket-Key', key),
('Sec-WebSocket-Version', version)
- ]
- request = "GET %s HTTP/1.1" % resource
- return build_handshake(headers, request)
+ ])
-def create_server_handshake(key):
+def server_handshake_headers(key):
"""
The server response is a valid HTTP 101 response.
"""
- headers = [
- ('Connection', 'Upgrade'),
- ('Upgrade', 'websocket'),
- ('Sec-WebSocket-Accept', create_server_nonce(key))
- ]
- request = "HTTP/1.1 101 Switching Protocols"
- return build_handshake(headers, request)
-
-
-def build_handshake(headers, request):
- handshake = [request.encode('utf-8')]
- for header, value in headers:
- handshake.append(("%s: %s" % (header, value)).encode('utf-8'))
- handshake.append(b'\r\n')
- return b'\r\n'.join(handshake)
-
-
-def read_handshake(reader, num_bytes_per_read):
- """
- From provided function that reads bytes, read in a
- complete HTTP request, which terminates with a CLRF
- """
- response = b''
- doubleCLRF = b'\r\n\r\n'
- while True:
- bytes = reader.read(num_bytes_per_read)
- if not bytes:
- break
- response += bytes
- if doubleCLRF in response:
- break
- return response
+ return odict.ODictCaseless(
+ [
+ ('Connection', 'Upgrade'),
+ ('Upgrade', 'websocket'),
+ ('Sec-WebSocket-Accept', create_server_nonce(key))
+ ]
+ )
def get_payload_length_pair(payload_bytestring):
@@ -384,33 +360,19 @@ def get_payload_length_pair(payload_bytestring):
return (length_code, actual_length)
-def process_handshake_from_client(handshake):
- headers = headers_from_http_message(handshake)
- if headers.get("Upgrade", None) != "websocket":
+def check_client_handshake(req):
+ if req.headers.get_first("upgrade", None) != "websocket":
return
- key = headers['Sec-WebSocket-Key']
- return key
+ return req.headers.get_first('sec-websocket-key')
-def process_handshake_from_server(handshake):
- headers = headers_from_http_message(handshake)
- if headers.get("Upgrade", None) != "websocket":
+def check_server_handshake(resp):
+ if resp.headers.get_first("upgrade", None) != "websocket":
return
- key = headers['Sec-WebSocket-Accept']
- return key
-
-
-def headers_from_http_message(http_message):
- return mimetools.Message(
- StringIO.StringIO(http_message.split('\r\n', 1)[1])
- )
+ return resp.headers.get_first('sec-websocket-accept')
def create_server_nonce(client_nonce):
return base64.b64encode(
hashlib.sha1(client_nonce + websockets_magic).hexdigest().decode('hex')
)
-
-
-def create_client_nonce():
- return base64.b64encode(os.urandom(16)).decode('utf-8')