refactor ipython integration

This commit is contained in:
xenia 2021-06-17 05:46:37 -04:00
parent f5df1fb77f
commit 0a9e097c25
2 changed files with 55 additions and 40 deletions

View File

@ -4,36 +4,15 @@ ipy = __import__("IPython")
@ipy.core.magic.needs_local_scope
def dask(line, local_ns):
"initializes dask"
from distributed.client import default_client
from distributed.diagnostics.progressbar import TextProgressBar
from distributed import futures_of
from tqdm import tqdm
class TqdmBar(TextProgressBar):
def __init__(self, *args, **kwargs):
self.p = tqdm(desc="scheduling...")
self.last = None
TextProgressBar.__init__(self, *args, **kwargs)
def _draw_bar(self, remaining, all, **kwargs):
if not all:
return
if self.last is None:
self.last = 0
self.p.set_description("🦈")
self.p.reset(total=all)
self.p.update((all - remaining) - self.last)
self.last = (all - remaining)
def _draw_stop(self, **kwargs):
self.p.close()
def tqdmprogress(future):
futures = futures_of(future)
if not isinstance(futures, (set, list)):
futures = [futures]
TqdmBar(futures, complete=True)
from leylines import get_server_node, DEFAULT_PORT, db
from leylines.dask import init_dask, tqdmprogress
local_ns['tqdmprogress'] = tqdmprogress
client = init_dask()
local_ns['client'] = client
def upload(client, file):
import dask
import distributed
@ -46,22 +25,14 @@ def dask(line, local_ns):
return dask.delayed(get_file)()
try:
default = default_client()
local_ns['client'] = default
local_ns['upload'] = lambda file: upload(client, file)
return
except ValueError:
pass
from distributed import Client
import re
from leylines import get_server_node, DEFAULT_PORT, db
local_ns['upload'] = lambda file: upload(client, file)
server_node = get_server_node()
workers = [node for node in db.get_nodes()
if node.id != server_node.id and node.ssh_key is not None]
dest = f"{server_node.ip}:{DEFAULT_PORT}"
print("connecting to APEX at", dest)
client = Client(dest)
print("connected to APEX at", dest)
workers_by_ip = {str(node.ip):node for node in workers}
workers_status = {str(node.ip):False for node in workers}
for addr, info in client.scheduler_info()["workers"].items():
@ -71,8 +42,6 @@ def dask(line, local_ns):
print(f"{node.name} ({node.ip}): up")
else:
print(f"{node.name} ({node.ip}): down")
local_ns['client'] = client
local_ns['upload'] = lambda file: upload(client, file)
@ipy.core.magic.register_line_magic
@ipy.core.magic.needs_local_scope

46
leylines/leylines/dask.py Normal file
View File

@ -0,0 +1,46 @@
import os
from distributed.client import default_client
from distributed.diagnostics.progressbar import TextProgressBar
from distributed import Client, futures_of
from tqdm import tqdm
from leylines import get_server_node, DEFAULT_PORT
def init_dask():
"initializes dask"
try:
default = default_client()
return default
except ValueError:
pass
server_node = get_server_node()
dest = f"{server_node.ip}:{DEFAULT_PORT}"
client = Client(dest)
return client
class DaskTqdmBar(TextProgressBar):
def __init__(self, *args, **kwargs):
self.p = tqdm(desc="scheduling...")
self.last = None
TextProgressBar.__init__(self, *args, **kwargs)
def _draw_bar(self, remaining, all, **kwargs):
if not all:
return
if self.last is None:
self.last = 0
self.p.set_description("🦈")
self.p.reset(total=all)
self.p.update((all - remaining) - self.last)
self.last = (all - remaining)
def _draw_stop(self, **kwargs):
self.p.close()
def tqdmprogress(future):
futures = futures_of(future)
if not isinstance(futures, (set, list)):
futures = [futures]
DaskTqdmBar(futures, complete=True)