diff options
Diffstat (limited to 'netlib')
-rw-r--r-- | netlib/http2/__init__.py | 1 | ||||
-rw-r--r-- | netlib/http2/frame.py | 79 | ||||
-rw-r--r-- | netlib/http2/protocol.py | 160 | ||||
-rw-r--r-- | netlib/http_cookies.py | 8 | ||||
-rw-r--r-- | netlib/http_uastrings.py | 24 | ||||
-rw-r--r-- | netlib/tcp.py | 88 | ||||
-rw-r--r-- | netlib/utils.py | 2 | ||||
-rw-r--r-- | netlib/websockets.py | 16 |
8 files changed, 248 insertions, 130 deletions
diff --git a/netlib/http2/__init__.py b/netlib/http2/__init__.py index 92897b5d..5acf7696 100644 --- a/netlib/http2/__init__.py +++ b/netlib/http2/__init__.py @@ -1,3 +1,2 @@ - from frame import * from protocol import * diff --git a/netlib/http2/frame.py b/netlib/http2/frame.py index 4a305d82..b4783a02 100644 --- a/netlib/http2/frame.py +++ b/netlib/http2/frame.py @@ -1,6 +1,5 @@ import sys import struct -from functools import reduce from hpack.hpack import Encoder, Decoder from .. import utils @@ -52,7 +51,7 @@ class Frame(object): self.stream_id = stream_id @classmethod - def _check_frame_size(self, length, state): + def _check_frame_size(cls, length, state): if state: settings = state.http2_settings else: @@ -67,7 +66,7 @@ class Frame(object): length, max_frame_size)) @classmethod - def from_file(self, fp, state=None): + def from_file(cls, fp, state=None): """ read a HTTP/2 frame sent by a server or client fp is a "file like" object that could be backed by a network @@ -83,7 +82,7 @@ class Frame(object): if raw_header[:4] == b'HTTP': # pragma no cover print >> sys.stderr, "WARNING: This looks like an HTTP/1 connection!" - self._check_frame_size(length, state) + cls._check_frame_size(length, state) payload = fp.safe_read(length) return FRAMES[fields[2]].from_bytes( @@ -113,16 +112,13 @@ class Frame(object): def payload_human_readable(self): # pragma: no cover raise NotImplementedError() - def human_readable(self): + def human_readable(self, direction="-"): + self.length = len(self.payload_bytes()) + return "\n".join([ - "============================================================", - "length: %d bytes" % self.length, - "type: %s (%#x)" % (self.__class__.__name__, self.TYPE), - "flags: %#x" % self.flags, - "stream_id: %#x" % self.stream_id, - "------------------------------------------------------------", + "%s: %s | length: %d | flags: %#x | stream_id: %d" % (direction, self.__class__.__name__, self.length, self.flags, self.stream_id), self.payload_human_readable(), - "============================================================", + "===============================================================", ]) def __eq__(self, other): @@ -146,10 +142,10 @@ class DataFrame(Frame): self.pad_length = pad_length @classmethod - def from_bytes(self, state, length, flags, stream_id, payload): - f = self(state=state, length=length, flags=flags, stream_id=stream_id) + def from_bytes(cls, state, length, flags, stream_id, payload): + f = cls(state=state, length=length, flags=flags, stream_id=stream_id) - if f.flags & self.FLAG_PADDED: + if f.flags & Frame.FLAG_PADDED: f.pad_length = struct.unpack('!B', payload[0])[0] f.payload = payload[1:-f.pad_length] else: @@ -204,16 +200,16 @@ class HeadersFrame(Frame): self.weight = weight @classmethod - def from_bytes(self, state, length, flags, stream_id, payload): - f = self(state=state, length=length, flags=flags, stream_id=stream_id) + def from_bytes(cls, state, length, flags, stream_id, payload): + f = cls(state=state, length=length, flags=flags, stream_id=stream_id) - if f.flags & self.FLAG_PADDED: + if f.flags & Frame.FLAG_PADDED: f.pad_length = struct.unpack('!B', payload[0])[0] f.header_block_fragment = payload[1:-f.pad_length] else: f.header_block_fragment = payload[0:] - if f.flags & self.FLAG_PRIORITY: + if f.flags & Frame.FLAG_PRIORITY: f.stream_dependency, f.weight = struct.unpack( '!LB', f.header_block_fragment[:5]) f.exclusive = bool(f.stream_dependency >> 31) @@ -279,8 +275,8 @@ class PriorityFrame(Frame): self.weight = weight @classmethod - def from_bytes(self, state, length, flags, stream_id, payload): - f = self(state=state, length=length, flags=flags, stream_id=stream_id) + def from_bytes(cls, state, length, flags, stream_id, payload): + f = cls(state=state, length=length, flags=flags, stream_id=stream_id) f.stream_dependency, f.weight = struct.unpack('!LB', payload) f.exclusive = bool(f.stream_dependency >> 31) @@ -325,8 +321,8 @@ class RstStreamFrame(Frame): self.error_code = error_code @classmethod - def from_bytes(self, state, length, flags, stream_id, payload): - f = self(state=state, length=length, flags=flags, stream_id=stream_id) + def from_bytes(cls, state, length, flags, stream_id, payload): + f = cls(state=state, length=length, flags=flags, stream_id=stream_id) f.error_code = struct.unpack('!L', payload)[0] return f @@ -369,8 +365,8 @@ class SettingsFrame(Frame): self.settings = settings @classmethod - def from_bytes(self, state, length, flags, stream_id, payload): - f = self(state=state, length=length, flags=flags, stream_id=stream_id) + def from_bytes(cls, state, length, flags, stream_id, payload): + f = cls(state=state, length=length, flags=flags, stream_id=stream_id) for i in xrange(0, len(payload), 6): identifier, value = struct.unpack("!HL", payload[i:i + 6]) @@ -420,10 +416,10 @@ class PushPromiseFrame(Frame): self.header_block_fragment = header_block_fragment @classmethod - def from_bytes(self, state, length, flags, stream_id, payload): - f = self(state=state, length=length, flags=flags, stream_id=stream_id) + def from_bytes(cls, state, length, flags, stream_id, payload): + f = cls(state=state, length=length, flags=flags, stream_id=stream_id) - if f.flags & self.FLAG_PADDED: + if f.flags & Frame.FLAG_PADDED: f.pad_length, f.promised_stream = struct.unpack('!BL', payload[:5]) f.header_block_fragment = payload[5:-f.pad_length] else: @@ -461,7 +457,10 @@ class PushPromiseFrame(Frame): s.append("padding: %d" % self.pad_length) s.append("promised stream: %#x" % self.promised_stream) - s.append("header_block_fragment: %s" % str(self.header_block_fragment)) + s.append( + "header_block_fragment: %s" % + self.header_block_fragment.encode('hex')) + return "\n".join(s) @@ -480,8 +479,8 @@ class PingFrame(Frame): self.payload = payload @classmethod - def from_bytes(self, state, length, flags, stream_id, payload): - f = self(state=state, length=length, flags=flags, stream_id=stream_id) + def from_bytes(cls, state, length, flags, stream_id, payload): + f = cls(state=state, length=length, flags=flags, stream_id=stream_id) f.payload = payload return f @@ -517,8 +516,8 @@ class GoAwayFrame(Frame): self.data = data @classmethod - def from_bytes(self, state, length, flags, stream_id, payload): - f = self(state=state, length=length, flags=flags, stream_id=stream_id) + def from_bytes(cls, state, length, flags, stream_id, payload): + f = cls(state=state, length=length, flags=flags, stream_id=stream_id) f.last_stream, f.error_code = struct.unpack("!LL", payload[:8]) f.last_stream &= 0x7FFFFFFF @@ -558,8 +557,8 @@ class WindowUpdateFrame(Frame): self.window_size_increment = window_size_increment @classmethod - def from_bytes(self, state, length, flags, stream_id, payload): - f = self(state=state, length=length, flags=flags, stream_id=stream_id) + def from_bytes(cls, state, length, flags, stream_id, payload): + f = cls(state=state, length=length, flags=flags, stream_id=stream_id) f.window_size_increment = struct.unpack("!L", payload)[0] f.window_size_increment &= 0x7FFFFFFF @@ -592,8 +591,8 @@ class ContinuationFrame(Frame): self.header_block_fragment = header_block_fragment @classmethod - def from_bytes(self, state, length, flags, stream_id, payload): - f = self(state=state, length=length, flags=flags, stream_id=stream_id) + def from_bytes(cls, state, length, flags, stream_id, payload): + f = cls(state=state, length=length, flags=flags, stream_id=stream_id) f.header_block_fragment = payload return f @@ -605,7 +604,11 @@ class ContinuationFrame(Frame): return self.header_block_fragment def payload_human_readable(self): - return "header_block_fragment: %s" % str(self.header_block_fragment) + s = [] + s.append( + "header_block_fragment: %s" % + self.header_block_fragment.encode('hex')) + return "\n".join(s) _FRAME_CLASSES = [ DataFrame, diff --git a/netlib/http2/protocol.py b/netlib/http2/protocol.py index feac220c..ac89bac4 100644 --- a/netlib/http2/protocol.py +++ b/netlib/http2/protocol.py @@ -26,72 +26,106 @@ class HTTP2Protocol(object): ) # "PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n" - CLIENT_CONNECTION_PREFACE = '505249202a20485454502f322e300d0a0d0a534d0d0a0d0a' + CLIENT_CONNECTION_PREFACE =\ + '505249202a20485454502f322e300d0a0d0a534d0d0a0d0a'.decode('hex') ALPN_PROTO_H2 = 'h2' - def __init__(self, tcp_client): - self.tcp_client = tcp_client + def __init__(self, tcp_handler, is_server=False, dump_frames=False): + self.tcp_handler = tcp_handler + self.is_server = is_server self.http2_settings = frame.HTTP2_DEFAULT_SETTINGS.copy() self.current_stream_id = None self.encoder = Encoder() self.decoder = Decoder() + self.connection_preface_performed = False + self.dump_frames = dump_frames def check_alpn(self): - alp = self.tcp_client.get_alpn_proto_negotiated() + alp = self.tcp_handler.get_alpn_proto_negotiated() if alp != self.ALPN_PROTO_H2: raise NotImplementedError( "HTTP2Protocol can not handle unknown ALP: %s" % alp) return True - def perform_connection_preface(self): - self.tcp_client.wfile.write( - bytes(self.CLIENT_CONNECTION_PREFACE.decode('hex'))) - self.send_frame(frame.SettingsFrame(state=self)) + def _receive_settings(self, hide=False): + while True: + frm = self.read_frame(hide) + if isinstance(frm, frame.SettingsFrame): + break + + def _read_settings_ack(self, hide=False): # pragma no cover + while True: + frm = self.read_frame(hide) + if isinstance(frm, frame.SettingsFrame): + assert settings_ack_frame.flags & frame.Frame.FLAG_ACK + assert len(settings_ack_frame.settings) == 0 + break + + def perform_server_connection_preface(self, force=False): + if force or not self.connection_preface_performed: + self.connection_preface_performed = True - # read server settings frame - frm = frame.Frame.from_file(self.tcp_client.rfile, self) - assert isinstance(frm, frame.SettingsFrame) - self._apply_settings(frm.settings) + magic_length = len(self.CLIENT_CONNECTION_PREFACE) + magic = self.tcp_handler.rfile.safe_read(magic_length) + assert magic == self.CLIENT_CONNECTION_PREFACE - # read setting ACK frame - settings_ack_frame = self.read_frame() - assert isinstance(settings_ack_frame, frame.SettingsFrame) - assert settings_ack_frame.flags & frame.Frame.FLAG_ACK - assert len(settings_ack_frame.settings) == 0 + self.send_frame(frame.SettingsFrame(state=self), hide=True) + self._receive_settings(hide=True) + + def perform_client_connection_preface(self, force=False): + if force or not self.connection_preface_performed: + self.connection_preface_performed = True + + self.tcp_handler.wfile.write(self.CLIENT_CONNECTION_PREFACE) + + self.send_frame(frame.SettingsFrame(state=self), hide=True) + self._receive_settings(hide=True) def next_stream_id(self): if self.current_stream_id is None: - self.current_stream_id = 1 + if self.is_server: + # servers must use even stream ids + self.current_stream_id = 2 + else: + # clients must use odd stream ids + self.current_stream_id = 1 else: self.current_stream_id += 2 return self.current_stream_id - def send_frame(self, frame): - raw_bytes = frame.to_bytes() - self.tcp_client.wfile.write(raw_bytes) - self.tcp_client.wfile.flush() + def send_frame(self, frm, hide=False): + raw_bytes = frm.to_bytes() + self.tcp_handler.wfile.write(raw_bytes) + self.tcp_handler.wfile.flush() + if not hide and self.dump_frames: # pragma no cover + print(frm.human_readable(">>")) - def read_frame(self): - frm = frame.Frame.from_file(self.tcp_client.rfile, self) - if isinstance(frm, frame.SettingsFrame): - self._apply_settings(frm.settings) + def read_frame(self, hide=False): + frm = frame.Frame.from_file(self.tcp_handler.rfile, self) + if not hide and self.dump_frames: # pragma no cover + print(frm.human_readable("<<")) + if isinstance(frm, frame.SettingsFrame) and not frm.flags & frame.Frame.FLAG_ACK: + self._apply_settings(frm.settings, hide) return frm - def _apply_settings(self, settings): + def _apply_settings(self, settings, hide=False): for setting, value in settings.items(): old_value = self.http2_settings[setting] if not old_value: old_value = '-' - self.http2_settings[setting] = value self.send_frame( frame.SettingsFrame( state=self, - flags=frame.Frame.FLAG_ACK)) + flags=frame.Frame.FLAG_ACK), + hide) + + # be liberal in what we expect from the other end + # to be more strict use: self._read_settings_ack(hide) def _create_headers(self, headers, stream_id, end_stream=True): # TODO: implement max frame size checks and sending in chunks @@ -102,12 +136,16 @@ class HTTP2Protocol(object): header_block_fragment = self.encoder.encode(headers) - bytes = frame.HeadersFrame( + frm = frame.HeadersFrame( state=self, flags=flags, stream_id=stream_id, - header_block_fragment=header_block_fragment).to_bytes() - return [bytes] + header_block_fragment=header_block_fragment) + + if self.dump_frames: # pragma no cover + print(frm.human_readable(">>")) + + return [frm.to_bytes()] def _create_body(self, body, stream_id): if body is None or len(body) == 0: @@ -116,21 +154,32 @@ class HTTP2Protocol(object): # TODO: implement max frame size checks and sending in chunks # TODO: implement flow-control window - bytes = frame.DataFrame( + frm = frame.DataFrame( state=self, flags=frame.Frame.FLAG_END_STREAM, stream_id=stream_id, - payload=body).to_bytes() - return [bytes] + payload=body) + + if self.dump_frames: # pragma no cover + print(frm.human_readable(">>")) + + return [frm.to_bytes()] + def create_request(self, method, path, headers=None, body=None): if headers is None: headers = [] + authority = self.tcp_handler.sni if self.tcp_handler.sni else self.tcp_handler.address.host + if self.tcp_handler.address.port != 443: + authority += ":%d" % self.tcp_handler.address.port + headers = [ (b':method', bytes(method)), (b':path', bytes(path)), - (b':scheme', b'https')] + headers + (b':scheme', b'https'), + (b':authority', authority), + ] + headers stream_id = self.next_stream_id() @@ -139,25 +188,54 @@ class HTTP2Protocol(object): self._create_body(body, stream_id))) def read_response(self): + stream_id, headers, body = self._receive_transmission() + return headers[':status'], headers, body + + def read_request(self): + return self._receive_transmission() + + def _receive_transmission(self): + body_expected = True + + stream_id = 0 header_block_fragment = b'' body = b'' while True: frm = self.read_frame() - if isinstance(frm, frame.HeadersFrame): + if isinstance(frm, frame.HeadersFrame)\ + or isinstance(frm, frame.ContinuationFrame): + stream_id = frm.stream_id header_block_fragment += frm.header_block_fragment - if frm.flags | frame.Frame.FLAG_END_HEADERS: + if frm.flags & frame.Frame.FLAG_END_STREAM: + body_expected = False + if frm.flags & frame.Frame.FLAG_END_HEADERS: break - while True: + while body_expected: frm = self.read_frame() if isinstance(frm, frame.DataFrame): body += frm.payload - if frm.flags | frame.Frame.FLAG_END_STREAM: + if frm.flags & frame.Frame.FLAG_END_STREAM: break + # TODO: implement window update & flow headers = {} for header, value in self.decoder.decode(header_block_fragment): headers[header] = value - return headers[':status'], headers, body + return stream_id, headers, body + + def create_response(self, code, stream_id=None, headers=None, body=None): + if headers is None: + headers = [] + + headers = [(b':status', bytes(str(code)))] + headers + + if not stream_id: + stream_id = self.next_stream_id() + + return list(itertools.chain( + self._create_headers(headers, stream_id, end_stream=(body is None)), + self._create_body(body, stream_id), + )) diff --git a/netlib/http_cookies.py b/netlib/http_cookies.py index 5cb39e5c..b7311714 100644 --- a/netlib/http_cookies.py +++ b/netlib/http_cookies.py @@ -158,7 +158,7 @@ def _parse_set_cookie_pairs(s): return pairs -def parse_set_cookie_header(str): +def parse_set_cookie_header(line): """ Parse a Set-Cookie header value @@ -166,7 +166,7 @@ def parse_set_cookie_header(str): ODictCaseless set of attributes. No attempt is made to parse attribute values - they are treated purely as strings. """ - pairs = _parse_set_cookie_pairs(str) + pairs = _parse_set_cookie_pairs(line) if pairs: return pairs[0][0], pairs[0][1], odict.ODictCaseless(pairs[1:]) @@ -180,12 +180,12 @@ def format_set_cookie_header(name, value, attrs): return _format_set_cookie_pairs(pairs) -def parse_cookie_header(str): +def parse_cookie_header(line): """ Parse a Cookie header value. Returns a (possibly empty) ODict object. """ - pairs, off = _read_pairs(str) + pairs, off = _read_pairs(line) return odict.ODict(pairs) diff --git a/netlib/http_uastrings.py b/netlib/http_uastrings.py index d9869531..c1ef557c 100644 --- a/netlib/http_uastrings.py +++ b/netlib/http_uastrings.py @@ -5,40 +5,42 @@ from __future__ import (absolute_import, print_function, division) kept reasonably current to reflect common usage. """ +# pylint: line-too-long + # A collection of (name, shortcut, string) tuples. UASTRINGS = [ ("android", "a", - "Mozilla/5.0 (Linux; U; Android 4.1.1; en-gb; Nexus 7 Build/JRO03D) AFL/01.04.02"), + "Mozilla/5.0 (Linux; U; Android 4.1.1; en-gb; Nexus 7 Build/JRO03D) AFL/01.04.02"), # noqa ("blackberry", "l", - "Mozilla/5.0 (BlackBerry; U; BlackBerry 9900; en) AppleWebKit/534.11+ (KHTML, like Gecko) Version/7.1.0.346 Mobile Safari/534.11+"), + "Mozilla/5.0 (BlackBerry; U; BlackBerry 9900; en) AppleWebKit/534.11+ (KHTML, like Gecko) Version/7.1.0.346 Mobile Safari/534.11+"), # noqa ("bingbot", "b", - "Mozilla/5.0 (compatible; bingbot/2.0; +http://www.bing.com/bingbot.htm)"), + "Mozilla/5.0 (compatible; bingbot/2.0; +http://www.bing.com/bingbot.htm)"), # noqa ("chrome", "c", - "Mozilla/5.0 (Windows NT 6.1; WOW64) AppleWebKit/537.1 (KHTML, like Gecko) Chrome/22.0.1207.1 Safari/537.1"), + "Mozilla/5.0 (Windows NT 6.1; WOW64) AppleWebKit/537.1 (KHTML, like Gecko) Chrome/22.0.1207.1 Safari/537.1"), # noqa ("firefox", "f", - "Mozilla/5.0 (Windows NT 6.1; Win64; x64; rv:14.0) Gecko/20120405 Firefox/14.0a1"), + "Mozilla/5.0 (Windows NT 6.1; Win64; x64; rv:14.0) Gecko/20120405 Firefox/14.0a1"), # noqa ("googlebot", "g", - "Googlebot/2.1 (+http://www.googlebot.com/bot.html)"), + "Googlebot/2.1 (+http://www.googlebot.com/bot.html)"), # noqa ("ie9", "i", - "Mozilla/5.0 (Windows; U; MSIE 9.0; WIndows NT 9.0; en-US))"), + "Mozilla/5.0 (Windows; U; MSIE 9.0; WIndows NT 9.0; en-US))"), # noqa ("ipad", "p", - "Mozilla/5.0 (iPad; CPU OS 5_1 like Mac OS X) AppleWebKit/534.46 (KHTML, like Gecko ) Version/5.1 Mobile/9B176 Safari/7534.48.3"), + "Mozilla/5.0 (iPad; CPU OS 5_1 like Mac OS X) AppleWebKit/534.46 (KHTML, like Gecko ) Version/5.1 Mobile/9B176 Safari/7534.48.3"), # noqa ("iphone", "h", - "Mozilla/5.0 (iPhone; CPU iPhone OS 4_2_1 like Mac OS X) AppleWebKit/533.17.9 (KHTML, like Gecko) Version/5.0.2 Mobile/8C148a Safari/6533.18.5", - ), + "Mozilla/5.0 (iPhone; CPU iPhone OS 4_2_1 like Mac OS X) AppleWebKit/533.17.9 (KHTML, like Gecko) Version/5.0.2 Mobile/8C148a Safari/6533.18.5"), # noqa ("safari", "s", - "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_7_3) AppleWebKit/534.55.3 (KHTML, like Gecko) Version/5.1.3 Safari/534.53.10")] + "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_7_3) AppleWebKit/534.55.3 (KHTML, like Gecko) Version/5.1.3 Safari/534.53.10"), # noqa +] def get_by_shortcut(s): diff --git a/netlib/tcp.py b/netlib/tcp.py index 9a980035..65075776 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -7,6 +7,7 @@ import threading import time import traceback +import certifi import OpenSSL from OpenSSL import SSL @@ -19,8 +20,18 @@ SSLv2_METHOD = SSL.SSLv2_METHOD SSLv3_METHOD = SSL.SSLv3_METHOD SSLv23_METHOD = SSL.SSLv23_METHOD TLSv1_METHOD = SSL.TLSv1_METHOD -OP_NO_SSLv2 = SSL.OP_NO_SSLv2 -OP_NO_SSLv3 = SSL.OP_NO_SSLv3 +TLSv1_1_METHOD = SSL.TLSv1_1_METHOD +TLSv1_2_METHOD = SSL.TLSv1_2_METHOD + + +SSL_DEFAULT_OPTIONS = ( + SSL.OP_NO_SSLv2 | + SSL.OP_NO_SSLv3 | + SSL.OP_CIPHER_SERVER_PREFERENCE +) + +if hasattr(SSL, "OP_NO_COMPRESSION"): + SSL_DEFAULT_OPTIONS |= SSL.OP_NO_COMPRESSION class NetLibError(Exception): @@ -293,7 +304,7 @@ def close_socket(sock): """ try: # We already indicate that we close our end. - # may raise "Transport endpoint is not connected" on Linux + # may raise "Transport endpoint is not connected" on Linux sock.shutdown(socket.SHUT_WR) # Section 4.2.2.13 of RFC 1122 tells us that a close() with any pending @@ -364,20 +375,24 @@ class _Connection(object): except SSL.Error: pass - """ - Creates an SSL Context. - """ - def _create_ssl_context(self, method=SSLv23_METHOD, - options=(OP_NO_SSLv2 | OP_NO_SSLv3), + options=SSL_DEFAULT_OPTIONS, + verify_options=SSL.VERIFY_NONE, + ca_path=certifi.where(), + ca_pemfile=None, cipher_list=None, alpn_protos=None, alpn_select=None, ): """ - :param method: One of SSLv2_METHOD, SSLv3_METHOD, SSLv23_METHOD, TLSv1_METHOD or TLSv1_1_METHOD + Creates an SSL Context. + + :param method: One of SSLv2_METHOD, SSLv3_METHOD, SSLv23_METHOD, TLSv1_METHOD, TLSv1_1_METHOD, or TLSv1_2_METHOD :param options: A bit field consisting of OpenSSL.SSL.OP_* values + :param verify_options: A bit field consisting of OpenSSL.SSL.VERIFY_* values + :param ca_path: Path to a directory of trusted CA certificates prepared using the c_rehash tool + :param ca_pemfile: Path to a PEM formatted trusted CA certificate :param cipher_list: A textual OpenSSL cipher list, see https://www.openssl.org/docs/apps/ciphers.html :rtype : SSL.Context """ @@ -386,6 +401,18 @@ class _Connection(object): if options is not None: context.set_options(options) + # Verify Options (NONE/PEER/PEER|FAIL_IF_... and trusted CAs) + if verify_options is not None and verify_options is not SSL.VERIFY_NONE: + def verify_cert(conn, cert, errno, err_depth, is_cert_verified): + if is_cert_verified: + return True + raise NetLibError( + "Upstream certificate validation failed at depth: %s with error number: %s" % + (err_depth, errno)) + + context.set_verify(verify_options, verify_cert) + context.load_verify_locations(ca_pemfile, ca_path) + # Workaround for # https://github.com/pyca/pyopenssl/issues/190 # https://github.com/mitmproxy/mitmproxy/issues/472 @@ -396,6 +423,9 @@ class _Connection(object): if cipher_list: try: context.set_cipher_list(cipher_list) + + # TODO: maybe change this to with newer pyOpenSSL APIs + context.set_tmp_ecdh(OpenSSL.crypto.get_elliptic_curve('prime256v1')) except SSL.Error as v: raise NetLibError("SSL cipher specification error: %s" % str(v)) @@ -404,16 +434,17 @@ class _Connection(object): context.set_info_callback(log_ssl_key) if OpenSSL._util.lib.Cryptography_HAS_ALPN: - # advertise application layer protocols if alpn_protos is not None: + # advertise application layer protocols context.set_alpn_protos(alpn_protos) - - # select application layer protocol - if alpn_select is not None: - def alpn_select_f(conn, options): - return bytes(alpn_select) - - context.set_alpn_select_callback(alpn_select_f) + elif alpn_select is not None: + # select application layer protocol + def alpn_select_callback(conn, options): + if alpn_select in options: + return bytes(alpn_select) + else: # pragma no cover + return options[0] + context.set_alpn_select_callback(alpn_select_callback) return context @@ -458,6 +489,9 @@ class TCPClient(_Connection): cert: Path to a file containing both client cert and private key. options: A bit field consisting of OpenSSL.SSL.OP_* values + verify_options: A bit field consisting of OpenSSL.SSL.VERIFY_* values + ca_path: Path to a directory of trusted CA certificates prepared using the c_rehash tool + ca_pemfile: Path to a PEM formatted trusted CA certificate """ context = self.create_ssl_context( alpn_protos=alpn_protos, @@ -499,10 +533,10 @@ class TCPClient(_Connection): return self.connection.gettimeout() def get_alpn_proto_negotiated(self): - if OpenSSL._util.lib.Cryptography_HAS_ALPN: + if OpenSSL._util.lib.Cryptography_HAS_ALPN and self.ssl_established: return self.connection.get_alpn_proto_negotiated() - else: # pragma no cover - return None + else: + return "" class BaseHandler(_Connection): @@ -531,7 +565,6 @@ class BaseHandler(_Connection): request_client_cert=None, chain_file=None, dhparams=None, - alpn_select=None, **sslctx_kwargs): """ cert: A certutils.SSLCert object. @@ -558,9 +591,7 @@ class BaseHandler(_Connection): until then we're conservative. """ - context = self._create_ssl_context( - alpn_select=alpn_select, - **sslctx_kwargs) + context = self._create_ssl_context(**sslctx_kwargs) context.use_privatekey(key) context.use_certificate(cert.x509) @@ -585,7 +616,7 @@ class BaseHandler(_Connection): return context - def convert_to_ssl(self, cert, key, alpn_select=None, **sslctx_kwargs): + def convert_to_ssl(self, cert, key, **sslctx_kwargs): """ Convert connection to SSL. For a list of parameters, see BaseHandler._create_ssl_context(...) @@ -594,7 +625,6 @@ class BaseHandler(_Connection): context = self.create_ssl_context( cert, key, - alpn_select=alpn_select, **sslctx_kwargs) self.connection = SSL.Connection(context, self.connection) self.connection.set_accept_state() @@ -612,6 +642,12 @@ class BaseHandler(_Connection): def settimeout(self, n): self.connection.settimeout(n) + def get_alpn_proto_negotiated(self): + if OpenSSL._util.lib.Cryptography_HAS_ALPN and self.ssl_established: + return self.connection.get_alpn_proto_negotiated() + else: + return "" + class TCPServer(object): request_queue_size = 20 diff --git a/netlib/utils.py b/netlib/utils.py index 9c5404e6..ac42bd53 100644 --- a/netlib/utils.py +++ b/netlib/utils.py @@ -67,7 +67,7 @@ def getbit(byte, offset): return True -class BiDi: +class BiDi(object): """ A wee utility class for keeping bi-directional mappings, like field diff --git a/netlib/websockets.py b/netlib/websockets.py index 346adf1b..c45db4df 100644 --- a/netlib/websockets.py +++ b/netlib/websockets.py @@ -35,7 +35,7 @@ OPCODE = utils.BiDi( ) -class Masker: +class Masker(object): """ Data sent from the server must be masked to prevent malicious clients @@ -94,15 +94,15 @@ def server_handshake_headers(key): ) -def make_length_code(len): +def make_length_code(length): """ A websockets frame contains an initial length_code, and an optional extended length code to represent the actual length if length code is larger than 125 """ - if len <= 125: - return len - elif len >= 126 and len <= 65535: + if length <= 125: + return length + elif length >= 126 and length <= 65535: return 126 else: return 127 @@ -129,7 +129,7 @@ def create_server_nonce(client_nonce): DEFAULT = object() -class FrameHeader: +class FrameHeader(object): def __init__( self, @@ -216,7 +216,7 @@ class FrameHeader: return b @classmethod - def from_file(klass, fp): + def from_file(cls, fp): """ read a websockets frame header """ @@ -248,7 +248,7 @@ class FrameHeader: else: masking_key = None - return klass( + return cls( fin=fin, rsv1=rsv1, rsv2=rsv2, |