update client support
This commit is contained in:
parent
6fb440c842
commit
618b32f993
|
@ -5,12 +5,30 @@ ipy = __import__("IPython")
|
|||
def dask(line, local_ns):
|
||||
"initializes dask"
|
||||
|
||||
from leylines.client import ClientSession, SERVER_NODE_ID
|
||||
from leylines.client import ClientSession, SERVER_NODE_ID, DEFAULT_PORT
|
||||
from leylines.dask import init_dask, tqdmprogress
|
||||
|
||||
local_ns['tqdmprogress'] = tqdmprogress
|
||||
|
||||
client = init_dask()
|
||||
import asyncio
|
||||
async def get_nodes_info():
|
||||
async with ClientSession() as api:
|
||||
nodes = (await api.getNodes().a_wait()).nodes
|
||||
server = None
|
||||
workers = []
|
||||
for node in nodes:
|
||||
info = await node.getInfo().a_wait()
|
||||
out_info = (info.ip, info.name)
|
||||
if info.id == SERVER_NODE_ID:
|
||||
server = out_info
|
||||
elif info.sshkey.which() != "none":
|
||||
workers.append(out_info)
|
||||
return server, workers
|
||||
|
||||
server, workers = asyncio.run(get_nodes_info())
|
||||
dest = f"{server[0]}:{DEFAULT_PORT}"
|
||||
|
||||
client = init_dask(dest)
|
||||
local_ns['client'] = client
|
||||
|
||||
def upload(client, file):
|
||||
|
@ -27,30 +45,13 @@ def dask(line, local_ns):
|
|||
|
||||
local_ns['upload'] = lambda file: upload(client, file)
|
||||
|
||||
import asyncio
|
||||
async def get_nodes_info():
|
||||
async with ClientSession() as api:
|
||||
nodes = (await api.getNodes().a_wait()).nodes
|
||||
server = None
|
||||
workers = []
|
||||
for node in nodes:
|
||||
info = (await node.getInfo().a_wait()).info
|
||||
out_info = (info.ip, info.name)
|
||||
if info.id == SERVER_NODE_ID:
|
||||
server = out_info
|
||||
elif info.sshkey.which() != "none":
|
||||
workers.append(out_info)
|
||||
return server, workers
|
||||
|
||||
server, workers = asyncio.run(get_nodes_info())
|
||||
dest = f"{server[0]}:{DEFAULT_PORT}"
|
||||
print("connected to APEX at", dest)
|
||||
|
||||
workers_by_ip = {str(node[0]):node for node in workers}
|
||||
workers_status = {str(node[0]):False for node in workers}
|
||||
for addr, info in client.scheduler_info()["workers"].items():
|
||||
workers_status[info["host"]] = True
|
||||
for ip, node in sorted(workers_by_ip.items(), key=lambda x:x[1].name):
|
||||
for ip, node in sorted(workers_by_ip.items(), key=lambda x:x[1][1]):
|
||||
if workers_status[ip]:
|
||||
print(f"{node[1]} ({node[0]}): up")
|
||||
else:
|
||||
|
|
|
@ -12,14 +12,12 @@ from pyroute2 import IPRoute, WireGuard
|
|||
import monocypher
|
||||
from .database import Database, Node, SERVER_NODE_ID
|
||||
from .leylines_capnp import Authenticator, LeylinesApi, Maybe
|
||||
from .defs import DEFAULT_PORT, API_PORT, API_SSL_PORT
|
||||
|
||||
|
||||
IFNAME = 'leyline-wg'
|
||||
DEFAULT_PORT = 31337
|
||||
API_PORT = 31338
|
||||
API_SSL_PORT = 31337
|
||||
logger = logging.getLogger(__name__)
|
||||
db = Database()
|
||||
db = None
|
||||
|
||||
|
||||
class AuthenticatorImpl(Authenticator.Server):
|
||||
|
@ -139,6 +137,8 @@ async def client_connected(reader: asyncio.StreamReader, writer: asyncio.StreamW
|
|||
|
||||
|
||||
async def main() -> None:
|
||||
global db
|
||||
db = Database()
|
||||
try:
|
||||
sync_interface()
|
||||
except:
|
||||
|
|
|
@ -5,7 +5,7 @@ from typing import Optional
|
|||
import capnp
|
||||
|
||||
from .leylines_capnp import Authenticator, LeylinesApi
|
||||
from . import API_PORT, API_SSL_PORT
|
||||
from .defs import API_PORT, API_SSL_PORT, DEFAULT_PORT
|
||||
from .database import SERVER_NODE_ID
|
||||
|
||||
|
||||
|
|
|
@ -1,16 +1,16 @@
|
|||
import asyncio
|
||||
import os
|
||||
from typing import Iterable, TypeVar, List
|
||||
from typing import Iterable, TypeVar, List, Optional
|
||||
|
||||
from distributed.client import default_client
|
||||
from distributed.diagnostics.progressbar import TextProgressBar
|
||||
from distributed import Client, futures_of, Future
|
||||
from tqdm import tqdm
|
||||
|
||||
from leylines import get_server_node, DEFAULT_PORT
|
||||
from leylines.client import ClientSession, SERVER_NODE_ID, DEFAULT_PORT
|
||||
|
||||
|
||||
def init_dask() -> Client:
|
||||
def init_dask(dest: Optional[str] = None) -> Client:
|
||||
"initializes dask"
|
||||
|
||||
try:
|
||||
|
@ -18,10 +18,20 @@ def init_dask() -> Client:
|
|||
return default
|
||||
except ValueError:
|
||||
pass
|
||||
server_node = get_server_node()
|
||||
dest = f"{server_node.ip}:{DEFAULT_PORT}"
|
||||
client = Client(dest)
|
||||
return client
|
||||
|
||||
if dest is None:
|
||||
async def get_dest() -> str:
|
||||
async with ClientSession() as api:
|
||||
nodes = (await api.getNodes().a_wait()).nodes
|
||||
for node in nodes:
|
||||
info = await node.getInfo().a_wait()
|
||||
if info.id == SERVER_NODE_ID:
|
||||
return f"{info.ip}:{DEFAULT_PORT}"
|
||||
raise Exception("no server node defined!")
|
||||
|
||||
dest = asyncio.run(get_dest())
|
||||
|
||||
return Client(dest)
|
||||
|
||||
|
||||
async def init_dask_async() -> Client:
|
||||
|
|
|
@ -0,0 +1,3 @@
|
|||
DEFAULT_PORT = 31337
|
||||
API_PORT = 31338
|
||||
API_SSL_PORT = 31337
|
Loading…
Reference in New Issue