Extract flow reading into addons

This patch moves the final pieces of master functionality into addons.

- Add a ReadFile addon to read from file
- Add a separate ReadStdin addon to read from stdin, only used by mitmdump
- Remove all methods that know about io and serialization from master.Master
This commit is contained in:
Aldo Cortesi
2017-03-15 12:47:03 +13:00
committed by Aldo Cortesi
parent eba6d4359c
commit ef582333ff
13 changed files with 235 additions and 123 deletions

View File

@@ -8,6 +8,7 @@ from mitmproxy.addons import disable_h2c
from mitmproxy.addons import onboarding from mitmproxy.addons import onboarding
from mitmproxy.addons import proxyauth from mitmproxy.addons import proxyauth
from mitmproxy.addons import replace from mitmproxy.addons import replace
from mitmproxy.addons import readfile
from mitmproxy.addons import script from mitmproxy.addons import script
from mitmproxy.addons import serverplayback from mitmproxy.addons import serverplayback
from mitmproxy.addons import setheaders from mitmproxy.addons import setheaders
@@ -37,5 +38,6 @@ def default_addons():
stickycookie.StickyCookie(), stickycookie.StickyCookie(),
streambodies.StreamBodies(), streambodies.StreamBodies(),
streamfile.StreamFile(), streamfile.StreamFile(),
readfile.ReadFile(),
upstream_auth.UpstreamAuth(), upstream_auth.UpstreamAuth(),
] ]

View File

@@ -0,0 +1,50 @@
import os.path
from mitmproxy import ctx
from mitmproxy import io
from mitmproxy import exceptions
class ReadFile:
"""
An addon that handles reading from file on startup.
"""
def __init__(self):
self.path = None
self.keepserving = False
def load_flows_file(self, path: str) -> int:
path = os.path.expanduser(path)
cnt = 0
try:
with open(path, "rb") as f:
freader = io.FlowReader(f)
for i in freader.stream():
cnt += 1
ctx.master.load_flow(i)
return cnt
except (IOError, exceptions.FlowReadException) as v:
if cnt:
ctx.log.warn(
"Flow file corrupted - loaded %i flows." % cnt,
)
else:
ctx.log.error("Flow file corrupted.")
raise exceptions.FlowReadException(v)
def configure(self, options, updated):
if "keepserving" in updated:
self.keepserving = options.keepserving
if "rfile" in updated and options.rfile:
self.path = options.rfile
def running(self):
if self.path:
try:
self.load_flows_file(self.path)
except exceptions.FlowReadException as v:
raise exceptions.OptionsError(v)
finally:
self.path = None
if not self.keepserving:
ctx.master.shutdown()

View File

@@ -0,0 +1,34 @@
from mitmproxy import ctx
from mitmproxy import io
from mitmproxy import exceptions
import sys
class ReadStdin:
"""
An addon that reads from stdin if we're not attached to (someting like)
a tty.
"""
def __init__(self):
self.keepserving = False
def configure(self, options, updated):
if "keepserving" in updated:
self.keepserving = options.keepserving
def running(self, stdin = sys.stdin):
if not stdin.isatty():
ctx.log.info("Reading from stdin")
try:
stdin.buffer.read(0)
except Exception as e:
ctx.log.warn("Cannot read from stdin: {}".format(e))
return
freader = io.FlowReader(stdin.buffer)
try:
for i in freader.stream():
ctx.master.load_flow(i)
except exceptions.FlowReadException as e:
ctx.log.error("Error reading from stdin: %s" % e)
if not self.keepserving:
ctx.master.shutdown()

View File

