diff options
Diffstat (limited to 'netlib')
-rw-r--r-- | netlib/http/http1/protocol.py | 153 | ||||
-rw-r--r-- | netlib/http/http2/protocol.py | 272 | ||||
-rw-r--r-- | netlib/http/semantics.py | 91 | ||||
-rw-r--r-- | netlib/odict.py | 7 | ||||
-rw-r--r-- | netlib/utils.py | 10 |
5 files changed, 415 insertions, 118 deletions
diff --git a/netlib/http/http1/protocol.py b/netlib/http/http1/protocol.py index e46ad7ab..b098110a 100644 --- a/netlib/http/http1/protocol.py +++ b/netlib/http/http1/protocol.py @@ -4,8 +4,10 @@ import collections import string import sys import urlparse +import time from netlib import odict, utils, tcp, http +from netlib.http import semantics from .. import status_codes from ..exceptions import * @@ -14,13 +16,10 @@ class TCPHandler(object): self.rfile = rfile self.wfile = wfile -class HTTP1Protocol(object): +class HTTP1Protocol(semantics.ProtocolMixin): def __init__(self, tcp_handler=None, rfile=None, wfile=None): - if tcp_handler: - self.tcp_handler = tcp_handler - else: - self.tcp_handler = TCPHandler(rfile, wfile) + self.tcp_handler = tcp_handler or TCPHandler(rfile, wfile) def read_request(self, include_body=True, body_size_limit=None, allow_empty=False): @@ -39,6 +38,10 @@ class HTTP1Protocol(object): Raises: HttpError: If the input is invalid. """ + timestamp_start = time.time() + if hasattr(self.tcp_handler.rfile, "reset_timestamps"): + self.tcp_handler.rfile.reset_timestamps() + httpversion, host, port, scheme, method, path, headers, body = ( None, None, None, None, None, None, None, None) @@ -106,6 +109,12 @@ class HTTP1Protocol(object): True ) + if hasattr(self.tcp_handler.rfile, "first_byte_timestamp"): + # more accurate timestamp_start + timestamp_start = self.tcp_handler.rfile.first_byte_timestamp + + timestamp_end = time.time() + return http.Request( form_in, method, @@ -115,7 +124,9 @@ class HTTP1Protocol(object): path, httpversion, headers, - body + body, + timestamp_start, + timestamp_end, ) @@ -124,12 +135,15 @@ class HTTP1Protocol(object): Returns an http.Response By default, both response header and body are read. - If include_body=False is specified, content may be one of the + If include_body=False is specified, body may be one of the following: - None, if the response is technically allowed to have a response body - "", if the response must not have a response body (e.g. it's a response to a HEAD request) """ + timestamp_start = time.time() + if hasattr(self.tcp_handler.rfile, "reset_timestamps"): + self.tcp_handler.rfile.reset_timestamps() line = self.tcp_handler.rfile.readline() # Possible leftover from previous message @@ -149,7 +163,7 @@ class HTTP1Protocol(object): raise HttpError(502, "Invalid headers.") if include_body: - content = self.read_http_body( + body = self.read_http_body( headers, body_size_limit, request_method, @@ -157,10 +171,55 @@ class HTTP1Protocol(object): False ) else: - # if include_body==False then a None content means the body should be + # if include_body==False then a None body means the body should be # read separately - content = None - return http.Response(httpversion, code, msg, headers, content) + body = None + + + if hasattr(self.tcp_handler.rfile, "first_byte_timestamp"): + # more accurate timestamp_start + timestamp_start = self.tcp_handler.rfile.first_byte_timestamp + + if include_body: + timestamp_end = time.time() + else: + timestamp_end = None + + return http.Response( + httpversion, + code, + msg, + headers, + body, + timestamp_start=timestamp_start, + timestamp_end=timestamp_end, + ) + + + def assemble_request(self, request): + assert isinstance(request, semantics.Request) + + if request.body == semantics.CONTENT_MISSING: + raise http.HttpError( + 502, + "Cannot assemble flow with CONTENT_MISSING" + ) + first_line = self._assemble_request_first_line(request) + headers = self._assemble_request_headers(request) + return "%s\r\n%s\r\n%s" % (first_line, headers, request.body) + + + def assemble_response(self, response): + assert isinstance(response, semantics.Response) + + if response.body == semantics.CONTENT_MISSING: + raise http.HttpError( + 502, + "Cannot assemble flow with CONTENT_MISSING" + ) + first_line = self._assemble_response_first_line(response) + headers = self._assemble_response_headers(response) + return "%s\r\n%s\r\n%s" % (first_line, headers, response.body) def read_headers(self): @@ -331,7 +390,6 @@ class HTTP1Protocol(object): return line - def _read_chunked(self, limit, is_request): """ Read a chunked HTTP body. @@ -494,3 +552,74 @@ class HTTP1Protocol(object): except ValueError: return None return (proto, code, msg) + + + @classmethod + def _assemble_request_first_line(self, request): + if request.form_in == "relative": + request_line = '%s %s HTTP/%s.%s' % ( + request.method, + request.path, + request.httpversion[0], + request.httpversion[1], + ) + elif request.form_in == "authority": + request_line = '%s %s:%s HTTP/%s.%s' % ( + request.method, + request.host, + request.port, + request.httpversion[0], + request.httpversion[1], + ) + elif request.form_in == "absolute": + request_line = '%s %s://%s:%s%s HTTP/%s.%s' % ( + request.method, + request.scheme, + request.host, + request.port, + request.path, + request.httpversion[0], + request.httpversion[1], + ) + else: + raise http.HttpError(400, "Invalid request form") + return request_line + + def _assemble_request_headers(self, request): + headers = request.headers.copy() + for k in request._headers_to_strip_off: + del headers[k] + if 'host' not in headers and request.scheme and request.host and request.port: + headers["Host"] = [utils.hostport(request.scheme, + request.host, + request.port)] + + # If content is defined (i.e. not None or CONTENT_MISSING), we always + # add a content-length header. + if request.body or request.body == "": + headers["Content-Length"] = [str(len(request.body))] + + return headers.format() + + + def _assemble_response_first_line(self, response): + return 'HTTP/%s.%s %s %s' % ( + response.httpversion[0], + response.httpversion[1], + response.status_code, + response.msg, + ) + + def _assemble_response_headers(self, response, preserve_transfer_encoding=False): + headers = response.headers.copy() + for k in response._headers_to_strip_off: + del headers[k] + if not preserve_transfer_encoding: + del headers['Transfer-Encoding'] + + # If body is defined (i.e. not None or CONTENT_MISSING), we always + # add a content-length header. + if response.body or response.body == "": + headers["Content-Length"] = [str(len(response.body))] + + return headers.format() diff --git a/netlib/http/http2/protocol.py b/netlib/http/http2/protocol.py index 55b5ca76..a1ca4a18 100644 --- a/netlib/http/http2/protocol.py +++ b/netlib/http/http2/protocol.py @@ -1,12 +1,20 @@ from __future__ import (absolute_import, print_function, division) import itertools +import time from hpack.hpack import Encoder, Decoder from netlib import http, utils, odict +from netlib.http import semantics from . import frame -class HTTP2Protocol(object): +class TCPHandler(object): + def __init__(self, rfile, wfile=None): + self.rfile = rfile + self.wfile = wfile + + +class HTTP2Protocol(semantics.ProtocolMixin): ERROR_CODES = utils.BiDi( NO_ERROR=0x0, @@ -31,37 +39,146 @@ class HTTP2Protocol(object): ALPN_PROTO_H2 = 'h2' - def __init__(self, tcp_handler, is_server=False, dump_frames=False): - self.tcp_handler = tcp_handler + + def __init__( + self, + tcp_handler=None, + rfile=None, + wfile=None, + is_server=False, + dump_frames=False, + encoder=None, + decoder=None, + ): + self.tcp_handler = tcp_handler or TCPHandler(rfile, wfile) self.is_server = is_server + self.dump_frames = dump_frames + self.encoder = encoder or Encoder() + self.decoder = decoder or Decoder() self.http2_settings = frame.HTTP2_DEFAULT_SETTINGS.copy() self.current_stream_id = None - self.encoder = Encoder() - self.decoder = Decoder() self.connection_preface_performed = False - self.dump_frames = dump_frames - def check_alpn(self): - alp = self.tcp_handler.get_alpn_proto_negotiated() - if alp != self.ALPN_PROTO_H2: - raise NotImplementedError( - "HTTP2Protocol can not handle unknown ALP: %s" % alp) - return True + def read_request(self, include_body=True, body_size_limit=None, allow_empty=False): + self.perform_connection_preface() + + timestamp_start = time.time() + if hasattr(self.tcp_handler.rfile, "reset_timestamps"): + self.tcp_handler.rfile.reset_timestamps() + + stream_id, headers, body = self._receive_transmission(include_body) + + if hasattr(self.tcp_handler.rfile, "first_byte_timestamp"): + # more accurate timestamp_start + timestamp_start = self.tcp_handler.rfile.first_byte_timestamp + + timestamp_end = time.time() + + request = http.Request( + "relative", # TODO: use the correct value + headers.get_first(':method', 'GET'), + headers.get_first(':scheme', 'https'), + headers.get_first(':host', 'localhost'), + 443, # TODO: parse port number from host? + headers.get_first(':path', '/'), + (2, 0), + headers, + body, + timestamp_start, + timestamp_end, + ) + request.stream_id = stream_id - def _receive_settings(self, hide=False): - while True: - frm = self.read_frame(hide) - if isinstance(frm, frame.SettingsFrame): - break + return request - def _read_settings_ack(self, hide=False): # pragma no cover - while True: - frm = self.read_frame(hide) - if isinstance(frm, frame.SettingsFrame): - assert frm.flags & frame.Frame.FLAG_ACK - assert len(frm.settings) == 0 - break + def read_response(self, request_method='', body_size_limit=None, include_body=True): + self.perform_connection_preface() + + timestamp_start = time.time() + if hasattr(self.tcp_handler.rfile, "reset_timestamps"): + self.tcp_handler.rfile.reset_timestamps() + + stream_id, headers, body = self._receive_transmission(include_body) + + if hasattr(self.tcp_handler.rfile, "first_byte_timestamp"): + # more accurate timestamp_start + timestamp_start = self.tcp_handler.rfile.first_byte_timestamp + + if include_body: + timestamp_end = time.time() + else: + timestamp_end = None + + response = http.Response( + (2, 0), + int(headers.get_first(':status')), + "", + headers, + body, + timestamp_start=timestamp_start, + timestamp_end=timestamp_end, + ) + response.stream_id = stream_id + + return response + + + def assemble_request(self, request): + assert isinstance(request, semantics.Request) + + authority = self.tcp_handler.sni if self.tcp_handler.sni else self.tcp_handler.address.host + if self.tcp_handler.address.port != 443: + authority += ":%d" % self.tcp_handler.address.port + + headers = request.headers.copy() + + if not ':authority' in headers.keys(): + headers.add(':authority', bytes(authority), prepend=True) + if not ':scheme' in headers.keys(): + headers.add(':scheme', bytes(request.scheme), prepend=True) + if not ':path' in headers.keys(): + headers.add(':path', bytes(request.path), prepend=True) + if not ':method' in headers.keys(): + headers.add(':method', bytes(request.method), prepend=True) + + headers = headers.items() + + if hasattr(request, 'stream_id'): + stream_id = request.stream_id + else: + stream_id = self._next_stream_id() + + return list(itertools.chain( + self._create_headers(headers, stream_id, end_stream=(request.body is None or len(request.body) == 0)), + self._create_body(request.body, stream_id))) + + def assemble_response(self, response): + assert isinstance(response, semantics.Response) + + headers = response.headers.copy() + + if not ':status' in headers.keys(): + headers.add(':status', bytes(str(response.status_code)), prepend=True) + + headers = headers.items() + + if hasattr(response, 'stream_id'): + stream_id = response.stream_id + else: + stream_id = self._next_stream_id() + + return list(itertools.chain( + self._create_headers(headers, stream_id, end_stream=(response.body is None or len(response.body) == 0)), + self._create_body(response.body, stream_id), + )) + + def perform_connection_preface(self, force=False): + if force or not self.connection_preface_performed: + if self.is_server: + self.perform_server_connection_preface(force) + else: + self.perform_client_connection_preface(force) def perform_server_connection_preface(self, force=False): if force or not self.connection_preface_performed: @@ -83,18 +200,6 @@ class HTTP2Protocol(object): self.send_frame(frame.SettingsFrame(state=self), hide=True) self._receive_settings(hide=True) - def next_stream_id(self): - if self.current_stream_id is None: - if self.is_server: - # servers must use even stream ids - self.current_stream_id = 2 - else: - # clients must use odd stream ids - self.current_stream_id = 1 - else: - self.current_stream_id += 2 - return self.current_stream_id - def send_frame(self, frm, hide=False): raw_bytes = frm.to_bytes() self.tcp_handler.wfile.write(raw_bytes) @@ -111,6 +216,39 @@ class HTTP2Protocol(object): return frm + def check_alpn(self): + alp = self.tcp_handler.get_alpn_proto_negotiated() + if alp != self.ALPN_PROTO_H2: + raise NotImplementedError( + "HTTP2Protocol can not handle unknown ALP: %s" % alp) + return True + + def _receive_settings(self, hide=False): + while True: + frm = self.read_frame(hide) + if isinstance(frm, frame.SettingsFrame): + break + + def _read_settings_ack(self, hide=False): # pragma no cover + while True: + frm = self.read_frame(hide) + if isinstance(frm, frame.SettingsFrame): + assert frm.flags & frame.Frame.FLAG_ACK + assert len(frm.settings) == 0 + break + + def _next_stream_id(self): + if self.current_stream_id is None: + if self.is_server: + # servers must use even stream ids + self.current_stream_id = 2 + else: + # clients must use odd stream ids + self.current_stream_id = 1 + else: + self.current_stream_id += 2 + return self.current_stream_id + def _apply_settings(self, settings, hide=False): for setting, value in settings.items(): old_value = self.http2_settings[setting] @@ -164,51 +302,7 @@ class HTTP2Protocol(object): return [frm.to_bytes()] - - def create_request(self, method, path, headers=None, body=None): - if headers is None: - headers = [] - - authority = self.tcp_handler.sni if self.tcp_handler.sni else self.tcp_handler.address.host - if self.tcp_handler.address.port != 443: - authority += ":%d" % self.tcp_handler.address.port - - headers = [ - (b':method', bytes(method)), - (b':path', bytes(path)), - (b':scheme', b'https'), - (b':authority', authority), - ] + headers - - stream_id = self.next_stream_id() - - return list(itertools.chain( - self._create_headers(headers, stream_id, end_stream=(body is None)), - self._create_body(body, stream_id))) - - def read_response(self, *args): - stream_id, headers, body = self._receive_transmission() - - status = headers[':status'][0] - response = http.Response("HTTP/2", status, "", headers, body) - response.stream_id = stream_id - return response - - def read_request(self): - stream_id, headers, body = self._receive_transmission() - - form_in = "" - method = headers.get(':method', [''])[0] - scheme = headers.get(':scheme', [''])[0] - host = headers.get(':host', [''])[0] - port = '' # TODO: parse port number? - path = headers.get(':path', [''])[0] - - request = http.Request(form_in, method, scheme, host, port, path, "HTTP/2", headers, body) - request.stream_id = stream_id - return request - - def _receive_transmission(self): + def _receive_transmission(self, include_body=True): body_expected = True stream_id = 0 @@ -239,19 +333,3 @@ class HTTP2Protocol(object): headers.add(header, value) return stream_id, headers, body - - def create_response(self, code, stream_id=None, headers=None, body=None): - if headers is None: - headers = [] - if isinstance(headers, odict.ODict): - headers = headers.items() - - headers = [(b':status', bytes(str(code)))] + headers - - if not stream_id: - stream_id = self.next_stream_id() - - return list(itertools.chain( - self._create_headers(headers, stream_id, end_stream=(body is None)), - self._create_body(body, stream_id), - )) diff --git a/netlib/http/semantics.py b/netlib/http/semantics.py index 9e13edaa..54bf83d2 100644 --- a/netlib/http/semantics.py +++ b/netlib/http/semantics.py @@ -7,6 +7,32 @@ import urlparse from .. import utils, odict +CONTENT_MISSING = 0 + + +class ProtocolMixin(object): + + def read_request(self): + raise NotImplemented + + def read_response(self): + raise NotImplemented + + def assemble(self, message): + if isinstance(message, Request): + return self.assemble_request(message) + elif isinstance(message, Response): + return self.assemble_response(message) + else: + raise ValueError("HTTP message not supported.") + + def assemble_request(self, request): + raise NotImplemented + + def assemble_response(self, response): + raise NotImplemented + + class Request(object): def __init__( @@ -18,9 +44,15 @@ class Request(object): port, path, httpversion, - headers, - body, + headers=None, + body=None, + timestamp_start=None, + timestamp_end=None, ): + if not headers: + headers = odict.ODictCaseless() + assert isinstance(headers, odict.ODictCaseless) + self.form_in = form_in self.method = method self.scheme = scheme @@ -30,17 +62,31 @@ class Request(object): self.httpversion = httpversion self.headers = headers self.body = body + self.timestamp_start = timestamp_start + self.timestamp_end = timestamp_end + def __eq__(self, other): - return self.__dict__ == other.__dict__ + try: + self_d = [self.__dict__[k] for k in self.__dict__ if k not in ('timestamp_start', 'timestamp_end')] + other_d = [other.__dict__[k] for k in other.__dict__ if k not in ('timestamp_start', 'timestamp_end')] + return self_d == other_d + except: + return False def __repr__(self): return "Request(%s - %s, %s)" % (self.method, self.host, self.path) @property def content(self): + # TODO: remove deprecated getter return self.body + @content.setter + def content(self, content): + # TODO: remove deprecated setter + self.body = content + class EmptyRequest(Request): def __init__(self): @@ -63,28 +109,59 @@ class Response(object): self, httpversion, status_code, - msg, - headers, - body, + msg=None, + headers=None, + body=None, sslinfo=None, + timestamp_start=None, + timestamp_end=None, ): + if not headers: + headers = odict.ODictCaseless() + assert isinstance(headers, odict.ODictCaseless) + self.httpversion = httpversion self.status_code = status_code self.msg = msg self.headers = headers self.body = body self.sslinfo = sslinfo + self.timestamp_start = timestamp_start + self.timestamp_end = timestamp_end + def __eq__(self, other): - return self.__dict__ == other.__dict__ + try: + self_d = [self.__dict__[k] for k in self.__dict__ if k not in ('timestamp_start', 'timestamp_end')] + other_d = [other.__dict__[k] for k in other.__dict__ if k not in ('timestamp_start', 'timestamp_end')] + return self_d == other_d + except: + return False def __repr__(self): return "Response(%s - %s)" % (self.status_code, self.msg) @property def content(self): + # TODO: remove deprecated getter return self.body + @content.setter + def content(self, content): + # TODO: remove deprecated setter + self.body = content + + @property + def code(self): + # TODO: remove deprecated getter + return self.status_code + + @code.setter + def code(self, code): + # TODO: remove deprecated setter + self.status_code = code + + def is_valid_port(port): if not 0 <= port <= 65535: diff --git a/netlib/odict.py b/netlib/odict.py index f52acd50..d02de08d 100644 --- a/netlib/odict.py +++ b/netlib/odict.py @@ -96,8 +96,11 @@ class ODict(object): return True return False - def add(self, key, value): - self.lst.append([key, value]) + def add(self, key, value, prepend=False): + if prepend: + self.lst.insert(0, [key, value]) + else: + self.lst.append([key, value]) def get(self, k, d=None): if k in self: diff --git a/netlib/utils.py b/netlib/utils.py index bee412f9..86e33f33 100644 --- a/netlib/utils.py +++ b/netlib/utils.py @@ -129,3 +129,13 @@ class Data(object): if not os.path.exists(fullpath): raise ValueError("dataPath: %s does not exist." % fullpath) return fullpath + + +def hostport(scheme, host, port): + """ + Returns the host component, with a port specifcation if needed. + """ + if (port, scheme) in [(80, "http"), (443, "https")]: + return host + else: + return "%s:%s" % (host, port) |