add websockets support to mitmproxy

This commit is contained in:
Thomas Kriechbaumer
2016-08-16 18:31:50 +02:00
parent d12515f84b
commit e5b0dae7e9
7 changed files with 465 additions and 9 deletions

View File

@@ -7,12 +7,15 @@ import traceback
import h2.exceptions
import six
import netlib.exceptions
from mitmproxy import exceptions
from mitmproxy import models
from mitmproxy.protocol import base
from .websockets import WebSocketsLayer
import netlib.exceptions
from netlib import http
from netlib import tcp
from netlib import websockets
class _HttpTransmissionLayer(base.Layer):
@@ -189,6 +192,21 @@ class HttpLayer(base.Layer):
self.process_request_hook(flow)
try:
# WebSockets
if websockets.check_handshake(request.headers):
if websockets.check_client_version(request.headers):
layer = WebSocketsLayer(self, request)
layer()
return
else:
# we only support RFC6455 with WebSockets version 13
self.send_response(models.make_error_response(
400,
http.status_codes.RESPONSES.get(400),
http.Headers(sec_websocket_version="13")
))
return
if not flow.response:
self.establish_server_connection(
flow.request.host,

View File

@@ -0,0 +1,140 @@
from __future__ import absolute_import, print_function, division
import socket
import struct
from OpenSSL import SSL
from mitmproxy import exceptions
from mitmproxy import models
from mitmproxy.protocol import base
import netlib.exceptions
from netlib import tcp
from netlib import http
from netlib import websockets
class WebSocketsLayer(base.Layer):
"""
WebSockets layer to intercept, modify, and forward WebSockets connections
Only version 13 is supported (as specified in RFC6455)
Only HTTP/1.1-initiated connections are supported.
The client starts by sending an Upgrade-request.
In order to determine the handshake and negotiate the correct protocol
and extensions, the Upgrade-request is forwarded to the server.
The response from the server is then parsed and negotiated settings are extracted.
Finally the handshake is completed by forwarding the server-response to the client.
After that, only WebSockets frames are exchanged.
PING/PONG frames pass through and must be answered by the other endpoint.
CLOSE frames are forwarded before this WebSocketsLayer terminates.
This layer is transparent to any negotiated extensions.
This layer is transparent to any negotiated subprotocols.
Only raw frames are forwarded to the other endpoint.
"""
def __init__(self, ctx, request):
super(WebSocketsLayer, self).__init__(ctx)
self._request = request
self.client_key = websockets.get_client_key(self._request.headers)
self.client_protocol = websockets.get_protocol(self._request.headers)
self.client_extensions = websockets.get_extensions(self._request.headers)
self.server_accept = None
self.server_protocol = None
self.server_extensions = None
def _initiate_server_conn(self):
self.establish_server_connection(
self._request.host,
self._request.port,
self._request.scheme,
)
self.server_conn.send(netlib.http.http1.assemble_request(self._request))
response = netlib.http.http1.read_response(self.server_conn.rfile, self._request, body_size_limit=None)
if not websockets.check_handshake(response.headers):
raise exceptions.ProtocolException("Establishing WebSockets connection with server failed: {}".format(response.headers))
self.server_accept = websockets.get_server_accept(response.headers)
self.server_protocol = websockets.get_protocol(response.headers)
self.server_extensions = websockets.get_extensions(response.headers)
def _complete_handshake(self):
headers = websockets.server_handshake_headers(self.client_key, self.server_protocol, self.server_extensions)
self.send_response(models.HTTPResponse(
self._request.http_version,
101,
http.status_codes.RESPONSES.get(101),
headers,
b"",
))
def _handle_frame(self, frame, source_conn, other_conn, is_server):
self.log(
"WebSockets Frame received from {}".format("server" if is_server else "client"),
"debug",
[repr(frame)]
)
if frame.header.opcode & 0x8 == 0:
# forward the data frame to the other side
other_conn.send(bytes(frame))
self.log("WebSockets frame received by {}: {}".format(is_server, frame), "debug")
elif frame.header.opcode in (websockets.OPCODE.PING, websockets.OPCODE.PONG):
# just forward the ping/pong to the other side
other_conn.send(bytes(frame))
elif frame.header.opcode == websockets.OPCODE.CLOSE:
other_conn.send(bytes(frame))
code = '(status code missing)'
msg = None
reason = '(message missing)'
if len(frame.payload) >= 2:
code, = struct.unpack('!H', frame.payload[:2])
msg = websockets.CLOSE_REASON.get_name(code, default='unknown status code')
if len(frame.payload) > 2:
reason = frame.payload[2:]
self.log("WebSockets connection closed: {} {}, {}".format(code, msg, reason), "info")
# close the connection
return False
else:
# unknown frame - just forward it
other_conn.send(bytes(frame))
# continue the connection
return True
def __call__(self):
self._initiate_server_conn()
self._complete_handshake()
client = self.client_conn.connection
server = self.server_conn.connection
conns = [client, server]
try:
while not self.channel.should_exit.is_set():
r = tcp.ssl_read_select(conns, 1)
for conn in r:
source_conn = self.client_conn if conn == client else self.server_conn
other_conn = self.server_conn if conn == client else self.client_conn
is_server = (conn == self.server_conn.connection)
frame = websockets.Frame.from_file(source_conn.rfile)
if not self._handle_frame(frame, source_conn, other_conn, is_server):
return
except (socket.error, netlib.exceptions.TcpException, SSL.Error) as e:
self.log("WebSockets connection closed unexpectedly by {}: {}".format(
"server" if is_server else "client", repr(e)), "info")
except Exception as e: # pragma: no cover
raise exceptions.ProtocolException("Error in WebSockets connection: {}".format(repr(e)))

View File

@@ -198,7 +198,7 @@ class Response(_HTTPMessage):
1,
StatusCode(101)
)
headers = netlib.websockets.WebsocketsProtocol.server_handshake_headers(
headers = netlib.websockets.server_handshake_headers(
settings.websocket_key
)
for i in headers.fields:
@@ -310,7 +310,7 @@ class Request(_HTTPMessage):
1,
Method("get")
)
for i in netlib.websockets.WebsocketsProtocol.client_handshake_headers().fields:
for i in netlib.websockets.client_handshake_headers().fields:
if not get_header(i[0], self.headers):
tokens.append(
Header(

View File

@@ -139,7 +139,7 @@ class WebsocketFrameReader(basethread.BaseThread):
except exceptions.TcpDisconnect:
return
self.frames_queue.put(frm)
log("<< %s" % frm.header.human_readable())
log("<< %s" % repr(frm.header))
if self.ws_read_limit is not None:
self.ws_read_limit -= 1
starttime = time.time()

View File

@@ -173,12 +173,13 @@ class PathodHandler(tcp.BaseHandler):
retlog["cipher"] = self.get_current_cipher()
m = utils.MemBool()
websocket_key = websockets.WebsocketsProtocol.check_client_handshake(headers)
self.settings.websocket_key = websocket_key
valid_websockets_handshake = websockets.check_handshake(headers)
self.settings.websocket_key = websockets.get_client_key(headers)
# If this is a websocket initiation, we respond with a proper
# server response, unless over-ridden.
if websocket_key:
if valid_websockets_handshake:
anchor_gen = language.parse_pathod("ws")
else:
anchor_gen = None
@@ -225,7 +226,7 @@ class PathodHandler(tcp.BaseHandler):
spec,
lg
)
if nexthandler and websocket_key:
if nexthandler and valid_websockets_handshake:
self.protocol = protocols.websockets.WebsocketsProtocol(self)
return self.protocol.handle_websocket, retlog
else:

View File

@@ -20,7 +20,7 @@ class WebsocketsProtocol:
lg("Error reading websocket frame: %s" % e)
return None, None
ended = time.time()
lg(frm.human_readable())
lg(repr(frm))
retlog = dict(
type="inbound",
protocol="websockets",

View File

@@ -0,0 +1,297 @@
import pytest
import os
import tempfile
import traceback
from mitmproxy import options
from mitmproxy.proxy.config import ProxyConfig
import netlib
from netlib import http
from ...netlib import tservers as netlib_tservers
from .. import tservers
from netlib import websockets
class _WebSocketsServerBase(netlib_tservers.ServerTestBase):
class handler(netlib.tcp.BaseHandler):
def handle(self):
try:
request = http.http1.read_request(self.rfile)
assert websockets.check_handshake(request.headers)
response = http.Response(
"HTTP/1.1",
101,
reason=http.status_codes.RESPONSES.get(101),
headers=http.Headers(
connection='upgrade',
upgrade='websocket',
sec_websocket_accept=b'',
),
content=b'',
)
self.wfile.write(http.http1.assemble_response(response))
self.wfile.flush()
self.server.handle_websockets(self.rfile, self.wfile)
except:
traceback.print_exc()
class _WebSocketsTestBase(object):
@classmethod
def setup_class(cls):
opts = cls.get_options()
cls.config = ProxyConfig(opts)
tmaster = tservers.TestMaster(opts, cls.config)
tmaster.start_app(options.APP_HOST, options.APP_PORT)
cls.proxy = tservers.ProxyThread(tmaster)
cls.proxy.start()
@classmethod
def teardown_class(cls):
cls.proxy.shutdown()
@classmethod
def get_options(cls):
opts = options.Options(
listen_port=0,
no_upstream_cert=False,
ssl_insecure=True
)
opts.cadir = os.path.join(tempfile.gettempdir(), "mitmproxy")
return opts
@property
def master(self):
return self.proxy.tmaster
def setup(self):
self.master.clear_log()
self.master.state.clear()
self.server.server.handle_websockets = self.handle_websockets
def _setup_connection(self):
client = netlib.tcp.TCPClient(("127.0.0.1", self.proxy.port))
client.connect()
request = http.Request(
"authority",
"CONNECT",
"",
"localhost",
self.server.server.address.port,
"",
"HTTP/1.1",
content=b'')
client.wfile.write(http.http1.assemble_request(request))
client.wfile.flush()
response = http.http1.read_response(client.rfile, request)
if self.ssl:
client.convert_to_ssl()
assert client.ssl_established
request = http.Request(
"relative",
"GET",
"http",
"localhost",
self.server.server.address.port,
"/ws",
"HTTP/1.1",
headers=http.Headers(
connection="upgrade",
upgrade="websocket",
sec_websocket_version="13",
sec_websocket_key="1234",
),
content=b'')
client.wfile.write(http.http1.assemble_request(request))
client.wfile.flush()
response = http.http1.read_response(client.rfile, request)
assert websockets.check_handshake(response.headers)
return client
class _WebSocketsTest(_WebSocketsTestBase, _WebSocketsServerBase):
@classmethod
def setup_class(cls):
_WebSocketsTestBase.setup_class()
_WebSocketsServerBase.setup_class(ssl=cls.ssl)
@classmethod
def teardown_class(cls):
_WebSocketsTestBase.teardown_class()
_WebSocketsServerBase.teardown_class()
class TestSimple(_WebSocketsTest):
@classmethod
def handle_websockets(cls, rfile, wfile):
wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.TEXT, payload=b'server-foobar')))
wfile.flush()
frame = websockets.Frame.from_file(rfile)
wfile.write(bytes(frame))
wfile.flush()
def test_simple(self):
client = self._setup_connection()
frame = websockets.Frame.from_file(client.rfile)
assert frame.payload == b'server-foobar'
client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.TEXT, payload=b'client-foobar')))
client.wfile.flush()
frame = websockets.Frame.from_file(client.rfile)
assert frame.payload == b'client-foobar'
client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.CLOSE)))
client.wfile.flush()
class TestSimpleTLS(_WebSocketsTest):
ssl = True
@classmethod
def handle_websockets(cls, rfile, wfile):
wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.TEXT, payload=b'server-foobar')))
wfile.flush()
frame = websockets.Frame.from_file(rfile)
wfile.write(bytes(frame))
wfile.flush()
def test_simple_tls(self):
client = self._setup_connection()
frame = websockets.Frame.from_file(client.rfile)
assert frame.payload == b'server-foobar'
client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.TEXT, payload=b'client-foobar')))
client.wfile.flush()
frame = websockets.Frame.from_file(client.rfile)
assert frame.payload == b'client-foobar'
client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.CLOSE)))
client.wfile.flush()
class TestPing(_WebSocketsTest):
@classmethod
def handle_websockets(cls, rfile, wfile):
wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.PING, payload=b'foobar')))
wfile.flush()
frame = websockets.Frame.from_file(rfile)
assert frame.header.opcode == websockets.OPCODE.PONG
assert frame.payload == b'foobar'
wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.TEXT, payload=b'pong-received')))
wfile.flush()
def test_ping(self):
client = self._setup_connection()
frame = websockets.Frame.from_file(client.rfile)
assert frame.header.opcode == websockets.OPCODE.PING
assert frame.payload == b'foobar'
client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.PONG, payload=frame.payload)))
client.wfile.flush()
frame = websockets.Frame.from_file(client.rfile)
assert frame.header.opcode == websockets.OPCODE.TEXT
assert frame.payload == b'pong-received'
class TestPong(_WebSocketsTest):
@classmethod
def handle_websockets(cls, rfile, wfile):
frame = websockets.Frame.from_file(rfile)
assert frame.header.opcode == websockets.OPCODE.PING
assert frame.payload == b'foobar'
wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.PONG, payload=frame.payload)))
wfile.flush()
def test_pong(self):
client = self._setup_connection()
client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.PING, payload=b'foobar')))
client.wfile.flush()
frame = websockets.Frame.from_file(client.rfile)
assert frame.header.opcode == websockets.OPCODE.PONG
assert frame.payload == b'foobar'
class TestClose(_WebSocketsTest):
@classmethod
def handle_websockets(cls, rfile, wfile):
frame = websockets.Frame.from_file(rfile)
wfile.write(bytes(frame))
wfile.flush()
with pytest.raises(netlib.exceptions.TcpDisconnect):
websockets.Frame.from_file(rfile)
def test_close(self):
client = self._setup_connection()
client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.CLOSE)))
client.wfile.flush()
with pytest.raises(netlib.exceptions.TcpDisconnect):
websockets.Frame.from_file(client.rfile)
def test_close_payload_1(self):
client = self._setup_connection()
client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.CLOSE, payload=b'\00\42')))
client.wfile.flush()
with pytest.raises(netlib.exceptions.TcpDisconnect):
websockets.Frame.from_file(client.rfile)
def test_close_payload_2(self):
client = self._setup_connection()
client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.CLOSE, payload=b'\00\42foobar')))
client.wfile.flush()
with pytest.raises(netlib.exceptions.TcpDisconnect):
websockets.Frame.from_file(client.rfile)
class TestInvalidFrame(_WebSocketsTest):
@classmethod
def handle_websockets(cls, rfile, wfile):
wfile.write(bytes(websockets.Frame(fin=1, opcode=15, payload=b'foobar')))
wfile.flush()
def test_invalid_frame(self):
client = self._setup_connection()
# with pytest.raises(netlib.exceptions.TcpDisconnect):
frame = websockets.Frame.from_file(client.rfile)
assert frame.header.opcode == 15
assert frame.payload == b'foobar'