implement closure conversion on Ir.id's

This commit is contained in:
tali 2023-12-13 19:18:54 -05:00
parent c986aa6ec0
commit c080982044
3 changed files with 85 additions and 26 deletions

View File

@ -7,8 +7,12 @@ let () =
try try
let ast = parse " let ast = parse "
val two = 2 val two = 2
val zero = 0 val one = 1
fun twice(x) x*two+zero fun twice(x) {
fun f(x) (x - one) * two
fun g() x + one
f(g())
}
println(twice(4)) println(twice(4))
" in " in
Logs.debug (fun m -> m "[AST] %a" Ast.pp_modl ast); Logs.debug (fun m -> m "[AST] %a" Ast.pp_modl ast);

View File

@ -14,7 +14,7 @@ let undef_method =
Value.Native_function Value.Native_function
(fun _ -> failwith "BUG: method undefined") (fun _ -> failwith "BUG: method undefined")
let rec compile_lambda (lam : Ir.lambda) = let rec compile_lambda ?clos_map (lam : Ir.lambda) =
let entrypoint = Code.make_block () in let entrypoint = Code.make_block () in
let currb = ref entrypoint in let currb = ref entrypoint in
let emit i = Code.extend !currb i in let emit i = Code.extend !currb i in
@ -104,20 +104,50 @@ let rec compile_lambda (lam : Ir.lambda) =
in in
emit (CAL (sp, obj, mth, args)) emit (CAL (sp, obj, mth, args))
| Ir.Obj { vals; funs } -> | Ir.Obj { vals; funs; clos } ->
let n_slots = List.length vals in (* assign each captured id to a slot *)
let elems = Hashtbl.create (List.length vals + List.length funs) in let clos_map = Hashtbl.create 64 in
let mthds = Array.make (List.length funs) undef_method in let n_slots =
List.fold_left
(fun n id ->
Hashtbl.add clos_map id n;
n + 1)
(List.length vals)
clos
in
(* assign each val to a slot *)
let elems = Hashtbl.create 64 in
List.iteri List.iteri
(fun i name -> (fun i name ->
Hashtbl.add elems name (Value.Field i)) Hashtbl.add elems name (Value.Field i))
vals; vals;
(* compile methods and assign to an index *)
let mthds = Array.make (List.length funs) undef_method in
List.iteri List.iteri
(fun i (name, lambda) -> (fun i (name, lambda) ->
Hashtbl.add elems name (Value.Method i); Hashtbl.add elems name (Value.Method i);
mthds.(i) <- Code.Function (compile_lambda lambda)) mthds.(i) <- Code.Function (compile_lambda lambda ~clos_map))
funs; funs;
emit (CON (sp, { n_slots; elems; mthds }))
(* construct object and save captured id's *)
emit (CON (sp, { n_slots; elems; mthds }));
Hashtbl.iter
(fun id idx ->
let obj = sp in
let loc = suc sp in
emit (LDI (loc, Value.of_int idx));
emit (SET (get_reg id, obj, loc)))
clos_map
| Ir.Open id ->
let idx = try Hashtbl.find (Option.get clos_map) id
with Not_found -> failwith "BUG: %S not captured"
| Invalid_argument _ -> failwith "BUG: no captured variables"
in
emit (LDI (sp, Value.of_int idx));
emit (GET (sp, get_reg lam.self, sp))
| ir -> | ir ->
let rv = emit_exp sp ir in let rv = emit_exp sp ir in
@ -131,6 +161,8 @@ let rec compile_lambda (lam : Ir.lambda) =
in in
(* R0 = self *)
(* R(i+1) = args[i] *)
set_reg lam.self (Code.R 0); set_reg lam.self (Code.R 0);
let sp = let sp =
List.fold_left List.fold_left

View File

@ -44,13 +44,14 @@ type exp =
| Bop of bop * exp * exp | Bop of bop * exp * exp
| Call of path * exp list | Call of path * exp list
| Obj of obj | Obj of obj
| Open of id
and path = id * string and path = id * string
and obj = { and obj = {
vals : string list; vals : string list;
funs : (string * lambda) list; funs : (string * lambda) list;
(* clos : id list; *) clos : id list;
} }
and lambda = { and lambda = {
@ -73,17 +74,21 @@ module Env = struct
| Fun of { | Fun of {
pred : t; pred : t;
args : (string * id) list; args : (string * id) list;
(* clos : ??? *) clos : (id, unit) Hashtbl.t;
} }
let rec find name = function let rec find name = function
| Empty -> | Empty ->
raise Not_found raise Not_found
| Fun { pred; args } -> | Fun { pred; args; clos } ->
begin match List.assoc name args with begin match List.assoc name args with
| id -> id, None | id -> id, None
| exception Not_found -> find name pred | exception Not_found ->
let id, fld = find name pred in
(* mark id's from pred env as needing capture *)
Hashtbl.replace clos id ();
id, fld
end end
| Obj { pred; self; elems } -> | Obj { pred; self; elems } ->
@ -95,6 +100,10 @@ end
let seq_r a b = Seq (b, a) let seq_r a b = Seq (b, a)
let union xs ys =
List.sort_uniq compare
(List.rev_append ys xs)
let lower ~lib (modl : Ast.modl) = let lower ~lib (modl : Ast.modl) =
let new_id = make_id_dispenser () in let new_id = make_id_dispenser () in
@ -186,34 +195,33 @@ let lower ~lib (modl : Ast.modl) =
let self = new_id "obj" in let self = new_id "obj" in
let env = Env.Obj { self; elems; pred = env } in let env = Env.Obj { self; elems; pred = env } in
let env' = Env.Obj { self; elems; pred = Empty } in
let funs_r, vals_r, inits_r = let funs_r, vals_r, inits_r, clos =
List.fold_left List.fold_left
(fun (fns, vls, ins) -> function (fun (fns, vls, ins, clos) -> function
| Ast.Item_exp exp -> | Ast.Item_exp exp ->
let init = lower_exp env exp in let init = lower_exp env exp in
fns, vls, init :: ins fns, vls, init :: ins, clos
| Ast.Item_val (name, exp) -> | Ast.Item_val (name, exp) ->
let init = Set ((self, name), lower_exp env exp) in let init = Set ((self, name), lower_exp env exp) in
fns, name :: vls, init :: ins fns, name :: vls, init :: ins, clos
| Ast.Item_obj (name, items) -> | Ast.Item_obj (name, items) ->
(* TODO: it would be ideal if we could construct the empty versions of obj's (* TODO: it would be ideal if we could construct the empty versions of obj's
in a sort of "pre-init" phase, before assigning field values. but for now, in a sort of "pre-init" phase, before assigning field values. but for now,
obj items are identical to val's where the rhs is an obj expression. *) obj items are identical to val's where the rhs is an obj expression. *)
let init = Set ((self, name), lower_block env items) in let init = Set ((self, name), lower_block env items) in
fns, name :: vls, init :: ins fns, name :: vls, init :: ins, clos
| Ast.Item_fun (name, args, body) -> | Ast.Item_fun (name, args, body) ->
let fn = (name, lower_lambda self env' args body) in let lam, clos' = lower_lambda self env args body in
fn :: fns, vls, ins) (name, lam) :: fns, vls, ins, union clos clos')
([], [], []) ([], [], [], [])
items items
in in
(* if [is_scope], return the last expr, otherwise return the object itself *) (* if [is_scope], return the last expr, otherwise return the object (self) *)
let ret, inits_r = match is_scope, inits_r with let ret, inits_r = match is_scope, inits_r with
| true, init :: inits -> init, inits | true, init :: inits -> init, inits
| _, inits -> Var self, inits | _, inits -> Var self, inits
@ -225,6 +233,7 @@ let lower ~lib (modl : Ast.modl) =
Obj { Obj {
funs = List.rev funs_r; funs = List.rev funs_r;
vals = List.rev vals_r; vals = List.rev vals_r;
clos;
}, },
List.fold_left List.fold_left
(fun a b -> Seq (b, a)) (fun a b -> Seq (b, a))
@ -234,11 +243,25 @@ let lower ~lib (modl : Ast.modl) =
and lower_lambda self env args body = and lower_lambda self env args body =
let args = List.map (fun a -> a, new_id a) args in let args = List.map (fun a -> a, new_id a) args in
let env = Env.Fun { args; pred = env } in let clos = Hashtbl.create 32 in
(* TODO: closure conversion *) let env = Env.Fun { args; clos; pred = env } in
let body = lower_exp env body in let body = lower_exp env body in
(* wrap body in let bindings to read from the closure *)
let body, clos =
Hashtbl.fold
(fun id () (ir, clos) ->
if id = self then
(* [self] isn't "captured"; it IS the closure! *)
ir, clos
else
Let (id, Open id, ir), id :: clos)
clos
(body, [])
in
let args = List.map snd args in let args = List.map snd args in
{ self; args; body } { self; args; body }, clos
in in