diff --git a/leylines-support/02-dask.ipy b/leylines-support/02-dask.ipy index 77d6de6..41450c7 100644 --- a/leylines-support/02-dask.ipy +++ b/leylines-support/02-dask.ipy @@ -5,7 +5,7 @@ ipy = __import__("IPython") def dask(line, local_ns): "initializes dask" - from leylines import get_server_node, DEFAULT_PORT, db + from leylines.client import ClientSession, SERVER_NODE_ID from leylines.dask import init_dask, tqdmprogress local_ns['tqdmprogress'] = tqdmprogress @@ -27,21 +27,34 @@ def dask(line, local_ns): local_ns['upload'] = lambda file: upload(client, file) - server_node = get_server_node() - workers = [node for node in db.get_nodes() - if node.id != server_node.id and node.ssh_key is not None] - dest = f"{server_node.ip}:{DEFAULT_PORT}" + import asyncio + async def get_nodes_info(): + async with ClientSession() as api: + nodes = (await api.getNodes().a_wait()).nodes + server = None + workers = [] + for node in nodes: + info = (await node.getInfo().a_wait()).info + out_info = (info.ip, info.name) + if info.id == SERVER_NODE_ID: + server = out_info + elif info.sshkey.which() != "none": + workers.append(out_info) + return server, workers + + server, workers = asyncio.run(get_nodes_info()) + dest = f"{server[0]}:{DEFAULT_PORT}" print("connected to APEX at", dest) - workers_by_ip = {str(node.ip):node for node in workers} - workers_status = {str(node.ip):False for node in workers} + workers_by_ip = {str(node[0]):node for node in workers} + workers_status = {str(node[0]):False for node in workers} for addr, info in client.scheduler_info()["workers"].items(): workers_status[info["host"]] = True for ip, node in sorted(workers_by_ip.items(), key=lambda x:x[1].name): if workers_status[ip]: - print(f"{node.name} ({node.ip}): up") + print(f"{node[1]} ({node[0]}): up") else: - print(f"{node.name} ({node.ip}): down") + print(f"{node[1]} ({node[0]}): down") @ipy.core.magic.register_line_magic @ipy.core.magic.needs_local_scope diff --git a/leylines-support/leylines-daemon.service b/leylines-support/leylines-daemon.service new file mode 100644 index 0000000..fd7a6ea --- /dev/null +++ b/leylines-support/leylines-daemon.service @@ -0,0 +1,9 @@ +[Unit] +Description=leylines server + +[Service] +Type=simple +ExecStart=/usr/bin/env python3 -m leylines daemon + +[Install] +WantedBy=default.target diff --git a/leylines/leylines/__init__.py b/leylines/leylines/__init__.py index 22104b3..aaa3d96 100644 --- a/leylines/leylines/__init__.py +++ b/leylines/leylines/__init__.py @@ -1,35 +1,151 @@ import asyncio import binascii +import ipaddress +import logging import secrets -from typing import Optional, Any +from typing import Optional, Any, Dict, List +import capnp +from capnp.lib.capnp import _CallContext as _Ctx from pyroute2 import IPRoute, WireGuard import monocypher from .database import Database, Node, SERVER_NODE_ID +from .leylines_capnp import Authenticator, LeylinesApi, Maybe IFNAME = 'leyline-wg' DEFAULT_PORT = 31337 API_PORT = 31338 +API_SSL_PORT = 31337 +logger = logging.getLogger(__name__) db = Database() +class AuthenticatorImpl(Authenticator.Server): + def authenticate(self, _context: _Ctx, creds: Authenticator.Credentials, + **kwargs) -> Authenticator.AuthResult: + logger.debug("got auth request!") + which = creds.which() + if which == "token" and creds.token == db.get_token(): + return Authenticator.AuthResult(succeeded=LeylinesApiImpl()) + else: + return Authenticator.AuthResult(unauthorized=None) + + +class LeylinesApiImpl(LeylinesApi.Server): + def getSetting(self, name: str, _context: _Ctx, **kwargs) -> str: + return db.get_setting(name) + + def putSetting(self, name: str, value: str, _context: _Ctx, **kwargs) -> None: + db.put_setting(name, value) + + def getNodes(self, _context: _Ctx, **kwargs) -> List: + nodes = db.get_nodes() + return [NodeImpl(node) for node in nodes] + + def initServer(self, name: str, ip: str, sshkey: str, _context: _Ctx, **kwargs) -> None: + db.init_server(name, ipaddress.IPv4Address(ip), sshkey) + + def sync(self, _context: _Ctx, **kwargs) -> None: + try: + sync_interface() + except: + logger.exception("failed to sync") + + def addNode(self, name: str, ip: Maybe, sshkey: Maybe, _context: _Ctx, **kwargs) -> None: + db.add_node( + name, + ipaddress.IPv4Address(ip.some) if ip.which == "some" else None, + sshkey.some if sshkey.which == "some" else None + ) + + def getNode(self, id: int, _context: _Ctx, **kwargs) -> None: + node = db.get_node(id) + if node is None: + setattr(_context.results, "none", None) + else: + setattr(_context.results, "some", NodeImpl(node)) + + +def set_all(target: Any, **kwargs) -> None: + for k,v in kwargs.items(): + setattr(target, k, v) + + +class NodeImpl(LeylinesApi.Node.Server): + def __init__(self, node: Node) -> None: + self._node = node + + def getInfo(self, _context: _Ctx, **kwargs) -> None: + set_all(_context.results, + id=self._node.id, + name=self._node.name, + publicIp=({"some": str(self._node.public_ip)} + if self._node.public_ip is not None else {"none": None}), + ip=str(self._node.ip), + seckey=self._node.seckey, + pubkey=seckey_to_pubkey(self._node.seckey), + sshkey=({"some": self._node.ssh_key} + if self._node.ssh_key is not None else {"none": None}), + ) + + def getResources(self, _context: _Ctx, **kwargs) -> List[str]: + return db.get_node_resources(self._node.id) + + def addResource(self, resource: str, _context: _Ctx, **kwargs) -> None: + db.add_node_resource(self._node.id, resource) + + def delResource(self, resource: str, _context: _Ctx, **kwargs) -> None: + db.remove_node_resource(self._node.id, resource) + + def getConfig(self, _context: _Ctx, **kwargs) -> str: + return generate_node_config(self._node.id) + + async def client_connected(reader: asyncio.StreamReader, writer: asyncio.StreamWriter) -> None: + server = capnp.TwoPartyServer(bootstrap=AuthenticatorImpl()) + logger.debug("got connection") + + async def read_task(): + while not reader.at_eof(): + data = await reader.read(4096) + await server.write(data) + + async def write_task(): + while not reader.at_eof(): + data = await server.read(4096) + writer.write(data.tobytes()) + await writer.drain() + + tasks = [asyncio.create_task(read_task()), + asyncio.create_task(write_task())] + + while not reader.at_eof(): + server.poll_once() + await asyncio.sleep(0.01) + + logger.debug("connection ended") + try: - while True: - line = reader.readline() - if line.strip() != db.get_token(): - raise Exception("unauthorized") - # TODO - finally: - writer.close() - await writer.wait_closed() + await tasks[0] + except: + logger.exception("client task error") + + try: + await tasks[1] + except: + logger.exception("client task error") async def main() -> None: - sync_interface() - server = await asyncio.create_server(client_connected, host="127.0.0.1", port=API_PORT) + try: + sync_interface() + except: + logger.exception("failed to sync") + loop = asyncio.get_event_loop() + server = await asyncio.start_server(client_connected, host="127.0.0.1", port=API_PORT) + logger.info("serving forever") await server.serve_forever() diff --git a/leylines/leylines/__main__.py b/leylines/leylines/__main__.py index 2e01a72..d2408ce 100644 --- a/leylines/leylines/__main__.py +++ b/leylines/leylines/__main__.py @@ -1,7 +1,9 @@ import argparse -import ipaddress +import asyncio +import logging +from typing import Optional, Dict -from leylines import main, db, sync_interface, seckey_to_pubkey, generate_node_config +from leylines import main from leylines.database import SERVER_NODE_ID @@ -10,11 +12,58 @@ def nop(): pass +def optional_to_maybe(value: Optional[str]) -> Dict: + if value is None: + return {"none": None} + else: + return {"some": value} + + +async def client_main(args: argparse.Namespace) -> None: + from leylines.client import ClientSession + async with ClientSession() as api: + if args.cmd == "status": + subnet = (await api.getSetting("subnet").a_wait()).value + nodes = (await api.getNodes().a_wait()).nodes + print("SUBNET:", subnet) + + server_node = None + for node in nodes: + info = (await node.getInfo().a_wait()); + public_ip = info.publicIp.some if info.publicIp.which() == "some" else None + print(f"NODE {info.id}:", info.name, public_ip, info.ip, info.pubkey, + "" if info.sshkey.which() == "some" else "") + if info.id == SERVER_NODE_ID: + server_node = node + + if server_node is None: + print("! server not defined. run leylines init>") + elif args.cmd == "init": + await api.initServer(args.name, args.ip, args.ssh_key).a_wait() + elif args.cmd == "add": + await api.addNode( + args.name, + optional_to_maybe(args.ip), + optional_to_maybe( + args.ssh_key.read() if args.ssh_key is not None else None) + ).a_wait() + elif args.cmd == "get-conf": + node = await api.getNode(args.id).a_wait() + if node.which() == "none": + print("no such node!") + else: + print((await node.some.getConfig().a_wait()).config) + elif args.cmd == "sync": + await api.sync().a_wait() + + parser = argparse.ArgumentParser(description="wireguard management system for dragons") cmd = parser.add_subparsers(dest="cmd") cmd.required = True cmd_daemon = cmd.add_parser("daemon") +cmd_local_daemon = cmd.add_parser("local-daemon") +cmd_print_token = cmd.add_parser("print-token") cmd_status = cmd.add_parser("status") @@ -37,32 +86,35 @@ cmd_sync = cmd.add_parser("sync") args = parser.parse_args() -if args.cmd == "daemon": +if args.cmd == "print-token": + from leylines import db + print(db.get_token()) +elif args.cmd == "daemon" or args.cmd == "local-daemon": + # Set up logging + root = logging.getLogger() + root.setLevel(logging.DEBUG) + logging.getLogger("asyncio").setLevel(logging.INFO) + try: + if args.cmd == "local-daemon": + raise ImportError() + from systemd.journal import JournalHandler + ch = JournalHandler() + ch.setLevel(logging.INFO) + fm = logging.Formatter( + "%(name)-25s - %(funcName)-10s - %(levelname)-5s - %(message)s") + ch.setFormatter(fm) + root.addHandler(ch) + except ImportError: + ch = logging.StreamHandler() + ch.setLevel(logging.DEBUG) + fm = logging.Formatter( + "%(asctime)s - %(name)-25s - %(funcName)-10s - %(levelname)-5s" + + " - %(message)s") + ch.setFormatter(fm) + root.addHandler(ch) + + # run + logging.info("starting leylines daemon") asyncio.run(main()) -elif args.cmd == "status": - token = db.get_token() - subnet = db.get_subnet() - nodes = db.get_nodes() - server_node = db.get_server_node() - print("TOKEN:", token) - print("SUBNET:", subnet) - if server_node is None: - print("SERVER: ") - else: - print("SERVER:", server_node.name, server_node.ip, server_node.public_ip, - seckey_to_pubkey(server_node.seckey), - "" if server_node.ssh_key is not None else "") - for node in nodes: - if node.id == SERVER_NODE_ID: - continue - print(f"NODE {node.id}:", node.name, node.public_ip, node.ip, - seckey_to_pubkey(node.seckey), "" if node.ssh_key is not None else "") -elif args.cmd == "init": - db.init_server(args.name, ipaddress.IPv4Address(args.ip)) -elif args.cmd == "add": - db.add_node(args.name, args.ip, args.ssh_key.read() if args.ssh_key is not None else None) - sync_interface() -elif args.cmd == "get-conf": - print(generate_node_config(args.id)) -elif args.cmd == "sync": - sync_interface() +else: + asyncio.run(client_main(args)) diff --git a/leylines/leylines/client.py b/leylines/leylines/client.py new file mode 100644 index 0000000..96fbec1 --- /dev/null +++ b/leylines/leylines/client.py @@ -0,0 +1,78 @@ +import asyncio +from contextlib import suppress +import os +from typing import Optional + +import capnp + +from .leylines_capnp import Authenticator, LeylinesApi +from . import API_PORT, API_SSL_PORT +from .database import SERVER_NODE_ID + + +XDG_CONFIG_HOME = os.path.expanduser(os.environ.get("XDG_CONFIG_HOME", "~/.config")) +DEFAULT_TOKEN_PATH = os.path.join(XDG_CONFIG_HOME, "leylines", "token") +DEFAULT_HOST_PATH = os.path.join(XDG_CONFIG_HOME, "leylines", "host") + + +def default_get_token() -> str: + if "LEYLINES_TOKEN" in os.environ: + return os.environ["LEYLINES_TOKEN"] + with open(DEFAULT_TOKEN_PATH, "r") as f: + return f.read().strip() + + +def default_get_host() -> str: + if "LEYLINES_HOST" in os.environ: + return os.environ["LEYLINES_HOST"] + with open(DEFAULT_HOST_PATH, "r") as f: + return f.read().strip() + + +class ClientSession: + def __init__(self, host: Optional[str] = None, token: Optional[str] = None) -> None: + self._host = host + self._token = token + if self._host == None: + self._host = default_get_host() + if self._token == None: + self._token = default_get_token() + + async def __aenter__(self): + reader, writer = await asyncio.open_connection(self._host, API_SSL_PORT, ssl=True) + self.writer = writer + client = capnp.TwoPartyClient() + + async def read_task(): + while not reader.at_eof(): + data = await reader.read(4096) + client.write(data) + + async def write_task(): + while not reader.at_eof(): + data = await client.read(4096) + writer.write(data.tobytes()) + await writer.drain() + + self.tasks = [asyncio.create_task(read_task()), asyncio.create_task(write_task())] + + auth = client.bootstrap().cast_as(Authenticator) + creds = Authenticator.Credentials(token=self._token) + response = await auth.authenticate(creds).a_wait() + if response.result.which() != "succeeded": + self.__aexit__() + raise Exception("authentication failure!") + + api = response.result.succeeded.cast_as(LeylinesApi) + return api + + async def __aexit__(self, *args): + with suppress(Exception): + self.writer.close() + await self.writer.wait_closed() + with suppress(Exception): + self.tasks[0].cancel() + await self.tasks[0] + with suppress(Exception): + self.tasks[1].cancel() + await self.tasks[1] diff --git a/leylines/leylines/database.py b/leylines/leylines/database.py index ca79e17..5a88b7c 100644 --- a/leylines/leylines/database.py +++ b/leylines/leylines/database.py @@ -53,15 +53,28 @@ class Database: raise Exception("there is already a server node defined") self.add_node(name, public_ip, ssh_key) - def get_subnet(self) -> ipaddress.IPv4Interface: + def get_setting(self, name: str) -> Optional[str]: cur = self.conn.cursor() - cur.execute("SELECT value FROM settings WHERE name='subnet'") - return ipaddress.IPv4Interface(cur.fetchone()[0]) + cur.execute("SELECT value FROM settings WHERE name=?", (name,)) + res = cur.fetchone() + return res[0] if res is not None else None + + def put_setting(self, name: str, value: str) -> None: + self.conn.execute("INSERT OR REPLACE INTO settings(name, value) VALUES(?, ?)", + (name, value)) + self.conn.commit() + + def get_subnet(self) -> ipaddress.IPv4Interface: + res = self.get_setting("subnet") + if res is None: + raise Exception("invalid state: no subnet setting") + return ipaddress.IPv4Interface(res) def get_token(self) -> str: - cur = self.conn.cursor() - cur.execute("SELECT value FROM settings WHERE name='token'") - return cur.fetchone()[0] + res = self.get_setting("token") + if res is None: + raise Exception("invalid state: no token setting") + return res def _get_free_ip(self) -> ipaddress.IPv4Address: subnet = self.get_subnet().network diff --git a/leylines/leylines/leylines.capnp b/leylines/leylines/leylines.capnp new file mode 100644 index 0000000..bb5e3d5 --- /dev/null +++ b/leylines/leylines/leylines.capnp @@ -0,0 +1,56 @@ +@0xa4819f0f2c639488; + +interface Authenticator { + authenticate @0 (creds :Credentials) -> (result :AuthResult); + + struct Credentials { + union { + nop @0 :Void; + token @1 :Text; + } + } + + struct AuthResult { + union { + unauthorized @0 :Void; + succeeded @1 :LeylinesApi; + } + } +} + +interface LeylinesApi { + getSetting @0 (name :Text) -> (value :Text); + putSetting @1 (name :Text, value :Text) -> (); + getNodes @2 () -> (nodes :List(Node)); + initServer @3 (name :Text, ip :Text, sshkey :Text) -> (); + sync @4 () -> (); + addNode @5 (name :Text, ip :Maybe(Text), sshkey :Maybe(Text)); + getNode @6 (id :Int32) -> Maybe(Node); + + interface Node { + getInfo @0 () -> Info; + + struct Info { + id @0 :Int32; + name @1 :Text; + publicIp @2 :Maybe(Text); + ip @3 :Text; + seckey @4 :Text; + pubkey @5 :Text; + sshkey @6 :Maybe(Text); + } + + getResources @1 () -> (resources :List(Text)); + addResource @2 (resource :Text) -> (); + delResource @3 (resource :Text) -> (); + + getConfig @4 () -> (config :Text); + } +} + +struct Maybe(T) { + union { + none @0 :Void; + some @1 :T; + } +} diff --git a/leylines/setup.py b/leylines/setup.py index de1ba60..0627576 100644 --- a/leylines/setup.py +++ b/leylines/setup.py @@ -11,7 +11,9 @@ setup( packages=['leylines'], install_requires=[ "leylines-monocypher", - "pyroute2" + "pyroute2", + "pycapnp", + "systemd" ], include_package_data=True, entry_points={