diff --git a/leylines-support/02-dask.ipy b/leylines-support/02-dask.ipy index 1a2690e..77d6de6 100644 --- a/leylines-support/02-dask.ipy +++ b/leylines-support/02-dask.ipy @@ -4,36 +4,15 @@ ipy = __import__("IPython") @ipy.core.magic.needs_local_scope def dask(line, local_ns): "initializes dask" - from distributed.client import default_client - from distributed.diagnostics.progressbar import TextProgressBar - from distributed import futures_of - from tqdm import tqdm - class TqdmBar(TextProgressBar): - def __init__(self, *args, **kwargs): - self.p = tqdm(desc="scheduling...") - self.last = None - TextProgressBar.__init__(self, *args, **kwargs) - def _draw_bar(self, remaining, all, **kwargs): - if not all: - return - if self.last is None: - self.last = 0 - self.p.set_description("🦈") - self.p.reset(total=all) - self.p.update((all - remaining) - self.last) - self.last = (all - remaining) - def _draw_stop(self, **kwargs): - self.p.close() - - def tqdmprogress(future): - futures = futures_of(future) - if not isinstance(futures, (set, list)): - futures = [futures] - TqdmBar(futures, complete=True) + from leylines import get_server_node, DEFAULT_PORT, db + from leylines.dask import init_dask, tqdmprogress local_ns['tqdmprogress'] = tqdmprogress + client = init_dask() + local_ns['client'] = client + def upload(client, file): import dask import distributed @@ -46,22 +25,14 @@ def dask(line, local_ns): return dask.delayed(get_file)() - try: - default = default_client() - local_ns['client'] = default - local_ns['upload'] = lambda file: upload(client, file) - return - except ValueError: - pass - from distributed import Client - import re - from leylines import get_server_node, DEFAULT_PORT, db + local_ns['upload'] = lambda file: upload(client, file) + server_node = get_server_node() workers = [node for node in db.get_nodes() if node.id != server_node.id and node.ssh_key is not None] dest = f"{server_node.ip}:{DEFAULT_PORT}" - print("connecting to APEX at", dest) - client = Client(dest) + print("connected to APEX at", dest) + workers_by_ip = {str(node.ip):node for node in workers} workers_status = {str(node.ip):False for node in workers} for addr, info in client.scheduler_info()["workers"].items(): @@ -71,8 +42,6 @@ def dask(line, local_ns): print(f"{node.name} ({node.ip}): up") else: print(f"{node.name} ({node.ip}): down") - local_ns['client'] = client - local_ns['upload'] = lambda file: upload(client, file) @ipy.core.magic.register_line_magic @ipy.core.magic.needs_local_scope diff --git a/leylines/leylines/dask.py b/leylines/leylines/dask.py new file mode 100644 index 0000000..156bd99 --- /dev/null +++ b/leylines/leylines/dask.py @@ -0,0 +1,46 @@ +import os + +from distributed.client import default_client +from distributed.diagnostics.progressbar import TextProgressBar +from distributed import Client, futures_of +from tqdm import tqdm + +from leylines import get_server_node, DEFAULT_PORT + +def init_dask(): + "initializes dask" + + try: + default = default_client() + return default + except ValueError: + pass + server_node = get_server_node() + dest = f"{server_node.ip}:{DEFAULT_PORT}" + client = Client(dest) + return client + + +class DaskTqdmBar(TextProgressBar): + def __init__(self, *args, **kwargs): + self.p = tqdm(desc="scheduling...") + self.last = None + TextProgressBar.__init__(self, *args, **kwargs) + def _draw_bar(self, remaining, all, **kwargs): + if not all: + return + if self.last is None: + self.last = 0 + self.p.set_description("🦈") + self.p.reset(total=all) + self.p.update((all - remaining) - self.last) + self.last = (all - remaining) + def _draw_stop(self, **kwargs): + self.p.close() + + +def tqdmprogress(future): + futures = futures_of(future) + if not isinstance(futures, (set, list)): + futures = [futures] + DaskTqdmBar(futures, complete=True)