From a5b250783c4610897ae21a7b662f92ee621a49fc Mon Sep 17 00:00:00 2001 From: Nadrieril Date: Thu, 9 Jan 2025 12:37:20 +0100 Subject: [PATCH] Properly substitute trait generics --- lib/AstOfLlbc.ml | 190 +++++++++++++++++++++++++++-------------------- 1 file changed, 109 insertions(+), 81 deletions(-) diff --git a/lib/AstOfLlbc.ml b/lib/AstOfLlbc.ml index dbb79d2..2917a1a 100644 --- a/lib/AstOfLlbc.ml +++ b/lib/AstOfLlbc.ml @@ -7,6 +7,7 @@ module C = struct include Charon.Expressions include Charon.Values include Charon.GAstUtils + include Charon.Substitute (* Fails if the variable is bound *) let expect_free_var = function @@ -54,6 +55,14 @@ type var_id = | ConstGenericVar of C.const_generic_var_id | Var of C.var_id * C.ety (* the ety aids code-generation, sometimes *) +(* Charon can have type-level variables (types, clause and const generics) + bound at two levels: each item has its own `generic_params`, and + additionally each method within a trait has a `generic_params` for the + method-specific generics. Charon distinguishes them with its `de_bruijn_var` + but we don't need that level of precision. In expression scope, only + `ItemBound` can exist. *) +type bound_var_origin = ItemBound | MethodBound + type env = { (* Lookup functions to resolve various id's into actual declarations. *) get_nth_function : C.FunDeclId.id -> C.fun_decl; @@ -103,8 +112,8 @@ type env = { "x": TCgArray (TBound 0, 0); (* types use the cg scope *) ], EBound 2 (* expressions refer to the copy of the cg var as an expression var *) *) - cg_binders : (C.const_generic_var_id * K.typ) list; - type_binders : C.type_var_id list; + cg_binders : ((bound_var_origin * C.const_generic_var_id) * K.typ) list; + type_binders : (bound_var_origin * C.type_var_id) list; binders : (var_id * K.typ) list; (* For printing. *) format_env : Charon.PrintLlbcAst.fmt_env; @@ -131,16 +140,29 @@ let snd3 (_, x, _) = x let thd3 (_, _, x) = x (* Suitable in types -- in expressions, use lookup_cg_in_expressions. *) -let lookup_cg_in_types env v1 = - let i, (_, t) = findi (fun (v2, _) -> v1 = v2) env.cg_binders in +let lookup_cg_in_types env (var : C.const_generic_var_id C.de_bruijn_var) = + let origin, id = + match var with + | Free id -> ItemBound, id + | Bound (_, id) -> MethodBound, id + in + let i, (_, t) = findi (fun (v2, _) -> (origin, id) = v2) env.cg_binders in i, t -let lookup_typ env (v1 : C.type_var_id) = - let i, _ = findi (( = ) v1) env.type_binders in +let lookup_typ env (var : C.type_var_id C.de_bruijn_var) = + let origin, id = + match var with + | Free id -> ItemBound, id + | Bound (_, id) -> MethodBound, id + in + let i, _ = findi (( = ) (origin, id)) env.type_binders in i -let push_type_binder env (t : C.type_var) = { env with type_binders = t.index :: env.type_binders } -let push_type_binders env (ts : C.type_var list) = List.fold_left push_type_binder env ts +let push_type_binder (origin : bound_var_origin) env (t : C.type_var) = + { env with type_binders = (origin, t.index) :: env.type_binders } + +let push_type_binders env (origin : bound_var_origin) (ts : C.type_var list) = + List.fold_left (push_type_binder origin) env ts (** Helpers: types *) @@ -321,7 +343,7 @@ let assert_cg_scalar = function let cg_of_const_generic env cg = match cg with - | C.CgVar var -> K.CgVar (fst (lookup_cg_in_types env (C.expect_free_var var))) + | C.CgVar var -> K.CgVar (fst (lookup_cg_in_types env var)) | C.CgValue (VScalar sv) -> CgConst (constant_of_scalar_value sv) | _ -> failwith @@ -336,7 +358,7 @@ let typ_of_literal_ty (_env : env) (ty : Charon.Types.literal_type) : K.typ = let rec typ_of_ty (env : env) (ty : Charon.Types.ty) : K.typ = match ty with - | TVar var -> K.TBound (lookup_typ env (C.expect_free_var var)) + | TVar var -> K.TBound (lookup_typ env var) | TLiteral t -> typ_of_literal_ty env t | TNever -> failwith "Impossible: Never" | TDynTrait _ -> failwith "TODO: dyn Trait" @@ -400,7 +422,7 @@ and maybe_cg_array env t cg = match cg with | CgValue _ -> K.TArray (typ_of_ty env t, constant_of_scalar_value (assert_cg_scalar cg)) | CgVar var -> - let id, cg_t = lookup_cg_in_types env (C.expect_free_var var) in + let id, cg_t = lookup_cg_in_types env var in assert (cg_t = K.TInt SizeT); K.TCgArray (typ_of_ty env t, id) | _ -> failwith "TODO: CgGlobal" @@ -442,14 +464,15 @@ let lookup_with_original_type env v1 = (* Const generic binders *) -let push_cg_binder env (t : C.const_generic_var) = +let push_cg_binder (origin : bound_var_origin) env (t : C.const_generic_var) = { env with - cg_binders = (t.index, typ_of_literal_ty env t.ty) :: env.cg_binders; + cg_binders = ((origin, t.index), typ_of_literal_ty env t.ty) :: env.cg_binders; binders = (ConstGenericVar t.index, typ_of_literal_ty env t.ty) :: env.binders; } -let push_cg_binders env (ts : C.const_generic_var list) = List.fold_left push_cg_binder env ts +let push_cg_binders env (origin : bound_var_origin) (ts : C.const_generic_var list) = + List.fold_left (push_cg_binder origin) env ts let push_binder env (t : C.var) = { env with binders = (Var (t.index, t.var_ty), typ_of_ty env t.var_ty) :: env.binders } @@ -746,8 +769,7 @@ let blocklisted_trait_decls = transitively, possibly called by this function, based on the trait bounds in its signature. *) type trait_clause_entry = - | ClauseMethod of - (C.generic_args * K.type_scheme * Charon.Types.name (* trait name *) * C.fun_sig) + | ClauseMethod of (Charon.Types.name (* trait name *) * C.fun_sig) | ClauseConstant of Charon.Types.name (* trait name *) * C.ty type trait_clause_mapping = ((C.trait_instance_id * string) * trait_clause_entry) list @@ -791,29 +813,68 @@ let rec build_trait_clause_mapping env (trait_clauses : C.trait_clause list) : t (List.length trait_decl.C.required_methods) (List.length trait_decl.C.provided_methods); + let clause_ref = C.Clause (Free clause_id) in + (* Substitute the value with the trait arguments. *) + let apply_trait_args : 'a. (C.subst -> 'a -> 'a) -> 'a -> 'a = + fun substitutor x -> + let bound_val = { C.item_binder_params = trait_decl.generics; item_binder_value = x } in + C.apply_args_to_item_binder clause_ref decl_generics substitutor bound_val + in + (* 1. Associated constants *) List.map (fun (item_name, typ) -> - (C.Clause (Free clause_id), item_name), ClauseConstant (trait_decl.C.item_meta.name, typ)) + let typ = apply_trait_args C.st_substitute_visitor#visit_ty typ in + (clause_ref, item_name), ClauseConstant (trait_decl.C.item_meta.name, typ)) trait_decl.C.consts (* 2. Trait methods *) @ List.map (fun (item_name, bound_fn) -> - let fun_decl_id = bound_fn.C.binder_value.C.fun_id in - let decl = env.get_nth_function fun_decl_id in - let ts = - { - K.n = List.length trait_decl.generics.types; - n_cgs = List.length trait_decl.generics.const_generics; - } + let bound_fn : C.fun_decl_ref C.binder = + apply_trait_args + (C.st_substitute_visitor#visit_binder C.st_substitute_visitor#visit_fun_decl_ref) + bound_fn in - ( (C.Clause (Free clause_id), item_name), - ClauseMethod (decl_generics, ts, trait_decl.C.item_meta.name, decl.C.signature) )) + + (* `bound_fn` is a mapping from the method generics to the full + generics of a fun_decl that contains the signature we care + about. E.g.: + + trait Hash + fn PRFxN(input: &[[u8; 33]; K]) -> [[u8; LEN]; K]; + + fn use_hash>() { ... } + + gives us a `bound_fn` that binds `LEN` and gives us `(32, + Clause 0, LEN)` to pass to the appropriate function item. We + retrieve its signature and substitute it to obtain the + signature we want: + + (LEN: size_t) -> &[[u8; 33]; 32] -> [[u8; LEN]; 32] + *) + let signature = + (* `fn_ref` uses variables bound in `bound_fn.binder_params`. *) + let fn_ref = bound_fn.binder_value in + (* `fn_decl.signature` uses variables bound in `fn_decl.signature.generics`. *) + let fn_decl = env.get_nth_function fn_ref.C.fun_id in + let subst = + C.make_subst_from_generics fn_decl.C.signature.C.generics fn_ref.fun_generics + in + (* Make `signature` use variables bound in `bound_fn.binder_params.` *) + let signature = C.st_substitute_visitor#visit_fun_sig subst fn_decl.C.signature in + (* This is nowa valid signature with the parameters we want. *) + { signature with generics = bound_fn.binder_params } + in + + (clause_ref, item_name), ClauseMethod (trait_decl.C.item_meta.name, signature)) (trait_decl.C.required_methods @ trait_decl.C.provided_methods) (* 1 + 2, recursively, for parent traits *) @ List.flatten (List.mapi (fun _i (parent_clause : C.trait_clause) -> + let parent_clause = + apply_trait_args C.st_substitute_visitor#visit_trait_clause parent_clause + in (* Mapping of the methods of the parent clause *) let m = build_trait_clause_mapping env [ parent_clause ] in List.map @@ -825,7 +886,7 @@ let rec build_trait_clause_mapping env (trait_clauses : C.trait_clause list) : t | _ -> fail "not a clause??" in let id : C.trait_instance_id = - ParentClause (Clause (Free clause_id), trait_decl_id, clause_id') + ParentClause (clause_ref, trait_decl_id, clause_id') in (id, m), v) m) @@ -853,7 +914,7 @@ let maybe_ts ts t = else t -let rec lookup_signature env depth signature : lookup_result = +let rec lookup_signature env (origin : bound_var_origin) depth signature : lookup_result = let { C.generics = { types = type_params; const_generics; trait_clauses; _ }; inputs; output; _ } = signature @@ -861,8 +922,8 @@ let rec lookup_signature env depth signature : lookup_result = L.log "Calls" "%s--> args: %s, ret: %s" depth (String.concat " ++ " (List.map (Charon.PrintTypes.ty_to_string env.format_env) inputs)) (Charon.PrintTypes.ty_to_string env.format_env output); - let env = push_cg_binders env const_generics in - let env = push_type_binders env type_params in + let env = push_cg_binders env origin const_generics in + let env = push_type_binders env origin type_params in let clause_mapping = build_trait_clause_mapping env trait_clauses in debug_trait_clause_mapping env clause_mapping; @@ -891,50 +952,18 @@ and mk_clause_binders_and_args env (clause_mapping : trait_clause_mapping) : (va List.map (fun ((clause_id, item_name), clause_entry) -> match clause_entry with - | ClauseMethod - ((clause_generics : C.generic_args), trait_ts, trait_name, (signature : C.fun_sig)) -> - (* Polymorphic signature for trait method has const generic for BOTH - trait-level generics and fn-level generics. Consider: - - trait Hash - fn PRFxN(input: &[[u8; 33]; K]) -> [[u8; LEN]; K]; - - which gives the signature: - - size_t -> size_t -> uint8_t[33size_t]* -> uint8_t[$0][$1] - *) - let _, t = typ_of_signature env signature in - (* We are in a function that has a trait clause of the form e.g. Hash. - cgs contains FOO, that's it. *) - let cgs = List.map (cg_of_const_generic env) clause_generics.C.const_generics in - let ts = List.map (typ_of_ty env) clause_generics.C.types in - (* A little bit of math to compute how many of these are on the method - itself *) - let f_ts = + | ClauseMethod (trait_name, signature) -> + let ts = { - K.n_cgs = List.length signature.C.generics.const_generics - List.length cgs; - n = List.length signature.C.generics.types - List.length ts; + K.n = List.length signature.generics.types; + K.n_cgs = List.length signature.generics.const_generics; } in - L.log "TraitClauses" "%s has %d fn-level const generics" item_name f_ts.n_cgs; - L.log "TraitClauses" "%s has %d fn-level type params" item_name f_ts.n; - L.log "TraitClauses" "About to substitute cgs=%a, ts=%a into %a" pcgs cgs ptyps ts ptyp t; - let t = Krml.DeBruijn.(subst_tn' f_ts.n ts (subst_ctn'' f_ts.n_cgs cgs t)) in - L.log "TraitClauses" "After subtitution t=%a" ptyp t; - let ret, args = Krml.Helpers.flatten_arrow t in - let _, args = Krml.KList.split trait_ts.K.n_cgs args in - let t = Krml.Helpers.fold_arrow args ret in - L.log "TraitClauses" "After chopping t=%a" ptyp t; - let t = maybe_ts f_ts t in + let _, t = typ_of_signature env MethodBound signature in + let t = maybe_ts ts t in L.log "TraitClauses" "After ts addition t=%a" ptyp t; let pretty_name = string_of_name env trait_name ^ "_" ^ item_name in - let ts = - { - K.n = List.length signature.generics.types - trait_ts.n; - K.n_cgs = List.length signature.generics.const_generics - trait_ts.n_cgs; - } - in (* TODO: figure out why this fails for e.g. Iterator.rev *) assert (ts.n_cgs >= 0 && ts.n >= 0); TraitClauseMethod { pretty_name; clause_id; item_name; ts }, t @@ -946,9 +975,9 @@ and mk_clause_binders_and_args env (clause_mapping : trait_clause_mapping) : (va (* Transforms a lookup result into a usable type, taking into account the fact that the internal Ast is ML-style and does not have zero-argument functions. *) -and typ_of_signature env signature = +and typ_of_signature env (origin : bound_var_origin) signature = let { cg_types = const_generics_ts; arg_types = inputs; ret_type = output; ts; _ } = - lookup_signature env "" signature + lookup_signature env origin "" signature in let adjusted_inputs = const_generics_ts @ inputs in @@ -964,12 +993,11 @@ and debug_trait_clause_mapping env (mapping : trait_clause_mapping) = List.iter (fun ((clause_id, item_name), clause_entry) -> match clause_entry with - | ClauseMethod (_, ts, trait_name, signature) -> - let _, t = typ_of_signature env signature in - L.log "TraitClauses" - "%s (a.k.a. %s)::%s: %a has trait-level %d const generics, %d type vars" + | ClauseMethod (trait_name, signature) -> + let _, t = typ_of_signature env MethodBound signature in + L.log "TraitClauses" "%s (a.k.a. %s)::%s: %a" (Charon.PrintTypes.trait_instance_id_to_string env.format_env clause_id) - (string_of_name env trait_name) item_name ptyp t ts.K.n_cgs ts.n + (string_of_name env trait_name) item_name ptyp t | ClauseConstant (trait_name, t) -> let t = typ_of_ty env t in L.log "TraitClauses" "%s (a.k.a. %s)::%s: associated constant %a" @@ -999,7 +1027,7 @@ let lookup_fun (env : env) depth (f : C.fn_ptr) : K.expr' * lookup_result = let { C.item_meta; signature; _ } = env.get_nth_function fun_id in let lid = lid_of_name env item_meta.name in L.log "Calls" "%s--> name: %a" depth plid lid; - K.EQualified lid, lookup_signature env depth signature + K.EQualified lid, lookup_signature env ItemBound depth signature in match f.func with @@ -1735,8 +1763,8 @@ let decl_of_id (env : env) (id : C.any_decl_id) : K.decl option = assert (def_id = id); let name = lid_of_name env name in - let env = push_cg_binders env const_generics in - let env = push_type_binders env type_params in + let env = push_cg_binders env ItemBound const_generics in + let env = push_type_binders env ItemBound type_params in match kind with | Union _ | Opaque | TError _ -> None @@ -1819,14 +1847,14 @@ let decl_of_id (env : env) (id : C.any_decl_id) : K.decl option = None | None, _ -> (* Opaque function *) - let { K.n_cgs; n }, t = typ_of_signature env signature in + let { K.n_cgs; n }, t = typ_of_signature env ItemBound signature in Some (K.DExternal (None, [], n_cgs, n, name, t, [])) | Some { locals; body; _ }, _ -> if Option.is_some decl.is_global_initializer then None else - let env = push_cg_binders env signature.C.generics.const_generics in - let env = push_type_binders env signature.C.generics.types in + let env = push_cg_binders env ItemBound signature.C.generics.const_generics in + let env = push_type_binders env ItemBound signature.C.generics.types in L.log "AstOfLlbc" "ty of locals: %s" (String.concat " ++ "