Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement mutually recursive closures #20

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions examples/mutual_recursion.hugorm
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
let rec even = lambda(n) -> if n = 0 then true else odd(n - 1)
and odd = lambda(n) -> if n = 0 then false else even(n - 1)
in even(42)
134 changes: 70 additions & 64 deletions lib/anf.ml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ and 'a adecl = ADFun of string * string list * 'a aexpr * 'a

and 'a aexpr =
| ALet of string * 'a cexpr * 'a aexpr * 'a
| ALetRec of string * 'a cexpr * 'a aexpr * 'a
| ALetRec of (string * string list * 'a aexpr * 'a) list * 'a aexpr * 'a
| ACExpr of 'a cexpr

and 'a cexpr =
Expand All @@ -25,71 +25,71 @@ and 'a immexpr =

type ctx = (string * unit cexpr * bool) list

let anf (e : tag expr) : unit aexpr =
let rec help_a (e : tag expr) : unit aexpr =
let cexpr, bindings = help_c e in
(* Enclose `cexpr` in nested let-bindings *)
List.fold_right
(fun (x, e, is_rec) body ->
if is_rec then ALetRec (x, e, body, ()) else ALet (x, e, body, ()))
bindings (ACExpr cexpr)
and help_c (e : tag expr) : unit cexpr * ctx =
let rec anf (e : 'a expr) : unit aexpr =
let ( let* ) = ( @@ ) in
let rec map f lst k =
match lst with
| [] -> k []
| x :: xs ->
let* x = f x in
let* xs = map f xs in
k (x :: xs)
in
let rec anf_aexpr e k =
match e with
| ELet (x, e, body, _) ->
let* cexpr = anf_cexpr e in
k (ALet (x, cexpr, anf body, ()))
| ELetRec (bindings, body, _) ->
let anf_binding (name, params, body, _) k =
k (name, params, anf body, ())
in
let* bindings = map anf_binding bindings in
let* aexpr = anf_aexpr body in
k (ALetRec (bindings, aexpr, ()))
| _ ->
let* cexpr = anf_cexpr e in
k (ACExpr cexpr)
and anf_cexpr e k =
match e with
| (ENumber _ | EBool _ | EId _) as e ->
let imm, ctx = help_i e in
(CImmExpr imm, ctx)
| EPrim1 (op, e, _) ->
let imm, ctx = help_i e in
(CPrim1 (op, imm, ()), ctx)
| EPrim2 (op, l, r, _) ->
let l_imm, l_ctx = help_i l in
let r_imm, r_ctx = help_i r in
(CPrim2 (op, l_imm, r_imm, ()), l_ctx @ r_ctx)
let* imm = anf_immexpr e in
k (CPrim1 (op, imm, ()))
| EPrim2 (op, left, right, _) ->
let* left_imm = anf_immexpr left in
let* right_imm = anf_immexpr right in
k (CPrim2 (op, left_imm, right_imm, ()))
| EIf (cond, thn, els, _) ->
let cond_imm, cond_ctx = help_i cond in
let thn_aexpr = help_a thn in
let els_aexpr = help_a els in
(CIf (cond_imm, thn_aexpr, els_aexpr, ()), cond_ctx)
| ELet (x, e, body, _) ->
let e_cexpr, e_ctx = help_c e in
let body_cexpr, body_ctx = help_c body in
(body_cexpr, e_ctx @ [ (x, e_cexpr, false) ] @ body_ctx)
| EApp (f, args, _) ->
let f_imm, f_ctx = help_i f in
let arg_imms, arg_ctxs = args |> List.map help_i |> List.split in
let args_ctx = List.concat arg_ctxs in
(CApp (f_imm, arg_imms, ()), f_ctx @ args_ctx)
| ETuple (elements, _) ->
let element_imms, ctxs = elements |> List.map help_i |> List.split in
let ctx = List.concat ctxs in
(CTuple (element_imms, ()), ctx)
| EGetItem (tuple, index, _) ->
let tuple_imm, tuple_ctx = help_i tuple in
let index_imm, index_ctx = help_i index in
(CGetItem (tuple_imm, index_imm, ()), tuple_ctx @ index_ctx)
| ELambda (params, body, _) ->
let body_aexpr = help_a body in
(CLambda (params, body_aexpr, ()), [])
| ELetRec (x, e, body, _) ->
let e_cexpr, e_ctx = help_c e in
let body_cexpr, body_ctx = help_c body in
(body_cexpr, e_ctx @ [ (x, e_cexpr, true) ] @ body_ctx)
and help_i (e : tag expr) : unit immexpr * ctx =
let* cond_imm = anf_immexpr cond in
k (CIf (cond_imm, anf thn, anf els, ()))
| EApp (func, args, _) ->
let* func_imm = anf_immexpr func in
let* args_imms = map anf_immexpr args in
k (CApp (func_imm, args_imms, ()))
| ETuple (exprs, _) ->
let* imms = map anf_immexpr exprs in
k (CTuple (imms, ()))
| EGetItem (expr, idx, _) ->
let* expr_imm = anf_immexpr expr in
let* idx_imm = anf_immexpr idx in
k (CGetItem (expr_imm, idx_imm, ()))
| ELambda (args, body, _) ->
let aexpr = anf body in
k (CLambda (args, aexpr, ()))
| _ ->
let* imm = anf_immexpr e in
k (CImmExpr imm)
and anf_immexpr e k =
match e with
| ENumber (n, _) -> (ImmNum (n, ()), [])
| EBool (b, _) -> (ImmBool (b, ()), [])
| EId (x, _) -> (ImmId (x, ()), [])
| ( EPrim1 _ | EPrim2 _ | EIf _ | ELet _ | EApp _ | ETuple _ | EGetItem _
| ELambda _ | ELetRec _ ) as e ->
imm_of_cexpr e
and imm_of_cexpr (e : tag expr) : unit immexpr * ctx =
let tag = tag_of_expr e in
let x = "x" ^ string_of_int tag in
let cexpr, ctx = help_c e in
let is_rec = match e with ELetRec _ -> true | _ -> false in
(ImmId (x, ()), ctx @ [ (x, cexpr, is_rec) ])
| ENumber (n, _) -> k (ImmNum (n, ()))
| EBool (b, _) -> k (ImmBool (b, ()))
| EId (x, _) -> k (ImmId (x, ()))
| _ ->
let x = "tmp$" ^ string_of_int (tag_of_expr e) in
let* cexpr = anf_cexpr e in
ALet (x, cexpr, k (ImmId (x, ())), ())
in
help_a e
anf_aexpr e Fun.id

let anf_decl (d : tag decl) : unit adecl =
match d with
Expand All @@ -114,10 +114,16 @@ let tag (e : unit aprogram) : tag aprogram =
let e, tag = tag_cexpr e tag in
let body, tag = tag_aexpr body tag in
(ALet (x, e, body, tag), tag + 1)
| ALetRec (x, e, body, ()) ->
let e, tag = tag_cexpr e tag in
| ALetRec (bindings, body, ()) ->
let bindings, tag =
List.fold_right
(fun (name, params, body, ()) (bindings, tag) ->
let body, tag = tag_aexpr body tag in
((name, params, body, tag) :: bindings, tag))
bindings ([], tag)
in
let body, tag = tag_aexpr body tag in
(ALetRec (x, e, body, tag), tag + 1)
(ALetRec (bindings, body, tag), tag + 1)
| ACExpr cexpr ->
let cexpr, tag = tag_cexpr cexpr tag in
(ACExpr cexpr, tag + 1)
Expand Down
110 changes: 63 additions & 47 deletions lib/compile.ml
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,11 @@ open Fvs

exception Unreachable of string

let find x env =
try List.assoc x env with Not_found -> raise (Unreachable __LOC__)

(* TODO: Make StringMap *)
(* TODO: Make StringSet*)
let find x env = try List.assoc x env with Not_found -> raise (Unreachable x)
let word_size = 8
let entry_label = "our_code_starts_here"
let min_int = Int64.div Int64.min_int 2L
let max_int = Int64.div Int64.max_int 2L
let const_true = Const 0xFFFFFFFFFFFFFFFFL
let const_false = Const 0x7FFFFFFFFFFFFFFFL
let bool_mask = Const 0x8000000000000000L
Expand All @@ -33,11 +32,15 @@ let count_let (e : 'a aexpr) =
match aexpr with
| ALet (_, bind, body, _) ->
max (count_let_cexpr bind) (1 + count_let_aexpr body)
| ALetRec (_, bind, body, _) ->
1 + max (count_let_cexpr bind) (count_let_aexpr body)
| ALetRec (bindings, body, _) ->
let max_let_bindings =
List.fold_left
(fun acc (_, _, body, _) -> max acc (count_let_aexpr body))
0 bindings
in
1 + max max_let_bindings (count_let_aexpr body)
| ACExpr cexpr -> count_let_cexpr cexpr
in

count_let_aexpr e

(* We align heap and stack before function calls *)
Expand All @@ -50,13 +53,6 @@ type env = (string * int) list
let empty_env : env = []
let fresh_label ?(prefix = "L") tag = prefix ^ string_of_int tag

let compile_immexpr (env : env) (immexpr : 'a immexpr) : arg =
match immexpr with
| ImmNum (n, _) -> Const (Int64.shift_left n 1)
| ImmBool (true, _) -> const_true
| ImmBool (false, _) -> const_false
| ImmId (x, _) -> RegOffset (RBP, find x env)

let rec compile_cexpr (env : env) (stack_index : int) (cexpr : tag cexpr) : asm
=
match cexpr with
Expand All @@ -70,6 +66,13 @@ let rec compile_cexpr (env : env) (stack_index : int) (cexpr : tag cexpr) : asm
| CLambda (params, body, tag) ->
compile_clambda env stack_index params body tag

and compile_immexpr (env : env) (immexpr : 'a immexpr) : arg =
match immexpr with
| ImmNum (n, _) -> Const (Int64.shift_left n 1)
| ImmBool (true, _) -> const_true
| ImmBool (false, _) -> const_false
| ImmId (x, _) -> RegOffset (RBP, find x env)

and compile_cimmexpr env immexpr =
let arg = compile_immexpr env immexpr in
[ IMov (Reg RAX, arg) ]
Expand Down Expand Up @@ -191,20 +194,20 @@ and compile_cgetitem env tuple index =
IMov (Reg RAX, RegBaseOffset (RAX, R11, 8, 8));
]

and allocate_closure fvs =
and allocate_closure num_fvs =
let to_allocate =
let size = 8 * (1 + List.length fvs) in
let size = 8 * (1 + num_fvs) in
size + stack_alignment_padding size
in
[
(* Tag the closure *)
IMov (Reg RAX, heap_reg);
(* Bump the header pointer *)
IAdd (Reg RAX, closure_tag);
IAdd (heap_reg, Const (Int64.of_int to_allocate));
]

and compile_closure env fvs label =
and compile_closure env fvs tag =
(* Assumes that the closure is already allocated, with the header pointer in RAX *)
let label = fresh_label ~prefix:"lambda" tag in
let move_fvs_to_heap_asm =
fvs
|> List.mapi (fun i fv ->
Expand All @@ -214,11 +217,17 @@ and compile_closure env fvs label =
])
|> List.concat
in
[ ILea (scratch_reg, Label label); IMov (RegOffset (RAX, 0), scratch_reg) ]
[
ISub (Reg RAX, closure_tag);
ILea (scratch_reg, Label label);
IMov (RegOffset (RAX, 0), scratch_reg);
]
@ move_fvs_to_heap_asm
@ [ IAdd (Reg RAX, closure_tag) ]

and compile_lambda_body env stack_index fvs params body =
and compile_lambda_body env stack_index fvs params body tag =
let lambda_label = fresh_label ~prefix:"lambda" tag in
let lambda_end_label = fresh_label ~prefix:"lambda_end" tag in
let frame_size =
let max_stack_size = 8 * (count_let body + List.length fvs) in
max_stack_size + stack_alignment_padding max_stack_size
Expand All @@ -238,6 +247,8 @@ and compile_lambda_body env stack_index fvs params body =
|> List.concat
in
[
IJmp lambda_end_label;
ILabel lambda_label;
IPush (Reg RBP);
IMov (Reg RBP, Reg RSP);
ISub (Reg RSP, Const (Int64.of_int frame_size));
Expand All @@ -248,16 +259,13 @@ and compile_lambda_body env stack_index fvs params body =
@ move_fvs_to_stack_asm
@ compile_aexpr env (stack_index - (8 * List.length fvs)) body
@ [ IMov (Reg RSP, Reg RBP); IPop (Reg RBP); IRet ]
@ [ ILabel lambda_end_label ]

and compile_clambda env stack_index params body tag =
let fvs = S.elements (S.diff (fvs_aexpr body) (S.of_list params)) in
let lambda_label = fresh_label ~prefix:"lambda" tag in
let lambda_end_label = fresh_label ~prefix:"lambda_end" tag in
[ IJmp lambda_end_label; ILabel lambda_label ]
@ compile_lambda_body env stack_index fvs params body
@ [ ILabel lambda_end_label ]
@ allocate_closure fvs
@ compile_closure env fvs lambda_label
compile_lambda_body env stack_index fvs params body tag
@ allocate_closure (List.length fvs)
@ compile_closure env fvs tag

and compile_body (body : tag aexpr) : asm =
let frame_size =
Expand All @@ -284,24 +292,31 @@ and compile_aexpr env stack_index aexpr =
compile_cexpr env stack_index bind
@ [ IMov (RegOffset (RBP, stack_index'), Reg RAX) ]
@ compile_aexpr env' stack_index' body
| ALetRec (f, (CLambda (params, body, tag) as lambda), body', _) ->
let fvs = S.elements (fvs_cexpr lambda) in
let lambda_label = fresh_label ~prefix:"lambda" tag in
let lambda_end_label = fresh_label ~prefix:"lambda_end" tag in
let stack_index' = stack_index - 8 in
let env' = (f, stack_index') :: env in
[ IJmp lambda_end_label; ILabel lambda_label ]
@ compile_lambda_body env' stack_index' fvs params body
@ [ ILabel lambda_end_label ]
@ allocate_closure fvs
@ [
IMov (scratch_reg, Reg RAX);
IAdd (scratch_reg, closure_tag);
IMov (RegOffset (RBP, stack_index'), scratch_reg);
]
@ compile_closure env' fvs lambda_label
@ compile_aexpr env' stack_index' body'
| ALetRec _ -> raise (Unreachable __LOC__)
| ALetRec (bindings, body, _) ->
let env' =
List.mapi (fun i (name, _, _, _) -> (name, -8 * (i + 1))) bindings @ env
in
let stack_index' = stack_index - (8 * List.length bindings) in
let allocate_closures_asm =
bindings
|> List.map (fun (x, params, body, tag) ->
let fvs = S.diff (fvs_aexpr body) (S.of_list params) in
allocate_closure (S.cardinal fvs)
@ [ IMov (RegOffset (RBP, find x env'), Reg RAX) ]
@ compile_lambda_body env' stack_index' (S.elements fvs) params
body tag)
|> List.concat
in
let build_closure_asm =
bindings
|> List.map (fun (x, params, body, tag) ->
let fvs = S.diff (fvs_aexpr body) (S.of_list params) in
[ IMov (Reg RAX, RegOffset (RBP, find x env')) ]
@ compile_closure env' (S.elements fvs) tag)
|> List.concat
in
allocate_closures_asm @ build_closure_asm
@ compile_aexpr env' stack_index' body

let compile_adecl (adecl : 'a adecl) : asm =
match adecl with
Expand All @@ -317,6 +332,7 @@ let compile (prog : 'a program) : asm =
Well_formedness.well_formed prog.body;
let tagged = Syntax.tag prog in
let renamed = Rename.rename_program tagged in
(* print_endline (Syntax.show_program (fun _ _ -> ()) renamed) ; *)
let anfed = Anf.anf_program renamed in
let retagged = Anf.tag anfed in
compile_aprog retagged
17 changes: 13 additions & 4 deletions lib/fvs.ml
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,26 @@ and fvs_cexpr cexpr =
| CIf (imm, aexpr1, aexpr2, _) ->
S.union (fvs_immexpr imm) (S.union (fvs_aexpr aexpr1) (fvs_aexpr aexpr2))
| CApp (imm, imms, _) ->
List.fold_left (fun acc imm -> S.union acc (fvs_immexpr imm)) (fvs_immexpr imm) imms
List.fold_left
(fun acc imm -> S.union acc (fvs_immexpr imm))
(fvs_immexpr imm) imms
| CTuple (imms, _) ->
List.fold_left (fun acc imm -> S.union acc (fvs_immexpr imm)) S.empty imms
| CGetItem (imm1, imm2, _) -> S.union (fvs_immexpr imm1) (fvs_immexpr imm2)
| CLambda (xs, aexpr, _) -> S.diff (fvs_aexpr aexpr) (S.of_list xs)
| CLambda (params, body, _) -> S.diff (fvs_aexpr body) (S.of_list params)
| CImmExpr imm -> fvs_immexpr imm

and fvs_aexpr aexpr =
match aexpr with
| ALet (x, cexpr, aexpr, _) ->
S.union (fvs_cexpr cexpr) (S.remove x (fvs_aexpr aexpr))
| ALetRec (x, cexpr, aexpr, _) ->
S.remove x (S.union (fvs_cexpr cexpr) (fvs_aexpr aexpr))
| ALetRec (bindings, aexpr, _) ->
let xs = List.map (fun (x, _, _, _) -> x) bindings in
let bindings_fvs =
List.fold_left
(fun acc (_, params, body, _) ->
S.union acc (S.diff (fvs_aexpr body) (S.of_list params)))
S.empty bindings
in
S.diff (S.union bindings_fvs (fvs_aexpr aexpr)) (S.of_list xs)
| ACExpr cexpr -> fvs_cexpr cexpr
Loading