(*
 * uTop_main.ml
 * ------------
 * Copyright : (c) 2011, Jeremie Dimino <jeremie@dimino.org>
 * Licence   : BSD3
 *
 * This file is a part of utop.
 *)

open CamomileLibraryDyn.Camomile
open Lwt
open Lwt_react
open LTerm_text
open LTerm_geom
open UTop_token
open UTop_styles
open UTop_private

module String_set = Set.Make(String)

(* +-----------------------------------------------------------------+
   | History                                                         |
   +-----------------------------------------------------------------+ *)

let history = ref []

let init_history () =
  let hist_name = Filename.concat LTerm_resources.home ".utop-history" in
  (* Save history on exit. *)
  Lwt_main.at_exit (fun () -> LTerm_read_line.save_history hist_name !history);
  (* Load history. *)
  lwt h = LTerm_read_line.load_history hist_name in
  history := h;
  return ()

(* +-----------------------------------------------------------------+
   | offset --> index                                                |
   +-----------------------------------------------------------------+ *)

(* Return the index (in unicode characters) of the character starting
   a offset (in bytes) [ofs] in [str]. *)
let index_of_offset src ofs =
  let rec aux idx ofs' =
    if ofs' = ofs then
      idx
    else if ofs' > ofs then
      idx - 1
    else if ofs' = String.length src then
      -1
    else
      aux (idx + 1) (Zed_utf8.unsafe_next src ofs')
  in
  aux 0 0

let convert_locs str locs = List.map (fun (a, b) -> (index_of_offset str a, index_of_offset str b)) locs

(* +-----------------------------------------------------------------+
   | The read-line class                                             |
   +-----------------------------------------------------------------+ *)

let parse_and_check input eos_is_error =
  match !UTop.parse_toplevel_phrase input eos_is_error with
    | UTop.Error (locs, msg) ->
        UTop.Error (convert_locs input locs, "Error: " ^ msg ^ "\n")
    | UTop.Value phrase ->
        match UTop.check_phrase phrase with
          | None ->
              UTop.Value phrase
          | Some (locs, msg) ->
              UTop.Error (convert_locs input locs, msg)

(* Read a phrase. If the result is a value, it is guaranteed to by a
   valid phrase (i.e. typable and compilable). *)
class read_phrase ~term = object(self)
  inherit [Parsetree.toplevel_phrase UTop.result] LTerm_read_line.engine ~history:!history () as super
  inherit [Parsetree.toplevel_phrase UTop.result] LTerm_read_line.term term as super_term

  val mutable return_value = None

  method eval =
    match return_value with
      | Some x ->
          x
      | None ->
          assert false

  method exec = function
    | LTerm_read_line.Accept :: actions when !UTop.smart_accept && S.value self#mode = LTerm_read_line.Edition -> begin
        Zed_macro.add self#macro LTerm_read_line.Accept;
        (* Try to parse the input. *)
        let input = Zed_rope.to_string (Zed_edit.text self#edit) in
        (* Toploop does that: *)
        Location.reset ();
        try
          let result = parse_and_check input false in
          return_value <- Some result;
          history := LTerm_read_line.add_entry input !history;
          return result
        with UTop.Need_more ->
          (* Input not finished, continue. *)
          self#insert (UChar.of_char '\n');
          self#exec actions
      end
    | actions ->
        super_term#exec actions

  method stylise last =
    let styled, position = super#stylise last in

    (* Syntax highlighting *)
    let stylise start stop token_style =
      for i = start to stop - 1 do
        let ch, style = styled.(i) in
        styled.(i) <- (ch, LTerm_style.merge token_style style)
      done
    in
    UTop_styles.stylise stylise (UTop_lexer.lex_string ~camlp4:(UTop.get_camlp4 ()) (LTerm_text.to_string styled));

    if not last then
      (* Parenthesis matching. *)
      LTerm_text.stylise_parenthesis styled position styles.style_paren
    else begin
      match return_value with
        | Some (UTop.Error (locs, _)) ->
            (* Highlight error locations. *)
            List.iter
              (fun (start, stop) ->
                 for i = start to stop - 1 do
                   let ch, style = styled.(i) in
                   styled.(i) <- (ch, { style with LTerm_style.underline = Some true })
                 done)
              locs
        | _ ->
            ()
    end;

    (styled, position)

  method completion =
    let pos, words = UTop_complete.complete (Zed_rope.to_string self#input_prev) in
    self#set_completion pos words

  initializer
    (* Set the source signal for the size of the terminal. *)
    UTop_private.set_size self#size;
    (* Set the source signal for the key sequence. *)
    UTop_private.set_key_sequence self#key_sequence;
    (* Set the prompt. *)
    self#set_prompt !UTop.prompt
end

(* +-----------------------------------------------------------------+
   | Out phrase printing                                             |
   +-----------------------------------------------------------------+ *)

let print_out_phrase term printer pp out_phrase =
  flush stdout;
  flush stderr;
  (match out_phrase with
     | Outcometree.Ophr_exception _ ->
         if Printexc.backtrace_status () then begin
           Printexc.print_backtrace stdout;
           flush stdout
         end
     | _ ->
         ());
  let buffer = Buffer.create 1024 in
  let pp = Format.formatter_of_buffer buffer in
  Format.pp_set_margin pp (LTerm.size term).cols;
  printer pp out_phrase;
  Format.pp_print_flush pp ();
  let string = Buffer.contents buffer in
  let styled = LTerm_text.of_string string in
  let stylise start stop token_style =
    for i = start to stop - 1 do
      let ch, style = styled.(i) in
      styled.(i) <- (ch, LTerm_style.merge token_style style)
    done
  in
  UTop_styles.stylise stylise (UTop_lexer.lex_string string);
  Lwt_main.run (LTerm.fprints term styled)

(* +-----------------------------------------------------------------+
   | Lwt_main.run auto-insertion                                     |
   +-----------------------------------------------------------------+ *)

let longident_lwt_main_run = Longident.Ldot (Longident.Lident "Lwt_main", "run")

let is_eval = function
  | { Parsetree.pstr_desc = Parsetree.Pstr_eval _ } -> true
  | _ -> false

let insert_lwt_main_run phrase =
  match phrase with
    | Parsetree.Ptop_def pstr ->
        let env = !Toploop.toplevel_env in
        let lwt_main_run_is_the_real_one =
          try
            match Env.lookup_value longident_lwt_main_run env with
              | Path.Pdot (Path.Pident id, "run", 0), _ ->
                  Ident.persistent id
              | _ ->
                  false
          with Not_found ->
            false
        in
        if lwt_main_run_is_the_real_one && List.exists is_eval pstr then
          let tstr, _, _ = Typemod.type_structure env pstr Location.none in
          Parsetree.Ptop_def
            (List.map2
               (fun pstr_item tstr_item ->
                  match pstr_item, tstr_item with
                    | { Parsetree.pstr_desc = Parsetree.Pstr_eval e; Parsetree.pstr_loc = loc },
                      Typedtree.Tstr_eval {
                        Typedtree.exp_type = {
                          Types.desc =
                            Types.Tconstr (Path.Pdot (Path.Pident id, "t", -1), _, _)
                        }
                      } ->
                        if Ident.persistent id && Ident.name id = "Lwt" then {
                          Parsetree.pstr_desc =
                            Parsetree.Pstr_eval {
                              Parsetree.pexp_desc =
                                Parsetree.Pexp_apply
                                  ({ Parsetree.pexp_desc = Parsetree.Pexp_ident longident_lwt_main_run; Parsetree.pexp_loc = loc },
                                   [("", e)]);
                              Parsetree.pexp_loc = loc;
                            };
                          Parsetree.pstr_loc = loc;
                        } else
                          pstr_item
                    | _ ->
                        pstr_item)
               pstr tstr)
        else
          phrase
    | Parsetree.Ptop_dir _ ->
        phrase

(* +-----------------------------------------------------------------+
   | Main loop                                                       |
   +-----------------------------------------------------------------+ *)

let rec read_phrase term =
  try_lwt
    (new read_phrase ~term)#run
  with Sys.Break ->
    lwt () = LTerm.fprintl term "Interrupted." in
    read_phrase term

let update_margin pp cols =
  if Format.pp_get_margin pp () <> cols then
    Format.pp_set_margin pp cols

let print_error msg =
  lwt term = Lazy.force LTerm.stdout in
  lwt () = LTerm.set_style term styles.style_error in
  lwt () = Lwt_io.print msg in
  lwt () = LTerm.set_style term LTerm_style.none in
  LTerm.flush term

let rec loop term =
  (* Reset completion. *)
  UTop_complete.reset ();

  (* increment the command counter. *)
  UTop_private.set_count (S.value UTop_private.count + 1);

  (* Call hooks. *)
  Lwt_sequence.iter_l (fun f -> f ()) UTop.new_command_hooks;

  (* Read interactively user input. *)
  let phrase_opt =
    Lwt_main.run (
      try_lwt
        match_lwt read_phrase term with
          | UTop.Value phrase ->
              return (Some phrase)
          | UTop.Error (_, msg) ->
              lwt () = print_error msg in
              return None
      finally
        LTerm.flush term
    )
  in

  match phrase_opt with
    | Some phrase ->
        (* Add Lwt_main.run to toplevel evals. *)
        let phrase = if UTop.get_auto_run_lwt () then insert_lwt_main_run phrase else phrase in
        (* Set the margin of standard formatters. *)
        let cols = (LTerm.size term).cols in
        update_margin Format.std_formatter cols;
        update_margin Format.err_formatter cols;
        (* No exception can be raised at this stage. *)
        ignore (Toploop.execute_phrase true Format.std_formatter phrase);
        loop term
    | None ->
        loop term

(* +-----------------------------------------------------------------+
   | Welcome message                                                 |
   +-----------------------------------------------------------------+ *)

let welcome term =
  (* Create a context to render the welcome message. *)
  let size = LTerm.size term in
  let size = { rows = 3; cols = size.cols } in
  let matrix = LTerm_draw.make_matrix size in
  let ctx = LTerm_draw.context matrix size in

  (* Draw the message in a box. *)

  let message = Printf.sprintf "Welcome to utop version %s (using OCaml version %s)!" UTop.version Sys.ocaml_version in

  LTerm_draw.fill_style ctx LTerm_style.({ none with foreground = Some lcyan });

  LTerm_draw.draw_hline ctx 0 0 size.cols LTerm_draw.Light;
  LTerm_draw.draw_frame ctx {
    row1 = 0;
    row2 = 3;
    col1 = (size.cols - (String.length message + 4)) / 2;
    col2 = (size.cols + (String.length message + 4)) / 2;
  } LTerm_draw.Light;

  LTerm_draw.draw_styled ctx 1 ((size.cols - String.length message) / 2) (eval [B_fg LTerm_style.yellow; S message]);

  (* Render to the screen. *)
  lwt () = LTerm.print_box term matrix in

  (* Move to after the box. *)
  lwt () = LTerm.fprint term "\n" in

  LTerm.flush term

(* +-----------------------------------------------------------------+
   | Classic mode                                                    |
   +-----------------------------------------------------------------+ *)

let read_input_classic prompt buffer len =
  let rec loop i =
    if i = len then
      return (i, false)
    else
      Lwt_io.read_char_opt Lwt_io.stdin >>= function
        | Some c ->
            buffer.[i] <- c;
            if c = '\n' then
              return (i + 1, false)
            else
              loop (i + 1)
        | None ->
            return (i, true)
  in
  Lwt_main.run (Lwt_io.write Lwt_io.stdout prompt >> loop 0)

(* +-----------------------------------------------------------------+
   | Emacs mode                                                      |
   +-----------------------------------------------------------------+ *)

module Emacs(M : sig end) = struct

  (* Copy standard output, which will be used to send commands. *)
  let command_oc = Unix.out_channel_of_descr (Unix.dup Unix.stdout)

  let split_lines str =
    let rec aux i j =
      if j = String.length str then
        []
      else if str.[j] = '\n' then
        String.sub str i (j - i) :: aux (j + 1) (j + 1)
      else
        aux i (j + 1)
    in
    aux 0 0

  (* +---------------------------------------------------------------+
     | Sending commands to Emacs                                     |
     +---------------------------------------------------------------+ *)

  (* Mutex used to send commands to Emacs. *)
  let command_mutex = Mutex.create ()

  let send command argument =
    Mutex.lock command_mutex;
    output_string command_oc command;
    output_char command_oc ':';
    output_string command_oc argument;
    output_char command_oc '\n';
    flush command_oc;
    Mutex.unlock command_mutex

  (* Keep the [utop-phrase-terminator] variable of the emacs part in sync. *)
  let () =
    S.keep (S.map (send "phrase-terminator") UTop.phrase_terminator)

  (* +---------------------------------------------------------------+
     | Standard outputs redirection                                  |
     +---------------------------------------------------------------+ *)

  (* The output of ocaml (stdout and stderr) is redirected so the
     emacs parts of utop can recognize it. *)

  (* Continuously copy the output of ocaml to Emacs. *)
  let rec copy_output which ic =
    let line = input_line ic in
    send which line;
    copy_output which ic

  (* Create a thread which redirect the given output: *)
  let redirect which fd =
    let fdr, fdw = Unix.pipe () in
    Unix.dup2 fdw fd;
    Unix.close fdw;
    Thread.create (copy_output which) (Unix.in_channel_of_descr fdr)

  (* Redirects stdout and stderr: *)
  let _ = redirect "stdout" Unix.stdout
  let _ = redirect "stderr" Unix.stderr

  (* +---------------------------------------------------------------+
     | Loop                                                          |
     +---------------------------------------------------------------+ *)

  let read_line () =
    let behavior = Sys.signal Sys.sigint Sys.Signal_ignore in
    try
      let line = Lwt_main.run (Lwt_io.read_line_opt Lwt_io.stdin) in
      Sys.set_signal Sys.sigint behavior;
      line
    with exn ->
      Sys.set_signal Sys.sigint behavior;
      raise exn

  let read_command () =
    match read_line () with
      | None ->
          None
      | Some line ->
          match try Some (String.index line ':') with Not_found -> None with
            | None ->
                send "stderr" "':' missing!";
                exit 1
            | Some idx ->
                Some (String.sub line 0 idx, String.sub line (idx + 1) (String.length line - (idx + 1)))

  let read_data ?(final_newline = true) () =
    let buf = Buffer.create 1024 in
    let rec loop first =
      match read_command () with
        | None ->
            send "stderr" "'end' command missing!";
            exit 1
        | Some ("data", data) ->
            if not first then Buffer.add_char buf '\n';
            Buffer.add_string buf data;
            loop false
        | Some ("end", _) ->
            if final_newline then Buffer.add_char buf '\n';
            Buffer.contents buf
        | Some (command, argument) ->
            Printf.ksprintf (send "stderr") "'data' or 'end' command expected, got %S!" command;
            exit 1
    in
    loop true

  let process_input eos_is_error =
    match parse_and_check (read_data ()) eos_is_error with
      | UTop.Value phrase ->
          send "accept" "";
          (* Add Lwt_main.run to toplevel evals. *)
          let phrase = if UTop.get_auto_run_lwt () then insert_lwt_main_run phrase else phrase in
          (* No exception can be raised at this stage. *)
          ignore (Toploop.execute_phrase true Format.std_formatter phrase)
      | UTop.Error (locs, msg) ->
          send "accept" (String.concat "," (List.map (fun (a, b) -> Printf.sprintf "%d,%d" a b) locs));
          List.iter (send "stderr") (split_lines msg)

  let rec loop () =
    (* Reset completion. *)
    UTop_complete.reset ();

    (* Increment the command counter. *)
    UTop_private.set_count (S.value UTop_private.count + 1);

    (* Call hooks. *)
    Lwt_sequence.iter_l (fun f -> f ()) UTop.new_command_hooks;

    (* Tell emacs we are ready. *)
    send "prompt" "";

    loop_commands ()

  and loop_commands () =
    match read_command () with
      | None ->
          ()
      | Some ("input", "allow-incomplete") ->
          let continue =
            try
              process_input false;
              false
            with UTop.Need_more ->
              send "continue" "";
              true
          in
          if continue then
            loop_commands ()
          else
            loop ()
      | Some ("input", "") ->
          process_input true;
          loop ()
      | Some ("complete", _) ->
          let input = read_data ~final_newline:false () in
          let start, words = UTop_complete.complete input in
          let words = List.map fst words in
          let prefix = LTerm_read_line.common_prefix words in
          let index = String.length input - start in
          let suffix =
            if index > 0 && index <= String.length prefix then
              String.sub prefix index (String.length prefix - index)
            else
              ""
          in
          if suffix = "" then begin
            send "completion-start" "";
            List.iter (fun word -> send "completion" word) words;
            send "completion-stop" "";
          end else
            send "completion-word" suffix;
          loop_commands ()
      | Some (command, _) ->
          Printf.ksprintf (send "stderr") "unrecognized command %S!" command;
          exit 1
end

(* +-----------------------------------------------------------------+
   | Entry point                                                     |
   +-----------------------------------------------------------------+ *)

let emacs_mode = ref false
let preload_objects = ref []

let prepare () =
  Toploop.set_paths ();
  try
    let res = List.for_all (Topdirs.load_file Format.err_formatter) (List.rev !preload_objects) in
    !Toploop.toplevel_startup_hook ();
    res
  with exn ->
    try
      Errors.report_error Format.err_formatter exn;
      false
    with exn ->
      Format.eprintf "Uncaught exception: %s\n" (Printexc.to_string exn);
      false

let read_script_from_stdin () =
  let args = Array.sub Sys.argv !Arg.current (Array.length Sys.argv - !Arg.current) in
  if prepare () && Toploop.run_script Format.err_formatter "" args then
    exit 0
  else
    exit 2

let file_argument name =
  if Filename.check_suffix name ".cmo" || Filename.check_suffix name ".cma" then
    preload_objects := name :: !preload_objects
  else begin
    let args = Array.sub Sys.argv !Arg.current (Array.length Sys.argv - !Arg.current) in
    if prepare () && Toploop.run_script Format.err_formatter name args then
      exit 0
    else
      exit 2
  end

let print_version () =
  Printf.printf "The universal toplevel for OCaml, version %s, compiled for OCaml version %s\n" UTop.version Sys.ocaml_version;
  exit 0

let print_version_num () =
  Printf.printf "%s\n" UTop.version

let args = Arg.align [
#if ocaml_version >= (3, 13, 0)
  "-absname", Arg.Set Location.absname, " Show absolute filenames in error message";
#endif
  "-I", Arg.String (fun dir ->  Clflags.include_dirs := Misc.expand_directory Config.standard_library dir :: !Clflags.include_dirs), "<dir> Add <dir> to the list of include directories";
  "-init", Arg.String (fun s -> Clflags.init_file := Some s), "<file> Load <file> instead of default init file";
  "-labels", Arg.Clear Clflags.classic, " Use commuting label mode";
  "-no-app-funct", Arg.Clear Clflags.applicative_functors, " Deactivate applicative functors";
  "-noassert", Arg.Set Clflags.noassert, " Do not compile assertion checks";
  "-nolabels", Arg.Set Clflags.classic, " Ignore non-optional labels in types";
  "-nostdlib", Arg.Set Clflags.no_std_include, " Do not add default directory to the list of include directories";
  "-principal", Arg.Set Clflags.principal, " Check principality of type inference";
  "-rectypes", Arg.Set Clflags.recursive_types, " Allow arbitrary recursive types";
  "-stdin", Arg.Unit read_script_from_stdin, " Read script from standard input";
  "-strict-sequence", Arg.Set Clflags.strict_sequence, " Left-hand part of a sequence must have type unit";
  "-unsafe", Arg.Set Clflags.fast, " Do not compile bounds checking on array and string access";
  "-version", Arg.Unit print_version, " Print version and exit";
  "-vnum", Arg.Unit print_version_num, " Print version number and exit";
  "-w", Arg.String (Warnings.parse_options false),
  Printf.sprintf
    "<list>  Enable or disable warnings according to <list>:\n\
    \        +<spec>   enable warnings in <spec>\n\
    \        -<spec>   disable warnings in <spec>\n\
    \        @<spec>   enable warnings in <spec> and treat them as errors\n\
    \     <spec> can be:\n\
    \        <num>             a single warning number\n\
    \        <num1>..<num2>    a range of consecutive warning numbers\n\
    \        <letter>          a predefined set\n\
    \     default setting is %S" Warnings.defaults_w;
  "-warn-error", Arg.String (Warnings.parse_options true),
  Printf.sprintf
    "<list>  Enable or disable error status for warnings according to <list>\n\
    \     See option -w for the syntax of <list>.\n\
    \     Default setting is %S" Warnings.defaults_warn_error;
  "-warn-help", Arg.Unit Warnings.help_warnings, " Show description of warning numbers";
  "-emacs", Arg.Set emacs_mode, " Run in emacs mode";
]

let usage = "Usage: utop <options> <object-files> [script-file [arguments]]\noptions are:"

let common_init () =
  (* Initializes toplevel environment. *)
  Toploop.initialize_toplevel_env ();
  (* Set the global input name. *)
  Location.input_name := UTop.input_name;
  (* Make sure SIGINT is catched while executing OCaml code. *)
  Sys.catch_break true;
  (* Load user's .ocamlinit file. *)
  match !Clflags.init_file with
    | Some fn ->
        if Sys.file_exists fn then
          ignore (Toploop.use_silently Format.err_formatter fn)
        else
          Printf.eprintf "Init file not found: \"%s\".\n" fn
    | None ->
        if Sys.file_exists ".ocamlinit" then
          ignore (Toploop.use_silently Format.err_formatter ".ocamlinit")
        else
          let fn = Filename.concat LTerm_resources.home ".ocamlinit" in
          if Sys.file_exists fn then
            ignore (Toploop.use_silently Format.err_formatter fn)

let main () =
  Arg.parse args file_argument usage;
  if not (prepare ()) then exit 2;
  if !emacs_mode then begin
    UTop_private.set_ui UTop_private.Emacs;
    let module Emacs = Emacs (struct end) in
    Printf.printf "Welcome to utop version %s (using OCaml version %s)!\n\n%!" UTop.version Sys.ocaml_version;
    common_init ();
    Emacs.loop ()
  end else begin
    UTop_private.set_ui UTop_private.Console;
    let term = Lwt_main.run (Lazy.force LTerm.stdout) in
    if LTerm.incoming_is_a_tty term && LTerm.outgoing_is_a_tty term then begin
      (* Set the initial size. *)
      UTop_private.set_size (S.const (LTerm.size term));
      (* Install our out phrase printer. *)
      Toploop.print_out_phrase := print_out_phrase term !Toploop.print_out_phrase;
      (* Load user data. *)
      Lwt_main.run (join [init_history (); UTop_styles.load (); LTerm_inputrc.load ()]);
      (* Display a welcome message. *)
      Lwt_main.run (welcome term);
      (* Common initialization. *)
      common_init ();
      (* Print help message. *)
      print_string "\nType #utop_help for help about using utop.\n\n";
      flush stdout;
      (* Main loop. *)
      try
        loop term
      with LTerm_read_line.Interrupt ->
        ()
    end else begin
      (* Use the standard toplevel. Just make sure that Lwt threads can
         run while reading phrases. *)
      Toploop.read_interactive_input := read_input_classic;
      Toploop.loop Format.std_formatter
    end
  end;
  (* Don't let the standard toplevel run... *)
  exit 0