talircd/lib/server/server.ml

211 lines
5.4 KiB
OCaml

open! Import
open Lwt.Syntax
open Lwt.Infix
include (val Logging.sublogs logger "Server")
type config = {
port : int;
listen_backlog : int;
ping_interval : int;
whowas_history_len : int;
hostname : string;
motd_file : string;
notify : [`ready | `stopping] -> unit;
}
let bind_server
~(port : int)
~(listen_backlog : int)
: fd Lwt.t =
let fd = Lwt_unix.socket PF_INET SOCK_STREAM 0 in
let srv_adr = Unix.ADDR_INET (Unix.inet_addr_any, port) in
let* () = Lwt_unix.bind fd srv_adr in
Lwt_unix.listen fd listen_backlog;
info (fun m -> m "listening on %a" pp_sockaddr srv_adr);
Lwt.return fd
let accepts (fd : fd) : (fd * sockaddr) Lwt_stream.t =
let accept () = Lwt_unix.accept fd >>= Lwt.return_some in
Lwt_stream.from accept
let reader (fd : fd) : Msg.t Lwt_stream.t =
let chunk = Buffer.create 512 in
let rdbuf = Bytes.create 512 in
let gets () : Msg.t list option Lwt.t =
Lwt.catch
(fun () ->
Lwt_unix.read fd rdbuf 0 (Bytes.length rdbuf) >>= function
| 0 -> Lwt.return_none
| n ->
Buffer.add_subbytes chunk rdbuf 0 n;
(* if Buffer.length chunk > 200_000 then panic *)
let msgs, rest = Msg.parse (Buffer.contents chunk) in
Buffer.clear chunk;
Buffer.add_string chunk rest;
Lwt.return_some msgs)
(function
| Unix.Unix_error (ECONNRESET, _, _) -> Lwt.return_none
| exn -> Lwt.fail exn)
in
Lwt_stream.from gets |> Lwt_stream.map_list Fun.id
let writer (fd : fd) (obox : Msg.t Lwt_stream.t) : unit Lwt.t =
let rec writeall bs i =
if i >= Bytes.length bs then
Lwt.return_unit
else
let* n = Lwt_unix.write fd bs i (Bytes.length bs - i) in
writeall bs (i + n)
in
let buf = Buffer.create 512 in
let on_msg msg =
Buffer.clear buf;
Msg.write buf msg;
writeall (Buffer.to_bytes buf) 0
in
Lwt.catch
(fun () -> Lwt_stream.iter_s on_msg obox)
(function
| Unix.Unix_error (ECONNRESET, _, _) -> Lwt.return_unit
| exn -> Lwt.fail exn)
let handle_client
(conn_fd : fd)
(conn_addr : sockaddr)
~(server_info : Server_info.t)
~(router : Router.t)
~(ping_wheel : Connection.t Wheel.t)
=
info (fun m -> m "new connection %a" pp_sockaddr conn_addr);
let conn : Connection.t =
Connection.make
~router
~server_info
~addr:conn_addr
in
Wheel.add ping_wheel conn;
let reader = Lwt_stream.iter (Connection.on_msg conn) (reader conn_fd) in
let writer = writer conn_fd (Outbox.stream (Connection.outbox conn)) in
let both =
Lwt.finalize
(fun () -> reader <&> writer)
(fun () -> Lwt_unix.close conn_fd)
in
begin
Lwt.on_termination reader (fun () -> Connection.close conn);
Lwt.on_termination writer (fun () -> Connection.close conn);
Lwt.on_termination both
(fun () -> info (fun m -> m "connection closed %a" pp_sockaddr conn_addr));
Lwt.on_failure both
(fun e -> error (fun m -> m "%a:@ %a" pp_sockaddr conn_addr Fmt.exn e));
end
let interval dt =
let tick () =
let* () = Lwt_unix.sleep dt in
Lwt.return_some ()
in
Lwt_stream.from tick
let interrupt () =
let signal, signal_waiter = Lwt.wait () in
let on_signal num =
trace (fun m -> m "caught signal %d" num);
try Lwt.wakeup signal_waiter () with
Invalid_argument _ -> failwith "unceremoniously exiting"
in
Lwt_unix.on_signal (2 (* SIGINT *)) on_signal |> ignore;
Lwt_unix.on_signal (15 (* SIGTERM *)) on_signal |> ignore;
signal
let run {
port;
listen_backlog;
ping_interval;
whowas_history_len;
hostname;
motd_file;
notify;
}
: unit Lwt.t =
debug (fun m -> m "ping interval:@ %ds" ping_interval);
debug (fun m -> m "whowas history:@ %d" whowas_history_len);
let* motd =
let* file = Lwt_io.open_file motd_file ~mode:Input in
let* lines = Lwt_io.read_lines file |> Lwt_stream.to_list in
let+ () = Lwt_io.close file in
debug (fun m -> m "motd file:@ %d lines" (List.length lines));
lines
in
let server_info =
Server_info.make ()
~hostname
~motd
in
info (fun m -> m "hostname:@ %s" server_info.hostname);
info (fun m -> m "version:@ %s" server_info.version);
info (fun m -> m "created:@ %s" server_info.created);
let* server : fd =
bind_server
~port
~listen_backlog
in
notify `ready;
let router : Router.t =
Router.make
~whowas_history_len
in
let ping_wheel : _ Wheel.t =
Wheel.make
ping_interval
in
let ping conn =
match Connection.on_ping conn with
| Ok () -> Wheel.add ping_wheel conn
| Error _ -> Connection.close conn ~reason:"Connection timeout"
in
let pinger_promise =
Lwt_stream.iter
(fun () ->
List.iter ping
(Wheel.tick ping_wheel))
(interval 1.0)
in
let listener_promise =
Lwt_stream.iter
(fun (fd, addr) ->
handle_client fd addr
~server_info
~router
~ping_wheel)
(accepts server)
in
let* () =
Lwt.pick [
listener_promise <?> pinger_promise;
interrupt ()
]
in
notify `stopping;
info (fun m -> m "shutting down");
let* () = Lwt_unix.close server in
Router.nuke router;
Wheel.iter (fun conn -> Connection.close conn ~reason:"Server shutting down")
(* ping wheel should contain every active connection *)
ping_wheel;
(* give some time for the messages to send *)
Lwt_unix.sleep 0.5