make a proper RPC protocol

This commit is contained in:
xenia 2021-06-18 03:07:35 -04:00
parent 19944c1fa2
commit 06c0c85e79
8 changed files with 396 additions and 57 deletions

View File

@ -5,7 +5,7 @@ ipy = __import__("IPython")
def dask(line, local_ns): def dask(line, local_ns):
"initializes dask" "initializes dask"
from leylines import get_server_node, DEFAULT_PORT, db from leylines.client import ClientSession, SERVER_NODE_ID
from leylines.dask import init_dask, tqdmprogress from leylines.dask import init_dask, tqdmprogress
local_ns['tqdmprogress'] = tqdmprogress local_ns['tqdmprogress'] = tqdmprogress
@ -27,21 +27,34 @@ def dask(line, local_ns):
local_ns['upload'] = lambda file: upload(client, file) local_ns['upload'] = lambda file: upload(client, file)
server_node = get_server_node() import asyncio
workers = [node for node in db.get_nodes() async def get_nodes_info():
if node.id != server_node.id and node.ssh_key is not None] async with ClientSession() as api:
dest = f"{server_node.ip}:{DEFAULT_PORT}" nodes = (await api.getNodes().a_wait()).nodes
server = None
workers = []
for node in nodes:
info = (await node.getInfo().a_wait()).info
out_info = (info.ip, info.name)
if info.id == SERVER_NODE_ID:
server = out_info
elif info.sshkey.which() != "none":
workers.append(out_info)
return server, workers
server, workers = asyncio.run(get_nodes_info())
dest = f"{server[0]}:{DEFAULT_PORT}"
print("connected to APEX at", dest) print("connected to APEX at", dest)
workers_by_ip = {str(node.ip):node for node in workers} workers_by_ip = {str(node[0]):node for node in workers}
workers_status = {str(node.ip):False for node in workers} workers_status = {str(node[0]):False for node in workers}
for addr, info in client.scheduler_info()["workers"].items(): for addr, info in client.scheduler_info()["workers"].items():
workers_status[info["host"]] = True workers_status[info["host"]] = True
for ip, node in sorted(workers_by_ip.items(), key=lambda x:x[1].name): for ip, node in sorted(workers_by_ip.items(), key=lambda x:x[1].name):
if workers_status[ip]: if workers_status[ip]:
print(f"{node.name} ({node.ip}): up") print(f"{node[1]} ({node[0]}): up")
else: else:
print(f"{node.name} ({node.ip}): down") print(f"{node[1]} ({node[0]}): down")
@ipy.core.magic.register_line_magic @ipy.core.magic.register_line_magic
@ipy.core.magic.needs_local_scope @ipy.core.magic.needs_local_scope

View File

@ -0,0 +1,9 @@
[Unit]
Description=leylines server
[Service]
Type=simple
ExecStart=/usr/bin/env python3 -m leylines daemon
[Install]
WantedBy=default.target

View File

