make a proper RPC protocol
This commit is contained in:
parent
19944c1fa2
commit
06c0c85e79
|
@ -5,7 +5,7 @@ ipy = __import__("IPython")
|
|||
def dask(line, local_ns):
|
||||
"initializes dask"
|
||||
|
||||
from leylines import get_server_node, DEFAULT_PORT, db
|
||||
from leylines.client import ClientSession, SERVER_NODE_ID
|
||||
from leylines.dask import init_dask, tqdmprogress
|
||||
|
||||
local_ns['tqdmprogress'] = tqdmprogress
|
||||
|
@ -27,21 +27,34 @@ def dask(line, local_ns):
|
|||
|
||||
local_ns['upload'] = lambda file: upload(client, file)
|
||||
|
||||
server_node = get_server_node()
|
||||
workers = [node for node in db.get_nodes()
|
||||
if node.id != server_node.id and node.ssh_key is not None]
|
||||
dest = f"{server_node.ip}:{DEFAULT_PORT}"
|
||||
import asyncio
|
||||
async def get_nodes_info():
|
||||
async with ClientSession() as api:
|
||||
nodes = (await api.getNodes().a_wait()).nodes
|
||||
server = None
|
||||
workers = []
|
||||
for node in nodes:
|
||||
info = (await node.getInfo().a_wait()).info
|
||||
out_info = (info.ip, info.name)
|
||||
if info.id == SERVER_NODE_ID:
|
||||
server = out_info
|
||||
elif info.sshkey.which() != "none":
|
||||
workers.append(out_info)
|
||||
return server, workers
|
||||
|
||||
server, workers = asyncio.run(get_nodes_info())
|
||||
dest = f"{server[0]}:{DEFAULT_PORT}"
|
||||
print("connected to APEX at", dest)
|
||||
|
||||
workers_by_ip = {str(node.ip):node for node in workers}
|
||||
workers_status = {str(node.ip):False for node in workers}
|
||||
workers_by_ip = {str(node[0]):node for node in workers}
|
||||
workers_status = {str(node[0]):False for node in workers}
|
||||
for addr, info in client.scheduler_info()["workers"].items():
|
||||
workers_status[info["host"]] = True
|
||||
for ip, node in sorted(workers_by_ip.items(), key=lambda x:x[1].name):
|
||||
if workers_status[ip]:
|
||||
print(f"{node.name} ({node.ip}): up")
|
||||
print(f"{node[1]} ({node[0]}): up")
|
||||
else:
|
||||
print(f"{node.name} ({node.ip}): down")
|
||||
print(f"{node[1]} ({node[0]}): down")
|
||||
|
||||
@ipy.core.magic.register_line_magic
|
||||
@ipy.core.magic.needs_local_scope
|
||||
|
|
|
@ -0,0 +1,9 @@
|
|||
[Unit]
|
||||
Description=leylines server
|
||||
|
||||
[Service]
|
||||
Type=simple
|
||||
ExecStart=/usr/bin/env python3 -m leylines daemon
|
||||
|
||||
[Install]
|
||||
WantedBy=default.target
|
|
@ -1,35 +1,151 @@
|
|||
import asyncio
|
||||
import binascii
|
||||
import ipaddress
|
||||
import logging
|
||||
import secrets
|
||||
from typing import Optional, Any
|
||||
from typing import Optional, Any, Dict, List
|
||||
|
||||
import capnp
|
||||
from capnp.lib.capnp import _CallContext as _Ctx
|
||||
from pyroute2 import IPRoute, WireGuard
|
||||
|
||||
import monocypher
|
||||
from .database import Database, Node, SERVER_NODE_ID
|
||||
from .leylines_capnp import Authenticator, LeylinesApi, Maybe
|
||||
|
||||
|
||||
IFNAME = 'leyline-wg'
|
||||
DEFAULT_PORT = 31337
|
||||
API_PORT = 31338
|
||||
API_SSL_PORT = 31337
|
||||
logger = logging.getLogger(__name__)
|
||||
db = Database()
|
||||
|
||||
|
||||
async def client_connected(reader: asyncio.StreamReader, writer: asyncio.StreamWriter) -> None:
|
||||
class AuthenticatorImpl(Authenticator.Server):
|
||||
def authenticate(self, _context: _Ctx, creds: Authenticator.Credentials,
|
||||
**kwargs) -> Authenticator.AuthResult:
|
||||
logger.debug("got auth request!")
|
||||
which = creds.which()
|
||||
if which == "token" and creds.token == db.get_token():
|
||||
return Authenticator.AuthResult(succeeded=LeylinesApiImpl())
|
||||
else:
|
||||
return Authenticator.AuthResult(unauthorized=None)
|
||||
|
||||
|
||||
class LeylinesApiImpl(LeylinesApi.Server):
|
||||
def getSetting(self, name: str, _context: _Ctx, **kwargs) -> str:
|
||||
return db.get_setting(name)
|
||||
|
||||
def putSetting(self, name: str, value: str, _context: _Ctx, **kwargs) -> None:
|
||||
db.put_setting(name, value)
|
||||
|
||||
def getNodes(self, _context: _Ctx, **kwargs) -> List:
|
||||
nodes = db.get_nodes()
|
||||
return [NodeImpl(node) for node in nodes]
|
||||
|
||||
def initServer(self, name: str, ip: str, sshkey: str, _context: _Ctx, **kwargs) -> None:
|
||||
db.init_server(name, ipaddress.IPv4Address(ip), sshkey)
|
||||
|
||||
def sync(self, _context: _Ctx, **kwargs) -> None:
|
||||
try:
|
||||
while True:
|
||||
line = reader.readline()
|
||||
if line.strip() != db.get_token():
|
||||
raise Exception("unauthorized")
|
||||
# TODO
|
||||
finally:
|
||||
writer.close()
|
||||
await writer.wait_closed()
|
||||
sync_interface()
|
||||
except:
|
||||
logger.exception("failed to sync")
|
||||
|
||||
def addNode(self, name: str, ip: Maybe, sshkey: Maybe, _context: _Ctx, **kwargs) -> None:
|
||||
db.add_node(
|
||||
name,
|
||||
ipaddress.IPv4Address(ip.some) if ip.which == "some" else None,
|
||||
sshkey.some if sshkey.which == "some" else None
|
||||
)
|
||||
|
||||
def getNode(self, id: int, _context: _Ctx, **kwargs) -> None:
|
||||
node = db.get_node(id)
|
||||
if node is None:
|
||||
setattr(_context.results, "none", None)
|
||||
else:
|
||||
setattr(_context.results, "some", NodeImpl(node))
|
||||
|
||||
|
||||
def set_all(target: Any, **kwargs) -> None:
|
||||
for k,v in kwargs.items():
|
||||
setattr(target, k, v)
|
||||
|
||||
|
||||
class NodeImpl(LeylinesApi.Node.Server):
|
||||
def __init__(self, node: Node) -> None:
|
||||
self._node = node
|
||||
|
||||
def getInfo(self, _context: _Ctx, **kwargs) -> None:
|
||||
set_all(_context.results,
|
||||
id=self._node.id,
|
||||
name=self._node.name,
|
||||
publicIp=({"some": str(self._node.public_ip)}
|
||||
if self._node.public_ip is not None else {"none": None}),
|
||||
ip=str(self._node.ip),
|
||||
seckey=self._node.seckey,
|
||||
pubkey=seckey_to_pubkey(self._node.seckey),
|
||||
sshkey=({"some": self._node.ssh_key}
|
||||
if self._node.ssh_key is not None else {"none": None}),
|
||||
)
|
||||
|
||||
def getResources(self, _context: _Ctx, **kwargs) -> List[str]:
|
||||
return db.get_node_resources(self._node.id)
|
||||
|
||||
def addResource(self, resource: str, _context: _Ctx, **kwargs) -> None:
|
||||
db.add_node_resource(self._node.id, resource)
|
||||
|
||||
def delResource(self, resource: str, _context: _Ctx, **kwargs) -> None:
|
||||
db.remove_node_resource(self._node.id, resource)
|
||||
|
||||
def getConfig(self, _context: _Ctx, **kwargs) -> str:
|
||||
return generate_node_config(self._node.id)
|
||||
|
||||
|
||||
async def client_connected(reader: asyncio.StreamReader, writer: asyncio.StreamWriter) -> None:
|
||||
server = capnp.TwoPartyServer(bootstrap=AuthenticatorImpl())
|
||||
logger.debug("got connection")
|
||||
|
||||
async def read_task():
|
||||
while not reader.at_eof():
|
||||
data = await reader.read(4096)
|
||||
await server.write(data)
|
||||
|
||||
async def write_task():
|
||||
while not reader.at_eof():
|
||||
data = await server.read(4096)
|
||||
writer.write(data.tobytes())
|
||||
await writer.drain()
|
||||
|
||||
tasks = [asyncio.create_task(read_task()),
|
||||
asyncio.create_task(write_task())]
|
||||
|
||||
while not reader.at_eof():
|
||||
server.poll_once()
|
||||
await asyncio.sleep(0.01)
|
||||
|
||||
logger.debug("connection ended")
|
||||
|
||||
try:
|
||||
await tasks[0]
|
||||
except:
|
||||
logger.exception("client task error")
|
||||
|
||||
try:
|
||||
await tasks[1]
|
||||
except:
|
||||
logger.exception("client task error")
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
try:
|
||||
sync_interface()
|
||||
server = await asyncio.create_server(client_connected, host="127.0.0.1", port=API_PORT)
|
||||
except:
|
||||
logger.exception("failed to sync")
|
||||
loop = asyncio.get_event_loop()
|
||||
server = await asyncio.start_server(client_connected, host="127.0.0.1", port=API_PORT)
|
||||
logger.info("serving forever")
|
||||
await server.serve_forever()
|
||||
|
||||
|
||||
|
|
|
@ -1,7 +1,9 @@
|
|||
import argparse
|
||||
import ipaddress
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Optional, Dict
|
||||
|
||||
from leylines import main, db, sync_interface, seckey_to_pubkey, generate_node_config
|
||||
from leylines import main
|
||||
from leylines.database import SERVER_NODE_ID
|
||||
|
||||
|
||||
|
@ -10,11 +12,58 @@ def nop():
|
|||
pass
|
||||
|
||||
|
||||
def optional_to_maybe(value: Optional[str]) -> Dict:
|
||||
if value is None:
|
||||
return {"none": None}
|
||||
else:
|
||||
return {"some": value}
|
||||
|
||||
|
||||
async def client_main(args: argparse.Namespace) -> None:
|
||||
from leylines.client import ClientSession
|
||||
async with ClientSession() as api:
|
||||
if args.cmd == "status":
|
||||
subnet = (await api.getSetting("subnet").a_wait()).value
|
||||
nodes = (await api.getNodes().a_wait()).nodes
|
||||
print("SUBNET:", subnet)
|
||||
|
||||
server_node = None
|
||||
for node in nodes:
|
||||
info = (await node.getInfo().a_wait());
|
||||
public_ip = info.publicIp.some if info.publicIp.which() == "some" else None
|
||||
print(f"NODE {info.id}:", info.name, public_ip, info.ip, info.pubkey,
|
||||
"<Have SSH Key>" if info.sshkey.which() == "some" else "")
|
||||
if info.id == SERVER_NODE_ID:
|
||||
server_node = node
|
||||
|
||||
if server_node is None:
|
||||
print("! server not defined. run leylines init>")
|
||||
elif args.cmd == "init":
|
||||
await api.initServer(args.name, args.ip, args.ssh_key).a_wait()
|
||||
elif args.cmd == "add":
|
||||
await api.addNode(
|
||||
args.name,
|
||||
optional_to_maybe(args.ip),
|
||||
optional_to_maybe(
|
||||
args.ssh_key.read() if args.ssh_key is not None else None)
|
||||
).a_wait()
|
||||
elif args.cmd == "get-conf":
|
||||
node = await api.getNode(args.id).a_wait()
|
||||
if node.which() == "none":
|
||||
print("no such node!")
|
||||
else:
|
||||
print((await node.some.getConfig().a_wait()).config)
|
||||
elif args.cmd == "sync":
|
||||
await api.sync().a_wait()
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser(description="wireguard management system for dragons")
|
||||
cmd = parser.add_subparsers(dest="cmd")
|
||||
cmd.required = True
|
||||
|
||||
cmd_daemon = cmd.add_parser("daemon")
|
||||
cmd_local_daemon = cmd.add_parser("local-daemon")
|
||||
cmd_print_token = cmd.add_parser("print-token")
|
||||
|
||||
cmd_status = cmd.add_parser("status")
|
||||
|
||||
|
@ -37,32 +86,35 @@ cmd_sync = cmd.add_parser("sync")
|
|||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.cmd == "daemon":
|
||||
if args.cmd == "print-token":
|
||||
from leylines import db
|
||||
print(db.get_token())
|
||||
elif args.cmd == "daemon" or args.cmd == "local-daemon":
|
||||
# Set up logging
|
||||
root = logging.getLogger()
|
||||
root.setLevel(logging.DEBUG)
|
||||
logging.getLogger("asyncio").setLevel(logging.INFO)
|
||||
try:
|
||||
if args.cmd == "local-daemon":
|
||||
raise ImportError()
|
||||
from systemd.journal import JournalHandler
|
||||
ch = JournalHandler()
|
||||
ch.setLevel(logging.INFO)
|
||||
fm = logging.Formatter(
|
||||
"%(name)-25s - %(funcName)-10s - %(levelname)-5s - %(message)s")
|
||||
ch.setFormatter(fm)
|
||||
root.addHandler(ch)
|
||||
except ImportError:
|
||||
ch = logging.StreamHandler()
|
||||
ch.setLevel(logging.DEBUG)
|
||||
fm = logging.Formatter(
|
||||
"%(asctime)s - %(name)-25s - %(funcName)-10s - %(levelname)-5s"
|
||||
+ " - %(message)s")
|
||||
ch.setFormatter(fm)
|
||||
root.addHandler(ch)
|
||||
|
||||
# run
|
||||
logging.info("starting leylines daemon")
|
||||
asyncio.run(main())
|
||||
elif args.cmd == "status":
|
||||
token = db.get_token()
|
||||
subnet = db.get_subnet()
|
||||
nodes = db.get_nodes()
|
||||
server_node = db.get_server_node()
|
||||
print("TOKEN:", token)
|
||||
print("SUBNET:", subnet)
|
||||
if server_node is None:
|
||||
print("SERVER: <not defined. run leylines init>")
|
||||
else:
|
||||
print("SERVER:", server_node.name, server_node.ip, server_node.public_ip,
|
||||
seckey_to_pubkey(server_node.seckey),
|
||||
"<Have SSH Key>" if server_node.ssh_key is not None else "")
|
||||
for node in nodes:
|
||||
if node.id == SERVER_NODE_ID:
|
||||
continue
|
||||
print(f"NODE {node.id}:", node.name, node.public_ip, node.ip,
|
||||
seckey_to_pubkey(node.seckey), "<Have SSH Key>" if node.ssh_key is not None else "")
|
||||
elif args.cmd == "init":
|
||||
db.init_server(args.name, ipaddress.IPv4Address(args.ip))
|
||||
elif args.cmd == "add":
|
||||
db.add_node(args.name, args.ip, args.ssh_key.read() if args.ssh_key is not None else None)
|
||||
sync_interface()
|
||||
elif args.cmd == "get-conf":
|
||||
print(generate_node_config(args.id))
|
||||
elif args.cmd == "sync":
|
||||
sync_interface()
|
||||
else:
|
||||
asyncio.run(client_main(args))
|
||||
|
|
|
@ -0,0 +1,78 @@
|
|||
import asyncio
|
||||
from contextlib import suppress
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
import capnp
|
||||
|
||||
from .leylines_capnp import Authenticator, LeylinesApi
|
||||
from . import API_PORT, API_SSL_PORT
|
||||
from .database import SERVER_NODE_ID
|
||||
|
||||
|
||||
XDG_CONFIG_HOME = os.path.expanduser(os.environ.get("XDG_CONFIG_HOME", "~/.config"))
|
||||
DEFAULT_TOKEN_PATH = os.path.join(XDG_CONFIG_HOME, "leylines", "token")
|
||||
DEFAULT_HOST_PATH = os.path.join(XDG_CONFIG_HOME, "leylines", "host")
|
||||
|
||||
|
||||
def default_get_token() -> str:
|
||||
if "LEYLINES_TOKEN" in os.environ:
|
||||
return os.environ["LEYLINES_TOKEN"]
|
||||
with open(DEFAULT_TOKEN_PATH, "r") as f:
|
||||
return f.read().strip()
|
||||
|
||||
|
||||
def default_get_host() -> str:
|
||||
if "LEYLINES_HOST" in os.environ:
|
||||
return os.environ["LEYLINES_HOST"]
|
||||
with open(DEFAULT_HOST_PATH, "r") as f:
|
||||
return f.read().strip()
|
||||
|
||||
|
||||
class ClientSession:
|
||||
def __init__(self, host: Optional[str] = None, token: Optional[str] = None) -> None:
|
||||
self._host = host
|
||||
self._token = token
|
||||
if self._host == None:
|
||||
self._host = default_get_host()
|
||||
if self._token == None:
|
||||
self._token = default_get_token()
|
||||
|
||||
async def __aenter__(self):
|
||||
reader, writer = await asyncio.open_connection(self._host, API_SSL_PORT, ssl=True)
|
||||
self.writer = writer
|
||||
client = capnp.TwoPartyClient()
|
||||
|
||||
async def read_task():
|
||||
while not reader.at_eof():
|
||||
data = await reader.read(4096)
|
||||
client.write(data)
|
||||
|
||||
async def write_task():
|
||||
while not reader.at_eof():
|
||||
data = await client.read(4096)
|
||||
writer.write(data.tobytes())
|
||||
await writer.drain()
|
||||
|
||||
self.tasks = [asyncio.create_task(read_task()), asyncio.create_task(write_task())]
|
||||
|
||||
auth = client.bootstrap().cast_as(Authenticator)
|
||||
creds = Authenticator.Credentials(token=self._token)
|
||||
response = await auth.authenticate(creds).a_wait()
|
||||
if response.result.which() != "succeeded":
|
||||
self.__aexit__()
|
||||
raise Exception("authentication failure!")
|
||||
|
||||
api = response.result.succeeded.cast_as(LeylinesApi)
|
||||
return api
|
||||
|
||||
async def __aexit__(self, *args):
|
||||
with suppress(Exception):
|
||||
self.writer.close()
|
||||
await self.writer.wait_closed()
|
||||
with suppress(Exception):
|
||||
self.tasks[0].cancel()
|
||||
await self.tasks[0]
|
||||
with suppress(Exception):
|
||||
self.tasks[1].cancel()
|
||||
await self.tasks[1]
|
|
@ -53,15 +53,28 @@ class Database:
|
|||
raise Exception("there is already a server node defined")
|
||||
self.add_node(name, public_ip, ssh_key)
|
||||
|
||||
def get_subnet(self) -> ipaddress.IPv4Interface:
|
||||
def get_setting(self, name: str) -> Optional[str]:
|
||||
cur = self.conn.cursor()
|
||||
cur.execute("SELECT value FROM settings WHERE name='subnet'")
|
||||
return ipaddress.IPv4Interface(cur.fetchone()[0])
|
||||
cur.execute("SELECT value FROM settings WHERE name=?", (name,))
|
||||
res = cur.fetchone()
|
||||
return res[0] if res is not None else None
|
||||
|
||||
def put_setting(self, name: str, value: str) -> None:
|
||||
self.conn.execute("INSERT OR REPLACE INTO settings(name, value) VALUES(?, ?)",
|
||||
(name, value))
|
||||
self.conn.commit()
|
||||
|
||||
def get_subnet(self) -> ipaddress.IPv4Interface:
|
||||
res = self.get_setting("subnet")
|
||||
if res is None:
|
||||
raise Exception("invalid state: no subnet setting")
|
||||
return ipaddress.IPv4Interface(res)
|
||||
|
||||
def get_token(self) -> str:
|
||||
cur = self.conn.cursor()
|
||||
cur.execute("SELECT value FROM settings WHERE name='token'")
|
||||
return cur.fetchone()[0]
|
||||
res = self.get_setting("token")
|
||||
if res is None:
|
||||
raise Exception("invalid state: no token setting")
|
||||
return res
|
||||
|
||||
def _get_free_ip(self) -> ipaddress.IPv4Address:
|
||||
subnet = self.get_subnet().network
|
||||
|
|
|
@ -0,0 +1,56 @@
|
|||
@0xa4819f0f2c639488;
|
||||
|
||||
interface Authenticator {
|
||||
authenticate @0 (creds :Credentials) -> (result :AuthResult);
|
||||
|
||||
struct Credentials {
|
||||
union {
|
||||
nop @0 :Void;
|
||||
token @1 :Text;
|
||||
}
|
||||
}
|
||||
|
||||
struct AuthResult {
|
||||
union {
|
||||
unauthorized @0 :Void;
|
||||
succeeded @1 :LeylinesApi;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
interface LeylinesApi {
|
||||
getSetting @0 (name :Text) -> (value :Text);
|
||||
putSetting @1 (name :Text, value :Text) -> ();
|
||||
getNodes @2 () -> (nodes :List(Node));
|
||||
initServer @3 (name :Text, ip :Text, sshkey :Text) -> ();
|
||||
sync @4 () -> ();
|
||||
addNode @5 (name :Text, ip :Maybe(Text), sshkey :Maybe(Text));
|
||||
getNode @6 (id :Int32) -> Maybe(Node);
|
||||
|
||||
interface Node {
|
||||
getInfo @0 () -> Info;
|
||||
|
||||
struct Info {
|
||||
id @0 :Int32;
|
||||
name @1 :Text;
|
||||
publicIp @2 :Maybe(Text);
|
||||
ip @3 :Text;
|
||||
seckey @4 :Text;
|
||||
pubkey @5 :Text;
|
||||
sshkey @6 :Maybe(Text);
|
||||
}
|
||||
|
||||
getResources @1 () -> (resources :List(Text));
|
||||
addResource @2 (resource :Text) -> ();
|
||||
delResource @3 (resource :Text) -> ();
|
||||
|
||||
getConfig @4 () -> (config :Text);
|
||||
}
|
||||
}
|
||||
|
||||
struct Maybe(T) {
|
||||
union {
|
||||
none @0 :Void;
|
||||
some @1 :T;
|
||||
}
|
||||
}
|
|
@ -11,7 +11,9 @@ setup(
|
|||
packages=['leylines'],
|
||||
install_requires=[
|
||||
"leylines-monocypher",
|
||||
"pyroute2"
|
||||
"pyroute2",
|
||||
"pycapnp",
|
||||
"systemd"
|
||||
],
|
||||
include_package_data=True,
|
||||
entry_points={
|
||||
|
|
Loading…
Reference in New Issue