From 150d25b1e0a2d88c1fa200957bb9fd6bceaf479e Mon Sep 17 00:00:00 2001 From: tali Date: Sun, 7 Jan 2024 15:54:39 -0500 Subject: [PATCH] user registration logic --- bin/dune | 2 +- bin/main.ml | 132 +++------------------------------------ lib/server/connection.ml | 94 ++++++++++++++++++++++++++++ lib/server/dune | 8 +++ lib/server/import.ml | 9 +++ lib/server/irc_server.ml | 98 +++++++++++++++++++++++++++++ lib/server/userbase.ml | 39 ++++++++++++ 7 files changed, 256 insertions(+), 126 deletions(-) create mode 100644 lib/server/connection.ml create mode 100644 lib/server/dune create mode 100644 lib/server/import.ml create mode 100644 lib/server/irc_server.ml create mode 100644 lib/server/userbase.ml diff --git a/bin/dune b/bin/dune index 3cff5f8..839f097 100644 --- a/bin/dune +++ b/bin/dune @@ -3,4 +3,4 @@ (name main) (libraries lwt lwt.unix logs fmt - irc_msg)) + irc_server)) diff --git a/bin/main.ml b/bin/main.ml index d3313b6..9f35c4b 100644 --- a/bin/main.ml +++ b/bin/main.ml @@ -1,126 +1,8 @@ -open Lwt.Syntax -open Lwt.Infix +Logs.set_level (Some Debug); +Logs.set_reporter (Logs.format_reporter ()); -type sockaddr = Unix.sockaddr -type fd = Lwt_unix.file_descr - -let pp_inet_addr = Fmt.of_to_string Unix.string_of_inet_addr - -let pp_sock_addr ppf = function - | Unix.ADDR_INET (adr, port) -> Fmt.pf ppf "%a:%d" pp_inet_addr adr port - | Unix.ADDR_UNIX path -> Fmt.pf ppf "unix:%s" path - -let listener ~(port : int) ~(backlog : int) : (fd * sockaddr) Lwt_stream.t = - let sock : fd Lwt.t = - let fd = Lwt_unix.socket PF_INET SOCK_STREAM 0 in - Lwt_unix.setsockopt fd SO_KEEPALIVE false; - Lwt_unix.setsockopt fd SO_REUSEPORT true; - let srv_adr = Unix.ADDR_INET (Unix.inet_addr_any, port) in - let* () = Lwt_unix.bind fd srv_adr in - Lwt_unix.listen fd backlog; - Logs.info (fun m -> m "listening on %a" pp_sock_addr srv_adr); - Lwt.return fd - in - let accept () = sock >>= Lwt_unix.accept >|= Option.some in - Lwt_stream.from accept - -let reader (fd : fd) : Irc_msg.t Lwt_stream.t = - let chunk = Buffer.create 512 in - let rdbuf = Bytes.create 512 in - let gets () : Irc_msg.t list option Lwt.t = - Lwt.catch - (fun () -> - Lwt_unix.read fd rdbuf 0 (Bytes.length rdbuf) >>= function - | 0 -> Lwt.return_none - | n -> - Buffer.add_subbytes chunk rdbuf 0 n; - (* if Buffer.length chunk > 200_000 then panic *) - let msgs, rest = Irc_msg.parse (Buffer.contents chunk) in - Buffer.clear chunk; - Buffer.add_string chunk rest; - Lwt.return_some msgs) - (function - | Unix.Unix_error (ECONNRESET, _, _) -> Lwt.return_none - | exn -> Lwt.fail exn) - in - Lwt_stream.from gets |> Lwt_stream.map_list Fun.id - -let writer (fd : fd) (obox : Irc_msg.t Lwt_stream.t) : unit Lwt.t = - let rec writeall bs i = - if i >= Bytes.length bs then - Lwt.return_unit - else - let* n = Lwt_unix.write fd bs i (Bytes.length bs - i) in - writeall bs (i + n) - in - let buf = Buffer.create 512 in - let on_msg msg = - Buffer.clear buf; - Irc_msg.write buf msg; - writeall (Buffer.to_bytes buf) 0 - in - Lwt.catch - (fun () -> Lwt_stream.iter_s on_msg obox) - (function - | Unix.Unix_error (ECONNRESET, _, _) -> Lwt.return_unit - | exn -> Lwt.fail exn) - -let handle_client (con_fd : fd) (con_adr : sockaddr) = - Logs.info (fun m -> m "new connection %a" pp_sock_addr con_adr); - let ibox, push_evt = Lwt_stream.create () in - let obox, push_msg = Lwt_stream.create () in - let send_evt m = push_evt (Some m) in - let send_msg m = push_msg (Some m) in - - let on_msg (msg : Irc_msg.t) = - Logs.debug (fun m -> m "%a: %a" pp_sock_addr con_adr Irc_msg.pp msg); - match msg.command, msg.params with - | "NICK", [n] -> send_evt (`nick n) - | "QUIT", _ -> send_evt `quit - | c, _ -> send_evt (`invalid_cmd c) - in - - let on_evt = function - | `nick n -> - send_msg (Irc_msg.make "001" [n; "Welcome to the IRC network"]) - | `invalid_cmd c -> - send_msg (Irc_msg.make "421" [c; "Unknown command"]) - | `quit -> - push_evt None - in - - let rd = Lwt_stream.iter on_msg (reader con_fd) in - let wr = writer con_fd obox in - let eh = Lwt_stream.iter on_evt ibox in - - Lwt.finalize - (fun () -> Lwt.choose [rd; wr; eh]) - (fun () -> - Logs.info (fun m -> m "connection closed %a" pp_sock_addr con_adr); - Lwt_unix.close con_fd) - -type config = { - port : int; - tcp_listen_backlog : int; -} - -let run_server (cfg : config) = - let on_con (fd, adr) = - Lwt.on_failure - (handle_client fd adr) - (fun exn -> Logs.err (fun m -> m "%a: %a" pp_sock_addr adr Fmt.exn exn)) - in - Lwt_stream.iter - on_con - (listener - ~port:cfg.port - ~backlog:cfg.tcp_listen_backlog) - -let () = - Logs.set_level (Some Info); - Logs.set_reporter (Logs.format_reporter ()); - Lwt_main.run @@ - run_server { - port = 6667; - tcp_listen_backlog = 8; - } +Lwt_main.run + (Irc_server.run { + port = 6667; + tcp_listen_backlog = 8 + }) diff --git a/lib/server/connection.ml b/lib/server/connection.ml new file mode 100644 index 0000000..b4b7726 --- /dev/null +++ b/lib/server/connection.ml @@ -0,0 +1,94 @@ +open! Import + +type t = { + addr : sockaddr; + userbase : Userbase.t; + user : Userbase.user; + mutable regis : string option * (string * string) option; + outbox : Irc_msg.t Lwt_stream.t; + push_outbox : (Irc_msg.t option -> unit); + quit : unit Lwt_condition.t; +} + +let make ~(userbase : Userbase.t) ~(addr : sockaddr) : t = + let user = Userbase.make_user () in + let regis = None, None in + let outbox, push_outbox = Lwt_stream.create () in + let quit = Lwt_condition.create () in + { addr; userbase; user; regis; outbox; push_outbox; quit } + +let quitting t = Lwt_condition.wait t.quit +let outbox t = t.outbox +let send t msg = t.push_outbox (Some msg) + +let cleanup t = + Userbase.leave t.userbase t.user + +(* message handlers *) + +(* > user registration *) + +let update_regis t nick username = + t.regis <- (nick, username); + match nick, username with + | Some nick, Some _ -> + begin match Userbase.register t.userbase ~nick ~user:t.user with + | `inuse -> `nicknameinuse nick + | `ok -> `ok + end + | _, _ -> + (* wait for remaining credentials *) + `ok + +let on_nick_msg t new_nick = + (* TODO: validate nickname string *) + let _, username = t.regis in + update_regis t (Some new_nick) username + +let on_user_msg t new_username _mode = + (* TODO: validate user string *) + (* TODO: validate mode string *) + match t.regis with + | nick, None -> + update_regis t nick (Some new_username) + | _, Some _ -> + `alreadyregistered + +(* > misc *) + +let on_quit_msg t why = + Logs.debug (fun m -> m "%a: quit: %S" pp_sockaddr t.addr (String.concat " " why)); + Lwt_condition.broadcast t.quit (); + `ok + +(* message transmission *) + +module Rpl = struct + open Irc_msg + let unknowncommand cmd = make "421" [cmd; "Unknown command"] + let needmoreparams cmd = make "461" [cmd; "Not enough parameters"] + let tryagain cmd = make "263" [cmd; "Please wait a while and try again."] + let alreadyregistered () = make "462" ["Unauthorized command (already registered)"] + let nicknameinuse nick = make "433" [nick; "Nickname is already in use"] +end + +let on_msg t (msg : Irc_msg.t) : unit = + Logs.debug (fun m -> m "%a: %a" pp_sockaddr t.addr Irc_msg.pp msg); + let result = + match msg.command, msg.params with + | "NICK", new_nick :: _ -> + on_nick_msg t new_nick + | "USER", uname :: modestr :: _host :: rname :: _ -> + on_user_msg t (uname, rname) modestr + | "QUIT", why -> + on_quit_msg t why + | "NICK", _ | "USER", _ -> `needmoreparams + | _, _ -> `unknowncommand + in + match result with + | `ok -> () + | `unknowncommand -> send t (Rpl.unknowncommand msg.command) + | `needmoreparams -> send t (Rpl.needmoreparams msg.command) + | `tryagain -> send t (Rpl.tryagain msg.command) + | `alreadyregistered -> send t (Rpl.alreadyregistered ()) + | `nicknameinuse n -> send t (Rpl.nicknameinuse n) diff --git a/lib/server/dune b/lib/server/dune new file mode 100644 index 0000000..2ea573d --- /dev/null +++ b/lib/server/dune @@ -0,0 +1,8 @@ +(library + (package talircd) + (name irc_server) + ; (inline_tests) + ; (preprocess (pps ppx_expect ppx_deriving.show)) + (libraries + lwt lwt.unix logs fmt + irc_msg)) diff --git a/lib/server/import.ml b/lib/server/import.ml new file mode 100644 index 0000000..216c789 --- /dev/null +++ b/lib/server/import.ml @@ -0,0 +1,9 @@ +include Lwt.Syntax +include Lwt.Infix + +type sockaddr = Unix.sockaddr +type fd = Lwt_unix.file_descr + +let pp_sockaddr ppf = function + | Unix.ADDR_INET (adr, port) -> Fmt.pf ppf "%s:%d" (Unix.string_of_inet_addr adr) port + | Unix.ADDR_UNIX path -> Fmt.string ppf path diff --git a/lib/server/irc_server.ml b/lib/server/irc_server.ml new file mode 100644 index 0000000..e7cf114 --- /dev/null +++ b/lib/server/irc_server.ml @@ -0,0 +1,98 @@ +open! Import + +let listener ~(port : int) ~(backlog : int) : (fd * sockaddr) Lwt_stream.t = + let sock : fd Lwt.t = + let fd = Lwt_unix.socket PF_INET SOCK_STREAM 0 in + Lwt_unix.setsockopt fd SO_KEEPALIVE false; + Lwt_unix.setsockopt fd SO_REUSEPORT true; + let srv_adr = Unix.ADDR_INET (Unix.inet_addr_any, port) in + let* () = Lwt_unix.bind fd srv_adr in + Lwt_unix.listen fd backlog; + Logs.info (fun m -> m "listening on %a" pp_sockaddr srv_adr); + Lwt.return fd + in + let accept () = sock >>= Lwt_unix.accept >|= Option.some in + Lwt_stream.from accept + +let reader (fd : fd) : Irc_msg.t Lwt_stream.t = + let chunk = Buffer.create 512 in + let rdbuf = Bytes.create 512 in + let gets () : Irc_msg.t list option Lwt.t = + Lwt.catch + (fun () -> + Lwt_unix.read fd rdbuf 0 (Bytes.length rdbuf) >>= function + | 0 -> Lwt.return_none + | n -> + Buffer.add_subbytes chunk rdbuf 0 n; + (* if Buffer.length chunk > 200_000 then panic *) + let msgs, rest = Irc_msg.parse (Buffer.contents chunk) in + Buffer.clear chunk; + Buffer.add_string chunk rest; + Lwt.return_some msgs) + (function + | Unix.Unix_error (ECONNRESET, _, _) -> Lwt.return_none + | exn -> Lwt.fail exn) + in + Lwt_stream.from gets |> Lwt_stream.map_list Fun.id + +let writer (fd : fd) (obox : Irc_msg.t Lwt_stream.t) : unit Lwt.t = + let rec writeall bs i = + if i >= Bytes.length bs then + Lwt.return_unit + else + let* n = Lwt_unix.write fd bs i (Bytes.length bs - i) in + writeall bs (i + n) + in + let buf = Buffer.create 512 in + let on_msg msg = + Buffer.clear buf; + Irc_msg.write buf msg; + writeall (Buffer.to_bytes buf) 0 + in + Lwt.catch + (fun () -> Lwt_stream.iter_s on_msg obox) + (function + | Unix.Unix_error (ECONNRESET, _, _) -> Lwt.return_unit + | exn -> Lwt.fail exn) + +let handle_client cx (conn_fd : fd) (conn_addr : sockaddr) = + let conn : Connection.t = + Connection.make + ~addr:conn_addr + ~userbase:cx#get_userbase + in + let rd = Lwt_stream.iter (Connection.on_msg conn) (reader conn_fd) in + let wr = writer conn_fd (Connection.outbox conn) in + let qt = Connection.quitting conn in + Lwt.finalize + (fun () -> + Logs.info (fun m -> m "new connection %a" pp_sockaddr conn_addr); + Lwt.choose [rd; wr; qt]) + (fun () -> + Logs.info (fun m -> m "connection closed %a" pp_sockaddr conn_addr); + Connection.cleanup conn; + Lwt_unix.close conn_fd) + +type config = { + port : int; + tcp_listen_backlog : int; +} + +let run (cfg : config) : unit Lwt.t = + let cx = object + val userbase = Userbase.make () + method get_userbase = userbase + end in + + let on_con (fd, adr) = + Lwt.on_failure + (handle_client cx fd adr) + (fun exn -> + Logs.err (fun m -> m "%a: %a" pp_sockaddr adr Fmt.exn exn)) + in + + Lwt_stream.iter + on_con + (listener + ~port:cfg.port + ~backlog:cfg.tcp_listen_backlog) diff --git a/lib/server/userbase.ml b/lib/server/userbase.ml new file mode 100644 index 0000000..4f0101d --- /dev/null +++ b/lib/server/userbase.ml @@ -0,0 +1,39 @@ +type privmsg = { + pm_from : string; + pm_text : string; +} + +type user = { + mutable nick : string option; + inbox : notif Lwt_stream.t; + push_inbox : (notif option -> unit); +} + +and notif = privmsg + +let make_user () = + let inbox, push_inbox = Lwt_stream.create () in + { nick = None; inbox; push_inbox } + +let notify u no = u.push_inbox (Some no) + +type t = { + users : (string, user) Hashtbl.t + (* TODO: channels *) +} + +let make () = + { users = Hashtbl.create 4096 } + +let register t ~nick ~user = + if Hashtbl.mem t.users nick then + `inuse + else begin + Option.iter (Hashtbl.remove t.users) user.nick; + Hashtbl.add t.users nick user; + user.nick <- Some nick; + `ok + end + +let leave t user = + Option.iter (Hashtbl.remove t.users) user.nick