@ -1,35 +1,151 @@
import asyncio import asyncio
import binascii import binascii
import ipaddress
import logging
import secrets import secrets
from typing import Optional, Any from typing import Optional, Any, Dict, List
import capnp
from capnp.lib.capnp import _CallContext as _Ctx
from pyroute2 import IPRoute, WireGuard from pyroute2 import IPRoute, WireGuard
import monocypher import monocypher
from .database import Database, Node, SERVER_NODE_ID from .database import Database, Node, SERVER_NODE_ID
from .leylines_capnp import Authenticator, LeylinesApi, Maybe
IFNAME = 'leyline-wg' IFNAME = 'leyline-wg'
DEFAULT_PORT = 31337 DEFAULT_PORT = 31337
API_PORT = 31338 API_PORT = 31338
API_SSL_PORT = 31337
logger = logging.getLogger(__name__)
db = Database() db = Database()
async def client_connected(reader: asyncio.StreamReader, writer: asyncio.StreamWriter) -> None: class AuthenticatorImpl(Authenticator.Server):
def authenticate(self, _context: _Ctx, creds: Authenticator.Credentials,
**kwargs) -> Authenticator.AuthResult:
logger.debug("got auth request!")
which = creds.which()
if which == "token" and creds.token == db.get_token():
return Authenticator.AuthResult(succeeded=LeylinesApiImpl())
else:
return Authenticator.AuthResult(unauthorized=None)
class LeylinesApiImpl(LeylinesApi.Server):
def getSetting(self, name: str, _context: _Ctx, **kwargs) -> str:
return db.get_setting(name)
def putSetting(self, name: str, value: str, _context: _Ctx, **kwargs) -> None:
db.put_setting(name, value)
def getNodes(self, _context: _Ctx, **kwargs) -> List:
nodes = db.get_nodes()
return [NodeImpl(node) for node in nodes]
def initServer(self, name: str, ip: str, sshkey: str, _context: _Ctx, **kwargs) -> None:
db.init_server(name, ipaddress.IPv4Address(ip), sshkey)
def sync(self, _context: _Ctx, **kwargs) -> None:
try: try:
while True: sync_interface()
line = reader.readline() except:
if line.strip() != db.get_token(): logger.exception("failed to sync")
raise Exception("unauthorized")
# TODO def addNode(self, name: str, ip: Maybe, sshkey: Maybe, _context: _Ctx, **kwargs) -> None:
finally: db.add_node(
writer.close() name,
await writer.wait_closed() ipaddress.IPv4Address(ip.some) if ip.which == "some" else None,
sshkey.some if sshkey.which == "some" else None
)
def getNode(self, id: int, _context: _Ctx, **kwargs) -> None:
node = db.get_node(id)
if node is None:
setattr(_context.results, "none", None)
else:
setattr(_context.results, "some", NodeImpl(node))
def set_all(target: Any, **kwargs) -> None:
for k,v in kwargs.items():
setattr(target, k, v)
class NodeImpl(LeylinesApi.Node.Server):
def __init__(self, node: Node) -> None:
self._node = node
def getInfo(self, _context: _Ctx, **kwargs) -> None:
set_all(_context.results,
id=self._node.id,
name=self._node.name,
publicIp=({"some": str(self._node.public_ip)}
if self._node.public_ip is not None else {"none": None}),
ip=str(self._node.ip),
seckey=self._node.seckey,
pubkey=seckey_to_pubkey(self._node.seckey),
sshkey=({"some": self._node.ssh_key}
if self._node.ssh_key is not None else {"none": None}),
)
def getResources(self, _context: _Ctx, **kwargs) -> List[str]:
return db.get_node_resources(self._node.id)
def addResource(self, resource: str, _context: _Ctx, **kwargs) -> None:
db.add_node_resource(self._node.id, resource)
def delResource(self, resource: str, _context: _Ctx, **kwargs) -> None:
db.remove_node_resource(self._node.id, resource)
def getConfig(self, _context: _Ctx, **kwargs) -> str:
return generate_node_config(self._node.id)
async def client_connected(reader: asyncio.StreamReader, writer: asyncio.StreamWriter) -> None:
server = capnp.TwoPartyServer(bootstrap=AuthenticatorImpl())
logger.debug("got connection")
async def read_task():
while not reader.at_eof():
data = await reader.read(4096)
await server.write(data)
async def write_task():
while not reader.at_eof():
data = await server.read(4096)
writer.write(data.tobytes())
await writer.drain()
tasks = [asyncio.create_task(read_task()),
asyncio.create_task(write_task())]
while not reader.at_eof():
server.poll_once()
await asyncio.sleep(0.01)
logger.debug("connection ended")
try:
await tasks[0]
except:
logger.exception("client task error")
try:
await tasks[1]
except:
logger.exception("client task error")
async def main() -> None: async def main() -> None:
try:
sync_interface() sync_interface()
server = await asyncio.create_server(client_connected, host="127.0.0.1", port=API_PORT) except:
logger.exception("failed to sync")
loop = asyncio.get_event_loop()
server = await asyncio.start_server(client_connected, host="127.0.0.1", port=API_PORT)
logger.info("serving forever")
await server.serve_forever() await server.serve_forever()

View File

