211 lines
5.4 KiB
OCaml
211 lines
5.4 KiB
OCaml
open! Import
|
|
open Lwt.Syntax
|
|
open Lwt.Infix
|
|
|
|
include (val Logging.sublogs logger "Server")
|
|
|
|
type config = {
|
|
port : int;
|
|
listen_backlog : int;
|
|
ping_interval : int;
|
|
whowas_history_len : int;
|
|
hostname : string;
|
|
motd_file : string;
|
|
notify : [`ready | `stopping] -> unit;
|
|
}
|
|
|
|
let bind_server
|
|
~(port : int)
|
|
~(listen_backlog : int)
|
|
: fd Lwt.t =
|
|
let fd = Lwt_unix.socket PF_INET SOCK_STREAM 0 in
|
|
let srv_adr = Unix.ADDR_INET (Unix.inet_addr_any, port) in
|
|
let* () = Lwt_unix.bind fd srv_adr in
|
|
Lwt_unix.listen fd listen_backlog;
|
|
info (fun m -> m "listening on %a" pp_sockaddr srv_adr);
|
|
Lwt.return fd
|
|
|
|
let accepts (fd : fd) : (fd * sockaddr) Lwt_stream.t =
|
|
let accept () = Lwt_unix.accept fd >>= Lwt.return_some in
|
|
Lwt_stream.from accept
|
|
|
|
let reader (fd : fd) : Msg.t Lwt_stream.t =
|
|
let chunk = Buffer.create 512 in
|
|
let rdbuf = Bytes.create 512 in
|
|
let gets () : Msg.t list option Lwt.t =
|
|
Lwt.catch
|
|
(fun () ->
|
|
Lwt_unix.read fd rdbuf 0 (Bytes.length rdbuf) >>= function
|
|
| 0 -> Lwt.return_none
|
|
| n ->
|
|
Buffer.add_subbytes chunk rdbuf 0 n;
|
|
(* if Buffer.length chunk > 200_000 then panic *)
|
|
let msgs, rest = Msg.parse (Buffer.contents chunk) in
|
|
Buffer.clear chunk;
|
|
Buffer.add_string chunk rest;
|
|
Lwt.return_some msgs)
|
|
(function
|
|
| Unix.Unix_error (ECONNRESET, _, _) -> Lwt.return_none
|
|
| exn -> Lwt.fail exn)
|
|
in
|
|
Lwt_stream.from gets |> Lwt_stream.map_list Fun.id
|
|
|
|
let writer (fd : fd) (obox : Msg.t Lwt_stream.t) : unit Lwt.t =
|
|
let rec writeall bs i =
|
|
if i >= Bytes.length bs then
|
|
Lwt.return_unit
|
|
else
|
|
let* n = Lwt_unix.write fd bs i (Bytes.length bs - i) in
|
|
writeall bs (i + n)
|
|
in
|
|
let buf = Buffer.create 512 in
|
|
let on_msg msg =
|
|
Buffer.clear buf;
|
|
Msg.write buf msg;
|
|
writeall (Buffer.to_bytes buf) 0
|
|
in
|
|
Lwt.catch
|
|
(fun () -> Lwt_stream.iter_s on_msg obox)
|
|
(function
|
|
| Unix.Unix_error (ECONNRESET, _, _) -> Lwt.return_unit
|
|
| exn -> Lwt.fail exn)
|
|
|
|
let handle_client
|
|
(conn_fd : fd)
|
|
(conn_addr : sockaddr)
|
|
~(server_info : Server_info.t)
|
|
~(router : Router.t)
|
|
~(ping_wheel : Connection.t Wheel.t)
|
|
=
|
|
info (fun m -> m "new connection %a" pp_sockaddr conn_addr);
|
|
let conn : Connection.t =
|
|
Connection.make
|
|
~router
|
|
~server_info
|
|
~addr:conn_addr
|
|
in
|
|
Wheel.add ping_wheel conn;
|
|
let reader = Lwt_stream.iter (Connection.on_msg conn) (reader conn_fd) in
|
|
let writer = writer conn_fd (Outbox.stream (Connection.outbox conn)) in
|
|
let both =
|
|
Lwt.finalize
|
|
(fun () -> reader <&> writer)
|
|
(fun () -> Lwt_unix.close conn_fd)
|
|
in
|
|
begin
|
|
Lwt.on_termination reader (fun () -> Connection.close conn);
|
|
Lwt.on_termination writer (fun () -> Connection.close conn);
|
|
Lwt.on_termination both
|
|
(fun () -> info (fun m -> m "connection closed %a" pp_sockaddr conn_addr));
|
|
Lwt.on_failure both
|
|
(fun e -> error (fun m -> m "%a:@ %a" pp_sockaddr conn_addr Fmt.exn e));
|
|
end
|
|
|
|
let interval dt =
|
|
let tick () =
|
|
let* () = Lwt_unix.sleep dt in
|
|
Lwt.return_some ()
|
|
in
|
|
Lwt_stream.from tick
|
|
|
|
let interrupt () =
|
|
let signal, signal_waiter = Lwt.wait () in
|
|
let on_signal num =
|
|
trace (fun m -> m "caught signal %d" num);
|
|
try Lwt.wakeup signal_waiter () with
|
|
Invalid_argument _ -> failwith "unceremoniously exiting"
|
|
in
|
|
Lwt_unix.on_signal (2 (* SIGINT *)) on_signal |> ignore;
|
|
Lwt_unix.on_signal (15 (* SIGTERM *)) on_signal |> ignore;
|
|
signal
|
|
|
|
let run {
|
|
port;
|
|
listen_backlog;
|
|
ping_interval;
|
|
whowas_history_len;
|
|
hostname;
|
|
motd_file;
|
|
notify;
|
|
}
|
|
: unit Lwt.t =
|
|
debug (fun m -> m "ping interval:@ %ds" ping_interval);
|
|
debug (fun m -> m "whowas history:@ %d" whowas_history_len);
|
|
|
|
let* motd =
|
|
let* file = Lwt_io.open_file motd_file ~mode:Input in
|
|
let* lines = Lwt_io.read_lines file |> Lwt_stream.to_list in
|
|
let+ () = Lwt_io.close file in
|
|
debug (fun m -> m "motd file:@ %d lines" (List.length lines));
|
|
lines
|
|
in
|
|
|
|
let server_info =
|
|
Server_info.make ()
|
|
~hostname
|
|
~motd
|
|
in
|
|
info (fun m -> m "hostname:@ %s" server_info.hostname);
|
|
info (fun m -> m "version:@ %s" server_info.version);
|
|
info (fun m -> m "created:@ %s" server_info.created);
|
|
|
|
let* server : fd =
|
|
bind_server
|
|
~port
|
|
~listen_backlog
|
|
in
|
|
notify `ready;
|
|
|
|
let router : Router.t =
|
|
Router.make
|
|
~whowas_history_len
|
|
in
|
|
|
|
let ping_wheel : _ Wheel.t =
|
|
Wheel.make
|
|
ping_interval
|
|
in
|
|
|
|
let ping conn =
|
|
match Connection.on_ping conn with
|
|
| Ok () -> Wheel.add ping_wheel conn
|
|
| Error _ -> Connection.close conn ~reason:"Connection timeout"
|
|
in
|
|
|
|
let pinger_promise =
|
|
Lwt_stream.iter
|
|
(fun () ->
|
|
List.iter ping
|
|
(Wheel.tick ping_wheel))
|
|
(interval 1.0)
|
|
in
|
|
|
|
let listener_promise =
|
|
Lwt_stream.iter
|
|
(fun (fd, addr) ->
|
|
handle_client fd addr
|
|
~server_info
|
|
~router
|
|
~ping_wheel)
|
|
(accepts server)
|
|
in
|
|
|
|
let* () =
|
|
Lwt.pick [
|
|
listener_promise <?> pinger_promise;
|
|
interrupt ()
|
|
]
|
|
in
|
|
notify `stopping;
|
|
|
|
info (fun m -> m "shutting down");
|
|
let* () = Lwt_unix.close server in
|
|
|
|
Router.nuke router;
|
|
Wheel.iter (fun conn -> Connection.close conn ~reason:"Server shutting down")
|
|
(* ping wheel should contain every active connection *)
|
|
ping_wheel;
|
|
|
|
(* give some time for the messages to send *)
|
|
Lwt_unix.sleep 0.5
|