diff --git a/src/lib/uTop_main.cppo.ml b/src/lib/uTop_main.cppo.ml index f613ea7..493e6c9 100644 --- a/src/lib/uTop_main.cppo.ml +++ b/src/lib/uTop_main.cppo.ml @@ -399,6 +399,8 @@ let with_loc loc str = { (* A rule for rewriting a toplevel expression. *) type rewrite_rule = { + type_to_rewrite : Longident.t; + mutable path_to_rewrite : Path.t option; required_values : Longident.t list; (* Values that must exist and be persistent for the rule to apply. *) rewrite : Location.t -> Parsetree.expression -> Parsetree.expression; @@ -407,10 +409,6 @@ type rewrite_rule = { (* Whether the rule is enabled or not. *) } -(* Rewrite rules, indexed by the identifier of the type - constructor. *) -let rewrite_rules : (Longident.t, rewrite_rule) Hashtbl.t = Hashtbl.create 42 - let longident_lwt_main_run = Longident.Ldot (Longident.Lident "Lwt_main", "run") let longident_async_thread_safe_block_on_async_exn = Longident.parse "Async.Std.Thread_safe.block_on_async_exn" @@ -436,9 +434,11 @@ let nolabel = Asttypes.Nolabel let nolabel = "" #endif -let () = +let rewrite_rules = [ (* Rewrite Lwt.t expressions to Lwt_main.run *) - Hashtbl.add rewrite_rules (Longident.Ldot (Longident.Lident "Lwt", "t")) { + { + type_to_rewrite = Longident.parse "Lwt.t"; + path_to_rewrite = None; required_values = [longident_lwt_main_run]; rewrite = (fun loc e -> #if OCAML_VERSION < (4, 02, 0) @@ -460,7 +460,9 @@ let () = (* Rewrite Async.Std.Defered.t expressions to Async.Std.Thread_safe.block_on_async_exn (fun () -> ). *) - let rule = { + { + type_to_rewrite = Longident.parse "Async.Std.Deferred.t"; + path_to_rewrite = None; required_values = [longident_async_thread_safe_block_on_async_exn]; rewrite = (fun loc e -> #if OCAML_VERSION < (4, 02, 0) @@ -482,21 +484,32 @@ let () = #endif ); enabled = UTop.auto_run_async; - } in - let deferred_aliases = - [ "Async_core.Ivar.Deferred.t" - ; "Async_kernel.Ivar.Deferred.t" - ; "Async_kernel.Deferred0.t" - ] - in - List.iter (fun s -> - Hashtbl.add rewrite_rules (Longident.parse s) rule) - deferred_aliases + } +] -(* Returns whether the argument is a toplevel expression. *) -let is_eval = function - | { Parsetree.pstr_desc = Parsetree.Pstr_eval _ } -> true - | _ -> false +let rule_path rule = + match rule.path_to_rewrite with + | Some _ as x -> x + | None -> + try + let env = !Toploop.toplevel_env in + let path = + match Env.lookup_type rule.type_to_rewrite env with + | path, { Types.type_kind = Types.Type_abstract + ; Types.type_private = Asttypes.Public + ; Types.type_manifest = Some ty + } -> begin + match Ctype.expand_head env ty with + | { Types.desc = Types.Tconstr (path, _, _) } -> path + | _ -> path + end + | path, _ -> path + in + let opt = Some path in + rule.path_to_rewrite <- opt; + opt + with _ -> + None (* Returns whether the given path is persistent. *) let rec is_persistent_path = function @@ -504,38 +517,6 @@ let rec is_persistent_path = function | Path.Pdot (p, _, _) -> is_persistent_path p | Path.Papply (_, p) -> is_persistent_path p -(* Convert a path to a long identifier. *) -let rec longident_of_path path = - match path with - | Path.Pident id -> - Longident.Lident (Ident.name id) - | Path.Pdot (path, s, _) -> - Longident.Ldot (longident_of_path path, s) - | Path.Papply (p1, p2) -> - Longident.Lapply (longident_of_path p1, longident_of_path p2) - -(* Returns the rewrite rule associated to a type, if any. *) -let rec rule_of_type typ = - match typ.Types.desc with - | Types.Tlink typ -> - rule_of_type typ - | Types.Tconstr (path, _, _) -> begin - match try Some (Env.find_type path !Toploop.toplevel_env) with Not_found -> None with - | Some { - Types.type_kind = Types.Type_abstract; - Types.type_private = Asttypes.Public; - Types.type_manifest = Some typ; - } -> - rule_of_type typ - | _ -> - try - Some (Hashtbl.find rewrite_rules (longident_of_path path)) - with Not_found -> - None - end - | _ -> - None - (* Check that the given long identifier is present in the environment and is persistent. *) let is_persistent_in_env longident = @@ -544,6 +525,30 @@ let is_persistent_in_env longident = with Not_found -> false +let rule_matches rule path = + React.S.value rule.enabled && + (match rule_path rule with + | None -> false + | Some path' -> Path.same path path') && + List.for_all is_persistent_in_env rule.required_values + +(* Returns whether the argument is a toplevel expression. *) +let is_eval = function + | { Parsetree.pstr_desc = Parsetree.Pstr_eval _ } -> true + | _ -> false + +(* Returns the rewrite rule associated to a type, if any. *) +let rec rule_of_type typ = + match (Ctype.expand_head !Toploop.toplevel_env typ).Types.desc with + | Types.Tconstr (path, _, _) -> begin + try + Some (List.find (fun rule -> rule_matches rule path) rewrite_rules) + with _ -> + None + end + | _ -> + None + #if OCAML_VERSION < (4, 02, 0) let rewrite_str_item pstr_item tstr_item = match pstr_item, tstr_item.Typedtree.str_desc with @@ -552,11 +557,8 @@ let rewrite_str_item pstr_item tstr_item = Typedtree.Tstr_eval { Typedtree.exp_type = typ }) -> begin match rule_of_type typ with | Some rule -> - if React.S.value rule.enabled && List.for_all is_persistent_in_env rule.required_values then - { Parsetree.pstr_desc = Parsetree.Pstr_eval (rule.rewrite loc e); - Parsetree.pstr_loc = loc } - else - pstr_item + { Parsetree.pstr_desc = Parsetree.Pstr_eval (rule.rewrite loc e); + Parsetree.pstr_loc = loc } | None -> pstr_item end @@ -570,11 +572,8 @@ let rewrite_str_item pstr_item tstr_item = Typedtree.Tstr_eval ({ Typedtree.exp_type = typ }, _)) -> begin match rule_of_type typ with | Some rule -> - if React.S.value rule.enabled && List.for_all is_persistent_in_env rule.required_values then - { Parsetree.pstr_desc = Parsetree.Pstr_eval (rule.rewrite loc e, []); - Parsetree.pstr_loc = loc } - else - pstr_item + { Parsetree.pstr_desc = Parsetree.Pstr_eval (rule.rewrite loc e, []); + Parsetree.pstr_loc = loc } | None -> pstr_item end