move ScatterGather to utils

This commit is contained in:
Jude Nelson
2017-05-17 22:05:32 -04:00
parent f5ecf9b148
commit 7e8a34ff7b
2 changed files with 123 additions and 119 deletions

View File

@@ -29,6 +29,7 @@ from ..proxy import *
from ..config import get_utxo_provider_client
from ..b40 import is_b40
from ..logger import get_logger
from ..utils import ScatterGather, ScatterGatherThread
from .blockchain import (
get_balance, is_address_usable, get_utxos,
@@ -120,125 +121,6 @@ def check_valid_namespace(nsid):
return 'The namespace ID is invalid'
class ScatterGatherThread(threading.Thread):
"""
Scatter/gatter thread worker
Useful for doing long-running queries in parallel
"""
def __init__(self, rpc_call):
threading.Thread.__init__(self)
self.rpc_call = rpc_call
self.result = None
self.has_result = False
self.result_mux = threading.Lock()
self.result_mux.acquire()
def get_result(self):
"""
Wait for data and get it
"""
self.result_mux.acquire()
res = self.result
self.result_mux.release()
return res
def post_result(self, res):
"""
Give back result and release
"""
if self.has_result:
return
self.has_result = True
self.result = res
self.result_mux.release()
return
@classmethod
def do_work(cls, rpc_call):
"""
Run the given RPC call and post the result
"""
try:
log.debug("Run task {}".format(rpc_call))
res = rpc_call()
log.debug("Task exit {}".format(rpc_call))
return res
except Exception as e:
log.exception(e)
log.debug("Task exit {}".format(rpc_call))
return {'error': 'Task encountered a fatal exception:\n{}'.format(traceback.format_exc())}
def run(self):
res = ScatterGatherThread.do_work(self.rpc_call)
self.post_result(res)
class ScatterGather(object):
"""
Scatter/gather work pool
Give it a few tasks, and it will run them
in parallel
"""
def __init__(self):
self.tasks = {}
self.ran = False
self.results = {}
def add_task(self, result_name, rpc_call):
assert result_name not in self.tasks.keys(), "Duplicate task: {}".format(result_name)
self.tasks[result_name] = rpc_call
def get_result(self, result_name):
assert self.ran
assert result_name in self.results, "Missing task: {}".format(result_name)
return self.results[result_name]
def get_results(self):
"""
Get the set of results
"""
assert self.ran
return self.results
def run_tasks(self, single_thread=False):
"""
Run all queued tasks, wait for them all to finish,
and return the set of results
"""
if not single_thread:
threads = {}
for task_name, task_call in self.tasks.items():
log.debug("Start task '{}'".format(task_name))
thr = ScatterGatherThread(task_call)
thr.start()
threads[task_name] = thr
for task_name, thr in threads.items():
log.debug("Join task '{}'".format(task_name))
thr.join()
res = thr.get_result()
self.results[task_name] = res
else:
# for testing purposes
for task_name, task_call in self.tasks.items():
res = ScatterGatherThread.do_work(task_call)
self.results[task_name] = res
self.ran = True
return self.results
def operation_sanity_checks(fqu_or_ns, operations, scatter_gather, payment_privkey_info, owner_privkey_info, required_checks=[],
min_confirmations=TX_MIN_CONFIRMATIONS, config_path=CONFIG_PATH,
transfer_address=None, owner_address=None, proxy=None):

View File

@@ -26,6 +26,9 @@ import sys
import os
import urllib2
import hashlib
import threading
import traceback
from .constants import DEFAULT_BLOCKSTACKD_PORT
from .logger import get_logger
@@ -223,3 +226,122 @@ def streq_constant(s1, s2):
return res == 0
class ScatterGatherThread(threading.Thread):
"""
Scatter/gatter thread worker
Useful for doing long-running queries in parallel
"""
def __init__(self, rpc_call):
threading.Thread.__init__(self)
self.rpc_call = rpc_call
self.result = None
self.has_result = False
self.result_mux = threading.Lock()
self.result_mux.acquire()
def get_result(self):
"""
Wait for data and get it
"""
self.result_mux.acquire()
res = self.result
self.result_mux.release()
return res
def post_result(self, res):
"""
Give back result and release
"""
if self.has_result:
return
self.has_result = True
self.result = res
self.result_mux.release()
return
@classmethod
def do_work(cls, rpc_call):
"""
Run the given RPC call and post the result
"""
try:
log.debug("Run task {}".format(rpc_call))
res = rpc_call()
log.debug("Task exit {}".format(rpc_call))
return res
except Exception as e:
log.exception(e)
log.debug("Task exit {}".format(rpc_call))
return {'error': 'Task encountered a fatal exception:\n{}'.format(traceback.format_exc())}
def run(self):
res = ScatterGatherThread.do_work(self.rpc_call)
self.post_result(res)
class ScatterGather(object):
"""
Scatter/gather work pool
Give it a few tasks, and it will run them
in parallel
"""
def __init__(self):
self.tasks = {}
self.ran = False
self.results = {}
def add_task(self, result_name, rpc_call):
assert result_name not in self.tasks.keys(), "Duplicate task: {}".format(result_name)
self.tasks[result_name] = rpc_call
def get_result(self, result_name):
assert self.ran
assert result_name in self.results, "Missing task: {}".format(result_name)
return self.results[result_name]
def get_results(self):
"""
Get the set of results
"""
assert self.ran
return self.results
def run_tasks(self, single_thread=False):
"""
Run all queued tasks, wait for them all to finish,
and return the set of results
"""
if not single_thread:
threads = {}
for task_name, task_call in self.tasks.items():
log.debug("Start task '{}'".format(task_name))
thr = ScatterGatherThread(task_call)
thr.start()
threads[task_name] = thr
for task_name, thr in threads.items():
log.debug("Join task '{}'".format(task_name))
thr.join()
res = thr.get_result()
self.results[task_name] = res
else:
# for testing purposes
for task_name, task_call in self.tasks.items():
res = ScatterGatherThread.do_work(task_call)
self.results[task_name] = res
self.ran = True
return self.results