diff --git a/leylines/leylines/dask.py b/leylines/leylines/dask.py index f8f7a95..7398efa 100644 --- a/leylines/leylines/dask.py +++ b/leylines/leylines/dask.py @@ -1,5 +1,6 @@ +import asyncio import os -from typing import Iterable +from typing import Iterable, TypeVar, List from distributed.client import default_client from distributed.diagnostics.progressbar import TextProgressBar @@ -61,6 +62,28 @@ def tqdmprogress(future: Future) -> None: DaskTqdmBar(futures, complete=True) +T = TypeVar("T") +async def tqdm_await(tasks: Iterable[T], pbar: tqdm = None) -> None: + pending = list(tasks) + need_close = False + if pbar is None: + pbar = tqdm(total=len(pending)) + need_close = True + else: + pbar.reset(len(pending)) + results = [] + try: + while len(pending) > 0: + done, pending = await asyncio.wait(pending, timeout=1, + return_when=asyncio.FIRST_COMPLETED) + pbar.update(len(done)) + for task in done: + yield task + finally: + if need_close: + pbar.close() + + class tqdm2(tqdm): def __init__(self, *args, **kwargs): self._next_desc = kwargs.get("run_desc", None)