open! Import open Lwt.Syntax open Lwt.Infix include (val Logging.sublogs logger "Server") type config = { port : int; listen_backlog : int; ping_interval : int; whowas_history_len : int; hostname : string; motd_file : string; notify : [`ready | `stopping] -> unit; } let bind_server ~(port : int) ~(listen_backlog : int) : fd Lwt.t = let fd = Lwt_unix.socket PF_INET SOCK_STREAM 0 in let srv_adr = Unix.ADDR_INET (Unix.inet_addr_any, port) in let* () = Lwt_unix.bind fd srv_adr in Lwt_unix.listen fd listen_backlog; info (fun m -> m "listening on %a" pp_sockaddr srv_adr); Lwt.return fd let accepts (fd : fd) : (fd * sockaddr) Lwt_stream.t = let accept () = Lwt_unix.accept fd >>= Lwt.return_some in Lwt_stream.from accept let reader (fd : fd) : Msg.t Lwt_stream.t = let chunk = Buffer.create 512 in let rdbuf = Bytes.create 512 in let gets () : Msg.t list option Lwt.t = Lwt.catch (fun () -> fd rdbuf 0 (Bytes.length rdbuf) >>= function | 0 -> Lwt.return_none | n -> Buffer.add_subbytes chunk rdbuf 0 n; (* if Buffer.length chunk > 200_000 then panic *) let msgs, rest = Msg.parse (Buffer.contents chunk) in Buffer.clear chunk; Buffer.add_string chunk rest; Lwt.return_some msgs) (function | Unix.Unix_error (ECONNRESET, _, _) -> Lwt.return_none | exn -> exn) in Lwt_stream.from gets |> Lwt_stream.map_list let writer (fd : fd) (obox : Msg.t Lwt_stream.t) : unit Lwt.t = let rec writeall bs i = if i >= Bytes.length bs then Lwt.return_unit else let* n = Lwt_unix.write fd bs i (Bytes.length bs - i) in writeall bs (i + n) in let buf = Buffer.create 512 in let on_msg msg = Buffer.clear buf; Msg.write buf msg; writeall (Buffer.to_bytes buf) 0 in Lwt.catch (fun () -> Lwt_stream.iter_s on_msg obox) (function | Unix.Unix_error (ECONNRESET, _, _) -> Lwt.return_unit | exn -> exn) let handle_client (conn_fd : fd) (conn_addr : sockaddr) ~(server_info : Server_info.t) ~(router : Router.t) ~(ping_wheel : Connection.t Wheel.t) = info (fun m -> m "new connection %a" pp_sockaddr conn_addr); let conn : Connection.t = Connection.make ~router ~server_info ~addr:conn_addr in Wheel.add ping_wheel conn; let reader = Lwt_stream.iter (Connection.on_msg conn) (reader conn_fd) in let writer = writer conn_fd ( (Connection.outbox conn)) in let both = Lwt.finalize (fun () -> reader <&> writer) (fun () -> Lwt_unix.close conn_fd) in begin Lwt.on_termination reader (fun () -> Connection.close conn); Lwt.on_termination writer (fun () -> Connection.close conn); Lwt.on_termination both (fun () -> info (fun m -> m "connection closed %a" pp_sockaddr conn_addr)); Lwt.on_failure both (fun e -> error (fun m -> m "%a:@ %a" pp_sockaddr conn_addr Fmt.exn e)); end let interval dt = let tick () = let* () = Lwt_unix.sleep dt in Lwt.return_some () in Lwt_stream.from tick let interrupt () = let signal, signal_waiter = Lwt.wait () in let on_signal num = trace (fun m -> m "caught signal %d" num); try Lwt.wakeup signal_waiter () with Invalid_argument _ -> failwith "unceremoniously exiting" in Lwt_unix.on_signal (2 (* SIGINT *)) on_signal |> ignore; Lwt_unix.on_signal (15 (* SIGTERM *)) on_signal |> ignore; signal let run { port; listen_backlog; ping_interval; whowas_history_len; hostname; motd_file; notify; } : unit Lwt.t = debug (fun m -> m "ping interval:@ %ds" ping_interval); debug (fun m -> m "whowas history:@ %d" whowas_history_len); let* motd = let* file = Lwt_io.open_file motd_file ~mode:Input in let* lines = Lwt_io.read_lines file |> Lwt_stream.to_list in let+ () = Lwt_io.close file in debug (fun m -> m "motd file:@ %d lines" (List.length lines)); lines in let server_info = Server_info.make () ~hostname ~motd in info (fun m -> m "hostname:@ %s" server_info.hostname); info (fun m -> m "version:@ %s" server_info.version); info (fun m -> m "created:@ %s" server_info.created); let* server : fd = bind_server ~port ~listen_backlog in notify `ready; let router : Router.t = Router.make ~whowas_history_len in let ping_wheel : _ Wheel.t = Wheel.make ping_interval in let ping conn = match Connection.on_ping conn with | Ok () -> Wheel.add ping_wheel conn | Error _ -> Connection.close conn ~reason:"Connection timeout" in let pinger_promise = Lwt_stream.iter (fun () -> List.iter ping (Wheel.tick ping_wheel)) (interval 1.0) in let listener_promise = Lwt_stream.iter (fun (fd, addr) -> handle_client fd addr ~server_info ~router ~ping_wheel) (accepts server) in let* () = Lwt.pick [ listener_promise pinger_promise; interrupt () ] in notify `stopping; info (fun m -> m "shutting down"); let* () = Lwt_unix.close server in Router.nuke router; Wheel.iter (fun conn -> Connection.close conn ~reason:"Server shutting down") (* ping wheel should contain every active connection *) ping_wheel; (* give some time for the messages to send *) Lwt_unix.sleep 0.5