@@ -1,8 +1,6 @@
import os
import threading import threading
import contextlib import contextlib
import queue import queue
import sys
from mitmproxy import addonmanager from mitmproxy import addonmanager
from mitmproxy import options from mitmproxy import options
@@ -12,7 +10,6 @@ from mitmproxy import exceptions
from mitmproxy import connections from mitmproxy import connections
from mitmproxy import http from mitmproxy import http
from mitmproxy import log from mitmproxy import log
from mitmproxy import io
from mitmproxy.proxy.protocol import http_replay from mitmproxy.proxy.protocol import http_replay
from mitmproxy.types import basethread from mitmproxy.types import basethread
import mitmproxy.net.http import mitmproxy.net.http
@@ -160,33 +157,6 @@ class Master:
for e, o in eventsequence.iterate(f): for e, o in eventsequence.iterate(f):
getattr(self, e)(o) getattr(self, e)(o)
def load_flows(self, fr: io.FlowReader) -> int:
"""
Load flows from a FlowReader object.
"""
cnt = 0
for i in fr.stream():
cnt += 1
self.load_flow(i)
return cnt
def load_flows_file(self, path: str) -> int:
path = os.path.expanduser(path)
try:
if path == "-":
try:
sys.stdin.buffer.read(0)
except Exception as e:
raise IOError("Cannot read from stdin: {}".format(e))
freader = io.FlowReader(sys.stdin.buffer)
return self.load_flows(freader)
else:
with open(path, "rb") as f:
freader = io.FlowReader(f)
return self.load_flows(freader)
except IOError as v:
raise exceptions.FlowReadException(v.strerror)
def replay_request( def replay_request(
self, self,
f: http.HTTPFlow, f: http.HTTPFlow,

View File

@@ -256,19 +256,6 @@ class ConsoleMaster(master.Master):
) )
self.ab = statusbar.ActionBar() self.ab = statusbar.ActionBar()
if self.options.rfile:
ret = self.load_flows_path(self.options.rfile)
if ret and self.view.store_count():
signals.add_log(
"File truncated or corrupted. "
"Loaded as many flows as possible.",
"error"
)
elif ret and not self.view.store_count():
self.shutdown()
print("Could not load file: {}".format(ret), file=sys.stderr)
sys.exit(1)
self.loop.set_alarm_in(0.01, self.ticker) self.loop.set_alarm_in(0.01, self.ticker)
self.loop.set_alarm_in( self.loop.set_alarm_in(
@@ -289,7 +276,10 @@ class ConsoleMaster(master.Master):
print("Shutting down...", file=sys.stderr) print("Shutting down...", file=sys.stderr)
finally: finally:
sys.stderr.flush() sys.stderr.flush()
self.shutdown() super().shutdown()
def shutdown(self):
raise urwid.ExitMainLoop
def view_help(self, helpctx): def view_help(self, helpctx):
signals.push_view_state.send( signals.push_view_state.send(
@@ -402,7 +392,7 @@ class ConsoleMaster(master.Master):
def quit(self, a): def quit(self, a):
if a != "n": if a != "n":
raise urwid.ExitMainLoop self.shutdown()
def clear_events(self): def clear_events(self):
self.logbuffer[:] = [] self.logbuffer[:] = []

View File

@@ -1,9 +1,8 @@
from mitmproxy import controller from mitmproxy import controller
from mitmproxy import exceptions
from mitmproxy import addons from mitmproxy import addons
from mitmproxy import options from mitmproxy import options
from mitmproxy import master from mitmproxy import master
from mitmproxy.addons import dumper, termlog, termstatus from mitmproxy.addons import dumper, termlog, termstatus, readstdin
class DumpMaster(master.Master): class DumpMaster(master.Master):
@@ -22,21 +21,9 @@ class DumpMaster(master.Master):
self.addons.add(*addons.default_addons()) self.addons.add(*addons.default_addons())
if with_dumper: if with_dumper:
self.addons.add(dumper.Dumper()) self.addons.add(dumper.Dumper())
self.addons.add(readstdin.ReadStdin())
if options.rfile:
try:
self.load_flows_file(options.rfile)
except exceptions.FlowReadException as v:
self.add_log("Flow file corrupted.", "error")
raise exceptions.OptionsError(v)
@controller.handler @controller.handler
def log(self, e): def log(self, e):
if e.level == "error": if e.level == "error":
self.has_errored = True self.has_errored = True
def run(self): # pragma: no cover
if self.options.rfile and not self.options.keepserving:
self.addons.done()
return
super().run()

View File

@@ -230,7 +230,8 @@ class DumpFlows(RequestHandler):
def post(self): def post(self):
self.view.clear() self.view.clear()
bio = BytesIO(self.filecontents) bio = BytesIO(self.filecontents)
self.master.load_flows(io.FlowReader(bio)) for i in io.FlowReader(bio).stream():
self.master.load_flow(i)
bio.close() bio.close()

View File

@@ -3,7 +3,6 @@ import webbrowser
import tornado.httpserver import tornado.httpserver
import tornado.ioloop import tornado.ioloop
from mitmproxy import addons from mitmproxy import addons
from mitmproxy import exceptions
from mitmproxy import log from mitmproxy import log
from mitmproxy import master from mitmproxy import master
from mitmproxy.addons import eventstore from mitmproxy.addons import eventstore
@@ -42,14 +41,6 @@ class WebMaster(master.Master):
) )
# This line is just for type hinting # This line is just for type hinting
self.options = self.options # type: Options self.options = self.options # type: Options
if options.rfile:
try:
self.load_flows_file(options.rfile)
except exceptions.FlowReadException as v:
self.add_log(
"Could not read flow file: %s" % v,
"error"
)
def _sig_view_add(self, view, flow): def _sig_view_add(self, view, flow):
app.ClientConnection.broadcast( app.ClientConnection.broadcast(

View File

@@ -0,0 +1,62 @@
from mitmproxy.addons import readfile
from mitmproxy.test import taddons
from mitmproxy.test import tflow
from mitmproxy import io
from mitmproxy import exceptions
from unittest import mock
import pytest
def write_data(path, corrupt=False):
with open(path, "wb") as tf:
w = io.FlowWriter(tf)
for i in range(3):
f = tflow.tflow(resp=True)
w.add(f)
for i in range(3):
f = tflow.tflow(err=True)
w.add(f)
f = tflow.ttcpflow()
w.add(f)
f = tflow.ttcpflow(err=True)
w.add(f)
if corrupt:
tf.write(b"flibble")
@mock.patch('mitmproxy.master.Master.load_flow')
def test_configure(mck, tmpdir):
rf = readfile.ReadFile()
with taddons.context() as tctx:
tf = str(tmpdir.join("tfile"))
write_data(tf)
tctx.configure(rf, rfile=str(tf), keepserving=False)
assert not mck.called
rf.running()
assert mck.called
write_data(tf, corrupt=True)
tctx.configure(rf, rfile=str(tf), keepserving=False)
with pytest.raises(exceptions.OptionsError):
rf.running()
@mock.patch('mitmproxy.master.Master.load_flow')
def test_corruption(mck, tmpdir):
rf = readfile.ReadFile()
with taddons.context() as tctx:
with pytest.raises(exceptions.FlowReadException):
rf.load_flows_file("nonexistent")
assert not mck.called
assert len(tctx.master.event_log) == 1
tfc = str(tmpdir.join("tfile"))
write_data(tfc, corrupt=True)
with pytest.raises(exceptions.FlowReadException):
rf.load_flows_file(tfc)
assert mck.called
assert len(tctx.master.event_log) == 2

View File

@@ -0,0 +1,59 @@
import io
from mitmproxy.addons import readstdin
from mitmproxy.test import taddons
from mitmproxy.test import tflow
import mitmproxy.io
from unittest import mock
def gen_data(corrupt=False):
tf = io.BytesIO()
w = mitmproxy.io.FlowWriter(tf)
for i in range(3):
f = tflow.tflow(resp=True)
w.add(f)
for i in range(3):
f = tflow.tflow(err=True)
w.add(f)
f = tflow.ttcpflow()
w.add(f)
f = tflow.ttcpflow(err=True)
w.add(f)
if corrupt:
tf.write(b"flibble")
tf.seek(0)
return tf
def test_configure(tmpdir):
rf = readstdin.ReadStdin()
with taddons.context() as tctx:
tctx.configure(rf, keepserving=False)
class mStdin:
def __init__(self, d):
self.buffer = d
def isatty(self):
return False
@mock.patch('mitmproxy.master.Master.load_flow')
def test_read(m, tmpdir):
rf = readstdin.ReadStdin()
with taddons.context() as tctx:
assert not m.called
rf.running(stdin=mStdin(gen_data()))
assert m.called
rf.running(stdin=mStdin(None))
assert tctx.master.event_log
tctx.master.clear()
m.reset_mock()
assert not m.called
rf.running(stdin=mStdin(gen_data(corrupt=True)))
assert m.called
assert tctx.master.event_log

View File

@@ -3,12 +3,13 @@ import pytest
from mitmproxy.test import tflow from mitmproxy.test import tflow
import mitmproxy.io import mitmproxy.io
from mitmproxy import flowfilter, options from mitmproxy import flowfilter
from mitmproxy import options
from mitmproxy.proxy import config
from mitmproxy.contrib import tnetstring from mitmproxy.contrib import tnetstring
from mitmproxy.exceptions import FlowReadException from mitmproxy.exceptions import FlowReadException
from mitmproxy import flow from mitmproxy import flow
from mitmproxy import http from mitmproxy import http
from mitmproxy.proxy import ProxyConfig
from mitmproxy.proxy.server import DummyServer from mitmproxy.proxy.server import DummyServer
from mitmproxy import master from mitmproxy import master
from . import tservers from . import tservers
@@ -16,23 +17,6 @@ from . import tservers
class TestSerialize: class TestSerialize:
def _treader(self):
sio = io.BytesIO()
w = mitmproxy.io.FlowWriter(sio)
for i in range(3):
f = tflow.tflow(resp=True)
w.add(f)
for i in range(3):
f = tflow.tflow(err=True)
w.add(f)
f = tflow.ttcpflow()
w.add(f)
f = tflow.ttcpflow(err=True)
w.add(f)
sio.seek(0)
return mitmproxy.io.FlowReader(sio)
def test_roundtrip(self): def test_roundtrip(self):
sio = io.BytesIO() sio = io.BytesIO()
f = tflow.tflow() f = tflow.tflow()
@@ -51,26 +35,6 @@ class TestSerialize:
assert f2.request == f.request assert f2.request == f.request
assert f2.marked assert f2.marked
def test_load_flows(self):
r = self._treader()
s = tservers.TestState()
fm = master.Master(None, DummyServer())
fm.addons.add(s)
fm.load_flows(r)
assert len(s.flows) == 6
def test_load_flows_reverse(self):
r = self._treader()
s = tservers.TestState()
opts = options.Options(
mode="reverse:https://use-this-domain"
)
conf = ProxyConfig(opts)
fm = master.Master(opts, DummyServer(conf))
fm.addons.add(s)
fm.load_flows(r)
assert s.flows[0].request.host == "use-this-domain"
def test_filter(self): def test_filter(self):
sio = io.BytesIO() sio = io.BytesIO()
flt = flowfilter.parse("~c 200") flt = flowfilter.parse("~c 200")
@@ -122,6 +86,17 @@ class TestSerialize:
class TestFlowMaster: class TestFlowMaster:
def test_load_flow_reverse(self):
s = tservers.TestState()
opts = options.Options(
mode="reverse:https://use-this-domain"
)
conf = config.ProxyConfig(opts)
fm = master.Master(opts, DummyServer(conf))
fm.addons.add(s)
f = tflow.tflow(resp=True)
fm.load_flow(f)
assert s.flows[0].request.host == "use-this-domain"
def test_replay(self): def test_replay(self):
fm = master.Master(None, DummyServer()) fm = master.Master(None, DummyServer())

View File

@@ -5,6 +5,7 @@ from mitmproxy import proxy
from mitmproxy import options from mitmproxy import options
from mitmproxy.tools.console import common from mitmproxy.tools.console import common
from ... import tservers from ... import tservers
import urwid
def test_format_keyvals(): def test_format_keyvals():
@@ -35,7 +36,10 @@ class TestMaster(tservers.MasterTest):
def test_basic(self): def test_basic(self):
m = self.mkmaster() m = self.mkmaster()
for i in (1, 2, 3): for i in (1, 2, 3):
self.dummy_cycle(m, 1, b"") try:
self.dummy_cycle(m, 1, b"")
except urwid.ExitMainLoop:
pass
assert len(m.view) == i assert len(m.view) == i
def test_run_script_once(self): def test_run_script_once(self):

View File

@@ -2,7 +2,6 @@ import pytest
from unittest import mock from unittest import mock
from mitmproxy import proxy from mitmproxy import proxy
from mitmproxy import exceptions
from mitmproxy import log from mitmproxy import log
from mitmproxy import controller from mitmproxy import controller
from mitmproxy import options from mitmproxy import options
@@ -17,18 +16,6 @@ class TestDumpMaster(tservers.MasterTest):
m = dump.DumpMaster(o, proxy.DummyServer(), with_termlog=False, with_dumper=False) m = dump.DumpMaster(o, proxy.DummyServer(), with_termlog=False, with_dumper=False)
return m return m
def test_read(self, tmpdir):
p = str(tmpdir.join("read"))
self.flowfile(p)
self.dummy_cycle(
self.mkmaster(None, rfile=p),
1, b"",
)
with pytest.raises(exceptions.OptionsError):
self.mkmaster(None, rfile="/nonexistent")
with pytest.raises(exceptions.OptionsError):
self.mkmaster(None, rfile="test_dump.py")
def test_has_error(self): def test_has_error(self):
m = self.mkmaster(None) m = self.mkmaster(None)
ent = log.LogEntry("foo", "error") ent = log.LogEntry("foo", "error")