aboutsummaryrefslogtreecommitdiffstats
path: root/netlib
diff options
context:
space:
mode:
Diffstat (limited to 'netlib')
-rw-r--r--netlib/http2/__init__.py1
-rw-r--r--netlib/http2/frame.py79
-rw-r--r--netlib/http2/protocol.py160
-rw-r--r--netlib/http_cookies.py8
-rw-r--r--netlib/http_uastrings.py24
-rw-r--r--netlib/tcp.py88
-rw-r--r--netlib/utils.py2
-rw-r--r--netlib/websockets.py16
8 files changed, 248 insertions, 130 deletions
diff --git a/netlib/http2/__init__.py b/netlib/http2/__init__.py
index 92897b5d..5acf7696 100644
--- a/netlib/http2/__init__.py
+++ b/netlib/http2/__init__.py
@@ -1,3 +1,2 @@
-
from frame import *
from protocol import *
diff --git a/netlib/http2/frame.py b/netlib/http2/frame.py
index 4a305d82..b4783a02 100644
--- a/netlib/http2/frame.py
+++ b/netlib/http2/frame.py
@@ -1,6 +1,5 @@
import sys
import struct
-from functools import reduce
from hpack.hpack import Encoder, Decoder
from .. import utils
@@ -52,7 +51,7 @@ class Frame(object):
self.stream_id = stream_id
@classmethod
- def _check_frame_size(self, length, state):
+ def _check_frame_size(cls, length, state):
if state:
settings = state.http2_settings
else:
@@ -67,7 +66,7 @@ class Frame(object):
length, max_frame_size))
@classmethod
- def from_file(self, fp, state=None):
+ def from_file(cls, fp, state=None):
"""
read a HTTP/2 frame sent by a server or client
fp is a "file like" object that could be backed by a network
@@ -83,7 +82,7 @@ class Frame(object):
if raw_header[:4] == b'HTTP': # pragma no cover
print >> sys.stderr, "WARNING: This looks like an HTTP/1 connection!"
- self._check_frame_size(length, state)
+ cls._check_frame_size(length, state)
payload = fp.safe_read(length)
return FRAMES[fields[2]].from_bytes(
@@ -113,16 +112,13 @@ class Frame(object):
def payload_human_readable(self): # pragma: no cover
raise NotImplementedError()
- def human_readable(self):
+ def human_readable(self, direction="-"):
+ self.length = len(self.payload_bytes())
+
return "\n".join([
- "============================================================",
- "length: %d bytes" % self.length,
- "type: %s (%#x)" % (self.__class__.__name__, self.TYPE),
- "flags: %#x" % self.flags,
- "stream_id: %#x" % self.stream_id,
- "------------------------------------------------------------",
+ "%s: %s | length: %d | flags: %#x | stream_id: %d" % (direction, self.__class__.__name__, self.length, self.flags, self.stream_id),
self.payload_human_readable(),
- "============================================================",
+ "===============================================================",
])
def __eq__(self, other):
@@ -146,10 +142,10 @@ class DataFrame(Frame):
self.pad_length = pad_length
@classmethod
- def from_bytes(self, state, length, flags, stream_id, payload):
- f = self(state=state, length=length, flags=flags, stream_id=stream_id)
+ def from_bytes(cls, state, length, flags, stream_id, payload):
+ f = cls(state=state, length=length, flags=flags, stream_id=stream_id)
- if f.flags & self.FLAG_PADDED:
+ if f.flags & Frame.FLAG_PADDED:
f.pad_length = struct.unpack('!B', payload[0])[0]
f.payload = payload[1:-f.pad_length]
else:
@@ -204,16 +200,16 @@ class HeadersFrame(Frame):
self.weight = weight
@classmethod
- def from_bytes(self, state, length, flags, stream_id, payload):
- f = self(state=state, length=length, flags=flags, stream_id=stream_id)
+ def from_bytes(cls, state, length, flags, stream_id, payload):
+ f = cls(state=state, length=length, flags=flags, stream_id=stream_id)
- if f.flags & self.FLAG_PADDED:
+ if f.flags & Frame.FLAG_PADDED:
f.pad_length = struct.unpack('!B', payload[0])[0]
f.header_block_fragment = payload[1:-f.pad_length]
else:
f.header_block_fragment = payload[0:]
- if f.flags & self.FLAG_PRIORITY:
+ if f.flags & Frame.FLAG_PRIORITY:
f.stream_dependency, f.weight = struct.unpack(
'!LB', f.header_block_fragment[:5])
f.exclusive = bool(f.stream_dependency >> 31)
@@ -279,8 +275,8 @@ class PriorityFrame(Frame):
self.weight = weight
@classmethod
- def from_bytes(self, state, length, flags, stream_id, payload):
- f = self(state=state, length=length, flags=flags, stream_id=stream_id)
+ def from_bytes(cls, state, length, flags, stream_id, payload):
+ f = cls(state=state, length=length, flags=flags, stream_id=stream_id)
f.stream_dependency, f.weight = struct.unpack('!LB', payload)
f.exclusive = bool(f.stream_dependency >> 31)
@@ -325,8 +321,8 @@ class RstStreamFrame(Frame):
self.error_code = error_code
@classmethod
- def from_bytes(self, state, length, flags, stream_id, payload):
- f = self(state=state, length=length, flags=flags, stream_id=stream_id)
+ def from_bytes(cls, state, length, flags, stream_id, payload):
+ f = cls(state=state, length=length, flags=flags, stream_id=stream_id)
f.error_code = struct.unpack('!L', payload)[0]
return f
@@ -369,8 +365,8 @@ class SettingsFrame(Frame):
self.settings = settings
@classmethod
- def from_bytes(self, state, length, flags, stream_id, payload):
- f = self(state=state, length=length, flags=flags, stream_id=stream_id)
+ def from_bytes(cls, state, length, flags, stream_id, payload):
+ f = cls(state=state, length=length, flags=flags, stream_id=stream_id)
for i in xrange(0, len(payload), 6):
identifier, value = struct.unpack("!HL", payload[i:i + 6])
@@ -420,10 +416,10 @@ class PushPromiseFrame(Frame):
self.header_block_fragment = header_block_fragment
@classmethod
- def from_bytes(self, state, length, flags, stream_id, payload):
- f = self(state=state, length=length, flags=flags, stream_id=stream_id)
+ def from_bytes(cls, state, length, flags, stream_id, payload):
+ f = cls(state=state, length=length, flags=flags, stream_id=stream_id)
- if f.flags & self.FLAG_PADDED:
+ if f.flags & Frame.FLAG_PADDED:
f.pad_length, f.promised_stream = struct.unpack('!BL', payload[:5])
f.header_block_fragment = payload[5:-f.pad_length]
else:
@@ -461,7 +457,10 @@ class PushPromiseFrame(Frame):
s.append("padding: %d" % self.pad_length)
s.append("promised stream: %#x" % self.promised_stream)
- s.append("header_block_fragment: %s" % str(self.header_block_fragment))
+ s.append(
+ "header_block_fragment: %s" %
+ self.header_block_fragment.encode('hex'))
+
return "\n".join(s)
@@ -480,8 +479,8 @@ class PingFrame(Frame):
self.payload = payload
@classmethod
- def from_bytes(self, state, length, flags, stream_id, payload):
- f = self(state=state, length=length, flags=flags, stream_id=stream_id)
+ def from_bytes(cls, state, length, flags, stream_id, payload):
+ f = cls(state=state, length=length, flags=flags, stream_id=stream_id)
f.payload = payload
return f
@@ -517,8 +516,8 @@ class GoAwayFrame(Frame):
self.data = data
@classmethod
- def from_bytes(self, state, length, flags, stream_id, payload):
- f = self(state=state, length=length, flags=flags, stream_id=stream_id)
+ def from_bytes(cls, state, length, flags, stream_id, payload):
+ f = cls(state=state, length=length, flags=flags, stream_id=stream_id)
f.last_stream, f.error_code = struct.unpack("!LL", payload[:8])
f.last_stream &= 0x7FFFFFFF
@@ -558,8 +557,8 @@ class WindowUpdateFrame(Frame):
self.window_size_increment = window_size_increment
@classmethod
- def from_bytes(self, state, length, flags, stream_id, payload):
- f = self(state=state, length=length, flags=flags, stream_id=stream_id)
+ def from_bytes(cls, state, length, flags, stream_id, payload):
+ f = cls(state=state, length=length, flags=flags, stream_id=stream_id)
f.window_size_increment = struct.unpack("!L", payload)[0]
f.window_size_increment &= 0x7FFFFFFF
@@ -592,8 +591,8 @@ class ContinuationFrame(Frame):
self.header_block_fragment = header_block_fragment
@classmethod
- def from_bytes(self, state, length, flags, stream_id, payload):
- f = self(state=state, length=length, flags=flags, stream_id=stream_id)
+ def from_bytes(cls, state, length, flags, stream_id, payload):
+ f = cls(state=state, length=length, flags=flags, stream_id=stream_id)
f.header_block_fragment = payload
return f
@@ -605,7 +604,11 @@ class ContinuationFrame(Frame):
return self.header_block_fragment
def payload_human_readable(self):
- return "header_block_fragment: %s" % str(self.header_block_fragment)
+ s = []
+ s.append(
+ "header_block_fragment: %s" %
+ self.header_block_fragment.encode('hex'))
+ return "\n".join(s)
_FRAME_CLASSES = [
DataFrame,
diff --git a/netlib/http2/protocol.py b/netlib/http2/protocol.py
index feac220c..ac89bac4 100644
--- a/netlib/http2/protocol.py
+++ b/netlib/http2/protocol.py
@@ -26,72 +26,106 @@ class HTTP2Protocol(object):
)
# "PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n"
- CLIENT_CONNECTION_PREFACE = '505249202a20485454502f322e300d0a0d0a534d0d0a0d0a'
+ CLIENT_CONNECTION_PREFACE =\
+ '505249202a20485454502f322e300d0a0d0a534d0d0a0d0a'.decode('hex')
ALPN_PROTO_H2 = 'h2'
- def __init__(self, tcp_client):
- self.tcp_client = tcp_client
+ def __init__(self, tcp_handler, is_server=False, dump_frames=False):
+ self.tcp_handler = tcp_handler
+ self.is_server = is_server
self.http2_settings = frame.HTTP2_DEFAULT_SETTINGS.copy()
self.current_stream_id = None
self.encoder = Encoder()
self.decoder = Decoder()
+ self.connection_preface_performed = False
+ self.dump_frames = dump_frames
def check_alpn(self):
- alp = self.tcp_client.get_alpn_proto_negotiated()
+ alp = self.tcp_handler.get_alpn_proto_negotiated()
if alp != self.ALPN_PROTO_H2:
raise NotImplementedError(
"HTTP2Protocol can not handle unknown ALP: %s" % alp)
return True
- def perform_connection_preface(self):
- self.tcp_client.wfile.write(
- bytes(self.CLIENT_CONNECTION_PREFACE.decode('hex')))
- self.send_frame(frame.SettingsFrame(state=self))
+ def _receive_settings(self, hide=False):
+ while True:
+ frm = self.read_frame(hide)
+ if isinstance(frm, frame.SettingsFrame):
+ break
+
+ def _read_settings_ack(self, hide=False): # pragma no cover
+ while True:
+ frm = self.read_frame(hide)
+ if isinstance(frm, frame.SettingsFrame):
+ assert settings_ack_frame.flags & frame.Frame.FLAG_ACK
+ assert len(settings_ack_frame.settings) == 0
+ break
+
+ def perform_server_connection_preface(self, force=False):
+ if force or not self.connection_preface_performed:
+ self.connection_preface_performed = True
- # read server settings frame
- frm = frame.Frame.from_file(self.tcp_client.rfile, self)
- assert isinstance(frm, frame.SettingsFrame)
- self._apply_settings(frm.settings)
+ magic_length = len(self.CLIENT_CONNECTION_PREFACE)
+ magic = self.tcp_handler.rfile.safe_read(magic_length)
+ assert magic == self.CLIENT_CONNECTION_PREFACE
- # read setting ACK frame
- settings_ack_frame = self.read_frame()
- assert isinstance(settings_ack_frame, frame.SettingsFrame)
- assert settings_ack_frame.flags & frame.Frame.FLAG_ACK
- assert len(settings_ack_frame.settings) == 0
+ self.send_frame(frame.SettingsFrame(state=self), hide=True)
+ self._receive_settings(hide=True)
+
+ def perform_client_connection_preface(self, force=False):
+ if force or not self.connection_preface_performed:
+ self.connection_preface_performed = True
+
+ self.tcp_handler.wfile.write(self.CLIENT_CONNECTION_PREFACE)
+
+ self.send_frame(frame.SettingsFrame(state=self), hide=True)
+ self._receive_settings(hide=True)
def next_stream_id(self):
if self.current_stream_id is None:
- self.current_stream_id = 1
+ if self.is_server:
+ # servers must use even stream ids
+ self.current_stream_id = 2
+ else:
+ # clients must use odd stream ids
+ self.current_stream_id = 1
else:
self.current_stream_id += 2
return self.current_stream_id
- def send_frame(self, frame):
- raw_bytes = frame.to_bytes()
- self.tcp_client.wfile.write(raw_bytes)
- self.tcp_client.wfile.flush()
+ def send_frame(self, frm, hide=False):
+ raw_bytes = frm.to_bytes()
+ self.tcp_handler.wfile.write(raw_bytes)
+ self.tcp_handler.wfile.flush()
+ if not hide and self.dump_frames: # pragma no cover
+ print(frm.human_readable(">>"))
- def read_frame(self):
- frm = frame.Frame.from_file(self.tcp_client.rfile, self)
- if isinstance(frm, frame.SettingsFrame):
- self._apply_settings(frm.settings)
+ def read_frame(self, hide=False):
+ frm = frame.Frame.from_file(self.tcp_handler.rfile, self)
+ if not hide and self.dump_frames: # pragma no cover
+ print(frm.human_readable("<<"))
+ if isinstance(frm, frame.SettingsFrame) and not frm.flags & frame.Frame.FLAG_ACK:
+ self._apply_settings(frm.settings, hide)
return frm
- def _apply_settings(self, settings):
+ def _apply_settings(self, settings, hide=False):
for setting, value in settings.items():
old_value = self.http2_settings[setting]
if not old_value:
old_value = '-'
-
self.http2_settings[setting] = value
self.send_frame(
frame.SettingsFrame(
state=self,
- flags=frame.Frame.FLAG_ACK))
+ flags=frame.Frame.FLAG_ACK),
+ hide)
+
+ # be liberal in what we expect from the other end
+ # to be more strict use: self._read_settings_ack(hide)
def _create_headers(self, headers, stream_id, end_stream=True):
# TODO: implement max frame size checks and sending in chunks
@@ -102,12 +136,16 @@ class HTTP2Protocol(object):
header_block_fragment = self.encoder.encode(headers)
- bytes = frame.HeadersFrame(
+ frm = frame.HeadersFrame(
state=self,
flags=flags,
stream_id=stream_id,
- header_block_fragment=header_block_fragment).to_bytes()
- return [bytes]
+ header_block_fragment=header_block_fragment)
+
+ if self.dump_frames: # pragma no cover
+ print(frm.human_readable(">>"))
+
+ return [frm.to_bytes()]
def _create_body(self, body, stream_id):
if body is None or len(body) == 0:
@@ -116,21 +154,32 @@ class HTTP2Protocol(object):
# TODO: implement max frame size checks and sending in chunks
# TODO: implement flow-control window
- bytes = frame.DataFrame(
+ frm = frame.DataFrame(
state=self,
flags=frame.Frame.FLAG_END_STREAM,
stream_id=stream_id,
- payload=body).to_bytes()
- return [bytes]
+ payload=body)
+
+ if self.dump_frames: # pragma no cover
+ print(frm.human_readable(">>"))
+
+ return [frm.to_bytes()]
+
def create_request(self, method, path, headers=None, body=None):
if headers is None:
headers = []
+ authority = self.tcp_handler.sni if self.tcp_handler.sni else self.tcp_handler.address.host
+ if self.tcp_handler.address.port != 443:
+ authority += ":%d" % self.tcp_handler.address.port
+
headers = [
(b':method', bytes(method)),
(b':path', bytes(path)),
- (b':scheme', b'https')] + headers
+ (b':scheme', b'https'),
+ (b':authority', authority),
+ ] + headers
stream_id = self.next_stream_id()
@@ -139,25 +188,54 @@ class HTTP2Protocol(object):
self._create_body(body, stream_id)))
def read_response(self):
+ stream_id, headers, body = self._receive_transmission()
+ return headers[':status'], headers, body
+
+ def read_request(self):
+ return self._receive_transmission()
+
+ def _receive_transmission(self):
+ body_expected = True
+
+ stream_id = 0
header_block_fragment = b''
body = b''
while True:
frm = self.read_frame()
- if isinstance(frm, frame.HeadersFrame):
+ if isinstance(frm, frame.HeadersFrame)\
+ or isinstance(frm, frame.ContinuationFrame):
+ stream_id = frm.stream_id
header_block_fragment += frm.header_block_fragment
- if frm.flags | frame.Frame.FLAG_END_HEADERS:
+ if frm.flags & frame.Frame.FLAG_END_STREAM:
+ body_expected = False
+ if frm.flags & frame.Frame.FLAG_END_HEADERS:
break
- while True:
+ while body_expected:
frm = self.read_frame()
if isinstance(frm, frame.DataFrame):
body += frm.payload
- if frm.flags | frame.Frame.FLAG_END_STREAM:
+ if frm.flags & frame.Frame.FLAG_END_STREAM:
break
+ # TODO: implement window update & flow
headers = {}
for header, value in self.decoder.decode(header_block_fragment):
headers[header] = value
- return headers[':status'], headers, body
+ return stream_id, headers, body
+
+ def create_response(self, code, stream_id=None, headers=None, body=None):
+ if headers is None:
+ headers = []
+
+ headers = [(b':status', bytes(str(code)))] + headers
+
+ if not stream_id:
+ stream_id = self.next_stream_id()
+
+ return list(itertools.chain(
+ self._create_headers(headers, stream_id, end_stream=(body is None)),
+ self._create_body(body, stream_id),
+ ))
diff --git a/netlib/http_cookies.py b/netlib/http_cookies.py
index 5cb39e5c..b7311714 100644
--- a/netlib/http_cookies.py
+++ b/netlib/http_cookies.py
@@ -158,7 +158,7 @@ def _parse_set_cookie_pairs(s):
return pairs
-def parse_set_cookie_header(str):
+def parse_set_cookie_header(line):
"""
Parse a Set-Cookie header value
@@ -166,7 +166,7 @@ def parse_set_cookie_header(str):
ODictCaseless set of attributes. No attempt is made to parse attribute
values - they are treated purely as strings.
"""
- pairs = _parse_set_cookie_pairs(str)
+ pairs = _parse_set_cookie_pairs(line)
if pairs:
return pairs[0][0], pairs[0][1], odict.ODictCaseless(pairs[1:])
@@ -180,12 +180,12 @@ def format_set_cookie_header(name, value, attrs):
return _format_set_cookie_pairs(pairs)
-def parse_cookie_header(str):
+def parse_cookie_header(line):
"""
Parse a Cookie header value.
Returns a (possibly empty) ODict object.
"""
- pairs, off = _read_pairs(str)
+ pairs, off = _read_pairs(line)
return odict.ODict(pairs)
diff --git a/netlib/http_uastrings.py b/netlib/http_uastrings.py
index d9869531..c1ef557c 100644
--- a/netlib/http_uastrings.py
+++ b/netlib/http_uastrings.py
@@ -5,40 +5,42 @@ from __future__ import (absolute_import, print_function, division)
kept reasonably current to reflect common usage.
"""
+# pylint: line-too-long
+
# A collection of (name, shortcut, string) tuples.
UASTRINGS = [
("android",
"a",
- "Mozilla/5.0 (Linux; U; Android 4.1.1; en-gb; Nexus 7 Build/JRO03D) AFL/01.04.02"),
+ "Mozilla/5.0 (Linux; U; Android 4.1.1; en-gb; Nexus 7 Build/JRO03D) AFL/01.04.02"), # noqa
("blackberry",
"l",
- "Mozilla/5.0 (BlackBerry; U; BlackBerry 9900; en) AppleWebKit/534.11+ (KHTML, like Gecko) Version/7.1.0.346 Mobile Safari/534.11+"),
+ "Mozilla/5.0 (BlackBerry; U; BlackBerry 9900; en) AppleWebKit/534.11+ (KHTML, like Gecko) Version/7.1.0.346 Mobile Safari/534.11+"), # noqa
("bingbot",
"b",
- "Mozilla/5.0 (compatible; bingbot/2.0; +http://www.bing.com/bingbot.htm)"),
+ "Mozilla/5.0 (compatible; bingbot/2.0; +http://www.bing.com/bingbot.htm)"), # noqa
("chrome",
"c",
- "Mozilla/5.0 (Windows NT 6.1; WOW64) AppleWebKit/537.1 (KHTML, like Gecko) Chrome/22.0.1207.1 Safari/537.1"),
+ "Mozilla/5.0 (Windows NT 6.1; WOW64) AppleWebKit/537.1 (KHTML, like Gecko) Chrome/22.0.1207.1 Safari/537.1"), # noqa
("firefox",
"f",
- "Mozilla/5.0 (Windows NT 6.1; Win64; x64; rv:14.0) Gecko/20120405 Firefox/14.0a1"),
+ "Mozilla/5.0 (Windows NT 6.1; Win64; x64; rv:14.0) Gecko/20120405 Firefox/14.0a1"), # noqa
("googlebot",
"g",
- "Googlebot/2.1 (+http://www.googlebot.com/bot.html)"),
+ "Googlebot/2.1 (+http://www.googlebot.com/bot.html)"), # noqa
("ie9",
"i",
- "Mozilla/5.0 (Windows; U; MSIE 9.0; WIndows NT 9.0; en-US))"),
+ "Mozilla/5.0 (Windows; U; MSIE 9.0; WIndows NT 9.0; en-US))"), # noqa
("ipad",
"p",
- "Mozilla/5.0 (iPad; CPU OS 5_1 like Mac OS X) AppleWebKit/534.46 (KHTML, like Gecko ) Version/5.1 Mobile/9B176 Safari/7534.48.3"),
+ "Mozilla/5.0 (iPad; CPU OS 5_1 like Mac OS X) AppleWebKit/534.46 (KHTML, like Gecko ) Version/5.1 Mobile/9B176 Safari/7534.48.3"), # noqa
("iphone",
"h",
- "Mozilla/5.0 (iPhone; CPU iPhone OS 4_2_1 like Mac OS X) AppleWebKit/533.17.9 (KHTML, like Gecko) Version/5.0.2 Mobile/8C148a Safari/6533.18.5",
- ),
+ "Mozilla/5.0 (iPhone; CPU iPhone OS 4_2_1 like Mac OS X) AppleWebKit/533.17.9 (KHTML, like Gecko) Version/5.0.2 Mobile/8C148a Safari/6533.18.5"), # noqa
("safari",
"s",
- "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_7_3) AppleWebKit/534.55.3 (KHTML, like Gecko) Version/5.1.3 Safari/534.53.10")]
+ "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_7_3) AppleWebKit/534.55.3 (KHTML, like Gecko) Version/5.1.3 Safari/534.53.10"), # noqa
+]
def get_by_shortcut(s):
diff --git a/netlib/tcp.py b/netlib/tcp.py
index 9a980035..65075776 100644
--- a/netlib/tcp.py
+++ b/netlib/tcp.py
@@ -7,6 +7,7 @@ import threading
import time
import traceback
+import certifi
import OpenSSL
from OpenSSL import SSL
@@ -19,8 +20,18 @@ SSLv2_METHOD = SSL.SSLv2_METHOD
SSLv3_METHOD = SSL.SSLv3_METHOD
SSLv23_METHOD = SSL.SSLv23_METHOD
TLSv1_METHOD = SSL.TLSv1_METHOD
-OP_NO_SSLv2 = SSL.OP_NO_SSLv2
-OP_NO_SSLv3 = SSL.OP_NO_SSLv3
+TLSv1_1_METHOD = SSL.TLSv1_1_METHOD
+TLSv1_2_METHOD = SSL.TLSv1_2_METHOD
+
+
+SSL_DEFAULT_OPTIONS = (
+ SSL.OP_NO_SSLv2 |
+ SSL.OP_NO_SSLv3 |
+ SSL.OP_CIPHER_SERVER_PREFERENCE
+)
+
+if hasattr(SSL, "OP_NO_COMPRESSION"):
+ SSL_DEFAULT_OPTIONS |= SSL.OP_NO_COMPRESSION
class NetLibError(Exception):
@@ -293,7 +304,7 @@ def close_socket(sock):
"""
try:
# We already indicate that we close our end.
- # may raise "Transport endpoint is not connected" on Linux
+ # may raise "Transport endpoint is not connected" on Linux
sock.shutdown(socket.SHUT_WR)
# Section 4.2.2.13 of RFC 1122 tells us that a close() with any pending
@@ -364,20 +375,24 @@ class _Connection(object):
except SSL.Error:
pass
- """
- Creates an SSL Context.
- """
-
def _create_ssl_context(self,
method=SSLv23_METHOD,
- options=(OP_NO_SSLv2 | OP_NO_SSLv3),
+ options=SSL_DEFAULT_OPTIONS,
+ verify_options=SSL.VERIFY_NONE,
+ ca_path=certifi.where(),
+ ca_pemfile=None,
cipher_list=None,
alpn_protos=None,
alpn_select=None,
):
"""
- :param method: One of SSLv2_METHOD, SSLv3_METHOD, SSLv23_METHOD, TLSv1_METHOD or TLSv1_1_METHOD
+ Creates an SSL Context.
+
+ :param method: One of SSLv2_METHOD, SSLv3_METHOD, SSLv23_METHOD, TLSv1_METHOD, TLSv1_1_METHOD, or TLSv1_2_METHOD
:param options: A bit field consisting of OpenSSL.SSL.OP_* values
+ :param verify_options: A bit field consisting of OpenSSL.SSL.VERIFY_* values
+ :param ca_path: Path to a directory of trusted CA certificates prepared using the c_rehash tool
+ :param ca_pemfile: Path to a PEM formatted trusted CA certificate
:param cipher_list: A textual OpenSSL cipher list, see https://www.openssl.org/docs/apps/ciphers.html
:rtype : SSL.Context
"""
@@ -386,6 +401,18 @@ class _Connection(object):
if options is not None:
context.set_options(options)
+ # Verify Options (NONE/PEER/PEER|FAIL_IF_... and trusted CAs)
+ if verify_options is not None and verify_options is not SSL.VERIFY_NONE:
+ def verify_cert(conn, cert, errno, err_depth, is_cert_verified):
+ if is_cert_verified:
+ return True
+ raise NetLibError(
+ "Upstream certificate validation failed at depth: %s with error number: %s" %
+ (err_depth, errno))
+
+ context.set_verify(verify_options, verify_cert)
+ context.load_verify_locations(ca_pemfile, ca_path)
+
# Workaround for
# https://github.com/pyca/pyopenssl/issues/190
# https://github.com/mitmproxy/mitmproxy/issues/472
@@ -396,6 +423,9 @@ class _Connection(object):
if cipher_list:
try:
context.set_cipher_list(cipher_list)
+
+ # TODO: maybe change this to with newer pyOpenSSL APIs
+ context.set_tmp_ecdh(OpenSSL.crypto.get_elliptic_curve('prime256v1'))
except SSL.Error as v:
raise NetLibError("SSL cipher specification error: %s" % str(v))
@@ -404,16 +434,17 @@ class _Connection(object):
context.set_info_callback(log_ssl_key)
if OpenSSL._util.lib.Cryptography_HAS_ALPN:
- # advertise application layer protocols
if alpn_protos is not None:
+ # advertise application layer protocols
context.set_alpn_protos(alpn_protos)
-
- # select application layer protocol
- if alpn_select is not None:
- def alpn_select_f(conn, options):
- return bytes(alpn_select)
-
- context.set_alpn_select_callback(alpn_select_f)
+ elif alpn_select is not None:
+ # select application layer protocol
+ def alpn_select_callback(conn, options):
+ if alpn_select in options:
+ return bytes(alpn_select)
+ else: # pragma no cover
+ return options[0]
+ context.set_alpn_select_callback(alpn_select_callback)
return context
@@ -458,6 +489,9 @@ class TCPClient(_Connection):
cert: Path to a file containing both client cert and private key.
options: A bit field consisting of OpenSSL.SSL.OP_* values
+ verify_options: A bit field consisting of OpenSSL.SSL.VERIFY_* values
+ ca_path: Path to a directory of trusted CA certificates prepared using the c_rehash tool
+ ca_pemfile: Path to a PEM formatted trusted CA certificate
"""
context = self.create_ssl_context(
alpn_protos=alpn_protos,
@@ -499,10 +533,10 @@ class TCPClient(_Connection):
return self.connection.gettimeout()
def get_alpn_proto_negotiated(self):
- if OpenSSL._util.lib.Cryptography_HAS_ALPN:
+ if OpenSSL._util.lib.Cryptography_HAS_ALPN and self.ssl_established:
return self.connection.get_alpn_proto_negotiated()
- else: # pragma no cover
- return None
+ else:
+ return ""
class BaseHandler(_Connection):
@@ -531,7 +565,6 @@ class BaseHandler(_Connection):
request_client_cert=None,
chain_file=None,
dhparams=None,
- alpn_select=None,
**sslctx_kwargs):
"""
cert: A certutils.SSLCert object.
@@ -558,9 +591,7 @@ class BaseHandler(_Connection):
until then we're conservative.
"""
- context = self._create_ssl_context(
- alpn_select=alpn_select,
- **sslctx_kwargs)
+ context = self._create_ssl_context(**sslctx_kwargs)
context.use_privatekey(key)
context.use_certificate(cert.x509)
@@ -585,7 +616,7 @@ class BaseHandler(_Connection):
return context
- def convert_to_ssl(self, cert, key, alpn_select=None, **sslctx_kwargs):
+ def convert_to_ssl(self, cert, key, **sslctx_kwargs):
"""
Convert connection to SSL.
For a list of parameters, see BaseHandler._create_ssl_context(...)
@@ -594,7 +625,6 @@ class BaseHandler(_Connection):
context = self.create_ssl_context(
cert,
key,
- alpn_select=alpn_select,
**sslctx_kwargs)
self.connection = SSL.Connection(context, self.connection)
self.connection.set_accept_state()
@@ -612,6 +642,12 @@ class BaseHandler(_Connection):
def settimeout(self, n):
self.connection.settimeout(n)
+ def get_alpn_proto_negotiated(self):
+ if OpenSSL._util.lib.Cryptography_HAS_ALPN and self.ssl_established:
+ return self.connection.get_alpn_proto_negotiated()
+ else:
+ return ""
+
class TCPServer(object):
request_queue_size = 20
diff --git a/netlib/utils.py b/netlib/utils.py
index 9c5404e6..ac42bd53 100644
--- a/netlib/utils.py
+++ b/netlib/utils.py
@@ -67,7 +67,7 @@ def getbit(byte, offset):
return True
-class BiDi:
+class BiDi(object):
"""
A wee utility class for keeping bi-directional mappings, like field
diff --git a/netlib/websockets.py b/netlib/websockets.py
index 346adf1b..c45db4df 100644
--- a/netlib/websockets.py
+++ b/netlib/websockets.py
@@ -35,7 +35,7 @@ OPCODE = utils.BiDi(
)
-class Masker:
+class Masker(object):
"""
Data sent from the server must be masked to prevent malicious clients
@@ -94,15 +94,15 @@ def server_handshake_headers(key):
)
-def make_length_code(len):
+def make_length_code(length):
"""
A websockets frame contains an initial length_code, and an optional
extended length code to represent the actual length if length code is
larger than 125
"""
- if len <= 125:
- return len
- elif len >= 126 and len <= 65535:
+ if length <= 125:
+ return length
+ elif length >= 126 and length <= 65535:
return 126
else:
return 127
@@ -129,7 +129,7 @@ def create_server_nonce(client_nonce):
DEFAULT = object()
-class FrameHeader:
+class FrameHeader(object):
def __init__(
self,
@@ -216,7 +216,7 @@ class FrameHeader:
return b
@classmethod
- def from_file(klass, fp):
+ def from_file(cls, fp):
"""
read a websockets frame header
"""
@@ -248,7 +248,7 @@ class FrameHeader:
else:
masking_key = None
- return klass(
+ return cls(
fin=fin,
rsv1=rsv1,
rsv2=rsv2,