mirror of
https://github.com/zhigang1992/mitmproxy.git
synced 2026-06-19 01:36:36 +08:00
add websockets support to mitmproxy
This commit is contained in:
@@ -7,12 +7,15 @@ import traceback
|
||||
import h2.exceptions
|
||||
import six
|
||||
|
||||
import netlib.exceptions
|
||||
from mitmproxy import exceptions
|
||||
from mitmproxy import models
|
||||
from mitmproxy.protocol import base
|
||||
from .websockets import WebSocketsLayer
|
||||
|
||||
import netlib.exceptions
|
||||
from netlib import http
|
||||
from netlib import tcp
|
||||
from netlib import websockets
|
||||
|
||||
|
||||
class _HttpTransmissionLayer(base.Layer):
|
||||
@@ -189,6 +192,21 @@ class HttpLayer(base.Layer):
|
||||
self.process_request_hook(flow)
|
||||
|
||||
try:
|
||||
# WebSockets
|
||||
if websockets.check_handshake(request.headers):
|
||||
if websockets.check_client_version(request.headers):
|
||||
layer = WebSocketsLayer(self, request)
|
||||
layer()
|
||||
return
|
||||
else:
|
||||
# we only support RFC6455 with WebSockets version 13
|
||||
self.send_response(models.make_error_response(
|
||||
400,
|
||||
http.status_codes.RESPONSES.get(400),
|
||||
http.Headers(sec_websocket_version="13")
|
||||
))
|
||||
return
|
||||
|
||||
if not flow.response:
|
||||
self.establish_server_connection(
|
||||
flow.request.host,
|
||||
|
||||
140
mitmproxy/protocol/websockets.py
Normal file
140
mitmproxy/protocol/websockets.py
Normal file
@@ -0,0 +1,140 @@
|
||||
from __future__ import absolute_import, print_function, division
|
||||
|
||||
import socket
|
||||
import struct
|
||||
|
||||
from OpenSSL import SSL
|
||||
|
||||
from mitmproxy import exceptions
|
||||
from mitmproxy import models
|
||||
from mitmproxy.protocol import base
|
||||
|
||||
import netlib.exceptions
|
||||
from netlib import tcp
|
||||
from netlib import http
|
||||
from netlib import websockets
|
||||
|
||||
|
||||
class WebSocketsLayer(base.Layer):
|
||||
"""
|
||||
WebSockets layer to intercept, modify, and forward WebSockets connections
|
||||
|
||||
Only version 13 is supported (as specified in RFC6455)
|
||||
Only HTTP/1.1-initiated connections are supported.
|
||||
|
||||
The client starts by sending an Upgrade-request.
|
||||
In order to determine the handshake and negotiate the correct protocol
|
||||
and extensions, the Upgrade-request is forwarded to the server.
|
||||
The response from the server is then parsed and negotiated settings are extracted.
|
||||
Finally the handshake is completed by forwarding the server-response to the client.
|
||||
After that, only WebSockets frames are exchanged.
|
||||
|
||||
PING/PONG frames pass through and must be answered by the other endpoint.
|
||||
|
||||
CLOSE frames are forwarded before this WebSocketsLayer terminates.
|
||||
|
||||
This layer is transparent to any negotiated extensions.
|
||||
This layer is transparent to any negotiated subprotocols.
|
||||
Only raw frames are forwarded to the other endpoint.
|
||||
"""
|
||||
|
||||
def __init__(self, ctx, request):
|
||||
super(WebSocketsLayer, self).__init__(ctx)
|
||||
self._request = request
|
||||
|
||||
self.client_key = websockets.get_client_key(self._request.headers)
|
||||
self.client_protocol = websockets.get_protocol(self._request.headers)
|
||||
self.client_extensions = websockets.get_extensions(self._request.headers)
|
||||
|
||||
self.server_accept = None
|
||||
self.server_protocol = None
|
||||
self.server_extensions = None
|
||||
|
||||
def _initiate_server_conn(self):
|
||||
self.establish_server_connection(
|
||||
self._request.host,
|
||||
self._request.port,
|
||||
self._request.scheme,
|
||||
)
|
||||
|
||||
self.server_conn.send(netlib.http.http1.assemble_request(self._request))
|
||||
response = netlib.http.http1.read_response(self.server_conn.rfile, self._request, body_size_limit=None)
|
||||
|
||||
if not websockets.check_handshake(response.headers):
|
||||
raise exceptions.ProtocolException("Establishing WebSockets connection with server failed: {}".format(response.headers))
|
||||
|
||||
self.server_accept = websockets.get_server_accept(response.headers)
|
||||
self.server_protocol = websockets.get_protocol(response.headers)
|
||||
self.server_extensions = websockets.get_extensions(response.headers)
|
||||
|
||||
def _complete_handshake(self):
|
||||
headers = websockets.server_handshake_headers(self.client_key, self.server_protocol, self.server_extensions)
|
||||
self.send_response(models.HTTPResponse(
|
||||
self._request.http_version,
|
||||
101,
|
||||
http.status_codes.RESPONSES.get(101),
|
||||
headers,
|
||||
b"",
|
||||
))
|
||||
|
||||
def _handle_frame(self, frame, source_conn, other_conn, is_server):
|
||||
self.log(
|
||||
"WebSockets Frame received from {}".format("server" if is_server else "client"),
|
||||
"debug",
|
||||
[repr(frame)]
|
||||
)
|
||||
|
||||
if frame.header.opcode & 0x8 == 0:
|
||||
# forward the data frame to the other side
|
||||
other_conn.send(bytes(frame))
|
||||
self.log("WebSockets frame received by {}: {}".format(is_server, frame), "debug")
|
||||
elif frame.header.opcode in (websockets.OPCODE.PING, websockets.OPCODE.PONG):
|
||||
# just forward the ping/pong to the other side
|
||||
other_conn.send(bytes(frame))
|
||||
elif frame.header.opcode == websockets.OPCODE.CLOSE:
|
||||
other_conn.send(bytes(frame))
|
||||
|
||||
code = '(status code missing)'
|
||||
msg = None
|
||||
reason = '(message missing)'
|
||||
if len(frame.payload) >= 2:
|
||||
code, = struct.unpack('!H', frame.payload[:2])
|
||||
msg = websockets.CLOSE_REASON.get_name(code, default='unknown status code')
|
||||
if len(frame.payload) > 2:
|
||||
reason = frame.payload[2:]
|
||||
self.log("WebSockets connection closed: {} {}, {}".format(code, msg, reason), "info")
|
||||
|
||||
# close the connection
|
||||
return False
|
||||
else:
|
||||
# unknown frame - just forward it
|
||||
other_conn.send(bytes(frame))
|
||||
|
||||
# continue the connection
|
||||
return True
|
||||
|
||||
def __call__(self):
|
||||
self._initiate_server_conn()
|
||||
self._complete_handshake()
|
||||
|
||||
client = self.client_conn.connection
|
||||
server = self.server_conn.connection
|
||||
conns = [client, server]
|
||||
|
||||
try:
|
||||
while not self.channel.should_exit.is_set():
|
||||
r = tcp.ssl_read_select(conns, 1)
|
||||
for conn in r:
|
||||
source_conn = self.client_conn if conn == client else self.server_conn
|
||||
other_conn = self.server_conn if conn == client else self.client_conn
|
||||
is_server = (conn == self.server_conn.connection)
|
||||
|
||||
frame = websockets.Frame.from_file(source_conn.rfile)
|
||||
|
||||
if not self._handle_frame(frame, source_conn, other_conn, is_server):
|
||||
return
|
||||
except (socket.error, netlib.exceptions.TcpException, SSL.Error) as e:
|
||||
self.log("WebSockets connection closed unexpectedly by {}: {}".format(
|
||||
"server" if is_server else "client", repr(e)), "info")
|
||||
except Exception as e: # pragma: no cover
|
||||
raise exceptions.ProtocolException("Error in WebSockets connection: {}".format(repr(e)))
|
||||
@@ -198,7 +198,7 @@ class Response(_HTTPMessage):
|
||||
1,
|
||||
StatusCode(101)
|
||||
)
|
||||
headers = netlib.websockets.WebsocketsProtocol.server_handshake_headers(
|
||||
headers = netlib.websockets.server_handshake_headers(
|
||||
settings.websocket_key
|
||||
)
|
||||
for i in headers.fields:
|
||||
@@ -310,7 +310,7 @@ class Request(_HTTPMessage):
|
||||
1,
|
||||
Method("get")
|
||||
)
|
||||
for i in netlib.websockets.WebsocketsProtocol.client_handshake_headers().fields:
|
||||
for i in netlib.websockets.client_handshake_headers().fields:
|
||||
if not get_header(i[0], self.headers):
|
||||
tokens.append(
|
||||
Header(
|
||||
|
||||
@@ -139,7 +139,7 @@ class WebsocketFrameReader(basethread.BaseThread):
|
||||
except exceptions.TcpDisconnect:
|
||||
return
|
||||
self.frames_queue.put(frm)
|
||||
log("<< %s" % frm.header.human_readable())
|
||||
log("<< %s" % repr(frm.header))
|
||||
if self.ws_read_limit is not None:
|
||||
self.ws_read_limit -= 1
|
||||
starttime = time.time()
|
||||
|
||||
@@ -173,12 +173,13 @@ class PathodHandler(tcp.BaseHandler):
|
||||
retlog["cipher"] = self.get_current_cipher()
|
||||
|
||||
m = utils.MemBool()
|
||||
websocket_key = websockets.WebsocketsProtocol.check_client_handshake(headers)
|
||||
self.settings.websocket_key = websocket_key
|
||||
|
||||
valid_websockets_handshake = websockets.check_handshake(headers)
|
||||
self.settings.websocket_key = websockets.get_client_key(headers)
|
||||
|
||||
# If this is a websocket initiation, we respond with a proper
|
||||
# server response, unless over-ridden.
|
||||
if websocket_key:
|
||||
if valid_websockets_handshake:
|
||||
anchor_gen = language.parse_pathod("ws")
|
||||
else:
|
||||
anchor_gen = None
|
||||
@@ -225,7 +226,7 @@ class PathodHandler(tcp.BaseHandler):
|
||||
spec,
|
||||
lg
|
||||
)
|
||||
if nexthandler and websocket_key:
|
||||
if nexthandler and valid_websockets_handshake:
|
||||
self.protocol = protocols.websockets.WebsocketsProtocol(self)
|
||||
return self.protocol.handle_websocket, retlog
|
||||
else:
|
||||
|
||||
@@ -20,7 +20,7 @@ class WebsocketsProtocol:
|
||||
lg("Error reading websocket frame: %s" % e)
|
||||
return None, None
|
||||
ended = time.time()
|
||||
lg(frm.human_readable())
|
||||
lg(repr(frm))
|
||||
retlog = dict(
|
||||
type="inbound",
|
||||
protocol="websockets",
|
||||
|
||||
297
test/mitmproxy/protocol/test_websockets.py
Normal file
297
test/mitmproxy/protocol/test_websockets.py
Normal file
@@ -0,0 +1,297 @@
|
||||
import pytest
|
||||
import os
|
||||
import tempfile
|
||||
import traceback
|
||||
|
||||
from mitmproxy import options
|
||||
from mitmproxy.proxy.config import ProxyConfig
|
||||
|
||||
import netlib
|
||||
from netlib import http
|
||||
from ...netlib import tservers as netlib_tservers
|
||||
from .. import tservers
|
||||
|
||||
from netlib import websockets
|
||||
|
||||
|
||||
class _WebSocketsServerBase(netlib_tservers.ServerTestBase):
|
||||
|
||||
class handler(netlib.tcp.BaseHandler):
|
||||
|
||||
def handle(self):
|
||||
try:
|
||||
request = http.http1.read_request(self.rfile)
|
||||
assert websockets.check_handshake(request.headers)
|
||||
|
||||
response = http.Response(
|
||||
"HTTP/1.1",
|
||||
101,
|
||||
reason=http.status_codes.RESPONSES.get(101),
|
||||
headers=http.Headers(
|
||||
connection='upgrade',
|
||||
upgrade='websocket',
|
||||
sec_websocket_accept=b'',
|
||||
),
|
||||
content=b'',
|
||||
)
|
||||
self.wfile.write(http.http1.assemble_response(response))
|
||||
self.wfile.flush()
|
||||
|
||||
self.server.handle_websockets(self.rfile, self.wfile)
|
||||
except:
|
||||
traceback.print_exc()
|
||||
|
||||
|
||||
class _WebSocketsTestBase(object):
|
||||
|
||||
@classmethod
|
||||
def setup_class(cls):
|
||||
opts = cls.get_options()
|
||||
cls.config = ProxyConfig(opts)
|
||||
|
||||
tmaster = tservers.TestMaster(opts, cls.config)
|
||||
tmaster.start_app(options.APP_HOST, options.APP_PORT)
|
||||
cls.proxy = tservers.ProxyThread(tmaster)
|
||||
cls.proxy.start()
|
||||
|
||||
@classmethod
|
||||
def teardown_class(cls):
|
||||
cls.proxy.shutdown()
|
||||
|
||||
@classmethod
|
||||
def get_options(cls):
|
||||
opts = options.Options(
|
||||
listen_port=0,
|
||||
no_upstream_cert=False,
|
||||
ssl_insecure=True
|
||||
)
|
||||
opts.cadir = os.path.join(tempfile.gettempdir(), "mitmproxy")
|
||||
return opts
|
||||
|
||||
@property
|
||||
def master(self):
|
||||
return self.proxy.tmaster
|
||||
|
||||
def setup(self):
|
||||
self.master.clear_log()
|
||||
self.master.state.clear()
|
||||
self.server.server.handle_websockets = self.handle_websockets
|
||||
|
||||
def _setup_connection(self):
|
||||
client = netlib.tcp.TCPClient(("127.0.0.1", self.proxy.port))
|
||||
client.connect()
|
||||
|
||||
request = http.Request(
|
||||
"authority",
|
||||
"CONNECT",
|
||||
"",
|
||||
"localhost",
|
||||
self.server.server.address.port,
|
||||
"",
|
||||
"HTTP/1.1",
|
||||
content=b'')
|
||||
client.wfile.write(http.http1.assemble_request(request))
|
||||
client.wfile.flush()
|
||||
|
||||
response = http.http1.read_response(client.rfile, request)
|
||||
|
||||
if self.ssl:
|
||||
client.convert_to_ssl()
|
||||
assert client.ssl_established
|
||||
|
||||
request = http.Request(
|
||||
"relative",
|
||||
"GET",
|
||||
"http",
|
||||
"localhost",
|
||||
self.server.server.address.port,
|
||||
"/ws",
|
||||
"HTTP/1.1",
|
||||
headers=http.Headers(
|
||||
connection="upgrade",
|
||||
upgrade="websocket",
|
||||
sec_websocket_version="13",
|
||||
sec_websocket_key="1234",
|
||||
),
|
||||
content=b'')
|
||||
client.wfile.write(http.http1.assemble_request(request))
|
||||
client.wfile.flush()
|
||||
|
||||
response = http.http1.read_response(client.rfile, request)
|
||||
assert websockets.check_handshake(response.headers)
|
||||
|
||||
return client
|
||||
|
||||
|
||||
class _WebSocketsTest(_WebSocketsTestBase, _WebSocketsServerBase):
|
||||
|
||||
@classmethod
|
||||
def setup_class(cls):
|
||||
_WebSocketsTestBase.setup_class()
|
||||
_WebSocketsServerBase.setup_class(ssl=cls.ssl)
|
||||
|
||||
@classmethod
|
||||
def teardown_class(cls):
|
||||
_WebSocketsTestBase.teardown_class()
|
||||
_WebSocketsServerBase.teardown_class()
|
||||
|
||||
|
||||
class TestSimple(_WebSocketsTest):
|
||||
|
||||
@classmethod
|
||||
def handle_websockets(cls, rfile, wfile):
|
||||
wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.TEXT, payload=b'server-foobar')))
|
||||
wfile.flush()
|
||||
|
||||
frame = websockets.Frame.from_file(rfile)
|
||||
wfile.write(bytes(frame))
|
||||
wfile.flush()
|
||||
|
||||
def test_simple(self):
|
||||
client = self._setup_connection()
|
||||
|
||||
frame = websockets.Frame.from_file(client.rfile)
|
||||
assert frame.payload == b'server-foobar'
|
||||
|
||||
client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.TEXT, payload=b'client-foobar')))
|
||||
client.wfile.flush()
|
||||
|
||||
frame = websockets.Frame.from_file(client.rfile)
|
||||
assert frame.payload == b'client-foobar'
|
||||
|
||||
client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.CLOSE)))
|
||||
client.wfile.flush()
|
||||
|
||||
|
||||
class TestSimpleTLS(_WebSocketsTest):
|
||||
ssl = True
|
||||
|
||||
@classmethod
|
||||
def handle_websockets(cls, rfile, wfile):
|
||||
wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.TEXT, payload=b'server-foobar')))
|
||||
wfile.flush()
|
||||
|
||||
frame = websockets.Frame.from_file(rfile)
|
||||
wfile.write(bytes(frame))
|
||||
wfile.flush()
|
||||
|
||||
def test_simple_tls(self):
|
||||
client = self._setup_connection()
|
||||
|
||||
frame = websockets.Frame.from_file(client.rfile)
|
||||
assert frame.payload == b'server-foobar'
|
||||
|
||||
client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.TEXT, payload=b'client-foobar')))
|
||||
client.wfile.flush()
|
||||
|
||||
frame = websockets.Frame.from_file(client.rfile)
|
||||
assert frame.payload == b'client-foobar'
|
||||
|
||||
client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.CLOSE)))
|
||||
client.wfile.flush()
|
||||
|
||||
|
||||
class TestPing(_WebSocketsTest):
|
||||
|
||||
@classmethod
|
||||
def handle_websockets(cls, rfile, wfile):
|
||||
wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.PING, payload=b'foobar')))
|
||||
wfile.flush()
|
||||
|
||||
frame = websockets.Frame.from_file(rfile)
|
||||
assert frame.header.opcode == websockets.OPCODE.PONG
|
||||
assert frame.payload == b'foobar'
|
||||
|
||||
wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.TEXT, payload=b'pong-received')))
|
||||
wfile.flush()
|
||||
|
||||
def test_ping(self):
|
||||
client = self._setup_connection()
|
||||
|
||||
frame = websockets.Frame.from_file(client.rfile)
|
||||
assert frame.header.opcode == websockets.OPCODE.PING
|
||||
assert frame.payload == b'foobar'
|
||||
|
||||
client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.PONG, payload=frame.payload)))
|
||||
client.wfile.flush()
|
||||
|
||||
frame = websockets.Frame.from_file(client.rfile)
|
||||
assert frame.header.opcode == websockets.OPCODE.TEXT
|
||||
assert frame.payload == b'pong-received'
|
||||
|
||||
|
||||
class TestPong(_WebSocketsTest):
|
||||
|
||||
@classmethod
|
||||
def handle_websockets(cls, rfile, wfile):
|
||||
frame = websockets.Frame.from_file(rfile)
|
||||
assert frame.header.opcode == websockets.OPCODE.PING
|
||||
assert frame.payload == b'foobar'
|
||||
|
||||
wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.PONG, payload=frame.payload)))
|
||||
wfile.flush()
|
||||
|
||||
def test_pong(self):
|
||||
client = self._setup_connection()
|
||||
|
||||
client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.PING, payload=b'foobar')))
|
||||
client.wfile.flush()
|
||||
|
||||
frame = websockets.Frame.from_file(client.rfile)
|
||||
assert frame.header.opcode == websockets.OPCODE.PONG
|
||||
assert frame.payload == b'foobar'
|
||||
|
||||
|
||||
class TestClose(_WebSocketsTest):
|
||||
|
||||
@classmethod
|
||||
def handle_websockets(cls, rfile, wfile):
|
||||
frame = websockets.Frame.from_file(rfile)
|
||||
wfile.write(bytes(frame))
|
||||
wfile.flush()
|
||||
|
||||
with pytest.raises(netlib.exceptions.TcpDisconnect):
|
||||
websockets.Frame.from_file(rfile)
|
||||
|
||||
def test_close(self):
|
||||
client = self._setup_connection()
|
||||
|
||||
client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.CLOSE)))
|
||||
client.wfile.flush()
|
||||
|
||||
with pytest.raises(netlib.exceptions.TcpDisconnect):
|
||||
websockets.Frame.from_file(client.rfile)
|
||||
|
||||
def test_close_payload_1(self):
|
||||
client = self._setup_connection()
|
||||
|
||||
client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.CLOSE, payload=b'\00\42')))
|
||||
client.wfile.flush()
|
||||
|
||||
with pytest.raises(netlib.exceptions.TcpDisconnect):
|
||||
websockets.Frame.from_file(client.rfile)
|
||||
|
||||
def test_close_payload_2(self):
|
||||
client = self._setup_connection()
|
||||
|
||||
client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.CLOSE, payload=b'\00\42foobar')))
|
||||
client.wfile.flush()
|
||||
|
||||
with pytest.raises(netlib.exceptions.TcpDisconnect):
|
||||
websockets.Frame.from_file(client.rfile)
|
||||
|
||||
|
||||
class TestInvalidFrame(_WebSocketsTest):
|
||||
|
||||
@classmethod
|
||||
def handle_websockets(cls, rfile, wfile):
|
||||
wfile.write(bytes(websockets.Frame(fin=1, opcode=15, payload=b'foobar')))
|
||||
wfile.flush()
|
||||
|
||||
def test_invalid_frame(self):
|
||||
client = self._setup_connection()
|
||||
|
||||
# with pytest.raises(netlib.exceptions.TcpDisconnect):
|
||||
frame = websockets.Frame.from_file(client.rfile)
|
||||
assert frame.header.opcode == 15
|
||||
assert frame.payload == b'foobar'
|
||||
Reference in New Issue
Block a user