split mitmproxy.flow into mitmproxy.flow.*

This commit is contained in:
Maximilian Hils
2016-05-30 01:40:09 -07:00
parent 6652e3a369
commit 89f07603ca
10 changed files with 1295 additions and 1282 deletions

View File

@@ -7,7 +7,7 @@ import os
import netlib.utils
from .. import utils
from .. import flow_export
from .. import flow
from ..models import decoded
from . import signals
@@ -282,16 +282,16 @@ def copy_flow_format_data(part, scope, flow):
return data, False
def export_prompt(k, flow):
def export_prompt(k, f):
exporters = {
"c": flow_export.curl_command,
"p": flow_export.python_code,
"r": flow_export.raw_request,
"l": flow_export.locust_code,
"t": flow_export.locust_task,
"c": flow.export.curl_command,
"p": flow.export.python_code,
"r": flow.export.raw_request,
"l": flow.export.locust_code,
"t": flow.export.locust_task,
}
if k in exporters:
copy_to_clipboard_or_prompt(exporters[k](flow))
copy_to_clipboard_or_prompt(exporters[k](f))
def copy_to_clipboard_or_prompt(data):

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,21 @@
from __future__ import absolute_import, print_function, division
from mitmproxy.flow import export, modules
from mitmproxy.flow.io import FlowWriter, FilteredFlowWriter, FlowReader, read_flows_from_paths
from mitmproxy.flow.master import FlowMaster
from mitmproxy.flow.modules import (
AppRegistry, ReplaceHooks, SetHeaders, StreamLargeBodies, ClientPlaybackState,
ServerPlaybackState, StickyCookieState, StickyAuthState
)
from mitmproxy.flow.state import State, FlowView
# TODO: We may want to remove the imports from .modules and just expose "modules"
__all__ = [
"export", "modules",
"FlowWriter", "FilteredFlowWriter", "FlowReader", "read_flows_from_paths",
"FlowMaster",
"AppRegistry", "ReplaceHooks", "SetHeaders", "StreamLargeBodies", "ClientPlaybackState",
"ServerPlaybackState", "StickyCookieState", "StickyAuthState",
"State", "FlowView",
]

View File

@@ -1,13 +1,12 @@
import json
import re
from textwrap import dedent
from six.moves.urllib.parse import quote, quote_plus
import netlib.http
from netlib.utils import parse_content_type
import re
from six.moves.urllib.parse import quote
from six.moves.urllib.parse import quote_plus
def curl_command(flow):
data = "curl "

83
mitmproxy/flow/io.py Normal file
View File

@@ -0,0 +1,83 @@
import os
from mitmproxy import tnetstring, models
from mitmproxy.exceptions import FlowReadException
from mitmproxy.flow import io_compat
class FlowWriter:
def __init__(self, fo):
self.fo = fo
def add(self, flow):
d = flow.get_state()
tnetstring.dump(d, self.fo)
class FlowReader:
def __init__(self, fo):
self.fo = fo
def stream(self):
"""
Yields Flow objects from the dump.
"""
# There is a weird mingw bug that breaks .tell() when reading from stdin.
try:
self.fo.tell()
except IOError: # pragma: no cover
can_tell = False
else:
can_tell = True
off = 0
try:
while True:
data = tnetstring.load(self.fo)
try:
data = io_compat.migrate_flow(data)
except ValueError as e:
raise FlowReadException(str(e))
if can_tell:
off = self.fo.tell()
if data["type"] not in models.FLOW_TYPES:
raise FlowReadException("Unknown flow type: {}".format(data["type"]))
yield models.FLOW_TYPES[data["type"]].from_state(data)
except ValueError:
# Error is due to EOF
if can_tell and self.fo.tell() == off and self.fo.read() == '':
return
raise FlowReadException("Invalid data format.")
class FilteredFlowWriter:
def __init__(self, fo, filt):
self.fo = fo
self.filt = filt
def add(self, f):
if self.filt and not f.match(self.filt):
return
d = f.get_state()
tnetstring.dump(d, self.fo)
def read_flows_from_paths(paths):
"""
Given a list of filepaths, read all flows and return a list of them.
From a performance perspective, streaming would be advisable -
however, if there's an error with one of the files, we want it to be raised immediately.
Raises:
FlowReadException, if any error occurs.
"""
try:
flows = []
for path in paths:
path = os.path.expanduser(path)
with open(path, "rb") as f:
flows.extend(FlowReader(f).stream())
except IOError as e:
raise FlowReadException(e.strerror)
return flows

View File

