101 lines
3.6 KiB
Plaintext
101 lines
3.6 KiB
Plaintext
ipy = __import__("IPython")
|
|
@ipy.core.magic.register_line_magic
|
|
@ipy.core.magic.needs_local_scope
|
|
def dask(line, local_ns):
|
|
"initializes dask"
|
|
from distributed.client import default_client
|
|
from distributed.diagnostics.progressbar import TextProgressBar
|
|
from distributed import futures_of
|
|
from tqdm import tqdm
|
|
|
|
class TqdmBar(TextProgressBar):
|
|
def __init__(self, *args, **kwargs):
|
|
self.p = tqdm(desc="scheduling...")
|
|
self.last = None
|
|
TextProgressBar.__init__(self, *args, **kwargs)
|
|
def _draw_bar(self, remaining, all, **kwargs):
|
|
if not all:
|
|
return
|
|
if self.last is None:
|
|
self.last = 0
|
|
self.p.set_description("🦈")
|
|
self.p.reset(total=all)
|
|
self.p.update((all - remaining) - self.last)
|
|
self.last = (all - remaining)
|
|
def _draw_stop(self, **kwargs):
|
|
self.p.close()
|
|
|
|
def tqdmprogress(future):
|
|
futures = futures_of(future)
|
|
if not isinstance(futures, (set, list)):
|
|
futures = [futures]
|
|
TqdmBar(futures, complete=True)
|
|
|
|
local_ns['tqdmprogress'] = tqdmprogress
|
|
|
|
try:
|
|
default = default_client()
|
|
local_ns['client'] = default
|
|
return
|
|
except ValueError:
|
|
pass
|
|
from distributed import Client
|
|
import re
|
|
from leylines import get_server_node, DEFAULT_PORT, db
|
|
server_node = get_server_node()
|
|
workers = [node for node in db.get_nodes()
|
|
if node.id != server_node.id and node.ssh_key is not None]
|
|
dest = f"{server_node.ip}:{DEFAULT_PORT}"
|
|
print("connecting to APEX at", dest)
|
|
client = Client(dest)
|
|
workers_by_ip = {str(node.ip):node for node in workers}
|
|
workers_status = {str(node.ip):False for node in workers}
|
|
for addr, info in client.scheduler_info()["workers"].items():
|
|
workers_status[info["host"]] = True
|
|
for ip, node in sorted(workers_by_ip.items(), key=lambda x:x[1].name):
|
|
if workers_status[ip]:
|
|
print(f"{node.name} ({node.ip}): up")
|
|
else:
|
|
print(f"{node.name} ({node.ip}): down")
|
|
local_ns['client'] = client
|
|
|
|
@ipy.core.magic.register_line_magic
|
|
@ipy.core.magic.needs_local_scope
|
|
def daskworker(line, local_ns):
|
|
"picks a worker to launch ipython on"
|
|
from distributed.client import default_client
|
|
import subprocess
|
|
import json
|
|
import tempfile
|
|
import time
|
|
import os
|
|
import shutil
|
|
|
|
splitter = None
|
|
if os.environ.get("TMUX", "") != "":
|
|
splitter = ["tmux", "split"]
|
|
elif shutil.which("gnome-terminal") is not None:
|
|
splitter = ["/usr/bin/env", "gnome-terminal", "--"]
|
|
else:
|
|
raise Exception("don't know how to split terminal!")
|
|
|
|
client = default_client()
|
|
workers = client.scheduler_info()["workers"].items()
|
|
sorted_workers = sorted(workers,
|
|
key=lambda w: w[1]["metrics"]["memory"] - w[1]["memory_limit"]
|
|
+ (100000000000 * w[1]["metrics"]["executing"]))
|
|
worker = sorted_workers[0][0]
|
|
print("starting ipython kernel on", sorted_workers[0][1]["id"])
|
|
info = list(client.start_ipython_workers([worker]).values())[0]
|
|
if isinstance(info["key"], bytes):
|
|
info["key"] = info["key"].decode()
|
|
|
|
with tempfile.NamedTemporaryFile(mode="w", prefix="apex-") as f:
|
|
json.dump(info, f)
|
|
f.flush()
|
|
subprocess.check_call(splitter + ["/usr/bin/env", "jupyter", "console", "--existing",
|
|
f.name])
|
|
time.sleep(1)
|
|
|
|
del dask, daskworker, ipy
|