update client support
This commit is contained in:
parent
6fb440c842
commit
618b32f993
|
@ -5,12 +5,30 @@ ipy = __import__("IPython")
|
||||||
def dask(line, local_ns):
|
def dask(line, local_ns):
|
||||||
"initializes dask"
|
"initializes dask"
|
||||||
|
|
||||||
from leylines.client import ClientSession, SERVER_NODE_ID
|
from leylines.client import ClientSession, SERVER_NODE_ID, DEFAULT_PORT
|
||||||
from leylines.dask import init_dask, tqdmprogress
|
from leylines.dask import init_dask, tqdmprogress
|
||||||
|
|
||||||
local_ns['tqdmprogress'] = tqdmprogress
|
local_ns['tqdmprogress'] = tqdmprogress
|
||||||
|
|
||||||
client = init_dask()
|
import asyncio
|
||||||
|
async def get_nodes_info():
|
||||||
|
async with ClientSession() as api:
|
||||||
|
nodes = (await api.getNodes().a_wait()).nodes
|
||||||
|
server = None
|
||||||
|
workers = []
|
||||||
|
for node in nodes:
|
||||||
|
info = await node.getInfo().a_wait()
|
||||||
|
out_info = (info.ip, info.name)
|
||||||
|
if info.id == SERVER_NODE_ID:
|
||||||
|
server = out_info
|
||||||
|
elif info.sshkey.which() != "none":
|
||||||
|
workers.append(out_info)
|
||||||
|
return server, workers
|
||||||
|
|
||||||
|
server, workers = asyncio.run(get_nodes_info())
|
||||||
|
dest = f"{server[0]}:{DEFAULT_PORT}"
|
||||||
|
|
||||||
|
client = init_dask(dest)
|
||||||
local_ns['client'] = client
|
local_ns['client'] = client
|
||||||
|
|
||||||
def upload(client, file):
|
def upload(client, file):
|
||||||
|
@ -27,30 +45,13 @@ def dask(line, local_ns):
|
||||||
|
|
||||||
local_ns['upload'] = lambda file: upload(client, file)
|
local_ns['upload'] = lambda file: upload(client, file)
|
||||||
|
|
||||||
import asyncio
|
|
||||||
async def get_nodes_info():
|
|
||||||
async with ClientSession() as api:
|
|
||||||
nodes = (await api.getNodes().a_wait()).nodes
|
|
||||||
server = None
|
|
||||||
workers = []
|
|
||||||
for node in nodes:
|
|
||||||
info = (await node.getInfo().a_wait()).info
|
|
||||||
out_info = (info.ip, info.name)
|
|
||||||
if info.id == SERVER_NODE_ID:
|
|
||||||
server = out_info
|
|
||||||
elif info.sshkey.which() != "none":
|
|
||||||
workers.append(out_info)
|
|
||||||
return server, workers
|
|
||||||
|
|
||||||
server, workers = asyncio.run(get_nodes_info())
|
|
||||||
dest = f"{server[0]}:{DEFAULT_PORT}"
|
|
||||||
print("connected to APEX at", dest)
|
print("connected to APEX at", dest)
|
||||||
|
|
||||||
workers_by_ip = {str(node[0]):node for node in workers}
|
workers_by_ip = {str(node[0]):node for node in workers}
|
||||||
workers_status = {str(node[0]):False for node in workers}
|
workers_status = {str(node[0]):False for node in workers}
|
||||||
for addr, info in client.scheduler_info()["workers"].items():
|
for addr, info in client.scheduler_info()["workers"].items():
|
||||||
workers_status[info["host"]] = True
|
workers_status[info["host"]] = True
|
||||||
for ip, node in sorted(workers_by_ip.items(), key=lambda x:x[1].name):
|
for ip, node in sorted(workers_by_ip.items(), key=lambda x:x[1][1]):
|
||||||
if workers_status[ip]:
|
if workers_status[ip]:
|
||||||
print(f"{node[1]} ({node[0]}): up")
|
print(f"{node[1]} ({node[0]}): up")
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -12,14 +12,12 @@ from pyroute2 import IPRoute, WireGuard
|
||||||
import monocypher
|
import monocypher
|
||||||
from .database import Database, Node, SERVER_NODE_ID
|
from .database import Database, Node, SERVER_NODE_ID
|
||||||
from .leylines_capnp import Authenticator, LeylinesApi, Maybe
|
from .leylines_capnp import Authenticator, LeylinesApi, Maybe
|
||||||
|
from .defs import DEFAULT_PORT, API_PORT, API_SSL_PORT
|
||||||
|
|
||||||
|
|
||||||
IFNAME = 'leyline-wg'
|
IFNAME = 'leyline-wg'
|
||||||
DEFAULT_PORT = 31337
|
|
||||||
API_PORT = 31338
|
|
||||||
API_SSL_PORT = 31337
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
db = Database()
|
db = None
|
||||||
|
|
||||||
|
|
||||||
class AuthenticatorImpl(Authenticator.Server):
|
class AuthenticatorImpl(Authenticator.Server):
|
||||||
|
@ -139,6 +137,8 @@ async def client_connected(reader: asyncio.StreamReader, writer: asyncio.StreamW
|
||||||
|
|
||||||
|
|
||||||
async def main() -> None:
|
async def main() -> None:
|
||||||
|
global db
|
||||||
|
db = Database()
|
||||||
try:
|
try:
|
||||||
sync_interface()
|
sync_interface()
|
||||||
except:
|
except:
|
||||||
|
|
|
@ -5,7 +5,7 @@ from typing import Optional
|
||||||
import capnp
|
import capnp
|
||||||
|
|
||||||
from .leylines_capnp import Authenticator, LeylinesApi
|
from .leylines_capnp import Authenticator, LeylinesApi
|
||||||
from . import API_PORT, API_SSL_PORT
|
from .defs import API_PORT, API_SSL_PORT, DEFAULT_PORT
|
||||||
from .database import SERVER_NODE_ID
|
from .database import SERVER_NODE_ID
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1,16 +1,16 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
import os
|
import os
|
||||||
from typing import Iterable, TypeVar, List
|
from typing import Iterable, TypeVar, List, Optional
|
||||||
|
|
||||||
from distributed.client import default_client
|
from distributed.client import default_client
|
||||||
from distributed.diagnostics.progressbar import TextProgressBar
|
from distributed.diagnostics.progressbar import TextProgressBar
|
||||||
from distributed import Client, futures_of, Future
|
from distributed import Client, futures_of, Future
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from leylines import get_server_node, DEFAULT_PORT
|
from leylines.client import ClientSession, SERVER_NODE_ID, DEFAULT_PORT
|
||||||
|
|
||||||
|
|
||||||
def init_dask() -> Client:
|
def init_dask(dest: Optional[str] = None) -> Client:
|
||||||
"initializes dask"
|
"initializes dask"
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
@ -18,10 +18,20 @@ def init_dask() -> Client:
|
||||||
return default
|
return default
|
||||||
except ValueError:
|
except ValueError:
|
||||||
pass
|
pass
|
||||||
server_node = get_server_node()
|
|
||||||
dest = f"{server_node.ip}:{DEFAULT_PORT}"
|
if dest is None:
|
||||||
client = Client(dest)
|
async def get_dest() -> str:
|
||||||
return client
|
async with ClientSession() as api:
|
||||||
|
nodes = (await api.getNodes().a_wait()).nodes
|
||||||
|
for node in nodes:
|
||||||
|
info = await node.getInfo().a_wait()
|
||||||
|
if info.id == SERVER_NODE_ID:
|
||||||
|
return f"{info.ip}:{DEFAULT_PORT}"
|
||||||
|
raise Exception("no server node defined!")
|
||||||
|
|
||||||
|
dest = asyncio.run(get_dest())
|
||||||
|
|
||||||
|
return Client(dest)
|
||||||
|
|
||||||
|
|
||||||
async def init_dask_async() -> Client:
|
async def init_dask_async() -> Client:
|
||||||
|
|
|
@ -0,0 +1,3 @@
|
||||||
|
DEFAULT_PORT = 31337
|
||||||
|
API_PORT = 31338
|
||||||
|
API_SSL_PORT = 31337
|
Loading…
Reference in New Issue