From 58ad1ac9d63fdec169f581a2d02bbefaf6c7c562 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jan=20Wo=C5=BAniak?= Date: Wed, 25 Oct 2023 19:03:22 +0200 Subject: [PATCH] feat: Support generic types in entry points --- .../generic_contract/src/contract.rs | 30 +++++++++--- .../src/custom_and_generic.rs | 4 +- .../contracts/generic_contract/src/cw1.rs | 4 +- .../contracts/generic_contract/src/generic.rs | 4 +- .../generic_iface_on_contract/src/contract.rs | 4 ++ sylvia-derive/src/interfaces.rs | 2 +- sylvia-derive/src/lib.rs | 5 +- sylvia-derive/src/message.rs | 49 ++++++++++++------- sylvia-derive/src/parser.rs | 45 ++++++++++++++--- 9 files changed, 111 insertions(+), 36 deletions(-) diff --git a/examples/contracts/generic_contract/src/contract.rs b/examples/contracts/generic_contract/src/contract.rs index 4fffc93e..12f605f9 100644 --- a/examples/contracts/generic_contract/src/contract.rs +++ b/examples/contracts/generic_contract/src/contract.rs @@ -1,4 +1,5 @@ use cosmwasm_std::{Reply, Response, StdResult}; +use cw_storage_plus::Item; use serde::de::DeserializeOwned; use serde::Deserialize; use sylvia::types::{ @@ -6,32 +7,48 @@ use sylvia::types::{ }; use sylvia::{contract, schemars}; -pub struct GenericContract( - std::marker::PhantomData<( +#[cfg(not(feature = "library"))] +use sylvia::entry_points; + +pub struct GenericContract< + InstantiateParam, + ExecParam, + QueryParam, + MigrateParam, + RetType, + FieldType, +> { + _field: Item<'static, FieldType>, + _phantom: std::marker::PhantomData<( InstantiateParam, ExecParam, QueryParam, MigrateParam, RetType, )>, -); +} +#[cfg_attr(not(feature = "library"), entry_points(generics))] #[contract] #[messages(cw1 as Cw1: custom(msg))] #[messages(generic as Generic: custom(msg))] #[messages(custom_and_generic as CustomAndGeneric)] #[sv::custom(msg=SvCustomMsg)] -impl - GenericContract +impl + GenericContract where for<'msg_de> InstantiateParam: CustomMsg + Deserialize<'msg_de> + 'msg_de, ExecParam: CustomMsg + DeserializeOwned + 'static, QueryParam: CustomMsg + DeserializeOwned + 'static, MigrateParam: CustomMsg + DeserializeOwned + 'static, RetType: CustomMsg + DeserializeOwned + 'static, + FieldType: 'static, { pub const fn new() -> Self { - Self(std::marker::PhantomData) + Self { + _field: Item::new("field"), + _phantom: std::marker::PhantomData, + } } #[msg(instantiate)] @@ -92,6 +109,7 @@ mod tests { SvCustomMsg, super::SvCustomMsg, super::SvCustomMsg, + String, _, > = CodeId::store_code(&app); diff --git a/examples/contracts/generic_contract/src/custom_and_generic.rs b/examples/contracts/generic_contract/src/custom_and_generic.rs index d1ce0b12..2dfadcd0 100644 --- a/examples/contracts/generic_contract/src/custom_and_generic.rs +++ b/examples/contracts/generic_contract/src/custom_and_generic.rs @@ -6,7 +6,7 @@ use sylvia::types::{ExecCtx, QueryCtx, SvCustomMsg}; #[contract(module = crate::contract)] #[messages(custom_and_generic as CustomAndGeneric)] #[sv::custom(msg=sylvia::types::SvCustomMsg)] -impl +impl CustomAndGeneric for crate::contract::GenericContract< InstantiateParam, @@ -14,6 +14,7 @@ impl QueryParam, MigrateParam, RetType, + FieldType, > { type Error = StdError; @@ -52,6 +53,7 @@ mod tests { SvCustomMsg, SvCustomMsg, sylvia::types::SvCustomMsg, + String, _, >::store_code(&app); diff --git a/examples/contracts/generic_contract/src/cw1.rs b/examples/contracts/generic_contract/src/cw1.rs index 4c5c125c..410bd13d 100644 --- a/examples/contracts/generic_contract/src/cw1.rs +++ b/examples/contracts/generic_contract/src/cw1.rs @@ -6,13 +6,14 @@ use sylvia::types::{ExecCtx, QueryCtx}; #[contract(module = crate::contract)] #[messages(cw1 as Cw1)] #[sv::custom(msg=sylvia::types::SvCustomMsg)] -impl Cw1 +impl Cw1 for crate::contract::GenericContract< InstantiateParam, ExecParam, QueryParam, MigrateParam, RetType, + FieldType, > { type Error = StdError; @@ -49,6 +50,7 @@ mod tests { SvCustomMsg, SvCustomMsg, sylvia::types::SvCustomMsg, + String, _, >::store_code(&app); diff --git a/examples/contracts/generic_contract/src/generic.rs b/examples/contracts/generic_contract/src/generic.rs index 8504c8ae..27176d94 100644 --- a/examples/contracts/generic_contract/src/generic.rs +++ b/examples/contracts/generic_contract/src/generic.rs @@ -6,7 +6,7 @@ use sylvia::types::{ExecCtx, QueryCtx, SvCustomMsg}; #[contract(module = crate::contract)] #[messages(generic as Generic)] #[sv::custom(msg=SvCustomMsg)] -impl +impl Generic for crate::contract::GenericContract< InstantiateParam, @@ -14,6 +14,7 @@ impl QueryParam, MigrateParam, RetType, + FieldType, > { type Error = StdError; @@ -58,6 +59,7 @@ mod tests { SvCustomMsg, SvCustomMsg, sylvia::types::SvCustomMsg, + String, _, > = CodeId::store_code(&app); diff --git a/examples/contracts/generic_iface_on_contract/src/contract.rs b/examples/contracts/generic_iface_on_contract/src/contract.rs index 0a94398b..a8f92805 100644 --- a/examples/contracts/generic_iface_on_contract/src/contract.rs +++ b/examples/contracts/generic_iface_on_contract/src/contract.rs @@ -2,8 +2,12 @@ use cosmwasm_std::{Response, StdResult}; use sylvia::types::{InstantiateCtx, SvCustomMsg}; use sylvia::{contract, schemars}; +#[cfg(not(feature = "library"))] +use sylvia::entry_points; + pub struct NonGenericContract; +#[cfg_attr(not(feature = "library"), entry_points)] #[contract] #[messages(generic as Generic: custom(msg))] #[messages(custom_and_generic as CustomAndGeneric)] diff --git a/sylvia-derive/src/interfaces.rs b/sylvia-derive/src/interfaces.rs index 0e75edd1..fbfe9718 100644 --- a/sylvia-derive/src/interfaces.rs +++ b/sylvia-derive/src/interfaces.rs @@ -159,7 +159,7 @@ impl Interfaces { quote! {} }; - let type_name = msg_ty.as_accessor_name(); + let type_name = msg_ty.as_accessor_name(false); quote! { <#module ::sv::Api #generics as #sylvia ::types::InterfaceApi> :: #type_name :: response_schemas_impl() } diff --git a/sylvia-derive/src/lib.rs b/sylvia-derive/src/lib.rs index a002c044..a5c6ca8e 100644 --- a/sylvia-derive/src/lib.rs +++ b/sylvia-derive/src/lib.rs @@ -258,9 +258,10 @@ pub fn entry_points(attr: TokenStream, item: TokenStream) -> TokenStream { #[cfg(not(tarpaulin_include))] fn entry_points_impl(attr: TokenStream2, item: TokenStream2) -> TokenStream2 { - fn inner(_attr: TokenStream2, item: TokenStream2) -> syn::Result { + fn inner(attr: TokenStream2, item: TokenStream2) -> syn::Result { + let attrs: parser::EntryPointArgs = parse2(attr)?; let input: ItemImpl = parse2(item)?; - let expanded = EntryPoints::new(&input).emit(); + let expanded = EntryPoints::new(&input, attrs).emit(); Ok(quote! { #input diff --git a/sylvia-derive/src/message.rs b/sylvia-derive/src/message.rs index fceb8c58..9a67a37f 100644 --- a/sylvia-derive/src/message.rs +++ b/sylvia-derive/src/message.rs @@ -3,7 +3,7 @@ use crate::crate_module; use crate::interfaces::Interfaces; use crate::parser::{ parse_associated_custom_type, parse_struct_message, ContractErrorAttr, ContractMessageAttr, - Custom, MsgAttr, MsgType, OverrideEntryPoints, + Custom, EntryPointArgs, MsgAttr, MsgType, OverrideEntryPoints, }; use crate::strip_generics::StripGenerics; use crate::utils::{ @@ -16,11 +16,12 @@ use proc_macro_error::emit_error; use quote::{quote, ToTokens}; use syn::fold::Fold; use syn::parse::{Parse, Parser}; +use syn::punctuated::Punctuated; use syn::spanned::Spanned; use syn::visit::Visit; use syn::{ - parse_quote, Attribute, GenericParam, Ident, ItemImpl, ItemTrait, Pat, PatType, Path, - ReturnType, Signature, TraitItem, Type, WhereClause, WherePredicate, + parse_quote, Attribute, GenericArgument, GenericParam, Ident, ItemImpl, ItemTrait, Pat, + PatType, Path, ReturnType, Signature, Token, TraitItem, Type, WhereClause, WherePredicate, }; /// Representation of single struct message @@ -747,7 +748,7 @@ impl<'a> MsgVariant<'a> { let bracketed_generics = emit_bracketed_generics(generics); let interface_enum = quote! { < #module sv::Api #bracketed_generics as #sylvia ::types::InterfaceApi> }; - let type_name = msg_ty.as_accessor_name(); + let type_name = msg_ty.as_accessor_name(false); let name = Ident::new(&name.to_string().to_case(Case::Snake), name.span()); match msg_ty { @@ -790,7 +791,7 @@ impl<'a> MsgVariant<'a> { } = self; let params = fields.iter().map(|field| field.emit_method_field()); - let type_name = msg_ty.as_accessor_name(); + let type_name = msg_ty.as_accessor_name(false); let name = Ident::new(&name.to_string().to_case(Case::Snake), name.span()); match msg_ty { @@ -1023,12 +1024,9 @@ where custom_query: &Type, name: &Type, error: &Type, + contract_generics: &Option>, ) -> TokenStream { - let Self { - used_generics, - msg_ty, - .. - } = self; + let Self { msg_ty, .. } = self; let sylvia = crate_module(); let resp_type = match msg_ty { @@ -1038,16 +1036,19 @@ where let params = msg_ty.emit_ctx_params(custom_query); let values = msg_ty.emit_ctx_values(); let ep_name = msg_ty.emit_ep_name(); - let msg_name = msg_ty.emit_msg_name(true); - let bracketed_generics = emit_bracketed_generics(used_generics); + let bracketed_generics = match &contract_generics { + Some(generics) => quote! { ::< #generics > }, + None => quote! {}, + }; + let associated_name = msg_ty.as_accessor_name(true); quote! { #[#sylvia ::cw_std::entry_point] pub fn #ep_name ( #params , - msg: sv:: #msg_name #bracketed_generics, + msg: < #name < #contract_generics > as #sylvia ::types::ContractApi> :: #associated_name, ) -> Result<#resp_type, #error> { - msg.dispatch(&#name ::new() , ( #values )).map_err(Into::into) + msg.dispatch(&#name #bracketed_generics ::new() , ( #values )).map_err(Into::into) } } } @@ -1608,10 +1609,11 @@ pub struct EntryPoints<'a> { override_entry_points: OverrideEntryPoints, generics: Vec<&'a GenericParam>, where_clause: &'a Option, + attrs: EntryPointArgs, } impl<'a> EntryPoints<'a> { - pub fn new(source: &'a ItemImpl) -> Self { + pub fn new(source: &'a ItemImpl, attrs: EntryPointArgs) -> Self { let sylvia = crate_module(); let name = StripGenerics.fold_type(*source.self_ty.clone()); let override_entry_points = OverrideEntryPoints::new(&source.attrs); @@ -1643,6 +1645,7 @@ impl<'a> EntryPoints<'a> { override_entry_points, generics, where_clause, + attrs, } } @@ -1655,6 +1658,7 @@ impl<'a> EntryPoints<'a> { override_entry_points, generics, where_clause, + attrs, } = self; let sylvia = crate_module(); @@ -1683,6 +1687,10 @@ impl<'a> EntryPoints<'a> { .iter() .map(|variant| variant.function_name.clone()) .next(); + let contract_generics = match &attrs.generics { + Some(generics) => quote! { ::< #generics > }, + None => quote! {}, + }; #[cfg(not(tarpaulin_include))] { @@ -1696,6 +1704,7 @@ impl<'a> EntryPoints<'a> { &custom_query, name, error, + &attrs.generics, ), }, ); @@ -1706,7 +1715,13 @@ impl<'a> EntryPoints<'a> { let migrate = if migrate_not_overridden && migrate_variants.get_only_variant().is_some() { - migrate_variants.emit_default_entry_point(&custom_msg, &custom_query, name, error) + migrate_variants.emit_default_entry_point( + &custom_msg, + &custom_query, + name, + error, + &attrs.generics, + ) } else { quote! {} }; @@ -1722,7 +1737,7 @@ impl<'a> EntryPoints<'a> { env: #sylvia ::cw_std::Env, msg: #sylvia ::cw_std::Reply, ) -> Result<#sylvia ::cw_std::Response < #custom_msg >, #error> { - #name ::new(). #reply((deps, env).into(), msg).map_err(Into::into) + #name #contract_generics ::new(). #reply((deps, env).into(), msg).map_err(Into::into) } }, _ => quote! {}, diff --git a/sylvia-derive/src/parser.rs b/sylvia-derive/src/parser.rs index 709ed404..569e9f83 100644 --- a/sylvia-derive/src/parser.rs +++ b/sylvia-derive/src/parser.rs @@ -13,9 +13,10 @@ use syn::{ use crate::crate_module; use crate::strip_generics::StripGenerics; -/// Parser arguments for `contract` macro +/// Parsed arguments for `contract` macro pub struct ContractArgs { - /// Module name wrapping generated messages, by default no additional module is created + /// Module in which contract impl block is defined. + /// Used only while implementing `Interface` on `Contract`. pub module: Option, } @@ -46,6 +47,31 @@ impl Parse for ContractArgs { } } +/// Parsed arguments for `entry_points` macro +pub struct EntryPointArgs { + /// Types used in place of contracts generics. + pub generics: Option>, +} + +impl Parse for EntryPointArgs { + fn parse(input: ParseStream) -> Result { + if input.is_empty() { + return Ok(Self { generics: None }); + } + + let path: Path = input.parse()?; + + let generics = match path.segments.last() { + Some(segment) if segment.ident == "generics" => Some(extract_generics_from_path(&path)), + _ => return Err(Error::new(path.span(), "Expected `generics`")), + }; + + let _: Nothing = input.parse()?; + + Ok(Self { generics }) + } +} + /// Type of message to be generated #[derive(PartialEq, Eq, Debug, Clone, Copy)] pub enum MsgType { @@ -158,11 +184,16 @@ impl MsgType { } } - pub fn as_accessor_name(&self) -> Option { + pub fn as_accessor_name(&self, is_wrapper: bool) -> Option { match self { + MsgType::Exec if is_wrapper => Some(parse_quote! { ContractExec }), + MsgType::Query if is_wrapper => Some(parse_quote! { ContractQuery }), + MsgType::Instantiate => Some(parse_quote! { Instantiate }), MsgType::Exec => Some(parse_quote! { Exec }), MsgType::Query => Some(parse_quote! { Query }), - _ => None, + MsgType::Migrate => Some(parse_quote! { Migrate }), + MsgType::Sudo => Some(parse_quote! { Sudo }), + MsgType::Reply => Some(parse_quote! { Reply }), } } } @@ -295,7 +326,7 @@ fn interface_has_custom(content: ParseStream) -> Result { Ok(customs) } -fn extract_generics_from_path(module: &mut Path) -> Punctuated { +fn extract_generics_from_path(module: &Path) -> Punctuated { let generics = module.segments.last().map(|segment| { match segment.arguments.clone(){ PathArguments::AngleBracketed(generics) => { @@ -322,8 +353,8 @@ impl Parse for ContractMessageAttr { let content; parenthesized!(content in input); - let mut module = content.parse()?; - let generics = extract_generics_from_path(&mut module); + let module = content.parse()?; + let generics = extract_generics_from_path(&module); let module = StripGenerics.fold_path(module); let _: Token![as] = content.parse()?;