make a proper RPC protocol

This commit is contained in:
xenia 2021-06-18 03:07:35 -04:00
parent 19944c1fa2
commit 06c0c85e79
8 changed files with 396 additions and 57 deletions

View File

@ -5,7 +5,7 @@ ipy = __import__("IPython")
def dask(line, local_ns):
"initializes dask"
from leylines import get_server_node, DEFAULT_PORT, db
from leylines.client import ClientSession, SERVER_NODE_ID
from leylines.dask import init_dask, tqdmprogress
local_ns['tqdmprogress'] = tqdmprogress
@ -27,21 +27,34 @@ def dask(line, local_ns):
local_ns['upload'] = lambda file: upload(client, file)
server_node = get_server_node()
workers = [node for node in db.get_nodes()
if node.id != server_node.id and node.ssh_key is not None]
dest = f"{server_node.ip}:{DEFAULT_PORT}"
import asyncio
async def get_nodes_info():
async with ClientSession() as api:
nodes = (await api.getNodes().a_wait()).nodes
server = None
workers = []
for node in nodes:
info = (await node.getInfo().a_wait()).info
out_info = (info.ip, info.name)
if info.id == SERVER_NODE_ID:
server = out_info
elif info.sshkey.which() != "none":
workers.append(out_info)
return server, workers
server, workers = asyncio.run(get_nodes_info())
dest = f"{server[0]}:{DEFAULT_PORT}"
print("connected to APEX at", dest)
workers_by_ip = {str(node.ip):node for node in workers}
workers_status = {str(node.ip):False for node in workers}
workers_by_ip = {str(node[0]):node for node in workers}
workers_status = {str(node[0]):False for node in workers}
for addr, info in client.scheduler_info()["workers"].items():
workers_status[info["host"]] = True
for ip, node in sorted(workers_by_ip.items(), key=lambda x:x[1].name):
if workers_status[ip]:
print(f"{node.name} ({node.ip}): up")
print(f"{node[1]} ({node[0]}): up")
else:
print(f"{node.name} ({node.ip}): down")
print(f"{node[1]} ({node[0]}): down")
@ipy.core.magic.register_line_magic
@ipy.core.magic.needs_local_scope

View File

@ -0,0 +1,9 @@
[Unit]
Description=leylines server
[Service]
Type=simple
ExecStart=/usr/bin/env python3 -m leylines daemon
[Install]
WantedBy=default.target

View File

