http2: improve protocol

This commit is contained in:
Thomas Kriechbaumer
2015-07-30 13:52:13 +02:00
parent c7fcc2cca5
commit 7b10817670
3 changed files with 53 additions and 26 deletions

View File

@@ -60,7 +60,9 @@ class HTTP2Protocol(semantics.ProtocolMixin):
self.current_stream_id = None
self.connection_preface_performed = False
def read_request(self, include_body=True, body_size_limit_=None, allow_empty_=False):
def read_request(self, include_body=True, body_size_limit=None, allow_empty=False):
self.perform_connection_preface()
timestamp_start = time.time()
if hasattr(self.tcp_handler.rfile, "reset_timestamps"):
self.tcp_handler.rfile.reset_timestamps()
@@ -73,15 +75,13 @@ class HTTP2Protocol(semantics.ProtocolMixin):
timestamp_end = time.time()
port = '' # TODO: parse port number?
request = http.Request(
"",
headers.get_first(':method', ['']),
headers.get_first(':scheme', ['']),
headers.get_first(':host', ['']),
port,
headers.get_first(':path', ['']),
"relative", # TODO: use the correct value
headers.get_first(':method', 'GET'),
headers.get_first(':scheme', 'https'),
headers.get_first(':host', 'localhost'),
443, # TODO: parse port number from host?
headers.get_first(':path', '/'),
(2, 0),
headers,
body,
@@ -92,7 +92,9 @@ class HTTP2Protocol(semantics.ProtocolMixin):
return request
def read_response(self, request_method_='', body_size_limit_=None, include_body=True):
def read_response(self, request_method='', body_size_limit=None, include_body=True):
self.perform_connection_preface()
timestamp_start = time.time()
if hasattr(self.tcp_handler.rfile, "reset_timestamps"):
self.tcp_handler.rfile.reset_timestamps()
@@ -110,7 +112,7 @@ class HTTP2Protocol(semantics.ProtocolMixin):
response = http.Response(
(2, 0),
headers[':status'][0],
int(headers.get_first(':status')),
"",
headers,
body,
@@ -121,6 +123,7 @@ class HTTP2Protocol(semantics.ProtocolMixin):
return response
def assemble_request(self, request):
assert isinstance(request, semantics.Request)
@@ -128,12 +131,18 @@ class HTTP2Protocol(semantics.ProtocolMixin):
if self.tcp_handler.address.port != 443:
authority += ":%d" % self.tcp_handler.address.port
headers = [
(b':method', bytes(request.method)),
(b':path', bytes(request.path)),
(b':scheme', b'https'),
(b':authority', authority),
] + request.headers.items()
headers = request.headers.copy()
if not ':authority' in headers.keys():
headers.add(':authority', bytes(authority), prepend=True)
if not ':scheme' in headers.keys():
headers.add(':scheme', bytes(request.scheme), prepend=True)
if not ':path' in headers.keys():
headers.add(':path', bytes(request.path), prepend=True)
if not ':method' in headers.keys():
headers.add(':method', bytes(request.method), prepend=True)
headers = headers.items()
if hasattr(request, 'stream_id'):
stream_id = request.stream_id
@@ -141,13 +150,18 @@ class HTTP2Protocol(semantics.ProtocolMixin):
stream_id = self._next_stream_id()
return list(itertools.chain(
self._create_headers(headers, stream_id, end_stream=(request.body is None)),
self._create_headers(headers, stream_id, end_stream=(request.body is None or len(request.body) == 0)),
self._create_body(request.body, stream_id)))
def assemble_response(self, response):
assert isinstance(response, semantics.Response)
headers = [(b':status', bytes(str(response.status_code)))] + response.headers.items()
headers = response.headers.copy()
if not ':status' in headers.keys():
headers.add(':status', bytes(str(response.status_code)), prepend=True)
headers = headers.items()
if hasattr(response, 'stream_id'):
stream_id = response.stream_id
@@ -155,10 +169,17 @@ class HTTP2Protocol(semantics.ProtocolMixin):
stream_id = self._next_stream_id()
return list(itertools.chain(
self._create_headers(headers, stream_id, end_stream=(response.body is None)),
self._create_headers(headers, stream_id, end_stream=(response.body is None or len(response.body) == 0)),
self._create_body(response.body, stream_id),
))
def perform_connection_preface(self, force=False):
if force or not self.connection_preface_performed:
if self.is_server:
self.perform_server_connection_preface(force)
else:
self.perform_client_connection_preface(force)
def perform_server_connection_preface(self, force=False):
if force or not self.connection_preface_performed:
self.connection_preface_performed = True

View File

@@ -96,8 +96,11 @@ class ODict(object):
return True
return False
def add(self, key, value):
self.lst.append([key, value])
def add(self, key, value, prepend=False):
if prepend:
self.lst.insert(0, [key, value])
else:
self.lst.append([key, value])
def get(self, k, d=None):
if k in self:

View File

@@ -222,7 +222,7 @@ class TestAssembleRequest():
bytes = http2.HTTP2Protocol(self.c).assemble_request(http.Request(
'',
'GET',
'',
'https',
'',
'',
'/',
@@ -237,7 +237,7 @@ class TestAssembleRequest():
bytes = http2.HTTP2Protocol(self.c).assemble_request(http.Request(
'',
'GET',
'',
'https',
'',
'',
'/',
@@ -269,11 +269,12 @@ class TestReadResponse(tservers.ServerTestBase):
c.connect()
c.convert_to_ssl()
protocol = http2.HTTP2Protocol(c)
protocol.connection_preface_performed = True
resp = protocol.read_response()
assert resp.httpversion == (2, 0)
assert resp.status_code == "200"
assert resp.status_code == 200
assert resp.msg == ""
assert resp.headers.lst == [[':status', '200'], ['etag', 'foobar']]
assert resp.body == b'foobar'
@@ -294,12 +295,13 @@ class TestReadEmptyResponse(tservers.ServerTestBase):
c.connect()
c.convert_to_ssl()
protocol = http2.HTTP2Protocol(c)
protocol.connection_preface_performed = True
resp = protocol.read_response()
assert resp.stream_id
assert resp.httpversion == (2, 0)
assert resp.status_code == "200"
assert resp.status_code == 200
assert resp.msg == ""
assert resp.headers.lst == [[':status', '200'], ['etag', 'foobar']]
assert resp.body == b''
@@ -322,6 +324,7 @@ class TestReadRequest(tservers.ServerTestBase):
c.connect()
c.convert_to_ssl()
protocol = http2.HTTP2Protocol(c, is_server=True)
protocol.connection_preface_performed = True
resp = protocol.read_request()