Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Risc-v: proper management of the return address #939

Merged
merged 2 commits into from
Nov 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 9 additions & 4 deletions compiler/src/arch_full.ml
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,14 @@ open Arch_extra
open Prog

type 'a callstyle =
| StackDirect (* call instruction push the return address on top of the stack *)
| ByReg of 'a option (* call instruction store the return address on a register,
(Some r) neams that the register is forced to be r *)
| StackDirect
(* call instruction push the return address on top of the stack *)
| ByReg of { call : 'a option; return : bool }
(* call instruction store the return address on a register,
- call: (Some r) means that the register is forced to be r
- return:
+ true means that the register is also used for the return
+ false means that there is no constraint (stack is also ok) *)

(* TODO: check that we cannot use sth already defined on the Coq side *)

Expand Down Expand Up @@ -191,7 +196,7 @@ module Arch_from_Core_arch (A : Core_arch) :
let callstyle =
match A.callstyle with
| StackDirect -> StackDirect
| ByReg o -> ByReg (Option.map var_of_reg o)
| ByReg { call; return } -> ByReg { call = Option.map var_of_reg call; return }

let arch_info = Pretyping.{
pd = reg_size;
Expand Down
12 changes: 9 additions & 3 deletions compiler/src/arch_full.mli
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,15 @@ open Arch_extra
open Prog

type 'a callstyle =
| StackDirect (* call instruction push the return address on top of the stack *)
| ByReg of 'a option (* call instruction store the return address on a register,
(Some r) neams that the register is forced to be r *)
| StackDirect
(* call instruction push the return address on top of the stack *)
| ByReg of { call : 'a option; return : bool }
(* call instruction store the return address on a register,
- call: (Some r) means that the register is forced to be r
- return:
+ true means that the register is also used for the return
+ false means that there is no constraint (stack is also ok) *)

(* x86 : StackDirect
arm v7 : ByReg (Some ra)
riscV : ByReg (can it be StackDirect too ?)
Expand Down
2 changes: 1 addition & 1 deletion compiler/src/arm_arch_full.ml
Original file line number Diff line number Diff line change
Expand Up @@ -115,5 +115,5 @@ module Arm (Lowering_params : Arm_input) : Arch_full.Core_arch = struct

let pp_asm = Pp_arm_m4.print_prog

let callstyle = Arch_full.ByReg (Some LR)
let callstyle = Arch_full.ByReg { call = Some LR; return = false }
end
46 changes: 24 additions & 22 deletions compiler/src/pp_riscv.ml
Original file line number Diff line number Diff line change
Expand Up @@ -173,33 +173,25 @@ let pp_instr fn i =
[ LInstr ("j", [ pp_remote_label lbl ]) ]

| JMPI arg ->
(* TODO_RISCV: Review. *)
let lbl =
match arg with
| Reg r -> pp_register r
| _ -> failwith "TODO_RISCV: pp_instr jmpi"
in
[ LInstr ("jr", [ lbl ]) ]
begin match arg with
| Reg RA -> [LInstr ("ret", [])]
| Reg r -> [ LInstr ("jr", [ pp_register r ]) ]
| _ -> failwith "TODO_RISCV: pp_instr jmpi"
end

| Jcc (lbl, ct) ->
let iname = pp_condition_kind ct.cond_kind in
let cond_fst = pp_cond_arg ct.cond_fst in
let cond_snd = pp_cond_arg ct.cond_snd in
[ LInstr (iname, [ cond_fst; cond_snd; pp_label fn lbl ]) ]

| CALL lbl ->
[LInstr ("call", [pp_remote_label lbl])]

| JAL (reg, lbl) ->
begin match reg with
| RA -> [LInstr ("call", [pp_remote_label lbl] )]
| _ -> [LInstr ("jalr", [pp_register reg; pp_remote_label lbl] )]
end
| JAL (RA, lbl) ->
[LInstr ("call", [pp_remote_label lbl])]

| POPPC ->
[ LInstr ("lw", [ pp_register RA; pp_reg_address_aux (pp_register SP) None None None]);
LInstr ("addi", [ pp_register SP; pp_register SP; "4"]);
LInstr ("ret", [ ]) ]
| JAL _
| CALL _
| POPPC ->
assert false

| SysCall op ->
[LInstr ("call", [ pp_syscall op ])]
Expand Down Expand Up @@ -240,11 +232,21 @@ let pp_fun (fn, fd) =
else []
in let pre =
let fn = escape fn in
if fd.asm_fd_export then [ LLabel (mangle fn); LLabel fn; LInstr ("addi", [ pp_register SP; pp_register SP; "-4"]); LInstr ("sw", [ pp_register RA; pp_reg_address_aux (pp_register SP) None None None])] else []
if fd.asm_fd_export then
[ LLabel (mangle fn);
LLabel fn;
LInstr ("addi", [ pp_register SP; pp_register SP; "-4"]);
LInstr ("sw", [ pp_register RA; pp_reg_address_aux (pp_register SP) None None None])]
else []
in
let body = pp_body fn fd.asm_fd_body in
(* TODO_RISCV: Review. *)
let pos = if fd.asm_fd_export then pp_instr fn POPPC else [] in
let pos =
if fd.asm_fd_export then
[ LInstr ("lw", [ pp_register RA; pp_reg_address_aux (pp_register SP) None None None]);
LInstr ("addi", [ pp_register SP; pp_register SP; "4"]);
LInstr ("ret", [ ]) ]
else []
in
head @ pre @ body @ pos

let pp_funcs funs = List.concat_map pp_fun funs
Expand Down
16 changes: 12 additions & 4 deletions compiler/src/printer.ml
Original file line number Diff line number Diff line change
Expand Up @@ -375,14 +375,22 @@ let pp_saved_stack ~debug fmt = function
let pp_tmp_option ~debug =
Format.pp_print_option (fun fmt x -> Format.fprintf fmt " [tmp = %a]" (pp_var ~debug) (Conv.var_of_cvar x))

let pp_ra_call ~debug =
Format.pp_print_option (fun fmt ra_call -> Format.fprintf fmt "%a -> " (pp_var ~debug) (Conv.var_of_cvar ra_call))

let pp_ra_return ~debug =
Format.pp_print_option (fun fmt ra_return -> Format.fprintf fmt " -> %a" (pp_var ~debug) (Conv.var_of_cvar ra_return))

let pp_return_address ~debug fmt = function
| Expr.RAreg (x, o) ->
Format.fprintf fmt "%a%a" (pp_var ~debug) (Conv.var_of_cvar x) (pp_tmp_option ~debug) o

| Expr.RAstack(Some x, z, o) ->
Format.fprintf fmt "%a, RSP + %a%a" (pp_var ~debug) (Conv.var_of_cvar x) Z.pp_print (Conv.z_of_cz z) (pp_tmp_option ~debug) o
| Expr.RAstack(None, z, o) ->
Format.fprintf fmt "RSP + %a%a" Z.pp_print (Conv.z_of_cz z) (pp_tmp_option ~debug) o
| Expr.RAstack(ra_call, ra_return, z, o) ->
Format.fprintf fmt "%aRSP + %a%a%a"
(pp_ra_call ~debug) ra_call Z.pp_print (Conv.z_of_cz z)
(pp_tmp_option ~debug) o
(pp_ra_return ~debug) ra_return

| Expr.RAnone -> Format.fprintf fmt "_"

let pp_sprog ~debug pd asmOp fmt ((funcs, p_extra):('info, 'asm) Prog.sprog) =
Expand Down
107 changes: 79 additions & 28 deletions compiler/src/regalloc.ml
Original file line number Diff line number Diff line change
Expand Up @@ -454,21 +454,36 @@ let collect_variables_cb ~(allvars: bool) (excluded: Sv.t) (fresh: unit -> int)
let n = fresh () in
Hv.add tbl v n

let collect_variables_aux ~(allvars: bool) (excluded: Sv.t) (fresh: unit -> int) (tbl: int Hv.t) (extra: var option) (f: ('info, 'asm) func) : unit =
let collect_variables_aux ~(allvars: bool) (excluded: Sv.t) (fresh: unit -> int) (tbl: int Hv.t) (extra: Sv.t) (f: ('info, 'asm) func) : unit =
let get v = collect_variables_cb ~allvars excluded fresh tbl v in
iter_variables get f;
match extra with Some x -> get x | None -> ()
Sv.iter get extra

let collect_variables ~(allvars: bool) (excluded:Sv.t) (f: ('info, 'asm) func) : int Hv.t * int =
let fresh, total = make_counter () in
let tbl : int Hv.t = Hv.create 97 in
collect_variables_aux ~allvars excluded fresh tbl None f;
collect_variables_aux ~allvars excluded fresh tbl Sv.empty f;
tbl, total ()

(* TODO: should StackDirect be just StackByReg (None, None, None)? *)
type retaddr =
| StackDirect
| StackByReg of var * var option
(* ra is passed on the stack and read from the stack *)
| StackByReg of var * var option * var option
(* StackByReg (ra_call, ra_return, tmp) *)
| ByReg of var * var option
(* ByReg (ra, tmp) *)

let vars_retaddr ra =
let oadd ov s =
match ov with
| None -> s
| Some v -> Sv.add v s
in
match ra with
| StackByReg (ra_call, ra_return, tmp) -> oadd tmp (oadd ra_return (Sv.singleton ra_call))
| ByReg (ra, tmp) -> oadd tmp (Sv.singleton ra)
| StackDirect -> Sv.empty

let collect_variables_in_prog
~(allvars: bool)
Expand All @@ -479,12 +494,8 @@ let collect_variables_in_prog
let fresh, total = make_counter () in
let tbl : int Hv.t = Hv.create 97 in
List.iter (fun f ->
let extra, tmp =
match Hf.find return_adresses f.f_name with
| StackByReg (v, tmp) | ByReg (v, tmp) -> Some v, tmp
| StackDirect -> None, None in
collect_variables_aux ~allvars excluded fresh tbl extra f;
Option.may (collect_variables_cb ~allvars excluded fresh tbl) tmp) f;
let extra = vars_retaddr (Hf.find return_adresses f.f_name) in
collect_variables_aux ~allvars excluded fresh tbl extra f) f;
List.iter (collect_variables_cb ~allvars excluded fresh tbl) all_reg;
tbl, total ()

Expand Down Expand Up @@ -694,10 +705,27 @@ let allocate_forced_registers return_addresses translate_var nv (vars: int Hv.t)
if FInfo.is_export f.f_cc then alloc_args loc identity f.f_args;
if FInfo.is_export f.f_cc then alloc_ret loc L.unloc f.f_ret;
alloc_stmt f.f_body;
match Hf.find return_addresses f.f_name, Arch.callstyle with
| (StackByReg (ra,_) | ByReg (ra, _)), Arch_full.ByReg (Some r) ->
match Arch.callstyle with
| Arch_full.ByReg { call = Some r; return } ->
begin match Hf.find return_addresses f.f_name with
| StackDirect -> ()
| StackByReg (ra_call, ra_return, _) ->
let i = Hv.find vars ra_call in
allocate_one nv vars (Location.i_loc f.f_loc []) cnf ra_call i r a;
if return then begin
match ra_return with
| Some ra_return ->
let i = Hv.find vars ra_return in
allocate_one nv vars (Location.i_loc f.f_loc []) cnf ra_return i r a
| None ->
(* calling convention requires the return address to be in a register,
but there is no booked register. This must not happen. *)
assert false
end
| ByReg (ra, _) ->
let i = Hv.find vars ra in
allocate_one nv vars (Location.i_loc f.f_loc []) cnf ra i r a
end
| _ -> ()

(* Returns a variable from [regs] that is allocated to a friend variable of [i]. Defaults to [dflt]. *)
Expand Down Expand Up @@ -943,7 +971,7 @@ let subroutine_ra_by_stack f =
| Subroutine _ ->
match Arch.callstyle with
| Arch_full.StackDirect -> true
| Arch_full.ByReg oreg ->
| Arch_full.ByReg { call = oreg } ->
let dfl = oreg <> None && has_call_or_syscall f.f_body in
match f.f_annot.retaddr_kind with
| None -> dfl
Expand Down Expand Up @@ -1108,15 +1136,15 @@ let global_allocation translate_var get_internal_size (funcs: ('info, 'asm) func
let preprocess f =
let f = f |> fill_in_missing_names |> Ssa.split_live_ranges false in
Hf.add liveness_table f.f_name (Liveness.live_fd true f);
(* compute where will be store the return address *)
(* compute where the return address will be stored *)
let ra =
match f.f_cc with
| Export _ -> StackDirect
| Internal -> assert false
| Subroutine _ ->
match Arch.callstyle with
| Arch_full.StackDirect -> StackDirect
| Arch_full.ByReg oreg ->
| Arch_full.ByReg { call = oreg; return } ->
let dfl = oreg <> None && has_call_or_syscall f.f_body in
let r = V.mk ("ra_"^f.f_name.fn_name) (Reg(Normal,Direct)) (tu Arch.reg_size) f.f_loc [] in
(* Fixme: Add an option in Arch to say when the tmp reg is needed *)
Expand All @@ -1130,7 +1158,14 @@ let global_allocation translate_var get_internal_size (funcs: ('info, 'asm) func
match f.f_annot.retaddr_kind with
| None -> dfl
| Some k -> dfl || k = OnStack in
if rastack then StackByReg (r, tmp)
if rastack then
let r_return =
if return then
let r_return = V.mk ("ra_"^f.f_name.fn_name) (Reg(Normal,Direct)) (tu Arch.reg_size) f.f_loc [] in
Some r_return
else None
in
StackByReg (r, r_return, tmp)
else ByReg (r, tmp) in
Hf.add return_addresses f.f_name ra;
let written =
Expand All @@ -1139,10 +1174,7 @@ let global_allocation translate_var get_internal_size (funcs: ('info, 'asm) func
match f.f_cc with
| (Export _ | Internal) -> written
| Subroutine _ ->
match ra with
| StackDirect -> written
| StackByReg (r, None) | ByReg (r, None) -> Sv.add r written
| StackByReg (r, Some t) | ByReg (r, Some t) -> Sv.add t (Sv.add r written)
Sv.union (vars_retaddr ra) written
in
let killed_by_calls =
Mf.fold (fun fn _locs acc -> Sv.union (killed fn) acc)
Expand Down Expand Up @@ -1217,14 +1249,32 @@ let global_allocation translate_var get_internal_size (funcs: ('info, 'asm) func
List.fold_left (fun cnf x -> conflicts_add_one Arch.pointer_data Arch.reg_size Arch.asmOp vars tr Lnone ra x cnf) in
List.fold_left (fun a f ->
match Hf.find return_addresses f.f_name with
| ByReg (ra, None) | StackByReg(ra,None) ->
doit ra a f.f_args
| ByReg (ra, Some tmp) | StackByReg(ra,Some tmp) ->
| StackDirect -> a
| StackByReg (ra_call, ra_return, tmp) ->
(* ra_call conflicts with function arguments *)
let a = doit ra_call a f.f_args in
let a =
match ra_return with
| Some ra_return ->
(* ra_return conflicts with function results *)
doit ra_return a (List.map L.unloc f.f_ret)
| None -> a
in
begin match tmp with
| Some tmp ->
(* tmp register used to increment the stack conflicts with function arguments and results *)
let a = doit tmp a f.f_args in
doit tmp a (List.map L.unloc f.f_ret)
| None -> a
end
| ByReg (ra, tmp) ->
let a = doit ra a f.f_args in
(* tmp register used to increment the stack conflicts with function arguments and results *)
let a = doit tmp a f.f_args in
doit tmp a (List.map L.unloc f.f_ret)
| StackDirect -> a)
match tmp with
| Some tmp ->
(* tmp register used to increment the stack conflicts with function arguments and results *)
let a = doit tmp a f.f_args in
doit tmp a (List.map L.unloc f.f_ret)
| None -> a)
conflicts funcs in
(* Inter-procedural conflicts *)
let conflicts =
Expand Down Expand Up @@ -1304,7 +1354,8 @@ let alloc_prog translate_var (has_stack: ('info, 'asm) func -> 'a -> bool) get_i
let ro_return_address =
match Hf.find return_addresses f.f_name with
| StackDirect -> StackDirect
| StackByReg(r, tmp) -> StackByReg (subst r, Option.map subst tmp)
| StackByReg(ra_call, ra_return, tmp) ->
StackByReg (subst ra_call, Option.map subst ra_return, Option.map subst tmp)
| ByReg(r, tmp) -> ByReg (subst r, Option.map subst tmp) in
let ro_to_save = if FInfo.is_export f.f_cc then Sv.elements to_save else [] in
e, { ro_to_save ; ro_rsp ; ro_return_address }, f
Expand Down
2 changes: 1 addition & 1 deletion compiler/src/regalloc.mli
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ val fill_in_missing_names : ('info, 'asm) Prog.func -> ('info, 'asm) Prog.func

type retaddr =
| StackDirect
| StackByReg of var * var option
| StackByReg of var * var option * var option
| ByReg of var * var option

type reg_oracle_t = {
Expand Down
2 changes: 1 addition & 1 deletion compiler/src/riscv_arch_full.ml
Original file line number Diff line number Diff line change
Expand Up @@ -51,5 +51,5 @@ module Riscv (Lowering_params : Riscv_input) : Arch_full.Core_arch = struct

let pp_asm = Pp_riscv.print_prog

let callstyle = Arch_full.ByReg (Some RA)
let callstyle = Arch_full.ByReg { call = Some RA; return = true }
end
7 changes: 4 additions & 3 deletions compiler/src/stackAlloc.ml
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,7 @@ let memory_analysis pp_err ~debug up =
sao_rsp = saved_stack;
sao_return_address =
(* This is a dummy value it will be fixed in fix_csao *)
RAstack (None, Conv.cz_of_int 0, None)
RAstack (None, None, Conv.cz_of_int 0, None)
} in
Hf.replace atbl fn csao
in
Expand All @@ -339,8 +339,9 @@ let memory_analysis pp_err ~debug up =
Stack_alloc.{ csao with
sao_return_address =
match ro.ro_return_address with
| StackDirect -> RAstack (None, Conv.cz_of_int 0, None) (* FIXME stackDirect should provide a tmp register *)
| StackByReg (r, tmp) -> RAstack (Some (Conv.cvar_of_var r), Conv.cz_of_int 0, Option.map Conv.cvar_of_var tmp)
| StackDirect -> RAstack (None, None, Conv.cz_of_int 0, None) (* FIXME stackDirect should provide a tmp register *)
| StackByReg (ra_call, ra_return, tmp) ->
RAstack (Some (Conv.cvar_of_var ra_call), Option.map Conv.cvar_of_var ra_return, Conv.cz_of_int 0, Option.map Conv.cvar_of_var tmp)
| ByReg (r, tmp) -> RAreg (Conv.cvar_of_var r, Option.map Conv.cvar_of_var tmp)
} in
Hf.replace atbl fn csao
Expand Down
4 changes: 2 additions & 2 deletions compiler/src/varalloc.ml
Original file line number Diff line number Diff line change
Expand Up @@ -396,13 +396,13 @@ let alloc_stack_fd callstyle pd get_info gtbl fd =
false (* For export function ra is not counted in the frame *)
| Subroutine _ ->
match callstyle with
| Arch_full.StackDirect ->
| Arch_full.StackDirect ->
if fd.f_annot.retaddr_kind = Some OnReg then
Utils.warning Always (L.i_loc fd.f_loc [])
"for function %s, return address by reg not allowed for that architecture, annotation is ignored"
fd.f_name.fn_name;
true
| Arch_full.ByReg oreg -> (* oreg = Some r implies that all call use r,
| Arch_full.ByReg { call = oreg } -> (* oreg = Some r implies that all call use r,
so if the function performs some call r will be overwritten,
so ra need to be saved on stack *)
let dfl = oreg <> None && has_call_or_syscall fd.f_body in
Expand Down
Loading