Merge commit '9bc5adf'

This commit is contained in:
Maximilian Hils
2016-11-23 22:45:21 +01:00
21 changed files with 457 additions and 181 deletions

View File

@@ -3,7 +3,7 @@
TCP Proxy
=========
WebSockets or other non-HTTP protocols are not supported by mitmproxy yet. However, you can exempt
In case mitmproxy does not handle a specific protocol, you can exempt
hostnames from processing, so that mitmproxy acts as a generic TCP forwarder.
This feature is closely related to the :ref:`passthrough` functionality,
but differs in two important aspects:

View File

@@ -158,21 +158,54 @@ HTTP Events
WebSocket Events
-----------------
These events are called only after a connection made an HTTP upgrade with
"101 Switching Protocols". No further HTTP-related events after the handshake
are issued, only new WebSocket messages are called.
.. list-table::
:widths: 40 60
:header-rows: 0
* - .. py:function:: websockets_handshake(flow)
- Called when a client wants to establish a WebSockets connection. The
WebSockets-specific headers can be manipulated to manipulate the
* - .. py:function:: websocket_handshake(flow)
- Called when a client wants to establish a WebSocket connection. The
WebSocket-specific headers can be manipulated to alter the
handshake. The ``flow`` object is guaranteed to have a non-None
``request`` attribute.
*flow*
The flow containing the HTTP websocket handshake request. The
The flow containing the HTTP WebSocket handshake request. The
object is guaranteed to have a non-None ``request`` attribute.
* - .. py:function:: websocket_start(flow)
- Called when WebSocket connection is established after a successful
handshake.
*flow*
A ``models.WebSocketFlow`` object.
* - .. py:function:: websocket_message(flow)
- Called when a WebSocket message is received from the client or server. The
sender and receiver are identifiable. The most recent message will be
``flow.messages[-1]``. The message is user-modifiable. Currently there are
two types of messages, corresponding to the BINARY and TEXT frame types.
*flow*
A ``models.WebSocketFlow`` object.
* - .. py:function:: websocket_end(flow)
- Called when WebSocket connection ends.
*flow*
A ``models.WebSocketFlow`` object.
* - .. py:function:: websocket_error(flow)
- Called when a WebSocket error occurs - e.g. the connection closing
unexpectedly.
*flow*
A ``models.WebSocketFlow`` object.
TCP Events
----------
@@ -185,6 +218,22 @@ connections.
:widths: 40 60
:header-rows: 0
* - .. py:function:: tcp_start(flow)
- Called when TCP streaming starts.
*flow*
A ``models.TCPFlow`` object.
* - .. py:function:: tcp_message(flow)
- Called when a TCP payload is received from the client or server. The
sender and receiver are identifiable. The most recent message will be
``flow.messages[-1]``. The message is user-modifiable.
*flow*
A ``models.TCPFlow`` object.
* - .. py:function:: tcp_end(flow)
- Called when TCP streaming ends.
@@ -197,18 +246,3 @@ connections.
*flow*
A ``models.TCPFlow`` object.
* - .. py:function:: tcp_message(flow)
- Called a TCP payload is received from the client or server. The
sender and receiver are identifiable. The most recent message will be
``flow.messages[-1]``. The message is user-modifiable.
*flow*
A ``models.TCPFlow`` object.
* - .. py:function:: tcp_start(flow)
- Called when TCP streaming starts.
*flow*
A ``models.TCPFlow`` object.

View File

@@ -223,6 +223,29 @@ class Dumper:
if self.match(f):
self.echo_flow(f)
def websocket_error(self, f):
self.echo(
"Error in WebSocket connection to {}: {}".format(
repr(f.server_conn.address), f.error
),
fg="red"
)
def websocket_message(self, f):
if self.match(f):
message = f.messages[-1]
self.echo(message.info)
if self.flow_detail >= 3:
self._echo_message(message)
def websocket_end(self, f):
if self.match(f):
self.echo("WebSocket connection closed by {}: {} {}, {}".format(
f.close_sender,
f.close_code,
f.close_message,
f.close_reason))
def tcp_error(self, f):
self.echo(
"Error in TCP connection to {}: {}".format(
@@ -240,4 +263,5 @@ class Dumper:
server=repr(f.server_conn.address),
direction=direction,
))
self._echo_message(message)
if self.flow_detail >= 3:
self._echo_message(message)