@@ -2,7 +2,8 @@
This module handles the import of mitmproxy flows generated by old versions.
"""
from __future__ import absolute_import, print_function, division
from . import version
from mitmproxy import version
def convert_013_014(data):

538
mitmproxy/flow/master.py Normal file
View File

@@ -0,0 +1,538 @@
import os
import sys
from typing import List, Optional, Set
from mitmproxy import controller, script, filt, models
from mitmproxy.exceptions import FlowReadException, Kill
from mitmproxy.flow import io, modules
from mitmproxy.onboarding import app
from mitmproxy.protocol.http_replay import RequestReplayThread
from mitmproxy.proxy.config import HostMatcher
from netlib import utils
from netlib.exceptions import HttpException
class FlowMaster(controller.Master):
@property
def server(self):
# At some point, we may want to have support for multiple servers.
# For now, this suffices.
if len(self.servers) > 0:
return self.servers[0]
def __init__(self, server, state):
super(FlowMaster, self).__init__()
if server:
self.add_server(server)
self.state = state
self.active_flows = set() # type: Set[models.Flow]
self.server_playback = None # type: Optional[modules.ServerPlaybackState]
self.client_playback = None # type: Optional[modules.ClientPlaybackState]
self.kill_nonreplay = False
self.scripts = [] # type: List[script.Script]
self.pause_scripts = False
self.stickycookie_state = None # type: Optional[modules.StickyCookieState]
self.stickycookie_txt = None
self.stickyauth_state = False # type: Optional[modules.StickyAuthState]
self.stickyauth_txt = None
self.anticache = False
self.anticomp = False
self.stream_large_bodies = None # type: Optional[modules.StreamLargeBodies]
self.refresh_server_playback = False
self.replacehooks = modules.ReplaceHooks()
self.setheaders = modules.SetHeaders()
self.replay_ignore_params = False
self.replay_ignore_content = None
self.replay_ignore_host = False
self.stream = None
self.apps = modules.AppRegistry()
def start_app(self, host, port):
self.apps.add(
app.mapp,
host,
port
)
def add_event(self, e, level="info"):
"""
level: debug, info, error
"""
def unload_scripts(self):
for s in self.scripts[:]:
self.unload_script(s)
def unload_script(self, script_obj):
try:
script_obj.unload()
except script.ScriptException as e:
self.add_event("Script error:\n" + str(e), "error")
script.reloader.unwatch(script_obj)
self.scripts.remove(script_obj)
def load_script(self, command, use_reloader=False):
"""
Loads a script.
Raises:
ScriptException
"""
s = script.Script(command, script.ScriptContext(self))
s.load()
if use_reloader:
script.reloader.watch(s, lambda: self.event_queue.put(("script_change", s)))
self.scripts.append(s)
def _run_single_script_hook(self, script_obj, name, *args, **kwargs):
if script_obj and not self.pause_scripts:
try:
script_obj.run(name, *args, **kwargs)
except script.ScriptException as e:
self.add_event("Script error:\n{}".format(e), "error")
def run_script_hook(self, name, *args, **kwargs):
for script_obj in self.scripts:
self._run_single_script_hook(script_obj, name, *args, **kwargs)
def get_ignore_filter(self):
return self.server.config.check_ignore.patterns
def set_ignore_filter(self, host_patterns):
self.server.config.check_ignore = HostMatcher(host_patterns)
def get_tcp_filter(self):
return self.server.config.check_tcp.patterns
def set_tcp_filter(self, host_patterns):
self.server.config.check_tcp = HostMatcher(host_patterns)
def set_stickycookie(self, txt):
if txt:
flt = filt.parse(txt)
if not flt:
return "Invalid filter expression."
self.stickycookie_state = modules.StickyCookieState(flt)
self.stickycookie_txt = txt
else:
self.stickycookie_state = None
self.stickycookie_txt = None
def set_stream_large_bodies(self, max_size):
if max_size is not None:
self.stream_large_bodies = modules.StreamLargeBodies(max_size)
else:
self.stream_large_bodies = False
def set_stickyauth(self, txt):
if txt:
flt = filt.parse(txt)
if not flt:
return "Invalid filter expression."
self.stickyauth_state = modules.StickyAuthState(flt)
self.stickyauth_txt = txt
else:
self.stickyauth_state = None
self.stickyauth_txt = None
def start_client_playback(self, flows, exit):
"""
flows: List of flows.
"""
self.client_playback = modules.ClientPlaybackState(flows, exit)
def stop_client_playback(self):
self.client_playback = None
def start_server_playback(
self,
flows,
kill,
headers,
exit,
nopop,
ignore_params,
ignore_content,
ignore_payload_params,
ignore_host):
"""
flows: List of flows.
kill: Boolean, should we kill requests not part of the replay?
ignore_params: list of parameters to ignore in server replay
ignore_content: true if request content should be ignored in server replay
ignore_payload_params: list of content params to ignore in server replay
ignore_host: true if request host should be ignored in server replay
"""
self.server_playback = modules.ServerPlaybackState(
headers,
flows,
exit,
nopop,
ignore_params,
ignore_content,
ignore_payload_params,
ignore_host)
self.kill_nonreplay = kill
def stop_server_playback(self):
self.server_playback = None
def do_server_playback(self, flow):
"""
This method should be called by child classes in the request
handler. Returns True if playback has taken place, None if not.
"""
if self.server_playback:
rflow = self.server_playback.next_flow(flow)
if not rflow:
return None
response = rflow.response.copy()
response.is_replay = True
if self.refresh_server_playback:
response.refresh()
flow.response = response
return True
return None
def tick(self, timeout):
if self.client_playback:
stop = (
self.client_playback.done() and
self.state.active_flow_count() == 0
)
exit = self.client_playback.exit
if stop:
self.stop_client_playback()
if exit:
self.shutdown()
else:
self.client_playback.tick(self)
if self.server_playback:
stop = (
self.server_playback.count() == 0 and
self.state.active_flow_count() == 0 and
not self.kill_nonreplay
)
exit = self.server_playback.exit
if stop:
self.stop_server_playback()
if exit:
self.shutdown()
return super(FlowMaster, self).tick(timeout)
def duplicate_flow(self, f):
f2 = f.copy()
self.load_flow(f2)
return f2
def create_request(self, method, scheme, host, port, path):
"""
this method creates a new artificial and minimalist request also adds it to flowlist
"""
c = models.ClientConnection.make_dummy(("", 0))
s = models.ServerConnection.make_dummy((host, port))
f = models.HTTPFlow(c, s)
headers = models.Headers()
req = models.HTTPRequest(
"absolute",
method,
scheme,
host,
port,
path,
b"HTTP/1.1",
headers,
b""
)
f.request = req
self.load_flow(f)
return f
def load_flow(self, f):
"""
Loads a flow
"""
if isinstance(f, models.HTTPFlow):
if self.server and self.server.config.mode == "reverse":
f.request.host = self.server.config.upstream_server.address.host
f.request.port = self.server.config.upstream_server.address.port
f.request.scheme = self.server.config.upstream_server.scheme
f.reply = controller.DummyReply()
if f.request:
self.request(f)
if f.response:
self.responseheaders(f)
self.response(f)
if f.error:
self.error(f)
elif isinstance(f, models.TCPFlow):
messages = f.messages
f.messages = []
f.reply = controller.DummyReply()
self.tcp_open(f)
while messages:
f.messages.append(messages.pop(0))
self.tcp_message(f)
if f.error:
self.tcp_error(f)
self.tcp_close(f)
else:
raise NotImplementedError()
def load_flows(self, fr):
"""
Load flows from a FlowReader object.
"""
cnt = 0
for i in fr.stream():
cnt += 1
self.load_flow(i)
return cnt
def load_flows_file(self, path):
path = os.path.expanduser(path)
try:
if path == "-":
# This is incompatible with Python 3 - maybe we can use click?
freader = io.FlowReader(sys.stdin)
return self.load_flows(freader)
else:
with open(path, "rb") as f:
freader = io.FlowReader(f)
return self.load_flows(freader)
except IOError as v:
raise FlowReadException(v.strerror)
def process_new_request(self, f):
if self.stickycookie_state:
self.stickycookie_state.handle_request(f)
if self.stickyauth_state:
self.stickyauth_state.handle_request(f)
if self.anticache:
f.request.anticache()
if self.anticomp:
f.request.anticomp()
if self.server_playback:
pb = self.do_server_playback(f)
if not pb and self.kill_nonreplay:
f.kill(self)
def process_new_response(self, f):
if self.stickycookie_state:
self.stickycookie_state.handle_response(f)
def replay_request(self, f, block=False, run_scripthooks=True):
"""
Returns None if successful, or error message if not.
"""
if f.live and run_scripthooks:
return "Can't replay live request."
if f.intercepted:
return "Can't replay while intercepting..."
if f.request.content is None:
return "Can't replay request with missing content..."
if f.request:
f.backup()
f.request.is_replay = True
if "Content-Length" in f.request.headers:
f.request.headers["Content-Length"] = str(len(f.request.content))
f.response = None
f.error = None
self.process_new_request(f)
rt = RequestReplayThread(
self.server.config,
f,
self.event_queue if run_scripthooks else False,
self.should_exit
)
rt.start() # pragma: no cover
if block:
rt.join()
@controller.handler
def log(self, l):
self.add_event(l.msg, l.level)
@controller.handler
def clientconnect(self, root_layer):
self.run_script_hook("clientconnect", root_layer)
@controller.handler
def clientdisconnect(self, root_layer):
self.run_script_hook("clientdisconnect", root_layer)
@controller.handler
def serverconnect(self, server_conn):
self.run_script_hook("serverconnect", server_conn)
@controller.handler
def serverdisconnect(self, server_conn):
self.run_script_hook("serverdisconnect", server_conn)
@controller.handler
def next_layer(self, top_layer):
self.run_script_hook("next_layer", top_layer)
@controller.handler
def error(self, f):
self.state.update_flow(f)
self.run_script_hook("error", f)
if self.client_playback:
self.client_playback.clear(f)
return f
@controller.handler
def request(self, f):
if f.live:
app = self.apps.get(f.request)
if app:
err = app.serve(
f,
f.client_conn.wfile,
**{"mitmproxy.master": self}
)
if err:
self.add_event("Error in wsgi app. %s" % err, "error")
f.reply(Kill)
return
if f not in self.state.flows: # don't add again on replay
self.state.add_flow(f)
self.active_flows.add(f)
self.replacehooks.run(f)
self.setheaders.run(f)
self.process_new_request(f)
self.run_script_hook("request", f)
return f
@controller.handler
def responseheaders(self, f):
try:
if self.stream_large_bodies:
self.stream_large_bodies.run(f, False)
except HttpException:
f.reply(Kill)
return
self.run_script_hook("responseheaders", f)
return f
@controller.handler
def response(self, f):
self.active_flows.discard(f)
self.state.update_flow(f)
self.replacehooks.run(f)
self.setheaders.run(f)
self.run_script_hook("response", f)
if self.client_playback:
self.client_playback.clear(f)
self.process_new_response(f)
if self.stream:
self.stream.add(f)
return f
def handle_intercept(self, f):
self.state.update_flow(f)
def handle_accept_intercept(self, f):
self.state.update_flow(f)
@controller.handler
def script_change(self, s):
"""
Handle a script whose contents have been changed on the file system.
Args:
s (script.Script): the changed script
Returns:
True, if reloading was successful.
False, otherwise.
"""
ok = True
# We deliberately do not want to fail here.
# In the worst case, we have an "empty" script object.
try:
s.unload()
except script.ScriptException as e:
ok = False
self.add_event('Error reloading "{}":\n{}'.format(s.filename, e), 'error')
try:
s.load()
except script.ScriptException as e:
ok = False
self.add_event('Error reloading "{}":\n{}'.format(s.filename, e), 'error')
else:
self.add_event('"{}" reloaded.'.format(s.filename), 'info')
return ok
@controller.handler
def tcp_open(self, flow):
# TODO: This would break mitmproxy currently.
# self.state.add_flow(flow)
self.active_flows.add(flow)
self.run_script_hook("tcp_open", flow)
@controller.handler
def tcp_message(self, flow):
self.run_script_hook("tcp_message", flow)
message = flow.messages[-1]
direction = "->" if message.from_client else "<-"
self.add_event("{client} {direction} tcp {direction} {server}".format(
client=repr(flow.client_conn.address),
server=repr(flow.server_conn.address),
direction=direction,
), "info")
self.add_event(utils.clean_bin(message.content), "debug")
@controller.handler
def tcp_error(self, flow):
self.add_event("Error in TCP connection to {}: {}".format(
repr(flow.server_conn.address),
flow.error
), "info")
self.run_script_hook("tcp_error", flow)
@controller.handler
def tcp_close(self, flow):
self.active_flows.discard(flow)
if self.stream:
self.stream.add(flow)
self.run_script_hook("tcp_close", flow)
def shutdown(self):
super(FlowMaster, self).shutdown()
# Add all flows that are still active
if self.stream:
for flow in self.active_flows:
self.stream.add(flow)
self.stop_stream()
self.unload_scripts()
def start_stream(self, fp, filt):
self.stream = io.FilteredFlowWriter(fp, filt)
def stop_stream(self):
self.stream.fo.close()
self.stream = None
def start_stream_to_path(self, path, mode="wb", filt=None):
path = os.path.expanduser(path)
try:
f = open(path, mode)
self.start_stream(f, filt)
except IOError as v:
return str(v)
self.stream_path = path

353
mitmproxy/flow/modules.py Normal file
View File

@@ -0,0 +1,353 @@
import collections
import hashlib
import re
from six.moves import http_cookiejar
from six.moves import urllib
from mitmproxy import version, filt, controller
from netlib import wsgi
from netlib.http import http1, cookies
class AppRegistry:
def __init__(self):
self.apps = {}
def add(self, app, domain, port):
"""
Add a WSGI app to the registry, to be served for requests to the
specified domain, on the specified port.
"""
self.apps[(domain, port)] = wsgi.WSGIAdaptor(
app,
domain,
port,
version.NAMEVERSION
)
def get(self, request):
"""
Returns an WSGIAdaptor instance if request matches an app, or None.
"""
if (request.host, request.port) in self.apps:
return self.apps[(request.host, request.port)]
if "host" in request.headers:
host = request.headers["host"]
return self.apps.get((host, request.port), None)
class ReplaceHooks:
def __init__(self):
self.lst = []
def set(self, r):
self.clear()
for i in r:
self.add(*i)
def add(self, fpatt, rex, s):
"""
add a replacement hook.
fpatt: a string specifying a filter pattern.
rex: a regular expression.
s: the replacement string
returns true if hook was added, false if the pattern could not be
parsed.
"""
cpatt = filt.parse(fpatt)
if not cpatt:
return False
try:
re.compile(rex)
except re.error:
return False
self.lst.append((fpatt, rex, s, cpatt))
return True
def get_specs(self):
"""
Retrieve the hook specifcations. Returns a list of (fpatt, rex, s)
tuples.
"""
return [i[:3] for i in self.lst]
def count(self):
return len(self.lst)
def run(self, f):
for _, rex, s, cpatt in self.lst:
if cpatt(f):
if f.response:
f.response.replace(rex, s)
else:
f.request.replace(rex, s)
def clear(self):
self.lst = []
class SetHeaders:
def __init__(self):
self.lst = []
def set(self, r):
self.clear()
for i in r:
self.add(*i)
def add(self, fpatt, header, value):
"""
Add a set header hook.
fpatt: String specifying a filter pattern.
header: Header name.
value: Header value string
Returns True if hook was added, False if the pattern could not be
parsed.
"""
cpatt = filt.parse(fpatt)
if not cpatt:
return False
self.lst.append((fpatt, header, value, cpatt))
return True
def get_specs(self):
"""
Retrieve the hook specifcations. Returns a list of (fpatt, rex, s)
tuples.
"""
return [i[:3] for i in self.lst]
def count(self):
return len(self.lst)
def clear(self):
self.lst = []
def run(self, f):
for _, header, value, cpatt in self.lst:
if cpatt(f):
if f.response:
f.response.headers.pop(header, None)
else:
f.request.headers.pop(header, None)
for _, header, value, cpatt in self.lst:
if cpatt(f):
if f.response:
f.response.headers.add(header, value)
else:
f.request.headers.add(header, value)
class StreamLargeBodies(object):
def __init__(self, max_size):
self.max_size = max_size
def run(self, flow, is_request):
r = flow.request if is_request else flow.response
expected_size = http1.expected_http_body_size(
flow.request, flow.response if not is_request else None
)
if not r.content and not (0 <= expected_size <= self.max_size):
# r.stream may already be a callable, which we want to preserve.
r.stream = r.stream or True
class ClientPlaybackState:
def __init__(self, flows, exit):
self.flows, self.exit = flows, exit
self.current = None
self.testing = False # Disables actual replay for testing.
def count(self):
return len(self.flows)
def done(self):
if len(self.flows) == 0 and not self.current:
return True
return False
def clear(self, flow):
"""
A request has returned in some way - if this is the one we're
servicing, go to the next flow.
"""
if flow is self.current:
self.current = None
def tick(self, master):
if self.flows and not self.current:
self.current = self.flows.pop(0).copy()
if not self.testing:
master.replay_request(self.current)
else:
self.current.reply = controller.DummyReply()
master.request(self.current)
if self.current.response:
master.response(self.current)
class ServerPlaybackState:
def __init__(
self,
headers,
flows,
exit,
nopop,
ignore_params,
ignore_content,
ignore_payload_params,
ignore_host):
"""
headers: Case-insensitive list of request headers that should be
included in request-response matching.
"""
self.headers = headers
self.exit = exit
self.nopop = nopop
self.ignore_params = ignore_params
self.ignore_content = ignore_content
self.ignore_payload_params = ignore_payload_params
self.ignore_host = ignore_host
self.fmap = {}
for i in flows:
if i.response:
l = self.fmap.setdefault(self._hash(i), [])
l.append(i)
def count(self):
return sum(len(i) for i in self.fmap.values())
def _hash(self, flow):
"""
Calculates a loose hash of the flow request.
"""
r = flow.request
_, _, path, _, query, _ = urllib.parse.urlparse(r.url)
queriesArray = urllib.parse.parse_qsl(query, keep_blank_values=True)
key = [
str(r.port),
str(r.scheme),
str(r.method),
str(path),
]
if not self.ignore_content:
form_contents = r.urlencoded_form or r.multipart_form
if self.ignore_payload_params and form_contents:
key.extend(
p for p in form_contents.items(multi=True)
if p[0] not in self.ignore_payload_params
)
else:
key.append(str(r.content))
if not self.ignore_host:
key.append(r.host)
filtered = []
ignore_params = self.ignore_params or []
for p in queriesArray:
if p[0] not in ignore_params:
filtered.append(p)
for p in filtered:
key.append(p[0])
key.append(p[1])
if self.headers:
headers = []
for i in self.headers:
v = r.headers.get(i)
headers.append((i, v))
key.append(headers)
return hashlib.sha256(repr(key)).digest()
def next_flow(self, request):
"""
Returns the next flow object, or None if no matching flow was
found.
"""
l = self.fmap.get(self._hash(request))
if not l:
return None
if self.nopop:
return l[0]
else:
return l.pop(0)
class StickyCookieState:
def __init__(self, flt):
"""
flt: Compiled filter.
"""
self.jar = collections.defaultdict(dict)
self.flt = flt
def ckey(self, attrs, f):
"""
Returns a (domain, port, path) tuple.
"""
domain = f.request.host
path = "/"
if "domain" in attrs:
domain = attrs["domain"]
if "path" in attrs:
path = attrs["path"]
return (domain, f.request.port, path)
def domain_match(self, a, b):
if http_cookiejar.domain_match(a, b):
return True
elif http_cookiejar.domain_match(a, b.strip(".")):
return True
return False
def handle_response(self, f):
for name, (value, attrs) in f.response.cookies.items(multi=True):
# FIXME: We now know that Cookie.py screws up some cookies with
# valid RFC 822/1123 datetime specifications for expiry. Sigh.
a = self.ckey(attrs, f)
if self.domain_match(f.request.host, a[0]):
b = attrs.with_insert(0, name, value)
self.jar[a][name] = b
def handle_request(self, f):
l = []
if f.match(self.flt):
for domain, port, path in self.jar.keys():
match = [
self.domain_match(f.request.host, domain),
f.request.port == port,
f.request.path.startswith(path)
]
if all(match):
c = self.jar[(domain, port, path)]
l.extend([cookies.format_cookie_header(c[name].items(multi=True)) for name in c.keys()])
if l:
f.request.stickycookie = True
f.request.headers["cookie"] = "; ".join(l)
class StickyAuthState:
def __init__(self, flt):
"""
flt: Compiled filter.
"""
self.flt = flt
self.hosts = {}
def handle_request(self, f):
host = f.request.host
if "authorization" in f.request.headers:
self.hosts[host] = f.request.headers["authorization"]
elif f.match(self.flt):
if host in self.hosts:
f.request.headers["authorization"] = self.hosts[host]

266
mitmproxy/flow/state.py Normal file
View File

@@ -0,0 +1,266 @@
from abc import abstractmethod, ABCMeta
import six
from typing import List
from mitmproxy import models, filt
@six.add_metaclass(ABCMeta)
class FlowList(object):
def __init__(self):
self._list = [] # type: List[models.Flow]
def __iter__(self):
return iter(self._list)
def __contains__(self, item):
return item in self._list
def __getitem__(self, item):
return self._list[item]
def __bool__(self):
return bool(self._list)
if six.PY2:
__nonzero__ = __bool__
def __len__(self):
return len(self._list)
def index(self, f):
return self._list.index(f)
@abstractmethod
def _add(self, f):
return
@abstractmethod
def _update(self, f):
return
@abstractmethod
def _remove(self, f):
return
def _pos(*args):
return True
class FlowView(FlowList):
def __init__(self, store, filt=None):
super(FlowView, self).__init__()
if not filt:
filt = _pos
self._build(store, filt)
self.store = store
self.store.views.append(self)
def _close(self):
self.store.views.remove(self)
def _build(self, flows, filt=None):
if filt:
self.filt = filt
self._list = list(filter(self.filt, flows))
def _add(self, f):
if self.filt(f):
self._list.append(f)
def _update(self, f):
if f not in self._list:
self._add(f)
elif not self.filt(f):
self._remove(f)
def _remove(self, f):
if f in self._list:
self._list.remove(f)
def _recalculate(self, flows):
self._build(flows)
class FlowStore(FlowList):
"""
Responsible for handling flows in the state:
Keeps a list of all flows and provides views on them.
"""
def __init__(self):
super(FlowStore, self).__init__()
self._set = set() # Used for O(1) lookups
self.views = []
self._recalculate_views()
def get(self, flow_id):
for f in self._list:
if f.id == flow_id:
return f
def __contains__(self, f):
return f in self._set
def _add(self, f):
"""
Adds a flow to the state.
The flow to add must not be present in the state.
"""
self._list.append(f)
self._set.add(f)
for view in self.views:
view._add(f)
def _update(self, f):
"""
Notifies the state that a flow has been updated.
The flow must be present in the state.
"""
if f in self:
for view in self.views:
view._update(f)
def _remove(self, f):
"""
Deletes a flow from the state.
The flow must be present in the state.
"""
self._list.remove(f)
self._set.remove(f)
for view in self.views:
view._remove(f)
# Expensive bulk operations
def _extend(self, flows):
"""
Adds a list of flows to the state.
The list of flows to add must not contain flows that are already in the state.
"""
self._list.extend(flows)
self._set.update(flows)
self._recalculate_views()
def _clear(self):
self._list = []
self._set = set()
self._recalculate_views()
def _recalculate_views(self):
"""
Expensive operation: Recalculate all the views after a bulk change.
"""
for view in self.views:
view._recalculate(self)
# Utility functions.
# There are some common cases where we need to argue about all flows
# irrespective of filters on the view etc (i.e. on shutdown).
def active_count(self):
c = 0
for i in self._list:
if not i.response and not i.error:
c += 1
return c
# TODO: Should accept_all operate on views or on all flows?
def accept_all(self, master):
for f in self._list:
f.accept_intercept(master)
def kill_all(self, master):
for f in self._list:
if not f.reply.acked:
f.kill(master)
class State(object):
def __init__(self):
self.flows = FlowStore()
self.view = FlowView(self.flows, None)
# These are compiled filt expressions:
self.intercept = None
@property
def limit_txt(self):
return getattr(self.view.filt, "pattern", None)
def flow_count(self):
return len(self.flows)
# TODO: All functions regarding flows that don't cause side-effects should
# be moved into FlowStore.
def index(self, f):
return self.flows.index(f)
def active_flow_count(self):
return self.flows.active_count()
def add_flow(self, f):
"""
Add a request to the state.
"""
self.flows._add(f)
return f
def update_flow(self, f):
"""
Add a response to the state.
"""
self.flows._update(f)
return f
def delete_flow(self, f):
self.flows._remove(f)
def load_flows(self, flows):
self.flows._extend(flows)
def set_limit(self, txt):
if txt == self.limit_txt:
return
if txt:
f = filt.parse(txt)
if not f:
return "Invalid filter expression."
self.view._close()
self.view = FlowView(self.flows, f)
else:
self.view._close()
self.view = FlowView(self.flows, None)
def set_intercept(self, txt):
if txt:
f = filt.parse(txt)
if not f:
return "Invalid filter expression."
self.intercept = f
else:
self.intercept = None
@property
def intercept_txt(self):
return getattr(self.intercept, "pattern", None)
def clear(self):
self.flows._clear()
def accept_all(self, master):
self.flows.accept_all(master)
def backup(self, f):
f.backup()
self.update_flow(f)
def revert(self, f):
f.revert()
self.update_flow(f)
def killall(self, master):
self.flows.kill_all(master)

View File

@@ -3,7 +3,7 @@ import re
import netlib.tutils
from netlib.http import Headers
from mitmproxy import flow_export
from mitmproxy.flow import export # heh
from . import tutils
@@ -36,38 +36,38 @@ class TestExportCurlCommand():
def test_get(self):
flow = tutils.tflow(req=req_get())
result = """curl -H 'header:qvalue' -H 'content-length:7' 'http://address/path?a=foo&a=bar&b=baz'"""
assert flow_export.curl_command(flow) == result
assert export.curl_command(flow) == result
def test_post(self):
flow = tutils.tflow(req=req_post())
result = """curl -X POST 'http://address/path' --data-binary 'content'"""
assert flow_export.curl_command(flow) == result
assert export.curl_command(flow) == result
def test_patch(self):
flow = tutils.tflow(req=req_patch())
result = """curl -H 'header:qvalue' -H 'content-length:7' -X PATCH 'http://address/path?query=param' --data-binary 'content'"""
assert flow_export.curl_command(flow) == result
assert export.curl_command(flow) == result
class TestExportPythonCode():
def test_get(self):
flow = tutils.tflow(req=req_get())
python_equals("data/test_flow_export/python_get.py", flow_export.python_code(flow))
python_equals("data/test_flow_export/python_get.py", export.python_code(flow))
def test_post(self):
flow = tutils.tflow(req=req_post())
python_equals("data/test_flow_export/python_post.py", flow_export.python_code(flow))
python_equals("data/test_flow_export/python_post.py", export.python_code(flow))
def test_post_json(self):
p = req_post()
p.content = '{"name": "example", "email": "example@example.com"}'
p.headers = Headers(content_type="application/json")
flow = tutils.tflow(req=p)
python_equals("data/test_flow_export/python_post_json.py", flow_export.python_code(flow))
python_equals("data/test_flow_export/python_post_json.py", export.python_code(flow))
def test_patch(self):
flow = tutils.tflow(req=req_patch())
python_equals("data/test_flow_export/python_patch.py", flow_export.python_code(flow))
python_equals("data/test_flow_export/python_patch.py", export.python_code(flow))
class TestRawRequest():
@@ -80,7 +80,7 @@ class TestRawRequest():
host: address:22\r
\r
""").strip(" ").lstrip()
assert flow_export.raw_request(flow) == result
assert export.raw_request(flow) == result
def test_post(self):
flow = tutils.tflow(req=req_post())
@@ -90,7 +90,7 @@ class TestRawRequest():
\r
content
""").strip()
assert flow_export.raw_request(flow) == result
assert export.raw_request(flow) == result
def test_patch(self):
flow = tutils.tflow(req=req_patch())
@@ -102,54 +102,54 @@ class TestRawRequest():
\r
content
""").strip()
assert flow_export.raw_request(flow) == result
assert export.raw_request(flow) == result
class TestExportLocustCode():
def test_get(self):
flow = tutils.tflow(req=req_get())
python_equals("data/test_flow_export/locust_get.py", flow_export.locust_code(flow))
python_equals("data/test_flow_export/locust_get.py", export.locust_code(flow))
def test_post(self):
p = req_post()
p.content = '''content'''
p.headers = ''
flow = tutils.tflow(req=p)
python_equals("data/test_flow_export/locust_post.py", flow_export.locust_code(flow))
python_equals("data/test_flow_export/locust_post.py", export.locust_code(flow))
def test_patch(self):
flow = tutils.tflow(req=req_patch())
python_equals("data/test_flow_export/locust_patch.py", flow_export.locust_code(flow))
python_equals("data/test_flow_export/locust_patch.py", export.locust_code(flow))
class TestExportLocustTask():
def test_get(self):
flow = tutils.tflow(req=req_get())
python_equals("data/test_flow_export/locust_task_get.py", flow_export.locust_task(flow))
python_equals("data/test_flow_export/locust_task_get.py", export.locust_task(flow))
def test_post(self):
flow = tutils.tflow(req=req_post())
python_equals("data/test_flow_export/locust_task_post.py", flow_export.locust_task(flow))
python_equals("data/test_flow_export/locust_task_post.py", export.locust_task(flow))
def test_patch(self):
flow = tutils.tflow(req=req_patch())
python_equals("data/test_flow_export/locust_task_patch.py", flow_export.locust_task(flow))
python_equals("data/test_flow_export/locust_task_patch.py", export.locust_task(flow))
class TestIsJson():
def test_empty(self):
assert flow_export.is_json(None, None) is False
assert export.is_json(None, None) is False
def test_json_type(self):
headers = Headers(content_type="application/json")
assert flow_export.is_json(headers, "foobar") is False
assert export.is_json(headers, "foobar") is False
def test_valid(self):
headers = Headers(content_type="application/foobar")
j = flow_export.is_json(headers, '{"name": "example", "email": "example@example.com"}')
j = export.is_json(headers, '{"name": "example", "email": "example@example.com"}')
assert j is False
def test_valid2(self):
headers = Headers(content_type="application/json")
j = flow_export.is_json(headers, '{"name": "example", "email": "example@example.com"}')
j = export.is_json(headers, '{"name": "example", "email": "example@example.com"}')
assert isinstance(j, dict)