mirror of
https://github.com/zhigang1992/mitmproxy.git
synced 2026-04-29 04:35:02 +08:00
http2: improve protocol
This commit is contained in:
@@ -60,7 +60,9 @@ class HTTP2Protocol(semantics.ProtocolMixin):
|
||||
self.current_stream_id = None
|
||||
self.connection_preface_performed = False
|
||||
|
||||
def read_request(self, include_body=True, body_size_limit_=None, allow_empty_=False):
|
||||
def read_request(self, include_body=True, body_size_limit=None, allow_empty=False):
|
||||
self.perform_connection_preface()
|
||||
|
||||
timestamp_start = time.time()
|
||||
if hasattr(self.tcp_handler.rfile, "reset_timestamps"):
|
||||
self.tcp_handler.rfile.reset_timestamps()
|
||||
@@ -73,15 +75,13 @@ class HTTP2Protocol(semantics.ProtocolMixin):
|
||||
|
||||
timestamp_end = time.time()
|
||||
|
||||
port = '' # TODO: parse port number?
|
||||
|
||||
request = http.Request(
|
||||
"",
|
||||
headers.get_first(':method', ['']),
|
||||
headers.get_first(':scheme', ['']),
|
||||
headers.get_first(':host', ['']),
|
||||
port,
|
||||
headers.get_first(':path', ['']),
|
||||
"relative", # TODO: use the correct value
|
||||
headers.get_first(':method', 'GET'),
|
||||
headers.get_first(':scheme', 'https'),
|
||||
headers.get_first(':host', 'localhost'),
|
||||
443, # TODO: parse port number from host?
|
||||
headers.get_first(':path', '/'),
|
||||
(2, 0),
|
||||
headers,
|
||||
body,
|
||||
@@ -92,7 +92,9 @@ class HTTP2Protocol(semantics.ProtocolMixin):
|
||||
|
||||
return request
|
||||
|
||||
def read_response(self, request_method_='', body_size_limit_=None, include_body=True):
|
||||
def read_response(self, request_method='', body_size_limit=None, include_body=True):
|
||||
self.perform_connection_preface()
|
||||
|
||||
timestamp_start = time.time()
|
||||
if hasattr(self.tcp_handler.rfile, "reset_timestamps"):
|
||||
self.tcp_handler.rfile.reset_timestamps()
|
||||
@@ -110,7 +112,7 @@ class HTTP2Protocol(semantics.ProtocolMixin):
|
||||
|
||||
response = http.Response(
|
||||
(2, 0),
|
||||
headers[':status'][0],
|
||||
int(headers.get_first(':status')),
|
||||
"",
|
||||
headers,
|
||||
body,
|
||||
@@ -121,6 +123,7 @@ class HTTP2Protocol(semantics.ProtocolMixin):
|
||||
|
||||
return response
|
||||
|
||||
|
||||
def assemble_request(self, request):
|
||||
assert isinstance(request, semantics.Request)
|
||||
|
||||
@@ -128,12 +131,18 @@ class HTTP2Protocol(semantics.ProtocolMixin):
|
||||
if self.tcp_handler.address.port != 443:
|
||||
authority += ":%d" % self.tcp_handler.address.port
|
||||
|
||||
headers = [
|
||||
(b':method', bytes(request.method)),
|
||||
(b':path', bytes(request.path)),
|
||||
(b':scheme', b'https'),
|
||||
(b':authority', authority),
|
||||
] + request.headers.items()
|
||||
headers = request.headers.copy()
|
||||
|
||||
if not ':authority' in headers.keys():
|
||||
headers.add(':authority', bytes(authority), prepend=True)
|
||||
if not ':scheme' in headers.keys():
|
||||
headers.add(':scheme', bytes(request.scheme), prepend=True)
|
||||
if not ':path' in headers.keys():
|
||||
headers.add(':path', bytes(request.path), prepend=True)
|
||||
if not ':method' in headers.keys():
|
||||
headers.add(':method', bytes(request.method), prepend=True)
|
||||
|
||||
headers = headers.items()
|
||||
|
||||
if hasattr(request, 'stream_id'):
|
||||
stream_id = request.stream_id
|
||||
@@ -141,13 +150,18 @@ class HTTP2Protocol(semantics.ProtocolMixin):
|
||||
stream_id = self._next_stream_id()
|
||||
|
||||
return list(itertools.chain(
|
||||
self._create_headers(headers, stream_id, end_stream=(request.body is None)),
|
||||
self._create_headers(headers, stream_id, end_stream=(request.body is None or len(request.body) == 0)),
|
||||
self._create_body(request.body, stream_id)))
|
||||
|
||||
def assemble_response(self, response):
|
||||
assert isinstance(response, semantics.Response)
|
||||
|
||||
headers = [(b':status', bytes(str(response.status_code)))] + response.headers.items()
|
||||
headers = response.headers.copy()
|
||||
|
||||
if not ':status' in headers.keys():
|
||||
headers.add(':status', bytes(str(response.status_code)), prepend=True)
|
||||
|
||||
headers = headers.items()
|
||||
|
||||
if hasattr(response, 'stream_id'):
|
||||
stream_id = response.stream_id
|
||||
@@ -155,10 +169,17 @@ class HTTP2Protocol(semantics.ProtocolMixin):
|
||||
stream_id = self._next_stream_id()
|
||||
|
||||
return list(itertools.chain(
|
||||
self._create_headers(headers, stream_id, end_stream=(response.body is None)),
|
||||
self._create_headers(headers, stream_id, end_stream=(response.body is None or len(response.body) == 0)),
|
||||
self._create_body(response.body, stream_id),
|
||||
))
|
||||
|
||||
def perform_connection_preface(self, force=False):
|
||||
if force or not self.connection_preface_performed:
|
||||
if self.is_server:
|
||||
self.perform_server_connection_preface(force)
|
||||
else:
|
||||
self.perform_client_connection_preface(force)
|
||||
|
||||
def perform_server_connection_preface(self, force=False):
|
||||
if force or not self.connection_preface_performed:
|
||||
self.connection_preface_performed = True
|
||||
|
||||
@@ -96,8 +96,11 @@ class ODict(object):
|
||||
return True
|
||||
return False
|
||||
|
||||
def add(self, key, value):
|
||||
self.lst.append([key, value])
|
||||
def add(self, key, value, prepend=False):
|
||||
if prepend:
|
||||
self.lst.insert(0, [key, value])
|
||||
else:
|
||||
self.lst.append([key, value])
|
||||
|
||||
def get(self, k, d=None):
|
||||
if k in self:
|
||||
|
||||
@@ -222,7 +222,7 @@ class TestAssembleRequest():
|
||||
bytes = http2.HTTP2Protocol(self.c).assemble_request(http.Request(
|
||||
'',
|
||||
'GET',
|
||||
'',
|
||||
'https',
|
||||
'',
|
||||
'',
|
||||
'/',
|
||||
@@ -237,7 +237,7 @@ class TestAssembleRequest():
|
||||
bytes = http2.HTTP2Protocol(self.c).assemble_request(http.Request(
|
||||
'',
|
||||
'GET',
|
||||
'',
|
||||
'https',
|
||||
'',
|
||||
'',
|
||||
'/',
|
||||
@@ -269,11 +269,12 @@ class TestReadResponse(tservers.ServerTestBase):
|
||||
c.connect()
|
||||
c.convert_to_ssl()
|
||||
protocol = http2.HTTP2Protocol(c)
|
||||
protocol.connection_preface_performed = True
|
||||
|
||||
resp = protocol.read_response()
|
||||
|
||||
assert resp.httpversion == (2, 0)
|
||||
assert resp.status_code == "200"
|
||||
assert resp.status_code == 200
|
||||
assert resp.msg == ""
|
||||
assert resp.headers.lst == [[':status', '200'], ['etag', 'foobar']]
|
||||
assert resp.body == b'foobar'
|
||||
@@ -294,12 +295,13 @@ class TestReadEmptyResponse(tservers.ServerTestBase):
|
||||
c.connect()
|
||||
c.convert_to_ssl()
|
||||
protocol = http2.HTTP2Protocol(c)
|
||||
protocol.connection_preface_performed = True
|
||||
|
||||
resp = protocol.read_response()
|
||||
|
||||
assert resp.stream_id
|
||||
assert resp.httpversion == (2, 0)
|
||||
assert resp.status_code == "200"
|
||||
assert resp.status_code == 200
|
||||
assert resp.msg == ""
|
||||
assert resp.headers.lst == [[':status', '200'], ['etag', 'foobar']]
|
||||
assert resp.body == b''
|
||||
@@ -322,6 +324,7 @@ class TestReadRequest(tservers.ServerTestBase):
|
||||
c.connect()
|
||||
c.convert_to_ssl()
|
||||
protocol = http2.HTTP2Protocol(c, is_server=True)
|
||||
protocol.connection_preface_performed = True
|
||||
|
||||
resp = protocol.read_request()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user