From c080982044a193a2a000bd5e4011bb0c36f2a0cc Mon Sep 17 00:00:00 2001 From: tali Date: Wed, 13 Dec 2023 19:18:54 -0500 Subject: [PATCH] implement closure conversion on Ir.id's --- bin/main.ml | 8 +++++-- lib/compile/bcc.ml | 46 +++++++++++++++++++++++++++++++------ lib/compile/ir.ml | 57 ++++++++++++++++++++++++++++++++-------------- 3 files changed, 85 insertions(+), 26 deletions(-) diff --git a/bin/main.ml b/bin/main.ml index e680df0..95daddb 100644 --- a/bin/main.ml +++ b/bin/main.ml @@ -7,8 +7,12 @@ let () = try let ast = parse " val two = 2 - val zero = 0 - fun twice(x) x*two+zero + val one = 1 + fun twice(x) { + fun f(x) (x - one) * two + fun g() x + one + f(g()) + } println(twice(4)) " in Logs.debug (fun m -> m "[AST] %a" Ast.pp_modl ast); diff --git a/lib/compile/bcc.ml b/lib/compile/bcc.ml index de4c4a3..467defc 100644 --- a/lib/compile/bcc.ml +++ b/lib/compile/bcc.ml @@ -14,7 +14,7 @@ let undef_method = Value.Native_function (fun _ -> failwith "BUG: method undefined") -let rec compile_lambda (lam : Ir.lambda) = +let rec compile_lambda ?clos_map (lam : Ir.lambda) = let entrypoint = Code.make_block () in let currb = ref entrypoint in let emit i = Code.extend !currb i in @@ -104,20 +104,50 @@ let rec compile_lambda (lam : Ir.lambda) = in emit (CAL (sp, obj, mth, args)) - | Ir.Obj { vals; funs } -> - let n_slots = List.length vals in - let elems = Hashtbl.create (List.length vals + List.length funs) in - let mthds = Array.make (List.length funs) undef_method in + | Ir.Obj { vals; funs; clos } -> + (* assign each captured id to a slot *) + let clos_map = Hashtbl.create 64 in + let n_slots = + List.fold_left + (fun n id -> + Hashtbl.add clos_map id n; + n + 1) + (List.length vals) + clos + in + + (* assign each val to a slot *) + let elems = Hashtbl.create 64 in List.iteri (fun i name -> Hashtbl.add elems name (Value.Field i)) vals; + + (* compile methods and assign to an index *) + let mthds = Array.make (List.length funs) undef_method in List.iteri (fun i (name, lambda) -> Hashtbl.add elems name (Value.Method i); - mthds.(i) <- Code.Function (compile_lambda lambda)) + mthds.(i) <- Code.Function (compile_lambda lambda ~clos_map)) funs; - emit (CON (sp, { n_slots; elems; mthds })) + + (* construct object and save captured id's *) + emit (CON (sp, { n_slots; elems; mthds })); + Hashtbl.iter + (fun id idx -> + let obj = sp in + let loc = suc sp in + emit (LDI (loc, Value.of_int idx)); + emit (SET (get_reg id, obj, loc))) + clos_map + + | Ir.Open id -> + let idx = try Hashtbl.find (Option.get clos_map) id + with Not_found -> failwith "BUG: %S not captured" + | Invalid_argument _ -> failwith "BUG: no captured variables" + in + emit (LDI (sp, Value.of_int idx)); + emit (GET (sp, get_reg lam.self, sp)) | ir -> let rv = emit_exp sp ir in @@ -131,6 +161,8 @@ let rec compile_lambda (lam : Ir.lambda) = in + (* R0 = self *) + (* R(i+1) = args[i] *) set_reg lam.self (Code.R 0); let sp = List.fold_left diff --git a/lib/compile/ir.ml b/lib/compile/ir.ml index 5756243..62cde27 100644 --- a/lib/compile/ir.ml +++ b/lib/compile/ir.ml @@ -44,13 +44,14 @@ type exp = | Bop of bop * exp * exp | Call of path * exp list | Obj of obj + | Open of id and path = id * string and obj = { vals : string list; funs : (string * lambda) list; - (* clos : id list; *) + clos : id list; } and lambda = { @@ -73,17 +74,21 @@ module Env = struct | Fun of { pred : t; args : (string * id) list; - (* clos : ??? *) + clos : (id, unit) Hashtbl.t; } let rec find name = function | Empty -> raise Not_found - | Fun { pred; args } -> + | Fun { pred; args; clos } -> begin match List.assoc name args with | id -> id, None - | exception Not_found -> find name pred + | exception Not_found -> + let id, fld = find name pred in + (* mark id's from pred env as needing capture *) + Hashtbl.replace clos id (); + id, fld end | Obj { pred; self; elems } -> @@ -95,6 +100,10 @@ end let seq_r a b = Seq (b, a) +let union xs ys = + List.sort_uniq compare + (List.rev_append ys xs) + let lower ~lib (modl : Ast.modl) = let new_id = make_id_dispenser () in @@ -186,34 +195,33 @@ let lower ~lib (modl : Ast.modl) = let self = new_id "obj" in let env = Env.Obj { self; elems; pred = env } in - let env' = Env.Obj { self; elems; pred = Empty } in - let funs_r, vals_r, inits_r = + let funs_r, vals_r, inits_r, clos = List.fold_left - (fun (fns, vls, ins) -> function + (fun (fns, vls, ins, clos) -> function | Ast.Item_exp exp -> let init = lower_exp env exp in - fns, vls, init :: ins + fns, vls, init :: ins, clos | Ast.Item_val (name, exp) -> let init = Set ((self, name), lower_exp env exp) in - fns, name :: vls, init :: ins + fns, name :: vls, init :: ins, clos | Ast.Item_obj (name, items) -> (* TODO: it would be ideal if we could construct the empty versions of obj's in a sort of "pre-init" phase, before assigning field values. but for now, obj items are identical to val's where the rhs is an obj expression. *) let init = Set ((self, name), lower_block env items) in - fns, name :: vls, init :: ins + fns, name :: vls, init :: ins, clos | Ast.Item_fun (name, args, body) -> - let fn = (name, lower_lambda self env' args body) in - fn :: fns, vls, ins) - ([], [], []) + let lam, clos' = lower_lambda self env args body in + (name, lam) :: fns, vls, ins, union clos clos') + ([], [], [], []) items in - (* if [is_scope], return the last expr, otherwise return the object itself *) + (* if [is_scope], return the last expr, otherwise return the object (self) *) let ret, inits_r = match is_scope, inits_r with | true, init :: inits -> init, inits | _, inits -> Var self, inits @@ -225,6 +233,7 @@ let lower ~lib (modl : Ast.modl) = Obj { funs = List.rev funs_r; vals = List.rev vals_r; + clos; }, List.fold_left (fun a b -> Seq (b, a)) @@ -234,11 +243,25 @@ let lower ~lib (modl : Ast.modl) = and lower_lambda self env args body = let args = List.map (fun a -> a, new_id a) args in - let env = Env.Fun { args; pred = env } in - (* TODO: closure conversion *) + let clos = Hashtbl.create 32 in + let env = Env.Fun { args; clos; pred = env } in let body = lower_exp env body in + + (* wrap body in let bindings to read from the closure *) + let body, clos = + Hashtbl.fold + (fun id () (ir, clos) -> + if id = self then + (* [self] isn't "captured"; it IS the closure! *) + ir, clos + else + Let (id, Open id, ir), id :: clos) + clos + (body, []) + in + let args = List.map snd args in - { self; args; body } + { self; args; body }, clos in