Better check for async rewrite
Resolve the types Async.Std.Deferred.t instead of hard-coding a list of aliases. Fix #137
This commit is contained in:
parent
e608977856
commit
9e41bf85da
|
@ -399,6 +399,8 @@ let with_loc loc str = {
|
||||||
|
|
||||||
(* A rule for rewriting a toplevel expression. *)
|
(* A rule for rewriting a toplevel expression. *)
|
||||||
type rewrite_rule = {
|
type rewrite_rule = {
|
||||||
|
type_to_rewrite : Longident.t;
|
||||||
|
mutable path_to_rewrite : Path.t option;
|
||||||
required_values : Longident.t list;
|
required_values : Longident.t list;
|
||||||
(* Values that must exist and be persistent for the rule to apply. *)
|
(* Values that must exist and be persistent for the rule to apply. *)
|
||||||
rewrite : Location.t -> Parsetree.expression -> Parsetree.expression;
|
rewrite : Location.t -> Parsetree.expression -> Parsetree.expression;
|
||||||
|
@ -407,10 +409,6 @@ type rewrite_rule = {
|
||||||
(* Whether the rule is enabled or not. *)
|
(* Whether the rule is enabled or not. *)
|
||||||
}
|
}
|
||||||
|
|
||||||
(* Rewrite rules, indexed by the identifier of the type
|
|
||||||
constructor. *)
|
|
||||||
let rewrite_rules : (Longident.t, rewrite_rule) Hashtbl.t = Hashtbl.create 42
|
|
||||||
|
|
||||||
let longident_lwt_main_run = Longident.Ldot (Longident.Lident "Lwt_main", "run")
|
let longident_lwt_main_run = Longident.Ldot (Longident.Lident "Lwt_main", "run")
|
||||||
let longident_async_thread_safe_block_on_async_exn =
|
let longident_async_thread_safe_block_on_async_exn =
|
||||||
Longident.parse "Async.Std.Thread_safe.block_on_async_exn"
|
Longident.parse "Async.Std.Thread_safe.block_on_async_exn"
|
||||||
|
@ -436,9 +434,11 @@ let nolabel = Asttypes.Nolabel
|
||||||
let nolabel = ""
|
let nolabel = ""
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
let () =
|
let rewrite_rules = [
|
||||||
(* Rewrite Lwt.t expressions to Lwt_main.run <expr> *)
|
(* Rewrite Lwt.t expressions to Lwt_main.run <expr> *)
|
||||||
Hashtbl.add rewrite_rules (Longident.Ldot (Longident.Lident "Lwt", "t")) {
|
{
|
||||||
|
type_to_rewrite = Longident.parse "Lwt.t";
|
||||||
|
path_to_rewrite = None;
|
||||||
required_values = [longident_lwt_main_run];
|
required_values = [longident_lwt_main_run];
|
||||||
rewrite = (fun loc e ->
|
rewrite = (fun loc e ->
|
||||||
#if OCAML_VERSION < (4, 02, 0)
|
#if OCAML_VERSION < (4, 02, 0)
|
||||||
|
@ -460,7 +460,9 @@ let () =
|
||||||
|
|
||||||
(* Rewrite Async.Std.Defered.t expressions to
|
(* Rewrite Async.Std.Defered.t expressions to
|
||||||
Async.Std.Thread_safe.block_on_async_exn (fun () -> <expr>). *)
|
Async.Std.Thread_safe.block_on_async_exn (fun () -> <expr>). *)
|
||||||
let rule = {
|
{
|
||||||
|
type_to_rewrite = Longident.parse "Async.Std.Deferred.t";
|
||||||
|
path_to_rewrite = None;
|
||||||
required_values = [longident_async_thread_safe_block_on_async_exn];
|
required_values = [longident_async_thread_safe_block_on_async_exn];
|
||||||
rewrite = (fun loc e ->
|
rewrite = (fun loc e ->
|
||||||
#if OCAML_VERSION < (4, 02, 0)
|
#if OCAML_VERSION < (4, 02, 0)
|
||||||
|
@ -482,21 +484,32 @@ let () =
|
||||||
#endif
|
#endif
|
||||||
);
|
);
|
||||||
enabled = UTop.auto_run_async;
|
enabled = UTop.auto_run_async;
|
||||||
} in
|
}
|
||||||
let deferred_aliases =
|
]
|
||||||
[ "Async_core.Ivar.Deferred.t"
|
|
||||||
; "Async_kernel.Ivar.Deferred.t"
|
|
||||||
; "Async_kernel.Deferred0.t"
|
|
||||||
]
|
|
||||||
in
|
|
||||||
List.iter (fun s ->
|
|
||||||
Hashtbl.add rewrite_rules (Longident.parse s) rule)
|
|
||||||
deferred_aliases
|
|
||||||
|
|
||||||
(* Returns whether the argument is a toplevel expression. *)
|
let rule_path rule =
|
||||||
let is_eval = function
|
match rule.path_to_rewrite with
|
||||||
| { Parsetree.pstr_desc = Parsetree.Pstr_eval _ } -> true
|
| Some _ as x -> x
|
||||||
| _ -> false
|
| None ->
|
||||||
|
try
|
||||||
|
let env = !Toploop.toplevel_env in
|
||||||
|
let path =
|
||||||
|
match Env.lookup_type rule.type_to_rewrite env with
|
||||||
|
| path, { Types.type_kind = Types.Type_abstract
|
||||||
|
; Types.type_private = Asttypes.Public
|
||||||
|
; Types.type_manifest = Some ty
|
||||||
|
} -> begin
|
||||||
|
match Ctype.expand_head env ty with
|
||||||
|
| { Types.desc = Types.Tconstr (path, _, _) } -> path
|
||||||
|
| _ -> path
|
||||||
|
end
|
||||||
|
| path, _ -> path
|
||||||
|
in
|
||||||
|
let opt = Some path in
|
||||||
|
rule.path_to_rewrite <- opt;
|
||||||
|
opt
|
||||||
|
with _ ->
|
||||||
|
None
|
||||||
|
|
||||||
(* Returns whether the given path is persistent. *)
|
(* Returns whether the given path is persistent. *)
|
||||||
let rec is_persistent_path = function
|
let rec is_persistent_path = function
|
||||||
|
@ -504,38 +517,6 @@ let rec is_persistent_path = function
|
||||||
| Path.Pdot (p, _, _) -> is_persistent_path p
|
| Path.Pdot (p, _, _) -> is_persistent_path p
|
||||||
| Path.Papply (_, p) -> is_persistent_path p
|
| Path.Papply (_, p) -> is_persistent_path p
|
||||||
|
|
||||||
(* Convert a path to a long identifier. *)
|
|
||||||
let rec longident_of_path path =
|
|
||||||
match path with
|
|
||||||
| Path.Pident id ->
|
|
||||||
Longident.Lident (Ident.name id)
|
|
||||||
| Path.Pdot (path, s, _) ->
|
|
||||||
Longident.Ldot (longident_of_path path, s)
|
|
||||||
| Path.Papply (p1, p2) ->
|
|
||||||
Longident.Lapply (longident_of_path p1, longident_of_path p2)
|
|
||||||
|
|
||||||
(* Returns the rewrite rule associated to a type, if any. *)
|
|
||||||
let rec rule_of_type typ =
|
|
||||||
match typ.Types.desc with
|
|
||||||
| Types.Tlink typ ->
|
|
||||||
rule_of_type typ
|
|
||||||
| Types.Tconstr (path, _, _) -> begin
|
|
||||||
match try Some (Env.find_type path !Toploop.toplevel_env) with Not_found -> None with
|
|
||||||
| Some {
|
|
||||||
Types.type_kind = Types.Type_abstract;
|
|
||||||
Types.type_private = Asttypes.Public;
|
|
||||||
Types.type_manifest = Some typ;
|
|
||||||
} ->
|
|
||||||
rule_of_type typ
|
|
||||||
| _ ->
|
|
||||||
try
|
|
||||||
Some (Hashtbl.find rewrite_rules (longident_of_path path))
|
|
||||||
with Not_found ->
|
|
||||||
None
|
|
||||||
end
|
|
||||||
| _ ->
|
|
||||||
None
|
|
||||||
|
|
||||||
(* Check that the given long identifier is present in the environment
|
(* Check that the given long identifier is present in the environment
|
||||||
and is persistent. *)
|
and is persistent. *)
|
||||||
let is_persistent_in_env longident =
|
let is_persistent_in_env longident =
|
||||||
|
@ -544,6 +525,30 @@ let is_persistent_in_env longident =
|
||||||
with Not_found ->
|
with Not_found ->
|
||||||
false
|
false
|
||||||
|
|
||||||
|
let rule_matches rule path =
|
||||||
|
React.S.value rule.enabled &&
|
||||||
|
(match rule_path rule with
|
||||||
|
| None -> false
|
||||||
|
| Some path' -> Path.same path path') &&
|
||||||
|
List.for_all is_persistent_in_env rule.required_values
|
||||||
|
|
||||||
|
(* Returns whether the argument is a toplevel expression. *)
|
||||||
|
let is_eval = function
|
||||||
|
| { Parsetree.pstr_desc = Parsetree.Pstr_eval _ } -> true
|
||||||
|
| _ -> false
|
||||||
|
|
||||||
|
(* Returns the rewrite rule associated to a type, if any. *)
|
||||||
|
let rec rule_of_type typ =
|
||||||
|
match (Ctype.expand_head !Toploop.toplevel_env typ).Types.desc with
|
||||||
|
| Types.Tconstr (path, _, _) -> begin
|
||||||
|
try
|
||||||
|
Some (List.find (fun rule -> rule_matches rule path) rewrite_rules)
|
||||||
|
with _ ->
|
||||||
|
None
|
||||||
|
end
|
||||||
|
| _ ->
|
||||||
|
None
|
||||||
|
|
||||||
#if OCAML_VERSION < (4, 02, 0)
|
#if OCAML_VERSION < (4, 02, 0)
|
||||||
let rewrite_str_item pstr_item tstr_item =
|
let rewrite_str_item pstr_item tstr_item =
|
||||||
match pstr_item, tstr_item.Typedtree.str_desc with
|
match pstr_item, tstr_item.Typedtree.str_desc with
|
||||||
|
@ -552,11 +557,8 @@ let rewrite_str_item pstr_item tstr_item =
|
||||||
Typedtree.Tstr_eval { Typedtree.exp_type = typ }) -> begin
|
Typedtree.Tstr_eval { Typedtree.exp_type = typ }) -> begin
|
||||||
match rule_of_type typ with
|
match rule_of_type typ with
|
||||||
| Some rule ->
|
| Some rule ->
|
||||||
if React.S.value rule.enabled && List.for_all is_persistent_in_env rule.required_values then
|
|
||||||
{ Parsetree.pstr_desc = Parsetree.Pstr_eval (rule.rewrite loc e);
|
{ Parsetree.pstr_desc = Parsetree.Pstr_eval (rule.rewrite loc e);
|
||||||
Parsetree.pstr_loc = loc }
|
Parsetree.pstr_loc = loc }
|
||||||
else
|
|
||||||
pstr_item
|
|
||||||
| None ->
|
| None ->
|
||||||
pstr_item
|
pstr_item
|
||||||
end
|
end
|
||||||
|
@ -570,11 +572,8 @@ let rewrite_str_item pstr_item tstr_item =
|
||||||
Typedtree.Tstr_eval ({ Typedtree.exp_type = typ }, _)) -> begin
|
Typedtree.Tstr_eval ({ Typedtree.exp_type = typ }, _)) -> begin
|
||||||
match rule_of_type typ with
|
match rule_of_type typ with
|
||||||
| Some rule ->
|
| Some rule ->
|
||||||
if React.S.value rule.enabled && List.for_all is_persistent_in_env rule.required_values then
|
|
||||||
{ Parsetree.pstr_desc = Parsetree.Pstr_eval (rule.rewrite loc e, []);
|
{ Parsetree.pstr_desc = Parsetree.Pstr_eval (rule.rewrite loc e, []);
|
||||||
Parsetree.pstr_loc = loc }
|
Parsetree.pstr_loc = loc }
|
||||||
else
|
|
||||||
pstr_item
|
|
||||||
| None ->
|
| None ->
|
||||||
pstr_item
|
pstr_item
|
||||||
end
|
end
|
||||||
|
|
Loading…
Reference in New Issue