dask improvements
This commit is contained in:
parent
0a9e097c25
commit
58064f5420
|
@ -1,13 +1,15 @@
|
||||||
import os
|
import os
|
||||||
|
from typing import Iterable
|
||||||
|
|
||||||
from distributed.client import default_client
|
from distributed.client import default_client
|
||||||
from distributed.diagnostics.progressbar import TextProgressBar
|
from distributed.diagnostics.progressbar import TextProgressBar
|
||||||
from distributed import Client, futures_of
|
from distributed import Client, futures_of, Future
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from leylines import get_server_node, DEFAULT_PORT
|
from leylines import get_server_node, DEFAULT_PORT
|
||||||
|
|
||||||
def init_dask():
|
|
||||||
|
def init_dask() -> Client:
|
||||||
"initializes dask"
|
"initializes dask"
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
@ -21,12 +23,25 @@ def init_dask():
|
||||||
return client
|
return client
|
||||||
|
|
||||||
|
|
||||||
|
async def init_dask_async() -> Client:
|
||||||
|
"initializes dask"
|
||||||
|
|
||||||
|
try:
|
||||||
|
default = default_client()
|
||||||
|
return default
|
||||||
|
except ValueError:
|
||||||
|
pass
|
||||||
|
server_node = get_server_node()
|
||||||
|
dest = f"{server_node.ip}:{DEFAULT_PORT}"
|
||||||
|
return await Client(dest, asynchronous=True)
|
||||||
|
|
||||||
|
|
||||||
class DaskTqdmBar(TextProgressBar):
|
class DaskTqdmBar(TextProgressBar):
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs) -> None:
|
||||||
self.p = tqdm(desc="scheduling...")
|
self.p = tqdm(desc="scheduling...")
|
||||||
self.last = None
|
self.last = None
|
||||||
TextProgressBar.__init__(self, *args, **kwargs)
|
TextProgressBar.__init__(self, *args, **kwargs)
|
||||||
def _draw_bar(self, remaining, all, **kwargs):
|
def _draw_bar(self, remaining, all, **kwargs) -> None:
|
||||||
if not all:
|
if not all:
|
||||||
return
|
return
|
||||||
if self.last is None:
|
if self.last is None:
|
||||||
|
@ -35,12 +50,27 @@ class DaskTqdmBar(TextProgressBar):
|
||||||
self.p.reset(total=all)
|
self.p.reset(total=all)
|
||||||
self.p.update((all - remaining) - self.last)
|
self.p.update((all - remaining) - self.last)
|
||||||
self.last = (all - remaining)
|
self.last = (all - remaining)
|
||||||
def _draw_stop(self, **kwargs):
|
def _draw_stop(self, **kwargs) -> None:
|
||||||
self.p.close()
|
self.p.close()
|
||||||
|
|
||||||
|
|
||||||
def tqdmprogress(future):
|
def tqdmprogress(future: Future) -> None:
|
||||||
futures = futures_of(future)
|
futures = futures_of(future)
|
||||||
if not isinstance(futures, (set, list)):
|
if not isinstance(futures, (set, list)):
|
||||||
futures = [futures]
|
futures = [futures]
|
||||||
DaskTqdmBar(futures, complete=True)
|
DaskTqdmBar(futures, complete=True)
|
||||||
|
|
||||||
|
|
||||||
|
class tqdm2(tqdm):
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
self._next_desc = kwargs.get("run_desc", None)
|
||||||
|
if "run_desc" in kwargs:
|
||||||
|
del kwargs["run_desc"]
|
||||||
|
tqdm.__init__(self, *args, **kwargs)
|
||||||
|
self._first_time = True
|
||||||
|
|
||||||
|
def update(self, *args, **kwargs):
|
||||||
|
if self._first_time and args[0] > 0:
|
||||||
|
self._first_time = False
|
||||||
|
self.set_description(self._next_desc)
|
||||||
|
tqdm.update(self, *args, **kwargs)
|
||||||
|
|
Loading…
Reference in New Issue