add tqdm_await
This commit is contained in:
parent
58064f5420
commit
19944c1fa2
|
@ -1,5 +1,6 @@
|
|||
import asyncio
|
||||
import os
|
||||
from typing import Iterable
|
||||
from typing import Iterable, TypeVar, List
|
||||
|
||||
from distributed.client import default_client
|
||||
from distributed.diagnostics.progressbar import TextProgressBar
|
||||
|
@ -61,6 +62,28 @@ def tqdmprogress(future: Future) -> None:
|
|||
DaskTqdmBar(futures, complete=True)
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
async def tqdm_await(tasks: Iterable[T], pbar: tqdm = None) -> None:
|
||||
pending = list(tasks)
|
||||
need_close = False
|
||||
if pbar is None:
|
||||
pbar = tqdm(total=len(pending))
|
||||
need_close = True
|
||||
else:
|
||||
pbar.reset(len(pending))
|
||||
results = []
|
||||
try:
|
||||
while len(pending) > 0:
|
||||
done, pending = await asyncio.wait(pending, timeout=1,
|
||||
return_when=asyncio.FIRST_COMPLETED)
|
||||
pbar.update(len(done))
|
||||
for task in done:
|
||||
yield task
|
||||
finally:
|
||||
if need_close:
|
||||
pbar.close()
|
||||
|
||||
|
||||
class tqdm2(tqdm):
|
||||
def __init__(self, *args, **kwargs):
|
||||
self._next_desc = kwargs.get("run_desc", None)
|
||||
|
|
Loading…
Reference in New Issue