@ -1,35 +1,151 @@
import asyncio
import binascii
import ipaddress
import logging
import secrets
from typing import Optional, Any
from typing import Optional, Any, Dict, List
import capnp
from capnp.lib.capnp import _CallContext as _Ctx
from pyroute2 import IPRoute, WireGuard
import monocypher
from .database import Database, Node, SERVER_NODE_ID
from .leylines_capnp import Authenticator, LeylinesApi, Maybe
IFNAME = 'leyline-wg'
DEFAULT_PORT = 31337
API_PORT = 31338
API_SSL_PORT = 31337
logger = logging.getLogger(__name__)
db = Database()
async def client_connected(reader: asyncio.StreamReader, writer: asyncio.StreamWriter) -> None:
class AuthenticatorImpl(Authenticator.Server):
def authenticate(self, _context: _Ctx, creds: Authenticator.Credentials,
**kwargs) -> Authenticator.AuthResult:
logger.debug("got auth request!")
which = creds.which()
if which == "token" and creds.token == db.get_token():
return Authenticator.AuthResult(succeeded=LeylinesApiImpl())
else:
return Authenticator.AuthResult(unauthorized=None)
class LeylinesApiImpl(LeylinesApi.Server):
def getSetting(self, name: str, _context: _Ctx, **kwargs) -> str:
return db.get_setting(name)
def putSetting(self, name: str, value: str, _context: _Ctx, **kwargs) -> None:
db.put_setting(name, value)
def getNodes(self, _context: _Ctx, **kwargs) -> List:
nodes = db.get_nodes()
return [NodeImpl(node) for node in nodes]
def initServer(self, name: str, ip: str, sshkey: str, _context: _Ctx, **kwargs) -> None:
db.init_server(name, ipaddress.IPv4Address(ip), sshkey)
def sync(self, _context: _Ctx, **kwargs) -> None:
try:
while True:
line = reader.readline()
if line.strip() != db.get_token():
raise Exception("unauthorized")
# TODO
finally:
writer.close()
await writer.wait_closed()
sync_interface()
except:
logger.exception("failed to sync")
def addNode(self, name: str, ip: Maybe, sshkey: Maybe, _context: _Ctx, **kwargs) -> None:
db.add_node(
name,
ipaddress.IPv4Address(ip.some) if ip.which == "some" else None,
sshkey.some if sshkey.which == "some" else None
)
def getNode(self, id: int, _context: _Ctx, **kwargs) -> None:
node = db.get_node(id)
if node is None:
setattr(_context.results, "none", None)
else:
setattr(_context.results, "some", NodeImpl(node))
def set_all(target: Any, **kwargs) -> None:
for k,v in kwargs.items():
setattr(target, k, v)
class NodeImpl(LeylinesApi.Node.Server):
def __init__(self, node: Node) -> None:
self._node = node
def getInfo(self, _context: _Ctx, **kwargs) -> None:
set_all(_context.results,
id=self._node.id,
name=self._node.name,
publicIp=({"some": str(self._node.public_ip)}
if self._node.public_ip is not None else {"none": None}),
ip=str(self._node.ip),
seckey=self._node.seckey,
pubkey=seckey_to_pubkey(self._node.seckey),
sshkey=({"some": self._node.ssh_key}
if self._node.ssh_key is not None else {"none": None}),
)
def getResources(self, _context: _Ctx, **kwargs) -> List[str]:
return db.get_node_resources(self._node.id)
def addResource(self, resource: str, _context: _Ctx, **kwargs) -> None:
db.add_node_resource(self._node.id, resource)
def delResource(self, resource: str, _context: _Ctx, **kwargs) -> None:
db.remove_node_resource(self._node.id, resource)
def getConfig(self, _context: _Ctx, **kwargs) -> str:
return generate_node_config(self._node.id)
async def client_connected(reader: asyncio.StreamReader, writer: asyncio.StreamWriter) -> None:
server = capnp.TwoPartyServer(bootstrap=AuthenticatorImpl())
logger.debug("got connection")
async def read_task():
while not reader.at_eof():
data = await reader.read(4096)
await server.write(data)
async def write_task():
while not reader.at_eof():
data = await server.read(4096)
writer.write(data.tobytes())
await writer.drain()
tasks = [asyncio.create_task(read_task()),
asyncio.create_task(write_task())]
while not reader.at_eof():
server.poll_once()
await asyncio.sleep(0.01)
logger.debug("connection ended")
try:
await tasks[0]
except:
logger.exception("client task error")
try:
await tasks[1]
except:
logger.exception("client task error")
async def main() -> None:
try:
sync_interface()
server = await asyncio.create_server(client_connected, host="127.0.0.1", port=API_PORT)
except:
logger.exception("failed to sync")
loop = asyncio.get_event_loop()
server = await asyncio.start_server(client_connected, host="127.0.0.1", port=API_PORT)
logger.info("serving forever")
await server.serve_forever()

View File

