diff --git a/leylines/leylines/dask.py b/leylines/leylines/dask.py index 156bd99..f8f7a95 100644 --- a/leylines/leylines/dask.py +++ b/leylines/leylines/dask.py @@ -1,13 +1,15 @@ import os +from typing import Iterable from distributed.client import default_client from distributed.diagnostics.progressbar import TextProgressBar -from distributed import Client, futures_of +from distributed import Client, futures_of, Future from tqdm import tqdm from leylines import get_server_node, DEFAULT_PORT -def init_dask(): + +def init_dask() -> Client: "initializes dask" try: @@ -21,12 +23,25 @@ def init_dask(): return client +async def init_dask_async() -> Client: + "initializes dask" + + try: + default = default_client() + return default + except ValueError: + pass + server_node = get_server_node() + dest = f"{server_node.ip}:{DEFAULT_PORT}" + return await Client(dest, asynchronous=True) + + class DaskTqdmBar(TextProgressBar): - def __init__(self, *args, **kwargs): + def __init__(self, *args, **kwargs) -> None: self.p = tqdm(desc="scheduling...") self.last = None TextProgressBar.__init__(self, *args, **kwargs) - def _draw_bar(self, remaining, all, **kwargs): + def _draw_bar(self, remaining, all, **kwargs) -> None: if not all: return if self.last is None: @@ -35,12 +50,27 @@ class DaskTqdmBar(TextProgressBar): self.p.reset(total=all) self.p.update((all - remaining) - self.last) self.last = (all - remaining) - def _draw_stop(self, **kwargs): + def _draw_stop(self, **kwargs) -> None: self.p.close() -def tqdmprogress(future): +def tqdmprogress(future: Future) -> None: futures = futures_of(future) if not isinstance(futures, (set, list)): futures = [futures] DaskTqdmBar(futures, complete=True) + + +class tqdm2(tqdm): + def __init__(self, *args, **kwargs): + self._next_desc = kwargs.get("run_desc", None) + if "run_desc" in kwargs: + del kwargs["run_desc"] + tqdm.__init__(self, *args, **kwargs) + self._first_time = True + + def update(self, *args, **kwargs): + if self._first_time and args[0] > 0: + self._first_time = False + self.set_description(self._next_desc) + tqdm.update(self, *args, **kwargs)