mirror of
https://github.com/zhigang1992/mitmproxy.git
synced 2026-04-24 04:14:57 +08:00
http2: test throttling at MAX_CONCURRENT_STREAMS
This commit is contained in:
@@ -3,9 +3,10 @@
|
||||
from __future__ import (absolute_import, print_function, division)
|
||||
|
||||
import pytest
|
||||
import traceback
|
||||
import os
|
||||
import traceback
|
||||
import tempfile
|
||||
|
||||
import h2
|
||||
|
||||
from mitmproxy.proxy.config import ProxyConfig
|
||||
@@ -46,6 +47,11 @@ class _Http2ServerBase(netlib_tservers.ServerTestBase):
|
||||
self.wfile.write(h2_conn.data_to_send())
|
||||
self.wfile.flush()
|
||||
|
||||
if 'h2_server_settings' in self.kwargs:
|
||||
h2_conn.update_settings(self.kwargs['h2_server_settings'])
|
||||
self.wfile.write(h2_conn.data_to_send())
|
||||
self.wfile.flush()
|
||||
|
||||
done = False
|
||||
while not done:
|
||||
try:
|
||||
@@ -508,3 +514,69 @@ class TestConnectionLost(_Http2TestBase, _Http2ServerBase):
|
||||
|
||||
if len(self.master.state.flows) == 1:
|
||||
assert self.master.state.flows[0].response is None
|
||||
|
||||
|
||||
@requires_alpn
|
||||
class TestMaxConcurrentStreams(_Http2TestBase, _Http2ServerBase):
|
||||
|
||||
@classmethod
|
||||
def setup_class(self):
|
||||
_Http2TestBase.setup_class()
|
||||
_Http2ServerBase.setup_class(h2_server_settings={h2.settings.MAX_CONCURRENT_STREAMS: 2})
|
||||
|
||||
@classmethod
|
||||
def teardown_class(self):
|
||||
_Http2TestBase.teardown_class()
|
||||
_Http2ServerBase.teardown_class()
|
||||
|
||||
@classmethod
|
||||
def handle_server_event(self, event, h2_conn, rfile, wfile):
|
||||
if isinstance(event, h2.events.ConnectionTerminated):
|
||||
return False
|
||||
elif isinstance(event, h2.events.RequestReceived):
|
||||
h2_conn.send_headers(event.stream_id, [
|
||||
(':status', '200'),
|
||||
('X-Stream-ID', str(event.stream_id)),
|
||||
])
|
||||
h2_conn.send_data(event.stream_id, b'Stream-ID {}'.format(event.stream_id))
|
||||
h2_conn.end_stream(event.stream_id)
|
||||
wfile.write(h2_conn.data_to_send())
|
||||
wfile.flush()
|
||||
return True
|
||||
|
||||
def test_max_concurrent_streams(self):
|
||||
client, h2_conn = self._setup_connection()
|
||||
new_streams = [1, 3, 5, 7, 9, 11]
|
||||
for id in new_streams:
|
||||
# this will exceed MAX_CONCURRENT_STREAMS on the server connection
|
||||
# and cause mitmproxy to throttle stream creation to the server
|
||||
self._send_request(client.wfile, h2_conn, stream_id=id, headers=[
|
||||
(':authority', "127.0.0.1:%s" % self.server.server.address.port),
|
||||
(':method', 'GET'),
|
||||
(':scheme', 'https'),
|
||||
(':path', '/'),
|
||||
('X-Stream-ID', str(id)),
|
||||
])
|
||||
|
||||
ended_streams = 0
|
||||
while ended_streams != len(new_streams):
|
||||
try:
|
||||
header, body = framereader.http2_read_raw_frame(client.rfile)
|
||||
events = h2_conn.receive_data(b''.join([header, body]))
|
||||
except:
|
||||
break
|
||||
client.wfile.write(h2_conn.data_to_send())
|
||||
client.wfile.flush()
|
||||
|
||||
for event in events:
|
||||
if isinstance(event, h2.events.StreamEnded):
|
||||
ended_streams += 1
|
||||
|
||||
h2_conn.close_connection()
|
||||
client.wfile.write(h2_conn.data_to_send())
|
||||
client.wfile.flush()
|
||||
|
||||
assert len(self.master.state.flows) == len(new_streams)
|
||||
for flow in self.master.state.flows:
|
||||
assert flow.response.status_code == 200
|
||||
assert "Stream-ID" in flow.response.body
|
||||
|
||||
@@ -24,7 +24,7 @@ class _ServerThread(threading.Thread):
|
||||
|
||||
class _TServer(tcp.TCPServer):
|
||||
|
||||
def __init__(self, ssl, q, handler_klass, addr):
|
||||
def __init__(self, ssl, q, handler_klass, addr, **kwargs):
|
||||
"""
|
||||
ssl: A dictionary of SSL parameters:
|
||||
|
||||
@@ -42,6 +42,8 @@ class _TServer(tcp.TCPServer):
|
||||
|
||||
self.q = q
|
||||
self.handler_klass = handler_klass
|
||||
if self.handler_klass is not None:
|
||||
self.handler_klass.kwargs = kwargs
|
||||
self.last_handler = None
|
||||
|
||||
def handle_client_connection(self, request, client_address):
|
||||
@@ -89,16 +91,16 @@ class ServerTestBase(object):
|
||||
addr = ("localhost", 0)
|
||||
|
||||
@classmethod
|
||||
def setup_class(cls):
|
||||
def setup_class(cls, **kwargs):
|
||||
cls.q = queue.Queue()
|
||||
s = cls.makeserver()
|
||||
s = cls.makeserver(**kwargs)
|
||||
cls.port = s.address.port
|
||||
cls.server = _ServerThread(s)
|
||||
cls.server.start()
|
||||
|
||||
@classmethod
|
||||
def makeserver(cls):
|
||||
return _TServer(cls.ssl, cls.q, cls.handler, cls.addr)
|
||||
def makeserver(cls, **kwargs):
|
||||
return _TServer(cls.ssl, cls.q, cls.handler, cls.addr, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def teardown_class(cls):
|
||||
|
||||
Reference in New Issue
Block a user