Skip to content

Commit

Permalink
RISC-V: proper management of the return address
Browse files Browse the repository at this point in the history
On RISC-V, the return address is read from register RA. This was not
reflected in the model and generated wrong programs.

The RAstack case of return_address_location is given an additional
optional argument ra_return that specifies what register to use (if any)
when returning from a function. linearization is adapted to generate the
right code. reg alloc and merge_varmaps take into account this potential
new register.
  • Loading branch information
eponier authored and bgregoir committed Nov 5, 2024
1 parent fd5e550 commit d03e378
Show file tree
Hide file tree
Showing 24 changed files with 728 additions and 499 deletions.
13 changes: 9 additions & 4 deletions compiler/src/arch_full.ml
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,14 @@ open Arch_extra
open Prog

type 'a callstyle =
| StackDirect (* call instruction push the return address on top of the stack *)
| ByReg of 'a option (* call instruction store the return address on a register,
(Some r) neams that the register is forced to be r *)
| StackDirect
(* call instruction push the return address on top of the stack *)
| ByReg of { call : 'a option; return : bool }
(* call instruction store the return address on a register,
- call: (Some r) means that the register is forced to be r
- return:
+ true means that the register is also used for the return
+ false means that there is no constraint (stack is also ok) *)

(* TODO: check that we cannot use sth already defined on the Coq side *)

Expand Down Expand Up @@ -191,7 +196,7 @@ module Arch_from_Core_arch (A : Core_arch) :
let callstyle =
match A.callstyle with
| StackDirect -> StackDirect
| ByReg o -> ByReg (Option.map var_of_reg o)
| ByReg { call; return } -> ByReg { call = Option.map var_of_reg call; return }

let arch_info = Pretyping.{
pd = reg_size;
Expand Down
12 changes: 9 additions & 3 deletions compiler/src/arch_full.mli
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,15 @@ open Arch_extra
open Prog

type 'a callstyle =
| StackDirect (* call instruction push the return address on top of the stack *)
| ByReg of 'a option (* call instruction store the return address on a register,
(Some r) neams that the register is forced to be r *)
| StackDirect
(* call instruction push the return address on top of the stack *)
| ByReg of { call : 'a option; return : bool }
(* call instruction store the return address on a register,
- call: (Some r) means that the register is forced to be r
- return:
+ true means that the register is also used for the return
+ false means that there is no constraint (stack is also ok) *)

(* x86 : StackDirect
arm v7 : ByReg (Some ra)
riscV : ByReg (can it be StackDirect too ?)
Expand Down
2 changes: 1 addition & 1 deletion compiler/src/arm_arch_full.ml
Original file line number Diff line number Diff line change
Expand Up @@ -115,5 +115,5 @@ module Arm (Lowering_params : Arm_input) : Arch_full.Core_arch = struct

let pp_asm = Pp_arm_m4.print_prog

let callstyle = Arch_full.ByReg (Some LR)
let callstyle = Arch_full.ByReg { call = Some LR; return = false }
end
46 changes: 24 additions & 22 deletions compiler/src/pp_riscv.ml
Original file line number Diff line number Diff line change
Expand Up @@ -173,33 +173,25 @@ let pp_instr fn i =
[ LInstr ("j", [ pp_remote_label lbl ]) ]

| JMPI arg ->
(* TODO_RISCV: Review. *)
let lbl =
match arg with
| Reg r -> pp_register r
| _ -> failwith "TODO_RISCV: pp_instr jmpi"
in
[ LInstr ("jr", [ lbl ]) ]
begin match arg with
| Reg RA -> [LInstr ("ret", [])]
| Reg r -> [ LInstr ("jr", [ pp_register r ]) ]
| _ -> failwith "TODO_RISCV: pp_instr jmpi"
end

| Jcc (lbl, ct) ->
let iname = pp_condition_kind ct.cond_kind in
let cond_fst = pp_cond_arg ct.cond_fst in
let cond_snd = pp_cond_arg ct.cond_snd in
[ LInstr (iname, [ cond_fst; cond_snd; pp_label fn lbl ]) ]

| CALL lbl ->
[LInstr ("call", [pp_remote_label lbl])]

| JAL (reg, lbl) ->
begin match reg with
| RA -> [LInstr ("call", [pp_remote_label lbl] )]
| _ -> [LInstr ("jalr", [pp_register reg; pp_remote_label lbl] )]
end
| JAL (RA, lbl) ->
[LInstr ("call", [pp_remote_label lbl])]

| POPPC ->
[ LInstr ("lw", [ pp_register RA; pp_reg_address_aux (pp_register SP) None None None]);
LInstr ("addi", [ pp_register SP; pp_register SP; "4"]);
LInstr ("ret", [ ]) ]
| JAL _
| CALL _
| POPPC ->
assert false

| SysCall op ->
[LInstr ("call", [ pp_syscall op ])]
Expand Down Expand Up @@ -240,11 +232,21 @@ let pp_fun (fn, fd) =
else []
in let pre =
let fn = escape fn in
if fd.asm_fd_export then [ LLabel (mangle fn); LLabel fn; LInstr ("addi", [ pp_register SP; pp_register SP; "-4"]); LInstr ("sw", [ pp_register RA; pp_reg_address_aux (pp_register SP) None None None])] else []
if fd.asm_fd_export then
[ LLabel (mangle fn);
LLabel fn;
LInstr ("addi", [ pp_register SP; pp_register SP; "-4"]);
LInstr ("sw", [ pp_register RA; pp_reg_address_aux (pp_register SP) None None None])]
else []
in
let body = pp_body fn fd.asm_fd_body in
(* TODO_RISCV: Review. *)
let pos = if fd.asm_fd_export then pp_instr fn POPPC else [] in
let pos =
if fd.asm_fd_export then
[ LInstr ("lw", [ pp_register RA; pp_reg_address_aux (pp_register SP) None None None]);
LInstr ("addi", [ pp_register SP; pp_register SP; "4"]);
LInstr ("ret", [ ]) ]
else []
in
head @ pre @ body @ pos

let pp_funcs funs = List.concat_map pp_fun funs
Expand Down
16 changes: 12 additions & 4 deletions compiler/src/printer.ml
Original file line number Diff line number Diff line change
Expand Up @@ -375,14 +375,22 @@ let pp_saved_stack ~debug fmt = function
let pp_tmp_option ~debug =
Format.pp_print_option (fun fmt x -> Format.fprintf fmt " [tmp = %a]" (pp_var ~debug) (Conv.var_of_cvar x))

let pp_ra_call ~debug =
Format.pp_print_option (fun fmt ra_call -> Format.fprintf fmt "%a -> " (pp_var ~debug) (Conv.var_of_cvar ra_call))

let pp_ra_return ~debug =
Format.pp_print_option (fun fmt ra_return -> Format.fprintf fmt " -> %a" (pp_var ~debug) (Conv.var_of_cvar ra_return))

let pp_return_address ~debug fmt = function
| Expr.RAreg (x, o) ->
Format.fprintf fmt "%a%a" (pp_var ~debug) (Conv.var_of_cvar x) (pp_tmp_option ~debug) o

| Expr.RAstack(Some x, z, o) ->
Format.fprintf fmt "%a, RSP + %a%a" (pp_var ~debug) (Conv.var_of_cvar x) Z.pp_print (Conv.z_of_cz z) (pp_tmp_option ~debug) o
| Expr.RAstack(None, z, o) ->
Format.fprintf fmt "RSP + %a%a" Z.pp_print (Conv.z_of_cz z) (pp_tmp_option ~debug) o
| Expr.RAstack(ra_call, ra_return, z, o) ->
Format.fprintf fmt "%aRSP + %a%a%a"
(pp_ra_call ~debug) ra_call Z.pp_print (Conv.z_of_cz z)
(pp_tmp_option ~debug) o
(pp_ra_return ~debug) ra_return

| Expr.RAnone -> Format.fprintf fmt "_"

let pp_sprog ~debug pd asmOp fmt ((funcs, p_extra):('info, 'asm) Prog.sprog) =
Expand Down
107 changes: 79 additions & 28 deletions compiler/src/regalloc.ml
Original file line number Diff line number Diff line change
Expand Up @@ -454,21 +454,36 @@ let collect_variables_cb ~(allvars: bool) (excluded: Sv.t) (fresh: unit -> int)
let n = fresh () in
Hv.add tbl v n

let collect_variables_aux ~(allvars: bool) (excluded: Sv.t) (fresh: unit -> int) (tbl: int Hv.t) (extra: var option) (f: ('info, 'asm) func) : unit =
let collect_variables_aux ~(allvars: bool) (excluded: Sv.t) (fresh: unit -> int) (tbl: int Hv.t) (extra: Sv.t) (f: ('info, 'asm) func) : unit =
let get v = collect_variables_cb ~allvars excluded fresh tbl v in
iter_variables get f;
match extra with Some x -> get x | None -> ()
Sv.iter get extra

let collect_variables ~(allvars: bool) (excluded:Sv.t) (f: ('info, 'asm) func) : int Hv.t * int =
let fresh, total = make_counter () in
let tbl : int Hv.t = Hv.create 97 in
collect_variables_aux ~allvars excluded fresh tbl None f;
collect_variables_aux ~allvars excluded fresh tbl Sv.empty f;
tbl, total ()

(* TODO: should StackDirect be just StackByReg (None, None, None)? *)
type retaddr =
| StackDirect
| StackByReg of var * var option
(* ra is passed on the stack and read from the stack *)
| StackByReg of var * var option * var option
(* StackByReg (ra_call, ra_return, tmp) *)
| ByReg of var * var option
(* ByReg (ra, tmp) *)

let vars_retaddr ra =
let oadd ov s =
match ov with
| None -> s
| Some v -> Sv.add v s
in
match ra with
| StackByReg (ra_call, ra_return, tmp) -> oadd tmp (oadd ra_return (Sv.singleton ra_call))
| ByReg (ra, tmp) -> oadd tmp (Sv.singleton ra)
| StackDirect -> Sv.empty

let collect_variables_in_prog
~(allvars: bool)
Expand All @@ -479,12 +494,8 @@ let collect_variables_in_prog
let fresh, total = make_counter () in
let tbl : int Hv.t = Hv.create 97 in
List.iter (fun f ->
let extra, tmp =
match Hf.find return_adresses f.f_name with
| StackByReg (v, tmp) | ByReg (v, tmp) -> Some v, tmp
| StackDirect -> None, None in
collect_variables_aux ~allvars excluded fresh tbl extra f;
Option.may (collect_variables_cb ~allvars excluded fresh tbl) tmp) f;
let extra = vars_retaddr (Hf.find return_adresses f.f_name) in
collect_variables_aux ~allvars excluded fresh tbl extra f) f;
List.iter (collect_variables_cb ~allvars excluded fresh tbl) all_reg;
tbl, total ()

Expand Down Expand Up @@ -694,10 +705,27 @@ let allocate_forced_registers return_addresses translate_var nv (vars: int Hv.t)
if FInfo.is_export f.f_cc then alloc_args loc identity f.f_args;
if FInfo.is_export f.f_cc then alloc_ret loc L.unloc f.f_ret;
alloc_stmt f.f_body;
match Hf.find return_addresses f.f_name, Arch.callstyle with
| (StackByReg (ra,_) | ByReg (ra, _)), Arch_full.ByReg (Some r) ->
match Arch.callstyle with
| Arch_full.ByReg { call = Some r; return } ->
begin match Hf.find return_addresses f.f_name with
| StackDirect -> ()
| StackByReg (ra_call, ra_return, _) ->
let i = Hv.find vars ra_call in
allocate_one nv vars (Location.i_loc f.f_loc []) cnf ra_call i r a;
if return then begin
match ra_return with
| Some ra_return ->
let i = Hv.find vars ra_return in
allocate_one nv vars (Location.i_loc f.f_loc []) cnf ra_return i r a
| None ->
(* calling convention requires the return address to be in a register,
but there is no booked register. This must not happen. *)
assert false
end
| ByReg (ra, _) ->
let i = Hv.find vars ra in
allocate_one nv vars (Location.i_loc f.f_loc []) cnf ra i r a
end
| _ -> ()

(* Returns a variable from [regs] that is allocated to a friend variable of [i]. Defaults to [dflt]. *)
Expand Down Expand Up @@ -943,7 +971,7 @@ let subroutine_ra_by_stack f =
| Subroutine _ ->
match Arch.callstyle with
| Arch_full.StackDirect -> true
| Arch_full.ByReg oreg ->
| Arch_full.ByReg { call = oreg } ->
let dfl = oreg <> None && has_call_or_syscall f.f_body in
match f.f_annot.retaddr_kind with
| None -> dfl
Expand Down Expand Up @@ -1108,15 +1136,15 @@ let global_allocation translate_var get_internal_size (funcs: ('info, 'asm) func
let preprocess f =
let f = f |> fill_in_missing_names |> Ssa.split_live_ranges false in
Hf.add liveness_table f.f_name (Liveness.live_fd true f);
(* compute where will be store the return address *)
(* compute where the return address will be stored *)
let ra =
match f.f_cc with
| Export _ -> StackDirect
| Internal -> assert false
| Subroutine _ ->
match Arch.callstyle with
| Arch_full.StackDirect -> StackDirect
| Arch_full.ByReg oreg ->
| Arch_full.ByReg { call = oreg; return } ->
let dfl = oreg <> None && has_call_or_syscall f.f_body in
let r = V.mk ("ra_"^f.f_name.fn_name) (Reg(Normal,Direct)) (tu Arch.reg_size) f.f_loc [] in
(* Fixme: Add an option in Arch to say when the tmp reg is needed *)
Expand All @@ -1130,7 +1158,14 @@ let global_allocation translate_var get_internal_size (funcs: ('info, 'asm) func
match f.f_annot.retaddr_kind with
| None -> dfl
| Some k -> dfl || k = OnStack in
if rastack then StackByReg (r, tmp)
if rastack then
let r_return =
if return then
let r_return = V.mk ("ra_"^f.f_name.fn_name) (Reg(Normal,Direct)) (tu Arch.reg_size) f.f_loc [] in
Some r_return
else None
in
StackByReg (r, r_return, tmp)
else ByReg (r, tmp) in
Hf.add return_addresses f.f_name ra;
let written =
Expand All @@ -1139,10 +1174,7 @@ let global_allocation translate_var get_internal_size (funcs: ('info, 'asm) func
match f.f_cc with
| (Export _ | Internal) -> written
| Subroutine _ ->
match ra with
| StackDirect -> written
| StackByReg (r, None) | ByReg (r, None) -> Sv.add r written
| StackByReg (r, Some t) | ByReg (r, Some t) -> Sv.add t (Sv.add r written)
Sv.union (vars_retaddr ra) written
in
let killed_by_calls =
Mf.fold (fun fn _locs acc -> Sv.union (killed fn) acc)
Expand Down Expand Up @@ -1217,14 +1249,32 @@ let global_allocation translate_var get_internal_size (funcs: ('info, 'asm) func
List.fold_left (fun cnf x -> conflicts_add_one Arch.pointer_data Arch.reg_size Arch.asmOp vars tr Lnone ra x cnf) in
List.fold_left (fun a f ->
match Hf.find return_addresses f.f_name with
| ByReg (ra, None) | StackByReg(ra,None) ->
doit ra a f.f_args
| ByReg (ra, Some tmp) | StackByReg(ra,Some tmp) ->
| StackDirect -> a
| StackByReg (ra_call, ra_return, tmp) ->
(* ra_call conflicts with function arguments *)
let a = doit ra_call a f.f_args in
let a =
match ra_return with
| Some ra_return ->
(* ra_return conflicts with function results *)
doit ra_return a (List.map L.unloc f.f_ret)
| None -> a
in
begin match tmp with
| Some tmp ->
(* tmp register used to increment the stack conflicts with function arguments and results *)
let a = doit tmp a f.f_args in
doit tmp a (List.map L.unloc f.f_ret)
| None -> a
end
| ByReg (ra, tmp) ->
let a = doit ra a f.f_args in
(* tmp register used to increment the stack conflicts with function arguments and results *)
let a = doit tmp a f.f_args in
doit tmp a (List.map L.unloc f.f_ret)
| StackDirect -> a)
match tmp with
| Some tmp ->
(* tmp register used to increment the stack conflicts with function arguments and results *)
let a = doit tmp a f.f_args in
doit tmp a (List.map L.unloc f.f_ret)
| None -> a)
conflicts funcs in
(* Inter-procedural conflicts *)
let conflicts =
Expand Down Expand Up @@ -1304,7 +1354,8 @@ let alloc_prog translate_var (has_stack: ('info, 'asm) func -> 'a -> bool) get_i
let ro_return_address =
match Hf.find return_addresses f.f_name with
| StackDirect -> StackDirect
| StackByReg(r, tmp) -> StackByReg (subst r, Option.map subst tmp)
| StackByReg(ra_call, ra_return, tmp) ->
StackByReg (subst ra_call, Option.map subst ra_return, Option.map subst tmp)
| ByReg(r, tmp) -> ByReg (subst r, Option.map subst tmp) in
let ro_to_save = if FInfo.is_export f.f_cc then Sv.elements to_save else [] in
e, { ro_to_save ; ro_rsp ; ro_return_address }, f
Expand Down
2 changes: 1 addition & 1 deletion compiler/src/regalloc.mli
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ val fill_in_missing_names : ('info, 'asm) Prog.func -> ('info, 'asm) Prog.func

type retaddr =
| StackDirect
| StackByReg of var * var option
| StackByReg of var * var option * var option
| ByReg of var * var option

type reg_oracle_t = {
Expand Down
2 changes: 1 addition & 1 deletion compiler/src/riscv_arch_full.ml
Original file line number Diff line number Diff line change
Expand Up @@ -51,5 +51,5 @@ module Riscv (Lowering_params : Riscv_input) : Arch_full.Core_arch = struct

let pp_asm = Pp_riscv.print_prog

let callstyle = Arch_full.ByReg (Some RA)
let callstyle = Arch_full.ByReg { call = Some RA; return = true }
end
7 changes: 4 additions & 3 deletions compiler/src/stackAlloc.ml
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,7 @@ let memory_analysis pp_err ~debug up =
sao_rsp = saved_stack;
sao_return_address =
(* This is a dummy value it will be fixed in fix_csao *)
RAstack (None, Conv.cz_of_int 0, None)
RAstack (None, None, Conv.cz_of_int 0, None)
} in
Hf.replace atbl fn csao
in
Expand All @@ -339,8 +339,9 @@ let memory_analysis pp_err ~debug up =
Stack_alloc.{ csao with
sao_return_address =
match ro.ro_return_address with
| StackDirect -> RAstack (None, Conv.cz_of_int 0, None) (* FIXME stackDirect should provide a tmp register *)
| StackByReg (r, tmp) -> RAstack (Some (Conv.cvar_of_var r), Conv.cz_of_int 0, Option.map Conv.cvar_of_var tmp)
| StackDirect -> RAstack (None, None, Conv.cz_of_int 0, None) (* FIXME stackDirect should provide a tmp register *)
| StackByReg (ra_call, ra_return, tmp) ->
RAstack (Some (Conv.cvar_of_var ra_call), Option.map Conv.cvar_of_var ra_return, Conv.cz_of_int 0, Option.map Conv.cvar_of_var tmp)
| ByReg (r, tmp) -> RAreg (Conv.cvar_of_var r, Option.map Conv.cvar_of_var tmp)
} in
Hf.replace atbl fn csao
Expand Down
4 changes: 2 additions & 2 deletions compiler/src/varalloc.ml
Original file line number Diff line number Diff line change
Expand Up @@ -396,13 +396,13 @@ let alloc_stack_fd callstyle pd get_info gtbl fd =
false (* For export function ra is not counted in the frame *)
| Subroutine _ ->
match callstyle with
| Arch_full.StackDirect ->
| Arch_full.StackDirect ->
if fd.f_annot.retaddr_kind = Some OnReg then
Utils.warning Always (L.i_loc fd.f_loc [])
"for function %s, return address by reg not allowed for that architecture, annotation is ignored"
fd.f_name.fn_name;
true
| Arch_full.ByReg oreg -> (* oreg = Some r implies that all call use r,
| Arch_full.ByReg { call = oreg } -> (* oreg = Some r implies that all call use r,
so if the function performs some call r will be overwritten,
so ra need to be saved on stack *)
let dfl = oreg <> None && has_call_or_syscall fd.f_body in
Expand Down
Loading

0 comments on commit d03e378

Please sign in to comment.