@ -1,7 +1,9 @@
import argparse
import ipaddress
import asyncio
import logging
from typing import Optional, Dict
from leylines import main, db, sync_interface, seckey_to_pubkey, generate_node_config
from leylines import main
from leylines.database import SERVER_NODE_ID
@ -10,11 +12,58 @@ def nop():
pass
def optional_to_maybe(value: Optional[str]) -> Dict:
if value is None:
return {"none": None}
else:
return {"some": value}
async def client_main(args: argparse.Namespace) -> None:
from leylines.client import ClientSession
async with ClientSession() as api:
if args.cmd == "status":
subnet = (await api.getSetting("subnet").a_wait()).value
nodes = (await api.getNodes().a_wait()).nodes
print("SUBNET:", subnet)
server_node = None
for node in nodes:
info = (await node.getInfo().a_wait());
public_ip = info.publicIp.some if info.publicIp.which() == "some" else None
print(f"NODE {info.id}:", info.name, public_ip, info.ip, info.pubkey,
"<Have SSH Key>" if info.sshkey.which() == "some" else "")
if info.id == SERVER_NODE_ID:
server_node = node
if server_node is None:
print("! server not defined. run leylines init>")
elif args.cmd == "init":
await api.initServer(args.name, args.ip, args.ssh_key).a_wait()
elif args.cmd == "add":
await api.addNode(
args.name,
optional_to_maybe(args.ip),
optional_to_maybe(
args.ssh_key.read() if args.ssh_key is not None else None)
).a_wait()
elif args.cmd == "get-conf":
node = await api.getNode(args.id).a_wait()
if node.which() == "none":
print("no such node!")
else:
print((await node.some.getConfig().a_wait()).config)
elif args.cmd == "sync":
await api.sync().a_wait()
parser = argparse.ArgumentParser(description="wireguard management system for dragons")
cmd = parser.add_subparsers(dest="cmd")
cmd.required = True
cmd_daemon = cmd.add_parser("daemon")
cmd_local_daemon = cmd.add_parser("local-daemon")
cmd_print_token = cmd.add_parser("print-token")
cmd_status = cmd.add_parser("status")
@ -37,32 +86,35 @@ cmd_sync = cmd.add_parser("sync")
args = parser.parse_args()
if args.cmd == "daemon":
if args.cmd == "print-token":
from leylines import db
print(db.get_token())
elif args.cmd == "daemon" or args.cmd == "local-daemon":
# Set up logging
root = logging.getLogger()
root.setLevel(logging.DEBUG)
logging.getLogger("asyncio").setLevel(logging.INFO)
try:
if args.cmd == "local-daemon":
raise ImportError()
from systemd.journal import JournalHandler
ch = JournalHandler()
ch.setLevel(logging.INFO)
fm = logging.Formatter(
"%(name)-25s - %(funcName)-10s - %(levelname)-5s - %(message)s")
ch.setFormatter(fm)
root.addHandler(ch)
except ImportError:
ch = logging.StreamHandler()
ch.setLevel(logging.DEBUG)
fm = logging.Formatter(
"%(asctime)s - %(name)-25s - %(funcName)-10s - %(levelname)-5s"
+ " - %(message)s")
ch.setFormatter(fm)
root.addHandler(ch)
# run
logging.info("starting leylines daemon")
asyncio.run(main())
elif args.cmd == "status":
token = db.get_token()
subnet = db.get_subnet()
nodes = db.get_nodes()
server_node = db.get_server_node()
print("TOKEN:", token)
print("SUBNET:", subnet)
if server_node is None:
print("SERVER: <not defined. run leylines init>")
else:
print("SERVER:", server_node.name, server_node.ip, server_node.public_ip,
seckey_to_pubkey(server_node.seckey),
"<Have SSH Key>" if server_node.ssh_key is not None else "")
for node in nodes:
if node.id == SERVER_NODE_ID:
continue
print(f"NODE {node.id}:", node.name, node.public_ip, node.ip,
seckey_to_pubkey(node.seckey), "<Have SSH Key>" if node.ssh_key is not None else "")
elif args.cmd == "init":
db.init_server(args.name, ipaddress.IPv4Address(args.ip))
elif args.cmd == "add":
db.add_node(args.name, args.ip, args.ssh_key.read() if args.ssh_key is not None else None)
sync_interface()
elif args.cmd == "get-conf":
print(generate_node_config(args.id))
elif args.cmd == "sync":
sync_interface()
else:
asyncio.run(client_main(args))

View File

