open! Import open Lwt.Syntax open Lwt.Infix include (val Logging.sublogs logger "Server") type ping_wheel = Connection.t Wheel.t let listener ~(port : int) ~(listen_backlog : int) : (fd * sockaddr) Lwt_stream.t = let sock : fd Lwt.t = let fd = Lwt_unix.socket PF_INET SOCK_STREAM 0 in Lwt_unix.setsockopt fd SO_KEEPALIVE false; Lwt_unix.setsockopt fd SO_REUSEPORT true; let srv_adr = Unix.ADDR_INET (Unix.inet_addr_any, port) in let* () = Lwt_unix.bind fd srv_adr in Lwt_unix.listen fd listen_backlog; info (fun m -> m "listening on %a" pp_sockaddr srv_adr); Lwt.return fd in let accept () = sock >>= Lwt_unix.accept >|= Option.some in Lwt_stream.from accept let reader (fd : fd) : Msg.t Lwt_stream.t = let chunk = Buffer.create 512 in let rdbuf = Bytes.create 512 in let gets () : Msg.t list option Lwt.t = Lwt.catch (fun () -> Lwt_unix.read fd rdbuf 0 (Bytes.length rdbuf) >>= function | 0 -> Lwt.return_none | n -> Buffer.add_subbytes chunk rdbuf 0 n; (* if Buffer.length chunk > 200_000 then panic *) let msgs, rest = Msg.parse (Buffer.contents chunk) in Buffer.clear chunk; Buffer.add_string chunk rest; Lwt.return_some msgs) (function | Unix.Unix_error (ECONNRESET, _, _) -> Lwt.return_none | exn -> Lwt.fail exn) in Lwt_stream.from gets |> Lwt_stream.map_list Fun.id let writer (fd : fd) (obox : Msg.t Lwt_stream.t) : unit Lwt.t = let rec writeall bs i = if i >= Bytes.length bs then Lwt.return_unit else let* n = Lwt_unix.write fd bs i (Bytes.length bs - i) in writeall bs (i + n) in let buf = Buffer.create 512 in let on_msg msg = Buffer.clear buf; Msg.write buf msg; writeall (Buffer.to_bytes buf) 0 in Lwt.catch (fun () -> Lwt_stream.iter_s on_msg obox) (function | Unix.Unix_error (ECONNRESET, _, _) -> Lwt.return_unit | exn -> Lwt.fail exn) let handle_client (conn_fd : fd) (conn_addr : sockaddr) ~(server_info : Server_info.t) ~(router : Router.t) ~(ping_wheel : ping_wheel) = info (fun m -> m "new connection %a" pp_sockaddr conn_addr); let conn : Connection.t = Connection.make ~router ~server_info ~addr:conn_addr in Wheel.add ping_wheel conn; let reader = Lwt_stream.iter (Connection.on_msg conn) (reader conn_fd) in let writer = writer conn_fd (Outbox.stream (Connection.outbox conn)) in let both = Lwt.finalize (fun () -> reader <&> writer) (fun () -> Lwt_unix.close conn_fd) in begin Lwt.on_termination reader (fun () -> Connection.close conn); Lwt.on_termination writer (fun () -> Connection.close conn); Lwt.on_termination both (fun () -> info (fun m -> m "connection closed %a" pp_sockaddr conn_addr)); Lwt.on_failure both (fun e -> error (fun m -> m "%a:@ %a" pp_sockaddr conn_addr Fmt.exn e)); end type config = { port : int; listen_backlog : int; ping_interval : int; whowas_history_len : int; hostname : string; (* TODO: motd *) } let run { port; listen_backlog; ping_interval; whowas_history_len; hostname } : unit Lwt.t = let server_info = Server_info.make ~hostname (* ~motd *) in let router : Router.t = Router.make ~whowas_history_len in let ping_wheel : _ Wheel.t = Wheel.make ping_interval in let on_tick () = (* trace (fun m -> m "tick"); *) List.iter (fun conn -> match Connection.on_ping conn with | Ok () -> Wheel.add ping_wheel conn | Error () -> Connection.close conn ~reason:"Connection timed out") (Wheel.tick ping_wheel) in let pinger_promise = Lwt_stream.iter on_tick (Lwt_stream.from @@ fun () -> let* () = Lwt_unix.sleep 1.0 in Lwt.return_some ()) in let on_con (fd, adr) = handle_client fd adr ~server_info ~router ~ping_wheel in let listener_promise = Lwt_stream.iter on_con (listener ~port ~listen_backlog) in listener_promise <&> pinger_promise