update client support

This commit is contained in:
xenia 2021-06-18 03:34:01 -04:00
parent 6fb440c842
commit 618b32f993
5 changed files with 46 additions and 32 deletions

View File

@ -5,12 +5,30 @@ ipy = __import__("IPython")
def dask(line, local_ns): def dask(line, local_ns):
"initializes dask" "initializes dask"
from leylines.client import ClientSession, SERVER_NODE_ID from leylines.client import ClientSession, SERVER_NODE_ID, DEFAULT_PORT
from leylines.dask import init_dask, tqdmprogress from leylines.dask import init_dask, tqdmprogress
local_ns['tqdmprogress'] = tqdmprogress local_ns['tqdmprogress'] = tqdmprogress
client = init_dask() import asyncio
async def get_nodes_info():
async with ClientSession() as api:
nodes = (await api.getNodes().a_wait()).nodes
server = None
workers = []
for node in nodes:
info = await node.getInfo().a_wait()
out_info = (info.ip, info.name)
if info.id == SERVER_NODE_ID:
server = out_info
elif info.sshkey.which() != "none":
workers.append(out_info)
return server, workers
server, workers = asyncio.run(get_nodes_info())
dest = f"{server[0]}:{DEFAULT_PORT}"
client = init_dask(dest)
local_ns['client'] = client local_ns['client'] = client
def upload(client, file): def upload(client, file):
@ -27,30 +45,13 @@ def dask(line, local_ns):
local_ns['upload'] = lambda file: upload(client, file) local_ns['upload'] = lambda file: upload(client, file)
import asyncio
async def get_nodes_info():
async with ClientSession() as api:
nodes = (await api.getNodes().a_wait()).nodes
server = None
workers = []
for node in nodes:
info = (await node.getInfo().a_wait()).info
out_info = (info.ip, info.name)
if info.id == SERVER_NODE_ID:
server = out_info
elif info.sshkey.which() != "none":
workers.append(out_info)
return server, workers
server, workers = asyncio.run(get_nodes_info())
dest = f"{server[0]}:{DEFAULT_PORT}"
print("connected to APEX at", dest) print("connected to APEX at", dest)
workers_by_ip = {str(node[0]):node for node in workers} workers_by_ip = {str(node[0]):node for node in workers}
workers_status = {str(node[0]):False for node in workers} workers_status = {str(node[0]):False for node in workers}
for addr, info in client.scheduler_info()["workers"].items(): for addr, info in client.scheduler_info()["workers"].items():
workers_status[info["host"]] = True workers_status[info["host"]] = True
for ip, node in sorted(workers_by_ip.items(), key=lambda x:x[1].name): for ip, node in sorted(workers_by_ip.items(), key=lambda x:x[1][1]):
if workers_status[ip]: if workers_status[ip]:
print(f"{node[1]} ({node[0]}): up") print(f"{node[1]} ({node[0]}): up")
else: else:

View File

@ -12,14 +12,12 @@ from pyroute2 import IPRoute, WireGuard
import monocypher import monocypher
from .database import Database, Node, SERVER_NODE_ID from .database import Database, Node, SERVER_NODE_ID
from .leylines_capnp import Authenticator, LeylinesApi, Maybe from .leylines_capnp import Authenticator, LeylinesApi, Maybe
from .defs import DEFAULT_PORT, API_PORT, API_SSL_PORT
IFNAME = 'leyline-wg' IFNAME = 'leyline-wg'
DEFAULT_PORT = 31337
API_PORT = 31338
API_SSL_PORT = 31337
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
db = Database() db = None
class AuthenticatorImpl(Authenticator.Server): class AuthenticatorImpl(Authenticator.Server):
@ -139,6 +137,8 @@ async def client_connected(reader: asyncio.StreamReader, writer: asyncio.StreamW
async def main() -> None: async def main() -> None:
global db
db = Database()
try: try:
sync_interface() sync_interface()
except: except:

View File

@ -5,7 +5,7 @@ from typing import Optional
import capnp import capnp
from .leylines_capnp import Authenticator, LeylinesApi from .leylines_capnp import Authenticator, LeylinesApi
from . import API_PORT, API_SSL_PORT from .defs import API_PORT, API_SSL_PORT, DEFAULT_PORT
from .database import SERVER_NODE_ID from .database import SERVER_NODE_ID

View File

@ -1,16 +1,16 @@
import asyncio import asyncio
import os import os
from typing import Iterable, TypeVar, List from typing import Iterable, TypeVar, List, Optional
from distributed.client import default_client from distributed.client import default_client
from distributed.diagnostics.progressbar import TextProgressBar from distributed.diagnostics.progressbar import TextProgressBar
from distributed import Client, futures_of, Future from distributed import Client, futures_of, Future
from tqdm import tqdm from tqdm import tqdm
from leylines import get_server_node, DEFAULT_PORT from leylines.client import ClientSession, SERVER_NODE_ID, DEFAULT_PORT
def init_dask() -> Client: def init_dask(dest: Optional[str] = None) -> Client:
"initializes dask" "initializes dask"
try: try:
@ -18,10 +18,20 @@ def init_dask() -> Client:
return default return default
except ValueError: except ValueError:
pass pass
server_node = get_server_node()
dest = f"{server_node.ip}:{DEFAULT_PORT}" if dest is None:
client = Client(dest) async def get_dest() -> str:
return client async with ClientSession() as api:
nodes = (await api.getNodes().a_wait()).nodes
for node in nodes:
info = await node.getInfo().a_wait()
if info.id == SERVER_NODE_ID:
return f"{info.ip}:{DEFAULT_PORT}"
raise Exception("no server node defined!")
dest = asyncio.run(get_dest())
return Client(dest)
async def init_dask_async() -> Client: async def init_dask_async() -> Client:

View File

@ -0,0 +1,3 @@
DEFAULT_PORT = 31337
API_PORT = 31338
API_SSL_PORT = 31337