@ -0,0 +1,78 @@
import asyncio
from contextlib import suppress
import os
from typing import Optional
import capnp
from .leylines_capnp import Authenticator, LeylinesApi
from . import API_PORT, API_SSL_PORT
from .database import SERVER_NODE_ID
XDG_CONFIG_HOME = os.path.expanduser(os.environ.get("XDG_CONFIG_HOME", "~/.config"))
DEFAULT_TOKEN_PATH = os.path.join(XDG_CONFIG_HOME, "leylines", "token")
DEFAULT_HOST_PATH = os.path.join(XDG_CONFIG_HOME, "leylines", "host")
def default_get_token() -> str:
if "LEYLINES_TOKEN" in os.environ:
return os.environ["LEYLINES_TOKEN"]
with open(DEFAULT_TOKEN_PATH, "r") as f:
return f.read().strip()
def default_get_host() -> str:
if "LEYLINES_HOST" in os.environ:
return os.environ["LEYLINES_HOST"]
with open(DEFAULT_HOST_PATH, "r") as f:
return f.read().strip()
class ClientSession:
def __init__(self, host: Optional[str] = None, token: Optional[str] = None) -> None:
self._host = host
self._token = token
if self._host == None:
self._host = default_get_host()
if self._token == None:
self._token = default_get_token()
async def __aenter__(self):
reader, writer = await asyncio.open_connection(self._host, API_SSL_PORT, ssl=True)
self.writer = writer
client = capnp.TwoPartyClient()
async def read_task():
while not reader.at_eof():
data = await reader.read(4096)
client.write(data)
async def write_task():
while not reader.at_eof():
data = await client.read(4096)
writer.write(data.tobytes())
await writer.drain()
self.tasks = [asyncio.create_task(read_task()), asyncio.create_task(write_task())]
auth = client.bootstrap().cast_as(Authenticator)
creds = Authenticator.Credentials(token=self._token)
response = await auth.authenticate(creds).a_wait()
if response.result.which() != "succeeded":
self.__aexit__()
raise Exception("authentication failure!")
api = response.result.succeeded.cast_as(LeylinesApi)
return api
async def __aexit__(self, *args):
with suppress(Exception):
self.writer.close()
await self.writer.wait_closed()
with suppress(Exception):
self.tasks[0].cancel()
await self.tasks[0]
with suppress(Exception):
self.tasks[1].cancel()
await self.tasks[1]

View File

@ -53,15 +53,28 @@ class Database:
raise Exception("there is already a server node defined")
self.add_node(name, public_ip, ssh_key)
def get_subnet(self) -> ipaddress.IPv4Interface:
def get_setting(self, name: str) -> Optional[str]:
cur = self.conn.cursor()
cur.execute("SELECT value FROM settings WHERE name='subnet'")
return ipaddress.IPv4Interface(cur.fetchone()[0])
cur.execute("SELECT value FROM settings WHERE name=?", (name,))
res = cur.fetchone()
return res[0] if res is not None else None
def put_setting(self, name: str, value: str) -> None:
self.conn.execute("INSERT OR REPLACE INTO settings(name, value) VALUES(?, ?)",
(name, value))
self.conn.commit()
def get_subnet(self) -> ipaddress.IPv4Interface:
res = self.get_setting("subnet")
if res is None:
raise Exception("invalid state: no subnet setting")
return ipaddress.IPv4Interface(res)
def get_token(self) -> str:
cur = self.conn.cursor()
cur.execute("SELECT value FROM settings WHERE name='token'")
return cur.fetchone()[0]
res = self.get_setting("token")
if res is None:
raise Exception("invalid state: no token setting")
return res
def _get_free_ip(self) -> ipaddress.IPv4Address:
subnet = self.get_subnet().network

View File

@ -0,0 +1,56 @@
@0xa4819f0f2c639488;
interface Authenticator {
authenticate @0 (creds :Credentials) -> (result :AuthResult);
struct Credentials {
union {
nop @0 :Void;
token @1 :Text;
}
}
struct AuthResult {
union {
unauthorized @0 :Void;
succeeded @1 :LeylinesApi;
}
}
}
interface LeylinesApi {
getSetting @0 (name :Text) -> (value :Text);
putSetting @1 (name :Text, value :Text) -> ();
getNodes @2 () -> (nodes :List(Node));
initServer @3 (name :Text, ip :Text, sshkey :Text) -> ();
sync @4 () -> ();
addNode @5 (name :Text, ip :Maybe(Text), sshkey :Maybe(Text));
getNode @6 (id :Int32) -> Maybe(Node);
interface Node {
getInfo @0 () -> Info;
struct Info {
id @0 :Int32;
name @1 :Text;
publicIp @2 :Maybe(Text);
ip @3 :Text;
seckey @4 :Text;
pubkey @5 :Text;
sshkey @6 :Maybe(Text);
}
getResources @1 () -> (resources :List(Text));
addResource @2 (resource :Text) -> ();
delResource @3 (resource :Text) -> ();
getConfig @4 () -> (config :Text);
}
}
struct Maybe(T) {
union {
none @0 :Void;
some @1 :T;
}
}

View File

@ -11,7 +11,9 @@ setup(
packages=['leylines'],
install_requires=[
"leylines-monocypher",
"pyroute2"
"pyroute2",
"pycapnp",
"systemd"
],
include_package_data=True,
entry_points={