aboutsummaryrefslogtreecommitdiffstats
path: root/netlib/http/http2
diff options
context:
space:
mode:
Diffstat (limited to 'netlib/http/http2')
-rw-r--r--netlib/http/http2/protocol.py272
1 files changed, 175 insertions, 97 deletions
diff --git a/netlib/http/http2/protocol.py b/netlib/http/http2/protocol.py
index 55b5ca76..a1ca4a18 100644
--- a/netlib/http/http2/protocol.py
+++ b/netlib/http/http2/protocol.py
@@ -1,12 +1,20 @@
from __future__ import (absolute_import, print_function, division)
import itertools
+import time
from hpack.hpack import Encoder, Decoder
from netlib import http, utils, odict
+from netlib.http import semantics
from . import frame
-class HTTP2Protocol(object):
+class TCPHandler(object):
+ def __init__(self, rfile, wfile=None):
+ self.rfile = rfile
+ self.wfile = wfile
+
+
+class HTTP2Protocol(semantics.ProtocolMixin):
ERROR_CODES = utils.BiDi(
NO_ERROR=0x0,
@@ -31,37 +39,146 @@ class HTTP2Protocol(object):
ALPN_PROTO_H2 = 'h2'
- def __init__(self, tcp_handler, is_server=False, dump_frames=False):
- self.tcp_handler = tcp_handler
+
+ def __init__(
+ self,
+ tcp_handler=None,
+ rfile=None,
+ wfile=None,
+ is_server=False,
+ dump_frames=False,
+ encoder=None,
+ decoder=None,
+ ):
+ self.tcp_handler = tcp_handler or TCPHandler(rfile, wfile)
self.is_server = is_server
+ self.dump_frames = dump_frames
+ self.encoder = encoder or Encoder()
+ self.decoder = decoder or Decoder()
self.http2_settings = frame.HTTP2_DEFAULT_SETTINGS.copy()
self.current_stream_id = None
- self.encoder = Encoder()
- self.decoder = Decoder()
self.connection_preface_performed = False
- self.dump_frames = dump_frames
- def check_alpn(self):
- alp = self.tcp_handler.get_alpn_proto_negotiated()
- if alp != self.ALPN_PROTO_H2:
- raise NotImplementedError(
- "HTTP2Protocol can not handle unknown ALP: %s" % alp)
- return True
+ def read_request(self, include_body=True, body_size_limit=None, allow_empty=False):
+ self.perform_connection_preface()
+
+ timestamp_start = time.time()
+ if hasattr(self.tcp_handler.rfile, "reset_timestamps"):
+ self.tcp_handler.rfile.reset_timestamps()
+
+ stream_id, headers, body = self._receive_transmission(include_body)
+
+ if hasattr(self.tcp_handler.rfile, "first_byte_timestamp"):
+ # more accurate timestamp_start
+ timestamp_start = self.tcp_handler.rfile.first_byte_timestamp
+
+ timestamp_end = time.time()
+
+ request = http.Request(
+ "relative", # TODO: use the correct value
+ headers.get_first(':method', 'GET'),
+ headers.get_first(':scheme', 'https'),
+ headers.get_first(':host', 'localhost'),
+ 443, # TODO: parse port number from host?
+ headers.get_first(':path', '/'),
+ (2, 0),
+ headers,
+ body,
+ timestamp_start,
+ timestamp_end,
+ )
+ request.stream_id = stream_id
- def _receive_settings(self, hide=False):
- while True:
- frm = self.read_frame(hide)
- if isinstance(frm, frame.SettingsFrame):
- break
+ return request
- def _read_settings_ack(self, hide=False): # pragma no cover
- while True:
- frm = self.read_frame(hide)
- if isinstance(frm, frame.SettingsFrame):
- assert frm.flags & frame.Frame.FLAG_ACK
- assert len(frm.settings) == 0
- break
+ def read_response(self, request_method='', body_size_limit=None, include_body=True):
+ self.perform_connection_preface()
+
+ timestamp_start = time.time()
+ if hasattr(self.tcp_handler.rfile, "reset_timestamps"):
+ self.tcp_handler.rfile.reset_timestamps()
+
+ stream_id, headers, body = self._receive_transmission(include_body)
+
+ if hasattr(self.tcp_handler.rfile, "first_byte_timestamp"):
+ # more accurate timestamp_start
+ timestamp_start = self.tcp_handler.rfile.first_byte_timestamp
+
+ if include_body:
+ timestamp_end = time.time()
+ else:
+ timestamp_end = None
+
+ response = http.Response(
+ (2, 0),
+ int(headers.get_first(':status')),
+ "",
+ headers,
+ body,
+ timestamp_start=timestamp_start,
+ timestamp_end=timestamp_end,
+ )
+ response.stream_id = stream_id
+
+ return response
+
+
+ def assemble_request(self, request):
+ assert isinstance(request, semantics.Request)
+
+ authority = self.tcp_handler.sni if self.tcp_handler.sni else self.tcp_handler.address.host
+ if self.tcp_handler.address.port != 443:
+ authority += ":%d" % self.tcp_handler.address.port
+
+ headers = request.headers.copy()
+
+ if not ':authority' in headers.keys():
+ headers.add(':authority', bytes(authority), prepend=True)
+ if not ':scheme' in headers.keys():
+ headers.add(':scheme', bytes(request.scheme), prepend=True)
+ if not ':path' in headers.keys():
+ headers.add(':path', bytes(request.path), prepend=True)
+ if not ':method' in headers.keys():
+ headers.add(':method', bytes(request.method), prepend=True)
+
+ headers = headers.items()
+
+ if hasattr(request, 'stream_id'):
+ stream_id = request.stream_id
+ else:
+ stream_id = self._next_stream_id()
+
+ return list(itertools.chain(
+ self._create_headers(headers, stream_id, end_stream=(request.body is None or len(request.body) == 0)),
+ self._create_body(request.body, stream_id)))
+
+ def assemble_response(self, response):
+ assert isinstance(response, semantics.Response)
+
+ headers = response.headers.copy()
+
+ if not ':status' in headers.keys():
+ headers.add(':status', bytes(str(response.status_code)), prepend=True)
+
+ headers = headers.items()
+
+ if hasattr(response, 'stream_id'):
+ stream_id = response.stream_id
+ else:
+ stream_id = self._next_stream_id()
+
+ return list(itertools.chain(
+ self._create_headers(headers, stream_id, end_stream=(response.body is None or len(response.body) == 0)),
+ self._create_body(response.body, stream_id),
+ ))
+
+ def perform_connection_preface(self, force=False):
+ if force or not self.connection_preface_performed:
+ if self.is_server:
+ self.perform_server_connection_preface(force)
+ else:
+ self.perform_client_connection_preface(force)
def perform_server_connection_preface(self, force=False):
if force or not self.connection_preface_performed:
@@ -83,18 +200,6 @@ class HTTP2Protocol(object):
self.send_frame(frame.SettingsFrame(state=self), hide=True)
self._receive_settings(hide=True)
- def next_stream_id(self):
- if self.current_stream_id is None:
- if self.is_server:
- # servers must use even stream ids
- self.current_stream_id = 2
- else:
- # clients must use odd stream ids
- self.current_stream_id = 1
- else:
- self.current_stream_id += 2
- return self.current_stream_id
-
def send_frame(self, frm, hide=False):
raw_bytes = frm.to_bytes()
self.tcp_handler.wfile.write(raw_bytes)
@@ -111,6 +216,39 @@ class HTTP2Protocol(object):
return frm
+ def check_alpn(self):
+ alp = self.tcp_handler.get_alpn_proto_negotiated()
+ if alp != self.ALPN_PROTO_H2:
+ raise NotImplementedError(
+ "HTTP2Protocol can not handle unknown ALP: %s" % alp)
+ return True
+
+ def _receive_settings(self, hide=False):
+ while True:
+ frm = self.read_frame(hide)
+ if isinstance(frm, frame.SettingsFrame):
+ break
+
+ def _read_settings_ack(self, hide=False): # pragma no cover
+ while True:
+ frm = self.read_frame(hide)
+ if isinstance(frm, frame.SettingsFrame):
+ assert frm.flags & frame.Frame.FLAG_ACK
+ assert len(frm.settings) == 0
+ break
+
+ def _next_stream_id(self):
+ if self.current_stream_id is None:
+ if self.is_server:
+ # servers must use even stream ids
+ self.current_stream_id = 2
+ else:
+ # clients must use odd stream ids
+ self.current_stream_id = 1
+ else:
+ self.current_stream_id += 2
+ return self.current_stream_id
+
def _apply_settings(self, settings, hide=False):
for setting, value in settings.items():
old_value = self.http2_settings[setting]
@@ -164,51 +302,7 @@ class HTTP2Protocol(object):
return [frm.to_bytes()]
-
- def create_request(self, method, path, headers=None, body=None):
- if headers is None:
- headers = []
-
- authority = self.tcp_handler.sni if self.tcp_handler.sni else self.tcp_handler.address.host
- if self.tcp_handler.address.port != 443:
- authority += ":%d" % self.tcp_handler.address.port
-
- headers = [
- (b':method', bytes(method)),
- (b':path', bytes(path)),
- (b':scheme', b'https'),
- (b':authority', authority),
- ] + headers
-
- stream_id = self.next_stream_id()
-
- return list(itertools.chain(
- self._create_headers(headers, stream_id, end_stream=(body is None)),
- self._create_body(body, stream_id)))
-
- def read_response(self, *args):
- stream_id, headers, body = self._receive_transmission()
-
- status = headers[':status'][0]
- response = http.Response("HTTP/2", status, "", headers, body)
- response.stream_id = stream_id
- return response
-
- def read_request(self):
- stream_id, headers, body = self._receive_transmission()
-
- form_in = ""
- method = headers.get(':method', [''])[0]
- scheme = headers.get(':scheme', [''])[0]
- host = headers.get(':host', [''])[0]
- port = '' # TODO: parse port number?
- path = headers.get(':path', [''])[0]
-
- request = http.Request(form_in, method, scheme, host, port, path, "HTTP/2", headers, body)
- request.stream_id = stream_id
- return request
-
- def _receive_transmission(self):
+ def _receive_transmission(self, include_body=True):
body_expected = True
stream_id = 0
@@ -239,19 +333,3 @@ class HTTP2Protocol(object):
headers.add(header, value)
return stream_id, headers, body
-
- def create_response(self, code, stream_id=None, headers=None, body=None):
- if headers is None:
- headers = []
- if isinstance(headers, odict.ODict):
- headers = headers.items()
-
- headers = [(b':status', bytes(str(code)))] + headers
-
- if not stream_id:
- stream_id = self.next_stream_id()
-
- return list(itertools.chain(
- self._create_headers(headers, stream_id, end_stream=(body is None)),
- self._create_body(body, stream_id),
- ))