View File

@@ -1,6 +1,7 @@
from mitmproxy import controller
from mitmproxy import http
from mitmproxy import tcp
from mitmproxy import websocket
Events = frozenset([
"clientconnect",
@@ -24,6 +25,10 @@ Events = frozenset([
"resume",
"websocket_handshake",
"websocket_start",
"websocket_message",
"websocket_error",
"websocket_end",
"next_layer",
@@ -45,6 +50,17 @@ def event_sequence(f):
yield "response", f
if f.error:
yield "error", f
elif isinstance(f, websocket.WebSocketFlow):
messages = f.messages
f.messages = []
f.reply = controller.DummyReply()
yield "websocket_start", f
while messages:
f.messages.append(messages.pop(0))
yield "websocket_message", f
if f.error:
yield "websocket_error", f
yield "websocket_end", f
elif isinstance(f, tcp.TCPFlow):
messages = f.messages
f.messages = []

View File

@@ -4,12 +4,14 @@ from mitmproxy import exceptions
from mitmproxy import flowfilter
from mitmproxy import http
from mitmproxy import tcp
from mitmproxy import websocket
from mitmproxy.contrib import tnetstring
from mitmproxy import io_compat
FLOW_TYPES = dict(
http=http.HTTPFlow,
websocket=websocket.WebSocketFlow,
tcp=tcp.TCPFlow,
)

View File

@@ -283,6 +283,22 @@ class Master:
def websocket_handshake(self, f):
pass
@controller.handler
def websocket_start(self, flow):
pass
@controller.handler
def websocket_message(self, flow):
pass
@controller.handler
def websocket_error(self, flow):
pass
@controller.handler
def websocket_end(self, flow):
pass
@controller.handler
def tcp_start(self, flow):
pass

View File

@@ -90,7 +90,7 @@ class FrameHeader:
@classmethod
def _make_length_code(self, length):
"""
A websockets frame contains an initial length_code, and an optional
A WebSocket frame contains an initial length_code, and an optional
extended length code to represent the actual length if length code is
larger than 125
"""
@@ -149,7 +149,7 @@ class FrameHeader:
@classmethod
def from_file(cls, fp):
"""
read a websockets frame header
read a WebSocket frame header
"""
first_byte, second_byte = fp.safe_read(2)
fin = bits.getbit(first_byte, 7)
@@ -195,11 +195,11 @@ class FrameHeader:
class Frame:
"""
Represents a single WebSockets frame.
Represents a single WebSocket frame.
Constructor takes human readable forms of the frame components.
from_bytes() reads from a file-like object to create a new Frame.
WebSockets Frame as defined in RFC6455
WebSocket frame as defined in RFC6455
0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
+-+-+-+-+-------+-+-------------+-------------------------------+
@@ -253,7 +253,7 @@ class Frame:
@classmethod
def from_file(cls, fp):
"""
read a websockets frame sent by a server or client
read a WebSocket frame sent by a server or client
fp is a "file like" object that could be backed by a network
stream or a disk or an in memory stream reader

View File

@@ -1,5 +1,5 @@
"""
Collection of WebSockets Protocol utility functions (RFC6455)
Collection of WebSocket protocol utility functions (RFC6455)
Spec: https://tools.ietf.org/html/rfc6455
"""

View File

@@ -73,7 +73,7 @@ class Options(optmanager.OptManager):
mode: str = "regular",
no_upstream_cert: bool = False,
rawtcp: bool = False,
websockets: bool = False,
websocket: bool = True,
spoof_source_address: bool = False,
upstream_server: Optional[str] = None,
upstream_auth: Optional[str] = None,
@@ -136,7 +136,7 @@ class Options(optmanager.OptManager):
self.mode = mode
self.no_upstream_cert = no_upstream_cert
self.rawtcp = rawtcp
self.websockets = websockets
self.websocket = websocket
self.spoof_source_address = spoof_source_address
self.upstream_server = upstream_server
self.upstream_auth = upstream_auth

View File

@@ -2,7 +2,7 @@
In mitmproxy, protocols are implemented as a set of layers, which are composed
on top each other. The first layer is usually the proxy mode, e.g. transparent
proxy or normal HTTP proxy. Next, various protocol layers are stacked on top of
each other - imagine WebSockets on top of an HTTP Upgrade request. An actual
each other - imagine WebSocket on top of an HTTP Upgrade request. An actual
mitmproxy connection may look as follows (outermost layer first):
Transparent HTTP proxy, no TLS:
@@ -10,7 +10,7 @@ mitmproxy connection may look as follows (outermost layer first):
- Http1Layer
- HttpLayer
Regular proxy, CONNECT request with WebSockets over SSL:
Regular proxy, CONNECT request with WebSocket over SSL:
- ReverseProxy
- Http1Layer
- HttpLayer
@@ -34,7 +34,7 @@ from .http import UpstreamConnectLayer
from .http import HttpLayer
from .http1 import Http1Layer
from .http2 import Http2Layer
from .websockets import WebSocketsLayer
from .websocket import WebSocketLayer
from .rawtcp import RawTCPLayer
from .tls import TlsClientHello
from .tls import TlsLayer
@@ -47,6 +47,6 @@ __all__ = [
"HttpLayer",
"Http1Layer",
"Http2Layer",
"WebSocketsLayer",
"WebSocketLayer",
"RawTCPLayer",
]

View File

@@ -8,7 +8,7 @@ from mitmproxy import exceptions
from mitmproxy import http
from mitmproxy import flow
from mitmproxy.proxy.protocol import base
from mitmproxy.proxy.protocol import websockets as pwebsockets
from mitmproxy.proxy.protocol.websocket import WebSocketLayer
from mitmproxy.net import tcp
from mitmproxy.net import websockets
@@ -300,7 +300,7 @@ class HttpLayer(base.Layer):
try:
if websockets.check_handshake(request.headers) and websockets.check_client_version(request.headers):
# We only support RFC6455 with WebSockets version 13
# We only support RFC6455 with WebSocket version 13
# allow inline scripts to manipulate the client handshake
self.channel.ask("websocket_handshake", f)
@@ -392,19 +392,19 @@ class HttpLayer(base.Layer):
if f.response.status_code == 101:
# Handle a successful HTTP 101 Switching Protocols Response,
# received after e.g. a WebSocket upgrade request.
# Check for WebSockets handshake
is_websockets = (
# Check for WebSocket handshake
is_websocket = (
websockets.check_handshake(f.request.headers) and
websockets.check_handshake(f.response.headers)
)
if is_websockets and not self.config.options.websockets:
if is_websocket and not self.config.options.websocket:
self.log(
"Client requested WebSocket connection, but the protocol is disabled.",
"info"
)
if is_websockets and self.config.options.websockets:
layer = pwebsockets.WebSocketsLayer(self, f)
if is_websocket and self.config.options.websocket:
layer = WebSocketLayer(self, f)
else:
layer = self.ctx.next_layer(self)
layer()

View File

@@ -121,7 +121,7 @@ class Http2Layer(base.Layer):
self.client_conn.send(self.connections[self.client_conn].data_to_send())
def next_layer(self): # pragma: no cover
# WebSockets over HTTP/2?
# WebSocket over HTTP/2?
# CONNECT for proxying?
raise NotImplementedError()

View File

@@ -0,0 +1,146 @@
import os
import socket
import struct
from OpenSSL import SSL
from mitmproxy import exceptions
from mitmproxy import flow
from mitmproxy.proxy.protocol import base
from mitmproxy.net import tcp
from mitmproxy.net import websockets
from mitmproxy.websocket import WebSocketFlow, WebSocketBinaryMessage, WebSocketTextMessage
class WebSocketLayer(base.Layer):
"""
WebSocket layer to intercept, modify, and forward WebSocket messages.
Only version 13 is supported (as specified in RFC6455).
Only HTTP/1.1-initiated connections are supported.
The client starts by sending an Upgrade-request.
In order to determine the handshake and negotiate the correct protocol
and extensions, the Upgrade-request is forwarded to the server.
The response from the server is then parsed and negotiated settings are extracted.
Finally the handshake is completed by forwarding the server-response to the client.
After that, only WebSocket frames are exchanged.
PING/PONG frames pass through and must be answered by the other endpoint.
CLOSE frames are forwarded before this WebSocketLayer terminates.
This layer is transparent to any negotiated extensions.
This layer is transparent to any negotiated subprotocols.
Only raw frames are forwarded to the other endpoint.
WebSocket messages are stored in a WebSocketFlow.
"""
def __init__(self, ctx, handshake_flow):
super().__init__(ctx)
self.handshake_flow = handshake_flow
self.flow = None # type: WebSocketFlow
self.client_frame_buffer = []
self.server_frame_buffer = []
def _handle_frame(self, frame, source_conn, other_conn, is_server):
if frame.header.opcode & 0x8 == 0:
return self._handle_data_frame(frame, source_conn, other_conn, is_server)
elif frame.header.opcode in (websockets.OPCODE.PING, websockets.OPCODE.PONG):
return self._handle_ping_pong(frame, source_conn, other_conn, is_server)
elif frame.header.opcode == websockets.OPCODE.CLOSE:
return self._handle_close(frame, source_conn, other_conn, is_server)
else:
return self._handle_unknown_frame(frame, source_conn, other_conn, is_server)
def _handle_data_frame(self, frame, source_conn, other_conn, is_server):
fb = self.server_frame_buffer if is_server else self.client_frame_buffer
fb.append(frame)
if frame.header.fin:
if frame.header.opcode == websockets.OPCODE.TEXT:
t = WebSocketTextMessage
else:
t = WebSocketBinaryMessage
payload = b''.join(f.payload for f in fb)
fb.clear()
websocket_message = t(self.flow, not is_server, payload)
self.flow.messages.append(websocket_message)
self.channel.ask("websocket_message", self.flow)
# chunk payload into multiple 10kB frames, and send them
payload = websocket_message.content
chunk_size = 10240 # 10kB
chunks = range(0, len(payload), chunk_size)
frms = [
websockets.Frame(
payload=payload[i:i + chunk_size],
opcode=frame.header.opcode,
mask=(False if is_server else 1),
masking_key=(b'' if is_server else os.urandom(4))) for i in chunks
]
frms[-1].header.fin = 1
for frm in frms:
other_conn.send(bytes(frm))
return True
def _handle_ping_pong(self, frame, source_conn, other_conn, is_server):
# just forward the ping/pong to the other side
other_conn.send(bytes(frame))
return True
def _handle_close(self, frame, source_conn, other_conn, is_server):
self.flow.close_sender = "server" if is_server else "client"
if len(frame.payload) >= 2:
code, = struct.unpack('!H', frame.payload[:2])
self.flow.close_code = code
self.flow.close_message = websockets.CLOSE_REASON.get_name(code, default='unknown status code')
if len(frame.payload) > 2:
self.flow.close_reason = frame.payload[2:]
other_conn.send(bytes(frame))
# close the connection
return False
def _handle_unknown_frame(self, frame, source_conn, other_conn, is_server):
# unknown frame - just forward it
other_conn.send(bytes(frame))
sender = "server" if is_server else "client"
self.log("Unknown WebSocket frame received from {}".format(sender), "info", [repr(frame)])
return True
def __call__(self):
self.flow = WebSocketFlow(self.client_conn, self.server_conn, self.handshake_flow, self)
self.flow.metadata['websocket_handshake'] = self.handshake_flow
self.handshake_flow.metadata['websocket_flow'] = self.flow
self.channel.ask("websocket_start", self.flow)
client = self.client_conn.connection
server = self.server_conn.connection
conns = [client, server]
try:
while not self.channel.should_exit.is_set():
r = tcp.ssl_read_select(conns, 1)
for conn in r:
source_conn = self.client_conn if conn == client else self.server_conn
other_conn = self.server_conn if conn == client else self.client_conn
is_server = (conn == self.server_conn.connection)
frame = websockets.Frame.from_file(source_conn.rfile)
if not self._handle_frame(frame, source_conn, other_conn, is_server):
return
except (socket.error, exceptions.TcpException, SSL.Error) as e:
self.flow.error = flow.Error("WebSocket connection closed unexpectedly: {}".format(repr(e)))
self.channel.tell("websocket_error", self.flow)
finally:
self.channel.tell("websocket_end", self.flow)

View File

@@ -1,111 +0,0 @@
import socket
import struct
from OpenSSL import SSL
from mitmproxy import exceptions
from mitmproxy.proxy.protocol import base
from mitmproxy.utils import strutils
from mitmproxy.net import tcp
from mitmproxy.net import websockets
class WebSocketsLayer(base.Layer):
"""
WebSockets layer to intercept, modify, and forward WebSockets connections
Only version 13 is supported (as specified in RFC6455)
Only HTTP/1.1-initiated connections are supported.
The client starts by sending an Upgrade-request.
In order to determine the handshake and negotiate the correct protocol
and extensions, the Upgrade-request is forwarded to the server.
The response from the server is then parsed and negotiated settings are extracted.
Finally the handshake is completed by forwarding the server-response to the client.
After that, only WebSockets frames are exchanged.
PING/PONG frames pass through and must be answered by the other endpoint.
CLOSE frames are forwarded before this WebSocketsLayer terminates.
This layer is transparent to any negotiated extensions.
This layer is transparent to any negotiated subprotocols.
Only raw frames are forwarded to the other endpoint.
"""
def __init__(self, ctx, flow):
super().__init__(ctx)
self._flow = flow
self.client_key = websockets.get_client_key(self._flow.request.headers)
self.client_protocol = websockets.get_protocol(self._flow.request.headers)
self.client_extensions = websockets.get_extensions(self._flow.request.headers)
self.server_accept = websockets.get_server_accept(self._flow.response.headers)
self.server_protocol = websockets.get_protocol(self._flow.response.headers)
self.server_extensions = websockets.get_extensions(self._flow.response.headers)
def _handle_frame(self, frame, source_conn, other_conn, is_server):
sender = "server" if is_server else "client"
self.log(
"WebSockets Frame received from {}".format(sender),
"debug",
[repr(frame)]
)
if frame.header.opcode & 0x8 == 0:
self.log(
"{direction} websocket {direction} {server}".format(
server=repr(self.server_conn.address),
direction="<-" if is_server else "->",
),
"info",
strutils.bytes_to_escaped_str(frame.payload, keep_spacing=True).splitlines()
)
# forward the data frame to the other side
other_conn.send(bytes(frame))
elif frame.header.opcode in (websockets.OPCODE.PING, websockets.OPCODE.PONG):
# just forward the ping/pong to the other side
other_conn.send(bytes(frame))
elif frame.header.opcode == websockets.OPCODE.CLOSE:
code = '(status code missing)'
msg = None
reason = '(message missing)'
if len(frame.payload) >= 2:
code, = struct.unpack('!H', frame.payload[:2])
msg = websockets.CLOSE_REASON.get_name(code, default='unknown status code')
if len(frame.payload) > 2:
reason = frame.payload[2:]
self.log("WebSockets connection closed by {}: {} {}, {}".format(sender, code, msg, reason), "info")
other_conn.send(bytes(frame))
# close the connection
return False
else:
self.log("Unknown WebSockets frame received from {}".format(sender), "info", [repr(frame)])
# unknown frame - just forward it
other_conn.send(bytes(frame))
# continue the connection
return True
def __call__(self):
client = self.client_conn.connection
server = self.server_conn.connection
conns = [client, server]
try:
while not self.channel.should_exit.is_set():
r = tcp.ssl_read_select(conns, 1)
for conn in r:
source_conn = self.client_conn if conn == client else self.server_conn
other_conn = self.server_conn if conn == client else self.client_conn
is_server = (conn == self.server_conn.connection)
frame = websockets.Frame.from_file(source_conn.rfile)
if not self._handle_frame(frame, source_conn, other_conn, is_server):
return
except (socket.error, exceptions.TcpException, SSL.Error) as e:
self.log("WebSockets connection closed unexpectedly by {}: {}".format(
"server" if is_server else "client", repr(e)), "info")
except Exception as e: # pragma: no cover
raise exceptions.ProtocolException("Error in WebSockets connection: {}".format(repr(e)))

View File

@@ -11,9 +11,7 @@ class TCPMessage(serializable.Serializable):
def __init__(self, from_client, content, timestamp=None):
self.content = content
self.from_client = from_client
if timestamp is None:
timestamp = time.time()
self.timestamp = timestamp
self.timestamp = timestamp or time.time()
@classmethod
def from_state(cls, state):

View File

@@ -256,7 +256,7 @@ def get_common_options(args):
no_upstream_cert = args.no_upstream_cert,
spoof_source_address = args.spoof_source_address,
rawtcp = args.rawtcp,
websockets = args.websockets,
websocket = args.websocket,
upstream_server = upstream_server,
upstream_auth = args.upstream_auth,
ssl_version_client = args.ssl_version_client,
@@ -459,6 +459,12 @@ def proxy_options(parser):
If your OpenSSL version supports ALPN, HTTP/2 is enabled by default.
"""
)
group.add_argument(
"--no-websocket",
action="store_false", dest="websocket",
help="Explicitly disable WebSocket support."
)
parser.add_argument(
"--upstream-auth",
action="store", dest="upstream_auth", default=None,
@@ -468,6 +474,7 @@ def proxy_options(parser):
requests. Format: username:password
"""
)
rawtcp = group.add_mutually_exclusive_group()
rawtcp.add_argument("--raw-tcp", action="store_true", dest="rawtcp")
rawtcp.add_argument("--no-raw-tcp", action="store_false", dest="rawtcp",
@@ -475,13 +482,7 @@ def proxy_options(parser):
"Disabled by default. "
"Default value will change in a future version."
)
websockets = group.add_mutually_exclusive_group()
websockets.add_argument("--websockets", action="store_true", dest="websockets")
websockets.add_argument("--no-websockets", action="store_false", dest="websockets",
help="Explicitly enable/disable experimental WebSocket support. "
"Disabled by default as messages are only printed to the event log and not retained. "
"Default value will change in a future version."
)
group.add_argument(
"--spoof-source-address",
action="store_true", dest="spoof_source_address",

View File

@@ -14,10 +14,16 @@ def maybe_timestamp(base, attr):
def flowdetails(state, flow):
text = []
cc = flow.client_conn
sc = flow.server_conn
cc = flow.client_conn
req = flow.request
resp = flow.response
metadata = flow.metadata
if metadata is not None and len(metadata.items()) > 0:
parts = [[str(k), repr(v)] for k, v in metadata.items()]
text.append(urwid.Text([("head", "Metadata:")]))
text.extend(common.format_keyvals(parts, key="key", val="text", indent=4))
if sc is not None:
text.append(urwid.Text([("head", "Server Connection:")]))
@@ -109,6 +115,7 @@ def flowdetails(state, flow):
maybe_timestamp(cc, "timestamp_ssl_setup")
]
)
if sc is not None and sc.timestamp_start:
parts.append(
[
@@ -129,6 +136,7 @@ def flowdetails(state, flow):
maybe_timestamp(sc, "timestamp_ssl_setup")
]
)
if req is not None and req.timestamp_start:
parts.append(
[
@@ -142,6 +150,7 @@ def flowdetails(state, flow):
maybe_timestamp(req, "timestamp_end")
]
)
if resp is not None and resp.timestamp_start:
parts.append(
[
@@ -162,4 +171,5 @@ def flowdetails(state, flow):
text.append(urwid.Text([("head", "Timing:")]))
text.extend(common.format_keyvals(parts, key="key", val="text", indent=4))
return searchable.Searchable(state, text)

View File

@@ -446,17 +446,33 @@ class ConsoleMaster(master.Master):
self.logbuffer[:] = []
# Handlers
@controller.handler
def websocket_message(self, f):
super().websocket_message(f)
message = f.messages[-1]
signals.add_log(message.info, "info")
signals.add_log(strutils.bytes_to_escaped_str(message.content), "debug")
@controller.handler
def websocket_end(self, f):
super().websocket_end(f)
signals.add_log("WebSocket connection closed by {}: {} {}, {}".format(
f.close_sender,
f.close_code,
f.close_message,
f.close_reason), "info")
@controller.handler
def tcp_message(self, f):
super().tcp_message(f)
message = f.messages[-1]
direction = "->" if message.from_client else "<-"
self.add_log("{client} {direction} tcp {direction} {server}".format(
signals.add_log("{client} {direction} tcp {direction} {server}".format(
client=repr(f.client_conn.address),
server=repr(f.server_conn.address),
direction=direction,
), "info")
self.add_log(strutils.bytes_to_escaped_str(message.content), "debug")
signals.add_log(strutils.bytes_to_escaped_str(message.content), "debug")
@controller.handler
def log(self, evt):

87
mitmproxy/websocket.py Normal file
View File

@@ -0,0 +1,87 @@
import time
from typing import List
from mitmproxy import flow
from mitmproxy.http import HTTPFlow
from mitmproxy.net import websockets
from mitmproxy.utils import strutils
from mitmproxy.types import serializable
class WebSocketMessage(serializable.Serializable):
def __init__(self, flow, from_client, content, timestamp=None):
self.flow = flow
self.content = content
self.from_client = from_client
self.timestamp = timestamp or time.time()
@classmethod
def from_state(cls, state):
return cls(*state)
def get_state(self):
return self.from_client, self.content, self.timestamp
def set_state(self, state):
self.from_client = state.pop("from_client")
self.content = state.pop("content")
self.timestamp = state.pop("timestamp")
@property
def info(self):
return "{client} {direction} WebSocket {type} message {direction} {server}{endpoint}".format(
type=self.type,
client=repr(self.flow.client_conn.address),
server=repr(self.flow.server_conn.address),
direction="->" if self.from_client else "<-",
endpoint=self.flow.handshake_flow.request.path,
)
class WebSocketBinaryMessage(WebSocketMessage):
type = 'binary'
def __repr__(self):
return "binary message: {}".format(strutils.bytes_to_escaped_str(self.content))
class WebSocketTextMessage(WebSocketMessage):
type = 'text'
def __repr__(self):
return "text message: {}".format(repr(self.content))
class WebSocketFlow(flow.Flow):
"""
A WebsocketFlow is a simplified representation of a Websocket session.
"""
def __init__(self, client_conn, server_conn, handshake_flow, live=None):
super().__init__("websocket", client_conn, server_conn, live)
self.messages = [] # type: List[WebSocketMessage]
self.close_sender = 'client'
self.close_code = '(status code missing)'
self.close_message = '(message missing)'
self.close_reason = 'unknown status code'
self.handshake_flow = handshake_flow
self.client_key = websockets.get_client_key(self.handshake_flow.request.headers)
self.client_protocol = websockets.get_protocol(self.handshake_flow.request.headers)
self.client_extensions = websockets.get_extensions(self.handshake_flow.request.headers)
self.server_accept = websockets.get_server_accept(self.handshake_flow.response.headers)
self.server_protocol = websockets.get_protocol(self.handshake_flow.response.headers)
self.server_extensions = websockets.get_extensions(self.handshake_flow.response.headers)
_stateobject_attributes = flow.Flow._stateobject_attributes.copy()
_stateobject_attributes.update(
messages=List[WebSocketMessage],
handshake_flow=HTTPFlow,
)
def __repr__(self):
return "WebSocketFlow ({} messages)".format(len(self.messages))

View File

@@ -5,6 +5,8 @@ import traceback
from mitmproxy import options
from mitmproxy import exceptions
from mitmproxy.http import HTTPFlow
from mitmproxy.websocket import WebSocketFlow
from mitmproxy.proxy.config import ProxyConfig
import mitmproxy.net
@@ -15,7 +17,7 @@ from .. import tservers
from mitmproxy.net import websockets
class _WebSocketsServerBase(net_tservers.ServerTestBase):
class _WebSocketServerBase(net_tservers.ServerTestBase):
class handler(mitmproxy.net.tcp.BaseHandler):
@@ -43,7 +45,7 @@ class _WebSocketsServerBase(net_tservers.ServerTestBase):
traceback.print_exc()
class _WebSocketsTestBase:
class _WebSocketTestBase:
@classmethod
def setup_class(cls):
@@ -64,7 +66,7 @@ class _WebSocketsTestBase:
listen_port=0,
no_upstream_cert=False,
ssl_insecure=True,
websockets=True,
websocket=True,
)
opts.cadir = os.path.join(tempfile.gettempdir(), "mitmproxy")
return opts
@@ -123,20 +125,20 @@ class _WebSocketsTestBase:
return client
class _WebSocketsTest(_WebSocketsTestBase, _WebSocketsServerBase):
class _WebSocketTest(_WebSocketTestBase, _WebSocketServerBase):
@classmethod
def setup_class(cls):
_WebSocketsTestBase.setup_class()
_WebSocketsServerBase.setup_class(ssl=cls.ssl)
_WebSocketTestBase.setup_class()
_WebSocketServerBase.setup_class(ssl=cls.ssl)
@classmethod
def teardown_class(cls):
_WebSocketsTestBase.teardown_class()
_WebSocketsServerBase.teardown_class()
_WebSocketTestBase.teardown_class()
_WebSocketServerBase.teardown_class()
class TestSimple(_WebSocketsTest):
class TestSimple(_WebSocketTest):
@classmethod
def handle_websockets(cls, rfile, wfile):
@@ -147,6 +149,10 @@ class TestSimple(_WebSocketsTest):
wfile.write(bytes(frame))
wfile.flush()
frame = websockets.Frame.from_file(rfile)
wfile.write(bytes(frame))
wfile.flush()
def test_simple(self):
client = self._setup_connection()
@@ -159,11 +165,33 @@ class TestSimple(_WebSocketsTest):
frame = websockets.Frame.from_file(client.rfile)
assert frame.payload == b'client-foobar'
client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.BINARY, payload=b'\xde\xad\xbe\xef')))
client.wfile.flush()
frame = websockets.Frame.from_file(client.rfile)
assert frame.payload == b'\xde\xad\xbe\xef'
client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.CLOSE)))
client.wfile.flush()
assert len(self.master.state.flows) == 2
assert isinstance(self.master.state.flows[0], HTTPFlow)
assert isinstance(self.master.state.flows[1], WebSocketFlow)
assert len(self.master.state.flows[1].messages) == 5
assert self.master.state.flows[1].messages[0].content == b'server-foobar'
assert self.master.state.flows[1].messages[0].type == 'text'
assert self.master.state.flows[1].messages[1].content == b'client-foobar'
assert self.master.state.flows[1].messages[1].type == 'text'
assert self.master.state.flows[1].messages[2].content == b'client-foobar'
assert self.master.state.flows[1].messages[2].type == 'text'
assert self.master.state.flows[1].messages[3].content == b'\xde\xad\xbe\xef'
assert self.master.state.flows[1].messages[3].type == 'binary'
assert self.master.state.flows[1].messages[4].content == b'\xde\xad\xbe\xef'
assert self.master.state.flows[1].messages[4].type == 'binary'
assert [m.info for m in self.master.state.flows[1].messages]
class TestSimpleTLS(_WebSocketsTest):
class TestSimpleTLS(_WebSocketTest):
ssl = True
@classmethod
@@ -191,7 +219,7 @@ class TestSimpleTLS(_WebSocketsTest):
client.wfile.flush()
class TestPing(_WebSocketsTest):
class TestPing(_WebSocketTest):
@classmethod
def handle_websockets(cls, rfile, wfile):
@@ -220,7 +248,7 @@ class TestPing(_WebSocketsTest):
assert frame.payload == b'pong-received'
class TestPong(_WebSocketsTest):
class TestPong(_WebSocketTest):
@classmethod
def handle_websockets(cls, rfile, wfile):
@@ -242,7 +270,7 @@ class TestPong(_WebSocketsTest):
assert frame.payload == b'foobar'
class TestClose(_WebSocketsTest):
class TestClose(_WebSocketTest):
@classmethod
def handle_websockets(cls, rfile, wfile):
@@ -281,7 +309,7 @@ class TestClose(_WebSocketsTest):
websockets.Frame.from_file(client.rfile)
class TestInvalidFrame(_WebSocketsTest):
class TestInvalidFrame(_WebSocketTest):
@classmethod
def handle_websockets(cls, rfile, wfile):

View File

@@ -26,6 +26,15 @@ class TestState:
if f not in self.flows:
self.flows.append(f)
def websocket_start(self, f):
if f not in self.flows:
self.flows.append(f)
# TODO: add TCP support?
# def tcp_start(self, f):
# if f not in self.flows:
# self.flows.append(f)
# FIXME: compat with old state - remove in favor of len(state.flows)
def flow_count(self):
return len(self.flows)