add tqdm_await
This commit is contained in:
parent
58064f5420
commit
19944c1fa2
|
@ -1,5 +1,6 @@
|
||||||
|
import asyncio
|
||||||
import os
|
import os
|
||||||
from typing import Iterable
|
from typing import Iterable, TypeVar, List
|
||||||
|
|
||||||
from distributed.client import default_client
|
from distributed.client import default_client
|
||||||
from distributed.diagnostics.progressbar import TextProgressBar
|
from distributed.diagnostics.progressbar import TextProgressBar
|
||||||
|
@ -61,6 +62,28 @@ def tqdmprogress(future: Future) -> None:
|
||||||
DaskTqdmBar(futures, complete=True)
|
DaskTqdmBar(futures, complete=True)
|
||||||
|
|
||||||
|
|
||||||
|
T = TypeVar("T")
|
||||||
|
async def tqdm_await(tasks: Iterable[T], pbar: tqdm = None) -> None:
|
||||||
|
pending = list(tasks)
|
||||||
|
need_close = False
|
||||||
|
if pbar is None:
|
||||||
|
pbar = tqdm(total=len(pending))
|
||||||
|
need_close = True
|
||||||
|
else:
|
||||||
|
pbar.reset(len(pending))
|
||||||
|
results = []
|
||||||
|
try:
|
||||||
|
while len(pending) > 0:
|
||||||
|
done, pending = await asyncio.wait(pending, timeout=1,
|
||||||
|
return_when=asyncio.FIRST_COMPLETED)
|
||||||
|
pbar.update(len(done))
|
||||||
|
for task in done:
|
||||||
|
yield task
|
||||||
|
finally:
|
||||||
|
if need_close:
|
||||||
|
pbar.close()
|
||||||
|
|
||||||
|
|
||||||
class tqdm2(tqdm):
|
class tqdm2(tqdm):
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
self._next_desc = kwargs.get("run_desc", None)
|
self._next_desc = kwargs.get("run_desc", None)
|
||||||
|
|
Loading…
Reference in New Issue