From b558997fd9db8406b2a24a1831d06e283dbf35a6 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Tue, 19 Jun 2012 09:42:32 +1200 Subject: Initial checkin. --- netlib/__init__.py | 0 netlib/odict.py | 160 +++++++++++++++++++++++++++++++++++++++ netlib/protocol.py | 218 +++++++++++++++++++++++++++++++++++++++++++++++++++++ netlib/tcp.py | 182 ++++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 560 insertions(+) create mode 100644 netlib/__init__.py create mode 100644 netlib/odict.py create mode 100644 netlib/protocol.py create mode 100644 netlib/tcp.py (limited to 'netlib') diff --git a/netlib/__init__.py b/netlib/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/netlib/odict.py b/netlib/odict.py new file mode 100644 index 00000000..afc33caa --- /dev/null +++ b/netlib/odict.py @@ -0,0 +1,160 @@ +import re, copy + +def safe_subn(pattern, repl, target, *args, **kwargs): + """ + There are Unicode conversion problems with re.subn. We try to smooth + that over by casting the pattern and replacement to strings. We really + need a better solution that is aware of the actual content ecoding. + """ + return re.subn(str(pattern), str(repl), target, *args, **kwargs) + + +class ODict: + """ + A dictionary-like object for managing ordered (key, value) data. + """ + def __init__(self, lst=None): + self.lst = lst or [] + + def _kconv(self, s): + return s + + def __eq__(self, other): + return self.lst == other.lst + + def __getitem__(self, k): + """ + Returns a list of values matching key. + """ + ret = [] + k = self._kconv(k) + for i in self.lst: + if self._kconv(i[0]) == k: + ret.append(i[1]) + return ret + + def _filter_lst(self, k, lst): + k = self._kconv(k) + new = [] + for i in lst: + if self._kconv(i[0]) != k: + new.append(i) + return new + + def __len__(self): + """ + Total number of (key, value) pairs. + """ + return len(self.lst) + + def __setitem__(self, k, valuelist): + """ + Sets the values for key k. If there are existing values for this + key, they are cleared. + """ + if isinstance(valuelist, basestring): + raise ValueError("ODict valuelist should be lists.") + new = self._filter_lst(k, self.lst) + for i in valuelist: + new.append([k, i]) + self.lst = new + + def __delitem__(self, k): + """ + Delete all items matching k. + """ + self.lst = self._filter_lst(k, self.lst) + + def __contains__(self, k): + for i in self.lst: + if self._kconv(i[0]) == self._kconv(k): + return True + return False + + def add(self, key, value): + self.lst.append([key, str(value)]) + + def get(self, k, d=None): + if k in self: + return self[k] + else: + return d + + def items(self): + return self.lst[:] + + def _get_state(self): + return [tuple(i) for i in self.lst] + + @classmethod + def _from_state(klass, state): + return klass([list(i) for i in state]) + + def copy(self): + """ + Returns a copy of this object. + """ + lst = copy.deepcopy(self.lst) + return self.__class__(lst) + + def __repr__(self): + elements = [] + for itm in self.lst: + elements.append(itm[0] + ": " + itm[1]) + elements.append("") + return "\r\n".join(elements) + + def in_any(self, key, value, caseless=False): + """ + Do any of the values matching key contain value? + + If caseless is true, value comparison is case-insensitive. + """ + if caseless: + value = value.lower() + for i in self[key]: + if caseless: + i = i.lower() + if value in i: + return True + return False + + def match_re(self, expr): + """ + Match the regular expression against each (key, value) pair. For + each pair a string of the following format is matched against: + + "key: value" + """ + for k, v in self.lst: + s = "%s: %s"%(k, v) + if re.search(expr, s): + return True + return False + + def replace(self, pattern, repl, *args, **kwargs): + """ + Replaces a regular expression pattern with repl in both keys and + values. Encoded content will be decoded before replacement, and + re-encoded afterwards. + + Returns the number of replacements made. + """ + nlst, count = [], 0 + for i in self.lst: + k, c = safe_subn(pattern, repl, i[0], *args, **kwargs) + count += c + v, c = safe_subn(pattern, repl, i[1], *args, **kwargs) + count += c + nlst.append([k, v]) + self.lst = nlst + return count + + +class ODictCaseless(ODict): + """ + A variant of ODict with "caseless" keys. This version _preserves_ key + case, but does not consider case when setting or getting items. + """ + def _kconv(self, s): + return s.lower() diff --git a/netlib/protocol.py b/netlib/protocol.py new file mode 100644 index 00000000..55bcf440 --- /dev/null +++ b/netlib/protocol.py @@ -0,0 +1,218 @@ +import string, urlparse + +class ProtocolError(Exception): + def __init__(self, code, msg): + self.code, self.msg = code, msg + + def __str__(self): + return "ProtocolError(%s, %s)"%(self.code, self.msg) + + +def parse_url(url): + """ + Returns a (scheme, host, port, path) tuple, or None on error. + """ + scheme, netloc, path, params, query, fragment = urlparse.urlparse(url) + if not scheme: + return None + if ':' in netloc: + host, port = string.rsplit(netloc, ':', maxsplit=1) + try: + port = int(port) + except ValueError: + return None + else: + host = netloc + if scheme == "https": + port = 443 + else: + port = 80 + path = urlparse.urlunparse(('', '', path, params, query, fragment)) + if not path.startswith("/"): + path = "/" + path + return scheme, host, port, path + + +def read_headers(fp): + """ + Read a set of headers from a file pointer. Stop once a blank line + is reached. Return a ODictCaseless object. + """ + ret = [] + name = '' + while 1: + line = fp.readline() + if not line or line == '\r\n' or line == '\n': + break + if line[0] in ' \t': + # continued header + ret[-1][1] = ret[-1][1] + '\r\n ' + line.strip() + else: + i = line.find(':') + # We're being liberal in what we accept, here. + if i > 0: + name = line[:i] + value = line[i+1:].strip() + ret.append([name, value]) + return ret + + +def read_chunked(fp, limit): + content = "" + total = 0 + while 1: + line = fp.readline(128) + if line == "": + raise IOError("Connection closed") + if line == '\r\n' or line == '\n': + continue + try: + length = int(line,16) + except ValueError: + # FIXME: Not strictly correct - this could be from the server, in which + # case we should send a 502. + raise ProtocolError(400, "Invalid chunked encoding length: %s"%line) + if not length: + break + total += length + if limit is not None and total > limit: + msg = "HTTP Body too large."\ + " Limit is %s, chunked content length was at least %s"%(limit, total) + raise ProtocolError(509, msg) + content += fp.read(length) + line = fp.readline(5) + if line != '\r\n': + raise IOError("Malformed chunked body") + while 1: + line = fp.readline() + if line == "": + raise IOError("Connection closed") + if line == '\r\n' or line == '\n': + break + return content + + +def has_chunked_encoding(headers): + for i in headers["transfer-encoding"]: + for j in i.split(","): + if j.lower() == "chunked": + return True + return False + + +def read_http_body(rfile, headers, all, limit): + if has_chunked_encoding(headers): + content = read_chunked(rfile, limit) + elif "content-length" in headers: + try: + l = int(headers["content-length"][0]) + except ValueError: + # FIXME: Not strictly correct - this could be from the server, in which + # case we should send a 502. + raise ProtocolError(400, "Invalid content-length header: %s"%headers["content-length"]) + if limit is not None and l > limit: + raise ProtocolError(509, "HTTP Body too large. Limit is %s, content-length was %s"%(limit, l)) + content = rfile.read(l) + elif all: + content = rfile.read(limit if limit else None) + else: + content = "" + return content + + +def parse_http_protocol(s): + if not s.startswith("HTTP/"): + return None + major, minor = s.split('/')[1].split('.') + major = int(major) + minor = int(minor) + return major, minor + + +def parse_init_connect(line): + try: + method, url, protocol = string.split(line) + except ValueError: + return None + if method != 'CONNECT': + return None + try: + host, port = url.split(":") + except ValueError: + return None + port = int(port) + httpversion = parse_http_protocol(protocol) + if not httpversion: + return None + return host, port, httpversion + + +def parse_init_proxy(line): + try: + method, url, protocol = string.split(line) + except ValueError: + return None + parts = parse_url(url) + if not parts: + return None + scheme, host, port, path = parts + httpversion = parse_http_protocol(protocol) + if not httpversion: + return None + return method, scheme, host, port, path, httpversion + + +def parse_init_http(line): + """ + Returns (method, url, httpversion) + """ + try: + method, url, protocol = string.split(line) + except ValueError: + return None + if not (url.startswith("/") or url == "*"): + return None + httpversion = parse_http_protocol(protocol) + if not httpversion: + return None + return method, url, httpversion + + +def request_connection_close(httpversion, headers): + """ + Checks the request to see if the client connection should be closed. + """ + if "connection" in headers: + for value in ",".join(headers['connection']).split(","): + value = value.strip() + if value == "close": + return True + elif value == "keep-alive": + return False + # HTTP 1.1 connections are assumed to be persistent + if httpversion == (1, 1): + return False + return True + + +def response_connection_close(httpversion, headers): + """ + Checks the response to see if the client connection should be closed. + """ + if request_connection_close(httpversion, headers): + return True + elif not has_chunked_encoding(headers) and "content-length" in headers: + return True + return False + + +def read_http_body_request(rfile, wfile, headers, httpversion, limit): + if "expect" in headers: + # FIXME: Should be forwarded upstream + expect = ",".join(headers['expect']) + if expect == "100-continue" and httpversion >= (1, 1): + wfile.write('HTTP/1.1 100 Continue\r\n') + wfile.write('Proxy-agent: %s\r\n'%version.NAMEVERSION) + wfile.write('\r\n') + del headers['expect'] + return read_http_body(rfile, headers, False, limit) diff --git a/netlib/tcp.py b/netlib/tcp.py new file mode 100644 index 00000000..08ccba09 --- /dev/null +++ b/netlib/tcp.py @@ -0,0 +1,182 @@ +import select, socket, threading, traceback, sys +from OpenSSL import SSL + + +class NetLibError(Exception): pass + + +class FileLike: + def __init__(self, o): + self.o = o + + def __getattr__(self, attr): + return getattr(self.o, attr) + + def flush(self): + pass + + def read(self, length): + result = '' + while len(result) < length: + try: + data = self.o.read(length) + except SSL.ZeroReturnError: + break + if not data: + break + result += data + return result + + def write(self, v): + self.o.sendall(v) + + def readline(self, size = None): + result = '' + bytes_read = 0 + while True: + if size is not None and bytes_read >= size: + break + ch = self.read(1) + bytes_read += 1 + if not ch: + break + else: + result += ch + if ch == '\n': + break + return result + + +class TCPClient: + def __init__(self, ssl, host, port, clientcert): + self.ssl, self.host, self.port, self.clientcert = ssl, host, port, clientcert + self.connection, self.rfile, self.wfile = None, None, None + self.cert = None + self.connect() + + def connect(self): + try: + addr = socket.gethostbyname(self.host) + server = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + if self.ssl: + context = SSL.Context(SSL.SSLv23_METHOD) + if self.clientcert: + context.use_certificate_file(self.clientcert) + server = SSL.Connection(context, server) + server.connect((addr, self.port)) + if self.ssl: + self.cert = server.get_peer_certificate() + self.rfile, self.wfile = FileLike(server), FileLike(server) + else: + self.rfile, self.wfile = server.makefile('rb'), server.makefile('wb') + except socket.error, err: + raise NetLibError('Error connecting to "%s": %s' % (self.host, err)) + self.connection = server + + +class BaseHandler: + rbufsize = -1 + wbufsize = 0 + def __init__(self, connection, client_address, server): + self.connection = connection + self.rfile = self.connection.makefile('rb', self.rbufsize) + self.wfile = self.connection.makefile('wb', self.wbufsize) + + self.client_address = client_address + self.server = server + self.handle() + self.finish() + + def convert_to_ssl(self, cert, key): + ctx = SSL.Context(SSL.SSLv23_METHOD) + ctx.use_privatekey_file(key) + ctx.use_certificate_file(cert) + self.connection = SSL.Connection(ctx, self.connection) + self.connection.set_accept_state() + self.rfile = FileLike(self.connection) + self.wfile = FileLike(self.connection) + + def finish(self): + try: + if not getattr(self.wfile, "closed", False): + self.wfile.flush() + self.connection.close() + self.wfile.close() + self.rfile.close() + except IOError: # pragma: no cover + pass + + def handle(self): # pragma: no cover + raise NotImplementedError + + +class TCPServer: + request_queue_size = 20 + def __init__(self, server_address): + self.server_address = server_address + self.__is_shut_down = threading.Event() + self.__shutdown_request = False + self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + self.socket.bind(self.server_address) + self.server_address = self.socket.getsockname() + self.socket.listen(self.request_queue_size) + self.port = self.socket.getsockname()[1] + + def request_thread(self, request, client_address): + try: + self.handle_connection(request, client_address) + request.close() + except: + self.handle_error(request, client_address) + request.close() + + def serve_forever(self, poll_interval=0.5): + self.__is_shut_down.clear() + try: + while not self.__shutdown_request: + r, w, e = select.select([self.socket], [], [], poll_interval) + if self.socket in r: + try: + request, client_address = self.socket.accept() + except socket.error: + return + try: + t = threading.Thread( + target = self.request_thread, + args = (request, client_address) + ) + t.setDaemon(1) + t.start() + except: + self.handle_error(request, client_address) + request.close() + finally: + self.__shutdown_request = False + self.__is_shut_down.set() + + def shutdown(self): + self.__shutdown_request = True + self.__is_shut_down.wait() + self.handle_shutdown() + + def handle_error(self, request, client_address, fp=sys.stderr): + """ + Called when handle_connection raises an exception. + """ + print >> fp, '-'*40 + print >> fp, "Error processing of request from %s:%s"%client_address + print >> fp, traceback.format_exc() + print >> fp, '-'*40 + + def handle_connection(self, request, client_address): # pragma: no cover + """ + Called after client connection. + """ + raise NotImplementedError + + def handle_shutdown(self): + """ + Called after server shutdown. + """ + pass -- cgit v1.2.3 From c7e9051cbbee1e76abb24518268d30a24df3a16a Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Tue, 19 Jun 2012 10:42:25 +1200 Subject: Import wsgi. --- netlib/wsgi.py | 125 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 125 insertions(+) create mode 100644 netlib/wsgi.py (limited to 'netlib') diff --git a/netlib/wsgi.py b/netlib/wsgi.py new file mode 100644 index 00000000..0608245c --- /dev/null +++ b/netlib/wsgi.py @@ -0,0 +1,125 @@ +import cStringIO, urllib, time, sys, traceback +import odict + +def date_time_string(): + """Return the current date and time formatted for a message header.""" + WEEKS = ['Mon', 'Tue', 'Wed', 'Thu', 'Fri', 'Sat', 'Sun'] + MONTHS = [None, + 'Jan', 'Feb', 'Mar', 'Apr', 'May', 'Jun', + 'Jul', 'Aug', 'Sep', 'Oct', 'Nov', 'Dec'] + now = time.time() + year, month, day, hh, mm, ss, wd, y, z = time.gmtime(now) + s = "%s, %02d %3s %4d %02d:%02d:%02d GMT" % ( + WEEKS[wd], + day, MONTHS[month], year, + hh, mm, ss) + return s + + +class WSGIAdaptor: + def __init__(self, app, domain, port, sversion): + self.app, self.domain, self.port, self.sversion = app, domain, port, sversion + + def make_environ(self, request, errsoc): + if '?' in request.path: + path_info, query = request.path.split('?', 1) + else: + path_info = request.path + query = '' + environ = { + 'wsgi.version': (1, 0), + 'wsgi.url_scheme': request.scheme, + 'wsgi.input': cStringIO.StringIO(request.content), + 'wsgi.errors': errsoc, + 'wsgi.multithread': True, + 'wsgi.multiprocess': False, + 'wsgi.run_once': False, + 'SERVER_SOFTWARE': self.sversion, + 'REQUEST_METHOD': request.method, + 'SCRIPT_NAME': '', + 'PATH_INFO': urllib.unquote(path_info), + 'QUERY_STRING': query, + 'CONTENT_TYPE': request.headers.get('Content-Type', [''])[0], + 'CONTENT_LENGTH': request.headers.get('Content-Length', [''])[0], + 'SERVER_NAME': self.domain, + 'SERVER_PORT': self.port, + # FIXME: We need to pick up the protocol read from the request. + 'SERVER_PROTOCOL': "HTTP/1.1", + } + if request.client_conn.address: + environ["REMOTE_ADDR"], environ["REMOTE_PORT"] = request.client_conn.address + + for key, value in request.headers.items(): + key = 'HTTP_' + key.upper().replace('-', '_') + if key not in ('HTTP_CONTENT_TYPE', 'HTTP_CONTENT_LENGTH'): + environ[key] = value + return environ + + def error_page(self, soc, headers_sent, s): + """ + Make a best-effort attempt to write an error page. If headers are + already sent, we just bung the error into the page. + """ + c = """ + +

Internal Server Error

+
%s"
+ + """%s + if not headers_sent: + soc.write("HTTP/1.1 500 Internal Server Error\r\n") + soc.write("Content-Type: text/html\r\n") + soc.write("Content-Length: %s\r\n"%len(c)) + soc.write("\r\n") + soc.write(c) + + def serve(self, request, soc): + state = dict( + response_started = False, + headers_sent = False, + status = None, + headers = None + ) + def write(data): + if not state["headers_sent"]: + soc.write("HTTP/1.1 %s\r\n"%state["status"]) + h = state["headers"] + if 'server' not in h: + h["Server"] = [version.NAMEVERSION] + if 'date' not in h: + h["Date"] = [date_time_string()] + soc.write(str(h)) + soc.write("\r\n") + state["headers_sent"] = True + soc.write(data) + soc.flush() + + def start_response(status, headers, exc_info=None): + if exc_info: + try: + if state["headers_sent"]: + raise exc_info[0], exc_info[1], exc_info[2] + finally: + exc_info = None + elif state["status"]: + raise AssertionError('Response already started') + state["status"] = status + state["headers"] = odict.ODictCaseless(headers) + return write + + errs = cStringIO.StringIO() + try: + dataiter = self.app(self.make_environ(request, errs), start_response) + for i in dataiter: + write(i) + if not state["headers_sent"]: + write("") + except Exception, v: + try: + s = traceback.format_exc() + self.error_page(soc, state["headers_sent"], s) + except Exception, v: # pragma: no cover + pass # pragma: no cover + return errs.getvalue() + + -- cgit v1.2.3 From ce1ef554561d55a414961993dcaf8f11000d1f22 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Tue, 19 Jun 2012 14:23:22 +1200 Subject: Adapt WSGI, convert test suite to nose. --- netlib/wsgi.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) (limited to 'netlib') diff --git a/netlib/wsgi.py b/netlib/wsgi.py index 0608245c..3c3a8384 100644 --- a/netlib/wsgi.py +++ b/netlib/wsgi.py @@ -1,6 +1,19 @@ import cStringIO, urllib, time, sys, traceback import odict + +class ClientConn: + def __init__(self, address): + self.address = address + + +class Request: + def __init__(self, client_conn, scheme, method, path, headers, content): + self.scheme, self.method, self.path = scheme, method, path + self.headers, self.content = headers, content + self.client_conn = client_conn + + def date_time_string(): """Return the current date and time formatted for a message header.""" WEEKS = ['Mon', 'Tue', 'Wed', 'Thu', 'Fri', 'Sat', 'Sun'] @@ -85,7 +98,7 @@ class WSGIAdaptor: soc.write("HTTP/1.1 %s\r\n"%state["status"]) h = state["headers"] if 'server' not in h: - h["Server"] = [version.NAMEVERSION] + h["Server"] = [self.sversion] if 'date' not in h: h["Date"] = [date_time_string()] soc.write(str(h)) -- cgit v1.2.3 From 084be7684d5cb367d4b8995dbf01f177af6113bf Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Wed, 20 Jun 2012 10:51:02 +1200 Subject: Close socket on shutdown. --- netlib/tcp.py | 1 + 1 file changed, 1 insertion(+) (limited to 'netlib') diff --git a/netlib/tcp.py b/netlib/tcp.py index 08ccba09..92a7e92f 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -158,6 +158,7 @@ class TCPServer: def shutdown(self): self.__shutdown_request = True self.__is_shut_down.wait() + self.socket.close() self.handle_shutdown() def handle_error(self, request, client_address, fp=sys.stderr): -- cgit v1.2.3 From b7062007965ebd8c11e94bd28775ac6d6083eedf Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Wed, 20 Jun 2012 11:01:40 +1200 Subject: Drop default poll interval to 0.1s. --- netlib/tcp.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'netlib') diff --git a/netlib/tcp.py b/netlib/tcp.py index 92a7e92f..5a942522 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -131,7 +131,7 @@ class TCPServer: self.handle_error(request, client_address) request.close() - def serve_forever(self, poll_interval=0.5): + def serve_forever(self, poll_interval=0.1): self.__is_shut_down.clear() try: while not self.__shutdown_request: -- cgit v1.2.3 From 227e72abf4124cbf55328cd15be917b4af99367f Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Sat, 23 Jun 2012 13:49:57 +1200 Subject: README, setup.py, version --- netlib/version.py | 4 ++++ 1 file changed, 4 insertions(+) create mode 100644 netlib/version.py (limited to 'netlib') diff --git a/netlib/version.py b/netlib/version.py new file mode 100644 index 00000000..1c4a4b66 --- /dev/null +++ b/netlib/version.py @@ -0,0 +1,4 @@ +IVERSION = (0, 1) +VERSION = ".".join(str(i) for i in IVERSION) +NAME = "netlib" +NAMEVERSION = NAME + " " + VERSION -- cgit v1.2.3 From 5cf6aeb926e0b3a1cad23a0b169b8dfa8536a22f Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Sat, 23 Jun 2012 13:56:17 +1200 Subject: protocol.py -> http.py --- netlib/http.py | 218 +++++++++++++++++++++++++++++++++++++++++++++++++++++ netlib/protocol.py | 218 ----------------------------------------------------- 2 files changed, 218 insertions(+), 218 deletions(-) create mode 100644 netlib/http.py delete mode 100644 netlib/protocol.py (limited to 'netlib') diff --git a/netlib/http.py b/netlib/http.py new file mode 100644 index 00000000..c676c25c --- /dev/null +++ b/netlib/http.py @@ -0,0 +1,218 @@ +import string, urlparse + +class HttpError(Exception): + def __init__(self, code, msg): + self.code, self.msg = code, msg + + def __str__(self): + return "HttpError(%s, %s)"%(self.code, self.msg) + + +def parse_url(url): + """ + Returns a (scheme, host, port, path) tuple, or None on error. + """ + scheme, netloc, path, params, query, fragment = urlparse.urlparse(url) + if not scheme: + return None + if ':' in netloc: + host, port = string.rsplit(netloc, ':', maxsplit=1) + try: + port = int(port) + except ValueError: + return None + else: + host = netloc + if scheme == "https": + port = 443 + else: + port = 80 + path = urlparse.urlunparse(('', '', path, params, query, fragment)) + if not path.startswith("/"): + path = "/" + path + return scheme, host, port, path + + +def read_headers(fp): + """ + Read a set of headers from a file pointer. Stop once a blank line + is reached. Return a ODictCaseless object. + """ + ret = [] + name = '' + while 1: + line = fp.readline() + if not line or line == '\r\n' or line == '\n': + break + if line[0] in ' \t': + # continued header + ret[-1][1] = ret[-1][1] + '\r\n ' + line.strip() + else: + i = line.find(':') + # We're being liberal in what we accept, here. + if i > 0: + name = line[:i] + value = line[i+1:].strip() + ret.append([name, value]) + return ret + + +def read_chunked(fp, limit): + content = "" + total = 0 + while 1: + line = fp.readline(128) + if line == "": + raise IOError("Connection closed") + if line == '\r\n' or line == '\n': + continue + try: + length = int(line,16) + except ValueError: + # FIXME: Not strictly correct - this could be from the server, in which + # case we should send a 502. + raise HttpError(400, "Invalid chunked encoding length: %s"%line) + if not length: + break + total += length + if limit is not None and total > limit: + msg = "HTTP Body too large."\ + " Limit is %s, chunked content length was at least %s"%(limit, total) + raise HttpError(509, msg) + content += fp.read(length) + line = fp.readline(5) + if line != '\r\n': + raise IOError("Malformed chunked body") + while 1: + line = fp.readline() + if line == "": + raise IOError("Connection closed") + if line == '\r\n' or line == '\n': + break + return content + + +def has_chunked_encoding(headers): + for i in headers["transfer-encoding"]: + for j in i.split(","): + if j.lower() == "chunked": + return True + return False + + +def read_http_body(rfile, headers, all, limit): + if has_chunked_encoding(headers): + content = read_chunked(rfile, limit) + elif "content-length" in headers: + try: + l = int(headers["content-length"][0]) + except ValueError: + # FIXME: Not strictly correct - this could be from the server, in which + # case we should send a 502. + raise HttpError(400, "Invalid content-length header: %s"%headers["content-length"]) + if limit is not None and l > limit: + raise HttpError(509, "HTTP Body too large. Limit is %s, content-length was %s"%(limit, l)) + content = rfile.read(l) + elif all: + content = rfile.read(limit if limit else None) + else: + content = "" + return content + + +def parse_http_protocol(s): + if not s.startswith("HTTP/"): + return None + major, minor = s.split('/')[1].split('.') + major = int(major) + minor = int(minor) + return major, minor + + +def parse_init_connect(line): + try: + method, url, protocol = string.split(line) + except ValueError: + return None + if method != 'CONNECT': + return None + try: + host, port = url.split(":") + except ValueError: + return None + port = int(port) + httpversion = parse_http_protocol(protocol) + if not httpversion: + return None + return host, port, httpversion + + +def parse_init_proxy(line): + try: + method, url, protocol = string.split(line) + except ValueError: + return None + parts = parse_url(url) + if not parts: + return None + scheme, host, port, path = parts + httpversion = parse_http_protocol(protocol) + if not httpversion: + return None + return method, scheme, host, port, path, httpversion + + +def parse_init_http(line): + """ + Returns (method, url, httpversion) + """ + try: + method, url, protocol = string.split(line) + except ValueError: + return None + if not (url.startswith("/") or url == "*"): + return None + httpversion = parse_http_protocol(protocol) + if not httpversion: + return None + return method, url, httpversion + + +def request_connection_close(httpversion, headers): + """ + Checks the request to see if the client connection should be closed. + """ + if "connection" in headers: + for value in ",".join(headers['connection']).split(","): + value = value.strip() + if value == "close": + return True + elif value == "keep-alive": + return False + # HTTP 1.1 connections are assumed to be persistent + if httpversion == (1, 1): + return False + return True + + +def response_connection_close(httpversion, headers): + """ + Checks the response to see if the client connection should be closed. + """ + if request_connection_close(httpversion, headers): + return True + elif not has_chunked_encoding(headers) and "content-length" in headers: + return True + return False + + +def read_http_body_request(rfile, wfile, headers, httpversion, limit): + if "expect" in headers: + # FIXME: Should be forwarded upstream + expect = ",".join(headers['expect']) + if expect == "100-continue" and httpversion >= (1, 1): + wfile.write('HTTP/1.1 100 Continue\r\n') + wfile.write('Proxy-agent: %s\r\n'%version.NAMEVERSION) + wfile.write('\r\n') + del headers['expect'] + return read_http_body(rfile, headers, False, limit) diff --git a/netlib/protocol.py b/netlib/protocol.py deleted file mode 100644 index 55bcf440..00000000 --- a/netlib/protocol.py +++ /dev/null @@ -1,218 +0,0 @@ -import string, urlparse - -class ProtocolError(Exception): - def __init__(self, code, msg): - self.code, self.msg = code, msg - - def __str__(self): - return "ProtocolError(%s, %s)"%(self.code, self.msg) - - -def parse_url(url): - """ - Returns a (scheme, host, port, path) tuple, or None on error. - """ - scheme, netloc, path, params, query, fragment = urlparse.urlparse(url) - if not scheme: - return None - if ':' in netloc: - host, port = string.rsplit(netloc, ':', maxsplit=1) - try: - port = int(port) - except ValueError: - return None - else: - host = netloc - if scheme == "https": - port = 443 - else: - port = 80 - path = urlparse.urlunparse(('', '', path, params, query, fragment)) - if not path.startswith("/"): - path = "/" + path - return scheme, host, port, path - - -def read_headers(fp): - """ - Read a set of headers from a file pointer. Stop once a blank line - is reached. Return a ODictCaseless object. - """ - ret = [] - name = '' - while 1: - line = fp.readline() - if not line or line == '\r\n' or line == '\n': - break - if line[0] in ' \t': - # continued header - ret[-1][1] = ret[-1][1] + '\r\n ' + line.strip() - else: - i = line.find(':') - # We're being liberal in what we accept, here. - if i > 0: - name = line[:i] - value = line[i+1:].strip() - ret.append([name, value]) - return ret - - -def read_chunked(fp, limit): - content = "" - total = 0 - while 1: - line = fp.readline(128) - if line == "": - raise IOError("Connection closed") - if line == '\r\n' or line == '\n': - continue - try: - length = int(line,16) - except ValueError: - # FIXME: Not strictly correct - this could be from the server, in which - # case we should send a 502. - raise ProtocolError(400, "Invalid chunked encoding length: %s"%line) - if not length: - break - total += length - if limit is not None and total > limit: - msg = "HTTP Body too large."\ - " Limit is %s, chunked content length was at least %s"%(limit, total) - raise ProtocolError(509, msg) - content += fp.read(length) - line = fp.readline(5) - if line != '\r\n': - raise IOError("Malformed chunked body") - while 1: - line = fp.readline() - if line == "": - raise IOError("Connection closed") - if line == '\r\n' or line == '\n': - break - return content - - -def has_chunked_encoding(headers): - for i in headers["transfer-encoding"]: - for j in i.split(","): - if j.lower() == "chunked": - return True - return False - - -def read_http_body(rfile, headers, all, limit): - if has_chunked_encoding(headers): - content = read_chunked(rfile, limit) - elif "content-length" in headers: - try: - l = int(headers["content-length"][0]) - except ValueError: - # FIXME: Not strictly correct - this could be from the server, in which - # case we should send a 502. - raise ProtocolError(400, "Invalid content-length header: %s"%headers["content-length"]) - if limit is not None and l > limit: - raise ProtocolError(509, "HTTP Body too large. Limit is %s, content-length was %s"%(limit, l)) - content = rfile.read(l) - elif all: - content = rfile.read(limit if limit else None) - else: - content = "" - return content - - -def parse_http_protocol(s): - if not s.startswith("HTTP/"): - return None - major, minor = s.split('/')[1].split('.') - major = int(major) - minor = int(minor) - return major, minor - - -def parse_init_connect(line): - try: - method, url, protocol = string.split(line) - except ValueError: - return None - if method != 'CONNECT': - return None - try: - host, port = url.split(":") - except ValueError: - return None - port = int(port) - httpversion = parse_http_protocol(protocol) - if not httpversion: - return None - return host, port, httpversion - - -def parse_init_proxy(line): - try: - method, url, protocol = string.split(line) - except ValueError: - return None - parts = parse_url(url) - if not parts: - return None - scheme, host, port, path = parts - httpversion = parse_http_protocol(protocol) - if not httpversion: - return None - return method, scheme, host, port, path, httpversion - - -def parse_init_http(line): - """ - Returns (method, url, httpversion) - """ - try: - method, url, protocol = string.split(line) - except ValueError: - return None - if not (url.startswith("/") or url == "*"): - return None - httpversion = parse_http_protocol(protocol) - if not httpversion: - return None - return method, url, httpversion - - -def request_connection_close(httpversion, headers): - """ - Checks the request to see if the client connection should be closed. - """ - if "connection" in headers: - for value in ",".join(headers['connection']).split(","): - value = value.strip() - if value == "close": - return True - elif value == "keep-alive": - return False - # HTTP 1.1 connections are assumed to be persistent - if httpversion == (1, 1): - return False - return True - - -def response_connection_close(httpversion, headers): - """ - Checks the response to see if the client connection should be closed. - """ - if request_connection_close(httpversion, headers): - return True - elif not has_chunked_encoding(headers) and "content-length" in headers: - return True - return False - - -def read_http_body_request(rfile, wfile, headers, httpversion, limit): - if "expect" in headers: - # FIXME: Should be forwarded upstream - expect = ",".join(headers['expect']) - if expect == "100-continue" and httpversion >= (1, 1): - wfile.write('HTTP/1.1 100 Continue\r\n') - wfile.write('Proxy-agent: %s\r\n'%version.NAMEVERSION) - wfile.write('\r\n') - del headers['expect'] - return read_http_body(rfile, headers, False, limit) -- cgit v1.2.3 From 1263221ddd06da12f3f1f5f9c3e55858b304ce54 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Sat, 23 Jun 2012 15:07:42 +1200 Subject: 100% testcoverage for netlib.http --- netlib/http.py | 91 +++++++++++++++++++++++++++++++++++++--------------------- 1 file changed, 58 insertions(+), 33 deletions(-) (limited to 'netlib') diff --git a/netlib/http.py b/netlib/http.py index c676c25c..da43d070 100644 --- a/netlib/http.py +++ b/netlib/http.py @@ -57,36 +57,40 @@ def read_headers(fp): return ret -def read_chunked(fp, limit): +def read_chunked(code, fp, limit): + """ + Read a chunked HTTP body. + + May raise HttpError. + """ content = "" total = 0 while 1: line = fp.readline(128) if line == "": - raise IOError("Connection closed") - if line == '\r\n' or line == '\n': - continue - try: - length = int(line,16) - except ValueError: - # FIXME: Not strictly correct - this could be from the server, in which - # case we should send a 502. - raise HttpError(400, "Invalid chunked encoding length: %s"%line) - if not length: - break - total += length - if limit is not None and total > limit: - msg = "HTTP Body too large."\ - " Limit is %s, chunked content length was at least %s"%(limit, total) - raise HttpError(509, msg) - content += fp.read(length) - line = fp.readline(5) - if line != '\r\n': - raise IOError("Malformed chunked body") + raise HttpError(code, "Connection closed prematurely") + if line != '\r\n' and line != '\n': + try: + length = int(line, 16) + except ValueError: + # FIXME: Not strictly correct - this could be from the server, in which + # case we should send a 502. + raise HttpError(code, "Invalid chunked encoding length: %s"%line) + if not length: + break + total += length + if limit is not None and total > limit: + msg = "HTTP Body too large."\ + " Limit is %s, chunked content length was at least %s"%(limit, total) + raise HttpError(code, msg) + content += fp.read(length) + line = fp.readline(5) + if line != '\r\n': + raise HttpError(code, "Malformed chunked body") while 1: line = fp.readline() if line == "": - raise IOError("Connection closed") + raise HttpError(code, "Connection closed prematurely") if line == '\r\n' or line == '\n': break return content @@ -100,18 +104,27 @@ def has_chunked_encoding(headers): return False -def read_http_body(rfile, headers, all, limit): +def read_http_body(code, rfile, headers, all, limit): + """ + Read an HTTP body: + + code: The HTTP error code to be used when raising HttpError + rfile: A file descriptor to read from + headers: An ODictCaseless object + all: Should we read all data? + limit: Size limit. + """ if has_chunked_encoding(headers): - content = read_chunked(rfile, limit) + content = read_chunked(code, rfile, limit) elif "content-length" in headers: try: l = int(headers["content-length"][0]) except ValueError: # FIXME: Not strictly correct - this could be from the server, in which # case we should send a 502. - raise HttpError(400, "Invalid content-length header: %s"%headers["content-length"]) + raise HttpError(code, "Invalid content-length header: %s"%headers["content-length"]) if limit is not None and l > limit: - raise HttpError(509, "HTTP Body too large. Limit is %s, content-length was %s"%(limit, l)) + raise HttpError(code, "HTTP Body too large. Limit is %s, content-length was %s"%(limit, l)) content = rfile.read(l) elif all: content = rfile.read(limit if limit else None) @@ -121,6 +134,10 @@ def read_http_body(rfile, headers, all, limit): def parse_http_protocol(s): + """ + Parse an HTTP protocol declaration. Returns a (major, minor) tuple, or + None. + """ if not s.startswith("HTTP/"): return None major, minor = s.split('/')[1].split('.') @@ -201,18 +218,26 @@ def response_connection_close(httpversion, headers): """ if request_connection_close(httpversion, headers): return True - elif not has_chunked_encoding(headers) and "content-length" in headers: - return True - return False + elif (not has_chunked_encoding(headers)) and "content-length" in headers: + return False + return True def read_http_body_request(rfile, wfile, headers, httpversion, limit): + """ + Read the HTTP body from a client request. + """ if "expect" in headers: # FIXME: Should be forwarded upstream - expect = ",".join(headers['expect']) - if expect == "100-continue" and httpversion >= (1, 1): + if "100-continue" in headers['expect'] and httpversion >= (1, 1): wfile.write('HTTP/1.1 100 Continue\r\n') - wfile.write('Proxy-agent: %s\r\n'%version.NAMEVERSION) wfile.write('\r\n') del headers['expect'] - return read_http_body(rfile, headers, False, limit) + return read_http_body(400, rfile, headers, False, limit) + + +def read_http_body_response(rfile, headers, False, limit): + """ + Read the HTTP body from a server response. + """ + return read_http_body(500, rfile, headers, False, limit) -- cgit v1.2.3 From 171de05d8ea4a31b0f97c38206b44826364d7693 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Sat, 23 Jun 2012 18:34:51 +1200 Subject: Add http_status.py --- netlib/http_status.py | 103 ++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 103 insertions(+) create mode 100644 netlib/http_status.py (limited to 'netlib') diff --git a/netlib/http_status.py b/netlib/http_status.py new file mode 100644 index 00000000..9f3f7e15 --- /dev/null +++ b/netlib/http_status.py @@ -0,0 +1,103 @@ + +CONTINUE = 100 +SWITCHING = 101 +OK = 200 +CREATED = 201 +ACCEPTED = 202 +NON_AUTHORITATIVE_INFORMATION = 203 +NO_CONTENT = 204 +RESET_CONTENT = 205 +PARTIAL_CONTENT = 206 +MULTI_STATUS = 207 + +MULTIPLE_CHOICE = 300 +MOVED_PERMANENTLY = 301 +FOUND = 302 +SEE_OTHER = 303 +NOT_MODIFIED = 304 +USE_PROXY = 305 +TEMPORARY_REDIRECT = 307 + +BAD_REQUEST = 400 +UNAUTHORIZED = 401 +PAYMENT_REQUIRED = 402 +FORBIDDEN = 403 +NOT_FOUND = 404 +NOT_ALLOWED = 405 +NOT_ACCEPTABLE = 406 +PROXY_AUTH_REQUIRED = 407 +REQUEST_TIMEOUT = 408 +CONFLICT = 409 +GONE = 410 +LENGTH_REQUIRED = 411 +PRECONDITION_FAILED = 412 +REQUEST_ENTITY_TOO_LARGE = 413 +REQUEST_URI_TOO_LONG = 414 +UNSUPPORTED_MEDIA_TYPE = 415 +REQUESTED_RANGE_NOT_SATISFIABLE = 416 +EXPECTATION_FAILED = 417 + +INTERNAL_SERVER_ERROR = 500 +NOT_IMPLEMENTED = 501 +BAD_GATEWAY = 502 +SERVICE_UNAVAILABLE = 503 +GATEWAY_TIMEOUT = 504 +HTTP_VERSION_NOT_SUPPORTED = 505 +INSUFFICIENT_STORAGE_SPACE = 507 +NOT_EXTENDED = 510 + +RESPONSES = { + # 100 + CONTINUE: "Continue", + SWITCHING: "Switching Protocols", + + # 200 + OK: "OK", + CREATED: "Created", + ACCEPTED: "Accepted", + NON_AUTHORITATIVE_INFORMATION: "Non-Authoritative Information", + NO_CONTENT: "No Content", + RESET_CONTENT: "Reset Content.", + PARTIAL_CONTENT: "Partial Content", + MULTI_STATUS: "Multi-Status", + + # 300 + MULTIPLE_CHOICE: "Multiple Choices", + MOVED_PERMANENTLY: "Moved Permanently", + FOUND: "Found", + SEE_OTHER: "See Other", + NOT_MODIFIED: "Not Modified", + USE_PROXY: "Use Proxy", + # 306 not defined?? + TEMPORARY_REDIRECT: "Temporary Redirect", + + # 400 + BAD_REQUEST: "Bad Request", + UNAUTHORIZED: "Unauthorized", + PAYMENT_REQUIRED: "Payment Required", + FORBIDDEN: "Forbidden", + NOT_FOUND: "Not Found", + NOT_ALLOWED: "Method Not Allowed", + NOT_ACCEPTABLE: "Not Acceptable", + PROXY_AUTH_REQUIRED: "Proxy Authentication Required", + REQUEST_TIMEOUT: "Request Time-out", + CONFLICT: "Conflict", + GONE: "Gone", + LENGTH_REQUIRED: "Length Required", + PRECONDITION_FAILED: "Precondition Failed", + REQUEST_ENTITY_TOO_LARGE: "Request Entity Too Large", + REQUEST_URI_TOO_LONG: "Request-URI Too Long", + UNSUPPORTED_MEDIA_TYPE: "Unsupported Media Type", + REQUESTED_RANGE_NOT_SATISFIABLE: "Requested Range not satisfiable", + EXPECTATION_FAILED: "Expectation Failed", + + # 500 + INTERNAL_SERVER_ERROR: "Internal Server Error", + NOT_IMPLEMENTED: "Not Implemented", + BAD_GATEWAY: "Bad Gateway", + SERVICE_UNAVAILABLE: "Service Unavailable", + GATEWAY_TIMEOUT: "Gateway Time-out", + HTTP_VERSION_NOT_SUPPORTED: "HTTP Version not supported", + INSUFFICIENT_STORAGE_SPACE: "Insufficient Storage Space", + NOT_EXTENDED: "Not Extended" +} -- cgit v1.2.3 From 0de765f3600bfa977cffb48da1efa26f2e3236f3 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Sun, 24 Jun 2012 21:49:23 +1200 Subject: Make read_headers return an ODictCaseless object. --- netlib/http.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) (limited to 'netlib') diff --git a/netlib/http.py b/netlib/http.py index da43d070..1f5f8901 100644 --- a/netlib/http.py +++ b/netlib/http.py @@ -1,4 +1,5 @@ import string, urlparse +import odict class HttpError(Exception): def __init__(self, code, msg): @@ -54,7 +55,7 @@ def read_headers(fp): name = line[:i] value = line[i+1:].strip() ret.append([name, value]) - return ret + return odict.ODictCaseless(ret) def read_chunked(code, fp, limit): @@ -107,7 +108,7 @@ def has_chunked_encoding(headers): def read_http_body(code, rfile, headers, all, limit): """ Read an HTTP body: - + code: The HTTP error code to be used when raising HttpError rfile: A file descriptor to read from headers: An ODictCaseless object -- cgit v1.2.3 From 5988b65419d6d498b760876b47e4bd627b2467f6 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Sun, 24 Jun 2012 22:45:40 +1200 Subject: Add and unit test http.read_response --- netlib/http.py | 40 ++++++++++++++++++++++++++++++++++++---- 1 file changed, 36 insertions(+), 4 deletions(-) (limited to 'netlib') diff --git a/netlib/http.py b/netlib/http.py index 1f5f8901..f0982b6d 100644 --- a/netlib/http.py +++ b/netlib/http.py @@ -128,7 +128,7 @@ def read_http_body(code, rfile, headers, all, limit): raise HttpError(code, "HTTP Body too large. Limit is %s, content-length was %s"%(limit, l)) content = rfile.read(l) elif all: - content = rfile.read(limit if limit else None) + content = rfile.read(limit if limit else -1) else: content = "" return content @@ -141,7 +141,10 @@ def parse_http_protocol(s): """ if not s.startswith("HTTP/"): return None - major, minor = s.split('/')[1].split('.') + _, version = s.split('/') + if "." not in version: + return None + major, minor = version.split('.') major = int(major) minor = int(minor) return major, minor @@ -237,8 +240,37 @@ def read_http_body_request(rfile, wfile, headers, httpversion, limit): return read_http_body(400, rfile, headers, False, limit) -def read_http_body_response(rfile, headers, False, limit): +def read_http_body_response(rfile, headers, all, limit): """ Read the HTTP body from a server response. """ - return read_http_body(500, rfile, headers, False, limit) + return read_http_body(500, rfile, headers, all, limit) + + +def read_response(rfile, method, body_size_limit): + line = rfile.readline() + if line == "\r\n" or line == "\n": # Possible leftover from previous message + line = rfile.readline() + if not line: + raise HttpError(502, "Blank server response.") + parts = line.strip().split(" ", 2) + if len(parts) == 2: # handle missing message gracefully + parts.append("") + if not len(parts) == 3: + raise HttpError(502, "Invalid server response: %s."%line) + proto, code, msg = parts + httpversion = parse_http_protocol(proto) + if httpversion is None: + raise HttpError(502, "Invalid HTTP version: %s."%httpversion) + try: + code = int(code) + except ValueError: + raise HttpError(502, "Invalid server response: %s."%line) + headers = read_headers(rfile) + if code >= 100 and code <= 199: + return read_response(rfile, method, body_size_limit) + if method == "HEAD" or code == 204 or code == 304: + content = "" + else: + content = read_http_body_response(rfile, headers, True, body_size_limit) + return httpversion, code, msg, headers, content -- cgit v1.2.3 From 820ac5152e02108f9d4e2226da1ba4369f67a4df Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Sun, 24 Jun 2012 22:57:09 +1200 Subject: WSGI SERVER_PORT should be a string. --- netlib/wsgi.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'netlib') diff --git a/netlib/wsgi.py b/netlib/wsgi.py index 3c3a8384..755bea5a 100644 --- a/netlib/wsgi.py +++ b/netlib/wsgi.py @@ -55,7 +55,7 @@ class WSGIAdaptor: 'CONTENT_TYPE': request.headers.get('Content-Type', [''])[0], 'CONTENT_LENGTH': request.headers.get('Content-Length', [''])[0], 'SERVER_NAME': self.domain, - 'SERVER_PORT': self.port, + 'SERVER_PORT': str(self.port), # FIXME: We need to pick up the protocol read from the request. 'SERVER_PROTOCOL': "HTTP/1.1", } -- cgit v1.2.3 From 7d01d5c7970c2b1b86bc6c98be5dfcaa145b1d53 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Sun, 24 Jun 2012 23:13:09 +1200 Subject: Don't read all from server by default. This can cause us to hang waiting for data. More research is needed to establish the right course of action here. --- netlib/http.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'netlib') diff --git a/netlib/http.py b/netlib/http.py index f0982b6d..150995dd 100644 --- a/netlib/http.py +++ b/netlib/http.py @@ -272,5 +272,5 @@ def read_response(rfile, method, body_size_limit): if method == "HEAD" or code == 204 or code == 304: content = "" else: - content = read_http_body_response(rfile, headers, True, body_size_limit) + content = read_http_body_response(rfile, headers, False, body_size_limit) return httpversion, code, msg, headers, content -- cgit v1.2.3 From 8f0754b9c48176aa479dc7701c42b26e115163a5 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Mon, 25 Jun 2012 11:00:39 +1200 Subject: SSL tests, plus some self-signed test certificates. --- netlib/tcp.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'netlib') diff --git a/netlib/tcp.py b/netlib/tcp.py index 5a942522..007cf3a5 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -48,8 +48,8 @@ class FileLike: class TCPClient: - def __init__(self, ssl, host, port, clientcert): - self.ssl, self.host, self.port, self.clientcert = ssl, host, port, clientcert + def __init__(self, ssl, host, port, clientcert, sni): + self.ssl, self.host, self.port, self.clientcert, self.sni = ssl, host, port, clientcert, sni self.connection, self.rfile, self.wfile = None, None, None self.cert = None self.connect() -- cgit v1.2.3 From f3237503a77258d37b67c5716ac178cbfd7ffe1b Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Mon, 25 Jun 2012 11:23:04 +1200 Subject: Don't connect during __init__ methods for either client or server. This means we now need to do these things explicitly at the caller. --- netlib/tcp.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) (limited to 'netlib') diff --git a/netlib/tcp.py b/netlib/tcp.py index 007cf3a5..25e83e07 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -48,11 +48,10 @@ class FileLike: class TCPClient: - def __init__(self, ssl, host, port, clientcert, sni): - self.ssl, self.host, self.port, self.clientcert, self.sni = ssl, host, port, clientcert, sni + def __init__(self, ssl, host, port, clientcert): + self.ssl, self.host, self.port, self.clientcert = ssl, host, port, clientcert self.connection, self.rfile, self.wfile = None, None, None self.cert = None - self.connect() def connect(self): try: @@ -75,6 +74,9 @@ class TCPClient: class BaseHandler: + """ + The instantiator is expected to call the handle() and finish() methods. + """ rbufsize = -1 wbufsize = 0 def __init__(self, connection, client_address, server): @@ -84,8 +86,6 @@ class BaseHandler: self.client_address = client_address self.server = server - self.handle() - self.finish() def convert_to_ssl(self, cert, key): ctx = SSL.Context(SSL.SSLv23_METHOD) -- cgit v1.2.3 From 47f862ae278c61df9bd1b62ec291a954fc0707ea Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Mon, 25 Jun 2012 11:34:10 +1200 Subject: Add a finished flag to BaseHandler, and catch an extra OpenSSL exception. --- netlib/tcp.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) (limited to 'netlib') diff --git a/netlib/tcp.py b/netlib/tcp.py index 25e83e07..91b0c742 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -20,7 +20,7 @@ class FileLike: while len(result) < length: try: data = self.o.read(length) - except SSL.ZeroReturnError: + except (SSL.ZeroReturnError, SSL.SysCallError): break if not data: break @@ -86,6 +86,7 @@ class BaseHandler: self.client_address = client_address self.server = server + self.finished = False def convert_to_ssl(self, cert, key): ctx = SSL.Context(SSL.SSLv23_METHOD) @@ -97,6 +98,7 @@ class BaseHandler: self.wfile = FileLike(self.connection) def finish(self): + self.finished = True try: if not getattr(self.wfile, "closed", False): self.wfile.flush() -- cgit v1.2.3 From 353efec7ce032a447efbba60c5ccea441bc573fb Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Mon, 25 Jun 2012 14:42:15 +1200 Subject: Improve TCPClient interface. - Don't pass SSL parameters on instantiation. - Add a convert_to_ssl method analogous to that in TCPServer. --- netlib/tcp.py | 31 ++++++++++++++++--------------- 1 file changed, 16 insertions(+), 15 deletions(-) (limited to 'netlib') diff --git a/netlib/tcp.py b/netlib/tcp.py index 91b0c742..3c5c89b7 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -48,29 +48,30 @@ class FileLike: class TCPClient: - def __init__(self, ssl, host, port, clientcert): - self.ssl, self.host, self.port, self.clientcert = ssl, host, port, clientcert + def __init__(self, host, port): + self.host, self.port = host, port self.connection, self.rfile, self.wfile = None, None, None self.cert = None + def convert_to_ssl(self, clientcert=None): + context = SSL.Context(SSL.SSLv23_METHOD) + if clientcert: + context.use_certificate_file(self.clientcert) + self.connection = SSL.Connection(context, self.connection) + self.connection.set_connect_state() + self.cert = self.connection.get_peer_certificate() + self.rfile = FileLike(self.connection) + self.wfile = FileLike(self.connection) + def connect(self): try: addr = socket.gethostbyname(self.host) - server = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - if self.ssl: - context = SSL.Context(SSL.SSLv23_METHOD) - if self.clientcert: - context.use_certificate_file(self.clientcert) - server = SSL.Connection(context, server) - server.connect((addr, self.port)) - if self.ssl: - self.cert = server.get_peer_certificate() - self.rfile, self.wfile = FileLike(server), FileLike(server) - else: - self.rfile, self.wfile = server.makefile('rb'), server.makefile('wb') + connection = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + connection.connect((addr, self.port)) + self.rfile, self.wfile = connection.makefile('rb'), connection.makefile('wb') except socket.error, err: raise NetLibError('Error connecting to "%s": %s' % (self.host, err)) - self.connection = server + self.connection = connection class BaseHandler: -- cgit v1.2.3 From ea457fac2e270c258172be65a0eeb4701ad23d8e Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Mon, 25 Jun 2012 16:16:01 +1200 Subject: Perform handshake immediately on SSL conversion. Otherwise the handshake happens at first write, which can balls up if either side hangs immediately. --- netlib/tcp.py | 2 ++ 1 file changed, 2 insertions(+) (limited to 'netlib') diff --git a/netlib/tcp.py b/netlib/tcp.py index 3c5c89b7..276d3162 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -59,6 +59,7 @@ class TCPClient: context.use_certificate_file(self.clientcert) self.connection = SSL.Connection(context, self.connection) self.connection.set_connect_state() + self.connection.do_handshake() self.cert = self.connection.get_peer_certificate() self.rfile = FileLike(self.connection) self.wfile = FileLike(self.connection) @@ -95,6 +96,7 @@ class BaseHandler: ctx.use_certificate_file(cert) self.connection = SSL.Connection(ctx, self.connection) self.connection.set_accept_state() + self.connection.do_handshake() self.rfile = FileLike(self.connection) self.wfile = FileLike(self.connection) -- cgit v1.2.3 From ccf2603ddc9c832f9533eeb3c4ffbbd685b00057 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Tue, 26 Jun 2012 09:50:42 +1200 Subject: Add SNI. --- netlib/tcp.py | 23 ++++++++++++++++++++++- 1 file changed, 22 insertions(+), 1 deletion(-) (limited to 'netlib') diff --git a/netlib/tcp.py b/netlib/tcp.py index 276d3162..c8ffefdf 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -53,11 +53,13 @@ class TCPClient: self.connection, self.rfile, self.wfile = None, None, None self.cert = None - def convert_to_ssl(self, clientcert=None): + def convert_to_ssl(self, clientcert=None, sni=None): context = SSL.Context(SSL.SSLv23_METHOD) if clientcert: context.use_certificate_file(self.clientcert) self.connection = SSL.Connection(context, self.connection) + if sni: + self.connection.set_tlsext_host_name(sni) self.connection.set_connect_state() self.connection.do_handshake() self.cert = self.connection.get_peer_certificate() @@ -92,10 +94,12 @@ class BaseHandler: def convert_to_ssl(self, cert, key): ctx = SSL.Context(SSL.SSLv23_METHOD) + ctx.set_tlsext_servername_callback(self.handle_sni) ctx.use_privatekey_file(key) ctx.use_certificate_file(cert) self.connection = SSL.Connection(ctx, self.connection) self.connection.set_accept_state() + # SNI callback happens during do_handshake() self.connection.do_handshake() self.rfile = FileLike(self.connection) self.wfile = FileLike(self.connection) @@ -111,6 +115,23 @@ class BaseHandler: except IOError: # pragma: no cover pass + def handle_sni(self, connection): + """ + Called if the client has given a server name indication. + + Server name can be retrieved like this: + + connection.get_servername() + + And you can specify the connection keys as follows: + + new_context = Context(TLSv1_METHOD) + new_context.use_privatekey(key) + new_context.use_certificate(cert) + connection.set_context(new_context) + """ + pass + def handle(self): # pragma: no cover raise NotImplementedError -- cgit v1.2.3 From 658c9c0446591e41d6ebdb223c62c00342b83206 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Tue, 26 Jun 2012 14:49:23 +1200 Subject: Hunt down a tricky WSGI socket hang. --- netlib/tcp.py | 12 +++++++++--- netlib/wsgi.py | 3 ++- 2 files changed, 11 insertions(+), 4 deletions(-) (limited to 'netlib') diff --git a/netlib/tcp.py b/netlib/tcp.py index c8ffefdf..aa923fdd 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -61,7 +61,10 @@ class TCPClient: if sni: self.connection.set_tlsext_host_name(sni) self.connection.set_connect_state() - self.connection.do_handshake() + try: + self.connection.do_handshake() + except SSL.Error, v: + raise NetLibError("SSL handshake error: %s"%str(v)) self.cert = self.connection.get_peer_certificate() self.rfile = FileLike(self.connection) self.wfile = FileLike(self.connection) @@ -82,7 +85,7 @@ class BaseHandler: The instantiator is expected to call the handle() and finish() methods. """ rbufsize = -1 - wbufsize = 0 + wbufsize = -1 def __init__(self, connection, client_address, server): self.connection = connection self.rfile = self.connection.makefile('rb', self.rbufsize) @@ -100,7 +103,10 @@ class BaseHandler: self.connection = SSL.Connection(ctx, self.connection) self.connection.set_accept_state() # SNI callback happens during do_handshake() - self.connection.do_handshake() + try: + self.connection.do_handshake() + except SSL.Error, v: + raise NetLibError("SSL handshake error: %s"%str(v)) self.rfile = FileLike(self.connection) self.wfile = FileLike(self.connection) diff --git a/netlib/wsgi.py b/netlib/wsgi.py index 755bea5a..6fe6b6b3 100644 --- a/netlib/wsgi.py +++ b/netlib/wsgi.py @@ -104,7 +104,8 @@ class WSGIAdaptor: soc.write(str(h)) soc.write("\r\n") state["headers_sent"] = True - soc.write(data) + if data: + soc.write(data) soc.flush() def start_response(status, headers, exc_info=None): -- cgit v1.2.3 From abe335e57dd2871a6ea6cfe2559f9b29ae0c33bb Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Tue, 26 Jun 2012 23:52:35 +1200 Subject: Add a flag to track SSL connection establishment. --- netlib/tcp.py | 4 ++++ 1 file changed, 4 insertions(+) (limited to 'netlib') diff --git a/netlib/tcp.py b/netlib/tcp.py index aa923fdd..9b1fc65e 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -52,6 +52,7 @@ class TCPClient: self.host, self.port = host, port self.connection, self.rfile, self.wfile = None, None, None self.cert = None + self.ssl_established = False def convert_to_ssl(self, clientcert=None, sni=None): context = SSL.Context(SSL.SSLv23_METHOD) @@ -68,6 +69,7 @@ class TCPClient: self.cert = self.connection.get_peer_certificate() self.rfile = FileLike(self.connection) self.wfile = FileLike(self.connection) + self.ssl_established = True def connect(self): try: @@ -94,6 +96,7 @@ class BaseHandler: self.client_address = client_address self.server = server self.finished = False + self.ssl_established = False def convert_to_ssl(self, cert, key): ctx = SSL.Context(SSL.SSLv23_METHOD) @@ -109,6 +112,7 @@ class BaseHandler: raise NetLibError("SSL handshake error: %s"%str(v)) self.rfile = FileLike(self.connection) self.wfile = FileLike(self.connection) + self.ssl_established = True def finish(self): self.finished = True -- cgit v1.2.3 From d0fd8385e60ea6149d9ff6876fb5b4343187b23a Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Wed, 27 Jun 2012 12:11:55 +1200 Subject: Fix termiantion error in file read. --- netlib/tcp.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) (limited to 'netlib') diff --git a/netlib/tcp.py b/netlib/tcp.py index 9b1fc65e..0ab7f0e4 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -17,7 +17,7 @@ class FileLike: def read(self, length): result = '' - while len(result) < length: + while length > 0: try: data = self.o.read(length) except (SSL.ZeroReturnError, SSL.SysCallError): @@ -25,6 +25,7 @@ class FileLike: if not data: break result += data + length -= len(data) return result def write(self, v): -- cgit v1.2.3 From 5d4c7829bfdda8c0a5fd28896fd925d63221b929 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Wed, 27 Jun 2012 16:24:22 +1200 Subject: Minor refactoring. --- netlib/http.py | 3 +++ netlib/tcp.py | 2 +- 2 files changed, 4 insertions(+), 1 deletion(-) (limited to 'netlib') diff --git a/netlib/http.py b/netlib/http.py index 150995dd..9c72c601 100644 --- a/netlib/http.py +++ b/netlib/http.py @@ -248,6 +248,9 @@ def read_http_body_response(rfile, headers, all, limit): def read_response(rfile, method, body_size_limit): + """ + Return an (httpversion, code, msg, headers, content) tuple. + """ line = rfile.readline() if line == "\r\n" or line == "\n": # Possible leftover from previous message line = rfile.readline() diff --git a/netlib/tcp.py b/netlib/tcp.py index 9b1fc65e..49c8b7a2 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -156,8 +156,8 @@ class TCPServer: self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) self.socket.bind(self.server_address) self.server_address = self.socket.getsockname() + self.port = self.server_address[1] self.socket.listen(self.request_queue_size) - self.port = self.socket.getsockname()[1] def request_thread(self, request, client_address): try: -- cgit v1.2.3 From f7fcb1c80b2874df05db4603549c6a24d12e58c0 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Wed, 27 Jun 2012 16:42:00 +1200 Subject: Add certutils to netlib. --- netlib/certutils.py | 219 ++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 219 insertions(+) create mode 100644 netlib/certutils.py (limited to 'netlib') diff --git a/netlib/certutils.py b/netlib/certutils.py new file mode 100644 index 00000000..31b1fa08 --- /dev/null +++ b/netlib/certutils.py @@ -0,0 +1,219 @@ +import os, ssl, hashlib, socket, time, datetime +from pyasn1.type import univ, constraint, char, namedtype, tag +from pyasn1.codec.der.decoder import decode +import OpenSSL + +CERT_SLEEP_TIME = 1 +CERT_EXPIRY = str(365 * 3) + + +def create_ca(): + key = OpenSSL.crypto.PKey() + key.generate_key(OpenSSL.crypto.TYPE_RSA, 1024) + ca = OpenSSL.crypto.X509() + ca.set_serial_number(int(time.time()*10000)) + ca.set_version(2) + ca.get_subject().CN = "mitmproxy" + ca.get_subject().O = "mitmproxy" + ca.gmtime_adj_notBefore(0) + ca.gmtime_adj_notAfter(24 * 60 * 60 * 720) + ca.set_issuer(ca.get_subject()) + ca.set_pubkey(key) + ca.add_extensions([ + OpenSSL.crypto.X509Extension("basicConstraints", True, + "CA:TRUE"), + OpenSSL.crypto.X509Extension("nsCertType", True, + "sslCA"), + OpenSSL.crypto.X509Extension("extendedKeyUsage", True, + "serverAuth,clientAuth,emailProtection,timeStamping,msCodeInd,msCodeCom,msCTLSign,msSGC,msEFS,nsSGC" + ), + OpenSSL.crypto.X509Extension("keyUsage", False, + "keyCertSign, cRLSign"), + OpenSSL.crypto.X509Extension("subjectKeyIdentifier", False, "hash", + subject=ca), + ]) + ca.sign(key, "sha1") + return key, ca + + +def dummy_ca(path): + dirname = os.path.dirname(path) + if not os.path.exists(dirname): + os.makedirs(dirname) + if path.endswith(".pem"): + basename, _ = os.path.splitext(path) + else: + basename = path + + key, ca = create_ca() + + # Dump the CA plus private key + f = open(path, "w") + f.write(OpenSSL.crypto.dump_privatekey(OpenSSL.crypto.FILETYPE_PEM, key)) + f.write(OpenSSL.crypto.dump_certificate(OpenSSL.crypto.FILETYPE_PEM, ca)) + f.close() + + # Dump the certificate in PEM format + f = open(os.path.join(dirname, basename + "-cert.pem"), "w") + f.write(OpenSSL.crypto.dump_certificate(OpenSSL.crypto.FILETYPE_PEM, ca)) + f.close() + + # Create a .cer file with the same contents for Android + f = open(os.path.join(dirname, basename + "-cert.cer"), "w") + f.write(OpenSSL.crypto.dump_certificate(OpenSSL.crypto.FILETYPE_PEM, ca)) + f.close() + + # Dump the certificate in PKCS12 format for Windows devices + f = open(os.path.join(dirname, basename + "-cert.p12"), "w") + p12 = OpenSSL.crypto.PKCS12() + p12.set_certificate(ca) + p12.set_privatekey(key) + f.write(p12.export()) + f.close() + return True + + +def dummy_cert(certdir, ca, commonname, sans): + """ + certdir: Certificate directory. + ca: Path to the certificate authority file, or None. + commonname: Common name for the generated certificate. + + Returns cert path if operation succeeded, None if not. + """ + namehash = hashlib.sha256(commonname).hexdigest() + certpath = os.path.join(certdir, namehash + ".pem") + if os.path.exists(certpath): + return certpath + + ss = [] + for i in sans: + ss.append("DNS: %s"%i) + ss = ", ".join(ss) + + if ca: + raw = file(ca, "r").read() + ca = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, raw) + key = OpenSSL.crypto.load_privatekey(OpenSSL.crypto.FILETYPE_PEM, raw) + else: + key, ca = create_ca() + + req = OpenSSL.crypto.X509Req() + subj = req.get_subject() + subj.CN = commonname + req.set_pubkey(ca.get_pubkey()) + req.sign(key, "sha1") + if ss: + req.add_extensions([OpenSSL.crypto.X509Extension("subjectAltName", True, ss)]) + + cert = OpenSSL.crypto.X509() + cert.gmtime_adj_notBefore(0) + cert.gmtime_adj_notAfter(60 * 60 * 24 * 30) + cert.set_issuer(ca.get_subject()) + cert.set_subject(req.get_subject()) + cert.set_serial_number(int(time.time()*10000)) + if ss: + cert.add_extensions([OpenSSL.crypto.X509Extension("subjectAltName", True, ss)]) + cert.set_pubkey(req.get_pubkey()) + cert.sign(key, "sha1") + + f = open(certpath, "w") + f.write(OpenSSL.crypto.dump_certificate(OpenSSL.crypto.FILETYPE_PEM, cert)) + f.close() + + return certpath + + +class _GeneralName(univ.Choice): + # We are only interested in dNSNames. We use a default handler to ignore + # other types. + componentType = namedtype.NamedTypes( + namedtype.NamedType('dNSName', char.IA5String().subtype( + implicitTag=tag.Tag(tag.tagClassContext, tag.tagFormatSimple, 2) + ) + ), + ) + + +class _GeneralNames(univ.SequenceOf): + componentType = _GeneralName() + sizeSpec = univ.SequenceOf.sizeSpec + constraint.ValueSizeConstraint(1, 1024) + + +class SSLCert: + def __init__(self, pemtxt): + """ + Returns a (common name, [subject alternative names]) tuple. + """ + self.cert = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, pemtxt) + + @classmethod + def from_der(klass, der): + pem = ssl.DER_cert_to_PEM_cert(der) + return klass(pem) + + def digest(self, name): + return self.cert.digest(name) + + @property + def issuer(self): + return self.cert.get_issuer().get_components() + + @property + def notbefore(self): + t = self.cert.get_notBefore() + return datetime.datetime.strptime(t, "%Y%m%d%H%M%SZ") + + @property + def notafter(self): + t = self.cert.get_notAfter() + return datetime.datetime.strptime(t, "%Y%m%d%H%M%SZ") + + @property + def has_expired(self): + return self.cert.has_expired() + + @property + def subject(self): + return self.cert.get_subject().get_components() + + @property + def serial(self): + return self.cert.get_serial_number() + + @property + def keyinfo(self): + pk = self.cert.get_pubkey() + types = { + OpenSSL.crypto.TYPE_RSA: "RSA", + OpenSSL.crypto.TYPE_DSA: "DSA", + } + return ( + types.get(pk.type(), "UNKNOWN"), + pk.bits() + ) + + @property + def cn(self): + cn = None + for i in self.subject: + if i[0] == "CN": + cn = i[1] + return cn + + @property + def altnames(self): + altnames = [] + for i in range(self.cert.get_extension_count()): + ext = self.cert.get_extension(i) + if ext.get_short_name() == "subjectAltName": + dec = decode(ext.get_data(), asn1Spec=_GeneralNames()) + for i in dec[0]: + altnames.append(i[0].asOctets()) + return altnames + + +def get_remote_cert(host, port): # pragma: no cover + addr = socket.gethostbyname(host) + s = ssl.get_server_certificate((addr, port)) + return SSLCert(s) -- cgit v1.2.3 From b0ef9ad07ba4b805f3130237dcf9207434c33d84 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Wed, 27 Jun 2012 22:11:58 +1200 Subject: Refactor certutils.SSLCert API. --- netlib/certutils.py | 31 ++++++++++++++++++------------- 1 file changed, 18 insertions(+), 13 deletions(-) (limited to 'netlib') diff --git a/netlib/certutils.py b/netlib/certutils.py index 31b1fa08..6c9a5c57 100644 --- a/netlib/certutils.py +++ b/netlib/certutils.py @@ -141,49 +141,54 @@ class _GeneralNames(univ.SequenceOf): class SSLCert: - def __init__(self, pemtxt): + def __init__(self, cert): """ Returns a (common name, [subject alternative names]) tuple. """ - self.cert = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, pemtxt) + self.x509 = cert + + @classmethod + def from_pem(klass, txt): + x509 = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, txt) + return klass(x509) @classmethod def from_der(klass, der): pem = ssl.DER_cert_to_PEM_cert(der) - return klass(pem) + return klass.from_pem(pem) def digest(self, name): - return self.cert.digest(name) + return self.x509.digest(name) @property def issuer(self): - return self.cert.get_issuer().get_components() + return self.x509.get_issuer().get_components() @property def notbefore(self): - t = self.cert.get_notBefore() + t = self.x509.get_notBefore() return datetime.datetime.strptime(t, "%Y%m%d%H%M%SZ") @property def notafter(self): - t = self.cert.get_notAfter() + t = self.x509.get_notAfter() return datetime.datetime.strptime(t, "%Y%m%d%H%M%SZ") @property def has_expired(self): - return self.cert.has_expired() + return self.x509.has_expired() @property def subject(self): - return self.cert.get_subject().get_components() + return self.x509.get_subject().get_components() @property def serial(self): - return self.cert.get_serial_number() + return self.x509.get_serial_number() @property def keyinfo(self): - pk = self.cert.get_pubkey() + pk = self.x509.get_pubkey() types = { OpenSSL.crypto.TYPE_RSA: "RSA", OpenSSL.crypto.TYPE_DSA: "DSA", @@ -204,8 +209,8 @@ class SSLCert: @property def altnames(self): altnames = [] - for i in range(self.cert.get_extension_count()): - ext = self.cert.get_extension(i) + for i in range(self.x509.get_extension_count()): + ext = self.x509.get_extension(i) if ext.get_short_name() == "subjectAltName": dec = decode(ext.get_data(), asn1Spec=_GeneralNames()) for i in dec[0]: -- cgit v1.2.3 From a1491a6ae037b7874dd71de11f5cd43e10aa46e7 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Thu, 28 Jun 2012 08:15:55 +1200 Subject: Add a get_remote_cert method to tcp client. --- netlib/certutils.py | 10 ++++++---- netlib/tcp.py | 1 + 2 files changed, 7 insertions(+), 4 deletions(-) (limited to 'netlib') diff --git a/netlib/certutils.py b/netlib/certutils.py index 6c9a5c57..180e1ac0 100644 --- a/netlib/certutils.py +++ b/netlib/certutils.py @@ -2,6 +2,7 @@ import os, ssl, hashlib, socket, time, datetime from pyasn1.type import univ, constraint, char, namedtype, tag from pyasn1.codec.der.decoder import decode import OpenSSL +import tcp CERT_SLEEP_TIME = 1 CERT_EXPIRY = str(365 * 3) @@ -218,7 +219,8 @@ class SSLCert: return altnames -def get_remote_cert(host, port): # pragma: no cover - addr = socket.gethostbyname(host) - s = ssl.get_server_certificate((addr, port)) - return SSLCert(s) +def get_remote_cert(host, port, sni): + c = tcp.TCPClient(host, port) + c.connect() + c.convert_to_ssl(sni=sni) + return c.cert diff --git a/netlib/tcp.py b/netlib/tcp.py index ef3298d5..6c5b4976 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -1,5 +1,6 @@ import select, socket, threading, traceback, sys from OpenSSL import SSL +import certutils class NetLibError(Exception): pass -- cgit v1.2.3 From 92c7d38bd343a0436d73c0a984fe111996e15059 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Thu, 28 Jun 2012 09:56:58 +1200 Subject: Handle obscure termination scenario, where interpreter exits before thread termination. --- netlib/tcp.py | 24 ++++++++++++++---------- 1 file changed, 14 insertions(+), 10 deletions(-) (limited to 'netlib') diff --git a/netlib/tcp.py b/netlib/tcp.py index 0ab7f0e4..f02be550 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -117,14 +117,11 @@ class BaseHandler: def finish(self): self.finished = True - try: - if not getattr(self.wfile, "closed", False): - self.wfile.flush() - self.connection.close() - self.wfile.close() - self.rfile.close() - except IOError: # pragma: no cover - pass + if not getattr(self.wfile, "closed", False): + self.wfile.flush() + self.connection.close() + self.wfile.close() + self.rfile.close() def handle_sni(self, connection): """ @@ -165,8 +162,15 @@ class TCPServer: self.handle_connection(request, client_address) request.close() except: - self.handle_error(request, client_address) - request.close() + try: + self.handle_error(request, client_address) + request.close() + # Why a blanket except here? In some circumstances, a thread can + # persist until the interpreter exits. When this happens, all modules + # and builtins are set to None, and things balls up indeterminate + # ways. + except: + pass def serve_forever(self, poll_interval=0.1): self.__is_shut_down.clear() -- cgit v1.2.3 From 3f9aad53ab9b567ddc89848c54234d667a846db8 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Thu, 28 Jun 2012 10:59:03 +1200 Subject: Return a certutils.SSLCert object from get_remote_cert. --- netlib/tcp.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'netlib') diff --git a/netlib/tcp.py b/netlib/tcp.py index a265ef7a..b3fc2212 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -68,7 +68,7 @@ class TCPClient: self.connection.do_handshake() except SSL.Error, v: raise NetLibError("SSL handshake error: %s"%str(v)) - self.cert = self.connection.get_peer_certificate() + self.cert = certutils.SSLCert(self.connection.get_peer_certificate()) self.rfile = FileLike(self.connection) self.wfile = FileLike(self.connection) self.ssl_established = True -- cgit v1.2.3 From 7480f87cd721de6ca9d0cdb7c9437bdb58b16ba0 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Thu, 28 Jun 2012 14:56:21 +1200 Subject: Add utility function for converstion to PEM. --- netlib/certutils.py | 3 +++ 1 file changed, 3 insertions(+) (limited to 'netlib') diff --git a/netlib/certutils.py b/netlib/certutils.py index 180e1ac0..dcd54053 100644 --- a/netlib/certutils.py +++ b/netlib/certutils.py @@ -158,6 +158,9 @@ class SSLCert: pem = ssl.DER_cert_to_PEM_cert(der) return klass.from_pem(pem) + def to_pem(self): + return OpenSSL.crypto.dump_certificate(OpenSSL.crypto.FILETYPE_PEM, self.x509) + def digest(self, name): return self.x509.digest(name) -- cgit v1.2.3 From 67669a2a578157782a621fa1ac5531bbb2db8029 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Sat, 30 Jun 2012 10:52:28 +1200 Subject: Allow control of buffer size for TCPClient, improve error messages. --- netlib/http.py | 6 +++--- netlib/tcp.py | 5 ++++- 2 files changed, 7 insertions(+), 4 deletions(-) (limited to 'netlib') diff --git a/netlib/http.py b/netlib/http.py index 9c72c601..acd9d85e 100644 --- a/netlib/http.py +++ b/netlib/http.py @@ -260,15 +260,15 @@ def read_response(rfile, method, body_size_limit): if len(parts) == 2: # handle missing message gracefully parts.append("") if not len(parts) == 3: - raise HttpError(502, "Invalid server response: %s."%line) + raise HttpError(502, "Invalid server response: %s"%repr(line)) proto, code, msg = parts httpversion = parse_http_protocol(proto) if httpversion is None: - raise HttpError(502, "Invalid HTTP version: %s."%httpversion) + raise HttpError(502, "Invalid HTTP version: %s"%repr(httpversion)) try: code = int(code) except ValueError: - raise HttpError(502, "Invalid server response: %s."%line) + raise HttpError(502, "Invalid server response: %s"%repr(line)) headers = read_headers(rfile) if code >= 100 and code <= 199: return read_response(rfile, method, body_size_limit) diff --git a/netlib/tcp.py b/netlib/tcp.py index b3fc2212..bb0a00b9 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -50,6 +50,8 @@ class FileLike: class TCPClient: + rbufsize = -1 + wbufsize = -1 def __init__(self, host, port): self.host, self.port = host, port self.connection, self.rfile, self.wfile = None, None, None @@ -78,7 +80,8 @@ class TCPClient: addr = socket.gethostbyname(self.host) connection = socket.socket(socket.AF_INET, socket.SOCK_STREAM) connection.connect((addr, self.port)) - self.rfile, self.wfile = connection.makefile('rb'), connection.makefile('wb') + self.rfile = connection.makefile('rb', self.rbufsize) + self.wfile = connection.makefile('wb', self.wbufsize) except socket.error, err: raise NetLibError('Error connecting to "%s": %s' % (self.host, err)) self.connection = connection -- cgit v1.2.3 From 96af5c16a065a8167d167ed1d4dc9e0a77566e25 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Wed, 4 Jul 2012 21:30:07 +1200 Subject: Expose SSL options, use TLSv1 by default for client connections. --- netlib/tcp.py | 46 ++++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 42 insertions(+), 4 deletions(-) (limited to 'netlib') diff --git a/netlib/tcp.py b/netlib/tcp.py index bb0a00b9..54148172 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -2,6 +2,37 @@ import select, socket, threading, traceback, sys from OpenSSL import SSL import certutils +SSLv2_METHOD = SSL.SSLv2_METHOD +SSLv3_METHOD = SSL.SSLv3_METHOD +SSLv23_METHOD = SSL.SSLv23_METHOD +TLSv1_METHOD = SSL.TLSv1_METHOD + +OP_ALL = SSL.OP_ALL +OP_CIPHER_SERVER_PREFERENCE = SSL.OP_CIPHER_SERVER_PREFERENCE +OP_COOKIE_EXCHANGE = SSL.OP_COOKIE_EXCHANGE +OP_DONT_INSERT_EMPTY_FRAGMENTS = SSL.OP_DONT_INSERT_EMPTY_FRAGMENTS +OP_EPHEMERAL_RSA = SSL.OP_EPHEMERAL_RSA +OP_MICROSOFT_BIG_SSLV3_BUFFER = SSL.OP_MICROSOFT_BIG_SSLV3_BUFFER +OP_MICROSOFT_SESS_ID_BUG = SSL.OP_MICROSOFT_SESS_ID_BUG +OP_MSIE_SSLV2_RSA_PADDING = SSL.OP_MSIE_SSLV2_RSA_PADDING +OP_NETSCAPE_CA_DN_BUG = SSL.OP_NETSCAPE_CA_DN_BUG +OP_NETSCAPE_CHALLENGE_BUG = SSL.OP_NETSCAPE_CHALLENGE_BUG +OP_NETSCAPE_DEMO_CIPHER_CHANGE_BUG = SSL.OP_NETSCAPE_DEMO_CIPHER_CHANGE_BUG +OP_NETSCAPE_REUSE_CIPHER_CHANGE_BUG = SSL.OP_NETSCAPE_REUSE_CIPHER_CHANGE_BUG +OP_NO_QUERY_MTU = SSL.OP_NO_QUERY_MTU +OP_NO_SSLv2 = SSL.OP_NO_SSLv2 +OP_NO_SSLv3 = SSL.OP_NO_SSLv3 +OP_NO_TICKET = SSL.OP_NO_TICKET +OP_NO_TLSv1 = SSL.OP_NO_TLSv1 +OP_PKCS1_CHECK_1 = SSL.OP_PKCS1_CHECK_1 +OP_PKCS1_CHECK_2 = SSL.OP_PKCS1_CHECK_2 +OP_SINGLE_DH_USE = SSL.OP_SINGLE_DH_USE +OP_SSLEAY_080_CLIENT_DH_BUG = SSL.OP_SSLEAY_080_CLIENT_DH_BUG +OP_SSLREF2_REUSE_CERT_TYPE_BUG = SSL.OP_SSLREF2_REUSE_CERT_TYPE_BUG +OP_TLS_BLOCK_PADDING_BUG = SSL.OP_TLS_BLOCK_PADDING_BUG +OP_TLS_D5_BUG = SSL.OP_TLS_D5_BUG +OP_TLS_ROLLBACK_BUG = SSL.OP_TLS_ROLLBACK_BUG + class NetLibError(Exception): pass @@ -58,8 +89,10 @@ class TCPClient: self.cert = None self.ssl_established = False - def convert_to_ssl(self, clientcert=None, sni=None): - context = SSL.Context(SSL.SSLv23_METHOD) + def convert_to_ssl(self, clientcert=None, sni=None, method=TLSv1_METHOD, options=None): + context = SSL.Context(method) + if not options is None: + ctx.set_options(options) if clientcert: context.use_certificate_file(self.clientcert) self.connection = SSL.Connection(context, self.connection) @@ -103,8 +136,13 @@ class BaseHandler: self.finished = False self.ssl_established = False - def convert_to_ssl(self, cert, key): - ctx = SSL.Context(SSL.SSLv23_METHOD) + def convert_to_ssl(self, cert, key, method=SSLv23_METHOD, options=None): + """ + method: One of SSLv2_METHOD, SSLv3_METHOD, SSLv23_METHOD, or TLSv1_METHOD + """ + ctx = SSL.Context(method) + if not options is None: + ctx.set_options(options) ctx.set_tlsext_servername_callback(self.handle_sni) ctx.use_privatekey_file(key) ctx.use_certificate_file(cert) -- cgit v1.2.3 From 20cc1b6aa4488d9b230469ba57b6a92380bfeeca Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Thu, 5 Jul 2012 09:37:43 +1200 Subject: Refactor TCP test suite. --- netlib/tcp.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'netlib') diff --git a/netlib/tcp.py b/netlib/tcp.py index 54148172..0af3d463 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -209,7 +209,7 @@ class TCPServer: request.close() # Why a blanket except here? In some circumstances, a thread can # persist until the interpreter exits. When this happens, all modules - # and builtins are set to None, and things balls up indeterminate + # and builtins are set to None, and things balls up in indeterminate # ways. except: pass -- cgit v1.2.3 From ba7437abcbf3db11e227cae5e5c1d2df5975c77c Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Sun, 8 Jul 2012 23:50:38 +1200 Subject: Add an exception to indicate remote disconnects. --- netlib/tcp.py | 21 +++++++++++++++------ 1 file changed, 15 insertions(+), 6 deletions(-) (limited to 'netlib') diff --git a/netlib/tcp.py b/netlib/tcp.py index 0af3d463..281a0438 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -36,6 +36,8 @@ OP_TLS_ROLLBACK_BUG = SSL.OP_TLS_ROLLBACK_BUG class NetLibError(Exception): pass +class NetLibDisconnect(Exception): pass + class FileLike: def __init__(self, o): @@ -61,7 +63,10 @@ class FileLike: return result def write(self, v): - self.o.sendall(v) + try: + return self.o.sendall(v) + except SSL.SysCallError: + raise NetLibDisconnect() def readline(self, size = None): result = '' @@ -159,11 +164,15 @@ class BaseHandler: def finish(self): self.finished = True - if not getattr(self.wfile, "closed", False): - self.wfile.flush() - self.connection.close() - self.wfile.close() - self.rfile.close() + try: + if not getattr(self.wfile, "closed", False): + self.wfile.flush() + self.wfile.close() + self.rfile.close() + self.connection.close() + except socket.error: + # Remote has disconnected + pass def handle_sni(self, connection): """ -- cgit v1.2.3 From 721e2c8277123a99abf6299ee4703109c57675db Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Tue, 10 Jul 2012 16:22:45 +1200 Subject: Somewhat nicer handling of errors after thread termination. --- netlib/tcp.py | 23 ++++++++++------------- 1 file changed, 10 insertions(+), 13 deletions(-) (limited to 'netlib') diff --git a/netlib/tcp.py b/netlib/tcp.py index 281a0438..53ad8a05 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -213,15 +213,8 @@ class TCPServer: self.handle_connection(request, client_address) request.close() except: - try: - self.handle_error(request, client_address) - request.close() - # Why a blanket except here? In some circumstances, a thread can - # persist until the interpreter exits. When this happens, all modules - # and builtins are set to None, and things balls up in indeterminate - # ways. - except: - pass + self.handle_error(request, client_address) + request.close() def serve_forever(self, poll_interval=0.1): self.__is_shut_down.clear() @@ -257,10 +250,14 @@ class TCPServer: """ Called when handle_connection raises an exception. """ - print >> fp, '-'*40 - print >> fp, "Error processing of request from %s:%s"%client_address - print >> fp, traceback.format_exc() - print >> fp, '-'*40 + # If a thread has persisted after interpreter exit, the module might be + # none. + if traceback: + exc = traceback.format_exc() + print >> fp, '-'*40 + print >> fp, "Error in processing of request from %s:%s"%client_address + print >> fp, exc + print >> fp, '-'*40 def handle_connection(self, request, client_address): # pragma: no cover """ -- cgit v1.2.3 From 4fdc2179e25926d531ea8c4a5d6fc78ce75cd6ff Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Tue, 10 Jul 2012 16:34:39 +1200 Subject: Don't write empty values. --- netlib/tcp.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) (limited to 'netlib') diff --git a/netlib/tcp.py b/netlib/tcp.py index 53ad8a05..6ba58d86 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -63,10 +63,11 @@ class FileLike: return result def write(self, v): - try: - return self.o.sendall(v) - except SSL.SysCallError: - raise NetLibDisconnect() + if v: + try: + return self.o.sendall(v) + except SSL.SysCallError: + raise NetLibDisconnect() def readline(self, size = None): result = '' -- cgit v1.2.3 From 1227369db31bff39707091f562b0ad946d14728a Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Wed, 11 Jul 2012 07:16:45 +1200 Subject: Signal errors back to caller in WSGI .serve() --- netlib/wsgi.py | 1 + 1 file changed, 1 insertion(+) (limited to 'netlib') diff --git a/netlib/wsgi.py b/netlib/wsgi.py index 6fe6b6b3..4fa2c537 100644 --- a/netlib/wsgi.py +++ b/netlib/wsgi.py @@ -131,6 +131,7 @@ class WSGIAdaptor: except Exception, v: try: s = traceback.format_exc() + errs.write(s) self.error_page(soc, state["headers_sent"], s) except Exception, v: # pragma: no cover pass # pragma: no cover -- cgit v1.2.3 From 9ab7842c81e8b34cd99d5f3e8e98282729d85344 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Wed, 11 Jul 2012 11:09:41 +0200 Subject: fix relative certdir --- netlib/certutils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) (limited to 'netlib') diff --git a/netlib/certutils.py b/netlib/certutils.py index dcd54053..3effe610 100644 --- a/netlib/certutils.py +++ b/netlib/certutils.py @@ -43,8 +43,9 @@ def dummy_ca(path): os.makedirs(dirname) if path.endswith(".pem"): basename, _ = os.path.splitext(path) + basename = os.path.basename(basename) else: - basename = path + basename = os.path.basename(basename) key, ca = create_ca() -- cgit v1.2.3 From 63d789109a7ef0bb18e01fdf63851db86aef23bd Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Fri, 20 Jul 2012 14:43:51 +1200 Subject: close() methods for clients and servers. --- netlib/tcp.py | 34 +++++++++++++++++++++++++++++++--- 1 file changed, 31 insertions(+), 3 deletions(-) (limited to 'netlib') diff --git a/netlib/tcp.py b/netlib/tcp.py index 6ba58d86..b7f2b3bc 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -66,7 +66,7 @@ class FileLike: if v: try: return self.o.sendall(v) - except SSL.SysCallError: + except SSL.Error: raise NetLibDisconnect() def readline(self, size = None): @@ -125,6 +125,20 @@ class TCPClient: raise NetLibError('Error connecting to "%s": %s' % (self.host, err)) self.connection = connection + def close(self): + """ + Does a hard close of the socket, i.e. a shutdown, followed by a close. + """ + try: + if self.ssl_established: + self.connection.shutdown() + else: + self.connection.shutdown(socket.SHUT_RDWR) + self.connection.close() + except (socket.error, SSL.Error): + # Socket probably already closed + pass + class BaseHandler: """ @@ -170,7 +184,7 @@ class BaseHandler: self.wfile.flush() self.wfile.close() self.rfile.close() - self.connection.close() + self.close() except socket.error: # Remote has disconnected pass @@ -195,6 +209,20 @@ class BaseHandler: def handle(self): # pragma: no cover raise NotImplementedError + def close(self): + """ + Does a hard close of the socket, i.e. a shutdown, followed by a close. + """ + try: + if self.ssl_established: + self.connection.shutdown() + else: + self.connection.shutdown(socket.SHUT_RDWR) + self.connection.close() + except (socket.error, SSL.Error): + # Socket probably already closed + pass + class TCPServer: request_queue_size = 20 @@ -252,7 +280,7 @@ class TCPServer: Called when handle_connection raises an exception. """ # If a thread has persisted after interpreter exit, the module might be - # none. + # none. if traceback: exc = traceback.format_exc() print >> fp, '-'*40 -- cgit v1.2.3 From a1a1663c0fc3a1e76637a0ef3997da697ea97cfe Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Fri, 20 Jul 2012 14:45:58 +1200 Subject: Fix cert path. --- netlib/certutils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'netlib') diff --git a/netlib/certutils.py b/netlib/certutils.py index 3effe610..1f61132e 100644 --- a/netlib/certutils.py +++ b/netlib/certutils.py @@ -45,7 +45,7 @@ def dummy_ca(path): basename, _ = os.path.splitext(path) basename = os.path.basename(basename) else: - basename = os.path.basename(basename) + basename = os.path.basename(path) key, ca = create_ca() -- cgit v1.2.3 From ba53d2e4caa34df883a2cd6322d607426c97201b Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Fri, 20 Jul 2012 15:15:07 +1200 Subject: Set ssl_established right after the connection object is changed. --- netlib/tcp.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'netlib') diff --git a/netlib/tcp.py b/netlib/tcp.py index b7f2b3bc..3aee4c74 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -102,6 +102,7 @@ class TCPClient: if clientcert: context.use_certificate_file(self.clientcert) self.connection = SSL.Connection(context, self.connection) + self.ssl_established = True if sni: self.connection.set_tlsext_host_name(sni) self.connection.set_connect_state() @@ -112,7 +113,6 @@ class TCPClient: self.cert = certutils.SSLCert(self.connection.get_peer_certificate()) self.rfile = FileLike(self.connection) self.wfile = FileLike(self.connection) - self.ssl_established = True def connect(self): try: @@ -167,6 +167,7 @@ class BaseHandler: ctx.use_privatekey_file(key) ctx.use_certificate_file(cert) self.connection = SSL.Connection(ctx, self.connection) + self.ssl_established = True self.connection.set_accept_state() # SNI callback happens during do_handshake() try: @@ -175,7 +176,6 @@ class BaseHandler: raise NetLibError("SSL handshake error: %s"%str(v)) self.rfile = FileLike(self.connection) self.wfile = FileLike(self.connection) - self.ssl_established = True def finish(self): self.finished = True -- cgit v1.2.3 From 2387d2e8ed7d94e42b1ac02a4ea73f54e4c63ab8 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Sat, 21 Jul 2012 16:10:54 +1200 Subject: Timeout for TCP clients. --- netlib/tcp.py | 36 ++++++++++++++++++++++++++++-------- 1 file changed, 28 insertions(+), 8 deletions(-) (limited to 'netlib') diff --git a/netlib/tcp.py b/netlib/tcp.py index 3aee4c74..8771e789 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -1,4 +1,4 @@ -import select, socket, threading, traceback, sys +import select, socket, threading, traceback, sys, time from OpenSSL import SSL import certutils @@ -35,8 +35,8 @@ OP_TLS_ROLLBACK_BUG = SSL.OP_TLS_ROLLBACK_BUG class NetLibError(Exception): pass - class NetLibDisconnect(Exception): pass +class NetLibTimeout(Exception): pass class FileLike: @@ -47,15 +47,25 @@ class FileLike: return getattr(self.o, attr) def flush(self): - pass + if hasattr(self.o, "flush"): + self.o.flush() def read(self, length): result = '' + start = time.time() while length > 0: try: data = self.o.read(length) except (SSL.ZeroReturnError, SSL.SysCallError): break + except SSL.WantReadError: + if (time.time() - start) < self.o.gettimeout(): + time.sleep(0.1) + continue + else: + raise NetLibTimeout + except socket.timeout: + raise NetLibTimeout if not data: break result += data @@ -65,7 +75,11 @@ class FileLike: def write(self, v): if v: try: - return self.o.sendall(v) + if hasattr(self.o, "sendall"): + return self.o.sendall(v) + else: + r = self.o.write(v) + return r except SSL.Error: raise NetLibDisconnect() @@ -119,12 +133,18 @@ class TCPClient: addr = socket.gethostbyname(self.host) connection = socket.socket(socket.AF_INET, socket.SOCK_STREAM) connection.connect((addr, self.port)) - self.rfile = connection.makefile('rb', self.rbufsize) - self.wfile = connection.makefile('wb', self.wbufsize) + self.rfile = FileLike(connection.makefile('rb', self.rbufsize)) + self.wfile = FileLike(connection.makefile('wb', self.wbufsize)) except socket.error, err: raise NetLibError('Error connecting to "%s": %s' % (self.host, err)) self.connection = connection + def settimeout(self, n): + self.connection.settimeout(n) + + def gettimeout(self): + self.connection.gettimeout() + def close(self): """ Does a hard close of the socket, i.e. a shutdown, followed by a close. @@ -148,8 +168,8 @@ class BaseHandler: wbufsize = -1 def __init__(self, connection, client_address, server): self.connection = connection - self.rfile = self.connection.makefile('rb', self.rbufsize) - self.wfile = self.connection.makefile('wb', self.wbufsize) + self.rfile = FileLike(self.connection.makefile('rb', self.rbufsize)) + self.wfile = FileLike(self.connection.makefile('wb', self.wbufsize)) self.client_address = client_address self.server = server -- cgit v1.2.3 From 29f907ecf98468a89b5a7575b539938dc6741a8e Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Sat, 21 Jul 2012 17:27:23 +1200 Subject: Handle HTTP versions malformed due to non-integer major/minor numbers. --- netlib/http.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) (limited to 'netlib') diff --git a/netlib/http.py b/netlib/http.py index acd9d85e..88e66ce4 100644 --- a/netlib/http.py +++ b/netlib/http.py @@ -145,8 +145,11 @@ def parse_http_protocol(s): if "." not in version: return None major, minor = version.split('.') - major = int(major) - minor = int(minor) + try: + major = int(major) + minor = int(minor) + except ValueError: + return None return major, minor -- cgit v1.2.3 From b2c491fe3936b04b0c8e6349775bf53063c170a6 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Sat, 21 Jul 2012 17:50:21 +1200 Subject: Handle socket disconnects on reads. --- netlib/tcp.py | 2 ++ 1 file changed, 2 insertions(+) (limited to 'netlib') diff --git a/netlib/tcp.py b/netlib/tcp.py index 8771e789..ac4fab95 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -66,6 +66,8 @@ class FileLike: raise NetLibTimeout except socket.timeout: raise NetLibTimeout + except socket.error: + raise NetLibDisconnect if not data: break result += data -- cgit v1.2.3 From 619f3c6edce50a6e83b817d43ee0357cc763dd3d Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Sat, 21 Jul 2012 20:51:05 +1200 Subject: Handle unexpected SSL connection termination in readline. --- netlib/tcp.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) (limited to 'netlib') diff --git a/netlib/tcp.py b/netlib/tcp.py index ac4fab95..a68b608b 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -56,7 +56,7 @@ class FileLike: while length > 0: try: data = self.o.read(length) - except (SSL.ZeroReturnError, SSL.SysCallError): + except SSL.ZeroReturnError: break except SSL.WantReadError: if (time.time() - start) < self.o.gettimeout(): @@ -68,6 +68,8 @@ class FileLike: raise NetLibTimeout except socket.error: raise NetLibDisconnect + except SSL.SysCallError, v: + raise NetLibDisconnect if not data: break result += data @@ -82,7 +84,7 @@ class FileLike: else: r = self.o.write(v) return r - except SSL.Error: + except (SSL.Error, socket.error): raise NetLibDisconnect() def readline(self, size = None): @@ -91,7 +93,10 @@ class FileLike: while True: if size is not None and bytes_read >= size: break - ch = self.read(1) + try: + ch = self.read(1) + except NetLibDisconnect: + break bytes_read += 1 if not ch: break -- cgit v1.2.3 From ed64b0e79699681bd5db3ff2823c47a424fbc3e1 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Sun, 22 Jul 2012 12:35:16 +1200 Subject: Fix http_protocol parsing crash discovered with pathoc fuzzing. --- netlib/http.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'netlib') diff --git a/netlib/http.py b/netlib/http.py index 88e66ce4..9d6db003 100644 --- a/netlib/http.py +++ b/netlib/http.py @@ -141,10 +141,10 @@ def parse_http_protocol(s): """ if not s.startswith("HTTP/"): return None - _, version = s.split('/') + _, version = s.split('/', 1) if "." not in version: return None - major, minor = version.split('.') + major, minor = version.split('.', 1) try: major = int(major) minor = int(minor) -- cgit v1.2.3 From eb88cea3c74a253d3a08d010bfd328aa845c6d5b Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Mon, 23 Jul 2012 23:20:32 +1200 Subject: Catch an amazingly subtle SSL connection corruption bug. Closing a set of pseudo-file descriptors in the wrong order caused junk data to be written to the SSL stream. An apparent bug in OpenSSL then lets this corrupt the _next_ SSL connection. --- netlib/tcp.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) (limited to 'netlib') diff --git a/netlib/tcp.py b/netlib/tcp.py index a68b608b..66a26872 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -209,9 +209,9 @@ class BaseHandler: try: if not getattr(self.wfile, "closed", False): self.wfile.flush() + self.close() self.wfile.close() self.rfile.close() - self.close() except socket.error: # Remote has disconnected pass @@ -245,10 +245,10 @@ class BaseHandler: self.connection.shutdown() else: self.connection.shutdown(socket.SHUT_RDWR) - self.connection.close() - except (socket.error, SSL.Error): + except (socket.error, SSL.Error), v: # Socket probably already closed pass + self.connection.close() class TCPServer: -- cgit v1.2.3 From 91752990d5863526745e5c31cfb4b7459d11047e Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Tue, 24 Jul 2012 11:39:49 +1200 Subject: Handle HTTP responses that have a body but no content-length or transfer encoding We check if the server sent a connection:close header, and read till the socket closes. Closes #2 --- netlib/http.py | 37 +++++++++++++++++++++++-------------- netlib/tcp.py | 11 ++++++++--- 2 files changed, 31 insertions(+), 17 deletions(-) (limited to 'netlib') diff --git a/netlib/http.py b/netlib/http.py index 9d6db003..980d3f62 100644 --- a/netlib/http.py +++ b/netlib/http.py @@ -97,12 +97,21 @@ def read_chunked(code, fp, limit): return content -def has_chunked_encoding(headers): - for i in headers["transfer-encoding"]: +def get_header_tokens(headers, key): + """ + Retrieve all tokens for a header key. A number of different headers + follow a pattern where each header line can containe comma-separated + tokens, and headers can be set multiple times. + """ + toks = [] + for i in headers[key]: for j in i.split(","): - if j.lower() == "chunked": - return True - return False + toks.append(j.strip()) + return toks + + +def has_chunked_encoding(headers): + return "chunked" in [i.lower() for i in get_header_tokens(headers, "transfer-encoding")] def read_http_body(code, rfile, headers, all, limit): @@ -207,12 +216,11 @@ def request_connection_close(httpversion, headers): Checks the request to see if the client connection should be closed. """ if "connection" in headers: - for value in ",".join(headers['connection']).split(","): - value = value.strip() - if value == "close": - return True - elif value == "keep-alive": - return False + toks = get_header_tokens(headers, "connection") + if "close" in toks: + return True + elif "keep-alive" in toks: + return False # HTTP 1.1 connections are assumed to be persistent if httpversion == (1, 1): return False @@ -243,10 +251,11 @@ def read_http_body_request(rfile, wfile, headers, httpversion, limit): return read_http_body(400, rfile, headers, False, limit) -def read_http_body_response(rfile, headers, all, limit): +def read_http_body_response(rfile, headers, limit): """ Read the HTTP body from a server response. """ + all = "close" in get_header_tokens(headers, "connection") return read_http_body(500, rfile, headers, all, limit) @@ -267,7 +276,7 @@ def read_response(rfile, method, body_size_limit): proto, code, msg = parts httpversion = parse_http_protocol(proto) if httpversion is None: - raise HttpError(502, "Invalid HTTP version: %s"%repr(httpversion)) + raise HttpError(502, "Invalid HTTP version in line: %s"%repr(proto)) try: code = int(code) except ValueError: @@ -278,5 +287,5 @@ def read_response(rfile, method, body_size_limit): if method == "HEAD" or code == 204 or code == 304: content = "" else: - content = read_http_body_response(rfile, headers, False, body_size_limit) + content = read_http_body_response(rfile, headers, body_size_limit) return httpversion, code, msg, headers, content diff --git a/netlib/tcp.py b/netlib/tcp.py index 66a26872..7d3705da 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -40,6 +40,7 @@ class NetLibTimeout(Exception): pass class FileLike: + BLOCKSIZE = 1024 * 32 def __init__(self, o): self.o = o @@ -51,11 +52,14 @@ class FileLike: self.o.flush() def read(self, length): + """ + If length is None, we read until connection closes. + """ result = '' start = time.time() - while length > 0: + while length == -1 or length > 0: try: - data = self.o.read(length) + data = self.o.read(self.BLOCKSIZE if length == -1 else length) except SSL.ZeroReturnError: break except SSL.WantReadError: @@ -73,7 +77,8 @@ class FileLike: if not data: break result += data - length -= len(data) + if length != -1: + length -= len(data) return result def write(self, v): -- cgit v1.2.3 From 728ef107a00e7d6cef0c7d826f39a89197ddb732 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Tue, 24 Jul 2012 14:55:54 +1200 Subject: Ignore SAN entries that we don't understand. --- netlib/certutils.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) (limited to 'netlib') diff --git a/netlib/certutils.py b/netlib/certutils.py index 1f61132e..f55a096b 100644 --- a/netlib/certutils.py +++ b/netlib/certutils.py @@ -1,6 +1,7 @@ import os, ssl, hashlib, socket, time, datetime from pyasn1.type import univ, constraint, char, namedtype, tag from pyasn1.codec.der.decoder import decode +from pyasn1.error import PyAsn1Error import OpenSSL import tcp @@ -217,7 +218,10 @@ class SSLCert: for i in range(self.x509.get_extension_count()): ext = self.x509.get_extension(i) if ext.get_short_name() == "subjectAltName": - dec = decode(ext.get_data(), asn1Spec=_GeneralNames()) + try: + dec = decode(ext.get_data(), asn1Spec=_GeneralNames()) + except PyAsn1Error: + continue for i in dec[0]: altnames.append(i[0].asOctets()) return altnames -- cgit v1.2.3 From 4fb5d15f1480dd6ca86578aca2d0784bfef31dac Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Sun, 29 Jul 2012 15:53:42 +1200 Subject: Bump version. --- netlib/version.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'netlib') diff --git a/netlib/version.py b/netlib/version.py index 1c4a4b66..20460ad5 100644 --- a/netlib/version.py +++ b/netlib/version.py @@ -1,4 +1,4 @@ -IVERSION = (0, 1) +IVERSION = (0, 2) VERSION = ".".join(str(i) for i in IVERSION) NAME = "netlib" NAMEVERSION = NAME + " " + VERSION -- cgit v1.2.3 From eafa5566c27ec321131a9d83d85dab512aae7a37 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Mon, 30 Jul 2012 11:30:31 +1200 Subject: Handle disconnects on flush. --- netlib/tcp.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) (limited to 'netlib') diff --git a/netlib/tcp.py b/netlib/tcp.py index 7d3705da..e7bc79a8 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -48,8 +48,11 @@ class FileLike: return getattr(self.o, attr) def flush(self): - if hasattr(self.o, "flush"): - self.o.flush() + try: + if hasattr(self.o, "flush"): + self.o.flush() + except socket.error, v: + raise NetLibDisconnect(str(v)) def read(self, length): """ -- cgit v1.2.3 From 1c21a28e6423edf3b903191610b45345720e0458 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Mon, 30 Jul 2012 12:50:35 +1200 Subject: read_headers: handle some crashes, return None on invalid data. --- netlib/http.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) (limited to 'netlib') diff --git a/netlib/http.py b/netlib/http.py index 980d3f62..b71eb72d 100644 --- a/netlib/http.py +++ b/netlib/http.py @@ -36,8 +36,8 @@ def parse_url(url): def read_headers(fp): """ - Read a set of headers from a file pointer. Stop once a blank line - is reached. Return a ODictCaseless object. + Read a set of headers from a file pointer. Stop once a blank line is + reached. Return a ODictCaseless object, or None if headers are invalid. """ ret = [] name = '' @@ -46,6 +46,8 @@ def read_headers(fp): if not line or line == '\r\n' or line == '\n': break if line[0] in ' \t': + if not ret: + return None # continued header ret[-1][1] = ret[-1][1] + '\r\n ' + line.strip() else: @@ -55,6 +57,8 @@ def read_headers(fp): name = line[:i] value = line[i+1:].strip() ret.append([name, value]) + else: + return None return odict.ODictCaseless(ret) @@ -282,6 +286,8 @@ def read_response(rfile, method, body_size_limit): except ValueError: raise HttpError(502, "Invalid server response: %s"%repr(line)) headers = read_headers(rfile) + if headers is None: + raise HttpError(502, "Invalid headers.") if code >= 100 and code <= 199: return read_response(rfile, method, body_size_limit) if method == "HEAD" or code == 204 or code == 304: -- cgit v1.2.3 From 877a3e206263edbd8a973689b08f8c004de0225f Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Sat, 18 Aug 2012 18:14:13 +1200 Subject: Add a get_first convenience function to ODict. --- netlib/odict.py | 6 ++++++ 1 file changed, 6 insertions(+) (limited to 'netlib') diff --git a/netlib/odict.py b/netlib/odict.py index afc33caa..629fcade 100644 --- a/netlib/odict.py +++ b/netlib/odict.py @@ -80,6 +80,12 @@ class ODict: else: return d + def get_first(self, k, d=None): + if k in self: + return self[k][0] + else: + return d + def items(self): return self.lst[:] -- cgit v1.2.3 From 33557245bf2212c08cd645bcf21a73b773646607 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Thu, 23 Aug 2012 12:57:22 +1200 Subject: v0.2.1 --- netlib/version.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'netlib') diff --git a/netlib/version.py b/netlib/version.py index 20460ad5..614b87a1 100644 --- a/netlib/version.py +++ b/netlib/version.py @@ -1,4 +1,4 @@ -IVERSION = (0, 2) +IVERSION = (0, 2, 1) VERSION = ".".join(str(i) for i in IVERSION) NAME = "netlib" NAMEVERSION = NAME + " " + VERSION -- cgit v1.2.3 From 1c80c2fdd7dd9873abc7b0a74936dab7beda7c5c Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Sat, 1 Sep 2012 23:04:44 +1200 Subject: Add a collection of standard User-Agent strings. These will be used in both mitmproxy and pathod. --- netlib/http_uastrings.py | 77 ++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 77 insertions(+) create mode 100644 netlib/http_uastrings.py (limited to 'netlib') diff --git a/netlib/http_uastrings.py b/netlib/http_uastrings.py new file mode 100644 index 00000000..826c31a5 --- /dev/null +++ b/netlib/http_uastrings.py @@ -0,0 +1,77 @@ +""" + A small collection of useful user-agent header strings. These should be + kept reasonably current to reflect common usage. +""" + +# A collection of (name, shortcut, string) tuples. + +UASTRINGS = [ + ( + "android", + "a", + "Mozilla/5.0 (Linux; U; Android 4.1.1; en-gb; Nexus 7 Build/JRO03D) AFL/01.04.02" + ), + + ( + "blackberry", + "l", + "Mozilla/5.0 (BlackBerry; U; BlackBerry 9900; en) AppleWebKit/534.11+ (KHTML, like Gecko) Version/7.1.0.346 Mobile Safari/534.11+" + ), + + ( + "bingbot", + "b", + "Mozilla/5.0 (compatible; bingbot/2.0; +http://www.bing.com/bingbot.htm)" + ), + + ( + "chrome", + "c", + "Mozilla/5.0 (Windows NT 6.1; WOW64) AppleWebKit/537.1 (KHTML, like Gecko) Chrome/22.0.1207.1 Safari/537.1" + ), + + ( + "firefox", + "f", + "Mozilla/5.0 (Windows NT 6.1; Win64; x64; rv:14.0) Gecko/20120405 Firefox/14.0a1" + ), + + ( + "googlebot", + "g", + "Googlebot/2.1 (+http://www.googlebot.com/bot.html)" + ), + + ( + "ie9", + "i", + "Mozilla/5.0 (Windows; U; MSIE 9.0; WIndows NT 9.0; en-US))" + ), + + ( + "ipad", + "p", + "Mozilla/5.0 (iPad; CPU OS 5_1 like Mac OS X) AppleWebKit/534.46 (KHTML, like Gecko ) Version/5.1 Mobile/9B176 Safari/7534.48.3" + ), + + ( + "iphone", + "h", + "Mozilla/5.0 (iPhone; CPU iPhone OS 4_2_1 like Mac OS X) AppleWebKit/533.17.9 (KHTML, like Gecko) Version/5.0.2 Mobile/8C148a Safari/6533.18.5", + ), + + ( + "safari", + "s", + "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_7_3) AppleWebKit/534.55.3 (KHTML, like Gecko) Version/5.1.3 Safari/534.53.10" + ) +] + + +def get_by_shortcut(s): + """ + Retrieve a user agent entry by shortcut. + """ + for i in UASTRINGS: + if s == i[1]: + return i -- cgit v1.2.3 From 8a6cca530c5293aa2b77edd3bf928540ec771928 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Mon, 24 Sep 2012 10:47:41 +1200 Subject: Don't create fresh FileLike objects when converting to SSL --- netlib/tcp.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) (limited to 'netlib') diff --git a/netlib/tcp.py b/netlib/tcp.py index e7bc79a8..0fed7380 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -44,6 +44,9 @@ class FileLike: def __init__(self, o): self.o = o + def set_descriptor(self, o): + self.o = o + def __getattr__(self, attr): return getattr(self.o, attr) @@ -140,8 +143,8 @@ class TCPClient: except SSL.Error, v: raise NetLibError("SSL handshake error: %s"%str(v)) self.cert = certutils.SSLCert(self.connection.get_peer_certificate()) - self.rfile = FileLike(self.connection) - self.wfile = FileLike(self.connection) + self.rfile.set_descriptor(self.connection) + self.wfile.set_descriptor(self.connection) def connect(self): try: @@ -209,8 +212,8 @@ class BaseHandler: self.connection.do_handshake() except SSL.Error, v: raise NetLibError("SSL handshake error: %s"%str(v)) - self.rfile = FileLike(self.connection) - self.wfile = FileLike(self.connection) + self.rfile.set_descriptor(self.connection) + self.wfile.set_descriptor(self.connection) def finish(self): self.finished = True -- cgit v1.2.3 From 3a21e28bf13b5710639337fdc29741e9b6b71405 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Mon, 24 Sep 2012 11:10:21 +1200 Subject: Split FileLike into Writer and Reader, and add logging functionality. --- netlib/tcp.py | 69 +++++++++++++++++++++++++++++++++++++++++++++-------------- 1 file changed, 53 insertions(+), 16 deletions(-) (limited to 'netlib') diff --git a/netlib/tcp.py b/netlib/tcp.py index 0fed7380..e1318435 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -39,10 +39,11 @@ class NetLibDisconnect(Exception): pass class NetLibTimeout(Exception): pass -class FileLike: +class _FileLike: BLOCKSIZE = 1024 * 32 def __init__(self, o): self.o = o + self._log = None def set_descriptor(self, o): self.o = o @@ -50,6 +51,37 @@ class FileLike: def __getattr__(self, attr): return getattr(self.o, attr) + def start_log(self): + """ + Starts or resets the log. + + This will store all bytes read or written. + """ + self._log = [] + + def stop_log(self): + """ + Stops the log. + """ + self._log = None + + def is_logging(self): + return self._log is not None + + def get_log(self): + """ + Returns the log as a string. + """ + if not self.is_logging(): + raise ValueError("Not logging!") + return "".join(self._log) + + def add_log(self, v): + if self.is_logging(): + self._log.append(v) + + +class Writer(_FileLike): def flush(self): try: if hasattr(self.o, "flush"): @@ -57,6 +89,21 @@ class FileLike: except socket.error, v: raise NetLibDisconnect(str(v)) + def write(self, v): + if v: + try: + if hasattr(self.o, "sendall"): + self.add_log(v) + return self.o.sendall(v) + else: + r = self.o.write(v) + self.add_log(v[:r]) + return r + except (SSL.Error, socket.error): + raise NetLibDisconnect() + + +class Reader(_FileLike): def read(self, length): """ If length is None, we read until connection closes. @@ -85,19 +132,9 @@ class FileLike: result += data if length != -1: length -= len(data) + self.add_log(result) return result - def write(self, v): - if v: - try: - if hasattr(self.o, "sendall"): - return self.o.sendall(v) - else: - r = self.o.write(v) - return r - except (SSL.Error, socket.error): - raise NetLibDisconnect() - def readline(self, size = None): result = '' bytes_read = 0 @@ -151,8 +188,8 @@ class TCPClient: addr = socket.gethostbyname(self.host) connection = socket.socket(socket.AF_INET, socket.SOCK_STREAM) connection.connect((addr, self.port)) - self.rfile = FileLike(connection.makefile('rb', self.rbufsize)) - self.wfile = FileLike(connection.makefile('wb', self.wbufsize)) + self.rfile = Reader(connection.makefile('rb', self.rbufsize)) + self.wfile = Writer(connection.makefile('wb', self.wbufsize)) except socket.error, err: raise NetLibError('Error connecting to "%s": %s' % (self.host, err)) self.connection = connection @@ -186,8 +223,8 @@ class BaseHandler: wbufsize = -1 def __init__(self, connection, client_address, server): self.connection = connection - self.rfile = FileLike(self.connection.makefile('rb', self.rbufsize)) - self.wfile = FileLike(self.connection.makefile('wb', self.wbufsize)) + self.rfile = Reader(self.connection.makefile('rb', self.rbufsize)) + self.wfile = Writer(self.connection.makefile('wb', self.wbufsize)) self.client_address = client_address self.server = server -- cgit v1.2.3 From b308824193342c11c88b8bad2645a5b09efcf48f Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Mon, 24 Sep 2012 11:21:48 +1200 Subject: Create netlib.utils, move cleanBin and hexdump from libmproxy.utils. --- netlib/utils.py | 36 ++++++++++++++++++++++++++++++++++++ 1 file changed, 36 insertions(+) create mode 100644 netlib/utils.py (limited to 'netlib') diff --git a/netlib/utils.py b/netlib/utils.py new file mode 100644 index 00000000..ea749545 --- /dev/null +++ b/netlib/utils.py @@ -0,0 +1,36 @@ + +def cleanBin(s, fixspacing=False): + """ + Cleans binary data to make it safe to display. If fixspacing is True, + tabs, newlines and so forth will be maintained, if not, they will be + replaced with a placeholder. + """ + parts = [] + for i in s: + o = ord(i) + if (o > 31 and o < 127): + parts.append(i) + elif i in "\n\r\t" and not fixspacing: + parts.append(i) + else: + parts.append(".") + return "".join(parts) + + +def hexdump(s): + """ + Returns a set of tuples: + (offset, hex, str) + """ + parts = [] + for i in range(0, len(s), 16): + o = "%.10x"%i + part = s[i:i+16] + x = " ".join("%.2x"%ord(i) for i in part) + if len(part) < 16: + x += " " + x += " ".join(" " for i in range(16 - len(part))) + parts.append( + (o, x, cleanBin(part, True)) + ) + return parts -- cgit v1.2.3 From 064b4c80018d9b76c2bedc010ab45c8b9ea7faa3 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Thu, 27 Sep 2012 10:59:46 +1200 Subject: Make cleanBin escape carriage returns. We get confusing output on terminals if we leave \r unescaped. --- netlib/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'netlib') diff --git a/netlib/utils.py b/netlib/utils.py index ea749545..7621a1dc 100644 --- a/netlib/utils.py +++ b/netlib/utils.py @@ -10,7 +10,7 @@ def cleanBin(s, fixspacing=False): o = ord(i) if (o > 31 and o < 127): parts.append(i) - elif i in "\n\r\t" and not fixspacing: + elif i in "\n\t" and not fixspacing: parts.append(i) else: parts.append(".") -- cgit v1.2.3 From 15679e010d99def2fb7efd1de5533099a12772ca Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Mon, 1 Oct 2012 11:30:02 +1300 Subject: Add a settimeout method to tcp.BaseHandler. --- netlib/tcp.py | 3 +++ 1 file changed, 3 insertions(+) (limited to 'netlib') diff --git a/netlib/tcp.py b/netlib/tcp.py index e1318435..414c1237 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -284,6 +284,9 @@ class BaseHandler: def handle(self): # pragma: no cover raise NotImplementedError + def settimeout(self, n): + self.connection.settimeout(n) + def close(self): """ Does a hard close of the socket, i.e. a shutdown, followed by a close. -- cgit v1.2.3 From 77869634e20ae5a2646d7455e499866e9cfafbab Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Tue, 9 Oct 2012 16:25:15 +1300 Subject: Limit reads to block length. --- netlib/tcp.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) (limited to 'netlib') diff --git a/netlib/tcp.py b/netlib/tcp.py index 414c1237..f8f877de 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -106,13 +106,17 @@ class Writer(_FileLike): class Reader(_FileLike): def read(self, length): """ - If length is None, we read until connection closes. + If length is -1, we read until connection closes. """ result = '' start = time.time() while length == -1 or length > 0: + if length == -1 or length > self.BLOCKSIZE: + rlen = self.BLOCKSIZE + else: + rlen = length try: - data = self.o.read(self.BLOCKSIZE if length == -1 else length) + data = self.o.read(rlen) except SSL.ZeroReturnError: break except SSL.WantReadError: -- cgit v1.2.3 From 6517d9e717883bc3cd0eb361e2aa0f58259cae60 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Sun, 14 Oct 2012 09:03:23 +1300 Subject: More info on disconnect exception. --- netlib/tcp.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'netlib') diff --git a/netlib/tcp.py b/netlib/tcp.py index f8f877de..7656e398 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -99,8 +99,8 @@ class Writer(_FileLike): r = self.o.write(v) self.add_log(v[:r]) return r - except (SSL.Error, socket.error): - raise NetLibDisconnect() + except (SSL.Error, socket.error), v: + raise NetLibDisconnect(str(v)) class Reader(_FileLike): -- cgit v1.2.3 From f8e10bd6ae1adba0897669bb8b90b9180150350a Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Wed, 31 Oct 2012 22:24:45 +1300 Subject: Bump version. --- netlib/version.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'netlib') diff --git a/netlib/version.py b/netlib/version.py index 614b87a1..30a4c0f9 100644 --- a/netlib/version.py +++ b/netlib/version.py @@ -1,4 +1,4 @@ -IVERSION = (0, 2, 1) +IVERSION = (0, 2, 2) VERSION = ".".join(str(i) for i in IVERSION) NAME = "netlib" NAMEVERSION = NAME + " " + VERSION -- cgit v1.2.3 From 043d05bcdeae482ca1d9b80375a1922e54896a6b Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Wed, 5 Dec 2012 04:03:39 +0100 Subject: add __iter__ for odict --- netlib/odict.py | 3 +++ 1 file changed, 3 insertions(+) (limited to 'netlib') diff --git a/netlib/odict.py b/netlib/odict.py index 629fcade..bddb3877 100644 --- a/netlib/odict.py +++ b/netlib/odict.py @@ -22,6 +22,9 @@ class ODict: def __eq__(self, other): return self.lst == other.lst + def __iter__(self): + return self.lst.__iter__() + def __getitem__(self, k): """ Returns a list of values matching key. -- cgit v1.2.3 From ddc08efde1a5132734f1f06481a97e484cc368e3 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Fri, 4 Jan 2013 14:23:52 +1300 Subject: Minor cleanup of http.parse_init* methods. --- netlib/http.py | 39 ++++++++++++++++++++++----------------- 1 file changed, 22 insertions(+), 17 deletions(-) (limited to 'netlib') diff --git a/netlib/http.py b/netlib/http.py index b71eb72d..3f730a1a 100644 --- a/netlib/http.py +++ b/netlib/http.py @@ -166,36 +166,43 @@ def parse_http_protocol(s): return major, minor -def parse_init_connect(line): +def parse_init(line): try: method, url, protocol = string.split(line) except ValueError: return None - if method != 'CONNECT': + httpversion = parse_http_protocol(protocol) + if not httpversion: + return None + return method, url, httpversion + + +def parse_init_connect(line): + v = parse_init(line) + if not v: + return None + method, url, httpversion = v + + if method.upper() != 'CONNECT': return None try: host, port = url.split(":") except ValueError: return None port = int(port) - httpversion = parse_http_protocol(protocol) - if not httpversion: - return None return host, port, httpversion def parse_init_proxy(line): - try: - method, url, protocol = string.split(line) - except ValueError: + v = parse_init(line) + if not v: return None + method, url, httpversion = v + parts = parse_url(url) if not parts: return None scheme, host, port, path = parts - httpversion = parse_http_protocol(protocol) - if not httpversion: - return None return method, scheme, host, port, path, httpversion @@ -203,15 +210,13 @@ def parse_init_http(line): """ Returns (method, url, httpversion) """ - try: - method, url, protocol = string.split(line) - except ValueError: + v = parse_init(line) + if not v: return None + method, url, httpversion = v + if not (url.startswith("/") or url == "*"): return None - httpversion = parse_http_protocol(protocol) - if not httpversion: - return None return method, url, httpversion -- cgit v1.2.3 From d3b46feb6011c106b42d297b1a4807d187991345 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Sat, 5 Jan 2013 20:06:55 +1300 Subject: Handle non-integer port error in parse_init_connect correctly --- netlib/http.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) (limited to 'netlib') diff --git a/netlib/http.py b/netlib/http.py index 3f730a1a..076baf87 100644 --- a/netlib/http.py +++ b/netlib/http.py @@ -189,7 +189,10 @@ def parse_init_connect(line): host, port = url.split(":") except ValueError: return None - port = int(port) + try: + port = int(port) + except ValueError: + return None return host, port, httpversion -- cgit v1.2.3 From 72032d7fe75fae1bc1318cf0390e55af6a93ff4d Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Sun, 6 Jan 2013 01:15:53 +1300 Subject: Basic certificate store implementation and cert utils API cleanup. --- netlib/certutils.py | 72 +++++++++++++++++++++++++++++++++++++++-------------- 1 file changed, 53 insertions(+), 19 deletions(-) (limited to 'netlib') diff --git a/netlib/certutils.py b/netlib/certutils.py index f55a096b..51fd9da9 100644 --- a/netlib/certutils.py +++ b/netlib/certutils.py @@ -1,4 +1,4 @@ -import os, ssl, hashlib, socket, time, datetime +import os, ssl, hashlib, socket, time, datetime, tempfile, shutil from pyasn1.type import univ, constraint, char, namedtype, tag from pyasn1.codec.der.decoder import decode from pyasn1.error import PyAsn1Error @@ -76,30 +76,24 @@ def dummy_ca(path): return True -def dummy_cert(certdir, ca, commonname, sans): +def dummy_cert(fp, ca, commonname, sans): """ - certdir: Certificate directory. + Generates and writes a certificate to fp. + ca: Path to the certificate authority file, or None. commonname: Common name for the generated certificate. + sans: A list of Subject Alternate Names. Returns cert path if operation succeeded, None if not. """ - namehash = hashlib.sha256(commonname).hexdigest() - certpath = os.path.join(certdir, namehash + ".pem") - if os.path.exists(certpath): - return certpath - ss = [] for i in sans: ss.append("DNS: %s"%i) ss = ", ".join(ss) - if ca: - raw = file(ca, "r").read() - ca = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, raw) - key = OpenSSL.crypto.load_privatekey(OpenSSL.crypto.FILETYPE_PEM, raw) - else: - key, ca = create_ca() + raw = file(ca, "r").read() + ca = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, raw) + key = OpenSSL.crypto.load_privatekey(OpenSSL.crypto.FILETYPE_PEM, raw) req = OpenSSL.crypto.X509Req() subj = req.get_subject() @@ -110,7 +104,7 @@ def dummy_cert(certdir, ca, commonname, sans): req.add_extensions([OpenSSL.crypto.X509Extension("subjectAltName", True, ss)]) cert = OpenSSL.crypto.X509() - cert.gmtime_adj_notBefore(0) + cert.gmtime_adj_notBefore() cert.gmtime_adj_notAfter(60 * 60 * 24 * 30) cert.set_issuer(ca.get_subject()) cert.set_subject(req.get_subject()) @@ -120,11 +114,51 @@ def dummy_cert(certdir, ca, commonname, sans): cert.set_pubkey(req.get_pubkey()) cert.sign(key, "sha1") - f = open(certpath, "w") - f.write(OpenSSL.crypto.dump_certificate(OpenSSL.crypto.FILETYPE_PEM, cert)) - f.close() + fp.write(OpenSSL.crypto.dump_certificate(OpenSSL.crypto.FILETYPE_PEM, cert)) + fp.close() - return certpath + + +class CertStore: + """ + Implements an on-disk certificate store. + """ + def __init__(self, certdir=None): + """ + certdir: The certificate store directory. If None, a temporary + directory will be created, and destroyed when the .cleanup() method + is called. + """ + if certdir: + self.remove = False + self.certdir = certdir + else: + self.remove = True + self.certdir = tempfile.mkdtemp(prefix="certstore") + + def get_cert(self, commonname, sans, cacert=False): + """ + Returns the path to a certificate. + + commonname: Common name for the generated certificate. Must be a + valid, plain-ASCII, IDNA-encoded domain name. + + sans: A list of Subject Alternate Names. + + cacert: An optional path to a CA certificate. If specified, the + cert is created if it does not exist, else return None. + """ + certpath = os.path.join(self.certdir, commonname + ".pem") + if os.path.exists(certpath): + return certpath + elif cacert: + f = open(certpath, "w") + dummy_cert(f, cacert, commonname, sans) + return certpath + + def cleanup(self): + if self.remove: + shutil.rmtree(self.certdir) class _GeneralName(univ.Choice): -- cgit v1.2.3 From 91834ea78f36e1e89d4f19ecdddef83b0286b4d4 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Sun, 6 Jan 2013 01:16:58 +1300 Subject: Generate certificates with a commencement date an hour in the past. This helps smooth over small discrepancies in client and server times, where it's possible for a certificate to seem to be "in the future" to the client. --- netlib/certutils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'netlib') diff --git a/netlib/certutils.py b/netlib/certutils.py index 51fd9da9..87d9d5d8 100644 --- a/netlib/certutils.py +++ b/netlib/certutils.py @@ -104,7 +104,7 @@ def dummy_cert(fp, ca, commonname, sans): req.add_extensions([OpenSSL.crypto.X509Extension("subjectAltName", True, ss)]) cert = OpenSSL.crypto.X509() - cert.gmtime_adj_notBefore() + cert.gmtime_adj_notBefore(-3600) cert.gmtime_adj_notAfter(60 * 60 * 24 * 30) cert.set_issuer(ca.get_subject()) cert.set_subject(req.get_subject()) -- cgit v1.2.3 From e4acace8ea741af798523d6ff1d148d129f23582 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Sun, 6 Jan 2013 01:34:39 +1300 Subject: Sanity-check certstore common names. --- netlib/certutils.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) (limited to 'netlib') diff --git a/netlib/certutils.py b/netlib/certutils.py index 87d9d5d8..3fd57b2b 100644 --- a/netlib/certutils.py +++ b/netlib/certutils.py @@ -136,6 +136,18 @@ class CertStore: self.remove = True self.certdir = tempfile.mkdtemp(prefix="certstore") + def check_domain(self, commonname): + try: + commonname.decode("idna") + commonname.decode("ascii") + except: + return False + if ".." in commonname: + return False + if "/" in commonname: + return False + return True + def get_cert(self, commonname, sans, cacert=False): """ Returns the path to a certificate. @@ -147,7 +159,11 @@ class CertStore: cacert: An optional path to a CA certificate. If specified, the cert is created if it does not exist, else return None. + + Return None if the certificate could not be found or generated. """ + if not self.check_domain(commonname): + return None certpath = os.path.join(self.certdir, commonname + ".pem") if os.path.exists(certpath): return certpath -- cgit v1.2.3 From 10457e876ad6db9c66973c925b7e65f2a16ffbca Mon Sep 17 00:00:00 2001 From: Israel Nir Date: Thu, 10 Jan 2013 15:51:37 +0200 Subject: adding read timestamp to enable better resolution of when certain reads were performed (timestamp is updated when the first byte is available on the network) --- netlib/tcp.py | 4 ++++ 1 file changed, 4 insertions(+) (limited to 'netlib') diff --git a/netlib/tcp.py b/netlib/tcp.py index 7656e398..76fb7ca0 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -44,6 +44,7 @@ class _FileLike: def __init__(self, o): self.o = o self._log = None + self.timestamp = None def set_descriptor(self, o): self.o = o @@ -80,6 +81,8 @@ class _FileLike: if self.is_logging(): self._log.append(v) + def reset_timestamp(self): + self.timestamp = None class Writer(_FileLike): def flush(self): @@ -131,6 +134,7 @@ class Reader(_FileLike): raise NetLibDisconnect except SSL.SysCallError, v: raise NetLibDisconnect + self.timestamp = self.timestamp or time.time() if not data: break result += data -- cgit v1.2.3 From 04048b4c73f477f11d41788366eddffaae6bbb20 Mon Sep 17 00:00:00 2001 From: Rouli Date: Wed, 16 Jan 2013 22:30:19 +0200 Subject: renaming the timestamp in preparation of other timestamps that will be added later, adding tests --- netlib/tcp.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) (limited to 'netlib') diff --git a/netlib/tcp.py b/netlib/tcp.py index 76fb7ca0..9c5cfa64 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -44,7 +44,7 @@ class _FileLike: def __init__(self, o): self.o = o self._log = None - self.timestamp = None + self.first_byte_timestamp = None def set_descriptor(self, o): self.o = o @@ -81,8 +81,8 @@ class _FileLike: if self.is_logging(): self._log.append(v) - def reset_timestamp(self): - self.timestamp = None + def reset_timestamps(self): + self.first_byte_timestamp = None class Writer(_FileLike): def flush(self): @@ -134,7 +134,7 @@ class Reader(_FileLike): raise NetLibDisconnect except SSL.SysCallError, v: raise NetLibDisconnect - self.timestamp = self.timestamp or time.time() + self.first_byte_timestamp = self.first_byte_timestamp or time.time() if not data: break result += data -- cgit v1.2.3 From 1499529e62e6d2892a6908472398854094af89fb Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Fri, 18 Jan 2013 17:07:35 +1300 Subject: Fix client cert typo. --- netlib/tcp.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'netlib') diff --git a/netlib/tcp.py b/netlib/tcp.py index 9c5cfa64..afb7e059 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -177,7 +177,7 @@ class TCPClient: if not options is None: ctx.set_options(options) if clientcert: - context.use_certificate_file(self.clientcert) + context.use_certificate_file(clientcert) self.connection = SSL.Connection(context, self.connection) self.ssl_established = True if sni: -- cgit v1.2.3 From 00d20abdd4863d15fdda826615dab264c8e14d4a Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Sun, 20 Jan 2013 22:13:38 +1300 Subject: Beef up client certificate handling substantially. --- netlib/certutils.py | 6 +++--- netlib/tcp.py | 10 +++++++++- 2 files changed, 12 insertions(+), 4 deletions(-) (limited to 'netlib') diff --git a/netlib/certutils.py b/netlib/certutils.py index 3fd57b2b..e1407936 100644 --- a/netlib/certutils.py +++ b/netlib/certutils.py @@ -256,11 +256,11 @@ class SSLCert: @property def cn(self): - cn = None + c = None for i in self.subject: if i[0] == "CN": - cn = i[1] - return cn + c = i[1] + return c @property def altnames(self): diff --git a/netlib/tcp.py b/netlib/tcp.py index afb7e059..4b547d1f 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -173,10 +173,14 @@ class TCPClient: self.ssl_established = False def convert_to_ssl(self, clientcert=None, sni=None, method=TLSv1_METHOD, options=None): + """ + clientcert: Path to a file containing both client cert and private key. + """ context = SSL.Context(method) if not options is None: ctx.set_options(options) if clientcert: + context.use_privatekey_file(clientcert) context.use_certificate_file(clientcert) self.connection = SSL.Connection(context, self.connection) self.ssl_established = True @@ -238,6 +242,7 @@ class BaseHandler: self.server = server self.finished = False self.ssl_established = False + self.clientcert = None def convert_to_ssl(self, cert, key, method=SSLv23_METHOD, options=None): """ @@ -246,13 +251,16 @@ class BaseHandler: ctx = SSL.Context(method) if not options is None: ctx.set_options(options) + # SNI callback happens during do_handshake() ctx.set_tlsext_servername_callback(self.handle_sni) ctx.use_privatekey_file(key) ctx.use_certificate_file(cert) + def ver(*args): + self.clientcert = certutils.SSLCert(args[1]) + ctx.set_verify(SSL.VERIFY_PEER, ver) self.connection = SSL.Connection(ctx, self.connection) self.ssl_established = True self.connection.set_accept_state() - # SNI callback happens during do_handshake() try: self.connection.do_handshake() except SSL.Error, v: -- cgit v1.2.3 From 7248a22d5e381dd57d69c06f8e67e60fd55e55ba Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Sun, 20 Jan 2013 22:36:54 +1300 Subject: Improve error signalling for client certificates. --- netlib/tcp.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) (limited to 'netlib') diff --git a/netlib/tcp.py b/netlib/tcp.py index 4b547d1f..d0ca09f3 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -177,11 +177,14 @@ class TCPClient: clientcert: Path to a file containing both client cert and private key. """ context = SSL.Context(method) - if not options is None: + if options is not None: ctx.set_options(options) if clientcert: - context.use_privatekey_file(clientcert) - context.use_certificate_file(clientcert) + try: + context.use_privatekey_file(clientcert) + context.use_certificate_file(clientcert) + except SSL.Error, v: + raise NetLibError("SSL client certificate error: %s"%str(v)) self.connection = SSL.Connection(context, self.connection) self.ssl_established = True if sni: -- cgit v1.2.3 From 2eb6651e5180035cd3e17f9048b16ea38719a9ac Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Fri, 25 Jan 2013 15:54:41 +1300 Subject: Extract TCP test utilities into netlib.test --- netlib/tcp.py | 11 +++++----- netlib/test.py | 67 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 73 insertions(+), 5 deletions(-) create mode 100644 netlib/test.py (limited to 'netlib') diff --git a/netlib/tcp.py b/netlib/tcp.py index d0ca09f3..56cc0dea 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -1,4 +1,4 @@ -import select, socket, threading, traceback, sys, time +import select, socket, threading, sys, time, traceback from OpenSSL import SSL import certutils @@ -84,13 +84,14 @@ class _FileLike: def reset_timestamps(self): self.first_byte_timestamp = None + class Writer(_FileLike): def flush(self): - try: - if hasattr(self.o, "flush"): + if hasattr(self.o, "flush"): + try: self.o.flush() - except socket.error, v: - raise NetLibDisconnect(str(v)) + except socket.error, v: + raise NetLibDisconnect(str(v)) def write(self, v): if v: diff --git a/netlib/test.py b/netlib/test.py new file mode 100644 index 00000000..2f72f979 --- /dev/null +++ b/netlib/test.py @@ -0,0 +1,67 @@ +import threading, Queue, cStringIO +import tcp + +class ServerThread(threading.Thread): + def __init__(self, server): + self.server = server + threading.Thread.__init__(self) + + def run(self): + self.server.serve_forever() + + def shutdown(self): + self.server.shutdown() + + +class ServerTestBase: + @classmethod + def setupAll(cls): + cls.q = Queue.Queue() + s = cls.makeserver() + cls.port = s.port + cls.server = ServerThread(s) + cls.server.start() + + @classmethod + def teardownAll(cls): + cls.server.shutdown() + + + @property + def last_handler(self): + return self.server.server.last_handler + + +class TServer(tcp.TCPServer): + def __init__(self, ssl, q, handler_klass, addr=("127.0.0.1", 0)): + """ + ssl: A {cert, key, v3_only} dict. + """ + tcp.TCPServer.__init__(self, addr) + self.ssl, self.q = ssl, q + self.handler_klass = handler_klass + self.last_handler = None + + def handle_connection(self, request, client_address): + h = self.handler_klass(request, client_address, self) + self.last_handler = h + if self.ssl: + if self.ssl["v3_only"]: + method = tcp.SSLv3_METHOD + options = tcp.OP_NO_SSLv2|tcp.OP_NO_TLSv1 + else: + method = tcp.SSLv23_METHOD + options = None + h.convert_to_ssl( + self.ssl["cert"], + self.ssl["key"], + method = method, + options = options, + ) + h.handle() + h.finish() + + def handle_error(self, request, client_address): + s = cStringIO.StringIO() + tcp.TCPServer.handle_error(self, request, client_address, s) + self.q.put(s.getvalue()) -- cgit v1.2.3 From cc4867064be42409fd5fb8271901b03029b787de Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Fri, 25 Jan 2013 16:03:59 +1300 Subject: Streamline netlib.test API --- netlib/test.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) (limited to 'netlib') diff --git a/netlib/test.py b/netlib/test.py index 2f72f979..7d24d80e 100644 --- a/netlib/test.py +++ b/netlib/test.py @@ -14,6 +14,8 @@ class ServerThread(threading.Thread): class ServerTestBase: + ssl = None + handler = None @classmethod def setupAll(cls): cls.q = Queue.Queue() @@ -22,11 +24,14 @@ class ServerTestBase: cls.server = ServerThread(s) cls.server.start() + @classmethod + def makeserver(cls): + return TServer(cls.ssl, cls.q, cls.handler) + @classmethod def teardownAll(cls): cls.server.shutdown() - @property def last_handler(self): return self.server.server.last_handler -- cgit v1.2.3 From e5b125eec8e732112af9884cf3ab35377913303a Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Sat, 26 Jan 2013 21:19:35 +1300 Subject: Introduce the mock module to improve unit tests. There are a few socket corner-cases that are incredibly hard to reproduce in a unit test suite, so we use mock to trigger the exceptions instead. --- netlib/tcp.py | 6 ++++++ 1 file changed, 6 insertions(+) (limited to 'netlib') diff --git a/netlib/tcp.py b/netlib/tcp.py index 56cc0dea..a79f3ac4 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -87,6 +87,9 @@ class _FileLike: class Writer(_FileLike): def flush(self): + """ + May raise NetLibDisconnect + """ if hasattr(self.o, "flush"): try: self.o.flush() @@ -94,6 +97,9 @@ class Writer(_FileLike): raise NetLibDisconnect(str(v)) def write(self, v): + """ + May raise NetLibDisconnect + """ if v: try: if hasattr(self.o, "sendall"): -- cgit v1.2.3 From 7433dfceae3b2ac7e709fbcedd9e298800d2ac1b Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Sat, 26 Jan 2013 21:29:45 +1300 Subject: Bump unit tests, fix two serious wee buglets discovered. --- netlib/tcp.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'netlib') diff --git a/netlib/tcp.py b/netlib/tcp.py index a79f3ac4..40bd4bde 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -185,7 +185,7 @@ class TCPClient: """ context = SSL.Context(method) if options is not None: - ctx.set_options(options) + context.set_options(options) if clientcert: try: context.use_privatekey_file(clientcert) @@ -220,7 +220,7 @@ class TCPClient: self.connection.settimeout(n) def gettimeout(self): - self.connection.gettimeout() + return self.connection.gettimeout() def close(self): """ -- cgit v1.2.3 From 7d185356655fa2f40c452c273a3cd039360d20c1 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Sun, 27 Jan 2013 19:21:18 +1300 Subject: 100% test coverage --- netlib/tcp.py | 21 +++++++-------------- 1 file changed, 7 insertions(+), 14 deletions(-) (limited to 'netlib') diff --git a/netlib/tcp.py b/netlib/tcp.py index 40bd4bde..556f97ac 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -355,20 +355,13 @@ class TCPServer: while not self.__shutdown_request: r, w, e = select.select([self.socket], [], [], poll_interval) if self.socket in r: - try: - request, client_address = self.socket.accept() - except socket.error: - return - try: - t = threading.Thread( - target = self.request_thread, - args = (request, client_address) - ) - t.setDaemon(1) - t.start() - except: - self.handle_error(request, client_address) - request.close() + request, client_address = self.socket.accept() + t = threading.Thread( + target = self.request_thread, + args = (request, client_address) + ) + t.setDaemon(1) + t.start() finally: self.__shutdown_request = False self.__is_shut_down.set() -- cgit v1.2.3 From c6f9a2d74dc0b2d9185743a02e4c1410983f0c3f Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Sun, 24 Feb 2013 11:08:43 +1300 Subject: More accurate description of an HTTP read error, make pyflakes happy. --- netlib/certutils.py | 2 +- netlib/http.py | 2 +- netlib/tcp.py | 4 ++-- netlib/wsgi.py | 8 ++++---- 4 files changed, 8 insertions(+), 8 deletions(-) (limited to 'netlib') diff --git a/netlib/certutils.py b/netlib/certutils.py index e1407936..b3ba1dcf 100644 --- a/netlib/certutils.py +++ b/netlib/certutils.py @@ -1,4 +1,4 @@ -import os, ssl, hashlib, socket, time, datetime, tempfile, shutil +import os, ssl, time, datetime, tempfile, shutil from pyasn1.type import univ, constraint, char, namedtype, tag from pyasn1.codec.der.decoder import decode from pyasn1.error import PyAsn1Error diff --git a/netlib/http.py b/netlib/http.py index 076baf87..29bcf43d 100644 --- a/netlib/http.py +++ b/netlib/http.py @@ -279,7 +279,7 @@ def read_response(rfile, method, body_size_limit): if line == "\r\n" or line == "\n": # Possible leftover from previous message line = rfile.readline() if not line: - raise HttpError(502, "Blank server response.") + raise HttpError(502, "Server disconnect.") parts = line.strip().split(" ", 2) if len(parts) == 2: # handle missing message gracefully parts.append("") diff --git a/netlib/tcp.py b/netlib/tcp.py index 556f97ac..0a15d2ac 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -139,7 +139,7 @@ class Reader(_FileLike): raise NetLibTimeout except socket.error: raise NetLibDisconnect - except SSL.SysCallError, v: + except SSL.SysCallError: raise NetLibDisconnect self.first_byte_timestamp = self.first_byte_timestamp or time.time() if not data: @@ -322,7 +322,7 @@ class BaseHandler: self.connection.shutdown() else: self.connection.shutdown(socket.SHUT_RDWR) - except (socket.error, SSL.Error), v: + except (socket.error, SSL.Error): # Socket probably already closed pass self.connection.close() diff --git a/netlib/wsgi.py b/netlib/wsgi.py index 4fa2c537..dffc2ace 100644 --- a/netlib/wsgi.py +++ b/netlib/wsgi.py @@ -1,4 +1,4 @@ -import cStringIO, urllib, time, sys, traceback +import cStringIO, urllib, time, traceback import odict @@ -128,13 +128,13 @@ class WSGIAdaptor: write(i) if not state["headers_sent"]: write("") - except Exception, v: + except Exception: try: s = traceback.format_exc() errs.write(s) self.error_page(soc, state["headers_sent"], s) - except Exception, v: # pragma: no cover - pass # pragma: no cover + except Exception: # pragma: no cover + pass return errs.getvalue() -- cgit v1.2.3 From 97e11a219fb2a752d5b726b203874101d7ab651c Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Sun, 24 Feb 2013 15:36:15 +1300 Subject: Housekeeping and cleanup, some minor argument name changes. --- netlib/certutils.py | 1 - netlib/http.py | 9 ++++++--- netlib/tcp.py | 10 +++++----- 3 files changed, 11 insertions(+), 9 deletions(-) (limited to 'netlib') diff --git a/netlib/certutils.py b/netlib/certutils.py index b3ba1dcf..859c93f1 100644 --- a/netlib/certutils.py +++ b/netlib/certutils.py @@ -118,7 +118,6 @@ def dummy_cert(fp, ca, commonname, sans): fp.close() - class CertStore: """ Implements an on-disk certificate store. diff --git a/netlib/http.py b/netlib/http.py index 29bcf43d..58993686 100644 --- a/netlib/http.py +++ b/netlib/http.py @@ -9,6 +9,9 @@ class HttpError(Exception): return "HttpError(%s, %s)"%(self.code, self.msg) +class HttpErrorConnClosed(HttpError): pass + + def parse_url(url): """ Returns a (scheme, host, port, path) tuple, or None on error. @@ -73,7 +76,7 @@ def read_chunked(code, fp, limit): while 1: line = fp.readline(128) if line == "": - raise HttpError(code, "Connection closed prematurely") + raise HttpErrorConnClosed(code, "Connection closed prematurely") if line != '\r\n' and line != '\n': try: length = int(line, 16) @@ -95,7 +98,7 @@ def read_chunked(code, fp, limit): while 1: line = fp.readline() if line == "": - raise HttpError(code, "Connection closed prematurely") + raise HttpErrorConnClosed(code, "Connection closed prematurely") if line == '\r\n' or line == '\n': break return content @@ -279,7 +282,7 @@ def read_response(rfile, method, body_size_limit): if line == "\r\n" or line == "\n": # Possible leftover from previous message line = rfile.readline() if not line: - raise HttpError(502, "Server disconnect.") + raise HttpErrorConnClosed(502, "Server disconnect.") parts = line.strip().split(" ", 2) if len(parts) == 2: # handle missing message gracefully parts.append("") diff --git a/netlib/tcp.py b/netlib/tcp.py index 0a15d2ac..d909a5a4 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -179,17 +179,17 @@ class TCPClient: self.cert = None self.ssl_established = False - def convert_to_ssl(self, clientcert=None, sni=None, method=TLSv1_METHOD, options=None): + def convert_to_ssl(self, cert=None, sni=None, method=TLSv1_METHOD, options=None): """ - clientcert: Path to a file containing both client cert and private key. + cert: Path to a file containing both client cert and private key. """ context = SSL.Context(method) if options is not None: context.set_options(options) - if clientcert: + if cert: try: - context.use_privatekey_file(clientcert) - context.use_certificate_file(clientcert) + context.use_privatekey_file(cert) + context.use_certificate_file(cert) except SSL.Error, v: raise NetLibError("SSL client certificate error: %s"%str(v)) self.connection = SSL.Connection(context, self.connection) -- cgit v1.2.3 From f30df13384b1c31ee7bcd78b0caea37043434bcf Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Mon, 25 Feb 2013 21:11:09 +1300 Subject: Make sni_handler an argument to BaseHandler.convert_to_ssl --- netlib/tcp.py | 35 +++++++++++++++-------------------- netlib/test.py | 1 + 2 files changed, 16 insertions(+), 20 deletions(-) (limited to 'netlib') diff --git a/netlib/tcp.py b/netlib/tcp.py index d909a5a4..485d821f 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -254,15 +254,27 @@ class BaseHandler: self.ssl_established = False self.clientcert = None - def convert_to_ssl(self, cert, key, method=SSLv23_METHOD, options=None): + def convert_to_ssl(self, cert, key, method=SSLv23_METHOD, options=None, handle_sni=None): """ method: One of SSLv2_METHOD, SSLv3_METHOD, SSLv23_METHOD, or TLSv1_METHOD + handle_sni: SNI handler, should take a connection object. Server + name can be retrieved like this: + + connection.get_servername() + + And you can specify the connection keys as follows: + + new_context = Context(TLSv1_METHOD) + new_context.use_privatekey(key) + new_context.use_certificate(cert) + connection.set_context(new_context) """ ctx = SSL.Context(method) if not options is None: ctx.set_options(options) - # SNI callback happens during do_handshake() - ctx.set_tlsext_servername_callback(self.handle_sni) + if handle_sni: + # SNI callback happens during do_handshake() + ctx.set_tlsext_servername_callback(handle_sni) ctx.use_privatekey_file(key) ctx.use_certificate_file(cert) def ver(*args): @@ -290,23 +302,6 @@ class BaseHandler: # Remote has disconnected pass - def handle_sni(self, connection): - """ - Called if the client has given a server name indication. - - Server name can be retrieved like this: - - connection.get_servername() - - And you can specify the connection keys as follows: - - new_context = Context(TLSv1_METHOD) - new_context.use_privatekey(key) - new_context.use_certificate(cert) - connection.set_context(new_context) - """ - pass - def handle(self): # pragma: no cover raise NotImplementedError diff --git a/netlib/test.py b/netlib/test.py index 7d24d80e..3378279b 100644 --- a/netlib/test.py +++ b/netlib/test.py @@ -62,6 +62,7 @@ class TServer(tcp.TCPServer): self.ssl["key"], method = method, options = options, + handle_sni = getattr(h, "handle_sni", None) ) h.handle() h.finish() -- cgit v1.2.3 From 0fa63519654db2567995f3c3ac6e464796de66a3 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Thu, 28 Feb 2013 09:28:48 +1300 Subject: ODict.keys --- netlib/odict.py | 3 +++ 1 file changed, 3 insertions(+) (limited to 'netlib') diff --git a/netlib/odict.py b/netlib/odict.py index bddb3877..0759a5bf 100644 --- a/netlib/odict.py +++ b/netlib/odict.py @@ -36,6 +36,9 @@ class ODict: ret.append(i[1]) return ret + def keys(self): + return list(set([self._kconv(i[0]) for i in self.lst])) + def _filter_lst(self, k, lst): k = self._kconv(k) new = [] -- cgit v1.2.3 From 97537417f01c17903fb4cebd59991eea57faa5e6 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Sat, 2 Mar 2013 16:57:38 +1300 Subject: Factor out http.parse_response_line --- netlib/http.py | 24 ++++++++++++++++-------- 1 file changed, 16 insertions(+), 8 deletions(-) (limited to 'netlib') diff --git a/netlib/http.py b/netlib/http.py index 58993686..bc09c8a1 100644 --- a/netlib/http.py +++ b/netlib/http.py @@ -274,6 +274,20 @@ def read_http_body_response(rfile, headers, limit): return read_http_body(500, rfile, headers, all, limit) +def parse_response_line(line): + parts = line.strip().split(" ", 2) + if len(parts) == 2: # handle missing message gracefully + parts.append("") + if len(parts) != 3: + return None + proto, code, msg = parts + try: + code = int(code) + except ValueError: + return None + return (proto, code, msg) + + def read_response(rfile, method, body_size_limit): """ Return an (httpversion, code, msg, headers, content) tuple. @@ -283,19 +297,13 @@ def read_response(rfile, method, body_size_limit): line = rfile.readline() if not line: raise HttpErrorConnClosed(502, "Server disconnect.") - parts = line.strip().split(" ", 2) - if len(parts) == 2: # handle missing message gracefully - parts.append("") - if not len(parts) == 3: + parts = parse_response_line(line) + if not parts: raise HttpError(502, "Invalid server response: %s"%repr(line)) proto, code, msg = parts httpversion = parse_http_protocol(proto) if httpversion is None: raise HttpError(502, "Invalid HTTP version in line: %s"%repr(proto)) - try: - code = int(code) - except ValueError: - raise HttpError(502, "Invalid server response: %s"%repr(line)) headers = read_headers(rfile) if headers is None: raise HttpError(502, "Invalid headers.") -- cgit v1.2.3 From 0acab862a65ef4a1823a1bfb702d8be1e3d7b83d Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Sun, 3 Mar 2013 10:37:28 +1300 Subject: Integrate HTTP auth, test to 100% --- netlib/contrib/__init__.py | 0 netlib/contrib/md5crypt.py | 94 +++++++++++++++++++++++++++++++++++++ netlib/http.py | 22 ++++++++- netlib/http_auth.py | 113 +++++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 228 insertions(+), 1 deletion(-) create mode 100644 netlib/contrib/__init__.py create mode 100644 netlib/contrib/md5crypt.py create mode 100644 netlib/http_auth.py (limited to 'netlib') diff --git a/netlib/contrib/__init__.py b/netlib/contrib/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/netlib/contrib/md5crypt.py b/netlib/contrib/md5crypt.py new file mode 100644 index 00000000..d64ea8ac --- /dev/null +++ b/netlib/contrib/md5crypt.py @@ -0,0 +1,94 @@ +# Based on FreeBSD src/lib/libcrypt/crypt.c 1.2 +# http://www.freebsd.org/cgi/cvsweb.cgi/~checkout~/src/lib/libcrypt/crypt.c?rev=1.2&content-type=text/plain + +# Original license: +# * "THE BEER-WARE LICENSE" (Revision 42): +# * wrote this file. As long as you retain this notice you +# * can do whatever you want with this stuff. If we meet some day, and you think +# * this stuff is worth it, you can buy me a beer in return. Poul-Henning Kamp + +# This port adds no further stipulations. I forfeit any copyright interest. + +import md5 + +def md5crypt(password, salt, magic='$1$'): + # /* The password first, since that is what is most unknown */ /* Then our magic string */ /* Then the raw salt */ + m = md5.new() + m.update(password + magic + salt) + + # /* Then just as many characters of the MD5(pw,salt,pw) */ + mixin = md5.md5(password + salt + password).digest() + for i in range(0, len(password)): + m.update(mixin[i % 16]) + + # /* Then something really weird... */ + # Also really broken, as far as I can tell. -m + i = len(password) + while i: + if i & 1: + m.update('\x00') + else: + m.update(password[0]) + i >>= 1 + + final = m.digest() + + # /* and now, just to make sure things don't run too fast */ + for i in range(1000): + m2 = md5.md5() + if i & 1: + m2.update(password) + else: + m2.update(final) + + if i % 3: + m2.update(salt) + + if i % 7: + m2.update(password) + + if i & 1: + m2.update(final) + else: + m2.update(password) + + final = m2.digest() + + # This is the bit that uses to64() in the original code. + + itoa64 = './0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz' + + rearranged = '' + for a, b, c in ((0, 6, 12), (1, 7, 13), (2, 8, 14), (3, 9, 15), (4, 10, 5)): + v = ord(final[a]) << 16 | ord(final[b]) << 8 | ord(final[c]) + for i in range(4): + rearranged += itoa64[v & 0x3f]; v >>= 6 + + v = ord(final[11]) + for i in range(2): + rearranged += itoa64[v & 0x3f]; v >>= 6 + + return magic + salt + '$' + rearranged + +if __name__ == '__main__': + + def test(clear_password, the_hash): + magic, salt = the_hash[1:].split('$')[:2] + magic = '$' + magic + '$' + return md5crypt(clear_password, salt, magic) == the_hash + + test_cases = ( + (' ', '$1$yiiZbNIH$YiCsHZjcTkYd31wkgW8JF.'), + ('pass', '$1$YeNsbWdH$wvOF8JdqsoiLix754LTW90'), + ('____fifteen____', '$1$s9lUWACI$Kk1jtIVVdmT01p0z3b/hw1'), + ('____sixteen_____', '$1$dL3xbVZI$kkgqhCanLdxODGq14g/tW1'), + ('____seventeen____', '$1$NaH5na7J$j7y8Iss0hcRbu3kzoJs5V.'), + ('__________thirty-three___________', '$1$HO7Q6vzJ$yGwp2wbL5D7eOVzOmxpsy.'), + ('apache', '$apr1$J.w5a/..$IW9y6DR0oO/ADuhlMF5/X1') + ) + + for clearpw, hashpw in test_cases: + if test(clearpw, hashpw): + print '%s: pass' % clearpw + else: + print '%s: FAIL' % clearpw diff --git a/netlib/http.py b/netlib/http.py index bc09c8a1..10b6a402 100644 --- a/netlib/http.py +++ b/netlib/http.py @@ -1,4 +1,4 @@ -import string, urlparse +import string, urlparse, binascii import odict class HttpError(Exception): @@ -169,6 +169,26 @@ def parse_http_protocol(s): return major, minor +def parse_http_basic_auth(s): + words = s.split() + if len(words) != 2: + return None + scheme = words[0] + try: + user = binascii.a2b_base64(words[1]) + except binascii.Error: + return None + parts = user.split(':') + if len(parts) != 2: + return None + return scheme, parts[0], parts[1] + + +def assemble_http_basic_auth(scheme, username, password): + v = binascii.b2a_base64(username + ":" + password) + return scheme + " " + v + + def parse_init(line): try: method, url, protocol = string.split(line) diff --git a/netlib/http_auth.py b/netlib/http_auth.py new file mode 100644 index 00000000..d478ab10 --- /dev/null +++ b/netlib/http_auth.py @@ -0,0 +1,113 @@ +import binascii +import contrib.md5crypt as md5crypt +import http + + +class NullProxyAuth(): + """ + No proxy auth at all (returns empty challange headers) + """ + def __init__(self, password_manager): + self.password_manager = password_manager + + def clean(self, headers): + """ + Clean up authentication headers, so they're not passed upstream. + """ + pass + + def authenticate(self, headers): + """ + Tests that the user is allowed to use the proxy + """ + return True + + def auth_challenge_headers(self): + """ + Returns a dictionary containing the headers require to challenge the user + """ + return {} + + +class BasicProxyAuth(NullProxyAuth): + CHALLENGE_HEADER = 'Proxy-Authenticate' + AUTH_HEADER = 'Proxy-Authorization' + def __init__(self, password_manager, realm): + NullProxyAuth.__init__(self, password_manager) + self.realm = realm + + def clean(self, headers): + del headers[self.AUTH_HEADER] + + def authenticate(self, headers): + auth_value = headers.get(self.AUTH_HEADER, []) + if not auth_value: + return False + parts = http.parse_http_basic_auth(auth_value[0]) + if not parts: + return False + scheme, username, password = parts + if scheme.lower()!='basic': + return False + if not self.password_manager.test(username, password): + return False + self.username = username + return True + + def auth_challenge_headers(self): + return {self.CHALLENGE_HEADER:'Basic realm="%s"'%self.realm} + + +class PassMan(): + def test(self, username, password_token): + return False + + +class PassManNonAnon: + """ + Ensure the user specifies a username, accept any password. + """ + def test(self, username, password_token): + if username: + return True + return False + + +class PassManHtpasswd: + """ + Read usernames and passwords from an htpasswd file + """ + def __init__(self, fp): + """ + Raises ValueError if htpasswd file is invalid. + """ + self.usernames = {} + for l in fp: + l = l.strip().split(':') + if len(l) != 2: + raise ValueError("Invalid htpasswd file.") + parts = l[1].split('$') + if len(parts) != 4: + raise ValueError("Invalid htpasswd file.") + self.usernames[l[0]] = dict( + token = l[1], + dummy = parts[0], + magic = parts[1], + salt = parts[2], + hashed_password = parts[3] + ) + + def test(self, username, password_token): + ui = self.usernames.get(username) + if not ui: + return False + expected = md5crypt.md5crypt(password_token, ui["salt"], '$'+ui["magic"]+'$') + return expected==ui["token"] + + +class PassManSingleUser: + def __init__(self, username, password): + self.username, self.password = username, password + + def test(self, username, password_token): + return self.username==username and self.password==password_token -- cgit v1.2.3 From 1fe1a802adbef93b5b024a85d8dafb112ed652bb Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Sun, 3 Mar 2013 12:16:09 +1300 Subject: 100% test coverage. --- netlib/http_auth.py | 2 +- netlib/tcp.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) (limited to 'netlib') diff --git a/netlib/http_auth.py b/netlib/http_auth.py index d478ab10..4adae179 100644 --- a/netlib/http_auth.py +++ b/netlib/http_auth.py @@ -96,7 +96,7 @@ class PassManHtpasswd: salt = parts[2], hashed_password = parts[3] ) - + def test(self, username, password_token): ui = self.usernames.get(username) if not ui: diff --git a/netlib/tcp.py b/netlib/tcp.py index 485d821f..07b28cf9 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -298,7 +298,7 @@ class BaseHandler: self.close() self.wfile.close() self.rfile.close() - except socket.error: + except (socket.error, NetLibDisconnect): # Remote has disconnected pass -- cgit v1.2.3 From 2897ddfbee5ec3da72863cb8d5ee1370c9698f8a Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Sun, 3 Mar 2013 14:52:06 +1300 Subject: Stricter error checking for http.parse_url --- netlib/http.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) (limited to 'netlib') diff --git a/netlib/http.py b/netlib/http.py index 10b6a402..c864f1de 100644 --- a/netlib/http.py +++ b/netlib/http.py @@ -15,6 +15,11 @@ class HttpErrorConnClosed(HttpError): pass def parse_url(url): """ Returns a (scheme, host, port, path) tuple, or None on error. + + Checks that: + port is an integer + host is a valid IDNA-encoded hostname + path is valid ASCII """ scheme, netloc, path, params, query, fragment = urlparse.urlparse(url) if not scheme: @@ -34,6 +39,14 @@ def parse_url(url): path = urlparse.urlunparse(('', '', path, params, query, fragment)) if not path.startswith("/"): path = "/" + path + try: + host.decode("idna") + except ValueError: + return None + try: + path.decode("ascii") + except ValueError: + return None return scheme, host, port, path -- cgit v1.2.3 From cd4ed8530fa04fcbd54009e9db6ad9ea2518a10b Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Sun, 3 Mar 2013 15:03:57 +1300 Subject: Check that hosts in parse_url do not contain NULL bytes. --- netlib/http.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) (limited to 'netlib') diff --git a/netlib/http.py b/netlib/http.py index c864f1de..1b03d330 100644 --- a/netlib/http.py +++ b/netlib/http.py @@ -18,7 +18,7 @@ def parse_url(url): Checks that: port is an integer - host is a valid IDNA-encoded hostname + host is a valid IDNA-encoded hostname with no null-bytes path is valid ASCII """ scheme, netloc, path, params, query, fragment = urlparse.urlparse(url) @@ -43,6 +43,8 @@ def parse_url(url): host.decode("idna") except ValueError: return None + if "\0" in host: + return None try: path.decode("ascii") except ValueError: -- cgit v1.2.3 From 7b9300743e879a8a2e35f5786b23a17261350ff9 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Sun, 3 Mar 2013 15:08:17 +1300 Subject: More parse_url solidification: check that port is in range 0-65535 --- netlib/http.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) (limited to 'netlib') diff --git a/netlib/http.py b/netlib/http.py index 1b03d330..5628dd4d 100644 --- a/netlib/http.py +++ b/netlib/http.py @@ -17,7 +17,7 @@ def parse_url(url): Returns a (scheme, host, port, path) tuple, or None on error. Checks that: - port is an integer + port is an integer 0-65535 host is a valid IDNA-encoded hostname with no null-bytes path is valid ASCII """ @@ -49,6 +49,8 @@ def parse_url(url): path.decode("ascii") except ValueError: return None + if not 0 <= port <= 65535: + return None return scheme, host, port, path -- cgit v1.2.3 From b21a7da142625e3b47d712cd21cbd440eb48f490 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Sun, 3 Mar 2013 15:12:58 +1300 Subject: parse_url: Handle invalid IPv6 addresses --- netlib/http.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) (limited to 'netlib') diff --git a/netlib/http.py b/netlib/http.py index 5628dd4d..2c9e69cb 100644 --- a/netlib/http.py +++ b/netlib/http.py @@ -21,7 +21,10 @@ def parse_url(url): host is a valid IDNA-encoded hostname with no null-bytes path is valid ASCII """ - scheme, netloc, path, params, query, fragment = urlparse.urlparse(url) + try: + scheme, netloc, path, params, query, fragment = urlparse.urlparse(url) + except ValueError: + return None if not scheme: return None if ':' in netloc: -- cgit v1.2.3 From 5a050bb6b2b1a0bf05f4cd35d87e6f1d7a2608c0 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Sun, 3 Mar 2013 21:36:19 +1300 Subject: Tighten up checks on port ranges and path character sets. --- netlib/http.py | 37 ++++++++++++++++++++++++++----------- netlib/utils.py | 8 ++++++++ 2 files changed, 34 insertions(+), 11 deletions(-) (limited to 'netlib') diff --git a/netlib/http.py b/netlib/http.py index 2c9e69cb..0f2caa5a 100644 --- a/netlib/http.py +++ b/netlib/http.py @@ -1,5 +1,5 @@ import string, urlparse, binascii -import odict +import odict, utils class HttpError(Exception): def __init__(self, code, msg): @@ -12,6 +12,22 @@ class HttpError(Exception): class HttpErrorConnClosed(HttpError): pass +def _is_valid_port(port): + if not 0 <= port <= 65535: + return False + return True + + +def _is_valid_host(host): + try: + host.decode("idna") + except ValueError: + return False + if "\0" in host: + return None + return True + + def parse_url(url): """ Returns a (scheme, host, port, path) tuple, or None on error. @@ -42,17 +58,11 @@ def parse_url(url): path = urlparse.urlunparse(('', '', path, params, query, fragment)) if not path.startswith("/"): path = "/" + path - try: - host.decode("idna") - except ValueError: + if not _is_valid_host(host): return None - if "\0" in host: + if not utils.isascii(path): return None - try: - path.decode("ascii") - except ValueError: - return None - if not 0 <= port <= 65535: + if not _is_valid_port(port): return None return scheme, host, port, path @@ -236,6 +246,10 @@ def parse_init_connect(line): port = int(port) except ValueError: return None + if not _is_valid_port(port): + return None + if not _is_valid_host(host): + return None return host, port, httpversion @@ -260,7 +274,8 @@ def parse_init_http(line): if not v: return None method, url, httpversion = v - + if not utils.isascii(url): + return None if not (url.startswith("/") or url == "*"): return None return method, url, httpversion diff --git a/netlib/utils.py b/netlib/utils.py index 7621a1dc..61fd54ae 100644 --- a/netlib/utils.py +++ b/netlib/utils.py @@ -1,4 +1,12 @@ +def isascii(s): + try: + s.decode("ascii") + except ValueError: + return False + return True + + def cleanBin(s, fixspacing=False): """ Cleans binary data to make it safe to display. If fixspacing is True, -- cgit v1.2.3 From 5f0ad7b2a6b857419017e3e72062ab4e0e328238 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Sun, 3 Mar 2013 22:13:23 +1300 Subject: Ensure that HTTP methods are ASCII. --- netlib/http.py | 2 ++ 1 file changed, 2 insertions(+) (limited to 'netlib') diff --git a/netlib/http.py b/netlib/http.py index 0f2caa5a..f1a2bfb5 100644 --- a/netlib/http.py +++ b/netlib/http.py @@ -227,6 +227,8 @@ def parse_init(line): httpversion = parse_http_protocol(protocol) if not httpversion: return None + if not utils.isascii(method): + return None return method, url, httpversion -- cgit v1.2.3 From a94d17970e739cdda4e6223b3af8136b05e6e192 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Tue, 5 Mar 2013 09:09:52 +1300 Subject: Sync version number with mitmproxy. --- netlib/version.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'netlib') diff --git a/netlib/version.py b/netlib/version.py index 30a4c0f9..d90c000c 100644 --- a/netlib/version.py +++ b/netlib/version.py @@ -1,4 +1,4 @@ -IVERSION = (0, 2, 2) +IVERSION = (0, 9) VERSION = ".".join(str(i) for i in IVERSION) NAME = "netlib" NAMEVERSION = NAME + " " + VERSION -- cgit v1.2.3 From 241465c368c0117a8d86c17c44b39fed3116c6e0 Mon Sep 17 00:00:00 2001 From: Tim Becker Date: Fri, 19 Apr 2013 15:37:14 +0200 Subject: extensions aren't supported in v1, set to v3 (value=2) if using them. --- netlib/certutils.py | 1 + 1 file changed, 1 insertion(+) (limited to 'netlib') diff --git a/netlib/certutils.py b/netlib/certutils.py index 859c93f1..8407dcc8 100644 --- a/netlib/certutils.py +++ b/netlib/certutils.py @@ -110,6 +110,7 @@ def dummy_cert(fp, ca, commonname, sans): cert.set_subject(req.get_subject()) cert.set_serial_number(int(time.time()*10000)) if ss: + cert.set_version(2) cert.add_extensions([OpenSSL.crypto.X509Extension("subjectAltName", True, ss)]) cert.set_pubkey(req.get_pubkey()) cert.sign(key, "sha1") -- cgit v1.2.3 From 9c13224353eefbb6b1824ded20846036b07c558f Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Sun, 5 May 2013 13:49:20 +1200 Subject: Fix exception hierarchy. --- netlib/tcp.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'netlib') diff --git a/netlib/tcp.py b/netlib/tcp.py index 07b28cf9..b67ad0bb 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -35,8 +35,8 @@ OP_TLS_ROLLBACK_BUG = SSL.OP_TLS_ROLLBACK_BUG class NetLibError(Exception): pass -class NetLibDisconnect(Exception): pass -class NetLibTimeout(Exception): pass +class NetLibDisconnect(NetLibError): pass +class NetLibTimeout(NetLibError): pass class _FileLike: -- cgit v1.2.3 From 7f0aa415e1ab95ed6b27a760cc9aa8ff4ee85080 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Mon, 13 May 2013 08:48:21 +1200 Subject: Add a request_client_cert argument to server SSL conversion. By default, we now do not request the client cert. We're supposed to be able to do this with no negative effects - if the client has no cert to present, we're notified and proceed as usual. Unfortunately, Android seems to have a bug (tested on 4.2.2) - when an Android client is asked to present a certificate it does not have, it hangs up, which is frankly bogus. Some time down the track we may be able to make the proper behaviour the default again, but until then we're conservative. --- netlib/certutils.py | 3 --- netlib/tcp.py | 20 ++++++++++++++++---- netlib/test.py | 3 ++- 3 files changed, 18 insertions(+), 8 deletions(-) (limited to 'netlib') diff --git a/netlib/certutils.py b/netlib/certutils.py index 8407dcc8..f18318f6 100644 --- a/netlib/certutils.py +++ b/netlib/certutils.py @@ -5,9 +5,6 @@ from pyasn1.error import PyAsn1Error import OpenSSL import tcp -CERT_SLEEP_TIME = 1 -CERT_EXPIRY = str(365 * 3) - def create_ca(): key = OpenSSL.crypto.PKey() diff --git a/netlib/tcp.py b/netlib/tcp.py index b67ad0bb..47953724 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -240,6 +240,7 @@ class TCPClient: class BaseHandler: """ The instantiator is expected to call the handle() and finish() methods. + """ rbufsize = -1 wbufsize = -1 @@ -252,9 +253,10 @@ class BaseHandler: self.server = server self.finished = False self.ssl_established = False + self.clientcert = None - def convert_to_ssl(self, cert, key, method=SSLv23_METHOD, options=None, handle_sni=None): + def convert_to_ssl(self, cert, key, method=SSLv23_METHOD, options=None, handle_sni=None, request_client_cert=False): """ method: One of SSLv2_METHOD, SSLv3_METHOD, SSLv23_METHOD, or TLSv1_METHOD handle_sni: SNI handler, should take a connection object. Server @@ -268,6 +270,15 @@ class BaseHandler: new_context.use_privatekey(key) new_context.use_certificate(cert) connection.set_context(new_context) + + The request_client_cert argument requires some explanation. We're + supposed to be able to do this with no negative effects - if the + client has no cert to present, we're notified and proceed as usual. + Unfortunately, Android seems to have a bug (tested on 4.2.2) - when + an Android client is asked to present a certificate it does not + have, it hangs up, which is frankly bogus. Some time down the track + we may be able to make the proper behaviour the default again, but + until then we're conservative. """ ctx = SSL.Context(method) if not options is None: @@ -277,9 +288,10 @@ class BaseHandler: ctx.set_tlsext_servername_callback(handle_sni) ctx.use_privatekey_file(key) ctx.use_certificate_file(cert) - def ver(*args): - self.clientcert = certutils.SSLCert(args[1]) - ctx.set_verify(SSL.VERIFY_PEER, ver) + if request_client_cert: + def ver(*args): + self.clientcert = certutils.SSLCert(args[1]) + ctx.set_verify(SSL.VERIFY_PEER, ver) self.connection = SSL.Connection(ctx, self.connection) self.ssl_established = True self.connection.set_accept_state() diff --git a/netlib/test.py b/netlib/test.py index 3378279b..deaef64e 100644 --- a/netlib/test.py +++ b/netlib/test.py @@ -62,7 +62,8 @@ class TServer(tcp.TCPServer): self.ssl["key"], method = method, options = options, - handle_sni = getattr(h, "handle_sni", None) + handle_sni = getattr(h, "handle_sni", None), + request_client_cert = self.ssl["request_client_cert"] ) h.handle() h.finish() -- cgit v1.2.3 From c9ab1c60b5d43f0b4d645c751350b16e9e562b55 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Sun, 16 Jun 2013 00:28:21 +0200 Subject: always read files in binary mode --- netlib/certutils.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) (limited to 'netlib') diff --git a/netlib/certutils.py b/netlib/certutils.py index f18318f6..4c06eb8f 100644 --- a/netlib/certutils.py +++ b/netlib/certutils.py @@ -48,23 +48,23 @@ def dummy_ca(path): key, ca = create_ca() # Dump the CA plus private key - f = open(path, "w") + f = open(path, "wb") f.write(OpenSSL.crypto.dump_privatekey(OpenSSL.crypto.FILETYPE_PEM, key)) f.write(OpenSSL.crypto.dump_certificate(OpenSSL.crypto.FILETYPE_PEM, ca)) f.close() # Dump the certificate in PEM format - f = open(os.path.join(dirname, basename + "-cert.pem"), "w") + f = open(os.path.join(dirname, basename + "-cert.pem"), "wb") f.write(OpenSSL.crypto.dump_certificate(OpenSSL.crypto.FILETYPE_PEM, ca)) f.close() # Create a .cer file with the same contents for Android - f = open(os.path.join(dirname, basename + "-cert.cer"), "w") + f = open(os.path.join(dirname, basename + "-cert.cer"), "wb") f.write(OpenSSL.crypto.dump_certificate(OpenSSL.crypto.FILETYPE_PEM, ca)) f.close() # Dump the certificate in PKCS12 format for Windows devices - f = open(os.path.join(dirname, basename + "-cert.p12"), "w") + f = open(os.path.join(dirname, basename + "-cert.p12"), "wb") p12 = OpenSSL.crypto.PKCS12() p12.set_certificate(ca) p12.set_privatekey(key) @@ -88,7 +88,7 @@ def dummy_cert(fp, ca, commonname, sans): ss.append("DNS: %s"%i) ss = ", ".join(ss) - raw = file(ca, "r").read() + raw = file(ca, "rb").read() ca = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, raw) key = OpenSSL.crypto.load_privatekey(OpenSSL.crypto.FILETYPE_PEM, raw) @@ -165,7 +165,7 @@ class CertStore: if os.path.exists(certpath): return certpath elif cacert: - f = open(certpath, "w") + f = open(certpath, "wb") dummy_cert(f, cacert, commonname, sans) return certpath -- cgit v1.2.3 From 73f8a1e2e0006c2a37ae6264afe70a8207ffbb54 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Sun, 16 Jun 2013 13:38:39 +1200 Subject: Bump version. --- netlib/version.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'netlib') diff --git a/netlib/version.py b/netlib/version.py index d90c000c..63a9d862 100644 --- a/netlib/version.py +++ b/netlib/version.py @@ -1,4 +1,4 @@ -IVERSION = (0, 9) +IVERSION = (0, 9, 1) VERSION = ".".join(str(i) for i in IVERSION) NAME = "netlib" NAMEVERSION = NAME + " " + VERSION -- cgit v1.2.3 From 68e2e782b0afdc03844b107c28627391c51dd036 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Mon, 17 Jun 2013 17:03:17 +0200 Subject: attempt to fix 'half-duplex' TCP close sequence --- netlib/tcp.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) (limited to 'netlib') diff --git a/netlib/tcp.py b/netlib/tcp.py index 47953724..e37cb707 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -230,11 +230,15 @@ class TCPClient: if self.ssl_established: self.connection.shutdown() else: - self.connection.shutdown(socket.SHUT_RDWR) - self.connection.close() + self.connection.shutdown(socket.SHUT_WR) + #Section 4.2.2.13 of RFC 1122 tells us that a close() with any pending readable data could lead to an immediate RST being sent. + #http://ia600609.us.archive.org/22/items/TheUltimateSo_lingerPageOrWhyIsMyTcpNotReliable/the-ultimate-so_linger-page-or-why-is-my-tcp-not-reliable.html + while self.connection.recv(4096): + pass except (socket.error, SSL.Error): # Socket probably already closed pass + self.connection.close() class BaseHandler: @@ -328,10 +332,15 @@ class BaseHandler: if self.ssl_established: self.connection.shutdown() else: - self.connection.shutdown(socket.SHUT_RDWR) + self.connection.shutdown(socket.SHUT_WR) + #Section 4.2.2.13 of RFC 1122 tells us that a close() with any pending readable data could lead to an immediate RST being sent. + #http://ia600609.us.archive.org/22/items/TheUltimateSo_lingerPageOrWhyIsMyTcpNotReliable/the-ultimate-so_linger-page-or-why-is-my-tcp-not-reliable.html + while self.connection.recv(4096): + pass except (socket.error, SSL.Error): # Socket probably already closed pass + self.connection.close() -- cgit v1.2.3 From 02376b6a75fdb397a865697723f7282dbf70deca Mon Sep 17 00:00:00 2001 From: Andrey Plotnikov Date: Sun, 7 Jul 2013 13:33:56 +0800 Subject: Add socket binding support for TCPClient --- netlib/tcp.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) (limited to 'netlib') diff --git a/netlib/tcp.py b/netlib/tcp.py index 47953724..b5e9e2c4 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -173,11 +173,12 @@ class Reader(_FileLike): class TCPClient: rbufsize = -1 wbufsize = -1 - def __init__(self, host, port): + def __init__(self, host, port, source_address=None): self.host, self.port = host, port self.connection, self.rfile, self.wfile = None, None, None self.cert = None self.ssl_established = False + self.source_address = source_address def convert_to_ssl(self, cert=None, sni=None, method=TLSv1_METHOD, options=None): """ @@ -209,6 +210,8 @@ class TCPClient: try: addr = socket.gethostbyname(self.host) connection = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + if self.source_address: + connection.bind(self.source_address) connection.connect((addr, self.port)) self.rfile = Reader(connection.makefile('rb', self.rbufsize)) self.wfile = Writer(connection.makefile('wb', self.wbufsize)) -- cgit v1.2.3 From f5fdfd8a9f17e0fe213a9cf54acae84e4bc31462 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Tue, 30 Jul 2013 09:42:13 +1200 Subject: Clarify the interface for flush and close methods. --- netlib/tcp.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) (limited to 'netlib') diff --git a/netlib/tcp.py b/netlib/tcp.py index 69ad2da5..123c6515 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -93,7 +93,7 @@ class Writer(_FileLike): if hasattr(self.o, "flush"): try: self.o.flush() - except socket.error, v: + except (socket.error, IOError), v: raise NetLibDisconnect(str(v)) def write(self, v): @@ -215,7 +215,7 @@ class TCPClient: connection.connect((addr, self.port)) self.rfile = Reader(connection.makefile('rb', self.rbufsize)) self.wfile = Writer(connection.makefile('wb', self.wbufsize)) - except socket.error, err: + except (socket.error, IOError), err: raise NetLibError('Error connecting to "%s": %s' % (self.host, err)) self.connection = connection @@ -238,16 +238,16 @@ class TCPClient: #http://ia600609.us.archive.org/22/items/TheUltimateSo_lingerPageOrWhyIsMyTcpNotReliable/the-ultimate-so_linger-page-or-why-is-my-tcp-not-reliable.html while self.connection.recv(4096): pass - except (socket.error, SSL.Error): + self.connection.close() + except (socket.error, SSL.Error, IOError): # Socket probably already closed pass - self.connection.close() class BaseHandler: """ The instantiator is expected to call the handle() and finish() methods. - + """ rbufsize = -1 wbufsize = -1 -- cgit v1.2.3 From b9f06b473cd464e82bc53a973c5e190f93377bce Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Sat, 10 Aug 2013 23:07:09 +1200 Subject: Better handling of cert errors. --- netlib/tcp.py | 3 +++ 1 file changed, 3 insertions(+) (limited to 'netlib') diff --git a/netlib/tcp.py b/netlib/tcp.py index 123c6515..df1f8fea 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -37,6 +37,7 @@ OP_TLS_ROLLBACK_BUG = SSL.OP_TLS_ROLLBACK_BUG class NetLibError(Exception): pass class NetLibDisconnect(NetLibError): pass class NetLibTimeout(NetLibError): pass +class NetLibSSLError(NetLibError): pass class _FileLike: @@ -129,6 +130,8 @@ class Reader(_FileLike): data = self.o.read(rlen) except SSL.ZeroReturnError: break + except SSL.Error, v: + raise NetLibSSLError(v.message) except SSL.WantReadError: if (time.time() - start) < self.o.gettimeout(): time.sleep(0.1) -- cgit v1.2.3 From 2da57ecff0e9572e45663dbad1c5f520e57c531f Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Sun, 11 Aug 2013 11:47:07 +1200 Subject: Correct order of precedence for SSL errors. --- netlib/tcp.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'netlib') diff --git a/netlib/tcp.py b/netlib/tcp.py index df1f8fea..f4a8acf9 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -130,8 +130,6 @@ class Reader(_FileLike): data = self.o.read(rlen) except SSL.ZeroReturnError: break - except SSL.Error, v: - raise NetLibSSLError(v.message) except SSL.WantReadError: if (time.time() - start) < self.o.gettimeout(): time.sleep(0.1) @@ -144,6 +142,8 @@ class Reader(_FileLike): raise NetLibDisconnect except SSL.SysCallError: raise NetLibDisconnect + except SSL.Error, v: + raise NetLibSSLError(v.message) self.first_byte_timestamp = self.first_byte_timestamp or time.time() if not data: break -- cgit v1.2.3 From 62edceee093dd54956ed5b623dfb4cb8c1309a16 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Mon, 12 Aug 2013 16:03:29 +1200 Subject: Revamp dummy cert generation. We no longer use on-disk storage - we just keep the certs in memory. --- netlib/certutils.py | 45 +++++++++++++-------------------------------- netlib/tcp.py | 3 ++- netlib/test.py | 7 +++++-- 3 files changed, 20 insertions(+), 35 deletions(-) (limited to 'netlib') diff --git a/netlib/certutils.py b/netlib/certutils.py index 4c06eb8f..7dcb5450 100644 --- a/netlib/certutils.py +++ b/netlib/certutils.py @@ -73,7 +73,7 @@ def dummy_ca(path): return True -def dummy_cert(fp, ca, commonname, sans): +def dummy_cert(ca, commonname, sans): """ Generates and writes a certificate to fp. @@ -111,27 +111,15 @@ def dummy_cert(fp, ca, commonname, sans): cert.add_extensions([OpenSSL.crypto.X509Extension("subjectAltName", True, ss)]) cert.set_pubkey(req.get_pubkey()) cert.sign(key, "sha1") - - fp.write(OpenSSL.crypto.dump_certificate(OpenSSL.crypto.FILETYPE_PEM, cert)) - fp.close() + return SSLCert(cert) class CertStore: """ - Implements an on-disk certificate store. + Implements an in-memory certificate store. """ - def __init__(self, certdir=None): - """ - certdir: The certificate store directory. If None, a temporary - directory will be created, and destroyed when the .cleanup() method - is called. - """ - if certdir: - self.remove = False - self.certdir = certdir - else: - self.remove = True - self.certdir = tempfile.mkdtemp(prefix="certstore") + def __init__(self): + self.certs = {} def check_domain(self, commonname): try: @@ -145,33 +133,26 @@ class CertStore: return False return True - def get_cert(self, commonname, sans, cacert=False): + def get_cert(self, commonname, sans, cacert): """ - Returns the path to a certificate. + Returns an SSLCert object. commonname: Common name for the generated certificate. Must be a valid, plain-ASCII, IDNA-encoded domain name. sans: A list of Subject Alternate Names. - cacert: An optional path to a CA certificate. If specified, the - cert is created if it does not exist, else return None. + cacert: The path to a CA certificate. Return None if the certificate could not be found or generated. """ if not self.check_domain(commonname): return None - certpath = os.path.join(self.certdir, commonname + ".pem") - if os.path.exists(certpath): - return certpath - elif cacert: - f = open(certpath, "wb") - dummy_cert(f, cacert, commonname, sans) - return certpath - - def cleanup(self): - if self.remove: - shutil.rmtree(self.certdir) + if commonname in self.certs: + return self.certs[commonname] + c = dummy_cert(cacert, commonname, sans) + self.certs[commonname] = c + return c class _GeneralName(univ.Choice): diff --git a/netlib/tcp.py b/netlib/tcp.py index f4a8acf9..31e9a398 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -268,6 +268,7 @@ class BaseHandler: def convert_to_ssl(self, cert, key, method=SSLv23_METHOD, options=None, handle_sni=None, request_client_cert=False): """ + cert: A certutils.SSLCert object. method: One of SSLv2_METHOD, SSLv3_METHOD, SSLv23_METHOD, or TLSv1_METHOD handle_sni: SNI handler, should take a connection object. Server name can be retrieved like this: @@ -297,7 +298,7 @@ class BaseHandler: # SNI callback happens during do_handshake() ctx.set_tlsext_servername_callback(handle_sni) ctx.use_privatekey_file(key) - ctx.use_certificate_file(cert) + ctx.use_certificate(cert.x509) if request_client_cert: def ver(*args): self.clientcert = certutils.SSLCert(args[1]) diff --git a/netlib/test.py b/netlib/test.py index deaef64e..661395c5 100644 --- a/netlib/test.py +++ b/netlib/test.py @@ -1,5 +1,5 @@ import threading, Queue, cStringIO -import tcp +import tcp, certutils class ServerThread(threading.Thread): def __init__(self, server): @@ -51,6 +51,9 @@ class TServer(tcp.TCPServer): h = self.handler_klass(request, client_address, self) self.last_handler = h if self.ssl: + cert = certutils.SSLCert.from_pem( + file(self.ssl["cert"], "r").read() + ) if self.ssl["v3_only"]: method = tcp.SSLv3_METHOD options = tcp.OP_NO_SSLv2|tcp.OP_NO_TLSv1 @@ -58,7 +61,7 @@ class TServer(tcp.TCPServer): method = tcp.SSLv23_METHOD options = None h.convert_to_ssl( - self.ssl["cert"], + cert, self.ssl["key"], method = method, options = options, -- cgit v1.2.3 From c44f354fd0f9b4f1432913dd70cf1579910dfa4b Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Sat, 17 Aug 2013 16:15:37 +0200 Subject: fix windows bugs --- netlib/tcp.py | 3 +++ 1 file changed, 3 insertions(+) (limited to 'netlib') diff --git a/netlib/tcp.py b/netlib/tcp.py index 31e9a398..2de647ae 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -235,6 +235,7 @@ class TCPClient: try: if self.ssl_established: self.connection.shutdown() + self.connection.sock_shutdown(socket.SHUT_WR) else: self.connection.shutdown(socket.SHUT_WR) #Section 4.2.2.13 of RFC 1122 tells us that a close() with any pending readable data could lead to an immediate RST being sent. @@ -302,6 +303,7 @@ class BaseHandler: if request_client_cert: def ver(*args): self.clientcert = certutils.SSLCert(args[1]) + return True ctx.set_verify(SSL.VERIFY_PEER, ver) self.connection = SSL.Connection(ctx, self.connection) self.ssl_established = True @@ -338,6 +340,7 @@ class BaseHandler: try: if self.ssl_established: self.connection.shutdown() + self.connection.sock_shutdown(socket.SHUT_WR) else: self.connection.shutdown(socket.SHUT_WR) #Section 4.2.2.13 of RFC 1122 tells us that a close() with any pending readable data could lead to an immediate RST being sent. -- cgit v1.2.3 From 28a0030c1ecacb8ac5c6e6453b6a22bdf94d9f7e Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Mon, 19 Aug 2013 19:41:20 +0200 Subject: compatibility fixes for windows --- netlib/tcp.py | 3 ++- netlib/test.py | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) (limited to 'netlib') diff --git a/netlib/tcp.py b/netlib/tcp.py index 2de647ae..f4a713f9 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -303,7 +303,8 @@ class BaseHandler: if request_client_cert: def ver(*args): self.clientcert = certutils.SSLCert(args[1]) - return True + # err 20 = X509_V_ERR_UNABLE_TO_GET_ISSUER_CERT_LOCALLY + #return True ctx.set_verify(SSL.VERIFY_PEER, ver) self.connection = SSL.Connection(ctx, self.connection) self.ssl_established = True diff --git a/netlib/test.py b/netlib/test.py index 661395c5..87802bd5 100644 --- a/netlib/test.py +++ b/netlib/test.py @@ -52,7 +52,7 @@ class TServer(tcp.TCPServer): self.last_handler = h if self.ssl: cert = certutils.SSLCert.from_pem( - file(self.ssl["cert"], "r").read() + file(self.ssl["cert"], "rb").read() ) if self.ssl["v3_only"]: method = tcp.SSLv3_METHOD -- cgit v1.2.3 From d5b3e397e142ae60275fb89ea765423903e99bb6 Mon Sep 17 00:00:00 2001 From: Israel Nir Date: Wed, 21 Aug 2013 13:42:30 +0300 Subject: adding cipher list selection option to BaseHandler --- netlib/tcp.py | 4 +++- netlib/test.py | 3 ++- 2 files changed, 5 insertions(+), 2 deletions(-) (limited to 'netlib') diff --git a/netlib/tcp.py b/netlib/tcp.py index 31e9a398..f1496a32 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -266,7 +266,7 @@ class BaseHandler: self.clientcert = None - def convert_to_ssl(self, cert, key, method=SSLv23_METHOD, options=None, handle_sni=None, request_client_cert=False): + def convert_to_ssl(self, cert, key, method=SSLv23_METHOD, options=None, handle_sni=None, request_client_cert=False, cipher_list=None): """ cert: A certutils.SSLCert object. method: One of SSLv2_METHOD, SSLv3_METHOD, SSLv23_METHOD, or TLSv1_METHOD @@ -294,6 +294,8 @@ class BaseHandler: ctx = SSL.Context(method) if not options is None: ctx.set_options(options) + if cipher_list: + ctx.set_cipher_list(cipher_list) if handle_sni: # SNI callback happens during do_handshake() ctx.set_tlsext_servername_callback(handle_sni) diff --git a/netlib/test.py b/netlib/test.py index 661395c5..139d95bb 100644 --- a/netlib/test.py +++ b/netlib/test.py @@ -66,7 +66,8 @@ class TServer(tcp.TCPServer): method = method, options = options, handle_sni = getattr(h, "handle_sni", None), - request_client_cert = self.ssl["request_client_cert"] + request_client_cert = self.ssl["request_client_cert"], + cipher_list = self.ssl.get("cipher_list", None) ) h.handle() h.finish() -- cgit v1.2.3 From 7428f954744725381ced7c273609ca14d767dfff Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Sun, 25 Aug 2013 10:22:09 +1200 Subject: Handle interrupted system call errors. --- netlib/tcp.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) (limited to 'netlib') diff --git a/netlib/tcp.py b/netlib/tcp.py index 31e9a398..bee1f75b 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -376,7 +376,13 @@ class TCPServer: self.__is_shut_down.clear() try: while not self.__shutdown_request: - r, w, e = select.select([self.socket], [], [], poll_interval) + try: + r, w, e = select.select([self.socket], [], [], poll_interval) + except select.error, ex: + if ex[0] == 4: + continue + else: + raise if self.socket in r: request, client_address = self.socket.accept() t = threading.Thread( -- cgit v1.2.3 From 8a261b2c01fe49de896bf9808af8fbb66b300cfc Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Sun, 25 Aug 2013 10:30:48 +1200 Subject: Bump version. --- netlib/version.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'netlib') diff --git a/netlib/version.py b/netlib/version.py index 63a9d862..32013c35 100644 --- a/netlib/version.py +++ b/netlib/version.py @@ -1,4 +1,4 @@ -IVERSION = (0, 9, 1) +IVERSION = (0, 9, 2) VERSION = ".".join(str(i) for i in IVERSION) NAME = "netlib" NAMEVERSION = NAME + " " + VERSION -- cgit v1.2.3 From 98f765f693fc4fa7245c3179da1d791661ed502a Mon Sep 17 00:00:00 2001 From: Paul Date: Tue, 24 Sep 2013 21:18:41 +0200 Subject: Don't create a certificate request when creating a dummy cert --- netlib/certutils.py | 12 ++---------- 1 file changed, 2 insertions(+), 10 deletions(-) (limited to 'netlib') diff --git a/netlib/certutils.py b/netlib/certutils.py index 7dcb5450..60e41427 100644 --- a/netlib/certutils.py +++ b/netlib/certutils.py @@ -92,24 +92,16 @@ def dummy_cert(ca, commonname, sans): ca = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, raw) key = OpenSSL.crypto.load_privatekey(OpenSSL.crypto.FILETYPE_PEM, raw) - req = OpenSSL.crypto.X509Req() - subj = req.get_subject() - subj.CN = commonname - req.set_pubkey(ca.get_pubkey()) - req.sign(key, "sha1") - if ss: - req.add_extensions([OpenSSL.crypto.X509Extension("subjectAltName", True, ss)]) - cert = OpenSSL.crypto.X509() cert.gmtime_adj_notBefore(-3600) cert.gmtime_adj_notAfter(60 * 60 * 24 * 30) cert.set_issuer(ca.get_subject()) - cert.set_subject(req.get_subject()) + cert.get_subject().CN = commonname cert.set_serial_number(int(time.time()*10000)) if ss: cert.set_version(2) cert.add_extensions([OpenSSL.crypto.X509Extension("subjectAltName", True, ss)]) - cert.set_pubkey(req.get_pubkey()) + cert.set_pubkey(ca.get_pubkey()) cert.sign(key, "sha1") return SSLCert(cert) -- cgit v1.2.3 From 53b7c5abdd7c6dbb8ecaa1aa1000296f86eb45fa Mon Sep 17 00:00:00 2001 From: Sean Coates Date: Mon, 7 Oct 2013 16:48:30 -0400 Subject: allow specification of o, cn, expiry --- netlib/certutils.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) (limited to 'netlib') diff --git a/netlib/certutils.py b/netlib/certutils.py index 60e41427..a21f0188 100644 --- a/netlib/certutils.py +++ b/netlib/certutils.py @@ -5,17 +5,20 @@ from pyasn1.error import PyAsn1Error import OpenSSL import tcp +default_exp = 62208000 # =24 * 60 * 60 * 720 +default_o = "mitmproxy" +default_cn = "mitmproxy" -def create_ca(): +def create_ca(o=default_o, cn=default_cn, exp=default_exp): key = OpenSSL.crypto.PKey() key.generate_key(OpenSSL.crypto.TYPE_RSA, 1024) ca = OpenSSL.crypto.X509() ca.set_serial_number(int(time.time()*10000)) ca.set_version(2) - ca.get_subject().CN = "mitmproxy" - ca.get_subject().O = "mitmproxy" + ca.get_subject().CN = cn + ca.get_subject().O = o ca.gmtime_adj_notBefore(0) - ca.gmtime_adj_notAfter(24 * 60 * 60 * 720) + ca.gmtime_adj_notAfter(exp) ca.set_issuer(ca.get_subject()) ca.set_pubkey(key) ca.add_extensions([ @@ -35,7 +38,7 @@ def create_ca(): return key, ca -def dummy_ca(path): +def dummy_ca(path, o=default_o, cn=default_cn, exp=default_exp): dirname = os.path.dirname(path) if not os.path.exists(dirname): os.makedirs(dirname) @@ -45,7 +48,7 @@ def dummy_ca(path): else: basename = os.path.basename(path) - key, ca = create_ca() + key, ca = create_ca(o=o, cn=cn, exp=exp) # Dump the CA plus private key f = open(path, "wb") -- cgit v1.2.3 From 642b3f002ed7020ee359d23d46802b0bb02c1018 Mon Sep 17 00:00:00 2001 From: Sean Coates Date: Mon, 7 Oct 2013 16:55:35 -0400 Subject: remove tempfile and shutil imports because they're not actually used --- netlib/certutils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'netlib') diff --git a/netlib/certutils.py b/netlib/certutils.py index 60e41427..dab7e318 100644 --- a/netlib/certutils.py +++ b/netlib/certutils.py @@ -1,4 +1,4 @@ -import os, ssl, time, datetime, tempfile, shutil +import os, ssl, time, datetime from pyasn1.type import univ, constraint, char, namedtype, tag from pyasn1.codec.der.decoder import decode from pyasn1.error import PyAsn1Error -- cgit v1.2.3 From 5e4ccbd7edc6eebf9eee25fd4d6ca64994ed6522 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Tue, 19 Nov 2013 04:11:24 +0100 Subject: attempt to fix #24 --- netlib/http.py | 17 ++++------------- 1 file changed, 4 insertions(+), 13 deletions(-) (limited to 'netlib') diff --git a/netlib/http.py b/netlib/http.py index f1a2bfb5..7060b688 100644 --- a/netlib/http.py +++ b/netlib/http.py @@ -283,32 +283,23 @@ def parse_init_http(line): return method, url, httpversion -def request_connection_close(httpversion, headers): +def connection_close(httpversion, headers): """ - Checks the request to see if the client connection should be closed. + Checks the message to see if the client connection should be closed according to RFC 2616 Section 8.1 """ + # At first, check if we have an explicit Connection header. if "connection" in headers: toks = get_header_tokens(headers, "connection") if "close" in toks: return True elif "keep-alive" in toks: return False - # HTTP 1.1 connections are assumed to be persistent + # If we don't have a Connection header, HTTP 1.1 connections are assumed to be persistent if httpversion == (1, 1): return False return True -def response_connection_close(httpversion, headers): - """ - Checks the response to see if the client connection should be closed. - """ - if request_connection_close(httpversion, headers): - return True - elif (not has_chunked_encoding(headers)) and "content-length" in headers: - return False - return True - def read_http_body_request(rfile, wfile, headers, httpversion, limit): """ -- cgit v1.2.3 From e402e3b862312ca4f7bd7dd633db3654143c3380 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Thu, 21 Nov 2013 01:07:56 +0100 Subject: add custom argparse actions to seamlessly integrate ProxyAuth classes --- netlib/http_auth.py | 44 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 44 insertions(+) (limited to 'netlib') diff --git a/netlib/http_auth.py b/netlib/http_auth.py index 4adae179..6c91c7c5 100644 --- a/netlib/http_auth.py +++ b/netlib/http_auth.py @@ -1,6 +1,7 @@ import binascii import contrib.md5crypt as md5crypt import http +from argparse import Action, ArgumentTypeError class NullProxyAuth(): @@ -111,3 +112,46 @@ class PassManSingleUser: def test(self, username, password_token): return self.username==username and self.password==password_token + + +class AuthAction(Action): + """ + Helper class to allow seamless integration int argparse. Example usage: + parser.add_argument( + "--nonanonymous", + action=NonanonymousAuthAction, nargs=0, + help="Allow access to any user long as a credentials are specified." + ) + """ + def __call__(self, parser, namespace, values, option_string=None): + passman = self.getPasswordManager(values) + if passman: + authenticator = BasicProxyAuth(passman, "mitmproxy") + else: + authenticator = NullProxyAuth(None) + setattr(namespace, "authenticator", authenticator) + + def getPasswordManager(self, s): + """ + returns the password manager + """ + raise NotImplementedError() + + +class SingleuserAuthAction(AuthAction): + def getPasswordManager(self, s): + if len(s.split(':')) != 2: + raise ArgumentTypeError("Invalid single-user specification. Please use the format username:password") + username, password = s.split(':') + return PassManSingleUser(username, password) + + +class NonanonymousAuthAction(AuthAction): + def getPasswordManager(self, s): + return PassManNonAnon() + + +class HtpasswdAuthAction(AuthAction): + def getPasswordManager(self, s): + with open(s, "r") as f: + return PassManHtpasswd(f) \ No newline at end of file -- cgit v1.2.3 From 5aad09ab816b2343ca686d45e6c5d2b8ba07b10b Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Sun, 8 Dec 2013 10:15:19 +1300 Subject: Fix client certificate request feature. --- netlib/tcp.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'netlib') diff --git a/netlib/tcp.py b/netlib/tcp.py index f4a713f9..23458742 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -303,8 +303,8 @@ class BaseHandler: if request_client_cert: def ver(*args): self.clientcert = certutils.SSLCert(args[1]) - # err 20 = X509_V_ERR_UNABLE_TO_GET_ISSUER_CERT_LOCALLY - #return True + # Return true to prevent cert verification error + return True ctx.set_verify(SSL.VERIFY_PEER, ver) self.connection = SSL.Connection(ctx, self.connection) self.ssl_established = True -- cgit v1.2.3 From d05c20d8fab3345e19c06ac0de00a2c8f30c44ef Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Sun, 8 Dec 2013 13:15:08 +1300 Subject: Domain checks for persistent cert store is now irrelevant. We no longer store these on disk, so we don't care about path components. --- netlib/certutils.py | 14 -------------- netlib/tcp.py | 5 +++-- 2 files changed, 3 insertions(+), 16 deletions(-) (limited to 'netlib') diff --git a/netlib/certutils.py b/netlib/certutils.py index 22b5c35c..d9b8ce57 100644 --- a/netlib/certutils.py +++ b/netlib/certutils.py @@ -116,18 +116,6 @@ class CertStore: def __init__(self): self.certs = {} - def check_domain(self, commonname): - try: - commonname.decode("idna") - commonname.decode("ascii") - except: - return False - if ".." in commonname: - return False - if "/" in commonname: - return False - return True - def get_cert(self, commonname, sans, cacert): """ Returns an SSLCert object. @@ -141,8 +129,6 @@ class CertStore: Return None if the certificate could not be found or generated. """ - if not self.check_domain(commonname): - return None if commonname in self.certs: return self.certs[commonname] c = dummy_cert(cacert, commonname, sans) diff --git a/netlib/tcp.py b/netlib/tcp.py index 8fe04d2e..b3be43d6 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -346,8 +346,9 @@ class BaseHandler: self.connection.sock_shutdown(socket.SHUT_WR) else: self.connection.shutdown(socket.SHUT_WR) - #Section 4.2.2.13 of RFC 1122 tells us that a close() with any pending readable data could lead to an immediate RST being sent. - #http://ia600609.us.archive.org/22/items/TheUltimateSo_lingerPageOrWhyIsMyTcpNotReliable/the-ultimate-so_linger-page-or-why-is-my-tcp-not-reliable.html + # Section 4.2.2.13 of RFC 1122 tells us that a close() with any + # pending readable data could lead to an immediate RST being sent. + # http://ia600609.us.archive.org/22/items/TheUltimateSo_lingerPageOrWhyIsMyTcpNotReliable/the-ultimate-so_linger-page-or-why-is-my-tcp-not-reliable.html while self.connection.recv(4096): pass except (socket.error, SSL.Error): -- cgit v1.2.3 From 7213f86d49960a625643fb6179e6a3731b16d462 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Sun, 8 Dec 2013 13:35:42 +1300 Subject: Unit test auth actions. --- netlib/http_auth.py | 17 +++++++---------- 1 file changed, 7 insertions(+), 10 deletions(-) (limited to 'netlib') diff --git a/netlib/http_auth.py b/netlib/http_auth.py index 6c91c7c5..71f120d6 100644 --- a/netlib/http_auth.py +++ b/netlib/http_auth.py @@ -125,23 +125,19 @@ class AuthAction(Action): """ def __call__(self, parser, namespace, values, option_string=None): passman = self.getPasswordManager(values) - if passman: - authenticator = BasicProxyAuth(passman, "mitmproxy") - else: - authenticator = NullProxyAuth(None) + authenticator = BasicProxyAuth(passman, "mitmproxy") setattr(namespace, "authenticator", authenticator) - def getPasswordManager(self, s): - """ - returns the password manager - """ + def getPasswordManager(self, s): # pragma: nocover raise NotImplementedError() class SingleuserAuthAction(AuthAction): def getPasswordManager(self, s): if len(s.split(':')) != 2: - raise ArgumentTypeError("Invalid single-user specification. Please use the format username:password") + raise ArgumentTypeError( + "Invalid single-user specification. Please use the format username:password" + ) username, password = s.split(':') return PassManSingleUser(username, password) @@ -154,4 +150,5 @@ class NonanonymousAuthAction(AuthAction): class HtpasswdAuthAction(AuthAction): def getPasswordManager(self, s): with open(s, "r") as f: - return PassManHtpasswd(f) \ No newline at end of file + return PassManHtpasswd(f) + -- cgit v1.2.3 From 390f2a46c920ee332d758d6c46999b5147e0b30b Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Sun, 8 Dec 2013 01:37:45 +0100 Subject: make AuthAction generic --- netlib/http_auth.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) (limited to 'netlib') diff --git a/netlib/http_auth.py b/netlib/http_auth.py index 6c91c7c5..948d503a 100644 --- a/netlib/http_auth.py +++ b/netlib/http_auth.py @@ -33,6 +33,7 @@ class NullProxyAuth(): class BasicProxyAuth(NullProxyAuth): CHALLENGE_HEADER = 'Proxy-Authenticate' AUTH_HEADER = 'Proxy-Authorization' + def __init__(self, password_manager, realm): NullProxyAuth.__init__(self, password_manager) self.realm = realm @@ -125,11 +126,10 @@ class AuthAction(Action): """ def __call__(self, parser, namespace, values, option_string=None): passman = self.getPasswordManager(values) - if passman: - authenticator = BasicProxyAuth(passman, "mitmproxy") - else: - authenticator = NullProxyAuth(None) - setattr(namespace, "authenticator", authenticator) + if not passman: + raise ArgumentTypeError("Error creating password manager for proxy authentication.") + authenticator = BasicProxyAuth(passman, "mitmproxy") + setattr(namespace, self.dest, authenticator) def getPasswordManager(self, s): """ -- cgit v1.2.3 From 4840c6b3bf5c9e992895f9c3117ceddca4c0cc33 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Sun, 8 Dec 2013 15:26:30 +1300 Subject: Fix race condition in test suite. --- netlib/tcp.py | 1 - 1 file changed, 1 deletion(-) (limited to 'netlib') diff --git a/netlib/tcp.py b/netlib/tcp.py index b3be43d6..5a07c013 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -354,7 +354,6 @@ class BaseHandler: except (socket.error, SSL.Error): # Socket probably already closed pass - self.connection.close() -- cgit v1.2.3 From d66fd5ba1b11ad57b7825b7feb67392f45e88c24 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Tue, 10 Dec 2013 22:20:12 +1300 Subject: Bump version --- netlib/version.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'netlib') diff --git a/netlib/version.py b/netlib/version.py index 32013c35..9b2e037e 100644 --- a/netlib/version.py +++ b/netlib/version.py @@ -1,4 +1,4 @@ -IVERSION = (0, 9, 2) +IVERSION = (0, 10) VERSION = ".".join(str(i) for i in IVERSION) NAME = "netlib" NAMEVERSION = NAME + " " + VERSION -- cgit v1.2.3 From a7ac97eb823f599ca04f588f6cbe4da28e00a194 Mon Sep 17 00:00:00 2001 From: Matthias Urlichs Date: Thu, 12 Dec 2013 07:00:58 +0100 Subject: support ipv6 --- netlib/tcp.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) (limited to 'netlib') diff --git a/netlib/tcp.py b/netlib/tcp.py index 5a07c013..ee5fe618 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -176,12 +176,13 @@ class Reader(_FileLike): class TCPClient: rbufsize = -1 wbufsize = -1 - def __init__(self, host, port, source_address=None): + def __init__(self, host, port, source_address=None, use_ipv6=False): self.host, self.port = host, port self.connection, self.rfile, self.wfile = None, None, None self.cert = None self.ssl_established = False self.source_address = source_address + self.use_ipv6 = use_ipv6 def convert_to_ssl(self, cert=None, sni=None, method=TLSv1_METHOD, options=None): """ @@ -211,11 +212,10 @@ class TCPClient: def connect(self): try: - addr = socket.gethostbyname(self.host) - connection = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + connection = socket.socket(socket.AF_INET6 if self.use_ipv6 else socket.AF_INET, socket.SOCK_STREAM) if self.source_address: connection.bind(self.source_address) - connection.connect((addr, self.port)) + connection.connect((self.host, self.port)) self.rfile = Reader(connection.makefile('rb', self.rbufsize)) self.wfile = Writer(connection.makefile('wb', self.wbufsize)) except (socket.error, IOError), err: -- cgit v1.2.3 From 6f26cec83e77f8998b50988c54196f9dfae5b7dd Mon Sep 17 00:00:00 2001 From: Matthias Urlichs Date: Thu, 12 Dec 2013 07:11:13 +0100 Subject: tab fix --- netlib/tcp.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'netlib') diff --git a/netlib/tcp.py b/netlib/tcp.py index ee5fe618..aa9ca027 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -182,7 +182,7 @@ class TCPClient: self.cert = None self.ssl_established = False self.source_address = source_address - self.use_ipv6 = use_ipv6 + self.use_ipv6 = use_ipv6 def convert_to_ssl(self, cert=None, sni=None, method=TLSv1_METHOD, options=None): """ -- cgit v1.2.3 From 969595cca70edc4d02d5f676221267edf01e4252 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Fri, 13 Dec 2013 06:24:08 +0100 Subject: add requirements.txt, small changes --- netlib/http.py | 4 ++++ netlib/http_auth.py | 2 -- 2 files changed, 4 insertions(+), 2 deletions(-) (limited to 'netlib') diff --git a/netlib/http.py b/netlib/http.py index 7060b688..e160bd79 100644 --- a/netlib/http.py +++ b/netlib/http.py @@ -233,6 +233,10 @@ def parse_init(line): def parse_init_connect(line): + """ + Returns (host, port, httpversion) if line is a valid CONNECT line. + http://tools.ietf.org/html/draft-luotonen-web-proxy-tunneling-01 section 3.1 + """ v = parse_init(line) if not v: return None diff --git a/netlib/http_auth.py b/netlib/http_auth.py index 69bee5c1..8f062826 100644 --- a/netlib/http_auth.py +++ b/netlib/http_auth.py @@ -126,8 +126,6 @@ class AuthAction(Action): """ def __call__(self, parser, namespace, values, option_string=None): passman = self.getPasswordManager(values) - if not passman: - raise ArgumentTypeError("Error creating password manager for proxy authentication.") authenticator = BasicProxyAuth(passman, "mitmproxy") setattr(namespace, self.dest, authenticator) -- cgit v1.2.3 From cebec67e08bcb9a4dc353ca18aedc53d0230ea42 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Sun, 15 Dec 2013 06:43:54 +0100 Subject: refactor read_http_body --- netlib/http.py | 95 ++++++++++++++++++++++++---------------------------------- netlib/test.py | 2 +- 2 files changed, 40 insertions(+), 57 deletions(-) (limited to 'netlib') diff --git a/netlib/http.py b/netlib/http.py index e160bd79..454edb3a 100644 --- a/netlib/http.py +++ b/netlib/http.py @@ -95,14 +95,17 @@ def read_headers(fp): return odict.ODictCaseless(ret) -def read_chunked(code, fp, limit): +def read_chunked(fp, headers, limit, is_request): """ Read a chunked HTTP body. May raise HttpError. """ + # FIXME: Should check if chunked is the final encoding in the headers + # http://tools.ietf.org/html/draft-ietf-httpbis-p1-messaging-16#section-3.3 3.3 2. content = "" total = 0 + code = 400 if is_request else 502 while 1: line = fp.readline(128) if line == "": @@ -151,35 +154,6 @@ def has_chunked_encoding(headers): return "chunked" in [i.lower() for i in get_header_tokens(headers, "transfer-encoding")] -def read_http_body(code, rfile, headers, all, limit): - """ - Read an HTTP body: - - code: The HTTP error code to be used when raising HttpError - rfile: A file descriptor to read from - headers: An ODictCaseless object - all: Should we read all data? - limit: Size limit. - """ - if has_chunked_encoding(headers): - content = read_chunked(code, rfile, limit) - elif "content-length" in headers: - try: - l = int(headers["content-length"][0]) - except ValueError: - # FIXME: Not strictly correct - this could be from the server, in which - # case we should send a 502. - raise HttpError(code, "Invalid content-length header: %s"%headers["content-length"]) - if limit is not None and l > limit: - raise HttpError(code, "HTTP Body too large. Limit is %s, content-length was %s"%(limit, l)) - content = rfile.read(l) - elif all: - content = rfile.read(limit if limit else -1) - else: - content = "" - return content - - def parse_http_protocol(s): """ Parse an HTTP protocol declaration. Returns a (major, minor) tuple, or @@ -304,28 +278,6 @@ def connection_close(httpversion, headers): return True - -def read_http_body_request(rfile, wfile, headers, httpversion, limit): - """ - Read the HTTP body from a client request. - """ - if "expect" in headers: - # FIXME: Should be forwarded upstream - if "100-continue" in headers['expect'] and httpversion >= (1, 1): - wfile.write('HTTP/1.1 100 Continue\r\n') - wfile.write('\r\n') - del headers['expect'] - return read_http_body(400, rfile, headers, False, limit) - - -def read_http_body_response(rfile, headers, limit): - """ - Read the HTTP body from a server response. - """ - all = "close" in get_header_tokens(headers, "connection") - return read_http_body(500, rfile, headers, all, limit) - - def parse_response_line(line): parts = line.strip().split(" ", 2) if len(parts) == 2: # handle missing message gracefully @@ -359,10 +311,41 @@ def read_response(rfile, method, body_size_limit): headers = read_headers(rfile) if headers is None: raise HttpError(502, "Invalid headers.") - if code >= 100 and code <= 199: - return read_response(rfile, method, body_size_limit) - if method == "HEAD" or code == 204 or code == 304: + + # Parse response body according to http://tools.ietf.org/html/draft-ietf-httpbis-p1-messaging-16#section-3.3 + if method == "HEAD" or (code in [204, 304]) or 100 <= code <= 199: content = "" else: - content = read_http_body_response(rfile, headers, body_size_limit) + content = read_http_body(rfile, headers, body_size_limit, False) return httpversion, code, msg, headers, content + + +def read_http_body(rfile, headers, limit, is_request): + """ + Read an HTTP message body: + + rfile: A file descriptor to read from + headers: An ODictCaseless object + limit: Size limit. + is_request: True if the body to read belongs to a request, False otherwise + """ + if has_chunked_encoding(headers): + content = read_chunked(rfile, headers, limit, is_request) + elif "content-length" in headers: + try: + l = int(headers["content-length"][0]) + if l < 0: + raise ValueError() + except ValueError: + raise HttpError(400 if is_request else 502, "Invalid content-length header: %s"%headers["content-length"]) + if limit is not None and l > limit: + raise HttpError(400 if is_request else 509, "HTTP Body too large. Limit is %s, content-length was %s"%(limit, l)) + content = rfile.read(l) + elif is_request: + content = "" + else: + content = rfile.read(limit if limit else -1) + not_done = rfile.read(1) + if not_done: + raise HttpError(400 if is_request else 509, "HTTP Body too large. Limit is %s," % limit) + return content \ No newline at end of file diff --git a/netlib/test.py b/netlib/test.py index cd1a3847..85a56739 100644 --- a/netlib/test.py +++ b/netlib/test.py @@ -18,7 +18,7 @@ class ServerTestBase: handler = None addr = ("localhost", 0) use_ipv6 = False - + @classmethod def setupAll(cls): cls.q = Queue.Queue() -- cgit v1.2.3 From 5717e7300c1cc4a17f0fb0659dcf591fbd0a6e40 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Sun, 5 Jan 2014 10:57:50 +1300 Subject: Make it possible to pass custom environment variables into wsgi apps. --- netlib/wsgi.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) (limited to 'netlib') diff --git a/netlib/wsgi.py b/netlib/wsgi.py index dffc2ace..647cb899 100644 --- a/netlib/wsgi.py +++ b/netlib/wsgi.py @@ -33,7 +33,7 @@ class WSGIAdaptor: def __init__(self, app, domain, port, sversion): self.app, self.domain, self.port, self.sversion = app, domain, port, sversion - def make_environ(self, request, errsoc): + def make_environ(self, request, errsoc, **extra): if '?' in request.path: path_info, query = request.path.split('?', 1) else: @@ -59,6 +59,7 @@ class WSGIAdaptor: # FIXME: We need to pick up the protocol read from the request. 'SERVER_PROTOCOL': "HTTP/1.1", } + environ.update(extra) if request.client_conn.address: environ["REMOTE_ADDR"], environ["REMOTE_PORT"] = request.client_conn.address @@ -86,7 +87,7 @@ class WSGIAdaptor: soc.write("\r\n") soc.write(c) - def serve(self, request, soc): + def serve(self, request, soc, **env): state = dict( response_started = False, headers_sent = False, @@ -123,7 +124,7 @@ class WSGIAdaptor: errs = cStringIO.StringIO() try: - dataiter = self.app(self.make_environ(request, errs), start_response) + dataiter = self.app(self.make_environ(request, errs, **env), start_response) for i in dataiter: write(i) if not state["headers_sent"]: -- cgit v1.2.3 From ac1a700fa16e2ae2146425844823bff70cc86f4b Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Wed, 8 Jan 2014 14:46:55 +1300 Subject: Make certificate not-before time 48 hours. Fixes #200 --- netlib/certutils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'netlib') diff --git a/netlib/certutils.py b/netlib/certutils.py index d9b8ce57..0349bec7 100644 --- a/netlib/certutils.py +++ b/netlib/certutils.py @@ -96,7 +96,7 @@ def dummy_cert(ca, commonname, sans): key = OpenSSL.crypto.load_privatekey(OpenSSL.crypto.FILETYPE_PEM, raw) cert = OpenSSL.crypto.X509() - cert.gmtime_adj_notBefore(-3600) + cert.gmtime_adj_notBefore(-3600*48) cert.gmtime_adj_notAfter(60 * 60 * 24 * 30) cert.set_issuer(ca.get_subject()) cert.get_subject().CN = commonname -- cgit v1.2.3 From 951f2d517fa2e464d654a54bebacbd983f944c62 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Thu, 9 Jan 2014 01:57:37 +0100 Subject: change parameter names to reflect changes --- netlib/tcp.py | 29 +++++++++++++---------------- netlib/test.py | 2 +- 2 files changed, 14 insertions(+), 17 deletions(-) (limited to 'netlib') diff --git a/netlib/tcp.py b/netlib/tcp.py index 33f7ef3a..d35818bf 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -138,8 +138,8 @@ class Reader(_FileLike): raise NetLibTimeout except socket.timeout: raise NetLibTimeout - except socket.error: - raise NetLibDisconnect + except socket.error, v: + raise NetLibDisconnect(v[1]) except SSL.SysCallError: raise NetLibDisconnect except SSL.Error, v: @@ -255,16 +255,13 @@ class BaseHandler: """ rbufsize = -1 wbufsize = -1 - def __init__(self, connection, client_address, server): + def __init__(self, connection): self.connection = connection self.rfile = Reader(self.connection.makefile('rb', self.rbufsize)) self.wfile = Writer(self.connection.makefile('wb', self.wbufsize)) - self.client_address = client_address - self.server = server self.finished = False self.ssl_established = False - self.clientcert = None def convert_to_ssl(self, cert, key, method=SSLv23_METHOD, options=None, handle_sni=None, request_client_cert=False, cipher_list=None): @@ -371,13 +368,13 @@ class TCPServer: self.port = self.server_address[1] self.socket.listen(self.request_queue_size) - def request_thread(self, request, client_address): + def connection_thread(self, connection, client_address): try: - self.handle_connection(request, client_address) - request.close() + self.handle_client_connection(connection, client_address) except: - self.handle_error(request, client_address) - request.close() + self.handle_error(connection, client_address) + finally: + connection.close() def serve_forever(self, poll_interval=0.1): self.__is_shut_down.clear() @@ -391,10 +388,10 @@ class TCPServer: else: raise if self.socket in r: - request, client_address = self.socket.accept() + connection, client_address = self.socket.accept() t = threading.Thread( - target = self.request_thread, - args = (request, client_address) + target = self.connection_thread, + args = (connection, client_address) ) t.setDaemon(1) t.start() @@ -410,7 +407,7 @@ class TCPServer: def handle_error(self, request, client_address, fp=sys.stderr): """ - Called when handle_connection raises an exception. + Called when handle_client_connection raises an exception. """ # If a thread has persisted after interpreter exit, the module might be # none. @@ -421,7 +418,7 @@ class TCPServer: print >> fp, exc print >> fp, '-'*40 - def handle_connection(self, request, client_address): # pragma: no cover + def handle_client_connection(self, conn, client_address): # pragma: no cover """ Called after client connection. """ diff --git a/netlib/test.py b/netlib/test.py index cd1a3847..0c36da6a 100644 --- a/netlib/test.py +++ b/netlib/test.py @@ -50,7 +50,7 @@ class TServer(tcp.TCPServer): self.handler_klass = handler_klass self.last_handler = None - def handle_connection(self, request, client_address): + def handle_client_connection(self, request, client_address): h = self.handler_klass(request, client_address, self) self.last_handler = h if self.ssl: -- cgit v1.2.3 From d0a6d2e2545089893d3789e3c787e269645df852 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Thu, 9 Jan 2014 05:33:21 +0100 Subject: fix tests, remove duplicate code --- netlib/tcp.py | 91 ++++++++++++++++++++++++---------------------------------- netlib/test.py | 2 +- 2 files changed, 38 insertions(+), 55 deletions(-) (limited to 'netlib') diff --git a/netlib/tcp.py b/netlib/tcp.py index d35818bf..e48f4f6b 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -138,8 +138,8 @@ class Reader(_FileLike): raise NetLibTimeout except socket.timeout: raise NetLibTimeout - except socket.error, v: - raise NetLibDisconnect(v[1]) + except socket.error: + raise NetLibDisconnect except SSL.SysCallError: raise NetLibDisconnect except SSL.Error, v: @@ -173,7 +173,40 @@ class Reader(_FileLike): return result -class TCPClient: +class SocketCloseMixin: + def finish(self): + self.finished = True + try: + if not getattr(self.wfile, "closed", False): + self.wfile.flush() + self.close() + self.wfile.close() + self.rfile.close() + except (socket.error, NetLibDisconnect): + # Remote has disconnected + pass + + def close(self): + """ + Does a hard close of the socket, i.e. a shutdown, followed by a close. + """ + try: + if self.ssl_established: + self.connection.shutdown() + self.connection.sock_shutdown(socket.SHUT_WR) + else: + self.connection.shutdown(socket.SHUT_WR) + #Section 4.2.2.13 of RFC 1122 tells us that a close() with any pending readable data could lead to an immediate RST being sent. + #http://ia600609.us.archive.org/22/items/TheUltimateSo_lingerPageOrWhyIsMyTcpNotReliable/the-ultimate-so_linger-page-or-why-is-my-tcp-not-reliable.html + while self.connection.recv(4096): + pass + self.connection.close() + except (socket.error, SSL.Error, IOError): + # Socket probably already closed + pass + + +class TCPClient(SocketCloseMixin): rbufsize = -1 wbufsize = -1 def __init__(self, host, port, source_address=None, use_ipv6=False): @@ -228,27 +261,8 @@ class TCPClient: def gettimeout(self): return self.connection.gettimeout() - def close(self): - """ - Does a hard close of the socket, i.e. a shutdown, followed by a close. - """ - try: - if self.ssl_established: - self.connection.shutdown() - self.connection.sock_shutdown(socket.SHUT_WR) - else: - self.connection.shutdown(socket.SHUT_WR) - #Section 4.2.2.13 of RFC 1122 tells us that a close() with any pending readable data could lead to an immediate RST being sent. - #http://ia600609.us.archive.org/22/items/TheUltimateSo_lingerPageOrWhyIsMyTcpNotReliable/the-ultimate-so_linger-page-or-why-is-my-tcp-not-reliable.html - while self.connection.recv(4096): - pass - self.connection.close() - except (socket.error, SSL.Error, IOError): - # Socket probably already closed - pass - -class BaseHandler: +class BaseHandler(SocketCloseMixin): """ The instantiator is expected to call the handle() and finish() methods. @@ -315,43 +329,12 @@ class BaseHandler: self.rfile.set_descriptor(self.connection) self.wfile.set_descriptor(self.connection) - def finish(self): - self.finished = True - try: - if not getattr(self.wfile, "closed", False): - self.wfile.flush() - self.close() - self.wfile.close() - self.rfile.close() - except (socket.error, NetLibDisconnect): - # Remote has disconnected - pass - def handle(self): # pragma: no cover raise NotImplementedError def settimeout(self, n): self.connection.settimeout(n) - def close(self): - """ - Does a hard close of the socket, i.e. a shutdown, followed by a close. - """ - try: - if self.ssl_established: - self.connection.shutdown() - self.connection.sock_shutdown(socket.SHUT_WR) - else: - self.connection.shutdown(socket.SHUT_WR) - # Section 4.2.2.13 of RFC 1122 tells us that a close() with any - # pending readable data could lead to an immediate RST being sent. - # http://ia600609.us.archive.org/22/items/TheUltimateSo_lingerPageOrWhyIsMyTcpNotReliable/the-ultimate-so_linger-page-or-why-is-my-tcp-not-reliable.html - while self.connection.recv(4096): - pass - except (socket.error, SSL.Error): - # Socket probably already closed - pass - self.connection.close() class TCPServer: diff --git a/netlib/test.py b/netlib/test.py index 2209ebc3..f5599082 100644 --- a/netlib/test.py +++ b/netlib/test.py @@ -51,7 +51,7 @@ class TServer(tcp.TCPServer): self.last_handler = None def handle_client_connection(self, request, client_address): - h = self.handler_klass(request, client_address, self) + h = self.handler_klass(request) self.last_handler = h if self.ssl: cert = certutils.SSLCert.from_pem( -- cgit v1.2.3 From 0f22039bcadd26c2745f609085bcfdbba35b4945 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Sat, 18 Jan 2014 22:55:40 +0100 Subject: add CONNECT request to list of request types that don't have a response body --- netlib/http.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'netlib') diff --git a/netlib/http.py b/netlib/http.py index 454edb3a..51f85627 100644 --- a/netlib/http.py +++ b/netlib/http.py @@ -313,7 +313,7 @@ def read_response(rfile, method, body_size_limit): raise HttpError(502, "Invalid headers.") # Parse response body according to http://tools.ietf.org/html/draft-ietf-httpbis-p1-messaging-16#section-3.3 - if method == "HEAD" or (code in [204, 304]) or 100 <= code <= 199: + if method in ["HEAD", "CONNECT"] or (code in [204, 304]) or 100 <= code <= 199: content = "" else: content = read_http_body(rfile, headers, body_size_limit, False) -- cgit v1.2.3 From 8266699acdfcb786ba2c87007a17632ff1893fe5 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Sun, 19 Jan 2014 18:17:06 +1300 Subject: Silence pyflakes, adjust requirements.txt --- netlib/http_auth.py | 1 - 1 file changed, 1 deletion(-) (limited to 'netlib') diff --git a/netlib/http_auth.py b/netlib/http_auth.py index 8f062826..be99fb3d 100644 --- a/netlib/http_auth.py +++ b/netlib/http_auth.py @@ -1,4 +1,3 @@ -import binascii import contrib.md5crypt as md5crypt import http from argparse import Action, ArgumentTypeError -- cgit v1.2.3 From 763cb90b66b23cd94b6e37df3d4c7b8e7f89492a Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Tue, 28 Jan 2014 17:26:35 +0100 Subject: add tcp.Address to unify ipv4/ipv6 address handling --- netlib/certutils.py | 2 +- netlib/tcp.py | 56 +++++++++++++++++++++++++++++++++++++++-------------- netlib/test.py | 11 +++++------ 3 files changed, 48 insertions(+), 21 deletions(-) (limited to 'netlib') diff --git a/netlib/certutils.py b/netlib/certutils.py index 0349bec7..94294f6e 100644 --- a/netlib/certutils.py +++ b/netlib/certutils.py @@ -237,7 +237,7 @@ class SSLCert: def get_remote_cert(host, port, sni): - c = tcp.TCPClient(host, port) + c = tcp.TCPClient((host, port)) c.connect() c.convert_to_ssl(sni=sni) return c.cert diff --git a/netlib/tcp.py b/netlib/tcp.py index e48f4f6b..bad166d0 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -173,6 +173,35 @@ class Reader(_FileLike): return result +class Address(tuple): + """ + This class wraps an IPv4/IPv6 tuple to provide named attributes and ipv6 information. + """ + def __new__(cls, address, use_ipv6=False): + a = super(Address, cls).__new__(cls, tuple(address)) + a.family = socket.AF_INET6 if use_ipv6 else socket.AF_INET + return a + + @classmethod + def wrap(cls, t): + if isinstance(t, cls): + return t + else: + return cls(t) + + @property + def host(self): + return self[0] + + @property + def port(self): + return self[1] + + @property + def is_ipv6(self): + return self.family == socket.AF_INET6 + + class SocketCloseMixin: def finish(self): self.finished = True @@ -209,10 +238,9 @@ class SocketCloseMixin: class TCPClient(SocketCloseMixin): rbufsize = -1 wbufsize = -1 - def __init__(self, host, port, source_address=None, use_ipv6=False): - self.host, self.port = host, port + def __init__(self, address, source_address=None): + self.address = Address.wrap(address) self.source_address = source_address - self.use_ipv6 = use_ipv6 self.connection, self.rfile, self.wfile = None, None, None self.cert = None self.ssl_established = False @@ -245,14 +273,14 @@ class TCPClient(SocketCloseMixin): def connect(self): try: - connection = socket.socket(socket.AF_INET6 if self.use_ipv6 else socket.AF_INET, socket.SOCK_STREAM) + connection = socket.socket(self.address.family, socket.SOCK_STREAM) if self.source_address: connection.bind(self.source_address) - connection.connect((self.host, self.port)) + connection.connect(self.address) self.rfile = Reader(connection.makefile('rb', self.rbufsize)) self.wfile = Writer(connection.makefile('wb', self.wbufsize)) except (socket.error, IOError), err: - raise NetLibError('Error connecting to "%s": %s' % (self.host, err)) + raise NetLibError('Error connecting to "%s": %s' % (self.address[0], err)) self.connection = connection def settimeout(self, n): @@ -269,8 +297,9 @@ class BaseHandler(SocketCloseMixin): """ rbufsize = -1 wbufsize = -1 - def __init__(self, connection): + def __init__(self, connection, address): self.connection = connection + self.address = Address.wrap(address) self.rfile = Reader(self.connection.makefile('rb', self.rbufsize)) self.wfile = Writer(self.connection.makefile('wb', self.wbufsize)) @@ -339,19 +368,18 @@ class BaseHandler(SocketCloseMixin): class TCPServer: request_queue_size = 20 - def __init__(self, server_address, use_ipv6=False): - self.server_address = server_address - self.use_ipv6 = use_ipv6 + def __init__(self, address): + self.address = Address.wrap(address) self.__is_shut_down = threading.Event() self.__shutdown_request = False - self.socket = socket.socket(socket.AF_INET6 if self.use_ipv6 else socket.AF_INET, socket.SOCK_STREAM) + self.socket = socket.socket(self.address.family, socket.SOCK_STREAM) self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - self.socket.bind(self.server_address) - self.server_address = self.socket.getsockname() - self.port = self.server_address[1] + self.socket.bind(self.address) + self.address = Address.wrap(self.socket.getsockname()) self.socket.listen(self.request_queue_size) def connection_thread(self, connection, client_address): + client_address = Address(client_address) try: self.handle_client_connection(connection, client_address) except: diff --git a/netlib/test.py b/netlib/test.py index f5599082..565b97cd 100644 --- a/netlib/test.py +++ b/netlib/test.py @@ -17,19 +17,18 @@ class ServerTestBase: ssl = None handler = None addr = ("localhost", 0) - use_ipv6 = False @classmethod def setupAll(cls): cls.q = Queue.Queue() s = cls.makeserver() - cls.port = s.port + cls.port = s.address.port cls.server = ServerThread(s) cls.server.start() @classmethod def makeserver(cls): - return TServer(cls.ssl, cls.q, cls.handler, cls.addr, cls.use_ipv6) + return TServer(cls.ssl, cls.q, cls.handler, cls.addr) @classmethod def teardownAll(cls): @@ -41,17 +40,17 @@ class ServerTestBase: class TServer(tcp.TCPServer): - def __init__(self, ssl, q, handler_klass, addr, use_ipv6): + def __init__(self, ssl, q, handler_klass, addr): """ ssl: A {cert, key, v3_only} dict. """ - tcp.TCPServer.__init__(self, addr, use_ipv6=use_ipv6) + tcp.TCPServer.__init__(self, addr) self.ssl, self.q = ssl, q self.handler_klass = handler_klass self.last_handler = None def handle_client_connection(self, request, client_address): - h = self.handler_klass(request) + h = self.handler_klass(request, client_address) self.last_handler = h if self.ssl: cert = certutils.SSLCert.from_pem( -- cgit v1.2.3 From e18ac4b672e8645388dc8057801092ce417f1511 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Tue, 28 Jan 2014 20:30:16 +0100 Subject: re-add server attribute to BaseHandler --- netlib/tcp.py | 4 +++- netlib/test.py | 2 +- 2 files changed, 4 insertions(+), 2 deletions(-) (limited to 'netlib') diff --git a/netlib/tcp.py b/netlib/tcp.py index bad166d0..729e513e 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -297,9 +297,11 @@ class BaseHandler(SocketCloseMixin): """ rbufsize = -1 wbufsize = -1 - def __init__(self, connection, address): + + def __init__(self, connection, address, server): self.connection = connection self.address = Address.wrap(address) + self.server = server self.rfile = Reader(self.connection.makefile('rb', self.rbufsize)) self.wfile = Writer(self.connection.makefile('wb', self.wbufsize)) diff --git a/netlib/test.py b/netlib/test.py index 565b97cd..2f6a7107 100644 --- a/netlib/test.py +++ b/netlib/test.py @@ -50,7 +50,7 @@ class TServer(tcp.TCPServer): self.last_handler = None def handle_client_connection(self, request, client_address): - h = self.handler_klass(request, client_address) + h = self.handler_klass(request, client_address, self) self.last_handler = h if self.ssl: cert = certutils.SSLCert.from_pem( -- cgit v1.2.3 From ff9656be80192ac837cf98997f9fe6c00c9c5a32 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Thu, 30 Jan 2014 20:07:30 +0100 Subject: remove subclassing of tuple in tcp.Address, move StateObject into netlib --- netlib/certutils.py | 12 +++++++- netlib/odict.py | 7 ++++- netlib/stateobject.py | 80 +++++++++++++++++++++++++++++++++++++++++++++++++++ netlib/tcp.py | 45 ++++++++++++++++++++--------- 4 files changed, 128 insertions(+), 16 deletions(-) create mode 100644 netlib/stateobject.py (limited to 'netlib') diff --git a/netlib/certutils.py b/netlib/certutils.py index 94294f6e..139203b9 100644 --- a/netlib/certutils.py +++ b/netlib/certutils.py @@ -3,6 +3,7 @@ from pyasn1.type import univ, constraint, char, namedtype, tag from pyasn1.codec.der.decoder import decode from pyasn1.error import PyAsn1Error import OpenSSL +from netlib.stateobject import StateObject import tcp default_exp = 62208000 # =24 * 60 * 60 * 720 @@ -152,13 +153,22 @@ class _GeneralNames(univ.SequenceOf): sizeSpec = univ.SequenceOf.sizeSpec + constraint.ValueSizeConstraint(1, 1024) -class SSLCert: +class SSLCert(StateObject): def __init__(self, cert): """ Returns a (common name, [subject alternative names]) tuple. """ self.x509 = cert + def _get_state(self): + return self.to_pem() + + def _load_state(self, state): + self.x509 = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, state) + + def _from_state(cls, state): + return cls.from_pem(state) + @classmethod def from_pem(klass, txt): x509 = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, txt) diff --git a/netlib/odict.py b/netlib/odict.py index 0759a5bf..8e195afc 100644 --- a/netlib/odict.py +++ b/netlib/odict.py @@ -1,4 +1,6 @@ import re, copy +from netlib.stateobject import StateObject + def safe_subn(pattern, repl, target, *args, **kwargs): """ @@ -9,7 +11,7 @@ def safe_subn(pattern, repl, target, *args, **kwargs): return re.subn(str(pattern), str(repl), target, *args, **kwargs) -class ODict: +class ODict(StateObject): """ A dictionary-like object for managing ordered (key, value) data. """ @@ -98,6 +100,9 @@ class ODict: def _get_state(self): return [tuple(i) for i in self.lst] + def _load_state(self, state): + self.list = [list(i) for i in state] + @classmethod def _from_state(klass, state): return klass([list(i) for i in state]) diff --git a/netlib/stateobject.py b/netlib/stateobject.py new file mode 100644 index 00000000..c2ef2cd4 --- /dev/null +++ b/netlib/stateobject.py @@ -0,0 +1,80 @@ +from types import ClassType + + +class StateObject: + def _get_state(self): + raise NotImplementedError + + def _load_state(self, state): + raise NotImplementedError + + @classmethod + def _from_state(cls, state): + raise NotImplementedError + + def __eq__(self, other): + try: + return self._get_state() == other._get_state() + except AttributeError: # we may compare with something that's not a StateObject + return False + + +class SimpleStateObject(StateObject): + """ + A StateObject with opionated conventions that tries to keep everything DRY. + + Simply put, you agree on a list of attributes and their type. + Attributes can either be primitive types(str, tuple, bool, ...) or StateObject instances themselves. + SimpleStateObject uses this information for the default _get_state(), _from_state(s) and _load_state(s) methods. + Overriding _get_state or _load_state to add custom adjustments is always possible. + """ + + _stateobject_attributes = None # none by default to raise an exception if definition was forgotten + """ + An attribute-name -> class-or-type dict containing all attributes that should be serialized + If the attribute is a class, this class must be a subclass of StateObject. + """ + + def _get_state(self): + return {attr: self.__get_state_attr(attr, cls) + for attr, cls in self._stateobject_attributes.iteritems()} + + def __get_state_attr(self, attr, cls): + """ + helper for _get_state. + returns the value of the given attribute + """ + if getattr(self, attr) is None: + return None + if isinstance(cls, ClassType): + return getattr(self, attr)._get_state() + else: + return getattr(self, attr) + + def _load_state(self, state): + for attr, cls in self._stateobject_attributes.iteritems(): + self.__load_state_attr(attr, cls, state) + + def __load_state_attr(self, attr, cls, state): + """ + helper for _load_state. + loads the given attribute from the state. + """ + if state[attr] is not None: # First, catch None as value. + if isinstance(cls, ClassType): # Is the attribute a StateObject itself? + assert issubclass(cls, StateObject) + curr = getattr(self, attr) + if curr: # if the attribute is already present, delegate to the objects ._load_state method. + curr._load_state(state[attr]) + else: # otherwise, create a new object. + setattr(self, attr, cls._from_state(state[attr])) + else: + setattr(self, attr, cls(state[attr])) + else: + setattr(self, attr, None) + + @classmethod + def _from_state(cls, state): + f = cls() # the default implementation assumes an empty constructor. Override accordingly. + f._load_state(state) + return f \ No newline at end of file diff --git a/netlib/tcp.py b/netlib/tcp.py index 729e513e..c26d1191 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -1,6 +1,7 @@ import select, socket, threading, sys, time, traceback from OpenSSL import SSL import certutils +from netlib.stateobject import StateObject SSLv2_METHOD = SSL.SSLv2_METHOD SSLv3_METHOD = SSL.SSLv3_METHOD @@ -173,14 +174,13 @@ class Reader(_FileLike): return result -class Address(tuple): +class Address(StateObject): """ This class wraps an IPv4/IPv6 tuple to provide named attributes and ipv6 information. """ - def __new__(cls, address, use_ipv6=False): - a = super(Address, cls).__new__(cls, tuple(address)) - a.family = socket.AF_INET6 if use_ipv6 else socket.AF_INET - return a + def __init__(self, address, use_ipv6=False): + self.address = address + self.family = socket.AF_INET6 if use_ipv6 else socket.AF_INET @classmethod def wrap(cls, t): @@ -189,18 +189,35 @@ class Address(tuple): else: return cls(t) + def __call__(self): + return self.address + @property def host(self): - return self[0] + return self.address[0] @property def port(self): - return self[1] + return self.address[1] @property - def is_ipv6(self): + def use_ipv6(self): return self.family == socket.AF_INET6 + def _load_state(self, state): + self.address = state["address"] + self.family = socket.AF_INET6 if state["use_ipv6"] else socket.AF_INET + + def _get_state(self): + return dict( + address=self.address, + use_ipv6=self.use_ipv6 + ) + + @classmethod + def _from_state(cls, state): + return cls(**state) + class SocketCloseMixin: def finish(self): @@ -240,7 +257,7 @@ class TCPClient(SocketCloseMixin): wbufsize = -1 def __init__(self, address, source_address=None): self.address = Address.wrap(address) - self.source_address = source_address + self.source_address = Address.wrap(source_address) if source_address else None self.connection, self.rfile, self.wfile = None, None, None self.cert = None self.ssl_established = False @@ -275,12 +292,12 @@ class TCPClient(SocketCloseMixin): try: connection = socket.socket(self.address.family, socket.SOCK_STREAM) if self.source_address: - connection.bind(self.source_address) - connection.connect(self.address) + connection.bind(self.source_address()) + connection.connect(self.address()) self.rfile = Reader(connection.makefile('rb', self.rbufsize)) self.wfile = Writer(connection.makefile('wb', self.wbufsize)) except (socket.error, IOError), err: - raise NetLibError('Error connecting to "%s": %s' % (self.address[0], err)) + raise NetLibError('Error connecting to "%s": %s' % (self.address.host, err)) self.connection = connection def settimeout(self, n): @@ -376,7 +393,7 @@ class TCPServer: self.__shutdown_request = False self.socket = socket.socket(self.address.family, socket.SOCK_STREAM) self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - self.socket.bind(self.address) + self.socket.bind(self.address()) self.address = Address.wrap(self.socket.getsockname()) self.socket.listen(self.request_queue_size) @@ -427,7 +444,7 @@ class TCPServer: if traceback: exc = traceback.format_exc() print >> fp, '-'*40 - print >> fp, "Error in processing of request from %s:%s"%client_address + print >> fp, "Error in processing of request from %s:%s" % (client_address.host, client_address.port) print >> fp, exc print >> fp, '-'*40 -- cgit v1.2.3 From dc45b4bf19bff5edc0b72ccb68fad04d479aff83 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Fri, 31 Jan 2014 01:06:53 +0100 Subject: move StateObject back into libmproxy --- netlib/certutils.py | 12 +------- netlib/odict.py | 3 +- netlib/stateobject.py | 80 --------------------------------------------------- netlib/tcp.py | 21 ++++---------- 4 files changed, 7 insertions(+), 109 deletions(-) delete mode 100644 netlib/stateobject.py (limited to 'netlib') diff --git a/netlib/certutils.py b/netlib/certutils.py index 139203b9..94294f6e 100644 --- a/netlib/certutils.py +++ b/netlib/certutils.py @@ -3,7 +3,6 @@ from pyasn1.type import univ, constraint, char, namedtype, tag from pyasn1.codec.der.decoder import decode from pyasn1.error import PyAsn1Error import OpenSSL -from netlib.stateobject import StateObject import tcp default_exp = 62208000 # =24 * 60 * 60 * 720 @@ -153,22 +152,13 @@ class _GeneralNames(univ.SequenceOf): sizeSpec = univ.SequenceOf.sizeSpec + constraint.ValueSizeConstraint(1, 1024) -class SSLCert(StateObject): +class SSLCert: def __init__(self, cert): """ Returns a (common name, [subject alternative names]) tuple. """ self.x509 = cert - def _get_state(self): - return self.to_pem() - - def _load_state(self, state): - self.x509 = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, state) - - def _from_state(cls, state): - return cls.from_pem(state) - @classmethod def from_pem(klass, txt): x509 = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, txt) diff --git a/netlib/odict.py b/netlib/odict.py index 8e195afc..46b74e8e 100644 --- a/netlib/odict.py +++ b/netlib/odict.py @@ -1,5 +1,4 @@ import re, copy -from netlib.stateobject import StateObject def safe_subn(pattern, repl, target, *args, **kwargs): @@ -11,7 +10,7 @@ def safe_subn(pattern, repl, target, *args, **kwargs): return re.subn(str(pattern), str(repl), target, *args, **kwargs) -class ODict(StateObject): +class ODict: """ A dictionary-like object for managing ordered (key, value) data. """ diff --git a/netlib/stateobject.py b/netlib/stateobject.py deleted file mode 100644 index c2ef2cd4..00000000 --- a/netlib/stateobject.py +++ /dev/null @@ -1,80 +0,0 @@ -from types import ClassType - - -class StateObject: - def _get_state(self): - raise NotImplementedError - - def _load_state(self, state): - raise NotImplementedError - - @classmethod - def _from_state(cls, state): - raise NotImplementedError - - def __eq__(self, other): - try: - return self._get_state() == other._get_state() - except AttributeError: # we may compare with something that's not a StateObject - return False - - -class SimpleStateObject(StateObject): - """ - A StateObject with opionated conventions that tries to keep everything DRY. - - Simply put, you agree on a list of attributes and their type. - Attributes can either be primitive types(str, tuple, bool, ...) or StateObject instances themselves. - SimpleStateObject uses this information for the default _get_state(), _from_state(s) and _load_state(s) methods. - Overriding _get_state or _load_state to add custom adjustments is always possible. - """ - - _stateobject_attributes = None # none by default to raise an exception if definition was forgotten - """ - An attribute-name -> class-or-type dict containing all attributes that should be serialized - If the attribute is a class, this class must be a subclass of StateObject. - """ - - def _get_state(self): - return {attr: self.__get_state_attr(attr, cls) - for attr, cls in self._stateobject_attributes.iteritems()} - - def __get_state_attr(self, attr, cls): - """ - helper for _get_state. - returns the value of the given attribute - """ - if getattr(self, attr) is None: - return None - if isinstance(cls, ClassType): - return getattr(self, attr)._get_state() - else: - return getattr(self, attr) - - def _load_state(self, state): - for attr, cls in self._stateobject_attributes.iteritems(): - self.__load_state_attr(attr, cls, state) - - def __load_state_attr(self, attr, cls, state): - """ - helper for _load_state. - loads the given attribute from the state. - """ - if state[attr] is not None: # First, catch None as value. - if isinstance(cls, ClassType): # Is the attribute a StateObject itself? - assert issubclass(cls, StateObject) - curr = getattr(self, attr) - if curr: # if the attribute is already present, delegate to the objects ._load_state method. - curr._load_state(state[attr]) - else: # otherwise, create a new object. - setattr(self, attr, cls._from_state(state[attr])) - else: - setattr(self, attr, cls(state[attr])) - else: - setattr(self, attr, None) - - @classmethod - def _from_state(cls, state): - f = cls() # the default implementation assumes an empty constructor. Override accordingly. - f._load_state(state) - return f \ No newline at end of file diff --git a/netlib/tcp.py b/netlib/tcp.py index c26d1191..346bc053 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -1,7 +1,6 @@ import select, socket, threading, sys, time, traceback from OpenSSL import SSL import certutils -from netlib.stateobject import StateObject SSLv2_METHOD = SSL.SSLv2_METHOD SSLv3_METHOD = SSL.SSLv3_METHOD @@ -174,13 +173,13 @@ class Reader(_FileLike): return result -class Address(StateObject): +class Address(object): """ This class wraps an IPv4/IPv6 tuple to provide named attributes and ipv6 information. """ def __init__(self, address, use_ipv6=False): self.address = address - self.family = socket.AF_INET6 if use_ipv6 else socket.AF_INET + self.use_ipv6 = use_ipv6 @classmethod def wrap(cls, t): @@ -204,19 +203,9 @@ class Address(StateObject): def use_ipv6(self): return self.family == socket.AF_INET6 - def _load_state(self, state): - self.address = state["address"] - self.family = socket.AF_INET6 if state["use_ipv6"] else socket.AF_INET - - def _get_state(self): - return dict( - address=self.address, - use_ipv6=self.use_ipv6 - ) - - @classmethod - def _from_state(cls, state): - return cls(**state) + @use_ipv6.setter + def use_ipv6(self, b): + self.family = socket.AF_INET6 if b else socket.AF_INET class SocketCloseMixin: -- cgit v1.2.3 From 0bbc40dc33dd7bd3729e639874882dd6dd7ea818 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Tue, 4 Feb 2014 04:51:41 +0100 Subject: store used sni in TCPClient, add equality check for tcp.Address --- netlib/tcp.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) (limited to 'netlib') diff --git a/netlib/tcp.py b/netlib/tcp.py index 346bc053..94ea8806 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -207,8 +207,12 @@ class Address(object): def use_ipv6(self, b): self.family = socket.AF_INET6 if b else socket.AF_INET + def __eq__(self, other): + other = Address.wrap(other) + return (self.address, self.family) == (other.address, other.family) -class SocketCloseMixin: + +class SocketCloseMixin(object): def finish(self): self.finished = True try: @@ -250,6 +254,7 @@ class TCPClient(SocketCloseMixin): self.connection, self.rfile, self.wfile = None, None, None self.cert = None self.ssl_established = False + self.sni = None def convert_to_ssl(self, cert=None, sni=None, method=TLSv1_METHOD, options=None): """ @@ -267,6 +272,7 @@ class TCPClient(SocketCloseMixin): self.connection = SSL.Connection(context, self.connection) self.ssl_established = True if sni: + self.sni = sni self.connection.set_tlsext_host_name(sni) self.connection.set_connect_state() try: -- cgit v1.2.3 From 7fc544bc7ff8fd610ba9db92c0d3b59a0b040b5b Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Wed, 5 Feb 2014 21:34:14 +0100 Subject: adjust netlib.wsgi to reflect changes in mitmproxys flow format --- netlib/tcp.py | 2 +- netlib/wsgi.py | 15 ++++++++++----- 2 files changed, 11 insertions(+), 6 deletions(-) (limited to 'netlib') diff --git a/netlib/tcp.py b/netlib/tcp.py index 94ea8806..34e47999 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -178,7 +178,7 @@ class Address(object): This class wraps an IPv4/IPv6 tuple to provide named attributes and ipv6 information. """ def __init__(self, address, use_ipv6=False): - self.address = address + self.address = tuple(address) self.use_ipv6 = use_ipv6 @classmethod diff --git a/netlib/wsgi.py b/netlib/wsgi.py index 647cb899..b576bdff 100644 --- a/netlib/wsgi.py +++ b/netlib/wsgi.py @@ -1,17 +1,22 @@ import cStringIO, urllib, time, traceback -import odict +import odict, tcp class ClientConn: def __init__(self, address): - self.address = address + self.address = tcp.Address.wrap(address) + + +class Flow: + def __init__(self, client_conn): + self.client_conn = client_conn class Request: def __init__(self, client_conn, scheme, method, path, headers, content): self.scheme, self.method, self.path = scheme, method, path self.headers, self.content = headers, content - self.client_conn = client_conn + self.flow = Flow(client_conn) def date_time_string(): @@ -60,8 +65,8 @@ class WSGIAdaptor: 'SERVER_PROTOCOL': "HTTP/1.1", } environ.update(extra) - if request.client_conn.address: - environ["REMOTE_ADDR"], environ["REMOTE_PORT"] = request.client_conn.address + if request.flow.client_conn.address: + environ["REMOTE_ADDR"], environ["REMOTE_PORT"] = request.flow.client_conn.address() for key, value in request.headers.items(): key = 'HTTP_' + key.upper().replace('-', '_') -- cgit v1.2.3 From a72ae4d85c08b5716cd88715081be0f1ecaeb9d4 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Tue, 11 Feb 2014 12:09:58 +0100 Subject: Bump version Do it now already so that mitmproxy will warn the user if netlib is not from master. --- netlib/version.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'netlib') diff --git a/netlib/version.py b/netlib/version.py index 9b2e037e..1d3250e1 100644 --- a/netlib/version.py +++ b/netlib/version.py @@ -1,4 +1,4 @@ -IVERSION = (0, 10) +IVERSION = (0, 11) VERSION = ".".join(str(i) for i in IVERSION) NAME = "netlib" NAMEVERSION = NAME + " " + VERSION -- cgit v1.2.3 From c276b4294cac97c1281ce9bb4934e49d0ba970a2 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Sat, 15 Feb 2014 23:16:28 +0100 Subject: allow super() on TCPServer, add thread names for better debugging --- netlib/http_auth.py | 2 +- netlib/tcp.py | 9 ++++++--- 2 files changed, 7 insertions(+), 4 deletions(-) (limited to 'netlib') diff --git a/netlib/http_auth.py b/netlib/http_auth.py index be99fb3d..b0451e3b 100644 --- a/netlib/http_auth.py +++ b/netlib/http_auth.py @@ -1,4 +1,4 @@ -import contrib.md5crypt as md5crypt +from .contrib import md5crypt import http from argparse import Action, ArgumentTypeError diff --git a/netlib/tcp.py b/netlib/tcp.py index 34e47999..5c351bae 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -380,7 +380,7 @@ class BaseHandler(SocketCloseMixin): -class TCPServer: +class TCPServer(object): request_queue_size = 20 def __init__(self, address): self.address = Address.wrap(address) @@ -416,7 +416,10 @@ class TCPServer: connection, client_address = self.socket.accept() t = threading.Thread( target = self.connection_thread, - args = (connection, client_address) + args = (connection, client_address), + name = "ConnectionThread (%s:%s -> %s:%s)" % + (client_address[0], client_address[1], + self.address.host, self.address.port) ) t.setDaemon(1) t.start() @@ -443,7 +446,7 @@ class TCPServer: print >> fp, exc print >> fp, '-'*40 - def handle_client_connection(self, conn, client_address): # pragma: no cover + def handle_client_connection(self, conn, client_address): # pragma: no cover """ Called after client connection. """ -- cgit v1.2.3 From 3443bae94e090b0bf12005ef4f0ca474bd903fb1 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Thu, 27 Feb 2014 18:35:16 +1300 Subject: Cipher suite selection for client connections, improved error handling --- netlib/tcp.py | 19 ++++++++++++++++--- 1 file changed, 16 insertions(+), 3 deletions(-) (limited to 'netlib') diff --git a/netlib/tcp.py b/netlib/tcp.py index 5c351bae..23449baf 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -235,7 +235,8 @@ class SocketCloseMixin(object): self.connection.sock_shutdown(socket.SHUT_WR) else: self.connection.shutdown(socket.SHUT_WR) - #Section 4.2.2.13 of RFC 1122 tells us that a close() with any pending readable data could lead to an immediate RST being sent. + #Section 4.2.2.13 of RFC 1122 tells us that a close() with any + # pending readable data could lead to an immediate RST being sent. #http://ia600609.us.archive.org/22/items/TheUltimateSo_lingerPageOrWhyIsMyTcpNotReliable/the-ultimate-so_linger-page-or-why-is-my-tcp-not-reliable.html while self.connection.recv(4096): pass @@ -256,11 +257,16 @@ class TCPClient(SocketCloseMixin): self.ssl_established = False self.sni = None - def convert_to_ssl(self, cert=None, sni=None, method=TLSv1_METHOD, options=None): + def convert_to_ssl(self, cert=None, sni=None, method=TLSv1_METHOD, options=None, cipher_list=None): """ cert: Path to a file containing both client cert and private key. """ context = SSL.Context(method) + if cipher_list: + try: + context.set_cipher_list(cipher_list) + except SSL.Error, v: + raise NetLibError("SSL cipher specification error: %s"%str(v)) if options is not None: context.set_options(options) if cert: @@ -350,7 +356,10 @@ class BaseHandler(SocketCloseMixin): if not options is None: ctx.set_options(options) if cipher_list: - ctx.set_cipher_list(cipher_list) + try: + ctx.set_cipher_list(cipher_list) + except SSL.Error, v: + raise NetLibError("SSL cipher specification error: %s"%str(v)) if handle_sni: # SNI callback happens during do_handshake() ctx.set_tlsext_servername_callback(handle_sni) @@ -399,6 +408,10 @@ class TCPServer(object): except: self.handle_error(connection, client_address) finally: + try: + connection.shutdown(socket.SHUT_RDWR) + except: + pass connection.close() def serve_forever(self, poll_interval=0.1): -- cgit v1.2.3 From 7788391903ef67ed1e779560936d60402159f8f5 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Sun, 2 Mar 2014 13:50:19 +1300 Subject: Minor improvement to CertStore interface --- netlib/certutils.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) (limited to 'netlib') diff --git a/netlib/certutils.py b/netlib/certutils.py index 94294f6e..0b29d52f 100644 --- a/netlib/certutils.py +++ b/netlib/certutils.py @@ -113,10 +113,11 @@ class CertStore: """ Implements an in-memory certificate store. """ - def __init__(self): + def __init__(self, cacert): self.certs = {} + self.cacert = cacert - def get_cert(self, commonname, sans, cacert): + def get_cert(self, commonname, sans): """ Returns an SSLCert object. @@ -125,13 +126,11 @@ class CertStore: sans: A list of Subject Alternate Names. - cacert: The path to a CA certificate. - Return None if the certificate could not be found or generated. """ if commonname in self.certs: return self.certs[commonname] - c = dummy_cert(cacert, commonname, sans) + c = dummy_cert(self.cacert, commonname, sans) self.certs[commonname] = c return c -- cgit v1.2.3 From e381c0366863ae412547e16d67860137a6b89a32 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Sun, 2 Mar 2014 16:47:10 +1300 Subject: Cleanups, tests, and no-cover directives for code sections we can't test. --- netlib/odict.py | 10 ---------- netlib/tcp.py | 8 +++++--- 2 files changed, 5 insertions(+), 13 deletions(-) (limited to 'netlib') diff --git a/netlib/odict.py b/netlib/odict.py index 46b74e8e..7c743f4e 100644 --- a/netlib/odict.py +++ b/netlib/odict.py @@ -96,16 +96,6 @@ class ODict: def items(self): return self.lst[:] - def _get_state(self): - return [tuple(i) for i in self.lst] - - def _load_state(self, state): - self.list = [list(i) for i in state] - - @classmethod - def _from_state(klass, state): - return klass([list(i) for i in state]) - def copy(self): """ Returns a copy of this object. diff --git a/netlib/tcp.py b/netlib/tcp.py index 23449baf..8f2ebdf0 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -2,6 +2,8 @@ import select, socket, threading, sys, time, traceback from OpenSSL import SSL import certutils +EINTR = 4 + SSLv2_METHOD = SSL.SSLv2_METHOD SSLv3_METHOD = SSL.SSLv3_METHOD SSLv23_METHOD = SSL.SSLv23_METHOD @@ -238,7 +240,7 @@ class SocketCloseMixin(object): #Section 4.2.2.13 of RFC 1122 tells us that a close() with any # pending readable data could lead to an immediate RST being sent. #http://ia600609.us.archive.org/22/items/TheUltimateSo_lingerPageOrWhyIsMyTcpNotReliable/the-ultimate-so_linger-page-or-why-is-my-tcp-not-reliable.html - while self.connection.recv(4096): + while self.connection.recv(4096): # pragma: no cover pass self.connection.close() except (socket.error, SSL.Error, IOError): @@ -420,8 +422,8 @@ class TCPServer(object): while not self.__shutdown_request: try: r, w, e = select.select([self.socket], [], [], poll_interval) - except select.error, ex: - if ex[0] == 4: + except select.error, ex: # pragma: no cover + if ex[0] == EINTR: continue else: raise -- cgit v1.2.3 From 1acaf1c880ba7054e4eb1cc1ed4ea5d0cf852e61 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Sun, 2 Mar 2014 16:54:21 +1300 Subject: Re-add state operations to ODict. --- netlib/odict.py | 10 ++++++++++ 1 file changed, 10 insertions(+) (limited to 'netlib') diff --git a/netlib/odict.py b/netlib/odict.py index 7c743f4e..46b74e8e 100644 --- a/netlib/odict.py +++ b/netlib/odict.py @@ -96,6 +96,16 @@ class ODict: def items(self): return self.lst[:] + def _get_state(self): + return [tuple(i) for i in self.lst] + + def _load_state(self, state): + self.list = [list(i) for i in state] + + @classmethod + def _from_state(klass, state): + return klass([list(i) for i in state]) + def copy(self): """ Returns a copy of this object. -- cgit v1.2.3 From cfaa3da25cee39c5395a6ff27dfc47ff07dbeef6 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Sun, 2 Mar 2014 21:37:28 +1300 Subject: Use PyOpenSSL's underlying ffi interface to get current cipher for connections. --- netlib/tcp.py | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) (limited to 'netlib') diff --git a/netlib/tcp.py b/netlib/tcp.py index 8f2ebdf0..0dff807b 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -2,6 +2,7 @@ import select, socket, threading, sys, time, traceback from OpenSSL import SSL import certutils + EINTR = 4 SSLv2_METHOD = SSL.SSLv2_METHOD @@ -214,7 +215,16 @@ class Address(object): return (self.address, self.family) == (other.address, other.family) -class SocketCloseMixin(object): +class _Connection(object): + def get_current_cipher(self): + if not self.ssl_established: + return None + c = SSL._lib.SSL_get_current_cipher(self.connection._ssl) + name = SSL._native(SSL._ffi.string(SSL._lib.SSL_CIPHER_get_name(c))) + bits = SSL._lib.SSL_CIPHER_get_bits(c, SSL._ffi.NULL) + version = SSL._native(SSL._ffi.string(SSL._lib.SSL_CIPHER_get_version(c))) + return name, bits, version + def finish(self): self.finished = True try: @@ -248,7 +258,7 @@ class SocketCloseMixin(object): pass -class TCPClient(SocketCloseMixin): +class TCPClient(_Connection): rbufsize = -1 wbufsize = -1 def __init__(self, address, source_address=None): @@ -310,7 +320,7 @@ class TCPClient(SocketCloseMixin): return self.connection.gettimeout() -class BaseHandler(SocketCloseMixin): +class BaseHandler(_Connection): """ The instantiator is expected to call the handle() and finish() methods. -- cgit v1.2.3 From 7c82418e0baca311487230074655f5f106bcdd2b Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Tue, 4 Mar 2014 14:12:58 +1300 Subject: Beef up CertStore, add DH params. --- netlib/certutils.py | 157 ++++++++++++++++++++++++++++------------------------ 1 file changed, 85 insertions(+), 72 deletions(-) (limited to 'netlib') diff --git a/netlib/certutils.py b/netlib/certutils.py index 0b29d52f..b9c291d0 100644 --- a/netlib/certutils.py +++ b/netlib/certutils.py @@ -5,23 +5,27 @@ from pyasn1.error import PyAsn1Error import OpenSSL import tcp -default_exp = 62208000 # =24 * 60 * 60 * 720 -default_o = "mitmproxy" -default_cn = "mitmproxy" - -def create_ca(o=default_o, cn=default_cn, exp=default_exp): +DEFAULT_EXP = 62208000 # =24 * 60 * 60 * 720 +# Generated with "openssl dhparam". It's too slow to generate this on startup. +DEFAULT_DHPARAM = """-----BEGIN DH PARAMETERS----- +MIGHAoGBAOdPzMbYgoYfO3YBYauCLRlE8X1XypTiAjoeCFD0qWRx8YUsZ6Sj20W5 +zsfQxlZfKovo3f2MftjkDkbI/C/tDgxoe0ZPbjy5CjdOhkzxn0oTbKTs16Rw8DyK +1LjTR65sQJkJEdgsX8TSi/cicCftJZl9CaZEaObF2bdgSgGK+PezAgEC +-----END DH PARAMETERS-----""" + +def create_ca(o, cn, exp): key = OpenSSL.crypto.PKey() key.generate_key(OpenSSL.crypto.TYPE_RSA, 1024) - ca = OpenSSL.crypto.X509() - ca.set_serial_number(int(time.time()*10000)) - ca.set_version(2) - ca.get_subject().CN = cn - ca.get_subject().O = o - ca.gmtime_adj_notBefore(0) - ca.gmtime_adj_notAfter(exp) - ca.set_issuer(ca.get_subject()) - ca.set_pubkey(key) - ca.add_extensions([ + cert = OpenSSL.crypto.X509() + cert.set_serial_number(int(time.time()*10000)) + cert.set_version(2) + cert.get_subject().CN = cn + cert.get_subject().O = o + cert.gmtime_adj_notBefore(0) + cert.gmtime_adj_notAfter(exp) + cert.set_issuer(cert.get_subject()) + cert.set_pubkey(key) + cert.add_extensions([ OpenSSL.crypto.X509Extension("basicConstraints", True, "CA:TRUE"), OpenSSL.crypto.X509Extension("nsCertType", True, @@ -32,80 +36,39 @@ def create_ca(o=default_o, cn=default_cn, exp=default_exp): OpenSSL.crypto.X509Extension("keyUsage", False, "keyCertSign, cRLSign"), OpenSSL.crypto.X509Extension("subjectKeyIdentifier", False, "hash", - subject=ca), + subject=cert), ]) - ca.sign(key, "sha1") - return key, ca - - -def dummy_ca(path, o=default_o, cn=default_cn, exp=default_exp): - dirname = os.path.dirname(path) - if not os.path.exists(dirname): - os.makedirs(dirname) - if path.endswith(".pem"): - basename, _ = os.path.splitext(path) - basename = os.path.basename(basename) - else: - basename = os.path.basename(path) - - key, ca = create_ca(o=o, cn=cn, exp=exp) - - # Dump the CA plus private key - f = open(path, "wb") - f.write(OpenSSL.crypto.dump_privatekey(OpenSSL.crypto.FILETYPE_PEM, key)) - f.write(OpenSSL.crypto.dump_certificate(OpenSSL.crypto.FILETYPE_PEM, ca)) - f.close() - - # Dump the certificate in PEM format - f = open(os.path.join(dirname, basename + "-cert.pem"), "wb") - f.write(OpenSSL.crypto.dump_certificate(OpenSSL.crypto.FILETYPE_PEM, ca)) - f.close() - - # Create a .cer file with the same contents for Android - f = open(os.path.join(dirname, basename + "-cert.cer"), "wb") - f.write(OpenSSL.crypto.dump_certificate(OpenSSL.crypto.FILETYPE_PEM, ca)) - f.close() - - # Dump the certificate in PKCS12 format for Windows devices - f = open(os.path.join(dirname, basename + "-cert.p12"), "wb") - p12 = OpenSSL.crypto.PKCS12() - p12.set_certificate(ca) - p12.set_privatekey(key) - f.write(p12.export()) - f.close() - return True - - -def dummy_cert(ca, commonname, sans): + cert.sign(key, "sha1") + return key, cert + + +def dummy_cert(pkey, cacert, commonname, sans): """ - Generates and writes a certificate to fp. + Generates a dummy certificate. - ca: Path to the certificate authority file, or None. + pkey: CA private key + cacert: CA certificate commonname: Common name for the generated certificate. sans: A list of Subject Alternate Names. - Returns cert path if operation succeeded, None if not. + Returns cert if operation succeeded, None if not. """ ss = [] for i in sans: ss.append("DNS: %s"%i) ss = ", ".join(ss) - raw = file(ca, "rb").read() - ca = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, raw) - key = OpenSSL.crypto.load_privatekey(OpenSSL.crypto.FILETYPE_PEM, raw) - cert = OpenSSL.crypto.X509() cert.gmtime_adj_notBefore(-3600*48) cert.gmtime_adj_notAfter(60 * 60 * 24 * 30) - cert.set_issuer(ca.get_subject()) + cert.set_issuer(cacert.get_subject()) cert.get_subject().CN = commonname cert.set_serial_number(int(time.time()*10000)) if ss: cert.set_version(2) cert.add_extensions([OpenSSL.crypto.X509Extension("subjectAltName", True, ss)]) - cert.set_pubkey(ca.get_pubkey()) - cert.sign(key, "sha1") + cert.set_pubkey(cacert.get_pubkey()) + cert.sign(pkey, "sha1") return SSLCert(cert) @@ -113,9 +76,59 @@ class CertStore: """ Implements an in-memory certificate store. """ - def __init__(self, cacert): + def __init__(self, pkey, cert): + self.pkey, self.cert = pkey, cert self.certs = {} - self.cacert = cacert + + @classmethod + def from_store(klass, path, basename): + p = os.path.join(path, basename + "-ca.pem") + if not os.path.exists(p): + key, ca = klass.create_store(path, basename) + else: + p = os.path.join(path, basename + "-ca.pem") + raw = file(p, "rb").read() + ca = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, raw) + key = OpenSSL.crypto.load_privatekey(OpenSSL.crypto.FILETYPE_PEM, raw) + return klass(key, ca) + + @classmethod + def create_store(klass, path, basename, o=None, cn=None, expiry=DEFAULT_EXP): + if not os.path.exists(path): + os.makedirs(path) + + o = o or basename + cn = cn or basename + + key, ca = create_ca(o=o, cn=cn, exp=expiry) + # Dump the CA plus private key + f = open(os.path.join(path, basename + "-ca.pem"), "wb") + f.write(OpenSSL.crypto.dump_privatekey(OpenSSL.crypto.FILETYPE_PEM, key)) + f.write(OpenSSL.crypto.dump_certificate(OpenSSL.crypto.FILETYPE_PEM, ca)) + f.close() + + # Dump the certificate in PEM format + f = open(os.path.join(path, basename + "-cert.pem"), "wb") + f.write(OpenSSL.crypto.dump_certificate(OpenSSL.crypto.FILETYPE_PEM, ca)) + f.close() + + # Create a .cer file with the same contents for Android + f = open(os.path.join(path, basename + "-cert.cer"), "wb") + f.write(OpenSSL.crypto.dump_certificate(OpenSSL.crypto.FILETYPE_PEM, ca)) + f.close() + + # Dump the certificate in PKCS12 format for Windows devices + f = open(os.path.join(path, basename + "-cert.p12"), "wb") + p12 = OpenSSL.crypto.PKCS12() + p12.set_certificate(ca) + p12.set_privatekey(key) + f.write(p12.export()) + f.close() + + f = open(os.path.join(path, basename + "-dhparam.pem"), "wb") + f.write(DEFAULT_DHPARAM) + f.close() + return key, ca def get_cert(self, commonname, sans): """ @@ -130,7 +143,7 @@ class CertStore: """ if commonname in self.certs: return self.certs[commonname] - c = dummy_cert(self.cacert, commonname, sans) + c = dummy_cert(self.pkey, self.cert, commonname, sans) self.certs[commonname] = c return c -- cgit v1.2.3 From 0c3bc1cff2a8b1c4c425be5c1ca11c4b850bcc68 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Wed, 5 Mar 2014 13:19:16 +1300 Subject: Much more sophisticated certificate store - Handle wildcard lookup - Handle lookup of SANs - Provide hooks for registering override certs and keys for specific domains (including wildcard specifications) --- netlib/certutils.py | 87 +++++++++++++++++++++++++++++++++++++++++++++-------- 1 file changed, 75 insertions(+), 12 deletions(-) (limited to 'netlib') diff --git a/netlib/certutils.py b/netlib/certutils.py index b9c291d0..fafcb5fd 100644 --- a/netlib/certutils.py +++ b/netlib/certutils.py @@ -4,6 +4,7 @@ from pyasn1.codec.der.decoder import decode from pyasn1.error import PyAsn1Error import OpenSSL import tcp +import UserDict DEFAULT_EXP = 62208000 # =24 * 60 * 60 * 720 # Generated with "openssl dhparam". It's too slow to generate this on startup. @@ -42,11 +43,11 @@ def create_ca(o, cn, exp): return key, cert -def dummy_cert(pkey, cacert, commonname, sans): +def dummy_cert(privkey, cacert, commonname, sans): """ Generates a dummy certificate. - pkey: CA private key + privkey: CA private key cacert: CA certificate commonname: Common name for the generated certificate. sans: A list of Subject Alternate Names. @@ -68,17 +69,55 @@ def dummy_cert(pkey, cacert, commonname, sans): cert.set_version(2) cert.add_extensions([OpenSSL.crypto.X509Extension("subjectAltName", True, ss)]) cert.set_pubkey(cacert.get_pubkey()) - cert.sign(pkey, "sha1") + cert.sign(privkey, "sha1") return SSLCert(cert) +class _Node(UserDict.UserDict): + def __init__(self): + UserDict.UserDict.__init__(self) + self.value = None + + +class DNTree: + """ + Domain store that knows about wildcards. DNS wildcards are very + restricted - the only valid variety is an asterisk on the left-most + domain component, i.e.: + + *.foo.com + """ + def __init__(self): + self.d = _Node() + + def add(self, dn, cert): + parts = dn.split(".") + parts.reverse() + current = self.d + for i in parts: + current = current.setdefault(i, _Node()) + current.value = cert + + def get(self, dn): + parts = dn.split(".") + current = self.d + for i in reversed(parts): + if i in current: + current = current[i] + elif "*" in current: + return current["*"].value + else: + return None + return current.value + + class CertStore: """ Implements an in-memory certificate store. """ - def __init__(self, pkey, cert): - self.pkey, self.cert = pkey, cert - self.certs = {} + def __init__(self, privkey, cacert): + self.privkey, self.cacert = privkey, cacert + self.certs = DNTree() @classmethod def from_store(klass, path, basename): @@ -130,9 +169,29 @@ class CertStore: f.close() return key, ca + def add_cert_file(self, commonname, path): + raw = file(path, "rb").read() + cert = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, raw) + try: + privkey = OpenSSL.crypto.load_privatekey(OpenSSL.crypto.FILETYPE_PEM, raw) + except Exception: + privkey = None + self.add_cert(SSLCert(cert), privkey, commonname) + + def add_cert(self, cert, privkey, *names): + """ + Adds a cert to the certstore. We register the CN in the cert plus + any SANs, and also the list of names provided as an argument. + """ + self.certs.add(cert.cn, (cert, privkey)) + for i in cert.altnames: + self.certs.add(i, (cert, privkey)) + for i in names: + self.certs.add(i, (cert, privkey)) + def get_cert(self, commonname, sans): """ - Returns an SSLCert object. + Returns an (cert, privkey) tuple. commonname: Common name for the generated certificate. Must be a valid, plain-ASCII, IDNA-encoded domain name. @@ -141,11 +200,12 @@ class CertStore: Return None if the certificate could not be found or generated. """ - if commonname in self.certs: - return self.certs[commonname] - c = dummy_cert(self.pkey, self.cert, commonname, sans) - self.certs[commonname] = c - return c + c = self.certs.get(commonname) + if not c: + c = dummy_cert(self.privkey, self.cacert, commonname, sans) + self.add_cert(c, None) + c = (c, None) + return (c[0], c[1] or self.privkey) class _GeneralName(univ.Choice): @@ -171,6 +231,9 @@ class SSLCert: """ self.x509 = cert + def __eq__(self, other): + return self.digest("sha1") == other.digest("sha1") + @classmethod def from_pem(klass, txt): x509 = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, txt) -- cgit v1.2.3 From 86730a9a4c3a14b510590aa97a8ae8989cb6ec5e Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Wed, 5 Mar 2014 13:43:52 +1300 Subject: Handler convert_to_ssl now takes a key object, not a path. --- netlib/tcp.py | 2 +- netlib/test.py | 8 ++++++-- 2 files changed, 7 insertions(+), 3 deletions(-) (limited to 'netlib') diff --git a/netlib/tcp.py b/netlib/tcp.py index 0dff807b..83059bc2 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -375,7 +375,7 @@ class BaseHandler(_Connection): if handle_sni: # SNI callback happens during do_handshake() ctx.set_tlsext_servername_callback(handle_sni) - ctx.use_privatekey_file(key) + ctx.use_privatekey(key) ctx.use_certificate(cert.x509) if request_client_cert: def ver(*args): diff --git a/netlib/test.py b/netlib/test.py index 2f6a7107..b88b3586 100644 --- a/netlib/test.py +++ b/netlib/test.py @@ -1,5 +1,6 @@ import threading, Queue, cStringIO import tcp, certutils +import OpenSSL class ServerThread(threading.Thread): def __init__(self, server): @@ -49,6 +50,8 @@ class TServer(tcp.TCPServer): self.handler_klass = handler_klass self.last_handler = None + + def handle_client_connection(self, request, client_address): h = self.handler_klass(request, client_address, self) self.last_handler = h @@ -56,6 +59,8 @@ class TServer(tcp.TCPServer): cert = certutils.SSLCert.from_pem( file(self.ssl["cert"], "rb").read() ) + raw = file(self.ssl["key"], "rb").read() + key = OpenSSL.crypto.load_privatekey(OpenSSL.crypto.FILETYPE_PEM, raw) if self.ssl["v3_only"]: method = tcp.SSLv3_METHOD options = tcp.OP_NO_SSLv2|tcp.OP_NO_TLSv1 @@ -63,8 +68,7 @@ class TServer(tcp.TCPServer): method = tcp.SSLv23_METHOD options = None h.convert_to_ssl( - cert, - self.ssl["key"], + cert, key, method = method, options = options, handle_sni = getattr(h, "handle_sni", None), -- cgit v1.2.3 From 52b14aa1d1bbeb3e2b8c62ee9939b9575ee1840f Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Wed, 5 Mar 2014 17:29:14 +1300 Subject: CertStore: cope with certs that have no common name --- netlib/certutils.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) (limited to 'netlib') diff --git a/netlib/certutils.py b/netlib/certutils.py index fafcb5fd..d544cfa6 100644 --- a/netlib/certutils.py +++ b/netlib/certutils.py @@ -169,21 +169,22 @@ class CertStore: f.close() return key, ca - def add_cert_file(self, commonname, path): + def add_cert_file(self, spec, path): raw = file(path, "rb").read() cert = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, raw) try: privkey = OpenSSL.crypto.load_privatekey(OpenSSL.crypto.FILETYPE_PEM, raw) except Exception: privkey = None - self.add_cert(SSLCert(cert), privkey, commonname) + self.add_cert(SSLCert(cert), privkey, spec) def add_cert(self, cert, privkey, *names): """ Adds a cert to the certstore. We register the CN in the cert plus any SANs, and also the list of names provided as an argument. """ - self.certs.add(cert.cn, (cert, privkey)) + if cert.cn: + self.certs.add(cert.cn, (cert, privkey)) for i in cert.altnames: self.certs.add(i, (cert, privkey)) for i in names: -- cgit v1.2.3 From 2a12aa3c47d57cc2d3a36f6726a5f081ca493457 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Fri, 7 Mar 2014 16:38:50 +1300 Subject: Support Ephemeral Diffie-Hellman --- netlib/certutils.py | 24 +++++++++++++++++++----- netlib/tcp.py | 7 ++++++- netlib/test.py | 11 ++++++----- 3 files changed, 31 insertions(+), 11 deletions(-) (limited to 'netlib') diff --git a/netlib/certutils.py b/netlib/certutils.py index d544cfa6..19148382 100644 --- a/netlib/certutils.py +++ b/netlib/certutils.py @@ -115,10 +115,22 @@ class CertStore: """ Implements an in-memory certificate store. """ - def __init__(self, privkey, cacert): + def __init__(self, privkey, cacert, dhparams=None): self.privkey, self.cacert = privkey, cacert + self.dhparams = dhparams self.certs = DNTree() + @classmethod + def load_dhparam(klass, path): + bio = OpenSSL.SSL._lib.BIO_new_file(path, b"r") + if bio != OpenSSL.SSL._ffi.NULL: + bio = OpenSSL.SSL._ffi.gc(bio, OpenSSL.SSL._lib.BIO_free) + dh = OpenSSL.SSL._lib.PEM_read_bio_DHparams( + bio, OpenSSL.SSL._ffi.NULL, OpenSSL.SSL._ffi.NULL, OpenSSL.SSL._ffi.NULL + ) + dh = OpenSSL.SSL._ffi.gc(dh, OpenSSL.SSL._lib.DH_free) + return dh + @classmethod def from_store(klass, path, basename): p = os.path.join(path, basename + "-ca.pem") @@ -129,7 +141,9 @@ class CertStore: raw = file(p, "rb").read() ca = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, raw) key = OpenSSL.crypto.load_privatekey(OpenSSL.crypto.FILETYPE_PEM, raw) - return klass(key, ca) + dhp = os.path.join(path, basename + "-dhparam.pem") + dh = klass.load_dhparam(dhp) + return klass(key, ca, dh) @classmethod def create_store(klass, path, basename, o=None, cn=None, expiry=DEFAULT_EXP): @@ -147,17 +161,17 @@ class CertStore: f.close() # Dump the certificate in PEM format - f = open(os.path.join(path, basename + "-cert.pem"), "wb") + f = open(os.path.join(path, basename + "-ca-cert.pem"), "wb") f.write(OpenSSL.crypto.dump_certificate(OpenSSL.crypto.FILETYPE_PEM, ca)) f.close() # Create a .cer file with the same contents for Android - f = open(os.path.join(path, basename + "-cert.cer"), "wb") + f = open(os.path.join(path, basename + "-ca-cert.cer"), "wb") f.write(OpenSSL.crypto.dump_certificate(OpenSSL.crypto.FILETYPE_PEM, ca)) f.close() # Dump the certificate in PKCS12 format for Windows devices - f = open(os.path.join(path, basename + "-cert.p12"), "wb") + f = open(os.path.join(path, basename + "-ca-cert.p12"), "wb") p12 = OpenSSL.crypto.PKCS12() p12.set_certificate(ca) p12.set_privatekey(key) diff --git a/netlib/tcp.py b/netlib/tcp.py index 83059bc2..078ac497 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -339,7 +339,10 @@ class BaseHandler(_Connection): self.ssl_established = False self.clientcert = None - def convert_to_ssl(self, cert, key, method=SSLv23_METHOD, options=None, handle_sni=None, request_client_cert=False, cipher_list=None): + def convert_to_ssl(self, cert, key, + method=SSLv23_METHOD, options=None, handle_sni=None, + request_client_cert=False, cipher_list=None, dhparams=None + ): """ cert: A certutils.SSLCert object. method: One of SSLv2_METHOD, SSLv3_METHOD, SSLv23_METHOD, or TLSv1_METHOD @@ -377,6 +380,8 @@ class BaseHandler(_Connection): ctx.set_tlsext_servername_callback(handle_sni) ctx.use_privatekey(key) ctx.use_certificate(cert.x509) + if dhparams: + SSL._lib.SSL_CTX_set_tmp_dh(ctx._context, dhparams) if request_client_cert: def ver(*args): self.clientcert = certutils.SSLCert(args[1]) diff --git a/netlib/test.py b/netlib/test.py index b88b3586..bb0012ad 100644 --- a/netlib/test.py +++ b/netlib/test.py @@ -18,7 +18,6 @@ class ServerTestBase: ssl = None handler = None addr = ("localhost", 0) - @classmethod def setupAll(cls): cls.q = Queue.Queue() @@ -43,15 +42,16 @@ class ServerTestBase: class TServer(tcp.TCPServer): def __init__(self, ssl, q, handler_klass, addr): """ - ssl: A {cert, key, v3_only} dict. + ssl: A dictionary of SSL parameters: + + cert, key, request_client_cert, cipher_list, + dhparams, v3_only """ tcp.TCPServer.__init__(self, addr) self.ssl, self.q = ssl, q self.handler_klass = handler_klass self.last_handler = None - - def handle_client_connection(self, request, client_address): h = self.handler_klass(request, client_address, self) self.last_handler = h @@ -73,7 +73,8 @@ class TServer(tcp.TCPServer): options = options, handle_sni = getattr(h, "handle_sni", None), request_client_cert = self.ssl["request_client_cert"], - cipher_list = self.ssl.get("cipher_list", None) + cipher_list = self.ssl.get("cipher_list", None), + dhparams = self.ssl.get("dhparams", None) ) h.handle() h.finish() -- cgit v1.2.3 From f5cc63d653b27210d9c3d7646c01c3a9d540d9c7 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Mon, 10 Mar 2014 17:29:27 +1300 Subject: Certificate flags --- netlib/certffi.py | 36 ++++++++++++++++++++++++++++++++++++ netlib/certutils.py | 7 +++++++ 2 files changed, 43 insertions(+) create mode 100644 netlib/certffi.py (limited to 'netlib') diff --git a/netlib/certffi.py b/netlib/certffi.py new file mode 100644 index 00000000..c5d7c95e --- /dev/null +++ b/netlib/certffi.py @@ -0,0 +1,36 @@ +import cffi +import OpenSSL +xffi = cffi.FFI() +xffi.cdef (""" + struct rsa_meth_st { + int flags; + ...; + }; + struct rsa_st { + int pad; + long version; + struct rsa_meth_st *meth; + ...; + }; +""") +xffi.verify( + """#include """, + extra_compile_args=['-w'] +) + +def handle(privkey): + new = xffi.new("struct rsa_st*") + newbuf = xffi.buffer(new) + rsa = OpenSSL.SSL._lib.EVP_PKEY_get1_RSA(privkey._pkey) + oldbuf = OpenSSL.SSL._ffi.buffer(rsa) + newbuf[:] = oldbuf[:] + return new + +def set_flags(privkey, val): + hdl = handle(privkey) + hdl.meth.flags = val + return privkey + +def get_flags(privkey): + hdl = handle(privkey) + return hdl.meth.flags diff --git a/netlib/certutils.py b/netlib/certutils.py index 19148382..92b219ee 100644 --- a/netlib/certutils.py +++ b/netlib/certutils.py @@ -111,6 +111,7 @@ class DNTree: return current.value + class CertStore: """ Implements an in-memory certificate store. @@ -222,6 +223,11 @@ class CertStore: c = (c, None) return (c[0], c[1] or self.privkey) + def gen_pkey(self, cert): + import certffi + certffi.set_flags(self.privkey, 1) + return self.privkey + class _GeneralName(univ.Choice): # We are only interested in dNSNames. We use a default handler to ignore @@ -326,6 +332,7 @@ class SSLCert: return altnames + def get_remote_cert(host, port, sni): c = tcp.TCPClient((host, port)) c.connect() -- cgit v1.2.3 From 4bd15a28b73f521fc08ea77512198faffeaaa247 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Mon, 10 Mar 2014 17:43:39 +0100 Subject: fix #28 --- netlib/tcp.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) (limited to 'netlib') diff --git a/netlib/tcp.py b/netlib/tcp.py index 078ac497..c5f97f94 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -143,7 +143,9 @@ class Reader(_FileLike): raise NetLibTimeout except socket.error: raise NetLibDisconnect - except SSL.SysCallError: + except SSL.SysCallError as e: + if e.args == (-1, 'Unexpected EOF'): + break raise NetLibDisconnect except SSL.Error, v: raise NetLibSSLError(v.message) -- cgit v1.2.3 From 34e469eb558cae999b13510b029714a31d9dd1f3 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Tue, 11 Mar 2014 20:23:27 +0100 Subject: create dhparam file if it doesn't exist, fix mitmproxy/mitmproxy#235 --- netlib/certutils.py | 7 +++++++ 1 file changed, 7 insertions(+) (limited to 'netlib') diff --git a/netlib/certutils.py b/netlib/certutils.py index 92b219ee..ebe643e4 100644 --- a/netlib/certutils.py +++ b/netlib/certutils.py @@ -123,6 +123,13 @@ class CertStore: @classmethod def load_dhparam(klass, path): + + # netlib<=0.10 doesn't generate a dhparam file. + # Create it now if neccessary. + if not os.path.exists(path): + with open(path, "wb") as f: + f.write(DEFAULT_DHPARAM) + bio = OpenSSL.SSL._lib.BIO_new_file(path, b"r") if bio != OpenSSL.SSL._ffi.NULL: bio = OpenSSL.SSL._ffi.gc(bio, OpenSSL.SSL._lib.BIO_free) -- cgit v1.2.3 From d8f54c7c038872fb6f05952214654843c9103da1 Mon Sep 17 00:00:00 2001 From: Bradley Baetz Date: Thu, 20 Mar 2014 11:12:11 +1100 Subject: Change the criticality of a number of X509 extentions, to match the RFCs and real-world CAs/certs. This improve compatability with older browsers/clients. --- netlib/certutils.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) (limited to 'netlib') diff --git a/netlib/certutils.py b/netlib/certutils.py index ebe643e4..4c50b984 100644 --- a/netlib/certutils.py +++ b/netlib/certutils.py @@ -29,12 +29,12 @@ def create_ca(o, cn, exp): cert.add_extensions([ OpenSSL.crypto.X509Extension("basicConstraints", True, "CA:TRUE"), - OpenSSL.crypto.X509Extension("nsCertType", True, + OpenSSL.crypto.X509Extension("nsCertType", False, "sslCA"), - OpenSSL.crypto.X509Extension("extendedKeyUsage", True, + OpenSSL.crypto.X509Extension("extendedKeyUsage", False, "serverAuth,clientAuth,emailProtection,timeStamping,msCodeInd,msCodeCom,msCTLSign,msSGC,msEFS,nsSGC" ), - OpenSSL.crypto.X509Extension("keyUsage", False, + OpenSSL.crypto.X509Extension("keyUsage", True, "keyCertSign, cRLSign"), OpenSSL.crypto.X509Extension("subjectKeyIdentifier", False, "hash", subject=cert), @@ -67,7 +67,7 @@ def dummy_cert(privkey, cacert, commonname, sans): cert.set_serial_number(int(time.time()*10000)) if ss: cert.set_version(2) - cert.add_extensions([OpenSSL.crypto.X509Extension("subjectAltName", True, ss)]) + cert.add_extensions([OpenSSL.crypto.X509Extension("subjectAltName", False, ss)]) cert.set_pubkey(cacert.get_pubkey()) cert.sign(privkey, "sha1") return SSLCert(cert) -- cgit v1.2.3 From e7c3e4c5acdf9a229e13502e14a39caac332fe6c Mon Sep 17 00:00:00 2001 From: Pedro Worcel Date: Sun, 30 Mar 2014 20:58:47 +1300 Subject: Change error into awesome user-friendlyness Hi there, I was getting a very weird error "ODict valuelist should be lists", when attempting to add a header. My code was as followed: ``` msg.headers["API-Key"] = new_headers["API-Key"] 42 msg.headers["API-Sign"] = new_headers["API-Sign"] ``` In the end, that was because there could be multiple equal headers. In order to cater to that, it you guys might enjoy the patch I attach, for it converts strings automatically into lists of multiple headers. I think it should work, but I haven't tested it :$ It'd allow me to have the above code, instead of this one below: ``` msg.headers["API-Key"] = [new_headers["API-Key"]] 42 msg.headers["API-Sign"] = [new_headers["API-Sign"]] ``` --- netlib/odict.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) (limited to 'netlib') diff --git a/netlib/odict.py b/netlib/odict.py index 46b74e8e..d0ff5cf6 100644 --- a/netlib/odict.py +++ b/netlib/odict.py @@ -60,7 +60,9 @@ class ODict: key, they are cleared. """ if isinstance(valuelist, basestring): - raise ValueError("ODict valuelist should be lists.") + # convert the string into a single element list. + valuelist = [valuelist] + new = self._filter_lst(k, self.lst) for i in valuelist: new.append([k, i]) -- cgit v1.2.3 From bb10dfc5055b6877f35a362ee7705c612aece418 Mon Sep 17 00:00:00 2001 From: Pedro Worcel Date: Mon, 31 Mar 2014 20:19:23 +1300 Subject: Instead of removing the error, for consistency, leaving the error as-was and replaced the message with something that may or may not be more understandable :P --- netlib/odict.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) (limited to 'netlib') diff --git a/netlib/odict.py b/netlib/odict.py index d0ff5cf6..0640c25d 100644 --- a/netlib/odict.py +++ b/netlib/odict.py @@ -60,9 +60,8 @@ class ODict: key, they are cleared. """ if isinstance(valuelist, basestring): - # convert the string into a single element list. - valuelist = [valuelist] - + raise ValueError("Expected list instead of string. E.g. odict['elem'] = ['string1', 'string2']") + new = self._filter_lst(k, self.lst) for i in valuelist: new.append([k, i]) -- cgit v1.2.3 From c2c952b3ccf0a1803bd64d4a77998c754298e31a Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Mon, 31 Mar 2014 12:44:20 +0200 Subject: make error message example less abstract. --- netlib/odict.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'netlib') diff --git a/netlib/odict.py b/netlib/odict.py index 0640c25d..ea95a586 100644 --- a/netlib/odict.py +++ b/netlib/odict.py @@ -60,7 +60,7 @@ class ODict: key, they are cleared. """ if isinstance(valuelist, basestring): - raise ValueError("Expected list instead of string. E.g. odict['elem'] = ['string1', 'string2']") + raise ValueError("Expected list of values instead of string. Example: odict['Host'] = ['www.example.com']") new = self._filter_lst(k, self.lst) for i in valuelist: -- cgit v1.2.3 From 92081eee04ebbdae6443d24b74404c76fd4f17d4 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Fri, 25 Apr 2014 19:40:37 +0200 Subject: Update certutils.py refs mitmproxy/mitmproxy#200 --- netlib/certutils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'netlib') diff --git a/netlib/certutils.py b/netlib/certutils.py index ebe643e4..187abfae 100644 --- a/netlib/certutils.py +++ b/netlib/certutils.py @@ -22,7 +22,7 @@ def create_ca(o, cn, exp): cert.set_version(2) cert.get_subject().CN = cn cert.get_subject().O = o - cert.gmtime_adj_notBefore(0) + cert.gmtime_adj_notBefore(-3600*48) cert.gmtime_adj_notAfter(exp) cert.set_issuer(cert.get_subject()) cert.set_pubkey(key) -- cgit v1.2.3 From a8345af282692a7faf859b37f2748705091004fe Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Thu, 15 May 2014 13:51:59 +0200 Subject: extract cert creation to be accessible in handle_sni callbacks --- netlib/tcp.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) (limited to 'netlib') diff --git a/netlib/tcp.py b/netlib/tcp.py index c5f97f94..7b05222f 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -341,10 +341,9 @@ class BaseHandler(_Connection): self.ssl_established = False self.clientcert = None - def convert_to_ssl(self, cert, key, - method=SSLv23_METHOD, options=None, handle_sni=None, - request_client_cert=False, cipher_list=None, dhparams=None - ): + def _create_ssl_context(self, cert, key, method=SSLv23_METHOD, options=None, + handle_sni=None, request_client_cert=None, cipher_list=None, + dhparams=None ): """ cert: A certutils.SSLCert object. method: One of SSLv2_METHOD, SSLv3_METHOD, SSLv23_METHOD, or TLSv1_METHOD @@ -390,6 +389,14 @@ class BaseHandler(_Connection): # Return true to prevent cert verification error return True ctx.set_verify(SSL.VERIFY_PEER, ver) + return ctx + + def convert_to_ssl(self, **kwargs): + """ + Convert connection to SSL. + For a list of parameters, see BaseHandler._create_ssl_context(...) + """ + ctx = self._create_ssl_context(**kwargs) self.connection = SSL.Connection(ctx, self.connection) self.ssl_established = True self.connection.set_accept_state() -- cgit v1.2.3 From 71834aeab144d8bf083785f668989ad3fb21554e Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Thu, 15 May 2014 14:15:33 +0200 Subject: make cert and key mandatory --- netlib/tcp.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'netlib') diff --git a/netlib/tcp.py b/netlib/tcp.py index 7b05222f..e72d5e48 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -391,12 +391,12 @@ class BaseHandler(_Connection): ctx.set_verify(SSL.VERIFY_PEER, ver) return ctx - def convert_to_ssl(self, **kwargs): + def convert_to_ssl(self, cert, key, **sslctx_kwargs): """ Convert connection to SSL. For a list of parameters, see BaseHandler._create_ssl_context(...) """ - ctx = self._create_ssl_context(**kwargs) + ctx = self._create_ssl_context(cert, key, **sslctx_kwargs) self.connection = SSL.Connection(ctx, self.connection) self.ssl_established = True self.connection.set_accept_state() -- cgit v1.2.3 From 52c6ba8880363ba5d82b5e767559afbc72371272 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Thu, 15 May 2014 18:15:29 +0200 Subject: properly subclass Exception in HTTPError --- netlib/http.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) (limited to 'netlib') diff --git a/netlib/http.py b/netlib/http.py index 51f85627..f5b8118a 100644 --- a/netlib/http.py +++ b/netlib/http.py @@ -1,15 +1,15 @@ import string, urlparse, binascii import odict, utils -class HttpError(Exception): - def __init__(self, code, msg): - self.code, self.msg = code, msg - def __str__(self): - return "HttpError(%s, %s)"%(self.code, self.msg) +class HttpError(Exception): + def __init__(self, code, message): + super(HttpError, self).__init__(message) + self.code = code -class HttpErrorConnClosed(HttpError): pass +class HttpErrorConnClosed(HttpError): + pass def _is_valid_port(port): -- cgit v1.2.3 From 66ac56509f754d1239f81c92b6f7cfb65509dc47 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Wed, 21 May 2014 01:14:55 +0200 Subject: add support for ctx.load_verify_locations, refs mitmproxy/mitmproxy#174 --- netlib/tcp.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) (limited to 'netlib') diff --git a/netlib/tcp.py b/netlib/tcp.py index e72d5e48..c5bb7c4b 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -343,7 +343,7 @@ class BaseHandler(_Connection): def _create_ssl_context(self, cert, key, method=SSLv23_METHOD, options=None, handle_sni=None, request_client_cert=None, cipher_list=None, - dhparams=None ): + dhparams=None, ca_file=None): """ cert: A certutils.SSLCert object. method: One of SSLv2_METHOD, SSLv3_METHOD, SSLv23_METHOD, or TLSv1_METHOD @@ -371,6 +371,8 @@ class BaseHandler(_Connection): ctx = SSL.Context(method) if not options is None: ctx.set_options(options) + if ca_file: + ctx.load_verify_locations(ca_file) if cipher_list: try: ctx.set_cipher_list(cipher_list) @@ -450,7 +452,7 @@ class TCPServer(object): if ex[0] == EINTR: continue else: - raise + raise if self.socket in r: connection, client_address = self.socket.accept() t = threading.Thread( -- cgit v1.2.3 From dc071c4ea7c77b640cb733d769f06631dceb8477 Mon Sep 17 00:00:00 2001 From: Pritam Baral Date: Wed, 28 May 2014 07:10:10 +0530 Subject: Ignore username:password part in url --- netlib/http.py | 2 ++ 1 file changed, 2 insertions(+) (limited to 'netlib') diff --git a/netlib/http.py b/netlib/http.py index f5b8118a..d000b802 100644 --- a/netlib/http.py +++ b/netlib/http.py @@ -43,6 +43,8 @@ def parse_url(url): return None if not scheme: return None + if '@' in netloc: + _, netloc = string.rsplit(netloc, '@', maxsplit=1) if ':' in netloc: host, port = string.rsplit(netloc, ':', maxsplit=1) try: -- cgit v1.2.3 From 217660f5db8f91fa351c188e1e61903e9f54e94d Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Wed, 25 Jun 2014 14:30:42 +0200 Subject: add socks module --- netlib/socks.py | 142 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 142 insertions(+) create mode 100644 netlib/socks.py (limited to 'netlib') diff --git a/netlib/socks.py b/netlib/socks.py new file mode 100644 index 00000000..daebe577 --- /dev/null +++ b/netlib/socks.py @@ -0,0 +1,142 @@ +import socket +import struct +from array import array +from .tcp import Address + + +class SocksError(Exception): + def __init__(self, code, message): + super(SocksError, self).__init__(message) + self.code = code + +class VERSION: + SOCKS4 = 0x04 + SOCKS5 = 0x05 + + +class CMD: + CONNECT = 0x01 + BIND = 0x02 + UDP_ASSOCIATE = 0x03 + + +class ATYP: + IPV4_ADDRESS = 0x01 + DOMAINNAME = 0x03 + IPV6_ADDRESS = 0x04 + +class REP: + SUCCEEDED = 0x00 + GENERAL_SOCKS_SERVER_FAILURE = 0x01 + CONNECTION_NOT_ALLOWED_BY_RULESET = 0x02 + NETWORK_UNREACHABLE = 0x03 + HOST_UNREACHABLE = 0x04 + CONNECTION_REFUSED = 0x05 + TTL_EXPIRED = 0x06 + COMMAND_NOT_SUPPORTED = 0x07 + ADDRESS_TYPE_NOT_SUPPORTED = 0x08 + +class METHOD: + NO_AUTHENTICATION_REQUIRED = 0x00 + GSSAPI = 0x01 + USERNAME_PASSWORD = 0x02 + NO_ACCEPTABLE_METHODS = 0xFF + + +class ClientGreeting(object): + __slots__ = ("ver", "methods") + + def __init__(self, ver, methods): + self.ver = ver + self.methods = methods + + @classmethod + def from_file(cls, f): + ver, nmethods = struct.unpack_from("!BB", f) + methods = array("B") + methods.fromfile(f, nmethods) + return cls(ver, methods) + + def to_file(self, f): + struct.pack_into("!BB", f, 0, self.ver, len(self.methods)) + self.methods.tofile(f) + + +class ServerGreeting(object): + __slots__ = ("ver", "method") + + def __init__(self, ver, method): + self.ver = ver + self.method = method + + @classmethod + def from_file(cls, f): + ver, method = struct.unpack_from("!BB", f) + return cls(ver, method) + + def to_file(self, f): + struct.pack_into("!BB", f, 0, self.ver, self.method) + + +class Request(object): + __slots__ = ("ver", "cmd", "atyp", "dst") + + def __init__(self, ver, cmd, atyp, dst): + self.ver = ver + self.cmd = cmd + self.atyp = atyp + self.dst = dst + + @classmethod + def from_file(cls, f): + ver, cmd, rsv, atyp = struct.unpack_from("!BBBB", f) + if rsv != 0x00: + raise SocksError(REP.GENERAL_SOCKS_SERVER_FAILURE, + "Socks Request: Invalid reserved byte: %s" % rsv) + + if atyp == ATYP.IPV4_ADDRESS: + host = socket.inet_ntoa(f.read(4)) # We use tnoa here as ntop is not commonly available on Windows. + use_ipv6 = False + elif atyp == ATYP.IPV6_ADDRESS: + host = socket.inet_ntop(socket.AF_INET6, f.read(16)) + use_ipv6 = True + elif atyp == ATYP.DOMAINNAME: + length = struct.unpack_from("!B", f) + host = f.read(length) + use_ipv6 = False + else: + raise SocksError(REP.ADDRESS_TYPE_NOT_SUPPORTED, + "Socks Request: Unknown ATYP: %s" % atyp) + + port = struct.unpack_from("!H", f) + dst = Address(host, port, use_ipv6=use_ipv6) + return Request(ver, cmd, atyp, dst) + + def to_file(self, f): + raise NotImplementedError() + +class Reply(object): + __slots__ = ("ver", "rep", "atyp", "bnd") + + def __init__(self, ver, rep, atyp, bnd): + self.ver = ver + self.rep = rep + self.atyp = atyp + self.bnd = bnd + + @classmethod + def from_file(cls, f): + raise NotImplementedError() + + def to_file(self, f): + struct.pack_into("!BBBB", f, 0, self.ver, self.rep, 0x00, self.atyp) + if self.atyp == ATYP.IPV4_ADDRESS: + f.write(socket.inet_aton(self.bnd.host)) + elif self.atyp == ATYP.IPV6_ADDRESS: + f.write(socket.inet_pton(socket.AF_INET6, self.bnd.host)) + elif self.atyp == ATYP.DOMAINNAME: + struct.pack_into("!B", f, 0, len(self.bnd.host)) + f.write(self.bnd.host) + else: + raise SocksError(REP.ADDRESS_TYPE_NOT_SUPPORTED, "Unknown ATYP: %s" % self.atyp) + struct.pack_into("!H", f, 0, self.bnd.port) \ No newline at end of file -- cgit v1.2.3 From dc3d3e5f0a8c4de734187c39888af5fbdb63d8a0 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Wed, 25 Jun 2014 20:31:10 +0200 Subject: add inet_ntop/inet_pton functions --- netlib/utils.py | 29 ++++++++++++++++++++++++++--- netlib/version.py | 1 + 2 files changed, 27 insertions(+), 3 deletions(-) (limited to 'netlib') diff --git a/netlib/utils.py b/netlib/utils.py index 61fd54ae..00e1cd12 100644 --- a/netlib/utils.py +++ b/netlib/utils.py @@ -1,3 +1,5 @@ +import socket + def isascii(s): try: @@ -32,9 +34,9 @@ def hexdump(s): """ parts = [] for i in range(0, len(s), 16): - o = "%.10x"%i - part = s[i:i+16] - x = " ".join("%.2x"%ord(i) for i in part) + o = "%.10x" % i + part = s[i:i + 16] + x = " ".join("%.2x" % ord(i) for i in part) if len(part) < 16: x += " " x += " ".join(" " for i in range(16 - len(part))) @@ -42,3 +44,24 @@ def hexdump(s): (o, x, cleanBin(part, True)) ) return parts + + +def inet_ntop(address_family, packed_ip): + if hasattr(socket, "inet_ntop"): + return socket.inet_ntop(address_family, packed_ip) + # Windows Fallbacks + if address_family == socket.AF_INET: + return socket.inet_ntoa(packed_ip) + if address_family == socket.AF_INET6: + ip = packed_ip.encode("hex") + return ":".join([ip[i:i + 4] for i in range(0, len(ip), 4)]) + + +def inet_pton(address_family, ip_string): + if hasattr(socket, "inet_pton"): + return socket.inet_pton(address_family, ip_string) + # Windows Fallbacks + if address_family == socket.AF_INET: + return socket.inet_aton(ip_string) + if address_family == socket.AF_INET6: + return ip_string.replace(":", "").decode("hex") \ No newline at end of file diff --git a/netlib/version.py b/netlib/version.py index 1d3250e1..25565d40 100644 --- a/netlib/version.py +++ b/netlib/version.py @@ -1,4 +1,5 @@ IVERSION = (0, 11) VERSION = ".".join(str(i) for i in IVERSION) +MINORVERSION = ".".join(str(i) for i in IVERSION[:2]) NAME = "netlib" NAMEVERSION = NAME + " " + VERSION -- cgit v1.2.3 From 6405595ae8593a52f6b81d7f311044f113476d82 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Wed, 25 Jun 2014 20:31:28 +0200 Subject: socks module: polish, add tests --- netlib/socks.py | 71 +++++++++++++++++++++++---------------------------------- 1 file changed, 28 insertions(+), 43 deletions(-) (limited to 'netlib') diff --git a/netlib/socks.py b/netlib/socks.py index daebe577..01f54859 100644 --- a/netlib/socks.py +++ b/netlib/socks.py @@ -1,7 +1,7 @@ import socket import struct from array import array -from .tcp import Address +from . import tcp, utils class SocksError(Exception): @@ -9,6 +9,7 @@ class SocksError(Exception): super(SocksError, self).__init__(message) self.code = code + class VERSION: SOCKS4 = 0x04 SOCKS5 = 0x05 @@ -25,6 +26,7 @@ class ATYP: DOMAINNAME = 0x03 IPV6_ADDRESS = 0x04 + class REP: SUCCEEDED = 0x00 GENERAL_SOCKS_SERVER_FAILURE = 0x01 @@ -36,6 +38,7 @@ class REP: COMMAND_NOT_SUPPORTED = 0x07 ADDRESS_TYPE_NOT_SUPPORTED = 0x08 + class METHOD: NO_AUTHENTICATION_REQUIRED = 0x00 GSSAPI = 0x01 @@ -52,15 +55,14 @@ class ClientGreeting(object): @classmethod def from_file(cls, f): - ver, nmethods = struct.unpack_from("!BB", f) + ver, nmethods = struct.unpack("!BB", f.read(2)) methods = array("B") - methods.fromfile(f, nmethods) + methods.fromstring(f.read(nmethods)) return cls(ver, methods) def to_file(self, f): - struct.pack_into("!BB", f, 0, self.ver, len(self.methods)) - self.methods.tofile(f) - + f.write(struct.pack("!BB", self.ver, len(self.methods))) + f.write(self.methods.tostring()) class ServerGreeting(object): __slots__ = ("ver", "method") @@ -71,72 +73,55 @@ class ServerGreeting(object): @classmethod def from_file(cls, f): - ver, method = struct.unpack_from("!BB", f) + ver, method = struct.unpack("!BB", f.read(2)) return cls(ver, method) def to_file(self, f): - struct.pack_into("!BB", f, 0, self.ver, self.method) + f.write(struct.pack("!BB", self.ver, self.method)) +class Message(object): + __slots__ = ("ver", "msg", "atyp", "addr") -class Request(object): - __slots__ = ("ver", "cmd", "atyp", "dst") - - def __init__(self, ver, cmd, atyp, dst): + def __init__(self, ver, msg, atyp, addr): self.ver = ver - self.cmd = cmd + self.msg = msg self.atyp = atyp - self.dst = dst + self.addr = addr @classmethod def from_file(cls, f): - ver, cmd, rsv, atyp = struct.unpack_from("!BBBB", f) + ver, msg, rsv, atyp = struct.unpack("!BBBB", f.read(4)) if rsv != 0x00: raise SocksError(REP.GENERAL_SOCKS_SERVER_FAILURE, "Socks Request: Invalid reserved byte: %s" % rsv) if atyp == ATYP.IPV4_ADDRESS: - host = socket.inet_ntoa(f.read(4)) # We use tnoa here as ntop is not commonly available on Windows. + host = utils.inet_ntop(socket.AF_INET, f.read(4)) # We use tnoa here as ntop is not commonly available on Windows. use_ipv6 = False elif atyp == ATYP.IPV6_ADDRESS: - host = socket.inet_ntop(socket.AF_INET6, f.read(16)) + host = utils.inet_ntop(socket.AF_INET6, f.read(16)) use_ipv6 = True elif atyp == ATYP.DOMAINNAME: - length = struct.unpack_from("!B", f) + length, = struct.unpack("!B", f.read(1)) host = f.read(length) use_ipv6 = False else: raise SocksError(REP.ADDRESS_TYPE_NOT_SUPPORTED, "Socks Request: Unknown ATYP: %s" % atyp) - port = struct.unpack_from("!H", f) - dst = Address(host, port, use_ipv6=use_ipv6) - return Request(ver, cmd, atyp, dst) - - def to_file(self, f): - raise NotImplementedError() - -class Reply(object): - __slots__ = ("ver", "rep", "atyp", "bnd") - - def __init__(self, ver, rep, atyp, bnd): - self.ver = ver - self.rep = rep - self.atyp = atyp - self.bnd = bnd - - @classmethod - def from_file(cls, f): - raise NotImplementedError() + port, = struct.unpack("!H", f.read(2)) + addr = tcp.Address((host, port), use_ipv6=use_ipv6) + return cls(ver, msg, atyp, addr) def to_file(self, f): - struct.pack_into("!BBBB", f, 0, self.ver, self.rep, 0x00, self.atyp) + f.write(struct.pack("!BBBB", self.ver, self.msg, 0x00, self.atyp)) if self.atyp == ATYP.IPV4_ADDRESS: - f.write(socket.inet_aton(self.bnd.host)) + f.write(utils.inet_pton(socket.AF_INET, self.addr.host)) elif self.atyp == ATYP.IPV6_ADDRESS: - f.write(socket.inet_pton(socket.AF_INET6, self.bnd.host)) + f.write(utils.inet_pton(socket.AF_INET6, self.addr.host)) elif self.atyp == ATYP.DOMAINNAME: - struct.pack_into("!B", f, 0, len(self.bnd.host)) - f.write(self.bnd.host) + f.write(struct.pack("!B", len(self.addr.host))) + f.write(self.addr.host) else: raise SocksError(REP.ADDRESS_TYPE_NOT_SUPPORTED, "Unknown ATYP: %s" % self.atyp) - struct.pack_into("!H", f, 0, self.bnd.port) \ No newline at end of file + f.write(struct.pack("!H", self.addr.port)) \ No newline at end of file -- cgit v1.2.3 From e69133f98c513a99c017ad561ea9195280e3f7c5 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Wed, 25 Jun 2014 21:16:47 +0200 Subject: remove ntop windows workaround --- netlib/socks.py | 8 ++++---- netlib/utils.py | 23 +---------------------- 2 files changed, 5 insertions(+), 26 deletions(-) (limited to 'netlib') diff --git a/netlib/socks.py b/netlib/socks.py index 01f54859..97df3478 100644 --- a/netlib/socks.py +++ b/netlib/socks.py @@ -96,10 +96,10 @@ class Message(object): "Socks Request: Invalid reserved byte: %s" % rsv) if atyp == ATYP.IPV4_ADDRESS: - host = utils.inet_ntop(socket.AF_INET, f.read(4)) # We use tnoa here as ntop is not commonly available on Windows. + host = socket.inet_ntoa(f.read(4)) # We use tnoa here as ntop is not commonly available on Windows. use_ipv6 = False elif atyp == ATYP.IPV6_ADDRESS: - host = utils.inet_ntop(socket.AF_INET6, f.read(16)) + host = socket.inet_ntop(socket.AF_INET6, f.read(16)) use_ipv6 = True elif atyp == ATYP.DOMAINNAME: length, = struct.unpack("!B", f.read(1)) @@ -116,9 +116,9 @@ class Message(object): def to_file(self, f): f.write(struct.pack("!BBBB", self.ver, self.msg, 0x00, self.atyp)) if self.atyp == ATYP.IPV4_ADDRESS: - f.write(utils.inet_pton(socket.AF_INET, self.addr.host)) + f.write(socket.inet_aton(self.addr.host)) elif self.atyp == ATYP.IPV6_ADDRESS: - f.write(utils.inet_pton(socket.AF_INET6, self.addr.host)) + f.write(socket.inet_pton(socket.AF_INET6, self.addr.host)) elif self.atyp == ATYP.DOMAINNAME: f.write(struct.pack("!B", len(self.addr.host))) f.write(self.addr.host) diff --git a/netlib/utils.py b/netlib/utils.py index 00e1cd12..69ba456a 100644 --- a/netlib/utils.py +++ b/netlib/utils.py @@ -43,25 +43,4 @@ def hexdump(s): parts.append( (o, x, cleanBin(part, True)) ) - return parts - - -def inet_ntop(address_family, packed_ip): - if hasattr(socket, "inet_ntop"): - return socket.inet_ntop(address_family, packed_ip) - # Windows Fallbacks - if address_family == socket.AF_INET: - return socket.inet_ntoa(packed_ip) - if address_family == socket.AF_INET6: - ip = packed_ip.encode("hex") - return ":".join([ip[i:i + 4] for i in range(0, len(ip), 4)]) - - -def inet_pton(address_family, ip_string): - if hasattr(socket, "inet_pton"): - return socket.inet_pton(address_family, ip_string) - # Windows Fallbacks - if address_family == socket.AF_INET: - return socket.inet_aton(ip_string) - if address_family == socket.AF_INET6: - return ip_string.replace(":", "").decode("hex") \ No newline at end of file + return parts \ No newline at end of file -- cgit v1.2.3 From 4d5d8b65114d061da4f6a41673011ce643c29aab Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Sun, 29 Jun 2014 13:10:07 +0200 Subject: mark nsCertType non-critical, fix #39 --- netlib/certutils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'netlib') diff --git a/netlib/certutils.py b/netlib/certutils.py index 187abfae..8aec5e82 100644 --- a/netlib/certutils.py +++ b/netlib/certutils.py @@ -29,7 +29,7 @@ def create_ca(o, cn, exp): cert.add_extensions([ OpenSSL.crypto.X509Extension("basicConstraints", True, "CA:TRUE"), - OpenSSL.crypto.X509Extension("nsCertType", True, + OpenSSL.crypto.X509Extension("nsCertType", False, "sslCA"), OpenSSL.crypto.X509Extension("extendedKeyUsage", True, "serverAuth,clientAuth,emailProtection,timeStamping,msCodeInd,msCodeCom,msCTLSign,msSGC,msEFS,nsSGC" -- cgit v1.2.3 From 273c25a705c7784ed3fbe15faa11effe05809519 Mon Sep 17 00:00:00 2001 From: Brad Peabody Date: Sat, 12 Jul 2014 22:42:06 -0700 Subject: added option for read_response to only read the headers, beginnings of implementing streamed result in mitmproxy --- netlib/http.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) (limited to 'netlib') diff --git a/netlib/http.py b/netlib/http.py index f5b8118a..21cde538 100644 --- a/netlib/http.py +++ b/netlib/http.py @@ -292,7 +292,7 @@ def parse_response_line(line): return (proto, code, msg) -def read_response(rfile, method, body_size_limit): +def read_response(rfile, method, body_size_limit, include_body=True): """ Return an (httpversion, code, msg, headers, content) tuple. """ @@ -315,8 +315,10 @@ def read_response(rfile, method, body_size_limit): # Parse response body according to http://tools.ietf.org/html/draft-ietf-httpbis-p1-messaging-16#section-3.3 if method in ["HEAD", "CONNECT"] or (code in [204, 304]) or 100 <= code <= 199: content = "" - else: + elif include_body: content = read_http_body(rfile, headers, body_size_limit, False) + else: + content = None # if include_body==False then a None content means the body should be read separately return httpversion, code, msg, headers, content -- cgit v1.2.3 From 24ef9c61a39f24c8f5ec4414a4a9d0b6a2bc4283 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Mon, 14 Jul 2014 17:38:49 +0200 Subject: improve docs --- netlib/http.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) (limited to 'netlib') diff --git a/netlib/http.py b/netlib/http.py index 21cde538..413c73a1 100644 --- a/netlib/http.py +++ b/netlib/http.py @@ -292,12 +292,17 @@ def parse_response_line(line): return (proto, code, msg) -def read_response(rfile, method, body_size_limit, include_body=True): +def read_response(rfile, request_method, body_size_limit, include_body=True): """ Return an (httpversion, code, msg, headers, content) tuple. + + By default, both response header and body are read. + If include_body=False is specified, content may be one of the following: + - None, if the response is technically allowed to have a response body + - "", if the response must not have a response body (e.g. it's a response to a HEAD request) """ line = rfile.readline() - if line == "\r\n" or line == "\n": # Possible leftover from previous message + if line == "\r\n" or line == "\n": # Possible leftover from previous message line = rfile.readline() if not line: raise HttpErrorConnClosed(502, "Server disconnect.") @@ -312,13 +317,13 @@ def read_response(rfile, method, body_size_limit, include_body=True): if headers is None: raise HttpError(502, "Invalid headers.") - # Parse response body according to http://tools.ietf.org/html/draft-ietf-httpbis-p1-messaging-16#section-3.3 - if method in ["HEAD", "CONNECT"] or (code in [204, 304]) or 100 <= code <= 199: + # Parse response body according to http://tools.ietf.org/html/rfc7230#section-3.3 + if request_method in ["HEAD", "CONNECT"] or (code in [204, 304]) or 100 <= code <= 199: content = "" elif include_body: content = read_http_body(rfile, headers, body_size_limit, False) else: - content = None # if include_body==False then a None content means the body should be read separately + content = None # if include_body==False then a None content means the body should be read separately return httpversion, code, msg, headers, content -- cgit v1.2.3 From 280d9b862575d79b391e28c80156697d2d674c48 Mon Sep 17 00:00:00 2001 From: Brad Peabody Date: Thu, 17 Jul 2014 22:34:29 -0700 Subject: added some additional functions for dealing with chunks - needed for mitmproxy streaming capability --- netlib/http.py | 63 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 62 insertions(+), 1 deletion(-) (limited to 'netlib') diff --git a/netlib/http.py b/netlib/http.py index 21cde538..736c2c88 100644 --- a/netlib/http.py +++ b/netlib/http.py @@ -136,6 +136,49 @@ def read_chunked(fp, headers, limit, is_request): break return content +def read_next_chunk(fp, headers, is_request): + """ + Read next piece of a chunked HTTP body. Returns next piece of + content as a string or None if we hit the end. + """ + # TODO: see and understand the FIXME in read_chunked and + # see if we need to apply here? + content = "" + code = 400 if is_request else 502 + line = fp.readline(128) + if line == "": + raise HttpErrorConnClosed(code, "Connection closed prematurely") + try: + length = int(line, 16) + except ValueError: + # TODO: see note in this part of read_chunked() + raise HttpError(code, "Invalid chunked encoding length: %s"%line) + if length > 0: + content += fp.read(length) + print "read content: '%s'" % content + line = fp.readline(5) + if line == '': + raise HttpErrorConnClosed(code, "Connection closed prematurely") + if line != '\r\n': + raise HttpError(code, "Malformed chunked body: '%s' (len=%d)" % (line, length)) + if content == "": + content = None # normalize zero length to None, meaning end of chunked stream + return content # return this chunk + +def write_chunk(fp, content): + """ + Write a chunk with chunked encoding format, returns True + if there should be more chunks or False if you passed + None, meaning this was the last chunk. + """ + if content == None or content == "": + fp.write("0\r\n\r\n") + return False + fp.write("%x\r\n" % len(content)) + fp.write(content) + fp.write("\r\n") + return True + def get_header_tokens(headers, key): """ @@ -350,4 +393,22 @@ def read_http_body(rfile, headers, limit, is_request): not_done = rfile.read(1) if not_done: raise HttpError(400 if is_request else 509, "HTTP Body too large. Limit is %s," % limit) - return content \ No newline at end of file + return content + +def expected_http_body_size(headers, is_request): + """ + Returns length of body expected or -1 if not + known and we should just read until end of + stream. + """ + if "content-length" in headers: + try: + l = int(headers["content-length"][0]) + if l < 0: + raise ValueError() + return l + except ValueError: + raise HttpError(400 if is_request else 502, "Invalid content-length header: %s"%headers["content-length"]) + elif is_request: + return 0 + return -1 -- cgit v1.2.3 From a7837846a2c20f3fc48406fc63845aec1a7efae0 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Fri, 18 Jul 2014 22:55:25 +0200 Subject: temporarily replace DNTree with a simpler cert lookup mechanism, fix mitmproxy/mitmproxy#295 --- netlib/certutils.py | 99 ++++++++++++++++++++++++++++------------------------- 1 file changed, 53 insertions(+), 46 deletions(-) (limited to 'netlib') diff --git a/netlib/certutils.py b/netlib/certutils.py index 8aec5e82..87fb99c3 100644 --- a/netlib/certutils.py +++ b/netlib/certutils.py @@ -1,4 +1,5 @@ import os, ssl, time, datetime +import itertools from pyasn1.type import univ, constraint, char, namedtype, tag from pyasn1.codec.der.decoder import decode from pyasn1.error import PyAsn1Error @@ -73,42 +74,44 @@ def dummy_cert(privkey, cacert, commonname, sans): return SSLCert(cert) -class _Node(UserDict.UserDict): - def __init__(self): - UserDict.UserDict.__init__(self) - self.value = None - - -class DNTree: - """ - Domain store that knows about wildcards. DNS wildcards are very - restricted - the only valid variety is an asterisk on the left-most - domain component, i.e.: - - *.foo.com - """ - def __init__(self): - self.d = _Node() - - def add(self, dn, cert): - parts = dn.split(".") - parts.reverse() - current = self.d - for i in parts: - current = current.setdefault(i, _Node()) - current.value = cert - - def get(self, dn): - parts = dn.split(".") - current = self.d - for i in reversed(parts): - if i in current: - current = current[i] - elif "*" in current: - return current["*"].value - else: - return None - return current.value +# DNTree did not pass TestCertStore.test_sans_change and is temporarily replaced by a simple dict. +# +# class _Node(UserDict.UserDict): +# def __init__(self): +# UserDict.UserDict.__init__(self) +# self.value = None +# +# +# class DNTree: +# """ +# Domain store that knows about wildcards. DNS wildcards are very +# restricted - the only valid variety is an asterisk on the left-most +# domain component, i.e.: +# +# *.foo.com +# """ +# def __init__(self): +# self.d = _Node() +# +# def add(self, dn, cert): +# parts = dn.split(".") +# parts.reverse() +# current = self.d +# for i in parts: +# current = current.setdefault(i, _Node()) +# current.value = cert +# +# def get(self, dn): +# parts = dn.split(".") +# current = self.d +# for i in reversed(parts): +# if i in current: +# current = current[i] +# elif "*" in current: +# return current["*"].value +# else: +# return None +# return current.value @@ -119,7 +122,7 @@ class CertStore: def __init__(self, privkey, cacert, dhparams=None): self.privkey, self.cacert = privkey, cacert self.dhparams = dhparams - self.certs = DNTree() + self.certs = dict() @classmethod def load_dhparam(klass, path): @@ -206,11 +209,11 @@ class CertStore: any SANs, and also the list of names provided as an argument. """ if cert.cn: - self.certs.add(cert.cn, (cert, privkey)) + self.certs[cert.cn] = (cert, privkey) for i in cert.altnames: - self.certs.add(i, (cert, privkey)) + self.certs[i] = (cert, privkey) for i in names: - self.certs.add(i, (cert, privkey)) + self.certs[i] = (cert, privkey) def get_cert(self, commonname, sans): """ @@ -223,12 +226,16 @@ class CertStore: Return None if the certificate could not be found or generated. """ - c = self.certs.get(commonname) - if not c: - c = dummy_cert(self.privkey, self.cacert, commonname, sans) - self.add_cert(c, None) - c = (c, None) - return (c[0], c[1] or self.privkey) + + potential_keys = [commonname] + sans + [(commonname, tuple(sans))] + name = next(itertools.ifilter(lambda key: key in self.certs, potential_keys), None) + if name: + c = self.certs[name] + else: + c = dummy_cert(self.privkey, self.cacert, commonname, sans), None + self.certs[(commonname, tuple(sans))] = c + + return c[0], (c[1] or self.privkey) def gen_pkey(self, cert): import certffi -- cgit v1.2.3 From d382bb27bf4732def621cddb46fc4cc1d2143ab4 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Sat, 19 Jul 2014 00:02:31 +0200 Subject: certstore: add support for asterisk form to DNTree replacement --- netlib/certutils.py | 19 ++++++++++++++++++- 1 file changed, 18 insertions(+), 1 deletion(-) (limited to 'netlib') diff --git a/netlib/certutils.py b/netlib/certutils.py index 87fb99c3..308d6cf8 100644 --- a/netlib/certutils.py +++ b/netlib/certutils.py @@ -215,6 +215,19 @@ class CertStore: for i in names: self.certs[i] = (cert, privkey) + @staticmethod + def asterisk_forms(dn): + parts = dn.split(".") + parts.reverse() + curr_dn = "" + dn_forms = ["*"] + for part in parts[:-1]: + curr_dn = "." + part + curr_dn # .example.com + dn_forms.append("*" + curr_dn) # *.example.com + if parts[-1] != "*": + dn_forms.append(parts[-1] + curr_dn) + return dn_forms + def get_cert(self, commonname, sans): """ Returns an (cert, privkey) tuple. @@ -227,7 +240,11 @@ class CertStore: Return None if the certificate could not be found or generated. """ - potential_keys = [commonname] + sans + [(commonname, tuple(sans))] + potential_keys = self.asterisk_forms(commonname) + for s in sans: + potential_keys.extend(self.asterisk_forms(s)) + potential_keys.append((commonname, tuple(sans))) + name = next(itertools.ifilter(lambda key: key in self.certs, potential_keys), None) if name: c = self.certs[name] -- cgit v1.2.3 From 6bd5df79f82a33b7e725afb5f279bda4cba41935 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Mon, 21 Jul 2014 14:01:24 +0200 Subject: refactor response length handling --- netlib/http.py | 183 +++++++++++++++++++++++++-------------------------------- 1 file changed, 81 insertions(+), 102 deletions(-) (limited to 'netlib') diff --git a/netlib/http.py b/netlib/http.py index 736c2c88..f88e6652 100644 --- a/netlib/http.py +++ b/netlib/http.py @@ -1,4 +1,5 @@ import string, urlparse, binascii +import sys import odict, utils @@ -88,14 +89,14 @@ def read_headers(fp): # We're being liberal in what we accept, here. if i > 0: name = line[:i] - value = line[i+1:].strip() + value = line[i + 1:].strip() ret.append([name, value]) else: return None return odict.ODictCaseless(ret) -def read_chunked(fp, headers, limit, is_request): +def read_chunked(fp, limit, is_request): """ Read a chunked HTTP body. @@ -103,10 +104,9 @@ def read_chunked(fp, headers, limit, is_request): """ # FIXME: Should check if chunked is the final encoding in the headers # http://tools.ietf.org/html/draft-ietf-httpbis-p1-messaging-16#section-3.3 3.3 2. - content = "" total = 0 code = 400 if is_request else 502 - while 1: + while True: line = fp.readline(128) if line == "": raise HttpErrorConnClosed(code, "Connection closed prematurely") @@ -114,70 +114,19 @@ def read_chunked(fp, headers, limit, is_request): try: length = int(line, 16) except ValueError: - # FIXME: Not strictly correct - this could be from the server, in which - # case we should send a 502. - raise HttpError(code, "Invalid chunked encoding length: %s"%line) - if not length: - break + raise HttpError(code, "Invalid chunked encoding length: %s" % line) total += length if limit is not None and total > limit: - msg = "HTTP Body too large."\ - " Limit is %s, chunked content length was at least %s"%(limit, total) + msg = "HTTP Body too large." \ + " Limit is %s, chunked content length was at least %s" % (limit, total) raise HttpError(code, msg) - content += fp.read(length) - line = fp.readline(5) - if line != '\r\n': + chunk = fp.read(length) + suffix = fp.readline(5) + if suffix != '\r\n': raise HttpError(code, "Malformed chunked body") - while 1: - line = fp.readline() - if line == "": - raise HttpErrorConnClosed(code, "Connection closed prematurely") - if line == '\r\n' or line == '\n': - break - return content - -def read_next_chunk(fp, headers, is_request): - """ - Read next piece of a chunked HTTP body. Returns next piece of - content as a string or None if we hit the end. - """ - # TODO: see and understand the FIXME in read_chunked and - # see if we need to apply here? - content = "" - code = 400 if is_request else 502 - line = fp.readline(128) - if line == "": - raise HttpErrorConnClosed(code, "Connection closed prematurely") - try: - length = int(line, 16) - except ValueError: - # TODO: see note in this part of read_chunked() - raise HttpError(code, "Invalid chunked encoding length: %s"%line) - if length > 0: - content += fp.read(length) - print "read content: '%s'" % content - line = fp.readline(5) - if line == '': - raise HttpErrorConnClosed(code, "Connection closed prematurely") - if line != '\r\n': - raise HttpError(code, "Malformed chunked body: '%s' (len=%d)" % (line, length)) - if content == "": - content = None # normalize zero length to None, meaning end of chunked stream - return content # return this chunk - -def write_chunk(fp, content): - """ - Write a chunk with chunked encoding format, returns True - if there should be more chunks or False if you passed - None, meaning this was the last chunk. - """ - if content == None or content == "": - fp.write("0\r\n\r\n") - return False - fp.write("%x\r\n" % len(content)) - fp.write(content) - fp.write("\r\n") - return True + yield line, chunk, '\r\n' + if length == 0: + return def get_header_tokens(headers, key): @@ -307,6 +256,7 @@ def parse_init_http(line): def connection_close(httpversion, headers): """ Checks the message to see if the client connection should be closed according to RFC 2616 Section 8.1 + Note that a connection should be closed as well if the response has been read until end of the stream. """ # At first, check if we have an explicit Connection header. if "connection" in headers: @@ -323,7 +273,7 @@ def connection_close(httpversion, headers): def parse_response_line(line): parts = line.strip().split(" ", 2) - if len(parts) == 2: # handle missing message gracefully + if len(parts) == 2: # handle missing message gracefully parts.append("") if len(parts) != 3: return None @@ -335,37 +285,38 @@ def parse_response_line(line): return (proto, code, msg) -def read_response(rfile, method, body_size_limit, include_body=True): +def read_response(rfile, request_method, body_size_limit, include_body=True): """ Return an (httpversion, code, msg, headers, content) tuple. """ line = rfile.readline() - if line == "\r\n" or line == "\n": # Possible leftover from previous message + if line == "\r\n" or line == "\n": # Possible leftover from previous message line = rfile.readline() if not line: raise HttpErrorConnClosed(502, "Server disconnect.") parts = parse_response_line(line) if not parts: - raise HttpError(502, "Invalid server response: %s"%repr(line)) + raise HttpError(502, "Invalid server response: %s" % repr(line)) proto, code, msg = parts httpversion = parse_http_protocol(proto) if httpversion is None: - raise HttpError(502, "Invalid HTTP version in line: %s"%repr(proto)) + raise HttpError(502, "Invalid HTTP version in line: %s" % repr(proto)) headers = read_headers(rfile) if headers is None: raise HttpError(502, "Invalid headers.") - # Parse response body according to http://tools.ietf.org/html/draft-ietf-httpbis-p1-messaging-16#section-3.3 - if method in ["HEAD", "CONNECT"] or (code in [204, 304]) or 100 <= code <= 199: - content = "" - elif include_body: - content = read_http_body(rfile, headers, body_size_limit, False) + if include_body: + content = read_http_body(rfile, headers, body_size_limit, request_method, code, False) else: - content = None # if include_body==False then a None content means the body should be read separately + content = None # if include_body==False then a None content means the body should be read separately return httpversion, code, msg, headers, content -def read_http_body(rfile, headers, limit, is_request): +def read_http_body(*args, **kwargs): + return "".join(content for _, content, _ in read_http_body_chunked(*args, **kwargs)) + + +def read_http_body_chunked(rfile, headers, limit, request_method, response_code, is_request, max_chunk_size=None): """ Read an HTTP message body: @@ -374,41 +325,69 @@ def read_http_body(rfile, headers, limit, is_request): limit: Size limit. is_request: True if the body to read belongs to a request, False otherwise """ - if has_chunked_encoding(headers): - content = read_chunked(rfile, headers, limit, is_request) - elif "content-length" in headers: - try: - l = int(headers["content-length"][0]) - if l < 0: - raise ValueError() - except ValueError: - raise HttpError(400 if is_request else 502, "Invalid content-length header: %s"%headers["content-length"]) - if limit is not None and l > limit: - raise HttpError(400 if is_request else 509, "HTTP Body too large. Limit is %s, content-length was %s"%(limit, l)) - content = rfile.read(l) - elif is_request: - content = "" + if max_chunk_size is None: + max_chunk_size = limit or sys.maxint + + expected_size = expected_http_body_size(headers, is_request, request_method, response_code) + + if expected_size is None: + if has_chunked_encoding(headers): + # Python 3: yield from + for x in read_chunked(rfile, limit, is_request): + yield x + else: # pragma: nocover + raise HttpError(400 if is_request else 502, "Content-Length unknown but no chunked encoding") + elif expected_size >= 0: + if limit is not None and expected_size > limit: + raise HttpError(400 if is_request else 509, + "HTTP Body too large. Limit is %s, content-length was %s" % (limit, expected_size)) + bytes_left = expected_size + while bytes_left: + chunk_size = min(bytes_left, max_chunk_size) + yield "", rfile.read(chunk_size), "" + bytes_left -= chunk_size else: - content = rfile.read(limit if limit else -1) + bytes_left = limit or -1 + while bytes_left: + chunk_size = min(bytes_left, max_chunk_size) + content = rfile.read(chunk_size) + if not content: + return + yield "", content, "" + bytes_left -= chunk_size not_done = rfile.read(1) if not_done: raise HttpError(400 if is_request else 509, "HTTP Body too large. Limit is %s," % limit) - return content -def expected_http_body_size(headers, is_request): + +def expected_http_body_size(headers, is_request, request_method, response_code): """ - Returns length of body expected or -1 if not - known and we should just read until end of - stream. + Returns the expected body length: + - a positive integer, if the size is known in advance + - None, if the size in unknown in advance (chunked encoding) + - -1, if all data should be read until end of stream. """ + + # Determine response size according to http://tools.ietf.org/html/draft-ietf-httpbis-p1-messaging-16#section-3.3 + if request_method: + request_method = request_method.upper() + + if (not is_request and ( + request_method == "HEAD" or + (request_method == "CONNECT" and response_code == 200) or + response_code in [204, 304] or + 100 <= response_code <= 199)): + return 0 + if has_chunked_encoding(headers): + return None if "content-length" in headers: try: - l = int(headers["content-length"][0]) - if l < 0: + size = int(headers["content-length"][0]) + if size < 0: raise ValueError() - return l + return size except ValueError: - raise HttpError(400 if is_request else 502, "Invalid content-length header: %s"%headers["content-length"]) - elif is_request: + raise HttpError(400 if is_request else 502, "Invalid content-length header: %s" % headers["content-length"]) + if is_request: return 0 - return -1 + return -1 \ No newline at end of file -- cgit v1.2.3 From 197dae918388b53fde6f79dcec9613a0ac1d4ba1 Mon Sep 17 00:00:00 2001 From: kronick Date: Tue, 29 Jul 2014 15:12:13 +0200 Subject: Made attribute optional (as it is in pyOpenSSL) See https://github.com/pyca/pyopenssl/commit/0d7e8a1af28ab22950b21afa3fd451cec7dd5fdc -- It looks like this constant isn't set on some platforms (including Raspberry Pi's libssl) --- netlib/tcp.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) (limited to 'netlib') diff --git a/netlib/tcp.py b/netlib/tcp.py index c5bb7c4b..9c92ce38 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -17,7 +17,10 @@ OP_DONT_INSERT_EMPTY_FRAGMENTS = SSL.OP_DONT_INSERT_EMPTY_FRAGMENTS OP_EPHEMERAL_RSA = SSL.OP_EPHEMERAL_RSA OP_MICROSOFT_BIG_SSLV3_BUFFER = SSL.OP_MICROSOFT_BIG_SSLV3_BUFFER OP_MICROSOFT_SESS_ID_BUG = SSL.OP_MICROSOFT_SESS_ID_BUG -OP_MSIE_SSLV2_RSA_PADDING = SSL.OP_MSIE_SSLV2_RSA_PADDING +try: + OP_MSIE_SSLV2_RSA_PADDING = SSL.OP_MSIE_SSLV2_RSA_PADDING +except AttributeError: + pass OP_NETSCAPE_CA_DN_BUG = SSL.OP_NETSCAPE_CA_DN_BUG OP_NETSCAPE_CHALLENGE_BUG = SSL.OP_NETSCAPE_CHALLENGE_BUG OP_NETSCAPE_DEMO_CIPHER_CHANGE_BUG = SSL.OP_NETSCAPE_DEMO_CIPHER_CHANGE_BUG -- cgit v1.2.3 From 1c1167eda0a2757b8fb6588f0400d47020fdb1ab Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Sat, 16 Aug 2014 15:28:09 +0200 Subject: use passlib instead of md5crypt --- netlib/contrib/__init__.py | 0 netlib/contrib/md5crypt.py | 94 ---------------------------------------------- netlib/http_auth.py | 29 +++----------- 3 files changed, 5 insertions(+), 118 deletions(-) delete mode 100644 netlib/contrib/__init__.py delete mode 100644 netlib/contrib/md5crypt.py (limited to 'netlib') diff --git a/netlib/contrib/__init__.py b/netlib/contrib/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/netlib/contrib/md5crypt.py b/netlib/contrib/md5crypt.py deleted file mode 100644 index d64ea8ac..00000000 --- a/netlib/contrib/md5crypt.py +++ /dev/null @@ -1,94 +0,0 @@ -# Based on FreeBSD src/lib/libcrypt/crypt.c 1.2 -# http://www.freebsd.org/cgi/cvsweb.cgi/~checkout~/src/lib/libcrypt/crypt.c?rev=1.2&content-type=text/plain - -# Original license: -# * "THE BEER-WARE LICENSE" (Revision 42): -# * wrote this file. As long as you retain this notice you -# * can do whatever you want with this stuff. If we meet some day, and you think -# * this stuff is worth it, you can buy me a beer in return. Poul-Henning Kamp - -# This port adds no further stipulations. I forfeit any copyright interest. - -import md5 - -def md5crypt(password, salt, magic='$1$'): - # /* The password first, since that is what is most unknown */ /* Then our magic string */ /* Then the raw salt */ - m = md5.new() - m.update(password + magic + salt) - - # /* Then just as many characters of the MD5(pw,salt,pw) */ - mixin = md5.md5(password + salt + password).digest() - for i in range(0, len(password)): - m.update(mixin[i % 16]) - - # /* Then something really weird... */ - # Also really broken, as far as I can tell. -m - i = len(password) - while i: - if i & 1: - m.update('\x00') - else: - m.update(password[0]) - i >>= 1 - - final = m.digest() - - # /* and now, just to make sure things don't run too fast */ - for i in range(1000): - m2 = md5.md5() - if i & 1: - m2.update(password) - else: - m2.update(final) - - if i % 3: - m2.update(salt) - - if i % 7: - m2.update(password) - - if i & 1: - m2.update(final) - else: - m2.update(password) - - final = m2.digest() - - # This is the bit that uses to64() in the original code. - - itoa64 = './0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz' - - rearranged = '' - for a, b, c in ((0, 6, 12), (1, 7, 13), (2, 8, 14), (3, 9, 15), (4, 10, 5)): - v = ord(final[a]) << 16 | ord(final[b]) << 8 | ord(final[c]) - for i in range(4): - rearranged += itoa64[v & 0x3f]; v >>= 6 - - v = ord(final[11]) - for i in range(2): - rearranged += itoa64[v & 0x3f]; v >>= 6 - - return magic + salt + '$' + rearranged - -if __name__ == '__main__': - - def test(clear_password, the_hash): - magic, salt = the_hash[1:].split('$')[:2] - magic = '$' + magic + '$' - return md5crypt(clear_password, salt, magic) == the_hash - - test_cases = ( - (' ', '$1$yiiZbNIH$YiCsHZjcTkYd31wkgW8JF.'), - ('pass', '$1$YeNsbWdH$wvOF8JdqsoiLix754LTW90'), - ('____fifteen____', '$1$s9lUWACI$Kk1jtIVVdmT01p0z3b/hw1'), - ('____sixteen_____', '$1$dL3xbVZI$kkgqhCanLdxODGq14g/tW1'), - ('____seventeen____', '$1$NaH5na7J$j7y8Iss0hcRbu3kzoJs5V.'), - ('__________thirty-three___________', '$1$HO7Q6vzJ$yGwp2wbL5D7eOVzOmxpsy.'), - ('apache', '$apr1$J.w5a/..$IW9y6DR0oO/ADuhlMF5/X1') - ) - - for clearpw, hashpw in test_cases: - if test(clearpw, hashpw): - print '%s: pass' % clearpw - else: - print '%s: FAIL' % clearpw diff --git a/netlib/http_auth.py b/netlib/http_auth.py index b0451e3b..937b66f0 100644 --- a/netlib/http_auth.py +++ b/netlib/http_auth.py @@ -1,4 +1,4 @@ -from .contrib import md5crypt +from passlib.apache import HtpasswdFile import http from argparse import Action, ArgumentTypeError @@ -78,32 +78,14 @@ class PassManHtpasswd: """ Read usernames and passwords from an htpasswd file """ - def __init__(self, fp): + def __init__(self, path): """ Raises ValueError if htpasswd file is invalid. """ - self.usernames = {} - for l in fp: - l = l.strip().split(':') - if len(l) != 2: - raise ValueError("Invalid htpasswd file.") - parts = l[1].split('$') - if len(parts) != 4: - raise ValueError("Invalid htpasswd file.") - self.usernames[l[0]] = dict( - token = l[1], - dummy = parts[0], - magic = parts[1], - salt = parts[2], - hashed_password = parts[3] - ) + self.htpasswd = HtpasswdFile(path) def test(self, username, password_token): - ui = self.usernames.get(username) - if not ui: - return False - expected = md5crypt.md5crypt(password_token, ui["salt"], '$'+ui["magic"]+'$') - return expected==ui["token"] + return bool(self.htpasswd.check_password(username, password_token)) class PassManSingleUser: @@ -149,6 +131,5 @@ class NonanonymousAuthAction(AuthAction): class HtpasswdAuthAction(AuthAction): def getPasswordManager(self, s): - with open(s, "r") as f: - return PassManHtpasswd(f) + return PassManHtpasswd(s) -- cgit v1.2.3 From 6d1b601ddf070ef1335be1804386fa0f4a2fcbd4 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Sat, 16 Aug 2014 15:53:07 +0200 Subject: minor cleanups --- netlib/__init__.py | 1 + netlib/certffi.py | 9 +++++++-- netlib/certutils.py | 15 +++------------ netlib/http.py | 3 ++- netlib/http_auth.py | 3 ++- netlib/http_status.py | 1 + netlib/http_uastrings.py | 2 ++ netlib/odict.py | 1 + netlib/socks.py | 17 +++++++++-------- netlib/tcp.py | 3 ++- netlib/test.py | 3 ++- netlib/utils.py | 2 +- netlib/version.py | 2 ++ netlib/wsgi.py | 3 ++- 14 files changed, 37 insertions(+), 28 deletions(-) (limited to 'netlib') diff --git a/netlib/__init__.py b/netlib/__init__.py index e69de29b..9b4faa33 100644 --- a/netlib/__init__.py +++ b/netlib/__init__.py @@ -0,0 +1 @@ +from __future__ import (absolute_import, print_function, division) diff --git a/netlib/certffi.py b/netlib/certffi.py index c5d7c95e..81dc72e8 100644 --- a/netlib/certffi.py +++ b/netlib/certffi.py @@ -1,7 +1,9 @@ +from __future__ import (absolute_import, print_function, division) import cffi import OpenSSL + xffi = cffi.FFI() -xffi.cdef (""" +xffi.cdef(""" struct rsa_meth_st { int flags; ...; @@ -18,6 +20,7 @@ xffi.verify( extra_compile_args=['-w'] ) + def handle(privkey): new = xffi.new("struct rsa_st*") newbuf = xffi.buffer(new) @@ -26,11 +29,13 @@ def handle(privkey): newbuf[:] = oldbuf[:] return new + def set_flags(privkey, val): hdl = handle(privkey) - hdl.meth.flags = val + hdl.meth.flags = val return privkey + def get_flags(privkey): hdl = handle(privkey) return hdl.meth.flags diff --git a/netlib/certutils.py b/netlib/certutils.py index 308d6cf8..18179917 100644 --- a/netlib/certutils.py +++ b/netlib/certutils.py @@ -1,11 +1,10 @@ +from __future__ import (absolute_import, print_function, division) import os, ssl, time, datetime import itertools from pyasn1.type import univ, constraint, char, namedtype, tag from pyasn1.codec.der.decoder import decode from pyasn1.error import PyAsn1Error import OpenSSL -import tcp -import UserDict DEFAULT_EXP = 62208000 # =24 * 60 * 60 * 720 # Generated with "openssl dhparam". It's too slow to generate this on startup. @@ -255,7 +254,7 @@ class CertStore: return c[0], (c[1] or self.privkey) def gen_pkey(self, cert): - import certffi + from . import certffi certffi.set_flags(self.privkey, 1) return self.privkey @@ -360,12 +359,4 @@ class SSLCert: continue for i in dec[0]: altnames.append(i[0].asOctets()) - return altnames - - - -def get_remote_cert(host, port, sni): - c = tcp.TCPClient((host, port)) - c.connect() - c.convert_to_ssl(sni=sni) - return c.cert + return altnames \ No newline at end of file diff --git a/netlib/http.py b/netlib/http.py index 774bac6c..a49f0588 100644 --- a/netlib/http.py +++ b/netlib/http.py @@ -1,6 +1,7 @@ +from __future__ import (absolute_import, print_function, division) import string, urlparse, binascii import sys -import odict, utils +from . import odict, utils class HttpError(Exception): diff --git a/netlib/http_auth.py b/netlib/http_auth.py index 937b66f0..49f5925f 100644 --- a/netlib/http_auth.py +++ b/netlib/http_auth.py @@ -1,6 +1,7 @@ +from __future__ import (absolute_import, print_function, division) from passlib.apache import HtpasswdFile -import http from argparse import Action, ArgumentTypeError +from . import http class NullProxyAuth(): diff --git a/netlib/http_status.py b/netlib/http_status.py index 9f3f7e15..7dba2d56 100644 --- a/netlib/http_status.py +++ b/netlib/http_status.py @@ -1,3 +1,4 @@ +from __future__ import (absolute_import, print_function, division) CONTINUE = 100 SWITCHING = 101 diff --git a/netlib/http_uastrings.py b/netlib/http_uastrings.py index 826c31a5..d0d145da 100644 --- a/netlib/http_uastrings.py +++ b/netlib/http_uastrings.py @@ -1,3 +1,5 @@ +from __future__ import (absolute_import, print_function, division) + """ A small collection of useful user-agent header strings. These should be kept reasonably current to reflect common usage. diff --git a/netlib/odict.py b/netlib/odict.py index ea95a586..a0e1f694 100644 --- a/netlib/odict.py +++ b/netlib/odict.py @@ -1,3 +1,4 @@ +from __future__ import (absolute_import, print_function, division) import re, copy diff --git a/netlib/socks.py b/netlib/socks.py index 97df3478..1da5b6cc 100644 --- a/netlib/socks.py +++ b/netlib/socks.py @@ -1,7 +1,8 @@ +from __future__ import (absolute_import, print_function, division) import socket import struct -from array import array -from . import tcp, utils +import array +from . import tcp class SocksError(Exception): @@ -10,24 +11,24 @@ class SocksError(Exception): self.code = code -class VERSION: +class VERSION(object): SOCKS4 = 0x04 SOCKS5 = 0x05 -class CMD: +class CMD(object): CONNECT = 0x01 BIND = 0x02 UDP_ASSOCIATE = 0x03 -class ATYP: +class ATYP(object): IPV4_ADDRESS = 0x01 DOMAINNAME = 0x03 IPV6_ADDRESS = 0x04 -class REP: +class REP(object): SUCCEEDED = 0x00 GENERAL_SOCKS_SERVER_FAILURE = 0x01 CONNECTION_NOT_ALLOWED_BY_RULESET = 0x02 @@ -39,7 +40,7 @@ class REP: ADDRESS_TYPE_NOT_SUPPORTED = 0x08 -class METHOD: +class METHOD(object): NO_AUTHENTICATION_REQUIRED = 0x00 GSSAPI = 0x01 USERNAME_PASSWORD = 0x02 @@ -56,7 +57,7 @@ class ClientGreeting(object): @classmethod def from_file(cls, f): ver, nmethods = struct.unpack("!BB", f.read(2)) - methods = array("B") + methods = array.array("B") methods.fromstring(f.read(nmethods)) return cls(ver, methods) diff --git a/netlib/tcp.py b/netlib/tcp.py index 9c92ce38..f49346a1 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -1,6 +1,7 @@ +from __future__ import (absolute_import, print_function, division) import select, socket, threading, sys, time, traceback from OpenSSL import SSL -import certutils +from . import certutils EINTR = 4 diff --git a/netlib/test.py b/netlib/test.py index bb0012ad..31a848a6 100644 --- a/netlib/test.py +++ b/netlib/test.py @@ -1,6 +1,7 @@ +from __future__ import (absolute_import, print_function, division) import threading, Queue, cStringIO -import tcp, certutils import OpenSSL +from . import tcp, certutils class ServerThread(threading.Thread): def __init__(self, server): diff --git a/netlib/utils.py b/netlib/utils.py index 69ba456a..79077ac6 100644 --- a/netlib/utils.py +++ b/netlib/utils.py @@ -1,4 +1,4 @@ -import socket +from __future__ import (absolute_import, print_function, division) def isascii(s): diff --git a/netlib/version.py b/netlib/version.py index 25565d40..913f753a 100644 --- a/netlib/version.py +++ b/netlib/version.py @@ -1,3 +1,5 @@ +from __future__ import (absolute_import, print_function, division) + IVERSION = (0, 11) VERSION = ".".join(str(i) for i in IVERSION) MINORVERSION = ".".join(str(i) for i in IVERSION[:2]) diff --git a/netlib/wsgi.py b/netlib/wsgi.py index b576bdff..492803ab 100644 --- a/netlib/wsgi.py +++ b/netlib/wsgi.py @@ -1,5 +1,6 @@ +from __future__ import (absolute_import, print_function, division) import cStringIO, urllib, time, traceback -import odict, tcp +from . import odict, tcp class ClientConn: -- cgit v1.2.3 From 3d489f3bb7db6dda7b8476f6daa2177048c911ff Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Wed, 3 Sep 2014 17:15:50 +0200 Subject: adapt netlib.wsgi to changes in mitmproxy/mitmproxy#341 --- netlib/tcp.py | 8 ++++---- netlib/wsgi.py | 32 ++++++++++++++++---------------- 2 files changed, 20 insertions(+), 20 deletions(-) (limited to 'netlib') diff --git a/netlib/tcp.py b/netlib/tcp.py index f49346a1..b386603c 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -486,10 +486,10 @@ class TCPServer(object): # none. if traceback: exc = traceback.format_exc() - print >> fp, '-'*40 - print >> fp, "Error in processing of request from %s:%s" % (client_address.host, client_address.port) - print >> fp, exc - print >> fp, '-'*40 + print('-' * 40, file=fp) + print("Error in processing of request from %s:%s" % (client_address.host, client_address.port), file=fp) + print(exc, file=fp) + print('-' * 40, file=fp) def handle_client_connection(self, conn, client_address): # pragma: no cover """ diff --git a/netlib/wsgi.py b/netlib/wsgi.py index 492803ab..568b1f9c 100644 --- a/netlib/wsgi.py +++ b/netlib/wsgi.py @@ -9,15 +9,15 @@ class ClientConn: class Flow: - def __init__(self, client_conn): - self.client_conn = client_conn + def __init__(self, address, request): + self.client_conn = ClientConn(address) + self.request = request class Request: - def __init__(self, client_conn, scheme, method, path, headers, content): + def __init__(self, scheme, method, path, headers, content): self.scheme, self.method, self.path = scheme, method, path self.headers, self.content = headers, content - self.flow = Flow(client_conn) def date_time_string(): @@ -39,37 +39,37 @@ class WSGIAdaptor: def __init__(self, app, domain, port, sversion): self.app, self.domain, self.port, self.sversion = app, domain, port, sversion - def make_environ(self, request, errsoc, **extra): - if '?' in request.path: - path_info, query = request.path.split('?', 1) + def make_environ(self, flow, errsoc, **extra): + if '?' in flow.request.path: + path_info, query = flow.request.path.split('?', 1) else: - path_info = request.path + path_info = flow.request.path query = '' environ = { 'wsgi.version': (1, 0), - 'wsgi.url_scheme': request.scheme, - 'wsgi.input': cStringIO.StringIO(request.content), + 'wsgi.url_scheme': flow.request.scheme, + 'wsgi.input': cStringIO.StringIO(flow.request.content), 'wsgi.errors': errsoc, 'wsgi.multithread': True, 'wsgi.multiprocess': False, 'wsgi.run_once': False, 'SERVER_SOFTWARE': self.sversion, - 'REQUEST_METHOD': request.method, + 'REQUEST_METHOD': flow.request.method, 'SCRIPT_NAME': '', 'PATH_INFO': urllib.unquote(path_info), 'QUERY_STRING': query, - 'CONTENT_TYPE': request.headers.get('Content-Type', [''])[0], - 'CONTENT_LENGTH': request.headers.get('Content-Length', [''])[0], + 'CONTENT_TYPE': flow.request.headers.get('Content-Type', [''])[0], + 'CONTENT_LENGTH': flow.request.headers.get('Content-Length', [''])[0], 'SERVER_NAME': self.domain, 'SERVER_PORT': str(self.port), # FIXME: We need to pick up the protocol read from the request. 'SERVER_PROTOCOL': "HTTP/1.1", } environ.update(extra) - if request.flow.client_conn.address: - environ["REMOTE_ADDR"], environ["REMOTE_PORT"] = request.flow.client_conn.address() + if flow.client_conn.address: + environ["REMOTE_ADDR"], environ["REMOTE_PORT"] = flow.client_conn.address() - for key, value in request.headers.items(): + for key, value in flow.request.headers.items(): key = 'HTTP_' + key.upper().replace('-', '_') if key not in ('HTTP_CONTENT_TYPE', 'HTTP_CONTENT_LENGTH'): environ[key] = value -- cgit v1.2.3 From ec628bc37d173b622e905e8012a08a7328cf7215 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Thu, 4 Sep 2014 01:10:44 +0200 Subject: fix tcp.Address inequality comparison --- netlib/tcp.py | 6 ++++++ 1 file changed, 6 insertions(+) (limited to 'netlib') diff --git a/netlib/tcp.py b/netlib/tcp.py index b386603c..5ecfca9d 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -216,10 +216,16 @@ class Address(object): def use_ipv6(self, b): self.family = socket.AF_INET6 if b else socket.AF_INET + def __repr__(self): + return repr(self.address) + def __eq__(self, other): other = Address.wrap(other) return (self.address, self.family) == (other.address, other.family) + def __ne__(self, other): + return not self.__eq__(other) + class _Connection(object): def get_current_cipher(self): -- cgit v1.2.3 From 4bf7f3c0ff5158cd178756bc2a414f506fb34e05 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Thu, 4 Sep 2014 16:55:02 +0200 Subject: set source_address if not manually specified --- netlib/tcp.py | 2 ++ 1 file changed, 2 insertions(+) (limited to 'netlib') diff --git a/netlib/tcp.py b/netlib/tcp.py index 5ecfca9d..ede8682b 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -319,6 +319,8 @@ class TCPClient(_Connection): if self.source_address: connection.bind(self.source_address()) connection.connect(self.address()) + if not self.source_address: + self.source_address = Address(connection.getsockname()) self.rfile = Reader(connection.makefile('rb', self.rbufsize)) self.wfile = Writer(connection.makefile('wb', self.wbufsize)) except (socket.error, IOError), err: -- cgit v1.2.3 From d9a731b23a930474adc35d6b4ebee68cd05a0940 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Thu, 4 Sep 2014 19:18:43 +0200 Subject: make inequality comparison work --- netlib/certutils.py | 3 +++ netlib/odict.py | 3 +++ 2 files changed, 6 insertions(+) (limited to 'netlib') diff --git a/netlib/certutils.py b/netlib/certutils.py index 18179917..84316882 100644 --- a/netlib/certutils.py +++ b/netlib/certutils.py @@ -285,6 +285,9 @@ class SSLCert: def __eq__(self, other): return self.digest("sha1") == other.digest("sha1") + def __ne__(self, other): + return not self.__eq__(other) + @classmethod def from_pem(klass, txt): x509 = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, txt) diff --git a/netlib/odict.py b/netlib/odict.py index a0e1f694..1e51bb3f 100644 --- a/netlib/odict.py +++ b/netlib/odict.py @@ -24,6 +24,9 @@ class ODict: def __eq__(self, other): return self.lst == other.lst + def __ne__(self, other): + return not self.__eq__(other) + def __iter__(self): return self.lst.__iter__() -- cgit v1.2.3 From 3b81d678c4ff6ae8be563c3d087c4786648c24af Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Sun, 7 Sep 2014 11:24:41 +1200 Subject: Use print function after future import --- netlib/tcp.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) (limited to 'netlib') diff --git a/netlib/tcp.py b/netlib/tcp.py index f49346a1..a5b9af22 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -486,10 +486,13 @@ class TCPServer(object): # none. if traceback: exc = traceback.format_exc() - print >> fp, '-'*40 - print >> fp, "Error in processing of request from %s:%s" % (client_address.host, client_address.port) - print >> fp, exc - print >> fp, '-'*40 + print('-'*40, file=fp) + print( + "Error in processing of request from %s:%s" % ( + client_address.host, client_address.port + ), file=fp) + print(exc, file=fp) + print('-'*40, file=fp) def handle_client_connection(self, conn, client_address): # pragma: no cover """ -- cgit v1.2.3 From f4013dcd406c731c08c02789f80ccb364844c0ff Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Sun, 7 Sep 2014 12:47:17 +1200 Subject: Add a FIXME note for discarded credentials --- netlib/http.py | 3 +++ 1 file changed, 3 insertions(+) (limited to 'netlib') diff --git a/netlib/http.py b/netlib/http.py index 53a47d50..35e959cd 100644 --- a/netlib/http.py +++ b/netlib/http.py @@ -46,6 +46,9 @@ def parse_url(url): if not scheme: return None if '@' in netloc: + # FIXME: Consider what to do with the discarded credentials here Most + # probably we should extend the signature to return these as a separate + # value. _, netloc = string.rsplit(netloc, '@', maxsplit=1) if ':' in netloc: host, port = string.rsplit(netloc, ':', maxsplit=1) -- cgit v1.2.3 From f90ea89e69b3ff9fb612b0ee6024f5546f198ca6 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Mon, 8 Sep 2014 18:38:05 +0200 Subject: more verbose errors --- netlib/tcp.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'netlib') diff --git a/netlib/tcp.py b/netlib/tcp.py index 2704eeae..0a3c4ff9 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -308,7 +308,7 @@ class TCPClient(_Connection): try: self.connection.do_handshake() except SSL.Error, v: - raise NetLibError("SSL handshake error: %s"%str(v)) + raise NetLibError("SSL handshake error: %s"%repr(v)) self.cert = certutils.SSLCert(self.connection.get_peer_certificate()) self.rfile.set_descriptor(self.connection) self.wfile.set_descriptor(self.connection) @@ -417,7 +417,7 @@ class BaseHandler(_Connection): try: self.connection.do_handshake() except SSL.Error, v: - raise NetLibError("SSL handshake error: %s"%str(v)) + raise NetLibError("SSL handshake error: %s"%repr(v)) self.rfile.set_descriptor(self.connection) self.wfile.set_descriptor(self.connection) -- cgit v1.2.3 From 63c1efd3946ce672640b43b005d12f8f117d670a Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Tue, 9 Sep 2014 10:08:56 +1200 Subject: Remove avoidable imports from OpenSSL Fixes #38 --- netlib/tcp.py | 59 ++++++++++++++++++++++------------------------------------- 1 file changed, 22 insertions(+), 37 deletions(-) (limited to 'netlib') diff --git a/netlib/tcp.py b/netlib/tcp.py index 2704eeae..080797b4 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -1,6 +1,12 @@ from __future__ import (absolute_import, print_function, division) -import select, socket, threading, sys, time, traceback +import select +import socket +import sys +import threading +import time +import traceback from OpenSSL import SSL + from . import certutils @@ -11,35 +17,6 @@ SSLv3_METHOD = SSL.SSLv3_METHOD SSLv23_METHOD = SSL.SSLv23_METHOD TLSv1_METHOD = SSL.TLSv1_METHOD -OP_ALL = SSL.OP_ALL -OP_CIPHER_SERVER_PREFERENCE = SSL.OP_CIPHER_SERVER_PREFERENCE -OP_COOKIE_EXCHANGE = SSL.OP_COOKIE_EXCHANGE -OP_DONT_INSERT_EMPTY_FRAGMENTS = SSL.OP_DONT_INSERT_EMPTY_FRAGMENTS -OP_EPHEMERAL_RSA = SSL.OP_EPHEMERAL_RSA -OP_MICROSOFT_BIG_SSLV3_BUFFER = SSL.OP_MICROSOFT_BIG_SSLV3_BUFFER -OP_MICROSOFT_SESS_ID_BUG = SSL.OP_MICROSOFT_SESS_ID_BUG -try: - OP_MSIE_SSLV2_RSA_PADDING = SSL.OP_MSIE_SSLV2_RSA_PADDING -except AttributeError: - pass -OP_NETSCAPE_CA_DN_BUG = SSL.OP_NETSCAPE_CA_DN_BUG -OP_NETSCAPE_CHALLENGE_BUG = SSL.OP_NETSCAPE_CHALLENGE_BUG -OP_NETSCAPE_DEMO_CIPHER_CHANGE_BUG = SSL.OP_NETSCAPE_DEMO_CIPHER_CHANGE_BUG -OP_NETSCAPE_REUSE_CIPHER_CHANGE_BUG = SSL.OP_NETSCAPE_REUSE_CIPHER_CHANGE_BUG -OP_NO_QUERY_MTU = SSL.OP_NO_QUERY_MTU -OP_NO_SSLv2 = SSL.OP_NO_SSLv2 -OP_NO_SSLv3 = SSL.OP_NO_SSLv3 -OP_NO_TICKET = SSL.OP_NO_TICKET -OP_NO_TLSv1 = SSL.OP_NO_TLSv1 -OP_PKCS1_CHECK_1 = SSL.OP_PKCS1_CHECK_1 -OP_PKCS1_CHECK_2 = SSL.OP_PKCS1_CHECK_2 -OP_SINGLE_DH_USE = SSL.OP_SINGLE_DH_USE -OP_SSLEAY_080_CLIENT_DH_BUG = SSL.OP_SSLEAY_080_CLIENT_DH_BUG -OP_SSLREF2_REUSE_CERT_TYPE_BUG = SSL.OP_SSLREF2_REUSE_CERT_TYPE_BUG -OP_TLS_BLOCK_PADDING_BUG = SSL.OP_TLS_BLOCK_PADDING_BUG -OP_TLS_D5_BUG = SSL.OP_TLS_D5_BUG -OP_TLS_ROLLBACK_BUG = SSL.OP_TLS_ROLLBACK_BUG - class NetLibError(Exception): pass class NetLibDisconnect(NetLibError): pass @@ -251,7 +228,8 @@ class _Connection(object): def close(self): """ - Does a hard close of the socket, i.e. a shutdown, followed by a close. + Does a hard close of the socket, i.e. a shutdown, followed by a + close. """ try: if self.ssl_established: @@ -273,6 +251,7 @@ class _Connection(object): class TCPClient(_Connection): rbufsize = -1 wbufsize = -1 + def __init__(self, address, source_address=None): self.address = Address.wrap(address) self.source_address = Address.wrap(source_address) if source_address else None @@ -284,6 +263,8 @@ class TCPClient(_Connection): def convert_to_ssl(self, cert=None, sni=None, method=TLSv1_METHOD, options=None, cipher_list=None): """ cert: Path to a file containing both client cert and private key. + + options: A bit field consisting of OpenSSL.SSL.OP_* values """ context = SSL.Context(method) if cipher_list: @@ -358,18 +339,22 @@ class BaseHandler(_Connection): dhparams=None, ca_file=None): """ cert: A certutils.SSLCert object. + method: One of SSLv2_METHOD, SSLv3_METHOD, SSLv23_METHOD, or TLSv1_METHOD + handle_sni: SNI handler, should take a connection object. Server name can be retrieved like this: - connection.get_servername() + connection.get_servername() + + options: A bit field consisting of OpenSSL.SSL.OP_* values - And you can specify the connection keys as follows: + And you can specify the connection keys as follows: - new_context = Context(TLSv1_METHOD) - new_context.use_privatekey(key) - new_context.use_certificate(cert) - connection.set_context(new_context) + new_context = Context(TLSv1_METHOD) + new_context.use_privatekey(key) + new_context.use_certificate(cert) + connection.set_context(new_context) The request_client_cert argument requires some explanation. We're supposed to be able to do this with no negative effects - if the -- cgit v1.2.3 From 414a0a1602b27e9ed1d5aae42ad06d781a5461a6 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Wed, 17 Sep 2014 11:47:07 +1200 Subject: Adjust for state object protocol changes in mitmproxy. --- netlib/odict.py | 22 ++++++++++++---------- 1 file changed, 12 insertions(+), 10 deletions(-) (limited to 'netlib') diff --git a/netlib/odict.py b/netlib/odict.py index 1e51bb3f..3fb38d85 100644 --- a/netlib/odict.py +++ b/netlib/odict.py @@ -101,16 +101,6 @@ class ODict: def items(self): return self.lst[:] - def _get_state(self): - return [tuple(i) for i in self.lst] - - def _load_state(self, state): - self.list = [list(i) for i in state] - - @classmethod - def _from_state(klass, state): - return klass([list(i) for i in state]) - def copy(self): """ Returns a copy of this object. @@ -171,6 +161,18 @@ class ODict: self.lst = nlst return count + # Implement the StateObject protocol from mitmproxy + def get_state(self): + return [tuple(i) for i in self.lst] + + def load_state(self, state): + self.list = [list(i) for i in state] + + @classmethod + def from_state(klass, state): + return klass([list(i) for i in state]) + + class ODictCaseless(ODict): """ -- cgit v1.2.3 From 0e307964698379a973e8a1f96e3145188b9c0b8d Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Wed, 17 Sep 2014 14:04:26 +1200 Subject: Short-form getstate --- netlib/odict.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'netlib') diff --git a/netlib/odict.py b/netlib/odict.py index 3fb38d85..61448e6d 100644 --- a/netlib/odict.py +++ b/netlib/odict.py @@ -162,7 +162,7 @@ class ODict: return count # Implement the StateObject protocol from mitmproxy - def get_state(self): + def get_state(self, short=False): return [tuple(i) for i in self.lst] def load_state(self, state): -- cgit v1.2.3 From e73a2dbab12296d9787164b5b33320b6d31784d5 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Sun, 28 Sep 2014 03:15:26 +0200 Subject: minor changes --- netlib/tcp.py | 6 +++--- netlib/test.py | 6 +++--- 2 files changed, 6 insertions(+), 6 deletions(-) (limited to 'netlib') diff --git a/netlib/tcp.py b/netlib/tcp.py index c8a02ab4..4f5423e4 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -471,7 +471,7 @@ class TCPServer(object): self.socket.close() self.handle_shutdown() - def handle_error(self, request, client_address, fp=sys.stderr): + def handle_error(self, connection, client_address, fp=sys.stderr): """ Called when handle_client_connection raises an exception. """ @@ -479,13 +479,13 @@ class TCPServer(object): # none. if traceback: exc = traceback.format_exc() - print('-'*40, file=fp) + print('-' * 40, file=fp) print( "Error in processing of request from %s:%s" % ( client_address.host, client_address.port ), file=fp) print(exc, file=fp) - print('-'*40, file=fp) + print('-' * 40, file=fp) def handle_client_connection(self, conn, client_address): # pragma: no cover """ diff --git a/netlib/test.py b/netlib/test.py index 31a848a6..fb468907 100644 --- a/netlib/test.py +++ b/netlib/test.py @@ -64,7 +64,7 @@ class TServer(tcp.TCPServer): key = OpenSSL.crypto.load_privatekey(OpenSSL.crypto.FILETYPE_PEM, raw) if self.ssl["v3_only"]: method = tcp.SSLv3_METHOD - options = tcp.OP_NO_SSLv2|tcp.OP_NO_TLSv1 + options = OpenSSL.SSL.OP_NO_SSLv2 | OpenSSL.SSL.OP_NO_TLSv1 else: method = tcp.SSLv23_METHOD options = None @@ -80,7 +80,7 @@ class TServer(tcp.TCPServer): h.handle() h.finish() - def handle_error(self, request, client_address): + def handle_error(self, connection, client_address, fp=None): s = cStringIO.StringIO() - tcp.TCPServer.handle_error(self, request, client_address, s) + tcp.TCPServer.handle_error(self, connection, client_address, s) self.q.put(s.getvalue()) -- cgit v1.2.3 From 274688172d62131ddf30cf67e6c084e0e928d4bf Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Wed, 8 Oct 2014 18:40:46 +0200 Subject: fix mitmproxy/mitmproxy#373 --- netlib/tcp.py | 25 +++++++++++++++++-------- 1 file changed, 17 insertions(+), 8 deletions(-) (limited to 'netlib') diff --git a/netlib/tcp.py b/netlib/tcp.py index 4f5423e4..aca4bd1b 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -232,16 +232,25 @@ class _Connection(object): close. """ try: - if self.ssl_established: + if type(self.connection) == SSL.Connection: self.connection.shutdown() self.connection.sock_shutdown(socket.SHUT_WR) else: self.connection.shutdown(socket.SHUT_WR) - #Section 4.2.2.13 of RFC 1122 tells us that a close() with any - # pending readable data could lead to an immediate RST being sent. - #http://ia600609.us.archive.org/22/items/TheUltimateSo_lingerPageOrWhyIsMyTcpNotReliable/the-ultimate-so_linger-page-or-why-is-my-tcp-not-reliable.html - while self.connection.recv(4096): # pragma: no cover - pass + + # Section 4.2.2.13 of RFC 1122 tells us that a close() with any + # pending readable data could lead to an immediate RST being sent (which is the case on Windows). + # http://ia600609.us.archive.org/22/items/TheUltimateSo_lingerPageOrWhyIsMyTcpNotReliable/the-ultimate-so_linger-page-or-why-is-my-tcp-not-reliable.html + # + # Do not call this for an SSL.Connection: + # If the SSL handshake failed at the first place, OpenSSL's SSL_read tries to negotiate the connection + # again at this point, calls the SNI handler and segfaults. + # https://github.com/mitmproxy/mitmproxy/issues/373#issuecomment-58383499 + # (if this turns out to be an issue for successful SSL connections, + # we should check for ssl_established or access the socket directly) + + while self.connection.recv(4096): # pragma: no cover + pass self.connection.close() except (socket.error, SSL.Error, IOError): # Socket probably already closed @@ -281,7 +290,6 @@ class TCPClient(_Connection): except SSL.Error, v: raise NetLibError("SSL client certificate error: %s"%str(v)) self.connection = SSL.Connection(context, self.connection) - self.ssl_established = True if sni: self.sni = sni self.connection.set_tlsext_host_name(sni) @@ -290,6 +298,7 @@ class TCPClient(_Connection): self.connection.do_handshake() except SSL.Error, v: raise NetLibError("SSL handshake error: %s"%repr(v)) + self.ssl_established = True self.cert = certutils.SSLCert(self.connection.get_peer_certificate()) self.rfile.set_descriptor(self.connection) self.wfile.set_descriptor(self.connection) @@ -397,12 +406,12 @@ class BaseHandler(_Connection): """ ctx = self._create_ssl_context(cert, key, **sslctx_kwargs) self.connection = SSL.Connection(ctx, self.connection) - self.ssl_established = True self.connection.set_accept_state() try: self.connection.do_handshake() except SSL.Error, v: raise NetLibError("SSL handshake error: %s"%repr(v)) + self.ssl_established = True self.rfile.set_descriptor(self.connection) self.wfile.set_descriptor(self.connection) -- cgit v1.2.3 From fdb6f5552d43d7ab02320ccd7e6d58750e33c4c4 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Wed, 8 Oct 2014 20:46:30 +0200 Subject: CertStore: add support for cert chains --- netlib/certutils.py | 70 +++++++++++++++++++++++++++++++---------------------- netlib/tcp.py | 6 ++--- 2 files changed, 44 insertions(+), 32 deletions(-) (limited to 'netlib') diff --git a/netlib/certutils.py b/netlib/certutils.py index fe067ca1..c9e6df26 100644 --- a/netlib/certutils.py +++ b/netlib/certutils.py @@ -113,13 +113,21 @@ def dummy_cert(privkey, cacert, commonname, sans): # return current.value +class CertStoreEntry(object): + def __init__(self, cert, pkey=None, chain_file=None): + self.cert = cert + self.pkey = pkey + self.chain_file = chain_file + class CertStore: """ Implements an in-memory certificate store. """ - def __init__(self, privkey, cacert, dhparams=None): - self.privkey, self.cacert = privkey, cacert + def __init__(self, default_pkey, default_ca, default_chain_file, dhparams=None): + self.default_pkey = default_pkey + self.default_ca = default_ca + self.default_chain_file = default_chain_file self.dhparams = dhparams self.certs = dict() @@ -142,21 +150,21 @@ class CertStore: return dh @classmethod - def from_store(klass, path, basename): - p = os.path.join(path, basename + "-ca.pem") - if not os.path.exists(p): - key, ca = klass.create_store(path, basename) + def from_store(cls, path, basename): + ca_path = os.path.join(path, basename + "-ca.pem") + if not os.path.exists(ca_path): + key, ca = cls.create_store(path, basename) else: - p = os.path.join(path, basename + "-ca.pem") - raw = file(p, "rb").read() + with open(ca_path, "rb") as f: + raw = f.read() ca = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, raw) key = OpenSSL.crypto.load_privatekey(OpenSSL.crypto.FILETYPE_PEM, raw) - dhp = os.path.join(path, basename + "-dhparam.pem") - dh = klass.load_dhparam(dhp) - return klass(key, ca, dh) + dh_path = os.path.join(path, basename + "-dhparam.pem") + dh = cls.load_dhparam(dh_path) + return cls(key, ca, ca_path, dh) @classmethod - def create_store(klass, path, basename, o=None, cn=None, expiry=DEFAULT_EXP): + def create_store(cls, path, basename, o=None, cn=None, expiry=DEFAULT_EXP): if not os.path.exists(path): os.makedirs(path) @@ -194,25 +202,29 @@ class CertStore: return key, ca def add_cert_file(self, spec, path): - raw = file(path, "rb").read() - cert = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, raw) + with open(path, "rb") as f: + raw = f.read() + cert = SSLCert(OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, raw)) try: - privkey = OpenSSL.crypto.load_privatekey(OpenSSL.crypto.FILETYPE_PEM, raw) + pkey = OpenSSL.crypto.load_privatekey(OpenSSL.crypto.FILETYPE_PEM, raw) except Exception: - privkey = None - self.add_cert(SSLCert(cert), privkey, spec) + pkey = None + self.add_cert( + CertStoreEntry(cert, pkey, path), + spec + ) - def add_cert(self, cert, privkey, *names): + def add_cert(self, entry, *names): """ Adds a cert to the certstore. We register the CN in the cert plus any SANs, and also the list of names provided as an argument. """ - if cert.cn: - self.certs[cert.cn] = (cert, privkey) - for i in cert.altnames: - self.certs[i] = (cert, privkey) + if entry.cert.cn: + self.certs[entry.cert.cn] = entry + for i in entry.cert.altnames: + self.certs[i] = entry for i in names: - self.certs[i] = (cert, privkey) + self.certs[i] = entry @staticmethod def asterisk_forms(dn): @@ -246,17 +258,17 @@ class CertStore: name = next(itertools.ifilter(lambda key: key in self.certs, potential_keys), None) if name: - c = self.certs[name] + entry = self.certs[name] else: - c = dummy_cert(self.privkey, self.cacert, commonname, sans), None - self.certs[(commonname, tuple(sans))] = c + entry = CertStoreEntry(cert=dummy_cert(self.default_pkey, self.default_ca, commonname, sans)) + self.certs[(commonname, tuple(sans))] = entry - return c[0], (c[1] or self.privkey) + return entry.cert, (entry.pkey or self.default_pkey), (entry.chain_file or self.default_chain_file) def gen_pkey(self, cert): from . import certffi - certffi.set_flags(self.privkey, 1) - return self.privkey + certffi.set_flags(self.default_pkey, 1) + return self.default_pkey class _GeneralName(univ.Choice): diff --git a/netlib/tcp.py b/netlib/tcp.py index aca4bd1b..8e87bec8 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -345,7 +345,7 @@ class BaseHandler(_Connection): def _create_ssl_context(self, cert, key, method=SSLv23_METHOD, options=None, handle_sni=None, request_client_cert=None, cipher_list=None, - dhparams=None, ca_file=None): + dhparams=None, chain_file=None): """ cert: A certutils.SSLCert object. @@ -377,8 +377,8 @@ class BaseHandler(_Connection): ctx = SSL.Context(method) if not options is None: ctx.set_options(options) - if ca_file: - ctx.load_verify_locations(ca_file) + if chain_file: + ctx.load_verify_locations(chain_file) if cipher_list: try: ctx.set_cipher_list(cipher_list) -- cgit v1.2.3 From 9ef84ccc1cdd0d8da890ba012812c760e31f2fab Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Thu, 9 Oct 2014 00:15:39 +0200 Subject: clean up code --- netlib/certutils.py | 73 +++++++++++++++++++++++++++-------------------------- 1 file changed, 37 insertions(+), 36 deletions(-) (limited to 'netlib') diff --git a/netlib/certutils.py b/netlib/certutils.py index c9e6df26..af6177d8 100644 --- a/netlib/certutils.py +++ b/netlib/certutils.py @@ -114,9 +114,9 @@ def dummy_cert(privkey, cacert, commonname, sans): class CertStoreEntry(object): - def __init__(self, cert, pkey=None, chain_file=None): + def __init__(self, cert, privatekey, chain_file): self.cert = cert - self.pkey = pkey + self.privatekey = privatekey self.chain_file = chain_file @@ -124,15 +124,15 @@ class CertStore: """ Implements an in-memory certificate store. """ - def __init__(self, default_pkey, default_ca, default_chain_file, dhparams=None): - self.default_pkey = default_pkey + def __init__(self, default_privatekey, default_ca, default_chain_file, dhparams=None): + self.default_privatekey = default_privatekey self.default_ca = default_ca self.default_chain_file = default_chain_file self.dhparams = dhparams self.certs = dict() - @classmethod - def load_dhparam(klass, path): + @staticmethod + def load_dhparam(path): # netlib<=0.10 doesn't generate a dhparam file. # Create it now if neccessary. @@ -163,8 +163,8 @@ class CertStore: dh = cls.load_dhparam(dh_path) return cls(key, ca, ca_path, dh) - @classmethod - def create_store(cls, path, basename, o=None, cn=None, expiry=DEFAULT_EXP): + @staticmethod + def create_store(path, basename, o=None, cn=None, expiry=DEFAULT_EXP): if not os.path.exists(path): os.makedirs(path) @@ -173,32 +173,28 @@ class CertStore: key, ca = create_ca(o=o, cn=cn, exp=expiry) # Dump the CA plus private key - f = open(os.path.join(path, basename + "-ca.pem"), "wb") - f.write(OpenSSL.crypto.dump_privatekey(OpenSSL.crypto.FILETYPE_PEM, key)) - f.write(OpenSSL.crypto.dump_certificate(OpenSSL.crypto.FILETYPE_PEM, ca)) - f.close() + with open(os.path.join(path, basename + "-ca.pem"), "wb") as f: + f.write(OpenSSL.crypto.dump_privatekey(OpenSSL.crypto.FILETYPE_PEM, key)) + f.write(OpenSSL.crypto.dump_certificate(OpenSSL.crypto.FILETYPE_PEM, ca)) # Dump the certificate in PEM format - f = open(os.path.join(path, basename + "-ca-cert.pem"), "wb") - f.write(OpenSSL.crypto.dump_certificate(OpenSSL.crypto.FILETYPE_PEM, ca)) - f.close() + with open(os.path.join(path, basename + "-ca-cert.pem"), "wb") as f: + f.write(OpenSSL.crypto.dump_certificate(OpenSSL.crypto.FILETYPE_PEM, ca)) # Create a .cer file with the same contents for Android - f = open(os.path.join(path, basename + "-ca-cert.cer"), "wb") - f.write(OpenSSL.crypto.dump_certificate(OpenSSL.crypto.FILETYPE_PEM, ca)) - f.close() + with open(os.path.join(path, basename + "-ca-cert.cer"), "wb") as f: + f.write(OpenSSL.crypto.dump_certificate(OpenSSL.crypto.FILETYPE_PEM, ca)) # Dump the certificate in PKCS12 format for Windows devices - f = open(os.path.join(path, basename + "-ca-cert.p12"), "wb") - p12 = OpenSSL.crypto.PKCS12() - p12.set_certificate(ca) - p12.set_privatekey(key) - f.write(p12.export()) - f.close() - - f = open(os.path.join(path, basename + "-dhparam.pem"), "wb") - f.write(DEFAULT_DHPARAM) - f.close() + with open(os.path.join(path, basename + "-ca-cert.p12"), "wb") as f: + p12 = OpenSSL.crypto.PKCS12() + p12.set_certificate(ca) + p12.set_privatekey(key) + f.write(p12.export()) + + with open(os.path.join(path, basename + "-dhparam.pem"), "wb") as f: + f.write(DEFAULT_DHPARAM) + return key, ca def add_cert_file(self, spec, path): @@ -206,11 +202,11 @@ class CertStore: raw = f.read() cert = SSLCert(OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, raw)) try: - pkey = OpenSSL.crypto.load_privatekey(OpenSSL.crypto.FILETYPE_PEM, raw) + privatekey = OpenSSL.crypto.load_privatekey(OpenSSL.crypto.FILETYPE_PEM, raw) except Exception: - pkey = None + privatekey = self.default_privatekey self.add_cert( - CertStoreEntry(cert, pkey, path), + CertStoreEntry(cert, privatekey, path), spec ) @@ -241,7 +237,7 @@ class CertStore: def get_cert(self, commonname, sans): """ - Returns an (cert, privkey) tuple. + Returns an (cert, privkey, cert_chain) tuple. commonname: Common name for the generated certificate. Must be a valid, plain-ASCII, IDNA-encoded domain name. @@ -260,15 +256,20 @@ class CertStore: if name: entry = self.certs[name] else: - entry = CertStoreEntry(cert=dummy_cert(self.default_pkey, self.default_ca, commonname, sans)) + entry = CertStoreEntry( + cert=dummy_cert(self.default_privatekey, self.default_ca, commonname, sans), + privatekey=self.default_privatekey, + chain_file=self.default_chain_file + ) self.certs[(commonname, tuple(sans))] = entry - return entry.cert, (entry.pkey or self.default_pkey), (entry.chain_file or self.default_chain_file) + return entry.cert, entry.privatekey, entry.chain_file def gen_pkey(self, cert): + # FIXME: We should do something with cert here? from . import certffi - certffi.set_flags(self.default_pkey, 1) - return self.default_pkey + certffi.set_flags(self.default_privatekey, 1) + return self.default_privatekey class _GeneralName(univ.Choice): -- cgit v1.2.3 From 987fa22e646e2ab79cf93adf7966b5a27273685a Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Thu, 9 Oct 2014 01:46:08 +0200 Subject: make socks reading more bulletproof --- netlib/socks.py | 29 ++++++++++++++++++++--------- 1 file changed, 20 insertions(+), 9 deletions(-) (limited to 'netlib') diff --git a/netlib/socks.py b/netlib/socks.py index 1da5b6cc..5b05b397 100644 --- a/netlib/socks.py +++ b/netlib/socks.py @@ -47,6 +47,17 @@ class METHOD(object): NO_ACCEPTABLE_METHODS = 0xFF +def _read(f, n): + try: + d = f.read(n) + if len(d) == n: + return d + else: + raise SocksError(REP.GENERAL_SOCKS_SERVER_FAILURE, "Incomplete Read") + except socket.error as e: + raise SocksError(REP.GENERAL_SOCKS_SERVER_FAILURE, str(e)) + + class ClientGreeting(object): __slots__ = ("ver", "methods") @@ -56,9 +67,9 @@ class ClientGreeting(object): @classmethod def from_file(cls, f): - ver, nmethods = struct.unpack("!BB", f.read(2)) + ver, nmethods = struct.unpack("!BB", _read(f, 2)) methods = array.array("B") - methods.fromstring(f.read(nmethods)) + methods.fromstring(_read(f, nmethods)) return cls(ver, methods) def to_file(self, f): @@ -74,7 +85,7 @@ class ServerGreeting(object): @classmethod def from_file(cls, f): - ver, method = struct.unpack("!BB", f.read(2)) + ver, method = struct.unpack("!BB", _read(f, 2)) return cls(ver, method) def to_file(self, f): @@ -91,26 +102,26 @@ class Message(object): @classmethod def from_file(cls, f): - ver, msg, rsv, atyp = struct.unpack("!BBBB", f.read(4)) + ver, msg, rsv, atyp = struct.unpack("!BBBB", _read(f, 4)) if rsv != 0x00: raise SocksError(REP.GENERAL_SOCKS_SERVER_FAILURE, "Socks Request: Invalid reserved byte: %s" % rsv) if atyp == ATYP.IPV4_ADDRESS: - host = socket.inet_ntoa(f.read(4)) # We use tnoa here as ntop is not commonly available on Windows. + host = socket.inet_ntoa(_read(f, 4)) # We use tnoa here as ntop is not commonly available on Windows. use_ipv6 = False elif atyp == ATYP.IPV6_ADDRESS: - host = socket.inet_ntop(socket.AF_INET6, f.read(16)) + host = socket.inet_ntop(socket.AF_INET6, _read(f, 16)) use_ipv6 = True elif atyp == ATYP.DOMAINNAME: - length, = struct.unpack("!B", f.read(1)) - host = f.read(length) + length, = struct.unpack("!B", _read(f, 1)) + host = _read(f, length) use_ipv6 = False else: raise SocksError(REP.ADDRESS_TYPE_NOT_SUPPORTED, "Socks Request: Unknown ATYP: %s" % atyp) - port, = struct.unpack("!H", f.read(2)) + port, = struct.unpack("!H", _read(f, 2)) addr = tcp.Address((host, port), use_ipv6=use_ipv6) return cls(ver, msg, atyp, addr) -- cgit v1.2.3 From e6a8730f98d61583f31ac530e2a1c8da2fa181ed Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Thu, 9 Oct 2014 04:42:39 +0200 Subject: fix tcp closing for ssled connections --- netlib/tcp.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) (limited to 'netlib') diff --git a/netlib/tcp.py b/netlib/tcp.py index 8e87bec8..7a970be6 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -238,19 +238,18 @@ class _Connection(object): else: self.connection.shutdown(socket.SHUT_WR) + if type(self.connection) != SSL.Connection or self.ssl_established: # Section 4.2.2.13 of RFC 1122 tells us that a close() with any # pending readable data could lead to an immediate RST being sent (which is the case on Windows). # http://ia600609.us.archive.org/22/items/TheUltimateSo_lingerPageOrWhyIsMyTcpNotReliable/the-ultimate-so_linger-page-or-why-is-my-tcp-not-reliable.html # - # Do not call this for an SSL.Connection: + # Do not call this for every SSL.Connection: # If the SSL handshake failed at the first place, OpenSSL's SSL_read tries to negotiate the connection # again at this point, calls the SNI handler and segfaults. # https://github.com/mitmproxy/mitmproxy/issues/373#issuecomment-58383499 - # (if this turns out to be an issue for successful SSL connections, - # we should check for ssl_established or access the socket directly) - while self.connection.recv(4096): # pragma: no cover pass + self.connection.close() except (socket.error, SSL.Error, IOError): # Socket probably already closed -- cgit v1.2.3 From 29a4e9105053118aa8c0b458bcb8f10f0bc333d1 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Fri, 17 Oct 2014 18:48:30 +0200 Subject: fix mitmproxy/mitmproxy#375 --- netlib/tcp.py | 9 +++++++++ 1 file changed, 9 insertions(+) (limited to 'netlib') diff --git a/netlib/tcp.py b/netlib/tcp.py index 7a970be6..4705f6df 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -243,12 +243,21 @@ class _Connection(object): # pending readable data could lead to an immediate RST being sent (which is the case on Windows). # http://ia600609.us.archive.org/22/items/TheUltimateSo_lingerPageOrWhyIsMyTcpNotReliable/the-ultimate-so_linger-page-or-why-is-my-tcp-not-reliable.html # + # However, we cannot rely on the shutdown()-followed-by-read()-eof technique proposed by the page above: + # Some remote machines just don't send a TCP FIN, which would leave us in the unfortunate situation that + # recv() would block infinitely. + # As a workaround, we set a timeout here even if we were in blocking mode. + # Please let us know if you have a better solution to this problem. + # # Do not call this for every SSL.Connection: # If the SSL handshake failed at the first place, OpenSSL's SSL_read tries to negotiate the connection # again at this point, calls the SNI handler and segfaults. # https://github.com/mitmproxy/mitmproxy/issues/373#issuecomment-58383499 + timeout = self.connection.gettimeout() + self.connection.settimeout(timeout or 60) while self.connection.recv(4096): # pragma: no cover pass + self.connection.settimeout(timeout) self.connection.close() except (socket.error, SSL.Error, IOError): -- cgit v1.2.3 From ed5e6855652cd3a41579f700d2fb81169c60c3ea Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Wed, 22 Oct 2014 17:54:20 +0200 Subject: refactor tcp close, fix mitmproxy/mitmproxy#376 --- netlib/tcp.py | 99 ++++++++++++++++++++++++++++++----------------------------- 1 file changed, 51 insertions(+), 48 deletions(-) (limited to 'netlib') diff --git a/netlib/tcp.py b/netlib/tcp.py index 4705f6df..46c28cd9 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -204,6 +204,37 @@ class Address(object): return not self.__eq__(other) +def close_socket(sock): + """ + Does a hard close of a socket, without emitting a RST. + """ + try: + # We already indicate that we close our end. + # If we close RD, any further received bytes would result in a RST being set, which we want to avoid + # for our purposes + sock.shutdown(socket.SHUT_WR) # may raise "Transport endpoint is not connected" on Linux + + # Section 4.2.2.13 of RFC 1122 tells us that a close() with any + # pending readable data could lead to an immediate RST being sent (which is the case on Windows). + # http://ia600609.us.archive.org/22/items/TheUltimateSo_lingerPageOrWhyIsMyTcpNotReliable/the-ultimate-so_linger-page-or-why-is-my-tcp-not-reliable.html + # + # However, we cannot rely on the shutdown()-followed-by-read()-eof technique proposed by the page above: + # Some remote machines just don't send a TCP FIN, which would leave us in the unfortunate situation that + # recv() would block infinitely. + # As a workaround, we set a timeout here even if we are in blocking mode. + # Please let us know if you have a better solution to this problem. + + sock.settimeout(sock.gettimeout() or 20) + # may raise a timeout/disconnect exception. + while sock.recv(4096): # pragma: no cover + pass + + except socket.error: + pass + + sock.close() + + class _Connection(object): def get_current_cipher(self): if not self.ssl_established: @@ -216,59 +247,36 @@ class _Connection(object): def finish(self): self.finished = True - try: + + # If we have an SSL connection, wfile.close == connection.close + # (We call _FileLike.set_descriptor(conn)) + # Closing the socket is not our task, therefore we don't call close then. + if type(self.connection) != SSL.Connection: if not getattr(self.wfile, "closed", False): self.wfile.flush() - self.close() + self.wfile.close() self.rfile.close() - except (socket.error, NetLibDisconnect): - # Remote has disconnected - pass - - def close(self): - """ - Does a hard close of the socket, i.e. a shutdown, followed by a - close. - """ - try: - if type(self.connection) == SSL.Connection: + else: + try: self.connection.shutdown() - self.connection.sock_shutdown(socket.SHUT_WR) - else: - self.connection.shutdown(socket.SHUT_WR) - - if type(self.connection) != SSL.Connection or self.ssl_established: - # Section 4.2.2.13 of RFC 1122 tells us that a close() with any - # pending readable data could lead to an immediate RST being sent (which is the case on Windows). - # http://ia600609.us.archive.org/22/items/TheUltimateSo_lingerPageOrWhyIsMyTcpNotReliable/the-ultimate-so_linger-page-or-why-is-my-tcp-not-reliable.html - # - # However, we cannot rely on the shutdown()-followed-by-read()-eof technique proposed by the page above: - # Some remote machines just don't send a TCP FIN, which would leave us in the unfortunate situation that - # recv() would block infinitely. - # As a workaround, we set a timeout here even if we were in blocking mode. - # Please let us know if you have a better solution to this problem. - # - # Do not call this for every SSL.Connection: - # If the SSL handshake failed at the first place, OpenSSL's SSL_read tries to negotiate the connection - # again at this point, calls the SNI handler and segfaults. - # https://github.com/mitmproxy/mitmproxy/issues/373#issuecomment-58383499 - timeout = self.connection.gettimeout() - self.connection.settimeout(timeout or 60) - while self.connection.recv(4096): # pragma: no cover - pass - self.connection.settimeout(timeout) - - self.connection.close() - except (socket.error, SSL.Error, IOError): - # Socket probably already closed - pass + except SSL.Error: + pass class TCPClient(_Connection): rbufsize = -1 wbufsize = -1 + def close(self): + # Make sure to close the real socket, not the SSL proxy. + # OpenSSL is really good at screwing up, i.e. when trying to recv from a failed connection, + # it tries to renegotiate... + if type(self.connection) == SSL.Connection: + close_socket(self.connection._socket) + else: + close_socket(self.connection) + def __init__(self, address, source_address=None): self.address = Address.wrap(address) self.source_address = Address.wrap(source_address) if source_address else None @@ -430,7 +438,6 @@ class BaseHandler(_Connection): self.connection.settimeout(n) - class TCPServer(object): request_queue_size = 20 def __init__(self, address): @@ -450,11 +457,7 @@ class TCPServer(object): except: self.handle_error(connection, client_address) finally: - try: - connection.shutdown(socket.SHUT_RDWR) - except: - pass - connection.close() + close_socket(connection) def serve_forever(self, poll_interval=0.1): self.__is_shut_down.clear() -- cgit v1.2.3 From ba468f12b8f59f63ce85b221f0cb2d9e004efe6e Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Sun, 26 Oct 2014 17:30:26 +1300 Subject: Whitespace and legibility --- netlib/http.py | 80 +++++++++++++++++++++++++++++++++++++++++----------------- 1 file changed, 57 insertions(+), 23 deletions(-) (limited to 'netlib') diff --git a/netlib/http.py b/netlib/http.py index 35e959cd..9268418c 100644 --- a/netlib/http.py +++ b/netlib/http.py @@ -120,11 +120,14 @@ def read_chunked(fp, limit, is_request): try: length = int(line, 16) except ValueError: - raise HttpError(code, "Invalid chunked encoding length: %s" % line) + raise HttpError( + code, + "Invalid chunked encoding length: %s" % line + ) total += length if limit is not None and total > limit: - msg = "HTTP Body too large." \ - " Limit is %s, chunked content length was at least %s" % (limit, total) + msg = "HTTP Body too large. Limit is %s," \ + " chunked content longer than %s" % (limit, total) raise HttpError(code, msg) chunk = fp.read(length) suffix = fp.readline(5) @@ -149,7 +152,9 @@ def get_header_tokens(headers, key): def has_chunked_encoding(headers): - return "chunked" in [i.lower() for i in get_header_tokens(headers, "transfer-encoding")] + return "chunked" in [ + i.lower() for i in get_header_tokens(headers, "transfer-encoding") + ] def parse_http_protocol(s): @@ -261,8 +266,9 @@ def parse_init_http(line): def connection_close(httpversion, headers): """ - Checks the message to see if the client connection should be closed according to RFC 2616 Section 8.1 - Note that a connection should be closed as well if the response has been read until end of the stream. + Checks the message to see if the client connection should be closed + according to RFC 2616 Section 8.1 Note that a connection should be + closed as well if the response has been read until end of the stream. """ # At first, check if we have an explicit Connection header. if "connection" in headers: @@ -271,7 +277,8 @@ def connection_close(httpversion, headers): return True elif "keep-alive" in toks: return False - # If we don't have a Connection header, HTTP 1.1 connections are assumed to be persistent + # If we don't have a Connection header, HTTP 1.1 connections are assumed to + # be persistent if httpversion == (1, 1): return False return True @@ -317,14 +324,25 @@ def read_response(rfile, request_method, body_size_limit, include_body=True): raise HttpError(502, "Invalid headers.") if include_body: - content = read_http_body(rfile, headers, body_size_limit, request_method, code, False) + content = read_http_body( + rfile, + headers, + body_size_limit, + request_method, + code, + False + ) else: - content = None # if include_body==False then a None content means the body should be read separately + # if include_body==False then a None content means the body should be + # read separately + content = None return httpversion, code, msg, headers, content def read_http_body(*args, **kwargs): - return "".join(content for _, content, _ in read_http_body_chunked(*args, **kwargs)) + return "".join( + content for _, content, _ in read_http_body_chunked(*args, **kwargs) + ) def read_http_body_chunked(rfile, headers, limit, request_method, response_code, is_request, max_chunk_size=None): @@ -334,12 +352,15 @@ def read_http_body_chunked(rfile, headers, limit, request_method, response_code, rfile: A file descriptor to read from headers: An ODictCaseless object limit: Size limit. - is_request: True if the body to read belongs to a request, False otherwise + is_request: True if the body to read belongs to a request, False + otherwise """ if max_chunk_size is None: max_chunk_size = limit or sys.maxint - expected_size = expected_http_body_size(headers, is_request, request_method, response_code) + expected_size = expected_http_body_size( + headers, is_request, request_method, response_code + ) if expected_size is None: if has_chunked_encoding(headers): @@ -347,11 +368,18 @@ def read_http_body_chunked(rfile, headers, limit, request_method, response_code, for x in read_chunked(rfile, limit, is_request): yield x else: # pragma: nocover - raise HttpError(400 if is_request else 502, "Content-Length unknown but no chunked encoding") + raise HttpError( + 400 if is_request else 502, + "Content-Length unknown but no chunked encoding" + ) elif expected_size >= 0: if limit is not None and expected_size > limit: - raise HttpError(400 if is_request else 509, - "HTTP Body too large. Limit is %s, content-length was %s" % (limit, expected_size)) + raise HttpError( + 400 if is_request else 509, + "HTTP Body too large. Limit is %s, content-length was %s" % ( + limit, expected_size + ) + ) bytes_left = expected_size while bytes_left: chunk_size = min(bytes_left, max_chunk_size) @@ -368,7 +396,10 @@ def read_http_body_chunked(rfile, headers, limit, request_method, response_code, bytes_left -= chunk_size not_done = rfile.read(1) if not_done: - raise HttpError(400 if is_request else 509, "HTTP Body too large. Limit is %s," % limit) + raise HttpError( + 400 if is_request else 509, + "HTTP Body too large. Limit is %s," % limit + ) def expected_http_body_size(headers, is_request, request_method, response_code): @@ -378,16 +409,16 @@ def expected_http_body_size(headers, is_request, request_method, response_code): - None, if the size in unknown in advance (chunked encoding) - -1, if all data should be read until end of stream. """ - - # Determine response size according to http://tools.ietf.org/html/rfc7230#section-3.3 + # Determine response size according to + # http://tools.ietf.org/html/rfc7230#section-3.3 if request_method: request_method = request_method.upper() if (not is_request and ( - request_method == "HEAD" or - (request_method == "CONNECT" and response_code == 200) or - response_code in [204, 304] or - 100 <= response_code <= 199)): + request_method == "HEAD" or + (request_method == "CONNECT" and response_code == 200) or + response_code in [204, 304] or + 100 <= response_code <= 199)): return 0 if has_chunked_encoding(headers): return None @@ -398,7 +429,10 @@ def expected_http_body_size(headers, is_request, request_method, response_code): raise ValueError() return size except ValueError: - raise HttpError(400 if is_request else 502, "Invalid content-length header: %s" % headers["content-length"]) + raise HttpError( + 400 if is_request else 502, + "Invalid content-length header: %s" % headers["content-length"] + ) if is_request: return 0 return -1 -- cgit v1.2.3 From 9ce2f473f6febf3738dca77b20ab9a7d3092d3d0 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Fri, 7 Nov 2014 15:59:00 +1300 Subject: Simplify expected_http_body_size signature, fixing a traceback found in fuzzing --- netlib/http.py | 10 +++++----- netlib/http_auth.py | 4 ++-- netlib/socks.py | 18 ++++++++++++++---- 3 files changed, 21 insertions(+), 11 deletions(-) (limited to 'netlib') diff --git a/netlib/http.py b/netlib/http.py index 9268418c..d2fc6343 100644 --- a/netlib/http.py +++ b/netlib/http.py @@ -406,8 +406,11 @@ def expected_http_body_size(headers, is_request, request_method, response_code): """ Returns the expected body length: - a positive integer, if the size is known in advance - - None, if the size in unknown in advance (chunked encoding) + - None, if the size in unknown in advance (chunked encoding or invalid + data) - -1, if all data should be read until end of stream. + + May raise HttpError. """ # Determine response size according to # http://tools.ietf.org/html/rfc7230#section-3.3 @@ -429,10 +432,7 @@ def expected_http_body_size(headers, is_request, request_method, response_code): raise ValueError() return size except ValueError: - raise HttpError( - 400 if is_request else 502, - "Invalid content-length header: %s" % headers["content-length"] - ) + return None if is_request: return 0 return -1 diff --git a/netlib/http_auth.py b/netlib/http_auth.py index 49f5925f..dca6e2f3 100644 --- a/netlib/http_auth.py +++ b/netlib/http_auth.py @@ -1,5 +1,4 @@ from __future__ import (absolute_import, print_function, division) -from passlib.apache import HtpasswdFile from argparse import Action, ArgumentTypeError from . import http @@ -83,7 +82,8 @@ class PassManHtpasswd: """ Raises ValueError if htpasswd file is invalid. """ - self.htpasswd = HtpasswdFile(path) + import passlib.apache + self.htpasswd = passlib.apache.HtpasswdFile(path) def test(self, username, password_token): return bool(self.htpasswd.check_password(username, password_token)) diff --git a/netlib/socks.py b/netlib/socks.py index 5b05b397..a3c4e9a2 100644 --- a/netlib/socks.py +++ b/netlib/socks.py @@ -53,7 +53,10 @@ def _read(f, n): if len(d) == n: return d else: - raise SocksError(REP.GENERAL_SOCKS_SERVER_FAILURE, "Incomplete Read") + raise SocksError( + REP.GENERAL_SOCKS_SERVER_FAILURE, + "Incomplete Read" + ) except socket.error as e: raise SocksError(REP.GENERAL_SOCKS_SERVER_FAILURE, str(e)) @@ -76,6 +79,7 @@ class ClientGreeting(object): f.write(struct.pack("!BB", self.ver, len(self.methods))) f.write(self.methods.tostring()) + class ServerGreeting(object): __slots__ = ("ver", "method") @@ -91,6 +95,7 @@ class ServerGreeting(object): def to_file(self, f): f.write(struct.pack("!BB", self.ver, self.method)) + class Message(object): __slots__ = ("ver", "msg", "atyp", "addr") @@ -108,7 +113,8 @@ class Message(object): "Socks Request: Invalid reserved byte: %s" % rsv) if atyp == ATYP.IPV4_ADDRESS: - host = socket.inet_ntoa(_read(f, 4)) # We use tnoa here as ntop is not commonly available on Windows. + # We use tnoa here as ntop is not commonly available on Windows. + host = socket.inet_ntoa(_read(f, 4)) use_ipv6 = False elif atyp == ATYP.IPV6_ADDRESS: host = socket.inet_ntop(socket.AF_INET6, _read(f, 16)) @@ -135,5 +141,9 @@ class Message(object): f.write(struct.pack("!B", len(self.addr.host))) f.write(self.addr.host) else: - raise SocksError(REP.ADDRESS_TYPE_NOT_SUPPORTED, "Unknown ATYP: %s" % self.atyp) - f.write(struct.pack("!H", self.addr.port)) \ No newline at end of file + raise SocksError( + REP.ADDRESS_TYPE_NOT_SUPPORTED, + "Unknown ATYP: %s" % self.atyp + ) + f.write(struct.pack("!H", self.addr.port)) + -- cgit v1.2.3 From 0811a9ebde4975d4e934cf4752376dd0db9bb7e4 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Fri, 7 Nov 2014 16:01:41 +1300 Subject: .flush can raise NetlibDisconnect. This fixes a traceback found in fuzzing. --- netlib/tcp.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) (limited to 'netlib') diff --git a/netlib/tcp.py b/netlib/tcp.py index 46c28cd9..6b7540aa 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -253,7 +253,10 @@ class _Connection(object): # Closing the socket is not our task, therefore we don't call close then. if type(self.connection) != SSL.Connection: if not getattr(self.wfile, "closed", False): - self.wfile.flush() + try: + self.wfile.flush() + except NetLibDisconnect: + pass self.wfile.close() self.rfile.close() -- cgit v1.2.3 From 60584387ff860befe38ada5ec9d35f3c529d0238 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Tue, 11 Nov 2014 12:26:20 +0100 Subject: be more explicit about requirements --- netlib/version.py | 4 ++++ 1 file changed, 4 insertions(+) (limited to 'netlib') diff --git a/netlib/version.py b/netlib/version.py index 913f753a..15a8edf9 100644 --- a/netlib/version.py +++ b/netlib/version.py @@ -5,3 +5,7 @@ VERSION = ".".join(str(i) for i in IVERSION) MINORVERSION = ".".join(str(i) for i in IVERSION[:2]) NAME = "netlib" NAMEVERSION = NAME + " " + VERSION + +NEXT_MINORVERSION = list(IVERSION) +NEXT_MINORVERSION[1] += 1 +NEXT_MINORVERSION = ".".join(str(i) for i in NEXT_MINORVERSION[:2]) \ No newline at end of file -- cgit v1.2.3 From c56e7a90d886d7169a75246de062f0f90028ae6c Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Sat, 15 Nov 2014 12:31:13 +1300 Subject: Fix tracebacks in connection finish --- netlib/tcp.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'netlib') diff --git a/netlib/tcp.py b/netlib/tcp.py index 6b7540aa..1c3bf230 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -255,10 +255,10 @@ class _Connection(object): if not getattr(self.wfile, "closed", False): try: self.wfile.flush() + self.wfile.close() except NetLibDisconnect: pass - self.wfile.close() self.rfile.close() else: try: -- cgit v1.2.3 From 7098c90a6dceddda20de4d7a7dabf836247a38af Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Sat, 15 Nov 2014 12:45:06 +1300 Subject: Bump version to 0.11.1 --- netlib/version.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'netlib') diff --git a/netlib/version.py b/netlib/version.py index 15a8edf9..f67d06b3 100644 --- a/netlib/version.py +++ b/netlib/version.py @@ -1,6 +1,6 @@ from __future__ import (absolute_import, print_function, division) -IVERSION = (0, 11) +IVERSION = (0, 11, 1) VERSION = ".".join(str(i) for i in IVERSION) MINORVERSION = ".".join(str(i) for i in IVERSION[:2]) NAME = "netlib" @@ -8,4 +8,4 @@ NAMEVERSION = NAME + " " + VERSION NEXT_MINORVERSION = list(IVERSION) NEXT_MINORVERSION[1] += 1 -NEXT_MINORVERSION = ".".join(str(i) for i in NEXT_MINORVERSION[:2]) \ No newline at end of file +NEXT_MINORVERSION = ".".join(str(i) for i in NEXT_MINORVERSION[:2]) -- cgit v1.2.3 From 438c1fbc7dddcbddd234db3806a4d6b5770d9904 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Mon, 15 Dec 2014 12:32:36 +0100 Subject: TCPClient: Use TLS1.1+ where available, BaseHandler: disable SSLv2 --- netlib/tcp.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) (limited to 'netlib') diff --git a/netlib/tcp.py b/netlib/tcp.py index 1c3bf230..7010eef0 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -16,6 +16,8 @@ SSLv2_METHOD = SSL.SSLv2_METHOD SSLv3_METHOD = SSL.SSLv3_METHOD SSLv23_METHOD = SSL.SSLv23_METHOD TLSv1_METHOD = SSL.TLSv1_METHOD +OP_NO_SSLv2 = SSL.OP_NO_SSLv2 +OP_NO_SSLv3 = SSL.OP_NO_SSLv3 class NetLibError(Exception): pass @@ -288,7 +290,7 @@ class TCPClient(_Connection): self.ssl_established = False self.sni = None - def convert_to_ssl(self, cert=None, sni=None, method=TLSv1_METHOD, options=None, cipher_list=None): + def convert_to_ssl(self, cert=None, sni=None, method=SSLv23_METHOD, options=(OP_NO_SSLv2 | OP_NO_SSLv3), cipher_list=None): """ cert: Path to a file containing both client cert and private key. @@ -362,7 +364,7 @@ class BaseHandler(_Connection): self.ssl_established = False self.clientcert = None - def _create_ssl_context(self, cert, key, method=SSLv23_METHOD, options=None, + def _create_ssl_context(self, cert, key, method=SSLv23_METHOD, options=OP_NO_SSLv2, handle_sni=None, request_client_cert=None, cipher_list=None, dhparams=None, chain_file=None): """ -- cgit v1.2.3 From 3c919631d40cef69dacd166dabafc238a753edc8 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Sun, 28 Dec 2014 22:46:19 +1300 Subject: Bump version --- netlib/version.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'netlib') diff --git a/netlib/version.py b/netlib/version.py index f67d06b3..826c66fe 100644 --- a/netlib/version.py +++ b/netlib/version.py @@ -1,6 +1,6 @@ from __future__ import (absolute_import, print_function, division) -IVERSION = (0, 11, 1) +IVERSION = (0, 11, 2) VERSION = ".".join(str(i) for i in IVERSION) MINORVERSION = ".".join(str(i) for i in IVERSION[:2]) NAME = "netlib" -- cgit v1.2.3 From c9de3e770b8b8567cc3c233e9d0f82fd7a47e634 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Tue, 17 Feb 2015 11:59:07 +1300 Subject: By popular demand, bump dummy cert expiry to 5 years fixes #52 --- netlib/certutils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'netlib') diff --git a/netlib/certutils.py b/netlib/certutils.py index af6177d8..948eb85d 100644 --- a/netlib/certutils.py +++ b/netlib/certutils.py @@ -61,7 +61,7 @@ def dummy_cert(privkey, cacert, commonname, sans): cert = OpenSSL.crypto.X509() cert.gmtime_adj_notBefore(-3600*48) - cert.gmtime_adj_notAfter(60 * 60 * 24 * 30) + cert.gmtime_adj_notAfter(60 * 60 * 24 * 30 * 365 * 5) cert.set_issuer(cacert.get_subject()) cert.get_subject().CN = commonname cert.set_serial_number(int(time.time()*10000)) -- cgit v1.2.3 From 7e5bb74e7211dbe06b33847475854f54c56aa8d5 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Tue, 17 Feb 2015 12:03:52 +1300 Subject: 5 years is enough... --- netlib/certutils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'netlib') diff --git a/netlib/certutils.py b/netlib/certutils.py index 948eb85d..3eb9846d 100644 --- a/netlib/certutils.py +++ b/netlib/certutils.py @@ -61,7 +61,7 @@ def dummy_cert(privkey, cacert, commonname, sans): cert = OpenSSL.crypto.X509() cert.gmtime_adj_notBefore(-3600*48) - cert.gmtime_adj_notAfter(60 * 60 * 24 * 30 * 365 * 5) + cert.gmtime_adj_notAfter(60 * 60 * 24 * 365 * 5) cert.set_issuer(cacert.get_subject()) cert.get_subject().CN = commonname cert.set_serial_number(int(time.time()*10000)) -- cgit v1.2.3 From 2a2402dfffc9f1a51869170793673eaf49207d0f Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Tue, 17 Feb 2015 00:10:10 +0100 Subject: ...two years is not enough. --- netlib/certutils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'netlib') diff --git a/netlib/certutils.py b/netlib/certutils.py index 3eb9846d..5d8a56b8 100644 --- a/netlib/certutils.py +++ b/netlib/certutils.py @@ -6,7 +6,7 @@ from pyasn1.codec.der.decoder import decode from pyasn1.error import PyAsn1Error import OpenSSL -DEFAULT_EXP = 62208000 # =24 * 60 * 60 * 720 +DEFAULT_EXP = 157680000 # = 24 * 60 * 60 * 365 * 5 # Generated with "openssl dhparam". It's too slow to generate this on startup. DEFAULT_DHPARAM = """-----BEGIN DH PARAMETERS----- MIGHAoGBAOdPzMbYgoYfO3YBYauCLRlE8X1XypTiAjoeCFD0qWRx8YUsZ6Sj20W5 @@ -61,7 +61,7 @@ def dummy_cert(privkey, cacert, commonname, sans): cert = OpenSSL.crypto.X509() cert.gmtime_adj_notBefore(-3600*48) - cert.gmtime_adj_notAfter(60 * 60 * 24 * 365 * 5) + cert.gmtime_adj_notAfter(DEFAULT_EXP) cert.set_issuer(cacert.get_subject()) cert.get_subject().CN = commonname cert.set_serial_number(int(time.time()*10000)) -- cgit v1.2.3 From 224f737646a3f9d0d6540a295524806df7ed1943 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Fri, 27 Feb 2015 16:59:29 +0100 Subject: add option to log ssl keys refs mitmproxy/mitmproxy#475 --- netlib/tcp.py | 36 ++++++++++++++++++++++++++++++++++++ 1 file changed, 36 insertions(+) (limited to 'netlib') diff --git a/netlib/tcp.py b/netlib/tcp.py index 7010eef0..c6e0075e 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -1,4 +1,5 @@ from __future__ import (absolute_import, print_function, division) +import os import select import socket import sys @@ -26,6 +27,37 @@ class NetLibTimeout(NetLibError): pass class NetLibSSLError(NetLibError): pass +class SSLKeyLogger(object): + def __init__(self, filename): + self.filename = filename + self.f = None + self.lock = threading.Lock() + + __name__ = "SSLKeyLogger" # required for functools.wraps, which pyOpenSSL uses. + + def __call__(self, connection, where, ret): + if where == SSL.SSL_CB_HANDSHAKE_DONE and ret == 1: + with self.lock: + if not self.f: + self.f = open(self.filename, "ab") + self.f.write("\r\n") + client_random = connection.client_random().encode("hex") + masterkey = connection.master_key().encode("hex") + self.f.write("CLIENT_RANDOM {} {}\r\n".format(client_random, masterkey)) + self.f.flush() + + def close(self): + with self.lock: + if self.f: + self.f.close() + +_logfile = os.getenv("MITMPROXY_SSLKEYLOGFILE") or os.getenv("SSLKEYLOGFILE") +if _logfile: + log_ssl_key = SSLKeyLogger(_logfile) +else: + log_ssl_key = False + + class _FileLike: BLOCKSIZE = 1024 * 32 def __init__(self, o): @@ -314,6 +346,8 @@ class TCPClient(_Connection): if sni: self.sni = sni self.connection.set_tlsext_host_name(sni) + if log_ssl_key: + context.set_info_callback(log_ssl_key) self.connection.set_connect_state() try: self.connection.do_handshake() @@ -418,6 +452,8 @@ class BaseHandler(_Connection): # Return true to prevent cert verification error return True ctx.set_verify(SSL.VERIFY_PEER, ver) + if log_ssl_key: + ctx.set_info_callback(log_ssl_key) return ctx def convert_to_ssl(self, cert, key, **sslctx_kwargs): -- cgit v1.2.3 From 63fb43369029d33ce77cb2ce1df397e99494562c Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Fri, 27 Feb 2015 20:40:17 +0100 Subject: fix #53 --- netlib/odict.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'netlib') diff --git a/netlib/odict.py b/netlib/odict.py index 61448e6d..f97f074b 100644 --- a/netlib/odict.py +++ b/netlib/odict.py @@ -166,7 +166,7 @@ class ODict: return [tuple(i) for i in self.lst] def load_state(self, state): - self.list = [list(i) for i in state] + self.lst = [list(i) for i in state] @classmethod def from_state(klass, state): -- cgit v1.2.3 From da1eb94ccd36b31ea7e05c6a4e01dd5a6cf20376 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Fri, 27 Feb 2015 22:02:52 +0100 Subject: 100% test coverage :tada: --- netlib/tcp.py | 26 ++++++++++++++------------ netlib/test.py | 3 ++- 2 files changed, 16 insertions(+), 13 deletions(-) (limited to 'netlib') diff --git a/netlib/tcp.py b/netlib/tcp.py index c6e0075e..7f98b4f9 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -39,6 +39,9 @@ class SSLKeyLogger(object): if where == SSL.SSL_CB_HANDSHAKE_DONE and ret == 1: with self.lock: if not self.f: + d = os.path.dirname(self.filename) + if not os.path.isdir(d): + os.makedirs(d) self.f = open(self.filename, "ab") self.f.write("\r\n") client_random = connection.client_random().encode("hex") @@ -51,11 +54,13 @@ class SSLKeyLogger(object): if self.f: self.f.close() -_logfile = os.getenv("MITMPROXY_SSLKEYLOGFILE") or os.getenv("SSLKEYLOGFILE") -if _logfile: - log_ssl_key = SSLKeyLogger(_logfile) -else: - log_ssl_key = False + @staticmethod + def create_logfun(filename): + if filename: + return SSLKeyLogger(filename) + return False + +log_ssl_key = SSLKeyLogger.create_logfun(os.getenv("MITMPROXY_SSLKEYLOGFILE") or os.getenv("SSLKEYLOGFILE")) class _FileLike: @@ -161,9 +166,9 @@ class Reader(_FileLike): except SSL.SysCallError as e: if e.args == (-1, 'Unexpected EOF'): break - raise NetLibDisconnect - except SSL.Error, v: - raise NetLibSSLError(v.message) + raise NetLibSSLError(e.message) + except SSL.Error as e: + raise NetLibSSLError(e.message) self.first_byte_timestamp = self.first_byte_timestamp or time.time() if not data: break @@ -179,10 +184,7 @@ class Reader(_FileLike): while True: if size is not None and bytes_read >= size: break - try: - ch = self.read(1) - except NetLibDisconnect: - break + ch = self.read(1) bytes_read += 1 if not ch: break diff --git a/netlib/test.py b/netlib/test.py index fb468907..3a23ba8f 100644 --- a/netlib/test.py +++ b/netlib/test.py @@ -75,7 +75,8 @@ class TServer(tcp.TCPServer): handle_sni = getattr(h, "handle_sni", None), request_client_cert = self.ssl["request_client_cert"], cipher_list = self.ssl.get("cipher_list", None), - dhparams = self.ssl.get("dhparams", None) + dhparams = self.ssl.get("dhparams", None), + chain_file = self.ssl.get("chain_file", None) ) h.handle() h.finish() -- cgit v1.2.3 From d71f3b68fda688fec358b59fdcfaaa7031b3b80d Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Fri, 27 Feb 2015 22:27:23 +0100 Subject: make tests more robust, fix coveralls --- netlib/test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'netlib') diff --git a/netlib/test.py b/netlib/test.py index 3a23ba8f..db30c0e6 100644 --- a/netlib/test.py +++ b/netlib/test.py @@ -15,7 +15,7 @@ class ServerThread(threading.Thread): self.server.shutdown() -class ServerTestBase: +class ServerTestBase(object): ssl = None handler = None addr = ("localhost", 0) -- cgit v1.2.3 From dbadc1b61327d06bb176d0465ad5831a619126be Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Sat, 7 Mar 2015 01:22:02 +0100 Subject: clean up cert handling, fix mitmproxy/mitmproxy#472 --- netlib/tcp.py | 140 ++++++++++++++++++++++++++++++++++++---------------------- 1 file changed, 86 insertions(+), 54 deletions(-) (limited to 'netlib') diff --git a/netlib/tcp.py b/netlib/tcp.py index 7f98b4f9..ba4f008c 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -302,6 +302,43 @@ class _Connection(object): except SSL.Error: pass + """ + Creates an SSL Context. + """ + def _create_ssl_context(self, + method=SSLv23_METHOD, + options=(OP_NO_SSLv2 | OP_NO_SSLv3), + cipher_list=None + ): + """ + :param method: One of SSLv2_METHOD, SSLv3_METHOD, SSLv23_METHOD, TLSv1_METHOD or TLSv1_1_METHOD + :param options: A bit field consisting of OpenSSL.SSL.OP_* values + :param cipher_list: A textual OpenSSL cipher list, see https://www.openssl.org/docs/apps/ciphers.html + :rtype : SSL.Context + """ + context = SSL.Context(method) + # Options (NO_SSLv2/3) + if options is not None: + context.set_options(options) + + # Workaround for + # https://github.com/pyca/pyopenssl/issues/190 + # https://github.com/mitmproxy/mitmproxy/issues/472 + context.set_mode(SSL._lib.SSL_MODE_AUTO_RETRY) # Options already set before are not cleared. + + # Cipher List + if cipher_list: + try: + context.set_cipher_list(cipher_list) + except SSL.Error, v: + raise NetLibError("SSL cipher specification error: %s"%str(v)) + + # SSLKEYLOGFILE + if log_ssl_key: + context.set_info_callback(log_ssl_key) + + return context + class TCPClient(_Connection): rbufsize = -1 @@ -324,32 +361,28 @@ class TCPClient(_Connection): self.ssl_established = False self.sni = None - def convert_to_ssl(self, cert=None, sni=None, method=SSLv23_METHOD, options=(OP_NO_SSLv2 | OP_NO_SSLv3), cipher_list=None): - """ - cert: Path to a file containing both client cert and private key. - - options: A bit field consisting of OpenSSL.SSL.OP_* values - """ - context = SSL.Context(method) - if cipher_list: - try: - context.set_cipher_list(cipher_list) - except SSL.Error, v: - raise NetLibError("SSL cipher specification error: %s"%str(v)) - if options is not None: - context.set_options(options) + def create_ssl_context(self, cert=None, **sslctx_kwargs): + context = self._create_ssl_context(**sslctx_kwargs) + # Client Certs if cert: try: context.use_privatekey_file(cert) context.use_certificate_file(cert) except SSL.Error, v: raise NetLibError("SSL client certificate error: %s"%str(v)) + return context + + def convert_to_ssl(self, sni=None, **sslctx_kwargs): + """ + cert: Path to a file containing both client cert and private key. + + options: A bit field consisting of OpenSSL.SSL.OP_* values + """ + context = self.create_ssl_context(**sslctx_kwargs) self.connection = SSL.Connection(context, self.connection) if sni: self.sni = sni self.connection.set_tlsext_host_name(sni) - if log_ssl_key: - context.set_info_callback(log_ssl_key) self.connection.set_connect_state() try: self.connection.do_handshake() @@ -400,21 +433,21 @@ class BaseHandler(_Connection): self.ssl_established = False self.clientcert = None - def _create_ssl_context(self, cert, key, method=SSLv23_METHOD, options=OP_NO_SSLv2, - handle_sni=None, request_client_cert=None, cipher_list=None, - dhparams=None, chain_file=None): + def create_ssl_context(self, + cert, key, + handle_sni=None, + request_client_cert=None, + chain_file=None, + dhparams=None, + **sslctx_kwargs): """ cert: A certutils.SSLCert object. - method: One of SSLv2_METHOD, SSLv3_METHOD, SSLv23_METHOD, or TLSv1_METHOD - handle_sni: SNI handler, should take a connection object. Server name can be retrieved like this: connection.get_servername() - options: A bit field consisting of OpenSSL.SSL.OP_* values - And you can specify the connection keys as follows: new_context = Context(TLSv1_METHOD) @@ -431,40 +464,38 @@ class BaseHandler(_Connection): we may be able to make the proper behaviour the default again, but until then we're conservative. """ - ctx = SSL.Context(method) - if not options is None: - ctx.set_options(options) - if chain_file: - ctx.load_verify_locations(chain_file) - if cipher_list: - try: - ctx.set_cipher_list(cipher_list) - except SSL.Error, v: - raise NetLibError("SSL cipher specification error: %s"%str(v)) + context = self._create_ssl_context(**sslctx_kwargs) + + context.use_privatekey(key) + context.use_certificate(cert.x509) + if handle_sni: # SNI callback happens during do_handshake() - ctx.set_tlsext_servername_callback(handle_sni) - ctx.use_privatekey(key) - ctx.use_certificate(cert.x509) - if dhparams: - SSL._lib.SSL_CTX_set_tmp_dh(ctx._context, dhparams) + context.set_tlsext_servername_callback(handle_sni) + if request_client_cert: - def ver(*args): - self.clientcert = certutils.SSLCert(args[1]) + def save_cert(conn, cert, errno, depth, preverify_ok): + self.clientcert = certutils.SSLCert(cert) # Return true to prevent cert verification error return True - ctx.set_verify(SSL.VERIFY_PEER, ver) - if log_ssl_key: - ctx.set_info_callback(log_ssl_key) - return ctx + context.set_verify(SSL.VERIFY_PEER, save_cert) + + # Cert Verify + if chain_file: + context.load_verify_locations(chain_file) + + if dhparams: + SSL._lib.SSL_CTX_set_tmp_dh(context._context, dhparams) + + return context def convert_to_ssl(self, cert, key, **sslctx_kwargs): """ Convert connection to SSL. For a list of parameters, see BaseHandler._create_ssl_context(...) """ - ctx = self._create_ssl_context(cert, key, **sslctx_kwargs) - self.connection = SSL.Connection(ctx, self.connection) + context = self.create_ssl_context(cert, key, **sslctx_kwargs) + self.connection = SSL.Connection(context, self.connection) self.connection.set_accept_state() try: self.connection.do_handshake() @@ -474,7 +505,7 @@ class BaseHandler(_Connection): self.rfile.set_descriptor(self.connection) self.wfile.set_descriptor(self.connection) - def handle(self): # pragma: no cover + def handle(self): # pragma: no cover raise NotImplementedError def settimeout(self, n): @@ -483,6 +514,7 @@ class BaseHandler(_Connection): class TCPServer(object): request_queue_size = 20 + def __init__(self, address): self.address = Address.wrap(address) self.__is_shut_down = threading.Event() @@ -508,7 +540,7 @@ class TCPServer(object): while not self.__shutdown_request: try: r, w, e = select.select([self.socket], [], [], poll_interval) - except select.error, ex: # pragma: no cover + except select.error as ex: # pragma: no cover if ex[0] == EINTR: continue else: @@ -516,12 +548,12 @@ class TCPServer(object): if self.socket in r: connection, client_address = self.socket.accept() t = threading.Thread( - target = self.connection_thread, - args = (connection, client_address), - name = "ConnectionThread (%s:%s -> %s:%s)" % - (client_address[0], client_address[1], - self.address.host, self.address.port) - ) + target=self.connection_thread, + args=(connection, client_address), + name="ConnectionThread (%s:%s -> %s:%s)" % + (client_address[0], client_address[1], + self.address.host, self.address.port) + ) t.setDaemon(1) t.start() finally: -- cgit v1.2.3 From d5eff70b6e7acb3bd60a5e6f8233cf4936a5d606 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Sat, 7 Mar 2015 01:31:31 +0100 Subject: fix tests on Windows --- netlib/tcp.py | 5 +++++ 1 file changed, 5 insertions(+) (limited to 'netlib') diff --git a/netlib/tcp.py b/netlib/tcp.py index ba4f008c..b2f11851 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -7,6 +7,7 @@ import threading import time import traceback from OpenSSL import SSL +import OpenSSL from . import certutils @@ -301,6 +302,10 @@ class _Connection(object): self.connection.shutdown() except SSL.Error: pass + except KeyError as e: + # Workaround for https://github.com/pyca/pyopenssl/pull/183 + if OpenSSL.__version__ != "0.14": + raise e """ Creates an SSL Context. -- cgit v1.2.3 From 6fbe3006afa46c4c5f19e5c52b66e6e73a07f819 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Thu, 9 Apr 2015 00:12:41 +0200 Subject: fail gracefully if we cannot start a new thread --- netlib/tcp.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) (limited to 'netlib') diff --git a/netlib/tcp.py b/netlib/tcp.py index b2f11851..45c60fd8 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -560,7 +560,11 @@ class TCPServer(object): self.address.host, self.address.port) ) t.setDaemon(1) - t.start() + try: + t.start() + except threading.ThreadError: + self.handle_error(connection, Address(client_address)) + connection.close() finally: self.__shutdown_request = False self.__is_shut_down.set() -- cgit v1.2.3 From 7f7ccd3a1865e8e73f3d1813182d01c607d6e501 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Thu, 9 Apr 2015 00:57:37 +0200 Subject: 100% test coverage --- netlib/tcp.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'netlib') diff --git a/netlib/tcp.py b/netlib/tcp.py index 45c60fd8..20e7d45f 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -302,7 +302,7 @@ class _Connection(object): self.connection.shutdown() except SSL.Error: pass - except KeyError as e: + except KeyError as e: # pragma: no cover # Workaround for https://github.com/pyca/pyopenssl/pull/183 if OpenSSL.__version__ != "0.14": raise e -- cgit v1.2.3 From e58f76aec1db9cc784a3b73c3050d010bb084968 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Thu, 9 Apr 2015 02:09:33 +0200 Subject: fix code smell --- netlib/certutils.py | 4 ++-- netlib/http.py | 4 ++-- netlib/http_auth.py | 10 +++++----- netlib/odict.py | 2 +- netlib/tcp.py | 14 +++++++------- netlib/wsgi.py | 8 ++++---- 6 files changed, 21 insertions(+), 21 deletions(-) (limited to 'netlib') diff --git a/netlib/certutils.py b/netlib/certutils.py index 5d8a56b8..f5375c03 100644 --- a/netlib/certutils.py +++ b/netlib/certutils.py @@ -120,7 +120,7 @@ class CertStoreEntry(object): self.chain_file = chain_file -class CertStore: +class CertStore(object): """ Implements an in-memory certificate store. """ @@ -288,7 +288,7 @@ class _GeneralNames(univ.SequenceOf): sizeSpec = univ.SequenceOf.sizeSpec + constraint.ValueSizeConstraint(1, 1024) -class SSLCert: +class SSLCert(object): def __init__(self, cert): """ Returns a (common name, [subject alternative names]) tuple. diff --git a/netlib/http.py b/netlib/http.py index d2fc6343..26438863 100644 --- a/netlib/http.py +++ b/netlib/http.py @@ -333,8 +333,8 @@ def read_response(rfile, request_method, body_size_limit, include_body=True): False ) else: - # if include_body==False then a None content means the body should be - # read separately + # if include_body==False then a None content means the body should be + # read separately content = None return httpversion, code, msg, headers, content diff --git a/netlib/http_auth.py b/netlib/http_auth.py index dca6e2f3..296e094c 100644 --- a/netlib/http_auth.py +++ b/netlib/http_auth.py @@ -3,7 +3,7 @@ from argparse import Action, ArgumentTypeError from . import http -class NullProxyAuth(): +class NullProxyAuth(object): """ No proxy auth at all (returns empty challange headers) """ @@ -59,12 +59,12 @@ class BasicProxyAuth(NullProxyAuth): return {self.CHALLENGE_HEADER:'Basic realm="%s"'%self.realm} -class PassMan(): +class PassMan(object): def test(self, username, password_token): return False -class PassManNonAnon: +class PassManNonAnon(PassMan): """ Ensure the user specifies a username, accept any password. """ @@ -74,7 +74,7 @@ class PassManNonAnon: return False -class PassManHtpasswd: +class PassManHtpasswd(PassMan): """ Read usernames and passwords from an htpasswd file """ @@ -89,7 +89,7 @@ class PassManHtpasswd: return bool(self.htpasswd.check_password(username, password_token)) -class PassManSingleUser: +class PassManSingleUser(PassMan): def __init__(self, username, password): self.username, self.password = username, password diff --git a/netlib/odict.py b/netlib/odict.py index f97f074b..7a2f611b 100644 --- a/netlib/odict.py +++ b/netlib/odict.py @@ -11,7 +11,7 @@ def safe_subn(pattern, repl, target, *args, **kwargs): return re.subn(str(pattern), str(repl), target, *args, **kwargs) -class ODict: +class ODict(object): """ A dictionary-like object for managing ordered (key, value) data. """ diff --git a/netlib/tcp.py b/netlib/tcp.py index 20e7d45f..10269aa4 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -64,7 +64,7 @@ class SSLKeyLogger(object): log_ssl_key = SSLKeyLogger.create_logfun(os.getenv("MITMPROXY_SSLKEYLOGFILE") or os.getenv("SSLKEYLOGFILE")) -class _FileLike: +class _FileLike(object): BLOCKSIZE = 1024 * 32 def __init__(self, o): self.o = o @@ -134,8 +134,8 @@ class Writer(_FileLike): r = self.o.write(v) self.add_log(v[:r]) return r - except (SSL.Error, socket.error), v: - raise NetLibDisconnect(str(v)) + except (SSL.Error, socket.error) as e: + raise NetLibDisconnect(str(e)) class Reader(_FileLike): @@ -546,10 +546,10 @@ class TCPServer(object): try: r, w, e = select.select([self.socket], [], [], poll_interval) except select.error as ex: # pragma: no cover - if ex[0] == EINTR: - continue - else: - raise + if ex[0] == EINTR: + continue + else: + raise if self.socket in r: connection, client_address = self.socket.accept() t = threading.Thread( diff --git a/netlib/wsgi.py b/netlib/wsgi.py index 568b1f9c..bac27d5a 100644 --- a/netlib/wsgi.py +++ b/netlib/wsgi.py @@ -3,18 +3,18 @@ import cStringIO, urllib, time, traceback from . import odict, tcp -class ClientConn: +class ClientConn(object): def __init__(self, address): self.address = tcp.Address.wrap(address) -class Flow: +class Flow(object): def __init__(self, address, request): self.client_conn = ClientConn(address) self.request = request -class Request: +class Request(object): def __init__(self, scheme, method, path, headers, content): self.scheme, self.method, self.path = scheme, method, path self.headers, self.content = headers, content @@ -35,7 +35,7 @@ def date_time_string(): return s -class WSGIAdaptor: +class WSGIAdaptor(object): def __init__(self, app, domain, port, sversion): self.app, self.domain, self.port, self.sversion = app, domain, port, sversion -- cgit v1.2.3 From e41e5cbfdd7b778e6f68e86658e95f9e413133cb Mon Sep 17 00:00:00 2001 From: Chandler Abraham Date: Thu, 9 Apr 2015 19:35:40 -0700 Subject: netlib websockets --- netlib/http.py | 14 ++ netlib/utils.py | 3 + netlib/websockets/__init__.py | 1 + netlib/websockets/implementations.py | 81 ++++++++ netlib/websockets/websockets.py | 368 +++++++++++++++++++++++++++++++++++ 5 files changed, 467 insertions(+) create mode 100644 netlib/websockets/__init__.py create mode 100644 netlib/websockets/implementations.py create mode 100644 netlib/websockets/websockets.py (limited to 'netlib') diff --git a/netlib/http.py b/netlib/http.py index 26438863..2c72621d 100644 --- a/netlib/http.py +++ b/netlib/http.py @@ -29,6 +29,20 @@ def _is_valid_host(host): return None return True +def is_successful_upgrade(request, response): + """ + determines if a client and server successfully agreed to an HTTP protocol upgrade + + https://developer.mozilla.org/en-US/docs/Web/HTTP/Protocol_upgrade_mechanism + """ + http_switching_protocols_code = 101 + + if request and response: + responseUpgrade = request.headers.get("Upgrade") + requestUpgrade = response.headers.get("Upgrade") + if response.code == http_switching_protocols_code and responseUpgrade == requestUpgrade: + return requestUpgrade[0] if len(requestUpgrade) > 0 else None + return None def parse_url(url): """ diff --git a/netlib/utils.py b/netlib/utils.py index 79077ac6..03a70977 100644 --- a/netlib/utils.py +++ b/netlib/utils.py @@ -8,6 +8,9 @@ def isascii(s): return False return True +# best way to do it in python 2.x +def bytes_to_int(i): + return int(i.encode('hex'), 16) def cleanBin(s, fixspacing=False): """ diff --git a/netlib/websockets/__init__.py b/netlib/websockets/__init__.py new file mode 100644 index 00000000..9b4faa33 --- /dev/null +++ b/netlib/websockets/__init__.py @@ -0,0 +1 @@ +from __future__ import (absolute_import, print_function, division) diff --git a/netlib/websockets/implementations.py b/netlib/websockets/implementations.py new file mode 100644 index 00000000..78ae5be6 --- /dev/null +++ b/netlib/websockets/implementations.py @@ -0,0 +1,81 @@ +from netlib import tcp +from base64 import b64encode +from StringIO import StringIO +from . import websockets as ws +import struct +import SocketServer +import os + +# Simple websocket client and servers that are used to exercise the functionality in websockets.py +# These are *not* fully RFC6455 compliant + +class WebSocketsEchoHandler(tcp.BaseHandler): + def __init__(self, connection, address, server): + super(WebSocketsEchoHandler, self).__init__(connection, address, server) + self.handshake_done = False + + def handle(self): + while True: + if not self.handshake_done: + self.handshake() + else: + self.read_next_message() + + def read_next_message(self): + decoded = ws.WebSocketsFrame.from_byte_stream(self.rfile.read).decoded_payload + self.on_message(decoded) + + def send_message(self, message): + frame = ws.WebSocketsFrame.default_frame_from_message(message, from_client = False) + self.wfile.write(frame.to_bytes()) + self.wfile.flush() + + def handshake(self): + client_hs = ws.read_handshake(self.rfile.read, 1) + key = ws.server_process_handshake(client_hs) + response = ws.create_server_handshake(key) + self.wfile.write(response) + self.wfile.flush() + self.handshake_done = True + + def on_message(self, message): + if message is not None: + self.send_message(message) + + +class WebSocketsClient(tcp.TCPClient): + def __init__(self, address, source_address=None): + super(WebSocketsClient, self).__init__(address, source_address) + self.version = "13" + self.key = b64encode(os.urandom(16)).decode('utf-8') + self.resource = "/" + + def connect(self): + super(WebSocketsClient, self).connect() + + handshake = ws.create_client_handshake( + self.address.host, + self.address.port, + self.key, + self.version, + self.resource + ) + + self.wfile.write(handshake) + self.wfile.flush() + + response = ws.read_handshake(self.rfile.read, 1) + + if not response: + self.close() + + def read_next_message(self): + try: + return ws.WebSocketsFrame.from_byte_stream(self.rfile.read).payload + except IndexError: + self.close() + + def send_message(self, message): + frame = ws.WebSocketsFrame.default_frame_from_message(message, from_client = True) + self.wfile.write(frame.to_bytes()) + self.wfile.flush() diff --git a/netlib/websockets/websockets.py b/netlib/websockets/websockets.py new file mode 100644 index 00000000..b796ce39 --- /dev/null +++ b/netlib/websockets/websockets.py @@ -0,0 +1,368 @@ +from __future__ import absolute_import + +from base64 import b64encode +from hashlib import sha1 +from mimetools import Message +from netlib import tcp +from netlib import utils +from StringIO import StringIO +import os +import SocketServer +import struct +import io + +# Colleciton of utility functions that implement small portions of the RFC6455 WebSockets Protocol +# Useful for building WebSocket clients and servers. +# +# Emphassis is on readabilty, simplicity and modularity, not performance or completeness +# +# This is a work in progress and does not yet contain all the utilites need to create fully complient client/servers +# +# Spec: https://tools.ietf.org/html/rfc6455 + +# The magic sha that websocket servers must know to prove they understand RFC6455 +websockets_magic = '258EAFA5-E914-47DA-95CA-C5AB0DC85B11' + +class WebSocketFrameValidationException(Exception): + pass + +class WebSocketsFrame(object): + """ + Represents one websockets frame. + Constructor takes human readable forms of the frame components + from_bytes() is also avaliable. + + WebSockets Frame as defined in RFC6455 + + 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 + +-+-+-+-+-------+-+-------------+-------------------------------+ + |F|R|R|R| opcode|M| Payload len | Extended payload length | + |I|S|S|S| (4) |A| (7) | (16/64) | + |N|V|V|V| |S| | (if payload len==126/127) | + | |1|2|3| |K| | | + +-+-+-+-+-------+-+-------------+ - - - - - - - - - - - - - - - + + | Extended payload length continued, if payload len == 127 | + + - - - - - - - - - - - - - - - +-------------------------------+ + | |Masking-key, if MASK set to 1 | + +-------------------------------+-------------------------------+ + | Masking-key (continued) | Payload Data | + +-------------------------------- - - - - - - - - - - - - - - - + + : Payload Data continued ... : + + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + | Payload Data continued ... | + +---------------------------------------------------------------+ + """ + def __init__( + self, + fin, # decmial integer 1 or 0 + opcode, # decmial integer 1 - 4 + mask_bit, # decimal integer 1 or 0 + payload_length_code, # decimal integer 1 - 127 + decoded_payload, # bytestring + rsv1 = 0, # decimal integer 1 or 0 + rsv2 = 0, # decimal integer 1 or 0 + rsv3 = 0, # decimal integer 1 or 0 + payload = None, # bytestring + masking_key = None, # 32 bit byte string + actual_payload_length = None, # any decimal integer + use_validation = True # indicates whether or not you care if this frame adheres to the spec + ): + self.fin = fin + self.rsv1 = rsv1 + self.rsv2 = rsv2 + self.rsv3 = rsv3 + self.opcode = opcode + self.mask_bit = mask_bit + self.payload_length_code = payload_length_code + self.masking_key = masking_key + self.payload = payload + self.decoded_payload = decoded_payload + self.actual_payload_length = actual_payload_length + self.use_validation = use_validation + + if self.use_validation: + self.validate_frame() + + @classmethod + def from_bytes(cls, bytestring): + """ + Construct a websocket frame from an in-memory bytestring + to construct a frame from a stream of bytes, use read_frame() directly + """ + self.from_byte_stream(io.BytesIO(bytestring).read) + + @classmethod + def default_frame_from_message(cls, message, from_client = False): + """ + Construct a basic websocket frame from some default values. + Creates a non-fragmented text frame. + """ + length_code, actual_length = get_payload_length_pair(message) + + if from_client: + mask_bit = 1 + masking_key = random_masking_key() + payload = apply_mask(message, masking_key) + else: + mask_bit = 0 + masking_key = None + payload = message + + return cls( + fin = 1, # final frame + opcode = 1, # text + mask_bit = mask_bit, + payload_length_code = length_code, + payload = payload, + masking_key = masking_key, + decoded_payload = message, + actual_payload_length = actual_length + ) + + def validate_frame(self): + """ + Validate websocket frame invariants, call at anytime to ensure the WebSocketsFrame + has not been corrupted. + """ + try: + assert 0 <= self.fin <= 1 + assert 0 <= self.rsv1 <= 1 + assert 0 <= self.rsv2 <= 1 + assert 0 <= self.rsv3 <= 1 + assert 1 <= self.opcode <= 4 + assert 0 <= self.mask_bit <= 1 + assert 1 <= self.payload_length_code <= 127 + + if self.mask_bit == 1: + assert 1 <= len(self.masking_key) <= 4 + else: + assert self.masking_key == None + + assert self.actual_payload_length == len(self.payload) + + if self.payload is not None and self.masking_key is not None: + apply_mask(self.payload, self.masking_key) == self.decoded_payload + + except AssertionError: + raise WebSocketFrameValidationException() + + def human_readable(self): + return "\n".join([ + ("fin - " + str(self.fin)), + ("rsv1 - " + str(self.rsv1)), + ("rsv2 - " + str(self.rsv2)), + ("rsv3 - " + str(self.rsv3)), + ("opcode - " + str(self.opcode)), + ("mask_bit - " + str(self.mask_bit)), + ("payload_length_code - " + str(self.payload_length_code)), + ("masking_key - " + str(self.masking_key)), + ("payload - " + str(self.payload)), + ("decoded_payload - " + str(self.decoded_payload)), + ("actual_payload_length - " + str(self.actual_payload_length)), + ("use_validation - " + str(self.use_validation))]) + + def to_bytes(self): + """ + Serialize the frame back into the wire format, returns a bytestring + """ + # validate enforces all the assumptions made by this serializer + # in the spritit of mitmproxy, it's possible to create and serialize invalid frames + # by skipping validation. + if self.use_validation: + self.validate_frame() + + max_16_bit_int = (1 << 16) + max_64_bit_int = (1 << 63) + + # break down of the bit-math used to construct the first byte from the frame's integer values + # first shift the significant bit into the correct position + # 00000001 << 7 = 10000000 + # ... + # then combine: + # + # 10000000 fin + # 01000000 res1 + # 00100000 res2 + # 00010000 res3 + # 00000001 opcode + # -------- OR + # 11110001 = first_byte + + first_byte = (self.fin << 7) | (self.rsv1 << 6) | (self.rsv2 << 4) | (self.rsv3 << 4) | self.opcode + + second_byte = (self.mask_bit << 7) | self.payload_length_code + + bytes = chr(first_byte) + chr(second_byte) + + if self.actual_payload_length < 126: + pass + + elif self.actual_payload_length < max_16_bit_int: + # '!H' pack as 16 bit unsigned short + bytes += struct.pack('!H', self.actual_payload_length) # add 2 byte extended payload length + + elif self.actual_payload_length < max_64_bit_int: + # '!Q' = pack as 64 bit unsigned long long + bytes += struct.pack('!Q', self.actual_payload_length) # add 8 bytes extended payload length + + if self.masking_key is not None: + bytes += self.masking_key + + bytes += self.payload # already will be encoded if neccessary + + return bytes + + + @classmethod + def from_byte_stream(cls, read_bytes): + """ + read a websockets frame sent by a server or client + + read_bytes is a function that can be backed + by sockets or by any byte reader. So this + function may be used to read frames from disk/wire/memory + """ + first_byte = utils.bytes_to_int(read_bytes(1)) + second_byte = utils.bytes_to_int(read_bytes(1)) + + fin = first_byte >> 7 # grab the left most bit + opcode = first_byte & 15 # grab right most 4 bits by and-ing with 00001111 + mask_bit = second_byte >> 7 # grab left most bit + payload_length = second_byte & 127 # grab the next 7 bits + + # payload_lengthy > 125 indicates you need to read more bytes + # to get the actual payload length + if payload_length <= 125: + actual_payload_length = payload_length + + elif payload_length == 126: + actual_payload_length = utils.bytes_to_int(read_bytes(2)) + + elif payload_length == 127: + actual_payload_length = utils.bytes_to_int(read_bytes(8)) + + # masking key only present if mask bit set + if mask_bit == 1: + masking_key = read_bytes(4) + else: + masking_key = None + + payload = read_bytes(actual_payload_length) + + if mask_bit == 1: + decoded_payload = apply_mask(payload, masking_key) + else: + decoded_payload = payload + + return cls( + fin = fin, + opcode = opcode, + mask_bit = mask_bit, + payload_length_code = payload_length, + payload = payload, + masking_key = masking_key, + decoded_payload = decoded_payload, + actual_payload_length = actual_payload_length + ) + +def apply_mask(message, masking_key): + """ + Data sent from the server must be masked to prevent malicious clients + from sending data over the wire in predictable patterns + + This method both encodes and decodes strings with the provided mask + + Servers do not have to mask data they send to the client. + https://tools.ietf.org/html/rfc6455#section-5.3 + """ + masks = [utils.bytes_to_int(byte) for byte in masking_key] + result = "" + for char in message: + result += chr(ord(char) ^ masks[len(result) % 4]) + return result + +def random_masking_key(): + return os.urandom(4) + +def masking_key_list(masking_key): + return [utils.bytes_to_int(byte) for byte in masking_key] + +def create_client_handshake(host, port, key, version, resource): + """ + WebSockets connections are intiated by the client with a valid HTTP upgrade request + """ + headers = [ + ('Host', '%s:%s' % (host, port)), + ('Connection', 'Upgrade'), + ('Upgrade', 'websocket'), + ('Sec-WebSocket-Key', key), + ('Sec-WebSocket-Version', version) + ] + request = "GET %s HTTP/1.1" % resource + return build_handshake(headers, request) + + +def create_server_handshake(key, magic = websockets_magic): + """ + The server response is a valid HTTP 101 response. + """ + digest = b64encode(sha1(key + magic).hexdigest().decode('hex')) + headers = [ + ('Connection', 'Upgrade'), + ('Upgrade', 'websocket'), + ('Sec-WebSocket-Accept', digest) + ] + request = "HTTP/1.1 101 Switching Protocols" + return build_handshake(headers, request) + + +def build_handshake(headers, request): + handshake = [request.encode('utf-8')] + for header, value in headers: + handshake.append(("%s: %s" % (header, value)).encode('utf-8')) + handshake.append(b'\r\n') + return b'\r\n'.join(handshake) + + +def read_handshake(read_bytes, num_bytes_per_read): + """ + From provided function that reads bytes, read in a + complete HTTP request, which terminates with a CLRF + """ + response = b'' + doubleCLRF = b'\r\n\r\n' + while True: + bytes = read_bytes(num_bytes_per_read) + if not bytes: + break + response += bytes + if doubleCLRF in response: + break + return response + +def get_payload_length_pair(payload_bytestring): + """ + A websockets frame contains an initial length_code, and an optional + extended length code to represent the actual length if length code is larger + than 125 + """ + actual_length = len(payload_bytestring) + + if actual_length <= 125: + length_code = actual_length + elif actual_length >= 126 and actual_length <= 65535: + length_code = 126 + else: + length_code = 127 + return (length_code, actual_length) + +def server_process_handshake(handshake): + headers = Message(StringIO(handshake.split('\r\n', 1)[1])) + if headers.get("Upgrade", None) != "websocket": + return + key = headers['Sec-WebSocket-Key'] + return key + +def generate_client_nounce(): + return b64encode(os.urandom(16)).decode('utf-8') + -- cgit v1.2.3 From 0edc04814e3affa71025938ac354707b9b4c481c Mon Sep 17 00:00:00 2001 From: Chandler Abraham Date: Sat, 11 Apr 2015 11:35:15 -0700 Subject: small cleanups, working on tests --- netlib/websockets/implementations.py | 10 +++++----- netlib/websockets/websockets.py | 35 +++++++++++++++++------------------ 2 files changed, 22 insertions(+), 23 deletions(-) (limited to 'netlib') diff --git a/netlib/websockets/implementations.py b/netlib/websockets/implementations.py index 78ae5be6..ff42ff65 100644 --- a/netlib/websockets/implementations.py +++ b/netlib/websockets/implementations.py @@ -26,8 +26,8 @@ class WebSocketsEchoHandler(tcp.BaseHandler): self.on_message(decoded) def send_message(self, message): - frame = ws.WebSocketsFrame.default_frame_from_message(message, from_client = False) - self.wfile.write(frame.to_bytes()) + frame = ws.WebSocketsFrame.default(message, from_client = False) + self.wfile.write(frame.safe_to_bytes()) self.wfile.flush() def handshake(self): @@ -47,7 +47,7 @@ class WebSocketsClient(tcp.TCPClient): def __init__(self, address, source_address=None): super(WebSocketsClient, self).__init__(address, source_address) self.version = "13" - self.key = b64encode(os.urandom(16)).decode('utf-8') + self.key = ws.generate_client_nounce() self.resource = "/" def connect(self): @@ -76,6 +76,6 @@ class WebSocketsClient(tcp.TCPClient): self.close() def send_message(self, message): - frame = ws.WebSocketsFrame.default_frame_from_message(message, from_client = True) - self.wfile.write(frame.to_bytes()) + frame = ws.WebSocketsFrame.default(message, from_client = True) + self.wfile.write(frame.safe_to_bytes()) self.wfile.flush() diff --git a/netlib/websockets/websockets.py b/netlib/websockets/websockets.py index b796ce39..527d55d6 100644 --- a/netlib/websockets/websockets.py +++ b/netlib/websockets/websockets.py @@ -65,7 +65,6 @@ class WebSocketsFrame(object): payload = None, # bytestring masking_key = None, # 32 bit byte string actual_payload_length = None, # any decimal integer - use_validation = True # indicates whether or not you care if this frame adheres to the spec ): self.fin = fin self.rsv1 = rsv1 @@ -78,21 +77,18 @@ class WebSocketsFrame(object): self.payload = payload self.decoded_payload = decoded_payload self.actual_payload_length = actual_payload_length - self.use_validation = use_validation - - if self.use_validation: - self.validate_frame() @classmethod def from_bytes(cls, bytestring): """ Construct a websocket frame from an in-memory bytestring - to construct a frame from a stream of bytes, use read_frame() directly + to construct a frame from a stream of bytes, use from_byte_stream() directly """ self.from_byte_stream(io.BytesIO(bytestring).read) + @classmethod - def default_frame_from_message(cls, message, from_client = False): + def default(cls, message, from_client = False): """ Construct a basic websocket frame from some default values. Creates a non-fragmented text frame. @@ -119,7 +115,7 @@ class WebSocketsFrame(object): actual_payload_length = actual_length ) - def validate_frame(self): + def frame_is_valid(self): """ Validate websocket frame invariants, call at anytime to ensure the WebSocketsFrame has not been corrupted. @@ -141,10 +137,11 @@ class WebSocketsFrame(object): assert self.actual_payload_length == len(self.payload) if self.payload is not None and self.masking_key is not None: - apply_mask(self.payload, self.masking_key) == self.decoded_payload + assert apply_mask(self.payload, self.masking_key) == self.decoded_payload + return True except AssertionError: - raise WebSocketFrameValidationException() + return False def human_readable(self): return "\n".join([ @@ -161,15 +158,19 @@ class WebSocketsFrame(object): ("actual_payload_length - " + str(self.actual_payload_length)), ("use_validation - " + str(self.use_validation))]) + def safe_to_bytes(self): + try: + assert self.frame_is_valid() + return self.to_bytes() + except: + raise WebSocketFrameValidationException() + def to_bytes(self): """ Serialize the frame back into the wire format, returns a bytestring + If you haven't checked is_valid_frame() then there's no guarentees that the + serialized bytes will be correct. see safe_to_bytes() """ - # validate enforces all the assumptions made by this serializer - # in the spritit of mitmproxy, it's possible to create and serialize invalid frames - # by skipping validation. - if self.use_validation: - self.validate_frame() max_16_bit_int = (1 << 16) max_64_bit_int = (1 << 63) @@ -198,6 +199,7 @@ class WebSocketsFrame(object): pass elif self.actual_payload_length < max_16_bit_int: + # '!H' pack as 16 bit unsigned short bytes += struct.pack('!H', self.actual_payload_length) # add 2 byte extended payload length @@ -284,9 +286,6 @@ def apply_mask(message, masking_key): def random_masking_key(): return os.urandom(4) -def masking_key_list(masking_key): - return [utils.bytes_to_int(byte) for byte in masking_key] - def create_client_handshake(host, port, key, version, resource): """ WebSockets connections are intiated by the client with a valid HTTP upgrade request -- cgit v1.2.3 From 73ce169e3d11eeabeb78143bd86edfdbc3e07fd9 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Sun, 12 Apr 2015 10:26:09 +1200 Subject: Initial outline of a cookie parsing and serialization module. --- netlib/http_cookies.py | 133 +++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 133 insertions(+) create mode 100644 netlib/http_cookies.py (limited to 'netlib') diff --git a/netlib/http_cookies.py b/netlib/http_cookies.py new file mode 100644 index 00000000..e11e0f90 --- /dev/null +++ b/netlib/http_cookies.py @@ -0,0 +1,133 @@ +""" +A flexible module for cookie parsing and manipulation. + +We try to be as permissive as possible. Parsing accepts formats from RFC6265 an +RFC2109. Serialization follows RFC6265 strictly. + + http://tools.ietf.org/html/rfc6265 + http://tools.ietf.org/html/rfc2109 +""" + +import re + +import odict + + +def _read_until(s, start, term): + """ + Read until one of the characters in term is reached. + """ + if start == len(s): + return "", start+1 + for i in range(start, len(s)): + if s[i] in term: + return s[start:i], i + return s[start:i+1], i+1 + + +def _read_token(s, start): + """ + Read a token - the LHS of a token/value pair in a cookie. + """ + return _read_until(s, start, ";=") + + +def _read_quoted_string(s, start): + """ + start: offset to the first quote of the string to be read + + A sort of loose super-set of the various quoted string specifications. + + RFC6265 disallows backslashes or double quotes within quoted strings. + Prior RFCs use backslashes to escape. This leaves us free to apply + backslash escaping by default and be compatible with everything. + """ + escaping = False + ret = [] + # Skip the first quote + for i in range(start+1, len(s)): + if escaping: + ret.append(s[i]) + escaping = False + elif s[i] == '"': + break + elif s[i] == "\\": + escaping = True + pass + else: + ret.append(s[i]) + return "".join(ret), i+1 + + +def _read_value(s, start): + """ + Reads a value - the RHS of a token/value pair in a cookie. + """ + if s[start] == '"': + return _read_quoted_string(s, start) + else: + return _read_until(s, start, ";,") + + +def _read_pairs(s): + """ + Read pairs of lhs=rhs values. + """ + off = 0 + vals = [] + while 1: + lhs, off = _read_token(s, off) + rhs = None + if off < len(s): + if s[off] == "=": + rhs, off = _read_value(s, off+1) + vals.append([lhs.lstrip(), rhs]) + off += 1 + if not off < len(s): + break + return vals, off + + +ESCAPE = re.compile(r"([\"\\])") +SPECIAL = re.compile(r"^\w+$") + + +def _format_pairs(lst): + vals = [] + for k, v in lst: + if v is None: + vals.append(k) + else: + match = SPECIAL.search(v) + if match: + v = ESCAPE.sub(r"\1", v) + vals.append("%s=%s"%(k, v)) + return "; ".join(vals) + + +def parse_cookies(s): + """ + Parses a Cookie header value. + Returns an ODict object. + """ + pairs, off = _read_pairs(s) + return odict.ODict(pairs) + + +def unparse_cookies(od): + """ + Formats a Cookie header value. + """ + vals = [] + for i in od.lst: + vals.append("%s=%s"%(i[0], i[1])) + return "; ".join(vals) + + + +def parse_set_cookies(s): + start = 0 + + +def unparse_set_cookies(s): + pass -- cgit v1.2.3 From 2630da7263242411d413b5e4b2c520d29848c918 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Sun, 12 Apr 2015 11:26:02 +1200 Subject: cookies: Cater for special values, fix some bugs found in real-world testing --- netlib/http_cookies.py | 48 +++++++++++++++++++++++++++++++++--------------- 1 file changed, 33 insertions(+), 15 deletions(-) (limited to 'netlib') diff --git a/netlib/http_cookies.py b/netlib/http_cookies.py index e11e0f90..82675418 100644 --- a/netlib/http_cookies.py +++ b/netlib/http_cookies.py @@ -59,29 +59,39 @@ def _read_quoted_string(s, start): return "".join(ret), i+1 -def _read_value(s, start): +def _read_value(s, start, special): """ Reads a value - the RHS of a token/value pair in a cookie. + + special: If the value is special, commas are premitted. Else comma + terminates. This helps us support old and new style values. """ - if s[start] == '"': + if start >= len(s): + return "", start + elif s[start] == '"': return _read_quoted_string(s, start) + elif special: + return _read_until(s, start, ";") else: return _read_until(s, start, ";,") -def _read_pairs(s): +def _read_pairs(s, specials=()): """ Read pairs of lhs=rhs values. + + specials: A lower-cased list of keys that may contain commas. """ off = 0 vals = [] while 1: lhs, off = _read_token(s, off) + lhs = lhs.lstrip() rhs = None if off < len(s): if s[off] == "=": - rhs, off = _read_value(s, off+1) - vals.append([lhs.lstrip(), rhs]) + rhs, off = _read_value(s, off+1, lhs.lower() in specials) + vals.append([lhs, rhs]) off += 1 if not off < len(s): break @@ -89,18 +99,30 @@ def _read_pairs(s): ESCAPE = re.compile(r"([\"\\])") -SPECIAL = re.compile(r"^\w+$") -def _format_pairs(lst): +def _has_special(s): + for i in s: + if i in '",;\\': + return True + o = ord(i) + if o < 0x21 or o > 0x7e: + return True + return False + + +def _format_pairs(lst, specials=()): + """ + specials: A lower-cased list of keys that will not be quoted. + """ vals = [] for k, v in lst: if v is None: vals.append(k) else: - match = SPECIAL.search(v) - if match: - v = ESCAPE.sub(r"\1", v) + if k.lower() not in specials and _has_special(v): + v = ESCAPE.sub(r"\\\1", v) + v = '"%s"'%v vals.append("%s=%s"%(k, v)) return "; ".join(vals) @@ -118,11 +140,7 @@ def unparse_cookies(od): """ Formats a Cookie header value. """ - vals = [] - for i in od.lst: - vals.append("%s=%s"%(i[0], i[1])) - return "; ".join(vals) - + return _format_pairs(od.lst) def parse_set_cookies(s): -- cgit v1.2.3 From f131f9b855e77554072415c925ed112ec74ee48a Mon Sep 17 00:00:00 2001 From: Chandler Abraham Date: Sat, 11 Apr 2015 15:40:18 -0700 Subject: handshake tests, serialization test --- netlib/websockets/implementations.py | 19 +++++++++----- netlib/websockets/websockets.py | 51 ++++++++++++++++++++++++++---------- 2 files changed, 49 insertions(+), 21 deletions(-) (limited to 'netlib') diff --git a/netlib/websockets/implementations.py b/netlib/websockets/implementations.py index ff42ff65..73a84690 100644 --- a/netlib/websockets/implementations.py +++ b/netlib/websockets/implementations.py @@ -32,7 +32,7 @@ class WebSocketsEchoHandler(tcp.BaseHandler): def handshake(self): client_hs = ws.read_handshake(self.rfile.read, 1) - key = ws.server_process_handshake(client_hs) + key = ws.process_handshake_from_client(client_hs) response = ws.create_server_handshake(key) self.wfile.write(response) self.wfile.flush() @@ -46,9 +46,9 @@ class WebSocketsEchoHandler(tcp.BaseHandler): class WebSocketsClient(tcp.TCPClient): def __init__(self, address, source_address=None): super(WebSocketsClient, self).__init__(address, source_address) - self.version = "13" - self.key = ws.generate_client_nounce() - self.resource = "/" + self.version = "13" + self.client_nounce = ws.create_client_nounce() + self.resource = "/" def connect(self): super(WebSocketsClient, self).connect() @@ -56,7 +56,7 @@ class WebSocketsClient(tcp.TCPClient): handshake = ws.create_client_handshake( self.address.host, self.address.port, - self.key, + self.client_nounce, self.version, self.resource ) @@ -64,9 +64,14 @@ class WebSocketsClient(tcp.TCPClient): self.wfile.write(handshake) self.wfile.flush() - response = ws.read_handshake(self.rfile.read, 1) + server_handshake = ws.read_handshake(self.rfile.read, 1) - if not response: + if not server_handshake: + self.close() + + server_nounce = ws.process_handshake_from_server(server_handshake, self.client_nounce) + + if not server_nounce == ws.create_server_nounce(self.client_nounce): self.close() def read_next_message(self): diff --git a/netlib/websockets/websockets.py b/netlib/websockets/websockets.py index 527d55d6..cf9a68aa 100644 --- a/netlib/websockets/websockets.py +++ b/netlib/websockets/websockets.py @@ -84,7 +84,7 @@ class WebSocketsFrame(object): Construct a websocket frame from an in-memory bytestring to construct a frame from a stream of bytes, use from_byte_stream() directly """ - self.from_byte_stream(io.BytesIO(bytestring).read) + return cls.from_byte_stream(io.BytesIO(bytestring).read) @classmethod @@ -115,7 +115,7 @@ class WebSocketsFrame(object): actual_payload_length = actual_length ) - def frame_is_valid(self): + def is_valid(self): """ Validate websocket frame invariants, call at anytime to ensure the WebSocketsFrame has not been corrupted. @@ -155,12 +155,11 @@ class WebSocketsFrame(object): ("masking_key - " + str(self.masking_key)), ("payload - " + str(self.payload)), ("decoded_payload - " + str(self.decoded_payload)), - ("actual_payload_length - " + str(self.actual_payload_length)), - ("use_validation - " + str(self.use_validation))]) + ("actual_payload_length - " + str(self.actual_payload_length))]) def safe_to_bytes(self): try: - assert self.frame_is_valid() + assert self.is_valid() return self.to_bytes() except: raise WebSocketFrameValidationException() @@ -197,7 +196,7 @@ class WebSocketsFrame(object): if self.actual_payload_length < 126: pass - + elif self.actual_payload_length < max_16_bit_int: # '!H' pack as 16 bit unsigned short @@ -267,6 +266,20 @@ class WebSocketsFrame(object): actual_payload_length = actual_payload_length ) + def __eq__(self, other): + return ( + self.fin == other.fin and + self.rsv1 == other.rsv1 and + self.rsv2 == other.rsv2 and + self.rsv3 == other.rsv3 and + self.opcode == other.opcode and + self.mask_bit == other.mask_bit and + self.payload_length_code == other.payload_length_code and + self.masking_key == other.masking_key and + self.payload == other.payload and + self.decoded_payload == other.decoded_payload and + self.actual_payload_length == other.actual_payload_length) + def apply_mask(message, masking_key): """ Data sent from the server must be masked to prevent malicious clients @@ -300,16 +313,14 @@ def create_client_handshake(host, port, key, version, resource): request = "GET %s HTTP/1.1" % resource return build_handshake(headers, request) - -def create_server_handshake(key, magic = websockets_magic): +def create_server_handshake(key): """ The server response is a valid HTTP 101 response. """ - digest = b64encode(sha1(key + magic).hexdigest().decode('hex')) headers = [ ('Connection', 'Upgrade'), ('Upgrade', 'websocket'), - ('Sec-WebSocket-Accept', digest) + ('Sec-WebSocket-Accept', create_server_nounce(key)) ] request = "HTTP/1.1 101 Switching Protocols" return build_handshake(headers, request) @@ -322,7 +333,6 @@ def build_handshake(headers, request): handshake.append(b'\r\n') return b'\r\n'.join(handshake) - def read_handshake(read_bytes, num_bytes_per_read): """ From provided function that reads bytes, read in a @@ -355,13 +365,26 @@ def get_payload_length_pair(payload_bytestring): length_code = 127 return (length_code, actual_length) -def server_process_handshake(handshake): - headers = Message(StringIO(handshake.split('\r\n', 1)[1])) +def process_handshake_from_client(handshake): + headers = headers_from_http_message(handshake) if headers.get("Upgrade", None) != "websocket": return key = headers['Sec-WebSocket-Key'] return key -def generate_client_nounce(): +def process_handshake_from_server(handshake, client_nounce): + headers = headers_from_http_message(handshake) + if headers.get("Upgrade", None) != "websocket": + return + key = headers['Sec-WebSocket-Accept'] + return key + +def headers_from_http_message(http_message): + return Message(StringIO(http_message.split('\r\n', 1)[1])) + +def create_server_nounce(client_nounce): + return b64encode(sha1(client_nounce + websockets_magic).hexdigest().decode('hex')) + +def create_client_nounce(): return b64encode(os.urandom(16)).decode('utf-8') -- cgit v1.2.3 From 2d72a1b6b56f1643cd1d8be59eee55aa7ca2f17f Mon Sep 17 00:00:00 2001 From: Chandler Abraham Date: Mon, 13 Apr 2015 13:36:09 -0700 Subject: 100% test coverage, though still need plenty more --- netlib/http.py | 14 -------------- netlib/websockets/implementations.py | 10 ++-------- netlib/websockets/websockets.py | 9 ++++----- 3 files changed, 6 insertions(+), 27 deletions(-) (limited to 'netlib') diff --git a/netlib/http.py b/netlib/http.py index 2c72621d..26438863 100644 --- a/netlib/http.py +++ b/netlib/http.py @@ -29,20 +29,6 @@ def _is_valid_host(host): return None return True -def is_successful_upgrade(request, response): - """ - determines if a client and server successfully agreed to an HTTP protocol upgrade - - https://developer.mozilla.org/en-US/docs/Web/HTTP/Protocol_upgrade_mechanism - """ - http_switching_protocols_code = 101 - - if request and response: - responseUpgrade = request.headers.get("Upgrade") - requestUpgrade = response.headers.get("Upgrade") - if response.code == http_switching_protocols_code and responseUpgrade == requestUpgrade: - return requestUpgrade[0] if len(requestUpgrade) > 0 else None - return None def parse_url(url): """ diff --git a/netlib/websockets/implementations.py b/netlib/websockets/implementations.py index 73a84690..1ded3b85 100644 --- a/netlib/websockets/implementations.py +++ b/netlib/websockets/implementations.py @@ -65,9 +65,6 @@ class WebSocketsClient(tcp.TCPClient): self.wfile.flush() server_handshake = ws.read_handshake(self.rfile.read, 1) - - if not server_handshake: - self.close() server_nounce = ws.process_handshake_from_server(server_handshake, self.client_nounce) @@ -75,11 +72,8 @@ class WebSocketsClient(tcp.TCPClient): self.close() def read_next_message(self): - try: - return ws.WebSocketsFrame.from_byte_stream(self.rfile.read).payload - except IndexError: - self.close() - + return ws.WebSocketsFrame.from_byte_stream(self.rfile.read).payload + def send_message(self, message): frame = ws.WebSocketsFrame.default(message, from_client = True) self.wfile.write(frame.safe_to_bytes()) diff --git a/netlib/websockets/websockets.py b/netlib/websockets/websockets.py index cf9a68aa..ea3db21d 100644 --- a/netlib/websockets/websockets.py +++ b/netlib/websockets/websockets.py @@ -158,11 +158,10 @@ class WebSocketsFrame(object): ("actual_payload_length - " + str(self.actual_payload_length))]) def safe_to_bytes(self): - try: - assert self.is_valid() - return self.to_bytes() - except: - raise WebSocketFrameValidationException() + if self.is_valid(): + return self.to_bytes() + else: + raise WebSocketFrameValidationException() def to_bytes(self): """ -- cgit v1.2.3 From de9e7411253c4f67ea4d0b96f6f9e952024c5fa3 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Tue, 14 Apr 2015 10:02:10 +1200 Subject: Firm up cookie parsing and formatting API Make a tough call: we won't support old-style comma-separated set-cookie headers. Real world testing has shown that the latest rfc (6265) is often violated in ways that make the parsing problem indeterminate. Since this is much more common than the old style deprecated set-cookie variant, we focus on the most useful case. --- netlib/http_cookies.py | 112 ++++++++++++++++++++++++++++++++++++------------- 1 file changed, 83 insertions(+), 29 deletions(-) (limited to 'netlib') diff --git a/netlib/http_cookies.py b/netlib/http_cookies.py index 82675418..a1f240f5 100644 --- a/netlib/http_cookies.py +++ b/netlib/http_cookies.py @@ -1,13 +1,27 @@ """ A flexible module for cookie parsing and manipulation. -We try to be as permissive as possible. Parsing accepts formats from RFC6265 an -RFC2109. Serialization follows RFC6265 strictly. +This module differs from usual standards-compliant cookie modules in a number of +ways. We try to be as permissive as possible, and to retain even mal-formed +information. Duplicate cookies are preserved in parsing, and can be set in +formatting. We do attempt to escape and quote values where needed, but will not +reject data that violate the specs. + +Parsing accepts the formats in RFC6265 and partially RFC2109 and RFC2965. We do +not parse the comma-separated variant of Set-Cookie that allows multiple cookies +to be set in a single header. Technically this should be feasible, but it turns +out that violations of RFC6265 that makes the parsing problem indeterminate are +much more common than genuine occurences of the multi-cookie variants. +Serialization follows RFC6265. http://tools.ietf.org/html/rfc6265 http://tools.ietf.org/html/rfc2109 + http://tools.ietf.org/html/rfc2965 """ +# TODO +# - Disallow LHS-only Cookie values + import re import odict @@ -59,7 +73,7 @@ def _read_quoted_string(s, start): return "".join(ret), i+1 -def _read_value(s, start, special): +def _read_value(s, start, delims): """ Reads a value - the RHS of a token/value pair in a cookie. @@ -70,37 +84,41 @@ def _read_value(s, start, special): return "", start elif s[start] == '"': return _read_quoted_string(s, start) - elif special: - return _read_until(s, start, ";") else: - return _read_until(s, start, ";,") + return _read_until(s, start, delims) -def _read_pairs(s, specials=()): +def _read_pairs(s, off=0, term=None, specials=()): """ Read pairs of lhs=rhs values. - specials: A lower-cased list of keys that may contain commas. + off: start offset + term: if True, treat a comma as a terminator for the pairs lists + specials: a lower-cased list of keys that may contain commas if term is + True """ - off = 0 vals = [] while 1: lhs, off = _read_token(s, off) lhs = lhs.lstrip() - rhs = None - if off < len(s): - if s[off] == "=": - rhs, off = _read_value(s, off+1, lhs.lower() in specials) - vals.append([lhs, rhs]) + if lhs: + rhs = None + if off < len(s): + if s[off] == "=": + if term and lhs.lower() not in specials: + delims = ";," + else: + delims = ";" + rhs, off = _read_value(s, off+1, delims) + vals.append([lhs, rhs]) off += 1 if not off < len(s): break + if term and s[off-1] == ",": + break return vals, off -ESCAPE = re.compile(r"([\"\\])") - - def _has_special(s): for i in s: if i in '",;\\': @@ -111,6 +129,9 @@ def _has_special(s): return False +ESCAPE = re.compile(r"([\"\\])") + + def _format_pairs(lst, specials=()): """ specials: A lower-cased list of keys that will not be quoted. @@ -127,25 +148,58 @@ def _format_pairs(lst, specials=()): return "; ".join(vals) -def parse_cookies(s): +def _format_set_cookie_pairs(lst): + return _format_pairs( + lst, + specials = ("expires", "path") + ) + + +def _parse_set_cookie_pairs(s): """ - Parses a Cookie header value. - Returns an ODict object. + For Set-Cookie, we support multiple cookies as described in RFC2109. + This function therefore returns a list of lists. """ - pairs, off = _read_pairs(s) - return odict.ODict(pairs) + pairs, off = _read_pairs( + s, + specials = ("expires", "path") + ) + return pairs -def unparse_cookies(od): +def parse_set_cookie_header(str): """ - Formats a Cookie header value. + Parse a Set-Cookie header value + + Returns a (name, value, attrs) tuple, or None, where attrs is an + ODictCaseless set of attributes. No attempt is made to parse attribute + values - they are treated purely as strings. """ - return _format_pairs(od.lst) + pairs = _parse_set_cookie_pairs(str) + if pairs: + return pairs[0][0], pairs[0][1], odict.ODictCaseless(pairs[1:]) + + +def format_set_cookie_header(name, value, attrs): + """ + Formats a Set-Cookie header value. + """ + pairs = [[name, value]] + pairs.extend(attrs.lst) + return _format_set_cookie_pairs(pairs) -def parse_set_cookies(s): - start = 0 +def parse_cookie_header(str): + """ + Parse a Cookie header value. + Returns a (possibly empty) ODict object. + """ + pairs, off = _read_pairs(str) + return odict.ODict(pairs) -def unparse_set_cookies(s): - pass +def format_cookie_header(od): + """ + Formats a Cookie header value. + """ + return _format_pairs(od.lst) -- cgit v1.2.3 From 6db5e0a4a133e6e6150f9cab87cd56b40d6db0b2 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Tue, 14 Apr 2015 10:13:03 +1200 Subject: Remove old-style set-cookie cruft, unit tests to 100% --- netlib/http_cookies.py | 14 +++----------- 1 file changed, 3 insertions(+), 11 deletions(-) (limited to 'netlib') diff --git a/netlib/http_cookies.py b/netlib/http_cookies.py index a1f240f5..297efb80 100644 --- a/netlib/http_cookies.py +++ b/netlib/http_cookies.py @@ -88,14 +88,12 @@ def _read_value(s, start, delims): return _read_until(s, start, delims) -def _read_pairs(s, off=0, term=None, specials=()): +def _read_pairs(s, off=0, specials=()): """ Read pairs of lhs=rhs values. off: start offset - term: if True, treat a comma as a terminator for the pairs lists - specials: a lower-cased list of keys that may contain commas if term is - True + specials: a lower-cased list of keys that may contain commas """ vals = [] while 1: @@ -105,17 +103,11 @@ def _read_pairs(s, off=0, term=None, specials=()): rhs = None if off < len(s): if s[off] == "=": - if term and lhs.lower() not in specials: - delims = ";," - else: - delims = ";" - rhs, off = _read_value(s, off+1, delims) + rhs, off = _read_value(s, off+1, ";") vals.append([lhs, rhs]) off += 1 if not off < len(s): break - if term and s[off-1] == ",": - break return vals, off -- cgit v1.2.3 From d739882bf2dc65925c001c5bf848f5664640d299 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Tue, 14 Apr 2015 13:50:57 +1200 Subject: Add an .extend method for ODicts --- netlib/odict.py | 6 ++++++ 1 file changed, 6 insertions(+) (limited to 'netlib') diff --git a/netlib/odict.py b/netlib/odict.py index 7a2f611b..7a54f282 100644 --- a/netlib/odict.py +++ b/netlib/odict.py @@ -108,6 +108,12 @@ class ODict(object): lst = copy.deepcopy(self.lst) return self.__class__(lst) + def extend(self, other): + """ + Add the contents of other, preserving any duplicates. + """ + self.lst.extend(other.lst) + def __repr__(self): elements = [] for itm in self.lst: -- cgit v1.2.3 From aeebf31927eb3ff74824525005c7b146024de6d5 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Tue, 14 Apr 2015 16:20:02 +1200 Subject: odict: don't convert values to strings when added --- netlib/odict.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'netlib') diff --git a/netlib/odict.py b/netlib/odict.py index 7a54f282..a0ea9e53 100644 --- a/netlib/odict.py +++ b/netlib/odict.py @@ -84,7 +84,7 @@ class ODict(object): return False def add(self, key, value): - self.lst.append([key, str(value)]) + self.lst.append([key, value]) def get(self, k, d=None): if k in self: @@ -117,7 +117,7 @@ class ODict(object): def __repr__(self): elements = [] for itm in self.lst: - elements.append(itm[0] + ": " + itm[1]) + elements.append(itm[0] + ": " + str(itm[1])) elements.append("") return "\r\n".join(elements) -- cgit v1.2.3 From 0c85c72dc43d0d017e2bf5af9c2def46968d0499 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Wed, 15 Apr 2015 10:28:17 +1200 Subject: ODict improvements - Setting values now tries to preserve the existing order, rather than just appending to the end. - __repr__ now returns a repr of the tuple list. The old repr becomes a .format() method. This is clearer, makes troubleshooting easier, and doesn't assume all data in ODicts are header-like --- netlib/odict.py | 25 +++++++++++++++++++------ netlib/wsgi.py | 29 ++++++++++++++++++----------- 2 files changed, 37 insertions(+), 17 deletions(-) (limited to 'netlib') diff --git a/netlib/odict.py b/netlib/odict.py index a0ea9e53..dd738c55 100644 --- a/netlib/odict.py +++ b/netlib/odict.py @@ -13,7 +13,8 @@ def safe_subn(pattern, repl, target, *args, **kwargs): class ODict(object): """ - A dictionary-like object for managing ordered (key, value) data. + A dictionary-like object for managing ordered (key, value) data. Think + about it as a convenient interface to a list of (key, value) tuples. """ def __init__(self, lst=None): self.lst = lst or [] @@ -64,11 +65,20 @@ class ODict(object): key, they are cleared. """ if isinstance(valuelist, basestring): - raise ValueError("Expected list of values instead of string. Example: odict['Host'] = ['www.example.com']") - - new = self._filter_lst(k, self.lst) - for i in valuelist: - new.append([k, i]) + raise ValueError( + "Expected list of values instead of string. " + "Example: odict['Host'] = ['www.example.com']" + ) + kc = self._kconv(k) + new = [] + for i in self.lst: + if self._kconv(i[0]) == kc: + if valuelist: + new.append([k, valuelist.pop(0)]) + else: + new.append(i) + while valuelist: + new.append([k, valuelist.pop(0)]) self.lst = new def __delitem__(self, k): @@ -115,6 +125,9 @@ class ODict(object): self.lst.extend(other.lst) def __repr__(self): + return repr(self.lst) + + def format(self): elements = [] for itm in self.lst: elements.append(itm[0] + ": " + str(itm[1])) diff --git a/netlib/wsgi.py b/netlib/wsgi.py index bac27d5a..1b979608 100644 --- a/netlib/wsgi.py +++ b/netlib/wsgi.py @@ -1,5 +1,8 @@ from __future__ import (absolute_import, print_function, division) -import cStringIO, urllib, time, traceback +import cStringIO +import urllib +import time +import traceback from . import odict, tcp @@ -23,15 +26,18 @@ class Request(object): def date_time_string(): """Return the current date and time formatted for a message header.""" WEEKS = ['Mon', 'Tue', 'Wed', 'Thu', 'Fri', 'Sat', 'Sun'] - MONTHS = [None, - 'Jan', 'Feb', 'Mar', 'Apr', 'May', 'Jun', - 'Jul', 'Aug', 'Sep', 'Oct', 'Nov', 'Dec'] + MONTHS = [ + None, + 'Jan', 'Feb', 'Mar', 'Apr', 'May', 'Jun', + 'Jul', 'Aug', 'Sep', 'Oct', 'Nov', 'Dec' + ] now = time.time() year, month, day, hh, mm, ss, wd, y, z = time.gmtime(now) s = "%s, %02d %3s %4d %02d:%02d:%02d GMT" % ( - WEEKS[wd], - day, MONTHS[month], year, - hh, mm, ss) + WEEKS[wd], + day, MONTHS[month], year, + hh, mm, ss + ) return s @@ -100,6 +106,7 @@ class WSGIAdaptor(object): status = None, headers = None ) + def write(data): if not state["headers_sent"]: soc.write("HTTP/1.1 %s\r\n"%state["status"]) @@ -108,7 +115,7 @@ class WSGIAdaptor(object): h["Server"] = [self.sversion] if 'date' not in h: h["Date"] = [date_time_string()] - soc.write(str(h)) + soc.write(h.format()) soc.write("\r\n") state["headers_sent"] = True if data: @@ -130,7 +137,9 @@ class WSGIAdaptor(object): errs = cStringIO.StringIO() try: - dataiter = self.app(self.make_environ(request, errs, **env), start_response) + dataiter = self.app( + self.make_environ(request, errs, **env), start_response + ) for i in dataiter: write(i) if not state["headers_sent"]: @@ -143,5 +152,3 @@ class WSGIAdaptor(object): except Exception: # pragma: no cover pass return errs.getvalue() - - -- cgit v1.2.3 From c53d89fd7fad6c46458ab3d0140528e344de605f Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Thu, 16 Apr 2015 08:30:54 +1200 Subject: Improve flexibility of http_cookies._format_pairs --- netlib/http_cookies.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'netlib') diff --git a/netlib/http_cookies.py b/netlib/http_cookies.py index 297efb80..dab95ed0 100644 --- a/netlib/http_cookies.py +++ b/netlib/http_cookies.py @@ -124,7 +124,7 @@ def _has_special(s): ESCAPE = re.compile(r"([\"\\])") -def _format_pairs(lst, specials=()): +def _format_pairs(lst, specials=(), sep="; "): """ specials: A lower-cased list of keys that will not be quoted. """ @@ -137,7 +137,7 @@ def _format_pairs(lst, specials=()): v = ESCAPE.sub(r"\\\1", v) v = '"%s"'%v vals.append("%s=%s"%(k, v)) - return "; ".join(vals) + return sep.join(vals) def _format_set_cookie_pairs(lst): -- cgit v1.2.3 From 488c25d812a321f5a03253b62ab33b61ecc13de1 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Fri, 17 Apr 2015 13:57:39 +1200 Subject: websockets: whitespace, PEP8 --- netlib/websockets/websockets.py | 169 +++++++++++++++++++++++----------------- 1 file changed, 96 insertions(+), 73 deletions(-) (limited to 'netlib') diff --git a/netlib/websockets/websockets.py b/netlib/websockets/websockets.py index ea3db21d..8782ea49 100644 --- a/netlib/websockets/websockets.py +++ b/netlib/websockets/websockets.py @@ -1,31 +1,34 @@ from __future__ import absolute_import -from base64 import b64encode -from hashlib import sha1 -from mimetools import Message -from netlib import tcp -from netlib import utils -from StringIO import StringIO +import base64 +import hashlib +import mimetools +import StringIO import os -import SocketServer import struct import io -# Colleciton of utility functions that implement small portions of the RFC6455 WebSockets Protocol -# Useful for building WebSocket clients and servers. -# -# Emphassis is on readabilty, simplicity and modularity, not performance or completeness +from .. import utils + +# Colleciton of utility functions that implement small portions of the RFC6455 +# WebSockets Protocol Useful for building WebSocket clients and servers. # -# This is a work in progress and does not yet contain all the utilites need to create fully complient client/servers +# Emphassis is on readabilty, simplicity and modularity, not performance or +# completeness # +# This is a work in progress and does not yet contain all the utilites need to +# create fully complient client/servers # # Spec: https://tools.ietf.org/html/rfc6455 -# The magic sha that websocket servers must know to prove they understand RFC6455 +# The magic sha that websocket servers must know to prove they understand +# RFC6455 websockets_magic = '258EAFA5-E914-47DA-95CA-C5AB0DC85B11' + class WebSocketFrameValidationException(Exception): pass + class WebSocketsFrame(object): """ Represents one websockets frame. @@ -33,7 +36,7 @@ class WebSocketsFrame(object): from_bytes() is also avaliable. WebSockets Frame as defined in RFC6455 - + 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 +-+-+-+-+-------+-+-------------+-------------------------------+ |F|R|R|R| opcode|M| Payload len | Extended payload length | @@ -62,7 +65,7 @@ class WebSocketsFrame(object): rsv1 = 0, # decimal integer 1 or 0 rsv2 = 0, # decimal integer 1 or 0 rsv3 = 0, # decimal integer 1 or 0 - payload = None, # bytestring + payload = None, # bytestring masking_key = None, # 32 bit byte string actual_payload_length = None, # any decimal integer ): @@ -81,18 +84,17 @@ class WebSocketsFrame(object): @classmethod def from_bytes(cls, bytestring): """ - Construct a websocket frame from an in-memory bytestring - to construct a frame from a stream of bytes, use from_byte_stream() directly - """ + Construct a websocket frame from an in-memory bytestring to construct + a frame from a stream of bytes, use from_byte_stream() directly + """ return cls.from_byte_stream(io.BytesIO(bytestring).read) - @classmethod def default(cls, message, from_client = False): """ - Construct a basic websocket frame from some default values. + Construct a basic websocket frame from some default values. Creates a non-fragmented text frame. - """ + """ length_code, actual_length = get_payload_length_pair(message) if from_client: @@ -103,7 +105,7 @@ class WebSocketsFrame(object): mask_bit = 0 masking_key = None payload = message - + return cls( fin = 1, # final frame opcode = 1, # text @@ -117,10 +119,10 @@ class WebSocketsFrame(object): def is_valid(self): """ - Validate websocket frame invariants, call at anytime to ensure the WebSocketsFrame - has not been corrupted. - """ - try: + Validate websocket frame invariants, call at anytime to ensure the + WebSocketsFrame has not been corrupted. + """ + try: assert 0 <= self.fin <= 1 assert 0 <= self.rsv1 <= 1 assert 0 <= self.rsv2 <= 1 @@ -128,18 +130,18 @@ class WebSocketsFrame(object): assert 1 <= self.opcode <= 4 assert 0 <= self.mask_bit <= 1 assert 1 <= self.payload_length_code <= 127 - + if self.mask_bit == 1: assert 1 <= len(self.masking_key) <= 4 else: - assert self.masking_key == None - + assert self.masking_key is None + assert self.actual_payload_length == len(self.payload) if self.payload is not None and self.masking_key is not None: assert apply_mask(self.payload, self.masking_key) == self.decoded_payload - return True + return True except AssertionError: return False @@ -165,30 +167,32 @@ class WebSocketsFrame(object): def to_bytes(self): """ - Serialize the frame back into the wire format, returns a bytestring - If you haven't checked is_valid_frame() then there's no guarentees that the - serialized bytes will be correct. see safe_to_bytes() - """ + Serialize the frame back into the wire format, returns a bytestring If + you haven't checked is_valid_frame() then there's no guarentees that + the serialized bytes will be correct. see safe_to_bytes() + """ max_16_bit_int = (1 << 16) max_64_bit_int = (1 << 63) - # break down of the bit-math used to construct the first byte from the frame's integer values - # first shift the significant bit into the correct position + # break down of the bit-math used to construct the first byte from the + # frame's integer values first shift the significant bit into the + # correct position # 00000001 << 7 = 10000000 # ... # then combine: - # + # # 10000000 fin # 01000000 res1 # 00100000 res2 # 00010000 res3 # 00000001 opcode - # -------- OR + # -------- OR # 11110001 = first_byte - first_byte = (self.fin << 7) | (self.rsv1 << 6) | (self.rsv2 << 4) | (self.rsv3 << 4) | self.opcode - + first_byte = (self.fin << 7) | (self.rsv1 << 6) |\ + (self.rsv2 << 4) | (self.rsv3 << 4) | self.opcode + second_byte = (self.mask_bit << 7) | self.payload_length_code bytes = chr(first_byte) + chr(second_byte) @@ -199,11 +203,13 @@ class WebSocketsFrame(object): elif self.actual_payload_length < max_16_bit_int: # '!H' pack as 16 bit unsigned short - bytes += struct.pack('!H', self.actual_payload_length) # add 2 byte extended payload length - + # add 2 byte extended payload length + bytes += struct.pack('!H', self.actual_payload_length) + elif self.actual_payload_length < max_64_bit_int: # '!Q' = pack as 64 bit unsigned long long - bytes += struct.pack('!Q', self.actual_payload_length) # add 8 bytes extended payload length + # add 8 bytes extended payload length + bytes += struct.pack('!Q', self.actual_payload_length) if self.masking_key is not None: bytes += self.masking_key @@ -212,43 +218,46 @@ class WebSocketsFrame(object): return bytes - @classmethod def from_byte_stream(cls, read_bytes): """ read a websockets frame sent by a server or client - + read_bytes is a function that can be backed - by sockets or by any byte reader. So this + by sockets or by any byte reader. So this function may be used to read frames from disk/wire/memory - """ - first_byte = utils.bytes_to_int(read_bytes(1)) + """ + first_byte = utils.bytes_to_int(read_bytes(1)) second_byte = utils.bytes_to_int(read_bytes(1)) - - fin = first_byte >> 7 # grab the left most bit - opcode = first_byte & 15 # grab right most 4 bits by and-ing with 00001111 - mask_bit = second_byte >> 7 # grab left most bit - payload_length = second_byte & 127 # grab the next 7 bits + + # grab the left most bit + fin = first_byte >> 7 + # grab right most 4 bits by and-ing with 00001111 + opcode = first_byte & 15 + # grab left most bit + mask_bit = second_byte >> 7 + # grab the next 7 bits + payload_length = second_byte & 127 # payload_lengthy > 125 indicates you need to read more bytes # to get the actual payload length if payload_length <= 125: - actual_payload_length = payload_length + actual_payload_length = payload_length elif payload_length == 126: - actual_payload_length = utils.bytes_to_int(read_bytes(2)) + actual_payload_length = utils.bytes_to_int(read_bytes(2)) - elif payload_length == 127: - actual_payload_length = utils.bytes_to_int(read_bytes(8)) + elif payload_length == 127: + actual_payload_length = utils.bytes_to_int(read_bytes(8)) # masking key only present if mask bit set if mask_bit == 1: masking_key = read_bytes(4) else: masking_key = None - + payload = read_bytes(actual_payload_length) - + if mask_bit == 1: decoded_payload = apply_mask(payload, masking_key) else: @@ -295,12 +304,15 @@ def apply_mask(message, masking_key): result += chr(ord(char) ^ masks[len(result) % 4]) return result + def random_masking_key(): return os.urandom(4) + def create_client_handshake(host, port, key, version, resource): """ - WebSockets connections are intiated by the client with a valid HTTP upgrade request + WebSockets connections are intiated by the client with a valid HTTP + upgrade request """ headers = [ ('Host', '%s:%s' % (host, port)), @@ -312,10 +324,11 @@ def create_client_handshake(host, port, key, version, resource): request = "GET %s HTTP/1.1" % resource return build_handshake(headers, request) + def create_server_handshake(key): """ - The server response is a valid HTTP 101 response. - """ + The server response is a valid HTTP 101 response. + """ headers = [ ('Connection', 'Upgrade'), ('Upgrade', 'websocket'), @@ -332,12 +345,13 @@ def build_handshake(headers, request): handshake.append(b'\r\n') return b'\r\n'.join(handshake) + def read_handshake(read_bytes, num_bytes_per_read): """ - From provided function that reads bytes, read in a + From provided function that reads bytes, read in a complete HTTP request, which terminates with a CLRF - """ - response = b'' + """ + response = b'' doubleCLRF = b'\r\n\r\n' while True: bytes = read_bytes(num_bytes_per_read) @@ -348,14 +362,15 @@ def read_handshake(read_bytes, num_bytes_per_read): break return response + def get_payload_length_pair(payload_bytestring): """ A websockets frame contains an initial length_code, and an optional - extended length code to represent the actual length if length code is larger - than 125 - """ + extended length code to represent the actual length if length code is + larger than 125 + """ actual_length = len(payload_bytestring) - + if actual_length <= 125: length_code = actual_length elif actual_length >= 126 and actual_length <= 65535: @@ -364,6 +379,7 @@ def get_payload_length_pair(payload_bytestring): length_code = 127 return (length_code, actual_length) + def process_handshake_from_client(handshake): headers = headers_from_http_message(handshake) if headers.get("Upgrade", None) != "websocket": @@ -371,6 +387,7 @@ def process_handshake_from_client(handshake): key = headers['Sec-WebSocket-Key'] return key + def process_handshake_from_server(handshake, client_nounce): headers = headers_from_http_message(handshake) if headers.get("Upgrade", None) != "websocket": @@ -378,12 +395,18 @@ def process_handshake_from_server(handshake, client_nounce): key = headers['Sec-WebSocket-Accept'] return key + def headers_from_http_message(http_message): - return Message(StringIO(http_message.split('\r\n', 1)[1])) + return mimetools.Message( + StringIO.StringIO(http_message.split('\r\n', 1)[1]) + ) + def create_server_nounce(client_nounce): - return b64encode(sha1(client_nounce + websockets_magic).hexdigest().decode('hex')) + return base64.b64encode( + hashlib.sha1(client_nounce + websockets_magic).hexdigest().decode('hex') + ) -def create_client_nounce(): - return b64encode(os.urandom(16)).decode('utf-8') +def create_client_nounce(): + return base64.b64encode(os.urandom(16)).decode('utf-8') -- cgit v1.2.3 From 7defb5be862a4251da9d7c530593f7e9be3e739e Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Fri, 17 Apr 2015 14:29:20 +1200 Subject: websockets: more whitespace, WebSocketFrame -> Frame --- netlib/websockets/implementations.py | 12 ++--- netlib/websockets/websockets.py | 100 +++++++++++++++++------------------ 2 files changed, 55 insertions(+), 57 deletions(-) (limited to 'netlib') diff --git a/netlib/websockets/implementations.py b/netlib/websockets/implementations.py index 1ded3b85..337c5496 100644 --- a/netlib/websockets/implementations.py +++ b/netlib/websockets/implementations.py @@ -9,7 +9,7 @@ import os # Simple websocket client and servers that are used to exercise the functionality in websockets.py # These are *not* fully RFC6455 compliant -class WebSocketsEchoHandler(tcp.BaseHandler): +class WebSocketsEchoHandler(tcp.BaseHandler): def __init__(self, connection, address, server): super(WebSocketsEchoHandler, self).__init__(connection, address, server) self.handshake_done = False @@ -22,14 +22,14 @@ class WebSocketsEchoHandler(tcp.BaseHandler): self.read_next_message() def read_next_message(self): - decoded = ws.WebSocketsFrame.from_byte_stream(self.rfile.read).decoded_payload + decoded = ws.Frame.from_byte_stream(self.rfile.read).decoded_payload self.on_message(decoded) def send_message(self, message): - frame = ws.WebSocketsFrame.default(message, from_client = False) + frame = ws.Frame.default(message, from_client = False) self.wfile.write(frame.safe_to_bytes()) self.wfile.flush() - + def handshake(self): client_hs = ws.read_handshake(self.rfile.read, 1) key = ws.process_handshake_from_client(client_hs) @@ -72,9 +72,9 @@ class WebSocketsClient(tcp.TCPClient): self.close() def read_next_message(self): - return ws.WebSocketsFrame.from_byte_stream(self.rfile.read).payload + return ws.Frame.from_byte_stream(self.rfile.read).payload def send_message(self, message): - frame = ws.WebSocketsFrame.default(message, from_client = True) + frame = ws.Frame.default(message, from_client = True) self.wfile.write(frame.safe_to_bytes()) self.wfile.flush() diff --git a/netlib/websockets/websockets.py b/netlib/websockets/websockets.py index 8782ea49..86d98caf 100644 --- a/netlib/websockets/websockets.py +++ b/netlib/websockets/websockets.py @@ -29,7 +29,7 @@ class WebSocketFrameValidationException(Exception): pass -class WebSocketsFrame(object): +class Frame(object): """ Represents one websockets frame. Constructor takes human readable forms of the frame components @@ -98,29 +98,29 @@ class WebSocketsFrame(object): length_code, actual_length = get_payload_length_pair(message) if from_client: - mask_bit = 1 + mask_bit = 1 masking_key = random_masking_key() - payload = apply_mask(message, masking_key) + payload = apply_mask(message, masking_key) else: - mask_bit = 0 + mask_bit = 0 masking_key = None - payload = message + payload = message return cls( - fin = 1, # final frame - opcode = 1, # text - mask_bit = mask_bit, - payload_length_code = length_code, - payload = payload, - masking_key = masking_key, - decoded_payload = message, + fin = 1, # final frame + opcode = 1, # text + mask_bit = mask_bit, + payload_length_code = length_code, + payload = payload, + masking_key = masking_key, + decoded_payload = message, actual_payload_length = actual_length ) def is_valid(self): """ - Validate websocket frame invariants, call at anytime to ensure the - WebSocketsFrame has not been corrupted. + Validate websocket frame invariants, call at anytime to ensure the + Frame has not been corrupted. """ try: assert 0 <= self.fin <= 1 @@ -147,17 +147,18 @@ class WebSocketsFrame(object): def human_readable(self): return "\n".join([ - ("fin - " + str(self.fin)), - ("rsv1 - " + str(self.rsv1)), - ("rsv2 - " + str(self.rsv2)), - ("rsv3 - " + str(self.rsv3)), - ("opcode - " + str(self.opcode)), - ("mask_bit - " + str(self.mask_bit)), - ("payload_length_code - " + str(self.payload_length_code)), - ("masking_key - " + str(self.masking_key)), - ("payload - " + str(self.payload)), - ("decoded_payload - " + str(self.decoded_payload)), - ("actual_payload_length - " + str(self.actual_payload_length))]) + ("fin - " + str(self.fin)), + ("rsv1 - " + str(self.rsv1)), + ("rsv2 - " + str(self.rsv2)), + ("rsv3 - " + str(self.rsv3)), + ("opcode - " + str(self.opcode)), + ("mask_bit - " + str(self.mask_bit)), + ("payload_length_code - " + str(self.payload_length_code)), + ("masking_key - " + str(self.masking_key)), + ("payload - " + str(self.payload)), + ("decoded_payload - " + str(self.decoded_payload)), + ("actual_payload_length - " + str(self.actual_payload_length)) + ]) def safe_to_bytes(self): if self.is_valid(): @@ -167,11 +168,10 @@ class WebSocketsFrame(object): def to_bytes(self): """ - Serialize the frame back into the wire format, returns a bytestring If - you haven't checked is_valid_frame() then there's no guarentees that - the serialized bytes will be correct. see safe_to_bytes() + Serialize the frame back into the wire format, returns a bytestring + If you haven't checked is_valid_frame() then there's no guarentees + that the serialized bytes will be correct. see safe_to_bytes() """ - max_16_bit_int = (1 << 16) max_64_bit_int = (1 << 63) @@ -199,13 +199,10 @@ class WebSocketsFrame(object): if self.actual_payload_length < 126: pass - elif self.actual_payload_length < max_16_bit_int: - # '!H' pack as 16 bit unsigned short # add 2 byte extended payload length bytes += struct.pack('!H', self.actual_payload_length) - elif self.actual_payload_length < max_64_bit_int: # '!Q' = pack as 64 bit unsigned long long # add 8 bytes extended payload length @@ -215,7 +212,6 @@ class WebSocketsFrame(object): bytes += self.masking_key bytes += self.payload # already will be encoded if neccessary - return bytes @classmethod @@ -264,29 +260,31 @@ class WebSocketsFrame(object): decoded_payload = payload return cls( - fin = fin, - opcode = opcode, - mask_bit = mask_bit, - payload_length_code = payload_length, - payload = payload, - masking_key = masking_key, - decoded_payload = decoded_payload, + fin = fin, + opcode = opcode, + mask_bit = mask_bit, + payload_length_code = payload_length, + payload = payload, + masking_key = masking_key, + decoded_payload = decoded_payload, actual_payload_length = actual_payload_length ) def __eq__(self, other): return ( - self.fin == other.fin and - self.rsv1 == other.rsv1 and - self.rsv2 == other.rsv2 and - self.rsv3 == other.rsv3 and - self.opcode == other.opcode and - self.mask_bit == other.mask_bit and - self.payload_length_code == other.payload_length_code and - self.masking_key == other.masking_key and - self.payload == other.payload and - self.decoded_payload == other.decoded_payload and - self.actual_payload_length == other.actual_payload_length) + self.fin == other.fin and + self.rsv1 == other.rsv1 and + self.rsv2 == other.rsv2 and + self.rsv3 == other.rsv3 and + self.opcode == other.opcode and + self.mask_bit == other.mask_bit and + self.payload_length_code == other.payload_length_code and + self.masking_key == other.masking_key and + self.payload == other.payload and + self.decoded_payload == other.decoded_payload and + self.actual_payload_length == other.actual_payload_length + ) + def apply_mask(message, masking_key): """ -- cgit v1.2.3 From 0c2ad1edb1af013576f4ac31e05b308ffb440116 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Fri, 17 Apr 2015 16:29:09 +0200 Subject: fix socket_close on Windows, refs mitmproxy/mitmproxy#527 --- netlib/tcp.py | 32 ++++++++++++++++++++------------ 1 file changed, 20 insertions(+), 12 deletions(-) (limited to 'netlib') diff --git a/netlib/tcp.py b/netlib/tcp.py index 10269aa4..84008e2c 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -247,24 +247,32 @@ def close_socket(sock): """ try: # We already indicate that we close our end. - # If we close RD, any further received bytes would result in a RST being set, which we want to avoid - # for our purposes sock.shutdown(socket.SHUT_WR) # may raise "Transport endpoint is not connected" on Linux # Section 4.2.2.13 of RFC 1122 tells us that a close() with any # pending readable data could lead to an immediate RST being sent (which is the case on Windows). # http://ia600609.us.archive.org/22/items/TheUltimateSo_lingerPageOrWhyIsMyTcpNotReliable/the-ultimate-so_linger-page-or-why-is-my-tcp-not-reliable.html # - # However, we cannot rely on the shutdown()-followed-by-read()-eof technique proposed by the page above: - # Some remote machines just don't send a TCP FIN, which would leave us in the unfortunate situation that - # recv() would block infinitely. - # As a workaround, we set a timeout here even if we are in blocking mode. - # Please let us know if you have a better solution to this problem. - - sock.settimeout(sock.gettimeout() or 20) - # may raise a timeout/disconnect exception. - while sock.recv(4096): # pragma: no cover - pass + # This in turn results in the following issue: If we send an error page to the client and then close the socket, + # the RST may be received by the client before the error page and the users sees a connection error rather than + # the error page. Thus, we try to empty the read buffer on Windows first. + # (see https://github.com/mitmproxy/mitmproxy/issues/527#issuecomment-93782988) + # + if os.name == "nt": # pragma: no cover + # We cannot rely on the shutdown()-followed-by-read()-eof technique proposed by the page above: + # Some remote machines just don't send a TCP FIN, which would leave us in the unfortunate situation that + # recv() would block infinitely. + # As a workaround, we set a timeout here even if we are in blocking mode. + sock.settimeout(sock.gettimeout() or 20) + + # limit at a megabyte so that we don't read infinitely + for _ in xrange(1024 ** 3 // 4096): + # may raise a timeout/disconnect exception. + if not sock.recv(4096): + break + + # Now we can close the other half as well. + sock.shutdown(socket.SHUT_RD) except socket.error: pass -- cgit v1.2.3 From 74389ef04a3fdda4d388acb6d655adde78fccd7d Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Mon, 20 Apr 2015 09:38:09 +1200 Subject: Websockets: reorganise - websockets.py to top-level - implementations into test suite --- netlib/websockets.py | 410 +++++++++++++++++++++++++++++++++++ netlib/websockets/__init__.py | 1 - netlib/websockets/implementations.py | 80 ------- netlib/websockets/websockets.py | 410 ----------------------------------- 4 files changed, 410 insertions(+), 491 deletions(-) create mode 100644 netlib/websockets.py delete mode 100644 netlib/websockets/__init__.py delete mode 100644 netlib/websockets/implementations.py delete mode 100644 netlib/websockets/websockets.py (limited to 'netlib') diff --git a/netlib/websockets.py b/netlib/websockets.py new file mode 100644 index 00000000..83e90238 --- /dev/null +++ b/netlib/websockets.py @@ -0,0 +1,410 @@ +from __future__ import absolute_import + +import base64 +import hashlib +import mimetools +import StringIO +import os +import struct +import io + +from . import utils + +# Colleciton of utility functions that implement small portions of the RFC6455 +# WebSockets Protocol Useful for building WebSocket clients and servers. +# +# Emphassis is on readabilty, simplicity and modularity, not performance or +# completeness +# +# This is a work in progress and does not yet contain all the utilites need to +# create fully complient client/servers # +# Spec: https://tools.ietf.org/html/rfc6455 + +# The magic sha that websocket servers must know to prove they understand +# RFC6455 +websockets_magic = '258EAFA5-E914-47DA-95CA-C5AB0DC85B11' + + +class WebSocketFrameValidationException(Exception): + pass + + +class Frame(object): + """ + Represents one websockets frame. + Constructor takes human readable forms of the frame components + from_bytes() is also avaliable. + + WebSockets Frame as defined in RFC6455 + + 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 + +-+-+-+-+-------+-+-------------+-------------------------------+ + |F|R|R|R| opcode|M| Payload len | Extended payload length | + |I|S|S|S| (4) |A| (7) | (16/64) | + |N|V|V|V| |S| | (if payload len==126/127) | + | |1|2|3| |K| | | + +-+-+-+-+-------+-+-------------+ - - - - - - - - - - - - - - - + + | Extended payload length continued, if payload len == 127 | + + - - - - - - - - - - - - - - - +-------------------------------+ + | |Masking-key, if MASK set to 1 | + +-------------------------------+-------------------------------+ + | Masking-key (continued) | Payload Data | + +-------------------------------- - - - - - - - - - - - - - - - + + : Payload Data continued ... : + + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + | Payload Data continued ... | + +---------------------------------------------------------------+ + """ + def __init__( + self, + fin, # decmial integer 1 or 0 + opcode, # decmial integer 1 - 4 + mask_bit, # decimal integer 1 or 0 + payload_length_code, # decimal integer 1 - 127 + decoded_payload, # bytestring + rsv1 = 0, # decimal integer 1 or 0 + rsv2 = 0, # decimal integer 1 or 0 + rsv3 = 0, # decimal integer 1 or 0 + payload = None, # bytestring + masking_key = None, # 32 bit byte string + actual_payload_length = None, # any decimal integer + ): + self.fin = fin + self.rsv1 = rsv1 + self.rsv2 = rsv2 + self.rsv3 = rsv3 + self.opcode = opcode + self.mask_bit = mask_bit + self.payload_length_code = payload_length_code + self.masking_key = masking_key + self.payload = payload + self.decoded_payload = decoded_payload + self.actual_payload_length = actual_payload_length + + @classmethod + def from_bytes(cls, bytestring): + """ + Construct a websocket frame from an in-memory bytestring to construct + a frame from a stream of bytes, use from_byte_stream() directly + """ + return cls.from_byte_stream(io.BytesIO(bytestring).read) + + @classmethod + def default(cls, message, from_client = False): + """ + Construct a basic websocket frame from some default values. + Creates a non-fragmented text frame. + """ + length_code, actual_length = get_payload_length_pair(message) + + if from_client: + mask_bit = 1 + masking_key = random_masking_key() + payload = apply_mask(message, masking_key) + else: + mask_bit = 0 + masking_key = None + payload = message + + return cls( + fin = 1, # final frame + opcode = 1, # text + mask_bit = mask_bit, + payload_length_code = length_code, + payload = payload, + masking_key = masking_key, + decoded_payload = message, + actual_payload_length = actual_length + ) + + def is_valid(self): + """ + Validate websocket frame invariants, call at anytime to ensure the + Frame has not been corrupted. + """ + try: + assert 0 <= self.fin <= 1 + assert 0 <= self.rsv1 <= 1 + assert 0 <= self.rsv2 <= 1 + assert 0 <= self.rsv3 <= 1 + assert 1 <= self.opcode <= 4 + assert 0 <= self.mask_bit <= 1 + assert 1 <= self.payload_length_code <= 127 + + if self.mask_bit == 1: + assert 1 <= len(self.masking_key) <= 4 + else: + assert self.masking_key is None + + assert self.actual_payload_length == len(self.payload) + + if self.payload is not None and self.masking_key is not None: + assert apply_mask(self.payload, self.masking_key) == self.decoded_payload + + return True + except AssertionError: + return False + + def human_readable(self): + return "\n".join([ + ("fin - " + str(self.fin)), + ("rsv1 - " + str(self.rsv1)), + ("rsv2 - " + str(self.rsv2)), + ("rsv3 - " + str(self.rsv3)), + ("opcode - " + str(self.opcode)), + ("mask_bit - " + str(self.mask_bit)), + ("payload_length_code - " + str(self.payload_length_code)), + ("masking_key - " + str(self.masking_key)), + ("payload - " + str(self.payload)), + ("decoded_payload - " + str(self.decoded_payload)), + ("actual_payload_length - " + str(self.actual_payload_length)) + ]) + + def safe_to_bytes(self): + if self.is_valid(): + return self.to_bytes() + else: + raise WebSocketFrameValidationException() + + def to_bytes(self): + """ + Serialize the frame back into the wire format, returns a bytestring + If you haven't checked is_valid_frame() then there's no guarentees + that the serialized bytes will be correct. see safe_to_bytes() + """ + max_16_bit_int = (1 << 16) + max_64_bit_int = (1 << 63) + + # break down of the bit-math used to construct the first byte from the + # frame's integer values first shift the significant bit into the + # correct position + # 00000001 << 7 = 10000000 + # ... + # then combine: + # + # 10000000 fin + # 01000000 res1 + # 00100000 res2 + # 00010000 res3 + # 00000001 opcode + # -------- OR + # 11110001 = first_byte + + first_byte = (self.fin << 7) | (self.rsv1 << 6) |\ + (self.rsv2 << 4) | (self.rsv3 << 4) | self.opcode + + second_byte = (self.mask_bit << 7) | self.payload_length_code + + bytes = chr(first_byte) + chr(second_byte) + + if self.actual_payload_length < 126: + pass + elif self.actual_payload_length < max_16_bit_int: + # '!H' pack as 16 bit unsigned short + # add 2 byte extended payload length + bytes += struct.pack('!H', self.actual_payload_length) + elif self.actual_payload_length < max_64_bit_int: + # '!Q' = pack as 64 bit unsigned long long + # add 8 bytes extended payload length + bytes += struct.pack('!Q', self.actual_payload_length) + + if self.masking_key is not None: + bytes += self.masking_key + + bytes += self.payload # already will be encoded if neccessary + return bytes + + @classmethod + def from_byte_stream(cls, read_bytes): + """ + read a websockets frame sent by a server or client + + read_bytes is a function that can be backed + by sockets or by any byte reader. So this + function may be used to read frames from disk/wire/memory + """ + first_byte = utils.bytes_to_int(read_bytes(1)) + second_byte = utils.bytes_to_int(read_bytes(1)) + + # grab the left most bit + fin = first_byte >> 7 + # grab right most 4 bits by and-ing with 00001111 + opcode = first_byte & 15 + # grab left most bit + mask_bit = second_byte >> 7 + # grab the next 7 bits + payload_length = second_byte & 127 + + # payload_lengthy > 125 indicates you need to read more bytes + # to get the actual payload length + if payload_length <= 125: + actual_payload_length = payload_length + + elif payload_length == 126: + actual_payload_length = utils.bytes_to_int(read_bytes(2)) + + elif payload_length == 127: + actual_payload_length = utils.bytes_to_int(read_bytes(8)) + + # masking key only present if mask bit set + if mask_bit == 1: + masking_key = read_bytes(4) + else: + masking_key = None + + payload = read_bytes(actual_payload_length) + + if mask_bit == 1: + decoded_payload = apply_mask(payload, masking_key) + else: + decoded_payload = payload + + return cls( + fin = fin, + opcode = opcode, + mask_bit = mask_bit, + payload_length_code = payload_length, + payload = payload, + masking_key = masking_key, + decoded_payload = decoded_payload, + actual_payload_length = actual_payload_length + ) + + def __eq__(self, other): + return ( + self.fin == other.fin and + self.rsv1 == other.rsv1 and + self.rsv2 == other.rsv2 and + self.rsv3 == other.rsv3 and + self.opcode == other.opcode and + self.mask_bit == other.mask_bit and + self.payload_length_code == other.payload_length_code and + self.masking_key == other.masking_key and + self.payload == other.payload and + self.decoded_payload == other.decoded_payload and + self.actual_payload_length == other.actual_payload_length + ) + + +def apply_mask(message, masking_key): + """ + Data sent from the server must be masked to prevent malicious clients + from sending data over the wire in predictable patterns + + This method both encodes and decodes strings with the provided mask + + Servers do not have to mask data they send to the client. + https://tools.ietf.org/html/rfc6455#section-5.3 + """ + masks = [utils.bytes_to_int(byte) for byte in masking_key] + result = "" + for char in message: + result += chr(ord(char) ^ masks[len(result) % 4]) + return result + + +def random_masking_key(): + return os.urandom(4) + + +def create_client_handshake(host, port, key, version, resource): + """ + WebSockets connections are intiated by the client with a valid HTTP + upgrade request + """ + headers = [ + ('Host', '%s:%s' % (host, port)), + ('Connection', 'Upgrade'), + ('Upgrade', 'websocket'), + ('Sec-WebSocket-Key', key), + ('Sec-WebSocket-Version', version) + ] + request = "GET %s HTTP/1.1" % resource + return build_handshake(headers, request) + + +def create_server_handshake(key): + """ + The server response is a valid HTTP 101 response. + """ + headers = [ + ('Connection', 'Upgrade'), + ('Upgrade', 'websocket'), + ('Sec-WebSocket-Accept', create_server_nounce(key)) + ] + request = "HTTP/1.1 101 Switching Protocols" + return build_handshake(headers, request) + + +def build_handshake(headers, request): + handshake = [request.encode('utf-8')] + for header, value in headers: + handshake.append(("%s: %s" % (header, value)).encode('utf-8')) + handshake.append(b'\r\n') + return b'\r\n'.join(handshake) + + +def read_handshake(read_bytes, num_bytes_per_read): + """ + From provided function that reads bytes, read in a + complete HTTP request, which terminates with a CLRF + """ + response = b'' + doubleCLRF = b'\r\n\r\n' + while True: + bytes = read_bytes(num_bytes_per_read) + if not bytes: + break + response += bytes + if doubleCLRF in response: + break + return response + + +def get_payload_length_pair(payload_bytestring): + """ + A websockets frame contains an initial length_code, and an optional + extended length code to represent the actual length if length code is + larger than 125 + """ + actual_length = len(payload_bytestring) + + if actual_length <= 125: + length_code = actual_length + elif actual_length >= 126 and actual_length <= 65535: + length_code = 126 + else: + length_code = 127 + return (length_code, actual_length) + + +def process_handshake_from_client(handshake): + headers = headers_from_http_message(handshake) + if headers.get("Upgrade", None) != "websocket": + return + key = headers['Sec-WebSocket-Key'] + return key + + +def process_handshake_from_server(handshake, client_nounce): + headers = headers_from_http_message(handshake) + if headers.get("Upgrade", None) != "websocket": + return + key = headers['Sec-WebSocket-Accept'] + return key + + +def headers_from_http_message(http_message): + return mimetools.Message( + StringIO.StringIO(http_message.split('\r\n', 1)[1]) + ) + + +def create_server_nounce(client_nounce): + return base64.b64encode( + hashlib.sha1(client_nounce + websockets_magic).hexdigest().decode('hex') + ) + + +def create_client_nounce(): + return base64.b64encode(os.urandom(16)).decode('utf-8') diff --git a/netlib/websockets/__init__.py b/netlib/websockets/__init__.py deleted file mode 100644 index 9b4faa33..00000000 --- a/netlib/websockets/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from __future__ import (absolute_import, print_function, division) diff --git a/netlib/websockets/implementations.py b/netlib/websockets/implementations.py deleted file mode 100644 index 337c5496..00000000 --- a/netlib/websockets/implementations.py +++ /dev/null @@ -1,80 +0,0 @@ -from netlib import tcp -from base64 import b64encode -from StringIO import StringIO -from . import websockets as ws -import struct -import SocketServer -import os - -# Simple websocket client and servers that are used to exercise the functionality in websockets.py -# These are *not* fully RFC6455 compliant - -class WebSocketsEchoHandler(tcp.BaseHandler): - def __init__(self, connection, address, server): - super(WebSocketsEchoHandler, self).__init__(connection, address, server) - self.handshake_done = False - - def handle(self): - while True: - if not self.handshake_done: - self.handshake() - else: - self.read_next_message() - - def read_next_message(self): - decoded = ws.Frame.from_byte_stream(self.rfile.read).decoded_payload - self.on_message(decoded) - - def send_message(self, message): - frame = ws.Frame.default(message, from_client = False) - self.wfile.write(frame.safe_to_bytes()) - self.wfile.flush() - - def handshake(self): - client_hs = ws.read_handshake(self.rfile.read, 1) - key = ws.process_handshake_from_client(client_hs) - response = ws.create_server_handshake(key) - self.wfile.write(response) - self.wfile.flush() - self.handshake_done = True - - def on_message(self, message): - if message is not None: - self.send_message(message) - - -class WebSocketsClient(tcp.TCPClient): - def __init__(self, address, source_address=None): - super(WebSocketsClient, self).__init__(address, source_address) - self.version = "13" - self.client_nounce = ws.create_client_nounce() - self.resource = "/" - - def connect(self): - super(WebSocketsClient, self).connect() - - handshake = ws.create_client_handshake( - self.address.host, - self.address.port, - self.client_nounce, - self.version, - self.resource - ) - - self.wfile.write(handshake) - self.wfile.flush() - - server_handshake = ws.read_handshake(self.rfile.read, 1) - - server_nounce = ws.process_handshake_from_server(server_handshake, self.client_nounce) - - if not server_nounce == ws.create_server_nounce(self.client_nounce): - self.close() - - def read_next_message(self): - return ws.Frame.from_byte_stream(self.rfile.read).payload - - def send_message(self, message): - frame = ws.Frame.default(message, from_client = True) - self.wfile.write(frame.safe_to_bytes()) - self.wfile.flush() diff --git a/netlib/websockets/websockets.py b/netlib/websockets/websockets.py deleted file mode 100644 index 86d98caf..00000000 --- a/netlib/websockets/websockets.py +++ /dev/null @@ -1,410 +0,0 @@ -from __future__ import absolute_import - -import base64 -import hashlib -import mimetools -import StringIO -import os -import struct -import io - -from .. import utils - -# Colleciton of utility functions that implement small portions of the RFC6455 -# WebSockets Protocol Useful for building WebSocket clients and servers. -# -# Emphassis is on readabilty, simplicity and modularity, not performance or -# completeness -# -# This is a work in progress and does not yet contain all the utilites need to -# create fully complient client/servers # -# Spec: https://tools.ietf.org/html/rfc6455 - -# The magic sha that websocket servers must know to prove they understand -# RFC6455 -websockets_magic = '258EAFA5-E914-47DA-95CA-C5AB0DC85B11' - - -class WebSocketFrameValidationException(Exception): - pass - - -class Frame(object): - """ - Represents one websockets frame. - Constructor takes human readable forms of the frame components - from_bytes() is also avaliable. - - WebSockets Frame as defined in RFC6455 - - 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 - +-+-+-+-+-------+-+-------------+-------------------------------+ - |F|R|R|R| opcode|M| Payload len | Extended payload length | - |I|S|S|S| (4) |A| (7) | (16/64) | - |N|V|V|V| |S| | (if payload len==126/127) | - | |1|2|3| |K| | | - +-+-+-+-+-------+-+-------------+ - - - - - - - - - - - - - - - + - | Extended payload length continued, if payload len == 127 | - + - - - - - - - - - - - - - - - +-------------------------------+ - | |Masking-key, if MASK set to 1 | - +-------------------------------+-------------------------------+ - | Masking-key (continued) | Payload Data | - +-------------------------------- - - - - - - - - - - - - - - - + - : Payload Data continued ... : - + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + - | Payload Data continued ... | - +---------------------------------------------------------------+ - """ - def __init__( - self, - fin, # decmial integer 1 or 0 - opcode, # decmial integer 1 - 4 - mask_bit, # decimal integer 1 or 0 - payload_length_code, # decimal integer 1 - 127 - decoded_payload, # bytestring - rsv1 = 0, # decimal integer 1 or 0 - rsv2 = 0, # decimal integer 1 or 0 - rsv3 = 0, # decimal integer 1 or 0 - payload = None, # bytestring - masking_key = None, # 32 bit byte string - actual_payload_length = None, # any decimal integer - ): - self.fin = fin - self.rsv1 = rsv1 - self.rsv2 = rsv2 - self.rsv3 = rsv3 - self.opcode = opcode - self.mask_bit = mask_bit - self.payload_length_code = payload_length_code - self.masking_key = masking_key - self.payload = payload - self.decoded_payload = decoded_payload - self.actual_payload_length = actual_payload_length - - @classmethod - def from_bytes(cls, bytestring): - """ - Construct a websocket frame from an in-memory bytestring to construct - a frame from a stream of bytes, use from_byte_stream() directly - """ - return cls.from_byte_stream(io.BytesIO(bytestring).read) - - @classmethod - def default(cls, message, from_client = False): - """ - Construct a basic websocket frame from some default values. - Creates a non-fragmented text frame. - """ - length_code, actual_length = get_payload_length_pair(message) - - if from_client: - mask_bit = 1 - masking_key = random_masking_key() - payload = apply_mask(message, masking_key) - else: - mask_bit = 0 - masking_key = None - payload = message - - return cls( - fin = 1, # final frame - opcode = 1, # text - mask_bit = mask_bit, - payload_length_code = length_code, - payload = payload, - masking_key = masking_key, - decoded_payload = message, - actual_payload_length = actual_length - ) - - def is_valid(self): - """ - Validate websocket frame invariants, call at anytime to ensure the - Frame has not been corrupted. - """ - try: - assert 0 <= self.fin <= 1 - assert 0 <= self.rsv1 <= 1 - assert 0 <= self.rsv2 <= 1 - assert 0 <= self.rsv3 <= 1 - assert 1 <= self.opcode <= 4 - assert 0 <= self.mask_bit <= 1 - assert 1 <= self.payload_length_code <= 127 - - if self.mask_bit == 1: - assert 1 <= len(self.masking_key) <= 4 - else: - assert self.masking_key is None - - assert self.actual_payload_length == len(self.payload) - - if self.payload is not None and self.masking_key is not None: - assert apply_mask(self.payload, self.masking_key) == self.decoded_payload - - return True - except AssertionError: - return False - - def human_readable(self): - return "\n".join([ - ("fin - " + str(self.fin)), - ("rsv1 - " + str(self.rsv1)), - ("rsv2 - " + str(self.rsv2)), - ("rsv3 - " + str(self.rsv3)), - ("opcode - " + str(self.opcode)), - ("mask_bit - " + str(self.mask_bit)), - ("payload_length_code - " + str(self.payload_length_code)), - ("masking_key - " + str(self.masking_key)), - ("payload - " + str(self.payload)), - ("decoded_payload - " + str(self.decoded_payload)), - ("actual_payload_length - " + str(self.actual_payload_length)) - ]) - - def safe_to_bytes(self): - if self.is_valid(): - return self.to_bytes() - else: - raise WebSocketFrameValidationException() - - def to_bytes(self): - """ - Serialize the frame back into the wire format, returns a bytestring - If you haven't checked is_valid_frame() then there's no guarentees - that the serialized bytes will be correct. see safe_to_bytes() - """ - max_16_bit_int = (1 << 16) - max_64_bit_int = (1 << 63) - - # break down of the bit-math used to construct the first byte from the - # frame's integer values first shift the significant bit into the - # correct position - # 00000001 << 7 = 10000000 - # ... - # then combine: - # - # 10000000 fin - # 01000000 res1 - # 00100000 res2 - # 00010000 res3 - # 00000001 opcode - # -------- OR - # 11110001 = first_byte - - first_byte = (self.fin << 7) | (self.rsv1 << 6) |\ - (self.rsv2 << 4) | (self.rsv3 << 4) | self.opcode - - second_byte = (self.mask_bit << 7) | self.payload_length_code - - bytes = chr(first_byte) + chr(second_byte) - - if self.actual_payload_length < 126: - pass - elif self.actual_payload_length < max_16_bit_int: - # '!H' pack as 16 bit unsigned short - # add 2 byte extended payload length - bytes += struct.pack('!H', self.actual_payload_length) - elif self.actual_payload_length < max_64_bit_int: - # '!Q' = pack as 64 bit unsigned long long - # add 8 bytes extended payload length - bytes += struct.pack('!Q', self.actual_payload_length) - - if self.masking_key is not None: - bytes += self.masking_key - - bytes += self.payload # already will be encoded if neccessary - return bytes - - @classmethod - def from_byte_stream(cls, read_bytes): - """ - read a websockets frame sent by a server or client - - read_bytes is a function that can be backed - by sockets or by any byte reader. So this - function may be used to read frames from disk/wire/memory - """ - first_byte = utils.bytes_to_int(read_bytes(1)) - second_byte = utils.bytes_to_int(read_bytes(1)) - - # grab the left most bit - fin = first_byte >> 7 - # grab right most 4 bits by and-ing with 00001111 - opcode = first_byte & 15 - # grab left most bit - mask_bit = second_byte >> 7 - # grab the next 7 bits - payload_length = second_byte & 127 - - # payload_lengthy > 125 indicates you need to read more bytes - # to get the actual payload length - if payload_length <= 125: - actual_payload_length = payload_length - - elif payload_length == 126: - actual_payload_length = utils.bytes_to_int(read_bytes(2)) - - elif payload_length == 127: - actual_payload_length = utils.bytes_to_int(read_bytes(8)) - - # masking key only present if mask bit set - if mask_bit == 1: - masking_key = read_bytes(4) - else: - masking_key = None - - payload = read_bytes(actual_payload_length) - - if mask_bit == 1: - decoded_payload = apply_mask(payload, masking_key) - else: - decoded_payload = payload - - return cls( - fin = fin, - opcode = opcode, - mask_bit = mask_bit, - payload_length_code = payload_length, - payload = payload, - masking_key = masking_key, - decoded_payload = decoded_payload, - actual_payload_length = actual_payload_length - ) - - def __eq__(self, other): - return ( - self.fin == other.fin and - self.rsv1 == other.rsv1 and - self.rsv2 == other.rsv2 and - self.rsv3 == other.rsv3 and - self.opcode == other.opcode and - self.mask_bit == other.mask_bit and - self.payload_length_code == other.payload_length_code and - self.masking_key == other.masking_key and - self.payload == other.payload and - self.decoded_payload == other.decoded_payload and - self.actual_payload_length == other.actual_payload_length - ) - - -def apply_mask(message, masking_key): - """ - Data sent from the server must be masked to prevent malicious clients - from sending data over the wire in predictable patterns - - This method both encodes and decodes strings with the provided mask - - Servers do not have to mask data they send to the client. - https://tools.ietf.org/html/rfc6455#section-5.3 - """ - masks = [utils.bytes_to_int(byte) for byte in masking_key] - result = "" - for char in message: - result += chr(ord(char) ^ masks[len(result) % 4]) - return result - - -def random_masking_key(): - return os.urandom(4) - - -def create_client_handshake(host, port, key, version, resource): - """ - WebSockets connections are intiated by the client with a valid HTTP - upgrade request - """ - headers = [ - ('Host', '%s:%s' % (host, port)), - ('Connection', 'Upgrade'), - ('Upgrade', 'websocket'), - ('Sec-WebSocket-Key', key), - ('Sec-WebSocket-Version', version) - ] - request = "GET %s HTTP/1.1" % resource - return build_handshake(headers, request) - - -def create_server_handshake(key): - """ - The server response is a valid HTTP 101 response. - """ - headers = [ - ('Connection', 'Upgrade'), - ('Upgrade', 'websocket'), - ('Sec-WebSocket-Accept', create_server_nounce(key)) - ] - request = "HTTP/1.1 101 Switching Protocols" - return build_handshake(headers, request) - - -def build_handshake(headers, request): - handshake = [request.encode('utf-8')] - for header, value in headers: - handshake.append(("%s: %s" % (header, value)).encode('utf-8')) - handshake.append(b'\r\n') - return b'\r\n'.join(handshake) - - -def read_handshake(read_bytes, num_bytes_per_read): - """ - From provided function that reads bytes, read in a - complete HTTP request, which terminates with a CLRF - """ - response = b'' - doubleCLRF = b'\r\n\r\n' - while True: - bytes = read_bytes(num_bytes_per_read) - if not bytes: - break - response += bytes - if doubleCLRF in response: - break - return response - - -def get_payload_length_pair(payload_bytestring): - """ - A websockets frame contains an initial length_code, and an optional - extended length code to represent the actual length if length code is - larger than 125 - """ - actual_length = len(payload_bytestring) - - if actual_length <= 125: - length_code = actual_length - elif actual_length >= 126 and actual_length <= 65535: - length_code = 126 - else: - length_code = 127 - return (length_code, actual_length) - - -def process_handshake_from_client(handshake): - headers = headers_from_http_message(handshake) - if headers.get("Upgrade", None) != "websocket": - return - key = headers['Sec-WebSocket-Key'] - return key - - -def process_handshake_from_server(handshake, client_nounce): - headers = headers_from_http_message(handshake) - if headers.get("Upgrade", None) != "websocket": - return - key = headers['Sec-WebSocket-Accept'] - return key - - -def headers_from_http_message(http_message): - return mimetools.Message( - StringIO.StringIO(http_message.split('\r\n', 1)[1]) - ) - - -def create_server_nounce(client_nounce): - return base64.b64encode( - hashlib.sha1(client_nounce + websockets_magic).hexdigest().decode('hex') - ) - - -def create_client_nounce(): - return base64.b64encode(os.urandom(16)).decode('utf-8') -- cgit v1.2.3 From 4ea1ccb638366fbdac2d294c23ce8052dcf250c2 Mon Sep 17 00:00:00 2001 From: Chandler Abraham Date: Sun, 19 Apr 2015 22:18:30 -0700 Subject: fixing test coverage, adding to_file/from_file reader writes to match socks.py --- netlib/websockets.py | 62 ++++++++++++++++++++++++++++------------------------ 1 file changed, 34 insertions(+), 28 deletions(-) (limited to 'netlib') diff --git a/netlib/websockets.py b/netlib/websockets.py index 83e90238..5b9d8fbd 100644 --- a/netlib/websockets.py +++ b/netlib/websockets.py @@ -25,6 +25,11 @@ from . import utils websockets_magic = '258EAFA5-E914-47DA-95CA-C5AB0DC85B11' +class CONST(object): + MAX_16_BIT_INT = (1 << 16) + MAX_64_BIT_INT = (1 << 64) + + class WebSocketFrameValidationException(Exception): pass @@ -81,14 +86,6 @@ class Frame(object): self.decoded_payload = decoded_payload self.actual_payload_length = actual_payload_length - @classmethod - def from_bytes(cls, bytestring): - """ - Construct a websocket frame from an in-memory bytestring to construct - a frame from a stream of bytes, use from_byte_stream() directly - """ - return cls.from_byte_stream(io.BytesIO(bytestring).read) - @classmethod def default(cls, message, from_client = False): """ @@ -145,7 +142,7 @@ class Frame(object): except AssertionError: return False - def human_readable(self): + def human_readable(self): # pragma: nocover return "\n".join([ ("fin - " + str(self.fin)), ("rsv1 - " + str(self.rsv1)), @@ -160,6 +157,14 @@ class Frame(object): ("actual_payload_length - " + str(self.actual_payload_length)) ]) + @classmethod + def from_bytes(cls, bytestring): + """ + Construct a websocket frame from an in-memory bytestring + to construct a frame from a stream of bytes, use from_file() directly + """ + return cls.from_file(io.BytesIO(bytestring)) + def safe_to_bytes(self): if self.is_valid(): return self.to_bytes() @@ -172,8 +177,6 @@ class Frame(object): If you haven't checked is_valid_frame() then there's no guarentees that the serialized bytes will be correct. see safe_to_bytes() """ - max_16_bit_int = (1 << 16) - max_64_bit_int = (1 << 63) # break down of the bit-math used to construct the first byte from the # frame's integer values first shift the significant bit into the @@ -199,11 +202,11 @@ class Frame(object): if self.actual_payload_length < 126: pass - elif self.actual_payload_length < max_16_bit_int: + elif self.actual_payload_length < CONST.MAX_16_BIT_INT: # '!H' pack as 16 bit unsigned short # add 2 byte extended payload length bytes += struct.pack('!H', self.actual_payload_length) - elif self.actual_payload_length < max_64_bit_int: + elif self.actual_payload_length < CONST.MAX_64_BIT_INT: # '!Q' = pack as 64 bit unsigned long long # add 8 bytes extended payload length bytes += struct.pack('!Q', self.actual_payload_length) @@ -214,17 +217,20 @@ class Frame(object): bytes += self.payload # already will be encoded if neccessary return bytes + def to_file(self, writer): + writer.write(self.to_bytes()) + writer.flush() + @classmethod - def from_byte_stream(cls, read_bytes): + def from_file(cls, reader): """ read a websockets frame sent by a server or client - - read_bytes is a function that can be backed - by sockets or by any byte reader. So this - function may be used to read frames from disk/wire/memory - """ - first_byte = utils.bytes_to_int(read_bytes(1)) - second_byte = utils.bytes_to_int(read_bytes(1)) + + reader is a "file like" object that could be backed by a network stream or a disk + or an in memory stream reader + """ + first_byte = utils.bytes_to_int(reader.read(1)) + second_byte = utils.bytes_to_int(reader.read(1)) # grab the left most bit fin = first_byte >> 7 @@ -241,18 +247,18 @@ class Frame(object): actual_payload_length = payload_length elif payload_length == 126: - actual_payload_length = utils.bytes_to_int(read_bytes(2)) + actual_payload_length = utils.bytes_to_int(reader.read(2)) elif payload_length == 127: - actual_payload_length = utils.bytes_to_int(read_bytes(8)) + actual_payload_length = utils.bytes_to_int(reader.read(8)) # masking key only present if mask bit set if mask_bit == 1: - masking_key = read_bytes(4) + masking_key = reader.read(4) else: masking_key = None - payload = read_bytes(actual_payload_length) + payload = reader.read(actual_payload_length) if mask_bit == 1: decoded_payload = apply_mask(payload, masking_key) @@ -344,7 +350,7 @@ def build_handshake(headers, request): return b'\r\n'.join(handshake) -def read_handshake(read_bytes, num_bytes_per_read): +def read_handshake(reader, num_bytes_per_read): """ From provided function that reads bytes, read in a complete HTTP request, which terminates with a CLRF @@ -352,7 +358,7 @@ def read_handshake(read_bytes, num_bytes_per_read): response = b'' doubleCLRF = b'\r\n\r\n' while True: - bytes = read_bytes(num_bytes_per_read) + bytes = reader.read(num_bytes_per_read) if not bytes: break response += bytes @@ -386,7 +392,7 @@ def process_handshake_from_client(handshake): return key -def process_handshake_from_server(handshake, client_nounce): +def process_handshake_from_server(handshake): headers = headers_from_http_message(handshake) if headers.get("Upgrade", None) != "websocket": return -- cgit v1.2.3 From 2c660d76337b11eb438a2978ec3bda3ac10babd5 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Tue, 21 Apr 2015 11:05:12 +1200 Subject: Migrate requeset reading from mitmproxy to netlib --- netlib/http.py | 124 +++++++++++++++++++++++++++++++++++++++++++++++++++++++- netlib/utils.py | 2 +- 2 files changed, 123 insertions(+), 3 deletions(-) (limited to 'netlib') diff --git a/netlib/http.py b/netlib/http.py index 26438863..aacdd1d4 100644 --- a/netlib/http.py +++ b/netlib/http.py @@ -1,7 +1,10 @@ from __future__ import (absolute_import, print_function, division) -import string, urlparse, binascii +import collections +import string +import urlparse +import binascii import sys -from . import odict, utils +from . import odict, utils, tcp class HttpError(Exception): @@ -30,6 +33,19 @@ def _is_valid_host(host): return True +def get_line(fp): + """ + Get a line, possibly preceded by a blank. + """ + line = fp.readline() + if line == "\r\n" or line == "\n": + # Possible leftover from previous message + line = fp.readline() + if line == "": + raise tcp.NetLibDisconnect() + return line + + def parse_url(url): """ Returns a (scheme, host, port, path) tuple, or None on error. @@ -436,3 +452,107 @@ def expected_http_body_size(headers, is_request, request_method, response_code): if is_request: return 0 return -1 + + +Request = collections.namedtuple( + "Request", + [ + "form_in", + "method", + "scheme", + "host", + "port", + "path", + "httpversion", + "headers", + "content" + ] +) + + +def read_request(rfile, include_body=True, body_size_limit=None, wfile=None): + """ + Parse an HTTP request from a file stream + + Args: + rfile (file): Input file to read from + include_body (bool): Read response body as well + body_size_limit (bool): Maximum body size + wfile (file): If specified, HTTP Expect headers are handled + automatically, by writing a HTTP 100 CONTINUE response to the stream. + + Returns: + Request: The HTTP request + + Raises: + HttpError: If the input is invalid. + """ + httpversion, host, port, scheme, method, path, headers, content = ( + None, None, None, None, None, None, None, None) + + request_line = get_line(rfile) + + request_line_parts = parse_init(request_line) + if not request_line_parts: + raise HttpError( + 400, + "Bad HTTP request line: %s" % repr(request_line) + ) + method, path, httpversion = request_line_parts + + if path == '*' or path.startswith("/"): + form_in = "relative" + if not utils.isascii(path): + raise HttpError( + 400, + "Bad HTTP request line: %s" % repr(request_line) + ) + elif method.upper() == 'CONNECT': + form_in = "authority" + r = parse_init_connect(request_line) + if not r: + raise HttpError( + 400, + "Bad HTTP request line: %s" % repr(request_line) + ) + host, port, _ = r + path = None + else: + form_in = "absolute" + r = parse_init_proxy(request_line) + if not r: + raise HttpError( + 400, + "Bad HTTP request line: %s" % repr(request_line) + ) + _, scheme, host, port, path, _ = r + + headers = read_headers(rfile) + if headers is None: + raise HttpError(400, "Invalid headers") + + expect_header = headers.get_first("expect") + if expect_header and expect_header.lower() == "100-continue" and httpversion >= (1, 1): + wfile.write( + 'HTTP/1.1 100 Continue\r\n' + '\r\n' + ) + wfile.flush() + del headers['expect'] + + if include_body: + content = read_http_body( + rfile, headers, body_size_limit, method, None, True + ) + + return Request( + form_in, + method, + scheme, + host, + port, + path, + httpversion, + headers, + content + ) diff --git a/netlib/utils.py b/netlib/utils.py index 03a70977..57532453 100644 --- a/netlib/utils.py +++ b/netlib/utils.py @@ -46,4 +46,4 @@ def hexdump(s): parts.append( (o, x, cleanBin(part, True)) ) - return parts \ No newline at end of file + return parts -- cgit v1.2.3 From dd7ea896f24514bb2534b3762255e99f0aabc055 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Tue, 21 Apr 2015 11:11:16 +1200 Subject: Return a named tuple from read_response --- netlib/http.py | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) (limited to 'netlib') diff --git a/netlib/http.py b/netlib/http.py index aacdd1d4..5501ce73 100644 --- a/netlib/http.py +++ b/netlib/http.py @@ -314,6 +314,18 @@ def parse_response_line(line): return (proto, code, msg) +Response = collections.namedtuple( + "Response", + [ + "httpversion", + "code", + "msg", + "headers", + "content" + ] +) + + def read_response(rfile, request_method, body_size_limit, include_body=True): """ Return an (httpversion, code, msg, headers, content) tuple. @@ -352,7 +364,7 @@ def read_response(rfile, request_method, body_size_limit, include_body=True): # if include_body==False then a None content means the body should be # read separately content = None - return httpversion, code, msg, headers, content + return Response(httpversion, code, msg, headers, content) def read_http_body(*args, **kwargs): @@ -531,8 +543,8 @@ def read_request(rfile, include_body=True, body_size_limit=None, wfile=None): if headers is None: raise HttpError(400, "Invalid headers") - expect_header = headers.get_first("expect") - if expect_header and expect_header.lower() == "100-continue" and httpversion >= (1, 1): + expect_header = headers.get_first("expect", "").lower() + if expect_header == "100-continue" and httpversion >= (1, 1): wfile.write( 'HTTP/1.1 100 Continue\r\n' '\r\n' -- cgit v1.2.3 From 7d83e388aa78bb3637f71a4afb60af1baecb0314 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Tue, 21 Apr 2015 11:19:00 +1200 Subject: Whitespace, pep8, mixed indentation --- netlib/http.py | 19 +++++++++++++++---- netlib/utils.py | 4 +++- 2 files changed, 18 insertions(+), 5 deletions(-) (limited to 'netlib') diff --git a/netlib/http.py b/netlib/http.py index 5501ce73..b925fe87 100644 --- a/netlib/http.py +++ b/netlib/http.py @@ -331,12 +331,15 @@ def read_response(rfile, request_method, body_size_limit, include_body=True): Return an (httpversion, code, msg, headers, content) tuple. By default, both response header and body are read. - If include_body=False is specified, content may be one of the following: + If include_body=False is specified, content may be one of the + following: - None, if the response is technically allowed to have a response body - - "", if the response must not have a response body (e.g. it's a response to a HEAD request) + - "", if the response must not have a response body (e.g. it's a + response to a HEAD request) """ line = rfile.readline() - if line == "\r\n" or line == "\n": # Possible leftover from previous message + # Possible leftover from previous message + if line == "\r\n" or line == "\n": line = rfile.readline() if not line: raise HttpErrorConnClosed(502, "Server disconnect.") @@ -373,7 +376,15 @@ def read_http_body(*args, **kwargs): ) -def read_http_body_chunked(rfile, headers, limit, request_method, response_code, is_request, max_chunk_size=None): +def read_http_body_chunked( + rfile, + headers, + limit, + request_method, + response_code, + is_request, + max_chunk_size=None +): """ Read an HTTP message body: diff --git a/netlib/utils.py b/netlib/utils.py index 57532453..66bbdb5e 100644 --- a/netlib/utils.py +++ b/netlib/utils.py @@ -8,9 +8,11 @@ def isascii(s): return False return True + # best way to do it in python 2.x def bytes_to_int(i): - return int(i.encode('hex'), 16) + return int(i.encode('hex'), 16) + def cleanBin(s, fixspacing=False): """ -- cgit v1.2.3 From e5f12648380cb4401f77e3cae51189ef97b603dc Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Tue, 21 Apr 2015 13:39:00 +1200 Subject: Whitespace, indentation, nounce -> nonce --- netlib/http_cookies.py | 24 ++++++++++++------------ netlib/websockets.py | 50 +++++++++++++++++++++++++------------------------- 2 files changed, 37 insertions(+), 37 deletions(-) (limited to 'netlib') diff --git a/netlib/http_cookies.py b/netlib/http_cookies.py index dab95ed0..8e245891 100644 --- a/netlib/http_cookies.py +++ b/netlib/http_cookies.py @@ -1,18 +1,18 @@ """ A flexible module for cookie parsing and manipulation. -This module differs from usual standards-compliant cookie modules in a number of -ways. We try to be as permissive as possible, and to retain even mal-formed +This module differs from usual standards-compliant cookie modules in a number +of ways. We try to be as permissive as possible, and to retain even mal-formed information. Duplicate cookies are preserved in parsing, and can be set in formatting. We do attempt to escape and quote values where needed, but will not reject data that violate the specs. Parsing accepts the formats in RFC6265 and partially RFC2109 and RFC2965. We do -not parse the comma-separated variant of Set-Cookie that allows multiple cookies -to be set in a single header. Technically this should be feasible, but it turns -out that violations of RFC6265 that makes the parsing problem indeterminate are -much more common than genuine occurences of the multi-cookie variants. -Serialization follows RFC6265. +not parse the comma-separated variant of Set-Cookie that allows multiple +cookies to be set in a single header. Technically this should be feasible, but +it turns out that violations of RFC6265 that makes the parsing problem +indeterminate are much more common than genuine occurences of the multi-cookie +variants. Serialization follows RFC6265. http://tools.ietf.org/html/rfc6265 http://tools.ietf.org/html/rfc2109 @@ -32,11 +32,11 @@ def _read_until(s, start, term): Read until one of the characters in term is reached. """ if start == len(s): - return "", start+1 + return "", start + 1 for i in range(start, len(s)): if s[i] in term: return s[start:i], i - return s[start:i+1], i+1 + return s[start:i + 1], i + 1 def _read_token(s, start): @@ -59,7 +59,7 @@ def _read_quoted_string(s, start): escaping = False ret = [] # Skip the first quote - for i in range(start+1, len(s)): + for i in range(start + 1, len(s)): if escaping: ret.append(s[i]) escaping = False @@ -70,7 +70,7 @@ def _read_quoted_string(s, start): pass else: ret.append(s[i]) - return "".join(ret), i+1 + return "".join(ret), i + 1 def _read_value(s, start, delims): @@ -103,7 +103,7 @@ def _read_pairs(s, off=0, specials=()): rhs = None if off < len(s): if s[off] == "=": - rhs, off = _read_value(s, off+1, ";") + rhs, off = _read_value(s, off + 1, ";") vals.append([lhs, rhs]) off += 1 if not off < len(s): diff --git a/netlib/websockets.py b/netlib/websockets.py index 5b9d8fbd..f2d467a5 100644 --- a/netlib/websockets.py +++ b/netlib/websockets.py @@ -67,23 +67,23 @@ class Frame(object): mask_bit, # decimal integer 1 or 0 payload_length_code, # decimal integer 1 - 127 decoded_payload, # bytestring - rsv1 = 0, # decimal integer 1 or 0 - rsv2 = 0, # decimal integer 1 or 0 - rsv3 = 0, # decimal integer 1 or 0 - payload = None, # bytestring - masking_key = None, # 32 bit byte string + rsv1 = 0, # decimal integer 1 or 0 + rsv2 = 0, # decimal integer 1 or 0 + rsv3 = 0, # decimal integer 1 or 0 + payload = None, # bytestring + masking_key = None, # 32 bit byte string actual_payload_length = None, # any decimal integer ): - self.fin = fin - self.rsv1 = rsv1 - self.rsv2 = rsv2 - self.rsv3 = rsv3 - self.opcode = opcode - self.mask_bit = mask_bit - self.payload_length_code = payload_length_code - self.masking_key = masking_key - self.payload = payload - self.decoded_payload = decoded_payload + self.fin = fin + self.rsv1 = rsv1 + self.rsv2 = rsv2 + self.rsv3 = rsv3 + self.opcode = opcode + self.mask_bit = mask_bit + self.payload_length_code = payload_length_code + self.masking_key = masking_key + self.payload = payload + self.decoded_payload = decoded_payload self.actual_payload_length = actual_payload_length @classmethod @@ -162,7 +162,7 @@ class Frame(object): """ Construct a websocket frame from an in-memory bytestring to construct a frame from a stream of bytes, use from_file() directly - """ + """ return cls.from_file(io.BytesIO(bytestring)) def safe_to_bytes(self): @@ -206,7 +206,7 @@ class Frame(object): # '!H' pack as 16 bit unsigned short # add 2 byte extended payload length bytes += struct.pack('!H', self.actual_payload_length) - elif self.actual_payload_length < CONST.MAX_64_BIT_INT: + elif self.actual_payload_length < CONST.MAX_64_BIT_INT: # '!Q' = pack as 64 bit unsigned long long # add 8 bytes extended payload length bytes += struct.pack('!Q', self.actual_payload_length) @@ -225,10 +225,10 @@ class Frame(object): def from_file(cls, reader): """ read a websockets frame sent by a server or client - - reader is a "file like" object that could be backed by a network stream or a disk - or an in memory stream reader - """ + + reader is a "file like" object that could be backed by a network + stream or a disk or an in memory stream reader + """ first_byte = utils.bytes_to_int(reader.read(1)) second_byte = utils.bytes_to_int(reader.read(1)) @@ -336,7 +336,7 @@ def create_server_handshake(key): headers = [ ('Connection', 'Upgrade'), ('Upgrade', 'websocket'), - ('Sec-WebSocket-Accept', create_server_nounce(key)) + ('Sec-WebSocket-Accept', create_server_nonce(key)) ] request = "HTTP/1.1 101 Switching Protocols" return build_handshake(headers, request) @@ -406,11 +406,11 @@ def headers_from_http_message(http_message): ) -def create_server_nounce(client_nounce): +def create_server_nonce(client_nonce): return base64.b64encode( - hashlib.sha1(client_nounce + websockets_magic).hexdigest().decode('hex') + hashlib.sha1(client_nonce + websockets_magic).hexdigest().decode('hex') ) -def create_client_nounce(): +def create_client_nonce(): return base64.b64encode(os.urandom(16)).decode('utf-8') -- cgit v1.2.3 From 3e0a71ea345131a5f2dcc9581a7d93b8ebe09b13 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Tue, 21 Apr 2015 22:39:45 +1200 Subject: websockets: refactor to use http and header functions in http.py --- netlib/http.py | 126 ++++++++++++++++++++++++++++----------------------- netlib/websockets.py | 108 ++++++++++++++----------------------------- 2 files changed, 104 insertions(+), 130 deletions(-) (limited to 'netlib') diff --git a/netlib/http.py b/netlib/http.py index b925fe87..fe27240a 100644 --- a/netlib/http.py +++ b/netlib/http.py @@ -4,7 +4,7 @@ import string import urlparse import binascii import sys -from . import odict, utils, tcp +from . import odict, utils, tcp, http_status class HttpError(Exception): @@ -314,62 +314,6 @@ def parse_response_line(line): return (proto, code, msg) -Response = collections.namedtuple( - "Response", - [ - "httpversion", - "code", - "msg", - "headers", - "content" - ] -) - - -def read_response(rfile, request_method, body_size_limit, include_body=True): - """ - Return an (httpversion, code, msg, headers, content) tuple. - - By default, both response header and body are read. - If include_body=False is specified, content may be one of the - following: - - None, if the response is technically allowed to have a response body - - "", if the response must not have a response body (e.g. it's a - response to a HEAD request) - """ - line = rfile.readline() - # Possible leftover from previous message - if line == "\r\n" or line == "\n": - line = rfile.readline() - if not line: - raise HttpErrorConnClosed(502, "Server disconnect.") - parts = parse_response_line(line) - if not parts: - raise HttpError(502, "Invalid server response: %s" % repr(line)) - proto, code, msg = parts - httpversion = parse_http_protocol(proto) - if httpversion is None: - raise HttpError(502, "Invalid HTTP version in line: %s" % repr(proto)) - headers = read_headers(rfile) - if headers is None: - raise HttpError(502, "Invalid headers.") - - if include_body: - content = read_http_body( - rfile, - headers, - body_size_limit, - request_method, - code, - False - ) - else: - # if include_body==False then a None content means the body should be - # read separately - content = None - return Response(httpversion, code, msg, headers, content) - - def read_http_body(*args, **kwargs): return "".join( content for _, content, _ in read_http_body_chunked(*args, **kwargs) @@ -579,3 +523,71 @@ def read_request(rfile, include_body=True, body_size_limit=None, wfile=None): headers, content ) + + +Response = collections.namedtuple( + "Response", + [ + "httpversion", + "code", + "msg", + "headers", + "content" + ] +) + + +def read_response(rfile, request_method, body_size_limit, include_body=True): + """ + Return an (httpversion, code, msg, headers, content) tuple. + + By default, both response header and body are read. + If include_body=False is specified, content may be one of the + following: + - None, if the response is technically allowed to have a response body + - "", if the response must not have a response body (e.g. it's a + response to a HEAD request) + """ + line = rfile.readline() + # Possible leftover from previous message + if line == "\r\n" or line == "\n": + line = rfile.readline() + if not line: + raise HttpErrorConnClosed(502, "Server disconnect.") + parts = parse_response_line(line) + if not parts: + raise HttpError(502, "Invalid server response: %s" % repr(line)) + proto, code, msg = parts + httpversion = parse_http_protocol(proto) + if httpversion is None: + raise HttpError(502, "Invalid HTTP version in line: %s" % repr(proto)) + headers = read_headers(rfile) + if headers is None: + raise HttpError(502, "Invalid headers.") + + if include_body: + content = read_http_body( + rfile, + headers, + body_size_limit, + request_method, + code, + False + ) + else: + # if include_body==False then a None content means the body should be + # read separately + content = None + return Response(httpversion, code, msg, headers, content) + + +def request_preamble(method, resource, http_major="1", http_minor="1"): + return '%s %s HTTP/%s.%s' % ( + method, resource, http_major, http_minor + ) + + +def response_preamble(code, message=None, http_major="1", http_minor="1"): + if message is None: + message = http_status.RESPONSES.get(code) + return 'HTTP/%s.%s %s %s' % (http_major, http_minor, code, message) diff --git a/netlib/websockets.py b/netlib/websockets.py index f2d467a5..a03185fa 100644 --- a/netlib/websockets.py +++ b/netlib/websockets.py @@ -2,13 +2,11 @@ from __future__ import absolute_import import base64 import hashlib -import mimetools -import StringIO import os import struct import io -from . import utils +from . import utils, odict # Colleciton of utility functions that implement small portions of the RFC6455 # WebSockets Protocol Useful for building WebSocket clients and servers. @@ -23,6 +21,7 @@ from . import utils # The magic sha that websocket servers must know to prove they understand # RFC6455 websockets_magic = '258EAFA5-E914-47DA-95CA-C5AB0DC85B11' +VERSION = "13" class CONST(object): @@ -151,9 +150,9 @@ class Frame(object): ("opcode - " + str(self.opcode)), ("mask_bit - " + str(self.mask_bit)), ("payload_length_code - " + str(self.payload_length_code)), - ("masking_key - " + str(self.masking_key)), - ("payload - " + str(self.payload)), - ("decoded_payload - " + str(self.decoded_payload)), + ("masking_key - " + repr(str(self.masking_key))), + ("payload - " + repr(str(self.payload))), + ("decoded_payload - " + repr(str(self.decoded_payload))), ("actual_payload_length - " + str(self.actual_payload_length)) ]) @@ -198,24 +197,24 @@ class Frame(object): second_byte = (self.mask_bit << 7) | self.payload_length_code - bytes = chr(first_byte) + chr(second_byte) + b = chr(first_byte) + chr(second_byte) if self.actual_payload_length < 126: pass elif self.actual_payload_length < CONST.MAX_16_BIT_INT: # '!H' pack as 16 bit unsigned short # add 2 byte extended payload length - bytes += struct.pack('!H', self.actual_payload_length) + b += struct.pack('!H', self.actual_payload_length) elif self.actual_payload_length < CONST.MAX_64_BIT_INT: # '!Q' = pack as 64 bit unsigned long long # add 8 bytes extended payload length - bytes += struct.pack('!Q', self.actual_payload_length) + b += struct.pack('!Q', self.actual_payload_length) if self.masking_key is not None: - bytes += self.masking_key + b += self.masking_key - bytes += self.payload # already will be encoded if neccessary - return bytes + b += self.payload # already will be encoded if neccessary + return b def to_file(self, writer): writer.write(self.to_bytes()) @@ -313,58 +312,35 @@ def random_masking_key(): return os.urandom(4) -def create_client_handshake(host, port, key, version, resource): +def client_handshake_headers(key=None, version=VERSION): """ - WebSockets connections are intiated by the client with a valid HTTP - upgrade request + Create the headers for a valid HTTP upgrade request. If Key is not + specified, it is generated, and can be found in sec-websocket-key in + the returned header set. + + Returns an instance of ODictCaseless """ - headers = [ - ('Host', '%s:%s' % (host, port)), + if not key: + key = base64.b64encode(os.urandom(16)).decode('utf-8') + return odict.ODictCaseless([ ('Connection', 'Upgrade'), ('Upgrade', 'websocket'), ('Sec-WebSocket-Key', key), ('Sec-WebSocket-Version', version) - ] - request = "GET %s HTTP/1.1" % resource - return build_handshake(headers, request) + ]) -def create_server_handshake(key): +def server_handshake_headers(key): """ The server response is a valid HTTP 101 response. """ - headers = [ - ('Connection', 'Upgrade'), - ('Upgrade', 'websocket'), - ('Sec-WebSocket-Accept', create_server_nonce(key)) - ] - request = "HTTP/1.1 101 Switching Protocols" - return build_handshake(headers, request) - - -def build_handshake(headers, request): - handshake = [request.encode('utf-8')] - for header, value in headers: - handshake.append(("%s: %s" % (header, value)).encode('utf-8')) - handshake.append(b'\r\n') - return b'\r\n'.join(handshake) - - -def read_handshake(reader, num_bytes_per_read): - """ - From provided function that reads bytes, read in a - complete HTTP request, which terminates with a CLRF - """ - response = b'' - doubleCLRF = b'\r\n\r\n' - while True: - bytes = reader.read(num_bytes_per_read) - if not bytes: - break - response += bytes - if doubleCLRF in response: - break - return response + return odict.ODictCaseless( + [ + ('Connection', 'Upgrade'), + ('Upgrade', 'websocket'), + ('Sec-WebSocket-Accept', create_server_nonce(key)) + ] + ) def get_payload_length_pair(payload_bytestring): @@ -384,33 +360,19 @@ def get_payload_length_pair(payload_bytestring): return (length_code, actual_length) -def process_handshake_from_client(handshake): - headers = headers_from_http_message(handshake) - if headers.get("Upgrade", None) != "websocket": +def check_client_handshake(req): + if req.headers.get_first("upgrade", None) != "websocket": return - key = headers['Sec-WebSocket-Key'] - return key + return req.headers.get_first('sec-websocket-key') -def process_handshake_from_server(handshake): - headers = headers_from_http_message(handshake) - if headers.get("Upgrade", None) != "websocket": +def check_server_handshake(resp): + if resp.headers.get_first("upgrade", None) != "websocket": return - key = headers['Sec-WebSocket-Accept'] - return key - - -def headers_from_http_message(http_message): - return mimetools.Message( - StringIO.StringIO(http_message.split('\r\n', 1)[1]) - ) + return resp.headers.get_first('sec-websocket-accept') def create_server_nonce(client_nonce): return base64.b64encode( hashlib.sha1(client_nonce + websockets_magic).hexdigest().decode('hex') ) - - -def create_client_nonce(): - return base64.b64encode(os.urandom(16)).decode('utf-8') -- cgit v1.2.3 From 1b509d5aea31a636b6c8ce854e0dd685e34d03de Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Tue, 21 Apr 2015 22:51:01 +1200 Subject: Whitespace, interface simplification - safe_tobytes doesn't buy us much - move masking key generation inline --- netlib/websockets.py | 17 ++--------------- 1 file changed, 2 insertions(+), 15 deletions(-) (limited to 'netlib') diff --git a/netlib/websockets.py b/netlib/websockets.py index a03185fa..0cd4dba1 100644 --- a/netlib/websockets.py +++ b/netlib/websockets.py @@ -29,10 +29,6 @@ class CONST(object): MAX_64_BIT_INT = (1 << 64) -class WebSocketFrameValidationException(Exception): - pass - - class Frame(object): """ Represents one websockets frame. @@ -95,7 +91,8 @@ class Frame(object): if from_client: mask_bit = 1 - masking_key = random_masking_key() + # Random masking key + masking_key = os.urandom(4) payload = apply_mask(message, masking_key) else: mask_bit = 0 @@ -164,12 +161,6 @@ class Frame(object): """ return cls.from_file(io.BytesIO(bytestring)) - def safe_to_bytes(self): - if self.is_valid(): - return self.to_bytes() - else: - raise WebSocketFrameValidationException() - def to_bytes(self): """ Serialize the frame back into the wire format, returns a bytestring @@ -308,10 +299,6 @@ def apply_mask(message, masking_key): return result -def random_masking_key(): - return os.urandom(4) - - def client_handshake_headers(key=None, version=VERSION): """ Create the headers for a valid HTTP upgrade request. If Key is not -- cgit v1.2.3 From 176e29fc094119b036ba76d6e5cc1f2d7fb838e0 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Tue, 21 Apr 2015 23:13:42 +1200 Subject: websockets: constants, variable names, refactoring --- netlib/websockets.py | 75 ++++++++++++++++++++++++++++------------------------ 1 file changed, 40 insertions(+), 35 deletions(-) (limited to 'netlib') diff --git a/netlib/websockets.py b/netlib/websockets.py index 0cd4dba1..1e9c96cc 100644 --- a/netlib/websockets.py +++ b/netlib/websockets.py @@ -22,11 +22,17 @@ from . import utils, odict # RFC6455 websockets_magic = '258EAFA5-E914-47DA-95CA-C5AB0DC85B11' VERSION = "13" +MAX_16_BIT_INT = (1 << 16) +MAX_64_BIT_INT = (1 << 64) -class CONST(object): - MAX_16_BIT_INT = (1 << 16) - MAX_64_BIT_INT = (1 << 64) +class OPCODE: + CONTINUE = 0x00 + TEXT = 0x01 + BINARY = 0x02 + CLOSE = 0x08 + PING = 0x09 + PONG = 0x0a class Frame(object): @@ -101,7 +107,7 @@ class Frame(object): return cls( fin = 1, # final frame - opcode = 1, # text + opcode = OPCODE.TEXT, # text mask_bit = mask_bit, payload_length_code = length_code, payload = payload, @@ -115,28 +121,27 @@ class Frame(object): Validate websocket frame invariants, call at anytime to ensure the Frame has not been corrupted. """ - try: - assert 0 <= self.fin <= 1 - assert 0 <= self.rsv1 <= 1 - assert 0 <= self.rsv2 <= 1 - assert 0 <= self.rsv3 <= 1 - assert 1 <= self.opcode <= 4 - assert 0 <= self.mask_bit <= 1 - assert 1 <= self.payload_length_code <= 127 - - if self.mask_bit == 1: - assert 1 <= len(self.masking_key) <= 4 - else: - assert self.masking_key is None - - assert self.actual_payload_length == len(self.payload) - - if self.payload is not None and self.masking_key is not None: - assert apply_mask(self.payload, self.masking_key) == self.decoded_payload - - return True - except AssertionError: + constraints = [ + 0 <= self.fin <= 1, + 0 <= self.rsv1 <= 1, + 0 <= self.rsv2 <= 1, + 0 <= self.rsv3 <= 1, + 1 <= self.opcode <= 4, + 0 <= self.mask_bit <= 1, + 1 <= self.payload_length_code <= 127, + self.actual_payload_length == len(self.payload) + ] + if not all(constraints): + return False + elif self.mask_bit == 1 and not 1 <= len(self.masking_key) <= 4: + return False + elif self.mask_bit == 0 and self.masking_key is not None: return False + elif self.payload and self.masking_key: + decoded = apply_mask(self.payload, self.masking_key) + if decoded != self.decoded_payload: + return False + return True def human_readable(self): # pragma: nocover return "\n".join([ @@ -192,11 +197,11 @@ class Frame(object): if self.actual_payload_length < 126: pass - elif self.actual_payload_length < CONST.MAX_16_BIT_INT: + elif self.actual_payload_length < MAX_16_BIT_INT: # '!H' pack as 16 bit unsigned short # add 2 byte extended payload length b += struct.pack('!H', self.actual_payload_length) - elif self.actual_payload_length < CONST.MAX_64_BIT_INT: + elif self.actual_payload_length < MAX_64_BIT_INT: # '!Q' = pack as 64 bit unsigned long long # add 8 bytes extended payload length b += struct.pack('!Q', self.actual_payload_length) @@ -212,15 +217,15 @@ class Frame(object): writer.flush() @classmethod - def from_file(cls, reader): + def from_file(cls, fp): """ read a websockets frame sent by a server or client - reader is a "file like" object that could be backed by a network + fp is a "file like" object that could be backed by a network stream or a disk or an in memory stream reader """ - first_byte = utils.bytes_to_int(reader.read(1)) - second_byte = utils.bytes_to_int(reader.read(1)) + first_byte = utils.bytes_to_int(fp.read(1)) + second_byte = utils.bytes_to_int(fp.read(1)) # grab the left most bit fin = first_byte >> 7 @@ -237,18 +242,18 @@ class Frame(object): actual_payload_length = payload_length elif payload_length == 126: - actual_payload_length = utils.bytes_to_int(reader.read(2)) + actual_payload_length = utils.bytes_to_int(fp.read(2)) elif payload_length == 127: - actual_payload_length = utils.bytes_to_int(reader.read(8)) + actual_payload_length = utils.bytes_to_int(fp.read(8)) # masking key only present if mask bit set if mask_bit == 1: - masking_key = reader.read(4) + masking_key = fp.read(4) else: masking_key = None - payload = reader.read(actual_payload_length) + payload = fp.read(actual_payload_length) if mask_bit == 1: decoded_payload = apply_mask(payload, masking_key) -- cgit v1.2.3 From 4fb49c8e55cc3c64ac0d5cf8fb913518f1973162 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Tue, 21 Apr 2015 23:49:27 +1200 Subject: websockets: (very) slightly nicer is_valid constraints --- netlib/websockets.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) (limited to 'netlib') diff --git a/netlib/websockets.py b/netlib/websockets.py index 1e9c96cc..d5c5c2fe 100644 --- a/netlib/websockets.py +++ b/netlib/websockets.py @@ -129,14 +129,12 @@ class Frame(object): 1 <= self.opcode <= 4, 0 <= self.mask_bit <= 1, 1 <= self.payload_length_code <= 127, - self.actual_payload_length == len(self.payload) + self.actual_payload_length == len(self.payload), + 1 <= len(self.masking_key) <= 4 if self.mask_bit else True, + self.masking_key is not None if self.mask_bit else True ] if not all(constraints): return False - elif self.mask_bit == 1 and not 1 <= len(self.masking_key) <= 4: - return False - elif self.mask_bit == 0 and self.masking_key is not None: - return False elif self.payload and self.masking_key: decoded = apply_mask(self.payload, self.masking_key) if decoded != self.decoded_payload: -- cgit v1.2.3 From 42a87a1d8b3eeccfdd8e5e504f1cd4d90ae1dbfb Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Thu, 23 Apr 2015 08:23:51 +1200 Subject: websockets: handshake checks only take headers --- netlib/http.py | 8 ++++---- netlib/websockets.py | 12 ++++++------ 2 files changed, 10 insertions(+), 10 deletions(-) (limited to 'netlib') diff --git a/netlib/http.py b/netlib/http.py index fe27240a..43155486 100644 --- a/netlib/http.py +++ b/netlib/http.py @@ -33,7 +33,7 @@ def _is_valid_host(host): return True -def get_line(fp): +def get_request_line(fp): """ Get a line, possibly preceded by a blank. """ @@ -41,8 +41,6 @@ def get_line(fp): if line == "\r\n" or line == "\n": # Possible leftover from previous message line = fp.readline() - if line == "": - raise tcp.NetLibDisconnect() return line @@ -457,7 +455,9 @@ def read_request(rfile, include_body=True, body_size_limit=None, wfile=None): httpversion, host, port, scheme, method, path, headers, content = ( None, None, None, None, None, None, None, None) - request_line = get_line(rfile) + request_line = get_request_line(rfile) + if not request_line: + raise tcp.NetLibDisconnect() request_line_parts = parse_init(request_line) if not request_line_parts: diff --git a/netlib/websockets.py b/netlib/websockets.py index d5c5c2fe..da03768d 100644 --- a/netlib/websockets.py +++ b/netlib/websockets.py @@ -350,16 +350,16 @@ def get_payload_length_pair(payload_bytestring): return (length_code, actual_length) -def check_client_handshake(req): - if req.headers.get_first("upgrade", None) != "websocket": +def check_client_handshake(headers): + if headers.get_first("upgrade", None) != "websocket": return - return req.headers.get_first('sec-websocket-key') + return headers.get_first('sec-websocket-key') -def check_server_handshake(resp): - if resp.headers.get_first("upgrade", None) != "websocket": +def check_server_handshake(headers): + if headers.get_first("upgrade", None) != "websocket": return - return resp.headers.get_first('sec-websocket-accept') + return headers.get_first('sec-websocket-accept') def create_server_nonce(client_nonce): -- cgit v1.2.3 From bdd52fead339e634022a2251bb2bd85a924ca8d2 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Fri, 24 Apr 2015 08:47:09 +1200 Subject: websockets: extract frame header creation into a function --- netlib/websockets.py | 263 ++++++++++++++++++++++++++++----------------------- 1 file changed, 143 insertions(+), 120 deletions(-) (limited to 'netlib') diff --git a/netlib/websockets.py b/netlib/websockets.py index da03768d..abf86262 100644 --- a/netlib/websockets.py +++ b/netlib/websockets.py @@ -35,6 +35,139 @@ class OPCODE: PONG = 0x0a +def apply_mask(message, masking_key): + """ + Data sent from the server must be masked to prevent malicious clients + from sending data over the wire in predictable patterns + + This method both encodes and decodes strings with the provided mask + + Servers do not have to mask data they send to the client. + https://tools.ietf.org/html/rfc6455#section-5.3 + """ + masks = [utils.bytes_to_int(byte) for byte in masking_key] + result = "" + for char in message: + result += chr(ord(char) ^ masks[len(result) % 4]) + return result + + +def client_handshake_headers(key=None, version=VERSION): + """ + Create the headers for a valid HTTP upgrade request. If Key is not + specified, it is generated, and can be found in sec-websocket-key in + the returned header set. + + Returns an instance of ODictCaseless + """ + if not key: + key = base64.b64encode(os.urandom(16)).decode('utf-8') + return odict.ODictCaseless([ + ('Connection', 'Upgrade'), + ('Upgrade', 'websocket'), + ('Sec-WebSocket-Key', key), + ('Sec-WebSocket-Version', version) + ]) + + +def server_handshake_headers(key): + """ + The server response is a valid HTTP 101 response. + """ + return odict.ODictCaseless( + [ + ('Connection', 'Upgrade'), + ('Upgrade', 'websocket'), + ('Sec-WebSocket-Accept', create_server_nonce(key)) + ] + ) + + +def get_payload_length_pair(payload_bytestring): + """ + A websockets frame contains an initial length_code, and an optional + extended length code to represent the actual length if length code is + larger than 125 + """ + actual_length = len(payload_bytestring) + + if actual_length <= 125: + length_code = actual_length + elif actual_length >= 126 and actual_length <= 65535: + length_code = 126 + else: + length_code = 127 + return (length_code, actual_length) + + +def make_length_code(len): + """ + A websockets frame contains an initial length_code, and an optional + extended length code to represent the actual length if length code is + larger than 125 + """ + if len <= 125: + return len + elif len >= 126 and len <= 65535: + return 126 + else: + return 127 + + +def check_client_handshake(headers): + if headers.get_first("upgrade", None) != "websocket": + return + return headers.get_first('sec-websocket-key') + + +def check_server_handshake(headers): + if headers.get_first("upgrade", None) != "websocket": + return + return headers.get_first('sec-websocket-accept') + + +def create_server_nonce(client_nonce): + return base64.b64encode( + hashlib.sha1(client_nonce + websockets_magic).hexdigest().decode('hex') + ) + + +def frame_header_bytes( + opcode = 0, + payload_length = 0, + fin = 0, + rsv1 = 0, + rsv2 = 0, + rsv3 = 0, + mask = 0, + masking_key = None, + length_code = None +): + first_byte = (fin << 7) | (rsv1 << 6) |\ + (rsv2 << 4) | (rsv3 << 4) | opcode + + if length_code is None: + length_code = make_length_code(payload_length) + + second_byte = (mask << 7) | length_code + + b = chr(first_byte) + chr(second_byte) + + if payload_length < 126: + pass + elif payload_length < MAX_16_BIT_INT: + # '!H' pack as 16 bit unsigned short + # add 2 byte extended payload length + b += struct.pack('!H', payload_length) + elif payload_length < MAX_64_BIT_INT: + # '!Q' = pack as 64 bit unsigned long long + # add 8 bytes extended payload length + b += struct.pack('!Q', payload_length) + if masking_key is not None: + b += masking_key + return b + + class Frame(object): """ Represents one websockets frame. @@ -170,43 +303,16 @@ class Frame(object): If you haven't checked is_valid_frame() then there's no guarentees that the serialized bytes will be correct. see safe_to_bytes() """ - - # break down of the bit-math used to construct the first byte from the - # frame's integer values first shift the significant bit into the - # correct position - # 00000001 << 7 = 10000000 - # ... - # then combine: - # - # 10000000 fin - # 01000000 res1 - # 00100000 res2 - # 00010000 res3 - # 00000001 opcode - # -------- OR - # 11110001 = first_byte - - first_byte = (self.fin << 7) | (self.rsv1 << 6) |\ - (self.rsv2 << 4) | (self.rsv3 << 4) | self.opcode - - second_byte = (self.mask_bit << 7) | self.payload_length_code - - b = chr(first_byte) + chr(second_byte) - - if self.actual_payload_length < 126: - pass - elif self.actual_payload_length < MAX_16_BIT_INT: - # '!H' pack as 16 bit unsigned short - # add 2 byte extended payload length - b += struct.pack('!H', self.actual_payload_length) - elif self.actual_payload_length < MAX_64_BIT_INT: - # '!Q' = pack as 64 bit unsigned long long - # add 8 bytes extended payload length - b += struct.pack('!Q', self.actual_payload_length) - - if self.masking_key is not None: - b += self.masking_key - + b = frame_header_bytes( + opcode = self.opcode, + fin = self.fin, + rsv1 = self.rsv1, + rsv2 = self.rsv2, + rsv3 = self.rsv3, + mask = self.mask_bit, + masking_key = self.masking_key, + payload_length = self.actual_payload_length + ) b += self.payload # already will be encoded if neccessary return b @@ -283,86 +389,3 @@ class Frame(object): self.decoded_payload == other.decoded_payload and self.actual_payload_length == other.actual_payload_length ) - - -def apply_mask(message, masking_key): - """ - Data sent from the server must be masked to prevent malicious clients - from sending data over the wire in predictable patterns - - This method both encodes and decodes strings with the provided mask - - Servers do not have to mask data they send to the client. - https://tools.ietf.org/html/rfc6455#section-5.3 - """ - masks = [utils.bytes_to_int(byte) for byte in masking_key] - result = "" - for char in message: - result += chr(ord(char) ^ masks[len(result) % 4]) - return result - - -def client_handshake_headers(key=None, version=VERSION): - """ - Create the headers for a valid HTTP upgrade request. If Key is not - specified, it is generated, and can be found in sec-websocket-key in - the returned header set. - - Returns an instance of ODictCaseless - """ - if not key: - key = base64.b64encode(os.urandom(16)).decode('utf-8') - return odict.ODictCaseless([ - ('Connection', 'Upgrade'), - ('Upgrade', 'websocket'), - ('Sec-WebSocket-Key', key), - ('Sec-WebSocket-Version', version) - ]) - - -def server_handshake_headers(key): - """ - The server response is a valid HTTP 101 response. - """ - return odict.ODictCaseless( - [ - ('Connection', 'Upgrade'), - ('Upgrade', 'websocket'), - ('Sec-WebSocket-Accept', create_server_nonce(key)) - ] - ) - - -def get_payload_length_pair(payload_bytestring): - """ - A websockets frame contains an initial length_code, and an optional - extended length code to represent the actual length if length code is - larger than 125 - """ - actual_length = len(payload_bytestring) - - if actual_length <= 125: - length_code = actual_length - elif actual_length >= 126 and actual_length <= 65535: - length_code = 126 - else: - length_code = 127 - return (length_code, actual_length) - - -def check_client_handshake(headers): - if headers.get_first("upgrade", None) != "websocket": - return - return headers.get_first('sec-websocket-key') - - -def check_server_handshake(headers): - if headers.get_first("upgrade", None) != "websocket": - return - return headers.get_first('sec-websocket-accept') - - -def create_server_nonce(client_nonce): - return base64.b64encode( - hashlib.sha1(client_nonce + websockets_magic).hexdigest().decode('hex') - ) -- cgit v1.2.3 From 3519871f340cb0466fc6935d6e8e3b7822d36c52 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Fri, 24 Apr 2015 09:21:04 +1200 Subject: websockets: refactor to avoid rundantly specifying payloads and payload lengths --- netlib/websockets.py | 60 ++++++++++++++++++++-------------------------------- 1 file changed, 23 insertions(+), 37 deletions(-) (limited to 'netlib') diff --git a/netlib/websockets.py b/netlib/websockets.py index abf86262..7c127563 100644 --- a/netlib/websockets.py +++ b/netlib/websockets.py @@ -198,15 +198,13 @@ class Frame(object): self, fin, # decmial integer 1 or 0 opcode, # decmial integer 1 - 4 - mask_bit, # decimal integer 1 or 0 - payload_length_code, # decimal integer 1 - 127 - decoded_payload, # bytestring + payload = "", # bytestring + masking_key = None, # 32 bit byte string + mask_bit = 0, # decimal integer 1 or 0 + payload_length_code = None, # decimal integer 1 - 127 rsv1 = 0, # decimal integer 1 or 0 rsv2 = 0, # decimal integer 1 or 0 rsv3 = 0, # decimal integer 1 or 0 - payload = None, # bytestring - masking_key = None, # 32 bit byte string - actual_payload_length = None, # any decimal integer ): self.fin = fin self.rsv1 = rsv1 @@ -217,8 +215,6 @@ class Frame(object): self.payload_length_code = payload_length_code self.masking_key = masking_key self.payload = payload - self.decoded_payload = decoded_payload - self.actual_payload_length = actual_payload_length @classmethod def default(cls, message, from_client = False): @@ -226,27 +222,19 @@ class Frame(object): Construct a basic websocket frame from some default values. Creates a non-fragmented text frame. """ - length_code, actual_length = get_payload_length_pair(message) - if from_client: mask_bit = 1 - # Random masking key masking_key = os.urandom(4) - payload = apply_mask(message, masking_key) else: mask_bit = 0 masking_key = None - payload = message return cls( fin = 1, # final frame opcode = OPCODE.TEXT, # text mask_bit = mask_bit, - payload_length_code = length_code, - payload = payload, + payload = message, masking_key = masking_key, - decoded_payload = message, - actual_payload_length = actual_length ) def is_valid(self): @@ -261,17 +249,12 @@ class Frame(object): 0 <= self.rsv3 <= 1, 1 <= self.opcode <= 4, 0 <= self.mask_bit <= 1, - 1 <= self.payload_length_code <= 127, - self.actual_payload_length == len(self.payload), + #1 <= self.payload_length_code <= 127, 1 <= len(self.masking_key) <= 4 if self.mask_bit else True, self.masking_key is not None if self.mask_bit else True ] if not all(constraints): return False - elif self.payload and self.masking_key: - decoded = apply_mask(self.payload, self.masking_key) - if decoded != self.decoded_payload: - return False return True def human_readable(self): # pragma: nocover @@ -285,8 +268,6 @@ class Frame(object): ("payload_length_code - " + str(self.payload_length_code)), ("masking_key - " + repr(str(self.masking_key))), ("payload - " + repr(str(self.payload))), - ("decoded_payload - " + repr(str(self.decoded_payload))), - ("actual_payload_length - " + str(self.actual_payload_length)) ]) @classmethod @@ -311,9 +292,12 @@ class Frame(object): rsv3 = self.rsv3, mask = self.mask_bit, masking_key = self.masking_key, - payload_length = self.actual_payload_length + payload_length = len(self.payload) if self.payload else 0 ) - b += self.payload # already will be encoded if neccessary + if self.masking_key: + b += apply_mask(self.payload, self.masking_key) + else: + b += self.payload return b def to_file(self, writer): @@ -359,10 +343,8 @@ class Frame(object): payload = fp.read(actual_payload_length) - if mask_bit == 1: - decoded_payload = apply_mask(payload, masking_key) - else: - decoded_payload = payload + if mask_bit == 1 and masking_key: + payload = apply_mask(payload, masking_key) return cls( fin = fin, @@ -371,11 +353,17 @@ class Frame(object): payload_length_code = payload_length, payload = payload, masking_key = masking_key, - decoded_payload = decoded_payload, - actual_payload_length = actual_payload_length ) def __eq__(self, other): + if self.payload_length_code is None: + myplc = make_length_code(len(self.payload)) + else: + myplc = self.payload_length_code + if other.payload_length_code is None: + otherplc = make_length_code(len(other.payload)) + else: + otherplc = other.payload_length_code return ( self.fin == other.fin and self.rsv1 == other.rsv1 and @@ -383,9 +371,7 @@ class Frame(object): self.rsv3 == other.rsv3 and self.opcode == other.opcode and self.mask_bit == other.mask_bit and - self.payload_length_code == other.payload_length_code and self.masking_key == other.masking_key and - self.payload == other.payload and - self.decoded_payload == other.decoded_payload and - self.actual_payload_length == other.actual_payload_length + self.payload == other.payload, + myplc == otherplc ) -- cgit v1.2.3 From f22bc0b4c74776bcc312fed1f4ceede83f869a6e Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Fri, 24 Apr 2015 15:09:21 +1200 Subject: websocket: interface refactoring - Separate out FrameHeader. We need to deal with this separately in many circumstances. - Simpler equality scheme. - Bits are now specified by truthiness - we don't care about the integer value. This means lots of validation is not needed any more. --- netlib/utils.py | 16 +++ netlib/websockets.py | 303 ++++++++++++++++++++++++--------------------------- 2 files changed, 159 insertions(+), 160 deletions(-) (limited to 'netlib') diff --git a/netlib/utils.py b/netlib/utils.py index 66bbdb5e..44bed43a 100644 --- a/netlib/utils.py +++ b/netlib/utils.py @@ -49,3 +49,19 @@ def hexdump(s): (o, x, cleanBin(part, True)) ) return parts + + +def setbit(byte, offset, value): + """ + Set a bit in a byte to 1 if value is truthy, 0 if not. + """ + if value: + return byte | (1 << offset) + else: + return byte & ~(1 << offset) + + +def getbit(byte, offset): + mask = 1 << offset + if byte & mask: + return True diff --git a/netlib/websockets.py b/netlib/websockets.py index 7c127563..016e75c2 100644 --- a/netlib/websockets.py +++ b/netlib/websockets.py @@ -1,5 +1,4 @@ from __future__ import absolute_import - import base64 import hashlib import os @@ -83,23 +82,6 @@ def server_handshake_headers(key): ) -def get_payload_length_pair(payload_bytestring): - """ - A websockets frame contains an initial length_code, and an optional - extended length code to represent the actual length if length code is - larger than 125 - """ - actual_length = len(payload_bytestring) - - if actual_length <= 125: - length_code = actual_length - elif actual_length >= 126 and actual_length <= 65535: - length_code = 126 - else: - length_code = 127 - return (length_code, actual_length) - - def make_length_code(len): """ A websockets frame contains an initial length_code, and an optional @@ -132,40 +114,113 @@ def create_server_nonce(client_nonce): ) -def frame_header_bytes( - opcode = 0, - payload_length = 0, - fin = 0, - rsv1 = 0, - rsv2 = 0, - rsv3 = 0, - mask = 0, - masking_key = None, - length_code = None -): - first_byte = (fin << 7) | (rsv1 << 6) |\ - (rsv2 << 4) | (rsv3 << 4) | opcode - - if length_code is None: - length_code = make_length_code(payload_length) - - second_byte = (mask << 7) | length_code - - b = chr(first_byte) + chr(second_byte) - - if payload_length < 126: - pass - elif payload_length < MAX_16_BIT_INT: - # '!H' pack as 16 bit unsigned short - # add 2 byte extended payload length - b += struct.pack('!H', payload_length) - elif payload_length < MAX_64_BIT_INT: - # '!Q' = pack as 64 bit unsigned long long - # add 8 bytes extended payload length - b += struct.pack('!Q', payload_length) - if masking_key is not None: - b += masking_key - return b +DEFAULT = object() +class FrameHeader: + def __init__( + self, + opcode = OPCODE.TEXT, + payload_length = 0, + fin = False, + rsv1 = False, + rsv2 = False, + rsv3 = False, + masking_key = None, + mask = DEFAULT, + length_code = DEFAULT + ): + self.opcode = opcode + self.payload_length = payload_length + self.fin = fin + self.rsv1 = rsv1 + self.rsv2 = rsv2 + self.rsv3 = rsv3 + self.mask = mask + self.masking_key = masking_key + self.length_code = length_code + + def to_bytes(self): + first_byte = utils.setbit(0, 7, self.fin) + first_byte = utils.setbit(first_byte, 6, self.rsv1) + first_byte = utils.setbit(first_byte, 5, self.rsv2) + first_byte = utils.setbit(first_byte, 4, self.rsv3) + first_byte = first_byte | self.opcode + + if self.length_code is DEFAULT: + length_code = make_length_code(self.payload_length) + else: + length_code = self.length_code + + if self.mask is DEFAULT: + mask = bool(self.masking_key) + else: + mask = self.mask + + second_byte = (mask << 7) | length_code + + b = chr(first_byte) + chr(second_byte) + + if self.payload_length < 126: + pass + elif self.payload_length < MAX_16_BIT_INT: + # '!H' pack as 16 bit unsigned short + # add 2 byte extended payload length + b += struct.pack('!H', self.payload_length) + elif self.payload_length < MAX_64_BIT_INT: + # '!Q' = pack as 64 bit unsigned long long + # add 8 bytes extended payload length + b += struct.pack('!Q', self.payload_length) + if self.masking_key is not None: + b += self.masking_key + return b + + @classmethod + def from_file(klass, fp): + """ + read a websockets frame header + """ + first_byte = utils.bytes_to_int(fp.read(1)) + second_byte = utils.bytes_to_int(fp.read(1)) + + fin = utils.getbit(first_byte, 7) + rsv1 = utils.getbit(first_byte, 6) + rsv2 = utils.getbit(first_byte, 5) + rsv3 = utils.getbit(first_byte, 4) + # grab right most 4 bits by and-ing with 00001111 + opcode = first_byte & 15 + # grab left most bit + mask_bit = second_byte >> 7 + # grab the next 7 bits + length_code = second_byte & 127 + + # payload_lengthy > 125 indicates you need to read more bytes + # to get the actual payload length + if length_code <= 125: + payload_length = length_code + elif length_code == 126: + payload_length = utils.bytes_to_int(fp.read(2)) + elif length_code == 127: + payload_length = utils.bytes_to_int(fp.read(8)) + + # masking key only present if mask bit set + if mask_bit == 1: + masking_key = fp.read(4) + else: + masking_key = None + + return klass( + fin = fin, + rsv1 = rsv1, + rsv2 = rsv2, + rsv3 = rsv3, + opcode = opcode, + mask = mask_bit, + length_code = length_code, + payload_length = payload_length, + masking_key = masking_key, + ) + + def __eq__(self, other): + return self.to_bytes() == other.to_bytes() class Frame(object): @@ -194,27 +249,10 @@ class Frame(object): | Payload Data continued ... | +---------------------------------------------------------------+ """ - def __init__( - self, - fin, # decmial integer 1 or 0 - opcode, # decmial integer 1 - 4 - payload = "", # bytestring - masking_key = None, # 32 bit byte string - mask_bit = 0, # decimal integer 1 or 0 - payload_length_code = None, # decimal integer 1 - 127 - rsv1 = 0, # decimal integer 1 or 0 - rsv2 = 0, # decimal integer 1 or 0 - rsv3 = 0, # decimal integer 1 or 0 - ): - self.fin = fin - self.rsv1 = rsv1 - self.rsv2 = rsv2 - self.rsv3 = rsv3 - self.opcode = opcode - self.mask_bit = mask_bit - self.payload_length_code = payload_length_code - self.masking_key = masking_key + def __init__(self, payload = "", **kwargs): self.payload = payload + kwargs["payload_length"] = kwargs.get("payload_length", len(payload)) + self.header = FrameHeader(**kwargs) @classmethod def default(cls, message, from_client = False): @@ -230,10 +268,10 @@ class Frame(object): masking_key = None return cls( + message, fin = 1, # final frame opcode = OPCODE.TEXT, # text - mask_bit = mask_bit, - payload = message, + mask = mask_bit, masking_key = masking_key, ) @@ -243,30 +281,30 @@ class Frame(object): Frame has not been corrupted. """ constraints = [ - 0 <= self.fin <= 1, - 0 <= self.rsv1 <= 1, - 0 <= self.rsv2 <= 1, - 0 <= self.rsv3 <= 1, - 1 <= self.opcode <= 4, - 0 <= self.mask_bit <= 1, + 0 <= self.header.fin <= 1, + 0 <= self.header.rsv1 <= 1, + 0 <= self.header.rsv2 <= 1, + 0 <= self.header.rsv3 <= 1, + 1 <= self.header.opcode <= 4, + 0 <= self.header.mask <= 1, #1 <= self.payload_length_code <= 127, - 1 <= len(self.masking_key) <= 4 if self.mask_bit else True, - self.masking_key is not None if self.mask_bit else True + 1 <= len(self.header.masking_key) <= 4 if self.header.mask else True, + self.header.masking_key is not None if self.header.mask else True ] if not all(constraints): return False return True - def human_readable(self): # pragma: nocover + def human_readable(self): return "\n".join([ - ("fin - " + str(self.fin)), - ("rsv1 - " + str(self.rsv1)), - ("rsv2 - " + str(self.rsv2)), - ("rsv3 - " + str(self.rsv3)), - ("opcode - " + str(self.opcode)), - ("mask_bit - " + str(self.mask_bit)), - ("payload_length_code - " + str(self.payload_length_code)), - ("masking_key - " + repr(str(self.masking_key))), + ("fin - " + str(self.header.fin)), + ("rsv1 - " + str(self.header.rsv1)), + ("rsv2 - " + str(self.header.rsv2)), + ("rsv3 - " + str(self.header.rsv3)), + ("opcode - " + str(self.header.opcode)), + ("mask - " + str(self.header.mask)), + ("length_code - " + str(self.header.length_code)), + ("masking_key - " + repr(str(self.header.masking_key))), ("payload - " + repr(str(self.payload))), ]) @@ -284,18 +322,9 @@ class Frame(object): If you haven't checked is_valid_frame() then there's no guarentees that the serialized bytes will be correct. see safe_to_bytes() """ - b = frame_header_bytes( - opcode = self.opcode, - fin = self.fin, - rsv1 = self.rsv1, - rsv2 = self.rsv2, - rsv3 = self.rsv3, - mask = self.mask_bit, - masking_key = self.masking_key, - payload_length = len(self.payload) if self.payload else 0 - ) - if self.masking_key: - b += apply_mask(self.payload, self.masking_key) + b = self.header.to_bytes() + if self.header.masking_key: + b += apply_mask(self.payload, self.header.masking_key) else: b += self.payload return b @@ -312,66 +341,20 @@ class Frame(object): fp is a "file like" object that could be backed by a network stream or a disk or an in memory stream reader """ - first_byte = utils.bytes_to_int(fp.read(1)) - second_byte = utils.bytes_to_int(fp.read(1)) - - # grab the left most bit - fin = first_byte >> 7 - # grab right most 4 bits by and-ing with 00001111 - opcode = first_byte & 15 - # grab left most bit - mask_bit = second_byte >> 7 - # grab the next 7 bits - payload_length = second_byte & 127 - - # payload_lengthy > 125 indicates you need to read more bytes - # to get the actual payload length - if payload_length <= 125: - actual_payload_length = payload_length - - elif payload_length == 126: - actual_payload_length = utils.bytes_to_int(fp.read(2)) - - elif payload_length == 127: - actual_payload_length = utils.bytes_to_int(fp.read(8)) - - # masking key only present if mask bit set - if mask_bit == 1: - masking_key = fp.read(4) - else: - masking_key = None - - payload = fp.read(actual_payload_length) + header = FrameHeader.from_file(fp) + payload = fp.read(header.payload_length) - if mask_bit == 1 and masking_key: - payload = apply_mask(payload, masking_key) + if header.mask == 1 and header.masking_key: + payload = apply_mask(payload, header.masking_key) return cls( - fin = fin, - opcode = opcode, - mask_bit = mask_bit, - payload_length_code = payload_length, - payload = payload, - masking_key = masking_key, + payload, + fin = header.fin, + opcode = header.opcode, + mask = header.mask, + payload_length = header.payload_length, + masking_key = header.masking_key, ) def __eq__(self, other): - if self.payload_length_code is None: - myplc = make_length_code(len(self.payload)) - else: - myplc = self.payload_length_code - if other.payload_length_code is None: - otherplc = make_length_code(len(other.payload)) - else: - otherplc = other.payload_length_code - return ( - self.fin == other.fin and - self.rsv1 == other.rsv1 and - self.rsv2 == other.rsv2 and - self.rsv3 == other.rsv3 and - self.opcode == other.opcode and - self.mask_bit == other.mask_bit and - self.masking_key == other.masking_key and - self.payload == other.payload, - myplc == otherplc - ) + return self.to_bytes() == other.to_bytes() -- cgit v1.2.3 From def93ea8cae69676a91b01e149e8a406fa03eacd Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Fri, 24 Apr 2015 15:23:00 +1200 Subject: websockets: remove validation We don't really need this any more. The interface is much less error prone because bit flags are no longer integers, we have a range check on opcode on header instantiation, and we've deferred length code calculation and so forth into the byte render methods. --- netlib/websockets.py | 24 ++++-------------------- 1 file changed, 4 insertions(+), 20 deletions(-) (limited to 'netlib') diff --git a/netlib/websockets.py b/netlib/websockets.py index 016e75c2..b1afa620 100644 --- a/netlib/websockets.py +++ b/netlib/websockets.py @@ -115,6 +115,8 @@ def create_server_nonce(client_nonce): DEFAULT = object() + + class FrameHeader: def __init__( self, @@ -128,6 +130,8 @@ class FrameHeader: mask = DEFAULT, length_code = DEFAULT ): + if not 0 <= opcode < 2 ** 4: + raise ValueError("opcode must be 0-16") self.opcode = opcode self.payload_length = payload_length self.fin = fin @@ -275,26 +279,6 @@ class Frame(object): masking_key = masking_key, ) - def is_valid(self): - """ - Validate websocket frame invariants, call at anytime to ensure the - Frame has not been corrupted. - """ - constraints = [ - 0 <= self.header.fin <= 1, - 0 <= self.header.rsv1 <= 1, - 0 <= self.header.rsv2 <= 1, - 0 <= self.header.rsv3 <= 1, - 1 <= self.header.opcode <= 4, - 0 <= self.header.mask <= 1, - #1 <= self.payload_length_code <= 127, - 1 <= len(self.header.masking_key) <= 4 if self.header.mask else True, - self.header.masking_key is not None if self.header.mask else True - ] - if not all(constraints): - return False - return True - def human_readable(self): return "\n".join([ ("fin - " + str(self.header.fin)), -- cgit v1.2.3 From 192fd1db7f233b71398c5255cbdebe1928768b55 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Fri, 24 Apr 2015 15:31:14 +1200 Subject: websockets: include all header values in frame roundtrip --- netlib/websockets.py | 27 +++++++++++++++------------ 1 file changed, 15 insertions(+), 12 deletions(-) (limited to 'netlib') diff --git a/netlib/websockets.py b/netlib/websockets.py index b1afa620..85aad9c6 100644 --- a/netlib/websockets.py +++ b/netlib/websockets.py @@ -159,7 +159,7 @@ class FrameHeader: else: mask = self.mask - second_byte = (mask << 7) | length_code + second_byte = utils.setbit(length_code, 7, mask) b = chr(first_byte) + chr(second_byte) @@ -189,10 +189,9 @@ class FrameHeader: rsv1 = utils.getbit(first_byte, 6) rsv2 = utils.getbit(first_byte, 5) rsv3 = utils.getbit(first_byte, 4) - # grab right most 4 bits by and-ing with 00001111 + # grab right-most 4 bits opcode = first_byte & 15 - # grab left most bit - mask_bit = second_byte >> 7 + mask_bit = utils.getbit(second_byte, 7) # grab the next 7 bits length_code = second_byte & 127 @@ -279,6 +278,14 @@ class Frame(object): masking_key = masking_key, ) + @classmethod + def from_bytes(cls, bytestring): + """ + Construct a websocket frame from an in-memory bytestring + to construct a frame from a stream of bytes, use from_file() directly + """ + return cls.from_file(io.BytesIO(bytestring)) + def human_readable(self): return "\n".join([ ("fin - " + str(self.header.fin)), @@ -292,14 +299,6 @@ class Frame(object): ("payload - " + repr(str(self.payload))), ]) - @classmethod - def from_bytes(cls, bytestring): - """ - Construct a websocket frame from an in-memory bytestring - to construct a frame from a stream of bytes, use from_file() directly - """ - return cls.from_file(io.BytesIO(bytestring)) - def to_bytes(self): """ Serialize the frame back into the wire format, returns a bytestring @@ -338,6 +337,10 @@ class Frame(object): mask = header.mask, payload_length = header.payload_length, masking_key = header.masking_key, + rsv1 = header.rsv1, + rsv2 = header.rsv2, + rsv3 = header.rsv3, + length_code = header.length_code ) def __eq__(self, other): -- cgit v1.2.3 From 18df329930eb822395caf279862589d2a40413c9 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Fri, 24 Apr 2015 15:42:31 +1200 Subject: websockets: nicer frame construction - Resolve unspecified values on instantiation - Add a check for masking key length - Smarter resolution for masking_key and mask values. Do the right thing unless told not to. --- netlib/websockets.py | 38 +++++++++++++++++++++++--------------- 1 file changed, 23 insertions(+), 15 deletions(-) (limited to 'netlib') diff --git a/netlib/websockets.py b/netlib/websockets.py index 85aad9c6..493bb18a 100644 --- a/netlib/websockets.py +++ b/netlib/websockets.py @@ -126,7 +126,7 @@ class FrameHeader: rsv1 = False, rsv2 = False, rsv3 = False, - masking_key = None, + masking_key = DEFAULT, mask = DEFAULT, length_code = DEFAULT ): @@ -138,9 +138,27 @@ class FrameHeader: self.rsv1 = rsv1 self.rsv2 = rsv2 self.rsv3 = rsv3 - self.mask = mask - self.masking_key = masking_key - self.length_code = length_code + + if length_code is DEFAULT: + self.length_code = make_length_code(self.payload_length) + else: + self.length_code = length_code + + if mask is DEFAULT and masking_key is DEFAULT: + self.mask = False + self.masking_key = "" + elif mask is DEFAULT: + self.mask = 1 + self.masking_key = masking_key + elif masking_key is DEFAULT: + self.mask = mask + self.masking_key = os.urandom(4) + else: + self.mask = mask + self.masking_key = masking_key + + if self.masking_key and len(self.masking_key) != 4: + raise ValueError("Masking key must be 4 bytes.") def to_bytes(self): first_byte = utils.setbit(0, 7, self.fin) @@ -149,17 +167,7 @@ class FrameHeader: first_byte = utils.setbit(first_byte, 4, self.rsv3) first_byte = first_byte | self.opcode - if self.length_code is DEFAULT: - length_code = make_length_code(self.payload_length) - else: - length_code = self.length_code - - if self.mask is DEFAULT: - mask = bool(self.masking_key) - else: - mask = self.mask - - second_byte = utils.setbit(length_code, 7, mask) + second_byte = utils.setbit(self.length_code, 7, self.mask) b = chr(first_byte) + chr(second_byte) -- cgit v1.2.3 From 80860229209b4c6eb8384e1bca3cabdbe062fe6e Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Thu, 30 Apr 2015 09:04:22 +1200 Subject: Add a tiny utility class for keeping bi-directional mappings. Use it in websocket and socks. --- netlib/socks.py | 60 ++++++++++++++++++++++++++++------------------------ netlib/utils.py | 26 +++++++++++++++++++++++ netlib/websockets.py | 25 ++++++++++++++++------ 3 files changed, 77 insertions(+), 34 deletions(-) (limited to 'netlib') diff --git a/netlib/socks.py b/netlib/socks.py index a3c4e9a2..497b8eef 100644 --- a/netlib/socks.py +++ b/netlib/socks.py @@ -2,7 +2,7 @@ from __future__ import (absolute_import, print_function, division) import socket import struct import array -from . import tcp +from . import tcp, utils class SocksError(Exception): @@ -11,40 +11,45 @@ class SocksError(Exception): self.code = code -class VERSION(object): - SOCKS4 = 0x04 +VERSION = utils.BiDi( + SOCKS4 = 0x04, SOCKS5 = 0x05 +) -class CMD(object): - CONNECT = 0x01 - BIND = 0x02 +CMD = utils.BiDi( + CONNECT = 0x01, + BIND = 0x02, UDP_ASSOCIATE = 0x03 +) -class ATYP(object): - IPV4_ADDRESS = 0x01 - DOMAINNAME = 0x03 +ATYP = utils.BiDi( + IPV4_ADDRESS = 0x01, + DOMAINNAME = 0x03, IPV6_ADDRESS = 0x04 - - -class REP(object): - SUCCEEDED = 0x00 - GENERAL_SOCKS_SERVER_FAILURE = 0x01 - CONNECTION_NOT_ALLOWED_BY_RULESET = 0x02 - NETWORK_UNREACHABLE = 0x03 - HOST_UNREACHABLE = 0x04 - CONNECTION_REFUSED = 0x05 - TTL_EXPIRED = 0x06 - COMMAND_NOT_SUPPORTED = 0x07 - ADDRESS_TYPE_NOT_SUPPORTED = 0x08 - - -class METHOD(object): - NO_AUTHENTICATION_REQUIRED = 0x00 - GSSAPI = 0x01 - USERNAME_PASSWORD = 0x02 +) + + +REP = utils.BiDi( + SUCCEEDED = 0x00, + GENERAL_SOCKS_SERVER_FAILURE = 0x01, + CONNECTION_NOT_ALLOWED_BY_RULESET = 0x02, + NETWORK_UNREACHABLE = 0x03, + HOST_UNREACHABLE = 0x04, + CONNECTION_REFUSED = 0x05, + TTL_EXPIRED = 0x06, + COMMAND_NOT_SUPPORTED = 0x07, + ADDRESS_TYPE_NOT_SUPPORTED = 0x08, +) + + +METHOD = utils.BiDi( + NO_AUTHENTICATION_REQUIRED = 0x00, + GSSAPI = 0x01, + USERNAME_PASSWORD = 0x02, NO_ACCEPTABLE_METHODS = 0xFF +) def _read(f, n): @@ -146,4 +151,3 @@ class Message(object): "Unknown ATYP: %s" % self.atyp ) f.write(struct.pack("!H", self.addr.port)) - diff --git a/netlib/utils.py b/netlib/utils.py index 44bed43a..905d948f 100644 --- a/netlib/utils.py +++ b/netlib/utils.py @@ -65,3 +65,29 @@ def getbit(byte, offset): mask = 1 << offset if byte & mask: return True + + +class BiDi: + """ + A wee utility class for keeping bi-directional mappings, like field + constants in protocols: + + CONST = BiDi(a=1, b=2) + assert CONST.a == 1 + assert CONST[1] == "a" + """ + def __init__(self, **kwargs): + self.names = kwargs + self.values = {} + for k, v in kwargs.items(): + self.values[v] = k + if len(self.names) != len(self.values): + raise ValueError("Duplicate values not allowed.") + + def __getattr__(self, k): + if k in self.names: + return self.names[k] + raise AttributeError("No such attribute: %s", k) + + def __getitem__(self, k): + return self.values[k] diff --git a/netlib/websockets.py b/netlib/websockets.py index 493bb18a..d358ed53 100644 --- a/netlib/websockets.py +++ b/netlib/websockets.py @@ -25,13 +25,14 @@ MAX_16_BIT_INT = (1 << 16) MAX_64_BIT_INT = (1 << 64) -class OPCODE: - CONTINUE = 0x00 - TEXT = 0x01 - BINARY = 0x02 - CLOSE = 0x08 - PING = 0x09 +OPCODE = utils.BiDi( + CONTINUE = 0x00, + TEXT = 0x01, + BINARY = 0x02, + CLOSE = 0x08, + PING = 0x09, PONG = 0x0a +) def apply_mask(message, masking_key): @@ -160,6 +161,18 @@ class FrameHeader: if self.masking_key and len(self.masking_key) != 4: raise ValueError("Masking key must be 4 bytes.") + def human_readable(self): + return "\n".join([ + ("fin - " + str(self.fin)), + ("rsv1 - " + str(self.rsv1)), + ("rsv2 - " + str(self.rsv2)), + ("rsv3 - " + str(self.rsv3)), + ("opcode - " + str(self.opcode)), + ("mask - " + str(self.mask)), + ("length_code - " + str(self.length_code)), + ("masking_key - " + repr(str(self.masking_key))), + ]) + def to_bytes(self): first_byte = utils.setbit(0, 7, self.fin) first_byte = utils.setbit(first_byte, 6, self.rsv1) -- cgit v1.2.3 From 4dce7ee074c242f5b6530ff64879875d98c1d255 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Thu, 30 Apr 2015 12:10:08 +1200 Subject: websockets: more compact and legible human_readable --- netlib/utils.py | 25 +++++++++++++++++++++---- netlib/websockets.py | 38 +++++++++++++++++--------------------- 2 files changed, 38 insertions(+), 25 deletions(-) (limited to 'netlib') diff --git a/netlib/utils.py b/netlib/utils.py index 905d948f..7e539977 100644 --- a/netlib/utils.py +++ b/netlib/utils.py @@ -70,11 +70,12 @@ def getbit(byte, offset): class BiDi: """ A wee utility class for keeping bi-directional mappings, like field - constants in protocols: + constants in protocols. Names are attributes on the object, dict-like + access maps values to names: CONST = BiDi(a=1, b=2) assert CONST.a == 1 - assert CONST[1] == "a" + assert CONST.get_name(1) == "a" """ def __init__(self, **kwargs): self.names = kwargs @@ -89,5 +90,21 @@ class BiDi: return self.names[k] raise AttributeError("No such attribute: %s", k) - def __getitem__(self, k): - return self.values[k] + def get_name(self, n, default=None): + return self.values.get(n, default) + + +def pretty_size(size): + suffixes = [ + ("B", 2**10), + ("kB", 2**20), + ("MB", 2**30), + ] + for suf, lim in suffixes: + if size >= lim: + continue + else: + x = round(size/float(lim/2**10), 2) + if x == int(x): + x = int(x) + return str(x) + suf diff --git a/netlib/websockets.py b/netlib/websockets.py index d358ed53..1d02d684 100644 --- a/netlib/websockets.py +++ b/netlib/websockets.py @@ -162,16 +162,21 @@ class FrameHeader: raise ValueError("Masking key must be 4 bytes.") def human_readable(self): - return "\n".join([ - ("fin - " + str(self.fin)), - ("rsv1 - " + str(self.rsv1)), - ("rsv2 - " + str(self.rsv2)), - ("rsv3 - " + str(self.rsv3)), - ("opcode - " + str(self.opcode)), - ("mask - " + str(self.mask)), - ("length_code - " + str(self.length_code)), - ("masking_key - " + repr(str(self.masking_key))), - ]) + vals = [ + "wf:", + OPCODE.get_name(self.opcode, hex(self.opcode)).lower() + ] + flags = [] + for i in ["fin", "rsv1", "rsv2", "rsv3", "mask"]: + if getattr(self, i): + flags.append(i) + if flags: + vals.extend([":", "|".join(flags)]) + if self.masking_key: + vals.append(":key=%s"%repr(self.masking_key)) + if self.payload_length: + vals.append(" %s"%utils.pretty_size(self.payload_length)) + return "".join(vals) def to_bytes(self): first_byte = utils.setbit(0, 7, self.fin) @@ -308,17 +313,8 @@ class Frame(object): return cls.from_file(io.BytesIO(bytestring)) def human_readable(self): - return "\n".join([ - ("fin - " + str(self.header.fin)), - ("rsv1 - " + str(self.header.rsv1)), - ("rsv2 - " + str(self.header.rsv2)), - ("rsv3 - " + str(self.header.rsv3)), - ("opcode - " + str(self.header.opcode)), - ("mask - " + str(self.header.mask)), - ("length_code - " + str(self.header.length_code)), - ("masking_key - " + repr(str(self.header.masking_key))), - ("payload - " + repr(str(self.payload))), - ]) + hdr = self.header.human_readable() + return hdr + "\n" + repr(self.payload) def to_bytes(self): """ -- cgit v1.2.3 From 7d9e38ffb10e92b5127f203c2d8a524da8698b00 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Fri, 1 May 2015 10:09:35 +1200 Subject: websockets: A progressive masker. --- netlib/websockets.py | 32 ++++++++++++++++++-------------- 1 file changed, 18 insertions(+), 14 deletions(-) (limited to 'netlib') diff --git a/netlib/websockets.py b/netlib/websockets.py index 1d02d684..84eb03ba 100644 --- a/netlib/websockets.py +++ b/netlib/websockets.py @@ -35,21 +35,25 @@ OPCODE = utils.BiDi( ) -def apply_mask(message, masking_key): +class Masker: """ - Data sent from the server must be masked to prevent malicious clients - from sending data over the wire in predictable patterns + Data sent from the server must be masked to prevent malicious clients + from sending data over the wire in predictable patterns - This method both encodes and decodes strings with the provided mask - - Servers do not have to mask data they send to the client. - https://tools.ietf.org/html/rfc6455#section-5.3 + Servers do not have to mask data they send to the client. + https://tools.ietf.org/html/rfc6455#section-5.3 """ - masks = [utils.bytes_to_int(byte) for byte in masking_key] - result = "" - for char in message: - result += chr(ord(char) ^ masks[len(result) % 4]) - return result + def __init__(self, key): + self.key = key + self.masks = [utils.bytes_to_int(byte) for byte in key] + self.offset = 0 + + def __call__(self, data): + result = "" + for c in data: + result += chr(ord(c) ^ self.masks[self.offset % 4]) + self.offset += 1 + return result def client_handshake_headers(key=None, version=VERSION): @@ -324,7 +328,7 @@ class Frame(object): """ b = self.header.to_bytes() if self.header.masking_key: - b += apply_mask(self.payload, self.header.masking_key) + b += Masker(self.header.masking_key)(self.payload) else: b += self.payload return b @@ -345,7 +349,7 @@ class Frame(object): payload = fp.read(header.payload_length) if header.mask == 1 and header.masking_key: - payload = apply_mask(payload, header.masking_key) + payload = Masker(header.masking_key)(payload) return cls( payload, -- cgit v1.2.3 From 08b2e2a6a98fd175e1b49d62dffde34e91c77b1c Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Fri, 1 May 2015 10:31:20 +1200 Subject: websockets: more flexible masking interface. --- netlib/websockets.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) (limited to 'netlib') diff --git a/netlib/websockets.py b/netlib/websockets.py index 84eb03ba..0ad0e294 100644 --- a/netlib/websockets.py +++ b/netlib/websockets.py @@ -48,13 +48,18 @@ class Masker: self.masks = [utils.bytes_to_int(byte) for byte in key] self.offset = 0 - def __call__(self, data): + def mask(self, offset, data): result = "" for c in data: - result += chr(ord(c) ^ self.masks[self.offset % 4]) - self.offset += 1 + result += chr(ord(c) ^ self.masks[offset % 4]) + offset += 1 return result + def __call__(self, data): + ret = self.mask(self.offset, data) + self.offset += len(ret) + return ret + def client_handshake_headers(key=None, version=VERSION): """ -- cgit v1.2.3 From f2bc58cdd2f2b9b0025a88c0faccf55e10b29353 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Tue, 5 May 2015 10:47:02 +1200 Subject: Add tcp.Reader.safe_read, use it in socks and websockets safe_read is guaranteed to raise or return a byte string of the requested length. It's particularly useful for implementing binary protocols. --- netlib/socks.py | 32 +++++++++----------------------- netlib/tcp.py | 48 ++++++++++++++++++++++++++++++++++-------------- netlib/websockets.py | 16 ++++++++-------- 3 files changed, 51 insertions(+), 45 deletions(-) (limited to 'netlib') diff --git a/netlib/socks.py b/netlib/socks.py index 497b8eef..6f9f57bd 100644 --- a/netlib/socks.py +++ b/netlib/socks.py @@ -52,20 +52,6 @@ METHOD = utils.BiDi( ) -def _read(f, n): - try: - d = f.read(n) - if len(d) == n: - return d - else: - raise SocksError( - REP.GENERAL_SOCKS_SERVER_FAILURE, - "Incomplete Read" - ) - except socket.error as e: - raise SocksError(REP.GENERAL_SOCKS_SERVER_FAILURE, str(e)) - - class ClientGreeting(object): __slots__ = ("ver", "methods") @@ -75,9 +61,9 @@ class ClientGreeting(object): @classmethod def from_file(cls, f): - ver, nmethods = struct.unpack("!BB", _read(f, 2)) + ver, nmethods = struct.unpack("!BB", f.safe_read(2)) methods = array.array("B") - methods.fromstring(_read(f, nmethods)) + methods.fromstring(f.safe_read(nmethods)) return cls(ver, methods) def to_file(self, f): @@ -94,7 +80,7 @@ class ServerGreeting(object): @classmethod def from_file(cls, f): - ver, method = struct.unpack("!BB", _read(f, 2)) + ver, method = struct.unpack("!BB", f.safe_read(2)) return cls(ver, method) def to_file(self, f): @@ -112,27 +98,27 @@ class Message(object): @classmethod def from_file(cls, f): - ver, msg, rsv, atyp = struct.unpack("!BBBB", _read(f, 4)) + ver, msg, rsv, atyp = struct.unpack("!BBBB", f.safe_read(4)) if rsv != 0x00: raise SocksError(REP.GENERAL_SOCKS_SERVER_FAILURE, "Socks Request: Invalid reserved byte: %s" % rsv) if atyp == ATYP.IPV4_ADDRESS: # We use tnoa here as ntop is not commonly available on Windows. - host = socket.inet_ntoa(_read(f, 4)) + host = socket.inet_ntoa(f.safe_read(4)) use_ipv6 = False elif atyp == ATYP.IPV6_ADDRESS: - host = socket.inet_ntop(socket.AF_INET6, _read(f, 16)) + host = socket.inet_ntop(socket.AF_INET6, f.safe_read(16)) use_ipv6 = True elif atyp == ATYP.DOMAINNAME: - length, = struct.unpack("!B", _read(f, 1)) - host = _read(f, length) + length, = struct.unpack("!B", f.safe_read(1)) + host = f.safe_read(length) use_ipv6 = False else: raise SocksError(REP.ADDRESS_TYPE_NOT_SUPPORTED, "Socks Request: Unknown ATYP: %s" % atyp) - port, = struct.unpack("!H", _read(f, 2)) + port, = struct.unpack("!H", f.safe_read(2)) addr = tcp.Address((host, port), use_ipv6=use_ipv6) return cls(ver, msg, atyp, addr) diff --git a/netlib/tcp.py b/netlib/tcp.py index 84008e2c..dbe114a1 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -24,6 +24,7 @@ OP_NO_SSLv3 = SSL.OP_NO_SSLv3 class NetLibError(Exception): pass class NetLibDisconnect(NetLibError): pass +class NetLibIncomplete(NetLibError): pass class NetLibTimeout(NetLibError): pass class NetLibSSLError(NetLibError): pass @@ -195,10 +196,23 @@ class Reader(_FileLike): break return result + def safe_read(self, length): + """ + Like .read, but is guaranteed to either return length bytes, or + raise an exception. + """ + result = self.read(length) + if length != -1 and len(result) != length: + raise NetLibIncomplete( + "Expected %s bytes, got %s"%(length, len(result)) + ) + return result + class Address(object): """ - This class wraps an IPv4/IPv6 tuple to provide named attributes and ipv6 information. + This class wraps an IPv4/IPv6 tuple to provide named attributes and + ipv6 information. """ def __init__(self, address, use_ipv6=False): self.address = tuple(address) @@ -247,22 +261,28 @@ def close_socket(sock): """ try: # We already indicate that we close our end. - sock.shutdown(socket.SHUT_WR) # may raise "Transport endpoint is not connected" on Linux + # may raise "Transport endpoint is not connected" on Linux + sock.shutdown(socket.SHUT_WR) - # Section 4.2.2.13 of RFC 1122 tells us that a close() with any - # pending readable data could lead to an immediate RST being sent (which is the case on Windows). + # Section 4.2.2.13 of RFC 1122 tells us that a close() with any pending + # readable data could lead to an immediate RST being sent (which is the + # case on Windows). # http://ia600609.us.archive.org/22/items/TheUltimateSo_lingerPageOrWhyIsMyTcpNotReliable/the-ultimate-so_linger-page-or-why-is-my-tcp-not-reliable.html # - # This in turn results in the following issue: If we send an error page to the client and then close the socket, - # the RST may be received by the client before the error page and the users sees a connection error rather than - # the error page. Thus, we try to empty the read buffer on Windows first. - # (see https://github.com/mitmproxy/mitmproxy/issues/527#issuecomment-93782988) + # This in turn results in the following issue: If we send an error page + # to the client and then close the socket, the RST may be received by + # the client before the error page and the users sees a connection + # error rather than the error page. Thus, we try to empty the read + # buffer on Windows first. (see + # https://github.com/mitmproxy/mitmproxy/issues/527#issuecomment-93782988) # + if os.name == "nt": # pragma: no cover - # We cannot rely on the shutdown()-followed-by-read()-eof technique proposed by the page above: - # Some remote machines just don't send a TCP FIN, which would leave us in the unfortunate situation that - # recv() would block infinitely. - # As a workaround, we set a timeout here even if we are in blocking mode. + # We cannot rely on the shutdown()-followed-by-read()-eof technique + # proposed by the page above: Some remote machines just don't send + # a TCP FIN, which would leave us in the unfortunate situation that + # recv() would block infinitely. As a workaround, we set a timeout + # here even if we are in blocking mode. sock.settimeout(sock.gettimeout() or 20) # limit at a megabyte so that we don't read infinitely @@ -292,10 +312,10 @@ class _Connection(object): def finish(self): self.finished = True - # If we have an SSL connection, wfile.close == connection.close # (We call _FileLike.set_descriptor(conn)) - # Closing the socket is not our task, therefore we don't call close then. + # Closing the socket is not our task, therefore we don't call close + # then. if type(self.connection) != SSL.Connection: if not getattr(self.wfile, "closed", False): try: diff --git a/netlib/websockets.py b/netlib/websockets.py index 0ad0e294..6d08e101 100644 --- a/netlib/websockets.py +++ b/netlib/websockets.py @@ -5,7 +5,7 @@ import os import struct import io -from . import utils, odict +from . import utils, odict, tcp # Colleciton of utility functions that implement small portions of the RFC6455 # WebSockets Protocol Useful for building WebSocket clients and servers. @@ -217,8 +217,8 @@ class FrameHeader: """ read a websockets frame header """ - first_byte = utils.bytes_to_int(fp.read(1)) - second_byte = utils.bytes_to_int(fp.read(1)) + first_byte = utils.bytes_to_int(fp.safe_read(1)) + second_byte = utils.bytes_to_int(fp.safe_read(1)) fin = utils.getbit(first_byte, 7) rsv1 = utils.getbit(first_byte, 6) @@ -235,13 +235,13 @@ class FrameHeader: if length_code <= 125: payload_length = length_code elif length_code == 126: - payload_length = utils.bytes_to_int(fp.read(2)) + payload_length = utils.bytes_to_int(fp.safe_read(2)) elif length_code == 127: - payload_length = utils.bytes_to_int(fp.read(8)) + payload_length = utils.bytes_to_int(fp.safe_read(8)) # masking key only present if mask bit set if mask_bit == 1: - masking_key = fp.read(4) + masking_key = fp.safe_read(4) else: masking_key = None @@ -319,7 +319,7 @@ class Frame(object): Construct a websocket frame from an in-memory bytestring to construct a frame from a stream of bytes, use from_file() directly """ - return cls.from_file(io.BytesIO(bytestring)) + return cls.from_file(tcp.Reader(io.BytesIO(bytestring))) def human_readable(self): hdr = self.header.human_readable() @@ -351,7 +351,7 @@ class Frame(object): stream or a disk or an in memory stream reader """ header = FrameHeader.from_file(fp) - payload = fp.read(header.payload_length) + payload = fp.safe_read(header.payload_length) if header.mask == 1 and header.masking_key: payload = Masker(header.masking_key)(payload) -- cgit v1.2.3 From ace4454523a81303b6432714f8ff73dab02a7e33 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Sat, 16 May 2015 11:32:18 +1200 Subject: Zap outdated comment --- netlib/websockets.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) (limited to 'netlib') diff --git a/netlib/websockets.py b/netlib/websockets.py index 6d08e101..a2d55c19 100644 --- a/netlib/websockets.py +++ b/netlib/websockets.py @@ -327,9 +327,7 @@ class Frame(object): def to_bytes(self): """ - Serialize the frame back into the wire format, returns a bytestring - If you haven't checked is_valid_frame() then there's no guarentees - that the serialized bytes will be correct. see safe_to_bytes() + Serialize the frame to wire format. Returns a string. """ b = self.header.to_bytes() if self.header.masking_key: -- cgit v1.2.3 From f40bf865b1e767d4f15e0e829b9ca3132c33d11d Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Mon, 18 May 2015 10:46:00 +1200 Subject: release prep: bump version --- netlib/version.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'netlib') diff --git a/netlib/version.py b/netlib/version.py index 826c66fe..502dce3a 100644 --- a/netlib/version.py +++ b/netlib/version.py @@ -1,6 +1,6 @@ from __future__ import (absolute_import, print_function, division) -IVERSION = (0, 11, 2) +IVERSION = (0, 12, 0) VERSION = ".".join(str(i) for i in IVERSION) MINORVERSION = ".".join(str(i) for i in IVERSION[:2]) NAME = "netlib" -- cgit v1.2.3 From 46fadfc82386265c26b77ea0d8c3801585c84fbc Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Mon, 18 May 2015 17:16:42 +0200 Subject: improve displaying tcp addresses --- netlib/tcp.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) (limited to 'netlib') diff --git a/netlib/tcp.py b/netlib/tcp.py index dbe114a1..a5f43ea3 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -245,7 +245,10 @@ class Address(object): self.family = socket.AF_INET6 if b else socket.AF_INET def __repr__(self): - return repr(self.address) + return "{}:{}".format(self.host, self.port) + + def __str__(self): + return str(self.address) def __eq__(self, other): other = Address.wrap(other) -- cgit v1.2.3 From ae749975e537990f3db767b4d0d4c6ec2321a088 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Tue, 26 May 2015 10:43:28 +1200 Subject: Post release version bump. --- netlib/version.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'netlib') diff --git a/netlib/version.py b/netlib/version.py index 502dce3a..3eb0ffc9 100644 --- a/netlib/version.py +++ b/netlib/version.py @@ -1,6 +1,6 @@ from __future__ import (absolute_import, print_function, division) -IVERSION = (0, 12, 0) +IVERSION = (0, 12, 1) VERSION = ".".join(str(i) for i in IVERSION) MINORVERSION = ".".join(str(i) for i in IVERSION[:2]) NAME = "netlib" -- cgit v1.2.3 From 4ce6f43616db9c23a29484610045aecd88ed2cfc Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Mon, 25 May 2015 12:10:21 +0200 Subject: implement basic HTTP/2 frame classes --- netlib/h2/__init__.py | 1 + netlib/h2/frame.py | 375 ++++++++++++++++++++++++++++++++++++++++++++++++++ netlib/h2/h2.py | 25 ++++ 3 files changed, 401 insertions(+) create mode 100644 netlib/h2/__init__.py create mode 100644 netlib/h2/frame.py create mode 100644 netlib/h2/h2.py (limited to 'netlib') diff --git a/netlib/h2/__init__.py b/netlib/h2/__init__.py new file mode 100644 index 00000000..9b4faa33 --- /dev/null +++ b/netlib/h2/__init__.py @@ -0,0 +1 @@ +from __future__ import (absolute_import, print_function, division) diff --git a/netlib/h2/frame.py b/netlib/h2/frame.py new file mode 100644 index 00000000..52cc2992 --- /dev/null +++ b/netlib/h2/frame.py @@ -0,0 +1,375 @@ +import base64 +import hashlib +import os +import struct +import io + +from .. import utils, odict, tcp + +class Frame(object): + """ + Baseclass Frame + contains header + payload is defined in subclasses + """ + + FLAG_NO_FLAGS = 0x0 + FLAG_ACK = 0x1 + FLAG_END_STREAM = 0x1 + FLAG_END_HEADERS = 0x4 + FLAG_PADDED = 0x8 + FLAG_PRIORITY = 0x20 + + def __init__(self, length, flags, stream_id): + valid_flags = reduce(lambda x, y: x | y, self.VALID_FLAGS, 0x0) + if flags | valid_flags != valid_flags: + raise ValueError('invalid flags detected.') + + self.length = length + self.flags = flags + self.stream_id = stream_id + + @classmethod + def from_bytes(self, data): + fields = struct.unpack("!HBBBL", data[:9]) + length = (fields[0] << 8) + fields[1] + # type is already deducted from class + flags = fields[3] + stream_id = fields[4] + return FRAMES[fields[2]].from_bytes(length, flags, stream_id, data[9:]) + + def to_bytes(self): + payload = self.payload_bytes() + self.length = len(payload) + + b = struct.pack('!HB', self.length & 0xFFFF00, self.length & 0x0000FF) + b += struct.pack('!B', self.TYPE) + b += struct.pack('!B', self.flags) + b += struct.pack('!L', self.stream_id & 0x7FFFFFFF) + b += payload + + return b + + def __eq__(self, other): + return self.to_bytes() == other.to_bytes() + +class DataFrame(Frame): + TYPE = 0x0 + VALID_FLAGS = [Frame.FLAG_END_STREAM, Frame.FLAG_PADDED] + + def __init__(self, length=0, flags=Frame.FLAG_NO_FLAGS, stream_id=0x0, payload=b'', pad_length=0): + super(DataFrame, self).__init__(length, flags, stream_id) + self.payload = payload + self.pad_length = pad_length + + @classmethod + def from_bytes(self, length, flags, stream_id, payload): + f = self(length=length, flags=flags, stream_id=stream_id) + + if f.flags & self.FLAG_PADDED: + f.pad_length = struct.unpack('!B', payload[0])[0] + f.payload = payload[1:-f.pad_length] + else: + f.payload = payload + + return f + + def payload_bytes(self): + if self.stream_id == 0x0: + raise ValueError('DATA frames MUST be associated with a stream.') + + b = b'' + if self.flags & self.FLAG_PADDED: + b += struct.pack('!B', self.pad_length) + + b += bytes(self.payload) + + if self.flags & self.FLAG_PADDED: + b += b'\0' * self.pad_length + + return b + +class HeadersFrame(Frame): + TYPE = 0x1 + VALID_FLAGS = [Frame.FLAG_END_STREAM, Frame.FLAG_END_HEADERS, Frame.FLAG_PADDED, Frame.FLAG_PRIORITY] + + def __init__(self, length=0, flags=Frame.FLAG_NO_FLAGS, stream_id=0x0, header_block_fragment=b'', pad_length=0, exclusive=False, stream_dependency=0x0, weight=0): + super(HeadersFrame, self).__init__(length, flags, stream_id) + self.header_block_fragment = header_block_fragment + self.pad_length = pad_length + self.exclusive = exclusive + self.stream_dependency = stream_dependency + self.weight = weight + + @classmethod + def from_bytes(self, length, flags, stream_id, payload): + f = self(length=length, flags=flags, stream_id=stream_id) + + if f.flags & self.FLAG_PADDED: + f.pad_length = struct.unpack('!B', payload[0])[0] + f.header_block_fragment = payload[1:-f.pad_length] + else: + f.header_block_fragment = payload[0:] + + if f.flags & self.FLAG_PRIORITY: + f.stream_dependency, f.weight = struct.unpack('!LB', f.header_block_fragment[:5]) + f.exclusive = bool(f.stream_dependency >> 31) + f.stream_dependency &= 0x7FFFFFFF + f.header_block_fragment = f.header_block_fragment[5:] + + return f + + def payload_bytes(self): + if self.stream_id == 0x0: + raise ValueError('HEADERS frames MUST be associated with a stream.') + + b = b'' + if self.flags & self.FLAG_PADDED: + b += struct.pack('!B', self.pad_length) + + if self.flags & self.FLAG_PRIORITY: + b += struct.pack('!LB', (int(self.exclusive) << 31) | self.stream_dependency, self.weight) + + b += bytes(self.header_block_fragment) + + if self.flags & self.FLAG_PADDED: + b += b'\0' * self.pad_length + + return b + +class PriorityFrame(Frame): + TYPE = 0x2 + VALID_FLAGS = [] + + def __init__(self, length=0, flags=Frame.FLAG_NO_FLAGS, stream_id=0x0, exclusive=False, stream_dependency=0x0, weight=0): + super(PriorityFrame, self).__init__(length, flags, stream_id) + self.exclusive = exclusive + self.stream_dependency = stream_dependency + self.weight = weight + + @classmethod + def from_bytes(self, length, flags, stream_id, payload): + f = self(length=length, flags=flags, stream_id=stream_id) + + f.stream_dependency, f.weight = struct.unpack('!LB', payload) + f.exclusive = bool(f.stream_dependency >> 31) + f.stream_dependency &= 0x7FFFFFFF + + return f + + def payload_bytes(self): + if self.stream_id == 0x0: + raise ValueError('PRIORITY frames MUST be associated with a stream.') + + if self.stream_dependency == 0x0: + raise ValueError('stream dependency is invalid.') + + return struct.pack('!LB', (int(self.exclusive) << 31) | self.stream_dependency, self.weight) + +class RstStreamFrame(Frame): + TYPE = 0x3 + VALID_FLAGS = [] + + def __init__(self, length=0, flags=Frame.FLAG_NO_FLAGS, stream_id=0x0, error_code=0x0): + super(RstStreamFrame, self).__init__(length, flags, stream_id) + self.error_code = error_code + + @classmethod + def from_bytes(self, length, flags, stream_id, payload): + f = self(length=length, flags=flags, stream_id=stream_id) + f.error_code = struct.unpack('!L', payload)[0] + return f + + def payload_bytes(self): + if self.stream_id == 0x0: + raise ValueError('RST_STREAM frames MUST be associated with a stream.') + + return struct.pack('!L', self.error_code) + +class SettingsFrame(Frame): + TYPE = 0x4 + VALID_FLAGS = [Frame.FLAG_ACK] + + SETTINGS = utils.BiDi( + SETTINGS_HEADER_TABLE_SIZE = 0x1, + SETTINGS_ENABLE_PUSH = 0x2, + SETTINGS_MAX_CONCURRENT_STREAMS = 0x3, + SETTINGS_INITIAL_WINDOW_SIZE = 0x4, + SETTINGS_MAX_FRAME_SIZE = 0x5, + SETTINGS_MAX_HEADER_LIST_SIZE = 0x6, + ) + + def __init__(self, length=0, flags=Frame.FLAG_NO_FLAGS, stream_id=0x0, settings={}): + super(SettingsFrame, self).__init__(length, flags, stream_id) + self.settings = settings + + @classmethod + def from_bytes(self, length, flags, stream_id, payload): + f = self(length=length, flags=flags, stream_id=stream_id) + + for i in xrange(0, len(payload), 6): + identifier, value = struct.unpack("!HL", payload[i:i+6]) + f.settings[identifier] = value + + return f + + def payload_bytes(self): + if self.stream_id != 0x0: + raise ValueError('SETTINGS frames MUST NOT be associated with a stream.') + + b = b'' + for identifier, value in self.settings.items(): + b += struct.pack("!HL", identifier & 0xFF, value) + + return b + +class PushPromiseFrame(Frame): + TYPE = 0x5 + VALID_FLAGS = [Frame.FLAG_END_HEADERS, Frame.FLAG_PADDED] + + def __init__(self, length=0, flags=Frame.FLAG_NO_FLAGS, stream_id=0x0, promised_stream=0x0, header_block_fragment=b'', pad_length=0): + super(PushPromiseFrame, self).__init__(length, flags, stream_id) + self.pad_length = pad_length + self.promised_stream = promised_stream + self.header_block_fragment = header_block_fragment + + @classmethod + def from_bytes(self, length, flags, stream_id, payload): + f = self(length=length, flags=flags, stream_id=stream_id) + + if f.flags & self.FLAG_PADDED: + f.pad_length, f.promised_stream = struct.unpack('!BL', payload[:5]) + f.header_block_fragment = payload[5:-f.pad_length] + else: + f.promised_stream = int(struct.unpack("!L", payload[:4])[0]) + f.header_block_fragment = payload[4:] + + f.promised_stream &= 0x7FFFFFFF + + return f + + def payload_bytes(self): + if self.stream_id == 0x0: + raise ValueError('PUSH_PROMISE frames MUST be associated with a stream.') + + if self.promised_stream == 0x0: + raise ValueError('Promised stream id not valid.') + + b = b'' + if self.flags & self.FLAG_PADDED: + b += struct.pack('!B', self.pad_length) + + b += struct.pack('!L', self.promised_stream & 0x7FFFFFFF) + b += bytes(self.header_block_fragment) + + if self.flags & self.FLAG_PADDED: + b += b'\0' * self.pad_length + + return b + +class PingFrame(Frame): + TYPE = 0x6 + VALID_FLAGS = [Frame.FLAG_ACK] + + def __init__(self, length=0, flags=Frame.FLAG_NO_FLAGS, stream_id=0x0, payload=b''): + super(PingFrame, self).__init__(length, flags, stream_id) + self.payload = payload + + @classmethod + def from_bytes(self, length, flags, stream_id, payload): + f = self(length=length, flags=flags, stream_id=stream_id) + f.payload = payload + return f + + def payload_bytes(self): + if self.stream_id != 0x0: + raise ValueError('PING frames MUST NOT be associated with a stream.') + + b = self.payload[0:8] + b += b'\0' * (8 - len(b)) + return b + +class GoAwayFrame(Frame): + TYPE = 0x7 + VALID_FLAGS = [] + + def __init__(self, length=0, flags=Frame.FLAG_NO_FLAGS, stream_id=0x0, last_stream=0x0, error_code=0x0, data=b''): + super(GoAwayFrame, self).__init__(length, flags, stream_id) + self.last_stream = last_stream + self.error_code = error_code + self.data = data + + @classmethod + def from_bytes(self, length, flags, stream_id, payload): + f = self(length=length, flags=flags, stream_id=stream_id) + + f.last_stream, f.error_code = struct.unpack("!LL", payload[:8]) + f.last_stream &= 0x7FFFFFFF + f.data = payload[8:] + + return f + + def payload_bytes(self): + if self.stream_id != 0x0: + raise ValueError('GOAWAY frames MUST NOT be associated with a stream.') + + b = struct.pack('!LL', self.last_stream & 0x7FFFFFFF, self.error_code) + b += bytes(self.data) + return b + +class WindowUpdateFrame(Frame): + TYPE = 0x8 + VALID_FLAGS = [] + + def __init__(self, length=0, flags=Frame.FLAG_NO_FLAGS, stream_id=0x0, window_size_increment=0x0): + super(WindowUpdateFrame, self).__init__(length, flags, stream_id) + self.window_size_increment = window_size_increment + + @classmethod + def from_bytes(self, length, flags, stream_id, payload): + f = self(length=length, flags=flags, stream_id=stream_id) + + f.window_size_increment = struct.unpack("!L", payload)[0] + f.window_size_increment &= 0x7FFFFFFF + + return f + + def payload_bytes(self): + if self.window_size_increment <= 0 or self.window_size_increment >= 2**31: + raise ValueError('Window Szie Increment MUST be greater than 0 and less than 2^31.') + + return struct.pack('!L', self.window_size_increment & 0x7FFFFFFF) + +class ContinuationFrame(Frame): + TYPE = 0x9 + VALID_FLAGS = [Frame.FLAG_END_HEADERS] + + def __init__(self, length=0, flags=Frame.FLAG_NO_FLAGS, stream_id=0x0, header_block_fragment=b''): + super(ContinuationFrame, self).__init__(length, flags, stream_id) + self.header_block_fragment = header_block_fragment + + @classmethod + def from_bytes(self, length, flags, stream_id, payload): + f = self(length=length, flags=flags, stream_id=stream_id) + f.header_block_fragment = payload + return f + + def payload_bytes(self): + if self.stream_id == 0x0: + raise ValueError('CONTINUATION frames MUST be associated with a stream.') + + return self.header_block_fragment + +_FRAME_CLASSES = [ + DataFrame, + HeadersFrame, + PriorityFrame, + RstStreamFrame, + SettingsFrame, + PushPromiseFrame, + PingFrame, + GoAwayFrame, + WindowUpdateFrame, + ContinuationFrame +] +FRAMES = {cls.TYPE: cls for cls in _FRAME_CLASSES} diff --git a/netlib/h2/h2.py b/netlib/h2/h2.py new file mode 100644 index 00000000..5d74c1c8 --- /dev/null +++ b/netlib/h2/h2.py @@ -0,0 +1,25 @@ +import base64 +import hashlib +import os +import struct +import io + +# "PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n" +CLIENT_CONNECTION_PREFACE = '505249202a20485454502f322e300d0a0d0a534d0d0a0d0a' + +ERROR_CODES = utils.BiDi( + NO_ERROR = 0x0, + PROTOCOL_ERROR = 0x1, + INTERNAL_ERROR = 0x2, + FLOW_CONTROL_ERROR = 0x3, + SETTINGS_TIMEOUT = 0x4, + STREAM_CLOSED = 0x5, + FRAME_SIZE_ERROR = 0x6, + REFUSED_STREAM = 0x7, + CANCEL = 0x8, + COMPRESSION_ERROR = 0x9, + CONNECT_ERROR = 0xa, + ENHANCE_YOUR_CALM = 0xb, + INADEQUATE_SECURITY = 0xc, + HTTP_1_1_REQUIRED = 0xd + ) -- cgit v1.2.3 From d6a68e1394ac57854ac1fa09fd19b88d015789e1 Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Wed, 27 May 2015 10:21:50 +0200 Subject: remove outdated workarounds --- netlib/tcp.py | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) (limited to 'netlib') diff --git a/netlib/tcp.py b/netlib/tcp.py index a5f43ea3..399203bb 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -307,10 +307,10 @@ class _Connection(object): def get_current_cipher(self): if not self.ssl_established: return None - c = SSL._lib.SSL_get_current_cipher(self.connection._ssl) - name = SSL._native(SSL._ffi.string(SSL._lib.SSL_CIPHER_get_name(c))) - bits = SSL._lib.SSL_CIPHER_get_bits(c, SSL._ffi.NULL) - version = SSL._native(SSL._ffi.string(SSL._lib.SSL_CIPHER_get_version(c))) + + name = self.connection.get_cipher_name() + bits = self.connection.get_cipher_bits() + version = self.connection.get_cipher_version() return name, bits, version def finish(self): @@ -333,10 +333,6 @@ class _Connection(object): self.connection.shutdown() except SSL.Error: pass - except KeyError as e: # pragma: no cover - # Workaround for https://github.com/pyca/pyopenssl/pull/183 - if OpenSSL.__version__ != "0.14": - raise e """ Creates an SSL Context. -- cgit v1.2.3 From 041ca5c499369ffbf115e4451b85aee77e3095c0 Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Wed, 27 May 2015 10:53:23 +0200 Subject: update TLS defaults: signature hash and DH params * SHA1 is deprecated (use SHA256) * increase RSA key to 2048 bits * increase DH params to 4096 bits (LogJam attack) --- netlib/certutils.py | 32 +++++++++++++++++++++----------- 1 file changed, 21 insertions(+), 11 deletions(-) (limited to 'netlib') diff --git a/netlib/certutils.py b/netlib/certutils.py index f5375c03..507241b2 100644 --- a/netlib/certutils.py +++ b/netlib/certutils.py @@ -8,15 +8,25 @@ import OpenSSL DEFAULT_EXP = 157680000 # = 24 * 60 * 60 * 365 * 5 # Generated with "openssl dhparam". It's too slow to generate this on startup. -DEFAULT_DHPARAM = """-----BEGIN DH PARAMETERS----- -MIGHAoGBAOdPzMbYgoYfO3YBYauCLRlE8X1XypTiAjoeCFD0qWRx8YUsZ6Sj20W5 -zsfQxlZfKovo3f2MftjkDkbI/C/tDgxoe0ZPbjy5CjdOhkzxn0oTbKTs16Rw8DyK -1LjTR65sQJkJEdgsX8TSi/cicCftJZl9CaZEaObF2bdgSgGK+PezAgEC ------END DH PARAMETERS-----""" +DEFAULT_DHPARAM = """ +-----BEGIN DH PARAMETERS----- +MIICCAKCAgEAyT6LzpwVFS3gryIo29J5icvgxCnCebcdSe/NHMkD8dKJf8suFCg3 +O2+dguLakSVif/t6dhImxInJk230HmfC8q93hdcg/j8rLGJYDKu3ik6H//BAHKIv +j5O9yjU3rXCfmVJQic2Nne39sg3CreAepEts2TvYHhVv3TEAzEqCtOuTjgDv0ntJ +Gwpj+BJBRQGG9NvprX1YGJ7WOFBP/hWU7d6tgvE6Xa7T/u9QIKpYHMIkcN/l3ZFB +chZEqVlyrcngtSXCROTPcDOQ6Q8QzhaBJS+Z6rcsd7X+haiQqvoFcmaJ08Ks6LQC +ZIL2EtYJw8V8z7C0igVEBIADZBI6OTbuuhDwRw//zU1uq52Oc48CIZlGxTYG/Evq +o9EWAXUYVzWkDSTeBH1r4z/qLPE2cnhtMxbFxuvK53jGB0emy2y1Ei6IhKshJ5qX +IB/aE7SSHyQ3MDHHkCmQJCsOd4Mo26YX61NZ+n501XjqpCBQ2+DfZCBh8Va2wDyv +A2Ryg9SUz8j0AXViRNMJgJrr446yro/FuJZwnQcO3WQnXeqSBnURqKjmqkeFP+d8 +6mk2tqJaY507lRNqtGlLnj7f5RNoBFJDCLBNurVgfvq9TCVWKDIFD4vZRjCrnl6I +rD693XKIHUCWOjMh1if6omGXKHH40QuME2gNa50+YPn1iYDl88uDbbMCAQI= +-----END DH PARAMETERS----- +""" def create_ca(o, cn, exp): key = OpenSSL.crypto.PKey() - key.generate_key(OpenSSL.crypto.TYPE_RSA, 1024) + key.generate_key(OpenSSL.crypto.TYPE_RSA, 2048) cert = OpenSSL.crypto.X509() cert.set_serial_number(int(time.time()*10000)) cert.set_version(2) @@ -39,7 +49,7 @@ def create_ca(o, cn, exp): OpenSSL.crypto.X509Extension("subjectKeyIdentifier", False, "hash", subject=cert), ]) - cert.sign(key, "sha1") + cert.sign(key, "sha256") return key, cert @@ -69,7 +79,7 @@ def dummy_cert(privkey, cacert, commonname, sans): cert.set_version(2) cert.add_extensions([OpenSSL.crypto.X509Extension("subjectAltName", False, ss)]) cert.set_pubkey(cacert.get_pubkey()) - cert.sign(privkey, "sha1") + cert.sign(privkey, "sha256") return SSLCert(cert) @@ -124,7 +134,7 @@ class CertStore(object): """ Implements an in-memory certificate store. """ - def __init__(self, default_privatekey, default_ca, default_chain_file, dhparams=None): + def __init__(self, default_privatekey, default_ca, default_chain_file, dhparams): self.default_privatekey = default_privatekey self.default_ca = default_ca self.default_chain_file = default_chain_file @@ -148,7 +158,7 @@ class CertStore(object): ) dh = OpenSSL.SSL._ffi.gc(dh, OpenSSL.SSL._lib.DH_free) return dh - + @classmethod def from_store(cls, path, basename): ca_path = os.path.join(path, basename + "-ca.pem") @@ -296,7 +306,7 @@ class SSLCert(object): self.x509 = cert def __eq__(self, other): - return self.digest("sha1") == other.digest("sha1") + return self.digest("sha256") == other.digest("sha256") def __ne__(self, other): return not self.__eq__(other) -- cgit v1.2.3 From e3d390e036430b9d7cc4b93679229fe118eb583a Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Wed, 27 May 2015 11:18:54 +0200 Subject: cleanup code with autopep8 run the following command: $ autopep8 -i -r -a -a . --- netlib/certutils.py | 56 ++++++++++++++++++-------------- netlib/h2/frame.py | 34 ++++++++++++++------ netlib/h2/h2.py | 30 ++++++++--------- netlib/http.py | 13 ++++---- netlib/http_auth.py | 22 ++++++++++--- netlib/http_cookies.py | 10 +++--- netlib/http_status.py | 84 ++++++++++++++++++++++++------------------------ netlib/odict.py | 10 ++++-- netlib/socks.py | 43 +++++++++++++------------ netlib/tcp.py | 62 +++++++++++++++++++++++------------ netlib/test.py | 24 ++++++++------ netlib/utils.py | 10 +++--- netlib/websockets.py | 87 ++++++++++++++++++++++++++------------------------ netlib/wsgi.py | 52 ++++++++++++++++-------------- 14 files changed, 308 insertions(+), 229 deletions(-) (limited to 'netlib') diff --git a/netlib/certutils.py b/netlib/certutils.py index f5375c03..da0e3355 100644 --- a/netlib/certutils.py +++ b/netlib/certutils.py @@ -1,12 +1,15 @@ from __future__ import (absolute_import, print_function, division) -import os, ssl, time, datetime +import os +import ssl +import time +import datetime import itertools from pyasn1.type import univ, constraint, char, namedtype, tag from pyasn1.codec.der.decoder import decode from pyasn1.error import PyAsn1Error import OpenSSL -DEFAULT_EXP = 157680000 # = 24 * 60 * 60 * 365 * 5 +DEFAULT_EXP = 157680000 # = 24 * 60 * 60 * 365 * 5 # Generated with "openssl dhparam". It's too slow to generate this on startup. DEFAULT_DHPARAM = """-----BEGIN DH PARAMETERS----- MIGHAoGBAOdPzMbYgoYfO3YBYauCLRlE8X1XypTiAjoeCFD0qWRx8YUsZ6Sj20W5 @@ -14,31 +17,32 @@ zsfQxlZfKovo3f2MftjkDkbI/C/tDgxoe0ZPbjy5CjdOhkzxn0oTbKTs16Rw8DyK 1LjTR65sQJkJEdgsX8TSi/cicCftJZl9CaZEaObF2bdgSgGK+PezAgEC -----END DH PARAMETERS-----""" + def create_ca(o, cn, exp): key = OpenSSL.crypto.PKey() key.generate_key(OpenSSL.crypto.TYPE_RSA, 1024) cert = OpenSSL.crypto.X509() - cert.set_serial_number(int(time.time()*10000)) + cert.set_serial_number(int(time.time() * 10000)) cert.set_version(2) cert.get_subject().CN = cn cert.get_subject().O = o - cert.gmtime_adj_notBefore(-3600*48) + cert.gmtime_adj_notBefore(-3600 * 48) cert.gmtime_adj_notAfter(exp) cert.set_issuer(cert.get_subject()) cert.set_pubkey(key) cert.add_extensions([ - OpenSSL.crypto.X509Extension("basicConstraints", True, - "CA:TRUE"), - OpenSSL.crypto.X509Extension("nsCertType", False, - "sslCA"), - OpenSSL.crypto.X509Extension("extendedKeyUsage", False, - "serverAuth,clientAuth,emailProtection,timeStamping,msCodeInd,msCodeCom,msCTLSign,msSGC,msEFS,nsSGC" - ), - OpenSSL.crypto.X509Extension("keyUsage", True, - "keyCertSign, cRLSign"), - OpenSSL.crypto.X509Extension("subjectKeyIdentifier", False, "hash", - subject=cert), - ]) + OpenSSL.crypto.X509Extension("basicConstraints", True, + "CA:TRUE"), + OpenSSL.crypto.X509Extension("nsCertType", False, + "sslCA"), + OpenSSL.crypto.X509Extension("extendedKeyUsage", False, + "serverAuth,clientAuth,emailProtection,timeStamping,msCodeInd,msCodeCom,msCTLSign,msSGC,msEFS,nsSGC" + ), + OpenSSL.crypto.X509Extension("keyUsage", True, + "keyCertSign, cRLSign"), + OpenSSL.crypto.X509Extension("subjectKeyIdentifier", False, "hash", + subject=cert), + ]) cert.sign(key, "sha1") return key, cert @@ -56,15 +60,15 @@ def dummy_cert(privkey, cacert, commonname, sans): """ ss = [] for i in sans: - ss.append("DNS: %s"%i) + ss.append("DNS: %s" % i) ss = ", ".join(ss) cert = OpenSSL.crypto.X509() - cert.gmtime_adj_notBefore(-3600*48) + cert.gmtime_adj_notBefore(-3600 * 48) cert.gmtime_adj_notAfter(DEFAULT_EXP) cert.set_issuer(cacert.get_subject()) cert.get_subject().CN = commonname - cert.set_serial_number(int(time.time()*10000)) + cert.set_serial_number(int(time.time() * 10000)) if ss: cert.set_version(2) cert.add_extensions([OpenSSL.crypto.X509Extension("subjectAltName", False, ss)]) @@ -114,6 +118,7 @@ def dummy_cert(privkey, cacert, commonname, sans): class CertStoreEntry(object): + def __init__(self, cert, privatekey, chain_file): self.cert = cert self.privatekey = privatekey @@ -121,9 +126,11 @@ class CertStoreEntry(object): class CertStore(object): + """ Implements an in-memory certificate store. """ + def __init__(self, default_privatekey, default_ca, default_chain_file, dhparams=None): self.default_privatekey = default_privatekey self.default_ca = default_ca @@ -144,11 +151,11 @@ class CertStore(object): if bio != OpenSSL.SSL._ffi.NULL: bio = OpenSSL.SSL._ffi.gc(bio, OpenSSL.SSL._lib.BIO_free) dh = OpenSSL.SSL._lib.PEM_read_bio_DHparams( - bio, OpenSSL.SSL._ffi.NULL, OpenSSL.SSL._ffi.NULL, OpenSSL.SSL._ffi.NULL - ) + bio, OpenSSL.SSL._ffi.NULL, OpenSSL.SSL._ffi.NULL, OpenSSL.SSL._ffi.NULL + ) dh = OpenSSL.SSL._ffi.gc(dh, OpenSSL.SSL._lib.DH_free) return dh - + @classmethod def from_store(cls, path, basename): ca_path = os.path.join(path, basename + "-ca.pem") @@ -277,8 +284,8 @@ class _GeneralName(univ.Choice): # other types. componentType = namedtype.NamedTypes( namedtype.NamedType('dNSName', char.IA5String().subtype( - implicitTag=tag.Tag(tag.tagClassContext, tag.tagFormatSimple, 2) - ) + implicitTag=tag.Tag(tag.tagClassContext, tag.tagFormatSimple, 2) + ) ), ) @@ -289,6 +296,7 @@ class _GeneralNames(univ.SequenceOf): class SSLCert(object): + def __init__(self, cert): """ Returns a (common name, [subject alternative names]) tuple. diff --git a/netlib/h2/frame.py b/netlib/h2/frame.py index 52cc2992..d846b3b9 100644 --- a/netlib/h2/frame.py +++ b/netlib/h2/frame.py @@ -5,8 +5,11 @@ import struct import io from .. import utils, odict, tcp +from functools import reduce + class Frame(object): + """ Baseclass Frame contains header @@ -53,6 +56,7 @@ class Frame(object): def __eq__(self, other): return self.to_bytes() == other.to_bytes() + class DataFrame(Frame): TYPE = 0x0 VALID_FLAGS = [Frame.FLAG_END_STREAM, Frame.FLAG_PADDED] @@ -89,11 +93,13 @@ class DataFrame(Frame): return b + class HeadersFrame(Frame): TYPE = 0x1 VALID_FLAGS = [Frame.FLAG_END_STREAM, Frame.FLAG_END_HEADERS, Frame.FLAG_PADDED, Frame.FLAG_PRIORITY] - def __init__(self, length=0, flags=Frame.FLAG_NO_FLAGS, stream_id=0x0, header_block_fragment=b'', pad_length=0, exclusive=False, stream_dependency=0x0, weight=0): + def __init__(self, length=0, flags=Frame.FLAG_NO_FLAGS, stream_id=0x0, header_block_fragment=b'', + pad_length=0, exclusive=False, stream_dependency=0x0, weight=0): super(HeadersFrame, self).__init__(length, flags, stream_id) self.header_block_fragment = header_block_fragment self.pad_length = pad_length @@ -137,6 +143,7 @@ class HeadersFrame(Frame): return b + class PriorityFrame(Frame): TYPE = 0x2 VALID_FLAGS = [] @@ -166,6 +173,7 @@ class PriorityFrame(Frame): return struct.pack('!LB', (int(self.exclusive) << 31) | self.stream_dependency, self.weight) + class RstStreamFrame(Frame): TYPE = 0x3 VALID_FLAGS = [] @@ -186,18 +194,19 @@ class RstStreamFrame(Frame): return struct.pack('!L', self.error_code) + class SettingsFrame(Frame): TYPE = 0x4 VALID_FLAGS = [Frame.FLAG_ACK] SETTINGS = utils.BiDi( - SETTINGS_HEADER_TABLE_SIZE = 0x1, - SETTINGS_ENABLE_PUSH = 0x2, - SETTINGS_MAX_CONCURRENT_STREAMS = 0x3, - SETTINGS_INITIAL_WINDOW_SIZE = 0x4, - SETTINGS_MAX_FRAME_SIZE = 0x5, - SETTINGS_MAX_HEADER_LIST_SIZE = 0x6, - ) + SETTINGS_HEADER_TABLE_SIZE=0x1, + SETTINGS_ENABLE_PUSH=0x2, + SETTINGS_MAX_CONCURRENT_STREAMS=0x3, + SETTINGS_INITIAL_WINDOW_SIZE=0x4, + SETTINGS_MAX_FRAME_SIZE=0x5, + SETTINGS_MAX_HEADER_LIST_SIZE=0x6, + ) def __init__(self, length=0, flags=Frame.FLAG_NO_FLAGS, stream_id=0x0, settings={}): super(SettingsFrame, self).__init__(length, flags, stream_id) @@ -208,7 +217,7 @@ class SettingsFrame(Frame): f = self(length=length, flags=flags, stream_id=stream_id) for i in xrange(0, len(payload), 6): - identifier, value = struct.unpack("!HL", payload[i:i+6]) + identifier, value = struct.unpack("!HL", payload[i:i + 6]) f.settings[identifier] = value return f @@ -223,6 +232,7 @@ class SettingsFrame(Frame): return b + class PushPromiseFrame(Frame): TYPE = 0x5 VALID_FLAGS = [Frame.FLAG_END_HEADERS, Frame.FLAG_PADDED] @@ -267,6 +277,7 @@ class PushPromiseFrame(Frame): return b + class PingFrame(Frame): TYPE = 0x6 VALID_FLAGS = [Frame.FLAG_ACK] @@ -289,6 +300,7 @@ class PingFrame(Frame): b += b'\0' * (8 - len(b)) return b + class GoAwayFrame(Frame): TYPE = 0x7 VALID_FLAGS = [] @@ -317,6 +329,7 @@ class GoAwayFrame(Frame): b += bytes(self.data) return b + class WindowUpdateFrame(Frame): TYPE = 0x8 VALID_FLAGS = [] @@ -335,11 +348,12 @@ class WindowUpdateFrame(Frame): return f def payload_bytes(self): - if self.window_size_increment <= 0 or self.window_size_increment >= 2**31: + if self.window_size_increment <= 0 or self.window_size_increment >= 2 ** 31: raise ValueError('Window Szie Increment MUST be greater than 0 and less than 2^31.') return struct.pack('!L', self.window_size_increment & 0x7FFFFFFF) + class ContinuationFrame(Frame): TYPE = 0x9 VALID_FLAGS = [Frame.FLAG_END_HEADERS] diff --git a/netlib/h2/h2.py b/netlib/h2/h2.py index 5d74c1c8..1a39a635 100644 --- a/netlib/h2/h2.py +++ b/netlib/h2/h2.py @@ -8,18 +8,18 @@ import io CLIENT_CONNECTION_PREFACE = '505249202a20485454502f322e300d0a0d0a534d0d0a0d0a' ERROR_CODES = utils.BiDi( - NO_ERROR = 0x0, - PROTOCOL_ERROR = 0x1, - INTERNAL_ERROR = 0x2, - FLOW_CONTROL_ERROR = 0x3, - SETTINGS_TIMEOUT = 0x4, - STREAM_CLOSED = 0x5, - FRAME_SIZE_ERROR = 0x6, - REFUSED_STREAM = 0x7, - CANCEL = 0x8, - COMPRESSION_ERROR = 0x9, - CONNECT_ERROR = 0xa, - ENHANCE_YOUR_CALM = 0xb, - INADEQUATE_SECURITY = 0xc, - HTTP_1_1_REQUIRED = 0xd - ) + NO_ERROR=0x0, + PROTOCOL_ERROR=0x1, + INTERNAL_ERROR=0x2, + FLOW_CONTROL_ERROR=0x3, + SETTINGS_TIMEOUT=0x4, + STREAM_CLOSED=0x5, + FRAME_SIZE_ERROR=0x6, + REFUSED_STREAM=0x7, + CANCEL=0x8, + COMPRESSION_ERROR=0x9, + CONNECT_ERROR=0xa, + ENHANCE_YOUR_CALM=0xb, + INADEQUATE_SECURITY=0xc, + HTTP_1_1_REQUIRED=0xd +) diff --git a/netlib/http.py b/netlib/http.py index 43155486..47658097 100644 --- a/netlib/http.py +++ b/netlib/http.py @@ -8,6 +8,7 @@ from . import odict, utils, tcp, http_status class HttpError(Exception): + def __init__(self, code, message): super(HttpError, self).__init__(message) self.code = code @@ -95,7 +96,7 @@ def read_headers(fp): """ ret = [] name = '' - while 1: + while True: line = fp.readline() if not line or line == '\r\n' or line == '\n': break @@ -337,7 +338,7 @@ def read_http_body_chunked( otherwise """ if max_chunk_size is None: - max_chunk_size = limit or sys.maxint + max_chunk_size = limit or sys.maxsize expected_size = expected_http_body_size( headers, is_request, request_method, response_code @@ -399,10 +400,10 @@ def expected_http_body_size(headers, is_request, request_method, response_code): request_method = request_method.upper() if (not is_request and ( - request_method == "HEAD" or - (request_method == "CONNECT" and response_code == 200) or - response_code in [204, 304] or - 100 <= response_code <= 199)): + request_method == "HEAD" or + (request_method == "CONNECT" and response_code == 200) or + response_code in [204, 304] or + 100 <= response_code <= 199)): return 0 if has_chunked_encoding(headers): return None diff --git a/netlib/http_auth.py b/netlib/http_auth.py index 296e094c..261b6654 100644 --- a/netlib/http_auth.py +++ b/netlib/http_auth.py @@ -4,9 +4,11 @@ from . import http class NullProxyAuth(object): + """ No proxy auth at all (returns empty challange headers) """ + def __init__(self, password_manager): self.password_manager = password_manager @@ -48,7 +50,7 @@ class BasicProxyAuth(NullProxyAuth): if not parts: return False scheme, username, password = parts - if scheme.lower()!='basic': + if scheme.lower() != 'basic': return False if not self.password_manager.test(username, password): return False @@ -56,18 +58,21 @@ class BasicProxyAuth(NullProxyAuth): return True def auth_challenge_headers(self): - return {self.CHALLENGE_HEADER:'Basic realm="%s"'%self.realm} + return {self.CHALLENGE_HEADER: 'Basic realm="%s"' % self.realm} class PassMan(object): + def test(self, username, password_token): return False class PassManNonAnon(PassMan): + """ Ensure the user specifies a username, accept any password. """ + def test(self, username, password_token): if username: return True @@ -75,9 +80,11 @@ class PassManNonAnon(PassMan): class PassManHtpasswd(PassMan): + """ Read usernames and passwords from an htpasswd file """ + def __init__(self, path): """ Raises ValueError if htpasswd file is invalid. @@ -90,14 +97,16 @@ class PassManHtpasswd(PassMan): class PassManSingleUser(PassMan): + def __init__(self, username, password): self.username, self.password = username, password def test(self, username, password_token): - return self.username==username and self.password==password_token + return self.username == username and self.password == password_token class AuthAction(Action): + """ Helper class to allow seamless integration int argparse. Example usage: parser.add_argument( @@ -106,16 +115,18 @@ class AuthAction(Action): help="Allow access to any user long as a credentials are specified." ) """ + def __call__(self, parser, namespace, values, option_string=None): passman = self.getPasswordManager(values) authenticator = BasicProxyAuth(passman, "mitmproxy") setattr(namespace, self.dest, authenticator) - def getPasswordManager(self, s): # pragma: nocover + def getPasswordManager(self, s): # pragma: nocover raise NotImplementedError() class SingleuserAuthAction(AuthAction): + def getPasswordManager(self, s): if len(s.split(':')) != 2: raise ArgumentTypeError( @@ -126,11 +137,12 @@ class SingleuserAuthAction(AuthAction): class NonanonymousAuthAction(AuthAction): + def getPasswordManager(self, s): return PassManNonAnon() class HtpasswdAuthAction(AuthAction): + def getPasswordManager(self, s): return PassManHtpasswd(s) - diff --git a/netlib/http_cookies.py b/netlib/http_cookies.py index 8e245891..73e3f589 100644 --- a/netlib/http_cookies.py +++ b/netlib/http_cookies.py @@ -96,7 +96,7 @@ def _read_pairs(s, off=0, specials=()): specials: a lower-cased list of keys that may contain commas """ vals = [] - while 1: + while True: lhs, off = _read_token(s, off) lhs = lhs.lstrip() if lhs: @@ -135,15 +135,15 @@ def _format_pairs(lst, specials=(), sep="; "): else: if k.lower() not in specials and _has_special(v): v = ESCAPE.sub(r"\\\1", v) - v = '"%s"'%v - vals.append("%s=%s"%(k, v)) + v = '"%s"' % v + vals.append("%s=%s" % (k, v)) return sep.join(vals) def _format_set_cookie_pairs(lst): return _format_pairs( lst, - specials = ("expires", "path") + specials=("expires", "path") ) @@ -154,7 +154,7 @@ def _parse_set_cookie_pairs(s): """ pairs, off = _read_pairs( s, - specials = ("expires", "path") + specials=("expires", "path") ) return pairs diff --git a/netlib/http_status.py b/netlib/http_status.py index 7dba2d56..dc09f465 100644 --- a/netlib/http_status.py +++ b/netlib/http_status.py @@ -1,51 +1,51 @@ from __future__ import (absolute_import, print_function, division) -CONTINUE = 100 -SWITCHING = 101 -OK = 200 -CREATED = 201 -ACCEPTED = 202 -NON_AUTHORITATIVE_INFORMATION = 203 -NO_CONTENT = 204 -RESET_CONTENT = 205 -PARTIAL_CONTENT = 206 -MULTI_STATUS = 207 +CONTINUE = 100 +SWITCHING = 101 +OK = 200 +CREATED = 201 +ACCEPTED = 202 +NON_AUTHORITATIVE_INFORMATION = 203 +NO_CONTENT = 204 +RESET_CONTENT = 205 +PARTIAL_CONTENT = 206 +MULTI_STATUS = 207 -MULTIPLE_CHOICE = 300 -MOVED_PERMANENTLY = 301 -FOUND = 302 -SEE_OTHER = 303 -NOT_MODIFIED = 304 -USE_PROXY = 305 -TEMPORARY_REDIRECT = 307 +MULTIPLE_CHOICE = 300 +MOVED_PERMANENTLY = 301 +FOUND = 302 +SEE_OTHER = 303 +NOT_MODIFIED = 304 +USE_PROXY = 305 +TEMPORARY_REDIRECT = 307 -BAD_REQUEST = 400 -UNAUTHORIZED = 401 -PAYMENT_REQUIRED = 402 -FORBIDDEN = 403 -NOT_FOUND = 404 -NOT_ALLOWED = 405 -NOT_ACCEPTABLE = 406 -PROXY_AUTH_REQUIRED = 407 -REQUEST_TIMEOUT = 408 -CONFLICT = 409 -GONE = 410 -LENGTH_REQUIRED = 411 -PRECONDITION_FAILED = 412 -REQUEST_ENTITY_TOO_LARGE = 413 -REQUEST_URI_TOO_LONG = 414 -UNSUPPORTED_MEDIA_TYPE = 415 +BAD_REQUEST = 400 +UNAUTHORIZED = 401 +PAYMENT_REQUIRED = 402 +FORBIDDEN = 403 +NOT_FOUND = 404 +NOT_ALLOWED = 405 +NOT_ACCEPTABLE = 406 +PROXY_AUTH_REQUIRED = 407 +REQUEST_TIMEOUT = 408 +CONFLICT = 409 +GONE = 410 +LENGTH_REQUIRED = 411 +PRECONDITION_FAILED = 412 +REQUEST_ENTITY_TOO_LARGE = 413 +REQUEST_URI_TOO_LONG = 414 +UNSUPPORTED_MEDIA_TYPE = 415 REQUESTED_RANGE_NOT_SATISFIABLE = 416 -EXPECTATION_FAILED = 417 +EXPECTATION_FAILED = 417 -INTERNAL_SERVER_ERROR = 500 -NOT_IMPLEMENTED = 501 -BAD_GATEWAY = 502 -SERVICE_UNAVAILABLE = 503 -GATEWAY_TIMEOUT = 504 -HTTP_VERSION_NOT_SUPPORTED = 505 -INSUFFICIENT_STORAGE_SPACE = 507 -NOT_EXTENDED = 510 +INTERNAL_SERVER_ERROR = 500 +NOT_IMPLEMENTED = 501 +BAD_GATEWAY = 502 +SERVICE_UNAVAILABLE = 503 +GATEWAY_TIMEOUT = 504 +HTTP_VERSION_NOT_SUPPORTED = 505 +INSUFFICIENT_STORAGE_SPACE = 507 +NOT_EXTENDED = 510 RESPONSES = { # 100 diff --git a/netlib/odict.py b/netlib/odict.py index dd738c55..f52acd50 100644 --- a/netlib/odict.py +++ b/netlib/odict.py @@ -1,5 +1,6 @@ from __future__ import (absolute_import, print_function, division) -import re, copy +import re +import copy def safe_subn(pattern, repl, target, *args, **kwargs): @@ -12,10 +13,12 @@ def safe_subn(pattern, repl, target, *args, **kwargs): class ODict(object): + """ A dictionary-like object for managing ordered (key, value) data. Think about it as a convenient interface to a list of (key, value) tuples. """ + def __init__(self, lst=None): self.lst = lst or [] @@ -157,7 +160,7 @@ class ODict(object): "key: value" """ for k, v in self.lst: - s = "%s: %s"%(k, v) + s = "%s: %s" % (k, v) if re.search(expr, s): return True return False @@ -192,11 +195,12 @@ class ODict(object): return klass([list(i) for i in state]) - class ODictCaseless(ODict): + """ A variant of ODict with "caseless" keys. This version _preserves_ key case, but does not consider case when setting or getting items. """ + def _kconv(self, s): return s.lower() diff --git a/netlib/socks.py b/netlib/socks.py index 6f9f57bd..5a73c61a 100644 --- a/netlib/socks.py +++ b/netlib/socks.py @@ -6,49 +6,50 @@ from . import tcp, utils class SocksError(Exception): + def __init__(self, code, message): super(SocksError, self).__init__(message) self.code = code VERSION = utils.BiDi( - SOCKS4 = 0x04, - SOCKS5 = 0x05 + SOCKS4=0x04, + SOCKS5=0x05 ) CMD = utils.BiDi( - CONNECT = 0x01, - BIND = 0x02, - UDP_ASSOCIATE = 0x03 + CONNECT=0x01, + BIND=0x02, + UDP_ASSOCIATE=0x03 ) ATYP = utils.BiDi( - IPV4_ADDRESS = 0x01, - DOMAINNAME = 0x03, - IPV6_ADDRESS = 0x04 + IPV4_ADDRESS=0x01, + DOMAINNAME=0x03, + IPV6_ADDRESS=0x04 ) REP = utils.BiDi( - SUCCEEDED = 0x00, - GENERAL_SOCKS_SERVER_FAILURE = 0x01, - CONNECTION_NOT_ALLOWED_BY_RULESET = 0x02, - NETWORK_UNREACHABLE = 0x03, - HOST_UNREACHABLE = 0x04, - CONNECTION_REFUSED = 0x05, - TTL_EXPIRED = 0x06, - COMMAND_NOT_SUPPORTED = 0x07, - ADDRESS_TYPE_NOT_SUPPORTED = 0x08, + SUCCEEDED=0x00, + GENERAL_SOCKS_SERVER_FAILURE=0x01, + CONNECTION_NOT_ALLOWED_BY_RULESET=0x02, + NETWORK_UNREACHABLE=0x03, + HOST_UNREACHABLE=0x04, + CONNECTION_REFUSED=0x05, + TTL_EXPIRED=0x06, + COMMAND_NOT_SUPPORTED=0x07, + ADDRESS_TYPE_NOT_SUPPORTED=0x08, ) METHOD = utils.BiDi( - NO_AUTHENTICATION_REQUIRED = 0x00, - GSSAPI = 0x01, - USERNAME_PASSWORD = 0x02, - NO_ACCEPTABLE_METHODS = 0xFF + NO_AUTHENTICATION_REQUIRED=0x00, + GSSAPI=0x01, + USERNAME_PASSWORD=0x02, + NO_ACCEPTABLE_METHODS=0xFF ) diff --git a/netlib/tcp.py b/netlib/tcp.py index 399203bb..7c115554 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -22,14 +22,28 @@ OP_NO_SSLv2 = SSL.OP_NO_SSLv2 OP_NO_SSLv3 = SSL.OP_NO_SSLv3 -class NetLibError(Exception): pass -class NetLibDisconnect(NetLibError): pass -class NetLibIncomplete(NetLibError): pass -class NetLibTimeout(NetLibError): pass -class NetLibSSLError(NetLibError): pass +class NetLibError(Exception): + pass + + +class NetLibDisconnect(NetLibError): + pass + + +class NetLibIncomplete(NetLibError): + pass + + +class NetLibTimeout(NetLibError): + pass + + +class NetLibSSLError(NetLibError): + pass class SSLKeyLogger(object): + def __init__(self, filename): self.filename = filename self.f = None @@ -67,6 +81,7 @@ log_ssl_key = SSLKeyLogger.create_logfun(os.getenv("MITMPROXY_SSLKEYLOGFILE") or class _FileLike(object): BLOCKSIZE = 1024 * 32 + def __init__(self, o): self.o = o self._log = None @@ -112,6 +127,7 @@ class _FileLike(object): class Writer(_FileLike): + def flush(self): """ May raise NetLibDisconnect @@ -119,7 +135,7 @@ class Writer(_FileLike): if hasattr(self.o, "flush"): try: self.o.flush() - except (socket.error, IOError), v: + except (socket.error, IOError) as v: raise NetLibDisconnect(str(v)) def write(self, v): @@ -135,11 +151,12 @@ class Writer(_FileLike): r = self.o.write(v) self.add_log(v[:r]) return r - except (SSL.Error, socket.error) as e: + except (SSL.Error, socket.error) as e: raise NetLibDisconnect(str(e)) class Reader(_FileLike): + def read(self, length): """ If length is -1, we read until connection closes. @@ -180,7 +197,7 @@ class Reader(_FileLike): self.add_log(result) return result - def readline(self, size = None): + def readline(self, size=None): result = '' bytes_read = 0 while True: @@ -204,16 +221,18 @@ class Reader(_FileLike): result = self.read(length) if length != -1 and len(result) != length: raise NetLibIncomplete( - "Expected %s bytes, got %s"%(length, len(result)) + "Expected %s bytes, got %s" % (length, len(result)) ) return result class Address(object): + """ This class wraps an IPv4/IPv6 tuple to provide named attributes and ipv6 information. """ + def __init__(self, address, use_ipv6=False): self.address = tuple(address) self.use_ipv6 = use_ipv6 @@ -304,6 +323,7 @@ def close_socket(sock): class _Connection(object): + def get_current_cipher(self): if not self.ssl_established: return None @@ -319,7 +339,7 @@ class _Connection(object): # (We call _FileLike.set_descriptor(conn)) # Closing the socket is not our task, therefore we don't call close # then. - if type(self.connection) != SSL.Connection: + if not isinstance(self.connection, SSL.Connection): if not getattr(self.wfile, "closed", False): try: self.wfile.flush() @@ -337,6 +357,7 @@ class _Connection(object): """ Creates an SSL Context. """ + def _create_ssl_context(self, method=SSLv23_METHOD, options=(OP_NO_SSLv2 | OP_NO_SSLv3), @@ -362,8 +383,8 @@ class _Connection(object): if cipher_list: try: context.set_cipher_list(cipher_list) - except SSL.Error, v: - raise NetLibError("SSL cipher specification error: %s"%str(v)) + except SSL.Error as v: + raise NetLibError("SSL cipher specification error: %s" % str(v)) # SSLKEYLOGFILE if log_ssl_key: @@ -380,7 +401,7 @@ class TCPClient(_Connection): # Make sure to close the real socket, not the SSL proxy. # OpenSSL is really good at screwing up, i.e. when trying to recv from a failed connection, # it tries to renegotiate... - if type(self.connection) == SSL.Connection: + if isinstance(self.connection, SSL.Connection): close_socket(self.connection._socket) else: close_socket(self.connection) @@ -400,8 +421,8 @@ class TCPClient(_Connection): try: context.use_privatekey_file(cert) context.use_certificate_file(cert) - except SSL.Error, v: - raise NetLibError("SSL client certificate error: %s"%str(v)) + except SSL.Error as v: + raise NetLibError("SSL client certificate error: %s" % str(v)) return context def convert_to_ssl(self, sni=None, **sslctx_kwargs): @@ -418,8 +439,8 @@ class TCPClient(_Connection): self.connection.set_connect_state() try: self.connection.do_handshake() - except SSL.Error, v: - raise NetLibError("SSL handshake error: %s"%repr(v)) + except SSL.Error as v: + raise NetLibError("SSL handshake error: %s" % repr(v)) self.ssl_established = True self.cert = certutils.SSLCert(self.connection.get_peer_certificate()) self.rfile.set_descriptor(self.connection) @@ -435,7 +456,7 @@ class TCPClient(_Connection): self.source_address = Address(connection.getsockname()) self.rfile = Reader(connection.makefile('rb', self.rbufsize)) self.wfile = Writer(connection.makefile('wb', self.wbufsize)) - except (socket.error, IOError), err: + except (socket.error, IOError) as err: raise NetLibError('Error connecting to "%s": %s' % (self.address.host, err)) self.connection = connection @@ -447,6 +468,7 @@ class TCPClient(_Connection): class BaseHandler(_Connection): + """ The instantiator is expected to call the handle() and finish() methods. @@ -531,8 +553,8 @@ class BaseHandler(_Connection): self.connection.set_accept_state() try: self.connection.do_handshake() - except SSL.Error, v: - raise NetLibError("SSL handshake error: %s"%repr(v)) + except SSL.Error as v: + raise NetLibError("SSL handshake error: %s" % repr(v)) self.ssl_established = True self.rfile.set_descriptor(self.connection) self.wfile.set_descriptor(self.connection) diff --git a/netlib/test.py b/netlib/test.py index db30c0e6..b6f94273 100644 --- a/netlib/test.py +++ b/netlib/test.py @@ -1,9 +1,13 @@ from __future__ import (absolute_import, print_function, division) -import threading, Queue, cStringIO +import threading +import Queue +import cStringIO import OpenSSL from . import tcp, certutils + class ServerThread(threading.Thread): + def __init__(self, server): self.server = server threading.Thread.__init__(self) @@ -19,6 +23,7 @@ class ServerTestBase(object): ssl = None handler = None addr = ("localhost", 0) + @classmethod def setupAll(cls): cls.q = Queue.Queue() @@ -41,10 +46,11 @@ class ServerTestBase(object): class TServer(tcp.TCPServer): + def __init__(self, ssl, q, handler_klass, addr): """ ssl: A dictionary of SSL parameters: - + cert, key, request_client_cert, cipher_list, dhparams, v3_only """ @@ -70,13 +76,13 @@ class TServer(tcp.TCPServer): options = None h.convert_to_ssl( cert, key, - method = method, - options = options, - handle_sni = getattr(h, "handle_sni", None), - request_client_cert = self.ssl["request_client_cert"], - cipher_list = self.ssl.get("cipher_list", None), - dhparams = self.ssl.get("dhparams", None), - chain_file = self.ssl.get("chain_file", None) + method=method, + options=options, + handle_sni=getattr(h, "handle_sni", None), + request_client_cert=self.ssl["request_client_cert"], + cipher_list=self.ssl.get("cipher_list", None), + dhparams=self.ssl.get("dhparams", None), + chain_file=self.ssl.get("chain_file", None) ) h.handle() h.finish() diff --git a/netlib/utils.py b/netlib/utils.py index 7e539977..9c5404e6 100644 --- a/netlib/utils.py +++ b/netlib/utils.py @@ -68,6 +68,7 @@ def getbit(byte, offset): class BiDi: + """ A wee utility class for keeping bi-directional mappings, like field constants in protocols. Names are attributes on the object, dict-like @@ -77,6 +78,7 @@ class BiDi: assert CONST.a == 1 assert CONST.get_name(1) == "a" """ + def __init__(self, **kwargs): self.names = kwargs self.values = {} @@ -96,15 +98,15 @@ class BiDi: def pretty_size(size): suffixes = [ - ("B", 2**10), - ("kB", 2**20), - ("MB", 2**30), + ("B", 2 ** 10), + ("kB", 2 ** 20), + ("MB", 2 ** 30), ] for suf, lim in suffixes: if size >= lim: continue else: - x = round(size/float(lim/2**10), 2) + x = round(size / float(lim / 2 ** 10), 2) if x == int(x): x = int(x) return str(x) + suf diff --git a/netlib/websockets.py b/netlib/websockets.py index a2d55c19..63dc03f1 100644 --- a/netlib/websockets.py +++ b/netlib/websockets.py @@ -26,16 +26,17 @@ MAX_64_BIT_INT = (1 << 64) OPCODE = utils.BiDi( - CONTINUE = 0x00, - TEXT = 0x01, - BINARY = 0x02, - CLOSE = 0x08, - PING = 0x09, - PONG = 0x0a + CONTINUE=0x00, + TEXT=0x01, + BINARY=0x02, + CLOSE=0x08, + PING=0x09, + PONG=0x0a ) class Masker: + """ Data sent from the server must be masked to prevent malicious clients from sending data over the wire in predictable patterns @@ -43,6 +44,7 @@ class Masker: Servers do not have to mask data they send to the client. https://tools.ietf.org/html/rfc6455#section-5.3 """ + def __init__(self, key): self.key = key self.masks = [utils.bytes_to_int(byte) for byte in key] @@ -128,17 +130,18 @@ DEFAULT = object() class FrameHeader: + def __init__( self, - opcode = OPCODE.TEXT, - payload_length = 0, - fin = False, - rsv1 = False, - rsv2 = False, - rsv3 = False, - masking_key = DEFAULT, - mask = DEFAULT, - length_code = DEFAULT + opcode=OPCODE.TEXT, + payload_length=0, + fin=False, + rsv1=False, + rsv2=False, + rsv3=False, + masking_key=DEFAULT, + mask=DEFAULT, + length_code=DEFAULT ): if not 0 <= opcode < 2 ** 4: raise ValueError("opcode must be 0-16") @@ -182,9 +185,9 @@ class FrameHeader: if flags: vals.extend([":", "|".join(flags)]) if self.masking_key: - vals.append(":key=%s"%repr(self.masking_key)) + vals.append(":key=%s" % repr(self.masking_key)) if self.payload_length: - vals.append(" %s"%utils.pretty_size(self.payload_length)) + vals.append(" %s" % utils.pretty_size(self.payload_length)) return "".join(vals) def to_bytes(self): @@ -246,15 +249,15 @@ class FrameHeader: masking_key = None return klass( - fin = fin, - rsv1 = rsv1, - rsv2 = rsv2, - rsv3 = rsv3, - opcode = opcode, - mask = mask_bit, - length_code = length_code, - payload_length = payload_length, - masking_key = masking_key, + fin=fin, + rsv1=rsv1, + rsv2=rsv2, + rsv3=rsv3, + opcode=opcode, + mask=mask_bit, + length_code=length_code, + payload_length=payload_length, + masking_key=masking_key, ) def __eq__(self, other): @@ -262,6 +265,7 @@ class FrameHeader: class Frame(object): + """ Represents one websockets frame. Constructor takes human readable forms of the frame components @@ -287,13 +291,14 @@ class Frame(object): | Payload Data continued ... | +---------------------------------------------------------------+ """ - def __init__(self, payload = "", **kwargs): + + def __init__(self, payload="", **kwargs): self.payload = payload kwargs["payload_length"] = kwargs.get("payload_length", len(payload)) self.header = FrameHeader(**kwargs) @classmethod - def default(cls, message, from_client = False): + def default(cls, message, from_client=False): """ Construct a basic websocket frame from some default values. Creates a non-fragmented text frame. @@ -307,10 +312,10 @@ class Frame(object): return cls( message, - fin = 1, # final frame - opcode = OPCODE.TEXT, # text - mask = mask_bit, - masking_key = masking_key, + fin=1, # final frame + opcode=OPCODE.TEXT, # text + mask=mask_bit, + masking_key=masking_key, ) @classmethod @@ -356,15 +361,15 @@ class Frame(object): return cls( payload, - fin = header.fin, - opcode = header.opcode, - mask = header.mask, - payload_length = header.payload_length, - masking_key = header.masking_key, - rsv1 = header.rsv1, - rsv2 = header.rsv2, - rsv3 = header.rsv3, - length_code = header.length_code + fin=header.fin, + opcode=header.opcode, + mask=header.mask, + payload_length=header.payload_length, + masking_key=header.masking_key, + rsv1=header.rsv1, + rsv2=header.rsv2, + rsv3=header.rsv3, + length_code=header.length_code ) def __eq__(self, other): diff --git a/netlib/wsgi.py b/netlib/wsgi.py index 1b979608..f393039a 100644 --- a/netlib/wsgi.py +++ b/netlib/wsgi.py @@ -7,17 +7,20 @@ from . import odict, tcp class ClientConn(object): + def __init__(self, address): self.address = tcp.Address.wrap(address) class Flow(object): + def __init__(self, address, request): self.client_conn = ClientConn(address) self.request = request class Request(object): + def __init__(self, scheme, method, path, headers, content): self.scheme, self.method, self.path = scheme, method, path self.headers, self.content = headers, content @@ -42,6 +45,7 @@ def date_time_string(): class WSGIAdaptor(object): + def __init__(self, app, domain, port, sversion): self.app, self.domain, self.port, self.sversion = app, domain, port, sversion @@ -52,24 +56,24 @@ class WSGIAdaptor(object): path_info = flow.request.path query = '' environ = { - 'wsgi.version': (1, 0), - 'wsgi.url_scheme': flow.request.scheme, - 'wsgi.input': cStringIO.StringIO(flow.request.content), - 'wsgi.errors': errsoc, - 'wsgi.multithread': True, - 'wsgi.multiprocess': False, - 'wsgi.run_once': False, - 'SERVER_SOFTWARE': self.sversion, - 'REQUEST_METHOD': flow.request.method, - 'SCRIPT_NAME': '', - 'PATH_INFO': urllib.unquote(path_info), - 'QUERY_STRING': query, - 'CONTENT_TYPE': flow.request.headers.get('Content-Type', [''])[0], - 'CONTENT_LENGTH': flow.request.headers.get('Content-Length', [''])[0], - 'SERVER_NAME': self.domain, - 'SERVER_PORT': str(self.port), + 'wsgi.version': (1, 0), + 'wsgi.url_scheme': flow.request.scheme, + 'wsgi.input': cStringIO.StringIO(flow.request.content), + 'wsgi.errors': errsoc, + 'wsgi.multithread': True, + 'wsgi.multiprocess': False, + 'wsgi.run_once': False, + 'SERVER_SOFTWARE': self.sversion, + 'REQUEST_METHOD': flow.request.method, + 'SCRIPT_NAME': '', + 'PATH_INFO': urllib.unquote(path_info), + 'QUERY_STRING': query, + 'CONTENT_TYPE': flow.request.headers.get('Content-Type', [''])[0], + 'CONTENT_LENGTH': flow.request.headers.get('Content-Length', [''])[0], + 'SERVER_NAME': self.domain, + 'SERVER_PORT': str(self.port), # FIXME: We need to pick up the protocol read from the request. - 'SERVER_PROTOCOL': "HTTP/1.1", + 'SERVER_PROTOCOL': "HTTP/1.1", } environ.update(extra) if flow.client_conn.address: @@ -91,25 +95,25 @@ class WSGIAdaptor(object):

Internal Server Error

%s"
- """%s + """ % s if not headers_sent: soc.write("HTTP/1.1 500 Internal Server Error\r\n") soc.write("Content-Type: text/html\r\n") - soc.write("Content-Length: %s\r\n"%len(c)) + soc.write("Content-Length: %s\r\n" % len(c)) soc.write("\r\n") soc.write(c) def serve(self, request, soc, **env): state = dict( - response_started = False, - headers_sent = False, - status = None, - headers = None + response_started=False, + headers_sent=False, + status=None, + headers=None ) def write(data): if not state["headers_sent"]: - soc.write("HTTP/1.1 %s\r\n"%state["status"]) + soc.write("HTTP/1.1 %s\r\n" % state["status"]) h = state["headers"] if 'server' not in h: h["Server"] = [self.sversion] -- cgit v1.2.3 From 161bc2cfaa8b70b4c2cab5562784df34013452e1 Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Wed, 27 May 2015 11:25:33 +0200 Subject: cleanup code with autoflake run the following command: $ autoflake -r -i --remove-all-unused-imports --remove-unused-variables . --- netlib/h2/frame.py | 6 +----- netlib/h2/h2.py | 5 ----- netlib/http_auth.py | 1 - netlib/http_cookies.py | 1 - netlib/tcp.py | 2 -- 5 files changed, 1 insertion(+), 14 deletions(-) (limited to 'netlib') diff --git a/netlib/h2/frame.py b/netlib/h2/frame.py index d846b3b9..a7e81f48 100644 --- a/netlib/h2/frame.py +++ b/netlib/h2/frame.py @@ -1,10 +1,6 @@ -import base64 -import hashlib -import os import struct -import io -from .. import utils, odict, tcp +from .. import utils from functools import reduce diff --git a/netlib/h2/h2.py b/netlib/h2/h2.py index 1a39a635..7a85226f 100644 --- a/netlib/h2/h2.py +++ b/netlib/h2/h2.py @@ -1,8 +1,3 @@ -import base64 -import hashlib -import os -import struct -import io # "PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n" CLIENT_CONNECTION_PREFACE = '505249202a20485454502f322e300d0a0d0a534d0d0a0d0a' diff --git a/netlib/http_auth.py b/netlib/http_auth.py index 261b6654..0143760c 100644 --- a/netlib/http_auth.py +++ b/netlib/http_auth.py @@ -16,7 +16,6 @@ class NullProxyAuth(object): """ Clean up authentication headers, so they're not passed upstream. """ - pass def authenticate(self, headers): """ diff --git a/netlib/http_cookies.py b/netlib/http_cookies.py index 73e3f589..5cb39e5c 100644 --- a/netlib/http_cookies.py +++ b/netlib/http_cookies.py @@ -67,7 +67,6 @@ def _read_quoted_string(s, start): break elif s[i] == "\\": escaping = True - pass else: ret.append(s[i]) return "".join(ret), i + 1 diff --git a/netlib/tcp.py b/netlib/tcp.py index 7c115554..49f92e4a 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -7,7 +7,6 @@ import threading import time import traceback from OpenSSL import SSL -import OpenSSL from . import certutils @@ -650,4 +649,3 @@ class TCPServer(object): """ Called after server shutdown. """ - pass -- cgit v1.2.3 From 1dda164d0381161d3d0ad4e65199f6382aa2bf0d Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Thu, 28 May 2015 12:18:56 +1200 Subject: Satisfy autobots. --- netlib/certutils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) (limited to 'netlib') diff --git a/netlib/certutils.py b/netlib/certutils.py index 05408a0c..abf1a28b 100644 --- a/netlib/certutils.py +++ b/netlib/certutils.py @@ -42,7 +42,7 @@ def create_ca(o, cn, exp): cert.set_pubkey(key) cert.add_extensions([ OpenSSL.crypto.X509Extension( - "basicConstraints", + "basicConstraints", True, "CA:TRUE" ), @@ -155,6 +155,7 @@ class CertStore(object): """ Implements an in-memory certificate store. """ + def __init__(self, default_privatekey, default_ca, default_chain_file, dhparams): self.default_privatekey = default_privatekey self.default_ca = default_ca -- cgit v1.2.3 From 5288aa36403bc4b350700a0bf97adc4413f2a398 Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Wed, 27 May 2015 12:58:55 +0200 Subject: add human_readable() to each frame for debugging --- netlib/h2/frame.py | 75 +++++++++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 74 insertions(+), 1 deletion(-) (limited to 'netlib') diff --git a/netlib/h2/frame.py b/netlib/h2/frame.py index a7e81f48..51de7d4d 100644 --- a/netlib/h2/frame.py +++ b/netlib/h2/frame.py @@ -25,6 +25,7 @@ class Frame(object): raise ValueError('invalid flags detected.') self.length = length + self.type = self.TYPE self.flags = flags self.stream_id = stream_id @@ -49,10 +50,27 @@ class Frame(object): return b + def payload_bytes(self): # pragma: no cover + raise NotImplementedError() + + def payload_human_readable(self): # pragma: no cover + raise NotImplementedError() + + def human_readable(self): + return "\n".join([ + "============================================================", + "length: %d bytes" % self.length, + "type: %s (%#x)" % (self.__class__.__name__, self.TYPE), + "flags: %#x" % self.flags, + "stream_id: %#x" % self.stream_id, + "------------------------------------------------------------", + self.payload_human_readable(), + "============================================================", + ]) + def __eq__(self, other): return self.to_bytes() == other.to_bytes() - class DataFrame(Frame): TYPE = 0x0 VALID_FLAGS = [Frame.FLAG_END_STREAM, Frame.FLAG_PADDED] @@ -89,6 +107,8 @@ class DataFrame(Frame): return b + def payload_human_readable(self): + return "payload: %s" % str(self.payload) class HeadersFrame(Frame): TYPE = 0x1 @@ -139,6 +159,19 @@ class HeadersFrame(Frame): return b + def payload_human_readable(self): + s = [] + + if self.flags & self.FLAG_PRIORITY: + s.append("exclusive: %d" % self.exclusive) + s.append("stream dependency: %#x" % self.stream_dependency) + s.append("weight: %d" % self.weight) + + if self.flags & self.FLAG_PADDED: + s.append("padding: %d" % self.pad_length) + + s.append("header_block_fragment: %s" % str(self.header_block_fragment)) + return "\n".join(s) class PriorityFrame(Frame): TYPE = 0x2 @@ -169,6 +202,12 @@ class PriorityFrame(Frame): return struct.pack('!LB', (int(self.exclusive) << 31) | self.stream_dependency, self.weight) + def payload_human_readable(self): + s = [] + s.append("exclusive: %d" % self.exclusive) + s.append("stream dependency: %#x" % self.stream_dependency) + s.append("weight: %d" % self.weight) + return "\n".join(s) class RstStreamFrame(Frame): TYPE = 0x3 @@ -190,6 +229,8 @@ class RstStreamFrame(Frame): return struct.pack('!L', self.error_code) + def payload_human_readable(self): + return "error code: %#x" % self.error_code class SettingsFrame(Frame): TYPE = 0x4 @@ -228,6 +269,16 @@ class SettingsFrame(Frame): return b + def payload_human_readable(self): + s = [] + + for identifier, value in self.settings.items(): + s.append("%s: %#x" % (self.SETTINGS.get_name(identifier), value)) + + if not s: + return "settings: None" + else: + return "\n".join(s) class PushPromiseFrame(Frame): TYPE = 0x5 @@ -273,6 +324,15 @@ class PushPromiseFrame(Frame): return b + def payload_human_readable(self): + s = [] + + if self.flags & self.FLAG_PADDED: + s.append("padding: %d" % self.pad_length) + + s.append("promised stream: %#x" % self.promised_stream) + s.append("header_block_fragment: %s" % str(self.header_block_fragment)) + return "\n".join(s) class PingFrame(Frame): TYPE = 0x6 @@ -296,6 +356,8 @@ class PingFrame(Frame): b += b'\0' * (8 - len(b)) return b + def payload_human_readable(self): + return "opaque data: %s" % str(self.payload) class GoAwayFrame(Frame): TYPE = 0x7 @@ -325,6 +387,12 @@ class GoAwayFrame(Frame): b += bytes(self.data) return b + def payload_human_readable(self): + s = [] + s.append("last stream: %#x" % self.last_stream) + s.append("error code: %d" % self.error_code) + s.append("debug data: %s" % str(self.data)) + return "\n".join(s) class WindowUpdateFrame(Frame): TYPE = 0x8 @@ -349,6 +417,8 @@ class WindowUpdateFrame(Frame): return struct.pack('!L', self.window_size_increment & 0x7FFFFFFF) + def payload_human_readable(self): + return "window size increment: %#x" % self.window_size_increment class ContinuationFrame(Frame): TYPE = 0x9 @@ -370,6 +440,9 @@ class ContinuationFrame(Frame): return self.header_block_fragment + def payload_human_readable(self): + return "header_block_fragment: %s" % str(self.header_block_fragment) + _FRAME_CLASSES = [ DataFrame, HeadersFrame, -- cgit v1.2.3 From 754f929187e3954eb05971e38bcd3358d3a5e3be Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Wed, 27 May 2015 17:31:18 +0200 Subject: fix default argument Python evaluates default args during method definition. So you get the same dict each time you call this method. Therefore the dict is the SAME actual object each time. --- netlib/h2/frame.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) (limited to 'netlib') diff --git a/netlib/h2/frame.py b/netlib/h2/frame.py index 51de7d4d..ed6af200 100644 --- a/netlib/h2/frame.py +++ b/netlib/h2/frame.py @@ -245,8 +245,12 @@ class SettingsFrame(Frame): SETTINGS_MAX_HEADER_LIST_SIZE=0x6, ) - def __init__(self, length=0, flags=Frame.FLAG_NO_FLAGS, stream_id=0x0, settings={}): + def __init__(self, length=0, flags=Frame.FLAG_NO_FLAGS, stream_id=0x0, settings=None): super(SettingsFrame, self).__init__(length, flags, stream_id) + + if settings is None: + settings = {} + self.settings = settings @classmethod -- cgit v1.2.3 From 4c469fdee1b5b01a7e847a75fbbd902dc3bfbd70 Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Wed, 27 May 2015 17:53:06 +0200 Subject: add hpack to encode and decode headers --- netlib/h2/frame.py | 41 ++++++++++++++++++++++++++++++++--------- 1 file changed, 32 insertions(+), 9 deletions(-) (limited to 'netlib') diff --git a/netlib/h2/frame.py b/netlib/h2/frame.py index ed6af200..179634b0 100644 --- a/netlib/h2/frame.py +++ b/netlib/h2/frame.py @@ -1,4 +1,6 @@ import struct +import io +from hpack.hpack import Encoder, Decoder from .. import utils from functools import reduce @@ -71,6 +73,7 @@ class Frame(object): def __eq__(self, other): return self.to_bytes() == other.to_bytes() + class DataFrame(Frame): TYPE = 0x0 VALID_FLAGS = [Frame.FLAG_END_STREAM, Frame.FLAG_PADDED] @@ -110,14 +113,18 @@ class DataFrame(Frame): def payload_human_readable(self): return "payload: %s" % str(self.payload) + class HeadersFrame(Frame): TYPE = 0x1 VALID_FLAGS = [Frame.FLAG_END_STREAM, Frame.FLAG_END_HEADERS, Frame.FLAG_PADDED, Frame.FLAG_PRIORITY] - def __init__(self, length=0, flags=Frame.FLAG_NO_FLAGS, stream_id=0x0, header_block_fragment=b'', - pad_length=0, exclusive=False, stream_dependency=0x0, weight=0): + def __init__(self, length=0, flags=Frame.FLAG_NO_FLAGS, stream_id=0x0, headers=None, pad_length=0, exclusive=False, stream_dependency=0x0, weight=0): super(HeadersFrame, self).__init__(length, flags, stream_id) - self.header_block_fragment = header_block_fragment + + if headers is None: + headers = [] + + self.headers = headers self.pad_length = pad_length self.exclusive = exclusive self.stream_dependency = stream_dependency @@ -129,15 +136,18 @@ class HeadersFrame(Frame): if f.flags & self.FLAG_PADDED: f.pad_length = struct.unpack('!B', payload[0])[0] - f.header_block_fragment = payload[1:-f.pad_length] + header_block_fragment = payload[1:-f.pad_length] else: - f.header_block_fragment = payload[0:] + header_block_fragment = payload[0:] if f.flags & self.FLAG_PRIORITY: - f.stream_dependency, f.weight = struct.unpack('!LB', f.header_block_fragment[:5]) + f.stream_dependency, f.weight = struct.unpack('!LB', header_block_fragment[:5]) f.exclusive = bool(f.stream_dependency >> 31) f.stream_dependency &= 0x7FFFFFFF - f.header_block_fragment = f.header_block_fragment[5:] + header_block_fragment = header_block_fragment[5:] + + for header, value in Decoder().decode(header_block_fragment): + f.headers.append((header, value)) return f @@ -152,7 +162,7 @@ class HeadersFrame(Frame): if self.flags & self.FLAG_PRIORITY: b += struct.pack('!LB', (int(self.exclusive) << 31) | self.stream_dependency, self.weight) - b += bytes(self.header_block_fragment) + b += Encoder().encode(self.headers) if self.flags & self.FLAG_PADDED: b += b'\0' * self.pad_length @@ -170,9 +180,15 @@ class HeadersFrame(Frame): if self.flags & self.FLAG_PADDED: s.append("padding: %d" % self.pad_length) - s.append("header_block_fragment: %s" % str(self.header_block_fragment)) + if not self.headers: + s.append("headers: None") + else: + for header, value in self.headers: + s.append("%s: %s" % (header, value)) + return "\n".join(s) + class PriorityFrame(Frame): TYPE = 0x2 VALID_FLAGS = [] @@ -209,6 +225,7 @@ class PriorityFrame(Frame): s.append("weight: %d" % self.weight) return "\n".join(s) + class RstStreamFrame(Frame): TYPE = 0x3 VALID_FLAGS = [] @@ -232,6 +249,7 @@ class RstStreamFrame(Frame): def payload_human_readable(self): return "error code: %#x" % self.error_code + class SettingsFrame(Frame): TYPE = 0x4 VALID_FLAGS = [Frame.FLAG_ACK] @@ -284,6 +302,7 @@ class SettingsFrame(Frame): else: return "\n".join(s) + class PushPromiseFrame(Frame): TYPE = 0x5 VALID_FLAGS = [Frame.FLAG_END_HEADERS, Frame.FLAG_PADDED] @@ -338,6 +357,7 @@ class PushPromiseFrame(Frame): s.append("header_block_fragment: %s" % str(self.header_block_fragment)) return "\n".join(s) + class PingFrame(Frame): TYPE = 0x6 VALID_FLAGS = [Frame.FLAG_ACK] @@ -363,6 +383,7 @@ class PingFrame(Frame): def payload_human_readable(self): return "opaque data: %s" % str(self.payload) + class GoAwayFrame(Frame): TYPE = 0x7 VALID_FLAGS = [] @@ -398,6 +419,7 @@ class GoAwayFrame(Frame): s.append("debug data: %s" % str(self.data)) return "\n".join(s) + class WindowUpdateFrame(Frame): TYPE = 0x8 VALID_FLAGS = [] @@ -424,6 +446,7 @@ class WindowUpdateFrame(Frame): def payload_human_readable(self): return "window size increment: %#x" % self.window_size_increment + class ContinuationFrame(Frame): TYPE = 0x9 VALID_FLAGS = [Frame.FLAG_END_HEADERS] -- cgit v1.2.3 From d50b9be0d5dab1772f0edcbfa89542ef9425e7bf Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Wed, 27 May 2015 17:53:45 +0200 Subject: add generic frame parsing method --- netlib/h2/frame.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) (limited to 'netlib') diff --git a/netlib/h2/frame.py b/netlib/h2/frame.py index 179634b0..11687316 100644 --- a/netlib/h2/frame.py +++ b/netlib/h2/frame.py @@ -31,6 +31,23 @@ class Frame(object): self.flags = flags self.stream_id = stream_id + @classmethod + def from_file(self, fp): + """ + read a HTTP/2 frame sent by a server or client + fp is a "file like" object that could be backed by a network + stream or a disk or an in memory stream reader + """ + raw_header = fp.safe_read(9) + + fields = struct.unpack("!HBBBL", raw_header) + length = (fields[0] << 8) + fields[1] + flags = fields[3] + stream_id = fields[4] + + payload = fp.safe_read(length) + return FRAMES[fields[2]].from_bytes(length, flags, stream_id, payload) + @classmethod def from_bytes(self, data): fields = struct.unpack("!HBBBL", data[:9]) -- cgit v1.2.3 From 780836b182cd982b978f16218299f2b77a8ed204 Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Thu, 28 May 2015 17:46:44 +0200 Subject: add ALPN support to TCP abstraction --- netlib/tcp.py | 35 +++++++++++++++++++++++++++-------- netlib/test.py | 3 ++- 2 files changed, 29 insertions(+), 9 deletions(-) (limited to 'netlib') diff --git a/netlib/tcp.py b/netlib/tcp.py index 49f92e4a..fc2c144e 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -360,7 +360,9 @@ class _Connection(object): def _create_ssl_context(self, method=SSLv23_METHOD, options=(OP_NO_SSLv2 | OP_NO_SSLv3), - cipher_list=None + cipher_list=None, + alpn_protos=None, + alpn_select=None, ): """ :param method: One of SSLv2_METHOD, SSLv3_METHOD, SSLv23_METHOD, TLSv1_METHOD or TLSv1_1_METHOD @@ -389,6 +391,17 @@ class _Connection(object): if log_ssl_key: context.set_info_callback(log_ssl_key) + # advertise application layer protocols + if alpn_protos is not None: + context.set_alpn_protos(alpn_protos) + + # select application layer protocol + if alpn_select is not None: + def alpn_select_f(conn, options): + return bytes(alpn_select) + + context.set_alpn_select_callback(alpn_select_f) + return context @@ -413,8 +426,8 @@ class TCPClient(_Connection): self.ssl_established = False self.sni = None - def create_ssl_context(self, cert=None, **sslctx_kwargs): - context = self._create_ssl_context(**sslctx_kwargs) + def create_ssl_context(self, cert=None, alpn_protos=None, **sslctx_kwargs): + context = self._create_ssl_context(alpn_protos=alpn_protos, **sslctx_kwargs) # Client Certs if cert: try: @@ -424,13 +437,13 @@ class TCPClient(_Connection): raise NetLibError("SSL client certificate error: %s" % str(v)) return context - def convert_to_ssl(self, sni=None, **sslctx_kwargs): + def convert_to_ssl(self, sni=None, alpn_protos=None, **sslctx_kwargs): """ cert: Path to a file containing both client cert and private key. options: A bit field consisting of OpenSSL.SSL.OP_* values """ - context = self.create_ssl_context(**sslctx_kwargs) + context = self.create_ssl_context(alpn_protos=alpn_protos, **sslctx_kwargs) self.connection = SSL.Connection(context, self.connection) if sni: self.sni = sni @@ -465,6 +478,9 @@ class TCPClient(_Connection): def gettimeout(self): return self.connection.gettimeout() + def get_alpn_proto_negotiated(self): + return self.connection.get_alpn_proto_negotiated() + class BaseHandler(_Connection): @@ -492,6 +508,7 @@ class BaseHandler(_Connection): request_client_cert=None, chain_file=None, dhparams=None, + alpn_select=None, **sslctx_kwargs): """ cert: A certutils.SSLCert object. @@ -517,7 +534,8 @@ class BaseHandler(_Connection): we may be able to make the proper behaviour the default again, but until then we're conservative. """ - context = self._create_ssl_context(**sslctx_kwargs) + + context = self._create_ssl_context(alpn_select=alpn_select, **sslctx_kwargs) context.use_privatekey(key) context.use_certificate(cert.x509) @@ -542,12 +560,13 @@ class BaseHandler(_Connection): return context - def convert_to_ssl(self, cert, key, **sslctx_kwargs): + def convert_to_ssl(self, cert, key, alpn_select=None, **sslctx_kwargs): """ Convert connection to SSL. For a list of parameters, see BaseHandler._create_ssl_context(...) """ - context = self.create_ssl_context(cert, key, **sslctx_kwargs) + + context = self.create_ssl_context(cert, key, alpn_select=alpn_select, **sslctx_kwargs) self.connection = SSL.Connection(context, self.connection) self.connection.set_accept_state() try: diff --git a/netlib/test.py b/netlib/test.py index b6f94273..63b493a9 100644 --- a/netlib/test.py +++ b/netlib/test.py @@ -82,7 +82,8 @@ class TServer(tcp.TCPServer): request_client_cert=self.ssl["request_client_cert"], cipher_list=self.ssl.get("cipher_list", None), dhparams=self.ssl.get("dhparams", None), - chain_file=self.ssl.get("chain_file", None) + chain_file=self.ssl.get("chain_file", None), + alpn_select=self.ssl.get("alpn_select", None) ) h.handle() h.finish() -- cgit v1.2.3 From e2de49596d0e60e343c71c73e0847b17fb27ac3c Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Thu, 28 May 2015 17:46:30 +0200 Subject: add HTTP/2-capable client --- netlib/h2/h2.py | 65 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 65 insertions(+) (limited to 'netlib') diff --git a/netlib/h2/h2.py b/netlib/h2/h2.py index 7a85226f..bfe5832b 100644 --- a/netlib/h2/h2.py +++ b/netlib/h2/h2.py @@ -1,3 +1,5 @@ +from .. import utils, odict, tcp +from frame import * # "PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n" CLIENT_CONNECTION_PREFACE = '505249202a20485454502f322e300d0a0d0a534d0d0a0d0a' @@ -18,3 +20,66 @@ ERROR_CODES = utils.BiDi( INADEQUATE_SECURITY=0xc, HTTP_1_1_REQUIRED=0xd ) + + +class H2Client(tcp.TCPClient): + ALPN_PROTO_H2 = b'h2' + + DEFAULT_SETTINGS = { + SettingsFrame.SETTINGS.SETTINGS_HEADER_TABLE_SIZE: 4096, + SettingsFrame.SETTINGS.SETTINGS_ENABLE_PUSH: 1, + SettingsFrame.SETTINGS.SETTINGS_MAX_CONCURRENT_STREAMS: None, + SettingsFrame.SETTINGS.SETTINGS_INITIAL_WINDOW_SIZE: 2 ^ 16 - 1, + SettingsFrame.SETTINGS.SETTINGS_MAX_FRAME_SIZE: 2 ^ 14, + SettingsFrame.SETTINGS.SETTINGS_MAX_HEADER_LIST_SIZE: None, + } + + def __init__(self, address, source_address=None): + super(H2Client, self).__init__(address, source_address) + self.settings = self.DEFAULT_SETTINGS.copy() + + def connect(self, send_preface=True): + super(H2Client, self).connect() + self.convert_to_ssl(alpn_protos=[self.ALPN_PROTO_H2]) + + alp = self.get_alpn_proto_negotiated() + if alp != b'h2': + raise NotImplementedError("H2Client can not handle unknown protocol: %s" % alp) + print "-> Successfully negotiated 'h2' application layer protocol." + + if send_preface: + self.wfile.write(bytes(CLIENT_CONNECTION_PREFACE.decode('hex'))) + self.send_frame(SettingsFrame()) + + frame = Frame.from_file(self.rfile) + print frame.human_readable() + assert isinstance(frame, SettingsFrame) + self.apply_settings(frame.settings) + + print "-> Connection Preface completed." + + print "-> H2Client is ready..." + + def send_frame(self, frame): + self.wfile.write(frame.to_bytes()) + self.wfile.flush() + + def read_frame(self): + frame = Frame.from_file(self.rfile) + if isinstance(frame, SettingsFrame): + self.apply_settings(frame.settings) + + return frame + + def apply_settings(self, settings): + for setting, value in settings.items(): + old_value = self.settings[setting] + if not old_value: + old_value = '-' + + self.settings[setting] = value + print "-> Setting changed: %s to %d (was %s)" % + (SettingsFrame.SETTINGS.get_name(setting), value, str(old_value)) + + self.send_frame(SettingsFrame(flags=Frame.FLAG_ACK)) + print "-> New settings acknowledged." -- cgit v1.2.3 From c32d8189faa24cbe016bb3c859f64c816e0871fe Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Fri, 29 May 2015 16:59:50 +0200 Subject: cleanup imports --- netlib/h2/frame.py | 1 - 1 file changed, 1 deletion(-) (limited to 'netlib') diff --git a/netlib/h2/frame.py b/netlib/h2/frame.py index 11687316..d4294052 100644 --- a/netlib/h2/frame.py +++ b/netlib/h2/frame.py @@ -1,5 +1,4 @@ import struct -import io from hpack.hpack import Encoder, Decoder from .. import utils -- cgit v1.2.3 From f76bfabc5d4ce36c56b1d1fd571728ee06f37b78 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Sat, 30 May 2015 12:02:58 +1200 Subject: Adjust pep8 parameters, reformat --- netlib/certutils.py | 75 +++++++++++++++++++++------- netlib/h2/frame.py | 127 ++++++++++++++++++++++++++++++++++++++--------- netlib/h2/h2.py | 8 ++- netlib/http.py | 3 +- netlib/http_uastrings.py | 91 ++++++++++++--------------------- netlib/tcp.py | 42 ++++++++++++---- netlib/test.py | 4 +- netlib/wsgi.py | 3 +- 8 files changed, 236 insertions(+), 117 deletions(-) (limited to 'netlib') diff --git a/netlib/certutils.py b/netlib/certutils.py index abf1a28b..ade61bb5 100644 --- a/netlib/certutils.py +++ b/netlib/certutils.py @@ -96,7 +96,8 @@ def dummy_cert(privkey, cacert, commonname, sans): cert.set_serial_number(int(time.time() * 10000)) if ss: cert.set_version(2) - cert.add_extensions([OpenSSL.crypto.X509Extension("subjectAltName", False, ss)]) + cert.add_extensions( + [OpenSSL.crypto.X509Extension("subjectAltName", False, ss)]) cert.set_pubkey(cacert.get_pubkey()) cert.sign(privkey, "sha256") return SSLCert(cert) @@ -156,7 +157,12 @@ class CertStore(object): Implements an in-memory certificate store. """ - def __init__(self, default_privatekey, default_ca, default_chain_file, dhparams): + def __init__( + self, + default_privatekey, + default_ca, + default_chain_file, + dhparams): self.default_privatekey = default_privatekey self.default_ca = default_ca self.default_chain_file = default_chain_file @@ -176,8 +182,10 @@ class CertStore(object): if bio != OpenSSL.SSL._ffi.NULL: bio = OpenSSL.SSL._ffi.gc(bio, OpenSSL.SSL._lib.BIO_free) dh = OpenSSL.SSL._lib.PEM_read_bio_DHparams( - bio, OpenSSL.SSL._ffi.NULL, OpenSSL.SSL._ffi.NULL, OpenSSL.SSL._ffi.NULL - ) + bio, + OpenSSL.SSL._ffi.NULL, + OpenSSL.SSL._ffi.NULL, + OpenSSL.SSL._ffi.NULL) dh = OpenSSL.SSL._ffi.gc(dh, OpenSSL.SSL._lib.DH_free) return dh @@ -189,8 +197,12 @@ class CertStore(object): else: with open(ca_path, "rb") as f: raw = f.read() - ca = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, raw) - key = OpenSSL.crypto.load_privatekey(OpenSSL.crypto.FILETYPE_PEM, raw) + ca = OpenSSL.crypto.load_certificate( + OpenSSL.crypto.FILETYPE_PEM, + raw) + key = OpenSSL.crypto.load_privatekey( + OpenSSL.crypto.FILETYPE_PEM, + raw) dh_path = os.path.join(path, basename + "-dhparam.pem") dh = cls.load_dhparam(dh_path) return cls(key, ca, ca_path, dh) @@ -206,16 +218,28 @@ class CertStore(object): key, ca = create_ca(o=o, cn=cn, exp=expiry) # Dump the CA plus private key with open(os.path.join(path, basename + "-ca.pem"), "wb") as f: - f.write(OpenSSL.crypto.dump_privatekey(OpenSSL.crypto.FILETYPE_PEM, key)) - f.write(OpenSSL.crypto.dump_certificate(OpenSSL.crypto.FILETYPE_PEM, ca)) + f.write( + OpenSSL.crypto.dump_privatekey( + OpenSSL.crypto.FILETYPE_PEM, + key)) + f.write( + OpenSSL.crypto.dump_certificate( + OpenSSL.crypto.FILETYPE_PEM, + ca)) # Dump the certificate in PEM format with open(os.path.join(path, basename + "-ca-cert.pem"), "wb") as f: - f.write(OpenSSL.crypto.dump_certificate(OpenSSL.crypto.FILETYPE_PEM, ca)) + f.write( + OpenSSL.crypto.dump_certificate( + OpenSSL.crypto.FILETYPE_PEM, + ca)) # Create a .cer file with the same contents for Android with open(os.path.join(path, basename + "-ca-cert.cer"), "wb") as f: - f.write(OpenSSL.crypto.dump_certificate(OpenSSL.crypto.FILETYPE_PEM, ca)) + f.write( + OpenSSL.crypto.dump_certificate( + OpenSSL.crypto.FILETYPE_PEM, + ca)) # Dump the certificate in PKCS12 format for Windows devices with open(os.path.join(path, basename + "-ca-cert.p12"), "wb") as f: @@ -232,9 +256,14 @@ class CertStore(object): def add_cert_file(self, spec, path): with open(path, "rb") as f: raw = f.read() - cert = SSLCert(OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, raw)) + cert = SSLCert( + OpenSSL.crypto.load_certificate( + OpenSSL.crypto.FILETYPE_PEM, + raw)) try: - privatekey = OpenSSL.crypto.load_privatekey(OpenSSL.crypto.FILETYPE_PEM, raw) + privatekey = OpenSSL.crypto.load_privatekey( + OpenSSL.crypto.FILETYPE_PEM, + raw) except Exception: privatekey = self.default_privatekey self.add_cert( @@ -284,15 +313,22 @@ class CertStore(object): potential_keys.extend(self.asterisk_forms(s)) potential_keys.append((commonname, tuple(sans))) - name = next(itertools.ifilter(lambda key: key in self.certs, potential_keys), None) + name = next( + itertools.ifilter( + lambda key: key in self.certs, + potential_keys), + None) if name: entry = self.certs[name] else: entry = CertStoreEntry( - cert=dummy_cert(self.default_privatekey, self.default_ca, commonname, sans), + cert=dummy_cert( + self.default_privatekey, + self.default_ca, + commonname, + sans), privatekey=self.default_privatekey, - chain_file=self.default_chain_file - ) + chain_file=self.default_chain_file) self.certs[(commonname, tuple(sans))] = entry return entry.cert, entry.privatekey, entry.chain_file @@ -317,7 +353,8 @@ class _GeneralName(univ.Choice): class _GeneralNames(univ.SequenceOf): componentType = _GeneralName() - sizeSpec = univ.SequenceOf.sizeSpec + constraint.ValueSizeConstraint(1, 1024) + sizeSpec = univ.SequenceOf.sizeSpec + \ + constraint.ValueSizeConstraint(1, 1024) class SSLCert(object): @@ -345,7 +382,9 @@ class SSLCert(object): return klass.from_pem(pem) def to_pem(self): - return OpenSSL.crypto.dump_certificate(OpenSSL.crypto.FILETYPE_PEM, self.x509) + return OpenSSL.crypto.dump_certificate( + OpenSSL.crypto.FILETYPE_PEM, + self.x509) def digest(self, name): return self.x509.digest(name) diff --git a/netlib/h2/frame.py b/netlib/h2/frame.py index d4294052..36456c46 100644 --- a/netlib/h2/frame.py +++ b/netlib/h2/frame.py @@ -94,7 +94,13 @@ class DataFrame(Frame): TYPE = 0x0 VALID_FLAGS = [Frame.FLAG_END_STREAM, Frame.FLAG_PADDED] - def __init__(self, length=0, flags=Frame.FLAG_NO_FLAGS, stream_id=0x0, payload=b'', pad_length=0): + def __init__( + self, + length=0, + flags=Frame.FLAG_NO_FLAGS, + stream_id=0x0, + payload=b'', + pad_length=0): super(DataFrame, self).__init__(length, flags, stream_id) self.payload = payload self.pad_length = pad_length @@ -132,9 +138,22 @@ class DataFrame(Frame): class HeadersFrame(Frame): TYPE = 0x1 - VALID_FLAGS = [Frame.FLAG_END_STREAM, Frame.FLAG_END_HEADERS, Frame.FLAG_PADDED, Frame.FLAG_PRIORITY] - - def __init__(self, length=0, flags=Frame.FLAG_NO_FLAGS, stream_id=0x0, headers=None, pad_length=0, exclusive=False, stream_dependency=0x0, weight=0): + VALID_FLAGS = [ + Frame.FLAG_END_STREAM, + Frame.FLAG_END_HEADERS, + Frame.FLAG_PADDED, + Frame.FLAG_PRIORITY] + + def __init__( + self, + length=0, + flags=Frame.FLAG_NO_FLAGS, + stream_id=0x0, + headers=None, + pad_length=0, + exclusive=False, + stream_dependency=0x0, + weight=0): super(HeadersFrame, self).__init__(length, flags, stream_id) if headers is None: @@ -157,7 +176,9 @@ class HeadersFrame(Frame): header_block_fragment = payload[0:] if f.flags & self.FLAG_PRIORITY: - f.stream_dependency, f.weight = struct.unpack('!LB', header_block_fragment[:5]) + f.stream_dependency, f.weight = struct.unpack( + '!LB', header_block_fragment[ + :5]) f.exclusive = bool(f.stream_dependency >> 31) f.stream_dependency &= 0x7FFFFFFF header_block_fragment = header_block_fragment[5:] @@ -176,7 +197,9 @@ class HeadersFrame(Frame): b += struct.pack('!B', self.pad_length) if self.flags & self.FLAG_PRIORITY: - b += struct.pack('!LB', (int(self.exclusive) << 31) | self.stream_dependency, self.weight) + b += struct.pack('!LB', + (int(self.exclusive) << 31) | self.stream_dependency, + self.weight) b += Encoder().encode(self.headers) @@ -209,7 +232,14 @@ class PriorityFrame(Frame): TYPE = 0x2 VALID_FLAGS = [] - def __init__(self, length=0, flags=Frame.FLAG_NO_FLAGS, stream_id=0x0, exclusive=False, stream_dependency=0x0, weight=0): + def __init__( + self, + length=0, + flags=Frame.FLAG_NO_FLAGS, + stream_id=0x0, + exclusive=False, + stream_dependency=0x0, + weight=0): super(PriorityFrame, self).__init__(length, flags, stream_id) self.exclusive = exclusive self.stream_dependency = stream_dependency @@ -227,12 +257,17 @@ class PriorityFrame(Frame): def payload_bytes(self): if self.stream_id == 0x0: - raise ValueError('PRIORITY frames MUST be associated with a stream.') + raise ValueError( + 'PRIORITY frames MUST be associated with a stream.') if self.stream_dependency == 0x0: raise ValueError('stream dependency is invalid.') - return struct.pack('!LB', (int(self.exclusive) << 31) | self.stream_dependency, self.weight) + return struct.pack( + '!LB', + (int( + self.exclusive) << 31) | self.stream_dependency, + self.weight) def payload_human_readable(self): s = [] @@ -246,7 +281,12 @@ class RstStreamFrame(Frame): TYPE = 0x3 VALID_FLAGS = [] - def __init__(self, length=0, flags=Frame.FLAG_NO_FLAGS, stream_id=0x0, error_code=0x0): + def __init__( + self, + length=0, + flags=Frame.FLAG_NO_FLAGS, + stream_id=0x0, + error_code=0x0): super(RstStreamFrame, self).__init__(length, flags, stream_id) self.error_code = error_code @@ -258,7 +298,8 @@ class RstStreamFrame(Frame): def payload_bytes(self): if self.stream_id == 0x0: - raise ValueError('RST_STREAM frames MUST be associated with a stream.') + raise ValueError( + 'RST_STREAM frames MUST be associated with a stream.') return struct.pack('!L', self.error_code) @@ -279,7 +320,12 @@ class SettingsFrame(Frame): SETTINGS_MAX_HEADER_LIST_SIZE=0x6, ) - def __init__(self, length=0, flags=Frame.FLAG_NO_FLAGS, stream_id=0x0, settings=None): + def __init__( + self, + length=0, + flags=Frame.FLAG_NO_FLAGS, + stream_id=0x0, + settings=None): super(SettingsFrame, self).__init__(length, flags, stream_id) if settings is None: @@ -299,7 +345,8 @@ class SettingsFrame(Frame): def payload_bytes(self): if self.stream_id != 0x0: - raise ValueError('SETTINGS frames MUST NOT be associated with a stream.') + raise ValueError( + 'SETTINGS frames MUST NOT be associated with a stream.') b = b'' for identifier, value in self.settings.items(): @@ -323,7 +370,14 @@ class PushPromiseFrame(Frame): TYPE = 0x5 VALID_FLAGS = [Frame.FLAG_END_HEADERS, Frame.FLAG_PADDED] - def __init__(self, length=0, flags=Frame.FLAG_NO_FLAGS, stream_id=0x0, promised_stream=0x0, header_block_fragment=b'', pad_length=0): + def __init__( + self, + length=0, + flags=Frame.FLAG_NO_FLAGS, + stream_id=0x0, + promised_stream=0x0, + header_block_fragment=b'', + pad_length=0): super(PushPromiseFrame, self).__init__(length, flags, stream_id) self.pad_length = pad_length self.promised_stream = promised_stream @@ -346,7 +400,8 @@ class PushPromiseFrame(Frame): def payload_bytes(self): if self.stream_id == 0x0: - raise ValueError('PUSH_PROMISE frames MUST be associated with a stream.') + raise ValueError( + 'PUSH_PROMISE frames MUST be associated with a stream.') if self.promised_stream == 0x0: raise ValueError('Promised stream id not valid.') @@ -378,7 +433,12 @@ class PingFrame(Frame): TYPE = 0x6 VALID_FLAGS = [Frame.FLAG_ACK] - def __init__(self, length=0, flags=Frame.FLAG_NO_FLAGS, stream_id=0x0, payload=b''): + def __init__( + self, + length=0, + flags=Frame.FLAG_NO_FLAGS, + stream_id=0x0, + payload=b''): super(PingFrame, self).__init__(length, flags, stream_id) self.payload = payload @@ -390,7 +450,8 @@ class PingFrame(Frame): def payload_bytes(self): if self.stream_id != 0x0: - raise ValueError('PING frames MUST NOT be associated with a stream.') + raise ValueError( + 'PING frames MUST NOT be associated with a stream.') b = self.payload[0:8] b += b'\0' * (8 - len(b)) @@ -404,7 +465,14 @@ class GoAwayFrame(Frame): TYPE = 0x7 VALID_FLAGS = [] - def __init__(self, length=0, flags=Frame.FLAG_NO_FLAGS, stream_id=0x0, last_stream=0x0, error_code=0x0, data=b''): + def __init__( + self, + length=0, + flags=Frame.FLAG_NO_FLAGS, + stream_id=0x0, + last_stream=0x0, + error_code=0x0, + data=b''): super(GoAwayFrame, self).__init__(length, flags, stream_id) self.last_stream = last_stream self.error_code = error_code @@ -422,7 +490,8 @@ class GoAwayFrame(Frame): def payload_bytes(self): if self.stream_id != 0x0: - raise ValueError('GOAWAY frames MUST NOT be associated with a stream.') + raise ValueError( + 'GOAWAY frames MUST NOT be associated with a stream.') b = struct.pack('!LL', self.last_stream & 0x7FFFFFFF, self.error_code) b += bytes(self.data) @@ -440,7 +509,12 @@ class WindowUpdateFrame(Frame): TYPE = 0x8 VALID_FLAGS = [] - def __init__(self, length=0, flags=Frame.FLAG_NO_FLAGS, stream_id=0x0, window_size_increment=0x0): + def __init__( + self, + length=0, + flags=Frame.FLAG_NO_FLAGS, + stream_id=0x0, + window_size_increment=0x0): super(WindowUpdateFrame, self).__init__(length, flags, stream_id) self.window_size_increment = window_size_increment @@ -455,7 +529,8 @@ class WindowUpdateFrame(Frame): def payload_bytes(self): if self.window_size_increment <= 0 or self.window_size_increment >= 2 ** 31: - raise ValueError('Window Szie Increment MUST be greater than 0 and less than 2^31.') + raise ValueError( + 'Window Szie Increment MUST be greater than 0 and less than 2^31.') return struct.pack('!L', self.window_size_increment & 0x7FFFFFFF) @@ -467,7 +542,12 @@ class ContinuationFrame(Frame): TYPE = 0x9 VALID_FLAGS = [Frame.FLAG_END_HEADERS] - def __init__(self, length=0, flags=Frame.FLAG_NO_FLAGS, stream_id=0x0, header_block_fragment=b''): + def __init__( + self, + length=0, + flags=Frame.FLAG_NO_FLAGS, + stream_id=0x0, + header_block_fragment=b''): super(ContinuationFrame, self).__init__(length, flags, stream_id) self.header_block_fragment = header_block_fragment @@ -479,7 +559,8 @@ class ContinuationFrame(Frame): def payload_bytes(self): if self.stream_id == 0x0: - raise ValueError('CONTINUATION frames MUST be associated with a stream.') + raise ValueError( + 'CONTINUATION frames MUST be associated with a stream.') return self.header_block_fragment diff --git a/netlib/h2/h2.py b/netlib/h2/h2.py index bfe5832b..707b1465 100644 --- a/netlib/h2/h2.py +++ b/netlib/h2/h2.py @@ -44,7 +44,9 @@ class H2Client(tcp.TCPClient): alp = self.get_alpn_proto_negotiated() if alp != b'h2': - raise NotImplementedError("H2Client can not handle unknown protocol: %s" % alp) + raise NotImplementedError( + "H2Client can not handle unknown protocol: %s" % + alp) print "-> Successfully negotiated 'h2' application layer protocol." if send_preface: @@ -79,7 +81,9 @@ class H2Client(tcp.TCPClient): self.settings[setting] = value print "-> Setting changed: %s to %d (was %s)" % - (SettingsFrame.SETTINGS.get_name(setting), value, str(old_value)) + (SettingsFrame.SETTINGS.get_name(setting), + value, + str(old_value)) self.send_frame(SettingsFrame(flags=Frame.FLAG_ACK)) print "-> New settings acknowledged." diff --git a/netlib/http.py b/netlib/http.py index 47658097..a2af9e49 100644 --- a/netlib/http.py +++ b/netlib/http.py @@ -124,7 +124,8 @@ def read_chunked(fp, limit, is_request): May raise HttpError. """ # FIXME: Should check if chunked is the final encoding in the headers - # http://tools.ietf.org/html/draft-ietf-httpbis-p1-messaging-16#section-3.3 3.3 2. + # http://tools.ietf.org/html/draft-ietf-httpbis-p1-messaging-16#section-3.3 + # 3.3 2. total = 0 code = 400 if is_request else 502 while True: diff --git a/netlib/http_uastrings.py b/netlib/http_uastrings.py index d0d145da..d9869531 100644 --- a/netlib/http_uastrings.py +++ b/netlib/http_uastrings.py @@ -8,66 +8,37 @@ from __future__ import (absolute_import, print_function, division) # A collection of (name, shortcut, string) tuples. UASTRINGS = [ - ( - "android", - "a", - "Mozilla/5.0 (Linux; U; Android 4.1.1; en-gb; Nexus 7 Build/JRO03D) AFL/01.04.02" - ), - - ( - "blackberry", - "l", - "Mozilla/5.0 (BlackBerry; U; BlackBerry 9900; en) AppleWebKit/534.11+ (KHTML, like Gecko) Version/7.1.0.346 Mobile Safari/534.11+" - ), - - ( - "bingbot", - "b", - "Mozilla/5.0 (compatible; bingbot/2.0; +http://www.bing.com/bingbot.htm)" - ), - - ( - "chrome", - "c", - "Mozilla/5.0 (Windows NT 6.1; WOW64) AppleWebKit/537.1 (KHTML, like Gecko) Chrome/22.0.1207.1 Safari/537.1" - ), - - ( - "firefox", - "f", - "Mozilla/5.0 (Windows NT 6.1; Win64; x64; rv:14.0) Gecko/20120405 Firefox/14.0a1" - ), - - ( - "googlebot", - "g", - "Googlebot/2.1 (+http://www.googlebot.com/bot.html)" - ), - - ( - "ie9", - "i", - "Mozilla/5.0 (Windows; U; MSIE 9.0; WIndows NT 9.0; en-US))" - ), - - ( - "ipad", - "p", - "Mozilla/5.0 (iPad; CPU OS 5_1 like Mac OS X) AppleWebKit/534.46 (KHTML, like Gecko ) Version/5.1 Mobile/9B176 Safari/7534.48.3" - ), - - ( - "iphone", - "h", - "Mozilla/5.0 (iPhone; CPU iPhone OS 4_2_1 like Mac OS X) AppleWebKit/533.17.9 (KHTML, like Gecko) Version/5.0.2 Mobile/8C148a Safari/6533.18.5", - ), - - ( - "safari", - "s", - "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_7_3) AppleWebKit/534.55.3 (KHTML, like Gecko) Version/5.1.3 Safari/534.53.10" - ) -] + ("android", + "a", + "Mozilla/5.0 (Linux; U; Android 4.1.1; en-gb; Nexus 7 Build/JRO03D) AFL/01.04.02"), + ("blackberry", + "l", + "Mozilla/5.0 (BlackBerry; U; BlackBerry 9900; en) AppleWebKit/534.11+ (KHTML, like Gecko) Version/7.1.0.346 Mobile Safari/534.11+"), + ("bingbot", + "b", + "Mozilla/5.0 (compatible; bingbot/2.0; +http://www.bing.com/bingbot.htm)"), + ("chrome", + "c", + "Mozilla/5.0 (Windows NT 6.1; WOW64) AppleWebKit/537.1 (KHTML, like Gecko) Chrome/22.0.1207.1 Safari/537.1"), + ("firefox", + "f", + "Mozilla/5.0 (Windows NT 6.1; Win64; x64; rv:14.0) Gecko/20120405 Firefox/14.0a1"), + ("googlebot", + "g", + "Googlebot/2.1 (+http://www.googlebot.com/bot.html)"), + ("ie9", + "i", + "Mozilla/5.0 (Windows; U; MSIE 9.0; WIndows NT 9.0; en-US))"), + ("ipad", + "p", + "Mozilla/5.0 (iPad; CPU OS 5_1 like Mac OS X) AppleWebKit/534.46 (KHTML, like Gecko ) Version/5.1 Mobile/9B176 Safari/7534.48.3"), + ("iphone", + "h", + "Mozilla/5.0 (iPhone; CPU iPhone OS 4_2_1 like Mac OS X) AppleWebKit/533.17.9 (KHTML, like Gecko) Version/5.0.2 Mobile/8C148a Safari/6533.18.5", + ), + ("safari", + "s", + "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_7_3) AppleWebKit/534.55.3 (KHTML, like Gecko) Version/5.1.3 Safari/534.53.10")] def get_by_shortcut(s): diff --git a/netlib/tcp.py b/netlib/tcp.py index fc2c144e..a705c95b 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -48,7 +48,8 @@ class SSLKeyLogger(object): self.f = None self.lock = threading.Lock() - __name__ = "SSLKeyLogger" # required for functools.wraps, which pyOpenSSL uses. + # required for functools.wraps, which pyOpenSSL uses. + __name__ = "SSLKeyLogger" def __call__(self, connection, where, ret): if where == SSL.SSL_CB_HANDSHAKE_DONE and ret == 1: @@ -61,7 +62,10 @@ class SSLKeyLogger(object): self.f.write("\r\n") client_random = connection.client_random().encode("hex") masterkey = connection.master_key().encode("hex") - self.f.write("CLIENT_RANDOM {} {}\r\n".format(client_random, masterkey)) + self.f.write( + "CLIENT_RANDOM {} {}\r\n".format( + client_random, + masterkey)) self.f.flush() def close(self): @@ -75,7 +79,8 @@ class SSLKeyLogger(object): return SSLKeyLogger(filename) return False -log_ssl_key = SSLKeyLogger.create_logfun(os.getenv("MITMPROXY_SSLKEYLOGFILE") or os.getenv("SSLKEYLOGFILE")) +log_ssl_key = SSLKeyLogger.create_logfun( + os.getenv("MITMPROXY_SSLKEYLOGFILE") or os.getenv("SSLKEYLOGFILE")) class _FileLike(object): @@ -378,7 +383,8 @@ class _Connection(object): # Workaround for # https://github.com/pyca/pyopenssl/issues/190 # https://github.com/mitmproxy/mitmproxy/issues/472 - context.set_mode(SSL._lib.SSL_MODE_AUTO_RETRY) # Options already set before are not cleared. + # Options already set before are not cleared. + context.set_mode(SSL._lib.SSL_MODE_AUTO_RETRY) # Cipher List if cipher_list: @@ -420,14 +426,17 @@ class TCPClient(_Connection): def __init__(self, address, source_address=None): self.address = Address.wrap(address) - self.source_address = Address.wrap(source_address) if source_address else None + self.source_address = Address.wrap( + source_address) if source_address else None self.connection, self.rfile, self.wfile = None, None, None self.cert = None self.ssl_established = False self.sni = None def create_ssl_context(self, cert=None, alpn_protos=None, **sslctx_kwargs): - context = self._create_ssl_context(alpn_protos=alpn_protos, **sslctx_kwargs) + context = self._create_ssl_context( + alpn_protos=alpn_protos, + **sslctx_kwargs) # Client Certs if cert: try: @@ -443,7 +452,9 @@ class TCPClient(_Connection): options: A bit field consisting of OpenSSL.SSL.OP_* values """ - context = self.create_ssl_context(alpn_protos=alpn_protos, **sslctx_kwargs) + context = self.create_ssl_context( + alpn_protos=alpn_protos, + **sslctx_kwargs) self.connection = SSL.Connection(context, self.connection) if sni: self.sni = sni @@ -469,7 +480,9 @@ class TCPClient(_Connection): self.rfile = Reader(connection.makefile('rb', self.rbufsize)) self.wfile = Writer(connection.makefile('wb', self.wbufsize)) except (socket.error, IOError) as err: - raise NetLibError('Error connecting to "%s": %s' % (self.address.host, err)) + raise NetLibError( + 'Error connecting to "%s": %s' % + (self.address.host, err)) self.connection = connection def settimeout(self, n): @@ -535,7 +548,9 @@ class BaseHandler(_Connection): until then we're conservative. """ - context = self._create_ssl_context(alpn_select=alpn_select, **sslctx_kwargs) + context = self._create_ssl_context( + alpn_select=alpn_select, + **sslctx_kwargs) context.use_privatekey(key) context.use_certificate(cert.x509) @@ -566,7 +581,11 @@ class BaseHandler(_Connection): For a list of parameters, see BaseHandler._create_ssl_context(...) """ - context = self.create_ssl_context(cert, key, alpn_select=alpn_select, **sslctx_kwargs) + context = self.create_ssl_context( + cert, + key, + alpn_select=alpn_select, + **sslctx_kwargs) self.connection = SSL.Connection(context, self.connection) self.connection.set_accept_state() try: @@ -611,7 +630,8 @@ class TCPServer(object): try: while not self.__shutdown_request: try: - r, w, e = select.select([self.socket], [], [], poll_interval) + r, w, e = select.select( + [self.socket], [], [], poll_interval) except select.error as ex: # pragma: no cover if ex[0] == EINTR: continue diff --git a/netlib/test.py b/netlib/test.py index 63b493a9..14f50157 100644 --- a/netlib/test.py +++ b/netlib/test.py @@ -67,7 +67,9 @@ class TServer(tcp.TCPServer): file(self.ssl["cert"], "rb").read() ) raw = file(self.ssl["key"], "rb").read() - key = OpenSSL.crypto.load_privatekey(OpenSSL.crypto.FILETYPE_PEM, raw) + key = OpenSSL.crypto.load_privatekey( + OpenSSL.crypto.FILETYPE_PEM, + raw) if self.ssl["v3_only"]: method = tcp.SSLv3_METHOD options = OpenSSL.SSL.OP_NO_SSLv2 | OpenSSL.SSL.OP_NO_TLSv1 diff --git a/netlib/wsgi.py b/netlib/wsgi.py index f393039a..827cf6f0 100644 --- a/netlib/wsgi.py +++ b/netlib/wsgi.py @@ -77,7 +77,8 @@ class WSGIAdaptor(object): } environ.update(extra) if flow.client_conn.address: - environ["REMOTE_ADDR"], environ["REMOTE_PORT"] = flow.client_conn.address() + environ["REMOTE_ADDR"], environ[ + "REMOTE_PORT"] = flow.client_conn.address() for key, value in flow.request.headers.items(): key = 'HTTP_' + key.upper().replace('-', '_') -- cgit v1.2.3 From b395049a853aa378773aebae83468c1b889c2d4e Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Thu, 28 May 2015 10:56:00 +0200 Subject: distribute cffi correctly --- netlib/certffi.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'netlib') diff --git a/netlib/certffi.py b/netlib/certffi.py index 81dc72e8..451f4493 100644 --- a/netlib/certffi.py +++ b/netlib/certffi.py @@ -1,8 +1,8 @@ from __future__ import (absolute_import, print_function, division) -import cffi +from cffi import FFI import OpenSSL -xffi = cffi.FFI() +xffi = FFI() xffi.cdef(""" struct rsa_meth_st { int flags; -- cgit v1.2.3 From 4ec181c1403670702c2f163062b92de4dec3d2cc Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Sun, 31 May 2015 13:12:01 +1200 Subject: Move version check to netlib, unit test it. --- netlib/version_check.py | 49 +++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 49 insertions(+) create mode 100644 netlib/version_check.py (limited to 'netlib') diff --git a/netlib/version_check.py b/netlib/version_check.py new file mode 100644 index 00000000..09dc23ae --- /dev/null +++ b/netlib/version_check.py @@ -0,0 +1,49 @@ +from __future__ import print_function, absolute_import +import sys +import inspect +import os.path + +import OpenSSL +from . import version + +PYOPENSSL_MIN_VERSION = (0, 15) + + +def version_check( + mitmproxy_version, + pyopenssl_min_version=PYOPENSSL_MIN_VERSION, + fp=sys.stderr): + """ + Having installed a wrong version of pyOpenSSL or netlib is unfortunately a + very common source of error. Check before every start that both versions + are somewhat okay. + """ + # We don't introduce backward-incompatible changes in patch versions. Only + # consider major and minor version. + if version.IVERSION[:2] != mitmproxy_version[:2]: + print( + "You are using mitmproxy %s with netlib %s. " + "Most likely, that won't work - please upgrade!" % ( + mitmproxy_version, version.VERSION + ), + file=fp + ) + sys.exit(1) + v = tuple([int(x) for x in OpenSSL.__version__.split(".")][:2]) + if v < pyopenssl_min_version: + print( + "You are using an outdated version of pyOpenSSL:" + " mitmproxy requires pyOpenSSL %x or greater." % + pyopenssl_min_version, + file=fp + ) + # Some users apparently have multiple versions of pyOpenSSL installed. + # Report which one we got. + pyopenssl_path = os.path.dirname(inspect.getfile(OpenSSL)) + print( + "Your pyOpenSSL %s installation is located at %s" % ( + OpenSSL.__version__, pyopenssl_path + ), + file=fp + ) + sys.exit(1) -- cgit v1.2.3 From 73376e605a61fab239213da375a612ed7d3274b5 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Sun, 31 May 2015 16:54:14 +1200 Subject: Save first byte timestamp for writers too. --- netlib/tcp.py | 1 + 1 file changed, 1 insertion(+) (limited to 'netlib') diff --git a/netlib/tcp.py b/netlib/tcp.py index a705c95b..c8545d4f 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -147,6 +147,7 @@ class Writer(_FileLike): May raise NetLibDisconnect """ if v: + self.first_byte_timestamp = self.first_byte_timestamp or time.time() try: if hasattr(self.o, "sendall"): self.add_log(v) -- cgit v1.2.3 From f7bd690e3aba0be05c30a3b9a4d499de8dbd5e06 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Sun, 31 May 2015 17:18:55 +1200 Subject: When we see an incomplete read with 0 bytes, it's a disconnect Partially fixes mitmproxy/mitmproxy:#593 --- netlib/tcp.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) (limited to 'netlib') diff --git a/netlib/tcp.py b/netlib/tcp.py index c8545d4f..f6179faa 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -225,9 +225,12 @@ class Reader(_FileLike): """ result = self.read(length) if length != -1 and len(result) != length: - raise NetLibIncomplete( - "Expected %s bytes, got %s" % (length, len(result)) - ) + if not result: + raise NetLibDisconnect() + else: + raise NetLibIncomplete( + "Expected %s bytes, got %s" % (length, len(result)) + ) return result -- cgit v1.2.3 From 35856ead075829d5b086e60c60ac20fdfc8560f1 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Sun, 31 May 2015 17:24:44 +1200 Subject: websockets: nicer human readable --- netlib/websockets.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) (limited to 'netlib') diff --git a/netlib/websockets.py b/netlib/websockets.py index 63dc03f1..bf920897 100644 --- a/netlib/websockets.py +++ b/netlib/websockets.py @@ -175,7 +175,7 @@ class FrameHeader: def human_readable(self): vals = [ - "wf:", + "ws frame:", OPCODE.get_name(self.opcode, hex(self.opcode)).lower() ] flags = [] @@ -327,8 +327,10 @@ class Frame(object): return cls.from_file(tcp.Reader(io.BytesIO(bytestring))) def human_readable(self): - hdr = self.header.human_readable() - return hdr + "\n" + repr(self.payload) + ret = self.header.human_readable() + if self.payload: + ret = ret + "\nPayload:\n" + utils.cleanBin(self.payload) + return ret def to_bytes(self): """ -- cgit v1.2.3 From 113c5c187f0c37ce0c13c399248f4bf91e3a3149 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Thu, 4 Jun 2015 11:14:47 +1200 Subject: Bump version --- netlib/version.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'netlib') diff --git a/netlib/version.py b/netlib/version.py index 3eb0ffc9..bc9a1a57 100644 --- a/netlib/version.py +++ b/netlib/version.py @@ -1,6 +1,6 @@ from __future__ import (absolute_import, print_function, division) -IVERSION = (0, 12, 1) +IVERSION = (0, 12, 2) VERSION = ".".join(str(i) for i in IVERSION) MINORVERSION = ".".join(str(i) for i in IVERSION[:2]) NAME = "netlib" -- cgit v1.2.3 From 4ca62e0d9bd09aa286cde9bafceff7204304d00c Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Fri, 5 Jun 2015 11:42:06 +1200 Subject: tcp: clear_log to clear socket logs --- netlib/tcp.py | 3 +++ 1 file changed, 3 insertions(+) (limited to 'netlib') diff --git a/netlib/tcp.py b/netlib/tcp.py index f6179faa..2ebfae96 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -126,6 +126,9 @@ class _FileLike(object): if self.is_logging(): self._log.append(v) + def clear_log(self): + self._log = [] + def reset_timestamps(self): self.first_byte_timestamp = None -- cgit v1.2.3 From 2d9b9be1f4fb67d6989b57b68858896d8512293e Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Fri, 5 Jun 2015 11:50:29 +1200 Subject: Revert "tcp: clear_log to clear socket logs" start_log also clears the log, which is good enough. This reverts commit 4ca62e0d9bd09aa286cde9bafceff7204304d00c. --- netlib/tcp.py | 3 --- 1 file changed, 3 deletions(-) (limited to 'netlib') diff --git a/netlib/tcp.py b/netlib/tcp.py index 2ebfae96..f6179faa 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -126,9 +126,6 @@ class _FileLike(object): if self.is_logging(): self._log.append(v) - def clear_log(self): - self._log = [] - def reset_timestamps(self): self.first_byte_timestamp = None -- cgit v1.2.3 From 0269d0fb8b8726f8a84ebe916a553ef435a3a50d Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Fri, 5 Jun 2015 17:08:22 +1200 Subject: repr for websocket frames --- netlib/websockets.py | 3 +++ 1 file changed, 3 insertions(+) (limited to 'netlib') diff --git a/netlib/websockets.py b/netlib/websockets.py index bf920897..346adf1b 100644 --- a/netlib/websockets.py +++ b/netlib/websockets.py @@ -332,6 +332,9 @@ class Frame(object): ret = ret + "\nPayload:\n" + utils.cleanBin(self.payload) return ret + def __repr__(self): + return self.header.human_readable() + def to_bytes(self): """ Serialize the frame to wire format. Returns a string. -- cgit v1.2.3 From 9883509f894dde57c8a71340a69581ac46c44f51 Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Fri, 5 Jun 2015 12:44:29 +0200 Subject: simplify default ssl params for test servers --- netlib/test.py | 30 +++++++++++++++++++++--------- 1 file changed, 21 insertions(+), 9 deletions(-) (limited to 'netlib') diff --git a/netlib/test.py b/netlib/test.py index 14f50157..ee8c6685 100644 --- a/netlib/test.py +++ b/netlib/test.py @@ -4,6 +4,7 @@ import Queue import cStringIO import OpenSSL from . import tcp, certutils +import tutils class ServerThread(threading.Thread): @@ -55,22 +56,33 @@ class TServer(tcp.TCPServer): dhparams, v3_only """ tcp.TCPServer.__init__(self, addr) - self.ssl, self.q = ssl, q + + if ssl is True: + self.ssl = dict() + elif isinstance(ssl, dict): + self.ssl = ssl + else: + self.ssl = None + + self.q = q self.handler_klass = handler_klass self.last_handler = None def handle_client_connection(self, request, client_address): h = self.handler_klass(request, client_address, self) self.last_handler = h - if self.ssl: - cert = certutils.SSLCert.from_pem( - file(self.ssl["cert"], "rb").read() - ) - raw = file(self.ssl["key"], "rb").read() + if self.ssl is not None: + raw_cert = self.ssl.get( + "cert", + tutils.test_data.path("data/server.crt")) + cert = certutils.SSLCert.from_pem(file(raw_cert, "rb").read()) + raw_key = self.ssl.get( + "key", + tutils.test_data.path("data/server.key")) key = OpenSSL.crypto.load_privatekey( OpenSSL.crypto.FILETYPE_PEM, - raw) - if self.ssl["v3_only"]: + file(raw_key, "rb").read()) + if self.ssl.get("v3_only", False): method = tcp.SSLv3_METHOD options = OpenSSL.SSL.OP_NO_SSLv2 | OpenSSL.SSL.OP_NO_TLSv1 else: @@ -81,7 +93,7 @@ class TServer(tcp.TCPServer): method=method, options=options, handle_sni=getattr(h, "handle_sni", None), - request_client_cert=self.ssl["request_client_cert"], + request_client_cert=self.ssl.get("request_client_cert", None), cipher_list=self.ssl.get("cipher_list", None), dhparams=self.ssl.get("dhparams", None), chain_file=self.ssl.get("chain_file", None), -- cgit v1.2.3 From 436291764c4e557155d7e4e87482a4e378a2ccce Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Mon, 1 Jun 2015 15:14:31 +0200 Subject: http2: fix default settings --- netlib/h2/h2.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'netlib') diff --git a/netlib/h2/h2.py b/netlib/h2/h2.py index 707b1465..227139a3 100644 --- a/netlib/h2/h2.py +++ b/netlib/h2/h2.py @@ -29,8 +29,8 @@ class H2Client(tcp.TCPClient): SettingsFrame.SETTINGS.SETTINGS_HEADER_TABLE_SIZE: 4096, SettingsFrame.SETTINGS.SETTINGS_ENABLE_PUSH: 1, SettingsFrame.SETTINGS.SETTINGS_MAX_CONCURRENT_STREAMS: None, - SettingsFrame.SETTINGS.SETTINGS_INITIAL_WINDOW_SIZE: 2 ^ 16 - 1, - SettingsFrame.SETTINGS.SETTINGS_MAX_FRAME_SIZE: 2 ^ 14, + SettingsFrame.SETTINGS.SETTINGS_INITIAL_WINDOW_SIZE: 2 ** 16 - 1, + SettingsFrame.SETTINGS.SETTINGS_MAX_FRAME_SIZE: 2 ** 14, SettingsFrame.SETTINGS.SETTINGS_MAX_HEADER_LIST_SIZE: None, } -- cgit v1.2.3 From e4c129026fbf4228c13ae64da19a9a85fc7ff2a5 Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Mon, 1 Jun 2015 15:17:50 +0200 Subject: http2: introduce state for connection objects --- netlib/h2/frame.py | 102 +++++++++++++++++++++++++++++++++-------------------- 1 file changed, 63 insertions(+), 39 deletions(-) (limited to 'netlib') diff --git a/netlib/h2/frame.py b/netlib/h2/frame.py index 36456c46..174ceebd 100644 --- a/netlib/h2/frame.py +++ b/netlib/h2/frame.py @@ -20,18 +20,28 @@ class Frame(object): FLAG_PADDED = 0x8 FLAG_PRIORITY = 0x20 - def __init__(self, length, flags, stream_id): + def __init__(self, state=None, length=0, flags=FLAG_NO_FLAGS, stream_id=0x0): valid_flags = reduce(lambda x, y: x | y, self.VALID_FLAGS, 0x0) if flags | valid_flags != valid_flags: raise ValueError('invalid flags detected.') + if state is None: + class State(object): + pass + + state = State() + state.encoder = Encoder() + state.decoder = Decoder() + + self.state = state + self.length = length self.type = self.TYPE self.flags = flags self.stream_id = stream_id @classmethod - def from_file(self, fp): + def from_file(self, fp, state=None): """ read a HTTP/2 frame sent by a server or client fp is a "file like" object that could be backed by a network @@ -45,16 +55,16 @@ class Frame(object): stream_id = fields[4] payload = fp.safe_read(length) - return FRAMES[fields[2]].from_bytes(length, flags, stream_id, payload) + return FRAMES[fields[2]].from_bytes(state, length, flags, stream_id, payload) @classmethod - def from_bytes(self, data): + def from_bytes(self, data, state=None): fields = struct.unpack("!HBBBL", data[:9]) length = (fields[0] << 8) + fields[1] # type is already deducted from class flags = fields[3] stream_id = fields[4] - return FRAMES[fields[2]].from_bytes(length, flags, stream_id, data[9:]) + return FRAMES[fields[2]].from_bytes(state, length, flags, stream_id, data[9:]) def to_bytes(self): payload = self.payload_bytes() @@ -96,18 +106,19 @@ class DataFrame(Frame): def __init__( self, + state=None, length=0, flags=Frame.FLAG_NO_FLAGS, stream_id=0x0, payload=b'', pad_length=0): - super(DataFrame, self).__init__(length, flags, stream_id) + super(DataFrame, self).__init__(state, length, flags, stream_id) self.payload = payload self.pad_length = pad_length @classmethod - def from_bytes(self, length, flags, stream_id, payload): - f = self(length=length, flags=flags, stream_id=stream_id) + def from_bytes(self, state, length, flags, stream_id, payload): + f = self(state=state, length=length, flags=flags, stream_id=stream_id) if f.flags & self.FLAG_PADDED: f.pad_length = struct.unpack('!B', payload[0])[0] @@ -146,6 +157,7 @@ class HeadersFrame(Frame): def __init__( self, + state=None, length=0, flags=Frame.FLAG_NO_FLAGS, stream_id=0x0, @@ -154,7 +166,7 @@ class HeadersFrame(Frame): exclusive=False, stream_dependency=0x0, weight=0): - super(HeadersFrame, self).__init__(length, flags, stream_id) + super(HeadersFrame, self).__init__(state, length, flags, stream_id) if headers is None: headers = [] @@ -166,8 +178,8 @@ class HeadersFrame(Frame): self.weight = weight @classmethod - def from_bytes(self, length, flags, stream_id, payload): - f = self(length=length, flags=flags, stream_id=stream_id) + def from_bytes(self, state, length, flags, stream_id, payload): + f = self(state=state, length=length, flags=flags, stream_id=stream_id) if f.flags & self.FLAG_PADDED: f.pad_length = struct.unpack('!B', payload[0])[0] @@ -177,18 +189,22 @@ class HeadersFrame(Frame): if f.flags & self.FLAG_PRIORITY: f.stream_dependency, f.weight = struct.unpack( - '!LB', header_block_fragment[ - :5]) + '!LB', header_block_fragment[:5]) f.exclusive = bool(f.stream_dependency >> 31) f.stream_dependency &= 0x7FFFFFFF header_block_fragment = header_block_fragment[5:] - for header, value in Decoder().decode(header_block_fragment): + for header, value in f.state.decoder.decode(header_block_fragment): f.headers.append((header, value)) return f def payload_bytes(self): + """ + This encodes all headers with HPACK + Do NOT call this method twice - it will change the encoder state! + """ + if self.stream_id == 0x0: raise ValueError('HEADERS frames MUST be associated with a stream.') @@ -201,7 +217,7 @@ class HeadersFrame(Frame): (int(self.exclusive) << 31) | self.stream_dependency, self.weight) - b += Encoder().encode(self.headers) + b += self.state.encoder.encode(self.headers) if self.flags & self.FLAG_PADDED: b += b'\0' * self.pad_length @@ -234,20 +250,21 @@ class PriorityFrame(Frame): def __init__( self, + state=None, length=0, flags=Frame.FLAG_NO_FLAGS, stream_id=0x0, exclusive=False, stream_dependency=0x0, weight=0): - super(PriorityFrame, self).__init__(length, flags, stream_id) + super(PriorityFrame, self).__init__(state, length, flags, stream_id) self.exclusive = exclusive self.stream_dependency = stream_dependency self.weight = weight @classmethod - def from_bytes(self, length, flags, stream_id, payload): - f = self(length=length, flags=flags, stream_id=stream_id) + def from_bytes(self, state, length, flags, stream_id, payload): + f = self(state=state, length=length, flags=flags, stream_id=stream_id) f.stream_dependency, f.weight = struct.unpack('!LB', payload) f.exclusive = bool(f.stream_dependency >> 31) @@ -283,16 +300,17 @@ class RstStreamFrame(Frame): def __init__( self, + state=None, length=0, flags=Frame.FLAG_NO_FLAGS, stream_id=0x0, error_code=0x0): - super(RstStreamFrame, self).__init__(length, flags, stream_id) + super(RstStreamFrame, self).__init__(state, length, flags, stream_id) self.error_code = error_code @classmethod - def from_bytes(self, length, flags, stream_id, payload): - f = self(length=length, flags=flags, stream_id=stream_id) + def from_bytes(self, state, length, flags, stream_id, payload): + f = self(state=state, length=length, flags=flags, stream_id=stream_id) f.error_code = struct.unpack('!L', payload)[0] return f @@ -322,11 +340,12 @@ class SettingsFrame(Frame): def __init__( self, + state=None, length=0, flags=Frame.FLAG_NO_FLAGS, stream_id=0x0, settings=None): - super(SettingsFrame, self).__init__(length, flags, stream_id) + super(SettingsFrame, self).__init__(state, length, flags, stream_id) if settings is None: settings = {} @@ -334,8 +353,8 @@ class SettingsFrame(Frame): self.settings = settings @classmethod - def from_bytes(self, length, flags, stream_id, payload): - f = self(length=length, flags=flags, stream_id=stream_id) + def from_bytes(self, state, length, flags, stream_id, payload): + f = self(state=state, length=length, flags=flags, stream_id=stream_id) for i in xrange(0, len(payload), 6): identifier, value = struct.unpack("!HL", payload[i:i + 6]) @@ -372,20 +391,21 @@ class PushPromiseFrame(Frame): def __init__( self, + state=None, length=0, flags=Frame.FLAG_NO_FLAGS, stream_id=0x0, promised_stream=0x0, header_block_fragment=b'', pad_length=0): - super(PushPromiseFrame, self).__init__(length, flags, stream_id) + super(PushPromiseFrame, self).__init__(state, length, flags, stream_id) self.pad_length = pad_length self.promised_stream = promised_stream self.header_block_fragment = header_block_fragment @classmethod - def from_bytes(self, length, flags, stream_id, payload): - f = self(length=length, flags=flags, stream_id=stream_id) + def from_bytes(self, state, length, flags, stream_id, payload): + f = self(state=state, length=length, flags=flags, stream_id=stream_id) if f.flags & self.FLAG_PADDED: f.pad_length, f.promised_stream = struct.unpack('!BL', payload[:5]) @@ -435,16 +455,17 @@ class PingFrame(Frame): def __init__( self, + state=None, length=0, flags=Frame.FLAG_NO_FLAGS, stream_id=0x0, payload=b''): - super(PingFrame, self).__init__(length, flags, stream_id) + super(PingFrame, self).__init__(state, length, flags, stream_id) self.payload = payload @classmethod - def from_bytes(self, length, flags, stream_id, payload): - f = self(length=length, flags=flags, stream_id=stream_id) + def from_bytes(self, state, length, flags, stream_id, payload): + f = self(state=state, length=length, flags=flags, stream_id=stream_id) f.payload = payload return f @@ -467,20 +488,21 @@ class GoAwayFrame(Frame): def __init__( self, + state=None, length=0, flags=Frame.FLAG_NO_FLAGS, stream_id=0x0, last_stream=0x0, error_code=0x0, data=b''): - super(GoAwayFrame, self).__init__(length, flags, stream_id) + super(GoAwayFrame, self).__init__(state, length, flags, stream_id) self.last_stream = last_stream self.error_code = error_code self.data = data @classmethod - def from_bytes(self, length, flags, stream_id, payload): - f = self(length=length, flags=flags, stream_id=stream_id) + def from_bytes(self, state, length, flags, stream_id, payload): + f = self(state=state, length=length, flags=flags, stream_id=stream_id) f.last_stream, f.error_code = struct.unpack("!LL", payload[:8]) f.last_stream &= 0x7FFFFFFF @@ -511,16 +533,17 @@ class WindowUpdateFrame(Frame): def __init__( self, + state=None, length=0, flags=Frame.FLAG_NO_FLAGS, stream_id=0x0, window_size_increment=0x0): - super(WindowUpdateFrame, self).__init__(length, flags, stream_id) + super(WindowUpdateFrame, self).__init__(state, length, flags, stream_id) self.window_size_increment = window_size_increment @classmethod - def from_bytes(self, length, flags, stream_id, payload): - f = self(length=length, flags=flags, stream_id=stream_id) + def from_bytes(self, state, length, flags, stream_id, payload): + f = self(state=state, length=length, flags=flags, stream_id=stream_id) f.window_size_increment = struct.unpack("!L", payload)[0] f.window_size_increment &= 0x7FFFFFFF @@ -544,16 +567,17 @@ class ContinuationFrame(Frame): def __init__( self, + state=None, length=0, flags=Frame.FLAG_NO_FLAGS, stream_id=0x0, header_block_fragment=b''): - super(ContinuationFrame, self).__init__(length, flags, stream_id) + super(ContinuationFrame, self).__init__(state, length, flags, stream_id) self.header_block_fragment = header_block_fragment @classmethod - def from_bytes(self, length, flags, stream_id, payload): - f = self(length=length, flags=flags, stream_id=stream_id) + def from_bytes(self, state, length, flags, stream_id, payload): + f = self(state=state, length=length, flags=flags, stream_id=stream_id) f.header_block_fragment = payload return f -- cgit v1.2.3 From 5cecbdc1687346bb2bf139c904ffda2b37dc8276 Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Mon, 1 Jun 2015 12:34:50 +0200 Subject: http2: add basic protocol handling --- netlib/h2/__init__.py | 169 ++++++++++++++++++++++++++++++++++++++++++++++++++ netlib/h2/frame.py | 50 ++++++++++++--- netlib/h2/h2.py | 89 -------------------------- 3 files changed, 211 insertions(+), 97 deletions(-) delete mode 100644 netlib/h2/h2.py (limited to 'netlib') diff --git a/netlib/h2/__init__.py b/netlib/h2/__init__.py index 9b4faa33..054ba91c 100644 --- a/netlib/h2/__init__.py +++ b/netlib/h2/__init__.py @@ -1 +1,170 @@ from __future__ import (absolute_import, print_function, division) +import itertools + +from .. import utils +from .frame import * + + +class HTTP2Protocol(object): + + ERROR_CODES = utils.BiDi( + NO_ERROR=0x0, + PROTOCOL_ERROR=0x1, + INTERNAL_ERROR=0x2, + FLOW_CONTROL_ERROR=0x3, + SETTINGS_TIMEOUT=0x4, + STREAM_CLOSED=0x5, + FRAME_SIZE_ERROR=0x6, + REFUSED_STREAM=0x7, + CANCEL=0x8, + COMPRESSION_ERROR=0x9, + CONNECT_ERROR=0xa, + ENHANCE_YOUR_CALM=0xb, + INADEQUATE_SECURITY=0xc, + HTTP_1_1_REQUIRED=0xd + ) + + # "PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n" + CLIENT_CONNECTION_PREFACE = '505249202a20485454502f322e300d0a0d0a534d0d0a0d0a' + + ALPN_PROTO_H2 = b'h2' + + HTTP2_DEFAULT_SETTINGS = { + SettingsFrame.SETTINGS.SETTINGS_HEADER_TABLE_SIZE: 4096, + SettingsFrame.SETTINGS.SETTINGS_ENABLE_PUSH: 1, + SettingsFrame.SETTINGS.SETTINGS_MAX_CONCURRENT_STREAMS: None, + SettingsFrame.SETTINGS.SETTINGS_INITIAL_WINDOW_SIZE: 2 ** 16 - 1, + SettingsFrame.SETTINGS.SETTINGS_MAX_FRAME_SIZE: 2 ** 14, + SettingsFrame.SETTINGS.SETTINGS_MAX_HEADER_LIST_SIZE: None, + } + + def __init__(self): + self.http2_settings = self.HTTP2_DEFAULT_SETTINGS.copy() + self.current_stream_id = None + self.encoder = Encoder() + self.decoder = Decoder() + + def check_alpn(self): + alp = self.get_alpn_proto_negotiated() + if alp != self.ALPN_PROTO_H2: + raise NotImplementedError( + "H2Client can not handle unknown ALP: %s" % alp) + print("-> Successfully negotiated 'h2' application layer protocol.") + + def send_connection_preface(self): + self.wfile.write(bytes(self.CLIENT_CONNECTION_PREFACE.decode('hex'))) + self.send_frame(SettingsFrame(state=self)) + + frame = Frame.from_file(self.rfile, self) + assert isinstance(frame, SettingsFrame) + self._apply_settings(frame.settings) + self.read_frame() # read setting ACK frame + + print("-> Connection Preface completed.") + + def next_stream_id(self): + if self.current_stream_id is None: + self.current_stream_id = 1 + else: + self.current_stream_id += 2 + return self.current_stream_id + + def send_frame(self, frame): + raw_bytes = frame.to_bytes() + self.wfile.write(raw_bytes) + self.wfile.flush() + + def read_frame(self): + frame = Frame.from_file(self.rfile, self) + if isinstance(frame, SettingsFrame): + self._apply_settings(frame.settings) + + return frame + + def _apply_settings(self, settings): + for setting, value in settings.items(): + old_value = self.http2_settings[setting] + if not old_value: + old_value = '-' + + self.http2_settings[setting] = value + print("-> Setting changed: %s to %d (was %s)" % ( + SettingsFrame.SETTINGS.get_name(setting), + value, + str(old_value))) + + self.send_frame(SettingsFrame(state=self, flags=Frame.FLAG_ACK)) + print("-> New settings acknowledged.") + + def _create_headers(self, headers, stream_id, end_stream=True): + # TODO: implement max frame size checks and sending in chunks + + flags = Frame.FLAG_END_HEADERS + if end_stream: + flags |= Frame.FLAG_END_STREAM + + bytes = HeadersFrame( + state=self, + flags=flags, + stream_id=stream_id, + headers=headers).to_bytes() + return [bytes] + + def _create_body(self, body, stream_id): + if body is None or len(body) == 0: + return b'' + + # TODO: implement max frame size checks and sending in chunks + # TODO: implement flow-control window + + bytes = DataFrame( + state=self, + flags=Frame.FLAG_END_STREAM, + stream_id=stream_id, + payload=body).to_bytes() + return [bytes] + + def create_request(self, method, path, headers=None, body=None): + if headers is None: + headers = [] + + headers = [ + (b':method', bytes(method)), + (b':path', bytes(path)), + (b':scheme', b'https')] + headers + + stream_id = self.next_stream_id() + + return list(itertools.chain( + self._create_headers(headers, stream_id, end_stream=(body is None)), + self._create_body(body, stream_id))) + + def read_response(self): + header_block_fragment = b'' + body = b'' + + while True: + frame = self.read_frame() + if isinstance(frame, HeadersFrame): + header_block_fragment += frame.header_block_fragment + if frame.flags | Frame.FLAG_END_HEADERS: + break + else: + print("Unexpected frame received:") + print(frame.human_readable()) + + while True: + frame = self.read_frame() + if isinstance(frame, DataFrame): + body += frame.payload + if frame.flags | Frame.FLAG_END_STREAM: + break + else: + print("Unexpected frame received:") + print(frame.human_readable()) + + headers = {} + for header, value in self.decoder.decode(header_block_fragment): + headers[header] = value + + return headers[':status'], headers, body diff --git a/netlib/h2/frame.py b/netlib/h2/frame.py index 174ceebd..137cbb3d 100644 --- a/netlib/h2/frame.py +++ b/netlib/h2/frame.py @@ -20,16 +20,24 @@ class Frame(object): FLAG_PADDED = 0x8 FLAG_PRIORITY = 0x20 - def __init__(self, state=None, length=0, flags=FLAG_NO_FLAGS, stream_id=0x0): + def __init__( + self, + state=None, + length=0, + flags=FLAG_NO_FLAGS, + stream_id=0x0): valid_flags = reduce(lambda x, y: x | y, self.VALID_FLAGS, 0x0) if flags | valid_flags != valid_flags: raise ValueError('invalid flags detected.') if state is None: + from . import HTTP2Protocol + class State(object): pass state = State() + state.http2_settings = HTTP2Protocol.HTTP2_DEFAULT_SETTINGS.copy() state.encoder = Encoder() state.decoder = Decoder() @@ -40,6 +48,14 @@ class Frame(object): self.flags = flags self.stream_id = stream_id + def _check_frame_size(self, length): + max_length = self.state.http2_settings[ + SettingsFrame.SETTINGS.SETTINGS_MAX_FRAME_SIZE] + if length > max_length: + raise NotImplementedError( + "Frame size exceeded: %d, but only %d allowed." % ( + length, max_length)) + @classmethod def from_file(self, fp, state=None): """ @@ -54,8 +70,15 @@ class Frame(object): flags = fields[3] stream_id = fields[4] + # TODO: check frame size if <= current SETTINGS_MAX_FRAME_SIZE + payload = fp.safe_read(length) - return FRAMES[fields[2]].from_bytes(state, length, flags, stream_id, payload) + return FRAMES[fields[2]].from_bytes( + state, + length, + flags, + stream_id, + payload) @classmethod def from_bytes(self, data, state=None): @@ -64,12 +87,20 @@ class Frame(object): # type is already deducted from class flags = fields[3] stream_id = fields[4] - return FRAMES[fields[2]].from_bytes(state, length, flags, stream_id, data[9:]) + + return FRAMES[fields[2]].from_bytes( + state, + length, + flags, + stream_id, + data[9:]) def to_bytes(self): payload = self.payload_bytes() self.length = len(payload) + self._check_frame_size(self.length) + b = struct.pack('!HB', self.length & 0xFFFF00, self.length & 0x0000FF) b += struct.pack('!B', self.TYPE) b += struct.pack('!B', self.flags) @@ -183,19 +214,20 @@ class HeadersFrame(Frame): if f.flags & self.FLAG_PADDED: f.pad_length = struct.unpack('!B', payload[0])[0] - header_block_fragment = payload[1:-f.pad_length] + f.header_block_fragment = payload[1:-f.pad_length] else: - header_block_fragment = payload[0:] + f.header_block_fragment = payload[0:] if f.flags & self.FLAG_PRIORITY: f.stream_dependency, f.weight = struct.unpack( '!LB', header_block_fragment[:5]) f.exclusive = bool(f.stream_dependency >> 31) f.stream_dependency &= 0x7FFFFFFF - header_block_fragment = header_block_fragment[5:] + f.header_block_fragment = f.header_block_fragment[5:] - for header, value in f.state.decoder.decode(header_block_fragment): - f.headers.append((header, value)) + # TODO only do this if END_HEADERS or something... + # for header, value in f.state.decoder.decode(f.header_block_fragment): + # f.headers.append((header, value)) return f @@ -217,6 +249,8 @@ class HeadersFrame(Frame): (int(self.exclusive) << 31) | self.stream_dependency, self.weight) + # TODO: maybe remove that and only deal with header_block_fragments + # inside frames b += self.state.encoder.encode(self.headers) if self.flags & self.FLAG_PADDED: diff --git a/netlib/h2/h2.py b/netlib/h2/h2.py deleted file mode 100644 index 227139a3..00000000 --- a/netlib/h2/h2.py +++ /dev/null @@ -1,89 +0,0 @@ -from .. import utils, odict, tcp -from frame import * - -# "PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n" -CLIENT_CONNECTION_PREFACE = '505249202a20485454502f322e300d0a0d0a534d0d0a0d0a' - -ERROR_CODES = utils.BiDi( - NO_ERROR=0x0, - PROTOCOL_ERROR=0x1, - INTERNAL_ERROR=0x2, - FLOW_CONTROL_ERROR=0x3, - SETTINGS_TIMEOUT=0x4, - STREAM_CLOSED=0x5, - FRAME_SIZE_ERROR=0x6, - REFUSED_STREAM=0x7, - CANCEL=0x8, - COMPRESSION_ERROR=0x9, - CONNECT_ERROR=0xa, - ENHANCE_YOUR_CALM=0xb, - INADEQUATE_SECURITY=0xc, - HTTP_1_1_REQUIRED=0xd -) - - -class H2Client(tcp.TCPClient): - ALPN_PROTO_H2 = b'h2' - - DEFAULT_SETTINGS = { - SettingsFrame.SETTINGS.SETTINGS_HEADER_TABLE_SIZE: 4096, - SettingsFrame.SETTINGS.SETTINGS_ENABLE_PUSH: 1, - SettingsFrame.SETTINGS.SETTINGS_MAX_CONCURRENT_STREAMS: None, - SettingsFrame.SETTINGS.SETTINGS_INITIAL_WINDOW_SIZE: 2 ** 16 - 1, - SettingsFrame.SETTINGS.SETTINGS_MAX_FRAME_SIZE: 2 ** 14, - SettingsFrame.SETTINGS.SETTINGS_MAX_HEADER_LIST_SIZE: None, - } - - def __init__(self, address, source_address=None): - super(H2Client, self).__init__(address, source_address) - self.settings = self.DEFAULT_SETTINGS.copy() - - def connect(self, send_preface=True): - super(H2Client, self).connect() - self.convert_to_ssl(alpn_protos=[self.ALPN_PROTO_H2]) - - alp = self.get_alpn_proto_negotiated() - if alp != b'h2': - raise NotImplementedError( - "H2Client can not handle unknown protocol: %s" % - alp) - print "-> Successfully negotiated 'h2' application layer protocol." - - if send_preface: - self.wfile.write(bytes(CLIENT_CONNECTION_PREFACE.decode('hex'))) - self.send_frame(SettingsFrame()) - - frame = Frame.from_file(self.rfile) - print frame.human_readable() - assert isinstance(frame, SettingsFrame) - self.apply_settings(frame.settings) - - print "-> Connection Preface completed." - - print "-> H2Client is ready..." - - def send_frame(self, frame): - self.wfile.write(frame.to_bytes()) - self.wfile.flush() - - def read_frame(self): - frame = Frame.from_file(self.rfile) - if isinstance(frame, SettingsFrame): - self.apply_settings(frame.settings) - - return frame - - def apply_settings(self, settings): - for setting, value in settings.items(): - old_value = self.settings[setting] - if not old_value: - old_value = '-' - - self.settings[setting] = value - print "-> Setting changed: %s to %d (was %s)" % - (SettingsFrame.SETTINGS.get_name(setting), - value, - str(old_value)) - - self.send_frame(SettingsFrame(flags=Frame.FLAG_ACK)) - print "-> New settings acknowledged." -- cgit v1.2.3 From 40fa113116a2d3a549bc57c1b1381bbb55c7014b Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Thu, 4 Jun 2015 14:11:19 +0200 Subject: http2: change header_block_fragment handling --- netlib/h2/frame.py | 65 +++++++++++++++++------------------------------------- 1 file changed, 20 insertions(+), 45 deletions(-) (limited to 'netlib') diff --git a/netlib/h2/frame.py b/netlib/h2/frame.py index 137cbb3d..0755c96c 100644 --- a/netlib/h2/frame.py +++ b/netlib/h2/frame.py @@ -48,13 +48,21 @@ class Frame(object): self.flags = flags self.stream_id = stream_id - def _check_frame_size(self, length): - max_length = self.state.http2_settings[ - SettingsFrame.SETTINGS.SETTINGS_MAX_FRAME_SIZE] - if length > max_length: + @classmethod + def _check_frame_size(self, length, state): + from . import HTTP2Protocol + + if state: + settings = state.http2_settings + else: + settings = HTTP2Protocol.HTTP2_DEFAULT_SETTINGS + + max_frame_size = settings[SettingsFrame.SETTINGS.SETTINGS_MAX_FRAME_SIZE] + + if length > max_frame_size: raise NotImplementedError( "Frame size exceeded: %d, but only %d allowed." % ( - length, max_length)) + length, max_frame_size)) @classmethod def from_file(self, fp, state=None): @@ -70,7 +78,7 @@ class Frame(object): flags = fields[3] stream_id = fields[4] - # TODO: check frame size if <= current SETTINGS_MAX_FRAME_SIZE + self._check_frame_size(length, state) payload = fp.safe_read(length) return FRAMES[fields[2]].from_bytes( @@ -80,26 +88,11 @@ class Frame(object): stream_id, payload) - @classmethod - def from_bytes(self, data, state=None): - fields = struct.unpack("!HBBBL", data[:9]) - length = (fields[0] << 8) + fields[1] - # type is already deducted from class - flags = fields[3] - stream_id = fields[4] - - return FRAMES[fields[2]].from_bytes( - state, - length, - flags, - stream_id, - data[9:]) - def to_bytes(self): payload = self.payload_bytes() self.length = len(payload) - self._check_frame_size(self.length) + self._check_frame_size(self.length, self.state) b = struct.pack('!HB', self.length & 0xFFFF00, self.length & 0x0000FF) b += struct.pack('!B', self.TYPE) @@ -192,17 +185,14 @@ class HeadersFrame(Frame): length=0, flags=Frame.FLAG_NO_FLAGS, stream_id=0x0, - headers=None, + header_block_fragment=b'', pad_length=0, exclusive=False, stream_dependency=0x0, weight=0): super(HeadersFrame, self).__init__(state, length, flags, stream_id) - if headers is None: - headers = [] - - self.headers = headers + self.header_block_fragment = header_block_fragment self.pad_length = pad_length self.exclusive = exclusive self.stream_dependency = stream_dependency @@ -220,23 +210,14 @@ class HeadersFrame(Frame): if f.flags & self.FLAG_PRIORITY: f.stream_dependency, f.weight = struct.unpack( - '!LB', header_block_fragment[:5]) + '!LB', f.header_block_fragment[:5]) f.exclusive = bool(f.stream_dependency >> 31) f.stream_dependency &= 0x7FFFFFFF f.header_block_fragment = f.header_block_fragment[5:] - # TODO only do this if END_HEADERS or something... - # for header, value in f.state.decoder.decode(f.header_block_fragment): - # f.headers.append((header, value)) - return f def payload_bytes(self): - """ - This encodes all headers with HPACK - Do NOT call this method twice - it will change the encoder state! - """ - if self.stream_id == 0x0: raise ValueError('HEADERS frames MUST be associated with a stream.') @@ -249,9 +230,7 @@ class HeadersFrame(Frame): (int(self.exclusive) << 31) | self.stream_dependency, self.weight) - # TODO: maybe remove that and only deal with header_block_fragments - # inside frames - b += self.state.encoder.encode(self.headers) + b += self.header_block_fragment if self.flags & self.FLAG_PADDED: b += b'\0' * self.pad_length @@ -269,11 +248,7 @@ class HeadersFrame(Frame): if self.flags & self.FLAG_PADDED: s.append("padding: %d" % self.pad_length) - if not self.headers: - s.append("headers: None") - else: - for header, value in self.headers: - s.append("%s: %s" % (header, value)) + s.append("header_block_fragment: %s" % self.header_block_fragment.encode('hex')) return "\n".join(s) -- cgit v1.2.3 From 623dd850e0ce15630e0950b4de843c0af8046618 Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Thu, 4 Jun 2015 14:28:09 +0200 Subject: http2: add logging and error handling --- netlib/h2/__init__.py | 28 ++++++++++++++++++---------- netlib/h2/frame.py | 16 ++++++++++++---- 2 files changed, 30 insertions(+), 14 deletions(-) (limited to 'netlib') diff --git a/netlib/h2/__init__.py b/netlib/h2/__init__.py index 054ba91c..c06f7a11 100644 --- a/netlib/h2/__init__.py +++ b/netlib/h2/__init__.py @@ -1,8 +1,11 @@ from __future__ import (absolute_import, print_function, division) import itertools +import logging -from .. import utils from .frame import * +from .. import utils + +log = logging.getLogger(__name__) class HTTP2Protocol(object): @@ -49,7 +52,7 @@ class HTTP2Protocol(object): if alp != self.ALPN_PROTO_H2: raise NotImplementedError( "H2Client can not handle unknown ALP: %s" % alp) - print("-> Successfully negotiated 'h2' application layer protocol.") + log.debug("ALP 'h2' successfully negotiated.") def send_connection_preface(self): self.wfile.write(bytes(self.CLIENT_CONNECTION_PREFACE.decode('hex'))) @@ -60,7 +63,7 @@ class HTTP2Protocol(object): self._apply_settings(frame.settings) self.read_frame() # read setting ACK frame - print("-> Connection Preface completed.") + log.debug("Connection Preface completed.") def next_stream_id(self): if self.current_stream_id is None: @@ -88,13 +91,13 @@ class HTTP2Protocol(object): old_value = '-' self.http2_settings[setting] = value - print("-> Setting changed: %s to %d (was %s)" % ( + log.debug("Setting changed: %s to %d (was %s)" % ( SettingsFrame.SETTINGS.get_name(setting), value, str(old_value))) self.send_frame(SettingsFrame(state=self, flags=Frame.FLAG_ACK)) - print("-> New settings acknowledged.") + log.debug("New settings acknowledged.") def _create_headers(self, headers, stream_id, end_stream=True): # TODO: implement max frame size checks and sending in chunks @@ -103,11 +106,13 @@ class HTTP2Protocol(object): if end_stream: flags |= Frame.FLAG_END_STREAM + header_block_fragment = self.encoder.encode(headers) + bytes = HeadersFrame( state=self, flags=flags, stream_id=stream_id, - headers=headers).to_bytes() + header_block_fragment=header_block_fragment).to_bytes() return [bytes] def _create_body(self, body, stream_id): @@ -150,8 +155,8 @@ class HTTP2Protocol(object): if frame.flags | Frame.FLAG_END_HEADERS: break else: - print("Unexpected frame received:") - print(frame.human_readable()) + log.debug("Unexpected frame received:") + log.debug(frame.human_readable()) while True: frame = self.read_frame() @@ -160,11 +165,14 @@ class HTTP2Protocol(object): if frame.flags | Frame.FLAG_END_STREAM: break else: - print("Unexpected frame received:") - print(frame.human_readable()) + log.debug("Unexpected frame received:") + log.debug(frame.human_readable()) headers = {} for header, value in self.decoder.decode(header_block_fragment): headers[header] = value + for header, value in headers.items(): + log.debug("%s: %s" % (header, value)) + return headers[':status'], headers, body diff --git a/netlib/h2/frame.py b/netlib/h2/frame.py index 0755c96c..018e822f 100644 --- a/netlib/h2/frame.py +++ b/netlib/h2/frame.py @@ -1,9 +1,14 @@ import struct +import logging +from functools import reduce from hpack.hpack import Encoder, Decoder from .. import utils -from functools import reduce +log = logging.getLogger(__name__) + +class FrameSizeError(Exception): + pass class Frame(object): @@ -57,10 +62,11 @@ class Frame(object): else: settings = HTTP2Protocol.HTTP2_DEFAULT_SETTINGS - max_frame_size = settings[SettingsFrame.SETTINGS.SETTINGS_MAX_FRAME_SIZE] + max_frame_size = settings[ + SettingsFrame.SETTINGS.SETTINGS_MAX_FRAME_SIZE] if length > max_frame_size: - raise NotImplementedError( + raise FrameSizeError( "Frame size exceeded: %d, but only %d allowed." % ( length, max_frame_size)) @@ -248,7 +254,9 @@ class HeadersFrame(Frame): if self.flags & self.FLAG_PADDED: s.append("padding: %d" % self.pad_length) - s.append("header_block_fragment: %s" % self.header_block_fragment.encode('hex')) + s.append( + "header_block_fragment: %s" % + self.header_block_fragment.encode('hex')) return "\n".join(s) -- cgit v1.2.3 From f003f87197a6dffe1b51a82f7dd218121c75e206 Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Thu, 4 Jun 2015 19:44:48 +0200 Subject: http2: rename module and refactor as strategy --- netlib/h2/__init__.py | 178 -------------- netlib/h2/frame.py | 623 ---------------------------------------------- netlib/http2/__init__.py | 181 ++++++++++++++ netlib/http2/frame.py | 625 +++++++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 806 insertions(+), 801 deletions(-) delete mode 100644 netlib/h2/__init__.py delete mode 100644 netlib/h2/frame.py create mode 100644 netlib/http2/__init__.py create mode 100644 netlib/http2/frame.py (limited to 'netlib') diff --git a/netlib/h2/__init__.py b/netlib/h2/__init__.py deleted file mode 100644 index c06f7a11..00000000 --- a/netlib/h2/__init__.py +++ /dev/null @@ -1,178 +0,0 @@ -from __future__ import (absolute_import, print_function, division) -import itertools -import logging - -from .frame import * -from .. import utils - -log = logging.getLogger(__name__) - - -class HTTP2Protocol(object): - - ERROR_CODES = utils.BiDi( - NO_ERROR=0x0, - PROTOCOL_ERROR=0x1, - INTERNAL_ERROR=0x2, - FLOW_CONTROL_ERROR=0x3, - SETTINGS_TIMEOUT=0x4, - STREAM_CLOSED=0x5, - FRAME_SIZE_ERROR=0x6, - REFUSED_STREAM=0x7, - CANCEL=0x8, - COMPRESSION_ERROR=0x9, - CONNECT_ERROR=0xa, - ENHANCE_YOUR_CALM=0xb, - INADEQUATE_SECURITY=0xc, - HTTP_1_1_REQUIRED=0xd - ) - - # "PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n" - CLIENT_CONNECTION_PREFACE = '505249202a20485454502f322e300d0a0d0a534d0d0a0d0a' - - ALPN_PROTO_H2 = b'h2' - - HTTP2_DEFAULT_SETTINGS = { - SettingsFrame.SETTINGS.SETTINGS_HEADER_TABLE_SIZE: 4096, - SettingsFrame.SETTINGS.SETTINGS_ENABLE_PUSH: 1, - SettingsFrame.SETTINGS.SETTINGS_MAX_CONCURRENT_STREAMS: None, - SettingsFrame.SETTINGS.SETTINGS_INITIAL_WINDOW_SIZE: 2 ** 16 - 1, - SettingsFrame.SETTINGS.SETTINGS_MAX_FRAME_SIZE: 2 ** 14, - SettingsFrame.SETTINGS.SETTINGS_MAX_HEADER_LIST_SIZE: None, - } - - def __init__(self): - self.http2_settings = self.HTTP2_DEFAULT_SETTINGS.copy() - self.current_stream_id = None - self.encoder = Encoder() - self.decoder = Decoder() - - def check_alpn(self): - alp = self.get_alpn_proto_negotiated() - if alp != self.ALPN_PROTO_H2: - raise NotImplementedError( - "H2Client can not handle unknown ALP: %s" % alp) - log.debug("ALP 'h2' successfully negotiated.") - - def send_connection_preface(self): - self.wfile.write(bytes(self.CLIENT_CONNECTION_PREFACE.decode('hex'))) - self.send_frame(SettingsFrame(state=self)) - - frame = Frame.from_file(self.rfile, self) - assert isinstance(frame, SettingsFrame) - self._apply_settings(frame.settings) - self.read_frame() # read setting ACK frame - - log.debug("Connection Preface completed.") - - def next_stream_id(self): - if self.current_stream_id is None: - self.current_stream_id = 1 - else: - self.current_stream_id += 2 - return self.current_stream_id - - def send_frame(self, frame): - raw_bytes = frame.to_bytes() - self.wfile.write(raw_bytes) - self.wfile.flush() - - def read_frame(self): - frame = Frame.from_file(self.rfile, self) - if isinstance(frame, SettingsFrame): - self._apply_settings(frame.settings) - - return frame - - def _apply_settings(self, settings): - for setting, value in settings.items(): - old_value = self.http2_settings[setting] - if not old_value: - old_value = '-' - - self.http2_settings[setting] = value - log.debug("Setting changed: %s to %d (was %s)" % ( - SettingsFrame.SETTINGS.get_name(setting), - value, - str(old_value))) - - self.send_frame(SettingsFrame(state=self, flags=Frame.FLAG_ACK)) - log.debug("New settings acknowledged.") - - def _create_headers(self, headers, stream_id, end_stream=True): - # TODO: implement max frame size checks and sending in chunks - - flags = Frame.FLAG_END_HEADERS - if end_stream: - flags |= Frame.FLAG_END_STREAM - - header_block_fragment = self.encoder.encode(headers) - - bytes = HeadersFrame( - state=self, - flags=flags, - stream_id=stream_id, - header_block_fragment=header_block_fragment).to_bytes() - return [bytes] - - def _create_body(self, body, stream_id): - if body is None or len(body) == 0: - return b'' - - # TODO: implement max frame size checks and sending in chunks - # TODO: implement flow-control window - - bytes = DataFrame( - state=self, - flags=Frame.FLAG_END_STREAM, - stream_id=stream_id, - payload=body).to_bytes() - return [bytes] - - def create_request(self, method, path, headers=None, body=None): - if headers is None: - headers = [] - - headers = [ - (b':method', bytes(method)), - (b':path', bytes(path)), - (b':scheme', b'https')] + headers - - stream_id = self.next_stream_id() - - return list(itertools.chain( - self._create_headers(headers, stream_id, end_stream=(body is None)), - self._create_body(body, stream_id))) - - def read_response(self): - header_block_fragment = b'' - body = b'' - - while True: - frame = self.read_frame() - if isinstance(frame, HeadersFrame): - header_block_fragment += frame.header_block_fragment - if frame.flags | Frame.FLAG_END_HEADERS: - break - else: - log.debug("Unexpected frame received:") - log.debug(frame.human_readable()) - - while True: - frame = self.read_frame() - if isinstance(frame, DataFrame): - body += frame.payload - if frame.flags | Frame.FLAG_END_STREAM: - break - else: - log.debug("Unexpected frame received:") - log.debug(frame.human_readable()) - - headers = {} - for header, value in self.decoder.decode(header_block_fragment): - headers[header] = value - - for header, value in headers.items(): - log.debug("%s: %s" % (header, value)) - - return headers[':status'], headers, body diff --git a/netlib/h2/frame.py b/netlib/h2/frame.py deleted file mode 100644 index 018e822f..00000000 --- a/netlib/h2/frame.py +++ /dev/null @@ -1,623 +0,0 @@ -import struct -import logging -from functools import reduce -from hpack.hpack import Encoder, Decoder - -from .. import utils - -log = logging.getLogger(__name__) - -class FrameSizeError(Exception): - pass - -class Frame(object): - - """ - Baseclass Frame - contains header - payload is defined in subclasses - """ - - FLAG_NO_FLAGS = 0x0 - FLAG_ACK = 0x1 - FLAG_END_STREAM = 0x1 - FLAG_END_HEADERS = 0x4 - FLAG_PADDED = 0x8 - FLAG_PRIORITY = 0x20 - - def __init__( - self, - state=None, - length=0, - flags=FLAG_NO_FLAGS, - stream_id=0x0): - valid_flags = reduce(lambda x, y: x | y, self.VALID_FLAGS, 0x0) - if flags | valid_flags != valid_flags: - raise ValueError('invalid flags detected.') - - if state is None: - from . import HTTP2Protocol - - class State(object): - pass - - state = State() - state.http2_settings = HTTP2Protocol.HTTP2_DEFAULT_SETTINGS.copy() - state.encoder = Encoder() - state.decoder = Decoder() - - self.state = state - - self.length = length - self.type = self.TYPE - self.flags = flags - self.stream_id = stream_id - - @classmethod - def _check_frame_size(self, length, state): - from . import HTTP2Protocol - - if state: - settings = state.http2_settings - else: - settings = HTTP2Protocol.HTTP2_DEFAULT_SETTINGS - - max_frame_size = settings[ - SettingsFrame.SETTINGS.SETTINGS_MAX_FRAME_SIZE] - - if length > max_frame_size: - raise FrameSizeError( - "Frame size exceeded: %d, but only %d allowed." % ( - length, max_frame_size)) - - @classmethod - def from_file(self, fp, state=None): - """ - read a HTTP/2 frame sent by a server or client - fp is a "file like" object that could be backed by a network - stream or a disk or an in memory stream reader - """ - raw_header = fp.safe_read(9) - - fields = struct.unpack("!HBBBL", raw_header) - length = (fields[0] << 8) + fields[1] - flags = fields[3] - stream_id = fields[4] - - self._check_frame_size(length, state) - - payload = fp.safe_read(length) - return FRAMES[fields[2]].from_bytes( - state, - length, - flags, - stream_id, - payload) - - def to_bytes(self): - payload = self.payload_bytes() - self.length = len(payload) - - self._check_frame_size(self.length, self.state) - - b = struct.pack('!HB', self.length & 0xFFFF00, self.length & 0x0000FF) - b += struct.pack('!B', self.TYPE) - b += struct.pack('!B', self.flags) - b += struct.pack('!L', self.stream_id & 0x7FFFFFFF) - b += payload - - return b - - def payload_bytes(self): # pragma: no cover - raise NotImplementedError() - - def payload_human_readable(self): # pragma: no cover - raise NotImplementedError() - - def human_readable(self): - return "\n".join([ - "============================================================", - "length: %d bytes" % self.length, - "type: %s (%#x)" % (self.__class__.__name__, self.TYPE), - "flags: %#x" % self.flags, - "stream_id: %#x" % self.stream_id, - "------------------------------------------------------------", - self.payload_human_readable(), - "============================================================", - ]) - - def __eq__(self, other): - return self.to_bytes() == other.to_bytes() - - -class DataFrame(Frame): - TYPE = 0x0 - VALID_FLAGS = [Frame.FLAG_END_STREAM, Frame.FLAG_PADDED] - - def __init__( - self, - state=None, - length=0, - flags=Frame.FLAG_NO_FLAGS, - stream_id=0x0, - payload=b'', - pad_length=0): - super(DataFrame, self).__init__(state, length, flags, stream_id) - self.payload = payload - self.pad_length = pad_length - - @classmethod - def from_bytes(self, state, length, flags, stream_id, payload): - f = self(state=state, length=length, flags=flags, stream_id=stream_id) - - if f.flags & self.FLAG_PADDED: - f.pad_length = struct.unpack('!B', payload[0])[0] - f.payload = payload[1:-f.pad_length] - else: - f.payload = payload - - return f - - def payload_bytes(self): - if self.stream_id == 0x0: - raise ValueError('DATA frames MUST be associated with a stream.') - - b = b'' - if self.flags & self.FLAG_PADDED: - b += struct.pack('!B', self.pad_length) - - b += bytes(self.payload) - - if self.flags & self.FLAG_PADDED: - b += b'\0' * self.pad_length - - return b - - def payload_human_readable(self): - return "payload: %s" % str(self.payload) - - -class HeadersFrame(Frame): - TYPE = 0x1 - VALID_FLAGS = [ - Frame.FLAG_END_STREAM, - Frame.FLAG_END_HEADERS, - Frame.FLAG_PADDED, - Frame.FLAG_PRIORITY] - - def __init__( - self, - state=None, - length=0, - flags=Frame.FLAG_NO_FLAGS, - stream_id=0x0, - header_block_fragment=b'', - pad_length=0, - exclusive=False, - stream_dependency=0x0, - weight=0): - super(HeadersFrame, self).__init__(state, length, flags, stream_id) - - self.header_block_fragment = header_block_fragment - self.pad_length = pad_length - self.exclusive = exclusive - self.stream_dependency = stream_dependency - self.weight = weight - - @classmethod - def from_bytes(self, state, length, flags, stream_id, payload): - f = self(state=state, length=length, flags=flags, stream_id=stream_id) - - if f.flags & self.FLAG_PADDED: - f.pad_length = struct.unpack('!B', payload[0])[0] - f.header_block_fragment = payload[1:-f.pad_length] - else: - f.header_block_fragment = payload[0:] - - if f.flags & self.FLAG_PRIORITY: - f.stream_dependency, f.weight = struct.unpack( - '!LB', f.header_block_fragment[:5]) - f.exclusive = bool(f.stream_dependency >> 31) - f.stream_dependency &= 0x7FFFFFFF - f.header_block_fragment = f.header_block_fragment[5:] - - return f - - def payload_bytes(self): - if self.stream_id == 0x0: - raise ValueError('HEADERS frames MUST be associated with a stream.') - - b = b'' - if self.flags & self.FLAG_PADDED: - b += struct.pack('!B', self.pad_length) - - if self.flags & self.FLAG_PRIORITY: - b += struct.pack('!LB', - (int(self.exclusive) << 31) | self.stream_dependency, - self.weight) - - b += self.header_block_fragment - - if self.flags & self.FLAG_PADDED: - b += b'\0' * self.pad_length - - return b - - def payload_human_readable(self): - s = [] - - if self.flags & self.FLAG_PRIORITY: - s.append("exclusive: %d" % self.exclusive) - s.append("stream dependency: %#x" % self.stream_dependency) - s.append("weight: %d" % self.weight) - - if self.flags & self.FLAG_PADDED: - s.append("padding: %d" % self.pad_length) - - s.append( - "header_block_fragment: %s" % - self.header_block_fragment.encode('hex')) - - return "\n".join(s) - - -class PriorityFrame(Frame): - TYPE = 0x2 - VALID_FLAGS = [] - - def __init__( - self, - state=None, - length=0, - flags=Frame.FLAG_NO_FLAGS, - stream_id=0x0, - exclusive=False, - stream_dependency=0x0, - weight=0): - super(PriorityFrame, self).__init__(state, length, flags, stream_id) - self.exclusive = exclusive - self.stream_dependency = stream_dependency - self.weight = weight - - @classmethod - def from_bytes(self, state, length, flags, stream_id, payload): - f = self(state=state, length=length, flags=flags, stream_id=stream_id) - - f.stream_dependency, f.weight = struct.unpack('!LB', payload) - f.exclusive = bool(f.stream_dependency >> 31) - f.stream_dependency &= 0x7FFFFFFF - - return f - - def payload_bytes(self): - if self.stream_id == 0x0: - raise ValueError( - 'PRIORITY frames MUST be associated with a stream.') - - if self.stream_dependency == 0x0: - raise ValueError('stream dependency is invalid.') - - return struct.pack( - '!LB', - (int( - self.exclusive) << 31) | self.stream_dependency, - self.weight) - - def payload_human_readable(self): - s = [] - s.append("exclusive: %d" % self.exclusive) - s.append("stream dependency: %#x" % self.stream_dependency) - s.append("weight: %d" % self.weight) - return "\n".join(s) - - -class RstStreamFrame(Frame): - TYPE = 0x3 - VALID_FLAGS = [] - - def __init__( - self, - state=None, - length=0, - flags=Frame.FLAG_NO_FLAGS, - stream_id=0x0, - error_code=0x0): - super(RstStreamFrame, self).__init__(state, length, flags, stream_id) - self.error_code = error_code - - @classmethod - def from_bytes(self, state, length, flags, stream_id, payload): - f = self(state=state, length=length, flags=flags, stream_id=stream_id) - f.error_code = struct.unpack('!L', payload)[0] - return f - - def payload_bytes(self): - if self.stream_id == 0x0: - raise ValueError( - 'RST_STREAM frames MUST be associated with a stream.') - - return struct.pack('!L', self.error_code) - - def payload_human_readable(self): - return "error code: %#x" % self.error_code - - -class SettingsFrame(Frame): - TYPE = 0x4 - VALID_FLAGS = [Frame.FLAG_ACK] - - SETTINGS = utils.BiDi( - SETTINGS_HEADER_TABLE_SIZE=0x1, - SETTINGS_ENABLE_PUSH=0x2, - SETTINGS_MAX_CONCURRENT_STREAMS=0x3, - SETTINGS_INITIAL_WINDOW_SIZE=0x4, - SETTINGS_MAX_FRAME_SIZE=0x5, - SETTINGS_MAX_HEADER_LIST_SIZE=0x6, - ) - - def __init__( - self, - state=None, - length=0, - flags=Frame.FLAG_NO_FLAGS, - stream_id=0x0, - settings=None): - super(SettingsFrame, self).__init__(state, length, flags, stream_id) - - if settings is None: - settings = {} - - self.settings = settings - - @classmethod - def from_bytes(self, state, length, flags, stream_id, payload): - f = self(state=state, length=length, flags=flags, stream_id=stream_id) - - for i in xrange(0, len(payload), 6): - identifier, value = struct.unpack("!HL", payload[i:i + 6]) - f.settings[identifier] = value - - return f - - def payload_bytes(self): - if self.stream_id != 0x0: - raise ValueError( - 'SETTINGS frames MUST NOT be associated with a stream.') - - b = b'' - for identifier, value in self.settings.items(): - b += struct.pack("!HL", identifier & 0xFF, value) - - return b - - def payload_human_readable(self): - s = [] - - for identifier, value in self.settings.items(): - s.append("%s: %#x" % (self.SETTINGS.get_name(identifier), value)) - - if not s: - return "settings: None" - else: - return "\n".join(s) - - -class PushPromiseFrame(Frame): - TYPE = 0x5 - VALID_FLAGS = [Frame.FLAG_END_HEADERS, Frame.FLAG_PADDED] - - def __init__( - self, - state=None, - length=0, - flags=Frame.FLAG_NO_FLAGS, - stream_id=0x0, - promised_stream=0x0, - header_block_fragment=b'', - pad_length=0): - super(PushPromiseFrame, self).__init__(state, length, flags, stream_id) - self.pad_length = pad_length - self.promised_stream = promised_stream - self.header_block_fragment = header_block_fragment - - @classmethod - def from_bytes(self, state, length, flags, stream_id, payload): - f = self(state=state, length=length, flags=flags, stream_id=stream_id) - - if f.flags & self.FLAG_PADDED: - f.pad_length, f.promised_stream = struct.unpack('!BL', payload[:5]) - f.header_block_fragment = payload[5:-f.pad_length] - else: - f.promised_stream = int(struct.unpack("!L", payload[:4])[0]) - f.header_block_fragment = payload[4:] - - f.promised_stream &= 0x7FFFFFFF - - return f - - def payload_bytes(self): - if self.stream_id == 0x0: - raise ValueError( - 'PUSH_PROMISE frames MUST be associated with a stream.') - - if self.promised_stream == 0x0: - raise ValueError('Promised stream id not valid.') - - b = b'' - if self.flags & self.FLAG_PADDED: - b += struct.pack('!B', self.pad_length) - - b += struct.pack('!L', self.promised_stream & 0x7FFFFFFF) - b += bytes(self.header_block_fragment) - - if self.flags & self.FLAG_PADDED: - b += b'\0' * self.pad_length - - return b - - def payload_human_readable(self): - s = [] - - if self.flags & self.FLAG_PADDED: - s.append("padding: %d" % self.pad_length) - - s.append("promised stream: %#x" % self.promised_stream) - s.append("header_block_fragment: %s" % str(self.header_block_fragment)) - return "\n".join(s) - - -class PingFrame(Frame): - TYPE = 0x6 - VALID_FLAGS = [Frame.FLAG_ACK] - - def __init__( - self, - state=None, - length=0, - flags=Frame.FLAG_NO_FLAGS, - stream_id=0x0, - payload=b''): - super(PingFrame, self).__init__(state, length, flags, stream_id) - self.payload = payload - - @classmethod - def from_bytes(self, state, length, flags, stream_id, payload): - f = self(state=state, length=length, flags=flags, stream_id=stream_id) - f.payload = payload - return f - - def payload_bytes(self): - if self.stream_id != 0x0: - raise ValueError( - 'PING frames MUST NOT be associated with a stream.') - - b = self.payload[0:8] - b += b'\0' * (8 - len(b)) - return b - - def payload_human_readable(self): - return "opaque data: %s" % str(self.payload) - - -class GoAwayFrame(Frame): - TYPE = 0x7 - VALID_FLAGS = [] - - def __init__( - self, - state=None, - length=0, - flags=Frame.FLAG_NO_FLAGS, - stream_id=0x0, - last_stream=0x0, - error_code=0x0, - data=b''): - super(GoAwayFrame, self).__init__(state, length, flags, stream_id) - self.last_stream = last_stream - self.error_code = error_code - self.data = data - - @classmethod - def from_bytes(self, state, length, flags, stream_id, payload): - f = self(state=state, length=length, flags=flags, stream_id=stream_id) - - f.last_stream, f.error_code = struct.unpack("!LL", payload[:8]) - f.last_stream &= 0x7FFFFFFF - f.data = payload[8:] - - return f - - def payload_bytes(self): - if self.stream_id != 0x0: - raise ValueError( - 'GOAWAY frames MUST NOT be associated with a stream.') - - b = struct.pack('!LL', self.last_stream & 0x7FFFFFFF, self.error_code) - b += bytes(self.data) - return b - - def payload_human_readable(self): - s = [] - s.append("last stream: %#x" % self.last_stream) - s.append("error code: %d" % self.error_code) - s.append("debug data: %s" % str(self.data)) - return "\n".join(s) - - -class WindowUpdateFrame(Frame): - TYPE = 0x8 - VALID_FLAGS = [] - - def __init__( - self, - state=None, - length=0, - flags=Frame.FLAG_NO_FLAGS, - stream_id=0x0, - window_size_increment=0x0): - super(WindowUpdateFrame, self).__init__(state, length, flags, stream_id) - self.window_size_increment = window_size_increment - - @classmethod - def from_bytes(self, state, length, flags, stream_id, payload): - f = self(state=state, length=length, flags=flags, stream_id=stream_id) - - f.window_size_increment = struct.unpack("!L", payload)[0] - f.window_size_increment &= 0x7FFFFFFF - - return f - - def payload_bytes(self): - if self.window_size_increment <= 0 or self.window_size_increment >= 2 ** 31: - raise ValueError( - 'Window Szie Increment MUST be greater than 0 and less than 2^31.') - - return struct.pack('!L', self.window_size_increment & 0x7FFFFFFF) - - def payload_human_readable(self): - return "window size increment: %#x" % self.window_size_increment - - -class ContinuationFrame(Frame): - TYPE = 0x9 - VALID_FLAGS = [Frame.FLAG_END_HEADERS] - - def __init__( - self, - state=None, - length=0, - flags=Frame.FLAG_NO_FLAGS, - stream_id=0x0, - header_block_fragment=b''): - super(ContinuationFrame, self).__init__(state, length, flags, stream_id) - self.header_block_fragment = header_block_fragment - - @classmethod - def from_bytes(self, state, length, flags, stream_id, payload): - f = self(state=state, length=length, flags=flags, stream_id=stream_id) - f.header_block_fragment = payload - return f - - def payload_bytes(self): - if self.stream_id == 0x0: - raise ValueError( - 'CONTINUATION frames MUST be associated with a stream.') - - return self.header_block_fragment - - def payload_human_readable(self): - return "header_block_fragment: %s" % str(self.header_block_fragment) - -_FRAME_CLASSES = [ - DataFrame, - HeadersFrame, - PriorityFrame, - RstStreamFrame, - SettingsFrame, - PushPromiseFrame, - PingFrame, - GoAwayFrame, - WindowUpdateFrame, - ContinuationFrame -] -FRAMES = {cls.TYPE: cls for cls in _FRAME_CLASSES} diff --git a/netlib/http2/__init__.py b/netlib/http2/__init__.py new file mode 100644 index 00000000..d6f2c51c --- /dev/null +++ b/netlib/http2/__init__.py @@ -0,0 +1,181 @@ +from __future__ import (absolute_import, print_function, division) +import itertools +import logging + +from .frame import * +from .. import utils + +log = logging.getLogger(__name__) + + +class HTTP2Protocol(object): + + ERROR_CODES = utils.BiDi( + NO_ERROR=0x0, + PROTOCOL_ERROR=0x1, + INTERNAL_ERROR=0x2, + FLOW_CONTROL_ERROR=0x3, + SETTINGS_TIMEOUT=0x4, + STREAM_CLOSED=0x5, + FRAME_SIZE_ERROR=0x6, + REFUSED_STREAM=0x7, + CANCEL=0x8, + COMPRESSION_ERROR=0x9, + CONNECT_ERROR=0xa, + ENHANCE_YOUR_CALM=0xb, + INADEQUATE_SECURITY=0xc, + HTTP_1_1_REQUIRED=0xd + ) + + # "PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n" + CLIENT_CONNECTION_PREFACE = '505249202a20485454502f322e300d0a0d0a534d0d0a0d0a' + + ALPN_PROTO_H2 = b'h2' + + HTTP2_DEFAULT_SETTINGS = { + SettingsFrame.SETTINGS.SETTINGS_HEADER_TABLE_SIZE: 4096, + SettingsFrame.SETTINGS.SETTINGS_ENABLE_PUSH: 1, + SettingsFrame.SETTINGS.SETTINGS_MAX_CONCURRENT_STREAMS: None, + SettingsFrame.SETTINGS.SETTINGS_INITIAL_WINDOW_SIZE: 2 ** 16 - 1, + SettingsFrame.SETTINGS.SETTINGS_MAX_FRAME_SIZE: 2 ** 14, + SettingsFrame.SETTINGS.SETTINGS_MAX_HEADER_LIST_SIZE: None, + } + + def __init__(self, tcp_client): + self.tcp_client = tcp_client + + self.http2_settings = self.HTTP2_DEFAULT_SETTINGS.copy() + self.current_stream_id = None + self.encoder = Encoder() + self.decoder = Decoder() + + def check_alpn(self): + alp = self.tcp_client.get_alpn_proto_negotiated() + if alp != self.ALPN_PROTO_H2: + raise NotImplementedError( + "H2Client can not handle unknown ALP: %s" % alp) + log.debug("ALP 'h2' successfully negotiated.") + + def send_connection_preface(self): + self.tcp_client.wfile.write( + bytes(self.CLIENT_CONNECTION_PREFACE.decode('hex'))) + self.send_frame(SettingsFrame(state=self)) + + frame = Frame.from_file(self.tcp_client.rfile, self) + assert isinstance(frame, SettingsFrame) + self._apply_settings(frame.settings) + self.read_frame() # read setting ACK frame + + log.debug("Connection Preface completed.") + + def next_stream_id(self): + if self.current_stream_id is None: + self.current_stream_id = 1 + else: + self.current_stream_id += 2 + return self.current_stream_id + + def send_frame(self, frame): + raw_bytes = frame.to_bytes() + self.tcp_client.wfile.write(raw_bytes) + self.tcp_client.wfile.flush() + + def read_frame(self): + frame = Frame.from_file(self.tcp_client.rfile, self) + if isinstance(frame, SettingsFrame): + self._apply_settings(frame.settings) + + return frame + + def _apply_settings(self, settings): + for setting, value in settings.items(): + old_value = self.http2_settings[setting] + if not old_value: + old_value = '-' + + self.http2_settings[setting] = value + log.debug("Setting changed: %s to %d (was %s)" % ( + SettingsFrame.SETTINGS.get_name(setting), + value, + str(old_value))) + + self.send_frame(SettingsFrame(state=self, flags=Frame.FLAG_ACK)) + log.debug("New settings acknowledged.") + + def _create_headers(self, headers, stream_id, end_stream=True): + # TODO: implement max frame size checks and sending in chunks + + flags = Frame.FLAG_END_HEADERS + if end_stream: + flags |= Frame.FLAG_END_STREAM + + header_block_fragment = self.encoder.encode(headers) + + bytes = HeadersFrame( + state=self, + flags=flags, + stream_id=stream_id, + header_block_fragment=header_block_fragment).to_bytes() + return [bytes] + + def _create_body(self, body, stream_id): + if body is None or len(body) == 0: + return b'' + + # TODO: implement max frame size checks and sending in chunks + # TODO: implement flow-control window + + bytes = DataFrame( + state=self, + flags=Frame.FLAG_END_STREAM, + stream_id=stream_id, + payload=body).to_bytes() + return [bytes] + + def create_request(self, method, path, headers=None, body=None): + if headers is None: + headers = [] + + headers = [ + (b':method', bytes(method)), + (b':path', bytes(path)), + (b':scheme', b'https')] + headers + + stream_id = self.next_stream_id() + + return list(itertools.chain( + self._create_headers(headers, stream_id, end_stream=(body is None)), + self._create_body(body, stream_id))) + + def read_response(self): + header_block_fragment = b'' + body = b'' + + while True: + frame = self.read_frame() + if isinstance(frame, HeadersFrame): + header_block_fragment += frame.header_block_fragment + if frame.flags | Frame.FLAG_END_HEADERS: + break + else: + log.debug("Unexpected frame received:") + log.debug(frame.human_readable()) + + while True: + frame = self.read_frame() + if isinstance(frame, DataFrame): + body += frame.payload + if frame.flags | Frame.FLAG_END_STREAM: + break + else: + log.debug("Unexpected frame received:") + log.debug(frame.human_readable()) + + headers = {} + for header, value in self.decoder.decode(header_block_fragment): + headers[header] = value + + for header, value in headers.items(): + log.debug("%s: %s" % (header, value)) + + return headers[':status'], headers, body diff --git a/netlib/http2/frame.py b/netlib/http2/frame.py new file mode 100644 index 00000000..1497380a --- /dev/null +++ b/netlib/http2/frame.py @@ -0,0 +1,625 @@ +import struct +import logging +from functools import reduce +from hpack.hpack import Encoder, Decoder + +from .. import utils + +log = logging.getLogger(__name__) + + +class FrameSizeError(Exception): + pass + + +class Frame(object): + + """ + Baseclass Frame + contains header + payload is defined in subclasses + """ + + FLAG_NO_FLAGS = 0x0 + FLAG_ACK = 0x1 + FLAG_END_STREAM = 0x1 + FLAG_END_HEADERS = 0x4 + FLAG_PADDED = 0x8 + FLAG_PRIORITY = 0x20 + + def __init__( + self, + state=None, + length=0, + flags=FLAG_NO_FLAGS, + stream_id=0x0): + valid_flags = reduce(lambda x, y: x | y, self.VALID_FLAGS, 0x0) + if flags | valid_flags != valid_flags: + raise ValueError('invalid flags detected.') + + if state is None: + from . import HTTP2Protocol + + class State(object): + pass + + state = State() + state.http2_settings = HTTP2Protocol.HTTP2_DEFAULT_SETTINGS.copy() + state.encoder = Encoder() + state.decoder = Decoder() + + self.state = state + + self.length = length + self.type = self.TYPE + self.flags = flags + self.stream_id = stream_id + + @classmethod + def _check_frame_size(self, length, state): + from . import HTTP2Protocol + + if state: + settings = state.http2_settings + else: + settings = HTTP2Protocol.HTTP2_DEFAULT_SETTINGS + + max_frame_size = settings[ + SettingsFrame.SETTINGS.SETTINGS_MAX_FRAME_SIZE] + + if length > max_frame_size: + raise FrameSizeError( + "Frame size exceeded: %d, but only %d allowed." % ( + length, max_frame_size)) + + @classmethod + def from_file(self, fp, state=None): + """ + read a HTTP/2 frame sent by a server or client + fp is a "file like" object that could be backed by a network + stream or a disk or an in memory stream reader + """ + raw_header = fp.safe_read(9) + + fields = struct.unpack("!HBBBL", raw_header) + length = (fields[0] << 8) + fields[1] + flags = fields[3] + stream_id = fields[4] + + self._check_frame_size(length, state) + + payload = fp.safe_read(length) + return FRAMES[fields[2]].from_bytes( + state, + length, + flags, + stream_id, + payload) + + def to_bytes(self): + payload = self.payload_bytes() + self.length = len(payload) + + self._check_frame_size(self.length, self.state) + + b = struct.pack('!HB', self.length & 0xFFFF00, self.length & 0x0000FF) + b += struct.pack('!B', self.TYPE) + b += struct.pack('!B', self.flags) + b += struct.pack('!L', self.stream_id & 0x7FFFFFFF) + b += payload + + return b + + def payload_bytes(self): # pragma: no cover + raise NotImplementedError() + + def payload_human_readable(self): # pragma: no cover + raise NotImplementedError() + + def human_readable(self): + return "\n".join([ + "============================================================", + "length: %d bytes" % self.length, + "type: %s (%#x)" % (self.__class__.__name__, self.TYPE), + "flags: %#x" % self.flags, + "stream_id: %#x" % self.stream_id, + "------------------------------------------------------------", + self.payload_human_readable(), + "============================================================", + ]) + + def __eq__(self, other): + return self.to_bytes() == other.to_bytes() + + +class DataFrame(Frame): + TYPE = 0x0 + VALID_FLAGS = [Frame.FLAG_END_STREAM, Frame.FLAG_PADDED] + + def __init__( + self, + state=None, + length=0, + flags=Frame.FLAG_NO_FLAGS, + stream_id=0x0, + payload=b'', + pad_length=0): + super(DataFrame, self).__init__(state, length, flags, stream_id) + self.payload = payload + self.pad_length = pad_length + + @classmethod + def from_bytes(self, state, length, flags, stream_id, payload): + f = self(state=state, length=length, flags=flags, stream_id=stream_id) + + if f.flags & self.FLAG_PADDED: + f.pad_length = struct.unpack('!B', payload[0])[0] + f.payload = payload[1:-f.pad_length] + else: + f.payload = payload + + return f + + def payload_bytes(self): + if self.stream_id == 0x0: + raise ValueError('DATA frames MUST be associated with a stream.') + + b = b'' + if self.flags & self.FLAG_PADDED: + b += struct.pack('!B', self.pad_length) + + b += bytes(self.payload) + + if self.flags & self.FLAG_PADDED: + b += b'\0' * self.pad_length + + return b + + def payload_human_readable(self): + return "payload: %s" % str(self.payload) + + +class HeadersFrame(Frame): + TYPE = 0x1 + VALID_FLAGS = [ + Frame.FLAG_END_STREAM, + Frame.FLAG_END_HEADERS, + Frame.FLAG_PADDED, + Frame.FLAG_PRIORITY] + + def __init__( + self, + state=None, + length=0, + flags=Frame.FLAG_NO_FLAGS, + stream_id=0x0, + header_block_fragment=b'', + pad_length=0, + exclusive=False, + stream_dependency=0x0, + weight=0): + super(HeadersFrame, self).__init__(state, length, flags, stream_id) + + self.header_block_fragment = header_block_fragment + self.pad_length = pad_length + self.exclusive = exclusive + self.stream_dependency = stream_dependency + self.weight = weight + + @classmethod + def from_bytes(self, state, length, flags, stream_id, payload): + f = self(state=state, length=length, flags=flags, stream_id=stream_id) + + if f.flags & self.FLAG_PADDED: + f.pad_length = struct.unpack('!B', payload[0])[0] + f.header_block_fragment = payload[1:-f.pad_length] + else: + f.header_block_fragment = payload[0:] + + if f.flags & self.FLAG_PRIORITY: + f.stream_dependency, f.weight = struct.unpack( + '!LB', f.header_block_fragment[:5]) + f.exclusive = bool(f.stream_dependency >> 31) + f.stream_dependency &= 0x7FFFFFFF + f.header_block_fragment = f.header_block_fragment[5:] + + return f + + def payload_bytes(self): + if self.stream_id == 0x0: + raise ValueError('HEADERS frames MUST be associated with a stream.') + + b = b'' + if self.flags & self.FLAG_PADDED: + b += struct.pack('!B', self.pad_length) + + if self.flags & self.FLAG_PRIORITY: + b += struct.pack('!LB', + (int(self.exclusive) << 31) | self.stream_dependency, + self.weight) + + b += self.header_block_fragment + + if self.flags & self.FLAG_PADDED: + b += b'\0' * self.pad_length + + return b + + def payload_human_readable(self): + s = [] + + if self.flags & self.FLAG_PRIORITY: + s.append("exclusive: %d" % self.exclusive) + s.append("stream dependency: %#x" % self.stream_dependency) + s.append("weight: %d" % self.weight) + + if self.flags & self.FLAG_PADDED: + s.append("padding: %d" % self.pad_length) + + s.append( + "header_block_fragment: %s" % + self.header_block_fragment.encode('hex')) + + return "\n".join(s) + + +class PriorityFrame(Frame): + TYPE = 0x2 + VALID_FLAGS = [] + + def __init__( + self, + state=None, + length=0, + flags=Frame.FLAG_NO_FLAGS, + stream_id=0x0, + exclusive=False, + stream_dependency=0x0, + weight=0): + super(PriorityFrame, self).__init__(state, length, flags, stream_id) + self.exclusive = exclusive + self.stream_dependency = stream_dependency + self.weight = weight + + @classmethod + def from_bytes(self, state, length, flags, stream_id, payload): + f = self(state=state, length=length, flags=flags, stream_id=stream_id) + + f.stream_dependency, f.weight = struct.unpack('!LB', payload) + f.exclusive = bool(f.stream_dependency >> 31) + f.stream_dependency &= 0x7FFFFFFF + + return f + + def payload_bytes(self): + if self.stream_id == 0x0: + raise ValueError( + 'PRIORITY frames MUST be associated with a stream.') + + if self.stream_dependency == 0x0: + raise ValueError('stream dependency is invalid.') + + return struct.pack( + '!LB', + (int( + self.exclusive) << 31) | self.stream_dependency, + self.weight) + + def payload_human_readable(self): + s = [] + s.append("exclusive: %d" % self.exclusive) + s.append("stream dependency: %#x" % self.stream_dependency) + s.append("weight: %d" % self.weight) + return "\n".join(s) + + +class RstStreamFrame(Frame): + TYPE = 0x3 + VALID_FLAGS = [] + + def __init__( + self, + state=None, + length=0, + flags=Frame.FLAG_NO_FLAGS, + stream_id=0x0, + error_code=0x0): + super(RstStreamFrame, self).__init__(state, length, flags, stream_id) + self.error_code = error_code + + @classmethod + def from_bytes(self, state, length, flags, stream_id, payload): + f = self(state=state, length=length, flags=flags, stream_id=stream_id) + f.error_code = struct.unpack('!L', payload)[0] + return f + + def payload_bytes(self): + if self.stream_id == 0x0: + raise ValueError( + 'RST_STREAM frames MUST be associated with a stream.') + + return struct.pack('!L', self.error_code) + + def payload_human_readable(self): + return "error code: %#x" % self.error_code + + +class SettingsFrame(Frame): + TYPE = 0x4 + VALID_FLAGS = [Frame.FLAG_ACK] + + SETTINGS = utils.BiDi( + SETTINGS_HEADER_TABLE_SIZE=0x1, + SETTINGS_ENABLE_PUSH=0x2, + SETTINGS_MAX_CONCURRENT_STREAMS=0x3, + SETTINGS_INITIAL_WINDOW_SIZE=0x4, + SETTINGS_MAX_FRAME_SIZE=0x5, + SETTINGS_MAX_HEADER_LIST_SIZE=0x6, + ) + + def __init__( + self, + state=None, + length=0, + flags=Frame.FLAG_NO_FLAGS, + stream_id=0x0, + settings=None): + super(SettingsFrame, self).__init__(state, length, flags, stream_id) + + if settings is None: + settings = {} + + self.settings = settings + + @classmethod + def from_bytes(self, state, length, flags, stream_id, payload): + f = self(state=state, length=length, flags=flags, stream_id=stream_id) + + for i in xrange(0, len(payload), 6): + identifier, value = struct.unpack("!HL", payload[i:i + 6]) + f.settings[identifier] = value + + return f + + def payload_bytes(self): + if self.stream_id != 0x0: + raise ValueError( + 'SETTINGS frames MUST NOT be associated with a stream.') + + b = b'' + for identifier, value in self.settings.items(): + b += struct.pack("!HL", identifier & 0xFF, value) + + return b + + def payload_human_readable(self): + s = [] + + for identifier, value in self.settings.items(): + s.append("%s: %#x" % (self.SETTINGS.get_name(identifier), value)) + + if not s: + return "settings: None" + else: + return "\n".join(s) + + +class PushPromiseFrame(Frame): + TYPE = 0x5 + VALID_FLAGS = [Frame.FLAG_END_HEADERS, Frame.FLAG_PADDED] + + def __init__( + self, + state=None, + length=0, + flags=Frame.FLAG_NO_FLAGS, + stream_id=0x0, + promised_stream=0x0, + header_block_fragment=b'', + pad_length=0): + super(PushPromiseFrame, self).__init__(state, length, flags, stream_id) + self.pad_length = pad_length + self.promised_stream = promised_stream + self.header_block_fragment = header_block_fragment + + @classmethod + def from_bytes(self, state, length, flags, stream_id, payload): + f = self(state=state, length=length, flags=flags, stream_id=stream_id) + + if f.flags & self.FLAG_PADDED: + f.pad_length, f.promised_stream = struct.unpack('!BL', payload[:5]) + f.header_block_fragment = payload[5:-f.pad_length] + else: + f.promised_stream = int(struct.unpack("!L", payload[:4])[0]) + f.header_block_fragment = payload[4:] + + f.promised_stream &= 0x7FFFFFFF + + return f + + def payload_bytes(self): + if self.stream_id == 0x0: + raise ValueError( + 'PUSH_PROMISE frames MUST be associated with a stream.') + + if self.promised_stream == 0x0: + raise ValueError('Promised stream id not valid.') + + b = b'' + if self.flags & self.FLAG_PADDED: + b += struct.pack('!B', self.pad_length) + + b += struct.pack('!L', self.promised_stream & 0x7FFFFFFF) + b += bytes(self.header_block_fragment) + + if self.flags & self.FLAG_PADDED: + b += b'\0' * self.pad_length + + return b + + def payload_human_readable(self): + s = [] + + if self.flags & self.FLAG_PADDED: + s.append("padding: %d" % self.pad_length) + + s.append("promised stream: %#x" % self.promised_stream) + s.append("header_block_fragment: %s" % str(self.header_block_fragment)) + return "\n".join(s) + + +class PingFrame(Frame): + TYPE = 0x6 + VALID_FLAGS = [Frame.FLAG_ACK] + + def __init__( + self, + state=None, + length=0, + flags=Frame.FLAG_NO_FLAGS, + stream_id=0x0, + payload=b''): + super(PingFrame, self).__init__(state, length, flags, stream_id) + self.payload = payload + + @classmethod + def from_bytes(self, state, length, flags, stream_id, payload): + f = self(state=state, length=length, flags=flags, stream_id=stream_id) + f.payload = payload + return f + + def payload_bytes(self): + if self.stream_id != 0x0: + raise ValueError( + 'PING frames MUST NOT be associated with a stream.') + + b = self.payload[0:8] + b += b'\0' * (8 - len(b)) + return b + + def payload_human_readable(self): + return "opaque data: %s" % str(self.payload) + + +class GoAwayFrame(Frame): + TYPE = 0x7 + VALID_FLAGS = [] + + def __init__( + self, + state=None, + length=0, + flags=Frame.FLAG_NO_FLAGS, + stream_id=0x0, + last_stream=0x0, + error_code=0x0, + data=b''): + super(GoAwayFrame, self).__init__(state, length, flags, stream_id) + self.last_stream = last_stream + self.error_code = error_code + self.data = data + + @classmethod + def from_bytes(self, state, length, flags, stream_id, payload): + f = self(state=state, length=length, flags=flags, stream_id=stream_id) + + f.last_stream, f.error_code = struct.unpack("!LL", payload[:8]) + f.last_stream &= 0x7FFFFFFF + f.data = payload[8:] + + return f + + def payload_bytes(self): + if self.stream_id != 0x0: + raise ValueError( + 'GOAWAY frames MUST NOT be associated with a stream.') + + b = struct.pack('!LL', self.last_stream & 0x7FFFFFFF, self.error_code) + b += bytes(self.data) + return b + + def payload_human_readable(self): + s = [] + s.append("last stream: %#x" % self.last_stream) + s.append("error code: %d" % self.error_code) + s.append("debug data: %s" % str(self.data)) + return "\n".join(s) + + +class WindowUpdateFrame(Frame): + TYPE = 0x8 + VALID_FLAGS = [] + + def __init__( + self, + state=None, + length=0, + flags=Frame.FLAG_NO_FLAGS, + stream_id=0x0, + window_size_increment=0x0): + super(WindowUpdateFrame, self).__init__(state, length, flags, stream_id) + self.window_size_increment = window_size_increment + + @classmethod + def from_bytes(self, state, length, flags, stream_id, payload): + f = self(state=state, length=length, flags=flags, stream_id=stream_id) + + f.window_size_increment = struct.unpack("!L", payload)[0] + f.window_size_increment &= 0x7FFFFFFF + + return f + + def payload_bytes(self): + if self.window_size_increment <= 0 or self.window_size_increment >= 2 ** 31: + raise ValueError( + 'Window Szie Increment MUST be greater than 0 and less than 2^31.') + + return struct.pack('!L', self.window_size_increment & 0x7FFFFFFF) + + def payload_human_readable(self): + return "window size increment: %#x" % self.window_size_increment + + +class ContinuationFrame(Frame): + TYPE = 0x9 + VALID_FLAGS = [Frame.FLAG_END_HEADERS] + + def __init__( + self, + state=None, + length=0, + flags=Frame.FLAG_NO_FLAGS, + stream_id=0x0, + header_block_fragment=b''): + super(ContinuationFrame, self).__init__(state, length, flags, stream_id) + self.header_block_fragment = header_block_fragment + + @classmethod + def from_bytes(self, state, length, flags, stream_id, payload): + f = self(state=state, length=length, flags=flags, stream_id=stream_id) + f.header_block_fragment = payload + return f + + def payload_bytes(self): + if self.stream_id == 0x0: + raise ValueError( + 'CONTINUATION frames MUST be associated with a stream.') + + return self.header_block_fragment + + def payload_human_readable(self): + return "header_block_fragment: %s" % str(self.header_block_fragment) + +_FRAME_CLASSES = [ + DataFrame, + HeadersFrame, + PriorityFrame, + RstStreamFrame, + SettingsFrame, + PushPromiseFrame, + PingFrame, + GoAwayFrame, + WindowUpdateFrame, + ContinuationFrame +] +FRAMES = {cls.TYPE: cls for cls in _FRAME_CLASSES} -- cgit v1.2.3 From fdc908cb9811628435ef02e3168c4d5931c6a3c5 Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Fri, 5 Jun 2015 13:28:09 +0200 Subject: http2: add protocol tests --- netlib/http2/__init__.py | 25 +++++++++++++------------ netlib/test.py | 2 +- 2 files changed, 14 insertions(+), 13 deletions(-) (limited to 'netlib') diff --git a/netlib/http2/__init__.py b/netlib/http2/__init__.py index d6f2c51c..2803cccb 100644 --- a/netlib/http2/__init__.py +++ b/netlib/http2/__init__.py @@ -30,7 +30,7 @@ class HTTP2Protocol(object): # "PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n" CLIENT_CONNECTION_PREFACE = '505249202a20485454502f322e300d0a0d0a534d0d0a0d0a' - ALPN_PROTO_H2 = b'h2' + ALPN_PROTO_H2 = 'h2' HTTP2_DEFAULT_SETTINGS = { SettingsFrame.SETTINGS.SETTINGS_HEADER_TABLE_SIZE: 4096, @@ -53,18 +53,25 @@ class HTTP2Protocol(object): alp = self.tcp_client.get_alpn_proto_negotiated() if alp != self.ALPN_PROTO_H2: raise NotImplementedError( - "H2Client can not handle unknown ALP: %s" % alp) + "HTTP2Protocol can not handle unknown ALP: %s" % alp) log.debug("ALP 'h2' successfully negotiated.") + return True - def send_connection_preface(self): + def perform_connection_preface(self): self.tcp_client.wfile.write( bytes(self.CLIENT_CONNECTION_PREFACE.decode('hex'))) self.send_frame(SettingsFrame(state=self)) + # read server settings frame frame = Frame.from_file(self.tcp_client.rfile, self) assert isinstance(frame, SettingsFrame) self._apply_settings(frame.settings) - self.read_frame() # read setting ACK frame + + # read setting ACK frame + settings_ack_frame = self.read_frame() + assert isinstance(settings_ack_frame, SettingsFrame) + assert settings_ack_frame.flags & Frame.FLAG_ACK + assert len(settings_ack_frame.settings) == 0 log.debug("Connection Preface completed.") @@ -94,9 +101,9 @@ class HTTP2Protocol(object): old_value = '-' self.http2_settings[setting] = value - log.debug("Setting changed: %s to %d (was %s)" % ( + log.debug("Setting changed: %s to %s (was %s)" % ( SettingsFrame.SETTINGS.get_name(setting), - value, + str(value), str(old_value))) self.send_frame(SettingsFrame(state=self, flags=Frame.FLAG_ACK)) @@ -157,9 +164,6 @@ class HTTP2Protocol(object): header_block_fragment += frame.header_block_fragment if frame.flags | Frame.FLAG_END_HEADERS: break - else: - log.debug("Unexpected frame received:") - log.debug(frame.human_readable()) while True: frame = self.read_frame() @@ -167,9 +171,6 @@ class HTTP2Protocol(object): body += frame.payload if frame.flags | Frame.FLAG_END_STREAM: break - else: - log.debug("Unexpected frame received:") - log.debug(frame.human_readable()) headers = {} for header, value in self.decoder.decode(header_block_fragment): diff --git a/netlib/test.py b/netlib/test.py index ee8c6685..4b0b6bd2 100644 --- a/netlib/test.py +++ b/netlib/test.py @@ -4,7 +4,7 @@ import Queue import cStringIO import OpenSSL from . import tcp, certutils -import tutils +from test import tutils class ServerThread(threading.Thread): -- cgit v1.2.3 From f2db8abbe859266bb28117e1ffa4b0b99d62e321 Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Fri, 5 Jun 2015 20:52:11 +0200 Subject: use open instead of file --- netlib/test.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'netlib') diff --git a/netlib/test.py b/netlib/test.py index 4b0b6bd2..1e1b5e9d 100644 --- a/netlib/test.py +++ b/netlib/test.py @@ -75,13 +75,13 @@ class TServer(tcp.TCPServer): raw_cert = self.ssl.get( "cert", tutils.test_data.path("data/server.crt")) - cert = certutils.SSLCert.from_pem(file(raw_cert, "rb").read()) + cert = certutils.SSLCert.from_pem(open(raw_cert, "rb").read()) raw_key = self.ssl.get( "key", tutils.test_data.path("data/server.key")) key = OpenSSL.crypto.load_privatekey( OpenSSL.crypto.FILETYPE_PEM, - file(raw_key, "rb").read()) + open(raw_key, "rb").read()) if self.ssl.get("v3_only", False): method = tcp.SSLv3_METHOD options = OpenSSL.SSL.OP_NO_SSLv2 | OpenSSL.SSL.OP_NO_TLSv1 -- cgit v1.2.3 From f2d784896dd18ea7ded9b3a95bedcdceb3325213 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Sat, 6 Jun 2015 12:26:48 +1200 Subject: http2: resolve module structure and circular dependencies - Move implementation out of __init__.py to protocol.py (an anti-pattern because it makes the kind of structural refactoring we need hard) - protocol imports frame, frame does not import protocol. To do this, we shift the default settings to frame. If this feels wrong, we can move them to a separate module (defaults.py?.). --- netlib/http2/__init__.py | 183 +---------------------------------------------- netlib/http2/frame.py | 18 +++-- netlib/http2/protocol.py | 174 ++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 188 insertions(+), 187 deletions(-) create mode 100644 netlib/http2/protocol.py (limited to 'netlib') diff --git a/netlib/http2/__init__.py b/netlib/http2/__init__.py index 2803cccb..92897b5d 100644 --- a/netlib/http2/__init__.py +++ b/netlib/http2/__init__.py @@ -1,182 +1,3 @@ -from __future__ import (absolute_import, print_function, division) -import itertools -import logging -from .frame import * -from .. import utils - -log = logging.getLogger(__name__) - - -class HTTP2Protocol(object): - - ERROR_CODES = utils.BiDi( - NO_ERROR=0x0, - PROTOCOL_ERROR=0x1, - INTERNAL_ERROR=0x2, - FLOW_CONTROL_ERROR=0x3, - SETTINGS_TIMEOUT=0x4, - STREAM_CLOSED=0x5, - FRAME_SIZE_ERROR=0x6, - REFUSED_STREAM=0x7, - CANCEL=0x8, - COMPRESSION_ERROR=0x9, - CONNECT_ERROR=0xa, - ENHANCE_YOUR_CALM=0xb, - INADEQUATE_SECURITY=0xc, - HTTP_1_1_REQUIRED=0xd - ) - - # "PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n" - CLIENT_CONNECTION_PREFACE = '505249202a20485454502f322e300d0a0d0a534d0d0a0d0a' - - ALPN_PROTO_H2 = 'h2' - - HTTP2_DEFAULT_SETTINGS = { - SettingsFrame.SETTINGS.SETTINGS_HEADER_TABLE_SIZE: 4096, - SettingsFrame.SETTINGS.SETTINGS_ENABLE_PUSH: 1, - SettingsFrame.SETTINGS.SETTINGS_MAX_CONCURRENT_STREAMS: None, - SettingsFrame.SETTINGS.SETTINGS_INITIAL_WINDOW_SIZE: 2 ** 16 - 1, - SettingsFrame.SETTINGS.SETTINGS_MAX_FRAME_SIZE: 2 ** 14, - SettingsFrame.SETTINGS.SETTINGS_MAX_HEADER_LIST_SIZE: None, - } - - def __init__(self, tcp_client): - self.tcp_client = tcp_client - - self.http2_settings = self.HTTP2_DEFAULT_SETTINGS.copy() - self.current_stream_id = None - self.encoder = Encoder() - self.decoder = Decoder() - - def check_alpn(self): - alp = self.tcp_client.get_alpn_proto_negotiated() - if alp != self.ALPN_PROTO_H2: - raise NotImplementedError( - "HTTP2Protocol can not handle unknown ALP: %s" % alp) - log.debug("ALP 'h2' successfully negotiated.") - return True - - def perform_connection_preface(self): - self.tcp_client.wfile.write( - bytes(self.CLIENT_CONNECTION_PREFACE.decode('hex'))) - self.send_frame(SettingsFrame(state=self)) - - # read server settings frame - frame = Frame.from_file(self.tcp_client.rfile, self) - assert isinstance(frame, SettingsFrame) - self._apply_settings(frame.settings) - - # read setting ACK frame - settings_ack_frame = self.read_frame() - assert isinstance(settings_ack_frame, SettingsFrame) - assert settings_ack_frame.flags & Frame.FLAG_ACK - assert len(settings_ack_frame.settings) == 0 - - log.debug("Connection Preface completed.") - - def next_stream_id(self): - if self.current_stream_id is None: - self.current_stream_id = 1 - else: - self.current_stream_id += 2 - return self.current_stream_id - - def send_frame(self, frame): - raw_bytes = frame.to_bytes() - self.tcp_client.wfile.write(raw_bytes) - self.tcp_client.wfile.flush() - - def read_frame(self): - frame = Frame.from_file(self.tcp_client.rfile, self) - if isinstance(frame, SettingsFrame): - self._apply_settings(frame.settings) - - return frame - - def _apply_settings(self, settings): - for setting, value in settings.items(): - old_value = self.http2_settings[setting] - if not old_value: - old_value = '-' - - self.http2_settings[setting] = value - log.debug("Setting changed: %s to %s (was %s)" % ( - SettingsFrame.SETTINGS.get_name(setting), - str(value), - str(old_value))) - - self.send_frame(SettingsFrame(state=self, flags=Frame.FLAG_ACK)) - log.debug("New settings acknowledged.") - - def _create_headers(self, headers, stream_id, end_stream=True): - # TODO: implement max frame size checks and sending in chunks - - flags = Frame.FLAG_END_HEADERS - if end_stream: - flags |= Frame.FLAG_END_STREAM - - header_block_fragment = self.encoder.encode(headers) - - bytes = HeadersFrame( - state=self, - flags=flags, - stream_id=stream_id, - header_block_fragment=header_block_fragment).to_bytes() - return [bytes] - - def _create_body(self, body, stream_id): - if body is None or len(body) == 0: - return b'' - - # TODO: implement max frame size checks and sending in chunks - # TODO: implement flow-control window - - bytes = DataFrame( - state=self, - flags=Frame.FLAG_END_STREAM, - stream_id=stream_id, - payload=body).to_bytes() - return [bytes] - - def create_request(self, method, path, headers=None, body=None): - if headers is None: - headers = [] - - headers = [ - (b':method', bytes(method)), - (b':path', bytes(path)), - (b':scheme', b'https')] + headers - - stream_id = self.next_stream_id() - - return list(itertools.chain( - self._create_headers(headers, stream_id, end_stream=(body is None)), - self._create_body(body, stream_id))) - - def read_response(self): - header_block_fragment = b'' - body = b'' - - while True: - frame = self.read_frame() - if isinstance(frame, HeadersFrame): - header_block_fragment += frame.header_block_fragment - if frame.flags | Frame.FLAG_END_HEADERS: - break - - while True: - frame = self.read_frame() - if isinstance(frame, DataFrame): - body += frame.payload - if frame.flags | Frame.FLAG_END_STREAM: - break - - headers = {} - for header, value in self.decoder.decode(header_block_fragment): - headers[header] = value - - for header, value in headers.items(): - log.debug("%s: %s" % (header, value)) - - return headers[':status'], headers, body +from frame import * +from protocol import * diff --git a/netlib/http2/frame.py b/netlib/http2/frame.py index 1497380a..fc86c228 100644 --- a/netlib/http2/frame.py +++ b/netlib/http2/frame.py @@ -38,13 +38,11 @@ class Frame(object): raise ValueError('invalid flags detected.') if state is None: - from . import HTTP2Protocol - class State(object): pass state = State() - state.http2_settings = HTTP2Protocol.HTTP2_DEFAULT_SETTINGS.copy() + state.http2_settings = HTTP2_DEFAULT_SETTINGS.copy() state.encoder = Encoder() state.decoder = Decoder() @@ -57,12 +55,10 @@ class Frame(object): @classmethod def _check_frame_size(self, length, state): - from . import HTTP2Protocol - if state: settings = state.http2_settings else: - settings = HTTP2Protocol.HTTP2_DEFAULT_SETTINGS + settings = HTTP2_DEFAULT_SETTINGS.copy() max_frame_size = settings[ SettingsFrame.SETTINGS.SETTINGS_MAX_FRAME_SIZE] @@ -623,3 +619,13 @@ _FRAME_CLASSES = [ ContinuationFrame ] FRAMES = {cls.TYPE: cls for cls in _FRAME_CLASSES} + + +HTTP2_DEFAULT_SETTINGS = { + SettingsFrame.SETTINGS.SETTINGS_HEADER_TABLE_SIZE: 4096, + SettingsFrame.SETTINGS.SETTINGS_ENABLE_PUSH: 1, + SettingsFrame.SETTINGS.SETTINGS_MAX_CONCURRENT_STREAMS: None, + SettingsFrame.SETTINGS.SETTINGS_INITIAL_WINDOW_SIZE: 2 ** 16 - 1, + SettingsFrame.SETTINGS.SETTINGS_MAX_FRAME_SIZE: 2 ** 14, + SettingsFrame.SETTINGS.SETTINGS_MAX_HEADER_LIST_SIZE: None, +} diff --git a/netlib/http2/protocol.py b/netlib/http2/protocol.py new file mode 100644 index 00000000..9bab431c --- /dev/null +++ b/netlib/http2/protocol.py @@ -0,0 +1,174 @@ +from __future__ import (absolute_import, print_function, division) +import itertools +import logging + +from hpack.hpack import Encoder, Decoder +from .. import utils +from . import frame + +log = logging.getLogger(__name__) + + +class HTTP2Protocol(object): + + ERROR_CODES = utils.BiDi( + NO_ERROR=0x0, + PROTOCOL_ERROR=0x1, + INTERNAL_ERROR=0x2, + FLOW_CONTROL_ERROR=0x3, + SETTINGS_TIMEOUT=0x4, + STREAM_CLOSED=0x5, + FRAME_SIZE_ERROR=0x6, + REFUSED_STREAM=0x7, + CANCEL=0x8, + COMPRESSION_ERROR=0x9, + CONNECT_ERROR=0xa, + ENHANCE_YOUR_CALM=0xb, + INADEQUATE_SECURITY=0xc, + HTTP_1_1_REQUIRED=0xd + ) + + # "PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n" + CLIENT_CONNECTION_PREFACE = '505249202a20485454502f322e300d0a0d0a534d0d0a0d0a' + + ALPN_PROTO_H2 = 'h2' + + def __init__(self, tcp_client): + self.tcp_client = tcp_client + + self.http2_settings = frame.HTTP2_DEFAULT_SETTINGS.copy() + self.current_stream_id = None + self.encoder = Encoder() + self.decoder = Decoder() + + def check_alpn(self): + alp = self.tcp_client.get_alpn_proto_negotiated() + if alp != self.ALPN_PROTO_H2: + raise NotImplementedError( + "HTTP2Protocol can not handle unknown ALP: %s" % alp) + log.debug("ALP 'h2' successfully negotiated.") + return True + + def perform_connection_preface(self): + self.tcp_client.wfile.write( + bytes(self.CLIENT_CONNECTION_PREFACE.decode('hex'))) + self.send_frame(frame.SettingsFrame(state=self)) + + # read server settings frame + frm = frame.Frame.from_file(self.tcp_client.rfile, self) + assert isinstance(frm, frame.SettingsFrame) + self._apply_settings(frm.settings) + + # read setting ACK frame + settings_ack_frame = self.read_frame() + assert isinstance(settings_ack_frame, frame.SettingsFrame) + assert settings_ack_frame.flags & frame.Frame.FLAG_ACK + assert len(settings_ack_frame.settings) == 0 + + log.debug("Connection Preface completed.") + + def next_stream_id(self): + if self.current_stream_id is None: + self.current_stream_id = 1 + else: + self.current_stream_id += 2 + return self.current_stream_id + + def send_frame(self, frame): + raw_bytes = frame.to_bytes() + self.tcp_client.wfile.write(raw_bytes) + self.tcp_client.wfile.flush() + + def read_frame(self): + frm = frame.Frame.from_file(self.tcp_client.rfile, self) + if isinstance(frm, frame.SettingsFrame): + self._apply_settings(frm.settings) + + return frm + + def _apply_settings(self, settings): + for setting, value in settings.items(): + old_value = self.http2_settings[setting] + if not old_value: + old_value = '-' + + self.http2_settings[setting] = value + log.debug("Setting changed: %s to %s (was %s)" % ( + frame.SettingsFrame.SETTINGS.get_name(setting), + str(value), + str(old_value))) + + self.send_frame(frame.SettingsFrame(state=self, flags=frame.Frame.FLAG_ACK)) + log.debug("New settings acknowledged.") + + def _create_headers(self, headers, stream_id, end_stream=True): + # TODO: implement max frame size checks and sending in chunks + + flags = frame.Frame.FLAG_END_HEADERS + if end_stream: + flags |= frame.Frame.FLAG_END_STREAM + + header_block_fragment = self.encoder.encode(headers) + + bytes = frame.HeadersFrame( + state=self, + flags=flags, + stream_id=stream_id, + header_block_fragment=header_block_fragment).to_bytes() + return [bytes] + + def _create_body(self, body, stream_id): + if body is None or len(body) == 0: + return b'' + + # TODO: implement max frame size checks and sending in chunks + # TODO: implement flow-control window + + bytes = frame.DataFrame( + state=self, + flags=frame.Frame.FLAG_END_STREAM, + stream_id=stream_id, + payload=body).to_bytes() + return [bytes] + + def create_request(self, method, path, headers=None, body=None): + if headers is None: + headers = [] + + headers = [ + (b':method', bytes(method)), + (b':path', bytes(path)), + (b':scheme', b'https')] + headers + + stream_id = self.next_stream_id() + + return list(itertools.chain( + self._create_headers(headers, stream_id, end_stream=(body is None)), + self._create_body(body, stream_id))) + + def read_response(self): + header_block_fragment = b'' + body = b'' + + while True: + frm = self.read_frame() + if isinstance(frm, frame.HeadersFrame): + header_block_fragment += frm.header_block_fragment + if frm.flags | frame.Frame.FLAG_END_HEADERS: + break + + while True: + frm = self.read_frame() + if isinstance(frm, frame.DataFrame): + body += frm.payload + if frm.flags | frame.Frame.FLAG_END_STREAM: + break + + headers = {} + for header, value in self.decoder.decode(header_block_fragment): + headers[header] = value + + for header, value in headers.items(): + log.debug("%s: %s" % (header, value)) + + return headers[':status'], headers, body -- cgit v1.2.3 From 9c48bfb2a53bf3ac3c29408511e3126ada16afd8 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Sat, 6 Jun 2015 12:30:53 +1200 Subject: http2: ditch the logging for now The API is well designed: it looks like we can get all the information we need to expose debugging in the caller of the API. --- netlib/http2/frame.py | 3 --- netlib/http2/protocol.py | 13 ------------- 2 files changed, 16 deletions(-) (limited to 'netlib') diff --git a/netlib/http2/frame.py b/netlib/http2/frame.py index fc86c228..ac9b8d50 100644 --- a/netlib/http2/frame.py +++ b/netlib/http2/frame.py @@ -1,12 +1,9 @@ import struct -import logging from functools import reduce from hpack.hpack import Encoder, Decoder from .. import utils -log = logging.getLogger(__name__) - class FrameSizeError(Exception): pass diff --git a/netlib/http2/protocol.py b/netlib/http2/protocol.py index 9bab431c..459c2293 100644 --- a/netlib/http2/protocol.py +++ b/netlib/http2/protocol.py @@ -1,13 +1,10 @@ from __future__ import (absolute_import, print_function, division) import itertools -import logging from hpack.hpack import Encoder, Decoder from .. import utils from . import frame -log = logging.getLogger(__name__) - class HTTP2Protocol(object): @@ -46,7 +43,6 @@ class HTTP2Protocol(object): if alp != self.ALPN_PROTO_H2: raise NotImplementedError( "HTTP2Protocol can not handle unknown ALP: %s" % alp) - log.debug("ALP 'h2' successfully negotiated.") return True def perform_connection_preface(self): @@ -65,7 +61,6 @@ class HTTP2Protocol(object): assert settings_ack_frame.flags & frame.Frame.FLAG_ACK assert len(settings_ack_frame.settings) == 0 - log.debug("Connection Preface completed.") def next_stream_id(self): if self.current_stream_id is None: @@ -93,13 +88,8 @@ class HTTP2Protocol(object): old_value = '-' self.http2_settings[setting] = value - log.debug("Setting changed: %s to %s (was %s)" % ( - frame.SettingsFrame.SETTINGS.get_name(setting), - str(value), - str(old_value))) self.send_frame(frame.SettingsFrame(state=self, flags=frame.Frame.FLAG_ACK)) - log.debug("New settings acknowledged.") def _create_headers(self, headers, stream_id, end_stream=True): # TODO: implement max frame size checks and sending in chunks @@ -168,7 +158,4 @@ class HTTP2Protocol(object): for header, value in self.decoder.decode(header_block_fragment): headers[header] = value - for header, value in headers.items(): - log.debug("%s: %s" % (header, value)) - return headers[':status'], headers, body -- cgit v1.2.3 From 359ef469054b6a80ff8a5a3148a52e864a76fe9b Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Mon, 8 Jun 2015 12:21:08 +0200 Subject: fix coding style --- netlib/http2/protocol.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) (limited to 'netlib') diff --git a/netlib/http2/protocol.py b/netlib/http2/protocol.py index 459c2293..feac220c 100644 --- a/netlib/http2/protocol.py +++ b/netlib/http2/protocol.py @@ -61,7 +61,6 @@ class HTTP2Protocol(object): assert settings_ack_frame.flags & frame.Frame.FLAG_ACK assert len(settings_ack_frame.settings) == 0 - def next_stream_id(self): if self.current_stream_id is None: self.current_stream_id = 1 @@ -89,7 +88,10 @@ class HTTP2Protocol(object): self.http2_settings[setting] = value - self.send_frame(frame.SettingsFrame(state=self, flags=frame.Frame.FLAG_ACK)) + self.send_frame( + frame.SettingsFrame( + state=self, + flags=frame.Frame.FLAG_ACK)) def _create_headers(self, headers, stream_id, end_stream=True): # TODO: implement max frame size checks and sending in chunks -- cgit v1.2.3 From 4666d1e7bbf77b470d938d873d1a760283963adf Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Mon, 8 Jun 2015 11:29:01 +0200 Subject: improve ALPN support on travis --- netlib/tcp.py | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) (limited to 'netlib') diff --git a/netlib/tcp.py b/netlib/tcp.py index f6179faa..fc2ce115 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -6,6 +6,8 @@ import sys import threading import time import traceback + +import OpenSSL from OpenSSL import SSL from . import certutils @@ -401,16 +403,17 @@ class _Connection(object): if log_ssl_key: context.set_info_callback(log_ssl_key) - # advertise application layer protocols - if alpn_protos is not None: - context.set_alpn_protos(alpn_protos) + if OpenSSL._util.lib.Cryptography_HAS_ALPN: + # advertise application layer protocols + if alpn_protos is not None: + context.set_alpn_protos(alpn_protos) - # select application layer protocol - if alpn_select is not None: - def alpn_select_f(conn, options): - return bytes(alpn_select) + # select application layer protocol + if alpn_select is not None: + def alpn_select_f(conn, options): + return bytes(alpn_select) - context.set_alpn_select_callback(alpn_select_f) + context.set_alpn_select_callback(alpn_select_f) return context -- cgit v1.2.3 From abbe88c8ce4f19de33723ac0828cd24b8ec5f38b Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Mon, 8 Jun 2015 13:25:42 +0200 Subject: fix non-ALPN supported OpenSSL-related tests --- netlib/tcp.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) (limited to 'netlib') diff --git a/netlib/tcp.py b/netlib/tcp.py index fc2ce115..09c43ffc 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -499,7 +499,10 @@ class TCPClient(_Connection): return self.connection.gettimeout() def get_alpn_proto_negotiated(self): - return self.connection.get_alpn_proto_negotiated() + if OpenSSL._util.lib.Cryptography_HAS_ALPN: + return self.connection.get_alpn_proto_negotiated() + else: + return None class BaseHandler(_Connection): -- cgit v1.2.3 From fdbb3b76cf8cd7caaa644dc31e48521096ed5349 Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Mon, 8 Jun 2015 16:54:19 +0200 Subject: http2: add warning if raw data looks like HTTP/1 --- netlib/http2/frame.py | 4 ++++ netlib/tcp.py | 2 +- 2 files changed, 5 insertions(+), 1 deletion(-) (limited to 'netlib') diff --git a/netlib/http2/frame.py b/netlib/http2/frame.py index ac9b8d50..4a305d82 100644 --- a/netlib/http2/frame.py +++ b/netlib/http2/frame.py @@ -1,3 +1,4 @@ +import sys import struct from functools import reduce from hpack.hpack import Encoder, Decoder @@ -79,6 +80,9 @@ class Frame(object): flags = fields[3] stream_id = fields[4] + if raw_header[:4] == b'HTTP': # pragma no cover + print >> sys.stderr, "WARNING: This looks like an HTTP/1 connection!" + self._check_frame_size(length, state) payload = fp.safe_read(length) diff --git a/netlib/tcp.py b/netlib/tcp.py index 09c43ffc..62545244 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -501,7 +501,7 @@ class TCPClient(_Connection): def get_alpn_proto_negotiated(self): if OpenSSL._util.lib.Cryptography_HAS_ALPN: return self.connection.get_alpn_proto_negotiated() - else: + else: # pragma no cover return None -- cgit v1.2.3 From 0595585974dd889a10e05cade06f5534c85d7401 Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Mon, 8 Jun 2015 17:00:03 +0200 Subject: fix coding style --- netlib/tcp.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'netlib') diff --git a/netlib/tcp.py b/netlib/tcp.py index 62545244..9a980035 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -501,7 +501,7 @@ class TCPClient(_Connection): def get_alpn_proto_negotiated(self): if OpenSSL._util.lib.Cryptography_HAS_ALPN: return self.connection.get_alpn_proto_negotiated() - else: # pragma no cover + else: # pragma no cover return None -- cgit v1.2.3 From eeaed93a83fbe14762e263e9f25b5361088daa15 Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Thu, 11 Jun 2015 15:37:17 +0200 Subject: improve ALPN integration --- netlib/tcp.py | 23 +++++++++++++++-------- 1 file changed, 15 insertions(+), 8 deletions(-) (limited to 'netlib') diff --git a/netlib/tcp.py b/netlib/tcp.py index 9a980035..98b17c50 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -404,16 +404,17 @@ class _Connection(object): context.set_info_callback(log_ssl_key) if OpenSSL._util.lib.Cryptography_HAS_ALPN: - # advertise application layer protocols if alpn_protos is not None: + # advertise application layer protocols context.set_alpn_protos(alpn_protos) - - # select application layer protocol - if alpn_select is not None: - def alpn_select_f(conn, options): - return bytes(alpn_select) - - context.set_alpn_select_callback(alpn_select_f) + elif alpn_select is not None: + # select application layer protocol + def alpn_select_callback(conn, options): + if alpn_select in options: + return bytes(alpn_select) + else: + return options[0] + context.set_alpn_select_callback(alpn_select_callback) return context @@ -612,6 +613,12 @@ class BaseHandler(_Connection): def settimeout(self, n): self.connection.settimeout(n) + def get_alpn_proto_negotiated(self): + if OpenSSL._util.lib.Cryptography_HAS_ALPN and self.ssl_established: + return self.connection.get_alpn_proto_negotiated() + else: # pragma no cover + return None + class TCPServer(object): request_queue_size = 20 -- cgit v1.2.3 From 8ea157775debeccfa0f2fab3aa7e009d13ce4391 Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Thu, 11 Jun 2015 15:38:32 +0200 Subject: http2: general improvements --- netlib/http2/protocol.py | 63 +++++++++++++++++++++++++++++++++--------------- 1 file changed, 43 insertions(+), 20 deletions(-) (limited to 'netlib') diff --git a/netlib/http2/protocol.py b/netlib/http2/protocol.py index feac220c..4b69764f 100644 --- a/netlib/http2/protocol.py +++ b/netlib/http2/protocol.py @@ -26,12 +26,13 @@ class HTTP2Protocol(object): ) # "PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n" - CLIENT_CONNECTION_PREFACE = '505249202a20485454502f322e300d0a0d0a534d0d0a0d0a' + CLIENT_CONNECTION_PREFACE = '505249202a20485454502f322e300d0a0d0a534d0d0a0d0a'.decode('hex') ALPN_PROTO_H2 = 'h2' - def __init__(self, tcp_client): - self.tcp_client = tcp_client + def __init__(self, tcp_handler, is_server=False): + self.tcp_handler = tcp_handler + self.is_server = is_server self.http2_settings = frame.HTTP2_DEFAULT_SETTINGS.copy() self.current_stream_id = None @@ -39,28 +40,39 @@ class HTTP2Protocol(object): self.decoder = Decoder() def check_alpn(self): - alp = self.tcp_client.get_alpn_proto_negotiated() + alp = self.tcp_handler.get_alpn_proto_negotiated() if alp != self.ALPN_PROTO_H2: raise NotImplementedError( "HTTP2Protocol can not handle unknown ALP: %s" % alp) return True - def perform_connection_preface(self): - self.tcp_client.wfile.write( - bytes(self.CLIENT_CONNECTION_PREFACE.decode('hex'))) - self.send_frame(frame.SettingsFrame(state=self)) - - # read server settings frame - frm = frame.Frame.from_file(self.tcp_client.rfile, self) + def _receive_settings(self): + frm = frame.Frame.from_file(self.tcp_handler.rfile, self) assert isinstance(frm, frame.SettingsFrame) self._apply_settings(frm.settings) - # read setting ACK frame + def _read_settings_ack(self): settings_ack_frame = self.read_frame() assert isinstance(settings_ack_frame, frame.SettingsFrame) assert settings_ack_frame.flags & frame.Frame.FLAG_ACK assert len(settings_ack_frame.settings) == 0 + def perform_server_connection_preface(self): + magic_length = len(self.CLIENT_CONNECTION_PREFACE) + magic = self.tcp_handler.rfile.safe_read(magic_length) + assert magic == self.CLIENT_CONNECTION_PREFACE + + self.send_frame(frame.SettingsFrame(state=self)) + self._receive_settings() + self._read_settings_ack() + + def perform_client_connection_preface(self): + self.tcp_handler.wfile.write(self.CLIENT_CONNECTION_PREFACE) + + self.send_frame(frame.SettingsFrame(state=self)) + self._receive_settings() + self._read_settings_ack() + def next_stream_id(self): if self.current_stream_id is None: self.current_stream_id = 1 @@ -70,11 +82,11 @@ class HTTP2Protocol(object): def send_frame(self, frame): raw_bytes = frame.to_bytes() - self.tcp_client.wfile.write(raw_bytes) - self.tcp_client.wfile.flush() + self.tcp_handler.wfile.write(raw_bytes) + self.tcp_handler.wfile.flush() def read_frame(self): - frm = frame.Frame.from_file(self.tcp_client.rfile, self) + frm = frame.Frame.from_file(self.tcp_handler.rfile, self) if isinstance(frm, frame.SettingsFrame): self._apply_settings(frm.settings) @@ -139,25 +151,36 @@ class HTTP2Protocol(object): self._create_body(body, stream_id))) def read_response(self): + headers, body = self._receive_transmission() + return headers[':status'], headers, body + + def read_request(self): + return self._receive_transmission() + + def _receive_transmission(self): + body_expected = True + header_block_fragment = b'' body = b'' while True: frm = self.read_frame() - if isinstance(frm, frame.HeadersFrame): + if isinstance(frm, frame.HeadersFrame) or isinstance(frm, frame.ContinuationFrame): header_block_fragment += frm.header_block_fragment - if frm.flags | frame.Frame.FLAG_END_HEADERS: + if frm.flags & frame.Frame.FLAG_END_HEADERS: + if frm.flags & frame.Frame.FLAG_END_STREAM: + body_expected = False break - while True: + while body_expected: frm = self.read_frame() if isinstance(frm, frame.DataFrame): body += frm.payload - if frm.flags | frame.Frame.FLAG_END_STREAM: + if frm.flags & frame.Frame.FLAG_END_STREAM: break headers = {} for header, value in self.decoder.decode(header_block_fragment): headers[header] = value - return headers[':status'], headers, body + return headers, body -- cgit v1.2.3 From a901bc3032747faf00adf82c3187d38213c070ca Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Fri, 12 Jun 2015 14:41:54 +0200 Subject: http2: add response creation --- netlib/http2/protocol.py | 56 +++++++++++++++++++++++++++++++++++------------- 1 file changed, 41 insertions(+), 15 deletions(-) (limited to 'netlib') diff --git a/netlib/http2/protocol.py b/netlib/http2/protocol.py index 4b69764f..56aee490 100644 --- a/netlib/http2/protocol.py +++ b/netlib/http2/protocol.py @@ -26,7 +26,8 @@ class HTTP2Protocol(object): ) # "PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n" - CLIENT_CONNECTION_PREFACE = '505249202a20485454502f322e300d0a0d0a534d0d0a0d0a'.decode('hex') + CLIENT_CONNECTION_PREFACE =\ + '505249202a20485454502f322e300d0a0d0a534d0d0a0d0a'.decode('hex') ALPN_PROTO_H2 = 'h2' @@ -38,6 +39,7 @@ class HTTP2Protocol(object): self.current_stream_id = None self.encoder = Encoder() self.decoder = Decoder() + self.connection_preface_performed = False def check_alpn(self): alp = self.tcp_handler.get_alpn_proto_negotiated() @@ -57,25 +59,36 @@ class HTTP2Protocol(object): assert settings_ack_frame.flags & frame.Frame.FLAG_ACK assert len(settings_ack_frame.settings) == 0 - def perform_server_connection_preface(self): - magic_length = len(self.CLIENT_CONNECTION_PREFACE) - magic = self.tcp_handler.rfile.safe_read(magic_length) - assert magic == self.CLIENT_CONNECTION_PREFACE + def perform_server_connection_preface(self, force=False): + if force or not self.connection_preface_performed: + self.connection_preface_performed = True - self.send_frame(frame.SettingsFrame(state=self)) - self._receive_settings() - self._read_settings_ack() + magic_length = len(self.CLIENT_CONNECTION_PREFACE) + magic = self.tcp_handler.rfile.safe_read(magic_length) + assert magic == self.CLIENT_CONNECTION_PREFACE - def perform_client_connection_preface(self): - self.tcp_handler.wfile.write(self.CLIENT_CONNECTION_PREFACE) + self.send_frame(frame.SettingsFrame(state=self)) + self._receive_settings() + self._read_settings_ack() - self.send_frame(frame.SettingsFrame(state=self)) - self._receive_settings() - self._read_settings_ack() + def perform_client_connection_preface(self, force=False): + if force or not self.connection_preface_performed: + self.connection_preface_performed = True + + self.tcp_handler.wfile.write(self.CLIENT_CONNECTION_PREFACE) + + self.send_frame(frame.SettingsFrame(state=self)) + self._receive_settings() + self._read_settings_ack() def next_stream_id(self): if self.current_stream_id is None: - self.current_stream_id = 1 + if self.is_server: + # servers must use even stream ids + self.current_stream_id = 2 + else: + # clients must use odd stream ids + self.current_stream_id = 1 else: self.current_stream_id += 2 return self.current_stream_id @@ -165,7 +178,8 @@ class HTTP2Protocol(object): while True: frm = self.read_frame() - if isinstance(frm, frame.HeadersFrame) or isinstance(frm, frame.ContinuationFrame): + if isinstance(frm, frame.HeadersFrame)\ + or isinstance(frm, frame.ContinuationFrame): header_block_fragment += frm.header_block_fragment if frm.flags & frame.Frame.FLAG_END_HEADERS: if frm.flags & frame.Frame.FLAG_END_STREAM: @@ -184,3 +198,15 @@ class HTTP2Protocol(object): headers[header] = value return headers, body + + def create_response(self, code, headers=None, body=None): + if headers is None: + headers = [] + + headers = [(b':status', bytes(str(code)))] + headers + + stream_id = self.next_stream_id() + + return list(itertools.chain( + self._create_headers(headers, stream_id, end_stream=(body is None)), + self._create_body(body, stream_id))) -- cgit v1.2.3 From 5fab755a05f2ddd1b3e8e446e10fdcbded894e70 Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Fri, 12 Jun 2015 15:21:23 +0200 Subject: add more tests --- netlib/tcp.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) (limited to 'netlib') diff --git a/netlib/tcp.py b/netlib/tcp.py index 98b17c50..eb8a523f 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -412,7 +412,7 @@ class _Connection(object): def alpn_select_callback(conn, options): if alpn_select in options: return bytes(alpn_select) - else: + else: # pragma no cover return options[0] context.set_alpn_select_callback(alpn_select_callback) @@ -500,9 +500,9 @@ class TCPClient(_Connection): return self.connection.gettimeout() def get_alpn_proto_negotiated(self): - if OpenSSL._util.lib.Cryptography_HAS_ALPN: + if OpenSSL._util.lib.Cryptography_HAS_ALPN and self.ssl_established: return self.connection.get_alpn_proto_negotiated() - else: # pragma no cover + else: return None @@ -616,7 +616,7 @@ class BaseHandler(_Connection): def get_alpn_proto_negotiated(self): if OpenSSL._util.lib.Cryptography_HAS_ALPN and self.ssl_established: return self.connection.get_alpn_proto_negotiated() - else: # pragma no cover + else: return None -- cgit v1.2.3 From 9c6d237d02290c2388f19ec8f215827d4f921e4b Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Fri, 12 Jun 2015 16:03:01 +0200 Subject: add new TLS methods --- netlib/tcp.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) (limited to 'netlib') diff --git a/netlib/tcp.py b/netlib/tcp.py index eb8a523f..74fe70d4 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -19,6 +19,9 @@ SSLv2_METHOD = SSL.SSLv2_METHOD SSLv3_METHOD = SSL.SSLv3_METHOD SSLv23_METHOD = SSL.SSLv23_METHOD TLSv1_METHOD = SSL.TLSv1_METHOD +TLSv1_1_METHOD = SSL.TLSv1_1_METHOD +TLSv1_2_METHOD = SSL.TLSv1_2_METHOD + OP_NO_SSLv2 = SSL.OP_NO_SSLv2 OP_NO_SSLv3 = SSL.OP_NO_SSLv3 @@ -376,7 +379,7 @@ class _Connection(object): alpn_select=None, ): """ - :param method: One of SSLv2_METHOD, SSLv3_METHOD, SSLv23_METHOD, TLSv1_METHOD or TLSv1_1_METHOD + :param method: One of SSLv2_METHOD, SSLv3_METHOD, SSLv23_METHOD, TLSv1_METHOD, TLSv1_1_METHOD, or TLSv1_2_METHOD :param options: A bit field consisting of OpenSSL.SSL.OP_* values :param cipher_list: A textual OpenSSL cipher list, see https://www.openssl.org/docs/apps/ciphers.html :rtype : SSL.Context -- cgit v1.2.3 From 8d71a5b4aba8248b97918b11b12275bbf5197337 Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Sun, 14 Jun 2015 19:17:34 +0200 Subject: http2: add authority header --- netlib/http2/protocol.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) (limited to 'netlib') diff --git a/netlib/http2/protocol.py b/netlib/http2/protocol.py index 56aee490..1e722dfb 100644 --- a/netlib/http2/protocol.py +++ b/netlib/http2/protocol.py @@ -152,10 +152,13 @@ class HTTP2Protocol(object): if headers is None: headers = [] + authority = self.tcp_handler.sni if self.tcp_handler.sni else self.tcp_handler.address.host headers = [ (b':method', bytes(method)), (b':path', bytes(path)), - (b':scheme', b'https')] + headers + (b':scheme', b'https'), + (b':authority', authority), + ] + headers stream_id = self.next_stream_id() @@ -192,6 +195,7 @@ class HTTP2Protocol(object): body += frm.payload if frm.flags & frame.Frame.FLAG_END_STREAM: break + # TODO: implement window update & flow headers = {} for header, value in self.decoder.decode(header_block_fragment): -- cgit v1.2.3 From 0d137eac6f4c00a72d3aa4d11fce7d1ea15f0f21 Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Sun, 14 Jun 2015 19:50:35 +0200 Subject: simplify ALPN --- netlib/tcp.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) (limited to 'netlib') diff --git a/netlib/tcp.py b/netlib/tcp.py index 74fe70d4..897e3e65 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -535,7 +535,6 @@ class BaseHandler(_Connection): request_client_cert=None, chain_file=None, dhparams=None, - alpn_select=None, **sslctx_kwargs): """ cert: A certutils.SSLCert object. @@ -562,9 +561,7 @@ class BaseHandler(_Connection): until then we're conservative. """ - context = self._create_ssl_context( - alpn_select=alpn_select, - **sslctx_kwargs) + context = self._create_ssl_context(**sslctx_kwargs) context.use_privatekey(key) context.use_certificate(cert.x509) @@ -589,7 +586,7 @@ class BaseHandler(_Connection): return context - def convert_to_ssl(self, cert, key, alpn_select=None, **sslctx_kwargs): + def convert_to_ssl(self, cert, key, **sslctx_kwargs): """ Convert connection to SSL. For a list of parameters, see BaseHandler._create_ssl_context(...) @@ -598,7 +595,6 @@ class BaseHandler(_Connection): context = self.create_ssl_context( cert, key, - alpn_select=alpn_select, **sslctx_kwargs) self.connection = SSL.Connection(context, self.connection) self.connection.set_accept_state() -- cgit v1.2.3 From fe764cde5229046b8447062971c61fac745d2d58 Mon Sep 17 00:00:00 2001 From: Kyle Morton Date: Mon, 15 Jun 2015 10:16:44 -0700 Subject: Adding support for upstream certificate validation when using SSL/TLS with an instance of TCPClient. --- netlib/tcp.py | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) (limited to 'netlib') diff --git a/netlib/tcp.py b/netlib/tcp.py index 9a980035..ca948514 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -21,6 +21,7 @@ SSLv23_METHOD = SSL.SSLv23_METHOD TLSv1_METHOD = SSL.TLSv1_METHOD OP_NO_SSLv2 = SSL.OP_NO_SSLv2 OP_NO_SSLv3 = SSL.OP_NO_SSLv3 +VERIFY_NONE = SSL.VERIFY_NONE class NetLibError(Exception): @@ -371,6 +372,9 @@ class _Connection(object): def _create_ssl_context(self, method=SSLv23_METHOD, options=(OP_NO_SSLv2 | OP_NO_SSLv3), + verify_options=VERIFY_NONE, + ca_path=None, + ca_pemfile=None, cipher_list=None, alpn_protos=None, alpn_select=None, @@ -378,6 +382,9 @@ class _Connection(object): """ :param method: One of SSLv2_METHOD, SSLv3_METHOD, SSLv23_METHOD, TLSv1_METHOD or TLSv1_1_METHOD :param options: A bit field consisting of OpenSSL.SSL.OP_* values + :param verify_options: A bit field consisting of OpenSSL.SSL.VERIFY_* values + :param ca_path: Path to a directory of trusted CA certificates prepared using the c_rehash tool + :param ca_pemfile: Path to a PEM formatted trusted CA certificate :param cipher_list: A textual OpenSSL cipher list, see https://www.openssl.org/docs/apps/ciphers.html :rtype : SSL.Context """ @@ -386,6 +393,19 @@ class _Connection(object): if options is not None: context.set_options(options) + # Verify Options (NONE/PEER/PEER|FAIL_IF_... and trusted CAs) + if verify_options is not None and verify_options is not VERIFY_NONE: + def verify_cert(conn, cert, errno, err_depth, is_cert_verified): + if is_cert_verified: + return True + raise NetLibError( + "Upstream certificate validation failed at depth: %s with error number: %s" % + (err_depth, errno)) + + context.set_verify(verify_options, verify_cert) + if ca_path is not None or ca_pemfile is not None: + context.load_verify_locations(ca_pemfile, ca_path) + # Workaround for # https://github.com/pyca/pyopenssl/issues/190 # https://github.com/mitmproxy/mitmproxy/issues/472 @@ -458,6 +478,9 @@ class TCPClient(_Connection): cert: Path to a file containing both client cert and private key. options: A bit field consisting of OpenSSL.SSL.OP_* values + verify_options: A bit field consisting of OpenSSL.SSL.VERIFY_* values + ca_path: Path to a directory of trusted CA certificates prepared using the c_rehash tool + ca_pemfile: Path to a PEM formatted trusted CA certificate """ context = self.create_ssl_context( alpn_protos=alpn_protos, -- cgit v1.2.3 From 12702b9a01fb6baf4d675d6f974c140581982843 Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Mon, 15 Jun 2015 13:15:06 +0200 Subject: http2: improve frame output --- netlib/http2/frame.py | 11 +++------ netlib/http2/protocol.py | 61 ++++++++++++++++++++++++++++-------------------- 2 files changed, 39 insertions(+), 33 deletions(-) (limited to 'netlib') diff --git a/netlib/http2/frame.py b/netlib/http2/frame.py index 4a305d82..3e285cba 100644 --- a/netlib/http2/frame.py +++ b/netlib/http2/frame.py @@ -113,16 +113,11 @@ class Frame(object): def payload_human_readable(self): # pragma: no cover raise NotImplementedError() - def human_readable(self): + def human_readable(self, direction="-"): return "\n".join([ - "============================================================", - "length: %d bytes" % self.length, - "type: %s (%#x)" % (self.__class__.__name__, self.TYPE), - "flags: %#x" % self.flags, - "stream_id: %#x" % self.stream_id, - "------------------------------------------------------------", + "%s: %s | length: %d | flags: %#x | stream_id: %d" % (direction, self.__class__.__name__, self.length, self.flags, self.stream_id), self.payload_human_readable(), - "============================================================", + "===============================================================", ]) def __eq__(self, other): diff --git a/netlib/http2/protocol.py b/netlib/http2/protocol.py index 1e722dfb..7bf68602 100644 --- a/netlib/http2/protocol.py +++ b/netlib/http2/protocol.py @@ -48,13 +48,12 @@ class HTTP2Protocol(object): "HTTP2Protocol can not handle unknown ALP: %s" % alp) return True - def _receive_settings(self): - frm = frame.Frame.from_file(self.tcp_handler.rfile, self) + def _receive_settings(self, hide=False): + frm = self.read_frame(hide) assert isinstance(frm, frame.SettingsFrame) - self._apply_settings(frm.settings) - def _read_settings_ack(self): - settings_ack_frame = self.read_frame() + def _read_settings_ack(self, hide=False): + settings_ack_frame = self.read_frame(hide) assert isinstance(settings_ack_frame, frame.SettingsFrame) assert settings_ack_frame.flags & frame.Frame.FLAG_ACK assert len(settings_ack_frame.settings) == 0 @@ -67,9 +66,8 @@ class HTTP2Protocol(object): magic = self.tcp_handler.rfile.safe_read(magic_length) assert magic == self.CLIENT_CONNECTION_PREFACE - self.send_frame(frame.SettingsFrame(state=self)) - self._receive_settings() - self._read_settings_ack() + self.send_frame(frame.SettingsFrame(state=self), hide=True) + self._receive_settings(hide=True) def perform_client_connection_preface(self, force=False): if force or not self.connection_preface_performed: @@ -77,9 +75,8 @@ class HTTP2Protocol(object): self.tcp_handler.wfile.write(self.CLIENT_CONNECTION_PREFACE) - self.send_frame(frame.SettingsFrame(state=self)) - self._receive_settings() - self._read_settings_ack() + self.send_frame(frame.SettingsFrame(state=self), hide=True) + self._receive_settings(hide=True) def next_stream_id(self): if self.current_stream_id is None: @@ -93,30 +90,35 @@ class HTTP2Protocol(object): self.current_stream_id += 2 return self.current_stream_id - def send_frame(self, frame): - raw_bytes = frame.to_bytes() + def send_frame(self, frm, hide=False): + raw_bytes = frm.to_bytes() self.tcp_handler.wfile.write(raw_bytes) self.tcp_handler.wfile.flush() + if not hide and self.tcp_handler.http2_framedump: + print(frm.human_readable(">>")) - def read_frame(self): + def read_frame(self, hide=False): frm = frame.Frame.from_file(self.tcp_handler.rfile, self) - if isinstance(frm, frame.SettingsFrame): - self._apply_settings(frm.settings) + if not hide and self.tcp_handler.http2_framedump: + print(frm.human_readable("<<")) + if isinstance(frm, frame.SettingsFrame) and not frm.flags & frame.Frame.FLAG_ACK: + self._apply_settings(frm.settings, hide) return frm - def _apply_settings(self, settings): + def _apply_settings(self, settings, hide=False): for setting, value in settings.items(): old_value = self.http2_settings[setting] if not old_value: old_value = '-' - self.http2_settings[setting] = value self.send_frame( frame.SettingsFrame( state=self, - flags=frame.Frame.FLAG_ACK)) + flags=frame.Frame.FLAG_ACK), + hide) + self._read_settings_ack(hide) def _create_headers(self, headers, stream_id, end_stream=True): # TODO: implement max frame size checks and sending in chunks @@ -127,12 +129,16 @@ class HTTP2Protocol(object): header_block_fragment = self.encoder.encode(headers) - bytes = frame.HeadersFrame( + frm = frame.HeadersFrame( state=self, flags=flags, stream_id=stream_id, - header_block_fragment=header_block_fragment).to_bytes() - return [bytes] + header_block_fragment=header_block_fragment) + + if self.tcp_handler.http2_framedump: + print(frm.human_readable(">>")) + + return [frm.to_bytes()] def _create_body(self, body, stream_id): if body is None or len(body) == 0: @@ -141,12 +147,17 @@ class HTTP2Protocol(object): # TODO: implement max frame size checks and sending in chunks # TODO: implement flow-control window - bytes = frame.DataFrame( + frm = frame.DataFrame( state=self, flags=frame.Frame.FLAG_END_STREAM, stream_id=stream_id, - payload=body).to_bytes() - return [bytes] + payload=body) + + if self.tcp_handler.http2_framedump: + print(frm.human_readable(">>")) + + return [frm.to_bytes()] + def create_request(self, method, path, headers=None, body=None): if headers is None: -- cgit v1.2.3 From 79ff43993018209a76a2a7cff995e912eb20d4c3 Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Mon, 15 Jun 2015 09:47:43 +0200 Subject: add elliptic curve during TLS handshake --- netlib/tcp.py | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) (limited to 'netlib') diff --git a/netlib/tcp.py b/netlib/tcp.py index 953cef6e..2e847d83 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -22,11 +22,6 @@ TLSv1_METHOD = SSL.TLSv1_METHOD TLSv1_1_METHOD = SSL.TLSv1_1_METHOD TLSv1_2_METHOD = SSL.TLSv1_2_METHOD -OP_NO_SSLv2 = SSL.OP_NO_SSLv2 -OP_NO_SSLv3 = SSL.OP_NO_SSLv3 -VERIFY_NONE = SSL.VERIFY_NONE - - class NetLibError(Exception): pass @@ -374,8 +369,8 @@ class _Connection(object): def _create_ssl_context(self, method=SSLv23_METHOD, - options=(OP_NO_SSLv2 | OP_NO_SSLv3), - verify_options=VERIFY_NONE, + options=(SSL.OP_NO_SSLv2 | SSL.OP_NO_SSLv3 | SSL.OP_CIPHER_SERVER_PREFERENCE | SSL.OP_NO_COMPRESSION), + verify_options=SSL.VERIFY_NONE, ca_path=None, ca_pemfile=None, cipher_list=None, @@ -397,7 +392,7 @@ class _Connection(object): context.set_options(options) # Verify Options (NONE/PEER/PEER|FAIL_IF_... and trusted CAs) - if verify_options is not None and verify_options is not VERIFY_NONE: + if verify_options is not None and verify_options is not SSL.VERIFY_NONE: def verify_cert(conn, cert, errno, err_depth, is_cert_verified): if is_cert_verified: return True @@ -426,6 +421,8 @@ class _Connection(object): if log_ssl_key: context.set_info_callback(log_ssl_key) + context.set_tmp_ecdh(OpenSSL.crypto.get_elliptic_curve('prime256v1')) + if OpenSSL._util.lib.Cryptography_HAS_ALPN: if alpn_protos is not None: # advertise application layer protocols -- cgit v1.2.3 From e3db241a2fa47a38fcb85532ed52eeecf1a7b965 Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Mon, 15 Jun 2015 13:43:23 +0200 Subject: http2: improve frame output --- netlib/http2/protocol.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) (limited to 'netlib') diff --git a/netlib/http2/protocol.py b/netlib/http2/protocol.py index 7bf68602..24fcb712 100644 --- a/netlib/http2/protocol.py +++ b/netlib/http2/protocol.py @@ -31,7 +31,7 @@ class HTTP2Protocol(object): ALPN_PROTO_H2 = 'h2' - def __init__(self, tcp_handler, is_server=False): + def __init__(self, tcp_handler, is_server=False, dump_frames=False): self.tcp_handler = tcp_handler self.is_server = is_server @@ -40,6 +40,7 @@ class HTTP2Protocol(object): self.encoder = Encoder() self.decoder = Decoder() self.connection_preface_performed = False + self.dump_frames = dump_frames def check_alpn(self): alp = self.tcp_handler.get_alpn_proto_negotiated() @@ -94,12 +95,12 @@ class HTTP2Protocol(object): raw_bytes = frm.to_bytes() self.tcp_handler.wfile.write(raw_bytes) self.tcp_handler.wfile.flush() - if not hide and self.tcp_handler.http2_framedump: + if not hide and self.dump_frames: print(frm.human_readable(">>")) def read_frame(self, hide=False): frm = frame.Frame.from_file(self.tcp_handler.rfile, self) - if not hide and self.tcp_handler.http2_framedump: + if not hide and self.dump_frames: print(frm.human_readable("<<")) if isinstance(frm, frame.SettingsFrame) and not frm.flags & frame.Frame.FLAG_ACK: self._apply_settings(frm.settings, hide) @@ -135,7 +136,7 @@ class HTTP2Protocol(object): stream_id=stream_id, header_block_fragment=header_block_fragment) - if self.tcp_handler.http2_framedump: + if self.dump_frames: print(frm.human_readable(">>")) return [frm.to_bytes()] @@ -153,7 +154,7 @@ class HTTP2Protocol(object): stream_id=stream_id, payload=body) - if self.tcp_handler.http2_framedump: + if self.dump_frames: print(frm.human_readable(">>")) return [frm.to_bytes()] -- cgit v1.2.3 From d0a9d3cdda6d1f784a23ea4bd9efd3134e292628 Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Mon, 15 Jun 2015 14:21:34 +0200 Subject: http2: only first headers frame as END_STREAM flag --- netlib/http2/protocol.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'netlib') diff --git a/netlib/http2/protocol.py b/netlib/http2/protocol.py index 24fcb712..682b7863 100644 --- a/netlib/http2/protocol.py +++ b/netlib/http2/protocol.py @@ -196,9 +196,9 @@ class HTTP2Protocol(object): if isinstance(frm, frame.HeadersFrame)\ or isinstance(frm, frame.ContinuationFrame): header_block_fragment += frm.header_block_fragment + if frm.flags & frame.Frame.FLAG_END_STREAM: + body_expected = False if frm.flags & frame.Frame.FLAG_END_HEADERS: - if frm.flags & frame.Frame.FLAG_END_STREAM: - body_expected = False break while body_expected: -- cgit v1.2.3 From 1c124421e34d310c6e0577f20b595413d639a5c3 Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Mon, 15 Jun 2015 15:31:58 +0200 Subject: http2: fix header_block_fragments and length --- netlib/http2/frame.py | 13 +++++++++++-- netlib/http2/protocol.py | 23 +++++++++++++++-------- 2 files changed, 26 insertions(+), 10 deletions(-) (limited to 'netlib') diff --git a/netlib/http2/frame.py b/netlib/http2/frame.py index 3e285cba..98ced904 100644 --- a/netlib/http2/frame.py +++ b/netlib/http2/frame.py @@ -114,6 +114,8 @@ class Frame(object): raise NotImplementedError() def human_readable(self, direction="-"): + self.length = len(self.payload_bytes()) + return "\n".join([ "%s: %s | length: %d | flags: %#x | stream_id: %d" % (direction, self.__class__.__name__, self.length, self.flags, self.stream_id), self.payload_human_readable(), @@ -456,7 +458,10 @@ class PushPromiseFrame(Frame): s.append("padding: %d" % self.pad_length) s.append("promised stream: %#x" % self.promised_stream) - s.append("header_block_fragment: %s" % str(self.header_block_fragment)) + s.append( + "header_block_fragment: %s" % + self.header_block_fragment.encode('hex')) + return "\n".join(s) @@ -600,7 +605,11 @@ class ContinuationFrame(Frame): return self.header_block_fragment def payload_human_readable(self): - return "header_block_fragment: %s" % str(self.header_block_fragment) + s = [] + s.append( + "header_block_fragment: %s" % + self.header_block_fragment.encode('hex')) + return "\n".join(s) _FRAME_CLASSES = [ DataFrame, diff --git a/netlib/http2/protocol.py b/netlib/http2/protocol.py index 682b7863..f17f998f 100644 --- a/netlib/http2/protocol.py +++ b/netlib/http2/protocol.py @@ -50,14 +50,18 @@ class HTTP2Protocol(object): return True def _receive_settings(self, hide=False): - frm = self.read_frame(hide) - assert isinstance(frm, frame.SettingsFrame) + while True: + frm = self.read_frame(hide) + if isinstance(frm, frame.SettingsFrame): + break def _read_settings_ack(self, hide=False): - settings_ack_frame = self.read_frame(hide) - assert isinstance(settings_ack_frame, frame.SettingsFrame) - assert settings_ack_frame.flags & frame.Frame.FLAG_ACK - assert len(settings_ack_frame.settings) == 0 + while True: + frm = self.read_frame(hide) + if isinstance(frm, frame.SettingsFrame): + assert settings_ack_frame.flags & frame.Frame.FLAG_ACK + assert len(settings_ack_frame.settings) == 0 + break def perform_server_connection_preface(self, force=False): if force or not self.connection_preface_performed: @@ -119,7 +123,7 @@ class HTTP2Protocol(object): state=self, flags=frame.Frame.FLAG_ACK), hide) - self._read_settings_ack(hide) + # self._read_settings_ack(hide) def _create_headers(self, headers, stream_id, end_stream=True): # TODO: implement max frame size checks and sending in chunks @@ -219,10 +223,13 @@ class HTTP2Protocol(object): if headers is None: headers = [] + body='foobar' + headers = [(b':status', bytes(str(code)))] + headers stream_id = self.next_stream_id() return list(itertools.chain( self._create_headers(headers, stream_id, end_stream=(body is None)), - self._create_body(body, stream_id))) + self._create_body(body, stream_id), + )) -- cgit v1.2.3 From 20c136e070cee0e93e870bf32199cb36b1b85275 Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Mon, 15 Jun 2015 15:51:40 +0200 Subject: http2: return stream_id from request for response --- netlib/http2/protocol.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) (limited to 'netlib') diff --git a/netlib/http2/protocol.py b/netlib/http2/protocol.py index f17f998f..a77edd9b 100644 --- a/netlib/http2/protocol.py +++ b/netlib/http2/protocol.py @@ -183,7 +183,7 @@ class HTTP2Protocol(object): self._create_body(body, stream_id))) def read_response(self): - headers, body = self._receive_transmission() + stream_id, headers, body = self._receive_transmission() return headers[':status'], headers, body def read_request(self): @@ -192,6 +192,7 @@ class HTTP2Protocol(object): def _receive_transmission(self): body_expected = True + stream_id = 0 header_block_fragment = b'' body = b'' @@ -199,6 +200,7 @@ class HTTP2Protocol(object): frm = self.read_frame() if isinstance(frm, frame.HeadersFrame)\ or isinstance(frm, frame.ContinuationFrame): + stream_id = frm.stream_id header_block_fragment += frm.header_block_fragment if frm.flags & frame.Frame.FLAG_END_STREAM: body_expected = False @@ -217,9 +219,9 @@ class HTTP2Protocol(object): for header, value in self.decoder.decode(header_block_fragment): headers[header] = value - return headers, body + return stream_id, headers, body - def create_response(self, code, headers=None, body=None): + def create_response(self, code, stream_id=None, headers=None, body=None): if headers is None: headers = [] @@ -227,7 +229,8 @@ class HTTP2Protocol(object): headers = [(b':status', bytes(str(code)))] + headers - stream_id = self.next_stream_id() + if not stream_id: + stream_id = self.next_stream_id() return list(itertools.chain( self._create_headers(headers, stream_id, end_stream=(body is None)), -- cgit v1.2.3 From abb37a3ef52ab9a0f68dc46e4a8ca165e365139b Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Mon, 15 Jun 2015 17:31:08 +0200 Subject: http2: improve test suite --- netlib/http2/protocol.py | 16 ++++++++-------- netlib/tcp.py | 9 +++++---- 2 files changed, 13 insertions(+), 12 deletions(-) (limited to 'netlib') diff --git a/netlib/http2/protocol.py b/netlib/http2/protocol.py index a77edd9b..8191090c 100644 --- a/netlib/http2/protocol.py +++ b/netlib/http2/protocol.py @@ -55,7 +55,7 @@ class HTTP2Protocol(object): if isinstance(frm, frame.SettingsFrame): break - def _read_settings_ack(self, hide=False): + def _read_settings_ack(self, hide=False): # pragma no cover while True: frm = self.read_frame(hide) if isinstance(frm, frame.SettingsFrame): @@ -99,12 +99,12 @@ class HTTP2Protocol(object): raw_bytes = frm.to_bytes() self.tcp_handler.wfile.write(raw_bytes) self.tcp_handler.wfile.flush() - if not hide and self.dump_frames: + if not hide and self.dump_frames: # pragma no cover print(frm.human_readable(">>")) def read_frame(self, hide=False): frm = frame.Frame.from_file(self.tcp_handler.rfile, self) - if not hide and self.dump_frames: + if not hide and self.dump_frames: # pragma no cover print(frm.human_readable("<<")) if isinstance(frm, frame.SettingsFrame) and not frm.flags & frame.Frame.FLAG_ACK: self._apply_settings(frm.settings, hide) @@ -123,7 +123,9 @@ class HTTP2Protocol(object): state=self, flags=frame.Frame.FLAG_ACK), hide) - # self._read_settings_ack(hide) + + # be liberal in what we expect from the other end + # to be more strict use: self._read_settings_ack(hide) def _create_headers(self, headers, stream_id, end_stream=True): # TODO: implement max frame size checks and sending in chunks @@ -140,7 +142,7 @@ class HTTP2Protocol(object): stream_id=stream_id, header_block_fragment=header_block_fragment) - if self.dump_frames: + if self.dump_frames: # pragma no cover print(frm.human_readable(">>")) return [frm.to_bytes()] @@ -158,7 +160,7 @@ class HTTP2Protocol(object): stream_id=stream_id, payload=body) - if self.dump_frames: + if self.dump_frames: # pragma no cover print(frm.human_readable(">>")) return [frm.to_bytes()] @@ -225,8 +227,6 @@ class HTTP2Protocol(object): if headers is None: headers = [] - body='foobar' - headers = [(b':status', bytes(str(code)))] + headers if not stream_id: diff --git a/netlib/tcp.py b/netlib/tcp.py index 2e847d83..cafc3ed9 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -414,6 +414,9 @@ class _Connection(object): if cipher_list: try: context.set_cipher_list(cipher_list) + + # TODO: maybe change this to with newer pyOpenSSL APIs + context.set_tmp_ecdh(OpenSSL.crypto.get_elliptic_curve('prime256v1')) except SSL.Error as v: raise NetLibError("SSL cipher specification error: %s" % str(v)) @@ -421,8 +424,6 @@ class _Connection(object): if log_ssl_key: context.set_info_callback(log_ssl_key) - context.set_tmp_ecdh(OpenSSL.crypto.get_elliptic_curve('prime256v1')) - if OpenSSL._util.lib.Cryptography_HAS_ALPN: if alpn_protos is not None: # advertise application layer protocols @@ -526,7 +527,7 @@ class TCPClient(_Connection): if OpenSSL._util.lib.Cryptography_HAS_ALPN and self.ssl_established: return self.connection.get_alpn_proto_negotiated() else: - return None + return "" class BaseHandler(_Connection): @@ -636,7 +637,7 @@ class BaseHandler(_Connection): if OpenSSL._util.lib.Cryptography_HAS_ALPN and self.ssl_established: return self.connection.get_alpn_proto_negotiated() else: - return None + return "" class TCPServer(object): -- cgit v1.2.3 From eb823a04a19de7fd9e15d225064ae4581f0b85bf Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Mon, 15 Jun 2015 23:36:14 +0200 Subject: http2: improve :authority header --- netlib/http2/protocol.py | 3 +++ 1 file changed, 3 insertions(+) (limited to 'netlib') diff --git a/netlib/http2/protocol.py b/netlib/http2/protocol.py index 8191090c..ac89bac4 100644 --- a/netlib/http2/protocol.py +++ b/netlib/http2/protocol.py @@ -171,6 +171,9 @@ class HTTP2Protocol(object): headers = [] authority = self.tcp_handler.sni if self.tcp_handler.sni else self.tcp_handler.address.host + if self.tcp_handler.address.port != 443: + authority += ":%d" % self.tcp_handler.address.port + headers = [ (b':method', bytes(method)), (b':path', bytes(path)), -- cgit v1.2.3 From c9c93af453ec332b660f70402b78ae8f269280f0 Mon Sep 17 00:00:00 2001 From: Kyle Morton Date: Tue, 16 Jun 2015 11:11:10 -0700 Subject: Adding certifi as default CA bundle. --- netlib/tcp.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) (limited to 'netlib') diff --git a/netlib/tcp.py b/netlib/tcp.py index ca948514..b523bea4 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -7,6 +7,7 @@ import threading import time import traceback +import certifi import OpenSSL from OpenSSL import SSL @@ -373,7 +374,7 @@ class _Connection(object): method=SSLv23_METHOD, options=(OP_NO_SSLv2 | OP_NO_SSLv3), verify_options=VERIFY_NONE, - ca_path=None, + ca_path=certifi.where(), ca_pemfile=None, cipher_list=None, alpn_protos=None, @@ -403,8 +404,7 @@ class _Connection(object): (err_depth, errno)) context.set_verify(verify_options, verify_cert) - if ca_path is not None or ca_pemfile is not None: - context.load_verify_locations(ca_pemfile, ca_path) + context.load_verify_locations(ca_pemfile, ca_path) # Workaround for # https://github.com/pyca/pyopenssl/issues/190 -- cgit v1.2.3 From 836b1eab9700230991822102d411aed067308123 Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Wed, 17 Jun 2015 13:10:27 +0200 Subject: fix warnings and code smells use prospector to find them --- netlib/http2/__init__.py | 1 - netlib/http2/frame.py | 55 ++++++++++++++++++++++++------------------------ netlib/http_cookies.py | 8 +++---- netlib/http_uastrings.py | 24 +++++++++++---------- netlib/tcp.py | 8 +++---- netlib/utils.py | 2 +- netlib/websockets.py | 16 +++++++------- 7 files changed, 56 insertions(+), 58 deletions(-) (limited to 'netlib') diff --git a/netlib/http2/__init__.py b/netlib/http2/__init__.py index 92897b5d..5acf7696 100644 --- a/netlib/http2/__init__.py +++ b/netlib/http2/__init__.py @@ -1,3 +1,2 @@ - from frame import * from protocol import * diff --git a/netlib/http2/frame.py b/netlib/http2/frame.py index 4a305d82..43676623 100644 --- a/netlib/http2/frame.py +++ b/netlib/http2/frame.py @@ -1,6 +1,5 @@ import sys import struct -from functools import reduce from hpack.hpack import Encoder, Decoder from .. import utils @@ -52,7 +51,7 @@ class Frame(object): self.stream_id = stream_id @classmethod - def _check_frame_size(self, length, state): + def _check_frame_size(cls, length, state): if state: settings = state.http2_settings else: @@ -67,7 +66,7 @@ class Frame(object): length, max_frame_size)) @classmethod - def from_file(self, fp, state=None): + def from_file(cls, fp, state=None): """ read a HTTP/2 frame sent by a server or client fp is a "file like" object that could be backed by a network @@ -83,7 +82,7 @@ class Frame(object): if raw_header[:4] == b'HTTP': # pragma no cover print >> sys.stderr, "WARNING: This looks like an HTTP/1 connection!" - self._check_frame_size(length, state) + cls._check_frame_size(length, state) payload = fp.safe_read(length) return FRAMES[fields[2]].from_bytes( @@ -146,10 +145,10 @@ class DataFrame(Frame): self.pad_length = pad_length @classmethod - def from_bytes(self, state, length, flags, stream_id, payload): - f = self(state=state, length=length, flags=flags, stream_id=stream_id) + def from_bytes(cls, state, length, flags, stream_id, payload): + f = cls(state=state, length=length, flags=flags, stream_id=stream_id) - if f.flags & self.FLAG_PADDED: + if f.flags & Frame.FLAG_PADDED: f.pad_length = struct.unpack('!B', payload[0])[0] f.payload = payload[1:-f.pad_length] else: @@ -204,16 +203,16 @@ class HeadersFrame(Frame): self.weight = weight @classmethod - def from_bytes(self, state, length, flags, stream_id, payload): - f = self(state=state, length=length, flags=flags, stream_id=stream_id) + def from_bytes(cls, state, length, flags, stream_id, payload): + f = cls(state=state, length=length, flags=flags, stream_id=stream_id) - if f.flags & self.FLAG_PADDED: + if f.flags & Frame.FLAG_PADDED: f.pad_length = struct.unpack('!B', payload[0])[0] f.header_block_fragment = payload[1:-f.pad_length] else: f.header_block_fragment = payload[0:] - if f.flags & self.FLAG_PRIORITY: + if f.flags & Frame.FLAG_PRIORITY: f.stream_dependency, f.weight = struct.unpack( '!LB', f.header_block_fragment[:5]) f.exclusive = bool(f.stream_dependency >> 31) @@ -279,8 +278,8 @@ class PriorityFrame(Frame): self.weight = weight @classmethod - def from_bytes(self, state, length, flags, stream_id, payload): - f = self(state=state, length=length, flags=flags, stream_id=stream_id) + def from_bytes(cls, state, length, flags, stream_id, payload): + f = cls(state=state, length=length, flags=flags, stream_id=stream_id) f.stream_dependency, f.weight = struct.unpack('!LB', payload) f.exclusive = bool(f.stream_dependency >> 31) @@ -325,8 +324,8 @@ class RstStreamFrame(Frame): self.error_code = error_code @classmethod - def from_bytes(self, state, length, flags, stream_id, payload): - f = self(state=state, length=length, flags=flags, stream_id=stream_id) + def from_bytes(cls, state, length, flags, stream_id, payload): + f = cls(state=state, length=length, flags=flags, stream_id=stream_id) f.error_code = struct.unpack('!L', payload)[0] return f @@ -369,8 +368,8 @@ class SettingsFrame(Frame): self.settings = settings @classmethod - def from_bytes(self, state, length, flags, stream_id, payload): - f = self(state=state, length=length, flags=flags, stream_id=stream_id) + def from_bytes(cls, state, length, flags, stream_id, payload): + f = cls(state=state, length=length, flags=flags, stream_id=stream_id) for i in xrange(0, len(payload), 6): identifier, value = struct.unpack("!HL", payload[i:i + 6]) @@ -420,10 +419,10 @@ class PushPromiseFrame(Frame): self.header_block_fragment = header_block_fragment @classmethod - def from_bytes(self, state, length, flags, stream_id, payload): - f = self(state=state, length=length, flags=flags, stream_id=stream_id) + def from_bytes(cls, state, length, flags, stream_id, payload): + f = cls(state=state, length=length, flags=flags, stream_id=stream_id) - if f.flags & self.FLAG_PADDED: + if f.flags & Frame.FLAG_PADDED: f.pad_length, f.promised_stream = struct.unpack('!BL', payload[:5]) f.header_block_fragment = payload[5:-f.pad_length] else: @@ -480,8 +479,8 @@ class PingFrame(Frame): self.payload = payload @classmethod - def from_bytes(self, state, length, flags, stream_id, payload): - f = self(state=state, length=length, flags=flags, stream_id=stream_id) + def from_bytes(cls, state, length, flags, stream_id, payload): + f = cls(state=state, length=length, flags=flags, stream_id=stream_id) f.payload = payload return f @@ -517,8 +516,8 @@ class GoAwayFrame(Frame): self.data = data @classmethod - def from_bytes(self, state, length, flags, stream_id, payload): - f = self(state=state, length=length, flags=flags, stream_id=stream_id) + def from_bytes(cls, state, length, flags, stream_id, payload): + f = cls(state=state, length=length, flags=flags, stream_id=stream_id) f.last_stream, f.error_code = struct.unpack("!LL", payload[:8]) f.last_stream &= 0x7FFFFFFF @@ -558,8 +557,8 @@ class WindowUpdateFrame(Frame): self.window_size_increment = window_size_increment @classmethod - def from_bytes(self, state, length, flags, stream_id, payload): - f = self(state=state, length=length, flags=flags, stream_id=stream_id) + def from_bytes(cls, state, length, flags, stream_id, payload): + f = cls(state=state, length=length, flags=flags, stream_id=stream_id) f.window_size_increment = struct.unpack("!L", payload)[0] f.window_size_increment &= 0x7FFFFFFF @@ -592,8 +591,8 @@ class ContinuationFrame(Frame): self.header_block_fragment = header_block_fragment @classmethod - def from_bytes(self, state, length, flags, stream_id, payload): - f = self(state=state, length=length, flags=flags, stream_id=stream_id) + def from_bytes(cls, state, length, flags, stream_id, payload): + f = cls(state=state, length=length, flags=flags, stream_id=stream_id) f.header_block_fragment = payload return f diff --git a/netlib/http_cookies.py b/netlib/http_cookies.py index 5cb39e5c..b7311714 100644 --- a/netlib/http_cookies.py +++ b/netlib/http_cookies.py @@ -158,7 +158,7 @@ def _parse_set_cookie_pairs(s): return pairs -def parse_set_cookie_header(str): +def parse_set_cookie_header(line): """ Parse a Set-Cookie header value @@ -166,7 +166,7 @@ def parse_set_cookie_header(str): ODictCaseless set of attributes. No attempt is made to parse attribute values - they are treated purely as strings. """ - pairs = _parse_set_cookie_pairs(str) + pairs = _parse_set_cookie_pairs(line) if pairs: return pairs[0][0], pairs[0][1], odict.ODictCaseless(pairs[1:]) @@ -180,12 +180,12 @@ def format_set_cookie_header(name, value, attrs): return _format_set_cookie_pairs(pairs) -def parse_cookie_header(str): +def parse_cookie_header(line): """ Parse a Cookie header value. Returns a (possibly empty) ODict object. """ - pairs, off = _read_pairs(str) + pairs, off = _read_pairs(line) return odict.ODict(pairs) diff --git a/netlib/http_uastrings.py b/netlib/http_uastrings.py index d9869531..c1ef557c 100644 --- a/netlib/http_uastrings.py +++ b/netlib/http_uastrings.py @@ -5,40 +5,42 @@ from __future__ import (absolute_import, print_function, division) kept reasonably current to reflect common usage. """ +# pylint: line-too-long + # A collection of (name, shortcut, string) tuples. UASTRINGS = [ ("android", "a", - "Mozilla/5.0 (Linux; U; Android 4.1.1; en-gb; Nexus 7 Build/JRO03D) AFL/01.04.02"), + "Mozilla/5.0 (Linux; U; Android 4.1.1; en-gb; Nexus 7 Build/JRO03D) AFL/01.04.02"), # noqa ("blackberry", "l", - "Mozilla/5.0 (BlackBerry; U; BlackBerry 9900; en) AppleWebKit/534.11+ (KHTML, like Gecko) Version/7.1.0.346 Mobile Safari/534.11+"), + "Mozilla/5.0 (BlackBerry; U; BlackBerry 9900; en) AppleWebKit/534.11+ (KHTML, like Gecko) Version/7.1.0.346 Mobile Safari/534.11+"), # noqa ("bingbot", "b", - "Mozilla/5.0 (compatible; bingbot/2.0; +http://www.bing.com/bingbot.htm)"), + "Mozilla/5.0 (compatible; bingbot/2.0; +http://www.bing.com/bingbot.htm)"), # noqa ("chrome", "c", - "Mozilla/5.0 (Windows NT 6.1; WOW64) AppleWebKit/537.1 (KHTML, like Gecko) Chrome/22.0.1207.1 Safari/537.1"), + "Mozilla/5.0 (Windows NT 6.1; WOW64) AppleWebKit/537.1 (KHTML, like Gecko) Chrome/22.0.1207.1 Safari/537.1"), # noqa ("firefox", "f", - "Mozilla/5.0 (Windows NT 6.1; Win64; x64; rv:14.0) Gecko/20120405 Firefox/14.0a1"), + "Mozilla/5.0 (Windows NT 6.1; Win64; x64; rv:14.0) Gecko/20120405 Firefox/14.0a1"), # noqa ("googlebot", "g", - "Googlebot/2.1 (+http://www.googlebot.com/bot.html)"), + "Googlebot/2.1 (+http://www.googlebot.com/bot.html)"), # noqa ("ie9", "i", - "Mozilla/5.0 (Windows; U; MSIE 9.0; WIndows NT 9.0; en-US))"), + "Mozilla/5.0 (Windows; U; MSIE 9.0; WIndows NT 9.0; en-US))"), # noqa ("ipad", "p", - "Mozilla/5.0 (iPad; CPU OS 5_1 like Mac OS X) AppleWebKit/534.46 (KHTML, like Gecko ) Version/5.1 Mobile/9B176 Safari/7534.48.3"), + "Mozilla/5.0 (iPad; CPU OS 5_1 like Mac OS X) AppleWebKit/534.46 (KHTML, like Gecko ) Version/5.1 Mobile/9B176 Safari/7534.48.3"), # noqa ("iphone", "h", - "Mozilla/5.0 (iPhone; CPU iPhone OS 4_2_1 like Mac OS X) AppleWebKit/533.17.9 (KHTML, like Gecko) Version/5.0.2 Mobile/8C148a Safari/6533.18.5", - ), + "Mozilla/5.0 (iPhone; CPU iPhone OS 4_2_1 like Mac OS X) AppleWebKit/533.17.9 (KHTML, like Gecko) Version/5.0.2 Mobile/8C148a Safari/6533.18.5"), # noqa ("safari", "s", - "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_7_3) AppleWebKit/534.55.3 (KHTML, like Gecko) Version/5.1.3 Safari/534.53.10")] + "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_7_3) AppleWebKit/534.55.3 (KHTML, like Gecko) Version/5.1.3 Safari/534.53.10"), # noqa +] def get_by_shortcut(s): diff --git a/netlib/tcp.py b/netlib/tcp.py index 953cef6e..807015c8 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -297,7 +297,7 @@ def close_socket(sock): """ try: # We already indicate that we close our end. - # may raise "Transport endpoint is not connected" on Linux + # may raise "Transport endpoint is not connected" on Linux sock.shutdown(socket.SHUT_WR) # Section 4.2.2.13 of RFC 1122 tells us that a close() with any pending @@ -368,10 +368,6 @@ class _Connection(object): except SSL.Error: pass - """ - Creates an SSL Context. - """ - def _create_ssl_context(self, method=SSLv23_METHOD, options=(OP_NO_SSLv2 | OP_NO_SSLv3), @@ -383,6 +379,8 @@ class _Connection(object): alpn_select=None, ): """ + Creates an SSL Context. + :param method: One of SSLv2_METHOD, SSLv3_METHOD, SSLv23_METHOD, TLSv1_METHOD, TLSv1_1_METHOD, or TLSv1_2_METHOD :param options: A bit field consisting of OpenSSL.SSL.OP_* values :param verify_options: A bit field consisting of OpenSSL.SSL.VERIFY_* values diff --git a/netlib/utils.py b/netlib/utils.py index 9c5404e6..ac42bd53 100644 --- a/netlib/utils.py +++ b/netlib/utils.py @@ -67,7 +67,7 @@ def getbit(byte, offset): return True -class BiDi: +class BiDi(object): """ A wee utility class for keeping bi-directional mappings, like field diff --git a/netlib/websockets.py b/netlib/websockets.py index 346adf1b..c45db4df 100644 --- a/netlib/websockets.py +++ b/netlib/websockets.py @@ -35,7 +35,7 @@ OPCODE = utils.BiDi( ) -class Masker: +class Masker(object): """ Data sent from the server must be masked to prevent malicious clients @@ -94,15 +94,15 @@ def server_handshake_headers(key): ) -def make_length_code(len): +def make_length_code(length): """ A websockets frame contains an initial length_code, and an optional extended length code to represent the actual length if length code is larger than 125 """ - if len <= 125: - return len - elif len >= 126 and len <= 65535: + if length <= 125: + return length + elif length >= 126 and length <= 65535: return 126 else: return 127 @@ -129,7 +129,7 @@ def create_server_nonce(client_nonce): DEFAULT = object() -class FrameHeader: +class FrameHeader(object): def __init__( self, @@ -216,7 +216,7 @@ class FrameHeader: return b @classmethod - def from_file(klass, fp): + def from_file(cls, fp): """ read a websockets frame header """ @@ -248,7 +248,7 @@ class FrameHeader: else: masking_key = None - return klass( + return cls( fin=fin, rsv1=rsv1, rsv2=rsv2, -- cgit v1.2.3 From 6e301f37d0597d86008c440f62526f906f0ae9f4 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Thu, 18 Jun 2015 12:18:22 +1200 Subject: Only set OP_NO_COMPRESSION by default if it exists in our version of OpenSSL We'll need to start testing under both new and old versions of OpenSSL somehow to catch these... --- netlib/tcp.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) (limited to 'netlib') diff --git a/netlib/tcp.py b/netlib/tcp.py index a1d1fe62..52ebc3c0 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -22,6 +22,17 @@ TLSv1_METHOD = SSL.TLSv1_METHOD TLSv1_1_METHOD = SSL.TLSv1_1_METHOD TLSv1_2_METHOD = SSL.TLSv1_2_METHOD + +SSL_DEFAULT_OPTIONS = ( + SSL.OP_NO_SSLv2 | + SSL.OP_NO_SSLv3 | + SSL.OP_CIPHER_SERVER_PREFERENCE +) + +if hasattr(SSL, "OP_NO_COMPRESSION"): + SSL_DEFAULT_OPTIONS |= SSL.OP_NO_COMPRESSION + + class NetLibError(Exception): pass @@ -365,7 +376,7 @@ class _Connection(object): def _create_ssl_context(self, method=SSLv23_METHOD, - options=(SSL.OP_NO_SSLv2 | SSL.OP_NO_SSLv3 | SSL.OP_CIPHER_SERVER_PREFERENCE | SSL.OP_NO_COMPRESSION), + options=SSL_DEFAULT_OPTIONS, verify_options=SSL.VERIFY_NONE, ca_path=None, ca_pemfile=None, -- cgit v1.2.3 From 69e71097f7a9633a43d566b2a46aab370f07dce3 Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Thu, 18 Jun 2015 15:32:52 +0200 Subject: mark unused variables and arguments --- netlib/certutils.py | 2 +- netlib/http2/frame.py | 3 ++- netlib/http2/protocol.py | 15 +++++++-------- netlib/http_auth.py | 9 +++++---- netlib/http_cookies.py | 9 +++------ netlib/tcp.py | 10 +++++----- netlib/wsgi.py | 2 +- 7 files changed, 24 insertions(+), 26 deletions(-) (limited to 'netlib') diff --git a/netlib/certutils.py b/netlib/certutils.py index ade61bb5..c6f0e628 100644 --- a/netlib/certutils.py +++ b/netlib/certutils.py @@ -333,7 +333,7 @@ class CertStore(object): return entry.cert, entry.privatekey, entry.chain_file - def gen_pkey(self, cert): + def gen_pkey(self, cert_): # FIXME: We should do something with cert here? from . import certffi certffi.set_flags(self.default_privatekey, 1) diff --git a/netlib/http2/frame.py b/netlib/http2/frame.py index b4783a02..f7e60471 100644 --- a/netlib/http2/frame.py +++ b/netlib/http2/frame.py @@ -116,7 +116,8 @@ class Frame(object): self.length = len(self.payload_bytes()) return "\n".join([ - "%s: %s | length: %d | flags: %#x | stream_id: %d" % (direction, self.__class__.__name__, self.length, self.flags, self.stream_id), + "%s: %s | length: %d | flags: %#x | stream_id: %d" % ( + direction, self.__class__.__name__, self.length, self.flags, self.stream_id), self.payload_human_readable(), "===============================================================", ]) diff --git a/netlib/http2/protocol.py b/netlib/http2/protocol.py index ac89bac4..8e5f5429 100644 --- a/netlib/http2/protocol.py +++ b/netlib/http2/protocol.py @@ -59,8 +59,8 @@ class HTTP2Protocol(object): while True: frm = self.read_frame(hide) if isinstance(frm, frame.SettingsFrame): - assert settings_ack_frame.flags & frame.Frame.FLAG_ACK - assert len(settings_ack_frame.settings) == 0 + assert frm.flags & frame.Frame.FLAG_ACK + assert len(frm.settings) == 0 break def perform_server_connection_preface(self, force=False): @@ -118,11 +118,10 @@ class HTTP2Protocol(object): old_value = '-' self.http2_settings[setting] = value - self.send_frame( - frame.SettingsFrame( - state=self, - flags=frame.Frame.FLAG_ACK), - hide) + frm = frame.SettingsFrame( + state=self, + flags=frame.Frame.FLAG_ACK) + self.send_frame(frm, hide) # be liberal in what we expect from the other end # to be more strict use: self._read_settings_ack(hide) @@ -188,7 +187,7 @@ class HTTP2Protocol(object): self._create_body(body, stream_id))) def read_response(self): - stream_id, headers, body = self._receive_transmission() + stream_id_, headers, body = self._receive_transmission() return headers[':status'], headers, body def read_request(self): diff --git a/netlib/http_auth.py b/netlib/http_auth.py index 0143760c..adab4aed 100644 --- a/netlib/http_auth.py +++ b/netlib/http_auth.py @@ -12,12 +12,13 @@ class NullProxyAuth(object): def __init__(self, password_manager): self.password_manager = password_manager - def clean(self, headers): + def clean(self, headers_): """ Clean up authentication headers, so they're not passed upstream. """ + pass - def authenticate(self, headers): + def authenticate(self, headers_): """ Tests that the user is allowed to use the proxy """ @@ -62,7 +63,7 @@ class BasicProxyAuth(NullProxyAuth): class PassMan(object): - def test(self, username, password_token): + def test(self, username_, password_token_): return False @@ -72,7 +73,7 @@ class PassManNonAnon(PassMan): Ensure the user specifies a username, accept any password. """ - def test(self, username, password_token): + def test(self, username, password_token_): if username: return True return False diff --git a/netlib/http_cookies.py b/netlib/http_cookies.py index b7311714..e91ee5c0 100644 --- a/netlib/http_cookies.py +++ b/netlib/http_cookies.py @@ -87,7 +87,7 @@ def _read_value(s, start, delims): return _read_until(s, start, delims) -def _read_pairs(s, off=0, specials=()): +def _read_pairs(s, off=0): """ Read pairs of lhs=rhs values. @@ -151,10 +151,7 @@ def _parse_set_cookie_pairs(s): For Set-Cookie, we support multiple cookies as described in RFC2109. This function therefore returns a list of lists. """ - pairs, off = _read_pairs( - s, - specials=("expires", "path") - ) + pairs, off_ = _read_pairs(s) return pairs @@ -185,7 +182,7 @@ def parse_cookie_header(line): Parse a Cookie header value. Returns a (possibly empty) ODict object. """ - pairs, off = _read_pairs(line) + pairs, off_ = _read_pairs(line) return odict.ODict(pairs) diff --git a/netlib/tcp.py b/netlib/tcp.py index 65075776..77eb7b52 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -403,7 +403,7 @@ class _Connection(object): # Verify Options (NONE/PEER/PEER|FAIL_IF_... and trusted CAs) if verify_options is not None and verify_options is not SSL.VERIFY_NONE: - def verify_cert(conn, cert, errno, err_depth, is_cert_verified): + def verify_cert(conn_, cert_, errno, err_depth, is_cert_verified): if is_cert_verified: return True raise NetLibError( @@ -439,7 +439,7 @@ class _Connection(object): context.set_alpn_protos(alpn_protos) elif alpn_select is not None: # select application layer protocol - def alpn_select_callback(conn, options): + def alpn_select_callback(conn_, options): if alpn_select in options: return bytes(alpn_select) else: # pragma no cover @@ -601,7 +601,7 @@ class BaseHandler(_Connection): context.set_tlsext_servername_callback(handle_sni) if request_client_cert: - def save_cert(conn, cert, errno, depth, preverify_ok): + def save_cert(conn_, cert, errno_, depth_, preverify_ok_): self.clientcert = certutils.SSLCert(cert) # Return true to prevent cert verification error return True @@ -676,7 +676,7 @@ class TCPServer(object): try: while not self.__shutdown_request: try: - r, w, e = select.select( + r, w_, e_ = select.select( [self.socket], [], [], poll_interval) except select.error as ex: # pragma: no cover if ex[0] == EINTR: @@ -708,7 +708,7 @@ class TCPServer(object): self.socket.close() self.handle_shutdown() - def handle_error(self, connection, client_address, fp=sys.stderr): + def handle_error(self, connection_, client_address, fp=sys.stderr): """ Called when handle_client_connection raises an exception. """ diff --git a/netlib/wsgi.py b/netlib/wsgi.py index 827cf6f0..ad43dc19 100644 --- a/netlib/wsgi.py +++ b/netlib/wsgi.py @@ -35,7 +35,7 @@ def date_time_string(): 'Jul', 'Aug', 'Sep', 'Oct', 'Nov', 'Dec' ] now = time.time() - year, month, day, hh, mm, ss, wd, y, z = time.gmtime(now) + year, month, day, hh, mm, ss, wd, y_, z_ = time.gmtime(now) s = "%s, %02d %3s %4d %02d:%02d:%02d GMT" % ( WEEKS[wd], day, MONTHS[month], year, -- cgit v1.2.3 From f5c5deb2aea047394238f3b993ddf24c60845768 Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Thu, 18 Jun 2015 17:36:58 +0200 Subject: fix http user agents --- netlib/http_uastrings.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'netlib') diff --git a/netlib/http_uastrings.py b/netlib/http_uastrings.py index c1ef557c..e8681908 100644 --- a/netlib/http_uastrings.py +++ b/netlib/http_uastrings.py @@ -30,10 +30,10 @@ UASTRINGS = [ "Googlebot/2.1 (+http://www.googlebot.com/bot.html)"), # noqa ("ie9", "i", - "Mozilla/5.0 (Windows; U; MSIE 9.0; WIndows NT 9.0; en-US))"), # noqa + "Mozilla/5.0 (Windows; U; MSIE 9.0; WIndows NT 9.0; en-US)"), # noqa ("ipad", "p", - "Mozilla/5.0 (iPad; CPU OS 5_1 like Mac OS X) AppleWebKit/534.46 (KHTML, like Gecko ) Version/5.1 Mobile/9B176 Safari/7534.48.3"), # noqa + "Mozilla/5.0 (iPad; CPU OS 5_1 like Mac OS X) AppleWebKit/534.46 (KHTML, like Gecko) Version/5.1 Mobile/9B176 Safari/7534.48.3"), # noqa ("iphone", "h", "Mozilla/5.0 (iPhone; CPU iPhone OS 4_2_1 like Mac OS X) AppleWebKit/533.17.9 (KHTML, like Gecko) Version/5.0.2 Mobile/8C148a Safari/6533.18.5"), # noqa -- cgit v1.2.3 From 2aa1b98fbf8d03005e022da86e3e534cf25ebf62 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Mon, 22 Jun 2015 14:52:23 +1200 Subject: netlib/test.py -> test/tservers.py --- netlib/test.py | 108 --------------------------------------------------------- 1 file changed, 108 deletions(-) delete mode 100644 netlib/test.py (limited to 'netlib') diff --git a/netlib/test.py b/netlib/test.py deleted file mode 100644 index 1e1b5e9d..00000000 --- a/netlib/test.py +++ /dev/null @@ -1,108 +0,0 @@ -from __future__ import (absolute_import, print_function, division) -import threading -import Queue -import cStringIO -import OpenSSL -from . import tcp, certutils -from test import tutils - - -class ServerThread(threading.Thread): - - def __init__(self, server): - self.server = server - threading.Thread.__init__(self) - - def run(self): - self.server.serve_forever() - - def shutdown(self): - self.server.shutdown() - - -class ServerTestBase(object): - ssl = None - handler = None - addr = ("localhost", 0) - - @classmethod - def setupAll(cls): - cls.q = Queue.Queue() - s = cls.makeserver() - cls.port = s.address.port - cls.server = ServerThread(s) - cls.server.start() - - @classmethod - def makeserver(cls): - return TServer(cls.ssl, cls.q, cls.handler, cls.addr) - - @classmethod - def teardownAll(cls): - cls.server.shutdown() - - @property - def last_handler(self): - return self.server.server.last_handler - - -class TServer(tcp.TCPServer): - - def __init__(self, ssl, q, handler_klass, addr): - """ - ssl: A dictionary of SSL parameters: - - cert, key, request_client_cert, cipher_list, - dhparams, v3_only - """ - tcp.TCPServer.__init__(self, addr) - - if ssl is True: - self.ssl = dict() - elif isinstance(ssl, dict): - self.ssl = ssl - else: - self.ssl = None - - self.q = q - self.handler_klass = handler_klass - self.last_handler = None - - def handle_client_connection(self, request, client_address): - h = self.handler_klass(request, client_address, self) - self.last_handler = h - if self.ssl is not None: - raw_cert = self.ssl.get( - "cert", - tutils.test_data.path("data/server.crt")) - cert = certutils.SSLCert.from_pem(open(raw_cert, "rb").read()) - raw_key = self.ssl.get( - "key", - tutils.test_data.path("data/server.key")) - key = OpenSSL.crypto.load_privatekey( - OpenSSL.crypto.FILETYPE_PEM, - open(raw_key, "rb").read()) - if self.ssl.get("v3_only", False): - method = tcp.SSLv3_METHOD - options = OpenSSL.SSL.OP_NO_SSLv2 | OpenSSL.SSL.OP_NO_TLSv1 - else: - method = tcp.SSLv23_METHOD - options = None - h.convert_to_ssl( - cert, key, - method=method, - options=options, - handle_sni=getattr(h, "handle_sni", None), - request_client_cert=self.ssl.get("request_client_cert", None), - cipher_list=self.ssl.get("cipher_list", None), - dhparams=self.ssl.get("dhparams", None), - chain_file=self.ssl.get("chain_file", None), - alpn_select=self.ssl.get("alpn_select", None) - ) - h.handle() - h.finish() - - def handle_error(self, connection, client_address, fp=None): - s = cStringIO.StringIO() - tcp.TCPServer.handle_error(self, connection, client_address, s) - self.q.put(s.getvalue()) -- cgit v1.2.3 From 58118d607e810e95fe8a0c0e6d7b8f4423f1f558 Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Mon, 22 Jun 2015 20:39:30 +0200 Subject: unify SSL version/method handling --- netlib/tcp.py | 25 ++++++++++++++++++------- 1 file changed, 18 insertions(+), 7 deletions(-) (limited to 'netlib') diff --git a/netlib/tcp.py b/netlib/tcp.py index 77eb7b52..705cc311 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -16,13 +16,24 @@ from . import certutils EINTR = 4 -SSLv2_METHOD = SSL.SSLv2_METHOD -SSLv3_METHOD = SSL.SSLv3_METHOD -SSLv23_METHOD = SSL.SSLv23_METHOD -TLSv1_METHOD = SSL.TLSv1_METHOD -TLSv1_1_METHOD = SSL.TLSv1_1_METHOD -TLSv1_2_METHOD = SSL.TLSv1_2_METHOD +# To enable all SSL methods use: SSLv23 +# then add options to disable certain methods +# https://bugs.launchpad.net/pyopenssl/+bug/1020632/comments/3 +# Use ONLY for parsing of CLI arguments! +# All code internals should use OpenSSL constants directly! +SSL_VERSIONS = { + 'TLSv1.2': SSL.TLSv1_2_METHOD, + 'TLSv1.1': SSL.TLSv1_1_METHOD, + 'TLSv1': SSL.TLSv1_METHOD, + 'SSLv3': SSL.SSLv3_METHOD, + 'SSLv2': SSL.SSLv2_METHOD, + 'SSLv23': SSL.SSLv23_METHOD, +} + +SSL_DEFAULT_VERSION = 'SSLv23' + +SSL_DEFAULT_METHOD = SSL_VERSIONS[SSL_DEFAULT_VERSION] SSL_DEFAULT_OPTIONS = ( SSL.OP_NO_SSLv2 | @@ -376,7 +387,7 @@ class _Connection(object): pass def _create_ssl_context(self, - method=SSLv23_METHOD, + method=SSL_DEFAULT_METHOD, options=SSL_DEFAULT_OPTIONS, verify_options=SSL.VERIFY_NONE, ca_path=certifi.where(), -- cgit v1.2.3 From 7afe44ba4ee8810e24abfa32f74dfac61e5551d3 Mon Sep 17 00:00:00 2001 From: Kyle Morton Date: Sat, 20 Jun 2015 12:54:03 -0700 Subject: Updating TCPServer to allow tests (and potentially other use cases) to serve certificate chains instead of only single certificates. --- netlib/tcp.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) (limited to 'netlib') diff --git a/netlib/tcp.py b/netlib/tcp.py index 77eb7b52..61306e4e 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -567,7 +567,8 @@ class BaseHandler(_Connection): dhparams=None, **sslctx_kwargs): """ - cert: A certutils.SSLCert object. + cert: A certutils.SSLCert object or the path to a certificate + chain file. handle_sni: SNI handler, should take a connection object. Server name can be retrieved like this: @@ -594,7 +595,10 @@ class BaseHandler(_Connection): context = self._create_ssl_context(**sslctx_kwargs) context.use_privatekey(key) - context.use_certificate(cert.x509) + if isinstance(cert, certutils.SSLCert): + context.use_certificate(cert.x509) + else: + context.use_certificate_chain_file(cert) if handle_sni: # SNI callback happens during do_handshake() -- cgit v1.2.3 From d1452424beced04dc42bbadd68878d9e1c24da9c Mon Sep 17 00:00:00 2001 From: Kyle Morton Date: Sat, 20 Jun 2015 13:07:23 -0700 Subject: Cleaning up upstream server verification. Adding storage of cerificate verification errors on TCPClient object to enable warnings in downstream projects. --- netlib/tcp.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) (limited to 'netlib') diff --git a/netlib/tcp.py b/netlib/tcp.py index 61306e4e..2cae34ec 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -401,14 +401,13 @@ class _Connection(object): if options is not None: context.set_options(options) - # Verify Options (NONE/PEER/PEER|FAIL_IF_... and trusted CAs) - if verify_options is not None and verify_options is not SSL.VERIFY_NONE: - def verify_cert(conn_, cert_, errno, err_depth, is_cert_verified): - if is_cert_verified: - return True - raise NetLibError( - "Upstream certificate validation failed at depth: %s with error number: %s" % - (err_depth, errno)) + # Verify Options (NONE/PEER and trusted CAs) + if verify_options is not None: + def verify_cert(conn, x509, errno, err_depth, is_cert_verified): + if not is_cert_verified: + self.ssl_verification_error = dict(errno=errno, + depth=err_depth) + return is_cert_verified context.set_verify(verify_options, verify_cert) context.load_verify_locations(ca_pemfile, ca_path) @@ -469,6 +468,7 @@ class TCPClient(_Connection): self.connection, self.rfile, self.wfile = None, None, None self.cert = None self.ssl_established = False + self.ssl_verification_error = None self.sni = None def create_ssl_context(self, cert=None, alpn_protos=None, **sslctx_kwargs): -- cgit v1.2.3 From 239f4758afa65995769e896d8f4faa9e12414d28 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Tue, 23 Jun 2015 22:16:03 +1200 Subject: Remove dependence on pathod in test suite. --- netlib/utils.py | 21 ++++++++++++++++++++- 1 file changed, 20 insertions(+), 1 deletion(-) (limited to 'netlib') diff --git a/netlib/utils.py b/netlib/utils.py index ac42bd53..bee412f9 100644 --- a/netlib/utils.py +++ b/netlib/utils.py @@ -1,5 +1,5 @@ from __future__ import (absolute_import, print_function, division) - +import os.path def isascii(s): try: @@ -110,3 +110,22 @@ def pretty_size(size): if x == int(x): x = int(x) return str(x) + suf + + +class Data(object): + def __init__(self, name): + m = __import__(name) + dirname, _ = os.path.split(m.__file__) + self.dirname = os.path.abspath(dirname) + + def path(self, path): + """ + Returns a path to the package data housed at 'path' under this + module.Path can be a path to a file, or to a directory. + + This function will raise ValueError if the path does not exist. + """ + fullpath = os.path.join(self.dirname, path) + if not os.path.exists(fullpath): + raise ValueError("dataPath: %s does not exist." % fullpath) + return fullpath -- cgit v1.2.3 From 41925b01f71831c33424d5cd9e612d003b99a69d Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Thu, 25 Jun 2015 10:37:01 +1200 Subject: Fix printing of SSL version error Fixes #73 --- netlib/version_check.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'netlib') diff --git a/netlib/version_check.py b/netlib/version_check.py index 09dc23ae..df1612a2 100644 --- a/netlib/version_check.py +++ b/netlib/version_check.py @@ -33,7 +33,7 @@ def version_check( if v < pyopenssl_min_version: print( "You are using an outdated version of pyOpenSSL:" - " mitmproxy requires pyOpenSSL %x or greater." % + " mitmproxy requires pyOpenSSL %s or greater." % pyopenssl_min_version, file=fp ) -- cgit v1.2.3 From 2723a0e5739412953f60c37d0dab81d684ba5f26 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Fri, 26 Jun 2015 13:26:35 +0200 Subject: remove certffi --- netlib/certffi.py | 41 ----------------------------------------- netlib/certutils.py | 6 ------ 2 files changed, 47 deletions(-) delete mode 100644 netlib/certffi.py (limited to 'netlib') diff --git a/netlib/certffi.py b/netlib/certffi.py deleted file mode 100644 index 451f4493..00000000 --- a/netlib/certffi.py +++ /dev/null @@ -1,41 +0,0 @@ -from __future__ import (absolute_import, print_function, division) -from cffi import FFI -import OpenSSL - -xffi = FFI() -xffi.cdef(""" - struct rsa_meth_st { - int flags; - ...; - }; - struct rsa_st { - int pad; - long version; - struct rsa_meth_st *meth; - ...; - }; -""") -xffi.verify( - """#include """, - extra_compile_args=['-w'] -) - - -def handle(privkey): - new = xffi.new("struct rsa_st*") - newbuf = xffi.buffer(new) - rsa = OpenSSL.SSL._lib.EVP_PKEY_get1_RSA(privkey._pkey) - oldbuf = OpenSSL.SSL._ffi.buffer(rsa) - newbuf[:] = oldbuf[:] - return new - - -def set_flags(privkey, val): - hdl = handle(privkey) - hdl.meth.flags = val - return privkey - - -def get_flags(privkey): - hdl = handle(privkey) - return hdl.meth.flags diff --git a/netlib/certutils.py b/netlib/certutils.py index c6f0e628..c699af00 100644 --- a/netlib/certutils.py +++ b/netlib/certutils.py @@ -333,12 +333,6 @@ class CertStore(object): return entry.cert, entry.privatekey, entry.chain_file - def gen_pkey(self, cert_): - # FIXME: We should do something with cert here? - from . import certffi - certffi.set_flags(self.default_privatekey, 1) - return self.default_privatekey - class _GeneralName(univ.Choice): # We are only interested in dNSNames. We use a default handler to ignore -- cgit v1.2.3 From 0a2b25187faea1fa29a3b21935cd55294b173bf8 Mon Sep 17 00:00:00 2001 From: Kyle Morton Date: Fri, 26 Jun 2015 14:57:00 -0700 Subject: Fixing how certifi is made the default ca_path to simplify calling logic. --- netlib/tcp.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) (limited to 'netlib') diff --git a/netlib/tcp.py b/netlib/tcp.py index 74a275c9..38b77c9e 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -390,7 +390,7 @@ class _Connection(object): method=SSL_DEFAULT_METHOD, options=SSL_DEFAULT_OPTIONS, verify_options=SSL.VERIFY_NONE, - ca_path=certifi.where(), + ca_path=None, ca_pemfile=None, cipher_list=None, alpn_protos=None, @@ -421,6 +421,8 @@ class _Connection(object): return is_cert_verified context.set_verify(verify_options, verify_cert) + if ca_path is None and ca_pemfile is None: + ca_path = certifi.where() context.load_verify_locations(ca_pemfile, ca_path) # Workaround for -- cgit v1.2.3 From 9aaf10120d08e12e7aa82fc2184ca7faa35349c3 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Fri, 3 Jul 2015 02:01:30 +0200 Subject: socks: add assert_socks5 method --- netlib/socks.py | 41 ++++++++++++++++++++++++++++++++++------- 1 file changed, 34 insertions(+), 7 deletions(-) (limited to 'netlib') diff --git a/netlib/socks.py b/netlib/socks.py index 5a73c61a..eef98f5c 100644 --- a/netlib/socks.py +++ b/netlib/socks.py @@ -6,7 +6,6 @@ from . import tcp, utils class SocksError(Exception): - def __init__(self, code, message): super(SocksError, self).__init__(message) self.code = code @@ -17,21 +16,18 @@ VERSION = utils.BiDi( SOCKS5=0x05 ) - CMD = utils.BiDi( CONNECT=0x01, BIND=0x02, UDP_ASSOCIATE=0x03 ) - ATYP = utils.BiDi( IPV4_ADDRESS=0x01, DOMAINNAME=0x03, IPV6_ADDRESS=0x04 ) - REP = utils.BiDi( SUCCEEDED=0x00, GENERAL_SOCKS_SERVER_FAILURE=0x01, @@ -44,7 +40,6 @@ REP = utils.BiDi( ADDRESS_TYPE_NOT_SUPPORTED=0x08, ) - METHOD = utils.BiDi( NO_AUTHENTICATION_REQUIRED=0x00, GSSAPI=0x01, @@ -58,14 +53,27 @@ class ClientGreeting(object): def __init__(self, ver, methods): self.ver = ver - self.methods = methods + self.methods = array.array("B") + self.methods.extend(methods) + + def assert_socks5(self): + if self.ver != VERSION.SOCKS5: + if self.ver == ord("G") and len(self.methods) == ord("E"): + guess = "Probably not a SOCKS request but a regular HTTP request. " + else: + guess = "" + + raise SocksError( + REP.GENERAL_SOCKS_SERVER_FAILURE, + guess + "Invalid SOCKS version. Expected 0x05, got 0x%x" % self.ver + ) @classmethod def from_file(cls, f): ver, nmethods = struct.unpack("!BB", f.safe_read(2)) methods = array.array("B") methods.fromstring(f.safe_read(nmethods)) - return cls(ver, methods) + return cls(ver, methods.tolist()) def to_file(self, f): f.write(struct.pack("!BB", self.ver, len(self.methods))) @@ -79,6 +87,18 @@ class ServerGreeting(object): self.ver = ver self.method = method + def assert_socks5(self): + if self.ver != VERSION.SOCKS5: + if self.ver == ord("H") and self.method == ord("T"): + guess = "Probably not a SOCKS request but a regular HTTP response. " + else: + guess = "" + + raise SocksError( + REP.GENERAL_SOCKS_SERVER_FAILURE, + guess + "Invalid SOCKS version. Expected 0x05, got 0x%x" % self.ver + ) + @classmethod def from_file(cls, f): ver, method = struct.unpack("!BB", f.safe_read(2)) @@ -97,6 +117,13 @@ class Message(object): self.atyp = atyp self.addr = addr + def assert_socks5(self): + if self.ver != VERSION.SOCKS5: + raise SocksError( + REP.GENERAL_SOCKS_SERVER_FAILURE, + "Invalid SOCKS version. Expected 0x05, got 0x%x" % self.ver + ) + @classmethod def from_file(cls, f): ver, msg, rsv, atyp = struct.unpack("!BBBB", f.safe_read(4)) -- cgit v1.2.3 From 880c66fe48c5a6bb4779a8149a3551f007ff5b09 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Fri, 3 Jul 2015 02:45:12 +0200 Subject: socks: optionally fail early --- netlib/socks.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) (limited to 'netlib') diff --git a/netlib/socks.py b/netlib/socks.py index eef98f5c..d38b88c8 100644 --- a/netlib/socks.py +++ b/netlib/socks.py @@ -69,11 +69,16 @@ class ClientGreeting(object): ) @classmethod - def from_file(cls, f): + def from_file(cls, f, fail_early=False): + """ + :param fail_early: If true, a SocksError will be raised if the first byte does not indicate socks5. + """ ver, nmethods = struct.unpack("!BB", f.safe_read(2)) - methods = array.array("B") - methods.fromstring(f.safe_read(nmethods)) - return cls(ver, methods.tolist()) + client_greeting = cls(ver, []) + if fail_early: + client_greeting.assert_socks5() + client_greeting.methods.fromstring(f.safe_read(nmethods)) + return client_greeting def to_file(self, f): f.write(struct.pack("!BB", self.ver, len(self.methods))) @@ -115,7 +120,7 @@ class Message(object): self.ver = ver self.msg = msg self.atyp = atyp - self.addr = addr + self.addr = tcp.Address.wrap(addr) def assert_socks5(self): if self.ver != VERSION.SOCKS5: -- cgit v1.2.3 From 397b3bba5e718da8fca7131d5e1823c4ce5363ca Mon Sep 17 00:00:00 2001 From: "M. Utku Altinkaya" Date: Tue, 21 Jul 2015 13:17:46 +0300 Subject: Fixed version error formatting issue --- netlib/version_check.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'netlib') diff --git a/netlib/version_check.py b/netlib/version_check.py index df1612a2..2081c410 100644 --- a/netlib/version_check.py +++ b/netlib/version_check.py @@ -34,7 +34,7 @@ def version_check( print( "You are using an outdated version of pyOpenSSL:" " mitmproxy requires pyOpenSSL %s or greater." % - pyopenssl_min_version, + str(pyopenssl_min_version), file=fp ) # Some users apparently have multiple versions of pyOpenSSL installed. -- cgit v1.2.3 From 9fdc412fa043072f44eddec0b07659c161e4ca90 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Wed, 22 Jul 2015 00:17:05 +0200 Subject: bump version --- netlib/version.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'netlib') diff --git a/netlib/version.py b/netlib/version.py index bc9a1a57..ba426d74 100644 --- a/netlib/version.py +++ b/netlib/version.py @@ -1,6 +1,6 @@ from __future__ import (absolute_import, print_function, division) -IVERSION = (0, 12, 2) +IVERSION = (0, 13) VERSION = ".".join(str(i) for i in IVERSION) MINORVERSION = ".".join(str(i) for i in IVERSION[:2]) NAME = "netlib" -- cgit v1.2.3 From 155bdeb12352065bc36256ba8014003480361a0c Mon Sep 17 00:00:00 2001 From: Kyle Morton Date: Tue, 21 Jul 2015 18:01:51 -0700 Subject: Fixing default CA which ought to be read as a pemfile and not a directory --- netlib/tcp.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'netlib') diff --git a/netlib/tcp.py b/netlib/tcp.py index 38b77c9e..47ce8c0e 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -422,7 +422,7 @@ class _Connection(object): context.set_verify(verify_options, verify_cert) if ca_path is None and ca_pemfile is None: - ca_path = certifi.where() + ca_pemfile = certifi.where() context.load_verify_locations(ca_pemfile, ca_path) # Workaround for -- cgit v1.2.3 From c17af4162b5a2946c4bf53bf1d17fca41dc68da7 Mon Sep 17 00:00:00 2001 From: Kyle Morton Date: Tue, 21 Jul 2015 19:06:20 -0700 Subject: Added a fix for pre-1.0 OpenSSL which wasn't correctly erring on failed certificate validation --- netlib/tcp.py | 7 +++++++ 1 file changed, 7 insertions(+) (limited to 'netlib') diff --git a/netlib/tcp.py b/netlib/tcp.py index 47ce8c0e..5c4094d7 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -518,6 +518,13 @@ class TCPClient(_Connection): self.connection.do_handshake() except SSL.Error as v: raise NetLibError("SSL handshake error: %s" % repr(v)) + + # Fix for pre v1.0 OpenSSL, which doesn't throw an exception on + # certificate validation failure + verification_mode = sslctx_kwargs.get('verify_options', None) + if self.ssl_verification_error is not None and verification_mode == SSL.VERIFY_PEER: + raise NetLibError("SSL handshake error: certificate verify failed") + self.ssl_established = True self.cert = certutils.SSLCert(self.connection.get_peer_certificate()) self.rfile.set_descriptor(self.connection) -- cgit v1.2.3 From e316a9cdb44444667e26938f8c1c3969e56c2f0e Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Wed, 22 Jul 2015 13:39:48 +0200 Subject: bump version --- netlib/version.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'netlib') diff --git a/netlib/version.py b/netlib/version.py index ba426d74..de42ace1 100644 --- a/netlib/version.py +++ b/netlib/version.py @@ -1,6 +1,6 @@ from __future__ import (absolute_import, print_function, division) -IVERSION = (0, 13) +IVERSION = (0, 13, 1) VERSION = ".".join(str(i) for i in IVERSION) MINORVERSION = ".".join(str(i) for i in IVERSION[:2]) NAME = "netlib" -- cgit v1.2.3 From 6dcfc35011208f4bfde7f37a63d7b980f6c41ce0 Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Wed, 8 Jul 2015 09:20:25 +0200 Subject: introduce http_semantics module used for generic HTTP representation everything should apply for HTTP/1 and HTTP/2 --- netlib/http.py | 16 ++-------------- netlib/http_semantics.py | 23 +++++++++++++++++++++++ 2 files changed, 25 insertions(+), 14 deletions(-) create mode 100644 netlib/http_semantics.py (limited to 'netlib') diff --git a/netlib/http.py b/netlib/http.py index a2af9e49..073e9a3f 100644 --- a/netlib/http.py +++ b/netlib/http.py @@ -4,7 +4,7 @@ import string import urlparse import binascii import sys -from . import odict, utils, tcp, http_status +from . import odict, utils, tcp, http_semantics, http_status class HttpError(Exception): @@ -527,18 +527,6 @@ def read_request(rfile, include_body=True, body_size_limit=None, wfile=None): ) -Response = collections.namedtuple( - "Response", - [ - "httpversion", - "code", - "msg", - "headers", - "content" - ] -) - - def read_response(rfile, request_method, body_size_limit, include_body=True): """ Return an (httpversion, code, msg, headers, content) tuple. @@ -580,7 +568,7 @@ def read_response(rfile, request_method, body_size_limit, include_body=True): # if include_body==False then a None content means the body should be # read separately content = None - return Response(httpversion, code, msg, headers, content) + return http_semantics.Response(httpversion, code, msg, headers, content) def request_preamble(method, resource, http_major="1", http_minor="1"): diff --git a/netlib/http_semantics.py b/netlib/http_semantics.py new file mode 100644 index 00000000..e8313e3c --- /dev/null +++ b/netlib/http_semantics.py @@ -0,0 +1,23 @@ +class Response(object): + + def __init__( + self, + httpversion, + status_code, + msg, + headers, + content, + sslinfo=None, + ): + self.httpversion = httpversion + self.status_code = status_code + self.msg = msg + self.headers = headers + self.content = content + self.sslinfo = sslinfo + + def __eq__(self, other): + return self.__dict__ == other.__dict__ + + def __repr__(self): + return "Response(%s - %s)" % (self.status_code, self.msg) -- cgit v1.2.3 From bd5ee212840e3be731ea93e14ef1375745383d88 Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Wed, 8 Jul 2015 09:34:10 +0200 Subject: refactor websockets into protocol --- netlib/websockets.py | 381 ------------------------------------------ netlib/websockets/__init__.py | 2 + netlib/websockets/frame.py | 288 +++++++++++++++++++++++++++++++ netlib/websockets/protocol.py | 111 ++++++++++++ 4 files changed, 401 insertions(+), 381 deletions(-) delete mode 100644 netlib/websockets.py create mode 100644 netlib/websockets/__init__.py create mode 100644 netlib/websockets/frame.py create mode 100644 netlib/websockets/protocol.py (limited to 'netlib') diff --git a/netlib/websockets.py b/netlib/websockets.py deleted file mode 100644 index c45db4df..00000000 --- a/netlib/websockets.py +++ /dev/null @@ -1,381 +0,0 @@ -from __future__ import absolute_import -import base64 -import hashlib -import os -import struct -import io - -from . import utils, odict, tcp - -# Colleciton of utility functions that implement small portions of the RFC6455 -# WebSockets Protocol Useful for building WebSocket clients and servers. -# -# Emphassis is on readabilty, simplicity and modularity, not performance or -# completeness -# -# This is a work in progress and does not yet contain all the utilites need to -# create fully complient client/servers # -# Spec: https://tools.ietf.org/html/rfc6455 - -# The magic sha that websocket servers must know to prove they understand -# RFC6455 -websockets_magic = '258EAFA5-E914-47DA-95CA-C5AB0DC85B11' -VERSION = "13" -MAX_16_BIT_INT = (1 << 16) -MAX_64_BIT_INT = (1 << 64) - - -OPCODE = utils.BiDi( - CONTINUE=0x00, - TEXT=0x01, - BINARY=0x02, - CLOSE=0x08, - PING=0x09, - PONG=0x0a -) - - -class Masker(object): - - """ - Data sent from the server must be masked to prevent malicious clients - from sending data over the wire in predictable patterns - - Servers do not have to mask data they send to the client. - https://tools.ietf.org/html/rfc6455#section-5.3 - """ - - def __init__(self, key): - self.key = key - self.masks = [utils.bytes_to_int(byte) for byte in key] - self.offset = 0 - - def mask(self, offset, data): - result = "" - for c in data: - result += chr(ord(c) ^ self.masks[offset % 4]) - offset += 1 - return result - - def __call__(self, data): - ret = self.mask(self.offset, data) - self.offset += len(ret) - return ret - - -def client_handshake_headers(key=None, version=VERSION): - """ - Create the headers for a valid HTTP upgrade request. If Key is not - specified, it is generated, and can be found in sec-websocket-key in - the returned header set. - - Returns an instance of ODictCaseless - """ - if not key: - key = base64.b64encode(os.urandom(16)).decode('utf-8') - return odict.ODictCaseless([ - ('Connection', 'Upgrade'), - ('Upgrade', 'websocket'), - ('Sec-WebSocket-Key', key), - ('Sec-WebSocket-Version', version) - ]) - - -def server_handshake_headers(key): - """ - The server response is a valid HTTP 101 response. - """ - return odict.ODictCaseless( - [ - ('Connection', 'Upgrade'), - ('Upgrade', 'websocket'), - ('Sec-WebSocket-Accept', create_server_nonce(key)) - ] - ) - - -def make_length_code(length): - """ - A websockets frame contains an initial length_code, and an optional - extended length code to represent the actual length if length code is - larger than 125 - """ - if length <= 125: - return length - elif length >= 126 and length <= 65535: - return 126 - else: - return 127 - - -def check_client_handshake(headers): - if headers.get_first("upgrade", None) != "websocket": - return - return headers.get_first('sec-websocket-key') - - -def check_server_handshake(headers): - if headers.get_first("upgrade", None) != "websocket": - return - return headers.get_first('sec-websocket-accept') - - -def create_server_nonce(client_nonce): - return base64.b64encode( - hashlib.sha1(client_nonce + websockets_magic).hexdigest().decode('hex') - ) - - -DEFAULT = object() - - -class FrameHeader(object): - - def __init__( - self, - opcode=OPCODE.TEXT, - payload_length=0, - fin=False, - rsv1=False, - rsv2=False, - rsv3=False, - masking_key=DEFAULT, - mask=DEFAULT, - length_code=DEFAULT - ): - if not 0 <= opcode < 2 ** 4: - raise ValueError("opcode must be 0-16") - self.opcode = opcode - self.payload_length = payload_length - self.fin = fin - self.rsv1 = rsv1 - self.rsv2 = rsv2 - self.rsv3 = rsv3 - - if length_code is DEFAULT: - self.length_code = make_length_code(self.payload_length) - else: - self.length_code = length_code - - if mask is DEFAULT and masking_key is DEFAULT: - self.mask = False - self.masking_key = "" - elif mask is DEFAULT: - self.mask = 1 - self.masking_key = masking_key - elif masking_key is DEFAULT: - self.mask = mask - self.masking_key = os.urandom(4) - else: - self.mask = mask - self.masking_key = masking_key - - if self.masking_key and len(self.masking_key) != 4: - raise ValueError("Masking key must be 4 bytes.") - - def human_readable(self): - vals = [ - "ws frame:", - OPCODE.get_name(self.opcode, hex(self.opcode)).lower() - ] - flags = [] - for i in ["fin", "rsv1", "rsv2", "rsv3", "mask"]: - if getattr(self, i): - flags.append(i) - if flags: - vals.extend([":", "|".join(flags)]) - if self.masking_key: - vals.append(":key=%s" % repr(self.masking_key)) - if self.payload_length: - vals.append(" %s" % utils.pretty_size(self.payload_length)) - return "".join(vals) - - def to_bytes(self): - first_byte = utils.setbit(0, 7, self.fin) - first_byte = utils.setbit(first_byte, 6, self.rsv1) - first_byte = utils.setbit(first_byte, 5, self.rsv2) - first_byte = utils.setbit(first_byte, 4, self.rsv3) - first_byte = first_byte | self.opcode - - second_byte = utils.setbit(self.length_code, 7, self.mask) - - b = chr(first_byte) + chr(second_byte) - - if self.payload_length < 126: - pass - elif self.payload_length < MAX_16_BIT_INT: - # '!H' pack as 16 bit unsigned short - # add 2 byte extended payload length - b += struct.pack('!H', self.payload_length) - elif self.payload_length < MAX_64_BIT_INT: - # '!Q' = pack as 64 bit unsigned long long - # add 8 bytes extended payload length - b += struct.pack('!Q', self.payload_length) - if self.masking_key is not None: - b += self.masking_key - return b - - @classmethod - def from_file(cls, fp): - """ - read a websockets frame header - """ - first_byte = utils.bytes_to_int(fp.safe_read(1)) - second_byte = utils.bytes_to_int(fp.safe_read(1)) - - fin = utils.getbit(first_byte, 7) - rsv1 = utils.getbit(first_byte, 6) - rsv2 = utils.getbit(first_byte, 5) - rsv3 = utils.getbit(first_byte, 4) - # grab right-most 4 bits - opcode = first_byte & 15 - mask_bit = utils.getbit(second_byte, 7) - # grab the next 7 bits - length_code = second_byte & 127 - - # payload_lengthy > 125 indicates you need to read more bytes - # to get the actual payload length - if length_code <= 125: - payload_length = length_code - elif length_code == 126: - payload_length = utils.bytes_to_int(fp.safe_read(2)) - elif length_code == 127: - payload_length = utils.bytes_to_int(fp.safe_read(8)) - - # masking key only present if mask bit set - if mask_bit == 1: - masking_key = fp.safe_read(4) - else: - masking_key = None - - return cls( - fin=fin, - rsv1=rsv1, - rsv2=rsv2, - rsv3=rsv3, - opcode=opcode, - mask=mask_bit, - length_code=length_code, - payload_length=payload_length, - masking_key=masking_key, - ) - - def __eq__(self, other): - return self.to_bytes() == other.to_bytes() - - -class Frame(object): - - """ - Represents one websockets frame. - Constructor takes human readable forms of the frame components - from_bytes() is also avaliable. - - WebSockets Frame as defined in RFC6455 - - 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 - +-+-+-+-+-------+-+-------------+-------------------------------+ - |F|R|R|R| opcode|M| Payload len | Extended payload length | - |I|S|S|S| (4) |A| (7) | (16/64) | - |N|V|V|V| |S| | (if payload len==126/127) | - | |1|2|3| |K| | | - +-+-+-+-+-------+-+-------------+ - - - - - - - - - - - - - - - + - | Extended payload length continued, if payload len == 127 | - + - - - - - - - - - - - - - - - +-------------------------------+ - | |Masking-key, if MASK set to 1 | - +-------------------------------+-------------------------------+ - | Masking-key (continued) | Payload Data | - +-------------------------------- - - - - - - - - - - - - - - - + - : Payload Data continued ... : - + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + - | Payload Data continued ... | - +---------------------------------------------------------------+ - """ - - def __init__(self, payload="", **kwargs): - self.payload = payload - kwargs["payload_length"] = kwargs.get("payload_length", len(payload)) - self.header = FrameHeader(**kwargs) - - @classmethod - def default(cls, message, from_client=False): - """ - Construct a basic websocket frame from some default values. - Creates a non-fragmented text frame. - """ - if from_client: - mask_bit = 1 - masking_key = os.urandom(4) - else: - mask_bit = 0 - masking_key = None - - return cls( - message, - fin=1, # final frame - opcode=OPCODE.TEXT, # text - mask=mask_bit, - masking_key=masking_key, - ) - - @classmethod - def from_bytes(cls, bytestring): - """ - Construct a websocket frame from an in-memory bytestring - to construct a frame from a stream of bytes, use from_file() directly - """ - return cls.from_file(tcp.Reader(io.BytesIO(bytestring))) - - def human_readable(self): - ret = self.header.human_readable() - if self.payload: - ret = ret + "\nPayload:\n" + utils.cleanBin(self.payload) - return ret - - def __repr__(self): - return self.header.human_readable() - - def to_bytes(self): - """ - Serialize the frame to wire format. Returns a string. - """ - b = self.header.to_bytes() - if self.header.masking_key: - b += Masker(self.header.masking_key)(self.payload) - else: - b += self.payload - return b - - def to_file(self, writer): - writer.write(self.to_bytes()) - writer.flush() - - @classmethod - def from_file(cls, fp): - """ - read a websockets frame sent by a server or client - - fp is a "file like" object that could be backed by a network - stream or a disk or an in memory stream reader - """ - header = FrameHeader.from_file(fp) - payload = fp.safe_read(header.payload_length) - - if header.mask == 1 and header.masking_key: - payload = Masker(header.masking_key)(payload) - - return cls( - payload, - fin=header.fin, - opcode=header.opcode, - mask=header.mask, - payload_length=header.payload_length, - masking_key=header.masking_key, - rsv1=header.rsv1, - rsv2=header.rsv2, - rsv3=header.rsv3, - length_code=header.length_code - ) - - def __eq__(self, other): - return self.to_bytes() == other.to_bytes() diff --git a/netlib/websockets/__init__.py b/netlib/websockets/__init__.py new file mode 100644 index 00000000..5acf7696 --- /dev/null +++ b/netlib/websockets/__init__.py @@ -0,0 +1,2 @@ +from frame import * +from protocol import * diff --git a/netlib/websockets/frame.py b/netlib/websockets/frame.py new file mode 100644 index 00000000..d41059fa --- /dev/null +++ b/netlib/websockets/frame.py @@ -0,0 +1,288 @@ +from __future__ import absolute_import +import base64 +import hashlib +import os +import struct +import io + +from .protocol import Masker +from .. import utils, odict, tcp + +DEFAULT = object() + +MAX_16_BIT_INT = (1 << 16) +MAX_64_BIT_INT = (1 << 64) + +OPCODE = utils.BiDi( + CONTINUE=0x00, + TEXT=0x01, + BINARY=0x02, + CLOSE=0x08, + PING=0x09, + PONG=0x0a +) + +class FrameHeader(object): + + def __init__( + self, + opcode=OPCODE.TEXT, + payload_length=0, + fin=False, + rsv1=False, + rsv2=False, + rsv3=False, + masking_key=DEFAULT, + mask=DEFAULT, + length_code=DEFAULT + ): + if not 0 <= opcode < 2 ** 4: + raise ValueError("opcode must be 0-16") + self.opcode = opcode + self.payload_length = payload_length + self.fin = fin + self.rsv1 = rsv1 + self.rsv2 = rsv2 + self.rsv3 = rsv3 + + if length_code is DEFAULT: + self.length_code = self._make_length_code(self.payload_length) + else: + self.length_code = length_code + + if mask is DEFAULT and masking_key is DEFAULT: + self.mask = False + self.masking_key = "" + elif mask is DEFAULT: + self.mask = 1 + self.masking_key = masking_key + elif masking_key is DEFAULT: + self.mask = mask + self.masking_key = os.urandom(4) + else: + self.mask = mask + self.masking_key = masking_key + + if self.masking_key and len(self.masking_key) != 4: + raise ValueError("Masking key must be 4 bytes.") + + @classmethod + def _make_length_code(self, length): + """ + A websockets frame contains an initial length_code, and an optional + extended length code to represent the actual length if length code is + larger than 125 + """ + if length <= 125: + return length + elif length >= 126 and length <= 65535: + return 126 + else: + return 127 + + def human_readable(self): + vals = [ + "ws frame:", + OPCODE.get_name(self.opcode, hex(self.opcode)).lower() + ] + flags = [] + for i in ["fin", "rsv1", "rsv2", "rsv3", "mask"]: + if getattr(self, i): + flags.append(i) + if flags: + vals.extend([":", "|".join(flags)]) + if self.masking_key: + vals.append(":key=%s" % repr(self.masking_key)) + if self.payload_length: + vals.append(" %s" % utils.pretty_size(self.payload_length)) + return "".join(vals) + + def to_bytes(self): + first_byte = utils.setbit(0, 7, self.fin) + first_byte = utils.setbit(first_byte, 6, self.rsv1) + first_byte = utils.setbit(first_byte, 5, self.rsv2) + first_byte = utils.setbit(first_byte, 4, self.rsv3) + first_byte = first_byte | self.opcode + + second_byte = utils.setbit(self.length_code, 7, self.mask) + + b = chr(first_byte) + chr(second_byte) + + if self.payload_length < 126: + pass + elif self.payload_length < MAX_16_BIT_INT: + # '!H' pack as 16 bit unsigned short + # add 2 byte extended payload length + b += struct.pack('!H', self.payload_length) + elif self.payload_length < MAX_64_BIT_INT: + # '!Q' = pack as 64 bit unsigned long long + # add 8 bytes extended payload length + b += struct.pack('!Q', self.payload_length) + if self.masking_key is not None: + b += self.masking_key + return b + + @classmethod + def from_file(cls, fp): + """ + read a websockets frame header + """ + first_byte = utils.bytes_to_int(fp.safe_read(1)) + second_byte = utils.bytes_to_int(fp.safe_read(1)) + + fin = utils.getbit(first_byte, 7) + rsv1 = utils.getbit(first_byte, 6) + rsv2 = utils.getbit(first_byte, 5) + rsv3 = utils.getbit(first_byte, 4) + # grab right-most 4 bits + opcode = first_byte & 15 + mask_bit = utils.getbit(second_byte, 7) + # grab the next 7 bits + length_code = second_byte & 127 + + # payload_lengthy > 125 indicates you need to read more bytes + # to get the actual payload length + if length_code <= 125: + payload_length = length_code + elif length_code == 126: + payload_length = utils.bytes_to_int(fp.safe_read(2)) + elif length_code == 127: + payload_length = utils.bytes_to_int(fp.safe_read(8)) + + # masking key only present if mask bit set + if mask_bit == 1: + masking_key = fp.safe_read(4) + else: + masking_key = None + + return cls( + fin=fin, + rsv1=rsv1, + rsv2=rsv2, + rsv3=rsv3, + opcode=opcode, + mask=mask_bit, + length_code=length_code, + payload_length=payload_length, + masking_key=masking_key, + ) + + def __eq__(self, other): + return self.to_bytes() == other.to_bytes() + + +class Frame(object): + + """ + Represents one websockets frame. + Constructor takes human readable forms of the frame components + from_bytes() is also avaliable. + + WebSockets Frame as defined in RFC6455 + + 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 + +-+-+-+-+-------+-+-------------+-------------------------------+ + |F|R|R|R| opcode|M| Payload len | Extended payload length | + |I|S|S|S| (4) |A| (7) | (16/64) | + |N|V|V|V| |S| | (if payload len==126/127) | + | |1|2|3| |K| | | + +-+-+-+-+-------+-+-------------+ - - - - - - - - - - - - - - - + + | Extended payload length continued, if payload len == 127 | + + - - - - - - - - - - - - - - - +-------------------------------+ + | |Masking-key, if MASK set to 1 | + +-------------------------------+-------------------------------+ + | Masking-key (continued) | Payload Data | + +-------------------------------- - - - - - - - - - - - - - - - + + : Payload Data continued ... : + + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + | Payload Data continued ... | + +---------------------------------------------------------------+ + """ + + def __init__(self, payload="", **kwargs): + self.payload = payload + kwargs["payload_length"] = kwargs.get("payload_length", len(payload)) + self.header = FrameHeader(**kwargs) + + @classmethod + def default(cls, message, from_client=False): + """ + Construct a basic websocket frame from some default values. + Creates a non-fragmented text frame. + """ + if from_client: + mask_bit = 1 + masking_key = os.urandom(4) + else: + mask_bit = 0 + masking_key = None + + return cls( + message, + fin=1, # final frame + opcode=OPCODE.TEXT, # text + mask=mask_bit, + masking_key=masking_key, + ) + + @classmethod + def from_bytes(cls, bytestring): + """ + Construct a websocket frame from an in-memory bytestring + to construct a frame from a stream of bytes, use from_file() directly + """ + return cls.from_file(tcp.Reader(io.BytesIO(bytestring))) + + def human_readable(self): + ret = self.header.human_readable() + if self.payload: + ret = ret + "\nPayload:\n" + utils.cleanBin(self.payload) + return ret + + def __repr__(self): + return self.header.human_readable() + + def to_bytes(self): + """ + Serialize the frame to wire format. Returns a string. + """ + b = self.header.to_bytes() + if self.header.masking_key: + b += Masker(self.header.masking_key)(self.payload) + else: + b += self.payload + return b + + def to_file(self, writer): + writer.write(self.to_bytes()) + writer.flush() + + @classmethod + def from_file(cls, fp): + """ + read a websockets frame sent by a server or client + + fp is a "file like" object that could be backed by a network + stream or a disk or an in memory stream reader + """ + header = FrameHeader.from_file(fp) + payload = fp.safe_read(header.payload_length) + + if header.mask == 1 and header.masking_key: + payload = Masker(header.masking_key)(payload) + + return cls( + payload, + fin=header.fin, + opcode=header.opcode, + mask=header.mask, + payload_length=header.payload_length, + masking_key=header.masking_key, + rsv1=header.rsv1, + rsv2=header.rsv2, + rsv3=header.rsv3, + length_code=header.length_code + ) + + def __eq__(self, other): + return self.to_bytes() == other.to_bytes() diff --git a/netlib/websockets/protocol.py b/netlib/websockets/protocol.py new file mode 100644 index 00000000..dcab53fb --- /dev/null +++ b/netlib/websockets/protocol.py @@ -0,0 +1,111 @@ +from __future__ import absolute_import +import base64 +import hashlib +import os +import struct +import io + +from .. import utils, odict, tcp + +# Colleciton of utility functions that implement small portions of the RFC6455 +# WebSockets Protocol Useful for building WebSocket clients and servers. +# +# Emphassis is on readabilty, simplicity and modularity, not performance or +# completeness +# +# This is a work in progress and does not yet contain all the utilites need to +# create fully complient client/servers # +# Spec: https://tools.ietf.org/html/rfc6455 + +# The magic sha that websocket servers must know to prove they understand +# RFC6455 +websockets_magic = '258EAFA5-E914-47DA-95CA-C5AB0DC85B11' +VERSION = "13" + +HEADER_WEBSOCKET_KEY = 'sec-websocket-key' +HEADER_WEBSOCKET_ACCEPT = 'sec-websocket-accept' +HEADER_WEBSOCKET_VERSION = 'sec-websocket-version' + +class Masker(object): + + """ + Data sent from the server must be masked to prevent malicious clients + from sending data over the wire in predictable patterns + + Servers do not have to mask data they send to the client. + https://tools.ietf.org/html/rfc6455#section-5.3 + """ + + def __init__(self, key): + self.key = key + self.masks = [utils.bytes_to_int(byte) for byte in key] + self.offset = 0 + + def mask(self, offset, data): + result = "" + for c in data: + result += chr(ord(c) ^ self.masks[offset % 4]) + offset += 1 + return result + + def __call__(self, data): + ret = self.mask(self.offset, data) + self.offset += len(ret) + return ret + +class WebsocketsProtocol(object): + + def __init__(self): + pass + + @classmethod + def client_handshake_headers(self, key=None, version=VERSION): + """ + Create the headers for a valid HTTP upgrade request. If Key is not + specified, it is generated, and can be found in sec-websocket-key in + the returned header set. + + Returns an instance of ODictCaseless + """ + if not key: + key = base64.b64encode(os.urandom(16)).decode('utf-8') + return odict.ODictCaseless([ + ('Connection', 'Upgrade'), + ('Upgrade', 'websocket'), + (HEADER_WEBSOCKET_KEY, key), + (HEADER_WEBSOCKET_VERSION, version) + ]) + + @classmethod + def server_handshake_headers(self, key): + """ + The server response is a valid HTTP 101 response. + """ + return odict.ODictCaseless( + [ + ('Connection', 'Upgrade'), + ('Upgrade', 'websocket'), + (HEADER_WEBSOCKET_ACCEPT, self.create_server_nonce(key)) + ] + ) + + + @classmethod + def check_client_handshake(self, headers): + if headers.get_first("upgrade", None) != "websocket": + return + return headers.get_first(HEADER_WEBSOCKET_KEY) + + + @classmethod + def check_server_handshake(self, headers): + if headers.get_first("upgrade", None) != "websocket": + return + return headers.get_first(HEADER_WEBSOCKET_ACCEPT) + + + @classmethod + def create_server_nonce(self, client_nonce): + return base64.b64encode( + hashlib.sha1(client_nonce + websockets_magic).hexdigest().decode('hex') + ) -- cgit v1.2.3 From f50deb7b763d093a22a4d331e16465a2fb0329cf Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Tue, 14 Jul 2015 23:02:14 +0200 Subject: move bits around --- netlib/http.py | 583 -------------------------------------- netlib/http/__init__.py | 2 + netlib/http/authentication.py | 149 ++++++++++ netlib/http/cookies.py | 193 +++++++++++++ netlib/http/exceptions.py | 9 + netlib/http/http1/__init__.py | 1 + netlib/http/http1/protocol.py | 518 ++++++++++++++++++++++++++++++++++ netlib/http/http2/__init__.py | 2 + netlib/http/http2/frame.py | 636 ++++++++++++++++++++++++++++++++++++++++++ netlib/http/http2/protocol.py | 240 ++++++++++++++++ netlib/http/semantics.py | 94 +++++++ netlib/http/status_codes.py | 104 +++++++ netlib/http/user_agents.py | 52 ++++ netlib/http2/__init__.py | 2 - netlib/http2/frame.py | 636 ------------------------------------------ netlib/http2/protocol.py | 240 ---------------- netlib/http_auth.py | 148 ---------- netlib/http_cookies.py | 193 ------------- netlib/http_semantics.py | 23 -- netlib/http_status.py | 104 ------- netlib/http_uastrings.py | 52 ---- netlib/websockets/frame.py | 2 +- netlib/websockets/protocol.py | 2 +- 23 files changed, 2002 insertions(+), 1983 deletions(-) delete mode 100644 netlib/http.py create mode 100644 netlib/http/__init__.py create mode 100644 netlib/http/authentication.py create mode 100644 netlib/http/cookies.py create mode 100644 netlib/http/exceptions.py create mode 100644 netlib/http/http1/__init__.py create mode 100644 netlib/http/http1/protocol.py create mode 100644 netlib/http/http2/__init__.py create mode 100644 netlib/http/http2/frame.py create mode 100644 netlib/http/http2/protocol.py create mode 100644 netlib/http/semantics.py create mode 100644 netlib/http/status_codes.py create mode 100644 netlib/http/user_agents.py delete mode 100644 netlib/http2/__init__.py delete mode 100644 netlib/http2/frame.py delete mode 100644 netlib/http2/protocol.py delete mode 100644 netlib/http_auth.py delete mode 100644 netlib/http_cookies.py delete mode 100644 netlib/http_semantics.py delete mode 100644 netlib/http_status.py delete mode 100644 netlib/http_uastrings.py (limited to 'netlib') diff --git a/netlib/http.py b/netlib/http.py deleted file mode 100644 index 073e9a3f..00000000 --- a/netlib/http.py +++ /dev/null @@ -1,583 +0,0 @@ -from __future__ import (absolute_import, print_function, division) -import collections -import string -import urlparse -import binascii -import sys -from . import odict, utils, tcp, http_semantics, http_status - - -class HttpError(Exception): - - def __init__(self, code, message): - super(HttpError, self).__init__(message) - self.code = code - - -class HttpErrorConnClosed(HttpError): - pass - - -def _is_valid_port(port): - if not 0 <= port <= 65535: - return False - return True - - -def _is_valid_host(host): - try: - host.decode("idna") - except ValueError: - return False - if "\0" in host: - return None - return True - - -def get_request_line(fp): - """ - Get a line, possibly preceded by a blank. - """ - line = fp.readline() - if line == "\r\n" or line == "\n": - # Possible leftover from previous message - line = fp.readline() - return line - - -def parse_url(url): - """ - Returns a (scheme, host, port, path) tuple, or None on error. - - Checks that: - port is an integer 0-65535 - host is a valid IDNA-encoded hostname with no null-bytes - path is valid ASCII - """ - try: - scheme, netloc, path, params, query, fragment = urlparse.urlparse(url) - except ValueError: - return None - if not scheme: - return None - if '@' in netloc: - # FIXME: Consider what to do with the discarded credentials here Most - # probably we should extend the signature to return these as a separate - # value. - _, netloc = string.rsplit(netloc, '@', maxsplit=1) - if ':' in netloc: - host, port = string.rsplit(netloc, ':', maxsplit=1) - try: - port = int(port) - except ValueError: - return None - else: - host = netloc - if scheme == "https": - port = 443 - else: - port = 80 - path = urlparse.urlunparse(('', '', path, params, query, fragment)) - if not path.startswith("/"): - path = "/" + path - if not _is_valid_host(host): - return None - if not utils.isascii(path): - return None - if not _is_valid_port(port): - return None - return scheme, host, port, path - - -def read_headers(fp): - """ - Read a set of headers from a file pointer. Stop once a blank line is - reached. Return a ODictCaseless object, or None if headers are invalid. - """ - ret = [] - name = '' - while True: - line = fp.readline() - if not line or line == '\r\n' or line == '\n': - break - if line[0] in ' \t': - if not ret: - return None - # continued header - ret[-1][1] = ret[-1][1] + '\r\n ' + line.strip() - else: - i = line.find(':') - # We're being liberal in what we accept, here. - if i > 0: - name = line[:i] - value = line[i + 1:].strip() - ret.append([name, value]) - else: - return None - return odict.ODictCaseless(ret) - - -def read_chunked(fp, limit, is_request): - """ - Read a chunked HTTP body. - - May raise HttpError. - """ - # FIXME: Should check if chunked is the final encoding in the headers - # http://tools.ietf.org/html/draft-ietf-httpbis-p1-messaging-16#section-3.3 - # 3.3 2. - total = 0 - code = 400 if is_request else 502 - while True: - line = fp.readline(128) - if line == "": - raise HttpErrorConnClosed(code, "Connection closed prematurely") - if line != '\r\n' and line != '\n': - try: - length = int(line, 16) - except ValueError: - raise HttpError( - code, - "Invalid chunked encoding length: %s" % line - ) - total += length - if limit is not None and total > limit: - msg = "HTTP Body too large. Limit is %s," \ - " chunked content longer than %s" % (limit, total) - raise HttpError(code, msg) - chunk = fp.read(length) - suffix = fp.readline(5) - if suffix != '\r\n': - raise HttpError(code, "Malformed chunked body") - yield line, chunk, '\r\n' - if length == 0: - return - - -def get_header_tokens(headers, key): - """ - Retrieve all tokens for a header key. A number of different headers - follow a pattern where each header line can containe comma-separated - tokens, and headers can be set multiple times. - """ - toks = [] - for i in headers[key]: - for j in i.split(","): - toks.append(j.strip()) - return toks - - -def has_chunked_encoding(headers): - return "chunked" in [ - i.lower() for i in get_header_tokens(headers, "transfer-encoding") - ] - - -def parse_http_protocol(s): - """ - Parse an HTTP protocol declaration. Returns a (major, minor) tuple, or - None. - """ - if not s.startswith("HTTP/"): - return None - _, version = s.split('/', 1) - if "." not in version: - return None - major, minor = version.split('.', 1) - try: - major = int(major) - minor = int(minor) - except ValueError: - return None - return major, minor - - -def parse_http_basic_auth(s): - words = s.split() - if len(words) != 2: - return None - scheme = words[0] - try: - user = binascii.a2b_base64(words[1]) - except binascii.Error: - return None - parts = user.split(':') - if len(parts) != 2: - return None - return scheme, parts[0], parts[1] - - -def assemble_http_basic_auth(scheme, username, password): - v = binascii.b2a_base64(username + ":" + password) - return scheme + " " + v - - -def parse_init(line): - try: - method, url, protocol = string.split(line) - except ValueError: - return None - httpversion = parse_http_protocol(protocol) - if not httpversion: - return None - if not utils.isascii(method): - return None - return method, url, httpversion - - -def parse_init_connect(line): - """ - Returns (host, port, httpversion) if line is a valid CONNECT line. - http://tools.ietf.org/html/draft-luotonen-web-proxy-tunneling-01 section 3.1 - """ - v = parse_init(line) - if not v: - return None - method, url, httpversion = v - - if method.upper() != 'CONNECT': - return None - try: - host, port = url.split(":") - except ValueError: - return None - try: - port = int(port) - except ValueError: - return None - if not _is_valid_port(port): - return None - if not _is_valid_host(host): - return None - return host, port, httpversion - - -def parse_init_proxy(line): - v = parse_init(line) - if not v: - return None - method, url, httpversion = v - - parts = parse_url(url) - if not parts: - return None - scheme, host, port, path = parts - return method, scheme, host, port, path, httpversion - - -def parse_init_http(line): - """ - Returns (method, url, httpversion) - """ - v = parse_init(line) - if not v: - return None - method, url, httpversion = v - if not utils.isascii(url): - return None - if not (url.startswith("/") or url == "*"): - return None - return method, url, httpversion - - -def connection_close(httpversion, headers): - """ - Checks the message to see if the client connection should be closed - according to RFC 2616 Section 8.1 Note that a connection should be - closed as well if the response has been read until end of the stream. - """ - # At first, check if we have an explicit Connection header. - if "connection" in headers: - toks = get_header_tokens(headers, "connection") - if "close" in toks: - return True - elif "keep-alive" in toks: - return False - # If we don't have a Connection header, HTTP 1.1 connections are assumed to - # be persistent - if httpversion == (1, 1): - return False - return True - - -def parse_response_line(line): - parts = line.strip().split(" ", 2) - if len(parts) == 2: # handle missing message gracefully - parts.append("") - if len(parts) != 3: - return None - proto, code, msg = parts - try: - code = int(code) - except ValueError: - return None - return (proto, code, msg) - - -def read_http_body(*args, **kwargs): - return "".join( - content for _, content, _ in read_http_body_chunked(*args, **kwargs) - ) - - -def read_http_body_chunked( - rfile, - headers, - limit, - request_method, - response_code, - is_request, - max_chunk_size=None -): - """ - Read an HTTP message body: - - rfile: A file descriptor to read from - headers: An ODictCaseless object - limit: Size limit. - is_request: True if the body to read belongs to a request, False - otherwise - """ - if max_chunk_size is None: - max_chunk_size = limit or sys.maxsize - - expected_size = expected_http_body_size( - headers, is_request, request_method, response_code - ) - - if expected_size is None: - if has_chunked_encoding(headers): - # Python 3: yield from - for x in read_chunked(rfile, limit, is_request): - yield x - else: # pragma: nocover - raise HttpError( - 400 if is_request else 502, - "Content-Length unknown but no chunked encoding" - ) - elif expected_size >= 0: - if limit is not None and expected_size > limit: - raise HttpError( - 400 if is_request else 509, - "HTTP Body too large. Limit is %s, content-length was %s" % ( - limit, expected_size - ) - ) - bytes_left = expected_size - while bytes_left: - chunk_size = min(bytes_left, max_chunk_size) - yield "", rfile.read(chunk_size), "" - bytes_left -= chunk_size - else: - bytes_left = limit or -1 - while bytes_left: - chunk_size = min(bytes_left, max_chunk_size) - content = rfile.read(chunk_size) - if not content: - return - yield "", content, "" - bytes_left -= chunk_size - not_done = rfile.read(1) - if not_done: - raise HttpError( - 400 if is_request else 509, - "HTTP Body too large. Limit is %s," % limit - ) - - -def expected_http_body_size(headers, is_request, request_method, response_code): - """ - Returns the expected body length: - - a positive integer, if the size is known in advance - - None, if the size in unknown in advance (chunked encoding or invalid - data) - - -1, if all data should be read until end of stream. - - May raise HttpError. - """ - # Determine response size according to - # http://tools.ietf.org/html/rfc7230#section-3.3 - if request_method: - request_method = request_method.upper() - - if (not is_request and ( - request_method == "HEAD" or - (request_method == "CONNECT" and response_code == 200) or - response_code in [204, 304] or - 100 <= response_code <= 199)): - return 0 - if has_chunked_encoding(headers): - return None - if "content-length" in headers: - try: - size = int(headers["content-length"][0]) - if size < 0: - raise ValueError() - return size - except ValueError: - return None - if is_request: - return 0 - return -1 - - -Request = collections.namedtuple( - "Request", - [ - "form_in", - "method", - "scheme", - "host", - "port", - "path", - "httpversion", - "headers", - "content" - ] -) - - -def read_request(rfile, include_body=True, body_size_limit=None, wfile=None): - """ - Parse an HTTP request from a file stream - - Args: - rfile (file): Input file to read from - include_body (bool): Read response body as well - body_size_limit (bool): Maximum body size - wfile (file): If specified, HTTP Expect headers are handled - automatically, by writing a HTTP 100 CONTINUE response to the stream. - - Returns: - Request: The HTTP request - - Raises: - HttpError: If the input is invalid. - """ - httpversion, host, port, scheme, method, path, headers, content = ( - None, None, None, None, None, None, None, None) - - request_line = get_request_line(rfile) - if not request_line: - raise tcp.NetLibDisconnect() - - request_line_parts = parse_init(request_line) - if not request_line_parts: - raise HttpError( - 400, - "Bad HTTP request line: %s" % repr(request_line) - ) - method, path, httpversion = request_line_parts - - if path == '*' or path.startswith("/"): - form_in = "relative" - if not utils.isascii(path): - raise HttpError( - 400, - "Bad HTTP request line: %s" % repr(request_line) - ) - elif method.upper() == 'CONNECT': - form_in = "authority" - r = parse_init_connect(request_line) - if not r: - raise HttpError( - 400, - "Bad HTTP request line: %s" % repr(request_line) - ) - host, port, _ = r - path = None - else: - form_in = "absolute" - r = parse_init_proxy(request_line) - if not r: - raise HttpError( - 400, - "Bad HTTP request line: %s" % repr(request_line) - ) - _, scheme, host, port, path, _ = r - - headers = read_headers(rfile) - if headers is None: - raise HttpError(400, "Invalid headers") - - expect_header = headers.get_first("expect", "").lower() - if expect_header == "100-continue" and httpversion >= (1, 1): - wfile.write( - 'HTTP/1.1 100 Continue\r\n' - '\r\n' - ) - wfile.flush() - del headers['expect'] - - if include_body: - content = read_http_body( - rfile, headers, body_size_limit, method, None, True - ) - - return Request( - form_in, - method, - scheme, - host, - port, - path, - httpversion, - headers, - content - ) - - -def read_response(rfile, request_method, body_size_limit, include_body=True): - """ - Return an (httpversion, code, msg, headers, content) tuple. - - By default, both response header and body are read. - If include_body=False is specified, content may be one of the - following: - - None, if the response is technically allowed to have a response body - - "", if the response must not have a response body (e.g. it's a - response to a HEAD request) - """ - line = rfile.readline() - # Possible leftover from previous message - if line == "\r\n" or line == "\n": - line = rfile.readline() - if not line: - raise HttpErrorConnClosed(502, "Server disconnect.") - parts = parse_response_line(line) - if not parts: - raise HttpError(502, "Invalid server response: %s" % repr(line)) - proto, code, msg = parts - httpversion = parse_http_protocol(proto) - if httpversion is None: - raise HttpError(502, "Invalid HTTP version in line: %s" % repr(proto)) - headers = read_headers(rfile) - if headers is None: - raise HttpError(502, "Invalid headers.") - - if include_body: - content = read_http_body( - rfile, - headers, - body_size_limit, - request_method, - code, - False - ) - else: - # if include_body==False then a None content means the body should be - # read separately - content = None - return http_semantics.Response(httpversion, code, msg, headers, content) - - -def request_preamble(method, resource, http_major="1", http_minor="1"): - return '%s %s HTTP/%s.%s' % ( - method, resource, http_major, http_minor - ) - - -def response_preamble(code, message=None, http_major="1", http_minor="1"): - if message is None: - message = http_status.RESPONSES.get(code) - return 'HTTP/%s.%s %s %s' % (http_major, http_minor, code, message) diff --git a/netlib/http/__init__.py b/netlib/http/__init__.py new file mode 100644 index 00000000..9b4b0e6b --- /dev/null +++ b/netlib/http/__init__.py @@ -0,0 +1,2 @@ +from exceptions import * +from semantics import * diff --git a/netlib/http/authentication.py b/netlib/http/authentication.py new file mode 100644 index 00000000..26e3c2c4 --- /dev/null +++ b/netlib/http/authentication.py @@ -0,0 +1,149 @@ +from __future__ import (absolute_import, print_function, division) +from argparse import Action, ArgumentTypeError + +from .. import http + + +class NullProxyAuth(object): + + """ + No proxy auth at all (returns empty challange headers) + """ + + def __init__(self, password_manager): + self.password_manager = password_manager + + def clean(self, headers_): + """ + Clean up authentication headers, so they're not passed upstream. + """ + pass + + def authenticate(self, headers_): + """ + Tests that the user is allowed to use the proxy + """ + return True + + def auth_challenge_headers(self): + """ + Returns a dictionary containing the headers require to challenge the user + """ + return {} + + +class BasicProxyAuth(NullProxyAuth): + CHALLENGE_HEADER = 'Proxy-Authenticate' + AUTH_HEADER = 'Proxy-Authorization' + + def __init__(self, password_manager, realm): + NullProxyAuth.__init__(self, password_manager) + self.realm = realm + + def clean(self, headers): + del headers[self.AUTH_HEADER] + + def authenticate(self, headers): + auth_value = headers.get(self.AUTH_HEADER, []) + if not auth_value: + return False + parts = http.http1.parse_http_basic_auth(auth_value[0]) + if not parts: + return False + scheme, username, password = parts + if scheme.lower() != 'basic': + return False + if not self.password_manager.test(username, password): + return False + self.username = username + return True + + def auth_challenge_headers(self): + return {self.CHALLENGE_HEADER: 'Basic realm="%s"' % self.realm} + + +class PassMan(object): + + def test(self, username_, password_token_): + return False + + +class PassManNonAnon(PassMan): + + """ + Ensure the user specifies a username, accept any password. + """ + + def test(self, username, password_token_): + if username: + return True + return False + + +class PassManHtpasswd(PassMan): + + """ + Read usernames and passwords from an htpasswd file + """ + + def __init__(self, path): + """ + Raises ValueError if htpasswd file is invalid. + """ + import passlib.apache + self.htpasswd = passlib.apache.HtpasswdFile(path) + + def test(self, username, password_token): + return bool(self.htpasswd.check_password(username, password_token)) + + +class PassManSingleUser(PassMan): + + def __init__(self, username, password): + self.username, self.password = username, password + + def test(self, username, password_token): + return self.username == username and self.password == password_token + + +class AuthAction(Action): + + """ + Helper class to allow seamless integration int argparse. Example usage: + parser.add_argument( + "--nonanonymous", + action=NonanonymousAuthAction, nargs=0, + help="Allow access to any user long as a credentials are specified." + ) + """ + + def __call__(self, parser, namespace, values, option_string=None): + passman = self.getPasswordManager(values) + authenticator = BasicProxyAuth(passman, "mitmproxy") + setattr(namespace, self.dest, authenticator) + + def getPasswordManager(self, s): # pragma: nocover + raise NotImplementedError() + + +class SingleuserAuthAction(AuthAction): + + def getPasswordManager(self, s): + if len(s.split(':')) != 2: + raise ArgumentTypeError( + "Invalid single-user specification. Please use the format username:password" + ) + username, password = s.split(':') + return PassManSingleUser(username, password) + + +class NonanonymousAuthAction(AuthAction): + + def getPasswordManager(self, s): + return PassManNonAnon() + + +class HtpasswdAuthAction(AuthAction): + + def getPasswordManager(self, s): + return PassManHtpasswd(s) diff --git a/netlib/http/cookies.py b/netlib/http/cookies.py new file mode 100644 index 00000000..b77e3503 --- /dev/null +++ b/netlib/http/cookies.py @@ -0,0 +1,193 @@ +import re + +from .. import odict + +""" +A flexible module for cookie parsing and manipulation. + +This module differs from usual standards-compliant cookie modules in a number +of ways. We try to be as permissive as possible, and to retain even mal-formed +information. Duplicate cookies are preserved in parsing, and can be set in +formatting. We do attempt to escape and quote values where needed, but will not +reject data that violate the specs. + +Parsing accepts the formats in RFC6265 and partially RFC2109 and RFC2965. We do +not parse the comma-separated variant of Set-Cookie that allows multiple +cookies to be set in a single header. Technically this should be feasible, but +it turns out that violations of RFC6265 that makes the parsing problem +indeterminate are much more common than genuine occurences of the multi-cookie +variants. Serialization follows RFC6265. + + http://tools.ietf.org/html/rfc6265 + http://tools.ietf.org/html/rfc2109 + http://tools.ietf.org/html/rfc2965 +""" + +# TODO +# - Disallow LHS-only Cookie values + + +def _read_until(s, start, term): + """ + Read until one of the characters in term is reached. + """ + if start == len(s): + return "", start + 1 + for i in range(start, len(s)): + if s[i] in term: + return s[start:i], i + return s[start:i + 1], i + 1 + + +def _read_token(s, start): + """ + Read a token - the LHS of a token/value pair in a cookie. + """ + return _read_until(s, start, ";=") + + +def _read_quoted_string(s, start): + """ + start: offset to the first quote of the string to be read + + A sort of loose super-set of the various quoted string specifications. + + RFC6265 disallows backslashes or double quotes within quoted strings. + Prior RFCs use backslashes to escape. This leaves us free to apply + backslash escaping by default and be compatible with everything. + """ + escaping = False + ret = [] + # Skip the first quote + for i in range(start + 1, len(s)): + if escaping: + ret.append(s[i]) + escaping = False + elif s[i] == '"': + break + elif s[i] == "\\": + escaping = True + else: + ret.append(s[i]) + return "".join(ret), i + 1 + + +def _read_value(s, start, delims): + """ + Reads a value - the RHS of a token/value pair in a cookie. + + special: If the value is special, commas are premitted. Else comma + terminates. This helps us support old and new style values. + """ + if start >= len(s): + return "", start + elif s[start] == '"': + return _read_quoted_string(s, start) + else: + return _read_until(s, start, delims) + + +def _read_pairs(s, off=0): + """ + Read pairs of lhs=rhs values. + + off: start offset + specials: a lower-cased list of keys that may contain commas + """ + vals = [] + while True: + lhs, off = _read_token(s, off) + lhs = lhs.lstrip() + if lhs: + rhs = None + if off < len(s): + if s[off] == "=": + rhs, off = _read_value(s, off + 1, ";") + vals.append([lhs, rhs]) + off += 1 + if not off < len(s): + break + return vals, off + + +def _has_special(s): + for i in s: + if i in '",;\\': + return True + o = ord(i) + if o < 0x21 or o > 0x7e: + return True + return False + + +ESCAPE = re.compile(r"([\"\\])") + + +def _format_pairs(lst, specials=(), sep="; "): + """ + specials: A lower-cased list of keys that will not be quoted. + """ + vals = [] + for k, v in lst: + if v is None: + vals.append(k) + else: + if k.lower() not in specials and _has_special(v): + v = ESCAPE.sub(r"\\\1", v) + v = '"%s"' % v + vals.append("%s=%s" % (k, v)) + return sep.join(vals) + + +def _format_set_cookie_pairs(lst): + return _format_pairs( + lst, + specials=("expires", "path") + ) + + +def _parse_set_cookie_pairs(s): + """ + For Set-Cookie, we support multiple cookies as described in RFC2109. + This function therefore returns a list of lists. + """ + pairs, off_ = _read_pairs(s) + return pairs + + +def parse_set_cookie_header(line): + """ + Parse a Set-Cookie header value + + Returns a (name, value, attrs) tuple, or None, where attrs is an + ODictCaseless set of attributes. No attempt is made to parse attribute + values - they are treated purely as strings. + """ + pairs = _parse_set_cookie_pairs(line) + if pairs: + return pairs[0][0], pairs[0][1], odict.ODictCaseless(pairs[1:]) + + +def format_set_cookie_header(name, value, attrs): + """ + Formats a Set-Cookie header value. + """ + pairs = [[name, value]] + pairs.extend(attrs.lst) + return _format_set_cookie_pairs(pairs) + + +def parse_cookie_header(line): + """ + Parse a Cookie header value. + Returns a (possibly empty) ODict object. + """ + pairs, off_ = _read_pairs(line) + return odict.ODict(pairs) + + +def format_cookie_header(od): + """ + Formats a Cookie header value. + """ + return _format_pairs(od.lst) diff --git a/netlib/http/exceptions.py b/netlib/http/exceptions.py new file mode 100644 index 00000000..8a2bbebc --- /dev/null +++ b/netlib/http/exceptions.py @@ -0,0 +1,9 @@ +class HttpError(Exception): + + def __init__(self, code, message): + super(HttpError, self).__init__(message) + self.code = code + + +class HttpErrorConnClosed(HttpError): + pass diff --git a/netlib/http/http1/__init__.py b/netlib/http/http1/__init__.py new file mode 100644 index 00000000..6b5043af --- /dev/null +++ b/netlib/http/http1/__init__.py @@ -0,0 +1 @@ +from protocol import * diff --git a/netlib/http/http1/protocol.py b/netlib/http/http1/protocol.py new file mode 100644 index 00000000..0f7a0bd3 --- /dev/null +++ b/netlib/http/http1/protocol.py @@ -0,0 +1,518 @@ +from __future__ import (absolute_import, print_function, division) +import binascii +import collections +import string +import sys +import urlparse + +from netlib import odict, utils, tcp, http +from .. import status_codes +from ..exceptions import * + + +def get_request_line(fp): + """ + Get a line, possibly preceded by a blank. + """ + line = fp.readline() + if line == "\r\n" or line == "\n": + # Possible leftover from previous message + line = fp.readline() + return line + +def read_headers(fp): + """ + Read a set of headers from a file pointer. Stop once a blank line is + reached. Return a ODictCaseless object, or None if headers are invalid. + """ + ret = [] + name = '' + while True: + line = fp.readline() + if not line or line == '\r\n' or line == '\n': + break + if line[0] in ' \t': + if not ret: + return None + # continued header + ret[-1][1] = ret[-1][1] + '\r\n ' + line.strip() + else: + i = line.find(':') + # We're being liberal in what we accept, here. + if i > 0: + name = line[:i] + value = line[i + 1:].strip() + ret.append([name, value]) + else: + return None + return odict.ODictCaseless(ret) + + +def read_chunked(fp, limit, is_request): + """ + Read a chunked HTTP body. + + May raise HttpError. + """ + # FIXME: Should check if chunked is the final encoding in the headers + # http://tools.ietf.org/html/draft-ietf-httpbis-p1-messaging-16#section-3.3 + # 3.3 2. + total = 0 + code = 400 if is_request else 502 + while True: + line = fp.readline(128) + if line == "": + raise HttpErrorConnClosed(code, "Connection closed prematurely") + if line != '\r\n' and line != '\n': + try: + length = int(line, 16) + except ValueError: + raise HttpError( + code, + "Invalid chunked encoding length: %s" % line + ) + total += length + if limit is not None and total > limit: + msg = "HTTP Body too large. Limit is %s," \ + " chunked content longer than %s" % (limit, total) + raise HttpError(code, msg) + chunk = fp.read(length) + suffix = fp.readline(5) + if suffix != '\r\n': + raise HttpError(code, "Malformed chunked body") + yield line, chunk, '\r\n' + if length == 0: + return + + +def get_header_tokens(headers, key): + """ + Retrieve all tokens for a header key. A number of different headers + follow a pattern where each header line can containe comma-separated + tokens, and headers can be set multiple times. + """ + toks = [] + for i in headers[key]: + for j in i.split(","): + toks.append(j.strip()) + return toks + + +def has_chunked_encoding(headers): + return "chunked" in [ + i.lower() for i in get_header_tokens(headers, "transfer-encoding") + ] + + +def parse_http_protocol(s): + """ + Parse an HTTP protocol declaration. Returns a (major, minor) tuple, or + None. + """ + if not s.startswith("HTTP/"): + return None + _, version = s.split('/', 1) + if "." not in version: + return None + major, minor = version.split('.', 1) + try: + major = int(major) + minor = int(minor) + except ValueError: + return None + return major, minor + + +def parse_http_basic_auth(s): + # TODO: check if this is HTTP/1 only - otherwise move it to netlib.http.semantics + words = s.split() + if len(words) != 2: + return None + scheme = words[0] + try: + user = binascii.a2b_base64(words[1]) + except binascii.Error: + return None + parts = user.split(':') + if len(parts) != 2: + return None + return scheme, parts[0], parts[1] + + +def assemble_http_basic_auth(scheme, username, password): + # TODO: check if this is HTTP/1 only - otherwise move it to netlib.http.semantics + v = binascii.b2a_base64(username + ":" + password) + return scheme + " " + v + + +def parse_init(line): + try: + method, url, protocol = string.split(line) + except ValueError: + return None + httpversion = parse_http_protocol(protocol) + if not httpversion: + return None + if not utils.isascii(method): + return None + return method, url, httpversion + + +def parse_init_connect(line): + """ + Returns (host, port, httpversion) if line is a valid CONNECT line. + http://tools.ietf.org/html/draft-luotonen-web-proxy-tunneling-01 section 3.1 + """ + v = parse_init(line) + if not v: + return None + method, url, httpversion = v + + if method.upper() != 'CONNECT': + return None + try: + host, port = url.split(":") + except ValueError: + return None + try: + port = int(port) + except ValueError: + return None + if not http.is_valid_port(port): + return None + if not http.is_valid_host(host): + return None + return host, port, httpversion + + +def parse_init_proxy(line): + v = parse_init(line) + if not v: + return None + method, url, httpversion = v + + parts = http.parse_url(url) + if not parts: + return None + scheme, host, port, path = parts + return method, scheme, host, port, path, httpversion + + +def parse_init_http(line): + """ + Returns (method, url, httpversion) + """ + v = parse_init(line) + if not v: + return None + method, url, httpversion = v + if not utils.isascii(url): + return None + if not (url.startswith("/") or url == "*"): + return None + return method, url, httpversion + + +def connection_close(httpversion, headers): + """ + Checks the message to see if the client connection should be closed + according to RFC 2616 Section 8.1 Note that a connection should be + closed as well if the response has been read until end of the stream. + """ + # At first, check if we have an explicit Connection header. + if "connection" in headers: + toks = get_header_tokens(headers, "connection") + if "close" in toks: + return True + elif "keep-alive" in toks: + return False + # If we don't have a Connection header, HTTP 1.1 connections are assumed to + # be persistent + if httpversion == (1, 1): + return False + return True + + +def parse_response_line(line): + parts = line.strip().split(" ", 2) + if len(parts) == 2: # handle missing message gracefully + parts.append("") + if len(parts) != 3: + return None + proto, code, msg = parts + try: + code = int(code) + except ValueError: + return None + return (proto, code, msg) + + +def read_http_body(*args, **kwargs): + return "".join( + content for _, content, _ in read_http_body_chunked(*args, **kwargs) + ) + + +def read_http_body_chunked( + rfile, + headers, + limit, + request_method, + response_code, + is_request, + max_chunk_size=None +): + """ + Read an HTTP message body: + + rfile: A file descriptor to read from + headers: An ODictCaseless object + limit: Size limit. + is_request: True if the body to read belongs to a request, False + otherwise + """ + if max_chunk_size is None: + max_chunk_size = limit or sys.maxsize + + expected_size = expected_http_body_size( + headers, is_request, request_method, response_code + ) + + if expected_size is None: + if has_chunked_encoding(headers): + # Python 3: yield from + for x in read_chunked(rfile, limit, is_request): + yield x + else: # pragma: nocover + raise HttpError( + 400 if is_request else 502, + "Content-Length unknown but no chunked encoding" + ) + elif expected_size >= 0: + if limit is not None and expected_size > limit: + raise HttpError( + 400 if is_request else 509, + "HTTP Body too large. Limit is %s, content-length was %s" % ( + limit, expected_size + ) + ) + bytes_left = expected_size + while bytes_left: + chunk_size = min(bytes_left, max_chunk_size) + yield "", rfile.read(chunk_size), "" + bytes_left -= chunk_size + else: + bytes_left = limit or -1 + while bytes_left: + chunk_size = min(bytes_left, max_chunk_size) + content = rfile.read(chunk_size) + if not content: + return + yield "", content, "" + bytes_left -= chunk_size + not_done = rfile.read(1) + if not_done: + raise HttpError( + 400 if is_request else 509, + "HTTP Body too large. Limit is %s," % limit + ) + + +def expected_http_body_size(headers, is_request, request_method, response_code): + """ + Returns the expected body length: + - a positive integer, if the size is known in advance + - None, if the size in unknown in advance (chunked encoding or invalid + data) + - -1, if all data should be read until end of stream. + + May raise HttpError. + """ + # Determine response size according to + # http://tools.ietf.org/html/rfc7230#section-3.3 + if request_method: + request_method = request_method.upper() + + if (not is_request and ( + request_method == "HEAD" or + (request_method == "CONNECT" and response_code == 200) or + response_code in [204, 304] or + 100 <= response_code <= 199)): + return 0 + if has_chunked_encoding(headers): + return None + if "content-length" in headers: + try: + size = int(headers["content-length"][0]) + if size < 0: + raise ValueError() + return size + except ValueError: + return None + if is_request: + return 0 + return -1 + + +# TODO: make this a regular class - just like Response +Request = collections.namedtuple( + "Request", + [ + "form_in", + "method", + "scheme", + "host", + "port", + "path", + "httpversion", + "headers", + "content" + ] +) + + +def read_request(rfile, include_body=True, body_size_limit=None, wfile=None): + """ + Parse an HTTP request from a file stream + + Args: + rfile (file): Input file to read from + include_body (bool): Read response body as well + body_size_limit (bool): Maximum body size + wfile (file): If specified, HTTP Expect headers are handled + automatically, by writing a HTTP 100 CONTINUE response to the stream. + + Returns: + Request: The HTTP request + + Raises: + HttpError: If the input is invalid. + """ + httpversion, host, port, scheme, method, path, headers, content = ( + None, None, None, None, None, None, None, None) + + request_line = get_request_line(rfile) + if not request_line: + raise tcp.NetLibDisconnect() + + request_line_parts = parse_init(request_line) + if not request_line_parts: + raise HttpError( + 400, + "Bad HTTP request line: %s" % repr(request_line) + ) + method, path, httpversion = request_line_parts + + if path == '*' or path.startswith("/"): + form_in = "relative" + if not utils.isascii(path): + raise HttpError( + 400, + "Bad HTTP request line: %s" % repr(request_line) + ) + elif method.upper() == 'CONNECT': + form_in = "authority" + r = parse_init_connect(request_line) + if not r: + raise HttpError( + 400, + "Bad HTTP request line: %s" % repr(request_line) + ) + host, port, _ = r + path = None + else: + form_in = "absolute" + r = parse_init_proxy(request_line) + if not r: + raise HttpError( + 400, + "Bad HTTP request line: %s" % repr(request_line) + ) + _, scheme, host, port, path, _ = r + + headers = read_headers(rfile) + if headers is None: + raise HttpError(400, "Invalid headers") + + expect_header = headers.get_first("expect", "").lower() + if expect_header == "100-continue" and httpversion >= (1, 1): + wfile.write( + 'HTTP/1.1 100 Continue\r\n' + '\r\n' + ) + wfile.flush() + del headers['expect'] + + if include_body: + content = read_http_body( + rfile, headers, body_size_limit, method, None, True + ) + + return Request( + form_in, + method, + scheme, + host, + port, + path, + httpversion, + headers, + content + ) + + +def read_response(rfile, request_method, body_size_limit, include_body=True): + """ + Returns an http.Response + + By default, both response header and body are read. + If include_body=False is specified, content may be one of the + following: + - None, if the response is technically allowed to have a response body + - "", if the response must not have a response body (e.g. it's a + response to a HEAD request) + """ + + line = rfile.readline() + # Possible leftover from previous message + if line == "\r\n" or line == "\n": + line = rfile.readline() + if not line: + raise HttpErrorConnClosed(502, "Server disconnect.") + parts = parse_response_line(line) + if not parts: + raise HttpError(502, "Invalid server response: %s" % repr(line)) + proto, code, msg = parts + httpversion = parse_http_protocol(proto) + if httpversion is None: + raise HttpError(502, "Invalid HTTP version in line: %s" % repr(proto)) + headers = read_headers(rfile) + if headers is None: + raise HttpError(502, "Invalid headers.") + + if include_body: + content = read_http_body( + rfile, + headers, + body_size_limit, + request_method, + code, + False + ) + else: + # if include_body==False then a None content means the body should be + # read separately + content = None + return http.Response(httpversion, code, msg, headers, content) + + +def request_preamble(method, resource, http_major="1", http_minor="1"): + return '%s %s HTTP/%s.%s' % ( + method, resource, http_major, http_minor + ) + + +def response_preamble(code, message=None, http_major="1", http_minor="1"): + if message is None: + message = status_codes.RESPONSES.get(code) + return 'HTTP/%s.%s %s %s' % (http_major, http_minor, code, message) diff --git a/netlib/http/http2/__init__.py b/netlib/http/http2/__init__.py new file mode 100644 index 00000000..5acf7696 --- /dev/null +++ b/netlib/http/http2/__init__.py @@ -0,0 +1,2 @@ +from frame import * +from protocol import * diff --git a/netlib/http/http2/frame.py b/netlib/http/http2/frame.py new file mode 100644 index 00000000..f7e60471 --- /dev/null +++ b/netlib/http/http2/frame.py @@ -0,0 +1,636 @@ +import sys +import struct +from hpack.hpack import Encoder, Decoder + +from .. import utils + + +class FrameSizeError(Exception): + pass + + +class Frame(object): + + """ + Baseclass Frame + contains header + payload is defined in subclasses + """ + + FLAG_NO_FLAGS = 0x0 + FLAG_ACK = 0x1 + FLAG_END_STREAM = 0x1 + FLAG_END_HEADERS = 0x4 + FLAG_PADDED = 0x8 + FLAG_PRIORITY = 0x20 + + def __init__( + self, + state=None, + length=0, + flags=FLAG_NO_FLAGS, + stream_id=0x0): + valid_flags = reduce(lambda x, y: x | y, self.VALID_FLAGS, 0x0) + if flags | valid_flags != valid_flags: + raise ValueError('invalid flags detected.') + + if state is None: + class State(object): + pass + + state = State() + state.http2_settings = HTTP2_DEFAULT_SETTINGS.copy() + state.encoder = Encoder() + state.decoder = Decoder() + + self.state = state + + self.length = length + self.type = self.TYPE + self.flags = flags + self.stream_id = stream_id + + @classmethod + def _check_frame_size(cls, length, state): + if state: + settings = state.http2_settings + else: + settings = HTTP2_DEFAULT_SETTINGS.copy() + + max_frame_size = settings[ + SettingsFrame.SETTINGS.SETTINGS_MAX_FRAME_SIZE] + + if length > max_frame_size: + raise FrameSizeError( + "Frame size exceeded: %d, but only %d allowed." % ( + length, max_frame_size)) + + @classmethod + def from_file(cls, fp, state=None): + """ + read a HTTP/2 frame sent by a server or client + fp is a "file like" object that could be backed by a network + stream or a disk or an in memory stream reader + """ + raw_header = fp.safe_read(9) + + fields = struct.unpack("!HBBBL", raw_header) + length = (fields[0] << 8) + fields[1] + flags = fields[3] + stream_id = fields[4] + + if raw_header[:4] == b'HTTP': # pragma no cover + print >> sys.stderr, "WARNING: This looks like an HTTP/1 connection!" + + cls._check_frame_size(length, state) + + payload = fp.safe_read(length) + return FRAMES[fields[2]].from_bytes( + state, + length, + flags, + stream_id, + payload) + + def to_bytes(self): + payload = self.payload_bytes() + self.length = len(payload) + + self._check_frame_size(self.length, self.state) + + b = struct.pack('!HB', self.length & 0xFFFF00, self.length & 0x0000FF) + b += struct.pack('!B', self.TYPE) + b += struct.pack('!B', self.flags) + b += struct.pack('!L', self.stream_id & 0x7FFFFFFF) + b += payload + + return b + + def payload_bytes(self): # pragma: no cover + raise NotImplementedError() + + def payload_human_readable(self): # pragma: no cover + raise NotImplementedError() + + def human_readable(self, direction="-"): + self.length = len(self.payload_bytes()) + + return "\n".join([ + "%s: %s | length: %d | flags: %#x | stream_id: %d" % ( + direction, self.__class__.__name__, self.length, self.flags, self.stream_id), + self.payload_human_readable(), + "===============================================================", + ]) + + def __eq__(self, other): + return self.to_bytes() == other.to_bytes() + + +class DataFrame(Frame): + TYPE = 0x0 + VALID_FLAGS = [Frame.FLAG_END_STREAM, Frame.FLAG_PADDED] + + def __init__( + self, + state=None, + length=0, + flags=Frame.FLAG_NO_FLAGS, + stream_id=0x0, + payload=b'', + pad_length=0): + super(DataFrame, self).__init__(state, length, flags, stream_id) + self.payload = payload + self.pad_length = pad_length + + @classmethod + def from_bytes(cls, state, length, flags, stream_id, payload): + f = cls(state=state, length=length, flags=flags, stream_id=stream_id) + + if f.flags & Frame.FLAG_PADDED: + f.pad_length = struct.unpack('!B', payload[0])[0] + f.payload = payload[1:-f.pad_length] + else: + f.payload = payload + + return f + + def payload_bytes(self): + if self.stream_id == 0x0: + raise ValueError('DATA frames MUST be associated with a stream.') + + b = b'' + if self.flags & self.FLAG_PADDED: + b += struct.pack('!B', self.pad_length) + + b += bytes(self.payload) + + if self.flags & self.FLAG_PADDED: + b += b'\0' * self.pad_length + + return b + + def payload_human_readable(self): + return "payload: %s" % str(self.payload) + + +class HeadersFrame(Frame): + TYPE = 0x1 + VALID_FLAGS = [ + Frame.FLAG_END_STREAM, + Frame.FLAG_END_HEADERS, + Frame.FLAG_PADDED, + Frame.FLAG_PRIORITY] + + def __init__( + self, + state=None, + length=0, + flags=Frame.FLAG_NO_FLAGS, + stream_id=0x0, + header_block_fragment=b'', + pad_length=0, + exclusive=False, + stream_dependency=0x0, + weight=0): + super(HeadersFrame, self).__init__(state, length, flags, stream_id) + + self.header_block_fragment = header_block_fragment + self.pad_length = pad_length + self.exclusive = exclusive + self.stream_dependency = stream_dependency + self.weight = weight + + @classmethod + def from_bytes(cls, state, length, flags, stream_id, payload): + f = cls(state=state, length=length, flags=flags, stream_id=stream_id) + + if f.flags & Frame.FLAG_PADDED: + f.pad_length = struct.unpack('!B', payload[0])[0] + f.header_block_fragment = payload[1:-f.pad_length] + else: + f.header_block_fragment = payload[0:] + + if f.flags & Frame.FLAG_PRIORITY: + f.stream_dependency, f.weight = struct.unpack( + '!LB', f.header_block_fragment[:5]) + f.exclusive = bool(f.stream_dependency >> 31) + f.stream_dependency &= 0x7FFFFFFF + f.header_block_fragment = f.header_block_fragment[5:] + + return f + + def payload_bytes(self): + if self.stream_id == 0x0: + raise ValueError('HEADERS frames MUST be associated with a stream.') + + b = b'' + if self.flags & self.FLAG_PADDED: + b += struct.pack('!B', self.pad_length) + + if self.flags & self.FLAG_PRIORITY: + b += struct.pack('!LB', + (int(self.exclusive) << 31) | self.stream_dependency, + self.weight) + + b += self.header_block_fragment + + if self.flags & self.FLAG_PADDED: + b += b'\0' * self.pad_length + + return b + + def payload_human_readable(self): + s = [] + + if self.flags & self.FLAG_PRIORITY: + s.append("exclusive: %d" % self.exclusive) + s.append("stream dependency: %#x" % self.stream_dependency) + s.append("weight: %d" % self.weight) + + if self.flags & self.FLAG_PADDED: + s.append("padding: %d" % self.pad_length) + + s.append( + "header_block_fragment: %s" % + self.header_block_fragment.encode('hex')) + + return "\n".join(s) + + +class PriorityFrame(Frame): + TYPE = 0x2 + VALID_FLAGS = [] + + def __init__( + self, + state=None, + length=0, + flags=Frame.FLAG_NO_FLAGS, + stream_id=0x0, + exclusive=False, + stream_dependency=0x0, + weight=0): + super(PriorityFrame, self).__init__(state, length, flags, stream_id) + self.exclusive = exclusive + self.stream_dependency = stream_dependency + self.weight = weight + + @classmethod + def from_bytes(cls, state, length, flags, stream_id, payload): + f = cls(state=state, length=length, flags=flags, stream_id=stream_id) + + f.stream_dependency, f.weight = struct.unpack('!LB', payload) + f.exclusive = bool(f.stream_dependency >> 31) + f.stream_dependency &= 0x7FFFFFFF + + return f + + def payload_bytes(self): + if self.stream_id == 0x0: + raise ValueError( + 'PRIORITY frames MUST be associated with a stream.') + + if self.stream_dependency == 0x0: + raise ValueError('stream dependency is invalid.') + + return struct.pack( + '!LB', + (int( + self.exclusive) << 31) | self.stream_dependency, + self.weight) + + def payload_human_readable(self): + s = [] + s.append("exclusive: %d" % self.exclusive) + s.append("stream dependency: %#x" % self.stream_dependency) + s.append("weight: %d" % self.weight) + return "\n".join(s) + + +class RstStreamFrame(Frame): + TYPE = 0x3 + VALID_FLAGS = [] + + def __init__( + self, + state=None, + length=0, + flags=Frame.FLAG_NO_FLAGS, + stream_id=0x0, + error_code=0x0): + super(RstStreamFrame, self).__init__(state, length, flags, stream_id) + self.error_code = error_code + + @classmethod + def from_bytes(cls, state, length, flags, stream_id, payload): + f = cls(state=state, length=length, flags=flags, stream_id=stream_id) + f.error_code = struct.unpack('!L', payload)[0] + return f + + def payload_bytes(self): + if self.stream_id == 0x0: + raise ValueError( + 'RST_STREAM frames MUST be associated with a stream.') + + return struct.pack('!L', self.error_code) + + def payload_human_readable(self): + return "error code: %#x" % self.error_code + + +class SettingsFrame(Frame): + TYPE = 0x4 + VALID_FLAGS = [Frame.FLAG_ACK] + + SETTINGS = utils.BiDi( + SETTINGS_HEADER_TABLE_SIZE=0x1, + SETTINGS_ENABLE_PUSH=0x2, + SETTINGS_MAX_CONCURRENT_STREAMS=0x3, + SETTINGS_INITIAL_WINDOW_SIZE=0x4, + SETTINGS_MAX_FRAME_SIZE=0x5, + SETTINGS_MAX_HEADER_LIST_SIZE=0x6, + ) + + def __init__( + self, + state=None, + length=0, + flags=Frame.FLAG_NO_FLAGS, + stream_id=0x0, + settings=None): + super(SettingsFrame, self).__init__(state, length, flags, stream_id) + + if settings is None: + settings = {} + + self.settings = settings + + @classmethod + def from_bytes(cls, state, length, flags, stream_id, payload): + f = cls(state=state, length=length, flags=flags, stream_id=stream_id) + + for i in xrange(0, len(payload), 6): + identifier, value = struct.unpack("!HL", payload[i:i + 6]) + f.settings[identifier] = value + + return f + + def payload_bytes(self): + if self.stream_id != 0x0: + raise ValueError( + 'SETTINGS frames MUST NOT be associated with a stream.') + + b = b'' + for identifier, value in self.settings.items(): + b += struct.pack("!HL", identifier & 0xFF, value) + + return b + + def payload_human_readable(self): + s = [] + + for identifier, value in self.settings.items(): + s.append("%s: %#x" % (self.SETTINGS.get_name(identifier), value)) + + if not s: + return "settings: None" + else: + return "\n".join(s) + + +class PushPromiseFrame(Frame): + TYPE = 0x5 + VALID_FLAGS = [Frame.FLAG_END_HEADERS, Frame.FLAG_PADDED] + + def __init__( + self, + state=None, + length=0, + flags=Frame.FLAG_NO_FLAGS, + stream_id=0x0, + promised_stream=0x0, + header_block_fragment=b'', + pad_length=0): + super(PushPromiseFrame, self).__init__(state, length, flags, stream_id) + self.pad_length = pad_length + self.promised_stream = promised_stream + self.header_block_fragment = header_block_fragment + + @classmethod + def from_bytes(cls, state, length, flags, stream_id, payload): + f = cls(state=state, length=length, flags=flags, stream_id=stream_id) + + if f.flags & Frame.FLAG_PADDED: + f.pad_length, f.promised_stream = struct.unpack('!BL', payload[:5]) + f.header_block_fragment = payload[5:-f.pad_length] + else: + f.promised_stream = int(struct.unpack("!L", payload[:4])[0]) + f.header_block_fragment = payload[4:] + + f.promised_stream &= 0x7FFFFFFF + + return f + + def payload_bytes(self): + if self.stream_id == 0x0: + raise ValueError( + 'PUSH_PROMISE frames MUST be associated with a stream.') + + if self.promised_stream == 0x0: + raise ValueError('Promised stream id not valid.') + + b = b'' + if self.flags & self.FLAG_PADDED: + b += struct.pack('!B', self.pad_length) + + b += struct.pack('!L', self.promised_stream & 0x7FFFFFFF) + b += bytes(self.header_block_fragment) + + if self.flags & self.FLAG_PADDED: + b += b'\0' * self.pad_length + + return b + + def payload_human_readable(self): + s = [] + + if self.flags & self.FLAG_PADDED: + s.append("padding: %d" % self.pad_length) + + s.append("promised stream: %#x" % self.promised_stream) + s.append( + "header_block_fragment: %s" % + self.header_block_fragment.encode('hex')) + + return "\n".join(s) + + +class PingFrame(Frame): + TYPE = 0x6 + VALID_FLAGS = [Frame.FLAG_ACK] + + def __init__( + self, + state=None, + length=0, + flags=Frame.FLAG_NO_FLAGS, + stream_id=0x0, + payload=b''): + super(PingFrame, self).__init__(state, length, flags, stream_id) + self.payload = payload + + @classmethod + def from_bytes(cls, state, length, flags, stream_id, payload): + f = cls(state=state, length=length, flags=flags, stream_id=stream_id) + f.payload = payload + return f + + def payload_bytes(self): + if self.stream_id != 0x0: + raise ValueError( + 'PING frames MUST NOT be associated with a stream.') + + b = self.payload[0:8] + b += b'\0' * (8 - len(b)) + return b + + def payload_human_readable(self): + return "opaque data: %s" % str(self.payload) + + +class GoAwayFrame(Frame): + TYPE = 0x7 + VALID_FLAGS = [] + + def __init__( + self, + state=None, + length=0, + flags=Frame.FLAG_NO_FLAGS, + stream_id=0x0, + last_stream=0x0, + error_code=0x0, + data=b''): + super(GoAwayFrame, self).__init__(state, length, flags, stream_id) + self.last_stream = last_stream + self.error_code = error_code + self.data = data + + @classmethod + def from_bytes(cls, state, length, flags, stream_id, payload): + f = cls(state=state, length=length, flags=flags, stream_id=stream_id) + + f.last_stream, f.error_code = struct.unpack("!LL", payload[:8]) + f.last_stream &= 0x7FFFFFFF + f.data = payload[8:] + + return f + + def payload_bytes(self): + if self.stream_id != 0x0: + raise ValueError( + 'GOAWAY frames MUST NOT be associated with a stream.') + + b = struct.pack('!LL', self.last_stream & 0x7FFFFFFF, self.error_code) + b += bytes(self.data) + return b + + def payload_human_readable(self): + s = [] + s.append("last stream: %#x" % self.last_stream) + s.append("error code: %d" % self.error_code) + s.append("debug data: %s" % str(self.data)) + return "\n".join(s) + + +class WindowUpdateFrame(Frame): + TYPE = 0x8 + VALID_FLAGS = [] + + def __init__( + self, + state=None, + length=0, + flags=Frame.FLAG_NO_FLAGS, + stream_id=0x0, + window_size_increment=0x0): + super(WindowUpdateFrame, self).__init__(state, length, flags, stream_id) + self.window_size_increment = window_size_increment + + @classmethod + def from_bytes(cls, state, length, flags, stream_id, payload): + f = cls(state=state, length=length, flags=flags, stream_id=stream_id) + + f.window_size_increment = struct.unpack("!L", payload)[0] + f.window_size_increment &= 0x7FFFFFFF + + return f + + def payload_bytes(self): + if self.window_size_increment <= 0 or self.window_size_increment >= 2 ** 31: + raise ValueError( + 'Window Szie Increment MUST be greater than 0 and less than 2^31.') + + return struct.pack('!L', self.window_size_increment & 0x7FFFFFFF) + + def payload_human_readable(self): + return "window size increment: %#x" % self.window_size_increment + + +class ContinuationFrame(Frame): + TYPE = 0x9 + VALID_FLAGS = [Frame.FLAG_END_HEADERS] + + def __init__( + self, + state=None, + length=0, + flags=Frame.FLAG_NO_FLAGS, + stream_id=0x0, + header_block_fragment=b''): + super(ContinuationFrame, self).__init__(state, length, flags, stream_id) + self.header_block_fragment = header_block_fragment + + @classmethod + def from_bytes(cls, state, length, flags, stream_id, payload): + f = cls(state=state, length=length, flags=flags, stream_id=stream_id) + f.header_block_fragment = payload + return f + + def payload_bytes(self): + if self.stream_id == 0x0: + raise ValueError( + 'CONTINUATION frames MUST be associated with a stream.') + + return self.header_block_fragment + + def payload_human_readable(self): + s = [] + s.append( + "header_block_fragment: %s" % + self.header_block_fragment.encode('hex')) + return "\n".join(s) + +_FRAME_CLASSES = [ + DataFrame, + HeadersFrame, + PriorityFrame, + RstStreamFrame, + SettingsFrame, + PushPromiseFrame, + PingFrame, + GoAwayFrame, + WindowUpdateFrame, + ContinuationFrame +] +FRAMES = {cls.TYPE: cls for cls in _FRAME_CLASSES} + + +HTTP2_DEFAULT_SETTINGS = { + SettingsFrame.SETTINGS.SETTINGS_HEADER_TABLE_SIZE: 4096, + SettingsFrame.SETTINGS.SETTINGS_ENABLE_PUSH: 1, + SettingsFrame.SETTINGS.SETTINGS_MAX_CONCURRENT_STREAMS: None, + SettingsFrame.SETTINGS.SETTINGS_INITIAL_WINDOW_SIZE: 2 ** 16 - 1, + SettingsFrame.SETTINGS.SETTINGS_MAX_FRAME_SIZE: 2 ** 14, + SettingsFrame.SETTINGS.SETTINGS_MAX_HEADER_LIST_SIZE: None, +} diff --git a/netlib/http/http2/protocol.py b/netlib/http/http2/protocol.py new file mode 100644 index 00000000..8e5f5429 --- /dev/null +++ b/netlib/http/http2/protocol.py @@ -0,0 +1,240 @@ +from __future__ import (absolute_import, print_function, division) +import itertools + +from hpack.hpack import Encoder, Decoder +from .. import utils +from . import frame + + +class HTTP2Protocol(object): + + ERROR_CODES = utils.BiDi( + NO_ERROR=0x0, + PROTOCOL_ERROR=0x1, + INTERNAL_ERROR=0x2, + FLOW_CONTROL_ERROR=0x3, + SETTINGS_TIMEOUT=0x4, + STREAM_CLOSED=0x5, + FRAME_SIZE_ERROR=0x6, + REFUSED_STREAM=0x7, + CANCEL=0x8, + COMPRESSION_ERROR=0x9, + CONNECT_ERROR=0xa, + ENHANCE_YOUR_CALM=0xb, + INADEQUATE_SECURITY=0xc, + HTTP_1_1_REQUIRED=0xd + ) + + # "PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n" + CLIENT_CONNECTION_PREFACE =\ + '505249202a20485454502f322e300d0a0d0a534d0d0a0d0a'.decode('hex') + + ALPN_PROTO_H2 = 'h2' + + def __init__(self, tcp_handler, is_server=False, dump_frames=False): + self.tcp_handler = tcp_handler + self.is_server = is_server + + self.http2_settings = frame.HTTP2_DEFAULT_SETTINGS.copy() + self.current_stream_id = None + self.encoder = Encoder() + self.decoder = Decoder() + self.connection_preface_performed = False + self.dump_frames = dump_frames + + def check_alpn(self): + alp = self.tcp_handler.get_alpn_proto_negotiated() + if alp != self.ALPN_PROTO_H2: + raise NotImplementedError( + "HTTP2Protocol can not handle unknown ALP: %s" % alp) + return True + + def _receive_settings(self, hide=False): + while True: + frm = self.read_frame(hide) + if isinstance(frm, frame.SettingsFrame): + break + + def _read_settings_ack(self, hide=False): # pragma no cover + while True: + frm = self.read_frame(hide) + if isinstance(frm, frame.SettingsFrame): + assert frm.flags & frame.Frame.FLAG_ACK + assert len(frm.settings) == 0 + break + + def perform_server_connection_preface(self, force=False): + if force or not self.connection_preface_performed: + self.connection_preface_performed = True + + magic_length = len(self.CLIENT_CONNECTION_PREFACE) + magic = self.tcp_handler.rfile.safe_read(magic_length) + assert magic == self.CLIENT_CONNECTION_PREFACE + + self.send_frame(frame.SettingsFrame(state=self), hide=True) + self._receive_settings(hide=True) + + def perform_client_connection_preface(self, force=False): + if force or not self.connection_preface_performed: + self.connection_preface_performed = True + + self.tcp_handler.wfile.write(self.CLIENT_CONNECTION_PREFACE) + + self.send_frame(frame.SettingsFrame(state=self), hide=True) + self._receive_settings(hide=True) + + def next_stream_id(self): + if self.current_stream_id is None: + if self.is_server: + # servers must use even stream ids + self.current_stream_id = 2 + else: + # clients must use odd stream ids + self.current_stream_id = 1 + else: + self.current_stream_id += 2 + return self.current_stream_id + + def send_frame(self, frm, hide=False): + raw_bytes = frm.to_bytes() + self.tcp_handler.wfile.write(raw_bytes) + self.tcp_handler.wfile.flush() + if not hide and self.dump_frames: # pragma no cover + print(frm.human_readable(">>")) + + def read_frame(self, hide=False): + frm = frame.Frame.from_file(self.tcp_handler.rfile, self) + if not hide and self.dump_frames: # pragma no cover + print(frm.human_readable("<<")) + if isinstance(frm, frame.SettingsFrame) and not frm.flags & frame.Frame.FLAG_ACK: + self._apply_settings(frm.settings, hide) + + return frm + + def _apply_settings(self, settings, hide=False): + for setting, value in settings.items(): + old_value = self.http2_settings[setting] + if not old_value: + old_value = '-' + self.http2_settings[setting] = value + + frm = frame.SettingsFrame( + state=self, + flags=frame.Frame.FLAG_ACK) + self.send_frame(frm, hide) + + # be liberal in what we expect from the other end + # to be more strict use: self._read_settings_ack(hide) + + def _create_headers(self, headers, stream_id, end_stream=True): + # TODO: implement max frame size checks and sending in chunks + + flags = frame.Frame.FLAG_END_HEADERS + if end_stream: + flags |= frame.Frame.FLAG_END_STREAM + + header_block_fragment = self.encoder.encode(headers) + + frm = frame.HeadersFrame( + state=self, + flags=flags, + stream_id=stream_id, + header_block_fragment=header_block_fragment) + + if self.dump_frames: # pragma no cover + print(frm.human_readable(">>")) + + return [frm.to_bytes()] + + def _create_body(self, body, stream_id): + if body is None or len(body) == 0: + return b'' + + # TODO: implement max frame size checks and sending in chunks + # TODO: implement flow-control window + + frm = frame.DataFrame( + state=self, + flags=frame.Frame.FLAG_END_STREAM, + stream_id=stream_id, + payload=body) + + if self.dump_frames: # pragma no cover + print(frm.human_readable(">>")) + + return [frm.to_bytes()] + + + def create_request(self, method, path, headers=None, body=None): + if headers is None: + headers = [] + + authority = self.tcp_handler.sni if self.tcp_handler.sni else self.tcp_handler.address.host + if self.tcp_handler.address.port != 443: + authority += ":%d" % self.tcp_handler.address.port + + headers = [ + (b':method', bytes(method)), + (b':path', bytes(path)), + (b':scheme', b'https'), + (b':authority', authority), + ] + headers + + stream_id = self.next_stream_id() + + return list(itertools.chain( + self._create_headers(headers, stream_id, end_stream=(body is None)), + self._create_body(body, stream_id))) + + def read_response(self): + stream_id_, headers, body = self._receive_transmission() + return headers[':status'], headers, body + + def read_request(self): + return self._receive_transmission() + + def _receive_transmission(self): + body_expected = True + + stream_id = 0 + header_block_fragment = b'' + body = b'' + + while True: + frm = self.read_frame() + if isinstance(frm, frame.HeadersFrame)\ + or isinstance(frm, frame.ContinuationFrame): + stream_id = frm.stream_id + header_block_fragment += frm.header_block_fragment + if frm.flags & frame.Frame.FLAG_END_STREAM: + body_expected = False + if frm.flags & frame.Frame.FLAG_END_HEADERS: + break + + while body_expected: + frm = self.read_frame() + if isinstance(frm, frame.DataFrame): + body += frm.payload + if frm.flags & frame.Frame.FLAG_END_STREAM: + break + # TODO: implement window update & flow + + headers = {} + for header, value in self.decoder.decode(header_block_fragment): + headers[header] = value + + return stream_id, headers, body + + def create_response(self, code, stream_id=None, headers=None, body=None): + if headers is None: + headers = [] + + headers = [(b':status', bytes(str(code)))] + headers + + if not stream_id: + stream_id = self.next_stream_id() + + return list(itertools.chain( + self._create_headers(headers, stream_id, end_stream=(body is None)), + self._create_body(body, stream_id), + )) diff --git a/netlib/http/semantics.py b/netlib/http/semantics.py new file mode 100644 index 00000000..e7e84fe3 --- /dev/null +++ b/netlib/http/semantics.py @@ -0,0 +1,94 @@ +from __future__ import (absolute_import, print_function, division) +import binascii +import collections +import string +import sys +import urlparse + +from .. import utils + +class Response(object): + + def __init__( + self, + httpversion, + status_code, + msg, + headers, + content, + sslinfo=None, + ): + self.httpversion = httpversion + self.status_code = status_code + self.msg = msg + self.headers = headers + self.content = content + self.sslinfo = sslinfo + + def __eq__(self, other): + return self.__dict__ == other.__dict__ + + def __repr__(self): + return "Response(%s - %s)" % (self.status_code, self.msg) + + + +def is_valid_port(port): + if not 0 <= port <= 65535: + return False + return True + + +def is_valid_host(host): + try: + host.decode("idna") + except ValueError: + return False + if "\0" in host: + return None + return True + + + +def parse_url(url): + """ + Returns a (scheme, host, port, path) tuple, or None on error. + + Checks that: + port is an integer 0-65535 + host is a valid IDNA-encoded hostname with no null-bytes + path is valid ASCII + """ + try: + scheme, netloc, path, params, query, fragment = urlparse.urlparse(url) + except ValueError: + return None + if not scheme: + return None + if '@' in netloc: + # FIXME: Consider what to do with the discarded credentials here Most + # probably we should extend the signature to return these as a separate + # value. + _, netloc = string.rsplit(netloc, '@', maxsplit=1) + if ':' in netloc: + host, port = string.rsplit(netloc, ':', maxsplit=1) + try: + port = int(port) + except ValueError: + return None + else: + host = netloc + if scheme == "https": + port = 443 + else: + port = 80 + path = urlparse.urlunparse(('', '', path, params, query, fragment)) + if not path.startswith("/"): + path = "/" + path + if not is_valid_host(host): + return None + if not utils.isascii(path): + return None + if not is_valid_port(port): + return None + return scheme, host, port, path diff --git a/netlib/http/status_codes.py b/netlib/http/status_codes.py new file mode 100644 index 00000000..dc09f465 --- /dev/null +++ b/netlib/http/status_codes.py @@ -0,0 +1,104 @@ +from __future__ import (absolute_import, print_function, division) + +CONTINUE = 100 +SWITCHING = 101 +OK = 200 +CREATED = 201 +ACCEPTED = 202 +NON_AUTHORITATIVE_INFORMATION = 203 +NO_CONTENT = 204 +RESET_CONTENT = 205 +PARTIAL_CONTENT = 206 +MULTI_STATUS = 207 + +MULTIPLE_CHOICE = 300 +MOVED_PERMANENTLY = 301 +FOUND = 302 +SEE_OTHER = 303 +NOT_MODIFIED = 304 +USE_PROXY = 305 +TEMPORARY_REDIRECT = 307 + +BAD_REQUEST = 400 +UNAUTHORIZED = 401 +PAYMENT_REQUIRED = 402 +FORBIDDEN = 403 +NOT_FOUND = 404 +NOT_ALLOWED = 405 +NOT_ACCEPTABLE = 406 +PROXY_AUTH_REQUIRED = 407 +REQUEST_TIMEOUT = 408 +CONFLICT = 409 +GONE = 410 +LENGTH_REQUIRED = 411 +PRECONDITION_FAILED = 412 +REQUEST_ENTITY_TOO_LARGE = 413 +REQUEST_URI_TOO_LONG = 414 +UNSUPPORTED_MEDIA_TYPE = 415 +REQUESTED_RANGE_NOT_SATISFIABLE = 416 +EXPECTATION_FAILED = 417 + +INTERNAL_SERVER_ERROR = 500 +NOT_IMPLEMENTED = 501 +BAD_GATEWAY = 502 +SERVICE_UNAVAILABLE = 503 +GATEWAY_TIMEOUT = 504 +HTTP_VERSION_NOT_SUPPORTED = 505 +INSUFFICIENT_STORAGE_SPACE = 507 +NOT_EXTENDED = 510 + +RESPONSES = { + # 100 + CONTINUE: "Continue", + SWITCHING: "Switching Protocols", + + # 200 + OK: "OK", + CREATED: "Created", + ACCEPTED: "Accepted", + NON_AUTHORITATIVE_INFORMATION: "Non-Authoritative Information", + NO_CONTENT: "No Content", + RESET_CONTENT: "Reset Content.", + PARTIAL_CONTENT: "Partial Content", + MULTI_STATUS: "Multi-Status", + + # 300 + MULTIPLE_CHOICE: "Multiple Choices", + MOVED_PERMANENTLY: "Moved Permanently", + FOUND: "Found", + SEE_OTHER: "See Other", + NOT_MODIFIED: "Not Modified", + USE_PROXY: "Use Proxy", + # 306 not defined?? + TEMPORARY_REDIRECT: "Temporary Redirect", + + # 400 + BAD_REQUEST: "Bad Request", + UNAUTHORIZED: "Unauthorized", + PAYMENT_REQUIRED: "Payment Required", + FORBIDDEN: "Forbidden", + NOT_FOUND: "Not Found", + NOT_ALLOWED: "Method Not Allowed", + NOT_ACCEPTABLE: "Not Acceptable", + PROXY_AUTH_REQUIRED: "Proxy Authentication Required", + REQUEST_TIMEOUT: "Request Time-out", + CONFLICT: "Conflict", + GONE: "Gone", + LENGTH_REQUIRED: "Length Required", + PRECONDITION_FAILED: "Precondition Failed", + REQUEST_ENTITY_TOO_LARGE: "Request Entity Too Large", + REQUEST_URI_TOO_LONG: "Request-URI Too Long", + UNSUPPORTED_MEDIA_TYPE: "Unsupported Media Type", + REQUESTED_RANGE_NOT_SATISFIABLE: "Requested Range not satisfiable", + EXPECTATION_FAILED: "Expectation Failed", + + # 500 + INTERNAL_SERVER_ERROR: "Internal Server Error", + NOT_IMPLEMENTED: "Not Implemented", + BAD_GATEWAY: "Bad Gateway", + SERVICE_UNAVAILABLE: "Service Unavailable", + GATEWAY_TIMEOUT: "Gateway Time-out", + HTTP_VERSION_NOT_SUPPORTED: "HTTP Version not supported", + INSUFFICIENT_STORAGE_SPACE: "Insufficient Storage Space", + NOT_EXTENDED: "Not Extended" +} diff --git a/netlib/http/user_agents.py b/netlib/http/user_agents.py new file mode 100644 index 00000000..e8681908 --- /dev/null +++ b/netlib/http/user_agents.py @@ -0,0 +1,52 @@ +from __future__ import (absolute_import, print_function, division) + +""" + A small collection of useful user-agent header strings. These should be + kept reasonably current to reflect common usage. +""" + +# pylint: line-too-long + +# A collection of (name, shortcut, string) tuples. + +UASTRINGS = [ + ("android", + "a", + "Mozilla/5.0 (Linux; U; Android 4.1.1; en-gb; Nexus 7 Build/JRO03D) AFL/01.04.02"), # noqa + ("blackberry", + "l", + "Mozilla/5.0 (BlackBerry; U; BlackBerry 9900; en) AppleWebKit/534.11+ (KHTML, like Gecko) Version/7.1.0.346 Mobile Safari/534.11+"), # noqa + ("bingbot", + "b", + "Mozilla/5.0 (compatible; bingbot/2.0; +http://www.bing.com/bingbot.htm)"), # noqa + ("chrome", + "c", + "Mozilla/5.0 (Windows NT 6.1; WOW64) AppleWebKit/537.1 (KHTML, like Gecko) Chrome/22.0.1207.1 Safari/537.1"), # noqa + ("firefox", + "f", + "Mozilla/5.0 (Windows NT 6.1; Win64; x64; rv:14.0) Gecko/20120405 Firefox/14.0a1"), # noqa + ("googlebot", + "g", + "Googlebot/2.1 (+http://www.googlebot.com/bot.html)"), # noqa + ("ie9", + "i", + "Mozilla/5.0 (Windows; U; MSIE 9.0; WIndows NT 9.0; en-US)"), # noqa + ("ipad", + "p", + "Mozilla/5.0 (iPad; CPU OS 5_1 like Mac OS X) AppleWebKit/534.46 (KHTML, like Gecko) Version/5.1 Mobile/9B176 Safari/7534.48.3"), # noqa + ("iphone", + "h", + "Mozilla/5.0 (iPhone; CPU iPhone OS 4_2_1 like Mac OS X) AppleWebKit/533.17.9 (KHTML, like Gecko) Version/5.0.2 Mobile/8C148a Safari/6533.18.5"), # noqa + ("safari", + "s", + "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_7_3) AppleWebKit/534.55.3 (KHTML, like Gecko) Version/5.1.3 Safari/534.53.10"), # noqa +] + + +def get_by_shortcut(s): + """ + Retrieve a user agent entry by shortcut. + """ + for i in UASTRINGS: + if s == i[1]: + return i diff --git a/netlib/http2/__init__.py b/netlib/http2/__init__.py deleted file mode 100644 index 5acf7696..00000000 --- a/netlib/http2/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from frame import * -from protocol import * diff --git a/netlib/http2/frame.py b/netlib/http2/frame.py deleted file mode 100644 index f7e60471..00000000 --- a/netlib/http2/frame.py +++ /dev/null @@ -1,636 +0,0 @@ -import sys -import struct -from hpack.hpack import Encoder, Decoder - -from .. import utils - - -class FrameSizeError(Exception): - pass - - -class Frame(object): - - """ - Baseclass Frame - contains header - payload is defined in subclasses - """ - - FLAG_NO_FLAGS = 0x0 - FLAG_ACK = 0x1 - FLAG_END_STREAM = 0x1 - FLAG_END_HEADERS = 0x4 - FLAG_PADDED = 0x8 - FLAG_PRIORITY = 0x20 - - def __init__( - self, - state=None, - length=0, - flags=FLAG_NO_FLAGS, - stream_id=0x0): - valid_flags = reduce(lambda x, y: x | y, self.VALID_FLAGS, 0x0) - if flags | valid_flags != valid_flags: - raise ValueError('invalid flags detected.') - - if state is None: - class State(object): - pass - - state = State() - state.http2_settings = HTTP2_DEFAULT_SETTINGS.copy() - state.encoder = Encoder() - state.decoder = Decoder() - - self.state = state - - self.length = length - self.type = self.TYPE - self.flags = flags - self.stream_id = stream_id - - @classmethod - def _check_frame_size(cls, length, state): - if state: - settings = state.http2_settings - else: - settings = HTTP2_DEFAULT_SETTINGS.copy() - - max_frame_size = settings[ - SettingsFrame.SETTINGS.SETTINGS_MAX_FRAME_SIZE] - - if length > max_frame_size: - raise FrameSizeError( - "Frame size exceeded: %d, but only %d allowed." % ( - length, max_frame_size)) - - @classmethod - def from_file(cls, fp, state=None): - """ - read a HTTP/2 frame sent by a server or client - fp is a "file like" object that could be backed by a network - stream or a disk or an in memory stream reader - """ - raw_header = fp.safe_read(9) - - fields = struct.unpack("!HBBBL", raw_header) - length = (fields[0] << 8) + fields[1] - flags = fields[3] - stream_id = fields[4] - - if raw_header[:4] == b'HTTP': # pragma no cover - print >> sys.stderr, "WARNING: This looks like an HTTP/1 connection!" - - cls._check_frame_size(length, state) - - payload = fp.safe_read(length) - return FRAMES[fields[2]].from_bytes( - state, - length, - flags, - stream_id, - payload) - - def to_bytes(self): - payload = self.payload_bytes() - self.length = len(payload) - - self._check_frame_size(self.length, self.state) - - b = struct.pack('!HB', self.length & 0xFFFF00, self.length & 0x0000FF) - b += struct.pack('!B', self.TYPE) - b += struct.pack('!B', self.flags) - b += struct.pack('!L', self.stream_id & 0x7FFFFFFF) - b += payload - - return b - - def payload_bytes(self): # pragma: no cover - raise NotImplementedError() - - def payload_human_readable(self): # pragma: no cover - raise NotImplementedError() - - def human_readable(self, direction="-"): - self.length = len(self.payload_bytes()) - - return "\n".join([ - "%s: %s | length: %d | flags: %#x | stream_id: %d" % ( - direction, self.__class__.__name__, self.length, self.flags, self.stream_id), - self.payload_human_readable(), - "===============================================================", - ]) - - def __eq__(self, other): - return self.to_bytes() == other.to_bytes() - - -class DataFrame(Frame): - TYPE = 0x0 - VALID_FLAGS = [Frame.FLAG_END_STREAM, Frame.FLAG_PADDED] - - def __init__( - self, - state=None, - length=0, - flags=Frame.FLAG_NO_FLAGS, - stream_id=0x0, - payload=b'', - pad_length=0): - super(DataFrame, self).__init__(state, length, flags, stream_id) - self.payload = payload - self.pad_length = pad_length - - @classmethod - def from_bytes(cls, state, length, flags, stream_id, payload): - f = cls(state=state, length=length, flags=flags, stream_id=stream_id) - - if f.flags & Frame.FLAG_PADDED: - f.pad_length = struct.unpack('!B', payload[0])[0] - f.payload = payload[1:-f.pad_length] - else: - f.payload = payload - - return f - - def payload_bytes(self): - if self.stream_id == 0x0: - raise ValueError('DATA frames MUST be associated with a stream.') - - b = b'' - if self.flags & self.FLAG_PADDED: - b += struct.pack('!B', self.pad_length) - - b += bytes(self.payload) - - if self.flags & self.FLAG_PADDED: - b += b'\0' * self.pad_length - - return b - - def payload_human_readable(self): - return "payload: %s" % str(self.payload) - - -class HeadersFrame(Frame): - TYPE = 0x1 - VALID_FLAGS = [ - Frame.FLAG_END_STREAM, - Frame.FLAG_END_HEADERS, - Frame.FLAG_PADDED, - Frame.FLAG_PRIORITY] - - def __init__( - self, - state=None, - length=0, - flags=Frame.FLAG_NO_FLAGS, - stream_id=0x0, - header_block_fragment=b'', - pad_length=0, - exclusive=False, - stream_dependency=0x0, - weight=0): - super(HeadersFrame, self).__init__(state, length, flags, stream_id) - - self.header_block_fragment = header_block_fragment - self.pad_length = pad_length - self.exclusive = exclusive - self.stream_dependency = stream_dependency - self.weight = weight - - @classmethod - def from_bytes(cls, state, length, flags, stream_id, payload): - f = cls(state=state, length=length, flags=flags, stream_id=stream_id) - - if f.flags & Frame.FLAG_PADDED: - f.pad_length = struct.unpack('!B', payload[0])[0] - f.header_block_fragment = payload[1:-f.pad_length] - else: - f.header_block_fragment = payload[0:] - - if f.flags & Frame.FLAG_PRIORITY: - f.stream_dependency, f.weight = struct.unpack( - '!LB', f.header_block_fragment[:5]) - f.exclusive = bool(f.stream_dependency >> 31) - f.stream_dependency &= 0x7FFFFFFF - f.header_block_fragment = f.header_block_fragment[5:] - - return f - - def payload_bytes(self): - if self.stream_id == 0x0: - raise ValueError('HEADERS frames MUST be associated with a stream.') - - b = b'' - if self.flags & self.FLAG_PADDED: - b += struct.pack('!B', self.pad_length) - - if self.flags & self.FLAG_PRIORITY: - b += struct.pack('!LB', - (int(self.exclusive) << 31) | self.stream_dependency, - self.weight) - - b += self.header_block_fragment - - if self.flags & self.FLAG_PADDED: - b += b'\0' * self.pad_length - - return b - - def payload_human_readable(self): - s = [] - - if self.flags & self.FLAG_PRIORITY: - s.append("exclusive: %d" % self.exclusive) - s.append("stream dependency: %#x" % self.stream_dependency) - s.append("weight: %d" % self.weight) - - if self.flags & self.FLAG_PADDED: - s.append("padding: %d" % self.pad_length) - - s.append( - "header_block_fragment: %s" % - self.header_block_fragment.encode('hex')) - - return "\n".join(s) - - -class PriorityFrame(Frame): - TYPE = 0x2 - VALID_FLAGS = [] - - def __init__( - self, - state=None, - length=0, - flags=Frame.FLAG_NO_FLAGS, - stream_id=0x0, - exclusive=False, - stream_dependency=0x0, - weight=0): - super(PriorityFrame, self).__init__(state, length, flags, stream_id) - self.exclusive = exclusive - self.stream_dependency = stream_dependency - self.weight = weight - - @classmethod - def from_bytes(cls, state, length, flags, stream_id, payload): - f = cls(state=state, length=length, flags=flags, stream_id=stream_id) - - f.stream_dependency, f.weight = struct.unpack('!LB', payload) - f.exclusive = bool(f.stream_dependency >> 31) - f.stream_dependency &= 0x7FFFFFFF - - return f - - def payload_bytes(self): - if self.stream_id == 0x0: - raise ValueError( - 'PRIORITY frames MUST be associated with a stream.') - - if self.stream_dependency == 0x0: - raise ValueError('stream dependency is invalid.') - - return struct.pack( - '!LB', - (int( - self.exclusive) << 31) | self.stream_dependency, - self.weight) - - def payload_human_readable(self): - s = [] - s.append("exclusive: %d" % self.exclusive) - s.append("stream dependency: %#x" % self.stream_dependency) - s.append("weight: %d" % self.weight) - return "\n".join(s) - - -class RstStreamFrame(Frame): - TYPE = 0x3 - VALID_FLAGS = [] - - def __init__( - self, - state=None, - length=0, - flags=Frame.FLAG_NO_FLAGS, - stream_id=0x0, - error_code=0x0): - super(RstStreamFrame, self).__init__(state, length, flags, stream_id) - self.error_code = error_code - - @classmethod - def from_bytes(cls, state, length, flags, stream_id, payload): - f = cls(state=state, length=length, flags=flags, stream_id=stream_id) - f.error_code = struct.unpack('!L', payload)[0] - return f - - def payload_bytes(self): - if self.stream_id == 0x0: - raise ValueError( - 'RST_STREAM frames MUST be associated with a stream.') - - return struct.pack('!L', self.error_code) - - def payload_human_readable(self): - return "error code: %#x" % self.error_code - - -class SettingsFrame(Frame): - TYPE = 0x4 - VALID_FLAGS = [Frame.FLAG_ACK] - - SETTINGS = utils.BiDi( - SETTINGS_HEADER_TABLE_SIZE=0x1, - SETTINGS_ENABLE_PUSH=0x2, - SETTINGS_MAX_CONCURRENT_STREAMS=0x3, - SETTINGS_INITIAL_WINDOW_SIZE=0x4, - SETTINGS_MAX_FRAME_SIZE=0x5, - SETTINGS_MAX_HEADER_LIST_SIZE=0x6, - ) - - def __init__( - self, - state=None, - length=0, - flags=Frame.FLAG_NO_FLAGS, - stream_id=0x0, - settings=None): - super(SettingsFrame, self).__init__(state, length, flags, stream_id) - - if settings is None: - settings = {} - - self.settings = settings - - @classmethod - def from_bytes(cls, state, length, flags, stream_id, payload): - f = cls(state=state, length=length, flags=flags, stream_id=stream_id) - - for i in xrange(0, len(payload), 6): - identifier, value = struct.unpack("!HL", payload[i:i + 6]) - f.settings[identifier] = value - - return f - - def payload_bytes(self): - if self.stream_id != 0x0: - raise ValueError( - 'SETTINGS frames MUST NOT be associated with a stream.') - - b = b'' - for identifier, value in self.settings.items(): - b += struct.pack("!HL", identifier & 0xFF, value) - - return b - - def payload_human_readable(self): - s = [] - - for identifier, value in self.settings.items(): - s.append("%s: %#x" % (self.SETTINGS.get_name(identifier), value)) - - if not s: - return "settings: None" - else: - return "\n".join(s) - - -class PushPromiseFrame(Frame): - TYPE = 0x5 - VALID_FLAGS = [Frame.FLAG_END_HEADERS, Frame.FLAG_PADDED] - - def __init__( - self, - state=None, - length=0, - flags=Frame.FLAG_NO_FLAGS, - stream_id=0x0, - promised_stream=0x0, - header_block_fragment=b'', - pad_length=0): - super(PushPromiseFrame, self).__init__(state, length, flags, stream_id) - self.pad_length = pad_length - self.promised_stream = promised_stream - self.header_block_fragment = header_block_fragment - - @classmethod - def from_bytes(cls, state, length, flags, stream_id, payload): - f = cls(state=state, length=length, flags=flags, stream_id=stream_id) - - if f.flags & Frame.FLAG_PADDED: - f.pad_length, f.promised_stream = struct.unpack('!BL', payload[:5]) - f.header_block_fragment = payload[5:-f.pad_length] - else: - f.promised_stream = int(struct.unpack("!L", payload[:4])[0]) - f.header_block_fragment = payload[4:] - - f.promised_stream &= 0x7FFFFFFF - - return f - - def payload_bytes(self): - if self.stream_id == 0x0: - raise ValueError( - 'PUSH_PROMISE frames MUST be associated with a stream.') - - if self.promised_stream == 0x0: - raise ValueError('Promised stream id not valid.') - - b = b'' - if self.flags & self.FLAG_PADDED: - b += struct.pack('!B', self.pad_length) - - b += struct.pack('!L', self.promised_stream & 0x7FFFFFFF) - b += bytes(self.header_block_fragment) - - if self.flags & self.FLAG_PADDED: - b += b'\0' * self.pad_length - - return b - - def payload_human_readable(self): - s = [] - - if self.flags & self.FLAG_PADDED: - s.append("padding: %d" % self.pad_length) - - s.append("promised stream: %#x" % self.promised_stream) - s.append( - "header_block_fragment: %s" % - self.header_block_fragment.encode('hex')) - - return "\n".join(s) - - -class PingFrame(Frame): - TYPE = 0x6 - VALID_FLAGS = [Frame.FLAG_ACK] - - def __init__( - self, - state=None, - length=0, - flags=Frame.FLAG_NO_FLAGS, - stream_id=0x0, - payload=b''): - super(PingFrame, self).__init__(state, length, flags, stream_id) - self.payload = payload - - @classmethod - def from_bytes(cls, state, length, flags, stream_id, payload): - f = cls(state=state, length=length, flags=flags, stream_id=stream_id) - f.payload = payload - return f - - def payload_bytes(self): - if self.stream_id != 0x0: - raise ValueError( - 'PING frames MUST NOT be associated with a stream.') - - b = self.payload[0:8] - b += b'\0' * (8 - len(b)) - return b - - def payload_human_readable(self): - return "opaque data: %s" % str(self.payload) - - -class GoAwayFrame(Frame): - TYPE = 0x7 - VALID_FLAGS = [] - - def __init__( - self, - state=None, - length=0, - flags=Frame.FLAG_NO_FLAGS, - stream_id=0x0, - last_stream=0x0, - error_code=0x0, - data=b''): - super(GoAwayFrame, self).__init__(state, length, flags, stream_id) - self.last_stream = last_stream - self.error_code = error_code - self.data = data - - @classmethod - def from_bytes(cls, state, length, flags, stream_id, payload): - f = cls(state=state, length=length, flags=flags, stream_id=stream_id) - - f.last_stream, f.error_code = struct.unpack("!LL", payload[:8]) - f.last_stream &= 0x7FFFFFFF - f.data = payload[8:] - - return f - - def payload_bytes(self): - if self.stream_id != 0x0: - raise ValueError( - 'GOAWAY frames MUST NOT be associated with a stream.') - - b = struct.pack('!LL', self.last_stream & 0x7FFFFFFF, self.error_code) - b += bytes(self.data) - return b - - def payload_human_readable(self): - s = [] - s.append("last stream: %#x" % self.last_stream) - s.append("error code: %d" % self.error_code) - s.append("debug data: %s" % str(self.data)) - return "\n".join(s) - - -class WindowUpdateFrame(Frame): - TYPE = 0x8 - VALID_FLAGS = [] - - def __init__( - self, - state=None, - length=0, - flags=Frame.FLAG_NO_FLAGS, - stream_id=0x0, - window_size_increment=0x0): - super(WindowUpdateFrame, self).__init__(state, length, flags, stream_id) - self.window_size_increment = window_size_increment - - @classmethod - def from_bytes(cls, state, length, flags, stream_id, payload): - f = cls(state=state, length=length, flags=flags, stream_id=stream_id) - - f.window_size_increment = struct.unpack("!L", payload)[0] - f.window_size_increment &= 0x7FFFFFFF - - return f - - def payload_bytes(self): - if self.window_size_increment <= 0 or self.window_size_increment >= 2 ** 31: - raise ValueError( - 'Window Szie Increment MUST be greater than 0 and less than 2^31.') - - return struct.pack('!L', self.window_size_increment & 0x7FFFFFFF) - - def payload_human_readable(self): - return "window size increment: %#x" % self.window_size_increment - - -class ContinuationFrame(Frame): - TYPE = 0x9 - VALID_FLAGS = [Frame.FLAG_END_HEADERS] - - def __init__( - self, - state=None, - length=0, - flags=Frame.FLAG_NO_FLAGS, - stream_id=0x0, - header_block_fragment=b''): - super(ContinuationFrame, self).__init__(state, length, flags, stream_id) - self.header_block_fragment = header_block_fragment - - @classmethod - def from_bytes(cls, state, length, flags, stream_id, payload): - f = cls(state=state, length=length, flags=flags, stream_id=stream_id) - f.header_block_fragment = payload - return f - - def payload_bytes(self): - if self.stream_id == 0x0: - raise ValueError( - 'CONTINUATION frames MUST be associated with a stream.') - - return self.header_block_fragment - - def payload_human_readable(self): - s = [] - s.append( - "header_block_fragment: %s" % - self.header_block_fragment.encode('hex')) - return "\n".join(s) - -_FRAME_CLASSES = [ - DataFrame, - HeadersFrame, - PriorityFrame, - RstStreamFrame, - SettingsFrame, - PushPromiseFrame, - PingFrame, - GoAwayFrame, - WindowUpdateFrame, - ContinuationFrame -] -FRAMES = {cls.TYPE: cls for cls in _FRAME_CLASSES} - - -HTTP2_DEFAULT_SETTINGS = { - SettingsFrame.SETTINGS.SETTINGS_HEADER_TABLE_SIZE: 4096, - SettingsFrame.SETTINGS.SETTINGS_ENABLE_PUSH: 1, - SettingsFrame.SETTINGS.SETTINGS_MAX_CONCURRENT_STREAMS: None, - SettingsFrame.SETTINGS.SETTINGS_INITIAL_WINDOW_SIZE: 2 ** 16 - 1, - SettingsFrame.SETTINGS.SETTINGS_MAX_FRAME_SIZE: 2 ** 14, - SettingsFrame.SETTINGS.SETTINGS_MAX_HEADER_LIST_SIZE: None, -} diff --git a/netlib/http2/protocol.py b/netlib/http2/protocol.py deleted file mode 100644 index 8e5f5429..00000000 --- a/netlib/http2/protocol.py +++ /dev/null @@ -1,240 +0,0 @@ -from __future__ import (absolute_import, print_function, division) -import itertools - -from hpack.hpack import Encoder, Decoder -from .. import utils -from . import frame - - -class HTTP2Protocol(object): - - ERROR_CODES = utils.BiDi( - NO_ERROR=0x0, - PROTOCOL_ERROR=0x1, - INTERNAL_ERROR=0x2, - FLOW_CONTROL_ERROR=0x3, - SETTINGS_TIMEOUT=0x4, - STREAM_CLOSED=0x5, - FRAME_SIZE_ERROR=0x6, - REFUSED_STREAM=0x7, - CANCEL=0x8, - COMPRESSION_ERROR=0x9, - CONNECT_ERROR=0xa, - ENHANCE_YOUR_CALM=0xb, - INADEQUATE_SECURITY=0xc, - HTTP_1_1_REQUIRED=0xd - ) - - # "PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n" - CLIENT_CONNECTION_PREFACE =\ - '505249202a20485454502f322e300d0a0d0a534d0d0a0d0a'.decode('hex') - - ALPN_PROTO_H2 = 'h2' - - def __init__(self, tcp_handler, is_server=False, dump_frames=False): - self.tcp_handler = tcp_handler - self.is_server = is_server - - self.http2_settings = frame.HTTP2_DEFAULT_SETTINGS.copy() - self.current_stream_id = None - self.encoder = Encoder() - self.decoder = Decoder() - self.connection_preface_performed = False - self.dump_frames = dump_frames - - def check_alpn(self): - alp = self.tcp_handler.get_alpn_proto_negotiated() - if alp != self.ALPN_PROTO_H2: - raise NotImplementedError( - "HTTP2Protocol can not handle unknown ALP: %s" % alp) - return True - - def _receive_settings(self, hide=False): - while True: - frm = self.read_frame(hide) - if isinstance(frm, frame.SettingsFrame): - break - - def _read_settings_ack(self, hide=False): # pragma no cover - while True: - frm = self.read_frame(hide) - if isinstance(frm, frame.SettingsFrame): - assert frm.flags & frame.Frame.FLAG_ACK - assert len(frm.settings) == 0 - break - - def perform_server_connection_preface(self, force=False): - if force or not self.connection_preface_performed: - self.connection_preface_performed = True - - magic_length = len(self.CLIENT_CONNECTION_PREFACE) - magic = self.tcp_handler.rfile.safe_read(magic_length) - assert magic == self.CLIENT_CONNECTION_PREFACE - - self.send_frame(frame.SettingsFrame(state=self), hide=True) - self._receive_settings(hide=True) - - def perform_client_connection_preface(self, force=False): - if force or not self.connection_preface_performed: - self.connection_preface_performed = True - - self.tcp_handler.wfile.write(self.CLIENT_CONNECTION_PREFACE) - - self.send_frame(frame.SettingsFrame(state=self), hide=True) - self._receive_settings(hide=True) - - def next_stream_id(self): - if self.current_stream_id is None: - if self.is_server: - # servers must use even stream ids - self.current_stream_id = 2 - else: - # clients must use odd stream ids - self.current_stream_id = 1 - else: - self.current_stream_id += 2 - return self.current_stream_id - - def send_frame(self, frm, hide=False): - raw_bytes = frm.to_bytes() - self.tcp_handler.wfile.write(raw_bytes) - self.tcp_handler.wfile.flush() - if not hide and self.dump_frames: # pragma no cover - print(frm.human_readable(">>")) - - def read_frame(self, hide=False): - frm = frame.Frame.from_file(self.tcp_handler.rfile, self) - if not hide and self.dump_frames: # pragma no cover - print(frm.human_readable("<<")) - if isinstance(frm, frame.SettingsFrame) and not frm.flags & frame.Frame.FLAG_ACK: - self._apply_settings(frm.settings, hide) - - return frm - - def _apply_settings(self, settings, hide=False): - for setting, value in settings.items(): - old_value = self.http2_settings[setting] - if not old_value: - old_value = '-' - self.http2_settings[setting] = value - - frm = frame.SettingsFrame( - state=self, - flags=frame.Frame.FLAG_ACK) - self.send_frame(frm, hide) - - # be liberal in what we expect from the other end - # to be more strict use: self._read_settings_ack(hide) - - def _create_headers(self, headers, stream_id, end_stream=True): - # TODO: implement max frame size checks and sending in chunks - - flags = frame.Frame.FLAG_END_HEADERS - if end_stream: - flags |= frame.Frame.FLAG_END_STREAM - - header_block_fragment = self.encoder.encode(headers) - - frm = frame.HeadersFrame( - state=self, - flags=flags, - stream_id=stream_id, - header_block_fragment=header_block_fragment) - - if self.dump_frames: # pragma no cover - print(frm.human_readable(">>")) - - return [frm.to_bytes()] - - def _create_body(self, body, stream_id): - if body is None or len(body) == 0: - return b'' - - # TODO: implement max frame size checks and sending in chunks - # TODO: implement flow-control window - - frm = frame.DataFrame( - state=self, - flags=frame.Frame.FLAG_END_STREAM, - stream_id=stream_id, - payload=body) - - if self.dump_frames: # pragma no cover - print(frm.human_readable(">>")) - - return [frm.to_bytes()] - - - def create_request(self, method, path, headers=None, body=None): - if headers is None: - headers = [] - - authority = self.tcp_handler.sni if self.tcp_handler.sni else self.tcp_handler.address.host - if self.tcp_handler.address.port != 443: - authority += ":%d" % self.tcp_handler.address.port - - headers = [ - (b':method', bytes(method)), - (b':path', bytes(path)), - (b':scheme', b'https'), - (b':authority', authority), - ] + headers - - stream_id = self.next_stream_id() - - return list(itertools.chain( - self._create_headers(headers, stream_id, end_stream=(body is None)), - self._create_body(body, stream_id))) - - def read_response(self): - stream_id_, headers, body = self._receive_transmission() - return headers[':status'], headers, body - - def read_request(self): - return self._receive_transmission() - - def _receive_transmission(self): - body_expected = True - - stream_id = 0 - header_block_fragment = b'' - body = b'' - - while True: - frm = self.read_frame() - if isinstance(frm, frame.HeadersFrame)\ - or isinstance(frm, frame.ContinuationFrame): - stream_id = frm.stream_id - header_block_fragment += frm.header_block_fragment - if frm.flags & frame.Frame.FLAG_END_STREAM: - body_expected = False - if frm.flags & frame.Frame.FLAG_END_HEADERS: - break - - while body_expected: - frm = self.read_frame() - if isinstance(frm, frame.DataFrame): - body += frm.payload - if frm.flags & frame.Frame.FLAG_END_STREAM: - break - # TODO: implement window update & flow - - headers = {} - for header, value in self.decoder.decode(header_block_fragment): - headers[header] = value - - return stream_id, headers, body - - def create_response(self, code, stream_id=None, headers=None, body=None): - if headers is None: - headers = [] - - headers = [(b':status', bytes(str(code)))] + headers - - if not stream_id: - stream_id = self.next_stream_id() - - return list(itertools.chain( - self._create_headers(headers, stream_id, end_stream=(body is None)), - self._create_body(body, stream_id), - )) diff --git a/netlib/http_auth.py b/netlib/http_auth.py deleted file mode 100644 index adab4aed..00000000 --- a/netlib/http_auth.py +++ /dev/null @@ -1,148 +0,0 @@ -from __future__ import (absolute_import, print_function, division) -from argparse import Action, ArgumentTypeError -from . import http - - -class NullProxyAuth(object): - - """ - No proxy auth at all (returns empty challange headers) - """ - - def __init__(self, password_manager): - self.password_manager = password_manager - - def clean(self, headers_): - """ - Clean up authentication headers, so they're not passed upstream. - """ - pass - - def authenticate(self, headers_): - """ - Tests that the user is allowed to use the proxy - """ - return True - - def auth_challenge_headers(self): - """ - Returns a dictionary containing the headers require to challenge the user - """ - return {} - - -class BasicProxyAuth(NullProxyAuth): - CHALLENGE_HEADER = 'Proxy-Authenticate' - AUTH_HEADER = 'Proxy-Authorization' - - def __init__(self, password_manager, realm): - NullProxyAuth.__init__(self, password_manager) - self.realm = realm - - def clean(self, headers): - del headers[self.AUTH_HEADER] - - def authenticate(self, headers): - auth_value = headers.get(self.AUTH_HEADER, []) - if not auth_value: - return False - parts = http.parse_http_basic_auth(auth_value[0]) - if not parts: - return False - scheme, username, password = parts - if scheme.lower() != 'basic': - return False - if not self.password_manager.test(username, password): - return False - self.username = username - return True - - def auth_challenge_headers(self): - return {self.CHALLENGE_HEADER: 'Basic realm="%s"' % self.realm} - - -class PassMan(object): - - def test(self, username_, password_token_): - return False - - -class PassManNonAnon(PassMan): - - """ - Ensure the user specifies a username, accept any password. - """ - - def test(self, username, password_token_): - if username: - return True - return False - - -class PassManHtpasswd(PassMan): - - """ - Read usernames and passwords from an htpasswd file - """ - - def __init__(self, path): - """ - Raises ValueError if htpasswd file is invalid. - """ - import passlib.apache - self.htpasswd = passlib.apache.HtpasswdFile(path) - - def test(self, username, password_token): - return bool(self.htpasswd.check_password(username, password_token)) - - -class PassManSingleUser(PassMan): - - def __init__(self, username, password): - self.username, self.password = username, password - - def test(self, username, password_token): - return self.username == username and self.password == password_token - - -class AuthAction(Action): - - """ - Helper class to allow seamless integration int argparse. Example usage: - parser.add_argument( - "--nonanonymous", - action=NonanonymousAuthAction, nargs=0, - help="Allow access to any user long as a credentials are specified." - ) - """ - - def __call__(self, parser, namespace, values, option_string=None): - passman = self.getPasswordManager(values) - authenticator = BasicProxyAuth(passman, "mitmproxy") - setattr(namespace, self.dest, authenticator) - - def getPasswordManager(self, s): # pragma: nocover - raise NotImplementedError() - - -class SingleuserAuthAction(AuthAction): - - def getPasswordManager(self, s): - if len(s.split(':')) != 2: - raise ArgumentTypeError( - "Invalid single-user specification. Please use the format username:password" - ) - username, password = s.split(':') - return PassManSingleUser(username, password) - - -class NonanonymousAuthAction(AuthAction): - - def getPasswordManager(self, s): - return PassManNonAnon() - - -class HtpasswdAuthAction(AuthAction): - - def getPasswordManager(self, s): - return PassManHtpasswd(s) diff --git a/netlib/http_cookies.py b/netlib/http_cookies.py deleted file mode 100644 index e91ee5c0..00000000 --- a/netlib/http_cookies.py +++ /dev/null @@ -1,193 +0,0 @@ -""" -A flexible module for cookie parsing and manipulation. - -This module differs from usual standards-compliant cookie modules in a number -of ways. We try to be as permissive as possible, and to retain even mal-formed -information. Duplicate cookies are preserved in parsing, and can be set in -formatting. We do attempt to escape and quote values where needed, but will not -reject data that violate the specs. - -Parsing accepts the formats in RFC6265 and partially RFC2109 and RFC2965. We do -not parse the comma-separated variant of Set-Cookie that allows multiple -cookies to be set in a single header. Technically this should be feasible, but -it turns out that violations of RFC6265 that makes the parsing problem -indeterminate are much more common than genuine occurences of the multi-cookie -variants. Serialization follows RFC6265. - - http://tools.ietf.org/html/rfc6265 - http://tools.ietf.org/html/rfc2109 - http://tools.ietf.org/html/rfc2965 -""" - -# TODO -# - Disallow LHS-only Cookie values - -import re - -import odict - - -def _read_until(s, start, term): - """ - Read until one of the characters in term is reached. - """ - if start == len(s): - return "", start + 1 - for i in range(start, len(s)): - if s[i] in term: - return s[start:i], i - return s[start:i + 1], i + 1 - - -def _read_token(s, start): - """ - Read a token - the LHS of a token/value pair in a cookie. - """ - return _read_until(s, start, ";=") - - -def _read_quoted_string(s, start): - """ - start: offset to the first quote of the string to be read - - A sort of loose super-set of the various quoted string specifications. - - RFC6265 disallows backslashes or double quotes within quoted strings. - Prior RFCs use backslashes to escape. This leaves us free to apply - backslash escaping by default and be compatible with everything. - """ - escaping = False - ret = [] - # Skip the first quote - for i in range(start + 1, len(s)): - if escaping: - ret.append(s[i]) - escaping = False - elif s[i] == '"': - break - elif s[i] == "\\": - escaping = True - else: - ret.append(s[i]) - return "".join(ret), i + 1 - - -def _read_value(s, start, delims): - """ - Reads a value - the RHS of a token/value pair in a cookie. - - special: If the value is special, commas are premitted. Else comma - terminates. This helps us support old and new style values. - """ - if start >= len(s): - return "", start - elif s[start] == '"': - return _read_quoted_string(s, start) - else: - return _read_until(s, start, delims) - - -def _read_pairs(s, off=0): - """ - Read pairs of lhs=rhs values. - - off: start offset - specials: a lower-cased list of keys that may contain commas - """ - vals = [] - while True: - lhs, off = _read_token(s, off) - lhs = lhs.lstrip() - if lhs: - rhs = None - if off < len(s): - if s[off] == "=": - rhs, off = _read_value(s, off + 1, ";") - vals.append([lhs, rhs]) - off += 1 - if not off < len(s): - break - return vals, off - - -def _has_special(s): - for i in s: - if i in '",;\\': - return True - o = ord(i) - if o < 0x21 or o > 0x7e: - return True - return False - - -ESCAPE = re.compile(r"([\"\\])") - - -def _format_pairs(lst, specials=(), sep="; "): - """ - specials: A lower-cased list of keys that will not be quoted. - """ - vals = [] - for k, v in lst: - if v is None: - vals.append(k) - else: - if k.lower() not in specials and _has_special(v): - v = ESCAPE.sub(r"\\\1", v) - v = '"%s"' % v - vals.append("%s=%s" % (k, v)) - return sep.join(vals) - - -def _format_set_cookie_pairs(lst): - return _format_pairs( - lst, - specials=("expires", "path") - ) - - -def _parse_set_cookie_pairs(s): - """ - For Set-Cookie, we support multiple cookies as described in RFC2109. - This function therefore returns a list of lists. - """ - pairs, off_ = _read_pairs(s) - return pairs - - -def parse_set_cookie_header(line): - """ - Parse a Set-Cookie header value - - Returns a (name, value, attrs) tuple, or None, where attrs is an - ODictCaseless set of attributes. No attempt is made to parse attribute - values - they are treated purely as strings. - """ - pairs = _parse_set_cookie_pairs(line) - if pairs: - return pairs[0][0], pairs[0][1], odict.ODictCaseless(pairs[1:]) - - -def format_set_cookie_header(name, value, attrs): - """ - Formats a Set-Cookie header value. - """ - pairs = [[name, value]] - pairs.extend(attrs.lst) - return _format_set_cookie_pairs(pairs) - - -def parse_cookie_header(line): - """ - Parse a Cookie header value. - Returns a (possibly empty) ODict object. - """ - pairs, off_ = _read_pairs(line) - return odict.ODict(pairs) - - -def format_cookie_header(od): - """ - Formats a Cookie header value. - """ - return _format_pairs(od.lst) diff --git a/netlib/http_semantics.py b/netlib/http_semantics.py deleted file mode 100644 index e8313e3c..00000000 --- a/netlib/http_semantics.py +++ /dev/null @@ -1,23 +0,0 @@ -class Response(object): - - def __init__( - self, - httpversion, - status_code, - msg, - headers, - content, - sslinfo=None, - ): - self.httpversion = httpversion - self.status_code = status_code - self.msg = msg - self.headers = headers - self.content = content - self.sslinfo = sslinfo - - def __eq__(self, other): - return self.__dict__ == other.__dict__ - - def __repr__(self): - return "Response(%s - %s)" % (self.status_code, self.msg) diff --git a/netlib/http_status.py b/netlib/http_status.py deleted file mode 100644 index dc09f465..00000000 --- a/netlib/http_status.py +++ /dev/null @@ -1,104 +0,0 @@ -from __future__ import (absolute_import, print_function, division) - -CONTINUE = 100 -SWITCHING = 101 -OK = 200 -CREATED = 201 -ACCEPTED = 202 -NON_AUTHORITATIVE_INFORMATION = 203 -NO_CONTENT = 204 -RESET_CONTENT = 205 -PARTIAL_CONTENT = 206 -MULTI_STATUS = 207 - -MULTIPLE_CHOICE = 300 -MOVED_PERMANENTLY = 301 -FOUND = 302 -SEE_OTHER = 303 -NOT_MODIFIED = 304 -USE_PROXY = 305 -TEMPORARY_REDIRECT = 307 - -BAD_REQUEST = 400 -UNAUTHORIZED = 401 -PAYMENT_REQUIRED = 402 -FORBIDDEN = 403 -NOT_FOUND = 404 -NOT_ALLOWED = 405 -NOT_ACCEPTABLE = 406 -PROXY_AUTH_REQUIRED = 407 -REQUEST_TIMEOUT = 408 -CONFLICT = 409 -GONE = 410 -LENGTH_REQUIRED = 411 -PRECONDITION_FAILED = 412 -REQUEST_ENTITY_TOO_LARGE = 413 -REQUEST_URI_TOO_LONG = 414 -UNSUPPORTED_MEDIA_TYPE = 415 -REQUESTED_RANGE_NOT_SATISFIABLE = 416 -EXPECTATION_FAILED = 417 - -INTERNAL_SERVER_ERROR = 500 -NOT_IMPLEMENTED = 501 -BAD_GATEWAY = 502 -SERVICE_UNAVAILABLE = 503 -GATEWAY_TIMEOUT = 504 -HTTP_VERSION_NOT_SUPPORTED = 505 -INSUFFICIENT_STORAGE_SPACE = 507 -NOT_EXTENDED = 510 - -RESPONSES = { - # 100 - CONTINUE: "Continue", - SWITCHING: "Switching Protocols", - - # 200 - OK: "OK", - CREATED: "Created", - ACCEPTED: "Accepted", - NON_AUTHORITATIVE_INFORMATION: "Non-Authoritative Information", - NO_CONTENT: "No Content", - RESET_CONTENT: "Reset Content.", - PARTIAL_CONTENT: "Partial Content", - MULTI_STATUS: "Multi-Status", - - # 300 - MULTIPLE_CHOICE: "Multiple Choices", - MOVED_PERMANENTLY: "Moved Permanently", - FOUND: "Found", - SEE_OTHER: "See Other", - NOT_MODIFIED: "Not Modified", - USE_PROXY: "Use Proxy", - # 306 not defined?? - TEMPORARY_REDIRECT: "Temporary Redirect", - - # 400 - BAD_REQUEST: "Bad Request", - UNAUTHORIZED: "Unauthorized", - PAYMENT_REQUIRED: "Payment Required", - FORBIDDEN: "Forbidden", - NOT_FOUND: "Not Found", - NOT_ALLOWED: "Method Not Allowed", - NOT_ACCEPTABLE: "Not Acceptable", - PROXY_AUTH_REQUIRED: "Proxy Authentication Required", - REQUEST_TIMEOUT: "Request Time-out", - CONFLICT: "Conflict", - GONE: "Gone", - LENGTH_REQUIRED: "Length Required", - PRECONDITION_FAILED: "Precondition Failed", - REQUEST_ENTITY_TOO_LARGE: "Request Entity Too Large", - REQUEST_URI_TOO_LONG: "Request-URI Too Long", - UNSUPPORTED_MEDIA_TYPE: "Unsupported Media Type", - REQUESTED_RANGE_NOT_SATISFIABLE: "Requested Range not satisfiable", - EXPECTATION_FAILED: "Expectation Failed", - - # 500 - INTERNAL_SERVER_ERROR: "Internal Server Error", - NOT_IMPLEMENTED: "Not Implemented", - BAD_GATEWAY: "Bad Gateway", - SERVICE_UNAVAILABLE: "Service Unavailable", - GATEWAY_TIMEOUT: "Gateway Time-out", - HTTP_VERSION_NOT_SUPPORTED: "HTTP Version not supported", - INSUFFICIENT_STORAGE_SPACE: "Insufficient Storage Space", - NOT_EXTENDED: "Not Extended" -} diff --git a/netlib/http_uastrings.py b/netlib/http_uastrings.py deleted file mode 100644 index e8681908..00000000 --- a/netlib/http_uastrings.py +++ /dev/null @@ -1,52 +0,0 @@ -from __future__ import (absolute_import, print_function, division) - -""" - A small collection of useful user-agent header strings. These should be - kept reasonably current to reflect common usage. -""" - -# pylint: line-too-long - -# A collection of (name, shortcut, string) tuples. - -UASTRINGS = [ - ("android", - "a", - "Mozilla/5.0 (Linux; U; Android 4.1.1; en-gb; Nexus 7 Build/JRO03D) AFL/01.04.02"), # noqa - ("blackberry", - "l", - "Mozilla/5.0 (BlackBerry; U; BlackBerry 9900; en) AppleWebKit/534.11+ (KHTML, like Gecko) Version/7.1.0.346 Mobile Safari/534.11+"), # noqa - ("bingbot", - "b", - "Mozilla/5.0 (compatible; bingbot/2.0; +http://www.bing.com/bingbot.htm)"), # noqa - ("chrome", - "c", - "Mozilla/5.0 (Windows NT 6.1; WOW64) AppleWebKit/537.1 (KHTML, like Gecko) Chrome/22.0.1207.1 Safari/537.1"), # noqa - ("firefox", - "f", - "Mozilla/5.0 (Windows NT 6.1; Win64; x64; rv:14.0) Gecko/20120405 Firefox/14.0a1"), # noqa - ("googlebot", - "g", - "Googlebot/2.1 (+http://www.googlebot.com/bot.html)"), # noqa - ("ie9", - "i", - "Mozilla/5.0 (Windows; U; MSIE 9.0; WIndows NT 9.0; en-US)"), # noqa - ("ipad", - "p", - "Mozilla/5.0 (iPad; CPU OS 5_1 like Mac OS X) AppleWebKit/534.46 (KHTML, like Gecko) Version/5.1 Mobile/9B176 Safari/7534.48.3"), # noqa - ("iphone", - "h", - "Mozilla/5.0 (iPhone; CPU iPhone OS 4_2_1 like Mac OS X) AppleWebKit/533.17.9 (KHTML, like Gecko) Version/5.0.2 Mobile/8C148a Safari/6533.18.5"), # noqa - ("safari", - "s", - "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_7_3) AppleWebKit/534.55.3 (KHTML, like Gecko) Version/5.1.3 Safari/534.53.10"), # noqa -] - - -def get_by_shortcut(s): - """ - Retrieve a user agent entry by shortcut. - """ - for i in UASTRINGS: - if s == i[1]: - return i diff --git a/netlib/websockets/frame.py b/netlib/websockets/frame.py index d41059fa..49d8ee10 100644 --- a/netlib/websockets/frame.py +++ b/netlib/websockets/frame.py @@ -6,7 +6,7 @@ import struct import io from .protocol import Masker -from .. import utils, odict, tcp +from netlib import utils, odict, tcp DEFAULT = object() diff --git a/netlib/websockets/protocol.py b/netlib/websockets/protocol.py index dcab53fb..29b4db3d 100644 --- a/netlib/websockets/protocol.py +++ b/netlib/websockets/protocol.py @@ -5,7 +5,7 @@ import os import struct import io -from .. import utils, odict, tcp +from netlib import utils, odict, tcp # Colleciton of utility functions that implement small portions of the RFC6455 # WebSockets Protocol Useful for building WebSocket clients and servers. -- cgit v1.2.3 From bab6cbff1e5444aea72a188d57812130c375e0f0 Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Wed, 15 Jul 2015 22:32:14 +0200 Subject: extract authentication methods from protocol --- netlib/http/authentication.py | 22 +++++++++++++++++++++- netlib/http/http1/protocol.py | 39 ++------------------------------------- netlib/http/semantics.py | 14 +++++++++++++- 3 files changed, 36 insertions(+), 39 deletions(-) (limited to 'netlib') diff --git a/netlib/http/authentication.py b/netlib/http/authentication.py index 26e3c2c4..9a227010 100644 --- a/netlib/http/authentication.py +++ b/netlib/http/authentication.py @@ -1,8 +1,28 @@ from __future__ import (absolute_import, print_function, division) from argparse import Action, ArgumentTypeError +import binascii from .. import http +def parse_http_basic_auth(s): + words = s.split() + if len(words) != 2: + return None + scheme = words[0] + try: + user = binascii.a2b_base64(words[1]) + except binascii.Error: + return None + parts = user.split(':') + if len(parts) != 2: + return None + return scheme, parts[0], parts[1] + + +def assemble_http_basic_auth(scheme, username, password): + v = binascii.b2a_base64(username + ":" + password) + return scheme + " " + v + class NullProxyAuth(object): @@ -47,7 +67,7 @@ class BasicProxyAuth(NullProxyAuth): auth_value = headers.get(self.AUTH_HEADER, []) if not auth_value: return False - parts = http.http1.parse_http_basic_auth(auth_value[0]) + parts = parse_http_basic_auth(auth_value[0]) if not parts: return False scheme, username, password = parts diff --git a/netlib/http/http1/protocol.py b/netlib/http/http1/protocol.py index 0f7a0bd3..97c119a9 100644 --- a/netlib/http/http1/protocol.py +++ b/netlib/http/http1/protocol.py @@ -85,22 +85,9 @@ def read_chunked(fp, limit, is_request): return -def get_header_tokens(headers, key): - """ - Retrieve all tokens for a header key. A number of different headers - follow a pattern where each header line can containe comma-separated - tokens, and headers can be set multiple times. - """ - toks = [] - for i in headers[key]: - for j in i.split(","): - toks.append(j.strip()) - return toks - - def has_chunked_encoding(headers): return "chunked" in [ - i.lower() for i in get_header_tokens(headers, "transfer-encoding") + i.lower() for i in http.get_header_tokens(headers, "transfer-encoding") ] @@ -123,28 +110,6 @@ def parse_http_protocol(s): return major, minor -def parse_http_basic_auth(s): - # TODO: check if this is HTTP/1 only - otherwise move it to netlib.http.semantics - words = s.split() - if len(words) != 2: - return None - scheme = words[0] - try: - user = binascii.a2b_base64(words[1]) - except binascii.Error: - return None - parts = user.split(':') - if len(parts) != 2: - return None - return scheme, parts[0], parts[1] - - -def assemble_http_basic_auth(scheme, username, password): - # TODO: check if this is HTTP/1 only - otherwise move it to netlib.http.semantics - v = binascii.b2a_base64(username + ":" + password) - return scheme + " " + v - - def parse_init(line): try: method, url, protocol = string.split(line) @@ -221,7 +186,7 @@ def connection_close(httpversion, headers): """ # At first, check if we have an explicit Connection header. if "connection" in headers: - toks = get_header_tokens(headers, "connection") + toks = http.get_header_tokens(headers, "connection") if "close" in toks: return True elif "keep-alive" in toks: diff --git a/netlib/http/semantics.py b/netlib/http/semantics.py index e7e84fe3..a62c93e3 100644 --- a/netlib/http/semantics.py +++ b/netlib/http/semantics.py @@ -49,7 +49,6 @@ def is_valid_host(host): return True - def parse_url(url): """ Returns a (scheme, host, port, path) tuple, or None on error. @@ -92,3 +91,16 @@ def parse_url(url): if not is_valid_port(port): return None return scheme, host, port, path + + +def get_header_tokens(headers, key): + """ + Retrieve all tokens for a header key. A number of different headers + follow a pattern where each header line can containe comma-separated + tokens, and headers can be set multiple times. + """ + toks = [] + for i in headers[key]: + for j in i.split(","): + toks.append(j.strip()) + return toks -- cgit v1.2.3 From 230c16122b06f5c6af60e6ddc2d8e2e83cd75273 Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Thu, 16 Jul 2015 22:50:24 +0200 Subject: change HTTP2 interface to match HTTP1 --- netlib/http/http2/protocol.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) (limited to 'netlib') diff --git a/netlib/http/http2/protocol.py b/netlib/http/http2/protocol.py index 8e5f5429..0d6eac85 100644 --- a/netlib/http/http2/protocol.py +++ b/netlib/http/http2/protocol.py @@ -2,7 +2,7 @@ from __future__ import (absolute_import, print_function, division) import itertools from hpack.hpack import Encoder, Decoder -from .. import utils +from netlib import http, utils from . import frame @@ -186,9 +186,9 @@ class HTTP2Protocol(object): self._create_headers(headers, stream_id, end_stream=(body is None)), self._create_body(body, stream_id))) - def read_response(self): + def read_response(self, *args): stream_id_, headers, body = self._receive_transmission() - return headers[':status'], headers, body + return http.Response("HTTP/2", headers[':status'], "", headers, body) def read_request(self): return self._receive_transmission() -- cgit v1.2.3 From 808b294865257fc3f52b33ed2a796009658b126f Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Thu, 16 Jul 2015 22:56:34 +0200 Subject: refactor HTTP/1 as protocol --- netlib/http/http1/protocol.py | 901 +++++++++++++++++++++--------------------- 1 file changed, 457 insertions(+), 444 deletions(-) (limited to 'netlib') diff --git a/netlib/http/http1/protocol.py b/netlib/http/http1/protocol.py index 97c119a9..401654c1 100644 --- a/netlib/http/http1/protocol.py +++ b/netlib/http/http1/protocol.py @@ -9,475 +9,488 @@ from netlib import odict, utils, tcp, http from .. import status_codes from ..exceptions import * +class HTTP1Protocol(object): + + # TODO: make this a regular class - just like Response + Request = collections.namedtuple( + "Request", + [ + "form_in", + "method", + "scheme", + "host", + "port", + "path", + "httpversion", + "headers", + "content" + ] + ) -def get_request_line(fp): - """ - Get a line, possibly preceded by a blank. - """ - line = fp.readline() - if line == "\r\n" or line == "\n": - # Possible leftover from previous message - line = fp.readline() - return line - -def read_headers(fp): - """ - Read a set of headers from a file pointer. Stop once a blank line is - reached. Return a ODictCaseless object, or None if headers are invalid. - """ - ret = [] - name = '' - while True: - line = fp.readline() - if not line or line == '\r\n' or line == '\n': - break - if line[0] in ' \t': - if not ret: - return None - # continued header - ret[-1][1] = ret[-1][1] + '\r\n ' + line.strip() - else: - i = line.find(':') - # We're being liberal in what we accept, here. - if i > 0: - name = line[:i] - value = line[i + 1:].strip() - ret.append([name, value]) + def __init__(self, tcp_handler): + self.tcp_handler = tcp_handler + + def get_request_line(self): + """ + Get a line, possibly preceded by a blank. + """ + line = self.tcp_handler.rfile.readline() + if line == "\r\n" or line == "\n": + # Possible leftover from previous message + line = self.tcp_handler.rfile.readline() + return line + + def read_headers(self): + """ + Read a set of headers. + Stop once a blank line is reached. + + Return a ODictCaseless object, or None if headers are invalid. + """ + ret = [] + name = '' + while True: + line = self.tcp_handler.rfile.readline() + if not line or line == '\r\n' or line == '\n': + break + if line[0] in ' \t': + if not ret: + return None + # continued header + ret[-1][1] = ret[-1][1] + '\r\n ' + line.strip() else: - return None - return odict.ODictCaseless(ret) - - -def read_chunked(fp, limit, is_request): - """ - Read a chunked HTTP body. - - May raise HttpError. - """ - # FIXME: Should check if chunked is the final encoding in the headers - # http://tools.ietf.org/html/draft-ietf-httpbis-p1-messaging-16#section-3.3 - # 3.3 2. - total = 0 - code = 400 if is_request else 502 - while True: - line = fp.readline(128) - if line == "": - raise HttpErrorConnClosed(code, "Connection closed prematurely") - if line != '\r\n' and line != '\n': - try: - length = int(line, 16) - except ValueError: - raise HttpError( - code, - "Invalid chunked encoding length: %s" % line - ) - total += length - if limit is not None and total > limit: - msg = "HTTP Body too large. Limit is %s," \ - " chunked content longer than %s" % (limit, total) - raise HttpError(code, msg) - chunk = fp.read(length) - suffix = fp.readline(5) - if suffix != '\r\n': - raise HttpError(code, "Malformed chunked body") - yield line, chunk, '\r\n' - if length == 0: - return - - -def has_chunked_encoding(headers): - return "chunked" in [ - i.lower() for i in http.get_header_tokens(headers, "transfer-encoding") - ] - - -def parse_http_protocol(s): - """ - Parse an HTTP protocol declaration. Returns a (major, minor) tuple, or - None. - """ - if not s.startswith("HTTP/"): - return None - _, version = s.split('/', 1) - if "." not in version: - return None - major, minor = version.split('.', 1) - try: - major = int(major) - minor = int(minor) - except ValueError: - return None - return major, minor - - -def parse_init(line): - try: - method, url, protocol = string.split(line) - except ValueError: - return None - httpversion = parse_http_protocol(protocol) - if not httpversion: - return None - if not utils.isascii(method): - return None - return method, url, httpversion - - -def parse_init_connect(line): - """ - Returns (host, port, httpversion) if line is a valid CONNECT line. - http://tools.ietf.org/html/draft-luotonen-web-proxy-tunneling-01 section 3.1 - """ - v = parse_init(line) - if not v: - return None - method, url, httpversion = v - - if method.upper() != 'CONNECT': - return None - try: - host, port = url.split(":") - except ValueError: - return None - try: - port = int(port) - except ValueError: - return None - if not http.is_valid_port(port): - return None - if not http.is_valid_host(host): - return None - return host, port, httpversion - - -def parse_init_proxy(line): - v = parse_init(line) - if not v: - return None - method, url, httpversion = v - - parts = http.parse_url(url) - if not parts: - return None - scheme, host, port, path = parts - return method, scheme, host, port, path, httpversion - - -def parse_init_http(line): - """ - Returns (method, url, httpversion) - """ - v = parse_init(line) - if not v: - return None - method, url, httpversion = v - if not utils.isascii(url): - return None - if not (url.startswith("/") or url == "*"): - return None - return method, url, httpversion - - -def connection_close(httpversion, headers): - """ - Checks the message to see if the client connection should be closed - according to RFC 2616 Section 8.1 Note that a connection should be - closed as well if the response has been read until end of the stream. - """ - # At first, check if we have an explicit Connection header. - if "connection" in headers: - toks = http.get_header_tokens(headers, "connection") - if "close" in toks: - return True - elif "keep-alive" in toks: - return False - # If we don't have a Connection header, HTTP 1.1 connections are assumed to - # be persistent - if httpversion == (1, 1): - return False - return True - - -def parse_response_line(line): - parts = line.strip().split(" ", 2) - if len(parts) == 2: # handle missing message gracefully - parts.append("") - if len(parts) != 3: - return None - proto, code, msg = parts - try: - code = int(code) - except ValueError: - return None - return (proto, code, msg) - - -def read_http_body(*args, **kwargs): - return "".join( - content for _, content, _ in read_http_body_chunked(*args, **kwargs) - ) + i = line.find(':') + # We're being liberal in what we accept, here. + if i > 0: + name = line[:i] + value = line[i + 1:].strip() + ret.append([name, value]) + else: + return None + return odict.ODictCaseless(ret) + + + def read_chunked(self, limit, is_request): + """ + Read a chunked HTTP body. + + May raise HttpError. + """ + # FIXME: Should check if chunked is the final encoding in the headers + # http://tools.ietf.org/html/draft-ietf-httpbis-p1-messaging-16#section-3.3 + # 3.3 2. + total = 0 + code = 400 if is_request else 502 + while True: + line = self.tcp_handler.rfile.readline(128) + if line == "": + raise HttpErrorConnClosed(code, "Connection closed prematurely") + if line != '\r\n' and line != '\n': + try: + length = int(line, 16) + except ValueError: + raise HttpError( + code, + "Invalid chunked encoding length: %s" % line + ) + total += length + if limit is not None and total > limit: + msg = "HTTP Body too large. Limit is %s," \ + " chunked content longer than %s" % (limit, total) + raise HttpError(code, msg) + chunk = self.tcp_handler.rfile.read(length) + suffix = self.tcp_handler.rfile.readline(5) + if suffix != '\r\n': + raise HttpError(code, "Malformed chunked body") + yield line, chunk, '\r\n' + if length == 0: + return + + + @classmethod + def has_chunked_encoding(self, headers): + return "chunked" in [ + i.lower() for i in http.get_header_tokens(headers, "transfer-encoding") + ] + + + @classmethod + def parse_http_protocol(self, line): + """ + Parse an HTTP protocol declaration. + Returns a (major, minor) tuple, or None. + """ + if not line.startswith("HTTP/"): + return None + _, version = line.split('/', 1) + if "." not in version: + return None + major, minor = version.split('.', 1) + try: + major = int(major) + minor = int(minor) + except ValueError: + return None + return major, minor -def read_http_body_chunked( - rfile, - headers, - limit, - request_method, - response_code, - is_request, - max_chunk_size=None -): - """ - Read an HTTP message body: - - rfile: A file descriptor to read from - headers: An ODictCaseless object - limit: Size limit. - is_request: True if the body to read belongs to a request, False - otherwise - """ - if max_chunk_size is None: - max_chunk_size = limit or sys.maxsize - - expected_size = expected_http_body_size( - headers, is_request, request_method, response_code - ) + @classmethod + def parse_init(self, line): + try: + method, url, protocol = string.split(line) + except ValueError: + return None + httpversion = self.parse_http_protocol(protocol) + if not httpversion: + return None + if not utils.isascii(method): + return None + return method, url, httpversion - if expected_size is None: - if has_chunked_encoding(headers): - # Python 3: yield from - for x in read_chunked(rfile, limit, is_request): - yield x - else: # pragma: nocover - raise HttpError( - 400 if is_request else 502, - "Content-Length unknown but no chunked encoding" - ) - elif expected_size >= 0: - if limit is not None and expected_size > limit: - raise HttpError( - 400 if is_request else 509, - "HTTP Body too large. Limit is %s, content-length was %s" % ( - limit, expected_size - ) - ) - bytes_left = expected_size - while bytes_left: - chunk_size = min(bytes_left, max_chunk_size) - yield "", rfile.read(chunk_size), "" - bytes_left -= chunk_size - else: - bytes_left = limit or -1 - while bytes_left: - chunk_size = min(bytes_left, max_chunk_size) - content = rfile.read(chunk_size) - if not content: - return - yield "", content, "" - bytes_left -= chunk_size - not_done = rfile.read(1) - if not_done: - raise HttpError( - 400 if is_request else 509, - "HTTP Body too large. Limit is %s," % limit - ) + @classmethod + def parse_init_connect(self, line): + """ + Returns (host, port, httpversion) if line is a valid CONNECT line. + http://tools.ietf.org/html/draft-luotonen-web-proxy-tunneling-01 section 3.1 + """ + v = self.parse_init(line) + if not v: + return None + method, url, httpversion = v -def expected_http_body_size(headers, is_request, request_method, response_code): - """ - Returns the expected body length: - - a positive integer, if the size is known in advance - - None, if the size in unknown in advance (chunked encoding or invalid - data) - - -1, if all data should be read until end of stream. - - May raise HttpError. - """ - # Determine response size according to - # http://tools.ietf.org/html/rfc7230#section-3.3 - if request_method: - request_method = request_method.upper() - - if (not is_request and ( - request_method == "HEAD" or - (request_method == "CONNECT" and response_code == 200) or - response_code in [204, 304] or - 100 <= response_code <= 199)): - return 0 - if has_chunked_encoding(headers): - return None - if "content-length" in headers: + if method.upper() != 'CONNECT': + return None try: - size = int(headers["content-length"][0]) - if size < 0: - raise ValueError() - return size + host, port = url.split(":") except ValueError: return None - if is_request: - return 0 - return -1 - - -# TODO: make this a regular class - just like Response -Request = collections.namedtuple( - "Request", - [ - "form_in", - "method", - "scheme", - "host", - "port", - "path", - "httpversion", - "headers", - "content" - ] -) - - -def read_request(rfile, include_body=True, body_size_limit=None, wfile=None): - """ - Parse an HTTP request from a file stream - - Args: - rfile (file): Input file to read from - include_body (bool): Read response body as well - body_size_limit (bool): Maximum body size - wfile (file): If specified, HTTP Expect headers are handled - automatically, by writing a HTTP 100 CONTINUE response to the stream. - - Returns: - Request: The HTTP request - - Raises: - HttpError: If the input is invalid. - """ - httpversion, host, port, scheme, method, path, headers, content = ( - None, None, None, None, None, None, None, None) - - request_line = get_request_line(rfile) - if not request_line: - raise tcp.NetLibDisconnect() - - request_line_parts = parse_init(request_line) - if not request_line_parts: - raise HttpError( - 400, - "Bad HTTP request line: %s" % repr(request_line) + try: + port = int(port) + except ValueError: + return None + if not http.is_valid_port(port): + return None + if not http.is_valid_host(host): + return None + return host, port, httpversion + + @classmethod + def parse_init_proxy(self, line): + v = self.parse_init(line) + if not v: + return None + method, url, httpversion = v + + parts = http.parse_url(url) + if not parts: + return None + scheme, host, port, path = parts + return method, scheme, host, port, path, httpversion + + @classmethod + def parse_init_http(self, line): + """ + Returns (method, url, httpversion) + """ + v = self.parse_init(line) + if not v: + return None + method, url, httpversion = v + if not utils.isascii(url): + return None + if not (url.startswith("/") or url == "*"): + return None + return method, url, httpversion + + + @classmethod + def connection_close(self, httpversion, headers): + """ + Checks the message to see if the client connection should be closed + according to RFC 2616 Section 8.1 Note that a connection should be + closed as well if the response has been read until end of the stream. + """ + # At first, check if we have an explicit Connection header. + if "connection" in headers: + toks = http.get_header_tokens(headers, "connection") + if "close" in toks: + return True + elif "keep-alive" in toks: + return False + + # If we don't have a Connection header, HTTP 1.1 connections are assumed to + # be persistent + return httpversion != (1, 1) + + + @classmethod + def parse_response_line(self, line): + parts = line.strip().split(" ", 2) + if len(parts) == 2: # handle missing message gracefully + parts.append("") + if len(parts) != 3: + return None + proto, code, msg = parts + try: + code = int(code) + except ValueError: + return None + return (proto, code, msg) + + + def read_http_body(self, *args, **kwargs): + return "".join( + content for _, content, _ in self.read_http_body_chunked(*args, **kwargs) + ) + + + def read_http_body_chunked( + self, + headers, + limit, + request_method, + response_code, + is_request, + max_chunk_size=None + ): + """ + Read an HTTP message body: + headers: An ODictCaseless object + limit: Size limit. + is_request: True if the body to read belongs to a request, False + otherwise + """ + if max_chunk_size is None: + max_chunk_size = limit or sys.maxsize + + expected_size = self.expected_http_body_size( + headers, is_request, request_method, response_code ) - method, path, httpversion = request_line_parts - if path == '*' or path.startswith("/"): - form_in = "relative" - if not utils.isascii(path): + if expected_size is None: + if self.has_chunked_encoding(headers): + # Python 3: yield from + for x in self.read_chunked(limit, is_request): + yield x + else: # pragma: nocover + raise HttpError( + 400 if is_request else 502, + "Content-Length unknown but no chunked encoding" + ) + elif expected_size >= 0: + if limit is not None and expected_size > limit: + raise HttpError( + 400 if is_request else 509, + "HTTP Body too large. Limit is %s, content-length was %s" % ( + limit, expected_size + ) + ) + bytes_left = expected_size + while bytes_left: + chunk_size = min(bytes_left, max_chunk_size) + yield "", self.tcp_handler.rfile.read(chunk_size), "" + bytes_left -= chunk_size + else: + bytes_left = limit or -1 + while bytes_left: + chunk_size = min(bytes_left, max_chunk_size) + content = self.tcp_handler.rfile.read(chunk_size) + if not content: + return + yield "", content, "" + bytes_left -= chunk_size + not_done = self.tcp_handler.rfile.read(1) + if not_done: + raise HttpError( + 400 if is_request else 509, + "HTTP Body too large. Limit is %s," % limit + ) + + + @classmethod + def expected_http_body_size(self, headers, is_request, request_method, response_code): + """ + Returns the expected body length: + - a positive integer, if the size is known in advance + - None, if the size in unknown in advance (chunked encoding or invalid + data) + - -1, if all data should be read until end of stream. + + May raise HttpError. + """ + # Determine response size according to + # http://tools.ietf.org/html/rfc7230#section-3.3 + if request_method: + request_method = request_method.upper() + + if (not is_request and ( + request_method == "HEAD" or + (request_method == "CONNECT" and response_code == 200) or + response_code in [204, 304] or + 100 <= response_code <= 199)): + return 0 + if self.has_chunked_encoding(headers): + return None + if "content-length" in headers: + try: + size = int(headers["content-length"][0]) + if size < 0: + raise ValueError() + return size + except ValueError: + return None + if is_request: + return 0 + return -1 + + + def read_request(self, include_body=True, body_size_limit=None): + """ + Parse an HTTP request from a file stream + + Args: + include_body (bool): Read response body as well + body_size_limit (bool): Maximum body size + wfile (file): If specified, HTTP Expect headers are handled + automatically, by writing a HTTP 100 CONTINUE response to the stream. + + Returns: + Request: The HTTP request + + Raises: + HttpError: If the input is invalid. + """ + httpversion, host, port, scheme, method, path, headers, content = ( + None, None, None, None, None, None, None, None) + + request_line = self.get_request_line() + if not request_line: + raise tcp.NetLibDisconnect() + + request_line_parts = self.parse_init(request_line) + if not request_line_parts: raise HttpError( 400, "Bad HTTP request line: %s" % repr(request_line) ) - elif method.upper() == 'CONNECT': - form_in = "authority" - r = parse_init_connect(request_line) - if not r: - raise HttpError( - 400, - "Bad HTTP request line: %s" % repr(request_line) + method, path, httpversion = request_line_parts + + if path == '*' or path.startswith("/"): + form_in = "relative" + if not utils.isascii(path): + raise HttpError( + 400, + "Bad HTTP request line: %s" % repr(request_line) + ) + elif method.upper() == 'CONNECT': + form_in = "authority" + r = self.parse_init_connect(request_line) + if not r: + raise HttpError( + 400, + "Bad HTTP request line: %s" % repr(request_line) + ) + host, port, _ = r + path = None + else: + form_in = "absolute" + r = self.parse_init_proxy(request_line) + if not r: + raise HttpError( + 400, + "Bad HTTP request line: %s" % repr(request_line) + ) + _, scheme, host, port, path, _ = r + + headers = self.read_headers() + if headers is None: + raise HttpError(400, "Invalid headers") + + expect_header = headers.get_first("expect", "").lower() + if expect_header == "100-continue" and httpversion >= (1, 1): + self.tcp_handler.wfile.write( + 'HTTP/1.1 100 Continue\r\n' + '\r\n' ) - host, port, _ = r - path = None - else: - form_in = "absolute" - r = parse_init_proxy(request_line) - if not r: - raise HttpError( - 400, - "Bad HTTP request line: %s" % repr(request_line) + self.tcp_handler.wfile.flush() + del headers['expect'] + + if include_body: + content = self.read_http_body( + headers, + body_size_limit, + method, + None, + True ) - _, scheme, host, port, path, _ = r - headers = read_headers(rfile) - if headers is None: - raise HttpError(400, "Invalid headers") - - expect_header = headers.get_first("expect", "").lower() - if expect_header == "100-continue" and httpversion >= (1, 1): - wfile.write( - 'HTTP/1.1 100 Continue\r\n' - '\r\n' + return self.Request( + form_in, + method, + scheme, + host, + port, + path, + httpversion, + headers, + content ) - wfile.flush() - del headers['expect'] - if include_body: - content = read_http_body( - rfile, headers, body_size_limit, method, None, True - ) - return Request( - form_in, - method, - scheme, - host, - port, - path, - httpversion, - headers, - content - ) + def read_response(self, request_method, body_size_limit, include_body=True): + """ + Returns an http.Response + By default, both response header and body are read. + If include_body=False is specified, content may be one of the + following: + - None, if the response is technically allowed to have a response body + - "", if the response must not have a response body (e.g. it's a + response to a HEAD request) + """ -def read_response(rfile, request_method, body_size_limit, include_body=True): - """ - Returns an http.Response - - By default, both response header and body are read. - If include_body=False is specified, content may be one of the - following: - - None, if the response is technically allowed to have a response body - - "", if the response must not have a response body (e.g. it's a - response to a HEAD request) - """ - - line = rfile.readline() - # Possible leftover from previous message - if line == "\r\n" or line == "\n": - line = rfile.readline() - if not line: - raise HttpErrorConnClosed(502, "Server disconnect.") - parts = parse_response_line(line) - if not parts: - raise HttpError(502, "Invalid server response: %s" % repr(line)) - proto, code, msg = parts - httpversion = parse_http_protocol(proto) - if httpversion is None: - raise HttpError(502, "Invalid HTTP version in line: %s" % repr(proto)) - headers = read_headers(rfile) - if headers is None: - raise HttpError(502, "Invalid headers.") - - if include_body: - content = read_http_body( - rfile, - headers, - body_size_limit, - request_method, - code, - False - ) - else: - # if include_body==False then a None content means the body should be - # read separately - content = None - return http.Response(httpversion, code, msg, headers, content) + line = self.tcp_handler.rfile.readline() + # Possible leftover from previous message + if line == "\r\n" or line == "\n": + line = self.tcp_handler.rfile.readline() + if not line: + raise HttpErrorConnClosed(502, "Server disconnect.") + parts = self.parse_response_line(line) + if not parts: + raise HttpError(502, "Invalid server response: %s" % repr(line)) + proto, code, msg = parts + httpversion = self.parse_http_protocol(proto) + if httpversion is None: + raise HttpError(502, "Invalid HTTP version in line: %s" % repr(proto)) + headers = self.read_headers() + if headers is None: + raise HttpError(502, "Invalid headers.") + + if include_body: + content = self.read_http_body( + headers, + body_size_limit, + request_method, + code, + False + ) + else: + # if include_body==False then a None content means the body should be + # read separately + content = None + return http.Response(httpversion, code, msg, headers, content) -def request_preamble(method, resource, http_major="1", http_minor="1"): - return '%s %s HTTP/%s.%s' % ( - method, resource, http_major, http_minor - ) + @classmethod + def request_preamble(self, method, resource, http_major="1", http_minor="1"): + return '%s %s HTTP/%s.%s' % ( + method, resource, http_major, http_minor + ) -def response_preamble(code, message=None, http_major="1", http_minor="1"): - if message is None: - message = status_codes.RESPONSES.get(code) - return 'HTTP/%s.%s %s %s' % (http_major, http_minor, code, message) + @classmethod + def response_preamble(self, code, message=None, http_major="1", http_minor="1"): + if message is None: + message = status_codes.RESPONSES.get(code) + return 'HTTP/%s.%s %s %s' % (http_major, http_minor, code, message) -- cgit v1.2.3 From 4617ab8a3a981f3abd8d62b561c80f9ad141e57b Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Fri, 17 Jul 2015 09:37:57 +0200 Subject: add Request class and unify read_request interface --- netlib/http/__init__.py | 1 + netlib/http/http1/protocol.py | 22 +++++----------------- netlib/http/http2/protocol.py | 20 +++++++++++++++++--- netlib/http/semantics.py | 31 +++++++++++++++++++++++++++++++ 4 files changed, 54 insertions(+), 20 deletions(-) (limited to 'netlib') diff --git a/netlib/http/__init__.py b/netlib/http/__init__.py index 9b4b0e6b..b01afc6d 100644 --- a/netlib/http/__init__.py +++ b/netlib/http/__init__.py @@ -1,2 +1,3 @@ +from . import * from exceptions import * from semantics import * diff --git a/netlib/http/http1/protocol.py b/netlib/http/http1/protocol.py index 401654c1..8d631a13 100644 --- a/netlib/http/http1/protocol.py +++ b/netlib/http/http1/protocol.py @@ -11,25 +11,10 @@ from ..exceptions import * class HTTP1Protocol(object): - # TODO: make this a regular class - just like Response - Request = collections.namedtuple( - "Request", - [ - "form_in", - "method", - "scheme", - "host", - "port", - "path", - "httpversion", - "headers", - "content" - ] - ) - def __init__(self, tcp_handler): self.tcp_handler = tcp_handler + def get_request_line(self): """ Get a line, possibly preceded by a blank. @@ -40,6 +25,7 @@ class HTTP1Protocol(object): line = self.tcp_handler.rfile.readline() return line + def read_headers(self): """ Read a set of headers. @@ -175,6 +161,7 @@ class HTTP1Protocol(object): return None return host, port, httpversion + @classmethod def parse_init_proxy(self, line): v = self.parse_init(line) @@ -188,6 +175,7 @@ class HTTP1Protocol(object): scheme, host, port, path = parts return method, scheme, host, port, path, httpversion + @classmethod def parse_init_http(self, line): """ @@ -425,7 +413,7 @@ class HTTP1Protocol(object): True ) - return self.Request( + return http.Request( form_in, method, scheme, diff --git a/netlib/http/http2/protocol.py b/netlib/http/http2/protocol.py index 0d6eac85..1dfdda21 100644 --- a/netlib/http/http2/protocol.py +++ b/netlib/http/http2/protocol.py @@ -187,11 +187,25 @@ class HTTP2Protocol(object): self._create_body(body, stream_id))) def read_response(self, *args): - stream_id_, headers, body = self._receive_transmission() - return http.Response("HTTP/2", headers[':status'], "", headers, body) + stream_id, headers, body = self._receive_transmission() + + response = http.Response("HTTP/2", headers[':status'], "", headers, body) + response.stream_id = stream_id + return response def read_request(self): - return self._receive_transmission() + stream_id, headers, body = self._receive_transmission() + + form_in = "" + method = headers.get(':method', '') + scheme = headers.get(':scheme', '') + host = headers.get(':host', '') + port = '' # TODO: parse port number? + path = headers.get(':path', '') + + request = http.Request(form_in, method, scheme, host, port, path, "HTTP/2", headers, body) + request.stream_id = stream_id + return request def _receive_transmission(self): body_expected = True diff --git a/netlib/http/semantics.py b/netlib/http/semantics.py index a62c93e3..9a010318 100644 --- a/netlib/http/semantics.py +++ b/netlib/http/semantics.py @@ -7,6 +7,37 @@ import urlparse from .. import utils +class Request(object): + + def __init__( + self, + form_in, + method, + scheme, + host, + port, + path, + httpversion, + headers, + content, + ): + self.form_in = form_in + self.method = method + self.scheme = scheme + self.host = host + self.port = port + self.path = path + self.httpversion = httpversion + self.headers = headers + self.content = content + + def __eq__(self, other): + return self.__dict__ == other.__dict__ + + def __repr__(self): + return "Request(%s - %s, %s)" % (self.method, self.host, self.path) + + class Response(object): def __init__( -- cgit v1.2.3 From 37a0cb858cda255bac8f06749a81859c82c5177f Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Sun, 19 Jul 2015 17:52:10 +0200 Subject: introduce ConnectRequest class --- netlib/http/http1/protocol.py | 2 +- netlib/http/semantics.py | 24 +++++++++++++++++++----- netlib/odict.py | 2 ++ 3 files changed, 22 insertions(+), 6 deletions(-) (limited to 'netlib') diff --git a/netlib/http/http1/protocol.py b/netlib/http/http1/protocol.py index 8d631a13..257efb19 100644 --- a/netlib/http/http1/protocol.py +++ b/netlib/http/http1/protocol.py @@ -380,7 +380,7 @@ class HTTP1Protocol(object): "Bad HTTP request line: %s" % repr(request_line) ) host, port, _ = r - path = None + return http.ConnectRequest(host, port) else: form_in = "absolute" r = self.parse_init_proxy(request_line) diff --git a/netlib/http/semantics.py b/netlib/http/semantics.py index 9a010318..664f9def 100644 --- a/netlib/http/semantics.py +++ b/netlib/http/semantics.py @@ -19,7 +19,7 @@ class Request(object): path, httpversion, headers, - content, + body, ): self.form_in = form_in self.method = method @@ -29,7 +29,7 @@ class Request(object): self.path = path self.httpversion = httpversion self.headers = headers - self.content = content + self.body = body def __eq__(self, other): return self.__dict__ == other.__dict__ @@ -38,6 +38,21 @@ class Request(object): return "Request(%s - %s, %s)" % (self.method, self.host, self.path) +class ConnectRequest(Request): + def __init__(self, host, port): + super(ConnectRequest, self).__init__( + form_in="authority", + method="CONNECT", + scheme="", + host=host, + port=port, + path="", + httpversion="", + headers="", + body="", + ) + + class Response(object): def __init__( @@ -46,14 +61,14 @@ class Response(object): status_code, msg, headers, - content, + body, sslinfo=None, ): self.httpversion = httpversion self.status_code = status_code self.msg = msg self.headers = headers - self.content = content + self.body = body self.sslinfo = sslinfo def __eq__(self, other): @@ -63,7 +78,6 @@ class Response(object): return "Response(%s - %s)" % (self.status_code, self.msg) - def is_valid_port(port): if not 0 <= port <= 65535: return False diff --git a/netlib/odict.py b/netlib/odict.py index f52acd50..ee1e6938 100644 --- a/netlib/odict.py +++ b/netlib/odict.py @@ -20,6 +20,8 @@ class ODict(object): """ def __init__(self, lst=None): + if isinstance(lst, ODict): + lst = lst.items() self.lst = lst or [] def _kconv(self, s): -- cgit v1.2.3 From d62dbee0f6cd47b4cad1ee7cc731b413600c0add Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Sun, 19 Jul 2015 18:17:30 +0200 Subject: rename content -> body --- netlib/wsgi.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) (limited to 'netlib') diff --git a/netlib/wsgi.py b/netlib/wsgi.py index ad43dc19..99afe00e 100644 --- a/netlib/wsgi.py +++ b/netlib/wsgi.py @@ -21,9 +21,9 @@ class Flow(object): class Request(object): - def __init__(self, scheme, method, path, headers, content): + def __init__(self, scheme, method, path, headers, body): self.scheme, self.method, self.path = scheme, method, path - self.headers, self.content = headers, content + self.headers, self.body = headers, body def date_time_string(): @@ -58,7 +58,7 @@ class WSGIAdaptor(object): environ = { 'wsgi.version': (1, 0), 'wsgi.url_scheme': flow.request.scheme, - 'wsgi.input': cStringIO.StringIO(flow.request.content), + 'wsgi.input': cStringIO.StringIO(flow.request.body or ""), 'wsgi.errors': errsoc, 'wsgi.multithread': True, 'wsgi.multiprocess': False, -- cgit v1.2.3 From 83f013fca13c7395ca4e3da3fac60c8d907172b6 Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Sun, 19 Jul 2015 20:46:26 +0200 Subject: introduce EmptyRequest class --- netlib/http/http1/protocol.py | 7 +++++-- netlib/http/semantics.py | 14 ++++++++++++++ 2 files changed, 19 insertions(+), 2 deletions(-) (limited to 'netlib') diff --git a/netlib/http/http1/protocol.py b/netlib/http/http1/protocol.py index 257efb19..d2a77399 100644 --- a/netlib/http/http1/protocol.py +++ b/netlib/http/http1/protocol.py @@ -333,7 +333,7 @@ class HTTP1Protocol(object): return -1 - def read_request(self, include_body=True, body_size_limit=None): + def read_request(self, include_body=True, body_size_limit=None, allow_empty=False): """ Parse an HTTP request from a file stream @@ -354,7 +354,10 @@ class HTTP1Protocol(object): request_line = self.get_request_line() if not request_line: - raise tcp.NetLibDisconnect() + if allow_empty: + return http.EmptyRequest() + else: + raise tcp.NetLibDisconnect() request_line_parts = self.parse_init(request_line) if not request_line_parts: diff --git a/netlib/http/semantics.py b/netlib/http/semantics.py index 664f9def..355906dd 100644 --- a/netlib/http/semantics.py +++ b/netlib/http/semantics.py @@ -38,6 +38,20 @@ class Request(object): return "Request(%s - %s, %s)" % (self.method, self.host, self.path) +class EmptyRequest(Request): + def __init__(self): + super(EmptyRequest, self).__init__( + form_in="", + method="", + scheme="", + host="", + port="", + path="", + httpversion="", + headers="", + body="", + ) + class ConnectRequest(Request): def __init__(self, host, port): super(ConnectRequest, self).__init__( -- cgit v1.2.3 From ecc7ffe9282ae9d1b652a88946d6edc550dc9633 Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Sun, 19 Jul 2015 23:25:15 +0200 Subject: reduce public interface MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit use private indicator pattern “_methodname” --- netlib/http/http1/protocol.py | 569 +++++++++++++++++++++--------------------- 1 file changed, 285 insertions(+), 284 deletions(-) (limited to 'netlib') diff --git a/netlib/http/http1/protocol.py b/netlib/http/http1/protocol.py index d2a77399..e7727e00 100644 --- a/netlib/http/http1/protocol.py +++ b/netlib/http/http1/protocol.py @@ -15,15 +15,144 @@ class HTTP1Protocol(object): self.tcp_handler = tcp_handler - def get_request_line(self): + def read_request(self, include_body=True, body_size_limit=None, allow_empty=False): """ - Get a line, possibly preceded by a blank. + Parse an HTTP request from a file stream + + Args: + include_body (bool): Read response body as well + body_size_limit (bool): Maximum body size + wfile (file): If specified, HTTP Expect headers are handled + automatically, by writing a HTTP 100 CONTINUE response to the stream. + + Returns: + Request: The HTTP request + + Raises: + HttpError: If the input is invalid. """ + httpversion, host, port, scheme, method, path, headers, content = ( + None, None, None, None, None, None, None, None) + + request_line = self._get_request_line() + if not request_line: + if allow_empty: + return http.EmptyRequest() + else: + raise tcp.NetLibDisconnect() + + request_line_parts = self._parse_init(request_line) + if not request_line_parts: + raise HttpError( + 400, + "Bad HTTP request line: %s" % repr(request_line) + ) + method, path, httpversion = request_line_parts + + if path == '*' or path.startswith("/"): + form_in = "relative" + if not utils.isascii(path): + raise HttpError( + 400, + "Bad HTTP request line: %s" % repr(request_line) + ) + elif method.upper() == 'CONNECT': + form_in = "authority" + r = self._parse_init_connect(request_line) + if not r: + raise HttpError( + 400, + "Bad HTTP request line: %s" % repr(request_line) + ) + host, port, _ = r + return http.ConnectRequest(host, port) + else: + form_in = "absolute" + r = self._parse_init_proxy(request_line) + if not r: + raise HttpError( + 400, + "Bad HTTP request line: %s" % repr(request_line) + ) + _, scheme, host, port, path, _ = r + + headers = self.read_headers() + if headers is None: + raise HttpError(400, "Invalid headers") + + expect_header = headers.get_first("expect", "").lower() + if expect_header == "100-continue" and httpversion >= (1, 1): + self.tcp_handler.wfile.write( + 'HTTP/1.1 100 Continue\r\n' + '\r\n' + ) + self.tcp_handler.wfile.flush() + del headers['expect'] + + if include_body: + content = self.read_http_body( + headers, + body_size_limit, + method, + None, + True + ) + + return http.Request( + form_in, + method, + scheme, + host, + port, + path, + httpversion, + headers, + content + ) + + + def read_response(self, request_method, body_size_limit, include_body=True): + """ + Returns an http.Response + + By default, both response header and body are read. + If include_body=False is specified, content may be one of the + following: + - None, if the response is technically allowed to have a response body + - "", if the response must not have a response body (e.g. it's a + response to a HEAD request) + """ + line = self.tcp_handler.rfile.readline() + # Possible leftover from previous message if line == "\r\n" or line == "\n": - # Possible leftover from previous message line = self.tcp_handler.rfile.readline() - return line + if not line: + raise HttpErrorConnClosed(502, "Server disconnect.") + parts = self.parse_response_line(line) + if not parts: + raise HttpError(502, "Invalid server response: %s" % repr(line)) + proto, code, msg = parts + httpversion = self._parse_http_protocol(proto) + if httpversion is None: + raise HttpError(502, "Invalid HTTP version in line: %s" % repr(proto)) + headers = self.read_headers() + if headers is None: + raise HttpError(502, "Invalid headers.") + + if include_body: + content = self.read_http_body( + headers, + body_size_limit, + request_method, + code, + False + ) + else: + # if include_body==False then a None content means the body should be + # read separately + content = None + return http.Response(httpversion, code, msg, headers, content) def read_headers(self): @@ -56,7 +185,146 @@ class HTTP1Protocol(object): return odict.ODictCaseless(ret) - def read_chunked(self, limit, is_request): + def read_http_body(self, *args, **kwargs): + return "".join( + content for _, content, _ in self.read_http_body_chunked(*args, **kwargs) + ) + + + def read_http_body_chunked( + self, + headers, + limit, + request_method, + response_code, + is_request, + max_chunk_size=None + ): + """ + Read an HTTP message body: + headers: An ODictCaseless object + limit: Size limit. + is_request: True if the body to read belongs to a request, False + otherwise + """ + if max_chunk_size is None: + max_chunk_size = limit or sys.maxsize + + expected_size = self.expected_http_body_size( + headers, is_request, request_method, response_code + ) + + if expected_size is None: + if self.has_chunked_encoding(headers): + # Python 3: yield from + for x in self._read_chunked(limit, is_request): + yield x + else: # pragma: nocover + raise HttpError( + 400 if is_request else 502, + "Content-Length unknown but no chunked encoding" + ) + elif expected_size >= 0: + if limit is not None and expected_size > limit: + raise HttpError( + 400 if is_request else 509, + "HTTP Body too large. Limit is %s, content-length was %s" % ( + limit, expected_size + ) + ) + bytes_left = expected_size + while bytes_left: + chunk_size = min(bytes_left, max_chunk_size) + yield "", self.tcp_handler.rfile.read(chunk_size), "" + bytes_left -= chunk_size + else: + bytes_left = limit or -1 + while bytes_left: + chunk_size = min(bytes_left, max_chunk_size) + content = self.tcp_handler.rfile.read(chunk_size) + if not content: + return + yield "", content, "" + bytes_left -= chunk_size + not_done = self.tcp_handler.rfile.read(1) + if not_done: + raise HttpError( + 400 if is_request else 509, + "HTTP Body too large. Limit is %s," % limit + ) + + + @classmethod + def expected_http_body_size(self, headers, is_request, request_method, response_code): + """ + Returns the expected body length: + - a positive integer, if the size is known in advance + - None, if the size in unknown in advance (chunked encoding or invalid + data) + - -1, if all data should be read until end of stream. + + May raise HttpError. + """ + # Determine response size according to + # http://tools.ietf.org/html/rfc7230#section-3.3 + if request_method: + request_method = request_method.upper() + + if (not is_request and ( + request_method == "HEAD" or + (request_method == "CONNECT" and response_code == 200) or + response_code in [204, 304] or + 100 <= response_code <= 199)): + return 0 + if self.has_chunked_encoding(headers): + return None + if "content-length" in headers: + try: + size = int(headers["content-length"][0]) + if size < 0: + raise ValueError() + return size + except ValueError: + return None + if is_request: + return 0 + return -1 + + + @classmethod + def request_preamble(self, method, resource, http_major="1", http_minor="1"): + return '%s %s HTTP/%s.%s' % ( + method, resource, http_major, http_minor + ) + + + @classmethod + def response_preamble(self, code, message=None, http_major="1", http_minor="1"): + if message is None: + message = status_codes.RESPONSES.get(code) + return 'HTTP/%s.%s %s %s' % (http_major, http_minor, code, message) + + + @classmethod + def has_chunked_encoding(self, headers): + return "chunked" in [ + i.lower() for i in http.get_header_tokens(headers, "transfer-encoding") + ] + + + def _get_request_line(self): + """ + Get a line, possibly preceded by a blank. + """ + line = self.tcp_handler.rfile.readline() + if line == "\r\n" or line == "\n": + # Possible leftover from previous message + line = self.tcp_handler.rfile.readline() + return line + + + + def _read_chunked(self, limit, is_request): """ Read a chunked HTTP body. @@ -88,20 +356,13 @@ class HTTP1Protocol(object): suffix = self.tcp_handler.rfile.readline(5) if suffix != '\r\n': raise HttpError(code, "Malformed chunked body") - yield line, chunk, '\r\n' - if length == 0: - return - - - @classmethod - def has_chunked_encoding(self, headers): - return "chunked" in [ - i.lower() for i in http.get_header_tokens(headers, "transfer-encoding") - ] + yield line, chunk, '\r\n' + if length == 0: + return @classmethod - def parse_http_protocol(self, line): + def _parse_http_protocol(self, line): """ Parse an HTTP protocol declaration. Returns a (major, minor) tuple, or None. @@ -121,12 +382,12 @@ class HTTP1Protocol(object): @classmethod - def parse_init(self, line): + def _parse_init(self, line): try: method, url, protocol = string.split(line) except ValueError: return None - httpversion = self.parse_http_protocol(protocol) + httpversion = self._parse_http_protocol(protocol) if not httpversion: return None if not utils.isascii(method): @@ -135,12 +396,12 @@ class HTTP1Protocol(object): @classmethod - def parse_init_connect(self, line): + def _parse_init_connect(self, line): """ Returns (host, port, httpversion) if line is a valid CONNECT line. http://tools.ietf.org/html/draft-luotonen-web-proxy-tunneling-01 section 3.1 """ - v = self.parse_init(line) + v = self._parse_init(line) if not v: return None method, url, httpversion = v @@ -163,8 +424,8 @@ class HTTP1Protocol(object): @classmethod - def parse_init_proxy(self, line): - v = self.parse_init(line) + def _parse_init_proxy(self, line): + v = self._parse_init(line) if not v: return None method, url, httpversion = v @@ -177,11 +438,11 @@ class HTTP1Protocol(object): @classmethod - def parse_init_http(self, line): + def _parse_init_http(self, line): """ Returns (method, url, httpversion) """ - v = self.parse_init(line) + v = self._parse_init(line) if not v: return None method, url, httpversion = v @@ -225,263 +486,3 @@ class HTTP1Protocol(object): except ValueError: return None return (proto, code, msg) - - - def read_http_body(self, *args, **kwargs): - return "".join( - content for _, content, _ in self.read_http_body_chunked(*args, **kwargs) - ) - - - def read_http_body_chunked( - self, - headers, - limit, - request_method, - response_code, - is_request, - max_chunk_size=None - ): - """ - Read an HTTP message body: - headers: An ODictCaseless object - limit: Size limit. - is_request: True if the body to read belongs to a request, False - otherwise - """ - if max_chunk_size is None: - max_chunk_size = limit or sys.maxsize - - expected_size = self.expected_http_body_size( - headers, is_request, request_method, response_code - ) - - if expected_size is None: - if self.has_chunked_encoding(headers): - # Python 3: yield from - for x in self.read_chunked(limit, is_request): - yield x - else: # pragma: nocover - raise HttpError( - 400 if is_request else 502, - "Content-Length unknown but no chunked encoding" - ) - elif expected_size >= 0: - if limit is not None and expected_size > limit: - raise HttpError( - 400 if is_request else 509, - "HTTP Body too large. Limit is %s, content-length was %s" % ( - limit, expected_size - ) - ) - bytes_left = expected_size - while bytes_left: - chunk_size = min(bytes_left, max_chunk_size) - yield "", self.tcp_handler.rfile.read(chunk_size), "" - bytes_left -= chunk_size - else: - bytes_left = limit or -1 - while bytes_left: - chunk_size = min(bytes_left, max_chunk_size) - content = self.tcp_handler.rfile.read(chunk_size) - if not content: - return - yield "", content, "" - bytes_left -= chunk_size - not_done = self.tcp_handler.rfile.read(1) - if not_done: - raise HttpError( - 400 if is_request else 509, - "HTTP Body too large. Limit is %s," % limit - ) - - - @classmethod - def expected_http_body_size(self, headers, is_request, request_method, response_code): - """ - Returns the expected body length: - - a positive integer, if the size is known in advance - - None, if the size in unknown in advance (chunked encoding or invalid - data) - - -1, if all data should be read until end of stream. - - May raise HttpError. - """ - # Determine response size according to - # http://tools.ietf.org/html/rfc7230#section-3.3 - if request_method: - request_method = request_method.upper() - - if (not is_request and ( - request_method == "HEAD" or - (request_method == "CONNECT" and response_code == 200) or - response_code in [204, 304] or - 100 <= response_code <= 199)): - return 0 - if self.has_chunked_encoding(headers): - return None - if "content-length" in headers: - try: - size = int(headers["content-length"][0]) - if size < 0: - raise ValueError() - return size - except ValueError: - return None - if is_request: - return 0 - return -1 - - - def read_request(self, include_body=True, body_size_limit=None, allow_empty=False): - """ - Parse an HTTP request from a file stream - - Args: - include_body (bool): Read response body as well - body_size_limit (bool): Maximum body size - wfile (file): If specified, HTTP Expect headers are handled - automatically, by writing a HTTP 100 CONTINUE response to the stream. - - Returns: - Request: The HTTP request - - Raises: - HttpError: If the input is invalid. - """ - httpversion, host, port, scheme, method, path, headers, content = ( - None, None, None, None, None, None, None, None) - - request_line = self.get_request_line() - if not request_line: - if allow_empty: - return http.EmptyRequest() - else: - raise tcp.NetLibDisconnect() - - request_line_parts = self.parse_init(request_line) - if not request_line_parts: - raise HttpError( - 400, - "Bad HTTP request line: %s" % repr(request_line) - ) - method, path, httpversion = request_line_parts - - if path == '*' or path.startswith("/"): - form_in = "relative" - if not utils.isascii(path): - raise HttpError( - 400, - "Bad HTTP request line: %s" % repr(request_line) - ) - elif method.upper() == 'CONNECT': - form_in = "authority" - r = self.parse_init_connect(request_line) - if not r: - raise HttpError( - 400, - "Bad HTTP request line: %s" % repr(request_line) - ) - host, port, _ = r - return http.ConnectRequest(host, port) - else: - form_in = "absolute" - r = self.parse_init_proxy(request_line) - if not r: - raise HttpError( - 400, - "Bad HTTP request line: %s" % repr(request_line) - ) - _, scheme, host, port, path, _ = r - - headers = self.read_headers() - if headers is None: - raise HttpError(400, "Invalid headers") - - expect_header = headers.get_first("expect", "").lower() - if expect_header == "100-continue" and httpversion >= (1, 1): - self.tcp_handler.wfile.write( - 'HTTP/1.1 100 Continue\r\n' - '\r\n' - ) - self.tcp_handler.wfile.flush() - del headers['expect'] - - if include_body: - content = self.read_http_body( - headers, - body_size_limit, - method, - None, - True - ) - - return http.Request( - form_in, - method, - scheme, - host, - port, - path, - httpversion, - headers, - content - ) - - - def read_response(self, request_method, body_size_limit, include_body=True): - """ - Returns an http.Response - - By default, both response header and body are read. - If include_body=False is specified, content may be one of the - following: - - None, if the response is technically allowed to have a response body - - "", if the response must not have a response body (e.g. it's a - response to a HEAD request) - """ - - line = self.tcp_handler.rfile.readline() - # Possible leftover from previous message - if line == "\r\n" or line == "\n": - line = self.tcp_handler.rfile.readline() - if not line: - raise HttpErrorConnClosed(502, "Server disconnect.") - parts = self.parse_response_line(line) - if not parts: - raise HttpError(502, "Invalid server response: %s" % repr(line)) - proto, code, msg = parts - httpversion = self.parse_http_protocol(proto) - if httpversion is None: - raise HttpError(502, "Invalid HTTP version in line: %s" % repr(proto)) - headers = self.read_headers() - if headers is None: - raise HttpError(502, "Invalid headers.") - - if include_body: - content = self.read_http_body( - headers, - body_size_limit, - request_method, - code, - False - ) - else: - # if include_body==False then a None content means the body should be - # read separately - content = None - return http.Response(httpversion, code, msg, headers, content) - - - @classmethod - def request_preamble(self, method, resource, http_major="1", http_minor="1"): - return '%s %s HTTP/%s.%s' % ( - method, resource, http_major, http_minor - ) - - - @classmethod - def response_preamble(self, code, message=None, http_major="1", http_minor="1"): - if message is None: - message = status_codes.RESPONSES.get(code) - return 'HTTP/%s.%s %s %s' % (http_major, http_minor, code, message) -- cgit v1.2.3 From faf17d3d60e658d0cd1df30a10be4f11035502f8 Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Mon, 20 Jul 2015 16:33:00 +0200 Subject: http2: make proper use of odict --- netlib/http/http2/protocol.py | 19 +++++++++++-------- netlib/odict.py | 2 -- 2 files changed, 11 insertions(+), 10 deletions(-) (limited to 'netlib') diff --git a/netlib/http/http2/protocol.py b/netlib/http/http2/protocol.py index 1dfdda21..55b5ca76 100644 --- a/netlib/http/http2/protocol.py +++ b/netlib/http/http2/protocol.py @@ -2,7 +2,7 @@ from __future__ import (absolute_import, print_function, division) import itertools from hpack.hpack import Encoder, Decoder -from netlib import http, utils +from netlib import http, utils, odict from . import frame @@ -189,7 +189,8 @@ class HTTP2Protocol(object): def read_response(self, *args): stream_id, headers, body = self._receive_transmission() - response = http.Response("HTTP/2", headers[':status'], "", headers, body) + status = headers[':status'][0] + response = http.Response("HTTP/2", status, "", headers, body) response.stream_id = stream_id return response @@ -197,11 +198,11 @@ class HTTP2Protocol(object): stream_id, headers, body = self._receive_transmission() form_in = "" - method = headers.get(':method', '') - scheme = headers.get(':scheme', '') - host = headers.get(':host', '') + method = headers.get(':method', [''])[0] + scheme = headers.get(':scheme', [''])[0] + host = headers.get(':host', [''])[0] port = '' # TODO: parse port number? - path = headers.get(':path', '') + path = headers.get(':path', [''])[0] request = http.Request(form_in, method, scheme, host, port, path, "HTTP/2", headers, body) request.stream_id = stream_id @@ -233,15 +234,17 @@ class HTTP2Protocol(object): break # TODO: implement window update & flow - headers = {} + headers = odict.ODictCaseless() for header, value in self.decoder.decode(header_block_fragment): - headers[header] = value + headers.add(header, value) return stream_id, headers, body def create_response(self, code, stream_id=None, headers=None, body=None): if headers is None: headers = [] + if isinstance(headers, odict.ODict): + headers = headers.items() headers = [(b':status', bytes(str(code)))] + headers diff --git a/netlib/odict.py b/netlib/odict.py index ee1e6938..f52acd50 100644 --- a/netlib/odict.py +++ b/netlib/odict.py @@ -20,8 +20,6 @@ class ODict(object): """ def __init__(self, lst=None): - if isinstance(lst, ODict): - lst = lst.items() self.lst = lst or [] def _kconv(self, s): -- cgit v1.2.3 From 657973eca3b091cdf07a65f8363affd3d36f0d0f Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Wed, 22 Jul 2015 13:01:24 +0200 Subject: fix bugs --- netlib/http/http1/protocol.py | 26 +++++++++++++++++--------- netlib/http/semantics.py | 28 +++++++++++----------------- 2 files changed, 28 insertions(+), 26 deletions(-) (limited to 'netlib') diff --git a/netlib/http/http1/protocol.py b/netlib/http/http1/protocol.py index e7727e00..e46ad7ab 100644 --- a/netlib/http/http1/protocol.py +++ b/netlib/http/http1/protocol.py @@ -9,10 +9,18 @@ from netlib import odict, utils, tcp, http from .. import status_codes from ..exceptions import * +class TCPHandler(object): + def __init__(self, rfile, wfile=None): + self.rfile = rfile + self.wfile = wfile + class HTTP1Protocol(object): - def __init__(self, tcp_handler): - self.tcp_handler = tcp_handler + def __init__(self, tcp_handler=None, rfile=None, wfile=None): + if tcp_handler: + self.tcp_handler = tcp_handler + else: + self.tcp_handler = TCPHandler(rfile, wfile) def read_request(self, include_body=True, body_size_limit=None, allow_empty=False): @@ -31,7 +39,7 @@ class HTTP1Protocol(object): Raises: HttpError: If the input is invalid. """ - httpversion, host, port, scheme, method, path, headers, content = ( + httpversion, host, port, scheme, method, path, headers, body = ( None, None, None, None, None, None, None, None) request_line = self._get_request_line() @@ -56,7 +64,7 @@ class HTTP1Protocol(object): 400, "Bad HTTP request line: %s" % repr(request_line) ) - elif method.upper() == 'CONNECT': + elif method == 'CONNECT': form_in = "authority" r = self._parse_init_connect(request_line) if not r: @@ -64,8 +72,8 @@ class HTTP1Protocol(object): 400, "Bad HTTP request line: %s" % repr(request_line) ) - host, port, _ = r - return http.ConnectRequest(host, port) + host, port, httpversion = r + path = None else: form_in = "absolute" r = self._parse_init_proxy(request_line) @@ -81,7 +89,7 @@ class HTTP1Protocol(object): raise HttpError(400, "Invalid headers") expect_header = headers.get_first("expect", "").lower() - if expect_header == "100-continue" and httpversion >= (1, 1): + if expect_header == "100-continue" and httpversion == (1, 1): self.tcp_handler.wfile.write( 'HTTP/1.1 100 Continue\r\n' '\r\n' @@ -90,7 +98,7 @@ class HTTP1Protocol(object): del headers['expect'] if include_body: - content = self.read_http_body( + body = self.read_http_body( headers, body_size_limit, method, @@ -107,7 +115,7 @@ class HTTP1Protocol(object): path, httpversion, headers, - content + body ) diff --git a/netlib/http/semantics.py b/netlib/http/semantics.py index 355906dd..9e13edaa 100644 --- a/netlib/http/semantics.py +++ b/netlib/http/semantics.py @@ -5,7 +5,7 @@ import string import sys import urlparse -from .. import utils +from .. import utils, odict class Request(object): @@ -37,6 +37,10 @@ class Request(object): def __repr__(self): return "Request(%s - %s, %s)" % (self.method, self.host, self.path) + @property + def content(self): + return self.body + class EmptyRequest(Request): def __init__(self): @@ -47,22 +51,8 @@ class EmptyRequest(Request): host="", port="", path="", - httpversion="", - headers="", - body="", - ) - -class ConnectRequest(Request): - def __init__(self, host, port): - super(ConnectRequest, self).__init__( - form_in="authority", - method="CONNECT", - scheme="", - host=host, - port=port, - path="", - httpversion="", - headers="", + httpversion=(0, 0), + headers=odict.ODictCaseless(), body="", ) @@ -91,6 +81,10 @@ class Response(object): def __repr__(self): return "Response(%s - %s)" % (self.status_code, self.msg) + @property + def content(self): + return self.body + def is_valid_port(port): if not 0 <= port <= 65535: -- cgit v1.2.3 From 1b261613826565dc5453b2846904c23773243921 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Fri, 24 Jul 2015 16:47:28 +0200 Subject: add distinct error for cert verification issues --- netlib/certutils.py | 2 -- netlib/tcp.py | 11 +++++++++-- 2 files changed, 9 insertions(+), 4 deletions(-) (limited to 'netlib') diff --git a/netlib/certutils.py b/netlib/certutils.py index c699af00..cc143a50 100644 --- a/netlib/certutils.py +++ b/netlib/certutils.py @@ -304,8 +304,6 @@ class CertStore(object): valid, plain-ASCII, IDNA-encoded domain name. sans: A list of Subject Alternate Names. - - Return None if the certificate could not be found or generated. """ potential_keys = self.asterisk_forms(commonname) diff --git a/netlib/tcp.py b/netlib/tcp.py index 5c4094d7..77c2a531 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -65,6 +65,10 @@ class NetLibSSLError(NetLibError): pass +class NetLibInvalidCertificateError(NetLibSSLError): + pass + + class SSLKeyLogger(object): def __init__(self, filename): @@ -517,13 +521,16 @@ class TCPClient(_Connection): try: self.connection.do_handshake() except SSL.Error as v: - raise NetLibError("SSL handshake error: %s" % repr(v)) + if self.ssl_verification_error: + raise NetLibInvalidCertificateError("SSL handshake error: %s" % repr(v)) + else: + raise NetLibError("SSL handshake error: %s" % repr(v)) # Fix for pre v1.0 OpenSSL, which doesn't throw an exception on # certificate validation failure verification_mode = sslctx_kwargs.get('verify_options', None) if self.ssl_verification_error is not None and verification_mode == SSL.VERIFY_PEER: - raise NetLibError("SSL handshake error: certificate verify failed") + raise NetLibInvalidCertificateError("SSL handshake error: certificate verify failed") self.ssl_established = True self.cert = certutils.SSLCert(self.connection.get_peer_certificate()) -- cgit v1.2.3 From fb482172241b6235da083f6dbf154b641772a4fc Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Sat, 25 Jul 2015 13:30:25 +0200 Subject: improve pyopenssl version check --- netlib/version_check.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'netlib') diff --git a/netlib/version_check.py b/netlib/version_check.py index 2081c410..5465c901 100644 --- a/netlib/version_check.py +++ b/netlib/version_check.py @@ -29,7 +29,7 @@ def version_check( file=fp ) sys.exit(1) - v = tuple([int(x) for x in OpenSSL.__version__.split(".")][:2]) + v = tuple(int(x) for x in OpenSSL.__version__.split(".")[:2]) if v < pyopenssl_min_version: print( "You are using an outdated version of pyOpenSSL:" -- cgit v1.2.3 From 827fe824d97d96779512c8a4032d9b30d516d63f Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Mon, 27 Jul 2015 09:36:50 +0200 Subject: move code from mitmproxy to netlib --- netlib/http/http1/protocol.py | 52 +++++++++++++++++++----- netlib/http/http2/protocol.py | 92 ++++++++++++++++++++++++++++++++++--------- netlib/http/semantics.py | 49 ++++++++++++++++++++++- 3 files changed, 163 insertions(+), 30 deletions(-) (limited to 'netlib') diff --git a/netlib/http/http1/protocol.py b/netlib/http/http1/protocol.py index e46ad7ab..af9882e8 100644 --- a/netlib/http/http1/protocol.py +++ b/netlib/http/http1/protocol.py @@ -4,6 +4,7 @@ import collections import string import sys import urlparse +import time from netlib import odict, utils, tcp, http from .. import status_codes @@ -17,10 +18,7 @@ class TCPHandler(object): class HTTP1Protocol(object): def __init__(self, tcp_handler=None, rfile=None, wfile=None): - if tcp_handler: - self.tcp_handler = tcp_handler - else: - self.tcp_handler = TCPHandler(rfile, wfile) + self.tcp_handler = tcp_handler or TCPHandler(rfile, wfile) def read_request(self, include_body=True, body_size_limit=None, allow_empty=False): @@ -39,6 +37,10 @@ class HTTP1Protocol(object): Raises: HttpError: If the input is invalid. """ + timestamp_start = time.time() + if hasattr(self.tcp_handler.rfile, "reset_timestamps"): + self.tcp_handler.rfile.reset_timestamps() + httpversion, host, port, scheme, method, path, headers, body = ( None, None, None, None, None, None, None, None) @@ -106,6 +108,12 @@ class HTTP1Protocol(object): True ) + if hasattr(self.tcp_handler.rfile, "first_byte_timestamp"): + # more accurate timestamp_start + timestamp_start = self.tcp_handler.rfile.first_byte_timestamp + + timestamp_end = time.time() + return http.Request( form_in, method, @@ -115,7 +123,9 @@ class HTTP1Protocol(object): path, httpversion, headers, - body + body, + timestamp_start, + timestamp_end, ) @@ -124,12 +134,15 @@ class HTTP1Protocol(object): Returns an http.Response By default, both response header and body are read. - If include_body=False is specified, content may be one of the + If include_body=False is specified, body may be one of the following: - None, if the response is technically allowed to have a response body - "", if the response must not have a response body (e.g. it's a response to a HEAD request) """ + timestamp_start = time.time() + if hasattr(self.tcp_handler.rfile, "reset_timestamps"): + self.tcp_handler.rfile.reset_timestamps() line = self.tcp_handler.rfile.readline() # Possible leftover from previous message @@ -149,7 +162,7 @@ class HTTP1Protocol(object): raise HttpError(502, "Invalid headers.") if include_body: - content = self.read_http_body( + body = self.read_http_body( headers, body_size_limit, request_method, @@ -157,10 +170,29 @@ class HTTP1Protocol(object): False ) else: - # if include_body==False then a None content means the body should be + # if include_body==False then a None body means the body should be # read separately - content = None - return http.Response(httpversion, code, msg, headers, content) + body = None + + + if hasattr(self.tcp_handler.rfile, "first_byte_timestamp"): + # more accurate timestamp_start + timestamp_start = self.tcp_handler.rfile.first_byte_timestamp + + if include_body: + timestamp_end = time.time() + else: + timestamp_end = None + + return http.Response( + httpversion, + code, + msg, + headers, + body, + timestamp_start=timestamp_start, + timestamp_end=timestamp_end, + ) def read_headers(self): diff --git a/netlib/http/http2/protocol.py b/netlib/http/http2/protocol.py index 55b5ca76..41321fdc 100644 --- a/netlib/http/http2/protocol.py +++ b/netlib/http/http2/protocol.py @@ -1,11 +1,18 @@ from __future__ import (absolute_import, print_function, division) import itertools +import time from hpack.hpack import Encoder, Decoder from netlib import http, utils, odict from . import frame +class TCPHandler(object): + def __init__(self, rfile, wfile=None): + self.rfile = rfile + self.wfile = wfile + + class HTTP2Protocol(object): ERROR_CODES = utils.BiDi( @@ -31,16 +38,26 @@ class HTTP2Protocol(object): ALPN_PROTO_H2 = 'h2' - def __init__(self, tcp_handler, is_server=False, dump_frames=False): - self.tcp_handler = tcp_handler + + def __init__( + self, + tcp_handler=None, + rfile=None, + wfile=None, + is_server=False, + dump_frames=False, + encoder=None, + decoder=None, + ): + self.tcp_handler = tcp_handler or TCPHandler(rfile, wfile) self.is_server = is_server + self.dump_frames = dump_frames + self.encoder = encoder or Encoder() + self.decoder = decoder or Decoder() self.http2_settings = frame.HTTP2_DEFAULT_SETTINGS.copy() self.current_stream_id = None - self.encoder = Encoder() - self.decoder = Decoder() self.connection_preface_performed = False - self.dump_frames = dump_frames def check_alpn(self): alp = self.tcp_handler.get_alpn_proto_negotiated() @@ -186,29 +203,68 @@ class HTTP2Protocol(object): self._create_headers(headers, stream_id, end_stream=(body is None)), self._create_body(body, stream_id))) - def read_response(self, *args): - stream_id, headers, body = self._receive_transmission() + def read_response(self, request_method_='', body_size_limit_=None, include_body=True): + timestamp_start = time.time() + if hasattr(self.tcp_handler.rfile, "reset_timestamps"): + self.tcp_handler.rfile.reset_timestamps() + + stream_id, headers, body = self._receive_transmission(include_body) + + if hasattr(self.tcp_handler.rfile, "first_byte_timestamp"): + # more accurate timestamp_start + timestamp_start = self.tcp_handler.rfile.first_byte_timestamp - status = headers[':status'][0] - response = http.Response("HTTP/2", status, "", headers, body) + if include_body: + timestamp_end = time.time() + else: + timestamp_end = None + + response = http.Response( + (2, 0), + headers[':status'][0], + "", + headers, + body, + timestamp_start=timestamp_start, + timestamp_end=timestamp_end, + ) response.stream_id = stream_id + return response - def read_request(self): - stream_id, headers, body = self._receive_transmission() + def read_request(self, include_body=True, body_size_limit_=None, allow_empty_=False): + timestamp_start = time.time() + if hasattr(self.tcp_handler.rfile, "reset_timestamps"): + self.tcp_handler.rfile.reset_timestamps() + + stream_id, headers, body = self._receive_transmission(include_body) + + if hasattr(self.tcp_handler.rfile, "first_byte_timestamp"): + # more accurate timestamp_start + timestamp_start = self.tcp_handler.rfile.first_byte_timestamp + + timestamp_end = time.time() - form_in = "" - method = headers.get(':method', [''])[0] - scheme = headers.get(':scheme', [''])[0] - host = headers.get(':host', [''])[0] port = '' # TODO: parse port number? - path = headers.get(':path', [''])[0] - request = http.Request(form_in, method, scheme, host, port, path, "HTTP/2", headers, body) + request = http.Request( + "", + headers.get_first(':method', ['']), + headers.get_first(':scheme', ['']), + headers.get_first(':host', ['']), + port, + headers.get_first(':path', ['']), + (2, 0), + headers, + body, + timestamp_start, + timestamp_end, + ) request.stream_id = stream_id + return request - def _receive_transmission(self): + def _receive_transmission(self, include_body=True): body_expected = True stream_id = 0 diff --git a/netlib/http/semantics.py b/netlib/http/semantics.py index 9e13edaa..63b6beb9 100644 --- a/netlib/http/semantics.py +++ b/netlib/http/semantics.py @@ -20,7 +20,11 @@ class Request(object): httpversion, headers, body, + timestamp_start=None, + timestamp_end=None, ): + assert isinstance(headers, odict.ODictCaseless) or not headers + self.form_in = form_in self.method = method self.scheme = scheme @@ -30,17 +34,30 @@ class Request(object): self.httpversion = httpversion self.headers = headers self.body = body + self.timestamp_start = timestamp_start + self.timestamp_end = timestamp_end def __eq__(self, other): - return self.__dict__ == other.__dict__ + try: + self_d = [self.__dict__[k] for k in self.__dict__ if k not in ('timestamp_start', 'timestamp_end')] + other_d = [other.__dict__[k] for k in other.__dict__ if k not in ('timestamp_start', 'timestamp_end')] + return self_d == other_d + except: + return False def __repr__(self): return "Request(%s - %s, %s)" % (self.method, self.host, self.path) @property def content(self): + # TODO: remove deprecated getter return self.body + @content.setter + def content(self, content): + # TODO: remove deprecated setter + self.body = content + class EmptyRequest(Request): def __init__(self): @@ -67,24 +84,52 @@ class Response(object): headers, body, sslinfo=None, + timestamp_start=None, + timestamp_end=None, ): + assert isinstance(headers, odict.ODictCaseless) or not headers + self.httpversion = httpversion self.status_code = status_code self.msg = msg self.headers = headers self.body = body self.sslinfo = sslinfo + self.timestamp_start = timestamp_start + self.timestamp_end = timestamp_end def __eq__(self, other): - return self.__dict__ == other.__dict__ + try: + self_d = [self.__dict__[k] for k in self.__dict__ if k not in ('timestamp_start', 'timestamp_end')] + other_d = [other.__dict__[k] for k in other.__dict__ if k not in ('timestamp_start', 'timestamp_end')] + return self_d == other_d + except: + return False def __repr__(self): return "Response(%s - %s)" % (self.status_code, self.msg) @property def content(self): + # TODO: remove deprecated getter return self.body + @content.setter + def content(self, content): + # TODO: remove deprecated setter + self.body = content + + @property + def code(self): + # TODO: remove deprecated getter + return self.status_code + + @code.setter + def code(self, code): + # TODO: remove deprecated setter + self.status_code = code + + def is_valid_port(port): if not 0 <= port <= 65535: -- cgit v1.2.3 From c7fcc2cca5ff85641febbb908d11d22336bbd81c Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Wed, 29 Jul 2015 11:27:43 +0200 Subject: add on-the-wire representation methods --- netlib/http/http1/protocol.py | 101 +++++++++++++++- netlib/http/http2/protocol.py | 261 +++++++++++++++++++++--------------------- netlib/http/semantics.py | 46 ++++++-- netlib/utils.py | 10 ++ 4 files changed, 279 insertions(+), 139 deletions(-) (limited to 'netlib') diff --git a/netlib/http/http1/protocol.py b/netlib/http/http1/protocol.py index af9882e8..b098110a 100644 --- a/netlib/http/http1/protocol.py +++ b/netlib/http/http1/protocol.py @@ -7,6 +7,7 @@ import urlparse import time from netlib import odict, utils, tcp, http +from netlib.http import semantics from .. import status_codes from ..exceptions import * @@ -15,7 +16,7 @@ class TCPHandler(object): self.rfile = rfile self.wfile = wfile -class HTTP1Protocol(object): +class HTTP1Protocol(semantics.ProtocolMixin): def __init__(self, tcp_handler=None, rfile=None, wfile=None): self.tcp_handler = tcp_handler or TCPHandler(rfile, wfile) @@ -195,6 +196,32 @@ class HTTP1Protocol(object): ) + def assemble_request(self, request): + assert isinstance(request, semantics.Request) + + if request.body == semantics.CONTENT_MISSING: + raise http.HttpError( + 502, + "Cannot assemble flow with CONTENT_MISSING" + ) + first_line = self._assemble_request_first_line(request) + headers = self._assemble_request_headers(request) + return "%s\r\n%s\r\n%s" % (first_line, headers, request.body) + + + def assemble_response(self, response): + assert isinstance(response, semantics.Response) + + if response.body == semantics.CONTENT_MISSING: + raise http.HttpError( + 502, + "Cannot assemble flow with CONTENT_MISSING" + ) + first_line = self._assemble_response_first_line(response) + headers = self._assemble_response_headers(response) + return "%s\r\n%s\r\n%s" % (first_line, headers, response.body) + + def read_headers(self): """ Read a set of headers. @@ -363,7 +390,6 @@ class HTTP1Protocol(object): return line - def _read_chunked(self, limit, is_request): """ Read a chunked HTTP body. @@ -526,3 +552,74 @@ class HTTP1Protocol(object): except ValueError: return None return (proto, code, msg) + + + @classmethod + def _assemble_request_first_line(self, request): + if request.form_in == "relative": + request_line = '%s %s HTTP/%s.%s' % ( + request.method, + request.path, + request.httpversion[0], + request.httpversion[1], + ) + elif request.form_in == "authority": + request_line = '%s %s:%s HTTP/%s.%s' % ( + request.method, + request.host, + request.port, + request.httpversion[0], + request.httpversion[1], + ) + elif request.form_in == "absolute": + request_line = '%s %s://%s:%s%s HTTP/%s.%s' % ( + request.method, + request.scheme, + request.host, + request.port, + request.path, + request.httpversion[0], + request.httpversion[1], + ) + else: + raise http.HttpError(400, "Invalid request form") + return request_line + + def _assemble_request_headers(self, request): + headers = request.headers.copy() + for k in request._headers_to_strip_off: + del headers[k] + if 'host' not in headers and request.scheme and request.host and request.port: + headers["Host"] = [utils.hostport(request.scheme, + request.host, + request.port)] + + # If content is defined (i.e. not None or CONTENT_MISSING), we always + # add a content-length header. + if request.body or request.body == "": + headers["Content-Length"] = [str(len(request.body))] + + return headers.format() + + + def _assemble_response_first_line(self, response): + return 'HTTP/%s.%s %s %s' % ( + response.httpversion[0], + response.httpversion[1], + response.status_code, + response.msg, + ) + + def _assemble_response_headers(self, response, preserve_transfer_encoding=False): + headers = response.headers.copy() + for k in response._headers_to_strip_off: + del headers[k] + if not preserve_transfer_encoding: + del headers['Transfer-Encoding'] + + # If body is defined (i.e. not None or CONTENT_MISSING), we always + # add a content-length header. + if response.body or response.body == "": + headers["Content-Length"] = [str(len(response.body))] + + return headers.format() diff --git a/netlib/http/http2/protocol.py b/netlib/http/http2/protocol.py index 41321fdc..618476e2 100644 --- a/netlib/http/http2/protocol.py +++ b/netlib/http/http2/protocol.py @@ -4,6 +4,7 @@ import time from hpack.hpack import Encoder, Decoder from netlib import http, utils, odict +from netlib.http import semantics from . import frame @@ -13,7 +14,7 @@ class TCPHandler(object): self.wfile = wfile -class HTTP2Protocol(object): +class HTTP2Protocol(semantics.ProtocolMixin): ERROR_CODES = utils.BiDi( NO_ERROR=0x0, @@ -59,26 +60,104 @@ class HTTP2Protocol(object): self.current_stream_id = None self.connection_preface_performed = False - def check_alpn(self): - alp = self.tcp_handler.get_alpn_proto_negotiated() - if alp != self.ALPN_PROTO_H2: - raise NotImplementedError( - "HTTP2Protocol can not handle unknown ALP: %s" % alp) - return True + def read_request(self, include_body=True, body_size_limit_=None, allow_empty_=False): + timestamp_start = time.time() + if hasattr(self.tcp_handler.rfile, "reset_timestamps"): + self.tcp_handler.rfile.reset_timestamps() - def _receive_settings(self, hide=False): - while True: - frm = self.read_frame(hide) - if isinstance(frm, frame.SettingsFrame): - break + stream_id, headers, body = self._receive_transmission(include_body) - def _read_settings_ack(self, hide=False): # pragma no cover - while True: - frm = self.read_frame(hide) - if isinstance(frm, frame.SettingsFrame): - assert frm.flags & frame.Frame.FLAG_ACK - assert len(frm.settings) == 0 - break + if hasattr(self.tcp_handler.rfile, "first_byte_timestamp"): + # more accurate timestamp_start + timestamp_start = self.tcp_handler.rfile.first_byte_timestamp + + timestamp_end = time.time() + + port = '' # TODO: parse port number? + + request = http.Request( + "", + headers.get_first(':method', ['']), + headers.get_first(':scheme', ['']), + headers.get_first(':host', ['']), + port, + headers.get_first(':path', ['']), + (2, 0), + headers, + body, + timestamp_start, + timestamp_end, + ) + request.stream_id = stream_id + + return request + + def read_response(self, request_method_='', body_size_limit_=None, include_body=True): + timestamp_start = time.time() + if hasattr(self.tcp_handler.rfile, "reset_timestamps"): + self.tcp_handler.rfile.reset_timestamps() + + stream_id, headers, body = self._receive_transmission(include_body) + + if hasattr(self.tcp_handler.rfile, "first_byte_timestamp"): + # more accurate timestamp_start + timestamp_start = self.tcp_handler.rfile.first_byte_timestamp + + if include_body: + timestamp_end = time.time() + else: + timestamp_end = None + + response = http.Response( + (2, 0), + headers[':status'][0], + "", + headers, + body, + timestamp_start=timestamp_start, + timestamp_end=timestamp_end, + ) + response.stream_id = stream_id + + return response + + def assemble_request(self, request): + assert isinstance(request, semantics.Request) + + authority = self.tcp_handler.sni if self.tcp_handler.sni else self.tcp_handler.address.host + if self.tcp_handler.address.port != 443: + authority += ":%d" % self.tcp_handler.address.port + + headers = [ + (b':method', bytes(request.method)), + (b':path', bytes(request.path)), + (b':scheme', b'https'), + (b':authority', authority), + ] + request.headers.items() + + if hasattr(request, 'stream_id'): + stream_id = request.stream_id + else: + stream_id = self._next_stream_id() + + return list(itertools.chain( + self._create_headers(headers, stream_id, end_stream=(request.body is None)), + self._create_body(request.body, stream_id))) + + def assemble_response(self, response): + assert isinstance(response, semantics.Response) + + headers = [(b':status', bytes(str(response.status_code)))] + response.headers.items() + + if hasattr(response, 'stream_id'): + stream_id = response.stream_id + else: + stream_id = self._next_stream_id() + + return list(itertools.chain( + self._create_headers(headers, stream_id, end_stream=(response.body is None)), + self._create_body(response.body, stream_id), + )) def perform_server_connection_preface(self, force=False): if force or not self.connection_preface_performed: @@ -100,18 +179,6 @@ class HTTP2Protocol(object): self.send_frame(frame.SettingsFrame(state=self), hide=True) self._receive_settings(hide=True) - def next_stream_id(self): - if self.current_stream_id is None: - if self.is_server: - # servers must use even stream ids - self.current_stream_id = 2 - else: - # clients must use odd stream ids - self.current_stream_id = 1 - else: - self.current_stream_id += 2 - return self.current_stream_id - def send_frame(self, frm, hide=False): raw_bytes = frm.to_bytes() self.tcp_handler.wfile.write(raw_bytes) @@ -128,6 +195,39 @@ class HTTP2Protocol(object): return frm + def check_alpn(self): + alp = self.tcp_handler.get_alpn_proto_negotiated() + if alp != self.ALPN_PROTO_H2: + raise NotImplementedError( + "HTTP2Protocol can not handle unknown ALP: %s" % alp) + return True + + def _receive_settings(self, hide=False): + while True: + frm = self.read_frame(hide) + if isinstance(frm, frame.SettingsFrame): + break + + def _read_settings_ack(self, hide=False): # pragma no cover + while True: + frm = self.read_frame(hide) + if isinstance(frm, frame.SettingsFrame): + assert frm.flags & frame.Frame.FLAG_ACK + assert len(frm.settings) == 0 + break + + def _next_stream_id(self): + if self.current_stream_id is None: + if self.is_server: + # servers must use even stream ids + self.current_stream_id = 2 + else: + # clients must use odd stream ids + self.current_stream_id = 1 + else: + self.current_stream_id += 2 + return self.current_stream_id + def _apply_settings(self, settings, hide=False): for setting, value in settings.items(): old_value = self.http2_settings[setting] @@ -181,89 +281,6 @@ class HTTP2Protocol(object): return [frm.to_bytes()] - - def create_request(self, method, path, headers=None, body=None): - if headers is None: - headers = [] - - authority = self.tcp_handler.sni if self.tcp_handler.sni else self.tcp_handler.address.host - if self.tcp_handler.address.port != 443: - authority += ":%d" % self.tcp_handler.address.port - - headers = [ - (b':method', bytes(method)), - (b':path', bytes(path)), - (b':scheme', b'https'), - (b':authority', authority), - ] + headers - - stream_id = self.next_stream_id() - - return list(itertools.chain( - self._create_headers(headers, stream_id, end_stream=(body is None)), - self._create_body(body, stream_id))) - - def read_response(self, request_method_='', body_size_limit_=None, include_body=True): - timestamp_start = time.time() - if hasattr(self.tcp_handler.rfile, "reset_timestamps"): - self.tcp_handler.rfile.reset_timestamps() - - stream_id, headers, body = self._receive_transmission(include_body) - - if hasattr(self.tcp_handler.rfile, "first_byte_timestamp"): - # more accurate timestamp_start - timestamp_start = self.tcp_handler.rfile.first_byte_timestamp - - if include_body: - timestamp_end = time.time() - else: - timestamp_end = None - - response = http.Response( - (2, 0), - headers[':status'][0], - "", - headers, - body, - timestamp_start=timestamp_start, - timestamp_end=timestamp_end, - ) - response.stream_id = stream_id - - return response - - def read_request(self, include_body=True, body_size_limit_=None, allow_empty_=False): - timestamp_start = time.time() - if hasattr(self.tcp_handler.rfile, "reset_timestamps"): - self.tcp_handler.rfile.reset_timestamps() - - stream_id, headers, body = self._receive_transmission(include_body) - - if hasattr(self.tcp_handler.rfile, "first_byte_timestamp"): - # more accurate timestamp_start - timestamp_start = self.tcp_handler.rfile.first_byte_timestamp - - timestamp_end = time.time() - - port = '' # TODO: parse port number? - - request = http.Request( - "", - headers.get_first(':method', ['']), - headers.get_first(':scheme', ['']), - headers.get_first(':host', ['']), - port, - headers.get_first(':path', ['']), - (2, 0), - headers, - body, - timestamp_start, - timestamp_end, - ) - request.stream_id = stream_id - - return request - def _receive_transmission(self, include_body=True): body_expected = True @@ -295,19 +312,3 @@ class HTTP2Protocol(object): headers.add(header, value) return stream_id, headers, body - - def create_response(self, code, stream_id=None, headers=None, body=None): - if headers is None: - headers = [] - if isinstance(headers, odict.ODict): - headers = headers.items() - - headers = [(b':status', bytes(str(code)))] + headers - - if not stream_id: - stream_id = self.next_stream_id() - - return list(itertools.chain( - self._create_headers(headers, stream_id, end_stream=(body is None)), - self._create_body(body, stream_id), - )) diff --git a/netlib/http/semantics.py b/netlib/http/semantics.py index 63b6beb9..54bf83d2 100644 --- a/netlib/http/semantics.py +++ b/netlib/http/semantics.py @@ -7,6 +7,32 @@ import urlparse from .. import utils, odict +CONTENT_MISSING = 0 + + +class ProtocolMixin(object): + + def read_request(self): + raise NotImplemented + + def read_response(self): + raise NotImplemented + + def assemble(self, message): + if isinstance(message, Request): + return self.assemble_request(message) + elif isinstance(message, Response): + return self.assemble_response(message) + else: + raise ValueError("HTTP message not supported.") + + def assemble_request(self, request): + raise NotImplemented + + def assemble_response(self, response): + raise NotImplemented + + class Request(object): def __init__( @@ -18,12 +44,14 @@ class Request(object): port, path, httpversion, - headers, - body, + headers=None, + body=None, timestamp_start=None, timestamp_end=None, ): - assert isinstance(headers, odict.ODictCaseless) or not headers + if not headers: + headers = odict.ODictCaseless() + assert isinstance(headers, odict.ODictCaseless) self.form_in = form_in self.method = method @@ -37,6 +65,7 @@ class Request(object): self.timestamp_start = timestamp_start self.timestamp_end = timestamp_end + def __eq__(self, other): try: self_d = [self.__dict__[k] for k in self.__dict__ if k not in ('timestamp_start', 'timestamp_end')] @@ -80,14 +109,16 @@ class Response(object): self, httpversion, status_code, - msg, - headers, - body, + msg=None, + headers=None, + body=None, sslinfo=None, timestamp_start=None, timestamp_end=None, ): - assert isinstance(headers, odict.ODictCaseless) or not headers + if not headers: + headers = odict.ODictCaseless() + assert isinstance(headers, odict.ODictCaseless) self.httpversion = httpversion self.status_code = status_code @@ -98,6 +129,7 @@ class Response(object): self.timestamp_start = timestamp_start self.timestamp_end = timestamp_end + def __eq__(self, other): try: self_d = [self.__dict__[k] for k in self.__dict__ if k not in ('timestamp_start', 'timestamp_end')] diff --git a/netlib/utils.py b/netlib/utils.py index bee412f9..86e33f33 100644 --- a/netlib/utils.py +++ b/netlib/utils.py @@ -129,3 +129,13 @@ class Data(object): if not os.path.exists(fullpath): raise ValueError("dataPath: %s does not exist." % fullpath) return fullpath + + +def hostport(scheme, host, port): + """ + Returns the host component, with a port specifcation if needed. + """ + if (port, scheme) in [(80, "http"), (443, "https")]: + return host + else: + return "%s:%s" % (host, port) -- cgit v1.2.3 From 7b10817670b30550dd45af48491ed8cf3cacd5e6 Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Thu, 30 Jul 2015 13:52:13 +0200 Subject: http2: improve protocol --- netlib/http/http2/protocol.py | 61 +++++++++++++++++++++++++++++-------------- netlib/odict.py | 7 +++-- 2 files changed, 46 insertions(+), 22 deletions(-) (limited to 'netlib') diff --git a/netlib/http/http2/protocol.py b/netlib/http/http2/protocol.py index 618476e2..a1ca4a18 100644 --- a/netlib/http/http2/protocol.py +++ b/netlib/http/http2/protocol.py @@ -60,7 +60,9 @@ class HTTP2Protocol(semantics.ProtocolMixin): self.current_stream_id = None self.connection_preface_performed = False - def read_request(self, include_body=True, body_size_limit_=None, allow_empty_=False): + def read_request(self, include_body=True, body_size_limit=None, allow_empty=False): + self.perform_connection_preface() + timestamp_start = time.time() if hasattr(self.tcp_handler.rfile, "reset_timestamps"): self.tcp_handler.rfile.reset_timestamps() @@ -73,15 +75,13 @@ class HTTP2Protocol(semantics.ProtocolMixin): timestamp_end = time.time() - port = '' # TODO: parse port number? - request = http.Request( - "", - headers.get_first(':method', ['']), - headers.get_first(':scheme', ['']), - headers.get_first(':host', ['']), - port, - headers.get_first(':path', ['']), + "relative", # TODO: use the correct value + headers.get_first(':method', 'GET'), + headers.get_first(':scheme', 'https'), + headers.get_first(':host', 'localhost'), + 443, # TODO: parse port number from host? + headers.get_first(':path', '/'), (2, 0), headers, body, @@ -92,7 +92,9 @@ class HTTP2Protocol(semantics.ProtocolMixin): return request - def read_response(self, request_method_='', body_size_limit_=None, include_body=True): + def read_response(self, request_method='', body_size_limit=None, include_body=True): + self.perform_connection_preface() + timestamp_start = time.time() if hasattr(self.tcp_handler.rfile, "reset_timestamps"): self.tcp_handler.rfile.reset_timestamps() @@ -110,7 +112,7 @@ class HTTP2Protocol(semantics.ProtocolMixin): response = http.Response( (2, 0), - headers[':status'][0], + int(headers.get_first(':status')), "", headers, body, @@ -121,6 +123,7 @@ class HTTP2Protocol(semantics.ProtocolMixin): return response + def assemble_request(self, request): assert isinstance(request, semantics.Request) @@ -128,12 +131,18 @@ class HTTP2Protocol(semantics.ProtocolMixin): if self.tcp_handler.address.port != 443: authority += ":%d" % self.tcp_handler.address.port - headers = [ - (b':method', bytes(request.method)), - (b':path', bytes(request.path)), - (b':scheme', b'https'), - (b':authority', authority), - ] + request.headers.items() + headers = request.headers.copy() + + if not ':authority' in headers.keys(): + headers.add(':authority', bytes(authority), prepend=True) + if not ':scheme' in headers.keys(): + headers.add(':scheme', bytes(request.scheme), prepend=True) + if not ':path' in headers.keys(): + headers.add(':path', bytes(request.path), prepend=True) + if not ':method' in headers.keys(): + headers.add(':method', bytes(request.method), prepend=True) + + headers = headers.items() if hasattr(request, 'stream_id'): stream_id = request.stream_id @@ -141,13 +150,18 @@ class HTTP2Protocol(semantics.ProtocolMixin): stream_id = self._next_stream_id() return list(itertools.chain( - self._create_headers(headers, stream_id, end_stream=(request.body is None)), + self._create_headers(headers, stream_id, end_stream=(request.body is None or len(request.body) == 0)), self._create_body(request.body, stream_id))) def assemble_response(self, response): assert isinstance(response, semantics.Response) - headers = [(b':status', bytes(str(response.status_code)))] + response.headers.items() + headers = response.headers.copy() + + if not ':status' in headers.keys(): + headers.add(':status', bytes(str(response.status_code)), prepend=True) + + headers = headers.items() if hasattr(response, 'stream_id'): stream_id = response.stream_id @@ -155,10 +169,17 @@ class HTTP2Protocol(semantics.ProtocolMixin): stream_id = self._next_stream_id() return list(itertools.chain( - self._create_headers(headers, stream_id, end_stream=(response.body is None)), + self._create_headers(headers, stream_id, end_stream=(response.body is None or len(response.body) == 0)), self._create_body(response.body, stream_id), )) + def perform_connection_preface(self, force=False): + if force or not self.connection_preface_performed: + if self.is_server: + self.perform_server_connection_preface(force) + else: + self.perform_client_connection_preface(force) + def perform_server_connection_preface(self, force=False): if force or not self.connection_preface_performed: self.connection_preface_performed = True diff --git a/netlib/odict.py b/netlib/odict.py index f52acd50..d02de08d 100644 --- a/netlib/odict.py +++ b/netlib/odict.py @@ -96,8 +96,11 @@ class ODict(object): return True return False - def add(self, key, value): - self.lst.append([key, value]) + def add(self, key, value, prepend=False): + if prepend: + self.lst.insert(0, [key, value]) + else: + self.lst.append([key, value]) def get(self, k, d=None): if k in self: -- cgit v1.2.3 From a837230320378d629ba9f25960b1dfd25c892ad9 Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Sat, 1 Aug 2015 10:39:14 +0200 Subject: move code from mitmproxy to netlib --- netlib/encoding.py | 82 ++++++++++ netlib/http/exceptions.py | 13 ++ netlib/http/http1/protocol.py | 39 +---- netlib/http/semantics.py | 366 +++++++++++++++++++++++++++++++++--------- netlib/tutils.py | 125 +++++++++++++++ netlib/utils.py | 100 ++++++++++++ 6 files changed, 616 insertions(+), 109 deletions(-) create mode 100644 netlib/encoding.py create mode 100644 netlib/tutils.py (limited to 'netlib') diff --git a/netlib/encoding.py b/netlib/encoding.py new file mode 100644 index 00000000..f107eb5f --- /dev/null +++ b/netlib/encoding.py @@ -0,0 +1,82 @@ +""" + Utility functions for decoding response bodies. +""" +from __future__ import absolute_import +import cStringIO +import gzip +import zlib + +__ALL__ = ["ENCODINGS"] + +ENCODINGS = set(["identity", "gzip", "deflate"]) + + +def decode(e, content): + encoding_map = { + "identity": identity, + "gzip": decode_gzip, + "deflate": decode_deflate, + } + if e not in encoding_map: + return None + return encoding_map[e](content) + + +def encode(e, content): + encoding_map = { + "identity": identity, + "gzip": encode_gzip, + "deflate": encode_deflate, + } + if e not in encoding_map: + return None + return encoding_map[e](content) + + +def identity(content): + """ + Returns content unchanged. Identity is the default value of + Accept-Encoding headers. + """ + return content + + +def decode_gzip(content): + gfile = gzip.GzipFile(fileobj=cStringIO.StringIO(content)) + try: + return gfile.read() + except (IOError, EOFError): + return None + + +def encode_gzip(content): + s = cStringIO.StringIO() + gf = gzip.GzipFile(fileobj=s, mode='wb') + gf.write(content) + gf.close() + return s.getvalue() + + +def decode_deflate(content): + """ + Returns decompressed data for DEFLATE. Some servers may respond with + compressed data without a zlib header or checksum. An undocumented + feature of zlib permits the lenient decompression of data missing both + values. + + http://bugs.python.org/issue5784 + """ + try: + try: + return zlib.decompress(content) + except zlib.error: + return zlib.decompress(content, -15) + except zlib.error: + return None + + +def encode_deflate(content): + """ + Returns compressed content, always including zlib header and checksum. + """ + return zlib.compress(content) diff --git a/netlib/http/exceptions.py b/netlib/http/exceptions.py index 8a2bbebc..45bd2dce 100644 --- a/netlib/http/exceptions.py +++ b/netlib/http/exceptions.py @@ -7,3 +7,16 @@ class HttpError(Exception): class HttpErrorConnClosed(HttpError): pass + + + +class HttpAuthenticationError(Exception): + def __init__(self, auth_headers=None): + super(HttpAuthenticationError, self).__init__( + "Proxy Authentication Required" + ) + self.headers = auth_headers + self.code = 407 + + def __repr__(self): + return "Proxy Authentication Required" diff --git a/netlib/http/http1/protocol.py b/netlib/http/http1/protocol.py index b098110a..a189bffc 100644 --- a/netlib/http/http1/protocol.py +++ b/netlib/http/http1/protocol.py @@ -375,7 +375,7 @@ class HTTP1Protocol(semantics.ProtocolMixin): @classmethod def has_chunked_encoding(self, headers): return "chunked" in [ - i.lower() for i in http.get_header_tokens(headers, "transfer-encoding") + i.lower() for i in utils.get_header_tokens(headers, "transfer-encoding") ] @@ -482,9 +482,9 @@ class HTTP1Protocol(semantics.ProtocolMixin): port = int(port) except ValueError: return None - if not http.is_valid_port(port): + if not utils.is_valid_port(port): return None - if not http.is_valid_host(host): + if not utils.is_valid_host(host): return None return host, port, httpversion @@ -496,7 +496,7 @@ class HTTP1Protocol(semantics.ProtocolMixin): return None method, url, httpversion = v - parts = http.parse_url(url) + parts = utils.parse_url(url) if not parts: return None scheme, host, port, path = parts @@ -528,7 +528,7 @@ class HTTP1Protocol(semantics.ProtocolMixin): """ # At first, check if we have an explicit Connection header. if "connection" in headers: - toks = http.get_header_tokens(headers, "connection") + toks = utils.get_header_tokens(headers, "connection") if "close" in toks: return True elif "keep-alive" in toks: @@ -556,34 +556,7 @@ class HTTP1Protocol(semantics.ProtocolMixin): @classmethod def _assemble_request_first_line(self, request): - if request.form_in == "relative": - request_line = '%s %s HTTP/%s.%s' % ( - request.method, - request.path, - request.httpversion[0], - request.httpversion[1], - ) - elif request.form_in == "authority": - request_line = '%s %s:%s HTTP/%s.%s' % ( - request.method, - request.host, - request.port, - request.httpversion[0], - request.httpversion[1], - ) - elif request.form_in == "absolute": - request_line = '%s %s://%s:%s%s HTTP/%s.%s' % ( - request.method, - request.scheme, - request.host, - request.port, - request.path, - request.httpversion[0], - request.httpversion[1], - ) - else: - raise http.HttpError(400, "Invalid request form") - return request_line + return request.legacy_first_line() def _assemble_request_headers(self, request): headers = request.headers.copy() diff --git a/netlib/http/semantics.py b/netlib/http/semantics.py index 54bf83d2..e7ae2b5f 100644 --- a/netlib/http/semantics.py +++ b/netlib/http/semantics.py @@ -3,9 +3,15 @@ import binascii import collections import string import sys +import urllib import urlparse from .. import utils, odict +from . import cookies +from netlib import utils, encoding + +HDR_FORM_URLENCODED = "application/x-www-form-urlencoded" +HDR_FORM_MULTIPART = "multipart/form-data" CONTENT_MISSING = 0 @@ -75,7 +81,240 @@ class Request(object): return False def __repr__(self): - return "Request(%s - %s, %s)" % (self.method, self.host, self.path) + # return "Request(%s - %s, %s)" % (self.method, self.host, self.path) + + return "".format( + self.legacy_first_line()[:-9] + ) + + def legacy_first_line(self): + if self.form_in == "relative": + return '%s %s HTTP/%s.%s' % ( + self.method, + self.path, + self.httpversion[0], + self.httpversion[1], + ) + elif self.form_in == "authority": + return '%s %s:%s HTTP/%s.%s' % ( + self.method, + self.host, + self.port, + self.httpversion[0], + self.httpversion[1], + ) + elif self.form_in == "absolute": + return '%s %s://%s:%s%s HTTP/%s.%s' % ( + self.method, + self.scheme, + self.host, + self.port, + self.path, + self.httpversion[0], + self.httpversion[1], + ) + else: + raise http.HttpError(400, "Invalid request form") + + def anticache(self): + """ + Modifies this request to remove headers that might produce a cached + response. That is, we remove ETags and If-Modified-Since headers. + """ + delheaders = [ + "if-modified-since", + "if-none-match", + ] + for i in delheaders: + del self.headers[i] + + def anticomp(self): + """ + Modifies this request to remove headers that will compress the + resource's data. + """ + self.headers["accept-encoding"] = ["identity"] + + def constrain_encoding(self): + """ + Limits the permissible Accept-Encoding values, based on what we can + decode appropriately. + """ + if self.headers["accept-encoding"]: + self.headers["accept-encoding"] = [ + ', '.join( + e for e in encoding.ENCODINGS if e in self.headers["accept-encoding"][0])] + + def update_host_header(self): + """ + Update the host header to reflect the current target. + """ + self.headers["Host"] = [self.host] + + def get_form(self): + """ + Retrieves the URL-encoded or multipart form data, returning an ODict object. + Returns an empty ODict if there is no data or the content-type + indicates non-form data. + """ + if self.body: + if self.headers.in_any("content-type", HDR_FORM_URLENCODED, True): + return self.get_form_urlencoded() + elif self.headers.in_any("content-type", HDR_FORM_MULTIPART, True): + return self.get_form_multipart() + return odict.ODict([]) + + def get_form_urlencoded(self): + """ + Retrieves the URL-encoded form data, returning an ODict object. + Returns an empty ODict if there is no data or the content-type + indicates non-form data. + """ + if self.body and self.headers.in_any( + "content-type", + HDR_FORM_URLENCODED, + True): + return odict.ODict(utils.urldecode(self.body)) + return odict.ODict([]) + + def get_form_multipart(self): + if self.body and self.headers.in_any( + "content-type", + HDR_FORM_MULTIPART, + True): + return odict.ODict( + utils.multipartdecode( + self.headers, + self.body)) + return odict.ODict([]) + + def set_form_urlencoded(self, odict): + """ + Sets the body to the URL-encoded form data, and adds the + appropriate content-type header. Note that this will destory the + existing body if there is one. + """ + # FIXME: If there's an existing content-type header indicating a + # url-encoded form, leave it alone. + self.headers["Content-Type"] = [HDR_FORM_URLENCODED] + self.body = utils.urlencode(odict.lst) + + def get_path_components(self): + """ + Returns the path components of the URL as a list of strings. + + Components are unquoted. + """ + _, _, path, _, _, _ = urlparse.urlparse(self.url) + return [urllib.unquote(i) for i in path.split("/") if i] + + def set_path_components(self, lst): + """ + Takes a list of strings, and sets the path component of the URL. + + Components are quoted. + """ + lst = [urllib.quote(i, safe="") for i in lst] + path = "/" + "/".join(lst) + scheme, netloc, _, params, query, fragment = urlparse.urlparse(self.url) + self.url = urlparse.urlunparse( + [scheme, netloc, path, params, query, fragment] + ) + + def get_query(self): + """ + Gets the request query string. Returns an ODict object. + """ + _, _, _, _, query, _ = urlparse.urlparse(self.url) + if query: + return odict.ODict(utils.urldecode(query)) + return odict.ODict([]) + + def set_query(self, odict): + """ + Takes an ODict object, and sets the request query string. + """ + scheme, netloc, path, params, _, fragment = urlparse.urlparse(self.url) + query = utils.urlencode(odict.lst) + self.url = urlparse.urlunparse( + [scheme, netloc, path, params, query, fragment] + ) + + def pretty_host(self, hostheader): + """ + Heuristic to get the host of the request. + + Note that pretty_host() does not always return the TCP destination + of the request, e.g. if an upstream proxy is in place + + If hostheader is set to True, the Host: header will be used as + additional (and preferred) data source. This is handy in + transparent mode, where only the IO of the destination is known, + but not the resolved name. This is disabled by default, as an + attacker may spoof the host header to confuse an analyst. + """ + host = None + if hostheader: + host = self.headers.get_first("host") + if not host: + host = self.host + if host: + try: + return host.encode("idna") + except ValueError: + return host + else: + return None + + def pretty_url(self, hostheader): + if self.form_out == "authority": # upstream proxy mode + return "%s:%s" % (self.pretty_host(hostheader), self.port) + return utils.unparse_url(self.scheme, + self.pretty_host(hostheader), + self.port, + self.path).encode('ascii') + + def get_cookies(self): + """ + Returns a possibly empty netlib.odict.ODict object. + """ + ret = odict.ODict() + for i in self.headers["cookie"]: + ret.extend(cookies.parse_cookie_header(i)) + return ret + + def set_cookies(self, odict): + """ + Takes an netlib.odict.ODict object. Over-writes any existing Cookie + headers. + """ + v = cookies.format_cookie_header(odict) + self.headers["Cookie"] = [v] + + @property + def url(self): + """ + Returns a URL string, constructed from the Request's URL components. + """ + return utils.unparse_url( + self.scheme, + self.host, + self.port, + self.path + ).encode('ascii') + + @url.setter + def url(self, url): + """ + Parses a URL specification, and updates the Request's information + accordingly. + + Returns False if the URL was invalid, True if the request succeeded. + """ + parts = utils.parse_url(url) + if not parts: + raise ValueError("Invalid URL: %s" % url) + self.scheme, self.host, self.port, self.path = parts @property def content(self): @@ -139,7 +378,56 @@ class Response(object): return False def __repr__(self): - return "Response(%s - %s)" % (self.status_code, self.msg) + # return "Response(%s - %s)" % (self.status_code, self.msg) + + if self.body: + size = utils.pretty_size(len(self.body)) + else: + size = "content missing" + return "".format( + status_code=self.status_code, + msg=self.msg, + contenttype=self.headers.get_first( + "content-type", "unknown content type" + ), + size=size + ) + + + def get_cookies(self): + """ + Get the contents of all Set-Cookie headers. + + Returns a possibly empty ODict, where keys are cookie name strings, + and values are [value, attr] lists. Value is a string, and attr is + an ODictCaseless containing cookie attributes. Within attrs, unary + attributes (e.g. HTTPOnly) are indicated by a Null value. + """ + ret = [] + for header in self.headers["set-cookie"]: + v = cookies.parse_set_cookie_header(header) + if v: + name, value, attrs = v + ret.append([name, [value, attrs]]) + return odict.ODict(ret) + + def set_cookies(self, odict): + """ + Set the Set-Cookie headers on this response, over-writing existing + headers. + + Accepts an ODict of the same format as that returned by get_cookies. + """ + values = [] + for i in odict.lst: + values.append( + cookies.format_set_cookie_header( + i[0], + i[1][0], + i[1][1] + ) + ) + self.headers["Set-Cookie"] = values @property def content(self): @@ -160,77 +448,3 @@ class Response(object): def code(self, code): # TODO: remove deprecated setter self.status_code = code - - - -def is_valid_port(port): - if not 0 <= port <= 65535: - return False - return True - - -def is_valid_host(host): - try: - host.decode("idna") - except ValueError: - return False - if "\0" in host: - return None - return True - - -def parse_url(url): - """ - Returns a (scheme, host, port, path) tuple, or None on error. - - Checks that: - port is an integer 0-65535 - host is a valid IDNA-encoded hostname with no null-bytes - path is valid ASCII - """ - try: - scheme, netloc, path, params, query, fragment = urlparse.urlparse(url) - except ValueError: - return None - if not scheme: - return None - if '@' in netloc: - # FIXME: Consider what to do with the discarded credentials here Most - # probably we should extend the signature to return these as a separate - # value. - _, netloc = string.rsplit(netloc, '@', maxsplit=1) - if ':' in netloc: - host, port = string.rsplit(netloc, ':', maxsplit=1) - try: - port = int(port) - except ValueError: - return None - else: - host = netloc - if scheme == "https": - port = 443 - else: - port = 80 - path = urlparse.urlunparse(('', '', path, params, query, fragment)) - if not path.startswith("/"): - path = "/" + path - if not is_valid_host(host): - return None - if not utils.isascii(path): - return None - if not is_valid_port(port): - return None - return scheme, host, port, path - - -def get_header_tokens(headers, key): - """ - Retrieve all tokens for a header key. A number of different headers - follow a pattern where each header line can containe comma-separated - tokens, and headers can be set multiple times. - """ - toks = [] - for i in headers[key]: - for j in i.split(","): - toks.append(j.strip()) - return toks diff --git a/netlib/tutils.py b/netlib/tutils.py new file mode 100644 index 00000000..5018b9e8 --- /dev/null +++ b/netlib/tutils.py @@ -0,0 +1,125 @@ +import cStringIO +import tempfile +import os +import time +import shutil +from contextlib import contextmanager + +from netlib import tcp, utils, odict, http + + +def treader(bytes): + """ + Construct a tcp.Read object from bytes. + """ + fp = cStringIO.StringIO(bytes) + return tcp.Reader(fp) + + +@contextmanager +def tmpdir(*args, **kwargs): + orig_workdir = os.getcwd() + temp_workdir = tempfile.mkdtemp(*args, **kwargs) + os.chdir(temp_workdir) + + yield temp_workdir + + os.chdir(orig_workdir) + shutil.rmtree(temp_workdir) + + +def raises(exc, obj, *args, **kwargs): + """ + Assert that a callable raises a specified exception. + + :exc An exception class or a string. If a class, assert that an + exception of this type is raised. If a string, assert that the string + occurs in the string representation of the exception, based on a + case-insenstivie match. + + :obj A callable object. + + :args Arguments to be passsed to the callable. + + :kwargs Arguments to be passed to the callable. + """ + try: + ret = obj(*args, **kwargs) + except Exception as v: + if isinstance(exc, basestring): + if exc.lower() in str(v).lower(): + return + else: + raise AssertionError( + "Expected %s, but caught %s" % ( + repr(str(exc)), v + ) + ) + else: + if isinstance(v, exc): + return + else: + raise AssertionError( + "Expected %s, but caught %s %s" % ( + exc.__name__, v.__class__.__name__, str(v) + ) + ) + raise AssertionError("No exception raised. Return value: {}".format(ret)) + +test_data = utils.Data(__name__) + + + + +def treq(content="content", scheme="http", host="address", port=22): + """ + @return: libmproxy.protocol.http.HTTPRequest + """ + headers = odict.ODictCaseless() + headers["header"] = ["qvalue"] + req = http.Request( + "relative", + "GET", + scheme, + host, + port, + "/path", + (1, 1), + headers, + content, + None, + None, + ) + return req + + +def treq_absolute(content="content"): + """ + @return: libmproxy.protocol.http.HTTPRequest + """ + r = treq(content) + r.form_in = r.form_out = "absolute" + r.host = "address" + r.port = 22 + r.scheme = "http" + return r + + +def tresp(content="message"): + """ + @return: libmproxy.protocol.http.HTTPResponse + """ + + headers = odict.ODictCaseless() + headers["header_response"] = ["svalue"] + + resp = http.semantics.Response( + (1, 1), + 200, + "OK", + headers, + content, + time.time(), + time.time(), + ) + return resp diff --git a/netlib/utils.py b/netlib/utils.py index 86e33f33..39354605 100644 --- a/netlib/utils.py +++ b/netlib/utils.py @@ -1,5 +1,10 @@ from __future__ import (absolute_import, print_function, division) import os.path +import cgi +import urllib +import urlparse +import string + def isascii(s): try: @@ -131,6 +136,81 @@ class Data(object): return fullpath + + +def is_valid_port(port): + if not 0 <= port <= 65535: + return False + return True + + +def is_valid_host(host): + try: + host.decode("idna") + except ValueError: + return False + if "\0" in host: + return None + return True + + +def parse_url(url): + """ + Returns a (scheme, host, port, path) tuple, or None on error. + + Checks that: + port is an integer 0-65535 + host is a valid IDNA-encoded hostname with no null-bytes + path is valid ASCII + """ + try: + scheme, netloc, path, params, query, fragment = urlparse.urlparse(url) + except ValueError: + return None + if not scheme: + return None + if '@' in netloc: + # FIXME: Consider what to do with the discarded credentials here Most + # probably we should extend the signature to return these as a separate + # value. + _, netloc = string.rsplit(netloc, '@', maxsplit=1) + if ':' in netloc: + host, port = string.rsplit(netloc, ':', maxsplit=1) + try: + port = int(port) + except ValueError: + return None + else: + host = netloc + if scheme == "https": + port = 443 + else: + port = 80 + path = urlparse.urlunparse(('', '', path, params, query, fragment)) + if not path.startswith("/"): + path = "/" + path + if not is_valid_host(host): + return None + if not isascii(path): + return None + if not is_valid_port(port): + return None + return scheme, host, port, path + + +def get_header_tokens(headers, key): + """ + Retrieve all tokens for a header key. A number of different headers + follow a pattern where each header line can containe comma-separated + tokens, and headers can be set multiple times. + """ + toks = [] + for i in headers[key]: + for j in i.split(","): + toks.append(j.strip()) + return toks + + def hostport(scheme, host, port): """ Returns the host component, with a port specifcation if needed. @@ -139,3 +219,23 @@ def hostport(scheme, host, port): return host else: return "%s:%s" % (host, port) + +def unparse_url(scheme, host, port, path=""): + """ + Returns a URL string, constructed from the specified compnents. + """ + return "%s://%s%s" % (scheme, hostport(scheme, host, port), path) + + +def urlencode(s): + """ + Takes a list of (key, value) tuples and returns a urlencoded string. + """ + s = [tuple(i) for i in s] + return urllib.urlencode(s, False) + +def urldecode(s): + """ + Takes a urlencoded string and returns a list of (key, value) tuples. + """ + return cgi.parse_qsl(s, keep_blank_values=True) -- cgit v1.2.3 From 0be84fd6b96c170db6020b5aed1e962d64ffedda Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Sat, 1 Aug 2015 14:49:15 +0200 Subject: fix tutils imports --- netlib/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'netlib') diff --git a/netlib/utils.py b/netlib/utils.py index 39354605..35ea0ec7 100644 --- a/netlib/utils.py +++ b/netlib/utils.py @@ -130,7 +130,7 @@ class Data(object): This function will raise ValueError if the path does not exist. """ - fullpath = os.path.join(self.dirname, path) + fullpath = os.path.join(self.dirname, '../test/', path) if not os.path.exists(fullpath): raise ValueError("dataPath: %s does not exist." % fullpath) return fullpath -- cgit v1.2.3 From 6a678d86e16ccab7d16a74c79a6a0b928007d532 Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Sun, 2 Aug 2015 11:27:01 +0200 Subject: fix mitmproxy tests --- netlib/http/exceptions.py | 6 ++++-- netlib/http/http1/protocol.py | 3 ++- 2 files changed, 6 insertions(+), 3 deletions(-) (limited to 'netlib') diff --git a/netlib/http/exceptions.py b/netlib/http/exceptions.py index 45bd2dce..7cd26c12 100644 --- a/netlib/http/exceptions.py +++ b/netlib/http/exceptions.py @@ -1,5 +1,6 @@ -class HttpError(Exception): +from netlib import odict +class HttpError(Exception): def __init__(self, code, message): super(HttpError, self).__init__(message) self.code = code @@ -9,12 +10,13 @@ class HttpErrorConnClosed(HttpError): pass - class HttpAuthenticationError(Exception): def __init__(self, auth_headers=None): super(HttpAuthenticationError, self).__init__( "Proxy Authentication Required" ) + if isinstance(auth_headers, dict): + auth_headers = odict.ODictCaseless(auth_headers.items()) self.headers = auth_headers self.code = 407 diff --git a/netlib/http/http1/protocol.py b/netlib/http/http1/protocol.py index a189bffc..2e85a762 100644 --- a/netlib/http/http1/protocol.py +++ b/netlib/http/http1/protocol.py @@ -302,7 +302,8 @@ class HTTP1Protocol(semantics.ProtocolMixin): bytes_left = expected_size while bytes_left: chunk_size = min(bytes_left, max_chunk_size) - yield "", self.tcp_handler.rfile.read(chunk_size), "" + content = self.tcp_handler.rfile.read(chunk_size) + yield "", content, "" bytes_left -= chunk_size else: bytes_left = limit or -1 -- cgit v1.2.3 From c2832ef72bd4eed485a1c8d4bcb732da69896444 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Mon, 3 Aug 2015 18:06:31 +0200 Subject: fix mitmproxy/mitmproxy#705 --- netlib/tcp.py | 6 +++++- netlib/version_check.py | 25 ++++++++++++------------- 2 files changed, 17 insertions(+), 14 deletions(-) (limited to 'netlib') diff --git a/netlib/tcp.py b/netlib/tcp.py index 77c2a531..c355cfdd 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -11,7 +11,11 @@ import certifi import OpenSSL from OpenSSL import SSL -from . import certutils +from . import certutils, version_check + +# This is a rather hackish way to make sure that +# the latest version of pyOpenSSL is actually installed. +version_check.check_pyopenssl_version() EINTR = 4 diff --git a/netlib/version_check.py b/netlib/version_check.py index 5465c901..aae4e8c7 100644 --- a/netlib/version_check.py +++ b/netlib/version_check.py @@ -1,23 +1,19 @@ -from __future__ import print_function, absolute_import +""" +Having installed a wrong version of pyOpenSSL or netlib is unfortunately a +very common source of error. Check before every start that both versions +are somewhat okay. +""" +from __future__ import division, absolute_import, print_function, unicode_literals import sys import inspect import os.path - import OpenSSL from . import version PYOPENSSL_MIN_VERSION = (0, 15) -def version_check( - mitmproxy_version, - pyopenssl_min_version=PYOPENSSL_MIN_VERSION, - fp=sys.stderr): - """ - Having installed a wrong version of pyOpenSSL or netlib is unfortunately a - very common source of error. Check before every start that both versions - are somewhat okay. - """ +def check_mitmproxy_version(mitmproxy_version, fp=sys.stderr): # We don't introduce backward-incompatible changes in patch versions. Only # consider major and minor version. if version.IVERSION[:2] != mitmproxy_version[:2]: @@ -29,12 +25,15 @@ def version_check( file=fp ) sys.exit(1) + + +def check_pyopenssl_version(min_version=PYOPENSSL_MIN_VERSION, fp=sys.stderr): v = tuple(int(x) for x in OpenSSL.__version__.split(".")[:2]) - if v < pyopenssl_min_version: + if v < min_version: print( "You are using an outdated version of pyOpenSSL:" " mitmproxy requires pyOpenSSL %s or greater." % - str(pyopenssl_min_version), + str(min_version), file=fp ) # Some users apparently have multiple versions of pyOpenSSL installed. -- cgit v1.2.3 From 690b8b4f4e00d60b373b5a1481930f21bbc5054a Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Wed, 5 Aug 2015 21:32:53 +0200 Subject: add move tests and code from mitmproxy --- netlib/http/http1/protocol.py | 14 ----------- netlib/http/semantics.py | 43 ++++++++++++++++++++------------- netlib/odict.py | 3 ++- netlib/tutils.py | 4 ++-- netlib/utils.py | 56 +++++++++++++++++++++++++++++++++++++++++++ 5 files changed, 87 insertions(+), 33 deletions(-) (limited to 'netlib') diff --git a/netlib/http/http1/protocol.py b/netlib/http/http1/protocol.py index 2e85a762..31e9cc85 100644 --- a/netlib/http/http1/protocol.py +++ b/netlib/http/http1/protocol.py @@ -359,20 +359,6 @@ class HTTP1Protocol(semantics.ProtocolMixin): return -1 - @classmethod - def request_preamble(self, method, resource, http_major="1", http_minor="1"): - return '%s %s HTTP/%s.%s' % ( - method, resource, http_major, http_minor - ) - - - @classmethod - def response_preamble(self, code, message=None, http_major="1", http_minor="1"): - if message is None: - message = status_codes.RESPONSES.get(code) - return 'HTTP/%s.%s %s %s' % (http_major, http_minor, code, message) - - @classmethod def has_chunked_encoding(self, headers): return "chunked" in [ diff --git a/netlib/http/semantics.py b/netlib/http/semantics.py index e7ae2b5f..974fe6e6 100644 --- a/netlib/http/semantics.py +++ b/netlib/http/semantics.py @@ -7,7 +7,7 @@ import urllib import urlparse from .. import utils, odict -from . import cookies +from . import cookies, exceptions from netlib import utils, encoding HDR_FORM_URLENCODED = "application/x-www-form-urlencoded" @@ -18,10 +18,10 @@ CONTENT_MISSING = 0 class ProtocolMixin(object): - def read_request(self): + def read_request(self, *args, **kwargs): # pragma: no cover raise NotImplemented - def read_response(self): + def read_response(self, *args, **kwargs): # pragma: no cover raise NotImplemented def assemble(self, message): @@ -32,14 +32,23 @@ class ProtocolMixin(object): else: raise ValueError("HTTP message not supported.") - def assemble_request(self, request): + def assemble_request(self, request): # pragma: no cover raise NotImplemented - def assemble_response(self, response): + def assemble_response(self, response): # pragma: no cover raise NotImplemented class Request(object): + # This list is adopted legacy code. + # We probably don't need to strip off keep-alive. + _headers_to_strip_off = [ + 'Proxy-Connection', + 'Keep-Alive', + 'Connection', + 'Transfer-Encoding', + 'Upgrade', + ] def __init__( self, @@ -71,7 +80,6 @@ class Request(object): self.timestamp_start = timestamp_start self.timestamp_end = timestamp_end - def __eq__(self, other): try: self_d = [self.__dict__[k] for k in self.__dict__ if k not in ('timestamp_start', 'timestamp_end')] @@ -114,7 +122,7 @@ class Request(object): self.httpversion[1], ) else: - raise http.HttpError(400, "Invalid request form") + raise exceptions.HttpError(400, "Invalid request form") def anticache(self): """ @@ -143,7 +151,7 @@ class Request(object): if self.headers["accept-encoding"]: self.headers["accept-encoding"] = [ ', '.join( - e for e in encoding.ENCODINGS if e in self.headers["accept-encoding"][0])] + e for e in encoding.ENCODINGS if e in self.headers.get_first("accept-encoding"))] def update_host_header(self): """ @@ -317,12 +325,12 @@ class Request(object): self.scheme, self.host, self.port, self.path = parts @property - def content(self): + def content(self): # pragma: no cover # TODO: remove deprecated getter return self.body @content.setter - def content(self, content): + def content(self, content): # pragma: no cover # TODO: remove deprecated setter self.body = content @@ -343,6 +351,11 @@ class EmptyRequest(Request): class Response(object): + _headers_to_strip_off = [ + 'Proxy-Connection', + 'Alternate-Protocol', + 'Alt-Svc', + ] def __init__( self, @@ -368,7 +381,6 @@ class Response(object): self.timestamp_start = timestamp_start self.timestamp_end = timestamp_end - def __eq__(self, other): try: self_d = [self.__dict__[k] for k in self.__dict__ if k not in ('timestamp_start', 'timestamp_end')] @@ -393,7 +405,6 @@ class Response(object): size=size ) - def get_cookies(self): """ Get the contents of all Set-Cookie headers. @@ -430,21 +441,21 @@ class Response(object): self.headers["Set-Cookie"] = values @property - def content(self): + def content(self): # pragma: no cover # TODO: remove deprecated getter return self.body @content.setter - def content(self, content): + def content(self, content): # pragma: no cover # TODO: remove deprecated setter self.body = content @property - def code(self): + def code(self): # pragma: no cover # TODO: remove deprecated getter return self.status_code @code.setter - def code(self, code): + def code(self, code): # pragma: no cover # TODO: remove deprecated setter self.status_code = code diff --git a/netlib/odict.py b/netlib/odict.py index d02de08d..11d5d52a 100644 --- a/netlib/odict.py +++ b/netlib/odict.py @@ -91,8 +91,9 @@ class ODict(object): self.lst = self._filter_lst(k, self.lst) def __contains__(self, k): + k = self._kconv(k) for i in self.lst: - if self._kconv(i[0]) == self._kconv(k): + if self._kconv(i[0]) == k: return True return False diff --git a/netlib/tutils.py b/netlib/tutils.py index 5018b9e8..3c471d0d 100644 --- a/netlib/tutils.py +++ b/netlib/tutils.py @@ -119,7 +119,7 @@ def tresp(content="message"): "OK", headers, content, - time.time(), - time.time(), + timestamp_start=time.time(), + timestamp_end=time.time(), ) return resp diff --git a/netlib/utils.py b/netlib/utils.py index 35ea0ec7..2dfcafc6 100644 --- a/netlib/utils.py +++ b/netlib/utils.py @@ -4,6 +4,7 @@ import cgi import urllib import urlparse import string +import re def isascii(s): @@ -239,3 +240,58 @@ def urldecode(s): Takes a urlencoded string and returns a list of (key, value) tuples. """ return cgi.parse_qsl(s, keep_blank_values=True) + + +def parse_content_type(c): + """ + A simple parser for content-type values. Returns a (type, subtype, + parameters) tuple, where type and subtype are strings, and parameters + is a dict. If the string could not be parsed, return None. + + E.g. the following string: + + text/html; charset=UTF-8 + + Returns: + + ("text", "html", {"charset": "UTF-8"}) + """ + parts = c.split(";", 1) + ts = parts[0].split("/", 1) + if len(ts) != 2: + return None + d = {} + if len(parts) == 2: + for i in parts[1].split(";"): + clause = i.split("=", 1) + if len(clause) == 2: + d[clause[0].strip()] = clause[1].strip() + return ts[0].lower(), ts[1].lower(), d + + +def multipartdecode(hdrs, content): + """ + Takes a multipart boundary encoded string and returns list of (key, value) tuples. + """ + v = hdrs.get_first("content-type") + if v: + v = parse_content_type(v) + if not v: + return [] + boundary = v[2].get("boundary") + if not boundary: + return [] + + rx = re.compile(r'\bname="([^"]+)"') + r = [] + + for i in content.split("--" + boundary): + parts = i.splitlines() + if len(parts) > 1 and parts[0][0:2] != "--": + match = rx.search(parts[1]) + if match: + key = match.group(1) + value = "".join(parts[3 + parts[2:].index(""):]) + r.append((key, value)) + return r + return [] -- cgit v1.2.3 From 476badf45cd085d69b6162cd48983e3cd22cefcc Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Mon, 10 Aug 2015 20:36:47 +0200 Subject: cleanup imports --- netlib/http/authentication.py | 2 -- netlib/http/http1/protocol.py | 4 ---- netlib/http/semantics.py | 4 ---- netlib/websockets/frame.py | 5 ++--- netlib/websockets/protocol.py | 5 ++--- 5 files changed, 4 insertions(+), 16 deletions(-) (limited to 'netlib') diff --git a/netlib/http/authentication.py b/netlib/http/authentication.py index 9a227010..29b9eb3c 100644 --- a/netlib/http/authentication.py +++ b/netlib/http/authentication.py @@ -2,7 +2,6 @@ from __future__ import (absolute_import, print_function, division) from argparse import Action, ArgumentTypeError import binascii -from .. import http def parse_http_basic_auth(s): words = s.split() @@ -37,7 +36,6 @@ class NullProxyAuth(object): """ Clean up authentication headers, so they're not passed upstream. """ - pass def authenticate(self, headers_): """ diff --git a/netlib/http/http1/protocol.py b/netlib/http/http1/protocol.py index 31e9cc85..c797e930 100644 --- a/netlib/http/http1/protocol.py +++ b/netlib/http/http1/protocol.py @@ -1,14 +1,10 @@ from __future__ import (absolute_import, print_function, division) -import binascii -import collections import string import sys -import urlparse import time from netlib import odict, utils, tcp, http from netlib.http import semantics -from .. import status_codes from ..exceptions import * class TCPHandler(object): diff --git a/netlib/http/semantics.py b/netlib/http/semantics.py index 974fe6e6..15add957 100644 --- a/netlib/http/semantics.py +++ b/netlib/http/semantics.py @@ -1,8 +1,4 @@ from __future__ import (absolute_import, print_function, division) -import binascii -import collections -import string -import sys import urllib import urlparse diff --git a/netlib/websockets/frame.py b/netlib/websockets/frame.py index 49d8ee10..ad4ad0ee 100644 --- a/netlib/websockets/frame.py +++ b/netlib/websockets/frame.py @@ -1,12 +1,11 @@ from __future__ import absolute_import -import base64 -import hashlib import os import struct import io from .protocol import Masker -from netlib import utils, odict, tcp +from netlib import tcp +from netlib import utils DEFAULT = object() diff --git a/netlib/websockets/protocol.py b/netlib/websockets/protocol.py index 29b4db3d..8169309a 100644 --- a/netlib/websockets/protocol.py +++ b/netlib/websockets/protocol.py @@ -2,10 +2,9 @@ from __future__ import absolute_import import base64 import hashlib import os -import struct -import io -from netlib import utils, odict, tcp +from netlib import odict +from netlib import utils # Colleciton of utility functions that implement small portions of the RFC6455 # WebSockets Protocol Useful for building WebSocket clients and servers. -- cgit v1.2.3 From ff27d65f08d00c312a162965c5b1db711aa8f6ed Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Mon, 10 Aug 2015 20:44:36 +0200 Subject: cleanup whitespace --- netlib/http/exceptions.py | 3 +++ netlib/http/http1/protocol.py | 50 ++++++++++++++++++++++++------------------- netlib/http/http2/frame.py | 2 +- netlib/http/http2/protocol.py | 17 +++++++++++---- netlib/http/semantics.py | 10 ++++----- netlib/tutils.py | 2 -- netlib/utils.py | 5 +++-- netlib/websockets/frame.py | 1 + netlib/websockets/protocol.py | 2 ++ 9 files changed, 56 insertions(+), 36 deletions(-) (limited to 'netlib') diff --git a/netlib/http/exceptions.py b/netlib/http/exceptions.py index 7cd26c12..987a7908 100644 --- a/netlib/http/exceptions.py +++ b/netlib/http/exceptions.py @@ -1,6 +1,8 @@ from netlib import odict + class HttpError(Exception): + def __init__(self, code, message): super(HttpError, self).__init__(message) self.code = code @@ -11,6 +13,7 @@ class HttpErrorConnClosed(HttpError): class HttpAuthenticationError(Exception): + def __init__(self, auth_headers=None): super(HttpAuthenticationError, self).__init__( "Proxy Authentication Required" diff --git a/netlib/http/http1/protocol.py b/netlib/http/http1/protocol.py index c797e930..8eeb7744 100644 --- a/netlib/http/http1/protocol.py +++ b/netlib/http/http1/protocol.py @@ -7,18 +7,25 @@ from netlib import odict, utils, tcp, http from netlib.http import semantics from ..exceptions import * + class TCPHandler(object): + def __init__(self, rfile, wfile=None): self.rfile = rfile self.wfile = wfile + class HTTP1Protocol(semantics.ProtocolMixin): def __init__(self, tcp_handler=None, rfile=None, wfile=None): self.tcp_handler = tcp_handler or TCPHandler(rfile, wfile) - - def read_request(self, include_body=True, body_size_limit=None, allow_empty=False): + def read_request( + self, + include_body=True, + body_size_limit=None, + allow_empty=False, + ): """ Parse an HTTP request from a file stream @@ -125,8 +132,12 @@ class HTTP1Protocol(semantics.ProtocolMixin): timestamp_end, ) - - def read_response(self, request_method, body_size_limit, include_body=True): + def read_response( + self, + request_method, + body_size_limit, + include_body=True, + ): """ Returns an http.Response @@ -171,7 +182,6 @@ class HTTP1Protocol(semantics.ProtocolMixin): # read separately body = None - if hasattr(self.tcp_handler.rfile, "first_byte_timestamp"): # more accurate timestamp_start timestamp_start = self.tcp_handler.rfile.first_byte_timestamp @@ -191,7 +201,6 @@ class HTTP1Protocol(semantics.ProtocolMixin): timestamp_end=timestamp_end, ) - def assemble_request(self, request): assert isinstance(request, semantics.Request) @@ -204,7 +213,6 @@ class HTTP1Protocol(semantics.ProtocolMixin): headers = self._assemble_request_headers(request) return "%s\r\n%s\r\n%s" % (first_line, headers, request.body) - def assemble_response(self, response): assert isinstance(response, semantics.Response) @@ -217,7 +225,6 @@ class HTTP1Protocol(semantics.ProtocolMixin): headers = self._assemble_response_headers(response) return "%s\r\n%s\r\n%s" % (first_line, headers, response.body) - def read_headers(self): """ Read a set of headers. @@ -262,7 +269,7 @@ class HTTP1Protocol(semantics.ProtocolMixin): response_code, is_request, max_chunk_size=None - ): + ): """ Read an HTTP message body: headers: An ODictCaseless object @@ -317,9 +324,14 @@ class HTTP1Protocol(semantics.ProtocolMixin): "HTTP Body too large. Limit is %s," % limit ) - @classmethod - def expected_http_body_size(self, headers, is_request, request_method, response_code): + def expected_http_body_size( + self, + headers, + is_request, + request_method, + response_code, + ): """ Returns the expected body length: - a positive integer, if the size is known in advance @@ -372,7 +384,6 @@ class HTTP1Protocol(semantics.ProtocolMixin): line = self.tcp_handler.rfile.readline() return line - def _read_chunked(self, limit, is_request): """ Read a chunked HTTP body. @@ -409,7 +420,6 @@ class HTTP1Protocol(semantics.ProtocolMixin): if length == 0: return - @classmethod def _parse_http_protocol(self, line): """ @@ -429,7 +439,6 @@ class HTTP1Protocol(semantics.ProtocolMixin): return None return major, minor - @classmethod def _parse_init(self, line): try: @@ -443,7 +452,6 @@ class HTTP1Protocol(semantics.ProtocolMixin): return None return method, url, httpversion - @classmethod def _parse_init_connect(self, line): """ @@ -471,7 +479,6 @@ class HTTP1Protocol(semantics.ProtocolMixin): return None return host, port, httpversion - @classmethod def _parse_init_proxy(self, line): v = self._parse_init(line) @@ -485,7 +492,6 @@ class HTTP1Protocol(semantics.ProtocolMixin): scheme, host, port, path = parts return method, scheme, host, port, path, httpversion - @classmethod def _parse_init_http(self, line): """ @@ -501,7 +507,6 @@ class HTTP1Protocol(semantics.ProtocolMixin): return None return method, url, httpversion - @classmethod def connection_close(self, httpversion, headers): """ @@ -521,7 +526,6 @@ class HTTP1Protocol(semantics.ProtocolMixin): # be persistent return httpversion != (1, 1) - @classmethod def parse_response_line(self, line): parts = line.strip().split(" ", 2) @@ -536,7 +540,6 @@ class HTTP1Protocol(semantics.ProtocolMixin): return None return (proto, code, msg) - @classmethod def _assemble_request_first_line(self, request): return request.legacy_first_line() @@ -557,7 +560,6 @@ class HTTP1Protocol(semantics.ProtocolMixin): return headers.format() - def _assemble_response_first_line(self, response): return 'HTTP/%s.%s %s %s' % ( response.httpversion[0], @@ -566,7 +568,11 @@ class HTTP1Protocol(semantics.ProtocolMixin): response.msg, ) - def _assemble_response_headers(self, response, preserve_transfer_encoding=False): + def _assemble_response_headers( + self, + response, + preserve_transfer_encoding=False, + ): headers = response.headers.copy() for k in response._headers_to_strip_off: del headers[k] diff --git a/netlib/http/http2/frame.py b/netlib/http/http2/frame.py index f7e60471..aa1fbae4 100644 --- a/netlib/http/http2/frame.py +++ b/netlib/http/http2/frame.py @@ -117,7 +117,7 @@ class Frame(object): return "\n".join([ "%s: %s | length: %d | flags: %#x | stream_id: %d" % ( - direction, self.__class__.__name__, self.length, self.flags, self.stream_id), + direction, self.__class__.__name__, self.length, self.flags, self.stream_id), self.payload_human_readable(), "===============================================================", ]) diff --git a/netlib/http/http2/protocol.py b/netlib/http/http2/protocol.py index a1ca4a18..896b728b 100644 --- a/netlib/http/http2/protocol.py +++ b/netlib/http/http2/protocol.py @@ -9,6 +9,7 @@ from . import frame class TCPHandler(object): + def __init__(self, rfile, wfile=None): self.rfile = rfile self.wfile = wfile @@ -39,7 +40,6 @@ class HTTP2Protocol(semantics.ProtocolMixin): ALPN_PROTO_H2 = 'h2' - def __init__( self, tcp_handler=None, @@ -60,7 +60,12 @@ class HTTP2Protocol(semantics.ProtocolMixin): self.current_stream_id = None self.connection_preface_performed = False - def read_request(self, include_body=True, body_size_limit=None, allow_empty=False): + def read_request( + self, + include_body=True, + body_size_limit=None, + allow_empty=False, + ): self.perform_connection_preface() timestamp_start = time.time() @@ -92,7 +97,12 @@ class HTTP2Protocol(semantics.ProtocolMixin): return request - def read_response(self, request_method='', body_size_limit=None, include_body=True): + def read_response( + self, + request_method='', + body_size_limit=None, + include_body=True, + ): self.perform_connection_preface() timestamp_start = time.time() @@ -123,7 +133,6 @@ class HTTP2Protocol(semantics.ProtocolMixin): return response - def assemble_request(self, request): assert isinstance(request, semantics.Request) diff --git a/netlib/http/semantics.py b/netlib/http/semantics.py index 15add957..d9dbb559 100644 --- a/netlib/http/semantics.py +++ b/netlib/http/semantics.py @@ -332,6 +332,7 @@ class Request(object): class EmptyRequest(Request): + def __init__(self): super(EmptyRequest, self).__init__( form_in="", @@ -343,7 +344,7 @@ class EmptyRequest(Request): httpversion=(0, 0), headers=odict.ODictCaseless(), body="", - ) + ) class Response(object): @@ -396,10 +397,9 @@ class Response(object): status_code=self.status_code, msg=self.msg, contenttype=self.headers.get_first( - "content-type", "unknown content type" - ), - size=size - ) + "content-type", + "unknown content type"), + size=size) def get_cookies(self): """ diff --git a/netlib/tutils.py b/netlib/tutils.py index 3c471d0d..7434c108 100644 --- a/netlib/tutils.py +++ b/netlib/tutils.py @@ -69,8 +69,6 @@ def raises(exc, obj, *args, **kwargs): test_data = utils.Data(__name__) - - def treq(content="content", scheme="http", host="address", port=22): """ @return: libmproxy.protocol.http.HTTPRequest diff --git a/netlib/utils.py b/netlib/utils.py index 2dfcafc6..31dcd622 100644 --- a/netlib/utils.py +++ b/netlib/utils.py @@ -119,6 +119,7 @@ def pretty_size(size): class Data(object): + def __init__(self, name): m = __import__(name) dirname, _ = os.path.split(m.__file__) @@ -137,8 +138,6 @@ class Data(object): return fullpath - - def is_valid_port(port): if not 0 <= port <= 65535: return False @@ -221,6 +220,7 @@ def hostport(scheme, host, port): else: return "%s:%s" % (host, port) + def unparse_url(scheme, host, port, path=""): """ Returns a URL string, constructed from the specified compnents. @@ -235,6 +235,7 @@ def urlencode(s): s = [tuple(i) for i in s] return urllib.urlencode(s, False) + def urldecode(s): """ Takes a urlencoded string and returns a list of (key, value) tuples. diff --git a/netlib/websockets/frame.py b/netlib/websockets/frame.py index ad4ad0ee..1c4a03b2 100644 --- a/netlib/websockets/frame.py +++ b/netlib/websockets/frame.py @@ -21,6 +21,7 @@ OPCODE = utils.BiDi( PONG=0x0a ) + class FrameHeader(object): def __init__( diff --git a/netlib/websockets/protocol.py b/netlib/websockets/protocol.py index 8169309a..6ce32eac 100644 --- a/netlib/websockets/protocol.py +++ b/netlib/websockets/protocol.py @@ -25,6 +25,7 @@ HEADER_WEBSOCKET_KEY = 'sec-websocket-key' HEADER_WEBSOCKET_ACCEPT = 'sec-websocket-accept' HEADER_WEBSOCKET_VERSION = 'sec-websocket-version' + class Masker(object): """ @@ -52,6 +53,7 @@ class Masker(object): self.offset += len(ret) return ret + class WebsocketsProtocol(object): def __init__(self): -- cgit v1.2.3 From 6a30ad2ad236fa20d086e271ff962ebc907da027 Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Mon, 10 Aug 2015 20:50:05 +0200 Subject: fix minor style offences --- netlib/http/http2/protocol.py | 10 +++++----- netlib/http/semantics.py | 12 ++++++------ 2 files changed, 11 insertions(+), 11 deletions(-) (limited to 'netlib') diff --git a/netlib/http/http2/protocol.py b/netlib/http/http2/protocol.py index 896b728b..c2ad5edd 100644 --- a/netlib/http/http2/protocol.py +++ b/netlib/http/http2/protocol.py @@ -142,13 +142,13 @@ class HTTP2Protocol(semantics.ProtocolMixin): headers = request.headers.copy() - if not ':authority' in headers.keys(): + if ':authority' not in headers.keys(): headers.add(':authority', bytes(authority), prepend=True) - if not ':scheme' in headers.keys(): + if ':scheme' not in headers.keys(): headers.add(':scheme', bytes(request.scheme), prepend=True) - if not ':path' in headers.keys(): + if ':path' not in headers.keys(): headers.add(':path', bytes(request.path), prepend=True) - if not ':method' in headers.keys(): + if ':method' not in headers.keys(): headers.add(':method', bytes(request.method), prepend=True) headers = headers.items() @@ -167,7 +167,7 @@ class HTTP2Protocol(semantics.ProtocolMixin): headers = response.headers.copy() - if not ':status' in headers.keys(): + if ':status' not in headers.keys(): headers.add(':status', bytes(str(response.status_code)), prepend=True) headers = headers.items() diff --git a/netlib/http/semantics.py b/netlib/http/semantics.py index d9dbb559..76213cd1 100644 --- a/netlib/http/semantics.py +++ b/netlib/http/semantics.py @@ -15,10 +15,10 @@ CONTENT_MISSING = 0 class ProtocolMixin(object): def read_request(self, *args, **kwargs): # pragma: no cover - raise NotImplemented + raise NotImplementedError def read_response(self, *args, **kwargs): # pragma: no cover - raise NotImplemented + raise NotImplementedError def assemble(self, message): if isinstance(message, Request): @@ -28,11 +28,11 @@ class ProtocolMixin(object): else: raise ValueError("HTTP message not supported.") - def assemble_request(self, request): # pragma: no cover - raise NotImplemented + def assemble_request(self, *args, **kwargs): # pragma: no cover + raise NotImplementedError - def assemble_response(self, response): # pragma: no cover - raise NotImplemented + def assemble_response(self, *args, **kwargs): # pragma: no cover + raise NotImplementedError class Request(object): -- cgit v1.2.3 From b7e6e1c9b2c57270ee0c49af9235a2b119600056 Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Sat, 15 Aug 2015 17:49:59 +0200 Subject: add HTTP/1.1 ALPN version string --- netlib/http/http1/protocol.py | 2 ++ 1 file changed, 2 insertions(+) (limited to 'netlib') diff --git a/netlib/http/http1/protocol.py b/netlib/http/http1/protocol.py index 8eeb7744..dc33a8af 100644 --- a/netlib/http/http1/protocol.py +++ b/netlib/http/http1/protocol.py @@ -17,6 +17,8 @@ class TCPHandler(object): class HTTP1Protocol(semantics.ProtocolMixin): + ALPN_PROTO_HTTP1 = 'http/1.1' + def __init__(self, tcp_handler=None, rfile=None, wfile=None): self.tcp_handler = tcp_handler or TCPHandler(rfile, wfile) -- cgit v1.2.3 From 85cede47aa8f9ffd770ad2830084e53b04b4e77e Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Sun, 16 Aug 2015 11:41:34 +0200 Subject: allow direct ALPN callback method --- netlib/tcp.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) (limited to 'netlib') diff --git a/netlib/tcp.py b/netlib/tcp.py index c355cfdd..b3171a1c 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -403,6 +403,7 @@ class _Connection(object): cipher_list=None, alpn_protos=None, alpn_select=None, + alpn_select_callback=None, ): """ Creates an SSL Context. @@ -457,7 +458,7 @@ class _Connection(object): if alpn_protos is not None: # advertise application layer protocols context.set_alpn_protos(alpn_protos) - elif alpn_select is not None: + elif alpn_select is not None and alpn_select_callback is None: # select application layer protocol def alpn_select_callback(conn_, options): if alpn_select in options: @@ -465,6 +466,10 @@ class _Connection(object): else: # pragma no cover return options[0] context.set_alpn_select_callback(alpn_select_callback) + elif alpn_select_callback is not None and alpn_select is None: + context.set_alpn_select_callback(alpn_select_callback) + elif alpn_select_callback is not None and alpn_select is not None: + raise NetLibError("ALPN error: only define alpn_select (string) OR alpn_select_callback (method).") return context -- cgit v1.2.3 From 3d306671251723a781b6e69c826bb94117f86188 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Mon, 17 Aug 2015 10:21:30 +1200 Subject: Bump netlib version - 0.13.1 is already out --- netlib/version.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'netlib') diff --git a/netlib/version.py b/netlib/version.py index de42ace1..044fde2c 100644 --- a/netlib/version.py +++ b/netlib/version.py @@ -1,6 +1,6 @@ from __future__ import (absolute_import, print_function, division) -IVERSION = (0, 13, 1) +IVERSION = (0, 13, 2) VERSION = ".".join(str(i) for i in IVERSION) MINORVERSION = ".".join(str(i) for i in IVERSION[:2]) NAME = "netlib" -- cgit v1.2.3 From c92dc1b8682ed15b68890f18c65b3f31122e9fa4 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Sat, 15 Aug 2015 20:30:22 +0200 Subject: re-add form_out --- netlib/http/semantics.py | 12 ++++++++---- netlib/tcp.py | 2 ++ 2 files changed, 10 insertions(+), 4 deletions(-) (limited to 'netlib') diff --git a/netlib/http/semantics.py b/netlib/http/semantics.py index 76213cd1..5b7fb80f 100644 --- a/netlib/http/semantics.py +++ b/netlib/http/semantics.py @@ -59,6 +59,7 @@ class Request(object): body=None, timestamp_start=None, timestamp_end=None, + form_out=None ): if not headers: headers = odict.ODictCaseless() @@ -75,6 +76,7 @@ class Request(object): self.body = body self.timestamp_start = timestamp_start self.timestamp_end = timestamp_end + self.form_out = form_out or form_in def __eq__(self, other): try: @@ -91,15 +93,17 @@ class Request(object): self.legacy_first_line()[:-9] ) - def legacy_first_line(self): - if self.form_in == "relative": + def legacy_first_line(self, form=None): + if form is None: + form = self.form_out + if form == "relative": return '%s %s HTTP/%s.%s' % ( self.method, self.path, self.httpversion[0], self.httpversion[1], ) - elif self.form_in == "authority": + elif form == "authority": return '%s %s:%s HTTP/%s.%s' % ( self.method, self.host, @@ -107,7 +111,7 @@ class Request(object): self.httpversion[0], self.httpversion[1], ) - elif self.form_in == "absolute": + elif form == "absolute": return '%s %s://%s:%s%s HTTP/%s.%s' % ( self.method, self.scheme, diff --git a/netlib/tcp.py b/netlib/tcp.py index b3171a1c..22cd0965 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -310,6 +310,8 @@ class Address(object): return str(self.address) def __eq__(self, other): + if not other: + return False other = Address.wrap(other) return (self.address, self.family) == (other.address, other.family) -- cgit v1.2.3 From 62416daa4a3776563556fb45ef9bd749fb44c334 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Sat, 25 Jul 2015 13:31:04 +0200 Subject: add Reader.peek() --- netlib/tcp.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) (limited to 'netlib') diff --git a/netlib/tcp.py b/netlib/tcp.py index 22cd0965..b05e84f5 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -265,6 +265,24 @@ class Reader(_FileLike): ) return result + def peek(self, length): + """ + Tries to peek into the underlying file object. + + Returns: + Up to the next N bytes if peeking is successful. + None, otherwise. + + Raises: + NetLibSSLError if there was an error with pyOpenSSL. + """ + if isinstance(self.o, SSL.Connection) or isinstance(self.o, socket._fileobject): + try: + return self.o._sock.recv(length, socket.MSG_PEEK) + except SSL.Error as e: + raise NetLibSSLError(str(e)) + + class Address(object): -- cgit v1.2.3 From 231656859fcf82cb1252d1aad8dbc0f77dfb8bba Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Sun, 16 Aug 2015 23:33:11 +0200 Subject: TCPClient: more sophisticated address handling --- netlib/http/semantics.py | 3 ++- netlib/tcp.py | 34 +++++++++++++++++++++++----------- 2 files changed, 25 insertions(+), 12 deletions(-) (limited to 'netlib') diff --git a/netlib/http/semantics.py b/netlib/http/semantics.py index 5b7fb80f..836af550 100644 --- a/netlib/http/semantics.py +++ b/netlib/http/semantics.py @@ -397,7 +397,8 @@ class Response(object): size = utils.pretty_size(len(self.body)) else: size = "content missing" - return "".format( + # TODO: Remove "(unknown content type, content missing)" edge-case + return "".format( status_code=self.status_code, msg=self.msg, contenttype=self.headers.get_first( diff --git a/netlib/tcp.py b/netlib/tcp.py index b05e84f5..289618a7 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -283,7 +283,6 @@ class Reader(_FileLike): raise NetLibSSLError(str(e)) - class Address(object): """ @@ -498,6 +497,29 @@ class TCPClient(_Connection): rbufsize = -1 wbufsize = -1 + def __init__(self, address, source_address=None): + self.connection, self.rfile, self.wfile = None, None, None + self.address = address + self.source_address = Address.wrap( + source_address) if source_address else None + self.cert = None + self.ssl_established = False + self.ssl_verification_error = None + self.sni = None + + @property + def address(self): + return self.__address + + @address.setter + def address(self, address): + if self.connection: + raise RuntimeError("Cannot change server address after establishing connection") + if address: + self.__address = Address.wrap(address) + else: + self.__address = None + def close(self): # Make sure to close the real socket, not the SSL proxy. # OpenSSL is really good at screwing up, i.e. when trying to recv from a failed connection, @@ -507,16 +529,6 @@ class TCPClient(_Connection): else: close_socket(self.connection) - def __init__(self, address, source_address=None): - self.address = Address.wrap(address) - self.source_address = Address.wrap( - source_address) if source_address else None - self.connection, self.rfile, self.wfile = None, None, None - self.cert = None - self.ssl_established = False - self.ssl_verification_error = None - self.sni = None - def create_ssl_context(self, cert=None, alpn_protos=None, **sslctx_kwargs): context = self._create_ssl_context( alpn_protos=alpn_protos, -- cgit v1.2.3 From 0d384ac2a91898d4c8623290ae0fb3a60a35e514 Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Mon, 17 Aug 2015 22:55:33 +0200 Subject: http2: add support for too large data frames --- netlib/http/http2/protocol.py | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) (limited to 'netlib') diff --git a/netlib/http/http2/protocol.py b/netlib/http/http2/protocol.py index c2ad5edd..cc8daba8 100644 --- a/netlib/http/http2/protocol.py +++ b/netlib/http/http2/protocol.py @@ -297,19 +297,22 @@ class HTTP2Protocol(semantics.ProtocolMixin): if body is None or len(body) == 0: return b'' - # TODO: implement max frame size checks and sending in chunks - # TODO: implement flow-control window - - frm = frame.DataFrame( + chunk_size = self.http2_settings[frame.SettingsFrame.SETTINGS.SETTINGS_MAX_FRAME_SIZE] + chunks = range(0, len(body), chunk_size) + frms = [frame.DataFrame( state=self, - flags=frame.Frame.FLAG_END_STREAM, + flags=frame.Frame.FLAG_NO_FLAGS, stream_id=stream_id, - payload=body) + payload=body[i:i+chunk_size]) for i in chunks] + frms[-1].flags = frame.Frame.FLAG_END_STREAM + + # TODO: implement flow-control window if self.dump_frames: # pragma no cover - print(frm.human_readable(">>")) + for frm in frms: + print(frm.human_readable(">>")) - return [frm.to_bytes()] + return [frm.to_bytes() for frm in frms] def _receive_transmission(self, include_body=True): body_expected = True -- cgit v1.2.3 From 07a1356e2f155d5b9e3a5f97bf90515ed9f1011f Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Tue, 18 Aug 2015 09:49:56 +0200 Subject: http2: add support for too large header frames --- netlib/http/http2/protocol.py | 29 +++++++++++++++++++---------- 1 file changed, 19 insertions(+), 10 deletions(-) (limited to 'netlib') diff --git a/netlib/http/http2/protocol.py b/netlib/http/http2/protocol.py index cc8daba8..c27b4e9e 100644 --- a/netlib/http/http2/protocol.py +++ b/netlib/http/http2/protocol.py @@ -274,24 +274,33 @@ class HTTP2Protocol(semantics.ProtocolMixin): # to be more strict use: self._read_settings_ack(hide) def _create_headers(self, headers, stream_id, end_stream=True): - # TODO: implement max frame size checks and sending in chunks - - flags = frame.Frame.FLAG_END_HEADERS - if end_stream: - flags |= frame.Frame.FLAG_END_STREAM + def frame_cls(chunks): + for i in chunks: + if i == 0: + yield frame.HeadersFrame, i + else: + yield frame.ContinuationFrame, i header_block_fragment = self.encoder.encode(headers) - frm = frame.HeadersFrame( + chunk_size = self.http2_settings[frame.SettingsFrame.SETTINGS.SETTINGS_MAX_FRAME_SIZE] + chunks = range(0, len(header_block_fragment), chunk_size) + frms = [frm_cls( state=self, - flags=flags, + flags=frame.Frame.FLAG_NO_FLAGS, stream_id=stream_id, - header_block_fragment=header_block_fragment) + header_block_fragment=header_block_fragment[i:i+chunk_size]) for frm_cls, i in frame_cls(chunks)] + + last_flags = frame.Frame.FLAG_END_HEADERS + if end_stream: + last_flags |= frame.Frame.FLAG_END_STREAM + frms[-1].flags = last_flags if self.dump_frames: # pragma no cover - print(frm.human_readable(">>")) + for frm in frms: + print(frm.human_readable(">>")) - return [frm.to_bytes()] + return [frm.to_bytes() for frm in frms] def _create_body(self, body, stream_id): if body is None or len(body) == 0: -- cgit v1.2.3 From 9686a77dcb640ace74f923c1f0f7f7307f79edfe Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Sun, 16 Aug 2015 20:02:18 +0200 Subject: http2: implement request target --- netlib/http/cookies.py | 3 +-- netlib/http/http2/protocol.py | 39 +++++++++++++++++++++++++++++++++------ 2 files changed, 34 insertions(+), 8 deletions(-) (limited to 'netlib') diff --git a/netlib/http/cookies.py b/netlib/http/cookies.py index b77e3503..78b03a83 100644 --- a/netlib/http/cookies.py +++ b/netlib/http/cookies.py @@ -23,8 +23,7 @@ variants. Serialization follows RFC6265. http://tools.ietf.org/html/rfc2965 """ -# TODO -# - Disallow LHS-only Cookie values +# TODO: Disallow LHS-only Cookie values def _read_until(s, start, term): diff --git a/netlib/http/http2/protocol.py b/netlib/http/http2/protocol.py index c27b4e9e..eacbd2d8 100644 --- a/netlib/http/http2/protocol.py +++ b/netlib/http/http2/protocol.py @@ -80,13 +80,39 @@ class HTTP2Protocol(semantics.ProtocolMixin): timestamp_end = time.time() + authority = headers.get_first(':authority', '') + method = headers.get_first(':method', 'GET') + scheme = headers.get_first(':scheme', 'https') + path = headers.get_first(':path', '/') + host = None + port = None + + if path == '*' or path.startswith("/"): + form_in = "relative" + elif method == 'CONNECT': + form_in = "authority" + if ":" in authority: + host, port = authority.split(":", 1) + else: + host = authority + else: + form_in = "absolute" + # FIXME: verify if path or :host contains what we need + scheme, host, port, _ = utils.parse_url(path) + + if host is None: + host = 'localhost' + if port is None: + port = 80 if scheme == 'http' else 443 + port = int(port) + request = http.Request( - "relative", # TODO: use the correct value - headers.get_first(':method', 'GET'), - headers.get_first(':scheme', 'https'), - headers.get_first(':host', 'localhost'), - 443, # TODO: parse port number from host? - headers.get_first(':path', '/'), + form_in, + method, + scheme, + host, + port, + path, (2, 0), headers, body, @@ -324,6 +350,7 @@ class HTTP2Protocol(semantics.ProtocolMixin): return [frm.to_bytes() for frm in frms] def _receive_transmission(self, include_body=True): + # TODO: include_body is not respected body_expected = True stream_id = 0 -- cgit v1.2.3 From 6810fba54ef9c885215d5ff02534b93bb6868b2e Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Wed, 19 Aug 2015 16:05:42 +0200 Subject: add ssl peek polyfill --- netlib/tcp.py | 20 ++++++++++++++++++-- 1 file changed, 18 insertions(+), 2 deletions(-) (limited to 'netlib') diff --git a/netlib/tcp.py b/netlib/tcp.py index 289618a7..c6638177 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -271,16 +271,32 @@ class Reader(_FileLike): Returns: Up to the next N bytes if peeking is successful. - None, otherwise. Raises: + NetLibError if there was an error with the socket NetLibSSLError if there was an error with pyOpenSSL. + NotImplementedError if the underlying file object is not a (pyOpenSSL) socket """ - if isinstance(self.o, SSL.Connection) or isinstance(self.o, socket._fileobject): + if isinstance(self.o, socket._fileobject): try: return self.o._sock.recv(length, socket.MSG_PEEK) + except socket.error as e: + raise NetLibError(str(e)) + elif isinstance(self.o, SSL.Connection): + try: + if tuple(int(x) for x in OpenSSL.__version__.split(".")[:2]) > (0, 15): + return self.o.recv(length, socket.MSG_PEEK) + else: + # Polyfill for pyOpenSSL <= 0.15.1 + # Taken from https://github.com/pyca/pyopenssl/commit/1d95dea7fea03c7c0df345a5ea30c12d8a0378d2 + buf = SSL._ffi.new("char[]", length) + result = SSL._lib.SSL_peek(self.o._ssl, buf, length) + self.o._raise_ssl_error(self.o._ssl, result) + return SSL._ffi.buffer(buf, result)[:] except SSL.Error as e: raise NetLibSSLError(str(e)) + else: + raise NotImplementedError("Can only peek into (pyOpenSSL) sockets") class Address(object): -- cgit v1.2.3 From 9920de1e153d4a85bbc4fa1dfd8fe5db45d56ab3 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Wed, 19 Aug 2015 16:06:33 +0200 Subject: tcp._Connection: clean up code, fix inheritance --- netlib/tcp.py | 31 ++++++++++++++++++------------- 1 file changed, 18 insertions(+), 13 deletions(-) (limited to 'netlib') diff --git a/netlib/tcp.py b/netlib/tcp.py index c6638177..a0e2ab5e 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -399,6 +399,22 @@ def close_socket(sock): class _Connection(object): + rbufsize = -1 + wbufsize = -1 + + def __init__(self, connection): + if connection: + self.connection = connection + self.rfile = Reader(self.connection.makefile('rb', self.rbufsize)) + self.wfile = Writer(self.connection.makefile('wb', self.wbufsize)) + else: + self.connection = None + self.rfile = None + self.wfile = None + + self.ssl_established = False + self.finished = False + def get_current_cipher(self): if not self.ssl_established: return None @@ -510,16 +526,13 @@ class _Connection(object): class TCPClient(_Connection): - rbufsize = -1 - wbufsize = -1 def __init__(self, address, source_address=None): - self.connection, self.rfile, self.wfile = None, None, None + super(TCPClient, self).__init__(None) self.address = address self.source_address = Address.wrap( source_address) if source_address else None self.cert = None - self.ssl_established = False self.ssl_verification_error = None self.sni = None @@ -627,20 +640,12 @@ class BaseHandler(_Connection): """ The instantiator is expected to call the handle() and finish() methods. - """ - rbufsize = -1 - wbufsize = -1 def __init__(self, connection, address, server): - self.connection = connection + super(BaseHandler, self).__init__(connection) self.address = Address.wrap(address) self.server = server - self.rfile = Reader(self.connection.makefile('rb', self.rbufsize)) - self.wfile = Writer(self.connection.makefile('wb', self.wbufsize)) - - self.finished = False - self.ssl_established = False self.clientcert = None def create_ssl_context(self, -- cgit v1.2.3 From 1025c15242b1f9324bf17ceb53224c84e026b3dc Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Thu, 20 Aug 2015 09:54:45 +0200 Subject: fix typo --- netlib/http/http2/frame.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'netlib') diff --git a/netlib/http/http2/frame.py b/netlib/http/http2/frame.py index aa1fbae4..ad00a59a 100644 --- a/netlib/http/http2/frame.py +++ b/netlib/http/http2/frame.py @@ -569,7 +569,7 @@ class WindowUpdateFrame(Frame): def payload_bytes(self): if self.window_size_increment <= 0 or self.window_size_increment >= 2 ** 31: raise ValueError( - 'Window Szie Increment MUST be greater than 0 and less than 2^31.') + 'Window Size Increment MUST be greater than 0 and less than 2^31.') return struct.pack('!L', self.window_size_increment & 0x7FFFFFFF) -- cgit v1.2.3 From e20d4e5c027ad7000f0d997ffb327817ef0dd557 Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Wed, 19 Aug 2015 21:09:15 +0200 Subject: http2: add callback to handle unexpected frames --- netlib/http/http2/protocol.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) (limited to 'netlib') diff --git a/netlib/http/http2/protocol.py b/netlib/http/http2/protocol.py index eacbd2d8..aa52bd71 100644 --- a/netlib/http/http2/protocol.py +++ b/netlib/http/http2/protocol.py @@ -49,12 +49,14 @@ class HTTP2Protocol(semantics.ProtocolMixin): dump_frames=False, encoder=None, decoder=None, + unhandled_frame_cb=None, ): self.tcp_handler = tcp_handler or TCPHandler(rfile, wfile) self.is_server = is_server self.dump_frames = dump_frames self.encoder = encoder or Encoder() self.decoder = decoder or Decoder() + self.unhandled_frame_cb = unhandled_frame_cb self.http2_settings = frame.HTTP2_DEFAULT_SETTINGS.copy() self.current_stream_id = None @@ -258,11 +260,17 @@ class HTTP2Protocol(semantics.ProtocolMixin): "HTTP2Protocol can not handle unknown ALP: %s" % alp) return True + def _handle_unexpected_frame(self, frm): + if self.unhandled_frame_cb is not None: + self.unhandled_frame_cb(frm) + def _receive_settings(self, hide=False): while True: frm = self.read_frame(hide) if isinstance(frm, frame.SettingsFrame): break + else: + self._handle_unexpected_frame(frm) def _read_settings_ack(self, hide=False): # pragma no cover while True: @@ -271,6 +279,8 @@ class HTTP2Protocol(semantics.ProtocolMixin): assert frm.flags & frame.Frame.FLAG_ACK assert len(frm.settings) == 0 break + else: + self._handle_unexpected_frame(frm) def _next_stream_id(self): if self.current_stream_id is None: @@ -367,6 +377,8 @@ class HTTP2Protocol(semantics.ProtocolMixin): body_expected = False if frm.flags & frame.Frame.FLAG_END_HEADERS: break + else: + self._handle_unexpected_frame(frm) while body_expected: frm = self.read_frame() @@ -374,6 +386,9 @@ class HTTP2Protocol(semantics.ProtocolMixin): body += frm.payload if frm.flags & frame.Frame.FLAG_END_STREAM: break + else: + self._handle_unexpected_frame(frm) + # TODO: implement window update & flow headers = odict.ODictCaseless() -- cgit v1.2.3 From eb343055185fabc892a590c6220b125283036b4e Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Thu, 20 Aug 2015 10:21:22 +0200 Subject: http2: fix frame length field --- netlib/http/http2/frame.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'netlib') diff --git a/netlib/http/http2/frame.py b/netlib/http/http2/frame.py index ad00a59a..24e6510a 100644 --- a/netlib/http/http2/frame.py +++ b/netlib/http/http2/frame.py @@ -98,7 +98,7 @@ class Frame(object): self._check_frame_size(self.length, self.state) - b = struct.pack('!HB', self.length & 0xFFFF00, self.length & 0x0000FF) + b = struct.pack('!HB', (self.length & 0xFFFF00) >> 8, self.length & 0x0000FF) b += struct.pack('!B', self.TYPE) b += struct.pack('!B', self.flags) b += struct.pack('!L', self.stream_id & 0x7FFFFFFF) -- cgit v1.2.3 From 94b7beae2a818ac873fb63991ab5237de1c104dd Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Thu, 20 Aug 2015 10:21:38 +0200 Subject: http2: implement basic flow control updates --- netlib/http/http2/protocol.py | 21 ++++++++++++++------- 1 file changed, 14 insertions(+), 7 deletions(-) (limited to 'netlib') diff --git a/netlib/http/http2/protocol.py b/netlib/http/http2/protocol.py index aa52bd71..bf0b364f 100644 --- a/netlib/http/http2/protocol.py +++ b/netlib/http/http2/protocol.py @@ -251,6 +251,9 @@ class HTTP2Protocol(semantics.ProtocolMixin): if isinstance(frm, frame.SettingsFrame) and not frm.flags & frame.Frame.FLAG_ACK: self._apply_settings(frm.settings, hide) + if isinstance(frm, frame.DataFrame) and frm.length > 0: + self._update_flow_control_window(frm.stream_id, frm.length) + return frm def check_alpn(self): @@ -309,6 +312,12 @@ class HTTP2Protocol(semantics.ProtocolMixin): # be liberal in what we expect from the other end # to be more strict use: self._read_settings_ack(hide) + def _update_flow_control_window(self, stream_id, increment): + frm = frame.WindowUpdateFrame(stream_id=0, window_size_increment=increment) + self.send_frame(frm) + frm = frame.WindowUpdateFrame(stream_id=stream_id, window_size_increment=increment) + self.send_frame(frm) + def _create_headers(self, headers, stream_id, end_stream=True): def frame_cls(chunks): for i in chunks: @@ -351,8 +360,6 @@ class HTTP2Protocol(semantics.ProtocolMixin): payload=body[i:i+chunk_size]) for i in chunks] frms[-1].flags = frame.Frame.FLAG_END_STREAM - # TODO: implement flow-control window - if self.dump_frames: # pragma no cover for frm in frms: print(frm.human_readable(">>")) @@ -369,8 +376,10 @@ class HTTP2Protocol(semantics.ProtocolMixin): while True: frm = self.read_frame() - if isinstance(frm, frame.HeadersFrame)\ - or isinstance(frm, frame.ContinuationFrame): + if ( + (isinstance(frm, frame.HeadersFrame) or isinstance(frm, frame.ContinuationFrame)) and + (stream_id == 0 or frm.stream_id == stream_id) + ): stream_id = frm.stream_id header_block_fragment += frm.header_block_fragment if frm.flags & frame.Frame.FLAG_END_STREAM: @@ -382,15 +391,13 @@ class HTTP2Protocol(semantics.ProtocolMixin): while body_expected: frm = self.read_frame() - if isinstance(frm, frame.DataFrame): + if isinstance(frm, frame.DataFrame) and frm.stream_id == stream_id: body += frm.payload if frm.flags & frame.Frame.FLAG_END_STREAM: break else: self._handle_unexpected_frame(frm) - # TODO: implement window update & flow - headers = odict.ODictCaseless() for header, value in self.decoder.decode(header_block_fragment): headers.add(header, value) -- cgit v1.2.3 From 16f697f68a7f94375bd1435f5eec6e00911b7019 Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Thu, 20 Aug 2015 10:26:43 +0200 Subject: http2: disable features we do not support yet --- netlib/http/http2/protocol.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) (limited to 'netlib') diff --git a/netlib/http/http2/protocol.py b/netlib/http/http2/protocol.py index bf0b364f..cf46a130 100644 --- a/netlib/http/http2/protocol.py +++ b/netlib/http/http2/protocol.py @@ -225,7 +225,11 @@ class HTTP2Protocol(semantics.ProtocolMixin): magic = self.tcp_handler.rfile.safe_read(magic_length) assert magic == self.CLIENT_CONNECTION_PREFACE - self.send_frame(frame.SettingsFrame(state=self), hide=True) + frm = frame.SettingsFrame(state=self, settings={ + frame.SettingsFrame.SETTINGS.SETTINGS_ENABLE_PUSH: 0, + frame.SettingsFrame.SETTINGS.SETTINGS_MAX_CONCURRENT_STREAMS: 1, + }) + self.send_frame(frm, hide=True) self._receive_settings(hide=True) def perform_client_connection_preface(self, force=False): -- cgit v1.2.3 From 53f2582313ce5e8d1c875bea8b3f1a270db35b5b Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Thu, 20 Aug 2015 20:36:51 +0200 Subject: http2: fix unhandled settings frame --- netlib/http/http2/protocol.py | 16 ++-------------- 1 file changed, 2 insertions(+), 14 deletions(-) (limited to 'netlib') diff --git a/netlib/http/http2/protocol.py b/netlib/http/http2/protocol.py index cf46a130..66ce19c8 100644 --- a/netlib/http/http2/protocol.py +++ b/netlib/http/http2/protocol.py @@ -239,7 +239,8 @@ class HTTP2Protocol(semantics.ProtocolMixin): self.tcp_handler.wfile.write(self.CLIENT_CONNECTION_PREFACE) self.send_frame(frame.SettingsFrame(state=self), hide=True) - self._receive_settings(hide=True) + self._receive_settings(hide=True) # server announces own settings + self._receive_settings(hide=True) # server acks my settings def send_frame(self, frm, hide=False): raw_bytes = frm.to_bytes() @@ -279,16 +280,6 @@ class HTTP2Protocol(semantics.ProtocolMixin): else: self._handle_unexpected_frame(frm) - def _read_settings_ack(self, hide=False): # pragma no cover - while True: - frm = self.read_frame(hide) - if isinstance(frm, frame.SettingsFrame): - assert frm.flags & frame.Frame.FLAG_ACK - assert len(frm.settings) == 0 - break - else: - self._handle_unexpected_frame(frm) - def _next_stream_id(self): if self.current_stream_id is None: if self.is_server: @@ -313,9 +304,6 @@ class HTTP2Protocol(semantics.ProtocolMixin): flags=frame.Frame.FLAG_ACK) self.send_frame(frm, hide) - # be liberal in what we expect from the other end - # to be more strict use: self._read_settings_ack(hide) - def _update_flow_control_window(self, stream_id, increment): frm = frame.WindowUpdateFrame(stream_id=0, window_size_increment=increment) self.send_frame(frm) -- cgit v1.2.3 From cd9701050f58f90c757a34f7e4e6b5711700d649 Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Fri, 21 Aug 2015 10:03:57 +0200 Subject: read_response depends on request for stream_id --- netlib/http/http1/protocol.py | 4 ++-- netlib/http/http2/protocol.py | 18 +++++++++++------- netlib/http/semantics.py | 34 ++++++++++++++++++++++++---------- 3 files changed, 37 insertions(+), 19 deletions(-) (limited to 'netlib') diff --git a/netlib/http/http1/protocol.py b/netlib/http/http1/protocol.py index dc33a8af..107a48d1 100644 --- a/netlib/http/http1/protocol.py +++ b/netlib/http/http1/protocol.py @@ -136,7 +136,7 @@ class HTTP1Protocol(semantics.ProtocolMixin): def read_response( self, - request_method, + request, body_size_limit, include_body=True, ): @@ -175,7 +175,7 @@ class HTTP1Protocol(semantics.ProtocolMixin): body = self.read_http_body( headers, body_size_limit, - request_method, + request.method, code, False ) diff --git a/netlib/http/http2/protocol.py b/netlib/http/http2/protocol.py index 66ce19c8..e032c2a0 100644 --- a/netlib/http/http2/protocol.py +++ b/netlib/http/http2/protocol.py @@ -74,7 +74,9 @@ class HTTP2Protocol(semantics.ProtocolMixin): if hasattr(self.tcp_handler.rfile, "reset_timestamps"): self.tcp_handler.rfile.reset_timestamps() - stream_id, headers, body = self._receive_transmission(include_body) + stream_id, headers, body = self._receive_transmission( + include_body=include_body, + ) if hasattr(self.tcp_handler.rfile, "first_byte_timestamp"): # more accurate timestamp_start @@ -127,7 +129,7 @@ class HTTP2Protocol(semantics.ProtocolMixin): def read_response( self, - request_method='', + request='', body_size_limit=None, include_body=True, ): @@ -137,7 +139,10 @@ class HTTP2Protocol(semantics.ProtocolMixin): if hasattr(self.tcp_handler.rfile, "reset_timestamps"): self.tcp_handler.rfile.reset_timestamps() - stream_id, headers, body = self._receive_transmission(include_body) + stream_id, headers, body = self._receive_transmission( + stream_id=request.stream_id, + include_body=include_body, + ) if hasattr(self.tcp_handler.rfile, "first_byte_timestamp"): # more accurate timestamp_start @@ -145,7 +150,7 @@ class HTTP2Protocol(semantics.ProtocolMixin): if include_body: timestamp_end = time.time() - else: + else: # pragma: no cover timestamp_end = None response = http.Response( @@ -358,11 +363,10 @@ class HTTP2Protocol(semantics.ProtocolMixin): return [frm.to_bytes() for frm in frms] - def _receive_transmission(self, include_body=True): + def _receive_transmission(self, stream_id=None, include_body=True): # TODO: include_body is not respected body_expected = True - stream_id = 0 header_block_fragment = b'' body = b'' @@ -370,7 +374,7 @@ class HTTP2Protocol(semantics.ProtocolMixin): frm = self.read_frame() if ( (isinstance(frm, frame.HeadersFrame) or isinstance(frm, frame.ContinuationFrame)) and - (stream_id == 0 or frm.stream_id == stream_id) + (stream_id is None or frm.stream_id == stream_id) ): stream_id = frm.stream_id header_block_fragment += frm.header_block_fragment diff --git a/netlib/http/semantics.py b/netlib/http/semantics.py index 836af550..e388a344 100644 --- a/netlib/http/semantics.py +++ b/netlib/http/semantics.py @@ -337,18 +337,32 @@ class Request(object): class EmptyRequest(Request): - def __init__(self): + def __init__( + self, + form_in="", + method="", + scheme="", + host="", + port="", + path="", + httpversion=None, + headers=None, + body="", + stream_id=None + ): super(EmptyRequest, self).__init__( - form_in="", - method="", - scheme="", - host="", - port="", - path="", - httpversion=(0, 0), - headers=odict.ODictCaseless(), - body="", + form_in=form_in, + method=method, + scheme=scheme, + host=host, + port=port, + path=path, + httpversion=(httpversion or (0, 0)), + headers=(headers or odict.ODictCaseless()), + body=body, ) + if stream_id: + self.stream_id = stream_id class Response(object): -- cgit v1.2.3 From 622665952ca072a6276917c252758bbe19091a0d Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Mon, 24 Aug 2015 16:52:32 +0200 Subject: minor stylistic fixes --- netlib/http/http2/protocol.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) (limited to 'netlib') diff --git a/netlib/http/http2/protocol.py b/netlib/http/http2/protocol.py index e032c2a0..1d6e0168 100644 --- a/netlib/http/http2/protocol.py +++ b/netlib/http/http2/protocol.py @@ -123,6 +123,7 @@ class HTTP2Protocol(semantics.ProtocolMixin): timestamp_start, timestamp_end, ) + # FIXME: We should not do this. request.stream_id = stream_id return request @@ -150,7 +151,7 @@ class HTTP2Protocol(semantics.ProtocolMixin): if include_body: timestamp_end = time.time() - else: # pragma: no cover + else: timestamp_end = None response = http.Response( @@ -274,7 +275,7 @@ class HTTP2Protocol(semantics.ProtocolMixin): return True def _handle_unexpected_frame(self, frm): - if self.unhandled_frame_cb is not None: + if self.unhandled_frame_cb: self.unhandled_frame_cb(frm) def _receive_settings(self, hide=False): @@ -364,7 +365,9 @@ class HTTP2Protocol(semantics.ProtocolMixin): return [frm.to_bytes() for frm in frms] def _receive_transmission(self, stream_id=None, include_body=True): - # TODO: include_body is not respected + if not include_body: + raise NotImplementedError() + body_expected = True header_block_fragment = b'' -- cgit v1.2.3 From 21858995aee48c67430c9b6f3965d897b27cd734 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Mon, 24 Aug 2015 18:16:34 +0200 Subject: request -> request_method --- netlib/http/http1/protocol.py | 6 +++--- netlib/http/http2/protocol.py | 11 +++++++++-- netlib/http/semantics.py | 9 +++------ 3 files changed, 15 insertions(+), 11 deletions(-) (limited to 'netlib') diff --git a/netlib/http/http1/protocol.py b/netlib/http/http1/protocol.py index 107a48d1..6b4489fb 100644 --- a/netlib/http/http1/protocol.py +++ b/netlib/http/http1/protocol.py @@ -136,8 +136,8 @@ class HTTP1Protocol(semantics.ProtocolMixin): def read_response( self, - request, - body_size_limit, + request_method, + body_size_limit=None, include_body=True, ): """ @@ -175,7 +175,7 @@ class HTTP1Protocol(semantics.ProtocolMixin): body = self.read_http_body( headers, body_size_limit, - request.method, + request_method, code, False ) diff --git a/netlib/http/http2/protocol.py b/netlib/http/http2/protocol.py index 1d6e0168..b6a147d4 100644 --- a/netlib/http/http2/protocol.py +++ b/netlib/http/http2/protocol.py @@ -68,6 +68,9 @@ class HTTP2Protocol(semantics.ProtocolMixin): body_size_limit=None, allow_empty=False, ): + if body_size_limit is not None: + raise NotImplementedError() + self.perform_connection_preface() timestamp_start = time.time() @@ -130,10 +133,14 @@ class HTTP2Protocol(semantics.ProtocolMixin): def read_response( self, - request='', + request_method='', body_size_limit=None, include_body=True, + stream_id=None, ): + if body_size_limit is not None: + raise NotImplementedError() + self.perform_connection_preface() timestamp_start = time.time() @@ -141,7 +148,7 @@ class HTTP2Protocol(semantics.ProtocolMixin): self.tcp_handler.rfile.reset_timestamps() stream_id, headers, body = self._receive_transmission( - stream_id=request.stream_id, + stream_id=stream_id, include_body=include_body, ) diff --git a/netlib/http/semantics.py b/netlib/http/semantics.py index e388a344..2b960483 100644 --- a/netlib/http/semantics.py +++ b/netlib/http/semantics.py @@ -345,10 +345,9 @@ class EmptyRequest(Request): host="", port="", path="", - httpversion=None, + httpversion=(0, 0), headers=None, - body="", - stream_id=None + body="" ): super(EmptyRequest, self).__init__( form_in=form_in, @@ -357,12 +356,10 @@ class EmptyRequest(Request): host=host, port=port, path=path, - httpversion=(httpversion or (0, 0)), + httpversion=httpversion, headers=(headers or odict.ODictCaseless()), body=body, ) - if stream_id: - self.stream_id = stream_id class Response(object): -- cgit v1.2.3 From de0ced73f8e14aec8f94ea93c0ba0165026e09fc Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Tue, 25 Aug 2015 18:33:55 +0200 Subject: fix error messages --- netlib/tcp.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'netlib') diff --git a/netlib/tcp.py b/netlib/tcp.py index a0e2ab5e..3a094d9a 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -281,7 +281,7 @@ class Reader(_FileLike): try: return self.o._sock.recv(length, socket.MSG_PEEK) except socket.error as e: - raise NetLibError(str(e)) + raise NetLibError(repr(e)) elif isinstance(self.o, SSL.Connection): try: if tuple(int(x) for x in OpenSSL.__version__.split(".")[:2]) > (0, 15): @@ -294,7 +294,7 @@ class Reader(_FileLike): self.o._raise_ssl_error(self.o._ssl, result) return SSL._ffi.buffer(buf, result)[:] except SSL.Error as e: - raise NetLibSSLError(str(e)) + raise NetLibSSLError(repr(e)) else: raise NotImplementedError("Can only peek into (pyOpenSSL) sockets") -- cgit v1.2.3 From 3e3b59aa71a596fcddd14e72612067923a0d9b21 Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Wed, 26 Aug 2015 20:58:00 +0200 Subject: http2: fix priority stream dependency check --- netlib/http/http2/frame.py | 3 --- 1 file changed, 3 deletions(-) (limited to 'netlib') diff --git a/netlib/http/http2/frame.py b/netlib/http/http2/frame.py index 24e6510a..b36b3adf 100644 --- a/netlib/http/http2/frame.py +++ b/netlib/http/http2/frame.py @@ -290,9 +290,6 @@ class PriorityFrame(Frame): raise ValueError( 'PRIORITY frames MUST be associated with a stream.') - if self.stream_dependency == 0x0: - raise ValueError('stream dependency is invalid.') - return struct.pack( '!LB', (int( -- cgit v1.2.3 From 982d8000c420937da532d1c584e3ca7a86c5f3e8 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Fri, 28 Aug 2015 17:35:48 +0200 Subject: wip --- netlib/http/__init__.py | 1 - netlib/http/http2/protocol.py | 4 +--- netlib/tcp.py | 18 +----------------- netlib/utils.py | 2 +- 4 files changed, 3 insertions(+), 22 deletions(-) (limited to 'netlib') diff --git a/netlib/http/__init__.py b/netlib/http/__init__.py index b01afc6d..9b4b0e6b 100644 --- a/netlib/http/__init__.py +++ b/netlib/http/__init__.py @@ -1,3 +1,2 @@ -from . import * from exceptions import * from semantics import * diff --git a/netlib/http/http2/protocol.py b/netlib/http/http2/protocol.py index b6a147d4..b297e0b8 100644 --- a/netlib/http/http2/protocol.py +++ b/netlib/http/http2/protocol.py @@ -34,9 +34,7 @@ class HTTP2Protocol(semantics.ProtocolMixin): HTTP_1_1_REQUIRED=0xd ) - # "PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n" - CLIENT_CONNECTION_PREFACE =\ - '505249202a20485454502f322e300d0a0d0a534d0d0a0d0a'.decode('hex') + CLIENT_CONNECTION_PREFACE = "PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n" ALPN_PROTO_H2 = 'h2' diff --git a/netlib/tcp.py b/netlib/tcp.py index 3a094d9a..9dfa8d22 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -23,28 +23,12 @@ EINTR = 4 # To enable all SSL methods use: SSLv23 # then add options to disable certain methods # https://bugs.launchpad.net/pyopenssl/+bug/1020632/comments/3 - -# Use ONLY for parsing of CLI arguments! -# All code internals should use OpenSSL constants directly! -SSL_VERSIONS = { - 'TLSv1.2': SSL.TLSv1_2_METHOD, - 'TLSv1.1': SSL.TLSv1_1_METHOD, - 'TLSv1': SSL.TLSv1_METHOD, - 'SSLv3': SSL.SSLv3_METHOD, - 'SSLv2': SSL.SSLv2_METHOD, - 'SSLv23': SSL.SSLv23_METHOD, -} - -SSL_DEFAULT_VERSION = 'SSLv23' - -SSL_DEFAULT_METHOD = SSL_VERSIONS[SSL_DEFAULT_VERSION] - +SSL_DEFAULT_METHOD = SSL.SSLv23_METHOD SSL_DEFAULT_OPTIONS = ( SSL.OP_NO_SSLv2 | SSL.OP_NO_SSLv3 | SSL.OP_CIPHER_SERVER_PREFERENCE ) - if hasattr(SSL, "OP_NO_COMPRESSION"): SSL_DEFAULT_OPTIONS |= SSL.OP_NO_COMPRESSION diff --git a/netlib/utils.py b/netlib/utils.py index 31dcd622..d6190673 100644 --- a/netlib/utils.py +++ b/netlib/utils.py @@ -182,7 +182,7 @@ def parse_url(url): return None else: host = netloc - if scheme == "https": + if scheme.endswith("https"): port = 443 else: port = 80 -- cgit v1.2.3 From 1265945f55604f32d99c3dd7c1efd13b3f2ecd9b Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Sat, 29 Aug 2015 12:30:35 +0200 Subject: move sslversion mapping to netlib --- netlib/tcp.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) (limited to 'netlib') diff --git a/netlib/tcp.py b/netlib/tcp.py index 9dfa8d22..0d83816b 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -32,6 +32,23 @@ SSL_DEFAULT_OPTIONS = ( if hasattr(SSL, "OP_NO_COMPRESSION"): SSL_DEFAULT_OPTIONS |= SSL.OP_NO_COMPRESSION +""" +Map a reasonable SSL version specification into the format OpenSSL expects. +Don't ask... +https://bugs.launchpad.net/pyopenssl/+bug/1020632/comments/3 +""" +sslversion_choices = { + "all": (SSL.SSLv23_METHOD, 0), + # SSLv23_METHOD + NO_SSLv2 + NO_SSLv3 == TLS 1.0+ + # TLSv1_METHOD would be TLS 1.0 only + "secure": (SSL.SSLv23_METHOD, (SSL.OP_NO_SSLv2 | SSL.OP_NO_SSLv3)), + "SSLv2": (SSL.SSLv2_METHOD, 0), + "SSLv3": (SSL.SSLv3_METHOD, 0), + "TLSv1": (SSL.TLSv1_METHOD, 0), + "TLSv1_1": (SSL.TLSv1_1_METHOD, 0), + "TLSv1_2": (SSL.TLSv1_2_METHOD, 0), +} + class NetLibError(Exception): pass -- cgit v1.2.3 From 4a8fd79e334661c1a11cd1cd28d62e6999b384d9 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Sat, 29 Aug 2015 20:54:54 +0200 Subject: don't yield prefix and suffix --- netlib/http/http1/protocol.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) (limited to 'netlib') diff --git a/netlib/http/http1/protocol.py b/netlib/http/http1/protocol.py index 6b4489fb..50975818 100644 --- a/netlib/http/http1/protocol.py +++ b/netlib/http/http1/protocol.py @@ -258,9 +258,7 @@ class HTTP1Protocol(semantics.ProtocolMixin): def read_http_body(self, *args, **kwargs): - return "".join( - content for _, content, _ in self.read_http_body_chunked(*args, **kwargs) - ) + return "".join(self.read_http_body_chunked(*args, **kwargs)) def read_http_body_chunked( @@ -308,7 +306,7 @@ class HTTP1Protocol(semantics.ProtocolMixin): while bytes_left: chunk_size = min(bytes_left, max_chunk_size) content = self.tcp_handler.rfile.read(chunk_size) - yield "", content, "" + yield content bytes_left -= chunk_size else: bytes_left = limit or -1 @@ -317,7 +315,7 @@ class HTTP1Protocol(semantics.ProtocolMixin): content = self.tcp_handler.rfile.read(chunk_size) if not content: return - yield "", content, "" + yield content bytes_left -= chunk_size not_done = self.tcp_handler.rfile.read(1) if not_done: @@ -418,7 +416,7 @@ class HTTP1Protocol(semantics.ProtocolMixin): suffix = self.tcp_handler.rfile.readline(5) if suffix != '\r\n': raise HttpError(code, "Malformed chunked body") - yield line, chunk, '\r\n' + yield chunk if length == 0: return -- cgit v1.2.3 From 53abf5f4d7c1e6f0712c6473904e5c1a58db0bb9 Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Thu, 3 Sep 2015 21:22:40 +0200 Subject: http2: handle Ping in protocol --- netlib/http/http2/protocol.py | 25 +++++++++++++++---------- 1 file changed, 15 insertions(+), 10 deletions(-) (limited to 'netlib') diff --git a/netlib/http/http2/protocol.py b/netlib/http/http2/protocol.py index b297e0b8..2fbe7705 100644 --- a/netlib/http/http2/protocol.py +++ b/netlib/http/http2/protocol.py @@ -261,16 +261,21 @@ class HTTP2Protocol(semantics.ProtocolMixin): print(frm.human_readable(">>")) def read_frame(self, hide=False): - frm = frame.Frame.from_file(self.tcp_handler.rfile, self) - if not hide and self.dump_frames: # pragma no cover - print(frm.human_readable("<<")) - if isinstance(frm, frame.SettingsFrame) and not frm.flags & frame.Frame.FLAG_ACK: - self._apply_settings(frm.settings, hide) - - if isinstance(frm, frame.DataFrame) and frm.length > 0: - self._update_flow_control_window(frm.stream_id, frm.length) - - return frm + while True: + frm = frame.Frame.from_file(self.tcp_handler.rfile, self) + if not hide and self.dump_frames: # pragma no cover + print(frm.human_readable("<<")) + + if isinstance(frm, frame.PingFrame): + raw_bytes = frame.PingFrame(flags=frame.Frame.FLAG_ACK, payload=frm.payload).to_bytes() + self.tcp_handler.wfile.write(raw_bytes) + self.tcp_handler.wfile.flush() + continue + if isinstance(frm, frame.SettingsFrame) and not frm.flags & frame.Frame.FLAG_ACK: + self._apply_settings(frm.settings, hide) + if isinstance(frm, frame.DataFrame) and frm.length > 0: + self._update_flow_control_window(frm.stream_id, frm.length) + return frm def check_alpn(self): alp = self.tcp_handler.get_alpn_proto_negotiated() -- cgit v1.2.3 From 3ebe5a5147db20036d0762b92898f313b4d2f8d8 Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Thu, 3 Sep 2015 21:22:55 +0200 Subject: http2: do net let Settings frames escape --- netlib/http/http2/protocol.py | 2 ++ 1 file changed, 2 insertions(+) (limited to 'netlib') diff --git a/netlib/http/http2/protocol.py b/netlib/http/http2/protocol.py index 2fbe7705..4328ebdd 100644 --- a/netlib/http/http2/protocol.py +++ b/netlib/http/http2/protocol.py @@ -285,6 +285,8 @@ class HTTP2Protocol(semantics.ProtocolMixin): return True def _handle_unexpected_frame(self, frm): + if isinstance(frm, frame.SettingsFrame): + return if self.unhandled_frame_cb: self.unhandled_frame_cb(frm) -- cgit v1.2.3 From 5f97701958a283fca7188623c3cb4a313456b82c Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Sat, 5 Sep 2015 13:26:36 +0200 Subject: add new headers class --- netlib/http/semantics.py | 130 ++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 129 insertions(+), 1 deletion(-) (limited to 'netlib') diff --git a/netlib/http/semantics.py b/netlib/http/semantics.py index 2b960483..162cdbf5 100644 --- a/netlib/http/semantics.py +++ b/netlib/http/semantics.py @@ -1,4 +1,5 @@ from __future__ import (absolute_import, print_function, division) +import UserDict import urllib import urlparse @@ -12,8 +13,135 @@ HDR_FORM_MULTIPART = "multipart/form-data" CONTENT_MISSING = 0 -class ProtocolMixin(object): +class Headers(UserDict.DictMixin): + """ + Header class which allows both convenient access to individual headers as well as + direct access to the underlying raw data. Provides a full dictionary interface. + + Example: + + .. code-block:: python + + # Create header from a list of (header_name, header_value) tuples + >>> h = Headers([ + ["Host","example.com"], + ["Accept","text/html"], + ["accept","application/xml"] + ]) + + # Headers mostly behave like a normal dict. + >>> h["Host"] + "example.com" + + # HTTP Headers are case insensitive + >>> h["host"] + "example.com" + + # Multiple headers are folded into a single header as per RFC7230 + >>> h["Accept"] + "text/html, application/xml" + + # Setting a header removes all existing headers with the same name. + >>> h["Accept"] = "application/text" + >>> h["Accept"] + "application/text" + + # str(h) returns a HTTP1 header block. + >>> print(h) + Host: example.com + Accept: application/text + + # For full control, the raw header lines can be accessed + >>> h.lines + + # Headers can also be crated from keyword arguments + >>> h = Headers(host="example.com", content_type="application/xml") + + Caveats: + For use with the "Set-Cookie" header, see :py:meth:`get_all`. + """ + + def __init__(self, lines=None, **headers): + """ + For convenience, underscores in header names will be transformed to dashes. + This behaviour does not extend to other methods. + + If ``**headers`` contains multiple keys that have equal ``.lower()``s, + the behavior is undefined. + """ + self.lines = lines or [] + + # content_type -> content-type + headers = {k.replace("_", "-"): v for k, v in headers.iteritems()} + self.update(headers) + + def __str__(self): + return "\r\n".join(": ".join(line) for line in self.lines) + + def __getitem__(self, key): + values = self.get_all(key) + if not values: + raise KeyError(key) + else: + return ", ".join(values) + def __setitem__(self, key, value): + idx = self._index(key) + + # To please the human eye, we insert at the same position the first existing header occured. + if idx is not None: + del self[key] + self.lines.insert(idx, [key, value]) + else: + self.lines.append([key, value]) + + def __delitem__(self, key): + key = key.lower() + self.lines = [ + line for line in self.lines + if key != line[0].lower() + ] + + def _index(self, key): + key = key.lower() + for i, line in enumerate(self): + if line[0].lower() == key: + return i + return None + + def keys(self): + return list(set(line[0] for line in self.lines)) + + def __eq__(self, other): + return self.lines == other.lines + + def __ne__(self, other): + return not self.__eq__(other) + + def get_all(self, key, default=None): + """ + Like :py:meth:`get`, but does not fold multiple headers into a single one. + This is useful for Set-Cookie headers, which do not support folding. + + See also: https://tools.ietf.org/html/rfc7230#section-3.2.2 + """ + key = key.lower() + values = [line[1] for line in self.lines if line[0].lower() == key] + return values or default + + def set_all(self, key, values): + """ + Explicitly set multiple headers for the given key. + See: :py:meth:`get_all` + """ + if key in self: + del self[key] + self.lines.extend( + [key, value] for value in values + ) + + +class ProtocolMixin(object): def read_request(self, *args, **kwargs): # pragma: no cover raise NotImplementedError -- cgit v1.2.3 From 3718e59308745e4582f4e8061b4ff6113d9dfc74 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Sat, 5 Sep 2015 15:27:48 +0200 Subject: finalize Headers, add tests --- netlib/http/semantics.py | 109 +++++++++++++++++++++++++++++------------------ 1 file changed, 68 insertions(+), 41 deletions(-) (limited to 'netlib') diff --git a/netlib/http/semantics.py b/netlib/http/semantics.py index 162cdbf5..2fadf2c4 100644 --- a/netlib/http/semantics.py +++ b/netlib/http/semantics.py @@ -51,8 +51,8 @@ class Headers(UserDict.DictMixin): Host: example.com Accept: application/text - # For full control, the raw header lines can be accessed - >>> h.lines + # For full control, the raw header fields can be accessed + >>> h.fields # Headers can also be crated from keyword arguments >>> h = Headers(host="example.com", content_type="application/xml") @@ -61,85 +61,112 @@ class Headers(UserDict.DictMixin): For use with the "Set-Cookie" header, see :py:meth:`get_all`. """ - def __init__(self, lines=None, **headers): + def __init__(self, fields=None, **headers): """ - For convenience, underscores in header names will be transformed to dashes. - This behaviour does not extend to other methods. - - If ``**headers`` contains multiple keys that have equal ``.lower()``s, - the behavior is undefined. + Args: + fields: (optional) list of ``(name, value)`` header tuples, e.g. ``[("Host","example.com")]`` + **headers: Additional headers to set. Will overwrite existing values from `fields`. + For convenience, underscores in header names will be transformed to dashes - + this behaviour does not extend to other methods. + If ``**headers`` contains multiple keys that have equal ``.lower()`` s, + the behavior is undefined. """ - self.lines = lines or [] + self.fields = fields or [] # content_type -> content-type - headers = {k.replace("_", "-"): v for k, v in headers.iteritems()} + headers = { + name.replace("_", "-"): value + for name, value in headers.iteritems() + } self.update(headers) def __str__(self): - return "\r\n".join(": ".join(line) for line in self.lines) + return "\r\n".join(": ".join(field) for field in self.fields) - def __getitem__(self, key): - values = self.get_all(key) + def __getitem__(self, name): + values = self.get_all(name) if not values: - raise KeyError(key) + raise KeyError(name) else: return ", ".join(values) - def __setitem__(self, key, value): - idx = self._index(key) + def __setitem__(self, name, value): + idx = self._index(name) # To please the human eye, we insert at the same position the first existing header occured. if idx is not None: - del self[key] - self.lines.insert(idx, [key, value]) + del self[name] + self.fields.insert(idx, [name, value]) else: - self.lines.append([key, value]) - - def __delitem__(self, key): - key = key.lower() - self.lines = [ - line for line in self.lines - if key != line[0].lower() - ] - - def _index(self, key): - key = key.lower() - for i, line in enumerate(self): - if line[0].lower() == key: + self.fields.append([name, value]) + + def __delitem__(self, name): + if name not in self: + raise KeyError(name) + name = name.lower() + self.fields = [ + field for field in self.fields + if name != field[0].lower() + ] + + def _index(self, name): + name = name.lower() + for i, field in enumerate(self.fields): + if field[0].lower() == name: return i return None def keys(self): - return list(set(line[0] for line in self.lines)) + seen = set() + names = [] + for name, _ in self.fields: + name_lower = name.lower() + if name_lower not in seen: + seen.add(name_lower) + names.append(name) + return names def __eq__(self, other): - return self.lines == other.lines + if isinstance(other, Headers): + return self.fields == other.fields + return False def __ne__(self, other): return not self.__eq__(other) - def get_all(self, key, default=None): + def get_all(self, name, default=None): """ Like :py:meth:`get`, but does not fold multiple headers into a single one. This is useful for Set-Cookie headers, which do not support folding. See also: https://tools.ietf.org/html/rfc7230#section-3.2.2 """ - key = key.lower() - values = [line[1] for line in self.lines if line[0].lower() == key] + name = name.lower() + values = [value for n, value in self.fields if n.lower() == name] return values or default - def set_all(self, key, values): + def set_all(self, name, values): """ Explicitly set multiple headers for the given key. See: :py:meth:`get_all` """ - if key in self: - del self[key] - self.lines.extend( - [key, value] for value in values + if name in self: + del self[name] + self.fields.extend( + [name, value] for value in values ) + # Implement the StateObject protocol from mitmproxy + def get_state(self, short=False): + return tuple(tuple(field) for field in self.fields) + + def load_state(self, state): + self.fields = [list(field) for field in state] + + @classmethod + def from_state(cls, state): + return cls([list(field) for field in state]) + class ProtocolMixin(object): def read_request(self, *args, **kwargs): # pragma: no cover -- cgit v1.2.3 From 66ee1f465f6c492d5a4ff5659e6f0346fb243d67 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Sat, 5 Sep 2015 18:15:47 +0200 Subject: headers: adjust everything --- netlib/http/authentication.py | 4 +- netlib/http/exceptions.py | 18 ----- netlib/http/http1/protocol.py | 41 ++++++------ netlib/http/http2/protocol.py | 44 ++++++------- netlib/http/semantics.py | 148 ++++++++++++++++++++++-------------------- netlib/tutils.py | 10 +-- netlib/utils.py | 13 ++-- netlib/websockets/protocol.py | 28 ++++---- netlib/wsgi.py | 22 +++---- 9 files changed, 155 insertions(+), 173 deletions(-) (limited to 'netlib') diff --git a/netlib/http/authentication.py b/netlib/http/authentication.py index 29b9eb3c..fe1f0d14 100644 --- a/netlib/http/authentication.py +++ b/netlib/http/authentication.py @@ -62,10 +62,10 @@ class BasicProxyAuth(NullProxyAuth): del headers[self.AUTH_HEADER] def authenticate(self, headers): - auth_value = headers.get(self.AUTH_HEADER, []) + auth_value = headers.get(self.AUTH_HEADER) if not auth_value: return False - parts = parse_http_basic_auth(auth_value[0]) + parts = parse_http_basic_auth(auth_value) if not parts: return False scheme, username, password = parts diff --git a/netlib/http/exceptions.py b/netlib/http/exceptions.py index 987a7908..8a2bbebc 100644 --- a/netlib/http/exceptions.py +++ b/netlib/http/exceptions.py @@ -1,6 +1,3 @@ -from netlib import odict - - class HttpError(Exception): def __init__(self, code, message): @@ -10,18 +7,3 @@ class HttpError(Exception): class HttpErrorConnClosed(HttpError): pass - - -class HttpAuthenticationError(Exception): - - def __init__(self, auth_headers=None): - super(HttpAuthenticationError, self).__init__( - "Proxy Authentication Required" - ) - if isinstance(auth_headers, dict): - auth_headers = odict.ODictCaseless(auth_headers.items()) - self.headers = auth_headers - self.code = 407 - - def __repr__(self): - return "Proxy Authentication Required" diff --git a/netlib/http/http1/protocol.py b/netlib/http/http1/protocol.py index 50975818..bf33a18e 100644 --- a/netlib/http/http1/protocol.py +++ b/netlib/http/http1/protocol.py @@ -3,8 +3,8 @@ import string import sys import time -from netlib import odict, utils, tcp, http -from netlib.http import semantics +from ... import utils, tcp, http +from .. import semantics, Headers from ..exceptions import * @@ -96,7 +96,7 @@ class HTTP1Protocol(semantics.ProtocolMixin): if headers is None: raise HttpError(400, "Invalid headers") - expect_header = headers.get_first("expect", "").lower() + expect_header = headers.get("expect", "").lower() if expect_header == "100-continue" and httpversion == (1, 1): self.tcp_handler.wfile.write( 'HTTP/1.1 100 Continue\r\n' @@ -232,10 +232,9 @@ class HTTP1Protocol(semantics.ProtocolMixin): Read a set of headers. Stop once a blank line is reached. - Return a ODictCaseless object, or None if headers are invalid. + Return a Header object, or None if headers are invalid. """ ret = [] - name = '' while True: line = self.tcp_handler.rfile.readline() if not line or line == '\r\n' or line == '\n': @@ -254,7 +253,7 @@ class HTTP1Protocol(semantics.ProtocolMixin): ret.append([name, value]) else: return None - return odict.ODictCaseless(ret) + return Headers(ret) def read_http_body(self, *args, **kwargs): @@ -272,7 +271,7 @@ class HTTP1Protocol(semantics.ProtocolMixin): ): """ Read an HTTP message body: - headers: An ODictCaseless object + headers: A Header object limit: Size limit. is_request: True if the body to read belongs to a request, False otherwise @@ -356,7 +355,7 @@ class HTTP1Protocol(semantics.ProtocolMixin): return None if "content-length" in headers: try: - size = int(headers["content-length"][0]) + size = int(headers["content-length"]) if size < 0: raise ValueError() return size @@ -369,9 +368,7 @@ class HTTP1Protocol(semantics.ProtocolMixin): @classmethod def has_chunked_encoding(self, headers): - return "chunked" in [ - i.lower() for i in utils.get_header_tokens(headers, "transfer-encoding") - ] + return "chunked" in headers.get("transfer-encoding", "").lower() def _get_request_line(self): @@ -547,18 +544,20 @@ class HTTP1Protocol(semantics.ProtocolMixin): def _assemble_request_headers(self, request): headers = request.headers.copy() for k in request._headers_to_strip_off: - del headers[k] + headers.pop(k, None) if 'host' not in headers and request.scheme and request.host and request.port: - headers["Host"] = [utils.hostport(request.scheme, - request.host, - request.port)] + headers["Host"] = utils.hostport( + request.scheme, + request.host, + request.port + ) # If content is defined (i.e. not None or CONTENT_MISSING), we always # add a content-length header. if request.body or request.body == "": - headers["Content-Length"] = [str(len(request.body))] + headers["Content-Length"] = str(len(request.body)) - return headers.format() + return str(headers) def _assemble_response_first_line(self, response): return 'HTTP/%s.%s %s %s' % ( @@ -575,13 +574,13 @@ class HTTP1Protocol(semantics.ProtocolMixin): ): headers = response.headers.copy() for k in response._headers_to_strip_off: - del headers[k] + headers.pop(k, None) if not preserve_transfer_encoding: - del headers['Transfer-Encoding'] + headers.pop('Transfer-Encoding', None) # If body is defined (i.e. not None or CONTENT_MISSING), we always # add a content-length header. if response.body or response.body == "": - headers["Content-Length"] = [str(len(response.body))] + headers["Content-Length"] = str(len(response.body)) - return headers.format() + return str(headers) diff --git a/netlib/http/http2/protocol.py b/netlib/http/http2/protocol.py index b297e0b8..f3254caa 100644 --- a/netlib/http/http2/protocol.py +++ b/netlib/http/http2/protocol.py @@ -3,7 +3,7 @@ import itertools import time from hpack.hpack import Encoder, Decoder -from netlib import http, utils, odict +from netlib import http, utils from netlib.http import semantics from . import frame @@ -85,10 +85,10 @@ class HTTP2Protocol(semantics.ProtocolMixin): timestamp_end = time.time() - authority = headers.get_first(':authority', '') - method = headers.get_first(':method', 'GET') - scheme = headers.get_first(':scheme', 'https') - path = headers.get_first(':path', '/') + authority = headers.get(':authority', '') + method = headers.get(':method', 'GET') + scheme = headers.get(':scheme', 'https') + path = headers.get(':path', '/') host = None port = None @@ -161,7 +161,7 @@ class HTTP2Protocol(semantics.ProtocolMixin): response = http.Response( (2, 0), - int(headers.get_first(':status')), + int(headers.get(':status', 502)), "", headers, body, @@ -181,16 +181,14 @@ class HTTP2Protocol(semantics.ProtocolMixin): headers = request.headers.copy() - if ':authority' not in headers.keys(): - headers.add(':authority', bytes(authority), prepend=True) - if ':scheme' not in headers.keys(): - headers.add(':scheme', bytes(request.scheme), prepend=True) - if ':path' not in headers.keys(): - headers.add(':path', bytes(request.path), prepend=True) - if ':method' not in headers.keys(): - headers.add(':method', bytes(request.method), prepend=True) - - headers = headers.items() + if ':authority' not in headers: + headers.fields.insert(0, (':authority', bytes(authority))) + if ':scheme' not in headers: + headers.fields.insert(0, (':scheme', bytes(request.scheme))) + if ':path' not in headers: + headers.fields.insert(0, (':path', bytes(request.path))) + if ':method' not in headers: + headers.fields.insert(0, (':method', bytes(request.method))) if hasattr(request, 'stream_id'): stream_id = request.stream_id @@ -206,10 +204,8 @@ class HTTP2Protocol(semantics.ProtocolMixin): headers = response.headers.copy() - if ':status' not in headers.keys(): - headers.add(':status', bytes(str(response.status_code)), prepend=True) - - headers = headers.items() + if ':status' not in headers: + headers.fields.insert(0, (':status', bytes(str(response.status_code)))) if hasattr(response, 'stream_id'): stream_id = response.stream_id @@ -329,7 +325,7 @@ class HTTP2Protocol(semantics.ProtocolMixin): else: yield frame.ContinuationFrame, i - header_block_fragment = self.encoder.encode(headers) + header_block_fragment = self.encoder.encode(headers.fields) chunk_size = self.http2_settings[frame.SettingsFrame.SETTINGS.SETTINGS_MAX_FRAME_SIZE] chunks = range(0, len(header_block_fragment), chunk_size) @@ -402,8 +398,8 @@ class HTTP2Protocol(semantics.ProtocolMixin): else: self._handle_unexpected_frame(frm) - headers = odict.ODictCaseless() - for header, value in self.decoder.decode(header_block_fragment): - headers.add(header, value) + headers = http.Headers( + [[str(k), str(v)] for k, v in self.decoder.decode(header_block_fragment)] + ) return stream_id, headers, body diff --git a/netlib/http/semantics.py b/netlib/http/semantics.py index 2fadf2c4..edf5fc07 100644 --- a/netlib/http/semantics.py +++ b/netlib/http/semantics.py @@ -1,9 +1,10 @@ from __future__ import (absolute_import, print_function, division) import UserDict +import copy import urllib import urlparse -from .. import utils, odict +from .. import odict from . import cookies, exceptions from netlib import utils, encoding @@ -77,11 +78,11 @@ class Headers(UserDict.DictMixin): headers = { name.replace("_", "-"): value for name, value in headers.iteritems() - } + } self.update(headers) def __str__(self): - return "\r\n".join(": ".join(field) for field in self.fields) + return "\r\n".join(": ".join(field) for field in self.fields) + "\r\n" def __getitem__(self, name): values = self.get_all(name) @@ -107,7 +108,7 @@ class Headers(UserDict.DictMixin): self.fields = [ field for field in self.fields if name != field[0].lower() - ] + ] def _index(self, name): name = name.lower() @@ -134,7 +135,7 @@ class Headers(UserDict.DictMixin): def __ne__(self, other): return not self.__eq__(other) - def get_all(self, name, default=None): + def get_all(self, name, default=[]): """ Like :py:meth:`get`, but does not fold multiple headers into a single one. This is useful for Set-Cookie headers, which do not support folding. @@ -156,6 +157,9 @@ class Headers(UserDict.DictMixin): [name, value] for value in values ) + def copy(self): + return Headers(copy.copy(self.fields)) + # Implement the StateObject protocol from mitmproxy def get_state(self, short=False): return tuple(tuple(field) for field in self.fields) @@ -202,23 +206,23 @@ class Request(object): ] def __init__( - self, - form_in, - method, - scheme, - host, - port, - path, - httpversion, - headers=None, - body=None, - timestamp_start=None, - timestamp_end=None, - form_out=None + self, + form_in, + method, + scheme, + host, + port, + path, + httpversion, + headers=None, + body=None, + timestamp_start=None, + timestamp_end=None, + form_out=None ): if not headers: - headers = odict.ODictCaseless() - assert isinstance(headers, odict.ODictCaseless) + headers = Headers() + assert isinstance(headers, Headers) self.form_in = form_in self.method = method @@ -235,8 +239,10 @@ class Request(object): def __eq__(self, other): try: - self_d = [self.__dict__[k] for k in self.__dict__ if k not in ('timestamp_start', 'timestamp_end')] - other_d = [other.__dict__[k] for k in other.__dict__ if k not in ('timestamp_start', 'timestamp_end')] + self_d = [self.__dict__[k] for k in self.__dict__ if + k not in ('timestamp_start', 'timestamp_end')] + other_d = [other.__dict__[k] for k in other.__dict__ if + k not in ('timestamp_start', 'timestamp_end')] return self_d == other_d except: return False @@ -289,30 +295,35 @@ class Request(object): "if-none-match", ] for i in delheaders: - del self.headers[i] + self.headers.pop(i, None) def anticomp(self): """ Modifies this request to remove headers that will compress the resource's data. """ - self.headers["accept-encoding"] = ["identity"] + self.headers["accept-encoding"] = "identity" def constrain_encoding(self): """ Limits the permissible Accept-Encoding values, based on what we can decode appropriately. """ - if self.headers["accept-encoding"]: - self.headers["accept-encoding"] = [ + accept_encoding = self.headers.get("accept-encoding") + if accept_encoding: + self.headers["accept-encoding"] = ( ', '.join( - e for e in encoding.ENCODINGS if e in self.headers.get_first("accept-encoding"))] + e + for e in encoding.ENCODINGS + if e in accept_encoding + ) + ) def update_host_header(self): """ Update the host header to reflect the current target. """ - self.headers["Host"] = [self.host] + self.headers["Host"] = self.host def get_form(self): """ @@ -321,9 +332,9 @@ class Request(object): indicates non-form data. """ if self.body: - if self.headers.in_any("content-type", HDR_FORM_URLENCODED, True): + if HDR_FORM_URLENCODED in self.headers.get("content-type","").lower(): return self.get_form_urlencoded() - elif self.headers.in_any("content-type", HDR_FORM_MULTIPART, True): + elif HDR_FORM_MULTIPART in self.headers.get("content-type","").lower(): return self.get_form_multipart() return odict.ODict([]) @@ -333,18 +344,12 @@ class Request(object): Returns an empty ODict if there is no data or the content-type indicates non-form data. """ - if self.body and self.headers.in_any( - "content-type", - HDR_FORM_URLENCODED, - True): + if self.body and HDR_FORM_URLENCODED in self.headers.get("content-type","").lower(): return odict.ODict(utils.urldecode(self.body)) return odict.ODict([]) def get_form_multipart(self): - if self.body and self.headers.in_any( - "content-type", - HDR_FORM_MULTIPART, - True): + if self.body and HDR_FORM_MULTIPART in self.headers.get("content-type","").lower(): return odict.ODict( utils.multipartdecode( self.headers, @@ -359,7 +364,7 @@ class Request(object): """ # FIXME: If there's an existing content-type header indicating a # url-encoded form, leave it alone. - self.headers["Content-Type"] = [HDR_FORM_URLENCODED] + self.headers["Content-Type"] = HDR_FORM_URLENCODED self.body = utils.urlencode(odict.lst) def get_path_components(self): @@ -418,7 +423,7 @@ class Request(object): """ host = None if hostheader: - host = self.headers.get_first("host") + host = self.headers.get("Host") if not host: host = self.host if host: @@ -442,7 +447,7 @@ class Request(object): Returns a possibly empty netlib.odict.ODict object. """ ret = odict.ODict() - for i in self.headers["cookie"]: + for i in self.headers.get_all("cookie"): ret.extend(cookies.parse_cookie_header(i)) return ret @@ -452,7 +457,7 @@ class Request(object): headers. """ v = cookies.format_cookie_header(odict) - self.headers["Cookie"] = [v] + self.headers["Cookie"] = v @property def url(self): @@ -491,18 +496,17 @@ class Request(object): class EmptyRequest(Request): - def __init__( - self, - form_in="", - method="", - scheme="", - host="", - port="", - path="", - httpversion=(0, 0), - headers=None, - body="" + self, + form_in="", + method="", + scheme="", + host="", + port="", + path="", + httpversion=(0, 0), + headers=None, + body="" ): super(EmptyRequest, self).__init__( form_in=form_in, @@ -512,7 +516,7 @@ class EmptyRequest(Request): port=port, path=path, httpversion=httpversion, - headers=(headers or odict.ODictCaseless()), + headers=headers, body=body, ) @@ -525,19 +529,19 @@ class Response(object): ] def __init__( - self, - httpversion, - status_code, - msg=None, - headers=None, - body=None, - sslinfo=None, - timestamp_start=None, - timestamp_end=None, + self, + httpversion, + status_code, + msg=None, + headers=None, + body=None, + sslinfo=None, + timestamp_start=None, + timestamp_end=None, ): if not headers: - headers = odict.ODictCaseless() - assert isinstance(headers, odict.ODictCaseless) + headers = Headers() + assert isinstance(headers, Headers) self.httpversion = httpversion self.status_code = status_code @@ -550,8 +554,10 @@ class Response(object): def __eq__(self, other): try: - self_d = [self.__dict__[k] for k in self.__dict__ if k not in ('timestamp_start', 'timestamp_end')] - other_d = [other.__dict__[k] for k in other.__dict__ if k not in ('timestamp_start', 'timestamp_end')] + self_d = [self.__dict__[k] for k in self.__dict__ if + k not in ('timestamp_start', 'timestamp_end')] + other_d = [other.__dict__[k] for k in other.__dict__ if + k not in ('timestamp_start', 'timestamp_end')] return self_d == other_d except: return False @@ -567,9 +573,7 @@ class Response(object): return "".format( status_code=self.status_code, msg=self.msg, - contenttype=self.headers.get_first( - "content-type", - "unknown content type"), + contenttype=self.headers.get("content-type", "unknown content type"), size=size) def get_cookies(self): @@ -582,7 +586,7 @@ class Response(object): attributes (e.g. HTTPOnly) are indicated by a Null value. """ ret = [] - for header in self.headers["set-cookie"]: + for header in self.headers.get_all("set-cookie"): v = cookies.parse_set_cookie_header(header) if v: name, value, attrs = v @@ -605,7 +609,7 @@ class Response(object): i[1][1] ) ) - self.headers["Set-Cookie"] = values + self.headers.set_all("Set-Cookie", values) @property def content(self): # pragma: no cover diff --git a/netlib/tutils.py b/netlib/tutils.py index 7434c108..951ef3d9 100644 --- a/netlib/tutils.py +++ b/netlib/tutils.py @@ -5,7 +5,7 @@ import time import shutil from contextlib import contextmanager -from netlib import tcp, utils, odict, http +from netlib import tcp, utils, http def treader(bytes): @@ -73,8 +73,8 @@ def treq(content="content", scheme="http", host="address", port=22): """ @return: libmproxy.protocol.http.HTTPRequest """ - headers = odict.ODictCaseless() - headers["header"] = ["qvalue"] + headers = http.Headers() + headers["header"] = "qvalue" req = http.Request( "relative", "GET", @@ -108,8 +108,8 @@ def tresp(content="message"): @return: libmproxy.protocol.http.HTTPResponse """ - headers = odict.ODictCaseless() - headers["header_response"] = ["svalue"] + headers = http.Headers() + headers["header_response"] = "svalue" resp = http.semantics.Response( (1, 1), diff --git a/netlib/utils.py b/netlib/utils.py index d6190673..aae187da 100644 --- a/netlib/utils.py +++ b/netlib/utils.py @@ -204,11 +204,10 @@ def get_header_tokens(headers, key): follow a pattern where each header line can containe comma-separated tokens, and headers can be set multiple times. """ - toks = [] - for i in headers[key]: - for j in i.split(","): - toks.append(j.strip()) - return toks + if key not in headers: + return [] + tokens = headers[key].split(",") + return [token.strip() for token in tokens] def hostport(scheme, host, port): @@ -270,11 +269,11 @@ def parse_content_type(c): return ts[0].lower(), ts[1].lower(), d -def multipartdecode(hdrs, content): +def multipartdecode(headers, content): """ Takes a multipart boundary encoded string and returns list of (key, value) tuples. """ - v = hdrs.get_first("content-type") + v = headers.get("content-type") if v: v = parse_content_type(v) if not v: diff --git a/netlib/websockets/protocol.py b/netlib/websockets/protocol.py index 6ce32eac..46c02875 100644 --- a/netlib/websockets/protocol.py +++ b/netlib/websockets/protocol.py @@ -1,10 +1,5 @@ -from __future__ import absolute_import -import base64 -import hashlib -import os -from netlib import odict -from netlib import utils + # Colleciton of utility functions that implement small portions of the RFC6455 # WebSockets Protocol Useful for building WebSocket clients and servers. @@ -18,6 +13,13 @@ from netlib import utils # The magic sha that websocket servers must know to prove they understand # RFC6455 +from __future__ import absolute_import +import base64 +import hashlib +import os +from ..http import Headers +from .. import utils + websockets_magic = '258EAFA5-E914-47DA-95CA-C5AB0DC85B11' VERSION = "13" @@ -66,11 +68,11 @@ class WebsocketsProtocol(object): specified, it is generated, and can be found in sec-websocket-key in the returned header set. - Returns an instance of ODictCaseless + Returns an instance of Headers """ if not key: key = base64.b64encode(os.urandom(16)).decode('utf-8') - return odict.ODictCaseless([ + return Headers([ ('Connection', 'Upgrade'), ('Upgrade', 'websocket'), (HEADER_WEBSOCKET_KEY, key), @@ -82,7 +84,7 @@ class WebsocketsProtocol(object): """ The server response is a valid HTTP 101 response. """ - return odict.ODictCaseless( + return Headers( [ ('Connection', 'Upgrade'), ('Upgrade', 'websocket'), @@ -93,16 +95,16 @@ class WebsocketsProtocol(object): @classmethod def check_client_handshake(self, headers): - if headers.get_first("upgrade", None) != "websocket": + if headers.get("upgrade") != "websocket": return - return headers.get_first(HEADER_WEBSOCKET_KEY) + return headers.get(HEADER_WEBSOCKET_KEY) @classmethod def check_server_handshake(self, headers): - if headers.get_first("upgrade", None) != "websocket": + if headers.get("upgrade") != "websocket": return - return headers.get_first(HEADER_WEBSOCKET_ACCEPT) + return headers.get(HEADER_WEBSOCKET_ACCEPT) @classmethod diff --git a/netlib/wsgi.py b/netlib/wsgi.py index 99afe00e..8a98884a 100644 --- a/netlib/wsgi.py +++ b/netlib/wsgi.py @@ -3,7 +3,7 @@ import cStringIO import urllib import time import traceback -from . import odict, tcp +from . import http, tcp class ClientConn(object): @@ -68,8 +68,8 @@ class WSGIAdaptor(object): 'SCRIPT_NAME': '', 'PATH_INFO': urllib.unquote(path_info), 'QUERY_STRING': query, - 'CONTENT_TYPE': flow.request.headers.get('Content-Type', [''])[0], - 'CONTENT_LENGTH': flow.request.headers.get('Content-Length', [''])[0], + 'CONTENT_TYPE': flow.request.headers.get('Content-Type', ''), + 'CONTENT_LENGTH': flow.request.headers.get('Content-Length', ''), 'SERVER_NAME': self.domain, 'SERVER_PORT': str(self.port), # FIXME: We need to pick up the protocol read from the request. @@ -115,12 +115,12 @@ class WSGIAdaptor(object): def write(data): if not state["headers_sent"]: soc.write("HTTP/1.1 %s\r\n" % state["status"]) - h = state["headers"] - if 'server' not in h: - h["Server"] = [self.sversion] - if 'date' not in h: - h["Date"] = [date_time_string()] - soc.write(h.format()) + headers = state["headers"] + if 'server' not in headers: + headers["Server"] = self.sversion + if 'date' not in headers: + headers["Date"] = date_time_string() + soc.write(str(headers)) soc.write("\r\n") state["headers_sent"] = True if data: @@ -137,7 +137,7 @@ class WSGIAdaptor(object): elif state["status"]: raise AssertionError('Response already started') state["status"] = status - state["headers"] = odict.ODictCaseless(headers) + state["headers"] = http.Headers(headers) return write errs = cStringIO.StringIO() @@ -149,7 +149,7 @@ class WSGIAdaptor(object): write(i) if not state["headers_sent"]: write("") - except Exception: + except Exception as e: try: s = traceback.format_exc() errs.write(s) -- cgit v1.2.3 From fc86bbd03e7806bf5d3dc0d226b607192642c810 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Tue, 8 Sep 2015 15:16:25 +0200 Subject: let Headers inherit from object fixes mitmproxy/mitmproxy#753 --- netlib/http/semantics.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) (limited to 'netlib') diff --git a/netlib/http/semantics.py b/netlib/http/semantics.py index edf5fc07..5bb098a7 100644 --- a/netlib/http/semantics.py +++ b/netlib/http/semantics.py @@ -14,7 +14,7 @@ HDR_FORM_MULTIPART = "multipart/form-data" CONTENT_MISSING = 0 -class Headers(UserDict.DictMixin): +class Headers(object, UserDict.DictMixin): """ Header class which allows both convenient access to individual headers as well as direct access to the underlying raw data. Provides a full dictionary interface. @@ -135,7 +135,7 @@ class Headers(UserDict.DictMixin): def __ne__(self, other): return not self.__eq__(other) - def get_all(self, name, default=[]): + def get_all(self, name): """ Like :py:meth:`get`, but does not fold multiple headers into a single one. This is useful for Set-Cookie headers, which do not support folding. @@ -144,7 +144,7 @@ class Headers(UserDict.DictMixin): """ name = name.lower() values = [value for n, value in self.fields if n.lower() == name] - return values or default + return values def set_all(self, name, values): """ -- cgit v1.2.3 From 32b3c32138847cb1f5b0c1958fc9ad0a49f8810f Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Tue, 8 Sep 2015 21:31:27 +0200 Subject: add tcp.Address.__hash__ --- netlib/tcp.py | 3 +++ 1 file changed, 3 insertions(+) (limited to 'netlib') diff --git a/netlib/tcp.py b/netlib/tcp.py index 0d83816b..5c9d26de 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -352,6 +352,9 @@ class Address(object): def __ne__(self, other): return not self.__eq__(other) + def __hash__(self): + return hash(self.address) ^ 42 # different hash than the tuple alone. + def close_socket(sock): """ -- cgit v1.2.3 From a5f7752cf18a9c6b34916107abc89bbdf0050566 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Thu, 10 Sep 2015 11:30:17 +0200 Subject: add ssl_read_select --- netlib/tcp.py | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) (limited to 'netlib') diff --git a/netlib/tcp.py b/netlib/tcp.py index 5c9d26de..e9610099 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -356,6 +356,27 @@ class Address(object): return hash(self.address) ^ 42 # different hash than the tuple alone. +def ssl_read_select(rlist, timeout): + """ + This is a wrapper around select.select() which also works for SSL.Connections + by taking ssl_connection.pending() into account. + + Caveats: + If .pending() > 0 for any of the connections in rlist, we avoid the select syscall + and **will not include any other connections which may or may not be ready**. + + Args: + rlist: wait until ready for reading + + Returns: + subset of rlist which is ready for reading. + """ + return [ + conn for conn in rlist + if isinstance(conn, SSL.Connection) and conn.pending() > 0 + ] or select.select(rlist, (), (), timeout)[0] + + def close_socket(sock): """ Does a hard close of a socket, without emitting a RST. -- cgit v1.2.3 From 92c763f469fdf721f3d981346f8a40e33b06de23 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Thu, 10 Sep 2015 12:32:38 +0200 Subject: fix mitmproxy/mitmproxy#759 --- netlib/version_check.py | 23 +++++++++++++++++------ 1 file changed, 17 insertions(+), 6 deletions(-) (limited to 'netlib') diff --git a/netlib/version_check.py b/netlib/version_check.py index aae4e8c7..1d7e025c 100644 --- a/netlib/version_check.py +++ b/netlib/version_check.py @@ -3,10 +3,11 @@ Having installed a wrong version of pyOpenSSL or netlib is unfortunately a very common source of error. Check before every start that both versions are somewhat okay. """ -from __future__ import division, absolute_import, print_function, unicode_literals +from __future__ import division, absolute_import, print_function import sys import inspect import os.path + import OpenSSL from . import version @@ -28,19 +29,29 @@ def check_mitmproxy_version(mitmproxy_version, fp=sys.stderr): def check_pyopenssl_version(min_version=PYOPENSSL_MIN_VERSION, fp=sys.stderr): - v = tuple(int(x) for x in OpenSSL.__version__.split(".")[:2]) + min_version_str = ".".join(str(x) for x in min_version) + try: + v = tuple(int(x) for x in OpenSSL.__version__.split(".")[:2]) + except ValueError: + print( + "Cannot parse pyOpenSSL version: {}" + "mitmproxy requires pyOpenSSL {} or greater.".format( + OpenSSL.__version__, min_version_str + ), + file=fp + ) + return if v < min_version: print( - "You are using an outdated version of pyOpenSSL:" - " mitmproxy requires pyOpenSSL %s or greater." % - str(min_version), + "You are using an outdated version of pyOpenSSL: " + "mitmproxy requires pyOpenSSL {} or greater.".format(min_version_str), file=fp ) # Some users apparently have multiple versions of pyOpenSSL installed. # Report which one we got. pyopenssl_path = os.path.dirname(inspect.getfile(OpenSSL)) print( - "Your pyOpenSSL %s installation is located at %s" % ( + "Your pyOpenSSL {} installation is located at {}".format( OpenSSL.__version__, pyopenssl_path ), file=fp -- cgit v1.2.3 From a38142d5950a899c6e3f854841a45f4785515761 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Fri, 11 Sep 2015 01:17:39 +0200 Subject: don't yield empty chunks --- netlib/http/http1/protocol.py | 2 +- netlib/tcp.py | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) (limited to 'netlib') diff --git a/netlib/http/http1/protocol.py b/netlib/http/http1/protocol.py index bf33a18e..cf1dffa3 100644 --- a/netlib/http/http1/protocol.py +++ b/netlib/http/http1/protocol.py @@ -413,9 +413,9 @@ class HTTP1Protocol(semantics.ProtocolMixin): suffix = self.tcp_handler.rfile.readline(5) if suffix != '\r\n': raise HttpError(code, "Malformed chunked body") - yield chunk if length == 0: return + yield chunk @classmethod def _parse_http_protocol(self, line): diff --git a/netlib/tcp.py b/netlib/tcp.py index e9610099..4a7f6153 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -8,6 +8,7 @@ import time import traceback import certifi +import six import OpenSSL from OpenSSL import SSL @@ -295,7 +296,7 @@ class Reader(_FileLike): self.o._raise_ssl_error(self.o._ssl, result) return SSL._ffi.buffer(buf, result)[:] except SSL.Error as e: - raise NetLibSSLError(repr(e)) + six.reraise(NetLibSSLError, NetLibSSLError(str(e)), sys.exc_info()[2]) else: raise NotImplementedError("Can only peek into (pyOpenSSL) sockets") -- cgit v1.2.3 From 997fcde8ce94be9d8decddd4bc783106dbb41ab3 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Sat, 12 Sep 2015 17:03:09 +0200 Subject: make clean_bin unicode-aware --- netlib/utils.py | 39 +++++++++++++++++++++++++-------------- netlib/websockets/frame.py | 2 +- 2 files changed, 26 insertions(+), 15 deletions(-) (limited to 'netlib') diff --git a/netlib/utils.py b/netlib/utils.py index aae187da..d6774419 100644 --- a/netlib/utils.py +++ b/netlib/utils.py @@ -5,6 +5,8 @@ import urllib import urlparse import string import re +import six +import unicodedata def isascii(s): @@ -20,22 +22,31 @@ def bytes_to_int(i): return int(i.encode('hex'), 16) -def cleanBin(s, fixspacing=False): +def clean_bin(s, keep_spacing=True): """ - Cleans binary data to make it safe to display. If fixspacing is True, - tabs, newlines and so forth will be maintained, if not, they will be - replaced with a placeholder. + Cleans binary data to make it safe to display. + + Args: + keep_spacing: If False, tabs and newlines will also be replaced. """ - parts = [] - for i in s: - o = ord(i) - if (o > 31 and o < 127): - parts.append(i) - elif i in "\n\t" and not fixspacing: - parts.append(i) + if isinstance(s, six.text_type): + if keep_spacing: + keep = u" \n\r\t" + else: + keep = u" " + return u"".join( + ch if (unicodedata.category(ch)[0] not in "CZ" or ch in keep) else u"." + for ch in s + ) + else: + if keep_spacing: + keep = b"\n\r\t" else: - parts.append(".") - return "".join(parts) + keep = b"" + return b"".join( + ch if (31 < ord(ch) < 127 or ch in keep) else b"." + for ch in s + ) def hexdump(s): @@ -52,7 +63,7 @@ def hexdump(s): x += " " x += " ".join(" " for i in range(16 - len(part))) parts.append( - (o, x, cleanBin(part, True)) + (o, x, clean_bin(part, False)) ) return parts diff --git a/netlib/websockets/frame.py b/netlib/websockets/frame.py index 1c4a03b2..e3ff1405 100644 --- a/netlib/websockets/frame.py +++ b/netlib/websockets/frame.py @@ -236,7 +236,7 @@ class Frame(object): def human_readable(self): ret = self.header.human_readable() if self.payload: - ret = ret + "\nPayload:\n" + utils.cleanBin(self.payload) + ret = ret + "\nPayload:\n" + utils.clean_bin(self.payload) return ret def __repr__(self): -- cgit v1.2.3 From 11e7f476bd4bbcd6d072fa3659f628ae3a19705d Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Tue, 15 Sep 2015 19:12:15 +0200 Subject: wip --- netlib/encoding.py | 8 +- netlib/exceptions.py | 31 ++ netlib/http/__init__.py | 9 +- netlib/http/authentication.py | 4 +- netlib/http/exceptions.py | 9 - netlib/http/http1/__init__.py | 23 +- netlib/http/http1/assemble.py | 105 +++++++ netlib/http/http1/protocol.py | 586 ------------------------------------ netlib/http/http1/read.py | 346 +++++++++++++++++++++ netlib/http/http2/__init__.py | 2 - netlib/http/http2/connections.py | 412 +++++++++++++++++++++++++ netlib/http/http2/frame.py | 633 --------------------------------------- netlib/http/http2/frames.py | 633 +++++++++++++++++++++++++++++++++++++++ netlib/http/http2/protocol.py | 412 ------------------------- netlib/http/models.py | 571 +++++++++++++++++++++++++++++++++++ netlib/http/semantics.py | 632 -------------------------------------- netlib/tcp.py | 8 +- netlib/tutils.py | 70 +++-- netlib/utils.py | 162 ++++++---- netlib/version_check.py | 17 +- netlib/websockets/__init__.py | 4 +- 21 files changed, 2295 insertions(+), 2382 deletions(-) create mode 100644 netlib/exceptions.py delete mode 100644 netlib/http/exceptions.py create mode 100644 netlib/http/http1/assemble.py delete mode 100644 netlib/http/http1/protocol.py create mode 100644 netlib/http/http1/read.py create mode 100644 netlib/http/http2/connections.py delete mode 100644 netlib/http/http2/frame.py create mode 100644 netlib/http/http2/frames.py delete mode 100644 netlib/http/http2/protocol.py create mode 100644 netlib/http/models.py delete mode 100644 netlib/http/semantics.py (limited to 'netlib') diff --git a/netlib/encoding.py b/netlib/encoding.py index f107eb5f..06830f2c 100644 --- a/netlib/encoding.py +++ b/netlib/encoding.py @@ -2,13 +2,13 @@ Utility functions for decoding response bodies. """ from __future__ import absolute_import -import cStringIO +from io import BytesIO import gzip import zlib __ALL__ = ["ENCODINGS"] -ENCODINGS = set(["identity", "gzip", "deflate"]) +ENCODINGS = {"identity", "gzip", "deflate"} def decode(e, content): @@ -42,7 +42,7 @@ def identity(content): def decode_gzip(content): - gfile = gzip.GzipFile(fileobj=cStringIO.StringIO(content)) + gfile = gzip.GzipFile(fileobj=BytesIO(content)) try: return gfile.read() except (IOError, EOFError): @@ -50,7 +50,7 @@ def decode_gzip(content): def encode_gzip(content): - s = cStringIO.StringIO() + s = BytesIO() gf = gzip.GzipFile(fileobj=s, mode='wb') gf.write(content) gf.close() diff --git a/netlib/exceptions.py b/netlib/exceptions.py new file mode 100644 index 00000000..637be3df --- /dev/null +++ b/netlib/exceptions.py @@ -0,0 +1,31 @@ +""" +We try to be very hygienic regarding the exceptions we throw: +Every Exception netlib raises shall be a subclass of NetlibException. + + +See also: http://lucumr.pocoo.org/2014/10/16/on-error-handling/ +""" +from __future__ import absolute_import, print_function, division + + +class NetlibException(Exception): + """ + Base class for all exceptions thrown by netlib. + """ + def __init__(self, message=None): + super(NetlibException, self).__init__(message) + + +class ReadDisconnect(object): + """Immediate EOF""" + + +class HttpException(NetlibException): + pass + + +class HttpReadDisconnect(HttpException, ReadDisconnect): + pass + +class HttpSyntaxException(HttpException): + pass diff --git a/netlib/http/__init__.py b/netlib/http/__init__.py index 9b4b0e6b..0b1a0bc5 100644 --- a/netlib/http/__init__.py +++ b/netlib/http/__init__.py @@ -1,2 +1,7 @@ -from exceptions import * -from semantics import * +from .models import Request, Response, Headers, CONTENT_MISSING +from . import http1, http2 + +__all__ = [ + "Request", "Response", "Headers", "CONTENT_MISSING" + "http1", "http2" +] diff --git a/netlib/http/authentication.py b/netlib/http/authentication.py index fe1f0d14..2055f843 100644 --- a/netlib/http/authentication.py +++ b/netlib/http/authentication.py @@ -19,8 +19,8 @@ def parse_http_basic_auth(s): def assemble_http_basic_auth(scheme, username, password): - v = binascii.b2a_base64(username + ":" + password) - return scheme + " " + v + v = binascii.b2a_base64(username + b":" + password) + return scheme + b" " + v class NullProxyAuth(object): diff --git a/netlib/http/exceptions.py b/netlib/http/exceptions.py deleted file mode 100644 index 8a2bbebc..00000000 --- a/netlib/http/exceptions.py +++ /dev/null @@ -1,9 +0,0 @@ -class HttpError(Exception): - - def __init__(self, code, message): - super(HttpError, self).__init__(message) - self.code = code - - -class HttpErrorConnClosed(HttpError): - pass diff --git a/netlib/http/http1/__init__.py b/netlib/http/http1/__init__.py index 6b5043af..4d223f97 100644 --- a/netlib/http/http1/__init__.py +++ b/netlib/http/http1/__init__.py @@ -1 +1,22 @@ -from protocol import * +from .read import ( + read_request, read_request_head, + read_response, read_response_head, + read_message_body, read_message_body_chunked, + connection_close, + expected_http_body_size, +) +from .assemble import ( + assemble_request, assemble_request_head, + assemble_response, assemble_response_head, +) + + +__all__ = [ + "read_request", "read_request_head", + "read_response", "read_response_head", + "read_message_body", "read_message_body_chunked", + "connection_close", + "expected_http_body_size", + "assemble_request", "assemble_request_head", + "assemble_response", "assemble_response_head", +] diff --git a/netlib/http/http1/assemble.py b/netlib/http/http1/assemble.py new file mode 100644 index 00000000..a3269eed --- /dev/null +++ b/netlib/http/http1/assemble.py @@ -0,0 +1,105 @@ +from __future__ import absolute_import, print_function, division + +from ... import utils +from ...exceptions import HttpException +from .. import CONTENT_MISSING + + +def assemble_request(request): + if request.body == CONTENT_MISSING: + raise HttpException("Cannot assemble flow with CONTENT_MISSING") + head = assemble_request_head(request) + return head + request.body + + +def assemble_request_head(request): + first_line = _assemble_request_line(request) + headers = _assemble_request_headers(request) + return b"%s\r\n%s\r\n" % (first_line, headers) + + +def assemble_response(response): + if response.body == CONTENT_MISSING: + raise HttpException("Cannot assemble flow with CONTENT_MISSING") + head = assemble_response_head(response) + return head + response.body + + +def assemble_response_head(response): + first_line = _assemble_response_line(response) + headers = _assemble_response_headers(response) + return b"%s\r\n%s\r\n" % (first_line, headers) + + + + +def _assemble_request_line(request, form=None): + if form is None: + form = request.form_out + if form == "relative": + return b"%s %s %s" % ( + request.method, + request.path, + request.httpversion + ) + elif form == "authority": + return b"%s %s:%d %s" % ( + request.method, + request.host, + request.port, + request.httpversion + ) + elif form == "absolute": + return b"%s %s://%s:%s%s %s" % ( + request.method, + request.scheme, + request.host, + request.port, + request.path, + request.httpversion + ) + else: # pragma: nocover + raise RuntimeError("Invalid request form") + + +def _assemble_request_headers(request): + headers = request.headers.copy() + for k in request._headers_to_strip_off: + headers.pop(k, None) + if b"host" not in headers and request.scheme and request.host and request.port: + headers[b"Host"] = utils.hostport( + request.scheme, + request.host, + request.port + ) + + # If content is defined (i.e. not None or CONTENT_MISSING), we always + # add a content-length header. + if request.body or request.body == b"": + headers[b"Content-Length"] = str(len(request.body)).encode("ascii") + + return str(headers) + + +def _assemble_response_line(response): + return b"%s %s %s" % ( + response.httpversion, + response.status_code, + response.msg, + ) + + +def _assemble_response_headers(response, preserve_transfer_encoding=False): + # TODO: Remove preserve_transfer_encoding + headers = response.headers.copy() + for k in response._headers_to_strip_off: + headers.pop(k, None) + if not preserve_transfer_encoding: + headers.pop(b"Transfer-Encoding", None) + + # If body is defined (i.e. not None or CONTENT_MISSING), we always + # add a content-length header. + if response.body or response.body == b"": + headers[b"Content-Length"] = str(len(response.body)).encode("ascii") + + return bytes(headers) diff --git a/netlib/http/http1/protocol.py b/netlib/http/http1/protocol.py deleted file mode 100644 index cf1dffa3..00000000 --- a/netlib/http/http1/protocol.py +++ /dev/null @@ -1,586 +0,0 @@ -from __future__ import (absolute_import, print_function, division) -import string -import sys -import time - -from ... import utils, tcp, http -from .. import semantics, Headers -from ..exceptions import * - - -class TCPHandler(object): - - def __init__(self, rfile, wfile=None): - self.rfile = rfile - self.wfile = wfile - - -class HTTP1Protocol(semantics.ProtocolMixin): - - ALPN_PROTO_HTTP1 = 'http/1.1' - - def __init__(self, tcp_handler=None, rfile=None, wfile=None): - self.tcp_handler = tcp_handler or TCPHandler(rfile, wfile) - - def read_request( - self, - include_body=True, - body_size_limit=None, - allow_empty=False, - ): - """ - Parse an HTTP request from a file stream - - Args: - include_body (bool): Read response body as well - body_size_limit (bool): Maximum body size - wfile (file): If specified, HTTP Expect headers are handled - automatically, by writing a HTTP 100 CONTINUE response to the stream. - - Returns: - Request: The HTTP request - - Raises: - HttpError: If the input is invalid. - """ - timestamp_start = time.time() - if hasattr(self.tcp_handler.rfile, "reset_timestamps"): - self.tcp_handler.rfile.reset_timestamps() - - httpversion, host, port, scheme, method, path, headers, body = ( - None, None, None, None, None, None, None, None) - - request_line = self._get_request_line() - if not request_line: - if allow_empty: - return http.EmptyRequest() - else: - raise tcp.NetLibDisconnect() - - request_line_parts = self._parse_init(request_line) - if not request_line_parts: - raise HttpError( - 400, - "Bad HTTP request line: %s" % repr(request_line) - ) - method, path, httpversion = request_line_parts - - if path == '*' or path.startswith("/"): - form_in = "relative" - if not utils.isascii(path): - raise HttpError( - 400, - "Bad HTTP request line: %s" % repr(request_line) - ) - elif method == 'CONNECT': - form_in = "authority" - r = self._parse_init_connect(request_line) - if not r: - raise HttpError( - 400, - "Bad HTTP request line: %s" % repr(request_line) - ) - host, port, httpversion = r - path = None - else: - form_in = "absolute" - r = self._parse_init_proxy(request_line) - if not r: - raise HttpError( - 400, - "Bad HTTP request line: %s" % repr(request_line) - ) - _, scheme, host, port, path, _ = r - - headers = self.read_headers() - if headers is None: - raise HttpError(400, "Invalid headers") - - expect_header = headers.get("expect", "").lower() - if expect_header == "100-continue" and httpversion == (1, 1): - self.tcp_handler.wfile.write( - 'HTTP/1.1 100 Continue\r\n' - '\r\n' - ) - self.tcp_handler.wfile.flush() - del headers['expect'] - - if include_body: - body = self.read_http_body( - headers, - body_size_limit, - method, - None, - True - ) - - if hasattr(self.tcp_handler.rfile, "first_byte_timestamp"): - # more accurate timestamp_start - timestamp_start = self.tcp_handler.rfile.first_byte_timestamp - - timestamp_end = time.time() - - return http.Request( - form_in, - method, - scheme, - host, - port, - path, - httpversion, - headers, - body, - timestamp_start, - timestamp_end, - ) - - def read_response( - self, - request_method, - body_size_limit=None, - include_body=True, - ): - """ - Returns an http.Response - - By default, both response header and body are read. - If include_body=False is specified, body may be one of the - following: - - None, if the response is technically allowed to have a response body - - "", if the response must not have a response body (e.g. it's a - response to a HEAD request) - """ - timestamp_start = time.time() - if hasattr(self.tcp_handler.rfile, "reset_timestamps"): - self.tcp_handler.rfile.reset_timestamps() - - line = self.tcp_handler.rfile.readline() - # Possible leftover from previous message - if line == "\r\n" or line == "\n": - line = self.tcp_handler.rfile.readline() - if not line: - raise HttpErrorConnClosed(502, "Server disconnect.") - parts = self.parse_response_line(line) - if not parts: - raise HttpError(502, "Invalid server response: %s" % repr(line)) - proto, code, msg = parts - httpversion = self._parse_http_protocol(proto) - if httpversion is None: - raise HttpError(502, "Invalid HTTP version in line: %s" % repr(proto)) - headers = self.read_headers() - if headers is None: - raise HttpError(502, "Invalid headers.") - - if include_body: - body = self.read_http_body( - headers, - body_size_limit, - request_method, - code, - False - ) - else: - # if include_body==False then a None body means the body should be - # read separately - body = None - - if hasattr(self.tcp_handler.rfile, "first_byte_timestamp"): - # more accurate timestamp_start - timestamp_start = self.tcp_handler.rfile.first_byte_timestamp - - if include_body: - timestamp_end = time.time() - else: - timestamp_end = None - - return http.Response( - httpversion, - code, - msg, - headers, - body, - timestamp_start=timestamp_start, - timestamp_end=timestamp_end, - ) - - def assemble_request(self, request): - assert isinstance(request, semantics.Request) - - if request.body == semantics.CONTENT_MISSING: - raise http.HttpError( - 502, - "Cannot assemble flow with CONTENT_MISSING" - ) - first_line = self._assemble_request_first_line(request) - headers = self._assemble_request_headers(request) - return "%s\r\n%s\r\n%s" % (first_line, headers, request.body) - - def assemble_response(self, response): - assert isinstance(response, semantics.Response) - - if response.body == semantics.CONTENT_MISSING: - raise http.HttpError( - 502, - "Cannot assemble flow with CONTENT_MISSING" - ) - first_line = self._assemble_response_first_line(response) - headers = self._assemble_response_headers(response) - return "%s\r\n%s\r\n%s" % (first_line, headers, response.body) - - def read_headers(self): - """ - Read a set of headers. - Stop once a blank line is reached. - - Return a Header object, or None if headers are invalid. - """ - ret = [] - while True: - line = self.tcp_handler.rfile.readline() - if not line or line == '\r\n' or line == '\n': - break - if line[0] in ' \t': - if not ret: - return None - # continued header - ret[-1][1] = ret[-1][1] + '\r\n ' + line.strip() - else: - i = line.find(':') - # We're being liberal in what we accept, here. - if i > 0: - name = line[:i] - value = line[i + 1:].strip() - ret.append([name, value]) - else: - return None - return Headers(ret) - - - def read_http_body(self, *args, **kwargs): - return "".join(self.read_http_body_chunked(*args, **kwargs)) - - - def read_http_body_chunked( - self, - headers, - limit, - request_method, - response_code, - is_request, - max_chunk_size=None - ): - """ - Read an HTTP message body: - headers: A Header object - limit: Size limit. - is_request: True if the body to read belongs to a request, False - otherwise - """ - if max_chunk_size is None: - max_chunk_size = limit or sys.maxsize - - expected_size = self.expected_http_body_size( - headers, is_request, request_method, response_code - ) - - if expected_size is None: - if self.has_chunked_encoding(headers): - # Python 3: yield from - for x in self._read_chunked(limit, is_request): - yield x - else: # pragma: nocover - raise HttpError( - 400 if is_request else 502, - "Content-Length unknown but no chunked encoding" - ) - elif expected_size >= 0: - if limit is not None and expected_size > limit: - raise HttpError( - 400 if is_request else 509, - "HTTP Body too large. Limit is %s, content-length was %s" % ( - limit, expected_size - ) - ) - bytes_left = expected_size - while bytes_left: - chunk_size = min(bytes_left, max_chunk_size) - content = self.tcp_handler.rfile.read(chunk_size) - yield content - bytes_left -= chunk_size - else: - bytes_left = limit or -1 - while bytes_left: - chunk_size = min(bytes_left, max_chunk_size) - content = self.tcp_handler.rfile.read(chunk_size) - if not content: - return - yield content - bytes_left -= chunk_size - not_done = self.tcp_handler.rfile.read(1) - if not_done: - raise HttpError( - 400 if is_request else 509, - "HTTP Body too large. Limit is %s," % limit - ) - - @classmethod - def expected_http_body_size( - self, - headers, - is_request, - request_method, - response_code, - ): - """ - Returns the expected body length: - - a positive integer, if the size is known in advance - - None, if the size in unknown in advance (chunked encoding or invalid - data) - - -1, if all data should be read until end of stream. - - May raise HttpError. - """ - # Determine response size according to - # http://tools.ietf.org/html/rfc7230#section-3.3 - if request_method: - request_method = request_method.upper() - - if (not is_request and ( - request_method == "HEAD" or - (request_method == "CONNECT" and response_code == 200) or - response_code in [204, 304] or - 100 <= response_code <= 199)): - return 0 - if self.has_chunked_encoding(headers): - return None - if "content-length" in headers: - try: - size = int(headers["content-length"]) - if size < 0: - raise ValueError() - return size - except ValueError: - return None - if is_request: - return 0 - return -1 - - - @classmethod - def has_chunked_encoding(self, headers): - return "chunked" in headers.get("transfer-encoding", "").lower() - - - def _get_request_line(self): - """ - Get a line, possibly preceded by a blank. - """ - line = self.tcp_handler.rfile.readline() - if line == "\r\n" or line == "\n": - # Possible leftover from previous message - line = self.tcp_handler.rfile.readline() - return line - - def _read_chunked(self, limit, is_request): - """ - Read a chunked HTTP body. - - May raise HttpError. - """ - # FIXME: Should check if chunked is the final encoding in the headers - # http://tools.ietf.org/html/draft-ietf-httpbis-p1-messaging-16#section-3.3 - # 3.3 2. - total = 0 - code = 400 if is_request else 502 - while True: - line = self.tcp_handler.rfile.readline(128) - if line == "": - raise HttpErrorConnClosed(code, "Connection closed prematurely") - if line != '\r\n' and line != '\n': - try: - length = int(line, 16) - except ValueError: - raise HttpError( - code, - "Invalid chunked encoding length: %s" % line - ) - total += length - if limit is not None and total > limit: - msg = "HTTP Body too large. Limit is %s," \ - " chunked content longer than %s" % (limit, total) - raise HttpError(code, msg) - chunk = self.tcp_handler.rfile.read(length) - suffix = self.tcp_handler.rfile.readline(5) - if suffix != '\r\n': - raise HttpError(code, "Malformed chunked body") - if length == 0: - return - yield chunk - - @classmethod - def _parse_http_protocol(self, line): - """ - Parse an HTTP protocol declaration. - Returns a (major, minor) tuple, or None. - """ - if not line.startswith("HTTP/"): - return None - _, version = line.split('/', 1) - if "." not in version: - return None - major, minor = version.split('.', 1) - try: - major = int(major) - minor = int(minor) - except ValueError: - return None - return major, minor - - @classmethod - def _parse_init(self, line): - try: - method, url, protocol = string.split(line) - except ValueError: - return None - httpversion = self._parse_http_protocol(protocol) - if not httpversion: - return None - if not utils.isascii(method): - return None - return method, url, httpversion - - @classmethod - def _parse_init_connect(self, line): - """ - Returns (host, port, httpversion) if line is a valid CONNECT line. - http://tools.ietf.org/html/draft-luotonen-web-proxy-tunneling-01 section 3.1 - """ - v = self._parse_init(line) - if not v: - return None - method, url, httpversion = v - - if method.upper() != 'CONNECT': - return None - try: - host, port = url.split(":") - except ValueError: - return None - try: - port = int(port) - except ValueError: - return None - if not utils.is_valid_port(port): - return None - if not utils.is_valid_host(host): - return None - return host, port, httpversion - - @classmethod - def _parse_init_proxy(self, line): - v = self._parse_init(line) - if not v: - return None - method, url, httpversion = v - - parts = utils.parse_url(url) - if not parts: - return None - scheme, host, port, path = parts - return method, scheme, host, port, path, httpversion - - @classmethod - def _parse_init_http(self, line): - """ - Returns (method, url, httpversion) - """ - v = self._parse_init(line) - if not v: - return None - method, url, httpversion = v - if not utils.isascii(url): - return None - if not (url.startswith("/") or url == "*"): - return None - return method, url, httpversion - - @classmethod - def connection_close(self, httpversion, headers): - """ - Checks the message to see if the client connection should be closed - according to RFC 2616 Section 8.1 Note that a connection should be - closed as well if the response has been read until end of the stream. - """ - # At first, check if we have an explicit Connection header. - if "connection" in headers: - toks = utils.get_header_tokens(headers, "connection") - if "close" in toks: - return True - elif "keep-alive" in toks: - return False - - # If we don't have a Connection header, HTTP 1.1 connections are assumed to - # be persistent - return httpversion != (1, 1) - - @classmethod - def parse_response_line(self, line): - parts = line.strip().split(" ", 2) - if len(parts) == 2: # handle missing message gracefully - parts.append("") - if len(parts) != 3: - return None - proto, code, msg = parts - try: - code = int(code) - except ValueError: - return None - return (proto, code, msg) - - @classmethod - def _assemble_request_first_line(self, request): - return request.legacy_first_line() - - def _assemble_request_headers(self, request): - headers = request.headers.copy() - for k in request._headers_to_strip_off: - headers.pop(k, None) - if 'host' not in headers and request.scheme and request.host and request.port: - headers["Host"] = utils.hostport( - request.scheme, - request.host, - request.port - ) - - # If content is defined (i.e. not None or CONTENT_MISSING), we always - # add a content-length header. - if request.body or request.body == "": - headers["Content-Length"] = str(len(request.body)) - - return str(headers) - - def _assemble_response_first_line(self, response): - return 'HTTP/%s.%s %s %s' % ( - response.httpversion[0], - response.httpversion[1], - response.status_code, - response.msg, - ) - - def _assemble_response_headers( - self, - response, - preserve_transfer_encoding=False, - ): - headers = response.headers.copy() - for k in response._headers_to_strip_off: - headers.pop(k, None) - if not preserve_transfer_encoding: - headers.pop('Transfer-Encoding', None) - - # If body is defined (i.e. not None or CONTENT_MISSING), we always - # add a content-length header. - if response.body or response.body == "": - headers["Content-Length"] = str(len(response.body)) - - return str(headers) diff --git a/netlib/http/http1/read.py b/netlib/http/http1/read.py new file mode 100644 index 00000000..573bc739 --- /dev/null +++ b/netlib/http/http1/read.py @@ -0,0 +1,346 @@ +from __future__ import absolute_import, print_function, division +import time +import sys +import re + +from ... import utils +from ...exceptions import HttpReadDisconnect, HttpSyntaxException, HttpException +from .. import Request, Response, Headers + +ALPN_PROTO_HTTP1 = 'http/1.1' + + +def read_request(rfile, body_size_limit=None): + request = read_request_head(rfile) + request.body = read_message_body(rfile, request, limit=body_size_limit) + request.timestamp_end = time.time() + return request + + +def read_request_head(rfile): + """ + Parse an HTTP request head (request line + headers) from an input stream + + Args: + rfile: The input stream + body_size_limit (bool): Maximum body size + + Returns: + The HTTP request object + + Raises: + HttpReadDisconnect: If no bytes can be read from rfile. + HttpSyntaxException: If the input is invalid. + HttpException: A different error occured. + """ + timestamp_start = time.time() + if hasattr(rfile, "reset_timestamps"): + rfile.reset_timestamps() + + form, method, scheme, host, port, path, http_version = _read_request_line(rfile) + headers = _read_headers(rfile) + + if hasattr(rfile, "first_byte_timestamp"): + # more accurate timestamp_start + timestamp_start = rfile.first_byte_timestamp + + return Request( + form, method, scheme, host, port, path, http_version, headers, None, timestamp_start + ) + + +def read_response(rfile, request, body_size_limit=None): + response = read_response_head(rfile) + response.body = read_message_body(rfile, request, response, body_size_limit) + response.timestamp_end = time.time() + return response + + +def read_response_head(rfile): + timestamp_start = time.time() + if hasattr(rfile, "reset_timestamps"): + rfile.reset_timestamps() + + http_version, status_code, message = _read_response_line(rfile) + headers = _read_headers(rfile) + + if hasattr(rfile, "first_byte_timestamp"): + # more accurate timestamp_start + timestamp_start = rfile.first_byte_timestamp + + return Response( + http_version, + status_code, + message, + headers, + None, + timestamp_start + ) + + +def read_message_body(*args, **kwargs): + chunks = read_message_body_chunked(*args, **kwargs) + return b"".join(chunks) + + +def read_message_body_chunked(rfile, request, response=None, limit=None, max_chunk_size=None): + """ + Read an HTTP message body: + + Args: + If a request body should be read, only request should be passed. + If a response body should be read, both request and response should be passed. + + Raises: + HttpException + """ + if not response: + headers = request.headers + response_code = None + is_request = True + else: + headers = response.headers + response_code = response.status_code + is_request = False + + if not limit or limit < 0: + limit = sys.maxsize + if not max_chunk_size: + max_chunk_size = limit + + expected_size = expected_http_body_size( + headers, is_request, request.method, response_code + ) + + if expected_size is None: + for x in _read_chunked(rfile, limit): + yield x + elif expected_size >= 0: + if limit is not None and expected_size > limit: + raise HttpException( + "HTTP Body too large. " + "Limit is {}, content length was advertised as {}".format(limit, expected_size) + ) + bytes_left = expected_size + while bytes_left: + chunk_size = min(bytes_left, max_chunk_size) + content = rfile.read(chunk_size) + yield content + bytes_left -= chunk_size + else: + bytes_left = limit + while bytes_left: + chunk_size = min(bytes_left, max_chunk_size) + content = rfile.read(chunk_size) + if not content: + return + yield content + bytes_left -= chunk_size + not_done = rfile.read(1) + if not_done: + raise HttpException("HTTP body too large. Limit is {}.".format(limit)) + + +def connection_close(http_version, headers): + """ + Checks the message to see if the client connection should be closed + according to RFC 2616 Section 8.1. + """ + # At first, check if we have an explicit Connection header. + if b"connection" in headers: + toks = utils.get_header_tokens(headers, "connection") + if b"close" in toks: + return True + elif b"keep-alive" in toks: + return False + + # If we don't have a Connection header, HTTP 1.1 connections are assumed to + # be persistent + return http_version != (1, 1) + + +def expected_http_body_size( + headers, + is_request, + request_method, + response_code, +): + """ + Returns the expected body length: + - a positive integer, if the size is known in advance + - None, if the size in unknown in advance (chunked encoding) + - -1, if all data should be read until end of stream. + + Raises: + HttpSyntaxException, if the content length header is invalid + """ + # Determine response size according to + # http://tools.ietf.org/html/rfc7230#section-3.3 + if request_method: + request_method = request_method.upper() + + is_empty_response = (not is_request and ( + request_method == b"HEAD" or + 100 <= response_code <= 199 or + (response_code == 200 and request_method == b"CONNECT") or + response_code in (204, 304) + )) + + if is_empty_response: + return 0 + if is_request and headers.get(b"expect", b"").lower() == b"100-continue": + return 0 + if b"chunked" in headers.get(b"transfer-encoding", b"").lower(): + return None + if b"content-length" in headers: + try: + size = int(headers[b"content-length"]) + if size < 0: + raise ValueError() + return size + except ValueError: + raise HttpSyntaxException("Unparseable Content Length") + if is_request: + return 0 + return -1 + + +def _get_first_line(rfile): + line = rfile.readline() + if line == b"\r\n" or line == b"\n": + # Possible leftover from previous message + line = rfile.readline() + if not line: + raise HttpReadDisconnect() + return line + + +def _read_request_line(rfile): + line = _get_first_line(rfile) + + try: + method, path, http_version = line.strip().split(b" ") + + if path == b"*" or path.startswith(b"/"): + form = "relative" + path.decode("ascii") # should not raise a ValueError + scheme, host, port = None, None, None + elif method == b"CONNECT": + form = "authority" + host, port = _parse_authority_form(path) + scheme, path = None, None + else: + form = "absolute" + scheme, host, port, path = utils.parse_url(path) + + except ValueError: + raise HttpSyntaxException("Bad HTTP request line: {}".format(line)) + + return form, method, scheme, host, port, path, http_version + + +def _parse_authority_form(hostport): + """ + Returns (host, port) if hostport is a valid authority-form host specification. + http://tools.ietf.org/html/draft-luotonen-web-proxy-tunneling-01 section 3.1 + + Raises: + ValueError, if the input is malformed + """ + try: + host, port = hostport.split(b":") + port = int(port) + if not utils.is_valid_host(host) or not utils.is_valid_port(port): + raise ValueError() + except ValueError: + raise ValueError("Invalid host specification: {}".format(hostport)) + + return host, port + + +def _read_response_line(rfile): + line = _get_first_line(rfile) + + try: + + parts = line.strip().split(b" ") + if len(parts) == 2: # handle missing message gracefully + parts.append(b"") + + http_version, status_code, message = parts + status_code = int(status_code) + _check_http_version(http_version) + + except ValueError: + raise HttpSyntaxException("Bad HTTP response line: {}".format(line)) + + return http_version, status_code, message + + +def _check_http_version(http_version): + if not re.match(rb"^HTTP/\d\.\d$", http_version): + raise HttpSyntaxException("Unknown HTTP version: {}".format(http_version)) + + +def _read_headers(rfile): + """ + Read a set of headers. + Stop once a blank line is reached. + + Returns: + A headers object + + Raises: + HttpSyntaxException + """ + ret = [] + while True: + line = rfile.readline() + if not line or line == b"\r\n" or line == b"\n": + break + if line[0] in b" \t": + if not ret: + raise HttpSyntaxException("Invalid headers") + # continued header + ret[-1][1] = ret[-1][1] + b'\r\n ' + line.strip() + else: + try: + name, value = line.split(b":", 1) + value = value.strip() + ret.append([name, value]) + except ValueError: + raise HttpSyntaxException("Invalid headers") + return Headers(ret) + + +def _read_chunked(rfile, limit): + """ + Read a HTTP body with chunked transfer encoding. + + Args: + rfile: the input file + limit: A positive integer + """ + total = 0 + while True: + line = rfile.readline(128) + if line == b"": + raise HttpException("Connection closed prematurely") + if line != b"\r\n" and line != b"\n": + try: + length = int(line, 16) + except ValueError: + raise HttpSyntaxException("Invalid chunked encoding length: {}".format(line)) + total += length + if total > limit: + raise HttpException( + "HTTP Body too large. Limit is {}, " + "chunked content longer than {}".format(limit, total) + ) + chunk = rfile.read(length) + suffix = rfile.readline(5) + if suffix != b"\r\n": + raise HttpSyntaxException("Malformed chunked body") + if length == 0: + return + yield chunk diff --git a/netlib/http/http2/__init__.py b/netlib/http/http2/__init__.py index 5acf7696..e69de29b 100644 --- a/netlib/http/http2/__init__.py +++ b/netlib/http/http2/__init__.py @@ -1,2 +0,0 @@ -from frame import * -from protocol import * diff --git a/netlib/http/http2/connections.py b/netlib/http/http2/connections.py new file mode 100644 index 00000000..b6d376d3 --- /dev/null +++ b/netlib/http/http2/connections.py @@ -0,0 +1,412 @@ +from __future__ import (absolute_import, print_function, division) +import itertools +import time + +from hpack.hpack import Encoder, Decoder +from netlib import http, utils +from netlib.http import semantics +from . import frame + + +class TCPHandler(object): + + def __init__(self, rfile, wfile=None): + self.rfile = rfile + self.wfile = wfile + + +class HTTP2Protocol(semantics.ProtocolMixin): + + ERROR_CODES = utils.BiDi( + NO_ERROR=0x0, + PROTOCOL_ERROR=0x1, + INTERNAL_ERROR=0x2, + FLOW_CONTROL_ERROR=0x3, + SETTINGS_TIMEOUT=0x4, + STREAM_CLOSED=0x5, + FRAME_SIZE_ERROR=0x6, + REFUSED_STREAM=0x7, + CANCEL=0x8, + COMPRESSION_ERROR=0x9, + CONNECT_ERROR=0xa, + ENHANCE_YOUR_CALM=0xb, + INADEQUATE_SECURITY=0xc, + HTTP_1_1_REQUIRED=0xd + ) + + CLIENT_CONNECTION_PREFACE = "PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n" + + ALPN_PROTO_H2 = 'h2' + + def __init__( + self, + tcp_handler=None, + rfile=None, + wfile=None, + is_server=False, + dump_frames=False, + encoder=None, + decoder=None, + unhandled_frame_cb=None, + ): + self.tcp_handler = tcp_handler or TCPHandler(rfile, wfile) + self.is_server = is_server + self.dump_frames = dump_frames + self.encoder = encoder or Encoder() + self.decoder = decoder or Decoder() + self.unhandled_frame_cb = unhandled_frame_cb + + self.http2_settings = frame.HTTP2_DEFAULT_SETTINGS.copy() + self.current_stream_id = None + self.connection_preface_performed = False + + def read_request( + self, + include_body=True, + body_size_limit=None, + allow_empty=False, + ): + if body_size_limit is not None: + raise NotImplementedError() + + self.perform_connection_preface() + + timestamp_start = time.time() + if hasattr(self.tcp_handler.rfile, "reset_timestamps"): + self.tcp_handler.rfile.reset_timestamps() + + stream_id, headers, body = self._receive_transmission( + include_body=include_body, + ) + + if hasattr(self.tcp_handler.rfile, "first_byte_timestamp"): + # more accurate timestamp_start + timestamp_start = self.tcp_handler.rfile.first_byte_timestamp + + timestamp_end = time.time() + + authority = headers.get(':authority', '') + method = headers.get(':method', 'GET') + scheme = headers.get(':scheme', 'https') + path = headers.get(':path', '/') + host = None + port = None + + if path == '*' or path.startswith("/"): + form_in = "relative" + elif method == 'CONNECT': + form_in = "authority" + if ":" in authority: + host, port = authority.split(":", 1) + else: + host = authority + else: + form_in = "absolute" + # FIXME: verify if path or :host contains what we need + scheme, host, port, _ = utils.parse_url(path) + + if host is None: + host = 'localhost' + if port is None: + port = 80 if scheme == 'http' else 443 + port = int(port) + + request = http.Request( + form_in, + method, + scheme, + host, + port, + path, + (2, 0), + headers, + body, + timestamp_start, + timestamp_end, + ) + # FIXME: We should not do this. + request.stream_id = stream_id + + return request + + def read_response( + self, + request_method='', + body_size_limit=None, + include_body=True, + stream_id=None, + ): + if body_size_limit is not None: + raise NotImplementedError() + + self.perform_connection_preface() + + timestamp_start = time.time() + if hasattr(self.tcp_handler.rfile, "reset_timestamps"): + self.tcp_handler.rfile.reset_timestamps() + + stream_id, headers, body = self._receive_transmission( + stream_id=stream_id, + include_body=include_body, + ) + + if hasattr(self.tcp_handler.rfile, "first_byte_timestamp"): + # more accurate timestamp_start + timestamp_start = self.tcp_handler.rfile.first_byte_timestamp + + if include_body: + timestamp_end = time.time() + else: + timestamp_end = None + + response = http.Response( + (2, 0), + int(headers.get(':status', 502)), + "", + headers, + body, + timestamp_start=timestamp_start, + timestamp_end=timestamp_end, + ) + response.stream_id = stream_id + + return response + + def assemble_request(self, request): + assert isinstance(request, semantics.Request) + + authority = self.tcp_handler.sni if self.tcp_handler.sni else self.tcp_handler.address.host + if self.tcp_handler.address.port != 443: + authority += ":%d" % self.tcp_handler.address.port + + headers = request.headers.copy() + + if ':authority' not in headers: + headers.fields.insert(0, (':authority', bytes(authority))) + if ':scheme' not in headers: + headers.fields.insert(0, (':scheme', bytes(request.scheme))) + if ':path' not in headers: + headers.fields.insert(0, (':path', bytes(request.path))) + if ':method' not in headers: + headers.fields.insert(0, (':method', bytes(request.method))) + + if hasattr(request, 'stream_id'): + stream_id = request.stream_id + else: + stream_id = self._next_stream_id() + + return list(itertools.chain( + self._create_headers(headers, stream_id, end_stream=(request.body is None or len(request.body) == 0)), + self._create_body(request.body, stream_id))) + + def assemble_response(self, response): + assert isinstance(response, semantics.Response) + + headers = response.headers.copy() + + if ':status' not in headers: + headers.fields.insert(0, (':status', bytes(str(response.status_code)))) + + if hasattr(response, 'stream_id'): + stream_id = response.stream_id + else: + stream_id = self._next_stream_id() + + return list(itertools.chain( + self._create_headers(headers, stream_id, end_stream=(response.body is None or len(response.body) == 0)), + self._create_body(response.body, stream_id), + )) + + def perform_connection_preface(self, force=False): + if force or not self.connection_preface_performed: + if self.is_server: + self.perform_server_connection_preface(force) + else: + self.perform_client_connection_preface(force) + + def perform_server_connection_preface(self, force=False): + if force or not self.connection_preface_performed: + self.connection_preface_performed = True + + magic_length = len(self.CLIENT_CONNECTION_PREFACE) + magic = self.tcp_handler.rfile.safe_read(magic_length) + assert magic == self.CLIENT_CONNECTION_PREFACE + + frm = frame.SettingsFrame(state=self, settings={ + frame.SettingsFrame.SETTINGS.SETTINGS_ENABLE_PUSH: 0, + frame.SettingsFrame.SETTINGS.SETTINGS_MAX_CONCURRENT_STREAMS: 1, + }) + self.send_frame(frm, hide=True) + self._receive_settings(hide=True) + + def perform_client_connection_preface(self, force=False): + if force or not self.connection_preface_performed: + self.connection_preface_performed = True + + self.tcp_handler.wfile.write(self.CLIENT_CONNECTION_PREFACE) + + self.send_frame(frame.SettingsFrame(state=self), hide=True) + self._receive_settings(hide=True) # server announces own settings + self._receive_settings(hide=True) # server acks my settings + + def send_frame(self, frm, hide=False): + raw_bytes = frm.to_bytes() + self.tcp_handler.wfile.write(raw_bytes) + self.tcp_handler.wfile.flush() + if not hide and self.dump_frames: # pragma no cover + print(frm.human_readable(">>")) + + def read_frame(self, hide=False): + while True: + frm = frame.Frame.from_file(self.tcp_handler.rfile, self) + if not hide and self.dump_frames: # pragma no cover + print(frm.human_readable("<<")) + + if isinstance(frm, frame.PingFrame): + raw_bytes = frame.PingFrame(flags=frame.Frame.FLAG_ACK, payload=frm.payload).to_bytes() + self.tcp_handler.wfile.write(raw_bytes) + self.tcp_handler.wfile.flush() + continue + if isinstance(frm, frame.SettingsFrame) and not frm.flags & frame.Frame.FLAG_ACK: + self._apply_settings(frm.settings, hide) + if isinstance(frm, frame.DataFrame) and frm.length > 0: + self._update_flow_control_window(frm.stream_id, frm.length) + return frm + + def check_alpn(self): + alp = self.tcp_handler.get_alpn_proto_negotiated() + if alp != self.ALPN_PROTO_H2: + raise NotImplementedError( + "HTTP2Protocol can not handle unknown ALP: %s" % alp) + return True + + def _handle_unexpected_frame(self, frm): + if isinstance(frm, frame.SettingsFrame): + return + if self.unhandled_frame_cb: + self.unhandled_frame_cb(frm) + + def _receive_settings(self, hide=False): + while True: + frm = self.read_frame(hide) + if isinstance(frm, frame.SettingsFrame): + break + else: + self._handle_unexpected_frame(frm) + + def _next_stream_id(self): + if self.current_stream_id is None: + if self.is_server: + # servers must use even stream ids + self.current_stream_id = 2 + else: + # clients must use odd stream ids + self.current_stream_id = 1 + else: + self.current_stream_id += 2 + return self.current_stream_id + + def _apply_settings(self, settings, hide=False): + for setting, value in settings.items(): + old_value = self.http2_settings[setting] + if not old_value: + old_value = '-' + self.http2_settings[setting] = value + + frm = frame.SettingsFrame( + state=self, + flags=frame.Frame.FLAG_ACK) + self.send_frame(frm, hide) + + def _update_flow_control_window(self, stream_id, increment): + frm = frame.WindowUpdateFrame(stream_id=0, window_size_increment=increment) + self.send_frame(frm) + frm = frame.WindowUpdateFrame(stream_id=stream_id, window_size_increment=increment) + self.send_frame(frm) + + def _create_headers(self, headers, stream_id, end_stream=True): + def frame_cls(chunks): + for i in chunks: + if i == 0: + yield frame.HeadersFrame, i + else: + yield frame.ContinuationFrame, i + + header_block_fragment = self.encoder.encode(headers.fields) + + chunk_size = self.http2_settings[frame.SettingsFrame.SETTINGS.SETTINGS_MAX_FRAME_SIZE] + chunks = range(0, len(header_block_fragment), chunk_size) + frms = [frm_cls( + state=self, + flags=frame.Frame.FLAG_NO_FLAGS, + stream_id=stream_id, + header_block_fragment=header_block_fragment[i:i+chunk_size]) for frm_cls, i in frame_cls(chunks)] + + last_flags = frame.Frame.FLAG_END_HEADERS + if end_stream: + last_flags |= frame.Frame.FLAG_END_STREAM + frms[-1].flags = last_flags + + if self.dump_frames: # pragma no cover + for frm in frms: + print(frm.human_readable(">>")) + + return [frm.to_bytes() for frm in frms] + + def _create_body(self, body, stream_id): + if body is None or len(body) == 0: + return b'' + + chunk_size = self.http2_settings[frame.SettingsFrame.SETTINGS.SETTINGS_MAX_FRAME_SIZE] + chunks = range(0, len(body), chunk_size) + frms = [frame.DataFrame( + state=self, + flags=frame.Frame.FLAG_NO_FLAGS, + stream_id=stream_id, + payload=body[i:i+chunk_size]) for i in chunks] + frms[-1].flags = frame.Frame.FLAG_END_STREAM + + if self.dump_frames: # pragma no cover + for frm in frms: + print(frm.human_readable(">>")) + + return [frm.to_bytes() for frm in frms] + + def _receive_transmission(self, stream_id=None, include_body=True): + if not include_body: + raise NotImplementedError() + + body_expected = True + + header_block_fragment = b'' + body = b'' + + while True: + frm = self.read_frame() + if ( + (isinstance(frm, frame.HeadersFrame) or isinstance(frm, frame.ContinuationFrame)) and + (stream_id is None or frm.stream_id == stream_id) + ): + stream_id = frm.stream_id + header_block_fragment += frm.header_block_fragment + if frm.flags & frame.Frame.FLAG_END_STREAM: + body_expected = False + if frm.flags & frame.Frame.FLAG_END_HEADERS: + break + else: + self._handle_unexpected_frame(frm) + + while body_expected: + frm = self.read_frame() + if isinstance(frm, frame.DataFrame) and frm.stream_id == stream_id: + body += frm.payload + if frm.flags & frame.Frame.FLAG_END_STREAM: + break + else: + self._handle_unexpected_frame(frm) + + headers = http.Headers( + [[str(k), str(v)] for k, v in self.decoder.decode(header_block_fragment)] + ) + + return stream_id, headers, body diff --git a/netlib/http/http2/frame.py b/netlib/http/http2/frame.py deleted file mode 100644 index b36b3adf..00000000 --- a/netlib/http/http2/frame.py +++ /dev/null @@ -1,633 +0,0 @@ -import sys -import struct -from hpack.hpack import Encoder, Decoder - -from .. import utils - - -class FrameSizeError(Exception): - pass - - -class Frame(object): - - """ - Baseclass Frame - contains header - payload is defined in subclasses - """ - - FLAG_NO_FLAGS = 0x0 - FLAG_ACK = 0x1 - FLAG_END_STREAM = 0x1 - FLAG_END_HEADERS = 0x4 - FLAG_PADDED = 0x8 - FLAG_PRIORITY = 0x20 - - def __init__( - self, - state=None, - length=0, - flags=FLAG_NO_FLAGS, - stream_id=0x0): - valid_flags = reduce(lambda x, y: x | y, self.VALID_FLAGS, 0x0) - if flags | valid_flags != valid_flags: - raise ValueError('invalid flags detected.') - - if state is None: - class State(object): - pass - - state = State() - state.http2_settings = HTTP2_DEFAULT_SETTINGS.copy() - state.encoder = Encoder() - state.decoder = Decoder() - - self.state = state - - self.length = length - self.type = self.TYPE - self.flags = flags - self.stream_id = stream_id - - @classmethod - def _check_frame_size(cls, length, state): - if state: - settings = state.http2_settings - else: - settings = HTTP2_DEFAULT_SETTINGS.copy() - - max_frame_size = settings[ - SettingsFrame.SETTINGS.SETTINGS_MAX_FRAME_SIZE] - - if length > max_frame_size: - raise FrameSizeError( - "Frame size exceeded: %d, but only %d allowed." % ( - length, max_frame_size)) - - @classmethod - def from_file(cls, fp, state=None): - """ - read a HTTP/2 frame sent by a server or client - fp is a "file like" object that could be backed by a network - stream or a disk or an in memory stream reader - """ - raw_header = fp.safe_read(9) - - fields = struct.unpack("!HBBBL", raw_header) - length = (fields[0] << 8) + fields[1] - flags = fields[3] - stream_id = fields[4] - - if raw_header[:4] == b'HTTP': # pragma no cover - print >> sys.stderr, "WARNING: This looks like an HTTP/1 connection!" - - cls._check_frame_size(length, state) - - payload = fp.safe_read(length) - return FRAMES[fields[2]].from_bytes( - state, - length, - flags, - stream_id, - payload) - - def to_bytes(self): - payload = self.payload_bytes() - self.length = len(payload) - - self._check_frame_size(self.length, self.state) - - b = struct.pack('!HB', (self.length & 0xFFFF00) >> 8, self.length & 0x0000FF) - b += struct.pack('!B', self.TYPE) - b += struct.pack('!B', self.flags) - b += struct.pack('!L', self.stream_id & 0x7FFFFFFF) - b += payload - - return b - - def payload_bytes(self): # pragma: no cover - raise NotImplementedError() - - def payload_human_readable(self): # pragma: no cover - raise NotImplementedError() - - def human_readable(self, direction="-"): - self.length = len(self.payload_bytes()) - - return "\n".join([ - "%s: %s | length: %d | flags: %#x | stream_id: %d" % ( - direction, self.__class__.__name__, self.length, self.flags, self.stream_id), - self.payload_human_readable(), - "===============================================================", - ]) - - def __eq__(self, other): - return self.to_bytes() == other.to_bytes() - - -class DataFrame(Frame): - TYPE = 0x0 - VALID_FLAGS = [Frame.FLAG_END_STREAM, Frame.FLAG_PADDED] - - def __init__( - self, - state=None, - length=0, - flags=Frame.FLAG_NO_FLAGS, - stream_id=0x0, - payload=b'', - pad_length=0): - super(DataFrame, self).__init__(state, length, flags, stream_id) - self.payload = payload - self.pad_length = pad_length - - @classmethod - def from_bytes(cls, state, length, flags, stream_id, payload): - f = cls(state=state, length=length, flags=flags, stream_id=stream_id) - - if f.flags & Frame.FLAG_PADDED: - f.pad_length = struct.unpack('!B', payload[0])[0] - f.payload = payload[1:-f.pad_length] - else: - f.payload = payload - - return f - - def payload_bytes(self): - if self.stream_id == 0x0: - raise ValueError('DATA frames MUST be associated with a stream.') - - b = b'' - if self.flags & self.FLAG_PADDED: - b += struct.pack('!B', self.pad_length) - - b += bytes(self.payload) - - if self.flags & self.FLAG_PADDED: - b += b'\0' * self.pad_length - - return b - - def payload_human_readable(self): - return "payload: %s" % str(self.payload) - - -class HeadersFrame(Frame): - TYPE = 0x1 - VALID_FLAGS = [ - Frame.FLAG_END_STREAM, - Frame.FLAG_END_HEADERS, - Frame.FLAG_PADDED, - Frame.FLAG_PRIORITY] - - def __init__( - self, - state=None, - length=0, - flags=Frame.FLAG_NO_FLAGS, - stream_id=0x0, - header_block_fragment=b'', - pad_length=0, - exclusive=False, - stream_dependency=0x0, - weight=0): - super(HeadersFrame, self).__init__(state, length, flags, stream_id) - - self.header_block_fragment = header_block_fragment - self.pad_length = pad_length - self.exclusive = exclusive - self.stream_dependency = stream_dependency - self.weight = weight - - @classmethod - def from_bytes(cls, state, length, flags, stream_id, payload): - f = cls(state=state, length=length, flags=flags, stream_id=stream_id) - - if f.flags & Frame.FLAG_PADDED: - f.pad_length = struct.unpack('!B', payload[0])[0] - f.header_block_fragment = payload[1:-f.pad_length] - else: - f.header_block_fragment = payload[0:] - - if f.flags & Frame.FLAG_PRIORITY: - f.stream_dependency, f.weight = struct.unpack( - '!LB', f.header_block_fragment[:5]) - f.exclusive = bool(f.stream_dependency >> 31) - f.stream_dependency &= 0x7FFFFFFF - f.header_block_fragment = f.header_block_fragment[5:] - - return f - - def payload_bytes(self): - if self.stream_id == 0x0: - raise ValueError('HEADERS frames MUST be associated with a stream.') - - b = b'' - if self.flags & self.FLAG_PADDED: - b += struct.pack('!B', self.pad_length) - - if self.flags & self.FLAG_PRIORITY: - b += struct.pack('!LB', - (int(self.exclusive) << 31) | self.stream_dependency, - self.weight) - - b += self.header_block_fragment - - if self.flags & self.FLAG_PADDED: - b += b'\0' * self.pad_length - - return b - - def payload_human_readable(self): - s = [] - - if self.flags & self.FLAG_PRIORITY: - s.append("exclusive: %d" % self.exclusive) - s.append("stream dependency: %#x" % self.stream_dependency) - s.append("weight: %d" % self.weight) - - if self.flags & self.FLAG_PADDED: - s.append("padding: %d" % self.pad_length) - - s.append( - "header_block_fragment: %s" % - self.header_block_fragment.encode('hex')) - - return "\n".join(s) - - -class PriorityFrame(Frame): - TYPE = 0x2 - VALID_FLAGS = [] - - def __init__( - self, - state=None, - length=0, - flags=Frame.FLAG_NO_FLAGS, - stream_id=0x0, - exclusive=False, - stream_dependency=0x0, - weight=0): - super(PriorityFrame, self).__init__(state, length, flags, stream_id) - self.exclusive = exclusive - self.stream_dependency = stream_dependency - self.weight = weight - - @classmethod - def from_bytes(cls, state, length, flags, stream_id, payload): - f = cls(state=state, length=length, flags=flags, stream_id=stream_id) - - f.stream_dependency, f.weight = struct.unpack('!LB', payload) - f.exclusive = bool(f.stream_dependency >> 31) - f.stream_dependency &= 0x7FFFFFFF - - return f - - def payload_bytes(self): - if self.stream_id == 0x0: - raise ValueError( - 'PRIORITY frames MUST be associated with a stream.') - - return struct.pack( - '!LB', - (int( - self.exclusive) << 31) | self.stream_dependency, - self.weight) - - def payload_human_readable(self): - s = [] - s.append("exclusive: %d" % self.exclusive) - s.append("stream dependency: %#x" % self.stream_dependency) - s.append("weight: %d" % self.weight) - return "\n".join(s) - - -class RstStreamFrame(Frame): - TYPE = 0x3 - VALID_FLAGS = [] - - def __init__( - self, - state=None, - length=0, - flags=Frame.FLAG_NO_FLAGS, - stream_id=0x0, - error_code=0x0): - super(RstStreamFrame, self).__init__(state, length, flags, stream_id) - self.error_code = error_code - - @classmethod - def from_bytes(cls, state, length, flags, stream_id, payload): - f = cls(state=state, length=length, flags=flags, stream_id=stream_id) - f.error_code = struct.unpack('!L', payload)[0] - return f - - def payload_bytes(self): - if self.stream_id == 0x0: - raise ValueError( - 'RST_STREAM frames MUST be associated with a stream.') - - return struct.pack('!L', self.error_code) - - def payload_human_readable(self): - return "error code: %#x" % self.error_code - - -class SettingsFrame(Frame): - TYPE = 0x4 - VALID_FLAGS = [Frame.FLAG_ACK] - - SETTINGS = utils.BiDi( - SETTINGS_HEADER_TABLE_SIZE=0x1, - SETTINGS_ENABLE_PUSH=0x2, - SETTINGS_MAX_CONCURRENT_STREAMS=0x3, - SETTINGS_INITIAL_WINDOW_SIZE=0x4, - SETTINGS_MAX_FRAME_SIZE=0x5, - SETTINGS_MAX_HEADER_LIST_SIZE=0x6, - ) - - def __init__( - self, - state=None, - length=0, - flags=Frame.FLAG_NO_FLAGS, - stream_id=0x0, - settings=None): - super(SettingsFrame, self).__init__(state, length, flags, stream_id) - - if settings is None: - settings = {} - - self.settings = settings - - @classmethod - def from_bytes(cls, state, length, flags, stream_id, payload): - f = cls(state=state, length=length, flags=flags, stream_id=stream_id) - - for i in xrange(0, len(payload), 6): - identifier, value = struct.unpack("!HL", payload[i:i + 6]) - f.settings[identifier] = value - - return f - - def payload_bytes(self): - if self.stream_id != 0x0: - raise ValueError( - 'SETTINGS frames MUST NOT be associated with a stream.') - - b = b'' - for identifier, value in self.settings.items(): - b += struct.pack("!HL", identifier & 0xFF, value) - - return b - - def payload_human_readable(self): - s = [] - - for identifier, value in self.settings.items(): - s.append("%s: %#x" % (self.SETTINGS.get_name(identifier), value)) - - if not s: - return "settings: None" - else: - return "\n".join(s) - - -class PushPromiseFrame(Frame): - TYPE = 0x5 - VALID_FLAGS = [Frame.FLAG_END_HEADERS, Frame.FLAG_PADDED] - - def __init__( - self, - state=None, - length=0, - flags=Frame.FLAG_NO_FLAGS, - stream_id=0x0, - promised_stream=0x0, - header_block_fragment=b'', - pad_length=0): - super(PushPromiseFrame, self).__init__(state, length, flags, stream_id) - self.pad_length = pad_length - self.promised_stream = promised_stream - self.header_block_fragment = header_block_fragment - - @classmethod - def from_bytes(cls, state, length, flags, stream_id, payload): - f = cls(state=state, length=length, flags=flags, stream_id=stream_id) - - if f.flags & Frame.FLAG_PADDED: - f.pad_length, f.promised_stream = struct.unpack('!BL', payload[:5]) - f.header_block_fragment = payload[5:-f.pad_length] - else: - f.promised_stream = int(struct.unpack("!L", payload[:4])[0]) - f.header_block_fragment = payload[4:] - - f.promised_stream &= 0x7FFFFFFF - - return f - - def payload_bytes(self): - if self.stream_id == 0x0: - raise ValueError( - 'PUSH_PROMISE frames MUST be associated with a stream.') - - if self.promised_stream == 0x0: - raise ValueError('Promised stream id not valid.') - - b = b'' - if self.flags & self.FLAG_PADDED: - b += struct.pack('!B', self.pad_length) - - b += struct.pack('!L', self.promised_stream & 0x7FFFFFFF) - b += bytes(self.header_block_fragment) - - if self.flags & self.FLAG_PADDED: - b += b'\0' * self.pad_length - - return b - - def payload_human_readable(self): - s = [] - - if self.flags & self.FLAG_PADDED: - s.append("padding: %d" % self.pad_length) - - s.append("promised stream: %#x" % self.promised_stream) - s.append( - "header_block_fragment: %s" % - self.header_block_fragment.encode('hex')) - - return "\n".join(s) - - -class PingFrame(Frame): - TYPE = 0x6 - VALID_FLAGS = [Frame.FLAG_ACK] - - def __init__( - self, - state=None, - length=0, - flags=Frame.FLAG_NO_FLAGS, - stream_id=0x0, - payload=b''): - super(PingFrame, self).__init__(state, length, flags, stream_id) - self.payload = payload - - @classmethod - def from_bytes(cls, state, length, flags, stream_id, payload): - f = cls(state=state, length=length, flags=flags, stream_id=stream_id) - f.payload = payload - return f - - def payload_bytes(self): - if self.stream_id != 0x0: - raise ValueError( - 'PING frames MUST NOT be associated with a stream.') - - b = self.payload[0:8] - b += b'\0' * (8 - len(b)) - return b - - def payload_human_readable(self): - return "opaque data: %s" % str(self.payload) - - -class GoAwayFrame(Frame): - TYPE = 0x7 - VALID_FLAGS = [] - - def __init__( - self, - state=None, - length=0, - flags=Frame.FLAG_NO_FLAGS, - stream_id=0x0, - last_stream=0x0, - error_code=0x0, - data=b''): - super(GoAwayFrame, self).__init__(state, length, flags, stream_id) - self.last_stream = last_stream - self.error_code = error_code - self.data = data - - @classmethod - def from_bytes(cls, state, length, flags, stream_id, payload): - f = cls(state=state, length=length, flags=flags, stream_id=stream_id) - - f.last_stream, f.error_code = struct.unpack("!LL", payload[:8]) - f.last_stream &= 0x7FFFFFFF - f.data = payload[8:] - - return f - - def payload_bytes(self): - if self.stream_id != 0x0: - raise ValueError( - 'GOAWAY frames MUST NOT be associated with a stream.') - - b = struct.pack('!LL', self.last_stream & 0x7FFFFFFF, self.error_code) - b += bytes(self.data) - return b - - def payload_human_readable(self): - s = [] - s.append("last stream: %#x" % self.last_stream) - s.append("error code: %d" % self.error_code) - s.append("debug data: %s" % str(self.data)) - return "\n".join(s) - - -class WindowUpdateFrame(Frame): - TYPE = 0x8 - VALID_FLAGS = [] - - def __init__( - self, - state=None, - length=0, - flags=Frame.FLAG_NO_FLAGS, - stream_id=0x0, - window_size_increment=0x0): - super(WindowUpdateFrame, self).__init__(state, length, flags, stream_id) - self.window_size_increment = window_size_increment - - @classmethod - def from_bytes(cls, state, length, flags, stream_id, payload): - f = cls(state=state, length=length, flags=flags, stream_id=stream_id) - - f.window_size_increment = struct.unpack("!L", payload)[0] - f.window_size_increment &= 0x7FFFFFFF - - return f - - def payload_bytes(self): - if self.window_size_increment <= 0 or self.window_size_increment >= 2 ** 31: - raise ValueError( - 'Window Size Increment MUST be greater than 0 and less than 2^31.') - - return struct.pack('!L', self.window_size_increment & 0x7FFFFFFF) - - def payload_human_readable(self): - return "window size increment: %#x" % self.window_size_increment - - -class ContinuationFrame(Frame): - TYPE = 0x9 - VALID_FLAGS = [Frame.FLAG_END_HEADERS] - - def __init__( - self, - state=None, - length=0, - flags=Frame.FLAG_NO_FLAGS, - stream_id=0x0, - header_block_fragment=b''): - super(ContinuationFrame, self).__init__(state, length, flags, stream_id) - self.header_block_fragment = header_block_fragment - - @classmethod - def from_bytes(cls, state, length, flags, stream_id, payload): - f = cls(state=state, length=length, flags=flags, stream_id=stream_id) - f.header_block_fragment = payload - return f - - def payload_bytes(self): - if self.stream_id == 0x0: - raise ValueError( - 'CONTINUATION frames MUST be associated with a stream.') - - return self.header_block_fragment - - def payload_human_readable(self): - s = [] - s.append( - "header_block_fragment: %s" % - self.header_block_fragment.encode('hex')) - return "\n".join(s) - -_FRAME_CLASSES = [ - DataFrame, - HeadersFrame, - PriorityFrame, - RstStreamFrame, - SettingsFrame, - PushPromiseFrame, - PingFrame, - GoAwayFrame, - WindowUpdateFrame, - ContinuationFrame -] -FRAMES = {cls.TYPE: cls for cls in _FRAME_CLASSES} - - -HTTP2_DEFAULT_SETTINGS = { - SettingsFrame.SETTINGS.SETTINGS_HEADER_TABLE_SIZE: 4096, - SettingsFrame.SETTINGS.SETTINGS_ENABLE_PUSH: 1, - SettingsFrame.SETTINGS.SETTINGS_MAX_CONCURRENT_STREAMS: None, - SettingsFrame.SETTINGS.SETTINGS_INITIAL_WINDOW_SIZE: 2 ** 16 - 1, - SettingsFrame.SETTINGS.SETTINGS_MAX_FRAME_SIZE: 2 ** 14, - SettingsFrame.SETTINGS.SETTINGS_MAX_HEADER_LIST_SIZE: None, -} diff --git a/netlib/http/http2/frames.py b/netlib/http/http2/frames.py new file mode 100644 index 00000000..b36b3adf --- /dev/null +++ b/netlib/http/http2/frames.py @@ -0,0 +1,633 @@ +import sys +import struct +from hpack.hpack import Encoder, Decoder + +from .. import utils + + +class FrameSizeError(Exception): + pass + + +class Frame(object): + + """ + Baseclass Frame + contains header + payload is defined in subclasses + """ + + FLAG_NO_FLAGS = 0x0 + FLAG_ACK = 0x1 + FLAG_END_STREAM = 0x1 + FLAG_END_HEADERS = 0x4 + FLAG_PADDED = 0x8 + FLAG_PRIORITY = 0x20 + + def __init__( + self, + state=None, + length=0, + flags=FLAG_NO_FLAGS, + stream_id=0x0): + valid_flags = reduce(lambda x, y: x | y, self.VALID_FLAGS, 0x0) + if flags | valid_flags != valid_flags: + raise ValueError('invalid flags detected.') + + if state is None: + class State(object): + pass + + state = State() + state.http2_settings = HTTP2_DEFAULT_SETTINGS.copy() + state.encoder = Encoder() + state.decoder = Decoder() + + self.state = state + + self.length = length + self.type = self.TYPE + self.flags = flags + self.stream_id = stream_id + + @classmethod + def _check_frame_size(cls, length, state): + if state: + settings = state.http2_settings + else: + settings = HTTP2_DEFAULT_SETTINGS.copy() + + max_frame_size = settings[ + SettingsFrame.SETTINGS.SETTINGS_MAX_FRAME_SIZE] + + if length > max_frame_size: + raise FrameSizeError( + "Frame size exceeded: %d, but only %d allowed." % ( + length, max_frame_size)) + + @classmethod + def from_file(cls, fp, state=None): + """ + read a HTTP/2 frame sent by a server or client + fp is a "file like" object that could be backed by a network + stream or a disk or an in memory stream reader + """ + raw_header = fp.safe_read(9) + + fields = struct.unpack("!HBBBL", raw_header) + length = (fields[0] << 8) + fields[1] + flags = fields[3] + stream_id = fields[4] + + if raw_header[:4] == b'HTTP': # pragma no cover + print >> sys.stderr, "WARNING: This looks like an HTTP/1 connection!" + + cls._check_frame_size(length, state) + + payload = fp.safe_read(length) + return FRAMES[fields[2]].from_bytes( + state, + length, + flags, + stream_id, + payload) + + def to_bytes(self): + payload = self.payload_bytes() + self.length = len(payload) + + self._check_frame_size(self.length, self.state) + + b = struct.pack('!HB', (self.length & 0xFFFF00) >> 8, self.length & 0x0000FF) + b += struct.pack('!B', self.TYPE) + b += struct.pack('!B', self.flags) + b += struct.pack('!L', self.stream_id & 0x7FFFFFFF) + b += payload + + return b + + def payload_bytes(self): # pragma: no cover + raise NotImplementedError() + + def payload_human_readable(self): # pragma: no cover + raise NotImplementedError() + + def human_readable(self, direction="-"): + self.length = len(self.payload_bytes()) + + return "\n".join([ + "%s: %s | length: %d | flags: %#x | stream_id: %d" % ( + direction, self.__class__.__name__, self.length, self.flags, self.stream_id), + self.payload_human_readable(), + "===============================================================", + ]) + + def __eq__(self, other): + return self.to_bytes() == other.to_bytes() + + +class DataFrame(Frame): + TYPE = 0x0 + VALID_FLAGS = [Frame.FLAG_END_STREAM, Frame.FLAG_PADDED] + + def __init__( + self, + state=None, + length=0, + flags=Frame.FLAG_NO_FLAGS, + stream_id=0x0, + payload=b'', + pad_length=0): + super(DataFrame, self).__init__(state, length, flags, stream_id) + self.payload = payload + self.pad_length = pad_length + + @classmethod + def from_bytes(cls, state, length, flags, stream_id, payload): + f = cls(state=state, length=length, flags=flags, stream_id=stream_id) + + if f.flags & Frame.FLAG_PADDED: + f.pad_length = struct.unpack('!B', payload[0])[0] + f.payload = payload[1:-f.pad_length] + else: + f.payload = payload + + return f + + def payload_bytes(self): + if self.stream_id == 0x0: + raise ValueError('DATA frames MUST be associated with a stream.') + + b = b'' + if self.flags & self.FLAG_PADDED: + b += struct.pack('!B', self.pad_length) + + b += bytes(self.payload) + + if self.flags & self.FLAG_PADDED: + b += b'\0' * self.pad_length + + return b + + def payload_human_readable(self): + return "payload: %s" % str(self.payload) + + +class HeadersFrame(Frame): + TYPE = 0x1 + VALID_FLAGS = [ + Frame.FLAG_END_STREAM, + Frame.FLAG_END_HEADERS, + Frame.FLAG_PADDED, + Frame.FLAG_PRIORITY] + + def __init__( + self, + state=None, + length=0, + flags=Frame.FLAG_NO_FLAGS, + stream_id=0x0, + header_block_fragment=b'', + pad_length=0, + exclusive=False, + stream_dependency=0x0, + weight=0): + super(HeadersFrame, self).__init__(state, length, flags, stream_id) + + self.header_block_fragment = header_block_fragment + self.pad_length = pad_length + self.exclusive = exclusive + self.stream_dependency = stream_dependency + self.weight = weight + + @classmethod + def from_bytes(cls, state, length, flags, stream_id, payload): + f = cls(state=state, length=length, flags=flags, stream_id=stream_id) + + if f.flags & Frame.FLAG_PADDED: + f.pad_length = struct.unpack('!B', payload[0])[0] + f.header_block_fragment = payload[1:-f.pad_length] + else: + f.header_block_fragment = payload[0:] + + if f.flags & Frame.FLAG_PRIORITY: + f.stream_dependency, f.weight = struct.unpack( + '!LB', f.header_block_fragment[:5]) + f.exclusive = bool(f.stream_dependency >> 31) + f.stream_dependency &= 0x7FFFFFFF + f.header_block_fragment = f.header_block_fragment[5:] + + return f + + def payload_bytes(self): + if self.stream_id == 0x0: + raise ValueError('HEADERS frames MUST be associated with a stream.') + + b = b'' + if self.flags & self.FLAG_PADDED: + b += struct.pack('!B', self.pad_length) + + if self.flags & self.FLAG_PRIORITY: + b += struct.pack('!LB', + (int(self.exclusive) << 31) | self.stream_dependency, + self.weight) + + b += self.header_block_fragment + + if self.flags & self.FLAG_PADDED: + b += b'\0' * self.pad_length + + return b + + def payload_human_readable(self): + s = [] + + if self.flags & self.FLAG_PRIORITY: + s.append("exclusive: %d" % self.exclusive) + s.append("stream dependency: %#x" % self.stream_dependency) + s.append("weight: %d" % self.weight) + + if self.flags & self.FLAG_PADDED: + s.append("padding: %d" % self.pad_length) + + s.append( + "header_block_fragment: %s" % + self.header_block_fragment.encode('hex')) + + return "\n".join(s) + + +class PriorityFrame(Frame): + TYPE = 0x2 + VALID_FLAGS = [] + + def __init__( + self, + state=None, + length=0, + flags=Frame.FLAG_NO_FLAGS, + stream_id=0x0, + exclusive=False, + stream_dependency=0x0, + weight=0): + super(PriorityFrame, self).__init__(state, length, flags, stream_id) + self.exclusive = exclusive + self.stream_dependency = stream_dependency + self.weight = weight + + @classmethod + def from_bytes(cls, state, length, flags, stream_id, payload): + f = cls(state=state, length=length, flags=flags, stream_id=stream_id) + + f.stream_dependency, f.weight = struct.unpack('!LB', payload) + f.exclusive = bool(f.stream_dependency >> 31) + f.stream_dependency &= 0x7FFFFFFF + + return f + + def payload_bytes(self): + if self.stream_id == 0x0: + raise ValueError( + 'PRIORITY frames MUST be associated with a stream.') + + return struct.pack( + '!LB', + (int( + self.exclusive) << 31) | self.stream_dependency, + self.weight) + + def payload_human_readable(self): + s = [] + s.append("exclusive: %d" % self.exclusive) + s.append("stream dependency: %#x" % self.stream_dependency) + s.append("weight: %d" % self.weight) + return "\n".join(s) + + +class RstStreamFrame(Frame): + TYPE = 0x3 + VALID_FLAGS = [] + + def __init__( + self, + state=None, + length=0, + flags=Frame.FLAG_NO_FLAGS, + stream_id=0x0, + error_code=0x0): + super(RstStreamFrame, self).__init__(state, length, flags, stream_id) + self.error_code = error_code + + @classmethod + def from_bytes(cls, state, length, flags, stream_id, payload): + f = cls(state=state, length=length, flags=flags, stream_id=stream_id) + f.error_code = struct.unpack('!L', payload)[0] + return f + + def payload_bytes(self): + if self.stream_id == 0x0: + raise ValueError( + 'RST_STREAM frames MUST be associated with a stream.') + + return struct.pack('!L', self.error_code) + + def payload_human_readable(self): + return "error code: %#x" % self.error_code + + +class SettingsFrame(Frame): + TYPE = 0x4 + VALID_FLAGS = [Frame.FLAG_ACK] + + SETTINGS = utils.BiDi( + SETTINGS_HEADER_TABLE_SIZE=0x1, + SETTINGS_ENABLE_PUSH=0x2, + SETTINGS_MAX_CONCURRENT_STREAMS=0x3, + SETTINGS_INITIAL_WINDOW_SIZE=0x4, + SETTINGS_MAX_FRAME_SIZE=0x5, + SETTINGS_MAX_HEADER_LIST_SIZE=0x6, + ) + + def __init__( + self, + state=None, + length=0, + flags=Frame.FLAG_NO_FLAGS, + stream_id=0x0, + settings=None): + super(SettingsFrame, self).__init__(state, length, flags, stream_id) + + if settings is None: + settings = {} + + self.settings = settings + + @classmethod + def from_bytes(cls, state, length, flags, stream_id, payload): + f = cls(state=state, length=length, flags=flags, stream_id=stream_id) + + for i in xrange(0, len(payload), 6): + identifier, value = struct.unpack("!HL", payload[i:i + 6]) + f.settings[identifier] = value + + return f + + def payload_bytes(self): + if self.stream_id != 0x0: + raise ValueError( + 'SETTINGS frames MUST NOT be associated with a stream.') + + b = b'' + for identifier, value in self.settings.items(): + b += struct.pack("!HL", identifier & 0xFF, value) + + return b + + def payload_human_readable(self): + s = [] + + for identifier, value in self.settings.items(): + s.append("%s: %#x" % (self.SETTINGS.get_name(identifier), value)) + + if not s: + return "settings: None" + else: + return "\n".join(s) + + +class PushPromiseFrame(Frame): + TYPE = 0x5 + VALID_FLAGS = [Frame.FLAG_END_HEADERS, Frame.FLAG_PADDED] + + def __init__( + self, + state=None, + length=0, + flags=Frame.FLAG_NO_FLAGS, + stream_id=0x0, + promised_stream=0x0, + header_block_fragment=b'', + pad_length=0): + super(PushPromiseFrame, self).__init__(state, length, flags, stream_id) + self.pad_length = pad_length + self.promised_stream = promised_stream + self.header_block_fragment = header_block_fragment + + @classmethod + def from_bytes(cls, state, length, flags, stream_id, payload): + f = cls(state=state, length=length, flags=flags, stream_id=stream_id) + + if f.flags & Frame.FLAG_PADDED: + f.pad_length, f.promised_stream = struct.unpack('!BL', payload[:5]) + f.header_block_fragment = payload[5:-f.pad_length] + else: + f.promised_stream = int(struct.unpack("!L", payload[:4])[0]) + f.header_block_fragment = payload[4:] + + f.promised_stream &= 0x7FFFFFFF + + return f + + def payload_bytes(self): + if self.stream_id == 0x0: + raise ValueError( + 'PUSH_PROMISE frames MUST be associated with a stream.') + + if self.promised_stream == 0x0: + raise ValueError('Promised stream id not valid.') + + b = b'' + if self.flags & self.FLAG_PADDED: + b += struct.pack('!B', self.pad_length) + + b += struct.pack('!L', self.promised_stream & 0x7FFFFFFF) + b += bytes(self.header_block_fragment) + + if self.flags & self.FLAG_PADDED: + b += b'\0' * self.pad_length + + return b + + def payload_human_readable(self): + s = [] + + if self.flags & self.FLAG_PADDED: + s.append("padding: %d" % self.pad_length) + + s.append("promised stream: %#x" % self.promised_stream) + s.append( + "header_block_fragment: %s" % + self.header_block_fragment.encode('hex')) + + return "\n".join(s) + + +class PingFrame(Frame): + TYPE = 0x6 + VALID_FLAGS = [Frame.FLAG_ACK] + + def __init__( + self, + state=None, + length=0, + flags=Frame.FLAG_NO_FLAGS, + stream_id=0x0, + payload=b''): + super(PingFrame, self).__init__(state, length, flags, stream_id) + self.payload = payload + + @classmethod + def from_bytes(cls, state, length, flags, stream_id, payload): + f = cls(state=state, length=length, flags=flags, stream_id=stream_id) + f.payload = payload + return f + + def payload_bytes(self): + if self.stream_id != 0x0: + raise ValueError( + 'PING frames MUST NOT be associated with a stream.') + + b = self.payload[0:8] + b += b'\0' * (8 - len(b)) + return b + + def payload_human_readable(self): + return "opaque data: %s" % str(self.payload) + + +class GoAwayFrame(Frame): + TYPE = 0x7 + VALID_FLAGS = [] + + def __init__( + self, + state=None, + length=0, + flags=Frame.FLAG_NO_FLAGS, + stream_id=0x0, + last_stream=0x0, + error_code=0x0, + data=b''): + super(GoAwayFrame, self).__init__(state, length, flags, stream_id) + self.last_stream = last_stream + self.error_code = error_code + self.data = data + + @classmethod + def from_bytes(cls, state, length, flags, stream_id, payload): + f = cls(state=state, length=length, flags=flags, stream_id=stream_id) + + f.last_stream, f.error_code = struct.unpack("!LL", payload[:8]) + f.last_stream &= 0x7FFFFFFF + f.data = payload[8:] + + return f + + def payload_bytes(self): + if self.stream_id != 0x0: + raise ValueError( + 'GOAWAY frames MUST NOT be associated with a stream.') + + b = struct.pack('!LL', self.last_stream & 0x7FFFFFFF, self.error_code) + b += bytes(self.data) + return b + + def payload_human_readable(self): + s = [] + s.append("last stream: %#x" % self.last_stream) + s.append("error code: %d" % self.error_code) + s.append("debug data: %s" % str(self.data)) + return "\n".join(s) + + +class WindowUpdateFrame(Frame): + TYPE = 0x8 + VALID_FLAGS = [] + + def __init__( + self, + state=None, + length=0, + flags=Frame.FLAG_NO_FLAGS, + stream_id=0x0, + window_size_increment=0x0): + super(WindowUpdateFrame, self).__init__(state, length, flags, stream_id) + self.window_size_increment = window_size_increment + + @classmethod + def from_bytes(cls, state, length, flags, stream_id, payload): + f = cls(state=state, length=length, flags=flags, stream_id=stream_id) + + f.window_size_increment = struct.unpack("!L", payload)[0] + f.window_size_increment &= 0x7FFFFFFF + + return f + + def payload_bytes(self): + if self.window_size_increment <= 0 or self.window_size_increment >= 2 ** 31: + raise ValueError( + 'Window Size Increment MUST be greater than 0 and less than 2^31.') + + return struct.pack('!L', self.window_size_increment & 0x7FFFFFFF) + + def payload_human_readable(self): + return "window size increment: %#x" % self.window_size_increment + + +class ContinuationFrame(Frame): + TYPE = 0x9 + VALID_FLAGS = [Frame.FLAG_END_HEADERS] + + def __init__( + self, + state=None, + length=0, + flags=Frame.FLAG_NO_FLAGS, + stream_id=0x0, + header_block_fragment=b''): + super(ContinuationFrame, self).__init__(state, length, flags, stream_id) + self.header_block_fragment = header_block_fragment + + @classmethod + def from_bytes(cls, state, length, flags, stream_id, payload): + f = cls(state=state, length=length, flags=flags, stream_id=stream_id) + f.header_block_fragment = payload + return f + + def payload_bytes(self): + if self.stream_id == 0x0: + raise ValueError( + 'CONTINUATION frames MUST be associated with a stream.') + + return self.header_block_fragment + + def payload_human_readable(self): + s = [] + s.append( + "header_block_fragment: %s" % + self.header_block_fragment.encode('hex')) + return "\n".join(s) + +_FRAME_CLASSES = [ + DataFrame, + HeadersFrame, + PriorityFrame, + RstStreamFrame, + SettingsFrame, + PushPromiseFrame, + PingFrame, + GoAwayFrame, + WindowUpdateFrame, + ContinuationFrame +] +FRAMES = {cls.TYPE: cls for cls in _FRAME_CLASSES} + + +HTTP2_DEFAULT_SETTINGS = { + SettingsFrame.SETTINGS.SETTINGS_HEADER_TABLE_SIZE: 4096, + SettingsFrame.SETTINGS.SETTINGS_ENABLE_PUSH: 1, + SettingsFrame.SETTINGS.SETTINGS_MAX_CONCURRENT_STREAMS: None, + SettingsFrame.SETTINGS.SETTINGS_INITIAL_WINDOW_SIZE: 2 ** 16 - 1, + SettingsFrame.SETTINGS.SETTINGS_MAX_FRAME_SIZE: 2 ** 14, + SettingsFrame.SETTINGS.SETTINGS_MAX_HEADER_LIST_SIZE: None, +} diff --git a/netlib/http/http2/protocol.py b/netlib/http/http2/protocol.py deleted file mode 100644 index b6d376d3..00000000 --- a/netlib/http/http2/protocol.py +++ /dev/null @@ -1,412 +0,0 @@ -from __future__ import (absolute_import, print_function, division) -import itertools -import time - -from hpack.hpack import Encoder, Decoder -from netlib import http, utils -from netlib.http import semantics -from . import frame - - -class TCPHandler(object): - - def __init__(self, rfile, wfile=None): - self.rfile = rfile - self.wfile = wfile - - -class HTTP2Protocol(semantics.ProtocolMixin): - - ERROR_CODES = utils.BiDi( - NO_ERROR=0x0, - PROTOCOL_ERROR=0x1, - INTERNAL_ERROR=0x2, - FLOW_CONTROL_ERROR=0x3, - SETTINGS_TIMEOUT=0x4, - STREAM_CLOSED=0x5, - FRAME_SIZE_ERROR=0x6, - REFUSED_STREAM=0x7, - CANCEL=0x8, - COMPRESSION_ERROR=0x9, - CONNECT_ERROR=0xa, - ENHANCE_YOUR_CALM=0xb, - INADEQUATE_SECURITY=0xc, - HTTP_1_1_REQUIRED=0xd - ) - - CLIENT_CONNECTION_PREFACE = "PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n" - - ALPN_PROTO_H2 = 'h2' - - def __init__( - self, - tcp_handler=None, - rfile=None, - wfile=None, - is_server=False, - dump_frames=False, - encoder=None, - decoder=None, - unhandled_frame_cb=None, - ): - self.tcp_handler = tcp_handler or TCPHandler(rfile, wfile) - self.is_server = is_server - self.dump_frames = dump_frames - self.encoder = encoder or Encoder() - self.decoder = decoder or Decoder() - self.unhandled_frame_cb = unhandled_frame_cb - - self.http2_settings = frame.HTTP2_DEFAULT_SETTINGS.copy() - self.current_stream_id = None - self.connection_preface_performed = False - - def read_request( - self, - include_body=True, - body_size_limit=None, - allow_empty=False, - ): - if body_size_limit is not None: - raise NotImplementedError() - - self.perform_connection_preface() - - timestamp_start = time.time() - if hasattr(self.tcp_handler.rfile, "reset_timestamps"): - self.tcp_handler.rfile.reset_timestamps() - - stream_id, headers, body = self._receive_transmission( - include_body=include_body, - ) - - if hasattr(self.tcp_handler.rfile, "first_byte_timestamp"): - # more accurate timestamp_start - timestamp_start = self.tcp_handler.rfile.first_byte_timestamp - - timestamp_end = time.time() - - authority = headers.get(':authority', '') - method = headers.get(':method', 'GET') - scheme = headers.get(':scheme', 'https') - path = headers.get(':path', '/') - host = None - port = None - - if path == '*' or path.startswith("/"): - form_in = "relative" - elif method == 'CONNECT': - form_in = "authority" - if ":" in authority: - host, port = authority.split(":", 1) - else: - host = authority - else: - form_in = "absolute" - # FIXME: verify if path or :host contains what we need - scheme, host, port, _ = utils.parse_url(path) - - if host is None: - host = 'localhost' - if port is None: - port = 80 if scheme == 'http' else 443 - port = int(port) - - request = http.Request( - form_in, - method, - scheme, - host, - port, - path, - (2, 0), - headers, - body, - timestamp_start, - timestamp_end, - ) - # FIXME: We should not do this. - request.stream_id = stream_id - - return request - - def read_response( - self, - request_method='', - body_size_limit=None, - include_body=True, - stream_id=None, - ): - if body_size_limit is not None: - raise NotImplementedError() - - self.perform_connection_preface() - - timestamp_start = time.time() - if hasattr(self.tcp_handler.rfile, "reset_timestamps"): - self.tcp_handler.rfile.reset_timestamps() - - stream_id, headers, body = self._receive_transmission( - stream_id=stream_id, - include_body=include_body, - ) - - if hasattr(self.tcp_handler.rfile, "first_byte_timestamp"): - # more accurate timestamp_start - timestamp_start = self.tcp_handler.rfile.first_byte_timestamp - - if include_body: - timestamp_end = time.time() - else: - timestamp_end = None - - response = http.Response( - (2, 0), - int(headers.get(':status', 502)), - "", - headers, - body, - timestamp_start=timestamp_start, - timestamp_end=timestamp_end, - ) - response.stream_id = stream_id - - return response - - def assemble_request(self, request): - assert isinstance(request, semantics.Request) - - authority = self.tcp_handler.sni if self.tcp_handler.sni else self.tcp_handler.address.host - if self.tcp_handler.address.port != 443: - authority += ":%d" % self.tcp_handler.address.port - - headers = request.headers.copy() - - if ':authority' not in headers: - headers.fields.insert(0, (':authority', bytes(authority))) - if ':scheme' not in headers: - headers.fields.insert(0, (':scheme', bytes(request.scheme))) - if ':path' not in headers: - headers.fields.insert(0, (':path', bytes(request.path))) - if ':method' not in headers: - headers.fields.insert(0, (':method', bytes(request.method))) - - if hasattr(request, 'stream_id'): - stream_id = request.stream_id - else: - stream_id = self._next_stream_id() - - return list(itertools.chain( - self._create_headers(headers, stream_id, end_stream=(request.body is None or len(request.body) == 0)), - self._create_body(request.body, stream_id))) - - def assemble_response(self, response): - assert isinstance(response, semantics.Response) - - headers = response.headers.copy() - - if ':status' not in headers: - headers.fields.insert(0, (':status', bytes(str(response.status_code)))) - - if hasattr(response, 'stream_id'): - stream_id = response.stream_id - else: - stream_id = self._next_stream_id() - - return list(itertools.chain( - self._create_headers(headers, stream_id, end_stream=(response.body is None or len(response.body) == 0)), - self._create_body(response.body, stream_id), - )) - - def perform_connection_preface(self, force=False): - if force or not self.connection_preface_performed: - if self.is_server: - self.perform_server_connection_preface(force) - else: - self.perform_client_connection_preface(force) - - def perform_server_connection_preface(self, force=False): - if force or not self.connection_preface_performed: - self.connection_preface_performed = True - - magic_length = len(self.CLIENT_CONNECTION_PREFACE) - magic = self.tcp_handler.rfile.safe_read(magic_length) - assert magic == self.CLIENT_CONNECTION_PREFACE - - frm = frame.SettingsFrame(state=self, settings={ - frame.SettingsFrame.SETTINGS.SETTINGS_ENABLE_PUSH: 0, - frame.SettingsFrame.SETTINGS.SETTINGS_MAX_CONCURRENT_STREAMS: 1, - }) - self.send_frame(frm, hide=True) - self._receive_settings(hide=True) - - def perform_client_connection_preface(self, force=False): - if force or not self.connection_preface_performed: - self.connection_preface_performed = True - - self.tcp_handler.wfile.write(self.CLIENT_CONNECTION_PREFACE) - - self.send_frame(frame.SettingsFrame(state=self), hide=True) - self._receive_settings(hide=True) # server announces own settings - self._receive_settings(hide=True) # server acks my settings - - def send_frame(self, frm, hide=False): - raw_bytes = frm.to_bytes() - self.tcp_handler.wfile.write(raw_bytes) - self.tcp_handler.wfile.flush() - if not hide and self.dump_frames: # pragma no cover - print(frm.human_readable(">>")) - - def read_frame(self, hide=False): - while True: - frm = frame.Frame.from_file(self.tcp_handler.rfile, self) - if not hide and self.dump_frames: # pragma no cover - print(frm.human_readable("<<")) - - if isinstance(frm, frame.PingFrame): - raw_bytes = frame.PingFrame(flags=frame.Frame.FLAG_ACK, payload=frm.payload).to_bytes() - self.tcp_handler.wfile.write(raw_bytes) - self.tcp_handler.wfile.flush() - continue - if isinstance(frm, frame.SettingsFrame) and not frm.flags & frame.Frame.FLAG_ACK: - self._apply_settings(frm.settings, hide) - if isinstance(frm, frame.DataFrame) and frm.length > 0: - self._update_flow_control_window(frm.stream_id, frm.length) - return frm - - def check_alpn(self): - alp = self.tcp_handler.get_alpn_proto_negotiated() - if alp != self.ALPN_PROTO_H2: - raise NotImplementedError( - "HTTP2Protocol can not handle unknown ALP: %s" % alp) - return True - - def _handle_unexpected_frame(self, frm): - if isinstance(frm, frame.SettingsFrame): - return - if self.unhandled_frame_cb: - self.unhandled_frame_cb(frm) - - def _receive_settings(self, hide=False): - while True: - frm = self.read_frame(hide) - if isinstance(frm, frame.SettingsFrame): - break - else: - self._handle_unexpected_frame(frm) - - def _next_stream_id(self): - if self.current_stream_id is None: - if self.is_server: - # servers must use even stream ids - self.current_stream_id = 2 - else: - # clients must use odd stream ids - self.current_stream_id = 1 - else: - self.current_stream_id += 2 - return self.current_stream_id - - def _apply_settings(self, settings, hide=False): - for setting, value in settings.items(): - old_value = self.http2_settings[setting] - if not old_value: - old_value = '-' - self.http2_settings[setting] = value - - frm = frame.SettingsFrame( - state=self, - flags=frame.Frame.FLAG_ACK) - self.send_frame(frm, hide) - - def _update_flow_control_window(self, stream_id, increment): - frm = frame.WindowUpdateFrame(stream_id=0, window_size_increment=increment) - self.send_frame(frm) - frm = frame.WindowUpdateFrame(stream_id=stream_id, window_size_increment=increment) - self.send_frame(frm) - - def _create_headers(self, headers, stream_id, end_stream=True): - def frame_cls(chunks): - for i in chunks: - if i == 0: - yield frame.HeadersFrame, i - else: - yield frame.ContinuationFrame, i - - header_block_fragment = self.encoder.encode(headers.fields) - - chunk_size = self.http2_settings[frame.SettingsFrame.SETTINGS.SETTINGS_MAX_FRAME_SIZE] - chunks = range(0, len(header_block_fragment), chunk_size) - frms = [frm_cls( - state=self, - flags=frame.Frame.FLAG_NO_FLAGS, - stream_id=stream_id, - header_block_fragment=header_block_fragment[i:i+chunk_size]) for frm_cls, i in frame_cls(chunks)] - - last_flags = frame.Frame.FLAG_END_HEADERS - if end_stream: - last_flags |= frame.Frame.FLAG_END_STREAM - frms[-1].flags = last_flags - - if self.dump_frames: # pragma no cover - for frm in frms: - print(frm.human_readable(">>")) - - return [frm.to_bytes() for frm in frms] - - def _create_body(self, body, stream_id): - if body is None or len(body) == 0: - return b'' - - chunk_size = self.http2_settings[frame.SettingsFrame.SETTINGS.SETTINGS_MAX_FRAME_SIZE] - chunks = range(0, len(body), chunk_size) - frms = [frame.DataFrame( - state=self, - flags=frame.Frame.FLAG_NO_FLAGS, - stream_id=stream_id, - payload=body[i:i+chunk_size]) for i in chunks] - frms[-1].flags = frame.Frame.FLAG_END_STREAM - - if self.dump_frames: # pragma no cover - for frm in frms: - print(frm.human_readable(">>")) - - return [frm.to_bytes() for frm in frms] - - def _receive_transmission(self, stream_id=None, include_body=True): - if not include_body: - raise NotImplementedError() - - body_expected = True - - header_block_fragment = b'' - body = b'' - - while True: - frm = self.read_frame() - if ( - (isinstance(frm, frame.HeadersFrame) or isinstance(frm, frame.ContinuationFrame)) and - (stream_id is None or frm.stream_id == stream_id) - ): - stream_id = frm.stream_id - header_block_fragment += frm.header_block_fragment - if frm.flags & frame.Frame.FLAG_END_STREAM: - body_expected = False - if frm.flags & frame.Frame.FLAG_END_HEADERS: - break - else: - self._handle_unexpected_frame(frm) - - while body_expected: - frm = self.read_frame() - if isinstance(frm, frame.DataFrame) and frm.stream_id == stream_id: - body += frm.payload - if frm.flags & frame.Frame.FLAG_END_STREAM: - break - else: - self._handle_unexpected_frame(frm) - - headers = http.Headers( - [[str(k), str(v)] for k, v in self.decoder.decode(header_block_fragment)] - ) - - return stream_id, headers, body diff --git a/netlib/http/models.py b/netlib/http/models.py new file mode 100644 index 00000000..bd5863b1 --- /dev/null +++ b/netlib/http/models.py @@ -0,0 +1,571 @@ +from __future__ import absolute_import, print_function, division +import copy + +from ..odict import ODict +from .. import utils, encoding +from ..utils import always_bytes, always_byte_args +from . import cookies + +import six +from six.moves import urllib +try: + from collections import MutableMapping +except ImportError: + from collections.abc import MutableMapping + +HDR_FORM_URLENCODED = b"application/x-www-form-urlencoded" +HDR_FORM_MULTIPART = b"multipart/form-data" + +CONTENT_MISSING = 0 + + +class Headers(MutableMapping, object): + """ + Header class which allows both convenient access to individual headers as well as + direct access to the underlying raw data. Provides a full dictionary interface. + + Example: + + .. code-block:: python + + # Create header from a list of (header_name, header_value) tuples + >>> h = Headers([ + ["Host","example.com"], + ["Accept","text/html"], + ["accept","application/xml"] + ]) + + # Headers mostly behave like a normal dict. + >>> h["Host"] + "example.com" + + # HTTP Headers are case insensitive + >>> h["host"] + "example.com" + + # Multiple headers are folded into a single header as per RFC7230 + >>> h["Accept"] + "text/html, application/xml" + + # Setting a header removes all existing headers with the same name. + >>> h["Accept"] = "application/text" + >>> h["Accept"] + "application/text" + + # str(h) returns a HTTP1 header block. + >>> print(h) + Host: example.com + Accept: application/text + + # For full control, the raw header fields can be accessed + >>> h.fields + + # Headers can also be crated from keyword arguments + >>> h = Headers(host="example.com", content_type="application/xml") + + Caveats: + For use with the "Set-Cookie" header, see :py:meth:`get_all`. + """ + + @always_byte_args("ascii") + def __init__(self, fields=None, **headers): + """ + Args: + fields: (optional) list of ``(name, value)`` header tuples, + e.g. ``[("Host","example.com")]``. All names and values must be bytes. + **headers: Additional headers to set. Will overwrite existing values from `fields`. + For convenience, underscores in header names will be transformed to dashes - + this behaviour does not extend to other methods. + If ``**headers`` contains multiple keys that have equal ``.lower()`` s, + the behavior is undefined. + """ + self.fields = fields or [] + + # content_type -> content-type + headers = { + name.encode("ascii").replace(b"_", b"-"): value + for name, value in six.iteritems(headers) + } + self.update(headers) + + def __bytes__(self): + return b"\r\n".join(b": ".join(field) for field in self.fields) + b"\r\n" + + if six.PY2: + __str__ = __bytes__ + + @always_byte_args("ascii") + def __getitem__(self, name): + values = self.get_all(name) + if not values: + raise KeyError(name) + return b", ".join(values) + + @always_byte_args("ascii") + def __setitem__(self, name, value): + idx = self._index(name) + + # To please the human eye, we insert at the same position the first existing header occured. + if idx is not None: + del self[name] + self.fields.insert(idx, [name, value]) + else: + self.fields.append([name, value]) + + @always_byte_args("ascii") + def __delitem__(self, name): + if name not in self: + raise KeyError(name) + name = name.lower() + self.fields = [ + field for field in self.fields + if name != field[0].lower() + ] + + def __iter__(self): + seen = set() + for name, _ in self.fields: + name_lower = name.lower() + if name_lower not in seen: + seen.add(name_lower) + yield name + + def __len__(self): + return len(set(name.lower() for name, _ in self.fields)) + + #__hash__ = object.__hash__ + + def _index(self, name): + name = name.lower() + for i, field in enumerate(self.fields): + if field[0].lower() == name: + return i + return None + + def __eq__(self, other): + if isinstance(other, Headers): + return self.fields == other.fields + return False + + def __ne__(self, other): + return not self.__eq__(other) + + @always_byte_args("ascii") + def get_all(self, name): + """ + Like :py:meth:`get`, but does not fold multiple headers into a single one. + This is useful for Set-Cookie headers, which do not support folding. + + See also: https://tools.ietf.org/html/rfc7230#section-3.2.2 + """ + name_lower = name.lower() + values = [value for n, value in self.fields if n.lower() == name_lower] + return values + + def set_all(self, name, values): + """ + Explicitly set multiple headers for the given key. + See: :py:meth:`get_all` + """ + name = always_bytes(name, "ascii") + values = (always_bytes(value, "ascii") for value in values) + if name in self: + del self[name] + self.fields.extend( + [name, value] for value in values + ) + + def copy(self): + return Headers(copy.copy(self.fields)) + + # Implement the StateObject protocol from mitmproxy + def get_state(self, short=False): + return tuple(tuple(field) for field in self.fields) + + def load_state(self, state): + self.fields = [list(field) for field in state] + + @classmethod + def from_state(cls, state): + return cls([list(field) for field in state]) + + +class Request(object): + # This list is adopted legacy code. + # We probably don't need to strip off keep-alive. + _headers_to_strip_off = [ + 'Proxy-Connection', + 'Keep-Alive', + 'Connection', + 'Transfer-Encoding', + 'Upgrade', + ] + + def __init__( + self, + form_in, + method, + scheme, + host, + port, + path, + httpversion, + headers=None, + body=None, + timestamp_start=None, + timestamp_end=None, + form_out=None + ): + if not headers: + headers = Headers() + assert isinstance(headers, Headers) + + self.form_in = form_in + self.method = method + self.scheme = scheme + self.host = host + self.port = port + self.path = path + self.httpversion = httpversion + self.headers = headers + self.body = body + self.timestamp_start = timestamp_start + self.timestamp_end = timestamp_end + self.form_out = form_out or form_in + + def __eq__(self, other): + try: + self_d = [self.__dict__[k] for k in self.__dict__ if + k not in ('timestamp_start', 'timestamp_end')] + other_d = [other.__dict__[k] for k in other.__dict__ if + k not in ('timestamp_start', 'timestamp_end')] + return self_d == other_d + except: + return False + + def __repr__(self): + if self.host and self.port: + hostport = "{}:{}".format(self.host, self.port) + else: + hostport = "" + path = self.path or "" + return "HTTPRequest({} {}{})".format( + self.method, hostport, path + ) + + def anticache(self): + """ + Modifies this request to remove headers that might produce a cached + response. That is, we remove ETags and If-Modified-Since headers. + """ + delheaders = [ + "if-modified-since", + "if-none-match", + ] + for i in delheaders: + self.headers.pop(i, None) + + def anticomp(self): + """ + Modifies this request to remove headers that will compress the + resource's data. + """ + self.headers["accept-encoding"] = "identity" + + def constrain_encoding(self): + """ + Limits the permissible Accept-Encoding values, based on what we can + decode appropriately. + """ + accept_encoding = self.headers.get("accept-encoding") + if accept_encoding: + self.headers["accept-encoding"] = ( + ', '.join( + e + for e in encoding.ENCODINGS + if e in accept_encoding + ) + ) + + def update_host_header(self): + """ + Update the host header to reflect the current target. + """ + self.headers["Host"] = self.host + + def get_form(self): + """ + Retrieves the URL-encoded or multipart form data, returning an ODict object. + Returns an empty ODict if there is no data or the content-type + indicates non-form data. + """ + if self.body: + if HDR_FORM_URLENCODED in self.headers.get("content-type","").lower(): + return self.get_form_urlencoded() + elif HDR_FORM_MULTIPART in self.headers.get("content-type","").lower(): + return self.get_form_multipart() + return ODict([]) + + def get_form_urlencoded(self): + """ + Retrieves the URL-encoded form data, returning an ODict object. + Returns an empty ODict if there is no data or the content-type + indicates non-form data. + """ + if self.body and HDR_FORM_URLENCODED in self.headers.get("content-type","").lower(): + return ODict(utils.urldecode(self.body)) + return ODict([]) + + def get_form_multipart(self): + if self.body and HDR_FORM_MULTIPART in self.headers.get("content-type","").lower(): + return ODict( + utils.multipartdecode( + self.headers, + self.body)) + return ODict([]) + + def set_form_urlencoded(self, odict): + """ + Sets the body to the URL-encoded form data, and adds the + appropriate content-type header. Note that this will destory the + existing body if there is one. + """ + # FIXME: If there's an existing content-type header indicating a + # url-encoded form, leave it alone. + self.headers["Content-Type"] = HDR_FORM_URLENCODED + self.body = utils.urlencode(odict.lst) + + def get_path_components(self): + """ + Returns the path components of the URL as a list of strings. + + Components are unquoted. + """ + _, _, path, _, _, _ = urllib.parse.urlparse(self.url) + return [urllib.parse.unquote(i) for i in path.split(b"/") if i] + + def set_path_components(self, lst): + """ + Takes a list of strings, and sets the path component of the URL. + + Components are quoted. + """ + lst = [urllib.parse.quote(i, safe="") for i in lst] + path = b"/" + b"/".join(lst) + scheme, netloc, _, params, query, fragment = urllib.parse.urlparse(self.url) + self.url = urllib.parse.urlunparse( + [scheme, netloc, path, params, query, fragment] + ) + + def get_query(self): + """ + Gets the request query string. Returns an ODict object. + """ + _, _, _, _, query, _ = urllib.parse.urlparse(self.url) + if query: + return ODict(utils.urldecode(query)) + return ODict([]) + + def set_query(self, odict): + """ + Takes an ODict object, and sets the request query string. + """ + scheme, netloc, path, params, _, fragment = urllib.parse.urlparse(self.url) + query = utils.urlencode(odict.lst) + self.url = urllib.parse.urlunparse( + [scheme, netloc, path, params, query, fragment] + ) + + def pretty_host(self, hostheader): + """ + Heuristic to get the host of the request. + + Note that pretty_host() does not always return the TCP destination + of the request, e.g. if an upstream proxy is in place + + If hostheader is set to True, the Host: header will be used as + additional (and preferred) data source. This is handy in + transparent mode, where only the IO of the destination is known, + but not the resolved name. This is disabled by default, as an + attacker may spoof the host header to confuse an analyst. + """ + if hostheader and b"Host" in self.headers: + try: + return self.headers[b"Host"].decode("idna") + except ValueError: + pass + if self.host: + return self.host.decode("idna") + + def pretty_url(self, hostheader): + if self.form_out == "authority": # upstream proxy mode + return "%s:%s" % (self.pretty_host(hostheader), self.port) + return utils.unparse_url(self.scheme, + self.pretty_host(hostheader), + self.port, + self.path).encode('ascii') + + def get_cookies(self): + """ + Returns a possibly empty netlib.odict.ODict object. + """ + ret = ODict() + for i in self.headers.get_all("cookie"): + ret.extend(cookies.parse_cookie_header(i)) + return ret + + def set_cookies(self, odict): + """ + Takes an netlib.odict.ODict object. Over-writes any existing Cookie + headers. + """ + v = cookies.format_cookie_header(odict) + self.headers["Cookie"] = v + + @property + def url(self): + """ + Returns a URL string, constructed from the Request's URL components. + """ + return utils.unparse_url( + self.scheme, + self.host, + self.port, + self.path + ).encode('ascii') + + @url.setter + def url(self, url): + """ + Parses a URL specification, and updates the Request's information + accordingly. + + Raises: + ValueError if the URL was invalid + """ + # TODO: Should handle incoming unicode here. + parts = utils.parse_url(url) + if not parts: + raise ValueError("Invalid URL: %s" % url) + self.scheme, self.host, self.port, self.path = parts + + @property + def content(self): # pragma: no cover + # TODO: remove deprecated getter + return self.body + + @content.setter + def content(self, content): # pragma: no cover + # TODO: remove deprecated setter + self.body = content + + +class Response(object): + _headers_to_strip_off = [ + 'Proxy-Connection', + 'Alternate-Protocol', + 'Alt-Svc', + ] + + def __init__( + self, + httpversion, + status_code, + msg=None, + headers=None, + body=None, + sslinfo=None, + timestamp_start=None, + timestamp_end=None, + ): + if not headers: + headers = Headers() + assert isinstance(headers, Headers) + + self.httpversion = httpversion + self.status_code = status_code + self.msg = msg + self.headers = headers + self.body = body + self.sslinfo = sslinfo + self.timestamp_start = timestamp_start + self.timestamp_end = timestamp_end + + def __eq__(self, other): + try: + self_d = [self.__dict__[k] for k in self.__dict__ if + k not in ('timestamp_start', 'timestamp_end')] + other_d = [other.__dict__[k] for k in other.__dict__ if + k not in ('timestamp_start', 'timestamp_end')] + return self_d == other_d + except: + return False + + def __repr__(self): + # return "Response(%s - %s)" % (self.status_code, self.msg) + + if self.body: + size = utils.pretty_size(len(self.body)) + else: + size = "content missing" + # TODO: Remove "(unknown content type, content missing)" edge-case + return "".format( + status_code=self.status_code, + msg=self.msg, + contenttype=self.headers.get("content-type", "unknown content type"), + size=size) + + def get_cookies(self): + """ + Get the contents of all Set-Cookie headers. + + Returns a possibly empty ODict, where keys are cookie name strings, + and values are [value, attr] lists. Value is a string, and attr is + an ODictCaseless containing cookie attributes. Within attrs, unary + attributes (e.g. HTTPOnly) are indicated by a Null value. + """ + ret = [] + for header in self.headers.get_all("set-cookie"): + v = cookies.parse_set_cookie_header(header) + if v: + name, value, attrs = v + ret.append([name, [value, attrs]]) + return ODict(ret) + + def set_cookies(self, odict): + """ + Set the Set-Cookie headers on this response, over-writing existing + headers. + + Accepts an ODict of the same format as that returned by get_cookies. + """ + values = [] + for i in odict.lst: + values.append( + cookies.format_set_cookie_header( + i[0], + i[1][0], + i[1][1] + ) + ) + self.headers.set_all("Set-Cookie", values) + + @property + def content(self): # pragma: no cover + # TODO: remove deprecated getter + return self.body + + @content.setter + def content(self, content): # pragma: no cover + # TODO: remove deprecated setter + self.body = content + + @property + def code(self): # pragma: no cover + # TODO: remove deprecated getter + return self.status_code + + @code.setter + def code(self, code): # pragma: no cover + # TODO: remove deprecated setter + self.status_code = code diff --git a/netlib/http/semantics.py b/netlib/http/semantics.py deleted file mode 100644 index 5bb098a7..00000000 --- a/netlib/http/semantics.py +++ /dev/null @@ -1,632 +0,0 @@ -from __future__ import (absolute_import, print_function, division) -import UserDict -import copy -import urllib -import urlparse - -from .. import odict -from . import cookies, exceptions -from netlib import utils, encoding - -HDR_FORM_URLENCODED = "application/x-www-form-urlencoded" -HDR_FORM_MULTIPART = "multipart/form-data" - -CONTENT_MISSING = 0 - - -class Headers(object, UserDict.DictMixin): - """ - Header class which allows both convenient access to individual headers as well as - direct access to the underlying raw data. Provides a full dictionary interface. - - Example: - - .. code-block:: python - - # Create header from a list of (header_name, header_value) tuples - >>> h = Headers([ - ["Host","example.com"], - ["Accept","text/html"], - ["accept","application/xml"] - ]) - - # Headers mostly behave like a normal dict. - >>> h["Host"] - "example.com" - - # HTTP Headers are case insensitive - >>> h["host"] - "example.com" - - # Multiple headers are folded into a single header as per RFC7230 - >>> h["Accept"] - "text/html, application/xml" - - # Setting a header removes all existing headers with the same name. - >>> h["Accept"] = "application/text" - >>> h["Accept"] - "application/text" - - # str(h) returns a HTTP1 header block. - >>> print(h) - Host: example.com - Accept: application/text - - # For full control, the raw header fields can be accessed - >>> h.fields - - # Headers can also be crated from keyword arguments - >>> h = Headers(host="example.com", content_type="application/xml") - - Caveats: - For use with the "Set-Cookie" header, see :py:meth:`get_all`. - """ - - def __init__(self, fields=None, **headers): - """ - Args: - fields: (optional) list of ``(name, value)`` header tuples, e.g. ``[("Host","example.com")]`` - **headers: Additional headers to set. Will overwrite existing values from `fields`. - For convenience, underscores in header names will be transformed to dashes - - this behaviour does not extend to other methods. - If ``**headers`` contains multiple keys that have equal ``.lower()`` s, - the behavior is undefined. - """ - self.fields = fields or [] - - # content_type -> content-type - headers = { - name.replace("_", "-"): value - for name, value in headers.iteritems() - } - self.update(headers) - - def __str__(self): - return "\r\n".join(": ".join(field) for field in self.fields) + "\r\n" - - def __getitem__(self, name): - values = self.get_all(name) - if not values: - raise KeyError(name) - else: - return ", ".join(values) - - def __setitem__(self, name, value): - idx = self._index(name) - - # To please the human eye, we insert at the same position the first existing header occured. - if idx is not None: - del self[name] - self.fields.insert(idx, [name, value]) - else: - self.fields.append([name, value]) - - def __delitem__(self, name): - if name not in self: - raise KeyError(name) - name = name.lower() - self.fields = [ - field for field in self.fields - if name != field[0].lower() - ] - - def _index(self, name): - name = name.lower() - for i, field in enumerate(self.fields): - if field[0].lower() == name: - return i - return None - - def keys(self): - seen = set() - names = [] - for name, _ in self.fields: - name_lower = name.lower() - if name_lower not in seen: - seen.add(name_lower) - names.append(name) - return names - - def __eq__(self, other): - if isinstance(other, Headers): - return self.fields == other.fields - return False - - def __ne__(self, other): - return not self.__eq__(other) - - def get_all(self, name): - """ - Like :py:meth:`get`, but does not fold multiple headers into a single one. - This is useful for Set-Cookie headers, which do not support folding. - - See also: https://tools.ietf.org/html/rfc7230#section-3.2.2 - """ - name = name.lower() - values = [value for n, value in self.fields if n.lower() == name] - return values - - def set_all(self, name, values): - """ - Explicitly set multiple headers for the given key. - See: :py:meth:`get_all` - """ - if name in self: - del self[name] - self.fields.extend( - [name, value] for value in values - ) - - def copy(self): - return Headers(copy.copy(self.fields)) - - # Implement the StateObject protocol from mitmproxy - def get_state(self, short=False): - return tuple(tuple(field) for field in self.fields) - - def load_state(self, state): - self.fields = [list(field) for field in state] - - @classmethod - def from_state(cls, state): - return cls([list(field) for field in state]) - - -class ProtocolMixin(object): - def read_request(self, *args, **kwargs): # pragma: no cover - raise NotImplementedError - - def read_response(self, *args, **kwargs): # pragma: no cover - raise NotImplementedError - - def assemble(self, message): - if isinstance(message, Request): - return self.assemble_request(message) - elif isinstance(message, Response): - return self.assemble_response(message) - else: - raise ValueError("HTTP message not supported.") - - def assemble_request(self, *args, **kwargs): # pragma: no cover - raise NotImplementedError - - def assemble_response(self, *args, **kwargs): # pragma: no cover - raise NotImplementedError - - -class Request(object): - # This list is adopted legacy code. - # We probably don't need to strip off keep-alive. - _headers_to_strip_off = [ - 'Proxy-Connection', - 'Keep-Alive', - 'Connection', - 'Transfer-Encoding', - 'Upgrade', - ] - - def __init__( - self, - form_in, - method, - scheme, - host, - port, - path, - httpversion, - headers=None, - body=None, - timestamp_start=None, - timestamp_end=None, - form_out=None - ): - if not headers: - headers = Headers() - assert isinstance(headers, Headers) - - self.form_in = form_in - self.method = method - self.scheme = scheme - self.host = host - self.port = port - self.path = path - self.httpversion = httpversion - self.headers = headers - self.body = body - self.timestamp_start = timestamp_start - self.timestamp_end = timestamp_end - self.form_out = form_out or form_in - - def __eq__(self, other): - try: - self_d = [self.__dict__[k] for k in self.__dict__ if - k not in ('timestamp_start', 'timestamp_end')] - other_d = [other.__dict__[k] for k in other.__dict__ if - k not in ('timestamp_start', 'timestamp_end')] - return self_d == other_d - except: - return False - - def __repr__(self): - # return "Request(%s - %s, %s)" % (self.method, self.host, self.path) - - return "".format( - self.legacy_first_line()[:-9] - ) - - def legacy_first_line(self, form=None): - if form is None: - form = self.form_out - if form == "relative": - return '%s %s HTTP/%s.%s' % ( - self.method, - self.path, - self.httpversion[0], - self.httpversion[1], - ) - elif form == "authority": - return '%s %s:%s HTTP/%s.%s' % ( - self.method, - self.host, - self.port, - self.httpversion[0], - self.httpversion[1], - ) - elif form == "absolute": - return '%s %s://%s:%s%s HTTP/%s.%s' % ( - self.method, - self.scheme, - self.host, - self.port, - self.path, - self.httpversion[0], - self.httpversion[1], - ) - else: - raise exceptions.HttpError(400, "Invalid request form") - - def anticache(self): - """ - Modifies this request to remove headers that might produce a cached - response. That is, we remove ETags and If-Modified-Since headers. - """ - delheaders = [ - "if-modified-since", - "if-none-match", - ] - for i in delheaders: - self.headers.pop(i, None) - - def anticomp(self): - """ - Modifies this request to remove headers that will compress the - resource's data. - """ - self.headers["accept-encoding"] = "identity" - - def constrain_encoding(self): - """ - Limits the permissible Accept-Encoding values, based on what we can - decode appropriately. - """ - accept_encoding = self.headers.get("accept-encoding") - if accept_encoding: - self.headers["accept-encoding"] = ( - ', '.join( - e - for e in encoding.ENCODINGS - if e in accept_encoding - ) - ) - - def update_host_header(self): - """ - Update the host header to reflect the current target. - """ - self.headers["Host"] = self.host - - def get_form(self): - """ - Retrieves the URL-encoded or multipart form data, returning an ODict object. - Returns an empty ODict if there is no data or the content-type - indicates non-form data. - """ - if self.body: - if HDR_FORM_URLENCODED in self.headers.get("content-type","").lower(): - return self.get_form_urlencoded() - elif HDR_FORM_MULTIPART in self.headers.get("content-type","").lower(): - return self.get_form_multipart() - return odict.ODict([]) - - def get_form_urlencoded(self): - """ - Retrieves the URL-encoded form data, returning an ODict object. - Returns an empty ODict if there is no data or the content-type - indicates non-form data. - """ - if self.body and HDR_FORM_URLENCODED in self.headers.get("content-type","").lower(): - return odict.ODict(utils.urldecode(self.body)) - return odict.ODict([]) - - def get_form_multipart(self): - if self.body and HDR_FORM_MULTIPART in self.headers.get("content-type","").lower(): - return odict.ODict( - utils.multipartdecode( - self.headers, - self.body)) - return odict.ODict([]) - - def set_form_urlencoded(self, odict): - """ - Sets the body to the URL-encoded form data, and adds the - appropriate content-type header. Note that this will destory the - existing body if there is one. - """ - # FIXME: If there's an existing content-type header indicating a - # url-encoded form, leave it alone. - self.headers["Content-Type"] = HDR_FORM_URLENCODED - self.body = utils.urlencode(odict.lst) - - def get_path_components(self): - """ - Returns the path components of the URL as a list of strings. - - Components are unquoted. - """ - _, _, path, _, _, _ = urlparse.urlparse(self.url) - return [urllib.unquote(i) for i in path.split("/") if i] - - def set_path_components(self, lst): - """ - Takes a list of strings, and sets the path component of the URL. - - Components are quoted. - """ - lst = [urllib.quote(i, safe="") for i in lst] - path = "/" + "/".join(lst) - scheme, netloc, _, params, query, fragment = urlparse.urlparse(self.url) - self.url = urlparse.urlunparse( - [scheme, netloc, path, params, query, fragment] - ) - - def get_query(self): - """ - Gets the request query string. Returns an ODict object. - """ - _, _, _, _, query, _ = urlparse.urlparse(self.url) - if query: - return odict.ODict(utils.urldecode(query)) - return odict.ODict([]) - - def set_query(self, odict): - """ - Takes an ODict object, and sets the request query string. - """ - scheme, netloc, path, params, _, fragment = urlparse.urlparse(self.url) - query = utils.urlencode(odict.lst) - self.url = urlparse.urlunparse( - [scheme, netloc, path, params, query, fragment] - ) - - def pretty_host(self, hostheader): - """ - Heuristic to get the host of the request. - - Note that pretty_host() does not always return the TCP destination - of the request, e.g. if an upstream proxy is in place - - If hostheader is set to True, the Host: header will be used as - additional (and preferred) data source. This is handy in - transparent mode, where only the IO of the destination is known, - but not the resolved name. This is disabled by default, as an - attacker may spoof the host header to confuse an analyst. - """ - host = None - if hostheader: - host = self.headers.get("Host") - if not host: - host = self.host - if host: - try: - return host.encode("idna") - except ValueError: - return host - else: - return None - - def pretty_url(self, hostheader): - if self.form_out == "authority": # upstream proxy mode - return "%s:%s" % (self.pretty_host(hostheader), self.port) - return utils.unparse_url(self.scheme, - self.pretty_host(hostheader), - self.port, - self.path).encode('ascii') - - def get_cookies(self): - """ - Returns a possibly empty netlib.odict.ODict object. - """ - ret = odict.ODict() - for i in self.headers.get_all("cookie"): - ret.extend(cookies.parse_cookie_header(i)) - return ret - - def set_cookies(self, odict): - """ - Takes an netlib.odict.ODict object. Over-writes any existing Cookie - headers. - """ - v = cookies.format_cookie_header(odict) - self.headers["Cookie"] = v - - @property - def url(self): - """ - Returns a URL string, constructed from the Request's URL components. - """ - return utils.unparse_url( - self.scheme, - self.host, - self.port, - self.path - ).encode('ascii') - - @url.setter - def url(self, url): - """ - Parses a URL specification, and updates the Request's information - accordingly. - - Returns False if the URL was invalid, True if the request succeeded. - """ - parts = utils.parse_url(url) - if not parts: - raise ValueError("Invalid URL: %s" % url) - self.scheme, self.host, self.port, self.path = parts - - @property - def content(self): # pragma: no cover - # TODO: remove deprecated getter - return self.body - - @content.setter - def content(self, content): # pragma: no cover - # TODO: remove deprecated setter - self.body = content - - -class EmptyRequest(Request): - def __init__( - self, - form_in="", - method="", - scheme="", - host="", - port="", - path="", - httpversion=(0, 0), - headers=None, - body="" - ): - super(EmptyRequest, self).__init__( - form_in=form_in, - method=method, - scheme=scheme, - host=host, - port=port, - path=path, - httpversion=httpversion, - headers=headers, - body=body, - ) - - -class Response(object): - _headers_to_strip_off = [ - 'Proxy-Connection', - 'Alternate-Protocol', - 'Alt-Svc', - ] - - def __init__( - self, - httpversion, - status_code, - msg=None, - headers=None, - body=None, - sslinfo=None, - timestamp_start=None, - timestamp_end=None, - ): - if not headers: - headers = Headers() - assert isinstance(headers, Headers) - - self.httpversion = httpversion - self.status_code = status_code - self.msg = msg - self.headers = headers - self.body = body - self.sslinfo = sslinfo - self.timestamp_start = timestamp_start - self.timestamp_end = timestamp_end - - def __eq__(self, other): - try: - self_d = [self.__dict__[k] for k in self.__dict__ if - k not in ('timestamp_start', 'timestamp_end')] - other_d = [other.__dict__[k] for k in other.__dict__ if - k not in ('timestamp_start', 'timestamp_end')] - return self_d == other_d - except: - return False - - def __repr__(self): - # return "Response(%s - %s)" % (self.status_code, self.msg) - - if self.body: - size = utils.pretty_size(len(self.body)) - else: - size = "content missing" - # TODO: Remove "(unknown content type, content missing)" edge-case - return "".format( - status_code=self.status_code, - msg=self.msg, - contenttype=self.headers.get("content-type", "unknown content type"), - size=size) - - def get_cookies(self): - """ - Get the contents of all Set-Cookie headers. - - Returns a possibly empty ODict, where keys are cookie name strings, - and values are [value, attr] lists. Value is a string, and attr is - an ODictCaseless containing cookie attributes. Within attrs, unary - attributes (e.g. HTTPOnly) are indicated by a Null value. - """ - ret = [] - for header in self.headers.get_all("set-cookie"): - v = cookies.parse_set_cookie_header(header) - if v: - name, value, attrs = v - ret.append([name, [value, attrs]]) - return odict.ODict(ret) - - def set_cookies(self, odict): - """ - Set the Set-Cookie headers on this response, over-writing existing - headers. - - Accepts an ODict of the same format as that returned by get_cookies. - """ - values = [] - for i in odict.lst: - values.append( - cookies.format_set_cookie_header( - i[0], - i[1][0], - i[1][1] - ) - ) - self.headers.set_all("Set-Cookie", values) - - @property - def content(self): # pragma: no cover - # TODO: remove deprecated getter - return self.body - - @content.setter - def content(self, content): # pragma: no cover - # TODO: remove deprecated setter - self.body = content - - @property - def code(self): # pragma: no cover - # TODO: remove deprecated getter - return self.status_code - - @code.setter - def code(self, code): # pragma: no cover - # TODO: remove deprecated setter - self.status_code = code diff --git a/netlib/tcp.py b/netlib/tcp.py index 4a7f6153..1eb417b4 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -834,14 +834,14 @@ class TCPServer(object): # If a thread has persisted after interpreter exit, the module might be # none. if traceback: - exc = traceback.format_exc() - print('-' * 40, file=fp) + exc = six.text_type(traceback.format_exc()) + print(u'-' * 40, file=fp) print( - "Error in processing of request from %s:%s" % ( + u"Error in processing of request from %s:%s" % ( client_address.host, client_address.port ), file=fp) print(exc, file=fp) - print('-' * 40, file=fp) + print(u'-' * 40, file=fp) def handle_client_connection(self, conn, client_address): # pragma: no cover """ diff --git a/netlib/tutils.py b/netlib/tutils.py index 951ef3d9..65c4a313 100644 --- a/netlib/tutils.py +++ b/netlib/tutils.py @@ -1,9 +1,11 @@ -import cStringIO +from io import BytesIO import tempfile import os import time import shutil from contextlib import contextmanager +import six +import sys from netlib import tcp, utils, http @@ -12,7 +14,7 @@ def treader(bytes): """ Construct a tcp.Read object from bytes. """ - fp = cStringIO.StringIO(bytes) + fp = BytesIO(bytes) return tcp.Reader(fp) @@ -28,7 +30,24 @@ def tmpdir(*args, **kwargs): shutil.rmtree(temp_workdir) -def raises(exc, obj, *args, **kwargs): +def _check_exception(expected, actual, exc_tb): + if isinstance(expected, six.string_types): + if expected.lower() not in str(actual).lower(): + six.reraise(AssertionError, AssertionError( + "Expected %s, but caught %s" % ( + repr(str(expected)), actual + ) + ), exc_tb) + else: + if not isinstance(actual, expected): + six.reraise(AssertionError, AssertionError( + "Expected %s, but caught %s %s" % ( + expected.__name__, actual.__class__.__name__, str(actual) + ) + ), exc_tb) + + +def raises(expected_exception, obj=None, *args, **kwargs): """ Assert that a callable raises a specified exception. @@ -43,28 +62,31 @@ def raises(exc, obj, *args, **kwargs): :kwargs Arguments to be passed to the callable. """ - try: - ret = obj(*args, **kwargs) - except Exception as v: - if isinstance(exc, basestring): - if exc.lower() in str(v).lower(): - return - else: - raise AssertionError( - "Expected %s, but caught %s" % ( - repr(str(exc)), v - ) - ) + if obj is None: + return RaisesContext(expected_exception) + else: + try: + ret = obj(*args, **kwargs) + except Exception as actual: + _check_exception(expected_exception, actual, sys.exc_info()[2]) else: - if isinstance(v, exc): - return - else: - raise AssertionError( - "Expected %s, but caught %s %s" % ( - exc.__name__, v.__class__.__name__, str(v) - ) - ) - raise AssertionError("No exception raised. Return value: {}".format(ret)) + raise AssertionError("No exception raised. Return value: {}".format(ret)) + + +class RaisesContext(object): + def __init__(self, expected_exception): + self.expected_exception = expected_exception + + def __enter__(self): + return + + def __exit__(self, exc_type, exc_val, exc_tb): + if not exc_type: + raise AssertionError("No exception raised.") + else: + _check_exception(self.expected_exception, exc_val, exc_tb) + return True + test_data = utils.Data(__name__) diff --git a/netlib/utils.py b/netlib/utils.py index d6774419..fb579cac 100644 --- a/netlib/utils.py +++ b/netlib/utils.py @@ -1,17 +1,17 @@ -from __future__ import (absolute_import, print_function, division) +from __future__ import absolute_import, print_function, division import os.path -import cgi -import urllib -import urlparse -import string import re -import six +import string import unicodedata +import six + +from six.moves import urllib + -def isascii(s): +def isascii(bytes): try: - s.decode("ascii") + bytes.decode("ascii") except ValueError: return False return True @@ -44,8 +44,8 @@ def clean_bin(s, keep_spacing=True): else: keep = b"" return b"".join( - ch if (31 < ord(ch) < 127 or ch in keep) else b"." - for ch in s + six.int2byte(ch) if (31 < ch < 127 or ch in keep) else b"." + for ch in six.iterbytes(s) ) @@ -149,10 +149,7 @@ class Data(object): return fullpath -def is_valid_port(port): - if not 0 <= port <= 65535: - return False - return True +_label_valid = re.compile(b"(?!-)[A-Z\d-]{1,63}(? 255: + return False + if host[-1] == ".": + host = host[:-1] + return all(_label_valid.match(x) for x in host.split(b".")) + + +def is_valid_port(port): + return 0 <= port <= 65535 + + +# PY2 workaround +def decode_parse_result(result, enc): + if hasattr(result, "decode"): + return result.decode(enc) + else: + return urllib.parse.ParseResult(*[x.decode(enc) for x in result]) + + +# PY2 workaround +def encode_parse_result(result, enc): + if hasattr(result, "encode"): + return result.encode(enc) + else: + return urllib.parse.ParseResult(*[x.encode(enc) for x in result]) def parse_url(url): """ - Returns a (scheme, host, port, path) tuple, or None on error. + URL-parsing function that checks that + - port is an integer 0-65535 + - host is a valid IDNA-encoded hostname with no null-bytes + - path is valid ASCII - Checks that: - port is an integer 0-65535 - host is a valid IDNA-encoded hostname with no null-bytes - path is valid ASCII + Args: + A URL (as bytes or as unicode) + + Returns: + A (scheme, host, port, path) tuple + + Raises: + ValueError, if the URL is not properly formatted. """ - try: - scheme, netloc, path, params, query, fragment = urlparse.urlparse(url) - except ValueError: - return None - if not scheme: - return None - if '@' in netloc: - # FIXME: Consider what to do with the discarded credentials here Most - # probably we should extend the signature to return these as a separate - # value. - _, netloc = string.rsplit(netloc, '@', maxsplit=1) - if ':' in netloc: - host, port = string.rsplit(netloc, ':', maxsplit=1) - try: - port = int(port) - except ValueError: - return None + parsed = urllib.parse.urlparse(url) + + if not parsed.hostname: + raise ValueError("No hostname given") + + if isinstance(url, six.binary_type): + host = parsed.hostname + + # this should not raise a ValueError + decode_parse_result(parsed, "ascii") else: - host = netloc - if scheme.endswith("https"): - port = 443 - else: - port = 80 - path = urlparse.urlunparse(('', '', path, params, query, fragment)) - if not path.startswith("/"): - path = "/" + path + host = parsed.hostname.encode("idna") + parsed = encode_parse_result(parsed, "ascii") + + port = parsed.port + if not port: + port = 443 if parsed.scheme == b"https" else 80 + + full_path = urllib.parse.urlunparse( + (b"", b"", parsed.path, parsed.params, parsed.query, parsed.fragment) + ) + if not full_path.startswith(b"/"): + full_path = b"/" + full_path + if not is_valid_host(host): - return None - if not isascii(path): - return None + raise ValueError("Invalid Host") if not is_valid_port(port): - return None - return scheme, host, port, path + raise ValueError("Invalid Port") + + return parsed.scheme, host, port, full_path def get_header_tokens(headers, key): @@ -217,7 +240,7 @@ def get_header_tokens(headers, key): """ if key not in headers: return [] - tokens = headers[key].split(",") + tokens = headers[key].split(b",") return [token.strip() for token in tokens] @@ -228,7 +251,7 @@ def hostport(scheme, host, port): if (port, scheme) in [(80, "http"), (443, "https")]: return host else: - return "%s:%s" % (host, port) + return b"%s:%s" % (host, port) def unparse_url(scheme, host, port, path=""): @@ -243,14 +266,14 @@ def urlencode(s): Takes a list of (key, value) tuples and returns a urlencoded string. """ s = [tuple(i) for i in s] - return urllib.urlencode(s, False) + return urllib.parse.urlencode(s, False) def urldecode(s): """ Takes a urlencoded string and returns a list of (key, value) tuples. """ - return cgi.parse_qsl(s, keep_blank_values=True) + return urllib.parse.parse_qsl(s, keep_blank_values=True) def parse_content_type(c): @@ -267,14 +290,14 @@ def parse_content_type(c): ("text", "html", {"charset": "UTF-8"}) """ - parts = c.split(";", 1) - ts = parts[0].split("/", 1) + parts = c.split(b";", 1) + ts = parts[0].split(b"/", 1) if len(ts) != 2: return None d = {} if len(parts) == 2: - for i in parts[1].split(";"): - clause = i.split("=", 1) + for i in parts[1].split(b";"): + clause = i.split(b"=", 1) if len(clause) == 2: d[clause[0].strip()] = clause[1].strip() return ts[0].lower(), ts[1].lower(), d @@ -289,7 +312,7 @@ def multipartdecode(headers, content): v = parse_content_type(v) if not v: return [] - boundary = v[2].get("boundary") + boundary = v[2].get(b"boundary") if not boundary: return [] @@ -306,3 +329,20 @@ def multipartdecode(headers, content): r.append((key, value)) return r return [] + + +def always_bytes(unicode_or_bytes, encoding): + if isinstance(unicode_or_bytes, six.text_type): + return unicode_or_bytes.encode(encoding) + return unicode_or_bytes + + +def always_byte_args(encoding): + """Decorator that transparently encodes all arguments passed as unicode""" + def decorator(fun): + def _fun(*args, **kwargs): + args = [always_bytes(arg, encoding) for arg in args] + kwargs = {k: always_bytes(v, encoding) for k, v in six.iteritems(kwargs)} + return fun(*args, **kwargs) + return _fun + return decorator diff --git a/netlib/version_check.py b/netlib/version_check.py index 1d7e025c..9cf27eea 100644 --- a/netlib/version_check.py +++ b/netlib/version_check.py @@ -7,6 +7,7 @@ from __future__ import division, absolute_import, print_function import sys import inspect import os.path +import six import OpenSSL from . import version @@ -19,8 +20,8 @@ def check_mitmproxy_version(mitmproxy_version, fp=sys.stderr): # consider major and minor version. if version.IVERSION[:2] != mitmproxy_version[:2]: print( - "You are using mitmproxy %s with netlib %s. " - "Most likely, that won't work - please upgrade!" % ( + u"You are using mitmproxy %s with netlib %s. " + u"Most likely, that won't work - please upgrade!" % ( mitmproxy_version, version.VERSION ), file=fp @@ -29,13 +30,13 @@ def check_mitmproxy_version(mitmproxy_version, fp=sys.stderr): def check_pyopenssl_version(min_version=PYOPENSSL_MIN_VERSION, fp=sys.stderr): - min_version_str = ".".join(str(x) for x in min_version) + min_version_str = u".".join(six.text_type(x) for x in min_version) try: v = tuple(int(x) for x in OpenSSL.__version__.split(".")[:2]) except ValueError: print( - "Cannot parse pyOpenSSL version: {}" - "mitmproxy requires pyOpenSSL {} or greater.".format( + u"Cannot parse pyOpenSSL version: {}" + u"mitmproxy requires pyOpenSSL {} or greater.".format( OpenSSL.__version__, min_version_str ), file=fp @@ -43,15 +44,15 @@ def check_pyopenssl_version(min_version=PYOPENSSL_MIN_VERSION, fp=sys.stderr): return if v < min_version: print( - "You are using an outdated version of pyOpenSSL: " - "mitmproxy requires pyOpenSSL {} or greater.".format(min_version_str), + u"You are using an outdated version of pyOpenSSL: " + u"mitmproxy requires pyOpenSSL {} or greater.".format(min_version_str), file=fp ) # Some users apparently have multiple versions of pyOpenSSL installed. # Report which one we got. pyopenssl_path = os.path.dirname(inspect.getfile(OpenSSL)) print( - "Your pyOpenSSL {} installation is located at {}".format( + u"Your pyOpenSSL {} installation is located at {}".format( OpenSSL.__version__, pyopenssl_path ), file=fp diff --git a/netlib/websockets/__init__.py b/netlib/websockets/__init__.py index 5acf7696..1c143919 100644 --- a/netlib/websockets/__init__.py +++ b/netlib/websockets/__init__.py @@ -1,2 +1,2 @@ -from frame import * -from protocol import * +from .frame import * +from .protocol import * -- cgit v1.2.3 From a077d8877d210562f703c23e9625e8467c81222d Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Wed, 16 Sep 2015 00:04:23 +0200 Subject: finish netlib.http.http1 refactor --- netlib/http/__init__.py | 6 +- netlib/http/http1/__init__.py | 4 +- netlib/http/http1/assemble.py | 8 +- netlib/http/http1/read.py | 152 ++++----- netlib/http/http2/connections.py | 4 +- netlib/http/http2/frame.py | 654 +++++++++++++++++++++++++++++++++++++++ netlib/http/http2/frames.py | 633 ------------------------------------- netlib/http/models.py | 2 - netlib/tutils.py | 74 ++--- netlib/utils.py | 6 +- 10 files changed, 779 insertions(+), 764 deletions(-) create mode 100644 netlib/http/http2/frame.py delete mode 100644 netlib/http/http2/frames.py (limited to 'netlib') diff --git a/netlib/http/__init__.py b/netlib/http/__init__.py index 0b1a0bc5..9303de09 100644 --- a/netlib/http/__init__.py +++ b/netlib/http/__init__.py @@ -1,7 +1,9 @@ -from .models import Request, Response, Headers, CONTENT_MISSING +from .models import Request, Response, Headers +from .models import HDR_FORM_MULTIPART, HDR_FORM_URLENCODED, CONTENT_MISSING from . import http1, http2 __all__ = [ - "Request", "Response", "Headers", "CONTENT_MISSING" + "Request", "Response", "Headers", + "HDR_FORM_MULTIPART", "HDR_FORM_URLENCODED", "CONTENT_MISSING", "http1", "http2" ] diff --git a/netlib/http/http1/__init__.py b/netlib/http/http1/__init__.py index 4d223f97..a72c2e05 100644 --- a/netlib/http/http1/__init__.py +++ b/netlib/http/http1/__init__.py @@ -1,7 +1,7 @@ from .read import ( read_request, read_request_head, read_response, read_response_head, - read_message_body, read_message_body_chunked, + read_body, connection_close, expected_http_body_size, ) @@ -14,7 +14,7 @@ from .assemble import ( __all__ = [ "read_request", "read_request_head", "read_response", "read_response_head", - "read_message_body", "read_message_body_chunked", + "read_body", "connection_close", "expected_http_body_size", "assemble_request", "assemble_request_head", diff --git a/netlib/http/http1/assemble.py b/netlib/http/http1/assemble.py index a3269eed..47c7e95a 100644 --- a/netlib/http/http1/assemble.py +++ b/netlib/http/http1/assemble.py @@ -31,8 +31,6 @@ def assemble_response_head(response): return b"%s\r\n%s\r\n" % (first_line, headers) - - def _assemble_request_line(request, form=None): if form is None: form = request.form_out @@ -50,7 +48,7 @@ def _assemble_request_line(request, form=None): request.httpversion ) elif form == "absolute": - return b"%s %s://%s:%s%s %s" % ( + return b"%s %s://%s:%d%s %s" % ( request.method, request.scheme, request.host, @@ -78,11 +76,11 @@ def _assemble_request_headers(request): if request.body or request.body == b"": headers[b"Content-Length"] = str(len(request.body)).encode("ascii") - return str(headers) + return bytes(headers) def _assemble_response_line(response): - return b"%s %s %s" % ( + return b"%s %d %s" % ( response.httpversion, response.status_code, response.msg, diff --git a/netlib/http/http1/read.py b/netlib/http/http1/read.py index 573bc739..4c423c4c 100644 --- a/netlib/http/http1/read.py +++ b/netlib/http/http1/read.py @@ -7,12 +7,13 @@ from ... import utils from ...exceptions import HttpReadDisconnect, HttpSyntaxException, HttpException from .. import Request, Response, Headers -ALPN_PROTO_HTTP1 = 'http/1.1' +ALPN_PROTO_HTTP1 = b'http/1.1' def read_request(rfile, body_size_limit=None): request = read_request_head(rfile) - request.body = read_message_body(rfile, request, limit=body_size_limit) + expected_body_size = expected_http_body_size(request) + request.body = b"".join(read_body(rfile, expected_body_size, limit=body_size_limit)) request.timestamp_end = time.time() return request @@ -23,15 +24,14 @@ def read_request_head(rfile): Args: rfile: The input stream - body_size_limit (bool): Maximum body size Returns: - The HTTP request object + The HTTP request object (without body) Raises: - HttpReadDisconnect: If no bytes can be read from rfile. - HttpSyntaxException: If the input is invalid. - HttpException: A different error occured. + HttpReadDisconnect: No bytes can be read from rfile. + HttpSyntaxException: The input is malformed HTTP. + HttpException: Any other error occured. """ timestamp_start = time.time() if hasattr(rfile, "reset_timestamps"): @@ -51,12 +51,28 @@ def read_request_head(rfile): def read_response(rfile, request, body_size_limit=None): response = read_response_head(rfile) - response.body = read_message_body(rfile, request, response, body_size_limit) + expected_body_size = expected_http_body_size(request, response) + response.body = b"".join(read_body(rfile, expected_body_size, body_size_limit)) response.timestamp_end = time.time() return response def read_response_head(rfile): + """ + Parse an HTTP response head (response line + headers) from an input stream + + Args: + rfile: The input stream + + Returns: + The HTTP request object (without body) + + Raises: + HttpReadDisconnect: No bytes can be read from rfile. + HttpSyntaxException: The input is malformed HTTP. + HttpException: Any other error occured. + """ + timestamp_start = time.time() if hasattr(rfile, "reset_timestamps"): rfile.reset_timestamps() @@ -68,50 +84,33 @@ def read_response_head(rfile): # more accurate timestamp_start timestamp_start = rfile.first_byte_timestamp - return Response( - http_version, - status_code, - message, - headers, - None, - timestamp_start - ) - - -def read_message_body(*args, **kwargs): - chunks = read_message_body_chunked(*args, **kwargs) - return b"".join(chunks) + return Response(http_version, status_code, message, headers, None, timestamp_start) -def read_message_body_chunked(rfile, request, response=None, limit=None, max_chunk_size=None): +def read_body(rfile, expected_size, limit=None, max_chunk_size=4096): """ - Read an HTTP message body: + Read an HTTP message body Args: - If a request body should be read, only request should be passed. - If a response body should be read, both request and response should be passed. + rfile: The input stream + expected_size: The expected body size (see :py:meth:`expected_body_size`) + limit: Maximum body size + max_chunk_size: Maximium chunk size that gets yielded + + Returns: + A generator that yields byte chunks of the content. Raises: - HttpException - """ - if not response: - headers = request.headers - response_code = None - is_request = True - else: - headers = response.headers - response_code = response.status_code - is_request = False + HttpException, if an error occurs + Caveats: + max_chunk_size is not considered if the transfer encoding is chunked. + """ if not limit or limit < 0: limit = sys.maxsize if not max_chunk_size: max_chunk_size = limit - expected_size = expected_http_body_size( - headers, is_request, request.method, response_code - ) - if expected_size is None: for x in _read_chunked(rfile, limit): yield x @@ -125,6 +124,8 @@ def read_message_body_chunked(rfile, request, response=None, limit=None, max_chu while bytes_left: chunk_size = min(bytes_left, max_chunk_size) content = rfile.read(chunk_size) + if len(content) < chunk_size: + raise HttpException("Unexpected EOF") yield content bytes_left -= chunk_size else: @@ -148,10 +149,10 @@ def connection_close(http_version, headers): """ # At first, check if we have an explicit Connection header. if b"connection" in headers: - toks = utils.get_header_tokens(headers, "connection") - if b"close" in toks: + tokens = utils.get_header_tokens(headers, "connection") + if b"close" in tokens: return True - elif b"keep-alive" in toks: + elif b"keep-alive" in tokens: return False # If we don't have a Connection header, HTTP 1.1 connections are assumed to @@ -159,37 +160,41 @@ def connection_close(http_version, headers): return http_version != (1, 1) -def expected_http_body_size( - headers, - is_request, - request_method, - response_code, -): +def expected_http_body_size(request, response=False): """ - Returns the expected body length: - - a positive integer, if the size is known in advance - - None, if the size in unknown in advance (chunked encoding) - - -1, if all data should be read until end of stream. + Returns: + The expected body length: + - a positive integer, if the size is known in advance + - None, if the size in unknown in advance (chunked encoding) + - -1, if all data should be read until end of stream. Raises: HttpSyntaxException, if the content length header is invalid """ # Determine response size according to # http://tools.ietf.org/html/rfc7230#section-3.3 - if request_method: - request_method = request_method.upper() + if not response: + headers = request.headers + response_code = None + is_request = True + else: + headers = response.headers + response_code = response.status_code + is_request = False - is_empty_response = (not is_request and ( - request_method == b"HEAD" or - 100 <= response_code <= 199 or - (response_code == 200 and request_method == b"CONNECT") or - response_code in (204, 304) - )) + if is_request: + if headers.get(b"expect", b"").lower() == b"100-continue": + return 0 + else: + if request.method.upper() == b"HEAD": + return 0 + if 100 <= response_code <= 199: + return 0 + if response_code == 200 and request.method.upper() == b"CONNECT": + return 0 + if response_code in (204, 304): + return 0 - if is_empty_response: - return 0 - if is_request and headers.get(b"expect", b"").lower() == b"100-continue": - return 0 if b"chunked" in headers.get(b"transfer-encoding", b"").lower(): return None if b"content-length" in headers: @@ -212,18 +217,22 @@ def _get_first_line(rfile): line = rfile.readline() if not line: raise HttpReadDisconnect() - return line + line = line.strip() + try: + line.decode("ascii") + except ValueError: + raise HttpSyntaxException("Non-ascii characters in first line: {}".format(line)) + return line.strip() def _read_request_line(rfile): line = _get_first_line(rfile) try: - method, path, http_version = line.strip().split(b" ") + method, path, http_version = line.split(b" ") if path == b"*" or path.startswith(b"/"): form = "relative" - path.decode("ascii") # should not raise a ValueError scheme, host, port = None, None, None elif method == b"CONNECT": form = "authority" @@ -233,6 +242,7 @@ def _read_request_line(rfile): form = "absolute" scheme, host, port, path = utils.parse_url(path) + _check_http_version(http_version) except ValueError: raise HttpSyntaxException("Bad HTTP request line: {}".format(line)) @@ -253,7 +263,7 @@ def _parse_authority_form(hostport): if not utils.is_valid_host(host) or not utils.is_valid_port(port): raise ValueError() except ValueError: - raise ValueError("Invalid host specification: {}".format(hostport)) + raise HttpSyntaxException("Invalid host specification: {}".format(hostport)) return host, port @@ -263,7 +273,7 @@ def _read_response_line(rfile): try: - parts = line.strip().split(b" ") + parts = line.split(b" ", 2) if len(parts) == 2: # handle missing message gracefully parts.append(b"") @@ -278,7 +288,7 @@ def _read_response_line(rfile): def _check_http_version(http_version): - if not re.match(rb"^HTTP/\d\.\d$", http_version): + if not re.match(br"^HTTP/\d\.\d$", http_version): raise HttpSyntaxException("Unknown HTTP version: {}".format(http_version)) @@ -313,7 +323,7 @@ def _read_headers(rfile): return Headers(ret) -def _read_chunked(rfile, limit): +def _read_chunked(rfile, limit=sys.maxsize): """ Read a HTTP body with chunked transfer encoding. diff --git a/netlib/http/http2/connections.py b/netlib/http/http2/connections.py index b6d376d3..036bf68f 100644 --- a/netlib/http/http2/connections.py +++ b/netlib/http/http2/connections.py @@ -4,7 +4,7 @@ import time from hpack.hpack import Encoder, Decoder from netlib import http, utils -from netlib.http import semantics +from netlib.http import models as semantics from . import frame @@ -15,7 +15,7 @@ class TCPHandler(object): self.wfile = wfile -class HTTP2Protocol(semantics.ProtocolMixin): +class HTTP2Protocol(object): ERROR_CODES = utils.BiDi( NO_ERROR=0x0, diff --git a/netlib/http/http2/frame.py b/netlib/http/http2/frame.py new file mode 100644 index 00000000..cb2cde99 --- /dev/null +++ b/netlib/http/http2/frame.py @@ -0,0 +1,654 @@ +from __future__ import absolute_import, print_function, division +import struct +from hpack.hpack import Encoder, Decoder + +from ...utils import BiDi +from ...exceptions import HttpSyntaxException + + +ERROR_CODES = BiDi( + NO_ERROR=0x0, + PROTOCOL_ERROR=0x1, + INTERNAL_ERROR=0x2, + FLOW_CONTROL_ERROR=0x3, + SETTINGS_TIMEOUT=0x4, + STREAM_CLOSED=0x5, + FRAME_SIZE_ERROR=0x6, + REFUSED_STREAM=0x7, + CANCEL=0x8, + COMPRESSION_ERROR=0x9, + CONNECT_ERROR=0xa, + ENHANCE_YOUR_CALM=0xb, + INADEQUATE_SECURITY=0xc, + HTTP_1_1_REQUIRED=0xd +) + +CLIENT_CONNECTION_PREFACE = b"PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n" + +ALPN_PROTO_H2 = b'h2' + + +class Frame(object): + + """ + Baseclass Frame + contains header + payload is defined in subclasses + """ + + FLAG_NO_FLAGS = 0x0 + FLAG_ACK = 0x1 + FLAG_END_STREAM = 0x1 + FLAG_END_HEADERS = 0x4 + FLAG_PADDED = 0x8 + FLAG_PRIORITY = 0x20 + + def __init__( + self, + state=None, + length=0, + flags=FLAG_NO_FLAGS, + stream_id=0x0): + valid_flags = 0 + for flag in self.VALID_FLAGS: + valid_flags |= flag + if flags | valid_flags != valid_flags: + raise ValueError('invalid flags detected.') + + if state is None: + class State(object): + pass + + state = State() + state.http2_settings = HTTP2_DEFAULT_SETTINGS.copy() + state.encoder = Encoder() + state.decoder = Decoder() + + self.state = state + + self.length = length + self.type = self.TYPE + self.flags = flags + self.stream_id = stream_id + + @classmethod + def _check_frame_size(cls, length, state): + if state: + settings = state.http2_settings + else: + settings = HTTP2_DEFAULT_SETTINGS.copy() + + max_frame_size = settings[ + SettingsFrame.SETTINGS.SETTINGS_MAX_FRAME_SIZE] + + if length > max_frame_size: + raise HttpSyntaxException( + "Frame size exceeded: %d, but only %d allowed." % ( + length, max_frame_size)) + + @classmethod + def from_file(cls, fp, state=None): + """ + read a HTTP/2 frame sent by a server or client + fp is a "file like" object that could be backed by a network + stream or a disk or an in memory stream reader + """ + raw_header = fp.safe_read(9) + + fields = struct.unpack("!HBBBL", raw_header) + length = (fields[0] << 8) + fields[1] + flags = fields[3] + stream_id = fields[4] + + if raw_header[:4] == b'HTTP': # pragma no cover + raise HttpSyntaxException("Expected HTTP2 Frame, got HTTP/1 connection") + + cls._check_frame_size(length, state) + + payload = fp.safe_read(length) + return FRAMES[fields[2]].from_bytes( + state, + length, + flags, + stream_id, + payload) + + def to_bytes(self): + payload = self.payload_bytes() + self.length = len(payload) + + self._check_frame_size(self.length, self.state) + + b = struct.pack('!HB', (self.length & 0xFFFF00) >> 8, self.length & 0x0000FF) + b += struct.pack('!B', self.TYPE) + b += struct.pack('!B', self.flags) + b += struct.pack('!L', self.stream_id & 0x7FFFFFFF) + b += payload + + return b + + def payload_bytes(self): # pragma: no cover + raise NotImplementedError() + + def payload_human_readable(self): # pragma: no cover + raise NotImplementedError() + + def human_readable(self, direction="-"): + self.length = len(self.payload_bytes()) + + return "\n".join([ + "%s: %s | length: %d | flags: %#x | stream_id: %d" % ( + direction, self.__class__.__name__, self.length, self.flags, self.stream_id), + self.payload_human_readable(), + "===============================================================", + ]) + + def __eq__(self, other): + return self.to_bytes() == other.to_bytes() + + +class DataFrame(Frame): + TYPE = 0x0 + VALID_FLAGS = [Frame.FLAG_END_STREAM, Frame.FLAG_PADDED] + + def __init__( + self, + state=None, + length=0, + flags=Frame.FLAG_NO_FLAGS, + stream_id=0x0, + payload=b'', + pad_length=0): + super(DataFrame, self).__init__(state, length, flags, stream_id) + self.payload = payload + self.pad_length = pad_length + + @classmethod + def from_bytes(cls, state, length, flags, stream_id, payload): + f = cls(state=state, length=length, flags=flags, stream_id=stream_id) + + if f.flags & Frame.FLAG_PADDED: + f.pad_length = struct.unpack('!B', payload[0])[0] + f.payload = payload[1:-f.pad_length] + else: + f.payload = payload + + return f + + def payload_bytes(self): + if self.stream_id == 0x0: + raise ValueError('DATA frames MUST be associated with a stream.') + + b = b'' + if self.flags & self.FLAG_PADDED: + b += struct.pack('!B', self.pad_length) + + b += bytes(self.payload) + + if self.flags & self.FLAG_PADDED: + b += b'\0' * self.pad_length + + return b + + def payload_human_readable(self): + return "payload: %s" % str(self.payload) + + +class HeadersFrame(Frame): + TYPE = 0x1 + VALID_FLAGS = [ + Frame.FLAG_END_STREAM, + Frame.FLAG_END_HEADERS, + Frame.FLAG_PADDED, + Frame.FLAG_PRIORITY] + + def __init__( + self, + state=None, + length=0, + flags=Frame.FLAG_NO_FLAGS, + stream_id=0x0, + header_block_fragment=b'', + pad_length=0, + exclusive=False, + stream_dependency=0x0, + weight=0): + super(HeadersFrame, self).__init__(state, length, flags, stream_id) + + self.header_block_fragment = header_block_fragment + self.pad_length = pad_length + self.exclusive = exclusive + self.stream_dependency = stream_dependency + self.weight = weight + + @classmethod + def from_bytes(cls, state, length, flags, stream_id, payload): + f = cls(state=state, length=length, flags=flags, stream_id=stream_id) + + if f.flags & Frame.FLAG_PADDED: + f.pad_length = struct.unpack('!B', payload[0])[0] + f.header_block_fragment = payload[1:-f.pad_length] + else: + f.header_block_fragment = payload[0:] + + if f.flags & Frame.FLAG_PRIORITY: + f.stream_dependency, f.weight = struct.unpack( + '!LB', f.header_block_fragment[:5]) + f.exclusive = bool(f.stream_dependency >> 31) + f.stream_dependency &= 0x7FFFFFFF + f.header_block_fragment = f.header_block_fragment[5:] + + return f + + def payload_bytes(self): + if self.stream_id == 0x0: + raise ValueError('HEADERS frames MUST be associated with a stream.') + + b = b'' + if self.flags & self.FLAG_PADDED: + b += struct.pack('!B', self.pad_length) + + if self.flags & self.FLAG_PRIORITY: + b += struct.pack('!LB', + (int(self.exclusive) << 31) | self.stream_dependency, + self.weight) + + b += self.header_block_fragment + + if self.flags & self.FLAG_PADDED: + b += b'\0' * self.pad_length + + return b + + def payload_human_readable(self): + s = [] + + if self.flags & self.FLAG_PRIORITY: + s.append("exclusive: %d" % self.exclusive) + s.append("stream dependency: %#x" % self.stream_dependency) + s.append("weight: %d" % self.weight) + + if self.flags & self.FLAG_PADDED: + s.append("padding: %d" % self.pad_length) + + s.append( + "header_block_fragment: %s" % + self.header_block_fragment.encode('hex')) + + return "\n".join(s) + + +class PriorityFrame(Frame): + TYPE = 0x2 + VALID_FLAGS = [] + + def __init__( + self, + state=None, + length=0, + flags=Frame.FLAG_NO_FLAGS, + stream_id=0x0, + exclusive=False, + stream_dependency=0x0, + weight=0): + super(PriorityFrame, self).__init__(state, length, flags, stream_id) + self.exclusive = exclusive + self.stream_dependency = stream_dependency + self.weight = weight + + @classmethod + def from_bytes(cls, state, length, flags, stream_id, payload): + f = cls(state=state, length=length, flags=flags, stream_id=stream_id) + + f.stream_dependency, f.weight = struct.unpack('!LB', payload) + f.exclusive = bool(f.stream_dependency >> 31) + f.stream_dependency &= 0x7FFFFFFF + + return f + + def payload_bytes(self): + if self.stream_id == 0x0: + raise ValueError( + 'PRIORITY frames MUST be associated with a stream.') + + return struct.pack( + '!LB', + (int( + self.exclusive) << 31) | self.stream_dependency, + self.weight) + + def payload_human_readable(self): + s = [] + s.append("exclusive: %d" % self.exclusive) + s.append("stream dependency: %#x" % self.stream_dependency) + s.append("weight: %d" % self.weight) + return "\n".join(s) + + +class RstStreamFrame(Frame): + TYPE = 0x3 + VALID_FLAGS = [] + + def __init__( + self, + state=None, + length=0, + flags=Frame.FLAG_NO_FLAGS, + stream_id=0x0, + error_code=0x0): + super(RstStreamFrame, self).__init__(state, length, flags, stream_id) + self.error_code = error_code + + @classmethod + def from_bytes(cls, state, length, flags, stream_id, payload): + f = cls(state=state, length=length, flags=flags, stream_id=stream_id) + f.error_code = struct.unpack('!L', payload)[0] + return f + + def payload_bytes(self): + if self.stream_id == 0x0: + raise ValueError( + 'RST_STREAM frames MUST be associated with a stream.') + + return struct.pack('!L', self.error_code) + + def payload_human_readable(self): + return "error code: %#x" % self.error_code + + +class SettingsFrame(Frame): + TYPE = 0x4 + VALID_FLAGS = [Frame.FLAG_ACK] + + SETTINGS = BiDi( + SETTINGS_HEADER_TABLE_SIZE=0x1, + SETTINGS_ENABLE_PUSH=0x2, + SETTINGS_MAX_CONCURRENT_STREAMS=0x3, + SETTINGS_INITIAL_WINDOW_SIZE=0x4, + SETTINGS_MAX_FRAME_SIZE=0x5, + SETTINGS_MAX_HEADER_LIST_SIZE=0x6, + ) + + def __init__( + self, + state=None, + length=0, + flags=Frame.FLAG_NO_FLAGS, + stream_id=0x0, + settings=None): + super(SettingsFrame, self).__init__(state, length, flags, stream_id) + + if settings is None: + settings = {} + + self.settings = settings + + @classmethod + def from_bytes(cls, state, length, flags, stream_id, payload): + f = cls(state=state, length=length, flags=flags, stream_id=stream_id) + + for i in range(0, len(payload), 6): + identifier, value = struct.unpack("!HL", payload[i:i + 6]) + f.settings[identifier] = value + + return f + + def payload_bytes(self): + if self.stream_id != 0x0: + raise ValueError( + 'SETTINGS frames MUST NOT be associated with a stream.') + + b = b'' + for identifier, value in self.settings.items(): + b += struct.pack("!HL", identifier & 0xFF, value) + + return b + + def payload_human_readable(self): + s = [] + + for identifier, value in self.settings.items(): + s.append("%s: %#x" % (self.SETTINGS.get_name(identifier), value)) + + if not s: + return "settings: None" + else: + return "\n".join(s) + + +class PushPromiseFrame(Frame): + TYPE = 0x5 + VALID_FLAGS = [Frame.FLAG_END_HEADERS, Frame.FLAG_PADDED] + + def __init__( + self, + state=None, + length=0, + flags=Frame.FLAG_NO_FLAGS, + stream_id=0x0, + promised_stream=0x0, + header_block_fragment=b'', + pad_length=0): + super(PushPromiseFrame, self).__init__(state, length, flags, stream_id) + self.pad_length = pad_length + self.promised_stream = promised_stream + self.header_block_fragment = header_block_fragment + + @classmethod + def from_bytes(cls, state, length, flags, stream_id, payload): + f = cls(state=state, length=length, flags=flags, stream_id=stream_id) + + if f.flags & Frame.FLAG_PADDED: + f.pad_length, f.promised_stream = struct.unpack('!BL', payload[:5]) + f.header_block_fragment = payload[5:-f.pad_length] + else: + f.promised_stream = int(struct.unpack("!L", payload[:4])[0]) + f.header_block_fragment = payload[4:] + + f.promised_stream &= 0x7FFFFFFF + + return f + + def payload_bytes(self): + if self.stream_id == 0x0: + raise ValueError( + 'PUSH_PROMISE frames MUST be associated with a stream.') + + if self.promised_stream == 0x0: + raise ValueError('Promised stream id not valid.') + + b = b'' + if self.flags & self.FLAG_PADDED: + b += struct.pack('!B', self.pad_length) + + b += struct.pack('!L', self.promised_stream & 0x7FFFFFFF) + b += bytes(self.header_block_fragment) + + if self.flags & self.FLAG_PADDED: + b += b'\0' * self.pad_length + + return b + + def payload_human_readable(self): + s = [] + + if self.flags & self.FLAG_PADDED: + s.append("padding: %d" % self.pad_length) + + s.append("promised stream: %#x" % self.promised_stream) + s.append( + "header_block_fragment: %s" % + self.header_block_fragment.encode('hex')) + + return "\n".join(s) + + +class PingFrame(Frame): + TYPE = 0x6 + VALID_FLAGS = [Frame.FLAG_ACK] + + def __init__( + self, + state=None, + length=0, + flags=Frame.FLAG_NO_FLAGS, + stream_id=0x0, + payload=b''): + super(PingFrame, self).__init__(state, length, flags, stream_id) + self.payload = payload + + @classmethod + def from_bytes(cls, state, length, flags, stream_id, payload): + f = cls(state=state, length=length, flags=flags, stream_id=stream_id) + f.payload = payload + return f + + def payload_bytes(self): + if self.stream_id != 0x0: + raise ValueError( + 'PING frames MUST NOT be associated with a stream.') + + b = self.payload[0:8] + b += b'\0' * (8 - len(b)) + return b + + def payload_human_readable(self): + return "opaque data: %s" % str(self.payload) + + +class GoAwayFrame(Frame): + TYPE = 0x7 + VALID_FLAGS = [] + + def __init__( + self, + state=None, + length=0, + flags=Frame.FLAG_NO_FLAGS, + stream_id=0x0, + last_stream=0x0, + error_code=0x0, + data=b''): + super(GoAwayFrame, self).__init__(state, length, flags, stream_id) + self.last_stream = last_stream + self.error_code = error_code + self.data = data + + @classmethod + def from_bytes(cls, state, length, flags, stream_id, payload): + f = cls(state=state, length=length, flags=flags, stream_id=stream_id) + + f.last_stream, f.error_code = struct.unpack("!LL", payload[:8]) + f.last_stream &= 0x7FFFFFFF + f.data = payload[8:] + + return f + + def payload_bytes(self): + if self.stream_id != 0x0: + raise ValueError( + 'GOAWAY frames MUST NOT be associated with a stream.') + + b = struct.pack('!LL', self.last_stream & 0x7FFFFFFF, self.error_code) + b += bytes(self.data) + return b + + def payload_human_readable(self): + s = [] + s.append("last stream: %#x" % self.last_stream) + s.append("error code: %d" % self.error_code) + s.append("debug data: %s" % str(self.data)) + return "\n".join(s) + + +class WindowUpdateFrame(Frame): + TYPE = 0x8 + VALID_FLAGS = [] + + def __init__( + self, + state=None, + length=0, + flags=Frame.FLAG_NO_FLAGS, + stream_id=0x0, + window_size_increment=0x0): + super(WindowUpdateFrame, self).__init__(state, length, flags, stream_id) + self.window_size_increment = window_size_increment + + @classmethod + def from_bytes(cls, state, length, flags, stream_id, payload): + f = cls(state=state, length=length, flags=flags, stream_id=stream_id) + + f.window_size_increment = struct.unpack("!L", payload)[0] + f.window_size_increment &= 0x7FFFFFFF + + return f + + def payload_bytes(self): + if self.window_size_increment <= 0 or self.window_size_increment >= 2 ** 31: + raise ValueError( + 'Window Size Increment MUST be greater than 0 and less than 2^31.') + + return struct.pack('!L', self.window_size_increment & 0x7FFFFFFF) + + def payload_human_readable(self): + return "window size increment: %#x" % self.window_size_increment + + +class ContinuationFrame(Frame): + TYPE = 0x9 + VALID_FLAGS = [Frame.FLAG_END_HEADERS] + + def __init__( + self, + state=None, + length=0, + flags=Frame.FLAG_NO_FLAGS, + stream_id=0x0, + header_block_fragment=b''): + super(ContinuationFrame, self).__init__(state, length, flags, stream_id) + self.header_block_fragment = header_block_fragment + + @classmethod + def from_bytes(cls, state, length, flags, stream_id, payload): + f = cls(state=state, length=length, flags=flags, stream_id=stream_id) + f.header_block_fragment = payload + return f + + def payload_bytes(self): + if self.stream_id == 0x0: + raise ValueError( + 'CONTINUATION frames MUST be associated with a stream.') + + return self.header_block_fragment + + def payload_human_readable(self): + s = [] + s.append( + "header_block_fragment: %s" % + self.header_block_fragment.encode('hex')) + return "\n".join(s) + +_FRAME_CLASSES = [ + DataFrame, + HeadersFrame, + PriorityFrame, + RstStreamFrame, + SettingsFrame, + PushPromiseFrame, + PingFrame, + GoAwayFrame, + WindowUpdateFrame, + ContinuationFrame +] +FRAMES = {cls.TYPE: cls for cls in _FRAME_CLASSES} + + +HTTP2_DEFAULT_SETTINGS = { + SettingsFrame.SETTINGS.SETTINGS_HEADER_TABLE_SIZE: 4096, + SettingsFrame.SETTINGS.SETTINGS_ENABLE_PUSH: 1, + SettingsFrame.SETTINGS.SETTINGS_MAX_CONCURRENT_STREAMS: None, + SettingsFrame.SETTINGS.SETTINGS_INITIAL_WINDOW_SIZE: 2 ** 16 - 1, + SettingsFrame.SETTINGS.SETTINGS_MAX_FRAME_SIZE: 2 ** 14, + SettingsFrame.SETTINGS.SETTINGS_MAX_HEADER_LIST_SIZE: None, +} diff --git a/netlib/http/http2/frames.py b/netlib/http/http2/frames.py deleted file mode 100644 index b36b3adf..00000000 --- a/netlib/http/http2/frames.py +++ /dev/null @@ -1,633 +0,0 @@ -import sys -import struct -from hpack.hpack import Encoder, Decoder - -from .. import utils - - -class FrameSizeError(Exception): - pass - - -class Frame(object): - - """ - Baseclass Frame - contains header - payload is defined in subclasses - """ - - FLAG_NO_FLAGS = 0x0 - FLAG_ACK = 0x1 - FLAG_END_STREAM = 0x1 - FLAG_END_HEADERS = 0x4 - FLAG_PADDED = 0x8 - FLAG_PRIORITY = 0x20 - - def __init__( - self, - state=None, - length=0, - flags=FLAG_NO_FLAGS, - stream_id=0x0): - valid_flags = reduce(lambda x, y: x | y, self.VALID_FLAGS, 0x0) - if flags | valid_flags != valid_flags: - raise ValueError('invalid flags detected.') - - if state is None: - class State(object): - pass - - state = State() - state.http2_settings = HTTP2_DEFAULT_SETTINGS.copy() - state.encoder = Encoder() - state.decoder = Decoder() - - self.state = state - - self.length = length - self.type = self.TYPE - self.flags = flags - self.stream_id = stream_id - - @classmethod - def _check_frame_size(cls, length, state): - if state: - settings = state.http2_settings - else: - settings = HTTP2_DEFAULT_SETTINGS.copy() - - max_frame_size = settings[ - SettingsFrame.SETTINGS.SETTINGS_MAX_FRAME_SIZE] - - if length > max_frame_size: - raise FrameSizeError( - "Frame size exceeded: %d, but only %d allowed." % ( - length, max_frame_size)) - - @classmethod - def from_file(cls, fp, state=None): - """ - read a HTTP/2 frame sent by a server or client - fp is a "file like" object that could be backed by a network - stream or a disk or an in memory stream reader - """ - raw_header = fp.safe_read(9) - - fields = struct.unpack("!HBBBL", raw_header) - length = (fields[0] << 8) + fields[1] - flags = fields[3] - stream_id = fields[4] - - if raw_header[:4] == b'HTTP': # pragma no cover - print >> sys.stderr, "WARNING: This looks like an HTTP/1 connection!" - - cls._check_frame_size(length, state) - - payload = fp.safe_read(length) - return FRAMES[fields[2]].from_bytes( - state, - length, - flags, - stream_id, - payload) - - def to_bytes(self): - payload = self.payload_bytes() - self.length = len(payload) - - self._check_frame_size(self.length, self.state) - - b = struct.pack('!HB', (self.length & 0xFFFF00) >> 8, self.length & 0x0000FF) - b += struct.pack('!B', self.TYPE) - b += struct.pack('!B', self.flags) - b += struct.pack('!L', self.stream_id & 0x7FFFFFFF) - b += payload - - return b - - def payload_bytes(self): # pragma: no cover - raise NotImplementedError() - - def payload_human_readable(self): # pragma: no cover - raise NotImplementedError() - - def human_readable(self, direction="-"): - self.length = len(self.payload_bytes()) - - return "\n".join([ - "%s: %s | length: %d | flags: %#x | stream_id: %d" % ( - direction, self.__class__.__name__, self.length, self.flags, self.stream_id), - self.payload_human_readable(), - "===============================================================", - ]) - - def __eq__(self, other): - return self.to_bytes() == other.to_bytes() - - -class DataFrame(Frame): - TYPE = 0x0 - VALID_FLAGS = [Frame.FLAG_END_STREAM, Frame.FLAG_PADDED] - - def __init__( - self, - state=None, - length=0, - flags=Frame.FLAG_NO_FLAGS, - stream_id=0x0, - payload=b'', - pad_length=0): - super(DataFrame, self).__init__(state, length, flags, stream_id) - self.payload = payload - self.pad_length = pad_length - - @classmethod - def from_bytes(cls, state, length, flags, stream_id, payload): - f = cls(state=state, length=length, flags=flags, stream_id=stream_id) - - if f.flags & Frame.FLAG_PADDED: - f.pad_length = struct.unpack('!B', payload[0])[0] - f.payload = payload[1:-f.pad_length] - else: - f.payload = payload - - return f - - def payload_bytes(self): - if self.stream_id == 0x0: - raise ValueError('DATA frames MUST be associated with a stream.') - - b = b'' - if self.flags & self.FLAG_PADDED: - b += struct.pack('!B', self.pad_length) - - b += bytes(self.payload) - - if self.flags & self.FLAG_PADDED: - b += b'\0' * self.pad_length - - return b - - def payload_human_readable(self): - return "payload: %s" % str(self.payload) - - -class HeadersFrame(Frame): - TYPE = 0x1 - VALID_FLAGS = [ - Frame.FLAG_END_STREAM, - Frame.FLAG_END_HEADERS, - Frame.FLAG_PADDED, - Frame.FLAG_PRIORITY] - - def __init__( - self, - state=None, - length=0, - flags=Frame.FLAG_NO_FLAGS, - stream_id=0x0, - header_block_fragment=b'', - pad_length=0, - exclusive=False, - stream_dependency=0x0, - weight=0): - super(HeadersFrame, self).__init__(state, length, flags, stream_id) - - self.header_block_fragment = header_block_fragment - self.pad_length = pad_length - self.exclusive = exclusive - self.stream_dependency = stream_dependency - self.weight = weight - - @classmethod - def from_bytes(cls, state, length, flags, stream_id, payload): - f = cls(state=state, length=length, flags=flags, stream_id=stream_id) - - if f.flags & Frame.FLAG_PADDED: - f.pad_length = struct.unpack('!B', payload[0])[0] - f.header_block_fragment = payload[1:-f.pad_length] - else: - f.header_block_fragment = payload[0:] - - if f.flags & Frame.FLAG_PRIORITY: - f.stream_dependency, f.weight = struct.unpack( - '!LB', f.header_block_fragment[:5]) - f.exclusive = bool(f.stream_dependency >> 31) - f.stream_dependency &= 0x7FFFFFFF - f.header_block_fragment = f.header_block_fragment[5:] - - return f - - def payload_bytes(self): - if self.stream_id == 0x0: - raise ValueError('HEADERS frames MUST be associated with a stream.') - - b = b'' - if self.flags & self.FLAG_PADDED: - b += struct.pack('!B', self.pad_length) - - if self.flags & self.FLAG_PRIORITY: - b += struct.pack('!LB', - (int(self.exclusive) << 31) | self.stream_dependency, - self.weight) - - b += self.header_block_fragment - - if self.flags & self.FLAG_PADDED: - b += b'\0' * self.pad_length - - return b - - def payload_human_readable(self): - s = [] - - if self.flags & self.FLAG_PRIORITY: - s.append("exclusive: %d" % self.exclusive) - s.append("stream dependency: %#x" % self.stream_dependency) - s.append("weight: %d" % self.weight) - - if self.flags & self.FLAG_PADDED: - s.append("padding: %d" % self.pad_length) - - s.append( - "header_block_fragment: %s" % - self.header_block_fragment.encode('hex')) - - return "\n".join(s) - - -class PriorityFrame(Frame): - TYPE = 0x2 - VALID_FLAGS = [] - - def __init__( - self, - state=None, - length=0, - flags=Frame.FLAG_NO_FLAGS, - stream_id=0x0, - exclusive=False, - stream_dependency=0x0, - weight=0): - super(PriorityFrame, self).__init__(state, length, flags, stream_id) - self.exclusive = exclusive - self.stream_dependency = stream_dependency - self.weight = weight - - @classmethod - def from_bytes(cls, state, length, flags, stream_id, payload): - f = cls(state=state, length=length, flags=flags, stream_id=stream_id) - - f.stream_dependency, f.weight = struct.unpack('!LB', payload) - f.exclusive = bool(f.stream_dependency >> 31) - f.stream_dependency &= 0x7FFFFFFF - - return f - - def payload_bytes(self): - if self.stream_id == 0x0: - raise ValueError( - 'PRIORITY frames MUST be associated with a stream.') - - return struct.pack( - '!LB', - (int( - self.exclusive) << 31) | self.stream_dependency, - self.weight) - - def payload_human_readable(self): - s = [] - s.append("exclusive: %d" % self.exclusive) - s.append("stream dependency: %#x" % self.stream_dependency) - s.append("weight: %d" % self.weight) - return "\n".join(s) - - -class RstStreamFrame(Frame): - TYPE = 0x3 - VALID_FLAGS = [] - - def __init__( - self, - state=None, - length=0, - flags=Frame.FLAG_NO_FLAGS, - stream_id=0x0, - error_code=0x0): - super(RstStreamFrame, self).__init__(state, length, flags, stream_id) - self.error_code = error_code - - @classmethod - def from_bytes(cls, state, length, flags, stream_id, payload): - f = cls(state=state, length=length, flags=flags, stream_id=stream_id) - f.error_code = struct.unpack('!L', payload)[0] - return f - - def payload_bytes(self): - if self.stream_id == 0x0: - raise ValueError( - 'RST_STREAM frames MUST be associated with a stream.') - - return struct.pack('!L', self.error_code) - - def payload_human_readable(self): - return "error code: %#x" % self.error_code - - -class SettingsFrame(Frame): - TYPE = 0x4 - VALID_FLAGS = [Frame.FLAG_ACK] - - SETTINGS = utils.BiDi( - SETTINGS_HEADER_TABLE_SIZE=0x1, - SETTINGS_ENABLE_PUSH=0x2, - SETTINGS_MAX_CONCURRENT_STREAMS=0x3, - SETTINGS_INITIAL_WINDOW_SIZE=0x4, - SETTINGS_MAX_FRAME_SIZE=0x5, - SETTINGS_MAX_HEADER_LIST_SIZE=0x6, - ) - - def __init__( - self, - state=None, - length=0, - flags=Frame.FLAG_NO_FLAGS, - stream_id=0x0, - settings=None): - super(SettingsFrame, self).__init__(state, length, flags, stream_id) - - if settings is None: - settings = {} - - self.settings = settings - - @classmethod - def from_bytes(cls, state, length, flags, stream_id, payload): - f = cls(state=state, length=length, flags=flags, stream_id=stream_id) - - for i in xrange(0, len(payload), 6): - identifier, value = struct.unpack("!HL", payload[i:i + 6]) - f.settings[identifier] = value - - return f - - def payload_bytes(self): - if self.stream_id != 0x0: - raise ValueError( - 'SETTINGS frames MUST NOT be associated with a stream.') - - b = b'' - for identifier, value in self.settings.items(): - b += struct.pack("!HL", identifier & 0xFF, value) - - return b - - def payload_human_readable(self): - s = [] - - for identifier, value in self.settings.items(): - s.append("%s: %#x" % (self.SETTINGS.get_name(identifier), value)) - - if not s: - return "settings: None" - else: - return "\n".join(s) - - -class PushPromiseFrame(Frame): - TYPE = 0x5 - VALID_FLAGS = [Frame.FLAG_END_HEADERS, Frame.FLAG_PADDED] - - def __init__( - self, - state=None, - length=0, - flags=Frame.FLAG_NO_FLAGS, - stream_id=0x0, - promised_stream=0x0, - header_block_fragment=b'', - pad_length=0): - super(PushPromiseFrame, self).__init__(state, length, flags, stream_id) - self.pad_length = pad_length - self.promised_stream = promised_stream - self.header_block_fragment = header_block_fragment - - @classmethod - def from_bytes(cls, state, length, flags, stream_id, payload): - f = cls(state=state, length=length, flags=flags, stream_id=stream_id) - - if f.flags & Frame.FLAG_PADDED: - f.pad_length, f.promised_stream = struct.unpack('!BL', payload[:5]) - f.header_block_fragment = payload[5:-f.pad_length] - else: - f.promised_stream = int(struct.unpack("!L", payload[:4])[0]) - f.header_block_fragment = payload[4:] - - f.promised_stream &= 0x7FFFFFFF - - return f - - def payload_bytes(self): - if self.stream_id == 0x0: - raise ValueError( - 'PUSH_PROMISE frames MUST be associated with a stream.') - - if self.promised_stream == 0x0: - raise ValueError('Promised stream id not valid.') - - b = b'' - if self.flags & self.FLAG_PADDED: - b += struct.pack('!B', self.pad_length) - - b += struct.pack('!L', self.promised_stream & 0x7FFFFFFF) - b += bytes(self.header_block_fragment) - - if self.flags & self.FLAG_PADDED: - b += b'\0' * self.pad_length - - return b - - def payload_human_readable(self): - s = [] - - if self.flags & self.FLAG_PADDED: - s.append("padding: %d" % self.pad_length) - - s.append("promised stream: %#x" % self.promised_stream) - s.append( - "header_block_fragment: %s" % - self.header_block_fragment.encode('hex')) - - return "\n".join(s) - - -class PingFrame(Frame): - TYPE = 0x6 - VALID_FLAGS = [Frame.FLAG_ACK] - - def __init__( - self, - state=None, - length=0, - flags=Frame.FLAG_NO_FLAGS, - stream_id=0x0, - payload=b''): - super(PingFrame, self).__init__(state, length, flags, stream_id) - self.payload = payload - - @classmethod - def from_bytes(cls, state, length, flags, stream_id, payload): - f = cls(state=state, length=length, flags=flags, stream_id=stream_id) - f.payload = payload - return f - - def payload_bytes(self): - if self.stream_id != 0x0: - raise ValueError( - 'PING frames MUST NOT be associated with a stream.') - - b = self.payload[0:8] - b += b'\0' * (8 - len(b)) - return b - - def payload_human_readable(self): - return "opaque data: %s" % str(self.payload) - - -class GoAwayFrame(Frame): - TYPE = 0x7 - VALID_FLAGS = [] - - def __init__( - self, - state=None, - length=0, - flags=Frame.FLAG_NO_FLAGS, - stream_id=0x0, - last_stream=0x0, - error_code=0x0, - data=b''): - super(GoAwayFrame, self).__init__(state, length, flags, stream_id) - self.last_stream = last_stream - self.error_code = error_code - self.data = data - - @classmethod - def from_bytes(cls, state, length, flags, stream_id, payload): - f = cls(state=state, length=length, flags=flags, stream_id=stream_id) - - f.last_stream, f.error_code = struct.unpack("!LL", payload[:8]) - f.last_stream &= 0x7FFFFFFF - f.data = payload[8:] - - return f - - def payload_bytes(self): - if self.stream_id != 0x0: - raise ValueError( - 'GOAWAY frames MUST NOT be associated with a stream.') - - b = struct.pack('!LL', self.last_stream & 0x7FFFFFFF, self.error_code) - b += bytes(self.data) - return b - - def payload_human_readable(self): - s = [] - s.append("last stream: %#x" % self.last_stream) - s.append("error code: %d" % self.error_code) - s.append("debug data: %s" % str(self.data)) - return "\n".join(s) - - -class WindowUpdateFrame(Frame): - TYPE = 0x8 - VALID_FLAGS = [] - - def __init__( - self, - state=None, - length=0, - flags=Frame.FLAG_NO_FLAGS, - stream_id=0x0, - window_size_increment=0x0): - super(WindowUpdateFrame, self).__init__(state, length, flags, stream_id) - self.window_size_increment = window_size_increment - - @classmethod - def from_bytes(cls, state, length, flags, stream_id, payload): - f = cls(state=state, length=length, flags=flags, stream_id=stream_id) - - f.window_size_increment = struct.unpack("!L", payload)[0] - f.window_size_increment &= 0x7FFFFFFF - - return f - - def payload_bytes(self): - if self.window_size_increment <= 0 or self.window_size_increment >= 2 ** 31: - raise ValueError( - 'Window Size Increment MUST be greater than 0 and less than 2^31.') - - return struct.pack('!L', self.window_size_increment & 0x7FFFFFFF) - - def payload_human_readable(self): - return "window size increment: %#x" % self.window_size_increment - - -class ContinuationFrame(Frame): - TYPE = 0x9 - VALID_FLAGS = [Frame.FLAG_END_HEADERS] - - def __init__( - self, - state=None, - length=0, - flags=Frame.FLAG_NO_FLAGS, - stream_id=0x0, - header_block_fragment=b''): - super(ContinuationFrame, self).__init__(state, length, flags, stream_id) - self.header_block_fragment = header_block_fragment - - @classmethod - def from_bytes(cls, state, length, flags, stream_id, payload): - f = cls(state=state, length=length, flags=flags, stream_id=stream_id) - f.header_block_fragment = payload - return f - - def payload_bytes(self): - if self.stream_id == 0x0: - raise ValueError( - 'CONTINUATION frames MUST be associated with a stream.') - - return self.header_block_fragment - - def payload_human_readable(self): - s = [] - s.append( - "header_block_fragment: %s" % - self.header_block_fragment.encode('hex')) - return "\n".join(s) - -_FRAME_CLASSES = [ - DataFrame, - HeadersFrame, - PriorityFrame, - RstStreamFrame, - SettingsFrame, - PushPromiseFrame, - PingFrame, - GoAwayFrame, - WindowUpdateFrame, - ContinuationFrame -] -FRAMES = {cls.TYPE: cls for cls in _FRAME_CLASSES} - - -HTTP2_DEFAULT_SETTINGS = { - SettingsFrame.SETTINGS.SETTINGS_HEADER_TABLE_SIZE: 4096, - SettingsFrame.SETTINGS.SETTINGS_ENABLE_PUSH: 1, - SettingsFrame.SETTINGS.SETTINGS_MAX_CONCURRENT_STREAMS: None, - SettingsFrame.SETTINGS.SETTINGS_INITIAL_WINDOW_SIZE: 2 ** 16 - 1, - SettingsFrame.SETTINGS.SETTINGS_MAX_FRAME_SIZE: 2 ** 14, - SettingsFrame.SETTINGS.SETTINGS_MAX_HEADER_LIST_SIZE: None, -} diff --git a/netlib/http/models.py b/netlib/http/models.py index bd5863b1..572d66c9 100644 --- a/netlib/http/models.py +++ b/netlib/http/models.py @@ -474,7 +474,6 @@ class Response(object): msg=None, headers=None, body=None, - sslinfo=None, timestamp_start=None, timestamp_end=None, ): @@ -487,7 +486,6 @@ class Response(object): self.msg = msg self.headers = headers self.body = body - self.sslinfo = sslinfo self.timestamp_start = timestamp_start self.timestamp_end = timestamp_end diff --git a/netlib/tutils.py b/netlib/tutils.py index 65c4a313..758f8410 100644 --- a/netlib/tutils.py +++ b/netlib/tutils.py @@ -7,13 +7,15 @@ from contextlib import contextmanager import six import sys -from netlib import tcp, utils, http +from . import utils +from .http import Request, Response, Headers def treader(bytes): """ Construct a tcp.Read object from bytes. """ + from . import tcp # TODO: move to top once cryptography is on Python 3.5 fp = BytesIO(bytes) return tcp.Reader(fp) @@ -91,55 +93,39 @@ class RaisesContext(object): test_data = utils.Data(__name__) -def treq(content="content", scheme="http", host="address", port=22): +def treq(**kwargs): """ - @return: libmproxy.protocol.http.HTTPRequest + Returns: + netlib.http.Request """ - headers = http.Headers() - headers["header"] = "qvalue" - req = http.Request( - "relative", - "GET", - scheme, - host, - port, - "/path", - (1, 1), - headers, - content, - None, - None, + default = dict( + form_in="relative", + method=b"GET", + scheme=b"http", + host=b"address", + port=22, + path=b"/path", + httpversion=b"HTTP/1.1", + headers=Headers(header=b"qvalue"), + body=b"content" ) - return req + default.update(kwargs) + return Request(**default) -def treq_absolute(content="content"): +def tresp(**kwargs): """ - @return: libmproxy.protocol.http.HTTPRequest + Returns: + netlib.http.Response """ - r = treq(content) - r.form_in = r.form_out = "absolute" - r.host = "address" - r.port = 22 - r.scheme = "http" - return r - - -def tresp(content="message"): - """ - @return: libmproxy.protocol.http.HTTPResponse - """ - - headers = http.Headers() - headers["header_response"] = "svalue" - - resp = http.semantics.Response( - (1, 1), - 200, - "OK", - headers, - content, + default = dict( + httpversion=b"HTTP/1.1", + status_code=200, + msg=b"OK", + headers=Headers(header_response=b"svalue"), + body=b"message", timestamp_start=time.time(), - timestamp_end=time.time(), + timestamp_end=time.time() ) - return resp + default.update(kwargs) + return Response(**default) diff --git a/netlib/utils.py b/netlib/utils.py index fb579cac..a86b8019 100644 --- a/netlib/utils.py +++ b/netlib/utils.py @@ -40,9 +40,9 @@ def clean_bin(s, keep_spacing=True): ) else: if keep_spacing: - keep = b"\n\r\t" + keep = (9, 10, 13) # \t, \n, \r, else: - keep = b"" + keep = () return b"".join( six.int2byte(ch) if (31 < ch < 127 or ch in keep) else b"." for ch in six.iterbytes(s) @@ -251,7 +251,7 @@ def hostport(scheme, host, port): if (port, scheme) in [(80, "http"), (443, "https")]: return host else: - return b"%s:%s" % (host, port) + return b"%s:%d" % (host, port) def unparse_url(scheme, host, port, path=""): -- cgit v1.2.3 From 265f31e8782ee9da511ce4b63aa2da00221cbf66 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Wed, 16 Sep 2015 18:43:24 +0200 Subject: adjust http1-related code --- netlib/exceptions.py | 1 + netlib/http/__init__.py | 5 ++++- netlib/http/http1/__init__.py | 1 + netlib/http/http1/assemble.py | 4 ++-- netlib/http/http1/read.py | 18 +++++++++++------- netlib/http/http2/__init__.py | 6 ++++++ netlib/http/http2/connections.py | 28 ++++++++++++++++++---------- netlib/http/models.py | 3 +++ netlib/tutils.py | 4 ++-- 9 files changed, 48 insertions(+), 22 deletions(-) (limited to 'netlib') diff --git a/netlib/exceptions.py b/netlib/exceptions.py index 637be3df..e13af473 100644 --- a/netlib/exceptions.py +++ b/netlib/exceptions.py @@ -27,5 +27,6 @@ class HttpException(NetlibException): class HttpReadDisconnect(HttpException, ReadDisconnect): pass + class HttpSyntaxException(HttpException): pass diff --git a/netlib/http/__init__.py b/netlib/http/__init__.py index 9303de09..d72884b3 100644 --- a/netlib/http/__init__.py +++ b/netlib/http/__init__.py @@ -1,9 +1,12 @@ +from __future__ import absolute_import, print_function, division from .models import Request, Response, Headers +from .models import ALPN_PROTO_HTTP1, ALPN_PROTO_H2 from .models import HDR_FORM_MULTIPART, HDR_FORM_URLENCODED, CONTENT_MISSING from . import http1, http2 __all__ = [ "Request", "Response", "Headers", + "ALPN_PROTO_HTTP1", "ALPN_PROTO_H2", "HDR_FORM_MULTIPART", "HDR_FORM_URLENCODED", "CONTENT_MISSING", - "http1", "http2" + "http1", "http2", ] diff --git a/netlib/http/http1/__init__.py b/netlib/http/http1/__init__.py index a72c2e05..2d33ff8a 100644 --- a/netlib/http/http1/__init__.py +++ b/netlib/http/http1/__init__.py @@ -1,3 +1,4 @@ +from __future__ import absolute_import, print_function, division from .read import ( read_request, read_request_head, read_response, read_response_head, diff --git a/netlib/http/http1/assemble.py b/netlib/http/http1/assemble.py index 47c7e95a..ace25d79 100644 --- a/netlib/http/http1/assemble.py +++ b/netlib/http/http1/assemble.py @@ -25,9 +25,9 @@ def assemble_response(response): return head + response.body -def assemble_response_head(response): +def assemble_response_head(response, preserve_transfer_encoding=False): first_line = _assemble_response_line(response) - headers = _assemble_response_headers(response) + headers = _assemble_response_headers(response, preserve_transfer_encoding) return b"%s\r\n%s\r\n" % (first_line, headers) diff --git a/netlib/http/http1/read.py b/netlib/http/http1/read.py index 4c423c4c..62025d15 100644 --- a/netlib/http/http1/read.py +++ b/netlib/http/http1/read.py @@ -6,8 +6,7 @@ import re from ... import utils from ...exceptions import HttpReadDisconnect, HttpSyntaxException, HttpException from .. import Request, Response, Headers - -ALPN_PROTO_HTTP1 = b'http/1.1' +from netlib.tcp import NetLibDisconnect def read_request(rfile, body_size_limit=None): @@ -157,10 +156,10 @@ def connection_close(http_version, headers): # If we don't have a Connection header, HTTP 1.1 connections are assumed to # be persistent - return http_version != (1, 1) + return http_version != b"HTTP/1.1" -def expected_http_body_size(request, response=False): +def expected_http_body_size(request, response=None): """ Returns: The expected body length: @@ -211,10 +210,13 @@ def expected_http_body_size(request, response=False): def _get_first_line(rfile): - line = rfile.readline() - if line == b"\r\n" or line == b"\n": - # Possible leftover from previous message + try: line = rfile.readline() + if line == b"\r\n" or line == b"\n": + # Possible leftover from previous message + line = rfile.readline() + except NetLibDisconnect: + raise HttpReadDisconnect() if not line: raise HttpReadDisconnect() line = line.strip() @@ -317,6 +319,8 @@ def _read_headers(rfile): try: name, value = line.split(b":", 1) value = value.strip() + if not name or not value: + raise ValueError() ret.append([name, value]) except ValueError: raise HttpSyntaxException("Invalid headers") diff --git a/netlib/http/http2/__init__.py b/netlib/http/http2/__init__.py index e69de29b..7043d36f 100644 --- a/netlib/http/http2/__init__.py +++ b/netlib/http/http2/__init__.py @@ -0,0 +1,6 @@ +from __future__ import absolute_import, print_function, division +from .connections import HTTP2Protocol + +__all__ = [ + "HTTP2Protocol" +] diff --git a/netlib/http/http2/connections.py b/netlib/http/http2/connections.py index 036bf68f..5220d5d2 100644 --- a/netlib/http/http2/connections.py +++ b/netlib/http/http2/connections.py @@ -3,8 +3,8 @@ import itertools import time from hpack.hpack import Encoder, Decoder -from netlib import http, utils -from netlib.http import models as semantics +from ... import utils +from .. import Headers, Response, Request, ALPN_PROTO_H2 from . import frame @@ -36,8 +36,6 @@ class HTTP2Protocol(object): CLIENT_CONNECTION_PREFACE = "PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n" - ALPN_PROTO_H2 = 'h2' - def __init__( self, tcp_handler=None, @@ -62,6 +60,7 @@ class HTTP2Protocol(object): def read_request( self, + __rfile, include_body=True, body_size_limit=None, allow_empty=False, @@ -111,7 +110,7 @@ class HTTP2Protocol(object): port = 80 if scheme == 'http' else 443 port = int(port) - request = http.Request( + request = Request( form_in, method, scheme, @@ -131,6 +130,7 @@ class HTTP2Protocol(object): def read_response( self, + __rfile, request_method='', body_size_limit=None, include_body=True, @@ -159,7 +159,7 @@ class HTTP2Protocol(object): else: timestamp_end = None - response = http.Response( + response = Response( (2, 0), int(headers.get(':status', 502)), "", @@ -172,8 +172,16 @@ class HTTP2Protocol(object): return response + def assemble(self, message): + if isinstance(message, Request): + return self.assemble_request(message) + elif isinstance(message, Response): + return self.assemble_response(message) + else: + raise ValueError("HTTP message not supported.") + def assemble_request(self, request): - assert isinstance(request, semantics.Request) + assert isinstance(request, Request) authority = self.tcp_handler.sni if self.tcp_handler.sni else self.tcp_handler.address.host if self.tcp_handler.address.port != 443: @@ -200,7 +208,7 @@ class HTTP2Protocol(object): self._create_body(request.body, stream_id))) def assemble_response(self, response): - assert isinstance(response, semantics.Response) + assert isinstance(response, Response) headers = response.headers.copy() @@ -275,7 +283,7 @@ class HTTP2Protocol(object): def check_alpn(self): alp = self.tcp_handler.get_alpn_proto_negotiated() - if alp != self.ALPN_PROTO_H2: + if alp != ALPN_PROTO_H2: raise NotImplementedError( "HTTP2Protocol can not handle unknown ALP: %s" % alp) return True @@ -405,7 +413,7 @@ class HTTP2Protocol(object): else: self._handle_unexpected_frame(frm) - headers = http.Headers( + headers = Headers( [[str(k), str(v)] for k, v in self.decoder.decode(header_block_fragment)] ) diff --git a/netlib/http/models.py b/netlib/http/models.py index 572d66c9..2d09535c 100644 --- a/netlib/http/models.py +++ b/netlib/http/models.py @@ -13,6 +13,9 @@ try: except ImportError: from collections.abc import MutableMapping +# TODO: Move somewhere else? +ALPN_PROTO_HTTP1 = b'http/1.1' +ALPN_PROTO_H2 = b'h2' HDR_FORM_URLENCODED = b"application/x-www-form-urlencoded" HDR_FORM_MULTIPART = b"multipart/form-data" diff --git a/netlib/tutils.py b/netlib/tutils.py index 758f8410..05791c49 100644 --- a/netlib/tutils.py +++ b/netlib/tutils.py @@ -37,14 +37,14 @@ def _check_exception(expected, actual, exc_tb): if expected.lower() not in str(actual).lower(): six.reraise(AssertionError, AssertionError( "Expected %s, but caught %s" % ( - repr(str(expected)), actual + repr(expected), repr(actual) ) ), exc_tb) else: if not isinstance(actual, expected): six.reraise(AssertionError, AssertionError( "Expected %s, but caught %s %s" % ( - expected.__name__, actual.__class__.__name__, str(actual) + expected.__name__, actual.__class__.__name__, repr(actual) ) ), exc_tb) -- cgit v1.2.3 From dad9f06cb9403ac88d31d0ba8422034df2bc5078 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Thu, 17 Sep 2015 02:14:14 +0200 Subject: organize exceptions, improve content-length handling --- netlib/exceptions.py | 30 +++++++++++- netlib/http/http1/assemble.py | 8 ++-- netlib/http/http1/read.py | 9 ++-- netlib/http/models.py | 24 +++++++++- netlib/tcp.py | 108 +++++++++++++++++++----------------------- 5 files changed, 107 insertions(+), 72 deletions(-) (limited to 'netlib') diff --git a/netlib/exceptions.py b/netlib/exceptions.py index e13af473..e30235af 100644 --- a/netlib/exceptions.py +++ b/netlib/exceptions.py @@ -16,7 +16,7 @@ class NetlibException(Exception): super(NetlibException, self).__init__(message) -class ReadDisconnect(object): +class Disconnect(object): """Immediate EOF""" @@ -24,9 +24,35 @@ class HttpException(NetlibException): pass -class HttpReadDisconnect(HttpException, ReadDisconnect): +class HttpReadDisconnect(HttpException, Disconnect): pass class HttpSyntaxException(HttpException): pass + + +class TcpException(NetlibException): + pass + + +class TcpDisconnect(TcpException, Disconnect): + pass + + + + +class TcpReadIncomplete(TcpException): + pass + + +class TcpTimeout(TcpException): + pass + + +class TlsException(NetlibException): + pass + + +class InvalidCertificateException(TlsException): + pass diff --git a/netlib/http/http1/assemble.py b/netlib/http/http1/assemble.py index ace25d79..33b9ef25 100644 --- a/netlib/http/http1/assemble.py +++ b/netlib/http/http1/assemble.py @@ -95,9 +95,9 @@ def _assemble_response_headers(response, preserve_transfer_encoding=False): if not preserve_transfer_encoding: headers.pop(b"Transfer-Encoding", None) - # If body is defined (i.e. not None or CONTENT_MISSING), we always - # add a content-length header. - if response.body or response.body == b"": - headers[b"Content-Length"] = str(len(response.body)).encode("ascii") + # If body is defined (i.e. not None or CONTENT_MISSING), + # we now need to set a content-length header. + if response.body or response.body == b"": + headers[b"Content-Length"] = str(len(response.body)).encode("ascii") return bytes(headers) diff --git a/netlib/http/http1/read.py b/netlib/http/http1/read.py index 62025d15..7f2b7bab 100644 --- a/netlib/http/http1/read.py +++ b/netlib/http/http1/read.py @@ -4,15 +4,14 @@ import sys import re from ... import utils -from ...exceptions import HttpReadDisconnect, HttpSyntaxException, HttpException +from ...exceptions import HttpReadDisconnect, HttpSyntaxException, HttpException, TcpDisconnect from .. import Request, Response, Headers -from netlib.tcp import NetLibDisconnect def read_request(rfile, body_size_limit=None): request = read_request_head(rfile) expected_body_size = expected_http_body_size(request) - request.body = b"".join(read_body(rfile, expected_body_size, limit=body_size_limit)) + request._body = b"".join(read_body(rfile, expected_body_size, limit=body_size_limit)) request.timestamp_end = time.time() return request @@ -51,7 +50,7 @@ def read_request_head(rfile): def read_response(rfile, request, body_size_limit=None): response = read_response_head(rfile) expected_body_size = expected_http_body_size(request, response) - response.body = b"".join(read_body(rfile, expected_body_size, body_size_limit)) + response._body = b"".join(read_body(rfile, expected_body_size, body_size_limit)) response.timestamp_end = time.time() return response @@ -215,7 +214,7 @@ def _get_first_line(rfile): if line == b"\r\n" or line == b"\n": # Possible leftover from previous message line = rfile.readline() - except NetLibDisconnect: + except TcpDisconnect: raise HttpReadDisconnect() if not line: raise HttpReadDisconnect() diff --git a/netlib/http/models.py b/netlib/http/models.py index 2d09535c..b4446ecb 100644 --- a/netlib/http/models.py +++ b/netlib/http/models.py @@ -231,7 +231,7 @@ class Request(object): self.path = path self.httpversion = httpversion self.headers = headers - self.body = body + self._body = body self.timestamp_start = timestamp_start self.timestamp_end = timestamp_end self.form_out = form_out or form_in @@ -452,6 +452,16 @@ class Request(object): raise ValueError("Invalid URL: %s" % url) self.scheme, self.host, self.port, self.path = parts + @property + def body(self): + return self._body + + @body.setter + def body(self, body): + self._body = body + if isinstance(body, bytes): + self.headers["Content-Length"] = str(len(body)).encode() + @property def content(self): # pragma: no cover # TODO: remove deprecated getter @@ -488,7 +498,7 @@ class Response(object): self.status_code = status_code self.msg = msg self.headers = headers - self.body = body + self._body = body self.timestamp_start = timestamp_start self.timestamp_end = timestamp_end @@ -551,6 +561,16 @@ class Response(object): ) self.headers.set_all("Set-Cookie", values) + @property + def body(self): + return self._body + + @body.setter + def body(self, body): + self._body = body + if isinstance(body, bytes): + self.headers["Content-Length"] = str(len(body)).encode() + @property def content(self): # pragma: no cover # TODO: remove deprecated getter diff --git a/netlib/tcp.py b/netlib/tcp.py index 1eb417b4..707e11e0 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -16,6 +16,9 @@ from . import certutils, version_check # This is a rather hackish way to make sure that # the latest version of pyOpenSSL is actually installed. +from netlib.exceptions import InvalidCertificateException, TcpReadIncomplete, TlsException, \ + TcpTimeout, TcpDisconnect, TcpException + version_check.check_pyopenssl_version() @@ -24,11 +27,17 @@ EINTR = 4 # To enable all SSL methods use: SSLv23 # then add options to disable certain methods # https://bugs.launchpad.net/pyopenssl/+bug/1020632/comments/3 +SSL_BASIC_OPTIONS = ( + SSL.OP_CIPHER_SERVER_PREFERENCE +) +if hasattr(SSL, "OP_NO_COMPRESSION"): + SSL_BASIC_OPTIONS |= SSL.OP_NO_COMPRESSION + SSL_DEFAULT_METHOD = SSL.SSLv23_METHOD SSL_DEFAULT_OPTIONS = ( SSL.OP_NO_SSLv2 | SSL.OP_NO_SSLv3 | - SSL.OP_CIPHER_SERVER_PREFERENCE + SSL_BASIC_OPTIONS ) if hasattr(SSL, "OP_NO_COMPRESSION"): SSL_DEFAULT_OPTIONS |= SSL.OP_NO_COMPRESSION @@ -39,42 +48,17 @@ Don't ask... https://bugs.launchpad.net/pyopenssl/+bug/1020632/comments/3 """ sslversion_choices = { - "all": (SSL.SSLv23_METHOD, 0), + "all": (SSL.SSLv23_METHOD, SSL_BASIC_OPTIONS), # SSLv23_METHOD + NO_SSLv2 + NO_SSLv3 == TLS 1.0+ # TLSv1_METHOD would be TLS 1.0 only - "secure": (SSL.SSLv23_METHOD, (SSL.OP_NO_SSLv2 | SSL.OP_NO_SSLv3)), - "SSLv2": (SSL.SSLv2_METHOD, 0), - "SSLv3": (SSL.SSLv3_METHOD, 0), - "TLSv1": (SSL.TLSv1_METHOD, 0), - "TLSv1_1": (SSL.TLSv1_1_METHOD, 0), - "TLSv1_2": (SSL.TLSv1_2_METHOD, 0), + "secure": (SSL.SSLv23_METHOD, (SSL.OP_NO_SSLv2 | SSL.OP_NO_SSLv3 | SSL_BASIC_OPTIONS)), + "SSLv2": (SSL.SSLv2_METHOD, SSL_BASIC_OPTIONS), + "SSLv3": (SSL.SSLv3_METHOD, SSL_BASIC_OPTIONS), + "TLSv1": (SSL.TLSv1_METHOD, SSL_BASIC_OPTIONS), + "TLSv1_1": (SSL.TLSv1_1_METHOD, SSL_BASIC_OPTIONS), + "TLSv1_2": (SSL.TLSv1_2_METHOD, SSL_BASIC_OPTIONS), } - -class NetLibError(Exception): - pass - - -class NetLibDisconnect(NetLibError): - pass - - -class NetLibIncomplete(NetLibError): - pass - - -class NetLibTimeout(NetLibError): - pass - - -class NetLibSSLError(NetLibError): - pass - - -class NetLibInvalidCertificateError(NetLibSSLError): - pass - - class SSLKeyLogger(object): def __init__(self, filename): @@ -168,17 +152,17 @@ class Writer(_FileLike): def flush(self): """ - May raise NetLibDisconnect + May raise TcpDisconnect """ if hasattr(self.o, "flush"): try: self.o.flush() except (socket.error, IOError) as v: - raise NetLibDisconnect(str(v)) + raise TcpDisconnect(str(v)) def write(self, v): """ - May raise NetLibDisconnect + May raise TcpDisconnect """ if v: self.first_byte_timestamp = self.first_byte_timestamp or time.time() @@ -191,7 +175,7 @@ class Writer(_FileLike): self.add_log(v[:r]) return r except (SSL.Error, socket.error) as e: - raise NetLibDisconnect(str(e)) + raise TcpDisconnect(str(e)) class Reader(_FileLike): @@ -210,23 +194,29 @@ class Reader(_FileLike): try: data = self.o.read(rlen) except SSL.ZeroReturnError: + # TLS connection was shut down cleanly break - except SSL.WantReadError: + except (SSL.WantWriteError, SSL.WantReadError): + # From the OpenSSL docs: + # If the underlying BIO is non-blocking, SSL_read() will also return when the + # underlying BIO could not satisfy the needs of SSL_read() to continue the + # operation. In this case a call to SSL_get_error with the return value of + # SSL_read() will yield SSL_ERROR_WANT_READ or SSL_ERROR_WANT_WRITE. if (time.time() - start) < self.o.gettimeout(): time.sleep(0.1) continue else: - raise NetLibTimeout + raise TcpTimeout() except socket.timeout: - raise NetLibTimeout - except socket.error: - raise NetLibDisconnect + raise TcpTimeout() + except socket.error as e: + raise TcpDisconnect(str(e)) except SSL.SysCallError as e: if e.args == (-1, 'Unexpected EOF'): break - raise NetLibSSLError(e.message) + raise TlsException(e.message) except SSL.Error as e: - raise NetLibSSLError(e.message) + raise TlsException(e.message) self.first_byte_timestamp = self.first_byte_timestamp or time.time() if not data: break @@ -260,9 +250,9 @@ class Reader(_FileLike): result = self.read(length) if length != -1 and len(result) != length: if not result: - raise NetLibDisconnect() + raise TcpDisconnect() else: - raise NetLibIncomplete( + raise TcpReadIncomplete( "Expected %s bytes, got %s" % (length, len(result)) ) return result @@ -275,15 +265,15 @@ class Reader(_FileLike): Up to the next N bytes if peeking is successful. Raises: - NetLibError if there was an error with the socket - NetLibSSLError if there was an error with pyOpenSSL. + TcpException if there was an error with the socket + TlsException if there was an error with pyOpenSSL. NotImplementedError if the underlying file object is not a (pyOpenSSL) socket """ if isinstance(self.o, socket._fileobject): try: return self.o._sock.recv(length, socket.MSG_PEEK) except socket.error as e: - raise NetLibError(repr(e)) + raise TcpException(repr(e)) elif isinstance(self.o, SSL.Connection): try: if tuple(int(x) for x in OpenSSL.__version__.split(".")[:2]) > (0, 15): @@ -296,7 +286,7 @@ class Reader(_FileLike): self.o._raise_ssl_error(self.o._ssl, result) return SSL._ffi.buffer(buf, result)[:] except SSL.Error as e: - six.reraise(NetLibSSLError, NetLibSSLError(str(e)), sys.exc_info()[2]) + six.reraise(TlsException, TlsException(str(e)), sys.exc_info()[2]) else: raise NotImplementedError("Can only peek into (pyOpenSSL) sockets") @@ -461,7 +451,7 @@ class _Connection(object): try: self.wfile.flush() self.wfile.close() - except NetLibDisconnect: + except TcpDisconnect: pass self.rfile.close() @@ -525,7 +515,7 @@ class _Connection(object): # TODO: maybe change this to with newer pyOpenSSL APIs context.set_tmp_ecdh(OpenSSL.crypto.get_elliptic_curve('prime256v1')) except SSL.Error as v: - raise NetLibError("SSL cipher specification error: %s" % str(v)) + raise TlsException("SSL cipher specification error: %s" % str(v)) # SSLKEYLOGFILE if log_ssl_key: @@ -546,7 +536,7 @@ class _Connection(object): elif alpn_select_callback is not None and alpn_select is None: context.set_alpn_select_callback(alpn_select_callback) elif alpn_select_callback is not None and alpn_select is not None: - raise NetLibError("ALPN error: only define alpn_select (string) OR alpn_select_callback (method).") + raise TlsException("ALPN error: only define alpn_select (string) OR alpn_select_callback (method).") return context @@ -594,7 +584,7 @@ class TCPClient(_Connection): context.use_privatekey_file(cert) context.use_certificate_file(cert) except SSL.Error as v: - raise NetLibError("SSL client certificate error: %s" % str(v)) + raise TlsException("SSL client certificate error: %s" % str(v)) return context def convert_to_ssl(self, sni=None, alpn_protos=None, **sslctx_kwargs): @@ -618,15 +608,15 @@ class TCPClient(_Connection): self.connection.do_handshake() except SSL.Error as v: if self.ssl_verification_error: - raise NetLibInvalidCertificateError("SSL handshake error: %s" % repr(v)) + raise InvalidCertificateException("SSL handshake error: %s" % repr(v)) else: - raise NetLibError("SSL handshake error: %s" % repr(v)) + raise TlsException("SSL handshake error: %s" % repr(v)) # Fix for pre v1.0 OpenSSL, which doesn't throw an exception on # certificate validation failure verification_mode = sslctx_kwargs.get('verify_options', None) if self.ssl_verification_error is not None and verification_mode == SSL.VERIFY_PEER: - raise NetLibInvalidCertificateError("SSL handshake error: certificate verify failed") + raise InvalidCertificateException("SSL handshake error: certificate verify failed") self.ssl_established = True self.cert = certutils.SSLCert(self.connection.get_peer_certificate()) @@ -644,7 +634,7 @@ class TCPClient(_Connection): self.rfile = Reader(connection.makefile('rb', self.rbufsize)) self.wfile = Writer(connection.makefile('wb', self.wbufsize)) except (socket.error, IOError) as err: - raise NetLibError( + raise TcpException( 'Error connecting to "%s": %s' % (self.address.host, err)) self.connection = connection @@ -750,7 +740,7 @@ class BaseHandler(_Connection): try: self.connection.do_handshake() except SSL.Error as v: - raise NetLibError("SSL handshake error: %s" % repr(v)) + raise TlsException("SSL handshake error: %s" % repr(v)) self.ssl_established = True self.rfile.set_descriptor(self.connection) self.wfile.set_descriptor(self.connection) -- cgit v1.2.3 From a07e43df8b3988f137b48957f978ad570d9dc782 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Thu, 17 Sep 2015 02:39:42 +0200 Subject: http1: add assemble_body function --- netlib/exceptions.py | 2 -- netlib/http/http1/__init__.py | 2 ++ netlib/http/http1/assemble.py | 26 +++++++++++++++----------- 3 files changed, 17 insertions(+), 13 deletions(-) (limited to 'netlib') diff --git a/netlib/exceptions.py b/netlib/exceptions.py index e30235af..05f1054b 100644 --- a/netlib/exceptions.py +++ b/netlib/exceptions.py @@ -40,8 +40,6 @@ class TcpDisconnect(TcpException, Disconnect): pass - - class TcpReadIncomplete(TcpException): pass diff --git a/netlib/http/http1/__init__.py b/netlib/http/http1/__init__.py index 2d33ff8a..2aa7e26a 100644 --- a/netlib/http/http1/__init__.py +++ b/netlib/http/http1/__init__.py @@ -9,6 +9,7 @@ from .read import ( from .assemble import ( assemble_request, assemble_request_head, assemble_response, assemble_response_head, + assemble_body, ) @@ -20,4 +21,5 @@ __all__ = [ "expected_http_body_size", "assemble_request", "assemble_request_head", "assemble_response", "assemble_response_head", + "assemble_body", ] diff --git a/netlib/http/http1/assemble.py b/netlib/http/http1/assemble.py index 33b9ef25..7252c446 100644 --- a/netlib/http/http1/assemble.py +++ b/netlib/http/http1/assemble.py @@ -1,6 +1,7 @@ from __future__ import absolute_import, print_function, division from ... import utils +import itertools from ...exceptions import HttpException from .. import CONTENT_MISSING @@ -25,12 +26,23 @@ def assemble_response(response): return head + response.body -def assemble_response_head(response, preserve_transfer_encoding=False): +def assemble_response_head(response): first_line = _assemble_response_line(response) - headers = _assemble_response_headers(response, preserve_transfer_encoding) + headers = _assemble_response_headers(response) return b"%s\r\n%s\r\n" % (first_line, headers) +def assemble_body(headers, body_chunks): + if b"chunked" in headers.get(b"transfer-encoding", b"").lower(): + for chunk in body_chunks: + if chunk: + yield b"%x\r\n%s\r\n" % (len(chunk), chunk) + yield b"0\r\n\r\n" + else: + for chunk in body_chunks: + yield chunk + + def _assemble_request_line(request, form=None): if form is None: form = request.form_out @@ -87,17 +99,9 @@ def _assemble_response_line(response): ) -def _assemble_response_headers(response, preserve_transfer_encoding=False): - # TODO: Remove preserve_transfer_encoding +def _assemble_response_headers(response): headers = response.headers.copy() for k in response._headers_to_strip_off: headers.pop(k, None) - if not preserve_transfer_encoding: - headers.pop(b"Transfer-Encoding", None) - - # If body is defined (i.e. not None or CONTENT_MISSING), - # we now need to set a content-length header. - if response.body or response.body == b"": - headers[b"Content-Length"] = str(len(response.body)).encode("ascii") return bytes(headers) -- cgit v1.2.3 From 8d71059d77c2dd1d9858d7971dd0b6b4387ed9f4 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Thu, 17 Sep 2015 15:16:12 +0200 Subject: clean up http message models --- netlib/http/http1/assemble.py | 8 +-- netlib/http/models.py | 159 ++++++++++++++---------------------------- netlib/tutils.py | 4 +- netlib/utils.py | 30 +++----- netlib/websockets/frame.py | 9 +-- netlib/websockets/protocol.py | 3 +- 6 files changed, 74 insertions(+), 139 deletions(-) (limited to 'netlib') diff --git a/netlib/http/http1/assemble.py b/netlib/http/http1/assemble.py index 7252c446..b65a6be0 100644 --- a/netlib/http/http1/assemble.py +++ b/netlib/http/http1/assemble.py @@ -50,14 +50,14 @@ def _assemble_request_line(request, form=None): return b"%s %s %s" % ( request.method, request.path, - request.httpversion + request.http_version ) elif form == "authority": return b"%s %s:%d %s" % ( request.method, request.host, request.port, - request.httpversion + request.http_version ) elif form == "absolute": return b"%s %s://%s:%d%s %s" % ( @@ -66,7 +66,7 @@ def _assemble_request_line(request, form=None): request.host, request.port, request.path, - request.httpversion + request.http_version ) else: # pragma: nocover raise RuntimeError("Invalid request form") @@ -93,7 +93,7 @@ def _assemble_request_headers(request): def _assemble_response_line(response): return b"%s %d %s" % ( - response.httpversion, + response.http_version, response.status_code, response.msg, ) diff --git a/netlib/http/models.py b/netlib/http/models.py index b4446ecb..54b8b112 100644 --- a/netlib/http/models.py +++ b/netlib/http/models.py @@ -193,15 +193,45 @@ class Headers(MutableMapping, object): return cls([list(field) for field in state]) -class Request(object): +class Message(object): + def __init__(self, http_version, headers, body, timestamp_start, timestamp_end): + self.http_version = http_version + if not headers: + headers = Headers() + assert isinstance(headers, Headers) + self.headers = headers + + self._body = body + self.timestamp_start = timestamp_start + self.timestamp_end = timestamp_end + + @property + def body(self): + return self._body + + @body.setter + def body(self, body): + self._body = body + if isinstance(body, bytes): + self.headers[b"Content-Length"] = str(len(body)).encode() + + content = body + + def __eq__(self, other): + if isinstance(other, Message): + return self.__dict__ == other.__dict__ + return False + + +class Request(Message): # This list is adopted legacy code. # We probably don't need to strip off keep-alive. _headers_to_strip_off = [ - 'Proxy-Connection', - 'Keep-Alive', - 'Connection', - 'Transfer-Encoding', - 'Upgrade', + b'Proxy-Connection', + b'Keep-Alive', + b'Connection', + b'Transfer-Encoding', + b'Upgrade', ] def __init__( @@ -212,16 +242,14 @@ class Request(object): host, port, path, - httpversion, + http_version, headers=None, body=None, timestamp_start=None, timestamp_end=None, form_out=None ): - if not headers: - headers = Headers() - assert isinstance(headers, Headers) + super(Request, self).__init__(http_version, headers, body, timestamp_start, timestamp_end) self.form_in = form_in self.method = method @@ -229,23 +257,8 @@ class Request(object): self.host = host self.port = port self.path = path - self.httpversion = httpversion - self.headers = headers - self._body = body - self.timestamp_start = timestamp_start - self.timestamp_end = timestamp_end self.form_out = form_out or form_in - def __eq__(self, other): - try: - self_d = [self.__dict__[k] for k in self.__dict__ if - k not in ('timestamp_start', 'timestamp_end')] - other_d = [other.__dict__[k] for k in other.__dict__ if - k not in ('timestamp_start', 'timestamp_end')] - return self_d == other_d - except: - return False - def __repr__(self): if self.host and self.port: hostport = "{}:{}".format(self.host, self.port) @@ -262,8 +275,8 @@ class Request(object): response. That is, we remove ETags and If-Modified-Since headers. """ delheaders = [ - "if-modified-since", - "if-none-match", + b"if-modified-since", + b"if-none-match", ] for i in delheaders: self.headers.pop(i, None) @@ -273,16 +286,16 @@ class Request(object): Modifies this request to remove headers that will compress the resource's data. """ - self.headers["accept-encoding"] = "identity" + self.headers[b"accept-encoding"] = b"identity" def constrain_encoding(self): """ Limits the permissible Accept-Encoding values, based on what we can decode appropriately. """ - accept_encoding = self.headers.get("accept-encoding") + accept_encoding = self.headers.get(b"accept-encoding") if accept_encoding: - self.headers["accept-encoding"] = ( + self.headers[b"accept-encoding"] = ( ', '.join( e for e in encoding.ENCODINGS @@ -335,7 +348,7 @@ class Request(object): """ # FIXME: If there's an existing content-type header indicating a # url-encoded form, leave it alone. - self.headers["Content-Type"] = HDR_FORM_URLENCODED + self.headers[b"Content-Type"] = HDR_FORM_URLENCODED self.body = utils.urlencode(odict.lst) def get_path_components(self): @@ -452,37 +465,17 @@ class Request(object): raise ValueError("Invalid URL: %s" % url) self.scheme, self.host, self.port, self.path = parts - @property - def body(self): - return self._body - - @body.setter - def body(self, body): - self._body = body - if isinstance(body, bytes): - self.headers["Content-Length"] = str(len(body)).encode() - - @property - def content(self): # pragma: no cover - # TODO: remove deprecated getter - return self.body - - @content.setter - def content(self, content): # pragma: no cover - # TODO: remove deprecated setter - self.body = content - -class Response(object): +class Response(Message): _headers_to_strip_off = [ - 'Proxy-Connection', - 'Alternate-Protocol', - 'Alt-Svc', + b'Proxy-Connection', + b'Alternate-Protocol', + b'Alt-Svc', ] def __init__( self, - httpversion, + http_version, status_code, msg=None, headers=None, @@ -490,27 +483,9 @@ class Response(object): timestamp_start=None, timestamp_end=None, ): - if not headers: - headers = Headers() - assert isinstance(headers, Headers) - - self.httpversion = httpversion + super(Response, self).__init__(http_version, headers, body, timestamp_start, timestamp_end) self.status_code = status_code self.msg = msg - self.headers = headers - self._body = body - self.timestamp_start = timestamp_start - self.timestamp_end = timestamp_end - - def __eq__(self, other): - try: - self_d = [self.__dict__[k] for k in self.__dict__ if - k not in ('timestamp_start', 'timestamp_end')] - other_d = [other.__dict__[k] for k in other.__dict__ if - k not in ('timestamp_start', 'timestamp_end')] - return self_d == other_d - except: - return False def __repr__(self): # return "Response(%s - %s)" % (self.status_code, self.msg) @@ -536,7 +511,7 @@ class Response(object): attributes (e.g. HTTPOnly) are indicated by a Null value. """ ret = [] - for header in self.headers.get_all("set-cookie"): + for header in self.headers.get_all(b"set-cookie"): v = cookies.parse_set_cookie_header(header) if v: name, value, attrs = v @@ -559,34 +534,4 @@ class Response(object): i[1][1] ) ) - self.headers.set_all("Set-Cookie", values) - - @property - def body(self): - return self._body - - @body.setter - def body(self, body): - self._body = body - if isinstance(body, bytes): - self.headers["Content-Length"] = str(len(body)).encode() - - @property - def content(self): # pragma: no cover - # TODO: remove deprecated getter - return self.body - - @content.setter - def content(self, content): # pragma: no cover - # TODO: remove deprecated setter - self.body = content - - @property - def code(self): # pragma: no cover - # TODO: remove deprecated getter - return self.status_code - - @code.setter - def code(self, code): # pragma: no cover - # TODO: remove deprecated setter - self.status_code = code + self.headers.set_all(b"Set-Cookie", values) diff --git a/netlib/tutils.py b/netlib/tutils.py index 05791c49..b69495a3 100644 --- a/netlib/tutils.py +++ b/netlib/tutils.py @@ -105,7 +105,7 @@ def treq(**kwargs): host=b"address", port=22, path=b"/path", - httpversion=b"HTTP/1.1", + http_version=b"HTTP/1.1", headers=Headers(header=b"qvalue"), body=b"content" ) @@ -119,7 +119,7 @@ def tresp(**kwargs): netlib.http.Response """ default = dict( - httpversion=b"HTTP/1.1", + http_version=b"HTTP/1.1", status_code=200, msg=b"OK", headers=Headers(header_response=b"svalue"), diff --git a/netlib/utils.py b/netlib/utils.py index a86b8019..14b428d7 100644 --- a/netlib/utils.py +++ b/netlib/utils.py @@ -17,11 +17,6 @@ def isascii(bytes): return True -# best way to do it in python 2.x -def bytes_to_int(i): - return int(i.encode('hex'), 16) - - def clean_bin(s, keep_spacing=True): """ Cleans binary data to make it safe to display. @@ -51,21 +46,15 @@ def clean_bin(s, keep_spacing=True): def hexdump(s): """ - Returns a set of tuples: - (offset, hex, str) + Returns: + A generator of (offset, hex, str) tuples """ - parts = [] for i in range(0, len(s), 16): - o = "%.10x" % i + offset = b"%.10x" % i part = s[i:i + 16] - x = " ".join("%.2x" % ord(i) for i in part) - if len(part) < 16: - x += " " - x += " ".join(" " for i in range(16 - len(part))) - parts.append( - (o, x, clean_bin(part, False)) - ) - return parts + x = b" ".join(b"%.2x" % i for i in six.iterbytes(part)) + x = x.ljust(47) # 16*2 + 15 + yield (offset, x, clean_bin(part, False)) def setbit(byte, offset, value): @@ -80,8 +69,7 @@ def setbit(byte, offset, value): def getbit(byte, offset): mask = 1 << offset - if byte & mask: - return True + return bool(byte & mask) class BiDi(object): @@ -159,7 +147,7 @@ def is_valid_host(host): return False if len(host) > 255: return False - if host[-1] == ".": + if host[-1] == b".": host = host[:-1] return all(_label_valid.match(x) for x in host.split(b".")) @@ -248,7 +236,7 @@ def hostport(scheme, host, port): """ Returns the host component, with a port specifcation if needed. """ - if (port, scheme) in [(80, "http"), (443, "https")]: + if (port, scheme) in [(80, b"http"), (443, b"https")]: return host else: return b"%s:%d" % (host, port) diff --git a/netlib/websockets/frame.py b/netlib/websockets/frame.py index e3ff1405..ceddd273 100644 --- a/netlib/websockets/frame.py +++ b/netlib/websockets/frame.py @@ -2,6 +2,7 @@ from __future__ import absolute_import import os import struct import io +import six from .protocol import Masker from netlib import tcp @@ -127,8 +128,8 @@ class FrameHeader(object): """ read a websockets frame header """ - first_byte = utils.bytes_to_int(fp.safe_read(1)) - second_byte = utils.bytes_to_int(fp.safe_read(1)) + first_byte = six.byte2int(fp.safe_read(1)) + second_byte = six.byte2int(fp.safe_read(1)) fin = utils.getbit(first_byte, 7) rsv1 = utils.getbit(first_byte, 6) @@ -145,9 +146,9 @@ class FrameHeader(object): if length_code <= 125: payload_length = length_code elif length_code == 126: - payload_length = utils.bytes_to_int(fp.safe_read(2)) + payload_length, = struct.unpack("!H", fp.safe_read(2)) elif length_code == 127: - payload_length = utils.bytes_to_int(fp.safe_read(8)) + payload_length, = struct.unpack("!Q", fp.safe_read(8)) # masking key only present if mask bit set if mask_bit == 1: diff --git a/netlib/websockets/protocol.py b/netlib/websockets/protocol.py index 46c02875..68d827a5 100644 --- a/netlib/websockets/protocol.py +++ b/netlib/websockets/protocol.py @@ -17,6 +17,7 @@ from __future__ import absolute_import import base64 import hashlib import os +import six from ..http import Headers from .. import utils @@ -40,7 +41,7 @@ class Masker(object): def __init__(self, key): self.key = key - self.masks = [utils.bytes_to_int(byte) for byte in key] + self.masks = [six.byte2int(byte) for byte in key] self.offset = 0 def mask(self, offset, data): -- cgit v1.2.3 From d798ed955dab4681a5285024b3648b1a3f13c24e Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Thu, 17 Sep 2015 16:31:50 +0200 Subject: python3++ --- netlib/encoding.py | 20 ++++++++++++-------- netlib/http/models.py | 48 ++++++++++++++++++++++++------------------------ netlib/odict.py | 25 +++---------------------- netlib/tutils.py | 4 +--- netlib/utils.py | 22 +++++++++++----------- 5 files changed, 51 insertions(+), 68 deletions(-) (limited to 'netlib') diff --git a/netlib/encoding.py b/netlib/encoding.py index 06830f2c..8ac59905 100644 --- a/netlib/encoding.py +++ b/netlib/encoding.py @@ -5,28 +5,30 @@ from __future__ import absolute_import from io import BytesIO import gzip import zlib +from .utils import always_byte_args -__ALL__ = ["ENCODINGS"] -ENCODINGS = {"identity", "gzip", "deflate"} +ENCODINGS = {b"identity", b"gzip", b"deflate"} +@always_byte_args("ascii", "ignore") def decode(e, content): encoding_map = { - "identity": identity, - "gzip": decode_gzip, - "deflate": decode_deflate, + b"identity": identity, + b"gzip": decode_gzip, + b"deflate": decode_deflate, } if e not in encoding_map: return None return encoding_map[e](content) +@always_byte_args("ascii", "ignore") def encode(e, content): encoding_map = { - "identity": identity, - "gzip": encode_gzip, - "deflate": encode_deflate, + b"identity": identity, + b"gzip": encode_gzip, + b"deflate": encode_deflate, } if e not in encoding_map: return None @@ -80,3 +82,5 @@ def encode_deflate(content): Returns compressed content, always including zlib header and checksum. """ return zlib.compress(content) + +__all__ = ["ENCODINGS", "encode", "decode"] diff --git a/netlib/http/models.py b/netlib/http/models.py index 54b8b112..bc681de3 100644 --- a/netlib/http/models.py +++ b/netlib/http/models.py @@ -136,7 +136,7 @@ class Headers(MutableMapping, object): def __len__(self): return len(set(name.lower() for name, _ in self.fields)) - #__hash__ = object.__hash__ + # __hash__ = object.__hash__ def _index(self, name): name = name.lower() @@ -227,11 +227,11 @@ class Request(Message): # This list is adopted legacy code. # We probably don't need to strip off keep-alive. _headers_to_strip_off = [ - b'Proxy-Connection', - b'Keep-Alive', - b'Connection', - b'Transfer-Encoding', - b'Upgrade', + 'Proxy-Connection', + 'Keep-Alive', + 'Connection', + 'Transfer-Encoding', + 'Upgrade', ] def __init__( @@ -275,8 +275,8 @@ class Request(Message): response. That is, we remove ETags and If-Modified-Since headers. """ delheaders = [ - b"if-modified-since", - b"if-none-match", + b"If-Modified-Since", + b"If-None-Match", ] for i in delheaders: self.headers.pop(i, None) @@ -286,16 +286,16 @@ class Request(Message): Modifies this request to remove headers that will compress the resource's data. """ - self.headers[b"accept-encoding"] = b"identity" + self.headers["Accept-Encoding"] = b"identity" def constrain_encoding(self): """ Limits the permissible Accept-Encoding values, based on what we can decode appropriately. """ - accept_encoding = self.headers.get(b"accept-encoding") + accept_encoding = self.headers.get(b"Accept-Encoding") if accept_encoding: - self.headers[b"accept-encoding"] = ( + self.headers["Accept-Encoding"] = ( ', '.join( e for e in encoding.ENCODINGS @@ -316,9 +316,9 @@ class Request(Message): indicates non-form data. """ if self.body: - if HDR_FORM_URLENCODED in self.headers.get("content-type","").lower(): + if HDR_FORM_URLENCODED in self.headers.get("Content-Type", "").lower(): return self.get_form_urlencoded() - elif HDR_FORM_MULTIPART in self.headers.get("content-type","").lower(): + elif HDR_FORM_MULTIPART in self.headers.get("Content-Type", "").lower(): return self.get_form_multipart() return ODict([]) @@ -328,12 +328,12 @@ class Request(Message): Returns an empty ODict if there is no data or the content-type indicates non-form data. """ - if self.body and HDR_FORM_URLENCODED in self.headers.get("content-type","").lower(): + if self.body and HDR_FORM_URLENCODED in self.headers.get("Content-Type", "").lower(): return ODict(utils.urldecode(self.body)) return ODict([]) def get_form_multipart(self): - if self.body and HDR_FORM_MULTIPART in self.headers.get("content-type","").lower(): + if self.body and HDR_FORM_MULTIPART in self.headers.get("Content-Type", "").lower(): return ODict( utils.multipartdecode( self.headers, @@ -405,9 +405,9 @@ class Request(Message): but not the resolved name. This is disabled by default, as an attacker may spoof the host header to confuse an analyst. """ - if hostheader and b"Host" in self.headers: + if hostheader and "Host" in self.headers: try: - return self.headers[b"Host"].decode("idna") + return self.headers["Host"].decode("idna") except ValueError: pass if self.host: @@ -426,7 +426,7 @@ class Request(Message): Returns a possibly empty netlib.odict.ODict object. """ ret = ODict() - for i in self.headers.get_all("cookie"): + for i in self.headers.get_all("Cookie"): ret.extend(cookies.parse_cookie_header(i)) return ret @@ -468,9 +468,9 @@ class Request(Message): class Response(Message): _headers_to_strip_off = [ - b'Proxy-Connection', - b'Alternate-Protocol', - b'Alt-Svc', + 'Proxy-Connection', + 'Alternate-Protocol', + 'Alt-Svc', ] def __init__( @@ -498,7 +498,7 @@ class Response(Message): return "".format( status_code=self.status_code, msg=self.msg, - contenttype=self.headers.get("content-type", "unknown content type"), + contenttype=self.headers.get("Content-Type", "unknown content type"), size=size) def get_cookies(self): @@ -511,7 +511,7 @@ class Response(Message): attributes (e.g. HTTPOnly) are indicated by a Null value. """ ret = [] - for header in self.headers.get_all(b"set-cookie"): + for header in self.headers.get_all("Set-Cookie"): v = cookies.parse_set_cookie_header(header) if v: name, value, attrs = v @@ -534,4 +534,4 @@ class Response(Message): i[1][1] ) ) - self.headers.set_all(b"Set-Cookie", values) + self.headers.set_all("Set-Cookie", values) diff --git a/netlib/odict.py b/netlib/odict.py index 11d5d52a..1124b23a 100644 --- a/netlib/odict.py +++ b/netlib/odict.py @@ -1,6 +1,7 @@ from __future__ import (absolute_import, print_function, division) import re import copy +import six def safe_subn(pattern, repl, target, *args, **kwargs): @@ -67,10 +68,10 @@ class ODict(object): Sets the values for key k. If there are existing values for this key, they are cleared. """ - if isinstance(valuelist, basestring): + if isinstance(valuelist, six.text_type) or isinstance(valuelist, six.binary_type): raise ValueError( "Expected list of values instead of string. " - "Example: odict['Host'] = ['www.example.com']" + "Example: odict[b'Host'] = [b'www.example.com']" ) kc = self._kconv(k) new = [] @@ -134,13 +135,6 @@ class ODict(object): def __repr__(self): return repr(self.lst) - def format(self): - elements = [] - for itm in self.lst: - elements.append(itm[0] + ": " + str(itm[1])) - elements.append("") - return "\r\n".join(elements) - def in_any(self, key, value, caseless=False): """ Do any of the values matching key contain value? @@ -156,19 +150,6 @@ class ODict(object): return True return False - def match_re(self, expr): - """ - Match the regular expression against each (key, value) pair. For - each pair a string of the following format is matched against: - - "key: value" - """ - for k, v in self.lst: - s = "%s: %s" % (k, v) - if re.search(expr, s): - return True - return False - def replace(self, pattern, repl, *args, **kwargs): """ Replaces a regular expression pattern with repl in both keys and diff --git a/netlib/tutils.py b/netlib/tutils.py index b69495a3..746e1488 100644 --- a/netlib/tutils.py +++ b/netlib/tutils.py @@ -123,9 +123,7 @@ def tresp(**kwargs): status_code=200, msg=b"OK", headers=Headers(header_response=b"svalue"), - body=b"message", - timestamp_start=time.time(), - timestamp_end=time.time() + body=b"message" ) default.update(kwargs) return Response(**default) diff --git a/netlib/utils.py b/netlib/utils.py index 14b428d7..6fed44b6 100644 --- a/netlib/utils.py +++ b/netlib/utils.py @@ -246,7 +246,7 @@ def unparse_url(scheme, host, port, path=""): """ Returns a URL string, constructed from the specified compnents. """ - return "%s://%s%s" % (scheme, hostport(scheme, host, port), path) + return b"%s://%s%s" % (scheme, hostport(scheme, host, port), path) def urlencode(s): @@ -295,7 +295,7 @@ def multipartdecode(headers, content): """ Takes a multipart boundary encoded string and returns list of (key, value) tuples. """ - v = headers.get("content-type") + v = headers.get(b"Content-Type") if v: v = parse_content_type(v) if not v: @@ -304,33 +304,33 @@ def multipartdecode(headers, content): if not boundary: return [] - rx = re.compile(r'\bname="([^"]+)"') + rx = re.compile(br'\bname="([^"]+)"') r = [] - for i in content.split("--" + boundary): + for i in content.split(b"--" + boundary): parts = i.splitlines() - if len(parts) > 1 and parts[0][0:2] != "--": + if len(parts) > 1 and parts[0][0:2] != b"--": match = rx.search(parts[1]) if match: key = match.group(1) - value = "".join(parts[3 + parts[2:].index(""):]) + value = b"".join(parts[3 + parts[2:].index(b""):]) r.append((key, value)) return r return [] -def always_bytes(unicode_or_bytes, encoding): +def always_bytes(unicode_or_bytes, *encode_args): if isinstance(unicode_or_bytes, six.text_type): - return unicode_or_bytes.encode(encoding) + return unicode_or_bytes.encode(*encode_args) return unicode_or_bytes -def always_byte_args(encoding): +def always_byte_args(*encode_args): """Decorator that transparently encodes all arguments passed as unicode""" def decorator(fun): def _fun(*args, **kwargs): - args = [always_bytes(arg, encoding) for arg in args] - kwargs = {k: always_bytes(v, encoding) for k, v in six.iteritems(kwargs)} + args = [always_bytes(arg, *encode_args) for arg in args] + kwargs = {k: always_bytes(v, *encode_args) for k, v in six.iteritems(kwargs)} return fun(*args, **kwargs) return _fun return decorator -- cgit v1.2.3 From 266b80238db34cfa91f9018c951394492bbde593 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Thu, 17 Sep 2015 17:29:55 +0200 Subject: fix tests --- netlib/tutils.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) (limited to 'netlib') diff --git a/netlib/tutils.py b/netlib/tutils.py index 746e1488..4903d63b 100644 --- a/netlib/tutils.py +++ b/netlib/tutils.py @@ -123,7 +123,9 @@ def tresp(**kwargs): status_code=200, msg=b"OK", headers=Headers(header_response=b"svalue"), - body=b"message" + body=b"message", + timestamp_start=time.time(), + timestamp_end=time.time(), ) default.update(kwargs) return Response(**default) -- cgit v1.2.3 From 7b6b15754754b45552d0872d36f3f30f5fa1a783 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Fri, 18 Sep 2015 15:35:02 +0200 Subject: properly handle SNI IPs fixes mitmproxy/mitmproxy#772 We must use the ipaddress package here, because that's what cryptography uses. If we opt for something else, we have nasty namespace conflicts. --- netlib/certutils.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) (limited to 'netlib') diff --git a/netlib/certutils.py b/netlib/certutils.py index cc143a50..c3b795ac 100644 --- a/netlib/certutils.py +++ b/netlib/certutils.py @@ -4,6 +4,7 @@ import ssl import time import datetime import itertools +import ipaddress from pyasn1.type import univ, constraint, char, namedtype, tag from pyasn1.codec.der.decoder import decode from pyasn1.error import PyAsn1Error @@ -85,8 +86,13 @@ def dummy_cert(privkey, cacert, commonname, sans): """ ss = [] for i in sans: - ss.append("DNS: %s" % i) - ss = ", ".join(ss) + try: + ipaddress.ip_address(i.decode("ascii")) + except ValueError: + ss.append(b"DNS: %s" % i) + else: + ss.append(b"IP: %s" % i) + ss = b", ".join(ss) cert = OpenSSL.crypto.X509() cert.gmtime_adj_notBefore(-3600 * 48) @@ -335,6 +341,7 @@ class CertStore(object): class _GeneralName(univ.Choice): # We are only interested in dNSNames. We use a default handler to ignore # other types. + # TODO: We should also handle iPAddresses. componentType = namedtype.NamedTypes( namedtype.NamedType('dNSName', char.IA5String().subtype( implicitTag=tag.Tag(tag.tagClassContext, tag.tagFormatSimple, 2) -- cgit v1.2.3 From d1904c2f52dfc7409ae275bb081f23635c94acc9 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Fri, 18 Sep 2015 15:38:31 +0200 Subject: python3++ --- netlib/certutils.py | 40 ++++++++++++++++++++-------------------- 1 file changed, 20 insertions(+), 20 deletions(-) (limited to 'netlib') diff --git a/netlib/certutils.py b/netlib/certutils.py index c3b795ac..9193b757 100644 --- a/netlib/certutils.py +++ b/netlib/certutils.py @@ -12,7 +12,7 @@ import OpenSSL DEFAULT_EXP = 157680000 # = 24 * 60 * 60 * 365 * 5 # Generated with "openssl dhparam". It's too slow to generate this on startup. -DEFAULT_DHPARAM = """ +DEFAULT_DHPARAM = b""" -----BEGIN DH PARAMETERS----- MIICCAKCAgEAyT6LzpwVFS3gryIo29J5icvgxCnCebcdSe/NHMkD8dKJf8suFCg3 O2+dguLakSVif/t6dhImxInJk230HmfC8q93hdcg/j8rLGJYDKu3ik6H//BAHKIv @@ -43,29 +43,29 @@ def create_ca(o, cn, exp): cert.set_pubkey(key) cert.add_extensions([ OpenSSL.crypto.X509Extension( - "basicConstraints", + b"basicConstraints", True, - "CA:TRUE" + b"CA:TRUE" ), OpenSSL.crypto.X509Extension( - "nsCertType", + b"nsCertType", False, - "sslCA" + b"sslCA" ), OpenSSL.crypto.X509Extension( - "extendedKeyUsage", + b"extendedKeyUsage", False, - "serverAuth,clientAuth,emailProtection,timeStamping,msCodeInd,msCodeCom,msCTLSign,msSGC,msEFS,nsSGC" + b"serverAuth,clientAuth,emailProtection,timeStamping,msCodeInd,msCodeCom,msCTLSign,msSGC,msEFS,nsSGC" ), OpenSSL.crypto.X509Extension( - "keyUsage", + b"keyUsage", True, - "keyCertSign, cRLSign" + b"keyCertSign, cRLSign" ), OpenSSL.crypto.X509Extension( - "subjectKeyIdentifier", + b"subjectKeyIdentifier", False, - "hash", + b"hash", subject=cert ), ]) @@ -103,7 +103,7 @@ def dummy_cert(privkey, cacert, commonname, sans): if ss: cert.set_version(2) cert.add_extensions( - [OpenSSL.crypto.X509Extension("subjectAltName", False, ss)]) + [OpenSSL.crypto.X509Extension(b"subjectAltName", False, ss)]) cert.set_pubkey(cacert.get_pubkey()) cert.sign(privkey, "sha256") return SSLCert(cert) @@ -291,14 +291,14 @@ class CertStore(object): @staticmethod def asterisk_forms(dn): - parts = dn.split(".") + parts = dn.split(b".") parts.reverse() - curr_dn = "" - dn_forms = ["*"] + curr_dn = b"" + dn_forms = [b"*"] for part in parts[:-1]: - curr_dn = "." + part + curr_dn # .example.com - dn_forms.append("*" + curr_dn) # *.example.com - if parts[-1] != "*": + curr_dn = b"." + part + curr_dn # .example.com + dn_forms.append(b"*" + curr_dn) # *.example.com + if parts[-1] != b"*": dn_forms.append(parts[-1] + curr_dn) return dn_forms @@ -430,7 +430,7 @@ class SSLCert(object): def cn(self): c = None for i in self.subject: - if i[0] == "CN": + if i[0] == b"CN": c = i[1] return c @@ -439,7 +439,7 @@ class SSLCert(object): altnames = [] for i in range(self.x509.get_extension_count()): ext = self.x509.get_extension(i) - if ext.get_short_name() == "subjectAltName": + if ext.get_short_name() == b"subjectAltName": try: dec = decode(ext.get_data(), asn1Spec=_GeneralNames()) except PyAsn1Error: -- cgit v1.2.3 From 551d9f11e571eac495674f1c23cfd0dfa8af2cb7 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Fri, 18 Sep 2015 18:05:50 +0200 Subject: experimental: don't interfere with headers --- netlib/http/http1/assemble.py | 20 +++++--------------- netlib/http/models.py | 21 ++++----------------- 2 files changed, 9 insertions(+), 32 deletions(-) (limited to 'netlib') diff --git a/netlib/http/http1/assemble.py b/netlib/http/http1/assemble.py index b65a6be0..c2b60a0f 100644 --- a/netlib/http/http1/assemble.py +++ b/netlib/http/http1/assemble.py @@ -10,7 +10,8 @@ def assemble_request(request): if request.body == CONTENT_MISSING: raise HttpException("Cannot assemble flow with CONTENT_MISSING") head = assemble_request_head(request) - return head + request.body + body = b"".join(assemble_body(request.headers, [request.body])) + return head + body def assemble_request_head(request): @@ -23,7 +24,8 @@ def assemble_response(response): if response.body == CONTENT_MISSING: raise HttpException("Cannot assemble flow with CONTENT_MISSING") head = assemble_response_head(response) - return head + response.body + body = b"".join(assemble_body(response.headers, [response.body])) + return head + body def assemble_response_head(response): @@ -74,20 +76,12 @@ def _assemble_request_line(request, form=None): def _assemble_request_headers(request): headers = request.headers.copy() - for k in request._headers_to_strip_off: - headers.pop(k, None) if b"host" not in headers and request.scheme and request.host and request.port: headers[b"Host"] = utils.hostport( request.scheme, request.host, request.port ) - - # If content is defined (i.e. not None or CONTENT_MISSING), we always - # add a content-length header. - if request.body or request.body == b"": - headers[b"Content-Length"] = str(len(request.body)).encode("ascii") - return bytes(headers) @@ -100,8 +94,4 @@ def _assemble_response_line(response): def _assemble_response_headers(response): - headers = response.headers.copy() - for k in response._headers_to_strip_off: - headers.pop(k, None) - - return bytes(headers) + return bytes(response.headers) diff --git a/netlib/http/models.py b/netlib/http/models.py index bc681de3..ff854b13 100644 --- a/netlib/http/models.py +++ b/netlib/http/models.py @@ -92,7 +92,10 @@ class Headers(MutableMapping, object): self.update(headers) def __bytes__(self): - return b"\r\n".join(b": ".join(field) for field in self.fields) + b"\r\n" + if self.fields: + return b"\r\n".join(b": ".join(field) for field in self.fields) + b"\r\n" + else: + return b"" if six.PY2: __str__ = __bytes__ @@ -224,16 +227,6 @@ class Message(object): class Request(Message): - # This list is adopted legacy code. - # We probably don't need to strip off keep-alive. - _headers_to_strip_off = [ - 'Proxy-Connection', - 'Keep-Alive', - 'Connection', - 'Transfer-Encoding', - 'Upgrade', - ] - def __init__( self, form_in, @@ -467,12 +460,6 @@ class Request(Message): class Response(Message): - _headers_to_strip_off = [ - 'Proxy-Connection', - 'Alternate-Protocol', - 'Alt-Svc', - ] - def __init__( self, http_version, -- cgit v1.2.3 From 91cdd78201497e89b9a17275a484d461f0143137 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Sat, 19 Sep 2015 11:59:40 +0200 Subject: improve http error messages --- netlib/http/http1/read.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) (limited to 'netlib') diff --git a/netlib/http/http1/read.py b/netlib/http/http1/read.py index 7f2b7bab..c6760ff3 100644 --- a/netlib/http/http1/read.py +++ b/netlib/http/http1/read.py @@ -215,9 +215,9 @@ def _get_first_line(rfile): # Possible leftover from previous message line = rfile.readline() except TcpDisconnect: - raise HttpReadDisconnect() + raise HttpReadDisconnect("Remote disconnected") if not line: - raise HttpReadDisconnect() + raise HttpReadDisconnect("Remote disconnected") line = line.strip() try: line.decode("ascii") @@ -227,7 +227,11 @@ def _get_first_line(rfile): def _read_request_line(rfile): - line = _get_first_line(rfile) + try: + line = _get_first_line(rfile) + except HttpReadDisconnect: + # We want to provide a better error message. + raise HttpReadDisconnect("Client disconnected") try: method, path, http_version = line.split(b" ") @@ -270,7 +274,11 @@ def _parse_authority_form(hostport): def _read_response_line(rfile): - line = _get_first_line(rfile) + try: + line = _get_first_line(rfile) + except HttpReadDisconnect: + # We want to provide a better error message. + raise HttpReadDisconnect("Server disconnected") try: -- cgit v1.2.3 From 3f1ca556d14ce71331b8dbc69be4db670863271a Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Sun, 20 Sep 2015 18:12:55 +0200 Subject: python3++ --- netlib/certutils.py | 11 ++++++----- netlib/tcp.py | 4 ++-- netlib/wsgi.py | 31 +++++++++++++++++-------------- 3 files changed, 25 insertions(+), 21 deletions(-) (limited to 'netlib') diff --git a/netlib/certutils.py b/netlib/certutils.py index 9193b757..df793537 100644 --- a/netlib/certutils.py +++ b/netlib/certutils.py @@ -5,6 +5,8 @@ import time import datetime import itertools import ipaddress + +import sys from pyasn1.type import univ, constraint, char, namedtype, tag from pyasn1.codec.der.decoder import decode from pyasn1.error import PyAsn1Error @@ -184,7 +186,7 @@ class CertStore(object): with open(path, "wb") as f: f.write(DEFAULT_DHPARAM) - bio = OpenSSL.SSL._lib.BIO_new_file(path, b"r") + bio = OpenSSL.SSL._lib.BIO_new_file(path.encode(sys.getfilesystemencoding()), b"r") if bio != OpenSSL.SSL._ffi.NULL: bio = OpenSSL.SSL._ffi.gc(bio, OpenSSL.SSL._lib.BIO_free) dh = OpenSSL.SSL._lib.PEM_read_bio_DHparams( @@ -318,10 +320,9 @@ class CertStore(object): potential_keys.append((commonname, tuple(sans))) name = next( - itertools.ifilter( - lambda key: key in self.certs, - potential_keys), - None) + filter(lambda key: key in self.certs, potential_keys), + None + ) if name: entry = self.certs[name] else: diff --git a/netlib/tcp.py b/netlib/tcp.py index 707e11e0..6dcc8c72 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -76,7 +76,7 @@ class SSLKeyLogger(object): d = os.path.dirname(self.filename) if not os.path.isdir(d): os.makedirs(d) - self.f = open(self.filename, "ab") + self.f = open(self.filename, "a") self.f.write("\r\n") client_random = connection.client_random().encode("hex") masterkey = connection.master_key().encode("hex") @@ -184,7 +184,7 @@ class Reader(_FileLike): """ If length is -1, we read until connection closes. """ - result = '' + result = b'' start = time.time() while length == -1 or length > 0: if length == -1 or length > self.BLOCKSIZE: diff --git a/netlib/wsgi.py b/netlib/wsgi.py index 8a98884a..fba9f388 100644 --- a/netlib/wsgi.py +++ b/netlib/wsgi.py @@ -1,8 +1,11 @@ from __future__ import (absolute_import, print_function, division) -import cStringIO +from io import BytesIO import urllib import time import traceback + +import six + from . import http, tcp @@ -58,7 +61,7 @@ class WSGIAdaptor(object): environ = { 'wsgi.version': (1, 0), 'wsgi.url_scheme': flow.request.scheme, - 'wsgi.input': cStringIO.StringIO(flow.request.body or ""), + 'wsgi.input': BytesIO(flow.request.body or b""), 'wsgi.errors': errsoc, 'wsgi.multithread': True, 'wsgi.multiprocess': False, @@ -91,17 +94,17 @@ class WSGIAdaptor(object): Make a best-effort attempt to write an error page. If headers are already sent, we just bung the error into the page. """ - c = """ + c = b"""

Internal Server Error

%s"
- """ % s + """.strip() % s if not headers_sent: - soc.write("HTTP/1.1 500 Internal Server Error\r\n") - soc.write("Content-Type: text/html\r\n") - soc.write("Content-Length: %s\r\n" % len(c)) - soc.write("\r\n") + soc.write(b"HTTP/1.1 500 Internal Server Error\r\n") + soc.write(b"Content-Type: text/html\r\n") + soc.write(b"Content-Length: %s\r\n" % len(c)) + soc.write(b"\r\n") soc.write(c) def serve(self, request, soc, **env): @@ -114,14 +117,14 @@ class WSGIAdaptor(object): def write(data): if not state["headers_sent"]: - soc.write("HTTP/1.1 %s\r\n" % state["status"]) + soc.write(b"HTTP/1.1 %s\r\n" % state["status"]) headers = state["headers"] if 'server' not in headers: headers["Server"] = self.sversion if 'date' not in headers: headers["Date"] = date_time_string() - soc.write(str(headers)) - soc.write("\r\n") + soc.write(bytes(headers)) + soc.write(b"\r\n") state["headers_sent"] = True if data: soc.write(data) @@ -131,7 +134,7 @@ class WSGIAdaptor(object): if exc_info: try: if state["headers_sent"]: - raise exc_info[0], exc_info[1], exc_info[2] + six.reraise(*exc_info) finally: exc_info = None elif state["status"]: @@ -140,7 +143,7 @@ class WSGIAdaptor(object): state["headers"] = http.Headers(headers) return write - errs = cStringIO.StringIO() + errs = BytesIO() try: dataiter = self.app( self.make_environ(request, errs, **env), start_response @@ -148,7 +151,7 @@ class WSGIAdaptor(object): for i in dataiter: write(i) if not state["headers_sent"]: - write("") + write(b"") except Exception as e: try: s = traceback.format_exc() -- cgit v1.2.3 From 693cdfc6d75e460a00585ccc9b734b80d6eba74d Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Sun, 20 Sep 2015 19:40:09 +0200 Subject: python3++ --- netlib/certutils.py | 6 +++--- netlib/socks.py | 22 +++++++++++++--------- netlib/utils.py | 6 ++++++ 3 files changed, 22 insertions(+), 12 deletions(-) (limited to 'netlib') diff --git a/netlib/certutils.py b/netlib/certutils.py index df793537..b3ddcbe4 100644 --- a/netlib/certutils.py +++ b/netlib/certutils.py @@ -3,7 +3,7 @@ import os import ssl import time import datetime -import itertools +from six.moves import filter import ipaddress import sys @@ -396,12 +396,12 @@ class SSLCert(object): @property def notbefore(self): t = self.x509.get_notBefore() - return datetime.datetime.strptime(t, "%Y%m%d%H%M%SZ") + return datetime.datetime.strptime(t.decode("ascii"), "%Y%m%d%H%M%SZ") @property def notafter(self): t = self.x509.get_notAfter() - return datetime.datetime.strptime(t, "%Y%m%d%H%M%SZ") + return datetime.datetime.strptime(t.decode("ascii"), "%Y%m%d%H%M%SZ") @property def has_expired(self): diff --git a/netlib/socks.py b/netlib/socks.py index d38b88c8..51ad1c63 100644 --- a/netlib/socks.py +++ b/netlib/socks.py @@ -1,7 +1,7 @@ from __future__ import (absolute_import, print_function, division) -import socket import struct import array +import ipaddress from . import tcp, utils @@ -133,19 +133,23 @@ class Message(object): def from_file(cls, f): ver, msg, rsv, atyp = struct.unpack("!BBBB", f.safe_read(4)) if rsv != 0x00: - raise SocksError(REP.GENERAL_SOCKS_SERVER_FAILURE, - "Socks Request: Invalid reserved byte: %s" % rsv) - + raise SocksError( + REP.GENERAL_SOCKS_SERVER_FAILURE, + "Socks Request: Invalid reserved byte: %s" % rsv + ) if atyp == ATYP.IPV4_ADDRESS: # We use tnoa here as ntop is not commonly available on Windows. - host = socket.inet_ntoa(f.safe_read(4)) + host = ipaddress.IPv4Address(f.safe_read(4)).compressed use_ipv6 = False elif atyp == ATYP.IPV6_ADDRESS: - host = socket.inet_ntop(socket.AF_INET6, f.safe_read(16)) + host = ipaddress.IPv6Address(f.safe_read(16)).compressed use_ipv6 = True elif atyp == ATYP.DOMAINNAME: length, = struct.unpack("!B", f.safe_read(1)) host = f.safe_read(length) + if not utils.is_valid_host(host): + raise SocksError(REP.GENERAL_SOCKS_SERVER_FAILURE, "Invalid hostname: %s" % host) + host = host.decode("idna") use_ipv6 = False else: raise SocksError(REP.ADDRESS_TYPE_NOT_SUPPORTED, @@ -158,12 +162,12 @@ class Message(object): def to_file(self, f): f.write(struct.pack("!BBBB", self.ver, self.msg, 0x00, self.atyp)) if self.atyp == ATYP.IPV4_ADDRESS: - f.write(socket.inet_aton(self.addr.host)) + f.write(ipaddress.IPv4Address(self.addr.host).packed) elif self.atyp == ATYP.IPV6_ADDRESS: - f.write(socket.inet_pton(socket.AF_INET6, self.addr.host)) + f.write(ipaddress.IPv6Address(self.addr.host).packed) elif self.atyp == ATYP.DOMAINNAME: f.write(struct.pack("!B", len(self.addr.host))) - f.write(self.addr.host) + f.write(self.addr.host.encode("idna")) else: raise SocksError( REP.ADDRESS_TYPE_NOT_SUPPORTED, diff --git a/netlib/utils.py b/netlib/utils.py index 6fed44b6..799b0d42 100644 --- a/netlib/utils.py +++ b/netlib/utils.py @@ -141,6 +141,12 @@ _label_valid = re.compile(b"(?!-)[A-Z\d-]{1,63}(? Date: Sun, 20 Sep 2015 19:56:45 +0200 Subject: python3++ --- netlib/tcp.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) (limited to 'netlib') diff --git a/netlib/tcp.py b/netlib/tcp.py index 6dcc8c72..f6f7d06f 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -7,6 +7,8 @@ import threading import time import traceback +from six.moves import range + import certifi import six import OpenSSL @@ -227,7 +229,7 @@ class Reader(_FileLike): return result def readline(self, size=None): - result = '' + result = b'' bytes_read = 0 while True: if size is not None and bytes_read >= size: @@ -399,7 +401,7 @@ def close_socket(sock): sock.settimeout(sock.gettimeout() or 20) # limit at a megabyte so that we don't read infinitely - for _ in xrange(1024 ** 3 // 4096): + for _ in range(1024 ** 3 // 4096): # may raise a timeout/disconnect exception. if not sock.recv(4096): break @@ -649,7 +651,7 @@ class TCPClient(_Connection): if OpenSSL._util.lib.Cryptography_HAS_ALPN and self.ssl_established: return self.connection.get_alpn_proto_negotiated() else: - return "" + return b"" class BaseHandler(_Connection): -- cgit v1.2.3 From daebd1bd275a398d42cc4dbfe5c6399c7fe3b3a0 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Sun, 20 Sep 2015 20:35:45 +0200 Subject: python3++ --- netlib/http/authentication.py | 4 ++-- netlib/tcp.py | 28 ++++++++++++---------------- 2 files changed, 14 insertions(+), 18 deletions(-) (limited to 'netlib') diff --git a/netlib/http/authentication.py b/netlib/http/authentication.py index 2055f843..5831660b 100644 --- a/netlib/http/authentication.py +++ b/netlib/http/authentication.py @@ -12,7 +12,7 @@ def parse_http_basic_auth(s): user = binascii.a2b_base64(words[1]) except binascii.Error: return None - parts = user.split(':') + parts = user.split(b':') if len(parts) != 2: return None return scheme, parts[0], parts[1] @@ -69,7 +69,7 @@ class BasicProxyAuth(NullProxyAuth): if not parts: return False scheme, username, password = parts - if scheme.lower() != 'basic': + if scheme.lower() != b'basic': return False if not self.password_manager.test(username, password): return False diff --git a/netlib/tcp.py b/netlib/tcp.py index f6f7d06f..40ffbd48 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -7,6 +7,7 @@ import threading import time import traceback +import binascii from six.moves import range import certifi @@ -78,14 +79,11 @@ class SSLKeyLogger(object): d = os.path.dirname(self.filename) if not os.path.isdir(d): os.makedirs(d) - self.f = open(self.filename, "a") - self.f.write("\r\n") - client_random = connection.client_random().encode("hex") - masterkey = connection.master_key().encode("hex") - self.f.write( - "CLIENT_RANDOM {} {}\r\n".format( - client_random, - masterkey)) + self.f = open(self.filename, "ab") + self.f.write(b"\r\n") + client_random = binascii.hexlify(connection.client_random()) + masterkey = binascii.hexlify(connection.master_key()) + self.f.write(b"CLIENT_RANDOM %s %s\r\n" % (client_random, masterkey)) self.f.flush() def close(self): @@ -140,7 +138,7 @@ class _FileLike(object): """ if not self.is_logging(): raise ValueError("Not logging!") - return "".join(self._log) + return b"".join(self._log) def add_log(self, v): if self.is_logging(): @@ -216,9 +214,9 @@ class Reader(_FileLike): except SSL.SysCallError as e: if e.args == (-1, 'Unexpected EOF'): break - raise TlsException(e.message) + raise TlsException(str(e)) except SSL.Error as e: - raise TlsException(e.message) + raise TlsException(str(e)) self.first_byte_timestamp = self.first_byte_timestamp or time.time() if not data: break @@ -240,7 +238,7 @@ class Reader(_FileLike): break else: result += ch - if ch == '\n': + if ch == b'\n': break return result @@ -757,7 +755,7 @@ class BaseHandler(_Connection): if OpenSSL._util.lib.Cryptography_HAS_ALPN and self.ssl_established: return self.connection.get_alpn_proto_negotiated() else: - return "" + return b"" class TCPServer(object): @@ -829,9 +827,7 @@ class TCPServer(object): exc = six.text_type(traceback.format_exc()) print(u'-' * 40, file=fp) print( - u"Error in processing of request from %s:%s" % ( - client_address.host, client_address.port - ), file=fp) + u"Error in processing of request from %s" % repr(client_address), file=fp) print(exc, file=fp) print(u'-' * 40, file=fp) -- cgit v1.2.3 From 73586b1be95d97f0be76e85223b53d1f4ed697d6 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Mon, 21 Sep 2015 00:44:17 +0200 Subject: python 3++ --- netlib/encoding.py | 16 ++++----- netlib/http/models.py | 30 ++++++++--------- netlib/tutils.py | 5 ++- netlib/utils.py | 53 +++++++++++++++++++---------- netlib/websockets/frame.py | 77 ++++++++++++++++++++++++++++--------------- netlib/websockets/protocol.py | 52 ++++++++++++++++------------- netlib/wsgi.py | 55 ++++++++++++++++--------------- 7 files changed, 168 insertions(+), 120 deletions(-) (limited to 'netlib') diff --git a/netlib/encoding.py b/netlib/encoding.py index 8ac59905..4c11273b 100644 --- a/netlib/encoding.py +++ b/netlib/encoding.py @@ -8,27 +8,25 @@ import zlib from .utils import always_byte_args -ENCODINGS = {b"identity", b"gzip", b"deflate"} +ENCODINGS = {"identity", "gzip", "deflate"} -@always_byte_args("ascii", "ignore") def decode(e, content): encoding_map = { - b"identity": identity, - b"gzip": decode_gzip, - b"deflate": decode_deflate, + "identity": identity, + "gzip": decode_gzip, + "deflate": decode_deflate, } if e not in encoding_map: return None return encoding_map[e](content) -@always_byte_args("ascii", "ignore") def encode(e, content): encoding_map = { - b"identity": identity, - b"gzip": encode_gzip, - b"deflate": encode_deflate, + "identity": identity, + "gzip": encode_gzip, + "deflate": encode_deflate, } if e not in encoding_map: return None diff --git a/netlib/http/models.py b/netlib/http/models.py index ff854b13..3c360a37 100644 --- a/netlib/http/models.py +++ b/netlib/http/models.py @@ -3,7 +3,7 @@ import copy from ..odict import ODict from .. import utils, encoding -from ..utils import always_bytes, always_byte_args +from ..utils import always_bytes, always_byte_args, native from . import cookies import six @@ -254,7 +254,7 @@ class Request(Message): def __repr__(self): if self.host and self.port: - hostport = "{}:{}".format(self.host, self.port) + hostport = "{}:{}".format(native(self.host,"idna"), self.port) else: hostport = "" path = self.path or "" @@ -279,14 +279,14 @@ class Request(Message): Modifies this request to remove headers that will compress the resource's data. """ - self.headers["Accept-Encoding"] = b"identity" + self.headers["Accept-Encoding"] = "identity" def constrain_encoding(self): """ Limits the permissible Accept-Encoding values, based on what we can decode appropriately. """ - accept_encoding = self.headers.get(b"Accept-Encoding") + accept_encoding = native(self.headers.get("Accept-Encoding"), "ascii") if accept_encoding: self.headers["Accept-Encoding"] = ( ', '.join( @@ -309,9 +309,9 @@ class Request(Message): indicates non-form data. """ if self.body: - if HDR_FORM_URLENCODED in self.headers.get("Content-Type", "").lower(): + if HDR_FORM_URLENCODED in self.headers.get("Content-Type", b"").lower(): return self.get_form_urlencoded() - elif HDR_FORM_MULTIPART in self.headers.get("Content-Type", "").lower(): + elif HDR_FORM_MULTIPART in self.headers.get("Content-Type", b"").lower(): return self.get_form_multipart() return ODict([]) @@ -321,12 +321,12 @@ class Request(Message): Returns an empty ODict if there is no data or the content-type indicates non-form data. """ - if self.body and HDR_FORM_URLENCODED in self.headers.get("Content-Type", "").lower(): + if self.body and HDR_FORM_URLENCODED in self.headers.get("Content-Type", b"").lower(): return ODict(utils.urldecode(self.body)) return ODict([]) def get_form_multipart(self): - if self.body and HDR_FORM_MULTIPART in self.headers.get("Content-Type", "").lower(): + if self.body and HDR_FORM_MULTIPART in self.headers.get("Content-Type", b"").lower(): return ODict( utils.multipartdecode( self.headers, @@ -351,7 +351,7 @@ class Request(Message): Components are unquoted. """ _, _, path, _, _, _ = urllib.parse.urlparse(self.url) - return [urllib.parse.unquote(i) for i in path.split(b"/") if i] + return [urllib.parse.unquote(native(i,"ascii")) for i in path.split(b"/") if i] def set_path_components(self, lst): """ @@ -360,7 +360,7 @@ class Request(Message): Components are quoted. """ lst = [urllib.parse.quote(i, safe="") for i in lst] - path = b"/" + b"/".join(lst) + path = always_bytes("/" + "/".join(lst)) scheme, netloc, _, params, query, fragment = urllib.parse.urlparse(self.url) self.url = urllib.parse.urlunparse( [scheme, netloc, path, params, query, fragment] @@ -408,11 +408,11 @@ class Request(Message): def pretty_url(self, hostheader): if self.form_out == "authority": # upstream proxy mode - return "%s:%s" % (self.pretty_host(hostheader), self.port) + return b"%s:%d" % (always_bytes(self.pretty_host(hostheader)), self.port) return utils.unparse_url(self.scheme, self.pretty_host(hostheader), self.port, - self.path).encode('ascii') + self.path) def get_cookies(self): """ @@ -420,7 +420,7 @@ class Request(Message): """ ret = ODict() for i in self.headers.get_all("Cookie"): - ret.extend(cookies.parse_cookie_header(i)) + ret.extend(cookies.parse_cookie_header(native(i,"ascii"))) return ret def set_cookies(self, odict): @@ -441,7 +441,7 @@ class Request(Message): self.host, self.port, self.path - ).encode('ascii') + ) @url.setter def url(self, url): @@ -499,7 +499,7 @@ class Response(Message): """ ret = [] for header in self.headers.get_all("Set-Cookie"): - v = cookies.parse_set_cookie_header(header) + v = cookies.parse_set_cookie_header(native(header, "ascii")) if v: name, value, attrs = v ret.append([name, [value, attrs]]) diff --git a/netlib/tutils.py b/netlib/tutils.py index 4903d63b..1665a792 100644 --- a/netlib/tutils.py +++ b/netlib/tutils.py @@ -7,7 +7,7 @@ from contextlib import contextmanager import six import sys -from . import utils +from . import utils, tcp from .http import Request, Response, Headers @@ -15,7 +15,6 @@ def treader(bytes): """ Construct a tcp.Read object from bytes. """ - from . import tcp # TODO: move to top once cryptography is on Python 3.5 fp = BytesIO(bytes) return tcp.Reader(fp) @@ -106,7 +105,7 @@ def treq(**kwargs): port=22, path=b"/path", http_version=b"HTTP/1.1", - headers=Headers(header=b"qvalue"), + headers=Headers(header="qvalue"), body=b"content" ) default.update(kwargs) diff --git a/netlib/utils.py b/netlib/utils.py index 799b0d42..8d11bd5b 100644 --- a/netlib/utils.py +++ b/netlib/utils.py @@ -9,6 +9,41 @@ import six from six.moves import urllib +def always_bytes(unicode_or_bytes, *encode_args): + if isinstance(unicode_or_bytes, six.text_type): + return unicode_or_bytes.encode(*encode_args) + return unicode_or_bytes + + +def always_byte_args(*encode_args): + """Decorator that transparently encodes all arguments passed as unicode""" + def decorator(fun): + def _fun(*args, **kwargs): + args = [always_bytes(arg, *encode_args) for arg in args] + kwargs = {k: always_bytes(v, *encode_args) for k, v in six.iteritems(kwargs)} + return fun(*args, **kwargs) + return _fun + return decorator + + +def native(s, encoding="latin-1"): + """ + Convert :py:class:`bytes` or :py:class:`unicode` to the native + :py:class:`str` type, using latin1 encoding if conversion is necessary. + + https://www.python.org/dev/peps/pep-3333/#a-note-on-string-types + """ + if not isinstance(s, (six.binary_type, six.text_type)): + raise TypeError("%r is neither bytes nor unicode" % s) + if six.PY3: + if isinstance(s, six.binary_type): + return s.decode(encoding) + else: + if isinstance(s, six.text_type): + return s.encode(encoding) + return s + + def isascii(bytes): try: bytes.decode("ascii") @@ -238,6 +273,7 @@ def get_header_tokens(headers, key): return [token.strip() for token in tokens] +@always_byte_args() def hostport(scheme, host, port): """ Returns the host component, with a port specifcation if needed. @@ -323,20 +359,3 @@ def multipartdecode(headers, content): r.append((key, value)) return r return [] - - -def always_bytes(unicode_or_bytes, *encode_args): - if isinstance(unicode_or_bytes, six.text_type): - return unicode_or_bytes.encode(*encode_args) - return unicode_or_bytes - - -def always_byte_args(*encode_args): - """Decorator that transparently encodes all arguments passed as unicode""" - def decorator(fun): - def _fun(*args, **kwargs): - args = [always_bytes(arg, *encode_args) for arg in args] - kwargs = {k: always_bytes(v, *encode_args) for k, v in six.iteritems(kwargs)} - return fun(*args, **kwargs) - return _fun - return decorator diff --git a/netlib/websockets/frame.py b/netlib/websockets/frame.py index ceddd273..55eeaf41 100644 --- a/netlib/websockets/frame.py +++ b/netlib/websockets/frame.py @@ -2,13 +2,14 @@ from __future__ import absolute_import import os import struct import io +import warnings + import six from .protocol import Masker from netlib import tcp from netlib import utils -DEFAULT = object() MAX_16_BIT_INT = (1 << 16) MAX_64_BIT_INT = (1 << 64) @@ -33,9 +34,9 @@ class FrameHeader(object): rsv1=False, rsv2=False, rsv3=False, - masking_key=DEFAULT, - mask=DEFAULT, - length_code=DEFAULT + masking_key=None, + mask=None, + length_code=None ): if not 0 <= opcode < 2 ** 4: raise ValueError("opcode must be 0-16") @@ -46,18 +47,18 @@ class FrameHeader(object): self.rsv2 = rsv2 self.rsv3 = rsv3 - if length_code is DEFAULT: + if length_code is None: self.length_code = self._make_length_code(self.payload_length) else: self.length_code = length_code - if mask is DEFAULT and masking_key is DEFAULT: + if mask is None and masking_key is None: self.mask = False - self.masking_key = "" - elif mask is DEFAULT: + self.masking_key = b"" + elif mask is None: self.mask = 1 self.masking_key = masking_key - elif masking_key is DEFAULT: + elif masking_key is None: self.mask = mask self.masking_key = os.urandom(4) else: @@ -81,7 +82,7 @@ class FrameHeader(object): else: return 127 - def human_readable(self): + def __repr__(self): vals = [ "ws frame:", OPCODE.get_name(self.opcode, hex(self.opcode)).lower() @@ -98,7 +99,11 @@ class FrameHeader(object): vals.append(" %s" % utils.pretty_size(self.payload_length)) return "".join(vals) - def to_bytes(self): + def human_readable(self): + warnings.warn("FrameHeader.to_bytes is deprecated, use bytes(frame_header) instead.", DeprecationWarning) + return repr(self) + + def __bytes__(self): first_byte = utils.setbit(0, 7, self.fin) first_byte = utils.setbit(first_byte, 6, self.rsv1) first_byte = utils.setbit(first_byte, 5, self.rsv2) @@ -107,7 +112,7 @@ class FrameHeader(object): second_byte = utils.setbit(self.length_code, 7, self.mask) - b = chr(first_byte) + chr(second_byte) + b = six.int2byte(first_byte) + six.int2byte(second_byte) if self.payload_length < 126: pass @@ -119,10 +124,17 @@ class FrameHeader(object): # '!Q' = pack as 64 bit unsigned long long # add 8 bytes extended payload length b += struct.pack('!Q', self.payload_length) - if self.masking_key is not None: + if self.masking_key: b += self.masking_key return b + if six.PY2: + __str__ = __bytes__ + + def to_bytes(self): + warnings.warn("FrameHeader.to_bytes is deprecated, use bytes(frame_header) instead.", DeprecationWarning) + return bytes(self) + @classmethod def from_file(cls, fp): """ @@ -154,7 +166,7 @@ class FrameHeader(object): if mask_bit == 1: masking_key = fp.safe_read(4) else: - masking_key = None + masking_key = False return cls( fin=fin, @@ -169,7 +181,9 @@ class FrameHeader(object): ) def __eq__(self, other): - return self.to_bytes() == other.to_bytes() + if isinstance(other, FrameHeader): + return bytes(self) == bytes(other) + return False class Frame(object): @@ -200,7 +214,7 @@ class Frame(object): +---------------------------------------------------------------+ """ - def __init__(self, payload="", **kwargs): + def __init__(self, payload=b"", **kwargs): self.payload = payload kwargs["payload_length"] = kwargs.get("payload_length", len(payload)) self.header = FrameHeader(**kwargs) @@ -216,7 +230,7 @@ class Frame(object): masking_key = os.urandom(4) else: mask_bit = 0 - masking_key = None + masking_key = False return cls( message, @@ -234,28 +248,37 @@ class Frame(object): """ return cls.from_file(tcp.Reader(io.BytesIO(bytestring))) - def human_readable(self): - ret = self.header.human_readable() + def __repr__(self): + ret = repr(self.header) if self.payload: - ret = ret + "\nPayload:\n" + utils.clean_bin(self.payload) + ret = ret + "\nPayload:\n" + utils.clean_bin(self.payload).decode("ascii") return ret - def __repr__(self): - return self.header.human_readable() + def human_readable(self): + warnings.warn("Frame.to_bytes is deprecated, use bytes(frame) instead.", DeprecationWarning) + return repr(self) - def to_bytes(self): + def __bytes__(self): """ Serialize the frame to wire format. Returns a string. """ - b = self.header.to_bytes() + b = bytes(self.header) if self.header.masking_key: b += Masker(self.header.masking_key)(self.payload) else: b += self.payload return b + if six.PY2: + __str__ = __bytes__ + + def to_bytes(self): + warnings.warn("FrameHeader.to_bytes is deprecated, use bytes(frame_header) instead.", DeprecationWarning) + return bytes(self) + def to_file(self, writer): - writer.write(self.to_bytes()) + warnings.warn("Frame.to_file is deprecated, use wfile.write(bytes(frame)) instead.", DeprecationWarning) + writer.write(bytes(self)) writer.flush() @classmethod @@ -286,4 +309,6 @@ class Frame(object): ) def __eq__(self, other): - return self.to_bytes() == other.to_bytes() + if isinstance(other, Frame): + return bytes(self) == bytes(other) + return False \ No newline at end of file diff --git a/netlib/websockets/protocol.py b/netlib/websockets/protocol.py index 68d827a5..778fe7e7 100644 --- a/netlib/websockets/protocol.py +++ b/netlib/websockets/protocol.py @@ -17,11 +17,12 @@ from __future__ import absolute_import import base64 import hashlib import os + +import binascii import six from ..http import Headers -from .. import utils -websockets_magic = '258EAFA5-E914-47DA-95CA-C5AB0DC85B11' +websockets_magic = b'258EAFA5-E914-47DA-95CA-C5AB0DC85B11' VERSION = "13" HEADER_WEBSOCKET_KEY = 'sec-websocket-key' @@ -41,14 +42,21 @@ class Masker(object): def __init__(self, key): self.key = key - self.masks = [six.byte2int(byte) for byte in key] self.offset = 0 def mask(self, offset, data): - result = "" - for c in data: - result += chr(ord(c) ^ self.masks[offset % 4]) - offset += 1 + result = bytearray(data) + if six.PY2: + for i in range(len(data)): + result[i] ^= ord(self.key[offset % 4]) + offset += 1 + result = str(result) + else: + + for i in range(len(data)): + result[i] ^= self.key[offset % 4] + offset += 1 + result = bytes(result) return result def __call__(self, data): @@ -73,37 +81,35 @@ class WebsocketsProtocol(object): """ if not key: key = base64.b64encode(os.urandom(16)).decode('utf-8') - return Headers([ - ('Connection', 'Upgrade'), - ('Upgrade', 'websocket'), - (HEADER_WEBSOCKET_KEY, key), - (HEADER_WEBSOCKET_VERSION, version) - ]) + return Headers(**{ + HEADER_WEBSOCKET_KEY: key, + HEADER_WEBSOCKET_VERSION: version, + "Connection": "Upgrade", + "Upgrade": "websocket", + }) @classmethod def server_handshake_headers(self, key): """ The server response is a valid HTTP 101 response. """ - return Headers( - [ - ('Connection', 'Upgrade'), - ('Upgrade', 'websocket'), - (HEADER_WEBSOCKET_ACCEPT, self.create_server_nonce(key)) - ] - ) + return Headers(**{ + HEADER_WEBSOCKET_ACCEPT: self.create_server_nonce(key), + "Connection": "Upgrade", + "Upgrade": "websocket", + }) @classmethod def check_client_handshake(self, headers): - if headers.get("upgrade") != "websocket": + if headers.get("upgrade") != b"websocket": return return headers.get(HEADER_WEBSOCKET_KEY) @classmethod def check_server_handshake(self, headers): - if headers.get("upgrade") != "websocket": + if headers.get("upgrade") != b"websocket": return return headers.get(HEADER_WEBSOCKET_ACCEPT) @@ -111,5 +117,5 @@ class WebsocketsProtocol(object): @classmethod def create_server_nonce(self, client_nonce): return base64.b64encode( - hashlib.sha1(client_nonce + websockets_magic).hexdigest().decode('hex') + binascii.unhexlify(hashlib.sha1(client_nonce + websockets_magic).hexdigest()) ) diff --git a/netlib/wsgi.py b/netlib/wsgi.py index fba9f388..8fb09008 100644 --- a/netlib/wsgi.py +++ b/netlib/wsgi.py @@ -1,14 +1,15 @@ from __future__ import (absolute_import, print_function, division) -from io import BytesIO +from io import BytesIO, StringIO import urllib import time import traceback import six +from six.moves import urllib +from netlib.utils import always_bytes, native from . import http, tcp - class ClientConn(object): def __init__(self, address): @@ -24,9 +25,10 @@ class Flow(object): class Request(object): - def __init__(self, scheme, method, path, headers, body): + def __init__(self, scheme, method, path, http_version, headers, body): self.scheme, self.method, self.path = scheme, method, path self.headers, self.body = headers, body + self.http_version = http_version def date_time_string(): @@ -53,38 +55,38 @@ class WSGIAdaptor(object): self.app, self.domain, self.port, self.sversion = app, domain, port, sversion def make_environ(self, flow, errsoc, **extra): - if '?' in flow.request.path: - path_info, query = flow.request.path.split('?', 1) + path = native(flow.request.path) + if '?' in path: + path_info, query = native(path).split('?', 1) else: - path_info = flow.request.path + path_info = path query = '' environ = { 'wsgi.version': (1, 0), - 'wsgi.url_scheme': flow.request.scheme, + 'wsgi.url_scheme': native(flow.request.scheme), 'wsgi.input': BytesIO(flow.request.body or b""), 'wsgi.errors': errsoc, 'wsgi.multithread': True, 'wsgi.multiprocess': False, 'wsgi.run_once': False, 'SERVER_SOFTWARE': self.sversion, - 'REQUEST_METHOD': flow.request.method, + 'REQUEST_METHOD': native(flow.request.method), 'SCRIPT_NAME': '', - 'PATH_INFO': urllib.unquote(path_info), + 'PATH_INFO': urllib.parse.unquote(path_info), 'QUERY_STRING': query, - 'CONTENT_TYPE': flow.request.headers.get('Content-Type', ''), - 'CONTENT_LENGTH': flow.request.headers.get('Content-Length', ''), + 'CONTENT_TYPE': native(flow.request.headers.get('Content-Type', '')), + 'CONTENT_LENGTH': native(flow.request.headers.get('Content-Length', '')), 'SERVER_NAME': self.domain, 'SERVER_PORT': str(self.port), - # FIXME: We need to pick up the protocol read from the request. - 'SERVER_PROTOCOL': "HTTP/1.1", + 'SERVER_PROTOCOL': native(flow.request.http_version), } environ.update(extra) if flow.client_conn.address: - environ["REMOTE_ADDR"], environ[ - "REMOTE_PORT"] = flow.client_conn.address() + environ["REMOTE_ADDR"] = native(flow.client_conn.address.host) + environ["REMOTE_PORT"] = flow.client_conn.address.port for key, value in flow.request.headers.items(): - key = 'HTTP_' + key.upper().replace('-', '_') + key = 'HTTP_' + native(key).upper().replace('-', '_') if key not in ('HTTP_CONTENT_TYPE', 'HTTP_CONTENT_LENGTH'): environ[key] = value return environ @@ -99,7 +101,7 @@ class WSGIAdaptor(object):

Internal Server Error

%s"
- """.strip() % s + """.strip() % s.encode() if not headers_sent: soc.write(b"HTTP/1.1 500 Internal Server Error\r\n") soc.write(b"Content-Type: text/html\r\n") @@ -117,7 +119,7 @@ class WSGIAdaptor(object): def write(data): if not state["headers_sent"]: - soc.write(b"HTTP/1.1 %s\r\n" % state["status"]) + soc.write(b"HTTP/1.1 %s\r\n" % state["status"].encode()) headers = state["headers"] if 'server' not in headers: headers["Server"] = self.sversion @@ -132,18 +134,17 @@ class WSGIAdaptor(object): def start_response(status, headers, exc_info=None): if exc_info: - try: - if state["headers_sent"]: - six.reraise(*exc_info) - finally: - exc_info = None + if state["headers_sent"]: + six.reraise(*exc_info) elif state["status"]: raise AssertionError('Response already started') state["status"] = status - state["headers"] = http.Headers(headers) - return write + state["headers"] = http.Headers([[always_bytes(k), always_bytes(v)] for k,v in headers]) + if exc_info: + self.error_page(soc, state["headers_sent"], traceback.format_tb(exc_info[2])) + state["headers_sent"] = True - errs = BytesIO() + errs = six.BytesIO() try: dataiter = self.app( self.make_environ(request, errs, **env), start_response @@ -155,7 +156,7 @@ class WSGIAdaptor(object): except Exception as e: try: s = traceback.format_exc() - errs.write(s) + errs.write(s.encode("utf-8", "replace")) self.error_page(soc, state["headers_sent"], s) except Exception: # pragma: no cover pass -- cgit v1.2.3 From 1ff8f294b459e03e113acb417678a6fd782c2685 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Mon, 21 Sep 2015 18:34:43 +0200 Subject: minor encoding fixes --- netlib/utils.py | 6 +++--- netlib/wsgi.py | 18 +++++++++--------- 2 files changed, 12 insertions(+), 12 deletions(-) (limited to 'netlib') diff --git a/netlib/utils.py b/netlib/utils.py index 8d11bd5b..b9848038 100644 --- a/netlib/utils.py +++ b/netlib/utils.py @@ -26,7 +26,7 @@ def always_byte_args(*encode_args): return decorator -def native(s, encoding="latin-1"): +def native(s, *encoding_opts): """ Convert :py:class:`bytes` or :py:class:`unicode` to the native :py:class:`str` type, using latin1 encoding if conversion is necessary. @@ -37,10 +37,10 @@ def native(s, encoding="latin-1"): raise TypeError("%r is neither bytes nor unicode" % s) if six.PY3: if isinstance(s, six.binary_type): - return s.decode(encoding) + return s.decode(*encoding_opts) else: if isinstance(s, six.text_type): - return s.encode(encoding) + return s.encode(*encoding_opts) return s diff --git a/netlib/wsgi.py b/netlib/wsgi.py index 8fb09008..4fcd5178 100644 --- a/netlib/wsgi.py +++ b/netlib/wsgi.py @@ -55,38 +55,38 @@ class WSGIAdaptor(object): self.app, self.domain, self.port, self.sversion = app, domain, port, sversion def make_environ(self, flow, errsoc, **extra): - path = native(flow.request.path) + path = native(flow.request.path, "latin-1") if '?' in path: - path_info, query = native(path).split('?', 1) + path_info, query = native(path, "latin-1").split('?', 1) else: path_info = path query = '' environ = { 'wsgi.version': (1, 0), - 'wsgi.url_scheme': native(flow.request.scheme), + 'wsgi.url_scheme': native(flow.request.scheme, "latin-1"), 'wsgi.input': BytesIO(flow.request.body or b""), 'wsgi.errors': errsoc, 'wsgi.multithread': True, 'wsgi.multiprocess': False, 'wsgi.run_once': False, 'SERVER_SOFTWARE': self.sversion, - 'REQUEST_METHOD': native(flow.request.method), + 'REQUEST_METHOD': native(flow.request.method, "latin-1"), 'SCRIPT_NAME': '', 'PATH_INFO': urllib.parse.unquote(path_info), 'QUERY_STRING': query, - 'CONTENT_TYPE': native(flow.request.headers.get('Content-Type', '')), - 'CONTENT_LENGTH': native(flow.request.headers.get('Content-Length', '')), + 'CONTENT_TYPE': native(flow.request.headers.get('Content-Type', ''), "latin-1"), + 'CONTENT_LENGTH': native(flow.request.headers.get('Content-Length', ''), "latin-1"), 'SERVER_NAME': self.domain, 'SERVER_PORT': str(self.port), - 'SERVER_PROTOCOL': native(flow.request.http_version), + 'SERVER_PROTOCOL': native(flow.request.http_version, "latin-1"), } environ.update(extra) if flow.client_conn.address: - environ["REMOTE_ADDR"] = native(flow.client_conn.address.host) + environ["REMOTE_ADDR"] = native(flow.client_conn.address.host, "latin-1") environ["REMOTE_PORT"] = flow.client_conn.address.port for key, value in flow.request.headers.items(): - key = 'HTTP_' + native(key).upper().replace('-', '_') + key = 'HTTP_' + native(key, "latin-1").upper().replace('-', '_') if key not in ('HTTP_CONTENT_TYPE', 'HTTP_CONTENT_LENGTH'): environ[key] = value return environ -- cgit v1.2.3 From e9fe45f3f404bb1c762dfb13477072c06d4b74ec Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Mon, 21 Sep 2015 18:38:50 +0200 Subject: backport changes --- netlib/http/models.py | 36 ++++++++++++++++++------------------ netlib/tcp.py | 1 + 2 files changed, 19 insertions(+), 18 deletions(-) (limited to 'netlib') diff --git a/netlib/http/models.py b/netlib/http/models.py index 3c360a37..512a764d 100644 --- a/netlib/http/models.py +++ b/netlib/http/models.py @@ -216,7 +216,7 @@ class Message(object): def body(self, body): self._body = body if isinstance(body, bytes): - self.headers[b"Content-Length"] = str(len(body)).encode() + self.headers[b"content-length"] = str(len(body)).encode() content = body @@ -268,8 +268,8 @@ class Request(Message): response. That is, we remove ETags and If-Modified-Since headers. """ delheaders = [ - b"If-Modified-Since", - b"If-None-Match", + b"if-modified-since", + b"if-none-match", ] for i in delheaders: self.headers.pop(i, None) @@ -279,16 +279,16 @@ class Request(Message): Modifies this request to remove headers that will compress the resource's data. """ - self.headers["Accept-Encoding"] = "identity" + self.headers["accept-encoding"] = b"identity" def constrain_encoding(self): """ Limits the permissible Accept-Encoding values, based on what we can decode appropriately. """ - accept_encoding = native(self.headers.get("Accept-Encoding"), "ascii") + accept_encoding = native(self.headers.get("accept-encoding"), "ascii") if accept_encoding: - self.headers["Accept-Encoding"] = ( + self.headers["accept-encoding"] = ( ', '.join( e for e in encoding.ENCODINGS @@ -300,7 +300,7 @@ class Request(Message): """ Update the host header to reflect the current target. """ - self.headers["Host"] = self.host + self.headers["host"] = self.host def get_form(self): """ @@ -309,9 +309,9 @@ class Request(Message): indicates non-form data. """ if self.body: - if HDR_FORM_URLENCODED in self.headers.get("Content-Type", b"").lower(): + if HDR_FORM_URLENCODED in self.headers.get("content-type", b"").lower(): return self.get_form_urlencoded() - elif HDR_FORM_MULTIPART in self.headers.get("Content-Type", b"").lower(): + elif HDR_FORM_MULTIPART in self.headers.get("content-type", b"").lower(): return self.get_form_multipart() return ODict([]) @@ -321,12 +321,12 @@ class Request(Message): Returns an empty ODict if there is no data or the content-type indicates non-form data. """ - if self.body and HDR_FORM_URLENCODED in self.headers.get("Content-Type", b"").lower(): + if self.body and HDR_FORM_URLENCODED in self.headers.get("content-type", b"").lower(): return ODict(utils.urldecode(self.body)) return ODict([]) def get_form_multipart(self): - if self.body and HDR_FORM_MULTIPART in self.headers.get("Content-Type", b"").lower(): + if self.body and HDR_FORM_MULTIPART in self.headers.get("content-type", b"").lower(): return ODict( utils.multipartdecode( self.headers, @@ -341,7 +341,7 @@ class Request(Message): """ # FIXME: If there's an existing content-type header indicating a # url-encoded form, leave it alone. - self.headers[b"Content-Type"] = HDR_FORM_URLENCODED + self.headers[b"content-type"] = HDR_FORM_URLENCODED self.body = utils.urlencode(odict.lst) def get_path_components(self): @@ -398,9 +398,9 @@ class Request(Message): but not the resolved name. This is disabled by default, as an attacker may spoof the host header to confuse an analyst. """ - if hostheader and "Host" in self.headers: + if hostheader and "host" in self.headers: try: - return self.headers["Host"].decode("idna") + return self.headers["host"].decode("idna") except ValueError: pass if self.host: @@ -429,7 +429,7 @@ class Request(Message): headers. """ v = cookies.format_cookie_header(odict) - self.headers["Cookie"] = v + self.headers["cookie"] = v @property def url(self): @@ -485,7 +485,7 @@ class Response(Message): return "".format( status_code=self.status_code, msg=self.msg, - contenttype=self.headers.get("Content-Type", "unknown content type"), + contenttype=self.headers.get("content-type", "unknown content type"), size=size) def get_cookies(self): @@ -498,7 +498,7 @@ class Response(Message): attributes (e.g. HTTPOnly) are indicated by a Null value. """ ret = [] - for header in self.headers.get_all("Set-Cookie"): + for header in self.headers.get_all("set-cookie"): v = cookies.parse_set_cookie_header(native(header, "ascii")) if v: name, value, attrs = v @@ -521,4 +521,4 @@ class Response(Message): i[1][1] ) ) - self.headers.set_all("Set-Cookie", values) + self.headers.set_all("set-cookie", values) diff --git a/netlib/tcp.py b/netlib/tcp.py index 40ffbd48..b751d71f 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -279,6 +279,7 @@ class Reader(_FileLike): if tuple(int(x) for x in OpenSSL.__version__.split(".")[:2]) > (0, 15): return self.o.recv(length, socket.MSG_PEEK) else: + # TODO: remove once a new version is released # Polyfill for pyOpenSSL <= 0.15.1 # Taken from https://github.com/pyca/pyopenssl/commit/1d95dea7fea03c7c0df345a5ea30c12d8a0378d2 buf = SSL._ffi.new("char[]", length) -- cgit v1.2.3 From 9fbeac50ce3f6ae49b0f0270c508b6e81a1eaf17 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Mon, 21 Sep 2015 22:49:39 +0200 Subject: revert websocket changes from 73586b1b The DEFAULT construct is very weird, but with None we apparently break pathod in some difficult-to-debug ways. Revisit once we do more here. --- netlib/websockets/frame.py | 22 ++++++++++++---------- 1 file changed, 12 insertions(+), 10 deletions(-) (limited to 'netlib') diff --git a/netlib/websockets/frame.py b/netlib/websockets/frame.py index 55eeaf41..fce2c9d3 100644 --- a/netlib/websockets/frame.py +++ b/netlib/websockets/frame.py @@ -14,6 +14,8 @@ from netlib import utils MAX_16_BIT_INT = (1 << 16) MAX_64_BIT_INT = (1 << 64) +DEFAULT=object() + OPCODE = utils.BiDi( CONTINUE=0x00, TEXT=0x01, @@ -34,9 +36,9 @@ class FrameHeader(object): rsv1=False, rsv2=False, rsv3=False, - masking_key=None, - mask=None, - length_code=None + masking_key=DEFAULT, + mask=DEFAULT, + length_code=DEFAULT ): if not 0 <= opcode < 2 ** 4: raise ValueError("opcode must be 0-16") @@ -47,18 +49,18 @@ class FrameHeader(object): self.rsv2 = rsv2 self.rsv3 = rsv3 - if length_code is None: + if length_code is DEFAULT: self.length_code = self._make_length_code(self.payload_length) else: self.length_code = length_code - if mask is None and masking_key is None: + if mask is DEFAULT and masking_key is DEFAULT: self.mask = False self.masking_key = b"" - elif mask is None: + elif mask is DEFAULT: self.mask = 1 self.masking_key = masking_key - elif masking_key is None: + elif masking_key is DEFAULT: self.mask = mask self.masking_key = os.urandom(4) else: @@ -166,7 +168,7 @@ class FrameHeader(object): if mask_bit == 1: masking_key = fp.safe_read(4) else: - masking_key = False + masking_key = None return cls( fin=fin, @@ -230,7 +232,7 @@ class Frame(object): masking_key = os.urandom(4) else: mask_bit = 0 - masking_key = False + masking_key = None return cls( message, @@ -311,4 +313,4 @@ class Frame(object): def __eq__(self, other): if isinstance(other, Frame): return bytes(self) == bytes(other) - return False \ No newline at end of file + return False -- cgit v1.2.3 From f93752277395d201fabefed8fae6d412f13da699 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Tue, 22 Sep 2015 01:48:35 +0200 Subject: Headers: return str on all Python versions --- netlib/http/__init__.py | 6 +- netlib/http/authentication.py | 10 +- netlib/http/headers.py | 205 ++++++++++++++++++++++++++++++++++++++++ netlib/http/http1/assemble.py | 6 +- netlib/http/http1/read.py | 14 +-- netlib/http/models.py | 215 ++++-------------------------------------- netlib/utils.py | 17 ++-- netlib/websockets/protocol.py | 14 ++- 8 files changed, 257 insertions(+), 230 deletions(-) create mode 100644 netlib/http/headers.py (limited to 'netlib') diff --git a/netlib/http/__init__.py b/netlib/http/__init__.py index d72884b3..0ccf6b32 100644 --- a/netlib/http/__init__.py +++ b/netlib/http/__init__.py @@ -1,11 +1,13 @@ from __future__ import absolute_import, print_function, division -from .models import Request, Response, Headers +from .headers import Headers +from .models import Request, Response from .models import ALPN_PROTO_HTTP1, ALPN_PROTO_H2 from .models import HDR_FORM_MULTIPART, HDR_FORM_URLENCODED, CONTENT_MISSING from . import http1, http2 __all__ = [ - "Request", "Response", "Headers", + "Headers", + "Request", "Response", "ALPN_PROTO_HTTP1", "ALPN_PROTO_H2", "HDR_FORM_MULTIPART", "HDR_FORM_URLENCODED", "CONTENT_MISSING", "http1", "http2", diff --git a/netlib/http/authentication.py b/netlib/http/authentication.py index 5831660b..d769abe5 100644 --- a/netlib/http/authentication.py +++ b/netlib/http/authentication.py @@ -9,18 +9,18 @@ def parse_http_basic_auth(s): return None scheme = words[0] try: - user = binascii.a2b_base64(words[1]) + user = binascii.a2b_base64(words[1]).decode("utf8", "replace") except binascii.Error: return None - parts = user.split(b':') + parts = user.split(':') if len(parts) != 2: return None return scheme, parts[0], parts[1] def assemble_http_basic_auth(scheme, username, password): - v = binascii.b2a_base64(username + b":" + password) - return scheme + b" " + v + v = binascii.b2a_base64((username + ":" + password).encode("utf8")).decode("ascii") + return scheme + " " + v class NullProxyAuth(object): @@ -69,7 +69,7 @@ class BasicProxyAuth(NullProxyAuth): if not parts: return False scheme, username, password = parts - if scheme.lower() != b'basic': + if scheme.lower() != 'basic': return False if not self.password_manager.test(username, password): return False diff --git a/netlib/http/headers.py b/netlib/http/headers.py new file mode 100644 index 00000000..1511ea2d --- /dev/null +++ b/netlib/http/headers.py @@ -0,0 +1,205 @@ +""" + +Unicode Handling +---------------- +See also: http://lucumr.pocoo.org/2013/7/2/the-updated-guide-to-unicode/ +""" +from __future__ import absolute_import, print_function, division +import copy +try: + from collections.abc import MutableMapping +except ImportError: # Workaround for Python < 3.3 + from collections import MutableMapping + + +import six + +from netlib.utils import always_byte_args + +if six.PY2: + _native = lambda x: x + _asbytes = lambda x: x + _always_byte_args = lambda x: x +else: + # While headers _should_ be ASCII, it's not uncommon for certain headers to be utf-8 encoded. + _native = lambda x: x.decode("utf-8", "surrogateescape") + _asbytes = lambda x: x.encode("utf-8", "surrogateescape") + _always_byte_args = always_byte_args("utf-8", "surrogateescape") + + +class Headers(MutableMapping, object): + """ + Header class which allows both convenient access to individual headers as well as + direct access to the underlying raw data. Provides a full dictionary interface. + + Example: + + .. code-block:: python + + # Create header from a list of (header_name, header_value) tuples + >>> h = Headers([ + ["Host","example.com"], + ["Accept","text/html"], + ["accept","application/xml"] + ]) + + # Headers mostly behave like a normal dict. + >>> h["Host"] + "example.com" + + # HTTP Headers are case insensitive + >>> h["host"] + "example.com" + + # Multiple headers are folded into a single header as per RFC7230 + >>> h["Accept"] + "text/html, application/xml" + + # Setting a header removes all existing headers with the same name. + >>> h["Accept"] = "application/text" + >>> h["Accept"] + "application/text" + + # str(h) returns a HTTP1 header block. + >>> print(h) + Host: example.com + Accept: application/text + + # For full control, the raw header fields can be accessed + >>> h.fields + + # Headers can also be crated from keyword arguments + >>> h = Headers(host="example.com", content_type="application/xml") + + Caveats: + For use with the "Set-Cookie" header, see :py:meth:`get_all`. + """ + + @_always_byte_args + def __init__(self, fields=None, **headers): + """ + Args: + fields: (optional) list of ``(name, value)`` header tuples, + e.g. ``[("Host","example.com")]``. All names and values must be bytes. + **headers: Additional headers to set. Will overwrite existing values from `fields`. + For convenience, underscores in header names will be transformed to dashes - + this behaviour does not extend to other methods. + If ``**headers`` contains multiple keys that have equal ``.lower()`` s, + the behavior is undefined. + """ + self.fields = fields or [] + + for name, value in self.fields: + if not isinstance(name, bytes) or not isinstance(value, bytes): + raise ValueError("Headers passed as fields must be bytes.") + + # content_type -> content-type + headers = { + _asbytes(name).replace(b"_", b"-"): value + for name, value in six.iteritems(headers) + } + self.update(headers) + + def __bytes__(self): + if self.fields: + return b"\r\n".join(b": ".join(field) for field in self.fields) + b"\r\n" + else: + return b"" + + if six.PY2: + __str__ = __bytes__ + + @_always_byte_args + def __getitem__(self, name): + values = self.get_all(name) + if not values: + raise KeyError(name) + return ", ".join(values) + + @_always_byte_args + def __setitem__(self, name, value): + idx = self._index(name) + + # To please the human eye, we insert at the same position the first existing header occured. + if idx is not None: + del self[name] + self.fields.insert(idx, [name, value]) + else: + self.fields.append([name, value]) + + @_always_byte_args + def __delitem__(self, name): + if name not in self: + raise KeyError(name) + name = name.lower() + self.fields = [ + field for field in self.fields + if name != field[0].lower() + ] + + def __iter__(self): + seen = set() + for name, _ in self.fields: + name_lower = name.lower() + if name_lower not in seen: + seen.add(name_lower) + yield _native(name) + + def __len__(self): + return len(set(name.lower() for name, _ in self.fields)) + + # __hash__ = object.__hash__ + + def _index(self, name): + name = name.lower() + for i, field in enumerate(self.fields): + if field[0].lower() == name: + return i + return None + + def __eq__(self, other): + if isinstance(other, Headers): + return self.fields == other.fields + return False + + def __ne__(self, other): + return not self.__eq__(other) + + @_always_byte_args + def get_all(self, name): + """ + Like :py:meth:`get`, but does not fold multiple headers into a single one. + This is useful for Set-Cookie headers, which do not support folding. + + See also: https://tools.ietf.org/html/rfc7230#section-3.2.2 + """ + name_lower = name.lower() + values = [_native(value) for n, value in self.fields if n.lower() == name_lower] + return values + + @_always_byte_args + def set_all(self, name, values): + """ + Explicitly set multiple headers for the given key. + See: :py:meth:`get_all` + """ + values = map(_asbytes, values) # _always_byte_args does not fix lists + if name in self: + del self[name] + self.fields.extend( + [name, value] for value in values + ) + + def copy(self): + return Headers(copy.copy(self.fields)) + + # Implement the StateObject protocol from mitmproxy + def get_state(self, short=False): + return tuple(tuple(field) for field in self.fields) + + def load_state(self, state): + self.fields = [list(field) for field in state] + + @classmethod + def from_state(cls, state): + return cls([list(field) for field in state]) \ No newline at end of file diff --git a/netlib/http/http1/assemble.py b/netlib/http/http1/assemble.py index c2b60a0f..88aeac05 100644 --- a/netlib/http/http1/assemble.py +++ b/netlib/http/http1/assemble.py @@ -35,7 +35,7 @@ def assemble_response_head(response): def assemble_body(headers, body_chunks): - if b"chunked" in headers.get(b"transfer-encoding", b"").lower(): + if "chunked" in headers.get("transfer-encoding", "").lower(): for chunk in body_chunks: if chunk: yield b"%x\r\n%s\r\n" % (len(chunk), chunk) @@ -76,8 +76,8 @@ def _assemble_request_line(request, form=None): def _assemble_request_headers(request): headers = request.headers.copy() - if b"host" not in headers and request.scheme and request.host and request.port: - headers[b"Host"] = utils.hostport( + if "host" not in headers and request.scheme and request.host and request.port: + headers["host"] = utils.hostport( request.scheme, request.host, request.port diff --git a/netlib/http/http1/read.py b/netlib/http/http1/read.py index c6760ff3..4c898348 100644 --- a/netlib/http/http1/read.py +++ b/netlib/http/http1/read.py @@ -146,11 +146,11 @@ def connection_close(http_version, headers): according to RFC 2616 Section 8.1. """ # At first, check if we have an explicit Connection header. - if b"connection" in headers: + if "connection" in headers: tokens = utils.get_header_tokens(headers, "connection") - if b"close" in tokens: + if "close" in tokens: return True - elif b"keep-alive" in tokens: + elif "keep-alive" in tokens: return False # If we don't have a Connection header, HTTP 1.1 connections are assumed to @@ -181,7 +181,7 @@ def expected_http_body_size(request, response=None): is_request = False if is_request: - if headers.get(b"expect", b"").lower() == b"100-continue": + if headers.get("expect", "").lower() == "100-continue": return 0 else: if request.method.upper() == b"HEAD": @@ -193,11 +193,11 @@ def expected_http_body_size(request, response=None): if response_code in (204, 304): return 0 - if b"chunked" in headers.get(b"transfer-encoding", b"").lower(): + if "chunked" in headers.get("transfer-encoding", "").lower(): return None - if b"content-length" in headers: + if "content-length" in headers: try: - size = int(headers[b"content-length"]) + size = int(headers["content-length"]) if size < 0: raise ValueError() return size diff --git a/netlib/http/models.py b/netlib/http/models.py index 512a764d..55664533 100644 --- a/netlib/http/models.py +++ b/netlib/http/models.py @@ -1,201 +1,22 @@ -from __future__ import absolute_import, print_function, division -import copy + from ..odict import ODict from .. import utils, encoding -from ..utils import always_bytes, always_byte_args, native +from ..utils import always_bytes, native from . import cookies +from .headers import Headers -import six from six.moves import urllib -try: - from collections import MutableMapping -except ImportError: - from collections.abc import MutableMapping # TODO: Move somewhere else? ALPN_PROTO_HTTP1 = b'http/1.1' ALPN_PROTO_H2 = b'h2' -HDR_FORM_URLENCODED = b"application/x-www-form-urlencoded" -HDR_FORM_MULTIPART = b"multipart/form-data" +HDR_FORM_URLENCODED = "application/x-www-form-urlencoded" +HDR_FORM_MULTIPART = "multipart/form-data" CONTENT_MISSING = 0 -class Headers(MutableMapping, object): - """ - Header class which allows both convenient access to individual headers as well as - direct access to the underlying raw data. Provides a full dictionary interface. - - Example: - - .. code-block:: python - - # Create header from a list of (header_name, header_value) tuples - >>> h = Headers([ - ["Host","example.com"], - ["Accept","text/html"], - ["accept","application/xml"] - ]) - - # Headers mostly behave like a normal dict. - >>> h["Host"] - "example.com" - - # HTTP Headers are case insensitive - >>> h["host"] - "example.com" - - # Multiple headers are folded into a single header as per RFC7230 - >>> h["Accept"] - "text/html, application/xml" - - # Setting a header removes all existing headers with the same name. - >>> h["Accept"] = "application/text" - >>> h["Accept"] - "application/text" - - # str(h) returns a HTTP1 header block. - >>> print(h) - Host: example.com - Accept: application/text - - # For full control, the raw header fields can be accessed - >>> h.fields - - # Headers can also be crated from keyword arguments - >>> h = Headers(host="example.com", content_type="application/xml") - - Caveats: - For use with the "Set-Cookie" header, see :py:meth:`get_all`. - """ - - @always_byte_args("ascii") - def __init__(self, fields=None, **headers): - """ - Args: - fields: (optional) list of ``(name, value)`` header tuples, - e.g. ``[("Host","example.com")]``. All names and values must be bytes. - **headers: Additional headers to set. Will overwrite existing values from `fields`. - For convenience, underscores in header names will be transformed to dashes - - this behaviour does not extend to other methods. - If ``**headers`` contains multiple keys that have equal ``.lower()`` s, - the behavior is undefined. - """ - self.fields = fields or [] - - # content_type -> content-type - headers = { - name.encode("ascii").replace(b"_", b"-"): value - for name, value in six.iteritems(headers) - } - self.update(headers) - - def __bytes__(self): - if self.fields: - return b"\r\n".join(b": ".join(field) for field in self.fields) + b"\r\n" - else: - return b"" - - if six.PY2: - __str__ = __bytes__ - - @always_byte_args("ascii") - def __getitem__(self, name): - values = self.get_all(name) - if not values: - raise KeyError(name) - return b", ".join(values) - - @always_byte_args("ascii") - def __setitem__(self, name, value): - idx = self._index(name) - - # To please the human eye, we insert at the same position the first existing header occured. - if idx is not None: - del self[name] - self.fields.insert(idx, [name, value]) - else: - self.fields.append([name, value]) - - @always_byte_args("ascii") - def __delitem__(self, name): - if name not in self: - raise KeyError(name) - name = name.lower() - self.fields = [ - field for field in self.fields - if name != field[0].lower() - ] - - def __iter__(self): - seen = set() - for name, _ in self.fields: - name_lower = name.lower() - if name_lower not in seen: - seen.add(name_lower) - yield name - - def __len__(self): - return len(set(name.lower() for name, _ in self.fields)) - - # __hash__ = object.__hash__ - - def _index(self, name): - name = name.lower() - for i, field in enumerate(self.fields): - if field[0].lower() == name: - return i - return None - - def __eq__(self, other): - if isinstance(other, Headers): - return self.fields == other.fields - return False - - def __ne__(self, other): - return not self.__eq__(other) - - @always_byte_args("ascii") - def get_all(self, name): - """ - Like :py:meth:`get`, but does not fold multiple headers into a single one. - This is useful for Set-Cookie headers, which do not support folding. - - See also: https://tools.ietf.org/html/rfc7230#section-3.2.2 - """ - name_lower = name.lower() - values = [value for n, value in self.fields if n.lower() == name_lower] - return values - - def set_all(self, name, values): - """ - Explicitly set multiple headers for the given key. - See: :py:meth:`get_all` - """ - name = always_bytes(name, "ascii") - values = (always_bytes(value, "ascii") for value in values) - if name in self: - del self[name] - self.fields.extend( - [name, value] for value in values - ) - - def copy(self): - return Headers(copy.copy(self.fields)) - - # Implement the StateObject protocol from mitmproxy - def get_state(self, short=False): - return tuple(tuple(field) for field in self.fields) - - def load_state(self, state): - self.fields = [list(field) for field in state] - - @classmethod - def from_state(cls, state): - return cls([list(field) for field in state]) - - class Message(object): def __init__(self, http_version, headers, body, timestamp_start, timestamp_end): self.http_version = http_version @@ -216,7 +37,7 @@ class Message(object): def body(self, body): self._body = body if isinstance(body, bytes): - self.headers[b"content-length"] = str(len(body)).encode() + self.headers["content-length"] = str(len(body)).encode() content = body @@ -268,8 +89,8 @@ class Request(Message): response. That is, we remove ETags and If-Modified-Since headers. """ delheaders = [ - b"if-modified-since", - b"if-none-match", + "if-modified-since", + "if-none-match", ] for i in delheaders: self.headers.pop(i, None) @@ -279,14 +100,14 @@ class Request(Message): Modifies this request to remove headers that will compress the resource's data. """ - self.headers["accept-encoding"] = b"identity" + self.headers["accept-encoding"] = "identity" def constrain_encoding(self): """ Limits the permissible Accept-Encoding values, based on what we can decode appropriately. """ - accept_encoding = native(self.headers.get("accept-encoding"), "ascii") + accept_encoding = self.headers.get("accept-encoding") if accept_encoding: self.headers["accept-encoding"] = ( ', '.join( @@ -309,9 +130,9 @@ class Request(Message): indicates non-form data. """ if self.body: - if HDR_FORM_URLENCODED in self.headers.get("content-type", b"").lower(): + if HDR_FORM_URLENCODED in self.headers.get("content-type", "").lower(): return self.get_form_urlencoded() - elif HDR_FORM_MULTIPART in self.headers.get("content-type", b"").lower(): + elif HDR_FORM_MULTIPART in self.headers.get("content-type", "").lower(): return self.get_form_multipart() return ODict([]) @@ -321,12 +142,12 @@ class Request(Message): Returns an empty ODict if there is no data or the content-type indicates non-form data. """ - if self.body and HDR_FORM_URLENCODED in self.headers.get("content-type", b"").lower(): + if self.body and HDR_FORM_URLENCODED in self.headers.get("content-type", "").lower(): return ODict(utils.urldecode(self.body)) return ODict([]) def get_form_multipart(self): - if self.body and HDR_FORM_MULTIPART in self.headers.get("content-type", b"").lower(): + if self.body and HDR_FORM_MULTIPART in self.headers.get("content-type", "").lower(): return ODict( utils.multipartdecode( self.headers, @@ -341,7 +162,7 @@ class Request(Message): """ # FIXME: If there's an existing content-type header indicating a # url-encoded form, leave it alone. - self.headers[b"content-type"] = HDR_FORM_URLENCODED + self.headers["content-type"] = HDR_FORM_URLENCODED self.body = utils.urlencode(odict.lst) def get_path_components(self): @@ -400,7 +221,7 @@ class Request(Message): """ if hostheader and "host" in self.headers: try: - return self.headers["host"].decode("idna") + return self.headers["host"] except ValueError: pass if self.host: @@ -420,7 +241,7 @@ class Request(Message): """ ret = ODict() for i in self.headers.get_all("Cookie"): - ret.extend(cookies.parse_cookie_header(native(i,"ascii"))) + ret.extend(cookies.parse_cookie_header(i)) return ret def set_cookies(self, odict): @@ -499,7 +320,7 @@ class Response(Message): """ ret = [] for header in self.headers.get_all("set-cookie"): - v = cookies.parse_set_cookie_header(native(header, "ascii")) + v = cookies.parse_set_cookie_header(header) if v: name, value, attrs = v ret.append([name, [value, attrs]]) diff --git a/netlib/utils.py b/netlib/utils.py index b9848038..d5b30128 100644 --- a/netlib/utils.py +++ b/netlib/utils.py @@ -269,7 +269,7 @@ def get_header_tokens(headers, key): """ if key not in headers: return [] - tokens = headers[key].split(b",") + tokens = headers[key].split(",") return [token.strip() for token in tokens] @@ -320,14 +320,14 @@ def parse_content_type(c): ("text", "html", {"charset": "UTF-8"}) """ - parts = c.split(b";", 1) - ts = parts[0].split(b"/", 1) + parts = c.split(";", 1) + ts = parts[0].split("/", 1) if len(ts) != 2: return None d = {} if len(parts) == 2: - for i in parts[1].split(b";"): - clause = i.split(b"=", 1) + for i in parts[1].split(";"): + clause = i.split("=", 1) if len(clause) == 2: d[clause[0].strip()] = clause[1].strip() return ts[0].lower(), ts[1].lower(), d @@ -337,13 +337,14 @@ def multipartdecode(headers, content): """ Takes a multipart boundary encoded string and returns list of (key, value) tuples. """ - v = headers.get(b"Content-Type") + v = headers.get("Content-Type") if v: v = parse_content_type(v) if not v: return [] - boundary = v[2].get(b"boundary") - if not boundary: + try: + boundary = v[2]["boundary"].encode("ascii") + except (KeyError, UnicodeError): return [] rx = re.compile(br'\bname="([^"]+)"') diff --git a/netlib/websockets/protocol.py b/netlib/websockets/protocol.py index 778fe7e7..e62f8df6 100644 --- a/netlib/websockets/protocol.py +++ b/netlib/websockets/protocol.py @@ -80,7 +80,7 @@ class WebsocketsProtocol(object): Returns an instance of Headers """ if not key: - key = base64.b64encode(os.urandom(16)).decode('utf-8') + key = base64.b64encode(os.urandom(16)).decode('ascii') return Headers(**{ HEADER_WEBSOCKET_KEY: key, HEADER_WEBSOCKET_VERSION: version, @@ -95,27 +95,25 @@ class WebsocketsProtocol(object): """ return Headers(**{ HEADER_WEBSOCKET_ACCEPT: self.create_server_nonce(key), - "Connection": "Upgrade", - "Upgrade": "websocket", + "connection": "Upgrade", + "upgrade": "websocket", }) @classmethod def check_client_handshake(self, headers): - if headers.get("upgrade") != b"websocket": + if headers.get("upgrade") != "websocket": return return headers.get(HEADER_WEBSOCKET_KEY) @classmethod def check_server_handshake(self, headers): - if headers.get("upgrade") != b"websocket": + if headers.get("upgrade") != "websocket": return return headers.get(HEADER_WEBSOCKET_ACCEPT) @classmethod def create_server_nonce(self, client_nonce): - return base64.b64encode( - binascii.unhexlify(hashlib.sha1(client_nonce + websockets_magic).hexdigest()) - ) + return base64.b64encode(hashlib.sha1(client_nonce + websockets_magic).digest()) -- cgit v1.2.3 From c7b83225001505b32905376703ec7ddaf200af44 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Tue, 22 Sep 2015 01:56:09 +0200 Subject: also accept bytes as arguments --- netlib/http/headers.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) (limited to 'netlib') diff --git a/netlib/http/headers.py b/netlib/http/headers.py index 1511ea2d..613beb4f 100644 --- a/netlib/http/headers.py +++ b/netlib/http/headers.py @@ -14,16 +14,16 @@ except ImportError: # Workaround for Python < 3.3 import six -from netlib.utils import always_byte_args +from netlib.utils import always_byte_args, always_bytes if six.PY2: _native = lambda x: x - _asbytes = lambda x: x + _always_bytes = lambda x: x _always_byte_args = lambda x: x else: # While headers _should_ be ASCII, it's not uncommon for certain headers to be utf-8 encoded. _native = lambda x: x.decode("utf-8", "surrogateescape") - _asbytes = lambda x: x.encode("utf-8", "surrogateescape") + _always_bytes = lambda x: always_bytes(x, "utf-8", "surrogateescape") _always_byte_args = always_byte_args("utf-8", "surrogateescape") @@ -95,9 +95,9 @@ class Headers(MutableMapping, object): # content_type -> content-type headers = { - _asbytes(name).replace(b"_", b"-"): value + _always_bytes(name).replace(b"_", b"-"): value for name, value in six.iteritems(headers) - } + } self.update(headers) def __bytes__(self): @@ -183,7 +183,7 @@ class Headers(MutableMapping, object): Explicitly set multiple headers for the given key. See: :py:meth:`get_all` """ - values = map(_asbytes, values) # _always_byte_args does not fix lists + values = map(_always_bytes, values) # _always_byte_args does not fix lists if name in self: del self[name] self.fields.extend( -- cgit v1.2.3 From 45f2ea33b2fdb67ca89e7eedd860ebe683770497 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Fri, 25 Sep 2015 18:24:18 +0200 Subject: minor fixes --- netlib/utils.py | 2 +- netlib/websockets/protocol.py | 30 +++++++++++++----------------- 2 files changed, 14 insertions(+), 18 deletions(-) (limited to 'netlib') diff --git a/netlib/utils.py b/netlib/utils.py index d5b30128..6f6d1ea0 100644 --- a/netlib/utils.py +++ b/netlib/utils.py @@ -337,7 +337,7 @@ def multipartdecode(headers, content): """ Takes a multipart boundary encoded string and returns list of (key, value) tuples. """ - v = headers.get("Content-Type") + v = headers.get("content-type") if v: v = parse_content_type(v) if not v: diff --git a/netlib/websockets/protocol.py b/netlib/websockets/protocol.py index e62f8df6..1e95fa1c 100644 --- a/netlib/websockets/protocol.py +++ b/netlib/websockets/protocol.py @@ -25,10 +25,6 @@ from ..http import Headers websockets_magic = b'258EAFA5-E914-47DA-95CA-C5AB0DC85B11' VERSION = "13" -HEADER_WEBSOCKET_KEY = 'sec-websocket-key' -HEADER_WEBSOCKET_ACCEPT = 'sec-websocket-accept' -HEADER_WEBSOCKET_VERSION = 'sec-websocket-version' - class Masker(object): @@ -81,37 +77,37 @@ class WebsocketsProtocol(object): """ if not key: key = base64.b64encode(os.urandom(16)).decode('ascii') - return Headers(**{ - HEADER_WEBSOCKET_KEY: key, - HEADER_WEBSOCKET_VERSION: version, - "Connection": "Upgrade", - "Upgrade": "websocket", - }) + return Headers( + sec_websocket_key=key, + sec_websocket_version=version, + connection="Upgrade", + upgrade="websocket", + ) @classmethod def server_handshake_headers(self, key): """ The server response is a valid HTTP 101 response. """ - return Headers(**{ - HEADER_WEBSOCKET_ACCEPT: self.create_server_nonce(key), - "connection": "Upgrade", - "upgrade": "websocket", - }) + return Headers( + sec_websocket_accept=self.create_server_nonce(key), + connection="Upgrade", + upgrade="websocket" + ) @classmethod def check_client_handshake(self, headers): if headers.get("upgrade") != "websocket": return - return headers.get(HEADER_WEBSOCKET_KEY) + return headers.get("sec-websocket-key") @classmethod def check_server_handshake(self, headers): if headers.get("upgrade") != "websocket": return - return headers.get(HEADER_WEBSOCKET_ACCEPT) + return headers.get("sec-websocket-accept") @classmethod -- cgit v1.2.3 From 106f7046d3862cb0e3cbb4f38335af0330b4e7e3 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Sat, 26 Sep 2015 00:39:04 +0200 Subject: refactor request model --- netlib/http/__init__.py | 5 +- netlib/http/headers.py | 2 +- netlib/http/http1/assemble.py | 65 ++++---- netlib/http/http1/read.py | 8 +- netlib/http/message.py | 146 ++++++++++++++++++ netlib/http/models.py | 233 ---------------------------- netlib/http/request.py | 351 ++++++++++++++++++++++++++++++++++++++++++ netlib/http/response.py | 3 + netlib/tutils.py | 4 +- netlib/utils.py | 15 +- 10 files changed, 557 insertions(+), 275 deletions(-) create mode 100644 netlib/http/message.py create mode 100644 netlib/http/request.py create mode 100644 netlib/http/response.py (limited to 'netlib') diff --git a/netlib/http/__init__.py b/netlib/http/__init__.py index 0ccf6b32..e8c7ba20 100644 --- a/netlib/http/__init__.py +++ b/netlib/http/__init__.py @@ -1,12 +1,15 @@ from __future__ import absolute_import, print_function, division from .headers import Headers -from .models import Request, Response +from .message import decoded +from .request import Request +from .models import Response from .models import ALPN_PROTO_HTTP1, ALPN_PROTO_H2 from .models import HDR_FORM_MULTIPART, HDR_FORM_URLENCODED, CONTENT_MISSING from . import http1, http2 __all__ = [ "Headers", + "decoded", "Request", "Response", "ALPN_PROTO_HTTP1", "ALPN_PROTO_H2", "HDR_FORM_MULTIPART", "HDR_FORM_URLENCODED", "CONTENT_MISSING", diff --git a/netlib/http/headers.py b/netlib/http/headers.py index 613beb4f..47ea923b 100644 --- a/netlib/http/headers.py +++ b/netlib/http/headers.py @@ -27,7 +27,7 @@ else: _always_byte_args = always_byte_args("utf-8", "surrogateescape") -class Headers(MutableMapping, object): +class Headers(MutableMapping): """ Header class which allows both convenient access to individual headers as well as direct access to the underlying raw data. Provides a full dictionary interface. diff --git a/netlib/http/http1/assemble.py b/netlib/http/http1/assemble.py index 88aeac05..864f6017 100644 --- a/netlib/http/http1/assemble.py +++ b/netlib/http/http1/assemble.py @@ -7,24 +7,24 @@ from .. import CONTENT_MISSING def assemble_request(request): - if request.body == CONTENT_MISSING: + if request.content == CONTENT_MISSING: raise HttpException("Cannot assemble flow with CONTENT_MISSING") head = assemble_request_head(request) - body = b"".join(assemble_body(request.headers, [request.body])) + body = b"".join(assemble_body(request.headers, [request.data.content])) return head + body def assemble_request_head(request): - first_line = _assemble_request_line(request) - headers = _assemble_request_headers(request) + first_line = _assemble_request_line(request.data) + headers = _assemble_request_headers(request.data) return b"%s\r\n%s\r\n" % (first_line, headers) def assemble_response(response): - if response.body == CONTENT_MISSING: + if response.content == CONTENT_MISSING: raise HttpException("Cannot assemble flow with CONTENT_MISSING") head = assemble_response_head(response) - body = b"".join(assemble_body(response.headers, [response.body])) + body = b"".join(assemble_body(response.headers, [response.content])) return head + body @@ -45,42 +45,49 @@ def assemble_body(headers, body_chunks): yield chunk -def _assemble_request_line(request, form=None): - if form is None: - form = request.form_out +def _assemble_request_line(request_data): + """ + Args: + request_data (netlib.http.request.RequestData) + """ + form = request_data.first_line_format if form == "relative": return b"%s %s %s" % ( - request.method, - request.path, - request.http_version + request_data.method, + request_data.path, + request_data.http_version ) elif form == "authority": return b"%s %s:%d %s" % ( - request.method, - request.host, - request.port, - request.http_version + request_data.method, + request_data.host, + request_data.port, + request_data.http_version ) elif form == "absolute": return b"%s %s://%s:%d%s %s" % ( - request.method, - request.scheme, - request.host, - request.port, - request.path, - request.http_version + request_data.method, + request_data.scheme, + request_data.host, + request_data.port, + request_data.path, + request_data.http_version ) - else: # pragma: nocover + else: raise RuntimeError("Invalid request form") -def _assemble_request_headers(request): - headers = request.headers.copy() - if "host" not in headers and request.scheme and request.host and request.port: +def _assemble_request_headers(request_data): + """ + Args: + request_data (netlib.http.request.RequestData) + """ + headers = request_data.headers.copy() + if "host" not in headers and request_data.scheme and request_data.host and request_data.port: headers["host"] = utils.hostport( - request.scheme, - request.host, - request.port + request_data.scheme, + request_data.host, + request_data.port ) return bytes(headers) diff --git a/netlib/http/http1/read.py b/netlib/http/http1/read.py index 4c898348..76721e06 100644 --- a/netlib/http/http1/read.py +++ b/netlib/http/http1/read.py @@ -11,7 +11,7 @@ from .. import Request, Response, Headers def read_request(rfile, body_size_limit=None): request = read_request_head(rfile) expected_body_size = expected_http_body_size(request) - request._body = b"".join(read_body(rfile, expected_body_size, limit=body_size_limit)) + request.data.content = b"".join(read_body(rfile, expected_body_size, limit=body_size_limit)) request.timestamp_end = time.time() return request @@ -155,7 +155,7 @@ def connection_close(http_version, headers): # If we don't have a Connection header, HTTP 1.1 connections are assumed to # be persistent - return http_version != b"HTTP/1.1" + return http_version != "HTTP/1.1" and http_version != b"HTTP/1.1" # FIXME: Remove one case. def expected_http_body_size(request, response=None): @@ -184,11 +184,11 @@ def expected_http_body_size(request, response=None): if headers.get("expect", "").lower() == "100-continue": return 0 else: - if request.method.upper() == b"HEAD": + if request.method.upper() == "HEAD": return 0 if 100 <= response_code <= 199: return 0 - if response_code == 200 and request.method.upper() == b"CONNECT": + if response_code == 200 and request.method.upper() == "CONNECT": return 0 if response_code in (204, 304): return 0 diff --git a/netlib/http/message.py b/netlib/http/message.py new file mode 100644 index 00000000..20497bd5 --- /dev/null +++ b/netlib/http/message.py @@ -0,0 +1,146 @@ +from __future__ import absolute_import, print_function, division + +import warnings + +import six + +from .. import encoding, utils + +if six.PY2: + _native = lambda x: x + _always_bytes = lambda x: x +else: + # While headers _should_ be ASCII, it's not uncommon for certain headers to be utf-8 encoded. + _native = lambda x: x.decode("utf-8", "surrogateescape") + _always_bytes = lambda x: utils.always_bytes(x, "utf-8", "surrogateescape") + + +class Message(object): + def __init__(self, data): + self.data = data + + def __eq__(self, other): + if isinstance(other, Message): + return self.data == other.data + return False + + def __ne__(self, other): + return not self.__eq__(other) + + @property + def http_version(self): + """ + Version string, e.g. "HTTP/1.1" + """ + return _native(self.data.http_version) + + @http_version.setter + def http_version(self, http_version): + self.data.http_version = _always_bytes(http_version) + + @property + def headers(self): + """ + Message headers object + + Returns: + netlib.http.Headers + """ + return self.data.headers + + @headers.setter + def headers(self, h): + self.data.headers = h + + @property + def timestamp_start(self): + """ + First byte timestamp + """ + return self.data.timestamp_start + + @timestamp_start.setter + def timestamp_start(self, timestamp_start): + self.data.timestamp_start = timestamp_start + + @property + def timestamp_end(self): + """ + Last byte timestamp + """ + return self.data.timestamp_end + + @timestamp_end.setter + def timestamp_end(self, timestamp_end): + self.data.timestamp_end = timestamp_end + + @property + def content(self): + """ + The raw (encoded) HTTP message body + + See also: :py:attr:`text` + """ + return self.data.content + + @content.setter + def content(self, content): + self.data.content = content + if isinstance(content, bytes): + self.headers["content-length"] = str(len(content)) + + @property + def text(self): + """ + The decoded HTTP message body. + Decoded contents are not cached, so this method is relatively expensive to call. + + See also: :py:attr:`content`, :py:class:`decoded` + """ + # This attribute should be called text, because that's what requests does. + raise NotImplementedError() + + @text.setter + def text(self, text): + raise NotImplementedError() + + @property + def body(self): + warnings.warn(".body is deprecated, use .content instead.", DeprecationWarning) + return self.content + + @body.setter + def body(self, body): + warnings.warn(".body is deprecated, use .content instead.", DeprecationWarning) + self.content = body + + +class decoded(object): + """ + A context manager that decodes a request or response, and then + re-encodes it with the same encoding after execution of the block. + + Example: + + .. code-block:: python + + with decoded(request): + request.content = request.content.replace("foo", "bar") + """ + + def __init__(self, message): + self.message = message + ce = message.headers.get("content-encoding") + if ce in encoding.ENCODINGS: + self.ce = ce + else: + self.ce = None + + def __enter__(self): + if self.ce: + if not self.message.decode(): + self.ce = None + + def __exit__(self, type, value, tb): + if self.ce: + self.message.encode(self.ce) \ No newline at end of file diff --git a/netlib/http/models.py b/netlib/http/models.py index 55664533..40f6e98c 100644 --- a/netlib/http/models.py +++ b/netlib/http/models.py @@ -47,239 +47,6 @@ class Message(object): return False -class Request(Message): - def __init__( - self, - form_in, - method, - scheme, - host, - port, - path, - http_version, - headers=None, - body=None, - timestamp_start=None, - timestamp_end=None, - form_out=None - ): - super(Request, self).__init__(http_version, headers, body, timestamp_start, timestamp_end) - - self.form_in = form_in - self.method = method - self.scheme = scheme - self.host = host - self.port = port - self.path = path - self.form_out = form_out or form_in - - def __repr__(self): - if self.host and self.port: - hostport = "{}:{}".format(native(self.host,"idna"), self.port) - else: - hostport = "" - path = self.path or "" - return "HTTPRequest({} {}{})".format( - self.method, hostport, path - ) - - def anticache(self): - """ - Modifies this request to remove headers that might produce a cached - response. That is, we remove ETags and If-Modified-Since headers. - """ - delheaders = [ - "if-modified-since", - "if-none-match", - ] - for i in delheaders: - self.headers.pop(i, None) - - def anticomp(self): - """ - Modifies this request to remove headers that will compress the - resource's data. - """ - self.headers["accept-encoding"] = "identity" - - def constrain_encoding(self): - """ - Limits the permissible Accept-Encoding values, based on what we can - decode appropriately. - """ - accept_encoding = self.headers.get("accept-encoding") - if accept_encoding: - self.headers["accept-encoding"] = ( - ', '.join( - e - for e in encoding.ENCODINGS - if e in accept_encoding - ) - ) - - def update_host_header(self): - """ - Update the host header to reflect the current target. - """ - self.headers["host"] = self.host - - def get_form(self): - """ - Retrieves the URL-encoded or multipart form data, returning an ODict object. - Returns an empty ODict if there is no data or the content-type - indicates non-form data. - """ - if self.body: - if HDR_FORM_URLENCODED in self.headers.get("content-type", "").lower(): - return self.get_form_urlencoded() - elif HDR_FORM_MULTIPART in self.headers.get("content-type", "").lower(): - return self.get_form_multipart() - return ODict([]) - - def get_form_urlencoded(self): - """ - Retrieves the URL-encoded form data, returning an ODict object. - Returns an empty ODict if there is no data or the content-type - indicates non-form data. - """ - if self.body and HDR_FORM_URLENCODED in self.headers.get("content-type", "").lower(): - return ODict(utils.urldecode(self.body)) - return ODict([]) - - def get_form_multipart(self): - if self.body and HDR_FORM_MULTIPART in self.headers.get("content-type", "").lower(): - return ODict( - utils.multipartdecode( - self.headers, - self.body)) - return ODict([]) - - def set_form_urlencoded(self, odict): - """ - Sets the body to the URL-encoded form data, and adds the - appropriate content-type header. Note that this will destory the - existing body if there is one. - """ - # FIXME: If there's an existing content-type header indicating a - # url-encoded form, leave it alone. - self.headers["content-type"] = HDR_FORM_URLENCODED - self.body = utils.urlencode(odict.lst) - - def get_path_components(self): - """ - Returns the path components of the URL as a list of strings. - - Components are unquoted. - """ - _, _, path, _, _, _ = urllib.parse.urlparse(self.url) - return [urllib.parse.unquote(native(i,"ascii")) for i in path.split(b"/") if i] - - def set_path_components(self, lst): - """ - Takes a list of strings, and sets the path component of the URL. - - Components are quoted. - """ - lst = [urllib.parse.quote(i, safe="") for i in lst] - path = always_bytes("/" + "/".join(lst)) - scheme, netloc, _, params, query, fragment = urllib.parse.urlparse(self.url) - self.url = urllib.parse.urlunparse( - [scheme, netloc, path, params, query, fragment] - ) - - def get_query(self): - """ - Gets the request query string. Returns an ODict object. - """ - _, _, _, _, query, _ = urllib.parse.urlparse(self.url) - if query: - return ODict(utils.urldecode(query)) - return ODict([]) - - def set_query(self, odict): - """ - Takes an ODict object, and sets the request query string. - """ - scheme, netloc, path, params, _, fragment = urllib.parse.urlparse(self.url) - query = utils.urlencode(odict.lst) - self.url = urllib.parse.urlunparse( - [scheme, netloc, path, params, query, fragment] - ) - - def pretty_host(self, hostheader): - """ - Heuristic to get the host of the request. - - Note that pretty_host() does not always return the TCP destination - of the request, e.g. if an upstream proxy is in place - - If hostheader is set to True, the Host: header will be used as - additional (and preferred) data source. This is handy in - transparent mode, where only the IO of the destination is known, - but not the resolved name. This is disabled by default, as an - attacker may spoof the host header to confuse an analyst. - """ - if hostheader and "host" in self.headers: - try: - return self.headers["host"] - except ValueError: - pass - if self.host: - return self.host.decode("idna") - - def pretty_url(self, hostheader): - if self.form_out == "authority": # upstream proxy mode - return b"%s:%d" % (always_bytes(self.pretty_host(hostheader)), self.port) - return utils.unparse_url(self.scheme, - self.pretty_host(hostheader), - self.port, - self.path) - - def get_cookies(self): - """ - Returns a possibly empty netlib.odict.ODict object. - """ - ret = ODict() - for i in self.headers.get_all("Cookie"): - ret.extend(cookies.parse_cookie_header(i)) - return ret - - def set_cookies(self, odict): - """ - Takes an netlib.odict.ODict object. Over-writes any existing Cookie - headers. - """ - v = cookies.format_cookie_header(odict) - self.headers["cookie"] = v - - @property - def url(self): - """ - Returns a URL string, constructed from the Request's URL components. - """ - return utils.unparse_url( - self.scheme, - self.host, - self.port, - self.path - ) - - @url.setter - def url(self, url): - """ - Parses a URL specification, and updates the Request's information - accordingly. - - Raises: - ValueError if the URL was invalid - """ - # TODO: Should handle incoming unicode here. - parts = utils.parse_url(url) - if not parts: - raise ValueError("Invalid URL: %s" % url) - self.scheme, self.host, self.port, self.path = parts - - class Response(Message): def __init__( self, diff --git a/netlib/http/request.py b/netlib/http/request.py new file mode 100644 index 00000000..6830ca40 --- /dev/null +++ b/netlib/http/request.py @@ -0,0 +1,351 @@ +from __future__ import absolute_import, print_function, division + +import warnings + +import six +from six.moves import urllib + +from netlib import utils +from netlib.http import cookies +from netlib.odict import ODict +from .. import encoding +from .headers import Headers +from .message import Message, _native, _always_bytes + + +class RequestData(object): + def __init__(self, first_line_format, method, scheme, host, port, path, http_version, headers=None, content=None, + timestamp_start=None, timestamp_end=None): + if not headers: + headers = Headers() + assert isinstance(headers, Headers) + + self.first_line_format = first_line_format + self.method = method + self.scheme = scheme + self.host = host + self.port = port + self.path = path + self.http_version = http_version + self.headers = headers + self.content = content + self.timestamp_start = timestamp_start + self.timestamp_end = timestamp_end + + def __eq__(self, other): + if isinstance(other, RequestData): + return self.__dict__ == other.__dict__ + return False + + def __ne__(self, other): + return not self.__eq__(other) + + +class Request(Message): + """ + An HTTP request. + """ + def __init__(self, *args, **kwargs): + data = RequestData(*args, **kwargs) + super(Request, self).__init__(data) + + def __repr__(self): + if self.host and self.port: + hostport = "{}:{}".format(self.host, self.port) + else: + hostport = "" + path = self.path or "" + return "HTTPRequest({} {}{})".format( + self.method, hostport, path + ) + + @property + def first_line_format(self): + """ + HTTP request form as defined in `RFC7230 `_. + + origin-form and asterisk-form are subsumed as "relative". + """ + return self.data.first_line_format + + @first_line_format.setter + def first_line_format(self, first_line_format): + self.data.first_line_format = first_line_format + + @property + def method(self): + """ + HTTP request method, e.g. "GET". + """ + return _native(self.data.method) + + @method.setter + def method(self, method): + self.data.method = _always_bytes(method) + + @property + def scheme(self): + """ + HTTP request scheme, which should be "http" or "https". + """ + return _native(self.data.scheme) + + @scheme.setter + def scheme(self, scheme): + self.data.scheme = _always_bytes(scheme) + + @property + def host(self): + """ + Target host for the request. This may be directly taken in the request (e.g. "GET http://example.com/ HTTP/1.1") + or inferred from the proxy mode (e.g. an IP in transparent mode). + """ + + if six.PY2: + return self.data.host + + if not self.data.host: + return self.data.host + try: + return self.data.host.decode("idna") + except UnicodeError: + return self.data.host.decode("utf8", "surrogateescape") + + @host.setter + def host(self, host): + if isinstance(host, six.text_type): + try: + # There's no non-strict mode for IDNA encoding. + # We don't want this operation to fail though, so we try + # utf8 as a last resort. + host = host.encode("idna", "strict") + except UnicodeError: + host = host.encode("utf8", "surrogateescape") + + self.data.host = host + + # Update host header + if "host" in self.headers: + if host: + self.headers["host"] = host + else: + self.headers.pop("host") + + @property + def port(self): + """ + Target port + """ + return self.data.port + + @port.setter + def port(self, port): + self.data.port = port + + @property + def path(self): + """ + HTTP request path, e.g. "/index.html". + Guaranteed to start with a slash. + """ + return _native(self.data.path) + + @path.setter + def path(self, path): + self.data.path = _always_bytes(path) + + def anticache(self): + """ + Modifies this request to remove headers that might produce a cached + response. That is, we remove ETags and If-Modified-Since headers. + """ + delheaders = [ + "if-modified-since", + "if-none-match", + ] + for i in delheaders: + self.headers.pop(i, None) + + def anticomp(self): + """ + Modifies this request to remove headers that will compress the + resource's data. + """ + self.headers["accept-encoding"] = "identity" + + def constrain_encoding(self): + """ + Limits the permissible Accept-Encoding values, based on what we can + decode appropriately. + """ + accept_encoding = self.headers.get("accept-encoding") + if accept_encoding: + self.headers["accept-encoding"] = ( + ', '.join( + e + for e in encoding.ENCODINGS + if e in accept_encoding + ) + ) + + @property + def urlencoded_form(self): + """ + The URL-encoded form data as an ODict object. + None if there is no data or the content-type indicates non-form data. + """ + is_valid_content_type = "application/x-www-form-urlencoded" in self.headers.get("content-type", "").lower() + if self.content and is_valid_content_type: + return ODict(utils.urldecode(self.content)) + return None + + @urlencoded_form.setter + def urlencoded_form(self, odict): + """ + Sets the body to the URL-encoded form data, and adds the appropriate content-type header. + This will overwrite the existing content if there is one. + """ + self.headers["content-type"] = "application/x-www-form-urlencoded" + self.content = utils.urlencode(odict.lst) + + @property + def multipart_form(self): + """ + The multipart form data as an ODict object. + None if there is no data or the content-type indicates non-form data. + """ + is_valid_content_type = "multipart/form-data" in self.headers.get("content-type", "").lower() + if self.content and is_valid_content_type: + return ODict(utils.multipartdecode(self.headers,self.content)) + return None + + @multipart_form.setter + def multipart_form(self): + raise NotImplementedError() + + @property + def path_components(self): + """ + The URL's path components as a list of strings. + Components are unquoted. + """ + _, _, path, _, _, _ = urllib.parse.urlparse(self.url) + return [urllib.parse.unquote(i) for i in path.split("/") if i] + + @path_components.setter + def path_components(self, components): + components = map(lambda x: urllib.parse.quote(x, safe=""), components) + path = "/" + "/".join(components) + scheme, netloc, _, params, query, fragment = urllib.parse.urlparse(self.url) + self.url = urllib.parse.urlunparse([scheme, netloc, path, params, query, fragment]) + + @property + def query(self): + """ + The request query string as an ODict object. + None, if there is no query. + """ + _, _, _, _, query, _ = urllib.parse.urlparse(self.url) + if query: + return ODict(utils.urldecode(query)) + return None + + @query.setter + def query(self, odict): + query = utils.urlencode(odict.lst) + scheme, netloc, path, params, _, fragment = urllib.parse.urlparse(self.url) + self.url = urllib.parse.urlunparse([scheme, netloc, path, params, query, fragment]) + + @property + def cookies(self): + """ + The request cookies. + An empty ODict object if the cookie monster ate them all. + """ + ret = ODict() + for i in self.headers.get_all("Cookie"): + ret.extend(cookies.parse_cookie_header(i)) + return ret + + @cookies.setter + def cookies(self, odict): + self.headers["cookie"] = cookies.format_cookie_header(odict) + + @property + def url(self): + """ + The URL string, constructed from the request's URL components + """ + return utils.unparse_url(self.scheme, self.host, self.port, self.path) + + @url.setter + def url(self, url): + self.scheme, self.host, self.port, self.path = utils.parse_url(url) + + @property + def pretty_host(self): + return self.headers.get("host", self.host) + + @property + def pretty_url(self): + if self.first_line_format == "authority": + return "%s:%d" % (self.pretty_host, self.port) + return utils.unparse_url(self.scheme, self.pretty_host, self.port, self.path) + + # Legacy + + def get_cookies(self): + warnings.warn(".get_cookies is deprecated, use .cookies instead.", DeprecationWarning) + return self.cookies + + def set_cookies(self, odict): + warnings.warn(".set_cookies is deprecated, use .cookies instead.", DeprecationWarning) + self.cookies = odict + + def get_query(self): + warnings.warn(".get_query is deprecated, use .query instead.", DeprecationWarning) + return self.query or ODict([]) + + def set_query(self, odict): + warnings.warn(".set_query is deprecated, use .query instead.", DeprecationWarning) + self.query = odict + + def get_path_components(self): + warnings.warn(".get_path_components is deprecated, use .path_components instead.", DeprecationWarning) + return self.path_components + + def set_path_components(self, lst): + warnings.warn(".set_path_components is deprecated, use .path_components instead.", DeprecationWarning) + self.path_components = lst + + def get_form_urlencoded(self): + warnings.warn(".get_form_urlencoded is deprecated, use .urlencoded_form instead.", DeprecationWarning) + return self.urlencoded_form or ODict([]) + + def set_form_urlencoded(self, odict): + warnings.warn(".set_form_urlencoded is deprecated, use .urlencoded_form instead.", DeprecationWarning) + self.urlencoded_form = odict + + def get_form_multipart(self): + warnings.warn(".get_form_multipart is deprecated, use .multipart_form instead.", DeprecationWarning) + return self.multipart_form or ODict([]) + + @property + def form_in(self): + warnings.warn(".form_in is deprecated, use .first_line_format instead.", DeprecationWarning) + return self.first_line_format + + @form_in.setter + def form_in(self, form_in): + warnings.warn(".form_in is deprecated, use .first_line_format instead.", DeprecationWarning) + self.first_line_format = form_in + + @property + def form_out(self): + warnings.warn(".form_out is deprecated, use .first_line_format instead.", DeprecationWarning) + return self.first_line_format + + @form_out.setter + def form_out(self, form_out): + warnings.warn(".form_out is deprecated, use .first_line_format instead.", DeprecationWarning) + self.first_line_format = form_out \ No newline at end of file diff --git a/netlib/http/response.py b/netlib/http/response.py new file mode 100644 index 00000000..02fac3df --- /dev/null +++ b/netlib/http/response.py @@ -0,0 +1,3 @@ +from __future__ import absolute_import, print_function, division + +# TODO \ No newline at end of file diff --git a/netlib/tutils.py b/netlib/tutils.py index 1665a792..ff63c33c 100644 --- a/netlib/tutils.py +++ b/netlib/tutils.py @@ -98,7 +98,7 @@ def treq(**kwargs): netlib.http.Request """ default = dict( - form_in="relative", + first_line_format="relative", method=b"GET", scheme=b"http", host=b"address", @@ -106,7 +106,7 @@ def treq(**kwargs): path=b"/path", http_version=b"HTTP/1.1", headers=Headers(header="qvalue"), - body=b"content" + content=b"content" ) default.update(kwargs) return Request(**default) diff --git a/netlib/utils.py b/netlib/utils.py index 6f6d1ea0..3ec60890 100644 --- a/netlib/utils.py +++ b/netlib/utils.py @@ -273,22 +273,27 @@ def get_header_tokens(headers, key): return [token.strip() for token in tokens] -@always_byte_args() def hostport(scheme, host, port): """ Returns the host component, with a port specifcation if needed. """ - if (port, scheme) in [(80, b"http"), (443, b"https")]: + if (port, scheme) in [(80, "http"), (443, "https"), (80, b"http"), (443, b"https")]: return host else: - return b"%s:%d" % (host, port) + if isinstance(host, six.binary_type): + return b"%s:%d" % (host, port) + else: + return "%s:%d" % (host, port) def unparse_url(scheme, host, port, path=""): """ - Returns a URL string, constructed from the specified compnents. + Returns a URL string, constructed from the specified components. + + Args: + All args must be str. """ - return b"%s://%s%s" % (scheme, hostport(scheme, host, port), path) + return "%s://%s%s" % (scheme, hostport(scheme, host, port), path) def urlencode(s): -- cgit v1.2.3 From 49ea8fc0ebcfe4861f099200044a553f092faec7 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Sat, 26 Sep 2015 17:39:50 +0200 Subject: refactor response model --- netlib/http/__init__.py | 15 ++-- netlib/http/headers.py | 26 +++---- netlib/http/http1/assemble.py | 16 ++-- netlib/http/http1/read.py | 2 +- netlib/http/http2/connections.py | 4 +- netlib/http/http2/frame.py | 3 - netlib/http/message.py | 64 +++++++++------- netlib/http/models.py | 112 ---------------------------- netlib/http/request.py | 155 +++++++++++++++++++++------------------ netlib/http/response.py | 124 ++++++++++++++++++++++++++++++- netlib/tutils.py | 6 +- netlib/wsgi.py | 6 +- 12 files changed, 277 insertions(+), 256 deletions(-) delete mode 100644 netlib/http/models.py (limited to 'netlib') diff --git a/netlib/http/__init__.py b/netlib/http/__init__.py index e8c7ba20..fd632cd5 100644 --- a/netlib/http/__init__.py +++ b/netlib/http/__init__.py @@ -1,17 +1,14 @@ from __future__ import absolute_import, print_function, division -from .headers import Headers -from .message import decoded from .request import Request -from .models import Response -from .models import ALPN_PROTO_HTTP1, ALPN_PROTO_H2 -from .models import HDR_FORM_MULTIPART, HDR_FORM_URLENCODED, CONTENT_MISSING +from .response import Response +from .headers import Headers +from .message import decoded, CONTENT_MISSING from . import http1, http2 __all__ = [ + "Request", + "Response", "Headers", - "decoded", - "Request", "Response", - "ALPN_PROTO_HTTP1", "ALPN_PROTO_H2", - "HDR_FORM_MULTIPART", "HDR_FORM_URLENCODED", "CONTENT_MISSING", + "decoded", "CONTENT_MISSING", "http1", "http2", ] diff --git a/netlib/http/headers.py b/netlib/http/headers.py index 47ea923b..c79c3344 100644 --- a/netlib/http/headers.py +++ b/netlib/http/headers.py @@ -36,12 +36,8 @@ class Headers(MutableMapping): .. code-block:: python - # Create header from a list of (header_name, header_value) tuples - >>> h = Headers([ - ["Host","example.com"], - ["Accept","text/html"], - ["accept","application/xml"] - ]) + # Create headers with keyword arguments + >>> h = Headers(host="example.com", content_type="application/xml") # Headers mostly behave like a normal dict. >>> h["Host"] @@ -51,6 +47,13 @@ class Headers(MutableMapping): >>> h["host"] "example.com" + # Headers can also be creatd from a list of raw (header_name, header_value) byte tuples + >>> h = Headers([ + [b"Host",b"example.com"], + [b"Accept",b"text/html"], + [b"accept",b"application/xml"] + ]) + # Multiple headers are folded into a single header as per RFC7230 >>> h["Accept"] "text/html, application/xml" @@ -60,17 +63,14 @@ class Headers(MutableMapping): >>> h["Accept"] "application/text" - # str(h) returns a HTTP1 header block. - >>> print(h) + # bytes(h) returns a HTTP1 header block. + >>> print(bytes(h)) Host: example.com Accept: application/text # For full control, the raw header fields can be accessed >>> h.fields - # Headers can also be crated from keyword arguments - >>> h = Headers(host="example.com", content_type="application/xml") - Caveats: For use with the "Set-Cookie" header, see :py:meth:`get_all`. """ @@ -79,8 +79,8 @@ class Headers(MutableMapping): def __init__(self, fields=None, **headers): """ Args: - fields: (optional) list of ``(name, value)`` header tuples, - e.g. ``[("Host","example.com")]``. All names and values must be bytes. + fields: (optional) list of ``(name, value)`` header byte tuples, + e.g. ``[(b"Host", b"example.com")]``. All names and values must be bytes. **headers: Additional headers to set. Will overwrite existing values from `fields`. For convenience, underscores in header names will be transformed to dashes - this behaviour does not extend to other methods. diff --git a/netlib/http/http1/assemble.py b/netlib/http/http1/assemble.py index 864f6017..785ee8d3 100644 --- a/netlib/http/http1/assemble.py +++ b/netlib/http/http1/assemble.py @@ -10,7 +10,7 @@ def assemble_request(request): if request.content == CONTENT_MISSING: raise HttpException("Cannot assemble flow with CONTENT_MISSING") head = assemble_request_head(request) - body = b"".join(assemble_body(request.headers, [request.data.content])) + body = b"".join(assemble_body(request.data.headers, [request.data.content])) return head + body @@ -24,13 +24,13 @@ def assemble_response(response): if response.content == CONTENT_MISSING: raise HttpException("Cannot assemble flow with CONTENT_MISSING") head = assemble_response_head(response) - body = b"".join(assemble_body(response.headers, [response.content])) + body = b"".join(assemble_body(response.data.headers, [response.data.content])) return head + body def assemble_response_head(response): - first_line = _assemble_response_line(response) - headers = _assemble_response_headers(response) + first_line = _assemble_response_line(response.data) + headers = _assemble_response_headers(response.data) return b"%s\r\n%s\r\n" % (first_line, headers) @@ -92,11 +92,11 @@ def _assemble_request_headers(request_data): return bytes(headers) -def _assemble_response_line(response): +def _assemble_response_line(response_data): return b"%s %d %s" % ( - response.http_version, - response.status_code, - response.msg, + response_data.http_version, + response_data.status_code, + response_data.reason, ) diff --git a/netlib/http/http1/read.py b/netlib/http/http1/read.py index 76721e06..0d5e7f4b 100644 --- a/netlib/http/http1/read.py +++ b/netlib/http/http1/read.py @@ -50,7 +50,7 @@ def read_request_head(rfile): def read_response(rfile, request, body_size_limit=None): response = read_response_head(rfile) expected_body_size = expected_http_body_size(request, response) - response._body = b"".join(read_body(rfile, expected_body_size, body_size_limit)) + response.data.content = b"".join(read_body(rfile, expected_body_size, body_size_limit)) response.timestamp_end = time.time() return response diff --git a/netlib/http/http2/connections.py b/netlib/http/http2/connections.py index 5220d5d2..c493abe6 100644 --- a/netlib/http/http2/connections.py +++ b/netlib/http/http2/connections.py @@ -4,7 +4,7 @@ import time from hpack.hpack import Encoder, Decoder from ... import utils -from .. import Headers, Response, Request, ALPN_PROTO_H2 +from .. import Headers, Response, Request from . import frame @@ -283,7 +283,7 @@ class HTTP2Protocol(object): def check_alpn(self): alp = self.tcp_handler.get_alpn_proto_negotiated() - if alp != ALPN_PROTO_H2: + if alp != b'h2': raise NotImplementedError( "HTTP2Protocol can not handle unknown ALP: %s" % alp) return True diff --git a/netlib/http/http2/frame.py b/netlib/http/http2/frame.py index cb2cde99..188629d4 100644 --- a/netlib/http/http2/frame.py +++ b/netlib/http/http2/frame.py @@ -25,9 +25,6 @@ ERROR_CODES = BiDi( CLIENT_CONNECTION_PREFACE = b"PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n" -ALPN_PROTO_H2 = b'h2' - - class Frame(object): """ diff --git a/netlib/http/message.py b/netlib/http/message.py index 20497bd5..ee138746 100644 --- a/netlib/http/message.py +++ b/netlib/http/message.py @@ -6,11 +6,14 @@ import six from .. import encoding, utils + +CONTENT_MISSING = 0 + if six.PY2: _native = lambda x: x _always_bytes = lambda x: x else: - # While headers _should_ be ASCII, it's not uncommon for certain headers to be utf-8 encoded. + # While the HTTP head _should_ be ASCII, it's not uncommon for certain headers to be utf-8 encoded. _native = lambda x: x.decode("utf-8", "surrogateescape") _always_bytes = lambda x: utils.always_bytes(x, "utf-8", "surrogateescape") @@ -27,17 +30,6 @@ class Message(object): def __ne__(self, other): return not self.__eq__(other) - @property - def http_version(self): - """ - Version string, e.g. "HTTP/1.1" - """ - return _native(self.data.http_version) - - @http_version.setter - def http_version(self, http_version): - self.data.http_version = _always_bytes(http_version) - @property def headers(self): """ @@ -52,6 +44,32 @@ class Message(object): def headers(self, h): self.data.headers = h + @property + def content(self): + """ + The raw (encoded) HTTP message body + + See also: :py:attr:`text` + """ + return self.data.content + + @content.setter + def content(self, content): + self.data.content = content + if isinstance(content, bytes): + self.headers["content-length"] = str(len(content)) + + @property + def http_version(self): + """ + Version string, e.g. "HTTP/1.1" + """ + return _native(self.data.http_version) + + @http_version.setter + def http_version(self, http_version): + self.data.http_version = _always_bytes(http_version) + @property def timestamp_start(self): """ @@ -74,26 +92,14 @@ class Message(object): def timestamp_end(self, timestamp_end): self.data.timestamp_end = timestamp_end - @property - def content(self): - """ - The raw (encoded) HTTP message body - - See also: :py:attr:`text` - """ - return self.data.content - - @content.setter - def content(self, content): - self.data.content = content - if isinstance(content, bytes): - self.headers["content-length"] = str(len(content)) - @property def text(self): """ The decoded HTTP message body. - Decoded contents are not cached, so this method is relatively expensive to call. + Decoded contents are not cached, so accessing this attribute repeatedly is relatively expensive. + + .. note:: + This is not implemented yet. See also: :py:attr:`content`, :py:class:`decoded` """ @@ -104,6 +110,8 @@ class Message(object): def text(self, text): raise NotImplementedError() + # Legacy + @property def body(self): warnings.warn(".body is deprecated, use .content instead.", DeprecationWarning) diff --git a/netlib/http/models.py b/netlib/http/models.py deleted file mode 100644 index 40f6e98c..00000000 --- a/netlib/http/models.py +++ /dev/null @@ -1,112 +0,0 @@ - - -from ..odict import ODict -from .. import utils, encoding -from ..utils import always_bytes, native -from . import cookies -from .headers import Headers - -from six.moves import urllib - -# TODO: Move somewhere else? -ALPN_PROTO_HTTP1 = b'http/1.1' -ALPN_PROTO_H2 = b'h2' -HDR_FORM_URLENCODED = "application/x-www-form-urlencoded" -HDR_FORM_MULTIPART = "multipart/form-data" - -CONTENT_MISSING = 0 - - -class Message(object): - def __init__(self, http_version, headers, body, timestamp_start, timestamp_end): - self.http_version = http_version - if not headers: - headers = Headers() - assert isinstance(headers, Headers) - self.headers = headers - - self._body = body - self.timestamp_start = timestamp_start - self.timestamp_end = timestamp_end - - @property - def body(self): - return self._body - - @body.setter - def body(self, body): - self._body = body - if isinstance(body, bytes): - self.headers["content-length"] = str(len(body)).encode() - - content = body - - def __eq__(self, other): - if isinstance(other, Message): - return self.__dict__ == other.__dict__ - return False - - -class Response(Message): - def __init__( - self, - http_version, - status_code, - msg=None, - headers=None, - body=None, - timestamp_start=None, - timestamp_end=None, - ): - super(Response, self).__init__(http_version, headers, body, timestamp_start, timestamp_end) - self.status_code = status_code - self.msg = msg - - def __repr__(self): - # return "Response(%s - %s)" % (self.status_code, self.msg) - - if self.body: - size = utils.pretty_size(len(self.body)) - else: - size = "content missing" - # TODO: Remove "(unknown content type, content missing)" edge-case - return "".format( - status_code=self.status_code, - msg=self.msg, - contenttype=self.headers.get("content-type", "unknown content type"), - size=size) - - def get_cookies(self): - """ - Get the contents of all Set-Cookie headers. - - Returns a possibly empty ODict, where keys are cookie name strings, - and values are [value, attr] lists. Value is a string, and attr is - an ODictCaseless containing cookie attributes. Within attrs, unary - attributes (e.g. HTTPOnly) are indicated by a Null value. - """ - ret = [] - for header in self.headers.get_all("set-cookie"): - v = cookies.parse_set_cookie_header(header) - if v: - name, value, attrs = v - ret.append([name, [value, attrs]]) - return ODict(ret) - - def set_cookies(self, odict): - """ - Set the Set-Cookie headers on this response, over-writing existing - headers. - - Accepts an ODict of the same format as that returned by get_cookies. - """ - values = [] - for i in odict.lst: - values.append( - cookies.format_set_cookie_header( - i[0], - i[1][0], - i[1][1] - ) - ) - self.headers.set_all("set-cookie", values) diff --git a/netlib/http/request.py b/netlib/http/request.py index 6830ca40..f8a3b5b9 100644 --- a/netlib/http/request.py +++ b/netlib/http/request.py @@ -55,7 +55,7 @@ class Request(Message): else: hostport = "" path = self.path or "" - return "HTTPRequest({} {}{})".format( + return "Request({} {}{})".format( self.method, hostport, path ) @@ -97,7 +97,8 @@ class Request(Message): @property def host(self): """ - Target host for the request. This may be directly taken in the request (e.g. "GET http://example.com/ HTTP/1.1") + Target host. This may be parsed from the raw request + (e.g. from a ``GET http://example.com/ HTTP/1.1`` request line) or inferred from the proxy mode (e.g. an IP in transparent mode). """ @@ -154,6 +155,83 @@ class Request(Message): def path(self, path): self.data.path = _always_bytes(path) + @property + def url(self): + """ + The URL string, constructed from the request's URL components + """ + return utils.unparse_url(self.scheme, self.host, self.port, self.path) + + @url.setter + def url(self, url): + self.scheme, self.host, self.port, self.path = utils.parse_url(url) + + @property + def pretty_host(self): + """ + Similar to :py:attr:`host`, but using the Host headers as an additional preferred data source. + This is useful in transparent mode where :py:attr:`host` is only an IP address, + but may not reflect the actual destination as the Host header could be spoofed. + """ + return self.headers.get("host", self.host) + + @property + def pretty_url(self): + """ + Like :py:attr:`url`, but using :py:attr:`pretty_host` instead of :py:attr:`host`. + """ + if self.first_line_format == "authority": + return "%s:%d" % (self.pretty_host, self.port) + return utils.unparse_url(self.scheme, self.pretty_host, self.port, self.path) + + @property + def query(self): + """ + The request query string as an :py:class:`ODict` object. + None, if there is no query. + """ + _, _, _, _, query, _ = urllib.parse.urlparse(self.url) + if query: + return ODict(utils.urldecode(query)) + return None + + @query.setter + def query(self, odict): + query = utils.urlencode(odict.lst) + scheme, netloc, path, params, _, fragment = urllib.parse.urlparse(self.url) + self.url = urllib.parse.urlunparse([scheme, netloc, path, params, query, fragment]) + + @property + def cookies(self): + """ + The request cookies. + An empty :py:class:`ODict` object if the cookie monster ate them all. + """ + ret = ODict() + for i in self.headers.get_all("Cookie"): + ret.extend(cookies.parse_cookie_header(i)) + return ret + + @cookies.setter + def cookies(self, odict): + self.headers["cookie"] = cookies.format_cookie_header(odict) + + @property + def path_components(self): + """ + The URL's path components as a list of strings. + Components are unquoted. + """ + _, _, path, _, _, _ = urllib.parse.urlparse(self.url) + return [urllib.parse.unquote(i) for i in path.split("/") if i] + + @path_components.setter + def path_components(self, components): + components = map(lambda x: urllib.parse.quote(x, safe=""), components) + path = "/" + "/".join(components) + scheme, netloc, _, params, query, fragment = urllib.parse.urlparse(self.url) + self.url = urllib.parse.urlunparse([scheme, netloc, path, params, query, fragment]) + def anticache(self): """ Modifies this request to remove headers that might produce a cached @@ -191,7 +269,7 @@ class Request(Message): @property def urlencoded_form(self): """ - The URL-encoded form data as an ODict object. + The URL-encoded form data as an :py:class:`ODict` object. None if there is no data or the content-type indicates non-form data. """ is_valid_content_type = "application/x-www-form-urlencoded" in self.headers.get("content-type", "").lower() @@ -211,7 +289,7 @@ class Request(Message): @property def multipart_form(self): """ - The multipart form data as an ODict object. + The multipart form data as an :py:class:`ODict` object. None if there is no data or the content-type indicates non-form data. """ is_valid_content_type = "multipart/form-data" in self.headers.get("content-type", "").lower() @@ -223,75 +301,6 @@ class Request(Message): def multipart_form(self): raise NotImplementedError() - @property - def path_components(self): - """ - The URL's path components as a list of strings. - Components are unquoted. - """ - _, _, path, _, _, _ = urllib.parse.urlparse(self.url) - return [urllib.parse.unquote(i) for i in path.split("/") if i] - - @path_components.setter - def path_components(self, components): - components = map(lambda x: urllib.parse.quote(x, safe=""), components) - path = "/" + "/".join(components) - scheme, netloc, _, params, query, fragment = urllib.parse.urlparse(self.url) - self.url = urllib.parse.urlunparse([scheme, netloc, path, params, query, fragment]) - - @property - def query(self): - """ - The request query string as an ODict object. - None, if there is no query. - """ - _, _, _, _, query, _ = urllib.parse.urlparse(self.url) - if query: - return ODict(utils.urldecode(query)) - return None - - @query.setter - def query(self, odict): - query = utils.urlencode(odict.lst) - scheme, netloc, path, params, _, fragment = urllib.parse.urlparse(self.url) - self.url = urllib.parse.urlunparse([scheme, netloc, path, params, query, fragment]) - - @property - def cookies(self): - """ - The request cookies. - An empty ODict object if the cookie monster ate them all. - """ - ret = ODict() - for i in self.headers.get_all("Cookie"): - ret.extend(cookies.parse_cookie_header(i)) - return ret - - @cookies.setter - def cookies(self, odict): - self.headers["cookie"] = cookies.format_cookie_header(odict) - - @property - def url(self): - """ - The URL string, constructed from the request's URL components - """ - return utils.unparse_url(self.scheme, self.host, self.port, self.path) - - @url.setter - def url(self, url): - self.scheme, self.host, self.port, self.path = utils.parse_url(url) - - @property - def pretty_host(self): - return self.headers.get("host", self.host) - - @property - def pretty_url(self): - if self.first_line_format == "authority": - return "%s:%d" % (self.pretty_host, self.port) - return utils.unparse_url(self.scheme, self.pretty_host, self.port, self.path) - # Legacy def get_cookies(self): diff --git a/netlib/http/response.py b/netlib/http/response.py index 02fac3df..7d64243d 100644 --- a/netlib/http/response.py +++ b/netlib/http/response.py @@ -1,3 +1,125 @@ from __future__ import absolute_import, print_function, division -# TODO \ No newline at end of file +import warnings + +from . import cookies +from .headers import Headers +from .message import Message, _native, _always_bytes +from .. import utils +from ..odict import ODict + + +class ResponseData(object): + def __init__(self, http_version, status_code, reason=None, headers=None, content=None, + timestamp_start=None, timestamp_end=None): + if not headers: + headers = Headers() + assert isinstance(headers, Headers) + + self.http_version = http_version + self.status_code = status_code + self.reason = reason + self.headers = headers + self.content = content + self.timestamp_start = timestamp_start + self.timestamp_end = timestamp_end + + def __eq__(self, other): + if isinstance(other, ResponseData): + return self.__dict__ == other.__dict__ + return False + + def __ne__(self, other): + return not self.__eq__(other) + + +class Response(Message): + """ + An HTTP response. + """ + def __init__(self, *args, **kwargs): + data = ResponseData(*args, **kwargs) + super(Response, self).__init__(data) + + def __repr__(self): + if self.content: + details = "{}, {}".format( + self.headers.get("content-type", "unknown content type"), + utils.pretty_size(len(self.content)) + ) + else: + details = "content missing" + return "Response({status_code} {reason}, {details})".format( + status_code=self.status_code, + reason=self.reason, + details=details + ) + + @property + def status_code(self): + """ + HTTP Status Code, e.g. ``200``. + """ + return self.data.status_code + + @status_code.setter + def status_code(self, status_code): + self.data.status_code = status_code + + @property + def reason(self): + """ + HTTP Reason Phrase, e.g. "Not Found". + This is always :py:obj:`None` for HTTP2 requests, because HTTP2 responses do not contain a reason phrase. + """ + return _native(self.data.reason) + + @reason.setter + def reason(self, reason): + self.data.reason = _always_bytes(reason) + + @property + def cookies(self): + """ + Get the contents of all Set-Cookie headers. + + A possibly empty :py:class:`ODict`, where keys are cookie name strings, + and values are [value, attr] lists. Value is a string, and attr is + an ODictCaseless containing cookie attributes. Within attrs, unary + attributes (e.g. HTTPOnly) are indicated by a Null value. + """ + ret = [] + for header in self.headers.get_all("set-cookie"): + v = cookies.parse_set_cookie_header(header) + if v: + name, value, attrs = v + ret.append([name, [value, attrs]]) + return ODict(ret) + + @cookies.setter + def cookies(self, odict): + values = [] + for i in odict.lst: + header = cookies.format_set_cookie_header(i[0], i[1][0], i[1][1]) + values.append(header) + self.headers.set_all("set-cookie", values) + + # Legacy + + def get_cookies(self): + warnings.warn(".get_cookies is deprecated, use .cookies instead.", DeprecationWarning) + return self.cookies + + def set_cookies(self, odict): + warnings.warn(".set_cookies is deprecated, use .cookies instead.", DeprecationWarning) + self.cookies = odict + + @property + def msg(self): + warnings.warn(".msg is deprecated, use .reason instead.", DeprecationWarning) + return self.reason + + @msg.setter + def msg(self, reason): + warnings.warn(".msg is deprecated, use .reason instead.", DeprecationWarning) + self.reason = reason diff --git a/netlib/tutils.py b/netlib/tutils.py index ff63c33c..e16f1a76 100644 --- a/netlib/tutils.py +++ b/netlib/tutils.py @@ -120,9 +120,9 @@ def tresp(**kwargs): default = dict( http_version=b"HTTP/1.1", status_code=200, - msg=b"OK", - headers=Headers(header_response=b"svalue"), - body=b"message", + reason=b"OK", + headers=Headers(header_response="svalue"), + content=b"message", timestamp_start=time.time(), timestamp_end=time.time(), ) diff --git a/netlib/wsgi.py b/netlib/wsgi.py index 4fcd5178..df248a19 100644 --- a/netlib/wsgi.py +++ b/netlib/wsgi.py @@ -25,9 +25,9 @@ class Flow(object): class Request(object): - def __init__(self, scheme, method, path, http_version, headers, body): + def __init__(self, scheme, method, path, http_version, headers, content): self.scheme, self.method, self.path = scheme, method, path - self.headers, self.body = headers, body + self.headers, self.content = headers, content self.http_version = http_version @@ -64,7 +64,7 @@ class WSGIAdaptor(object): environ = { 'wsgi.version': (1, 0), 'wsgi.url_scheme': native(flow.request.scheme, "latin-1"), - 'wsgi.input': BytesIO(flow.request.body or b""), + 'wsgi.input': BytesIO(flow.request.content or b""), 'wsgi.errors': errsoc, 'wsgi.multithread': True, 'wsgi.multiprocess': False, -- cgit v1.2.3 From 466888b01a361e46fb3d4e66afa2c6a0fd168c8e Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Sat, 26 Sep 2015 20:07:11 +0200 Subject: improve request tests, coverage++ --- netlib/encoding.py | 4 ++++ netlib/http/headers.py | 8 ++++---- netlib/http/message.py | 42 +++++++++++++++++++++++++++++++++++++----- netlib/http/request.py | 28 ++++++++++++++-------------- netlib/http/response.py | 8 ++++---- netlib/http/status_codes.py | 4 +++- 6 files changed, 66 insertions(+), 28 deletions(-) (limited to 'netlib') diff --git a/netlib/encoding.py b/netlib/encoding.py index 4c11273b..14479e00 100644 --- a/netlib/encoding.py +++ b/netlib/encoding.py @@ -12,6 +12,8 @@ ENCODINGS = {"identity", "gzip", "deflate"} def decode(e, content): + if not isinstance(content, bytes): + return None encoding_map = { "identity": identity, "gzip": decode_gzip, @@ -23,6 +25,8 @@ def decode(e, content): def encode(e, content): + if not isinstance(content, bytes): + return None encoding_map = { "identity": identity, "gzip": encode_gzip, diff --git a/netlib/http/headers.py b/netlib/http/headers.py index c79c3344..f64e6200 100644 --- a/netlib/http/headers.py +++ b/netlib/http/headers.py @@ -8,15 +8,15 @@ from __future__ import absolute_import, print_function, division import copy try: from collections.abc import MutableMapping -except ImportError: # Workaround for Python < 3.3 - from collections import MutableMapping +except ImportError: # pragma: nocover + from collections import MutableMapping # Workaround for Python < 3.3 import six from netlib.utils import always_byte_args, always_bytes -if six.PY2: +if six.PY2: # pragma: nocover _native = lambda x: x _always_bytes = lambda x: x _always_byte_args = lambda x: x @@ -106,7 +106,7 @@ class Headers(MutableMapping): else: return b"" - if six.PY2: + if six.PY2: # pragma: nocover __str__ = __bytes__ @_always_byte_args diff --git a/netlib/http/message.py b/netlib/http/message.py index ee138746..7cb18f52 100644 --- a/netlib/http/message.py +++ b/netlib/http/message.py @@ -9,7 +9,7 @@ from .. import encoding, utils CONTENT_MISSING = 0 -if six.PY2: +if six.PY2: # pragma: nocover _native = lambda x: x _always_bytes = lambda x: x else: @@ -110,15 +110,48 @@ class Message(object): def text(self, text): raise NotImplementedError() + def decode(self): + """ + Decodes body based on the current Content-Encoding header, then + removes the header. If there is no Content-Encoding header, no + action is taken. + + Returns: + True, if decoding succeeded. + False, otherwise. + """ + ce = self.headers.get("content-encoding") + data = encoding.decode(ce, self.content) + if data is None: + return False + self.content = data + self.headers.pop("content-encoding", None) + return True + + def encode(self, e): + """ + Encodes body with the encoding e, where e is "gzip", "deflate" or "identity". + + Returns: + True, if decoding succeeded. + False, otherwise. + """ + data = encoding.encode(e, self.content) + if data is None: + return False + self.content = data + self.headers["content-encoding"] = e + return True + # Legacy @property - def body(self): + def body(self): # pragma: nocover warnings.warn(".body is deprecated, use .content instead.", DeprecationWarning) return self.content @body.setter - def body(self, body): + def body(self, body): # pragma: nocover warnings.warn(".body is deprecated, use .content instead.", DeprecationWarning) self.content = body @@ -146,8 +179,7 @@ class decoded(object): def __enter__(self): if self.ce: - if not self.message.decode(): - self.ce = None + self.message.decode() def __exit__(self, type, value, tb): if self.ce: diff --git a/netlib/http/request.py b/netlib/http/request.py index f8a3b5b9..325c0080 100644 --- a/netlib/http/request.py +++ b/netlib/http/request.py @@ -102,7 +102,7 @@ class Request(Message): or inferred from the proxy mode (e.g. an IP in transparent mode). """ - if six.PY2: + if six.PY2: # pragma: nocover return self.data.host if not self.data.host: @@ -303,58 +303,58 @@ class Request(Message): # Legacy - def get_cookies(self): + def get_cookies(self): # pragma: nocover warnings.warn(".get_cookies is deprecated, use .cookies instead.", DeprecationWarning) return self.cookies - def set_cookies(self, odict): + def set_cookies(self, odict): # pragma: nocover warnings.warn(".set_cookies is deprecated, use .cookies instead.", DeprecationWarning) self.cookies = odict - def get_query(self): + def get_query(self): # pragma: nocover warnings.warn(".get_query is deprecated, use .query instead.", DeprecationWarning) return self.query or ODict([]) - def set_query(self, odict): + def set_query(self, odict): # pragma: nocover warnings.warn(".set_query is deprecated, use .query instead.", DeprecationWarning) self.query = odict - def get_path_components(self): + def get_path_components(self): # pragma: nocover warnings.warn(".get_path_components is deprecated, use .path_components instead.", DeprecationWarning) return self.path_components - def set_path_components(self, lst): + def set_path_components(self, lst): # pragma: nocover warnings.warn(".set_path_components is deprecated, use .path_components instead.", DeprecationWarning) self.path_components = lst - def get_form_urlencoded(self): + def get_form_urlencoded(self): # pragma: nocover warnings.warn(".get_form_urlencoded is deprecated, use .urlencoded_form instead.", DeprecationWarning) return self.urlencoded_form or ODict([]) - def set_form_urlencoded(self, odict): + def set_form_urlencoded(self, odict): # pragma: nocover warnings.warn(".set_form_urlencoded is deprecated, use .urlencoded_form instead.", DeprecationWarning) self.urlencoded_form = odict - def get_form_multipart(self): + def get_form_multipart(self): # pragma: nocover warnings.warn(".get_form_multipart is deprecated, use .multipart_form instead.", DeprecationWarning) return self.multipart_form or ODict([]) @property - def form_in(self): + def form_in(self): # pragma: nocover warnings.warn(".form_in is deprecated, use .first_line_format instead.", DeprecationWarning) return self.first_line_format @form_in.setter - def form_in(self, form_in): + def form_in(self, form_in): # pragma: nocover warnings.warn(".form_in is deprecated, use .first_line_format instead.", DeprecationWarning) self.first_line_format = form_in @property - def form_out(self): + def form_out(self): # pragma: nocover warnings.warn(".form_out is deprecated, use .first_line_format instead.", DeprecationWarning) return self.first_line_format @form_out.setter - def form_out(self, form_out): + def form_out(self, form_out): # pragma: nocover warnings.warn(".form_out is deprecated, use .first_line_format instead.", DeprecationWarning) self.first_line_format = form_out \ No newline at end of file diff --git a/netlib/http/response.py b/netlib/http/response.py index 7d64243d..db31d2b9 100644 --- a/netlib/http/response.py +++ b/netlib/http/response.py @@ -106,20 +106,20 @@ class Response(Message): # Legacy - def get_cookies(self): + def get_cookies(self): # pragma: nocover warnings.warn(".get_cookies is deprecated, use .cookies instead.", DeprecationWarning) return self.cookies - def set_cookies(self, odict): + def set_cookies(self, odict): # pragma: nocover warnings.warn(".set_cookies is deprecated, use .cookies instead.", DeprecationWarning) self.cookies = odict @property - def msg(self): + def msg(self): # pragma: nocover warnings.warn(".msg is deprecated, use .reason instead.", DeprecationWarning) return self.reason @msg.setter - def msg(self, reason): + def msg(self, reason): # pragma: nocover warnings.warn(".msg is deprecated, use .reason instead.", DeprecationWarning) self.reason = reason diff --git a/netlib/http/status_codes.py b/netlib/http/status_codes.py index dc09f465..8a4dc1f5 100644 --- a/netlib/http/status_codes.py +++ b/netlib/http/status_codes.py @@ -1,4 +1,4 @@ -from __future__ import (absolute_import, print_function, division) +from __future__ import absolute_import, print_function, division CONTINUE = 100 SWITCHING = 101 @@ -37,6 +37,7 @@ REQUEST_URI_TOO_LONG = 414 UNSUPPORTED_MEDIA_TYPE = 415 REQUESTED_RANGE_NOT_SATISFIABLE = 416 EXPECTATION_FAILED = 417 +IM_A_TEAPOT = 418 INTERNAL_SERVER_ERROR = 500 NOT_IMPLEMENTED = 501 @@ -91,6 +92,7 @@ RESPONSES = { UNSUPPORTED_MEDIA_TYPE: "Unsupported Media Type", REQUESTED_RANGE_NOT_SATISFIABLE: "Requested Range not satisfiable", EXPECTATION_FAILED: "Expectation Failed", + IM_A_TEAPOT: "I'm a teapot", # 500 INTERNAL_SERVER_ERROR: "Internal Server Error", -- cgit v1.2.3 From 23d13e4c1282bc46c54222479c3b83032dad3335 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Sun, 27 Sep 2015 00:49:41 +0200 Subject: test response model, push coverage to 100% branch cov --- netlib/http/cookies.py | 1 + netlib/http/message.py | 10 ++++++++++ netlib/http/request.py | 12 ++---------- netlib/http/response.py | 14 +++----------- 4 files changed, 16 insertions(+), 21 deletions(-) (limited to 'netlib') diff --git a/netlib/http/cookies.py b/netlib/http/cookies.py index 78b03a83..18544b5e 100644 --- a/netlib/http/cookies.py +++ b/netlib/http/cookies.py @@ -58,6 +58,7 @@ def _read_quoted_string(s, start): escaping = False ret = [] # Skip the first quote + i = start # initialize in case the loop doesn't run. for i in range(start + 1, len(s)): if escaping: ret.append(s[i]) diff --git a/netlib/http/message.py b/netlib/http/message.py index 7cb18f52..e4e799ca 100644 --- a/netlib/http/message.py +++ b/netlib/http/message.py @@ -18,6 +18,16 @@ else: _always_bytes = lambda x: utils.always_bytes(x, "utf-8", "surrogateescape") +class MessageData(object): + def __eq__(self, other): + if isinstance(other, MessageData): + return self.__dict__ == other.__dict__ + return False + + def __ne__(self, other): + return not self.__eq__(other) + + class Message(object): def __init__(self, data): self.data = data diff --git a/netlib/http/request.py b/netlib/http/request.py index 325c0080..095b5945 100644 --- a/netlib/http/request.py +++ b/netlib/http/request.py @@ -10,10 +10,10 @@ from netlib.http import cookies from netlib.odict import ODict from .. import encoding from .headers import Headers -from .message import Message, _native, _always_bytes +from .message import Message, _native, _always_bytes, MessageData -class RequestData(object): +class RequestData(MessageData): def __init__(self, first_line_format, method, scheme, host, port, path, http_version, headers=None, content=None, timestamp_start=None, timestamp_end=None): if not headers: @@ -32,14 +32,6 @@ class RequestData(object): self.timestamp_start = timestamp_start self.timestamp_end = timestamp_end - def __eq__(self, other): - if isinstance(other, RequestData): - return self.__dict__ == other.__dict__ - return False - - def __ne__(self, other): - return not self.__eq__(other) - class Request(Message): """ diff --git a/netlib/http/response.py b/netlib/http/response.py index db31d2b9..66e5ded6 100644 --- a/netlib/http/response.py +++ b/netlib/http/response.py @@ -4,12 +4,12 @@ import warnings from . import cookies from .headers import Headers -from .message import Message, _native, _always_bytes +from .message import Message, _native, _always_bytes, MessageData from .. import utils from ..odict import ODict -class ResponseData(object): +class ResponseData(MessageData): def __init__(self, http_version, status_code, reason=None, headers=None, content=None, timestamp_start=None, timestamp_end=None): if not headers: @@ -24,14 +24,6 @@ class ResponseData(object): self.timestamp_start = timestamp_start self.timestamp_end = timestamp_end - def __eq__(self, other): - if isinstance(other, ResponseData): - return self.__dict__ == other.__dict__ - return False - - def __ne__(self, other): - return not self.__eq__(other) - class Response(Message): """ @@ -48,7 +40,7 @@ class Response(Message): utils.pretty_size(len(self.content)) ) else: - details = "content missing" + details = "no content" return "Response({status_code} {reason}, {details})".format( status_code=self.status_code, reason=self.reason, -- cgit v1.2.3 From 87566da3babcc827e9dae0f2e9ab9154c353aa11 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Mon, 28 Sep 2015 11:18:00 +0200 Subject: fix mitmproxy/mitmproxy#784 --- netlib/http/http1/read.py | 5 ----- netlib/utils.py | 5 +++-- 2 files changed, 3 insertions(+), 7 deletions(-) (limited to 'netlib') diff --git a/netlib/http/http1/read.py b/netlib/http/http1/read.py index 4c898348..73c7deed 100644 --- a/netlib/http/http1/read.py +++ b/netlib/http/http1/read.py @@ -218,11 +218,6 @@ def _get_first_line(rfile): raise HttpReadDisconnect("Remote disconnected") if not line: raise HttpReadDisconnect("Remote disconnected") - line = line.strip() - try: - line.decode("ascii") - except ValueError: - raise HttpSyntaxException("Non-ascii characters in first line: {}".format(line)) return line.strip() diff --git a/netlib/utils.py b/netlib/utils.py index 6f6d1ea0..8b9548ed 100644 --- a/netlib/utils.py +++ b/netlib/utils.py @@ -237,8 +237,9 @@ def parse_url(url): if isinstance(url, six.binary_type): host = parsed.hostname - # this should not raise a ValueError - decode_parse_result(parsed, "ascii") + # this should not raise a ValueError, + # but we try to be very forgiving here and accept just everything. + # decode_parse_result(parsed, "ascii") else: host = parsed.hostname.encode("idna") parsed = encode_parse_result(parsed, "ascii") -- cgit v1.2.3 From 5af9df326aef1cf72be7fd5390df239fb6b906c7 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Sun, 1 Nov 2015 18:15:30 +0100 Subject: fix certificate verification This commit fixes netlib's optional (turned off by default) certificate verification, which previously did not validate the cert's host name. As it turns out, verifying the connection's host name on an intercepting proxy is not really straightforward - if we receive a connection in transparent mode without SNI, we have no clue which hosts the client intends to connect to. There are two basic approaches to solve this problem: 1. Exactly mirror the host names presented by the server in the spoofed certificate presented to the client. 2. Require the client to send the TLS Server Name Indication extension. While this does not work with older clients, we can validate the hostname on the proxy. Approach 1 is problematic in mitmproxy's use case, as we may want to deliberately divert connections without the client's knowledge. As a consequence, we opt for approach 2. While mitmproxy does now require a SNI value to be sent by the client if certificate verification is turned on, we retain our ability to present certificates to the client which are accepted with a maximum likelihood. --- netlib/certutils.py | 5 +++++ netlib/tcp.py | 37 ++++++++++++++++++++++++++++++------- 2 files changed, 35 insertions(+), 7 deletions(-) (limited to 'netlib') diff --git a/netlib/certutils.py b/netlib/certutils.py index b3ddcbe4..93366a99 100644 --- a/netlib/certutils.py +++ b/netlib/certutils.py @@ -437,6 +437,11 @@ class SSLCert(object): @property def altnames(self): + """ + Returns: + All DNS altnames. + """ + # tcp.TCPClient.convert_to_ssl assumes that this property only contains DNS altnames for hostname verification. altnames = [] for i in range(self.x509.get_extension_count()): ext = self.x509.get_extension(i) diff --git a/netlib/tcp.py b/netlib/tcp.py index b751d71f..33776fc4 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -11,6 +11,7 @@ import binascii from six.moves import range import certifi +from backports import ssl_match_hostname import six import OpenSSL from OpenSSL import SSL @@ -597,9 +598,14 @@ class TCPClient(_Connection): ca_path: Path to a directory of trusted CA certificates prepared using the c_rehash tool ca_pemfile: Path to a PEM formatted trusted CA certificate """ + verification_mode = sslctx_kwargs.get('verify_options', None) + if verification_mode == SSL.VERIFY_PEER and not sni: + raise TlsException("Cannot validate certificate hostname without SNI") + context = self.create_ssl_context( alpn_protos=alpn_protos, - **sslctx_kwargs) + **sslctx_kwargs + ) self.connection = SSL.Connection(context, self.connection) if sni: self.sni = sni @@ -612,15 +618,32 @@ class TCPClient(_Connection): raise InvalidCertificateException("SSL handshake error: %s" % repr(v)) else: raise TlsException("SSL handshake error: %s" % repr(v)) + else: + # Fix for pre v1.0 OpenSSL, which doesn't throw an exception on + # certificate validation failure + if verification_mode == SSL.VERIFY_PEER and self.ssl_verification_error is not None: + raise InvalidCertificateException("SSL handshake error: certificate verify failed") - # Fix for pre v1.0 OpenSSL, which doesn't throw an exception on - # certificate validation failure - verification_mode = sslctx_kwargs.get('verify_options', None) - if self.ssl_verification_error is not None and verification_mode == SSL.VERIFY_PEER: - raise InvalidCertificateException("SSL handshake error: certificate verify failed") + self.cert = certutils.SSLCert(self.connection.get_peer_certificate()) + + # Validate TLS Hostname + try: + crt = dict( + subjectAltName=[("DNS", x.decode("ascii", "strict")) for x in self.cert.altnames] + ) + if self.cert.cn: + crt["subject"] = [[["commonName", self.cert.cn.decode("ascii", "strict")]]] + if sni: + hostname = sni.decode("ascii", "strict") + else: + hostname = "no-hostname" + ssl_match_hostname.match_hostname(crt, hostname) + except (ValueError, ssl_match_hostname.CertificateError) as e: + self.ssl_verification_error = dict(depth=0, errno="Invalid Hostname") + if verification_mode == SSL.VERIFY_PEER: + raise InvalidCertificateException("Presented certificate for {} is not valid: {}".format(sni, str(e))) self.ssl_established = True - self.cert = certutils.SSLCert(self.connection.get_peer_certificate()) self.rfile.set_descriptor(self.connection) self.wfile.set_descriptor(self.connection) -- cgit v1.2.3 From 9d36f8e43fc7a3b3c4bf10a8c1b9819da8999dad Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Sun, 1 Nov 2015 18:20:00 +0100 Subject: minor fixes --- netlib/http/request.py | 2 ++ netlib/tcp.py | 2 -- 2 files changed, 2 insertions(+), 2 deletions(-) (limited to 'netlib') diff --git a/netlib/http/request.py b/netlib/http/request.py index 92d99532..5ebf21a5 100644 --- a/netlib/http/request.py +++ b/netlib/http/request.py @@ -92,6 +92,8 @@ class Request(Message): Target host. This may be parsed from the raw request (e.g. from a ``GET http://example.com/ HTTP/1.1`` request line) or inferred from the proxy mode (e.g. an IP in transparent mode). + + Setting the host attribute also updates the host header, if present. """ if six.PY2: # pragma: nocover diff --git a/netlib/tcp.py b/netlib/tcp.py index b751d71f..ef5ab4b6 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -559,8 +559,6 @@ class TCPClient(_Connection): @address.setter def address(self, address): - if self.connection: - raise RuntimeError("Cannot change server address after establishing connection") if address: self.__address = Address.wrap(address) else: -- cgit v1.2.3 From 9d12425d5ee942ee3d954a9324c31b74f466d520 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Wed, 4 Nov 2015 11:28:02 +0100 Subject: Set default cert expiry to <39 months This sould fix mitmproxy/mitmproxy#815 --- netlib/certutils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) (limited to 'netlib') diff --git a/netlib/certutils.py b/netlib/certutils.py index b3ddcbe4..69530245 100644 --- a/netlib/certutils.py +++ b/netlib/certutils.py @@ -12,7 +12,8 @@ from pyasn1.codec.der.decoder import decode from pyasn1.error import PyAsn1Error import OpenSSL -DEFAULT_EXP = 157680000 # = 24 * 60 * 60 * 365 * 5 +# Default expiry must not be too long: https://github.com/mitmproxy/mitmproxy/issues/815 +DEFAULT_EXP = 94608000 # = 24 * 60 * 60 * 365 * 3 # Generated with "openssl dhparam". It's too slow to generate this on startup. DEFAULT_DHPARAM = b""" -----BEGIN DH PARAMETERS----- -- cgit v1.2.3 From 3e2eb3fef166822bfad0d2200dadffe541efbc38 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Fri, 6 Nov 2015 13:51:15 +1300 Subject: Bump version --- netlib/version.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'netlib') diff --git a/netlib/version.py b/netlib/version.py index 044fde2c..d2c3c369 100644 --- a/netlib/version.py +++ b/netlib/version.py @@ -1,6 +1,6 @@ from __future__ import (absolute_import, print_function, division) -IVERSION = (0, 13, 2) +IVERSION = (0, 14, 0) VERSION = ".".join(str(i) for i in IVERSION) MINORVERSION = ".".join(str(i) for i in IVERSION[:2]) NAME = "netlib" -- cgit v1.2.3 From 9cab9ee5d6f39b658c1e9260950cc3575d3ad9db Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Sat, 7 Nov 2015 09:30:49 +1300 Subject: Bump version for next release cycle --- netlib/version.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'netlib') diff --git a/netlib/version.py b/netlib/version.py index d2c3c369..e836dbe3 100644 --- a/netlib/version.py +++ b/netlib/version.py @@ -1,6 +1,6 @@ from __future__ import (absolute_import, print_function, division) -IVERSION = (0, 14, 0) +IVERSION = (0, 14, 1) VERSION = ".".join(str(i) for i in IVERSION) MINORVERSION = ".".join(str(i) for i in IVERSION[:2]) NAME = "netlib" -- cgit v1.2.3 From 823718348598efb324298ca29ad4cb7d5097c084 Mon Sep 17 00:00:00 2001 From: Sam Cleveland Date: Wed, 11 Nov 2015 11:32:02 -0600 Subject: Porting netlib to python3.4 Updated utils.py using 2to3-3.4 Updated hexdump to use .format() with .encode() to support python 3.4 Python 3.5 supports .format() on bytes objects, but 3.4 is the current default on Ubuntu. samc$ py.test netlib/test/test_utils.py = test session starts = platform darwin -- Python 3.4.1, pytest-2.8.2, py-1.4.30, pluggy-0.3.1 rootdir: /Users/samc/src/python/netlib, inifile: collected 11 items netlib/test/test_utils.py ........... = 11 passed in 0.19 seconds = --- netlib/utils.py | 16 +-- netlib/utils.py.bak | 368 ++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 376 insertions(+), 8 deletions(-) create mode 100644 netlib/utils.py.bak (limited to 'netlib') diff --git a/netlib/utils.py b/netlib/utils.py index acc7ccd4..62f17012 100644 --- a/netlib/utils.py +++ b/netlib/utils.py @@ -1,4 +1,4 @@ -from __future__ import absolute_import, print_function, division + import os.path import re import string @@ -61,11 +61,11 @@ def clean_bin(s, keep_spacing=True): """ if isinstance(s, six.text_type): if keep_spacing: - keep = u" \n\r\t" + keep = " \n\r\t" else: - keep = u" " - return u"".join( - ch if (unicodedata.category(ch)[0] not in "CZ" or ch in keep) else u"." + keep = " " + return "".join( + ch if (unicodedata.category(ch)[0] not in "CZ" or ch in keep) else "." for ch in s ) else: @@ -85,9 +85,9 @@ def hexdump(s): A generator of (offset, hex, str) tuples """ for i in range(0, len(s), 16): - offset = b"%.10x" % i + offset = "{:0=10x}".format(i).encode() part = s[i:i + 16] - x = b" ".join(b"%.2x" % i for i in six.iterbytes(part)) + x = b" ".join("{:0=2x}".format(i).encode() for i in six.iterbytes(part)) x = x.ljust(47) # 16*2 + 15 yield (offset, x, clean_bin(part, False)) @@ -122,7 +122,7 @@ class BiDi(object): def __init__(self, **kwargs): self.names = kwargs self.values = {} - for k, v in kwargs.items(): + for k, v in list(kwargs.items()): self.values[v] = k if len(self.names) != len(self.values): raise ValueError("Duplicate values not allowed.") diff --git a/netlib/utils.py.bak b/netlib/utils.py.bak new file mode 100644 index 00000000..acc7ccd4 --- /dev/null +++ b/netlib/utils.py.bak @@ -0,0 +1,368 @@ +from __future__ import absolute_import, print_function, division +import os.path +import re +import string +import unicodedata + +import six + +from six.moves import urllib + + +def always_bytes(unicode_or_bytes, *encode_args): + if isinstance(unicode_or_bytes, six.text_type): + return unicode_or_bytes.encode(*encode_args) + return unicode_or_bytes + + +def always_byte_args(*encode_args): + """Decorator that transparently encodes all arguments passed as unicode""" + def decorator(fun): + def _fun(*args, **kwargs): + args = [always_bytes(arg, *encode_args) for arg in args] + kwargs = {k: always_bytes(v, *encode_args) for k, v in six.iteritems(kwargs)} + return fun(*args, **kwargs) + return _fun + return decorator + + +def native(s, *encoding_opts): + """ + Convert :py:class:`bytes` or :py:class:`unicode` to the native + :py:class:`str` type, using latin1 encoding if conversion is necessary. + + https://www.python.org/dev/peps/pep-3333/#a-note-on-string-types + """ + if not isinstance(s, (six.binary_type, six.text_type)): + raise TypeError("%r is neither bytes nor unicode" % s) + if six.PY3: + if isinstance(s, six.binary_type): + return s.decode(*encoding_opts) + else: + if isinstance(s, six.text_type): + return s.encode(*encoding_opts) + return s + + +def isascii(bytes): + try: + bytes.decode("ascii") + except ValueError: + return False + return True + + +def clean_bin(s, keep_spacing=True): + """ + Cleans binary data to make it safe to display. + + Args: + keep_spacing: If False, tabs and newlines will also be replaced. + """ + if isinstance(s, six.text_type): + if keep_spacing: + keep = u" \n\r\t" + else: + keep = u" " + return u"".join( + ch if (unicodedata.category(ch)[0] not in "CZ" or ch in keep) else u"." + for ch in s + ) + else: + if keep_spacing: + keep = (9, 10, 13) # \t, \n, \r, + else: + keep = () + return b"".join( + six.int2byte(ch) if (31 < ch < 127 or ch in keep) else b"." + for ch in six.iterbytes(s) + ) + + +def hexdump(s): + """ + Returns: + A generator of (offset, hex, str) tuples + """ + for i in range(0, len(s), 16): + offset = b"%.10x" % i + part = s[i:i + 16] + x = b" ".join(b"%.2x" % i for i in six.iterbytes(part)) + x = x.ljust(47) # 16*2 + 15 + yield (offset, x, clean_bin(part, False)) + + +def setbit(byte, offset, value): + """ + Set a bit in a byte to 1 if value is truthy, 0 if not. + """ + if value: + return byte | (1 << offset) + else: + return byte & ~(1 << offset) + + +def getbit(byte, offset): + mask = 1 << offset + return bool(byte & mask) + + +class BiDi(object): + + """ + A wee utility class for keeping bi-directional mappings, like field + constants in protocols. Names are attributes on the object, dict-like + access maps values to names: + + CONST = BiDi(a=1, b=2) + assert CONST.a == 1 + assert CONST.get_name(1) == "a" + """ + + def __init__(self, **kwargs): + self.names = kwargs + self.values = {} + for k, v in kwargs.items(): + self.values[v] = k + if len(self.names) != len(self.values): + raise ValueError("Duplicate values not allowed.") + + def __getattr__(self, k): + if k in self.names: + return self.names[k] + raise AttributeError("No such attribute: %s", k) + + def get_name(self, n, default=None): + return self.values.get(n, default) + + +def pretty_size(size): + suffixes = [ + ("B", 2 ** 10), + ("kB", 2 ** 20), + ("MB", 2 ** 30), + ] + for suf, lim in suffixes: + if size >= lim: + continue + else: + x = round(size / float(lim / 2 ** 10), 2) + if x == int(x): + x = int(x) + return str(x) + suf + + +class Data(object): + + def __init__(self, name): + m = __import__(name) + dirname, _ = os.path.split(m.__file__) + self.dirname = os.path.abspath(dirname) + + def path(self, path): + """ + Returns a path to the package data housed at 'path' under this + module.Path can be a path to a file, or to a directory. + + This function will raise ValueError if the path does not exist. + """ + fullpath = os.path.join(self.dirname, '../test/', path) + if not os.path.exists(fullpath): + raise ValueError("dataPath: %s does not exist." % fullpath) + return fullpath + + +_label_valid = re.compile(b"(?!-)[A-Z\d-]{1,63}(? 255: + return False + if host[-1] == b".": + host = host[:-1] + return all(_label_valid.match(x) for x in host.split(b".")) + + +def is_valid_port(port): + return 0 <= port <= 65535 + + +# PY2 workaround +def decode_parse_result(result, enc): + if hasattr(result, "decode"): + return result.decode(enc) + else: + return urllib.parse.ParseResult(*[x.decode(enc) for x in result]) + + +# PY2 workaround +def encode_parse_result(result, enc): + if hasattr(result, "encode"): + return result.encode(enc) + else: + return urllib.parse.ParseResult(*[x.encode(enc) for x in result]) + + +def parse_url(url): + """ + URL-parsing function that checks that + - port is an integer 0-65535 + - host is a valid IDNA-encoded hostname with no null-bytes + - path is valid ASCII + + Args: + A URL (as bytes or as unicode) + + Returns: + A (scheme, host, port, path) tuple + + Raises: + ValueError, if the URL is not properly formatted. + """ + parsed = urllib.parse.urlparse(url) + + if not parsed.hostname: + raise ValueError("No hostname given") + + if isinstance(url, six.binary_type): + host = parsed.hostname + + # this should not raise a ValueError, + # but we try to be very forgiving here and accept just everything. + # decode_parse_result(parsed, "ascii") + else: + host = parsed.hostname.encode("idna") + parsed = encode_parse_result(parsed, "ascii") + + port = parsed.port + if not port: + port = 443 if parsed.scheme == b"https" else 80 + + full_path = urllib.parse.urlunparse( + (b"", b"", parsed.path, parsed.params, parsed.query, parsed.fragment) + ) + if not full_path.startswith(b"/"): + full_path = b"/" + full_path + + if not is_valid_host(host): + raise ValueError("Invalid Host") + if not is_valid_port(port): + raise ValueError("Invalid Port") + + return parsed.scheme, host, port, full_path + + +def get_header_tokens(headers, key): + """ + Retrieve all tokens for a header key. A number of different headers + follow a pattern where each header line can containe comma-separated + tokens, and headers can be set multiple times. + """ + if key not in headers: + return [] + tokens = headers[key].split(",") + return [token.strip() for token in tokens] + + +def hostport(scheme, host, port): + """ + Returns the host component, with a port specifcation if needed. + """ + if (port, scheme) in [(80, "http"), (443, "https"), (80, b"http"), (443, b"https")]: + return host + else: + if isinstance(host, six.binary_type): + return b"%s:%d" % (host, port) + else: + return "%s:%d" % (host, port) + + +def unparse_url(scheme, host, port, path=""): + """ + Returns a URL string, constructed from the specified components. + + Args: + All args must be str. + """ + return "%s://%s%s" % (scheme, hostport(scheme, host, port), path) + + +def urlencode(s): + """ + Takes a list of (key, value) tuples and returns a urlencoded string. + """ + s = [tuple(i) for i in s] + return urllib.parse.urlencode(s, False) + + +def urldecode(s): + """ + Takes a urlencoded string and returns a list of (key, value) tuples. + """ + return urllib.parse.parse_qsl(s, keep_blank_values=True) + + +def parse_content_type(c): + """ + A simple parser for content-type values. Returns a (type, subtype, + parameters) tuple, where type and subtype are strings, and parameters + is a dict. If the string could not be parsed, return None. + + E.g. the following string: + + text/html; charset=UTF-8 + + Returns: + + ("text", "html", {"charset": "UTF-8"}) + """ + parts = c.split(";", 1) + ts = parts[0].split("/", 1) + if len(ts) != 2: + return None + d = {} + if len(parts) == 2: + for i in parts[1].split(";"): + clause = i.split("=", 1) + if len(clause) == 2: + d[clause[0].strip()] = clause[1].strip() + return ts[0].lower(), ts[1].lower(), d + + +def multipartdecode(headers, content): + """ + Takes a multipart boundary encoded string and returns list of (key, value) tuples. + """ + v = headers.get("content-type") + if v: + v = parse_content_type(v) + if not v: + return [] + try: + boundary = v[2]["boundary"].encode("ascii") + except (KeyError, UnicodeError): + return [] + + rx = re.compile(br'\bname="([^"]+)"') + r = [] + + for i in content.split(b"--" + boundary): + parts = i.splitlines() + if len(parts) > 1 and parts[0][0:2] != b"--": + match = rx.search(parts[1]) + if match: + key = match.group(1) + value = b"".join(parts[3 + parts[2:].index(b""):]) + r.append((key, value)) + return r + return [] -- cgit v1.2.3 From 2d48f12332ff380db3ab66c8f436f78a62b2cd91 Mon Sep 17 00:00:00 2001 From: Sam Cleveland Date: Wed, 11 Nov 2015 19:41:42 -0600 Subject: Revert "Porting netlib to python3.4" This reverts commit 823718348598efb324298ca29ad4cb7d5097c084. --- netlib/utils.py | 16 +-- netlib/utils.py.bak | 368 ---------------------------------------------------- 2 files changed, 8 insertions(+), 376 deletions(-) delete mode 100644 netlib/utils.py.bak (limited to 'netlib') diff --git a/netlib/utils.py b/netlib/utils.py index 62f17012..acc7ccd4 100644 --- a/netlib/utils.py +++ b/netlib/utils.py @@ -1,4 +1,4 @@ - +from __future__ import absolute_import, print_function, division import os.path import re import string @@ -61,11 +61,11 @@ def clean_bin(s, keep_spacing=True): """ if isinstance(s, six.text_type): if keep_spacing: - keep = " \n\r\t" + keep = u" \n\r\t" else: - keep = " " - return "".join( - ch if (unicodedata.category(ch)[0] not in "CZ" or ch in keep) else "." + keep = u" " + return u"".join( + ch if (unicodedata.category(ch)[0] not in "CZ" or ch in keep) else u"." for ch in s ) else: @@ -85,9 +85,9 @@ def hexdump(s): A generator of (offset, hex, str) tuples """ for i in range(0, len(s), 16): - offset = "{:0=10x}".format(i).encode() + offset = b"%.10x" % i part = s[i:i + 16] - x = b" ".join("{:0=2x}".format(i).encode() for i in six.iterbytes(part)) + x = b" ".join(b"%.2x" % i for i in six.iterbytes(part)) x = x.ljust(47) # 16*2 + 15 yield (offset, x, clean_bin(part, False)) @@ -122,7 +122,7 @@ class BiDi(object): def __init__(self, **kwargs): self.names = kwargs self.values = {} - for k, v in list(kwargs.items()): + for k, v in kwargs.items(): self.values[v] = k if len(self.names) != len(self.values): raise ValueError("Duplicate values not allowed.") diff --git a/netlib/utils.py.bak b/netlib/utils.py.bak deleted file mode 100644 index acc7ccd4..00000000 --- a/netlib/utils.py.bak +++ /dev/null @@ -1,368 +0,0 @@ -from __future__ import absolute_import, print_function, division -import os.path -import re -import string -import unicodedata - -import six - -from six.moves import urllib - - -def always_bytes(unicode_or_bytes, *encode_args): - if isinstance(unicode_or_bytes, six.text_type): - return unicode_or_bytes.encode(*encode_args) - return unicode_or_bytes - - -def always_byte_args(*encode_args): - """Decorator that transparently encodes all arguments passed as unicode""" - def decorator(fun): - def _fun(*args, **kwargs): - args = [always_bytes(arg, *encode_args) for arg in args] - kwargs = {k: always_bytes(v, *encode_args) for k, v in six.iteritems(kwargs)} - return fun(*args, **kwargs) - return _fun - return decorator - - -def native(s, *encoding_opts): - """ - Convert :py:class:`bytes` or :py:class:`unicode` to the native - :py:class:`str` type, using latin1 encoding if conversion is necessary. - - https://www.python.org/dev/peps/pep-3333/#a-note-on-string-types - """ - if not isinstance(s, (six.binary_type, six.text_type)): - raise TypeError("%r is neither bytes nor unicode" % s) - if six.PY3: - if isinstance(s, six.binary_type): - return s.decode(*encoding_opts) - else: - if isinstance(s, six.text_type): - return s.encode(*encoding_opts) - return s - - -def isascii(bytes): - try: - bytes.decode("ascii") - except ValueError: - return False - return True - - -def clean_bin(s, keep_spacing=True): - """ - Cleans binary data to make it safe to display. - - Args: - keep_spacing: If False, tabs and newlines will also be replaced. - """ - if isinstance(s, six.text_type): - if keep_spacing: - keep = u" \n\r\t" - else: - keep = u" " - return u"".join( - ch if (unicodedata.category(ch)[0] not in "CZ" or ch in keep) else u"." - for ch in s - ) - else: - if keep_spacing: - keep = (9, 10, 13) # \t, \n, \r, - else: - keep = () - return b"".join( - six.int2byte(ch) if (31 < ch < 127 or ch in keep) else b"." - for ch in six.iterbytes(s) - ) - - -def hexdump(s): - """ - Returns: - A generator of (offset, hex, str) tuples - """ - for i in range(0, len(s), 16): - offset = b"%.10x" % i - part = s[i:i + 16] - x = b" ".join(b"%.2x" % i for i in six.iterbytes(part)) - x = x.ljust(47) # 16*2 + 15 - yield (offset, x, clean_bin(part, False)) - - -def setbit(byte, offset, value): - """ - Set a bit in a byte to 1 if value is truthy, 0 if not. - """ - if value: - return byte | (1 << offset) - else: - return byte & ~(1 << offset) - - -def getbit(byte, offset): - mask = 1 << offset - return bool(byte & mask) - - -class BiDi(object): - - """ - A wee utility class for keeping bi-directional mappings, like field - constants in protocols. Names are attributes on the object, dict-like - access maps values to names: - - CONST = BiDi(a=1, b=2) - assert CONST.a == 1 - assert CONST.get_name(1) == "a" - """ - - def __init__(self, **kwargs): - self.names = kwargs - self.values = {} - for k, v in kwargs.items(): - self.values[v] = k - if len(self.names) != len(self.values): - raise ValueError("Duplicate values not allowed.") - - def __getattr__(self, k): - if k in self.names: - return self.names[k] - raise AttributeError("No such attribute: %s", k) - - def get_name(self, n, default=None): - return self.values.get(n, default) - - -def pretty_size(size): - suffixes = [ - ("B", 2 ** 10), - ("kB", 2 ** 20), - ("MB", 2 ** 30), - ] - for suf, lim in suffixes: - if size >= lim: - continue - else: - x = round(size / float(lim / 2 ** 10), 2) - if x == int(x): - x = int(x) - return str(x) + suf - - -class Data(object): - - def __init__(self, name): - m = __import__(name) - dirname, _ = os.path.split(m.__file__) - self.dirname = os.path.abspath(dirname) - - def path(self, path): - """ - Returns a path to the package data housed at 'path' under this - module.Path can be a path to a file, or to a directory. - - This function will raise ValueError if the path does not exist. - """ - fullpath = os.path.join(self.dirname, '../test/', path) - if not os.path.exists(fullpath): - raise ValueError("dataPath: %s does not exist." % fullpath) - return fullpath - - -_label_valid = re.compile(b"(?!-)[A-Z\d-]{1,63}(? 255: - return False - if host[-1] == b".": - host = host[:-1] - return all(_label_valid.match(x) for x in host.split(b".")) - - -def is_valid_port(port): - return 0 <= port <= 65535 - - -# PY2 workaround -def decode_parse_result(result, enc): - if hasattr(result, "decode"): - return result.decode(enc) - else: - return urllib.parse.ParseResult(*[x.decode(enc) for x in result]) - - -# PY2 workaround -def encode_parse_result(result, enc): - if hasattr(result, "encode"): - return result.encode(enc) - else: - return urllib.parse.ParseResult(*[x.encode(enc) for x in result]) - - -def parse_url(url): - """ - URL-parsing function that checks that - - port is an integer 0-65535 - - host is a valid IDNA-encoded hostname with no null-bytes - - path is valid ASCII - - Args: - A URL (as bytes or as unicode) - - Returns: - A (scheme, host, port, path) tuple - - Raises: - ValueError, if the URL is not properly formatted. - """ - parsed = urllib.parse.urlparse(url) - - if not parsed.hostname: - raise ValueError("No hostname given") - - if isinstance(url, six.binary_type): - host = parsed.hostname - - # this should not raise a ValueError, - # but we try to be very forgiving here and accept just everything. - # decode_parse_result(parsed, "ascii") - else: - host = parsed.hostname.encode("idna") - parsed = encode_parse_result(parsed, "ascii") - - port = parsed.port - if not port: - port = 443 if parsed.scheme == b"https" else 80 - - full_path = urllib.parse.urlunparse( - (b"", b"", parsed.path, parsed.params, parsed.query, parsed.fragment) - ) - if not full_path.startswith(b"/"): - full_path = b"/" + full_path - - if not is_valid_host(host): - raise ValueError("Invalid Host") - if not is_valid_port(port): - raise ValueError("Invalid Port") - - return parsed.scheme, host, port, full_path - - -def get_header_tokens(headers, key): - """ - Retrieve all tokens for a header key. A number of different headers - follow a pattern where each header line can containe comma-separated - tokens, and headers can be set multiple times. - """ - if key not in headers: - return [] - tokens = headers[key].split(",") - return [token.strip() for token in tokens] - - -def hostport(scheme, host, port): - """ - Returns the host component, with a port specifcation if needed. - """ - if (port, scheme) in [(80, "http"), (443, "https"), (80, b"http"), (443, b"https")]: - return host - else: - if isinstance(host, six.binary_type): - return b"%s:%d" % (host, port) - else: - return "%s:%d" % (host, port) - - -def unparse_url(scheme, host, port, path=""): - """ - Returns a URL string, constructed from the specified components. - - Args: - All args must be str. - """ - return "%s://%s%s" % (scheme, hostport(scheme, host, port), path) - - -def urlencode(s): - """ - Takes a list of (key, value) tuples and returns a urlencoded string. - """ - s = [tuple(i) for i in s] - return urllib.parse.urlencode(s, False) - - -def urldecode(s): - """ - Takes a urlencoded string and returns a list of (key, value) tuples. - """ - return urllib.parse.parse_qsl(s, keep_blank_values=True) - - -def parse_content_type(c): - """ - A simple parser for content-type values. Returns a (type, subtype, - parameters) tuple, where type and subtype are strings, and parameters - is a dict. If the string could not be parsed, return None. - - E.g. the following string: - - text/html; charset=UTF-8 - - Returns: - - ("text", "html", {"charset": "UTF-8"}) - """ - parts = c.split(";", 1) - ts = parts[0].split("/", 1) - if len(ts) != 2: - return None - d = {} - if len(parts) == 2: - for i in parts[1].split(";"): - clause = i.split("=", 1) - if len(clause) == 2: - d[clause[0].strip()] = clause[1].strip() - return ts[0].lower(), ts[1].lower(), d - - -def multipartdecode(headers, content): - """ - Takes a multipart boundary encoded string and returns list of (key, value) tuples. - """ - v = headers.get("content-type") - if v: - v = parse_content_type(v) - if not v: - return [] - try: - boundary = v[2]["boundary"].encode("ascii") - except (KeyError, UnicodeError): - return [] - - rx = re.compile(br'\bname="([^"]+)"') - r = [] - - for i in content.split(b"--" + boundary): - parts = i.splitlines() - if len(parts) > 1 and parts[0][0:2] != b"--": - match = rx.search(parts[1]) - if match: - key = match.group(1) - value = b"".join(parts[3 + parts[2:].index(b""):]) - r.append((key, value)) - return r - return [] -- cgit v1.2.3 From 6689a342ae68c75bd52d81ee1959b1946739eca4 Mon Sep 17 00:00:00 2001 From: Sam Cleveland Date: Wed, 11 Nov 2015 19:53:51 -0600 Subject: Porting to Python 3.4 Fixed byte string formatting for hexdump. = test session starts = platform darwin -- Python 3.4.1, pytest-2.8.2, py-1.4.30, pluggy-0.3.1 rootdir: /Users/samc/src/python/netlib, inifile: collected 11 items netlib/test/test_utils.py ........... = 11 passed in 0.23 seconds = --- netlib/utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'netlib') diff --git a/netlib/utils.py b/netlib/utils.py index acc7ccd4..66225897 100644 --- a/netlib/utils.py +++ b/netlib/utils.py @@ -85,9 +85,9 @@ def hexdump(s): A generator of (offset, hex, str) tuples """ for i in range(0, len(s), 16): - offset = b"%.10x" % i + offset = "{:0=10x}".format(i).encode() part = s[i:i + 16] - x = b" ".join(b"%.2x" % i for i in six.iterbytes(part)) + x = b" ".join("{:0=2x}".format(i).encode() for i in six.iterbytes(part)) x = x.ljust(47) # 16*2 + 15 yield (offset, x, clean_bin(part, False)) -- cgit v1.2.3 From 2bd7bcb3711a20b6a166710f2c7d989d8ae5fcc8 Mon Sep 17 00:00:00 2001 From: Sam Cleveland Date: Wed, 11 Nov 2015 20:27:10 -0600 Subject: Porting to Python 3.4 Updated wsgi to support Python 3.4 byte strings. Updated test_wsgi to remove py.test warning for TestApp having an __init__ constructor. samc$ sudo py.test netlib/test/test_wsgi.py -r w = test session starts = platform darwin -- Python 3.4.1, pytest-2.8.2, py-1.4.30, pluggy-0.3.1 rootdir: /Users/samc/src/python/netlib, inifile: collected 6 items netlib/test/test_wsgi.py ...... = 6 passed in 0.20 seconds = --- netlib/wsgi.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) (limited to 'netlib') diff --git a/netlib/wsgi.py b/netlib/wsgi.py index df248a19..d6dfae5d 100644 --- a/netlib/wsgi.py +++ b/netlib/wsgi.py @@ -96,16 +96,17 @@ class WSGIAdaptor(object): Make a best-effort attempt to write an error page. If headers are already sent, we just bung the error into the page. """ - c = b""" + c = """

Internal Server Error

-
%s"
+
{err}"
- """.strip() % s.encode() + """.format(err=s).strip().encode() + if not headers_sent: soc.write(b"HTTP/1.1 500 Internal Server Error\r\n") soc.write(b"Content-Type: text/html\r\n") - soc.write(b"Content-Length: %s\r\n" % len(c)) + soc.write("Content-Length: {length}\r\n".format(length=len(c)).encode()) soc.write(b"\r\n") soc.write(c) @@ -119,7 +120,7 @@ class WSGIAdaptor(object): def write(data): if not state["headers_sent"]: - soc.write(b"HTTP/1.1 %s\r\n" % state["status"].encode()) + soc.write("HTTP/1.1 {status}\r\n".format(status=state["status"]).encode()) headers = state["headers"] if 'server' not in headers: headers["Server"] = self.sversion -- cgit v1.2.3 From c1385c9a176b8d8113f05cb5e920392016bda0cd Mon Sep 17 00:00:00 2001 From: Benjamin Lee Date: Tue, 17 Nov 2015 04:51:20 +1100 Subject: Fix to ignore empty header value. According to Augmented BNF in the following RFCs http://tools.ietf.org/html/rfc5234#section-3.6 http://www.w3.org/Protocols/rfc2616/rfc2616-sec2.html#sec2.1 field-value = *( field-content | LWS ) http://tools.ietf.org/html/rfc7230#section-3.2 field-value = *( field-content / obs-fold ) ... the HTTP message header `field-value` is allowed to be empty. --- netlib/http/http1/read.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'netlib') diff --git a/netlib/http/http1/read.py b/netlib/http/http1/read.py index 0f6de26c..6e3a1b93 100644 --- a/netlib/http/http1/read.py +++ b/netlib/http/http1/read.py @@ -321,7 +321,7 @@ def _read_headers(rfile): try: name, value = line.split(b":", 1) value = value.strip() - if not name or not value: + if not name: raise ValueError() ret.append([name, value]) except ValueError: -- cgit v1.2.3 From 71834421bbf63e89eb923b888ea97db437c59ea5 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Thu, 3 Dec 2015 18:13:24 +0100 Subject: bump version --- netlib/version.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'netlib') diff --git a/netlib/version.py b/netlib/version.py index e836dbe3..aa4ba641 100644 --- a/netlib/version.py +++ b/netlib/version.py @@ -1,6 +1,6 @@ from __future__ import (absolute_import, print_function, division) -IVERSION = (0, 14, 1) +IVERSION = (0, 15) VERSION = ".".join(str(i) for i in IVERSION) MINORVERSION = ".".join(str(i) for i in IVERSION[:2]) NAME = "netlib" -- cgit v1.2.3 From d1e6b5366c97dd31c9b9606db2bb7a8520cfbd2c Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Fri, 25 Dec 2015 16:00:50 +0100 Subject: bump version --- netlib/version.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'netlib') diff --git a/netlib/version.py b/netlib/version.py index aa4ba641..7a68ca39 100644 --- a/netlib/version.py +++ b/netlib/version.py @@ -1,6 +1,6 @@ from __future__ import (absolute_import, print_function, division) -IVERSION = (0, 15) +IVERSION = (0, 15, 1) VERSION = ".".join(str(i) for i in IVERSION) MINORVERSION = ".".join(str(i) for i in IVERSION[:2]) NAME = "netlib" -- cgit v1.2.3 From 4bb9f3d35b02bb076dd8df133288492c24295c8a Mon Sep 17 00:00:00 2001 From: Sandor Nemes Date: Fri, 8 Jan 2016 18:04:47 +0100 Subject: Added getter/setter for TCPClient source_address --- netlib/tcp.py | 11 +++++++++++ 1 file changed, 11 insertions(+) (limited to 'netlib') diff --git a/netlib/tcp.py b/netlib/tcp.py index 8e46d4f6..e5e9ec1a 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -565,6 +565,17 @@ class TCPClient(_Connection): else: self.__address = None + @property + def source_address(self): + return self.__source_address + + @source_address.setter + def source_address(self, source_address): + if source_address: + self.__source_address = Address.wrap(source_address) + else: + self.__source_address = None + def close(self): # Make sure to close the real socket, not the SSL proxy. # OpenSSL is really good at screwing up, i.e. when trying to recv from a failed connection, -- cgit v1.2.3 From b8e8c4d68222c9292daf23e6ace55351fcef1af6 Mon Sep 17 00:00:00 2001 From: Sandor Nemes Date: Mon, 11 Jan 2016 08:10:36 +0100 Subject: Simplified setting the source_address in the TCPClient constructor --- netlib/tcp.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) (limited to 'netlib') diff --git a/netlib/tcp.py b/netlib/tcp.py index e5e9ec1a..8902b9dc 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -548,8 +548,7 @@ class TCPClient(_Connection): def __init__(self, address, source_address=None): super(TCPClient, self).__init__(None) self.address = address - self.source_address = Address.wrap( - source_address) if source_address else None + self.source_address = source_address self.cert = None self.ssl_verification_error = None self.sni = None -- cgit v1.2.3 From 1b487539b1f3ea183eaed26ae756d0cc7d3ec3ea Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Sun, 24 Jan 2016 23:24:59 +0100 Subject: move tservers to netlib module --- netlib/tservers.py | 109 +++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 109 insertions(+) create mode 100644 netlib/tservers.py (limited to 'netlib') diff --git a/netlib/tservers.py b/netlib/tservers.py new file mode 100644 index 00000000..44ef8063 --- /dev/null +++ b/netlib/tservers.py @@ -0,0 +1,109 @@ +from __future__ import (absolute_import, print_function, division) + +import threading +from six.moves import queue +from io import StringIO +import OpenSSL + +from netlib import tcp +from netlib import tutils + + +class ServerThread(threading.Thread): + + def __init__(self, server): + self.server = server + threading.Thread.__init__(self) + + def run(self): + self.server.serve_forever() + + def shutdown(self): + self.server.shutdown() + + +class ServerTestBase(object): + ssl = None + handler = None + addr = ("localhost", 0) + + @classmethod + def setup_class(cls): + cls.q = queue.Queue() + s = cls.makeserver() + cls.port = s.address.port + cls.server = ServerThread(s) + cls.server.start() + + @classmethod + def makeserver(cls): + return TServer(cls.ssl, cls.q, cls.handler, cls.addr) + + @classmethod + def teardown_class(cls): + cls.server.shutdown() + + @property + def last_handler(self): + return self.server.server.last_handler + + +class TServer(tcp.TCPServer): + + def __init__(self, ssl, q, handler_klass, addr): + """ + ssl: A dictionary of SSL parameters: + + cert, key, request_client_cert, cipher_list, + dhparams, v3_only + """ + tcp.TCPServer.__init__(self, addr) + + if ssl is True: + self.ssl = dict() + elif isinstance(ssl, dict): + self.ssl = ssl + else: + self.ssl = None + + self.q = q + self.handler_klass = handler_klass + self.last_handler = None + + def handle_client_connection(self, request, client_address): + h = self.handler_klass(request, client_address, self) + self.last_handler = h + if self.ssl is not None: + cert = self.ssl.get( + "cert", + tutils.test_data.path("data/server.crt")) + raw_key = self.ssl.get( + "key", + tutils.test_data.path("data/server.key")) + key = OpenSSL.crypto.load_privatekey( + OpenSSL.crypto.FILETYPE_PEM, + open(raw_key, "rb").read()) + if self.ssl.get("v3_only", False): + method = OpenSSL.SSL.SSLv3_METHOD + options = OpenSSL.SSL.OP_NO_SSLv2 | OpenSSL.SSL.OP_NO_TLSv1 + else: + method = OpenSSL.SSL.SSLv23_METHOD + options = None + h.convert_to_ssl( + cert, key, + method=method, + options=options, + handle_sni=getattr(h, "handle_sni", None), + request_client_cert=self.ssl.get("request_client_cert", None), + cipher_list=self.ssl.get("cipher_list", None), + dhparams=self.ssl.get("dhparams", None), + chain_file=self.ssl.get("chain_file", None), + alpn_select=self.ssl.get("alpn_select", None) + ) + h.handle() + h.finish() + + def handle_error(self, connection, client_address, fp=None): + s = StringIO() + tcp.TCPServer.handle_error(self, connection, client_address, s) + self.q.put(s.getvalue()) -- cgit v1.2.3 From d253ebc142d80708a1bdc065d3db05d1394e3819 Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Sat, 30 Jan 2016 22:03:24 +0100 Subject: fix test request and response headers --- netlib/http/message.py | 2 +- netlib/tutils.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) (limited to 'netlib') diff --git a/netlib/http/message.py b/netlib/http/message.py index e4e799ca..28f55fa2 100644 --- a/netlib/http/message.py +++ b/netlib/http/message.py @@ -193,4 +193,4 @@ class decoded(object): def __exit__(self, type, value, tb): if self.ce: - self.message.encode(self.ce) \ No newline at end of file + self.message.encode(self.ce) diff --git a/netlib/tutils.py b/netlib/tutils.py index e16f1a76..14b4ef06 100644 --- a/netlib/tutils.py +++ b/netlib/tutils.py @@ -105,7 +105,7 @@ def treq(**kwargs): port=22, path=b"/path", http_version=b"HTTP/1.1", - headers=Headers(header="qvalue"), + headers=Headers(header="qvalue", content_length="7"), content=b"content" ) default.update(kwargs) @@ -121,7 +121,7 @@ def tresp(**kwargs): http_version=b"HTTP/1.1", status_code=200, reason=b"OK", - headers=Headers(header_response="svalue"), + headers=Headers(header_response="svalue", content_length="7"), content=b"message", timestamp_start=time.time(), timestamp_end=time.time(), -- cgit v1.2.3 From 280b491ab2b743f75483e2916e5344b22d4136e1 Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Sun, 31 Jan 2016 12:15:44 +0100 Subject: migrate to hyperframe --- netlib/http/http2/connections.py | 81 ++--- netlib/http/http2/frame.py | 651 --------------------------------------- netlib/utils.py | 21 +- 3 files changed, 62 insertions(+), 691 deletions(-) delete mode 100644 netlib/http/http2/frame.py (limited to 'netlib') diff --git a/netlib/http/http2/connections.py b/netlib/http/http2/connections.py index c493abe6..c963f7c4 100644 --- a/netlib/http/http2/connections.py +++ b/netlib/http/http2/connections.py @@ -5,7 +5,8 @@ import time from hpack.hpack import Encoder, Decoder from ... import utils from .. import Headers, Response, Request -from . import frame + +from hyperframe import frame class TCPHandler(object): @@ -36,6 +37,15 @@ class HTTP2Protocol(object): CLIENT_CONNECTION_PREFACE = "PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n" + HTTP2_DEFAULT_SETTINGS = { + frame.SettingsFrame.HEADER_TABLE_SIZE: 4096, + frame.SettingsFrame.ENABLE_PUSH: 1, + frame.SettingsFrame.MAX_CONCURRENT_STREAMS: None, + frame.SettingsFrame.INITIAL_WINDOW_SIZE: 2 ** 16 - 1, + frame.SettingsFrame.MAX_FRAME_SIZE: 2 ** 14, + frame.SettingsFrame.MAX_HEADER_LIST_SIZE: None, + } + def __init__( self, tcp_handler=None, @@ -54,7 +64,7 @@ class HTTP2Protocol(object): self.decoder = decoder or Decoder() self.unhandled_frame_cb = unhandled_frame_cb - self.http2_settings = frame.HTTP2_DEFAULT_SETTINGS.copy() + self.http2_settings = self.HTTP2_DEFAULT_SETTINGS.copy() self.current_stream_id = None self.connection_preface_performed = False @@ -240,9 +250,9 @@ class HTTP2Protocol(object): magic = self.tcp_handler.rfile.safe_read(magic_length) assert magic == self.CLIENT_CONNECTION_PREFACE - frm = frame.SettingsFrame(state=self, settings={ - frame.SettingsFrame.SETTINGS.SETTINGS_ENABLE_PUSH: 0, - frame.SettingsFrame.SETTINGS.SETTINGS_MAX_CONCURRENT_STREAMS: 1, + frm = frame.SettingsFrame(settings={ + frame.SettingsFrame.ENABLE_PUSH: 0, + frame.SettingsFrame.MAX_CONCURRENT_STREAMS: 1, }) self.send_frame(frm, hide=True) self._receive_settings(hide=True) @@ -253,12 +263,12 @@ class HTTP2Protocol(object): self.tcp_handler.wfile.write(self.CLIENT_CONNECTION_PREFACE) - self.send_frame(frame.SettingsFrame(state=self), hide=True) + self.send_frame(frame.SettingsFrame(), hide=True) self._receive_settings(hide=True) # server announces own settings self._receive_settings(hide=True) # server acks my settings def send_frame(self, frm, hide=False): - raw_bytes = frm.to_bytes() + raw_bytes = frm.serialize() self.tcp_handler.wfile.write(raw_bytes) self.tcp_handler.wfile.flush() if not hide and self.dump_frames: # pragma no cover @@ -266,19 +276,19 @@ class HTTP2Protocol(object): def read_frame(self, hide=False): while True: - frm = frame.Frame.from_file(self.tcp_handler.rfile, self) + frm = utils.http2_read_frame(self.tcp_handler.rfile) if not hide and self.dump_frames: # pragma no cover print(frm.human_readable("<<")) if isinstance(frm, frame.PingFrame): - raw_bytes = frame.PingFrame(flags=frame.Frame.FLAG_ACK, payload=frm.payload).to_bytes() + raw_bytes = frame.PingFrame(flags=['ACK'], payload=frm.payload).serialize() self.tcp_handler.wfile.write(raw_bytes) self.tcp_handler.wfile.flush() continue - if isinstance(frm, frame.SettingsFrame) and not frm.flags & frame.Frame.FLAG_ACK: + if isinstance(frm, frame.SettingsFrame) and 'ACK' not in frm.flags: self._apply_settings(frm.settings, hide) - if isinstance(frm, frame.DataFrame) and frm.length > 0: - self._update_flow_control_window(frm.stream_id, frm.length) + if isinstance(frm, frame.DataFrame) and frm.flow_controlled_length > 0: + self._update_flow_control_window(frm.stream_id, frm.flow_controlled_length) return frm def check_alpn(self): @@ -321,15 +331,13 @@ class HTTP2Protocol(object): old_value = '-' self.http2_settings[setting] = value - frm = frame.SettingsFrame( - state=self, - flags=frame.Frame.FLAG_ACK) + frm = frame.SettingsFrame(flags=['ACK']) self.send_frame(frm, hide) def _update_flow_control_window(self, stream_id, increment): - frm = frame.WindowUpdateFrame(stream_id=0, window_size_increment=increment) + frm = frame.WindowUpdateFrame(stream_id=0, window_increment=increment) self.send_frame(frm) - frm = frame.WindowUpdateFrame(stream_id=stream_id, window_size_increment=increment) + frm = frame.WindowUpdateFrame(stream_id=stream_id, window_increment=increment) self.send_frame(frm) def _create_headers(self, headers, stream_id, end_stream=True): @@ -342,43 +350,40 @@ class HTTP2Protocol(object): header_block_fragment = self.encoder.encode(headers.fields) - chunk_size = self.http2_settings[frame.SettingsFrame.SETTINGS.SETTINGS_MAX_FRAME_SIZE] + chunk_size = self.http2_settings[frame.SettingsFrame.MAX_FRAME_SIZE] chunks = range(0, len(header_block_fragment), chunk_size) frms = [frm_cls( - state=self, - flags=frame.Frame.FLAG_NO_FLAGS, + flags=[], stream_id=stream_id, - header_block_fragment=header_block_fragment[i:i+chunk_size]) for frm_cls, i in frame_cls(chunks)] + data=header_block_fragment[i:i+chunk_size]) for frm_cls, i in frame_cls(chunks)] - last_flags = frame.Frame.FLAG_END_HEADERS + frms[-1].flags.add('END_HEADERS') if end_stream: - last_flags |= frame.Frame.FLAG_END_STREAM - frms[-1].flags = last_flags + frms[0].flags.add('END_STREAM') if self.dump_frames: # pragma no cover for frm in frms: print(frm.human_readable(">>")) - return [frm.to_bytes() for frm in frms] + return [frm.serialize() for frm in frms] def _create_body(self, body, stream_id): if body is None or len(body) == 0: return b'' - chunk_size = self.http2_settings[frame.SettingsFrame.SETTINGS.SETTINGS_MAX_FRAME_SIZE] + chunk_size = self.http2_settings[frame.SettingsFrame.MAX_FRAME_SIZE] chunks = range(0, len(body), chunk_size) frms = [frame.DataFrame( - state=self, - flags=frame.Frame.FLAG_NO_FLAGS, + flags=[], stream_id=stream_id, - payload=body[i:i+chunk_size]) for i in chunks] - frms[-1].flags = frame.Frame.FLAG_END_STREAM + data=body[i:i+chunk_size]) for i in chunks] + frms[-1].flags.add('END_STREAM') if self.dump_frames: # pragma no cover for frm in frms: print(frm.human_readable(">>")) - return [frm.to_bytes() for frm in frms] + return [frm.serialize() for frm in frms] def _receive_transmission(self, stream_id=None, include_body=True): if not include_body: @@ -386,7 +391,7 @@ class HTTP2Protocol(object): body_expected = True - header_block_fragment = b'' + header_blocks = b'' body = b'' while True: @@ -396,10 +401,10 @@ class HTTP2Protocol(object): (stream_id is None or frm.stream_id == stream_id) ): stream_id = frm.stream_id - header_block_fragment += frm.header_block_fragment - if frm.flags & frame.Frame.FLAG_END_STREAM: + header_blocks += frm.data + if 'END_STREAM' in frm.flags: body_expected = False - if frm.flags & frame.Frame.FLAG_END_HEADERS: + if 'END_HEADERS' in frm.flags: break else: self._handle_unexpected_frame(frm) @@ -407,14 +412,14 @@ class HTTP2Protocol(object): while body_expected: frm = self.read_frame() if isinstance(frm, frame.DataFrame) and frm.stream_id == stream_id: - body += frm.payload - if frm.flags & frame.Frame.FLAG_END_STREAM: + body += frm.data + if 'END_STREAM' in frm.flags: break else: self._handle_unexpected_frame(frm) headers = Headers( - [[str(k), str(v)] for k, v in self.decoder.decode(header_block_fragment)] + [[str(k), str(v)] for k, v in self.decoder.decode(header_blocks)] ) return stream_id, headers, body diff --git a/netlib/http/http2/frame.py b/netlib/http/http2/frame.py deleted file mode 100644 index 188629d4..00000000 --- a/netlib/http/http2/frame.py +++ /dev/null @@ -1,651 +0,0 @@ -from __future__ import absolute_import, print_function, division -import struct -from hpack.hpack import Encoder, Decoder - -from ...utils import BiDi -from ...exceptions import HttpSyntaxException - - -ERROR_CODES = BiDi( - NO_ERROR=0x0, - PROTOCOL_ERROR=0x1, - INTERNAL_ERROR=0x2, - FLOW_CONTROL_ERROR=0x3, - SETTINGS_TIMEOUT=0x4, - STREAM_CLOSED=0x5, - FRAME_SIZE_ERROR=0x6, - REFUSED_STREAM=0x7, - CANCEL=0x8, - COMPRESSION_ERROR=0x9, - CONNECT_ERROR=0xa, - ENHANCE_YOUR_CALM=0xb, - INADEQUATE_SECURITY=0xc, - HTTP_1_1_REQUIRED=0xd -) - -CLIENT_CONNECTION_PREFACE = b"PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n" - -class Frame(object): - - """ - Baseclass Frame - contains header - payload is defined in subclasses - """ - - FLAG_NO_FLAGS = 0x0 - FLAG_ACK = 0x1 - FLAG_END_STREAM = 0x1 - FLAG_END_HEADERS = 0x4 - FLAG_PADDED = 0x8 - FLAG_PRIORITY = 0x20 - - def __init__( - self, - state=None, - length=0, - flags=FLAG_NO_FLAGS, - stream_id=0x0): - valid_flags = 0 - for flag in self.VALID_FLAGS: - valid_flags |= flag - if flags | valid_flags != valid_flags: - raise ValueError('invalid flags detected.') - - if state is None: - class State(object): - pass - - state = State() - state.http2_settings = HTTP2_DEFAULT_SETTINGS.copy() - state.encoder = Encoder() - state.decoder = Decoder() - - self.state = state - - self.length = length - self.type = self.TYPE - self.flags = flags - self.stream_id = stream_id - - @classmethod - def _check_frame_size(cls, length, state): - if state: - settings = state.http2_settings - else: - settings = HTTP2_DEFAULT_SETTINGS.copy() - - max_frame_size = settings[ - SettingsFrame.SETTINGS.SETTINGS_MAX_FRAME_SIZE] - - if length > max_frame_size: - raise HttpSyntaxException( - "Frame size exceeded: %d, but only %d allowed." % ( - length, max_frame_size)) - - @classmethod - def from_file(cls, fp, state=None): - """ - read a HTTP/2 frame sent by a server or client - fp is a "file like" object that could be backed by a network - stream or a disk or an in memory stream reader - """ - raw_header = fp.safe_read(9) - - fields = struct.unpack("!HBBBL", raw_header) - length = (fields[0] << 8) + fields[1] - flags = fields[3] - stream_id = fields[4] - - if raw_header[:4] == b'HTTP': # pragma no cover - raise HttpSyntaxException("Expected HTTP2 Frame, got HTTP/1 connection") - - cls._check_frame_size(length, state) - - payload = fp.safe_read(length) - return FRAMES[fields[2]].from_bytes( - state, - length, - flags, - stream_id, - payload) - - def to_bytes(self): - payload = self.payload_bytes() - self.length = len(payload) - - self._check_frame_size(self.length, self.state) - - b = struct.pack('!HB', (self.length & 0xFFFF00) >> 8, self.length & 0x0000FF) - b += struct.pack('!B', self.TYPE) - b += struct.pack('!B', self.flags) - b += struct.pack('!L', self.stream_id & 0x7FFFFFFF) - b += payload - - return b - - def payload_bytes(self): # pragma: no cover - raise NotImplementedError() - - def payload_human_readable(self): # pragma: no cover - raise NotImplementedError() - - def human_readable(self, direction="-"): - self.length = len(self.payload_bytes()) - - return "\n".join([ - "%s: %s | length: %d | flags: %#x | stream_id: %d" % ( - direction, self.__class__.__name__, self.length, self.flags, self.stream_id), - self.payload_human_readable(), - "===============================================================", - ]) - - def __eq__(self, other): - return self.to_bytes() == other.to_bytes() - - -class DataFrame(Frame): - TYPE = 0x0 - VALID_FLAGS = [Frame.FLAG_END_STREAM, Frame.FLAG_PADDED] - - def __init__( - self, - state=None, - length=0, - flags=Frame.FLAG_NO_FLAGS, - stream_id=0x0, - payload=b'', - pad_length=0): - super(DataFrame, self).__init__(state, length, flags, stream_id) - self.payload = payload - self.pad_length = pad_length - - @classmethod - def from_bytes(cls, state, length, flags, stream_id, payload): - f = cls(state=state, length=length, flags=flags, stream_id=stream_id) - - if f.flags & Frame.FLAG_PADDED: - f.pad_length = struct.unpack('!B', payload[0])[0] - f.payload = payload[1:-f.pad_length] - else: - f.payload = payload - - return f - - def payload_bytes(self): - if self.stream_id == 0x0: - raise ValueError('DATA frames MUST be associated with a stream.') - - b = b'' - if self.flags & self.FLAG_PADDED: - b += struct.pack('!B', self.pad_length) - - b += bytes(self.payload) - - if self.flags & self.FLAG_PADDED: - b += b'\0' * self.pad_length - - return b - - def payload_human_readable(self): - return "payload: %s" % str(self.payload) - - -class HeadersFrame(Frame): - TYPE = 0x1 - VALID_FLAGS = [ - Frame.FLAG_END_STREAM, - Frame.FLAG_END_HEADERS, - Frame.FLAG_PADDED, - Frame.FLAG_PRIORITY] - - def __init__( - self, - state=None, - length=0, - flags=Frame.FLAG_NO_FLAGS, - stream_id=0x0, - header_block_fragment=b'', - pad_length=0, - exclusive=False, - stream_dependency=0x0, - weight=0): - super(HeadersFrame, self).__init__(state, length, flags, stream_id) - - self.header_block_fragment = header_block_fragment - self.pad_length = pad_length - self.exclusive = exclusive - self.stream_dependency = stream_dependency - self.weight = weight - - @classmethod - def from_bytes(cls, state, length, flags, stream_id, payload): - f = cls(state=state, length=length, flags=flags, stream_id=stream_id) - - if f.flags & Frame.FLAG_PADDED: - f.pad_length = struct.unpack('!B', payload[0])[0] - f.header_block_fragment = payload[1:-f.pad_length] - else: - f.header_block_fragment = payload[0:] - - if f.flags & Frame.FLAG_PRIORITY: - f.stream_dependency, f.weight = struct.unpack( - '!LB', f.header_block_fragment[:5]) - f.exclusive = bool(f.stream_dependency >> 31) - f.stream_dependency &= 0x7FFFFFFF - f.header_block_fragment = f.header_block_fragment[5:] - - return f - - def payload_bytes(self): - if self.stream_id == 0x0: - raise ValueError('HEADERS frames MUST be associated with a stream.') - - b = b'' - if self.flags & self.FLAG_PADDED: - b += struct.pack('!B', self.pad_length) - - if self.flags & self.FLAG_PRIORITY: - b += struct.pack('!LB', - (int(self.exclusive) << 31) | self.stream_dependency, - self.weight) - - b += self.header_block_fragment - - if self.flags & self.FLAG_PADDED: - b += b'\0' * self.pad_length - - return b - - def payload_human_readable(self): - s = [] - - if self.flags & self.FLAG_PRIORITY: - s.append("exclusive: %d" % self.exclusive) - s.append("stream dependency: %#x" % self.stream_dependency) - s.append("weight: %d" % self.weight) - - if self.flags & self.FLAG_PADDED: - s.append("padding: %d" % self.pad_length) - - s.append( - "header_block_fragment: %s" % - self.header_block_fragment.encode('hex')) - - return "\n".join(s) - - -class PriorityFrame(Frame): - TYPE = 0x2 - VALID_FLAGS = [] - - def __init__( - self, - state=None, - length=0, - flags=Frame.FLAG_NO_FLAGS, - stream_id=0x0, - exclusive=False, - stream_dependency=0x0, - weight=0): - super(PriorityFrame, self).__init__(state, length, flags, stream_id) - self.exclusive = exclusive - self.stream_dependency = stream_dependency - self.weight = weight - - @classmethod - def from_bytes(cls, state, length, flags, stream_id, payload): - f = cls(state=state, length=length, flags=flags, stream_id=stream_id) - - f.stream_dependency, f.weight = struct.unpack('!LB', payload) - f.exclusive = bool(f.stream_dependency >> 31) - f.stream_dependency &= 0x7FFFFFFF - - return f - - def payload_bytes(self): - if self.stream_id == 0x0: - raise ValueError( - 'PRIORITY frames MUST be associated with a stream.') - - return struct.pack( - '!LB', - (int( - self.exclusive) << 31) | self.stream_dependency, - self.weight) - - def payload_human_readable(self): - s = [] - s.append("exclusive: %d" % self.exclusive) - s.append("stream dependency: %#x" % self.stream_dependency) - s.append("weight: %d" % self.weight) - return "\n".join(s) - - -class RstStreamFrame(Frame): - TYPE = 0x3 - VALID_FLAGS = [] - - def __init__( - self, - state=None, - length=0, - flags=Frame.FLAG_NO_FLAGS, - stream_id=0x0, - error_code=0x0): - super(RstStreamFrame, self).__init__(state, length, flags, stream_id) - self.error_code = error_code - - @classmethod - def from_bytes(cls, state, length, flags, stream_id, payload): - f = cls(state=state, length=length, flags=flags, stream_id=stream_id) - f.error_code = struct.unpack('!L', payload)[0] - return f - - def payload_bytes(self): - if self.stream_id == 0x0: - raise ValueError( - 'RST_STREAM frames MUST be associated with a stream.') - - return struct.pack('!L', self.error_code) - - def payload_human_readable(self): - return "error code: %#x" % self.error_code - - -class SettingsFrame(Frame): - TYPE = 0x4 - VALID_FLAGS = [Frame.FLAG_ACK] - - SETTINGS = BiDi( - SETTINGS_HEADER_TABLE_SIZE=0x1, - SETTINGS_ENABLE_PUSH=0x2, - SETTINGS_MAX_CONCURRENT_STREAMS=0x3, - SETTINGS_INITIAL_WINDOW_SIZE=0x4, - SETTINGS_MAX_FRAME_SIZE=0x5, - SETTINGS_MAX_HEADER_LIST_SIZE=0x6, - ) - - def __init__( - self, - state=None, - length=0, - flags=Frame.FLAG_NO_FLAGS, - stream_id=0x0, - settings=None): - super(SettingsFrame, self).__init__(state, length, flags, stream_id) - - if settings is None: - settings = {} - - self.settings = settings - - @classmethod - def from_bytes(cls, state, length, flags, stream_id, payload): - f = cls(state=state, length=length, flags=flags, stream_id=stream_id) - - for i in range(0, len(payload), 6): - identifier, value = struct.unpack("!HL", payload[i:i + 6]) - f.settings[identifier] = value - - return f - - def payload_bytes(self): - if self.stream_id != 0x0: - raise ValueError( - 'SETTINGS frames MUST NOT be associated with a stream.') - - b = b'' - for identifier, value in self.settings.items(): - b += struct.pack("!HL", identifier & 0xFF, value) - - return b - - def payload_human_readable(self): - s = [] - - for identifier, value in self.settings.items(): - s.append("%s: %#x" % (self.SETTINGS.get_name(identifier), value)) - - if not s: - return "settings: None" - else: - return "\n".join(s) - - -class PushPromiseFrame(Frame): - TYPE = 0x5 - VALID_FLAGS = [Frame.FLAG_END_HEADERS, Frame.FLAG_PADDED] - - def __init__( - self, - state=None, - length=0, - flags=Frame.FLAG_NO_FLAGS, - stream_id=0x0, - promised_stream=0x0, - header_block_fragment=b'', - pad_length=0): - super(PushPromiseFrame, self).__init__(state, length, flags, stream_id) - self.pad_length = pad_length - self.promised_stream = promised_stream - self.header_block_fragment = header_block_fragment - - @classmethod - def from_bytes(cls, state, length, flags, stream_id, payload): - f = cls(state=state, length=length, flags=flags, stream_id=stream_id) - - if f.flags & Frame.FLAG_PADDED: - f.pad_length, f.promised_stream = struct.unpack('!BL', payload[:5]) - f.header_block_fragment = payload[5:-f.pad_length] - else: - f.promised_stream = int(struct.unpack("!L", payload[:4])[0]) - f.header_block_fragment = payload[4:] - - f.promised_stream &= 0x7FFFFFFF - - return f - - def payload_bytes(self): - if self.stream_id == 0x0: - raise ValueError( - 'PUSH_PROMISE frames MUST be associated with a stream.') - - if self.promised_stream == 0x0: - raise ValueError('Promised stream id not valid.') - - b = b'' - if self.flags & self.FLAG_PADDED: - b += struct.pack('!B', self.pad_length) - - b += struct.pack('!L', self.promised_stream & 0x7FFFFFFF) - b += bytes(self.header_block_fragment) - - if self.flags & self.FLAG_PADDED: - b += b'\0' * self.pad_length - - return b - - def payload_human_readable(self): - s = [] - - if self.flags & self.FLAG_PADDED: - s.append("padding: %d" % self.pad_length) - - s.append("promised stream: %#x" % self.promised_stream) - s.append( - "header_block_fragment: %s" % - self.header_block_fragment.encode('hex')) - - return "\n".join(s) - - -class PingFrame(Frame): - TYPE = 0x6 - VALID_FLAGS = [Frame.FLAG_ACK] - - def __init__( - self, - state=None, - length=0, - flags=Frame.FLAG_NO_FLAGS, - stream_id=0x0, - payload=b''): - super(PingFrame, self).__init__(state, length, flags, stream_id) - self.payload = payload - - @classmethod - def from_bytes(cls, state, length, flags, stream_id, payload): - f = cls(state=state, length=length, flags=flags, stream_id=stream_id) - f.payload = payload - return f - - def payload_bytes(self): - if self.stream_id != 0x0: - raise ValueError( - 'PING frames MUST NOT be associated with a stream.') - - b = self.payload[0:8] - b += b'\0' * (8 - len(b)) - return b - - def payload_human_readable(self): - return "opaque data: %s" % str(self.payload) - - -class GoAwayFrame(Frame): - TYPE = 0x7 - VALID_FLAGS = [] - - def __init__( - self, - state=None, - length=0, - flags=Frame.FLAG_NO_FLAGS, - stream_id=0x0, - last_stream=0x0, - error_code=0x0, - data=b''): - super(GoAwayFrame, self).__init__(state, length, flags, stream_id) - self.last_stream = last_stream - self.error_code = error_code - self.data = data - - @classmethod - def from_bytes(cls, state, length, flags, stream_id, payload): - f = cls(state=state, length=length, flags=flags, stream_id=stream_id) - - f.last_stream, f.error_code = struct.unpack("!LL", payload[:8]) - f.last_stream &= 0x7FFFFFFF - f.data = payload[8:] - - return f - - def payload_bytes(self): - if self.stream_id != 0x0: - raise ValueError( - 'GOAWAY frames MUST NOT be associated with a stream.') - - b = struct.pack('!LL', self.last_stream & 0x7FFFFFFF, self.error_code) - b += bytes(self.data) - return b - - def payload_human_readable(self): - s = [] - s.append("last stream: %#x" % self.last_stream) - s.append("error code: %d" % self.error_code) - s.append("debug data: %s" % str(self.data)) - return "\n".join(s) - - -class WindowUpdateFrame(Frame): - TYPE = 0x8 - VALID_FLAGS = [] - - def __init__( - self, - state=None, - length=0, - flags=Frame.FLAG_NO_FLAGS, - stream_id=0x0, - window_size_increment=0x0): - super(WindowUpdateFrame, self).__init__(state, length, flags, stream_id) - self.window_size_increment = window_size_increment - - @classmethod - def from_bytes(cls, state, length, flags, stream_id, payload): - f = cls(state=state, length=length, flags=flags, stream_id=stream_id) - - f.window_size_increment = struct.unpack("!L", payload)[0] - f.window_size_increment &= 0x7FFFFFFF - - return f - - def payload_bytes(self): - if self.window_size_increment <= 0 or self.window_size_increment >= 2 ** 31: - raise ValueError( - 'Window Size Increment MUST be greater than 0 and less than 2^31.') - - return struct.pack('!L', self.window_size_increment & 0x7FFFFFFF) - - def payload_human_readable(self): - return "window size increment: %#x" % self.window_size_increment - - -class ContinuationFrame(Frame): - TYPE = 0x9 - VALID_FLAGS = [Frame.FLAG_END_HEADERS] - - def __init__( - self, - state=None, - length=0, - flags=Frame.FLAG_NO_FLAGS, - stream_id=0x0, - header_block_fragment=b''): - super(ContinuationFrame, self).__init__(state, length, flags, stream_id) - self.header_block_fragment = header_block_fragment - - @classmethod - def from_bytes(cls, state, length, flags, stream_id, payload): - f = cls(state=state, length=length, flags=flags, stream_id=stream_id) - f.header_block_fragment = payload - return f - - def payload_bytes(self): - if self.stream_id == 0x0: - raise ValueError( - 'CONTINUATION frames MUST be associated with a stream.') - - return self.header_block_fragment - - def payload_human_readable(self): - s = [] - s.append( - "header_block_fragment: %s" % - self.header_block_fragment.encode('hex')) - return "\n".join(s) - -_FRAME_CLASSES = [ - DataFrame, - HeadersFrame, - PriorityFrame, - RstStreamFrame, - SettingsFrame, - PushPromiseFrame, - PingFrame, - GoAwayFrame, - WindowUpdateFrame, - ContinuationFrame -] -FRAMES = {cls.TYPE: cls for cls in _FRAME_CLASSES} - - -HTTP2_DEFAULT_SETTINGS = { - SettingsFrame.SETTINGS.SETTINGS_HEADER_TABLE_SIZE: 4096, - SettingsFrame.SETTINGS.SETTINGS_ENABLE_PUSH: 1, - SettingsFrame.SETTINGS.SETTINGS_MAX_CONCURRENT_STREAMS: None, - SettingsFrame.SETTINGS.SETTINGS_INITIAL_WINDOW_SIZE: 2 ** 16 - 1, - SettingsFrame.SETTINGS.SETTINGS_MAX_FRAME_SIZE: 2 ** 14, - SettingsFrame.SETTINGS.SETTINGS_MAX_HEADER_LIST_SIZE: None, -} diff --git a/netlib/utils.py b/netlib/utils.py index 66225897..c537754a 100644 --- a/netlib/utils.py +++ b/netlib/utils.py @@ -2,12 +2,12 @@ from __future__ import absolute_import, print_function, division import os.path import re import string +import codecs import unicodedata - import six from six.moves import urllib - +import hyperframe def always_bytes(unicode_or_bytes, *encode_args): if isinstance(unicode_or_bytes, six.text_type): @@ -366,3 +366,20 @@ def multipartdecode(headers, content): r.append((key, value)) return r return [] + + +def http2_read_raw_frame(rfile): + field = rfile.peek(3) + length = int(codecs.encode(field, 'hex_codec'), 16) + + if length == 4740180: + raise ValueError("Length field looks more like HTTP/1.1: %s" % rfile.peek(20)) + + raw = rfile.safe_read(9 + length) + return raw + +def http2_read_frame(rfile): + raw = http2_read_raw_frame(rfile) + frame, length = hyperframe.frame.Frame.parse_frame_header(raw[:9]) + frame.parse_body(memoryview(raw[9:])) + return frame -- cgit v1.2.3 From e98c729bb9b0d3debde6f61c948108bdc9dbafbe Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Sun, 31 Jan 2016 14:16:03 +0100 Subject: test on python3 --- netlib/http/http2/connections.py | 40 +++++++++++++++++++++++----------------- netlib/utils.py | 14 +++++++------- 2 files changed, 30 insertions(+), 24 deletions(-) (limited to 'netlib') diff --git a/netlib/http/http2/connections.py b/netlib/http/http2/connections.py index c963f7c4..91133121 100644 --- a/netlib/http/http2/connections.py +++ b/netlib/http/http2/connections.py @@ -8,6 +8,11 @@ from .. import Headers, Response, Request from hyperframe import frame +# TODO: remove once hyperframe released a new version > 3.1.1 +# wrapper for deprecated name in old hyperframe release +frame.SettingsFrame.MAX_FRAME_SIZE = frame.SettingsFrame.SETTINGS_MAX_FRAME_SIZE +frame.SettingsFrame.MAX_HEADER_LIST_SIZE = frame.SettingsFrame.SETTINGS_MAX_HEADER_LIST_SIZE + class TCPHandler(object): @@ -35,7 +40,7 @@ class HTTP2Protocol(object): HTTP_1_1_REQUIRED=0xd ) - CLIENT_CONNECTION_PREFACE = "PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n" + CLIENT_CONNECTION_PREFACE = b'PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n' HTTP2_DEFAULT_SETTINGS = { frame.SettingsFrame.HEADER_TABLE_SIZE: 4096, @@ -94,7 +99,7 @@ class HTTP2Protocol(object): timestamp_end = time.time() - authority = headers.get(':authority', '') + authority = headers.get(':authority', b'') method = headers.get(':method', 'GET') scheme = headers.get(':scheme', 'https') path = headers.get(':path', '/') @@ -113,6 +118,8 @@ class HTTP2Protocol(object): form_in = "absolute" # FIXME: verify if path or :host contains what we need scheme, host, port, _ = utils.parse_url(path) + scheme = scheme.decode('ascii') + host = host.decode('ascii') if host is None: host = 'localhost' @@ -122,18 +129,17 @@ class HTTP2Protocol(object): request = Request( form_in, - method, - scheme, - host, + method.encode('ascii'), + scheme.encode('ascii'), + host.encode('ascii'), port, - path, - (2, 0), + path.encode('ascii'), + b'2.0', headers, body, timestamp_start, timestamp_end, ) - # FIXME: We should not do this. request.stream_id = stream_id return request @@ -141,7 +147,7 @@ class HTTP2Protocol(object): def read_response( self, __rfile, - request_method='', + request_method=b'', body_size_limit=None, include_body=True, stream_id=None, @@ -170,9 +176,9 @@ class HTTP2Protocol(object): timestamp_end = None response = Response( - (2, 0), + b'2.0', int(headers.get(':status', 502)), - "", + b'', headers, body, timestamp_start=timestamp_start, @@ -200,13 +206,13 @@ class HTTP2Protocol(object): headers = request.headers.copy() if ':authority' not in headers: - headers.fields.insert(0, (':authority', bytes(authority))) + headers.fields.insert(0, (b':authority', authority.encode('ascii'))) if ':scheme' not in headers: - headers.fields.insert(0, (':scheme', bytes(request.scheme))) + headers.fields.insert(0, (b':scheme', request.scheme.encode('ascii'))) if ':path' not in headers: - headers.fields.insert(0, (':path', bytes(request.path))) + headers.fields.insert(0, (b':path', request.path.encode('ascii'))) if ':method' not in headers: - headers.fields.insert(0, (':method', bytes(request.method))) + headers.fields.insert(0, (b':method', request.method.encode('ascii'))) if hasattr(request, 'stream_id'): stream_id = request.stream_id @@ -223,7 +229,7 @@ class HTTP2Protocol(object): headers = response.headers.copy() if ':status' not in headers: - headers.fields.insert(0, (':status', bytes(str(response.status_code)))) + headers.fields.insert(0, (b':status', str(response.status_code).encode('ascii'))) if hasattr(response, 'stream_id'): stream_id = response.stream_id @@ -419,7 +425,7 @@ class HTTP2Protocol(object): self._handle_unexpected_frame(frm) headers = Headers( - [[str(k), str(v)] for k, v in self.decoder.decode(header_blocks)] + [[k.encode('ascii'), v.encode('ascii')] for k, v in self.decoder.decode(header_blocks)] ) return stream_id, headers, body diff --git a/netlib/utils.py b/netlib/utils.py index c537754a..1c1b617a 100644 --- a/netlib/utils.py +++ b/netlib/utils.py @@ -369,17 +369,17 @@ def multipartdecode(headers, content): def http2_read_raw_frame(rfile): - field = rfile.peek(3) - length = int(codecs.encode(field, 'hex_codec'), 16) + header = rfile.safe_read(9) + length = int(codecs.encode(header[:3], 'hex_codec'), 16) if length == 4740180: raise ValueError("Length field looks more like HTTP/1.1: %s" % rfile.peek(20)) - raw = rfile.safe_read(9 + length) - return raw + body = rfile.safe_read(length) + return [header, body] def http2_read_frame(rfile): - raw = http2_read_raw_frame(rfile) - frame, length = hyperframe.frame.Frame.parse_frame_header(raw[:9]) - frame.parse_body(memoryview(raw[9:])) + header, body = http2_read_raw_frame(rfile) + frame, length = hyperframe.frame.Frame.parse_frame_header(header) + frame.parse_body(memoryview(body)) return frame -- cgit v1.2.3 From bda49dd178fee1361f3585bd7efad67883298e5a Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Mon, 1 Feb 2016 19:38:14 +0100 Subject: fix #113, make Reader.peek() work on Python 3 --- netlib/tcp.py | 30 +++++++++++++++++++++++++----- 1 file changed, 25 insertions(+), 5 deletions(-) (limited to 'netlib') diff --git a/netlib/tcp.py b/netlib/tcp.py index 8902b9dc..57a9b737 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -25,6 +25,10 @@ from netlib.exceptions import InvalidCertificateException, TcpReadIncomplete, Tl version_check.check_pyopenssl_version() +if six.PY2: + socket_fileobject = socket._fileobject +else: + socket_fileobject = socket.SocketIO EINTR = 4 @@ -270,7 +274,7 @@ class Reader(_FileLike): TlsException if there was an error with pyOpenSSL. NotImplementedError if the underlying file object is not a (pyOpenSSL) socket """ - if isinstance(self.o, socket._fileobject): + if isinstance(self.o, socket_fileobject): try: return self.o._sock.recv(length, socket.MSG_PEEK) except socket.error as e: @@ -423,8 +427,17 @@ class _Connection(object): def __init__(self, connection): if connection: self.connection = connection - self.rfile = Reader(self.connection.makefile('rb', self.rbufsize)) - self.wfile = Writer(self.connection.makefile('wb', self.wbufsize)) + # Ideally, we would use the Buffered IO in Python 3 by default. + # Unfortunately, the implementation of .peek() is broken for n>1 bytes, + # as it may just return what's left in the buffer and not all the bytes we want. + # As a workaround, we just use unbuffered sockets directly. + # https://mail.python.org/pipermail/python-dev/2009-June/089986.html + if six.PY2: + self.rfile = Reader(self.connection.makefile('rb', self.rbufsize)) + self.wfile = Writer(self.connection.makefile('wb', self.wbufsize)) + else: + self.rfile = Reader(socket.SocketIO(self.connection, "rb")) + self.wfile = Writer(socket.SocketIO(self.connection, "wb")) else: self.connection = None self.rfile = None @@ -663,8 +676,15 @@ class TCPClient(_Connection): connection.connect(self.address()) if not self.source_address: self.source_address = Address(connection.getsockname()) - self.rfile = Reader(connection.makefile('rb', self.rbufsize)) - self.wfile = Writer(connection.makefile('wb', self.wbufsize)) + + # See _Connection.__init__ why we do this dance. + if six.PY2: + self.rfile = Reader(connection.makefile('rb', self.rbufsize)) + self.wfile = Writer(connection.makefile('wb', self.wbufsize)) + else: + self.rfile = Reader(socket.SocketIO(connection, "rb")) + self.wfile = Writer(socket.SocketIO(connection, "wb")) + except (socket.error, IOError) as err: raise TcpException( 'Error connecting to "%s": %s' % -- cgit v1.2.3 From a3af0ce71d5b4368f1d9de8d17ff5e20086edcc4 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Mon, 1 Feb 2016 20:10:18 +0100 Subject: tests++ --- netlib/tcp.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'netlib') diff --git a/netlib/tcp.py b/netlib/tcp.py index 57a9b737..1523370b 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -272,7 +272,7 @@ class Reader(_FileLike): Raises: TcpException if there was an error with the socket TlsException if there was an error with pyOpenSSL. - NotImplementedError if the underlying file object is not a (pyOpenSSL) socket + NotImplementedError if the underlying file object is not a [pyOpenSSL] socket """ if isinstance(self.o, socket_fileobject): try: -- cgit v1.2.3 From 931b5459e92ec237914d7cca9034c5a348033bdb Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Mon, 1 Feb 2016 20:19:34 +0100 Subject: remove code duplication --- netlib/tcp.py | 38 ++++++++++++++++++-------------------- 1 file changed, 18 insertions(+), 20 deletions(-) (limited to 'netlib') diff --git a/netlib/tcp.py b/netlib/tcp.py index 1523370b..682db29a 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -424,20 +424,26 @@ class _Connection(object): rbufsize = -1 wbufsize = -1 + def _makefile(self): + """ + Set up .rfile and .wfile attributes from .connection + """ + # Ideally, we would use the Buffered IO in Python 3 by default. + # Unfortunately, the implementation of .peek() is broken for n>1 bytes, + # as it may just return what's left in the buffer and not all the bytes we want. + # As a workaround, we just use unbuffered sockets directly. + # https://mail.python.org/pipermail/python-dev/2009-June/089986.html + if six.PY2: + self.rfile = Reader(self.connection.makefile('rb', self.rbufsize)) + self.wfile = Writer(self.connection.makefile('wb', self.wbufsize)) + else: + self.rfile = Reader(socket.SocketIO(self.connection, "rb")) + self.wfile = Writer(socket.SocketIO(self.connection, "wb")) + def __init__(self, connection): if connection: self.connection = connection - # Ideally, we would use the Buffered IO in Python 3 by default. - # Unfortunately, the implementation of .peek() is broken for n>1 bytes, - # as it may just return what's left in the buffer and not all the bytes we want. - # As a workaround, we just use unbuffered sockets directly. - # https://mail.python.org/pipermail/python-dev/2009-June/089986.html - if six.PY2: - self.rfile = Reader(self.connection.makefile('rb', self.rbufsize)) - self.wfile = Writer(self.connection.makefile('wb', self.wbufsize)) - else: - self.rfile = Reader(socket.SocketIO(self.connection, "rb")) - self.wfile = Writer(socket.SocketIO(self.connection, "wb")) + self._makefile() else: self.connection = None self.rfile = None @@ -676,20 +682,12 @@ class TCPClient(_Connection): connection.connect(self.address()) if not self.source_address: self.source_address = Address(connection.getsockname()) - - # See _Connection.__init__ why we do this dance. - if six.PY2: - self.rfile = Reader(connection.makefile('rb', self.rbufsize)) - self.wfile = Writer(connection.makefile('wb', self.wbufsize)) - else: - self.rfile = Reader(socket.SocketIO(connection, "rb")) - self.wfile = Writer(socket.SocketIO(connection, "wb")) - except (socket.error, IOError) as err: raise TcpException( 'Error connecting to "%s": %s' % (self.address.host, err)) self.connection = connection + self._makefile() def settimeout(self, n): self.connection.settimeout(n) -- cgit v1.2.3 From e222858f01095c61178590123eea7b49b5d7853b Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Tue, 2 Feb 2016 17:39:49 +0100 Subject: bump dependency and remove deprecated fields --- netlib/http/http2/connections.py | 5 ----- 1 file changed, 5 deletions(-) (limited to 'netlib') diff --git a/netlib/http/http2/connections.py b/netlib/http/http2/connections.py index 91133121..5e877286 100644 --- a/netlib/http/http2/connections.py +++ b/netlib/http/http2/connections.py @@ -8,11 +8,6 @@ from .. import Headers, Response, Request from hyperframe import frame -# TODO: remove once hyperframe released a new version > 3.1.1 -# wrapper for deprecated name in old hyperframe release -frame.SettingsFrame.MAX_FRAME_SIZE = frame.SettingsFrame.SETTINGS_MAX_FRAME_SIZE -frame.SettingsFrame.MAX_HEADER_LIST_SIZE = frame.SettingsFrame.SETTINGS_MAX_HEADER_LIST_SIZE - class TCPHandler(object): -- cgit v1.2.3 From a188ae5ac55c4f9564d7590c827be9a7eb9afba4 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Tue, 2 Feb 2016 18:15:55 +0100 Subject: allow creation of certs without CN --- netlib/certutils.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) (limited to 'netlib') diff --git a/netlib/certutils.py b/netlib/certutils.py index e6d71c39..a0111381 100644 --- a/netlib/certutils.py +++ b/netlib/certutils.py @@ -101,7 +101,8 @@ def dummy_cert(privkey, cacert, commonname, sans): cert.gmtime_adj_notBefore(-3600 * 48) cert.gmtime_adj_notAfter(DEFAULT_EXP) cert.set_issuer(cacert.get_subject()) - cert.get_subject().CN = commonname + if commonname is not None: + cert.get_subject().CN = commonname cert.set_serial_number(int(time.time() * 10000)) if ss: cert.set_version(2) @@ -294,6 +295,8 @@ class CertStore(object): @staticmethod def asterisk_forms(dn): + if dn is None: + return [] parts = dn.split(b".") parts.reverse() curr_dn = b"" -- cgit v1.2.3 From 8f8796f9d9d49e1e968cb8c48b09f26b2a11dcb2 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Mon, 8 Feb 2016 00:40:55 +0100 Subject: expose OpenSSL's HAS_ALPN --- netlib/tcp.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) (limited to 'netlib') diff --git a/netlib/tcp.py b/netlib/tcp.py index 682db29a..85b4b0e2 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -31,6 +31,7 @@ else: socket_fileobject = socket.SocketIO EINTR = 4 +HAS_ALPN = OpenSSL._util.lib.Cryptography_HAS_ALPN # To enable all SSL methods use: SSLv23 # then add options to disable certain methods @@ -542,7 +543,7 @@ class _Connection(object): if log_ssl_key: context.set_info_callback(log_ssl_key) - if OpenSSL._util.lib.Cryptography_HAS_ALPN: + if HAS_ALPN: if alpn_protos is not None: # advertise application layer protocols context.set_alpn_protos(alpn_protos) @@ -696,7 +697,7 @@ class TCPClient(_Connection): return self.connection.gettimeout() def get_alpn_proto_negotiated(self): - if OpenSSL._util.lib.Cryptography_HAS_ALPN and self.ssl_established: + if HAS_ALPN and self.ssl_established: return self.connection.get_alpn_proto_negotiated() else: return b"" @@ -802,7 +803,7 @@ class BaseHandler(_Connection): self.connection.settimeout(n) def get_alpn_proto_negotiated(self): - if OpenSSL._util.lib.Cryptography_HAS_ALPN and self.ssl_established: + if HAS_ALPN and self.ssl_established: return self.connection.get_alpn_proto_negotiated() else: return b"" -- cgit v1.2.3 From 4873547de3c65ba7c14cace4bca7b17368b2900d Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Mon, 8 Feb 2016 02:10:48 +0100 Subject: minor fixes --- netlib/http/headers.py | 2 +- netlib/http/request.py | 2 +- netlib/odict.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) (limited to 'netlib') diff --git a/netlib/http/headers.py b/netlib/http/headers.py index f64e6200..6eb9db92 100644 --- a/netlib/http/headers.py +++ b/netlib/http/headers.py @@ -194,7 +194,7 @@ class Headers(MutableMapping): return Headers(copy.copy(self.fields)) # Implement the StateObject protocol from mitmproxy - def get_state(self, short=False): + def get_state(self): return tuple(tuple(field) for field in self.fields) def load_state(self, state): diff --git a/netlib/http/request.py b/netlib/http/request.py index 5ebf21a5..6dabb189 100644 --- a/netlib/http/request.py +++ b/netlib/http/request.py @@ -292,7 +292,7 @@ class Request(Message): return None @multipart_form.setter - def multipart_form(self): + def multipart_form(self, value): raise NotImplementedError() # Legacy diff --git a/netlib/odict.py b/netlib/odict.py index 1124b23a..90317e5e 100644 --- a/netlib/odict.py +++ b/netlib/odict.py @@ -169,7 +169,7 @@ class ODict(object): return count # Implement the StateObject protocol from mitmproxy - def get_state(self, short=False): + def get_state(self): return [tuple(i) for i in self.lst] def load_state(self, state): -- cgit v1.2.3 From fe0ed63c4a3486402f65638b476149ebba752055 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Mon, 8 Feb 2016 04:16:58 +0100 Subject: add Serializable ABC --- netlib/certutils.py | 22 +++++++++++++++++----- netlib/http/headers.py | 7 +++---- netlib/http/message.py | 33 ++++++++++++++++++++++++++++++--- netlib/http/request.py | 5 ++--- netlib/http/response.py | 5 ++--- netlib/odict.py | 10 ++++++---- netlib/tcp.py | 17 ++++++++++++++++- netlib/utils.py | 26 +++++++++++++++++++++++++- 8 files changed, 101 insertions(+), 24 deletions(-) (limited to 'netlib') diff --git a/netlib/certutils.py b/netlib/certutils.py index a0111381..ecdc0624 100644 --- a/netlib/certutils.py +++ b/netlib/certutils.py @@ -13,6 +13,8 @@ from pyasn1.error import PyAsn1Error import OpenSSL # Default expiry must not be too long: https://github.com/mitmproxy/mitmproxy/issues/815 +from netlib.utils import Serializable + DEFAULT_EXP = 94608000 # = 24 * 60 * 60 * 365 * 3 # Generated with "openssl dhparam". It's too slow to generate this on startup. DEFAULT_DHPARAM = b""" @@ -361,7 +363,7 @@ class _GeneralNames(univ.SequenceOf): constraint.ValueSizeConstraint(1, 1024) -class SSLCert(object): +class SSLCert(Serializable): def __init__(self, cert): """ @@ -375,15 +377,25 @@ class SSLCert(object): def __ne__(self, other): return not self.__eq__(other) + def get_state(self): + return self.to_pem() + + def set_state(self, state): + self.x509 = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, state) + + @classmethod + def from_state(cls, state): + cls.from_pem(state) + @classmethod - def from_pem(klass, txt): + def from_pem(cls, txt): x509 = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, txt) - return klass(x509) + return cls(x509) @classmethod - def from_der(klass, der): + def from_der(cls, der): pem = ssl.DER_cert_to_PEM_cert(der) - return klass.from_pem(pem) + return cls.from_pem(pem) def to_pem(self): return OpenSSL.crypto.dump_certificate( diff --git a/netlib/http/headers.py b/netlib/http/headers.py index 6eb9db92..78404796 100644 --- a/netlib/http/headers.py +++ b/netlib/http/headers.py @@ -14,7 +14,7 @@ except ImportError: # pragma: nocover import six -from netlib.utils import always_byte_args, always_bytes +from netlib.utils import always_byte_args, always_bytes, Serializable if six.PY2: # pragma: nocover _native = lambda x: x @@ -27,7 +27,7 @@ else: _always_byte_args = always_byte_args("utf-8", "surrogateescape") -class Headers(MutableMapping): +class Headers(MutableMapping, Serializable): """ Header class which allows both convenient access to individual headers as well as direct access to the underlying raw data. Provides a full dictionary interface. @@ -193,11 +193,10 @@ class Headers(MutableMapping): def copy(self): return Headers(copy.copy(self.fields)) - # Implement the StateObject protocol from mitmproxy def get_state(self): return tuple(tuple(field) for field in self.fields) - def load_state(self, state): + def set_state(self, state): self.fields = [list(field) for field in state] @classmethod diff --git a/netlib/http/message.py b/netlib/http/message.py index 28f55fa2..3d65f93e 100644 --- a/netlib/http/message.py +++ b/netlib/http/message.py @@ -4,9 +4,10 @@ import warnings import six +from netlib.utils import Serializable +from .headers import Headers from .. import encoding, utils - CONTENT_MISSING = 0 if six.PY2: # pragma: nocover @@ -18,7 +19,7 @@ else: _always_bytes = lambda x: utils.always_bytes(x, "utf-8", "surrogateescape") -class MessageData(object): +class MessageData(Serializable): def __eq__(self, other): if isinstance(other, MessageData): return self.__dict__ == other.__dict__ @@ -27,8 +28,24 @@ class MessageData(object): def __ne__(self, other): return not self.__eq__(other) + def set_state(self, state): + for k, v in state.items(): + if k == "headers": + v = Headers.from_state(v) + setattr(self, k, v) + + def get_state(self): + state = vars(self).copy() + state["headers"] = state["headers"].get_state() + return state + + @classmethod + def from_state(cls, state): + state["headers"] = Headers.from_state(state["headers"]) + return cls(**state) + -class Message(object): +class Message(Serializable): def __init__(self, data): self.data = data @@ -40,6 +57,16 @@ class Message(object): def __ne__(self, other): return not self.__eq__(other) + def get_state(self): + return self.data.get_state() + + def set_state(self, state): + self.data.set_state(state) + + @classmethod + def from_state(cls, state): + return cls(**state) + @property def headers(self): """ diff --git a/netlib/http/request.py b/netlib/http/request.py index 6dabb189..0e0f88ce 100644 --- a/netlib/http/request.py +++ b/netlib/http/request.py @@ -16,9 +16,8 @@ from .message import Message, _native, _always_bytes, MessageData class RequestData(MessageData): def __init__(self, first_line_format, method, scheme, host, port, path, http_version, headers=None, content=None, timestamp_start=None, timestamp_end=None): - if not headers: - headers = Headers() - assert isinstance(headers, Headers) + if not isinstance(headers, Headers): + headers = Headers(headers) self.first_line_format = first_line_format self.method = method diff --git a/netlib/http/response.py b/netlib/http/response.py index 66e5ded6..8f4d6215 100644 --- a/netlib/http/response.py +++ b/netlib/http/response.py @@ -12,9 +12,8 @@ from ..odict import ODict class ResponseData(MessageData): def __init__(self, http_version, status_code, reason=None, headers=None, content=None, timestamp_start=None, timestamp_end=None): - if not headers: - headers = Headers() - assert isinstance(headers, Headers) + if not isinstance(headers, Headers): + headers = Headers(headers) self.http_version = http_version self.status_code = status_code diff --git a/netlib/odict.py b/netlib/odict.py index 90317e5e..1e6e381a 100644 --- a/netlib/odict.py +++ b/netlib/odict.py @@ -3,6 +3,8 @@ import re import copy import six +from .utils import Serializable + def safe_subn(pattern, repl, target, *args, **kwargs): """ @@ -13,7 +15,7 @@ def safe_subn(pattern, repl, target, *args, **kwargs): return re.subn(str(pattern), str(repl), target, *args, **kwargs) -class ODict(object): +class ODict(Serializable): """ A dictionary-like object for managing ordered (key, value) data. Think @@ -172,12 +174,12 @@ class ODict(object): def get_state(self): return [tuple(i) for i in self.lst] - def load_state(self, state): + def set_state(self, state): self.lst = [list(i) for i in state] @classmethod - def from_state(klass, state): - return klass([list(i) for i in state]) + def from_state(cls, state): + return cls([list(i) for i in state]) class ODictCaseless(ODict): diff --git a/netlib/tcp.py b/netlib/tcp.py index 85b4b0e2..2e91a70c 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -16,6 +16,7 @@ import six import OpenSSL from OpenSSL import SSL +from netlib.utils import Serializable from . import certutils, version_check # This is a rather hackish way to make sure that @@ -298,7 +299,7 @@ class Reader(_FileLike): raise NotImplementedError("Can only peek into (pyOpenSSL) sockets") -class Address(object): +class Address(Serializable): """ This class wraps an IPv4/IPv6 tuple to provide named attributes and @@ -309,6 +310,20 @@ class Address(object): self.address = tuple(address) self.use_ipv6 = use_ipv6 + def get_state(self): + return { + "address": self.address, + "use_ipv6": self.use_ipv6 + } + + def set_state(self, state): + self.address = state["address"] + self.use_ipv6 = state["use_ipv6"] + + @classmethod + def from_state(cls, state): + return Address(**state) + @classmethod def wrap(cls, t): if isinstance(t, cls): diff --git a/netlib/utils.py b/netlib/utils.py index 1c1b617a..a0c2035c 100644 --- a/netlib/utils.py +++ b/netlib/utils.py @@ -1,14 +1,38 @@ from __future__ import absolute_import, print_function, division import os.path import re -import string import codecs import unicodedata +from abc import ABCMeta, abstractmethod + import six from six.moves import urllib import hyperframe + +@six.add_metaclass(ABCMeta) +class Serializable(object): + """ + ABC for Python's pickle protocol __getstate__ and __setstate__. + """ + + @classmethod + @abstractmethod + def from_state(cls, state): + obj = cls.__new__(cls) + obj.__setstate__(state) + return obj + + @abstractmethod + def get_state(self): + raise NotImplementedError() + + @abstractmethod + def set_state(self, state): + raise NotImplementedError() + + def always_bytes(unicode_or_bytes, *encode_args): if isinstance(unicode_or_bytes, six.text_type): return unicode_or_bytes.encode(*encode_args) -- cgit v1.2.3 From 173ff0b235cdb45a8923f313807d9804830c2a2b Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Mon, 8 Feb 2016 04:28:49 +0100 Subject: fix py3 compat --- netlib/certutils.py | 3 ++- netlib/http/message.py | 5 ++--- netlib/tcp.py | 5 ++--- 3 files changed, 6 insertions(+), 7 deletions(-) (limited to 'netlib') diff --git a/netlib/certutils.py b/netlib/certutils.py index ecdc0624..616a778e 100644 --- a/netlib/certutils.py +++ b/netlib/certutils.py @@ -12,8 +12,9 @@ from pyasn1.codec.der.decoder import decode from pyasn1.error import PyAsn1Error import OpenSSL +from .utils import Serializable + # Default expiry must not be too long: https://github.com/mitmproxy/mitmproxy/issues/815 -from netlib.utils import Serializable DEFAULT_EXP = 94608000 # = 24 * 60 * 60 * 365 * 3 # Generated with "openssl dhparam". It's too slow to generate this on startup. diff --git a/netlib/http/message.py b/netlib/http/message.py index 3d65f93e..e3d8ce37 100644 --- a/netlib/http/message.py +++ b/netlib/http/message.py @@ -4,7 +4,6 @@ import warnings import six -from netlib.utils import Serializable from .headers import Headers from .. import encoding, utils @@ -19,7 +18,7 @@ else: _always_bytes = lambda x: utils.always_bytes(x, "utf-8", "surrogateescape") -class MessageData(Serializable): +class MessageData(utils.Serializable): def __eq__(self, other): if isinstance(other, MessageData): return self.__dict__ == other.__dict__ @@ -45,7 +44,7 @@ class MessageData(Serializable): return cls(**state) -class Message(Serializable): +class Message(utils.Serializable): def __init__(self, data): self.data = data diff --git a/netlib/tcp.py b/netlib/tcp.py index 2e91a70c..c8548aea 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -16,8 +16,7 @@ import six import OpenSSL from OpenSSL import SSL -from netlib.utils import Serializable -from . import certutils, version_check +from . import certutils, version_check, utils # This is a rather hackish way to make sure that # the latest version of pyOpenSSL is actually installed. @@ -299,7 +298,7 @@ class Reader(_FileLike): raise NotImplementedError("Can only peek into (pyOpenSSL) sockets") -class Address(Serializable): +class Address(utils.Serializable): """ This class wraps an IPv4/IPv6 tuple to provide named attributes and -- cgit v1.2.3 From 655b521749efd5a600d342a1d95b67d32da280a8 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Mon, 8 Feb 2016 04:33:10 +0100 Subject: fix docstrings --- netlib/utils.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) (limited to 'netlib') diff --git a/netlib/utils.py b/netlib/utils.py index a0c2035c..d2fc7195 100644 --- a/netlib/utils.py +++ b/netlib/utils.py @@ -14,22 +14,29 @@ import hyperframe @six.add_metaclass(ABCMeta) class Serializable(object): """ - ABC for Python's pickle protocol __getstate__ and __setstate__. + Abstract Base Class that defines an API to save an object's state and restore it later on. """ @classmethod @abstractmethod def from_state(cls, state): - obj = cls.__new__(cls) - obj.__setstate__(state) - return obj + """ + Create a new object from the given state. + """ + raise NotImplementedError() @abstractmethod def get_state(self): + """ + Retrieve object state. + """ raise NotImplementedError() @abstractmethod def set_state(self, state): + """ + Set object state to the given state. + """ raise NotImplementedError() -- cgit v1.2.3 From ead9b0ab8c399feeb25e0851f2dadf654acf51f5 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Mon, 8 Feb 2016 15:09:25 +0100 Subject: fix http version string --- netlib/http/http2/connections.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'netlib') diff --git a/netlib/http/http2/connections.py b/netlib/http/http2/connections.py index 5e877286..52fa7193 100644 --- a/netlib/http/http2/connections.py +++ b/netlib/http/http2/connections.py @@ -129,7 +129,7 @@ class HTTP2Protocol(object): host.encode('ascii'), port, path.encode('ascii'), - b'2.0', + b"HTTP/2.0", headers, body, timestamp_start, @@ -171,7 +171,7 @@ class HTTP2Protocol(object): timestamp_end = None response = Response( - b'2.0', + b"HTTP/2.0", int(headers.get(':status', 502)), b'', headers, -- cgit v1.2.3 From 1dcb8b14acc3ba1f474ee9673bf4271e576fab9f Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Mon, 8 Feb 2016 15:09:29 +0100 Subject: bump version --- netlib/version.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'netlib') diff --git a/netlib/version.py b/netlib/version.py index 7a68ca39..8ff869cd 100644 --- a/netlib/version.py +++ b/netlib/version.py @@ -1,6 +1,6 @@ from __future__ import (absolute_import, print_function, division) -IVERSION = (0, 15, 1) +IVERSION = (0, 16) VERSION = ".".join(str(i) for i in IVERSION) MINORVERSION = ".".join(str(i) for i in IVERSION[:2]) NAME = "netlib" -- cgit v1.2.3 From aafa69a73829a7ec291a2d6fa0c4522caf287d17 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Sun, 14 Feb 2016 17:25:30 +0100 Subject: bump version --- netlib/version.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'netlib') diff --git a/netlib/version.py b/netlib/version.py index 8ff869cd..bc35c30f 100644 --- a/netlib/version.py +++ b/netlib/version.py @@ -1,6 +1,6 @@ from __future__ import (absolute_import, print_function, division) -IVERSION = (0, 16) +IVERSION = (0, 17) VERSION = ".".join(str(i) for i in IVERSION) MINORVERSION = ".".join(str(i) for i in IVERSION[:2]) NAME = "netlib" -- cgit v1.2.3