refactor ipython integration
This commit is contained in:
parent
f5df1fb77f
commit
0a9e097c25
|
@ -4,36 +4,15 @@ ipy = __import__("IPython")
|
||||||
@ipy.core.magic.needs_local_scope
|
@ipy.core.magic.needs_local_scope
|
||||||
def dask(line, local_ns):
|
def dask(line, local_ns):
|
||||||
"initializes dask"
|
"initializes dask"
|
||||||
from distributed.client import default_client
|
|
||||||
from distributed.diagnostics.progressbar import TextProgressBar
|
|
||||||
from distributed import futures_of
|
|
||||||
from tqdm import tqdm
|
|
||||||
|
|
||||||
class TqdmBar(TextProgressBar):
|
from leylines import get_server_node, DEFAULT_PORT, db
|
||||||
def __init__(self, *args, **kwargs):
|
from leylines.dask import init_dask, tqdmprogress
|
||||||
self.p = tqdm(desc="scheduling...")
|
|
||||||
self.last = None
|
|
||||||
TextProgressBar.__init__(self, *args, **kwargs)
|
|
||||||
def _draw_bar(self, remaining, all, **kwargs):
|
|
||||||
if not all:
|
|
||||||
return
|
|
||||||
if self.last is None:
|
|
||||||
self.last = 0
|
|
||||||
self.p.set_description("🦈")
|
|
||||||
self.p.reset(total=all)
|
|
||||||
self.p.update((all - remaining) - self.last)
|
|
||||||
self.last = (all - remaining)
|
|
||||||
def _draw_stop(self, **kwargs):
|
|
||||||
self.p.close()
|
|
||||||
|
|
||||||
def tqdmprogress(future):
|
|
||||||
futures = futures_of(future)
|
|
||||||
if not isinstance(futures, (set, list)):
|
|
||||||
futures = [futures]
|
|
||||||
TqdmBar(futures, complete=True)
|
|
||||||
|
|
||||||
local_ns['tqdmprogress'] = tqdmprogress
|
local_ns['tqdmprogress'] = tqdmprogress
|
||||||
|
|
||||||
|
client = init_dask()
|
||||||
|
local_ns['client'] = client
|
||||||
|
|
||||||
def upload(client, file):
|
def upload(client, file):
|
||||||
import dask
|
import dask
|
||||||
import distributed
|
import distributed
|
||||||
|
@ -46,22 +25,14 @@ def dask(line, local_ns):
|
||||||
|
|
||||||
return dask.delayed(get_file)()
|
return dask.delayed(get_file)()
|
||||||
|
|
||||||
try:
|
local_ns['upload'] = lambda file: upload(client, file)
|
||||||
default = default_client()
|
|
||||||
local_ns['client'] = default
|
|
||||||
local_ns['upload'] = lambda file: upload(client, file)
|
|
||||||
return
|
|
||||||
except ValueError:
|
|
||||||
pass
|
|
||||||
from distributed import Client
|
|
||||||
import re
|
|
||||||
from leylines import get_server_node, DEFAULT_PORT, db
|
|
||||||
server_node = get_server_node()
|
server_node = get_server_node()
|
||||||
workers = [node for node in db.get_nodes()
|
workers = [node for node in db.get_nodes()
|
||||||
if node.id != server_node.id and node.ssh_key is not None]
|
if node.id != server_node.id and node.ssh_key is not None]
|
||||||
dest = f"{server_node.ip}:{DEFAULT_PORT}"
|
dest = f"{server_node.ip}:{DEFAULT_PORT}"
|
||||||
print("connecting to APEX at", dest)
|
print("connected to APEX at", dest)
|
||||||
client = Client(dest)
|
|
||||||
workers_by_ip = {str(node.ip):node for node in workers}
|
workers_by_ip = {str(node.ip):node for node in workers}
|
||||||
workers_status = {str(node.ip):False for node in workers}
|
workers_status = {str(node.ip):False for node in workers}
|
||||||
for addr, info in client.scheduler_info()["workers"].items():
|
for addr, info in client.scheduler_info()["workers"].items():
|
||||||
|
@ -71,8 +42,6 @@ def dask(line, local_ns):
|
||||||
print(f"{node.name} ({node.ip}): up")
|
print(f"{node.name} ({node.ip}): up")
|
||||||
else:
|
else:
|
||||||
print(f"{node.name} ({node.ip}): down")
|
print(f"{node.name} ({node.ip}): down")
|
||||||
local_ns['client'] = client
|
|
||||||
local_ns['upload'] = lambda file: upload(client, file)
|
|
||||||
|
|
||||||
@ipy.core.magic.register_line_magic
|
@ipy.core.magic.register_line_magic
|
||||||
@ipy.core.magic.needs_local_scope
|
@ipy.core.magic.needs_local_scope
|
||||||
|
|
|
@ -0,0 +1,46 @@
|
||||||
|
import os
|
||||||
|
|
||||||
|
from distributed.client import default_client
|
||||||
|
from distributed.diagnostics.progressbar import TextProgressBar
|
||||||
|
from distributed import Client, futures_of
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from leylines import get_server_node, DEFAULT_PORT
|
||||||
|
|
||||||
|
def init_dask():
|
||||||
|
"initializes dask"
|
||||||
|
|
||||||
|
try:
|
||||||
|
default = default_client()
|
||||||
|
return default
|
||||||
|
except ValueError:
|
||||||
|
pass
|
||||||
|
server_node = get_server_node()
|
||||||
|
dest = f"{server_node.ip}:{DEFAULT_PORT}"
|
||||||
|
client = Client(dest)
|
||||||
|
return client
|
||||||
|
|
||||||
|
|
||||||
|
class DaskTqdmBar(TextProgressBar):
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
self.p = tqdm(desc="scheduling...")
|
||||||
|
self.last = None
|
||||||
|
TextProgressBar.__init__(self, *args, **kwargs)
|
||||||
|
def _draw_bar(self, remaining, all, **kwargs):
|
||||||
|
if not all:
|
||||||
|
return
|
||||||
|
if self.last is None:
|
||||||
|
self.last = 0
|
||||||
|
self.p.set_description("🦈")
|
||||||
|
self.p.reset(total=all)
|
||||||
|
self.p.update((all - remaining) - self.last)
|
||||||
|
self.last = (all - remaining)
|
||||||
|
def _draw_stop(self, **kwargs):
|
||||||
|
self.p.close()
|
||||||
|
|
||||||
|
|
||||||
|
def tqdmprogress(future):
|
||||||
|
futures = futures_of(future)
|
||||||
|
if not isinstance(futures, (set, list)):
|
||||||
|
futures = [futures]
|
||||||
|
DaskTqdmBar(futures, complete=True)
|
Loading…
Reference in New Issue