@ -1,7 +1,9 @@
import argparse import argparse
import ipaddress import asyncio
import logging
from typing import Optional, Dict
from leylines import main, db, sync_interface, seckey_to_pubkey, generate_node_config from leylines import main
from leylines.database import SERVER_NODE_ID from leylines.database import SERVER_NODE_ID
@ -10,11 +12,58 @@ def nop():
pass pass
def optional_to_maybe(value: Optional[str]) -> Dict:
if value is None:
return {"none": None}
else:
return {"some": value}
async def client_main(args: argparse.Namespace) -> None:
from leylines.client import ClientSession
async with ClientSession() as api:
if args.cmd == "status":
subnet = (await api.getSetting("subnet").a_wait()).value
nodes = (await api.getNodes().a_wait()).nodes
print("SUBNET:", subnet)
server_node = None
for node in nodes:
info = (await node.getInfo().a_wait());
public_ip = info.publicIp.some if info.publicIp.which() == "some" else None
print(f"NODE {info.id}:", info.name, public_ip, info.ip, info.pubkey,
"<Have SSH Key>" if info.sshkey.which() == "some" else "")
if info.id == SERVER_NODE_ID:
server_node = node
if server_node is None:
print("! server not defined. run leylines init>")
elif args.cmd == "init":
await api.initServer(args.name, args.ip, args.ssh_key).a_wait()
elif args.cmd == "add":
await api.addNode(
args.name,
optional_to_maybe(args.ip),
optional_to_maybe(
args.ssh_key.read() if args.ssh_key is not None else None)
).a_wait()
elif args.cmd == "get-conf":
node = await api.getNode(args.id).a_wait()
if node.which() == "none":
print("no such node!")
else:
print((await node.some.getConfig().a_wait()).config)
elif args.cmd == "sync":
await api.sync().a_wait()
parser = argparse.ArgumentParser(description="wireguard management system for dragons") parser = argparse.ArgumentParser(description="wireguard management system for dragons")
cmd = parser.add_subparsers(dest="cmd") cmd = parser.add_subparsers(dest="cmd")
cmd.required = True cmd.required = True
cmd_daemon = cmd.add_parser("daemon") cmd_daemon = cmd.add_parser("daemon")
cmd_local_daemon = cmd.add_parser("local-daemon")
cmd_print_token = cmd.add_parser("print-token")
cmd_status = cmd.add_parser("status") cmd_status = cmd.add_parser("status")
@ -37,32 +86,35 @@ cmd_sync = cmd.add_parser("sync")
args = parser.parse_args() args = parser.parse_args()
if args.cmd == "daemon": if args.cmd == "print-token":
from leylines import db
print(db.get_token())
elif args.cmd == "daemon" or args.cmd == "local-daemon":
# Set up logging
root = logging.getLogger()
root.setLevel(logging.DEBUG)
logging.getLogger("asyncio").setLevel(logging.INFO)
try:
if args.cmd == "local-daemon":
raise ImportError()
from systemd.journal import JournalHandler
ch = JournalHandler()
ch.setLevel(logging.INFO)
fm = logging.Formatter(
"%(name)-25s - %(funcName)-10s - %(levelname)-5s - %(message)s")
ch.setFormatter(fm)
root.addHandler(ch)
except ImportError:
ch = logging.StreamHandler()
ch.setLevel(logging.DEBUG)
fm = logging.Formatter(
"%(asctime)s - %(name)-25s - %(funcName)-10s - %(levelname)-5s"
+ " - %(message)s")
ch.setFormatter(fm)
root.addHandler(ch)
# run
logging.info("starting leylines daemon")
asyncio.run(main()) asyncio.run(main())
elif args.cmd == "status": else:
token = db.get_token() asyncio.run(client_main(args))
subnet = db.get_subnet()
nodes = db.get_nodes()
server_node = db.get_server_node()
print("TOKEN:", token)
print("SUBNET:", subnet)
if server_node is None:
print("SERVER: <not defined. run leylines init>")
else:
print("SERVER:", server_node.name, server_node.ip, server_node.public_ip,
seckey_to_pubkey(server_node.seckey),
"<Have SSH Key>" if server_node.ssh_key is not None else "")
for node in nodes:
if node.id == SERVER_NODE_ID:
continue
print(f"NODE {node.id}:", node.name, node.public_ip, node.ip,
seckey_to_pubkey(node.seckey), "<Have SSH Key>" if node.ssh_key is not None else "")
elif args.cmd == "init":
db.init_server(args.name, ipaddress.IPv4Address(args.ip))
elif args.cmd == "add":
db.add_node(args.name, args.ip, args.ssh_key.read() if args.ssh_key is not None else None)
sync_interface()
elif args.cmd == "get-conf":
print(generate_node_config(args.id))
elif args.cmd == "sync":
sync_interface()

View File

@ -0,0 +1,78 @@
import asyncio
from contextlib import suppress
import os
from typing import Optional
import capnp
from .leylines_capnp import Authenticator, LeylinesApi
from . import API_PORT, API_SSL_PORT
from .database import SERVER_NODE_ID
XDG_CONFIG_HOME = os.path.expanduser(os.environ.get("XDG_CONFIG_HOME", "~/.config"))
DEFAULT_TOKEN_PATH = os.path.join(XDG_CONFIG_HOME, "leylines", "token")
DEFAULT_HOST_PATH = os.path.join(XDG_CONFIG_HOME, "leylines", "host")
def default_get_token() -> str:
if "LEYLINES_TOKEN" in os.environ:
return os.environ["LEYLINES_TOKEN"]
with open(DEFAULT_TOKEN_PATH, "r") as f:
return f.read().strip()
def default_get_host() -> str:
if "LEYLINES_HOST" in os.environ:
return os.environ["LEYLINES_HOST"]
with open(DEFAULT_HOST_PATH, "r") as f:
return f.read().strip()
class ClientSession:
def __init__(self, host: Optional[str] = None, token: Optional[str] = None) -> None:
self._host = host
self._token = token
if self._host == None:
self._host = default_get_host()
if self._token == None:
self._token = default_get_token()
async def __aenter__(self):
reader, writer = await asyncio.open_connection(self._host, API_SSL_PORT, ssl=True)
self.writer = writer
client = capnp.TwoPartyClient()
async def read_task():
while not reader.at_eof():
data = await reader.read(4096)
client.write(data)
async def write_task():
while not reader.at_eof():
data = await client.read(4096)
writer.write(data.tobytes())
await writer.drain()
self.tasks = [asyncio.create_task(read_task()), asyncio.create_task(write_task())]
auth = client.bootstrap().cast_as(Authenticator)
creds = Authenticator.Credentials(token=self._token)
response = await auth.authenticate(creds).a_wait()
if response.result.which() != "succeeded":
self.__aexit__()
raise Exception("authentication failure!")
api = response.result.succeeded.cast_as(LeylinesApi)
return api
async def __aexit__(self, *args):
with suppress(Exception):
self.writer.close()
await self.writer.wait_closed()
with suppress(Exception):
self.tasks[0].cancel()
await self.tasks[0]
with suppress(Exception):
self.tasks[1].cancel()
await self.tasks[1]

View File

@ -53,15 +53,28 @@ class Database:
raise Exception("there is already a server node defined") raise Exception("there is already a server node defined")
self.add_node(name, public_ip, ssh_key) self.add_node(name, public_ip, ssh_key)
def get_subnet(self) -> ipaddress.IPv4Interface: def get_setting(self, name: str) -> Optional[str]:
cur = self.conn.cursor() cur = self.conn.cursor()
cur.execute("SELECT value FROM settings WHERE name='subnet'") cur.execute("SELECT value FROM settings WHERE name=?", (name,))
return ipaddress.IPv4Interface(cur.fetchone()[0]) res = cur.fetchone()
return res[0] if res is not None else None
def put_setting(self, name: str, value: str) -> None:
self.conn.execute("INSERT OR REPLACE INTO settings(name, value) VALUES(?, ?)",
(name, value))
self.conn.commit()
def get_subnet(self) -> ipaddress.IPv4Interface:
res = self.get_setting("subnet")
if res is None:
raise Exception("invalid state: no subnet setting")
return ipaddress.IPv4Interface(res)
def get_token(self) -> str: def get_token(self) -> str:
cur = self.conn.cursor() res = self.get_setting("token")
cur.execute("SELECT value FROM settings WHERE name='token'") if res is None:
return cur.fetchone()[0] raise Exception("invalid state: no token setting")
return res
def _get_free_ip(self) -> ipaddress.IPv4Address: def _get_free_ip(self) -> ipaddress.IPv4Address:
subnet = self.get_subnet().network subnet = self.get_subnet().network

View File

@ -0,0 +1,56 @@
@0xa4819f0f2c639488;
interface Authenticator {
authenticate @0 (creds :Credentials) -> (result :AuthResult);
struct Credentials {
union {
nop @0 :Void;
token @1 :Text;
}
}
struct AuthResult {
union {
unauthorized @0 :Void;
succeeded @1 :LeylinesApi;
}
}
}
interface LeylinesApi {
getSetting @0 (name :Text) -> (value :Text);
putSetting @1 (name :Text, value :Text) -> ();
getNodes @2 () -> (nodes :List(Node));
initServer @3 (name :Text, ip :Text, sshkey :Text) -> ();
sync @4 () -> ();
addNode @5 (name :Text, ip :Maybe(Text), sshkey :Maybe(Text));
getNode @6 (id :Int32) -> Maybe(Node);
interface Node {
getInfo @0 () -> Info;
struct Info {
id @0 :Int32;
name @1 :Text;
publicIp @2 :Maybe(Text);
ip @3 :Text;
seckey @4 :Text;
pubkey @5 :Text;
sshkey @6 :Maybe(Text);
}
getResources @1 () -> (resources :List(Text));
addResource @2 (resource :Text) -> ();
delResource @3 (resource :Text) -> ();
getConfig @4 () -> (config :Text);
}
}
struct Maybe(T) {
union {
none @0 :Void;
some @1 :T;
}
}

View File

@ -11,7 +11,9 @@ setup(
packages=['leylines'], packages=['leylines'],
install_requires=[ install_requires=[
"leylines-monocypher", "leylines-monocypher",
"pyroute2" "pyroute2",
"pycapnp",
"systemd"
], ],
include_package_data=True, include_package_data=True,
entry_points={ entry_points={