From 629f3a2fe1bcde3d40c01b7809ac70deb54c6bc2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jan=20Wo=C5=BAniak?= Date: Wed, 27 Sep 2023 17:45:12 +0200 Subject: [PATCH] feat: Support generic interface implemented on contract --- sylvia-derive/src/check_generics.rs | 68 ++++++++++++++----- sylvia-derive/src/input.rs | 71 ++++++++++++++------ sylvia-derive/src/interfaces.rs | 48 ++++++++----- sylvia-derive/src/message.rs | 100 ++++++++++++++++------------ sylvia-derive/src/multitest.rs | 65 ++++++++---------- sylvia-derive/src/parser.rs | 61 +++-------------- sylvia-derive/src/utils.rs | 25 ++++--- sylvia/tests/generics.rs | 54 ++++++++++++++- 8 files changed, 291 insertions(+), 201 deletions(-) diff --git a/sylvia-derive/src/check_generics.rs b/sylvia-derive/src/check_generics.rs index edba0ce4..c31ca402 100644 --- a/sylvia-derive/src/check_generics.rs +++ b/sylvia-derive/src/check_generics.rs @@ -1,26 +1,57 @@ use syn::visit::Visit; -use syn::GenericParam; +use syn::{parse_quote, GenericArgument, GenericParam, Type}; + +pub trait GetPath { + fn get_path(&self) -> Option; +} + +impl GetPath for GenericParam { + fn get_path(&self) -> Option { + match self { + GenericParam::Type(ty) => { + let ident = &ty.ident; + Some(parse_quote! { #ident }) + } + _ => None, + } + } +} + +impl GetPath for GenericArgument { + fn get_path(&self) -> Option { + match self { + GenericArgument::Type(Type::Path(path)) => { + let path = &path.path; + Some(parse_quote! { #path }) + } + _ => None, + } + } +} #[derive(Debug)] -pub struct CheckGenerics<'g> { - generics: &'g [&'g GenericParam], - used: Vec<&'g GenericParam>, +pub struct CheckGenerics<'g, Generic> { + generics: &'g [&'g Generic], + used: Vec<&'g Generic>, } -impl<'g> CheckGenerics<'g> { - pub fn new(generics: &'g [&'g GenericParam]) -> Self { +impl<'g, Generic> CheckGenerics<'g, Generic> +where + Generic: GetPath + PartialEq, +{ + pub fn new(generics: &'g [&'g Generic]) -> Self { Self { generics, used: vec![], } } - pub fn used(self) -> Vec<&'g GenericParam> { + pub fn used(self) -> Vec<&'g Generic> { self.used } /// Returns split between used and unused generics - pub fn used_unused(self) -> (Vec<&'g GenericParam>, Vec<&'g GenericParam>) { + pub fn used_unused(self) -> (Vec<&'g Generic>, Vec<&'g Generic>) { let unused = self .generics .iter() @@ -32,17 +63,18 @@ impl<'g> CheckGenerics<'g> { } } -impl<'ast, 'g> Visit<'ast> for CheckGenerics<'g> { +impl<'ast, 'g, Generic> Visit<'ast> for CheckGenerics<'g, Generic> +where + Generic: GetPath + PartialEq, +{ fn visit_path(&mut self, p: &'ast syn::Path) { - if let Some(p) = p.get_ident() { - if let Some(gen) = self - .generics - .iter() - .find(|gen| matches!(gen, GenericParam::Type(ty) if ty.ident == *p)) - { - if !self.used.contains(gen) { - self.used.push(gen); - } + if let Some(gen) = self + .generics + .iter() + .find(|gen| gen.get_path().as_ref() == Some(p)) + { + if !self.used.contains(gen) { + self.used.push(gen); } } diff --git a/sylvia-derive/src/input.rs b/sylvia-derive/src/input.rs index 718cb043..80b8929a 100644 --- a/sylvia-derive/src/input.rs +++ b/sylvia-derive/src/input.rs @@ -3,7 +3,10 @@ use proc_macro_error::emit_error; use quote::quote; use syn::parse::{Parse, Parser}; use syn::spanned::Spanned; -use syn::{parse_quote, GenericParam, Ident, ItemImpl, ItemTrait, TraitItem, Type}; +use syn::{ + parse_quote, GenericArgument, GenericParam, Ident, ItemImpl, ItemTrait, PathArguments, + TraitItem, Type, +}; use crate::crate_module; use crate::interfaces::Interfaces; @@ -156,42 +159,48 @@ impl<'a> ImplInput<'a> { } pub fn process(&self) -> TokenStream { - let is_trait = self.item.trait_.is_some(); + let Self { + item, + generics, + error, + custom, + override_entry_points, + interfaces, + .. + } = self; + let is_trait = item.trait_.is_some(); let multitest_helpers = if cfg!(feature = "mt") { + let interface_generics = self.extract_generic_argument(); MultitestHelpers::new( - self.item, + item, is_trait, - &self.error, - &self.generics, - &self.custom, - &self.override_entry_points, - &self.interfaces, + error, + &interface_generics, + custom, + override_entry_points, + interfaces, ) .emit() } else { quote! {} }; - let unbonded_generics = &vec![]; + let where_clause = &item.generics.where_clause; let variants = MsgVariants::new( self.item.as_variants(), MsgType::Query, - unbonded_generics, - &None, + generics, + where_clause, ); match is_trait { - true => self.process_interface(variants, multitest_helpers), + true => self.process_interface(multitest_helpers), false => self.process_contract(variants, multitest_helpers), } } - fn process_interface( - &self, - variants: MsgVariants<'a>, - multitest_helpers: TokenStream, - ) -> TokenStream { - let querier_bound_for_impl = self.emit_querier_for_bound_impl(variants); + fn process_interface(&self, multitest_helpers: TokenStream) -> TokenStream { + let querier_bound_for_impl = self.emit_querier_for_bound_impl(); #[cfg(not(tarpaulin_include))] quote! { @@ -203,7 +212,7 @@ impl<'a> ImplInput<'a> { fn process_contract( &self, - variants: MsgVariants<'a>, + variants: MsgVariants<'a, GenericParam>, multitest_helpers: TokenStream, ) -> TokenStream { let messages = self.emit_messages(); @@ -285,13 +294,31 @@ impl<'a> ImplInput<'a> { .emit() } - fn emit_querier_for_bound_impl(&self, variants: MsgVariants<'a>) -> TokenStream { + /// This method should only be called for trait impl block + fn extract_generic_argument(&self) -> Vec<&GenericArgument> { + let interface_generics = &self.item.trait_.as_ref(); + let args = match interface_generics { + Some((_, path, _)) => path.segments.last().map(|segment| &segment.arguments), + None => None, + }; + + match args { + Some(PathArguments::AngleBracketed(args)) => { + args.args.pairs().map(|pair| *pair.value()).collect() + } + _ => vec![], + } + } + + fn emit_querier_for_bound_impl(&self) -> TokenStream { let trait_module = self .interfaces - .interfaces() - .first() + .get_only_interface() .map(|interface| &interface.module); let contract_module = self.attributes.module.as_ref(); + let generics = self.extract_generic_argument(); + + let variants = MsgVariants::new(self.item.as_variants(), MsgType::Query, &generics, &None); variants.emit_querier_for_bound_impl(trait_module, contract_module) } diff --git a/sylvia-derive/src/interfaces.rs b/sylvia-derive/src/interfaces.rs index 94f0d575..f1d2a909 100644 --- a/sylvia-derive/src/interfaces.rs +++ b/sylvia-derive/src/interfaces.rs @@ -83,30 +83,23 @@ impl Interfaces { .collect() } - pub fn emit_glue_message_variants( - &self, - msg_ty: &MsgType, - msg_name: &Ident, - ) -> Vec { + pub fn emit_glue_message_variants(&self, msg_ty: &MsgType) -> Vec { + let sylvia = crate_module(); + self.interfaces .iter() .map(|interface| { let ContractMessageAttr { - module, - exec_generic_params, - query_generic_params, - variant, - .. + module, variant, .. } = interface; - let generics = match msg_ty { - MsgType::Exec => exec_generic_params.as_slice(), - MsgType::Query => query_generic_params.as_slice(), - _ => &[], - }; - - let enum_name = Self::merge_module_with_name(interface, msg_name); - quote! { #variant(#module :: #enum_name<#(#generics,)*>) } + let interface_enum = + quote! { <#module ::InterfaceTypes as #sylvia ::types::InterfaceMessages> }; + if msg_ty == &MsgType::Query { + quote! { #variant ( #interface_enum :: Query) } + } else { + quote! { #variant ( #interface_enum :: Exec)} + } }) .collect() } @@ -158,4 +151,23 @@ impl Interfaces { pub fn as_modules(&self) -> impl Iterator { self.interfaces.iter().map(|interface| &interface.module) } + + pub fn get_only_interface(&self) -> Option<&ContractMessageAttr> { + let interfaces = &self.interfaces; + match interfaces.len() { + 0 => None, + 1 => Some(&interfaces[0]), + _ => { + let first = &interfaces[0]; + for redefined in &interfaces[1..] { + emit_error!( + redefined.module, "The attribute `messages` is redefined"; + note = first.module.span() => "Previous definition of the attribute `messsages`"; + note = "Only one `messages` attribute can exist on an interface implementation on contract" + ); + } + None + } + } + } } diff --git a/sylvia-derive/src/message.rs b/sylvia-derive/src/message.rs index 9825f8e8..3e46a990 100644 --- a/sylvia-derive/src/message.rs +++ b/sylvia-derive/src/message.rs @@ -1,4 +1,4 @@ -use crate::check_generics::CheckGenerics; +use crate::check_generics::{CheckGenerics, GetPath}; use crate::crate_module; use crate::interfaces::Interfaces; use crate::parser::{ @@ -13,7 +13,7 @@ use crate::variant_descs::{AsVariantDescs, VariantDescs}; use convert_case::{Case, Casing}; use proc_macro2::{Span, TokenStream}; use proc_macro_error::emit_error; -use quote::quote; +use quote::{quote, ToTokens}; use syn::fold::Fold; use syn::parse::{Parse, Parser}; use syn::spanned::Spanned; @@ -269,11 +269,13 @@ impl<'a> EnumMessage<'a> { quote! {} } else if MsgType::Query == *msg_ty { quote! { + #[serde(skip)] #[returns((#(#generics,)*))] _Phantom(std::marker::PhantomData<( #(#generics,)* )>), } } else { quote! { + #[serde(skip)] _Phantom(std::marker::PhantomData<( #(#generics,)* )>), } }; @@ -281,7 +283,7 @@ impl<'a> EnumMessage<'a> { let match_arms = if !generics.is_empty() { quote! { #(#match_arms,)* - _Phantom(_) => unreachable!(), + _Phantom(_) => Err(#sylvia ::cw_std::StdError::generic_err("Phantom message should not be constructed.")).map_err(Into::into), } } else { quote! { @@ -489,11 +491,14 @@ pub struct MsgVariant<'a> { impl<'a> MsgVariant<'a> { /// Creates new message variant from trait method - pub fn new( + pub fn new( sig: &'a Signature, - generics_checker: &mut CheckGenerics, + generics_checker: &mut CheckGenerics, msg_attr: MsgAttr, - ) -> MsgVariant<'a> { + ) -> MsgVariant<'a> + where + Generic: GetPath + PartialEq, + { let function_name = &sig.ident; let name = Ident::new( @@ -622,10 +627,10 @@ impl<'a> MsgVariant<'a> { } } - pub fn emit_querier_impl( + pub fn emit_querier_impl( &self, trait_module: Option<&Path>, - unbonded_generics: &Vec<&GenericParam>, + unbonded_generics: &Vec<&Generic>, ) -> TokenStream { let sylvia = crate_module(); let Self { @@ -680,17 +685,20 @@ impl<'a> MsgVariant<'a> { } } -pub struct MsgVariants<'a> { +pub struct MsgVariants<'a, Generic> { variants: Vec>, - unbonded_generics: Vec<&'a GenericParam>, + used_generics: Vec<&'a Generic>, where_predicates: Vec<&'a WherePredicate>, } -impl<'a> MsgVariants<'a> { +impl<'a, Generic> MsgVariants<'a, Generic> +where + Generic: GetPath + PartialEq + ToTokens, +{ pub fn new( source: VariantDescs<'a>, msg_type: MsgType, - generics: &'a [&'a GenericParam], + generics: &'a [&'a Generic], unfiltered_where_clause: &'a Option, ) -> Self { let mut generics_checker = CheckGenerics::new(generics); @@ -717,12 +725,12 @@ impl<'a> MsgVariants<'a> { }) .collect(); - let (unbonded_generics, _) = generics_checker.used_unused(); - let where_predicates = filter_wheres(unfiltered_where_clause, generics, &unbonded_generics); + let (used_generics, _) = generics_checker.used_unused(); + let where_predicates = filter_wheres(unfiltered_where_clause, generics, &used_generics); Self { variants, - unbonded_generics, + used_generics, where_predicates, } } @@ -744,7 +752,7 @@ impl<'a> MsgVariants<'a> { let sylvia = crate_module(); let Self { variants, - unbonded_generics, + used_generics, .. } = self; let where_clause = self.where_clause(); @@ -752,14 +760,14 @@ impl<'a> MsgVariants<'a> { let methods_impl = variants .iter() .filter(|variant| variant.msg_type == MsgType::Query) - .map(|variant| variant.emit_querier_impl(None, unbonded_generics)); + .map(|variant| variant.emit_querier_impl(None, used_generics)); let methods_declaration = variants .iter() .filter(|variant| variant.msg_type == MsgType::Query) .map(MsgVariant::emit_querier_declaration); - let braced_generics = emit_bracketed_generics(unbonded_generics); + let braced_generics = emit_bracketed_generics(used_generics); let querier = quote! { Querier #braced_generics }; #[cfg(not(tarpaulin_include))] @@ -784,7 +792,7 @@ impl<'a> MsgVariants<'a> { } } - impl <'a, C: #sylvia ::cw_std::CustomQuery, #(#unbonded_generics,)*> #querier for BoundQuerier<'a, C> #where_clause { + impl <'a, C: #sylvia ::cw_std::CustomQuery, #(#used_generics,)*> #querier for BoundQuerier<'a, C> #where_clause { #(#methods_impl)* } @@ -803,7 +811,7 @@ impl<'a> MsgVariants<'a> { let sylvia = crate_module(); let Self { variants, - unbonded_generics, + used_generics, .. } = self; let where_clause = self.where_clause(); @@ -811,23 +819,25 @@ impl<'a> MsgVariants<'a> { let methods_impl = variants .iter() .filter(|variant| variant.msg_type == MsgType::Query) - .map(|variant| variant.emit_querier_impl(trait_module, unbonded_generics)); + .map(|variant| variant.emit_querier_impl(trait_module, used_generics)); - let mut querier = trait_module + let querier = trait_module .map(|module| quote! { #module ::Querier }) .unwrap_or_else(|| quote! { Querier }); let bound_querier = contract_module .map(|module| quote! { #module ::BoundQuerier}) .unwrap_or_else(|| quote! { BoundQuerier }); - if !unbonded_generics.is_empty() { - querier = quote! { #querier < #(#unbonded_generics,)* > }; - } + let querier = if !used_generics.is_empty() { + quote! { #querier < #(#used_generics,)* > } + } else { + quote! { #querier } + }; #[cfg(not(tarpaulin_include))] { quote! { - impl <'a, C: #sylvia ::cw_std::CustomQuery, #(#unbonded_generics,)*> #querier for #bound_querier<'a, C> #where_clause { + impl <'a, C: #sylvia ::cw_std::CustomQuery> #querier for #bound_querier<'a, C > #where_clause { #(#methods_impl)* } } @@ -844,7 +854,13 @@ pub struct MsgField<'a> { impl<'a> MsgField<'a> { /// Creates new field from trait method argument - pub fn new(item: &'a PatType, generics_checker: &mut CheckGenerics) -> Option> { + pub fn new( + item: &'a PatType, + generics_checker: &mut CheckGenerics, + ) -> Option> + where + Generic: GetPath + PartialEq, + { let name = match &*item.pat { Pat::Ident(p) => Some(&p.ident), pat => { @@ -950,7 +966,7 @@ impl<'a> GlueMessage<'a> { let contract = StripGenerics.fold_type((*contract).clone()); let enum_name = Ident::new(&format!("Contract{}", name), name.span()); - let variants = interfaces.emit_glue_message_variants(msg_ty, name); + let variants = interfaces.emit_glue_message_variants(msg_ty); let msg_name = quote! {#contract ( #name)}; let mut messages_call_on_all_variants: Vec = @@ -1101,8 +1117,8 @@ impl<'a> GlueMessage<'a> { } pub struct InterfaceMessages<'a> { - exec_variants: MsgVariants<'a>, - query_variants: MsgVariants<'a>, + exec_variants: MsgVariants<'a, GenericParam>, + query_variants: MsgVariants<'a, GenericParam>, generics: &'a [&'a GenericParam], } @@ -1137,8 +1153,8 @@ impl<'a> InterfaceMessages<'a> { generics, } = self; - let exec_generics = &exec_variants.unbonded_generics; - let query_generics = &query_variants.unbonded_generics; + let exec_generics = &exec_variants.used_generics; + let query_generics = &query_variants.used_generics; let bracket_generics = emit_bracketed_generics(generics); let exec_bracketed_generics = emit_bracketed_generics(exec_generics); @@ -1195,15 +1211,17 @@ impl<'a> EntryPoints<'a> { ) .unwrap_or_else(|| parse_quote! { #sylvia ::cw_std::StdError }); - let has_migrate = !MsgVariants::new(source.as_variants(), MsgType::Migrate, &[], &None) - .variants() - .is_empty(); - - let reply = MsgVariants::new(source.as_variants(), MsgType::Reply, &[], &None) - .variants() - .iter() - .map(|variant| variant.function_name.clone()) - .next(); + let has_migrate = + !MsgVariants::::new(source.as_variants(), MsgType::Migrate, &[], &None) + .variants() + .is_empty(); + + let reply = + MsgVariants::::new(source.as_variants(), MsgType::Reply, &[], &None) + .variants() + .iter() + .map(|variant| variant.function_name.clone()) + .next(); let custom = Custom::new(&source.attrs); Self { diff --git a/sylvia-derive/src/multitest.rs b/sylvia-derive/src/multitest.rs index 3183e1ef..2600f92a 100644 --- a/sylvia-derive/src/multitest.rs +++ b/sylvia-derive/src/multitest.rs @@ -1,20 +1,18 @@ use proc_macro2::{Ident, TokenStream}; use proc_macro_error::emit_error; -use quote::quote; +use quote::{quote, ToTokens}; use syn::parse::{Parse, Parser}; use syn::spanned::Spanned; -use syn::{ - parse_quote, FnArg, GenericParam, ImplItem, ItemImpl, ItemTrait, Pat, PatType, Path, Type, -}; +use syn::{parse_quote, FnArg, ImplItem, ItemImpl, ItemTrait, Pat, PatType, Path, Type}; -use crate::check_generics::CheckGenerics; +use crate::check_generics::{CheckGenerics, GetPath}; use crate::crate_module; use crate::interfaces::Interfaces; use crate::message::MsgField; use crate::parser::{ parse_struct_message, Custom, MsgAttr, MsgType, OverrideEntryPoint, OverrideEntryPoints, }; -use crate::utils::{extract_return_type, process_fields}; +use crate::utils::{emit_bracketed_generics, extract_return_type, process_fields}; struct MessageSignature<'a> { pub name: &'a Ident, @@ -24,7 +22,7 @@ struct MessageSignature<'a> { pub return_type: TokenStream, } -pub struct MultitestHelpers<'a> { +pub struct MultitestHelpers<'a, Generics> { messages: Vec>, error_type: TokenStream, contract: &'a Type, @@ -32,7 +30,7 @@ pub struct MultitestHelpers<'a> { is_migrate: bool, reply: Option, source: &'a ItemImpl, - generics: &'a [&'a GenericParam], + generics: &'a [&'a Generics], contract_name: &'a Ident, proxy_name: Ident, custom: &'a Custom<'a>, @@ -61,12 +59,15 @@ fn extract_contract_name(contract: &Type) -> &Ident { &segment.ident } -impl<'a> MultitestHelpers<'a> { +impl<'a, Generics> MultitestHelpers<'a, Generics> +where + Generics: ToTokens + PartialEq + GetPath, +{ pub fn new( source: &'a ItemImpl, is_trait: bool, contract_error: &'a Type, - generics: &'a [&'a GenericParam], + generics: &'a [&'a Generics], custom: &'a Custom, override_entry_points: &'a OverrideEntryPoints, interfaces: &'a Interfaces, @@ -380,6 +381,7 @@ impl<'a> MultitestHelpers<'a> { error_type, custom, interfaces, + generics, .. } = self; @@ -389,29 +391,13 @@ impl<'a> MultitestHelpers<'a> { let proxy_name = &self.proxy_name; let trait_name = Ident::new(&format!("{}", interface_name), interface_name.span()); - let modules: Vec<&Path> = interfaces.as_modules().collect(); - - #[cfg(not(tarpaulin_include))] - let module = match modules.len() { - 0 => { - quote! {} - } - 1 => { - let module = &modules[0]; - quote! {#module ::} - } - _ => { - let first = &modules[0]; - for redefined in &modules[1..] { - emit_error!( - redefined, "The attribute `messages` is redefined"; - note = first.span() => "Previous definition of the attribute `messsages`"; - note = "Only one `messages` attribute can exist on an interface implementation on contract" - ); - } - quote! {} - } - }; + let module = interfaces + .get_only_interface() + .map(|interface| { + let module = &interface.module; + quote! { #module :: } + }) + .unwrap_or(quote! {}); let custom_msg = custom.msg_or_default(); @@ -430,6 +416,9 @@ impl<'a> MultitestHelpers<'a> { > }; + let bracketed_generics = emit_bracketed_generics(generics); + let interface_enum = quote! { < #module InterfaceTypes #bracketed_generics as #sylvia ::types::InterfaceMessages> }; + #[cfg(not(tarpaulin_include))] let methods_definitions = messages.iter().map(|msg| { let MessageSignature { @@ -439,11 +428,12 @@ impl<'a> MultitestHelpers<'a> { msg_ty, return_type, } = msg; + let type_name = msg_ty.as_accessor_name(); if msg_ty == &MsgType::Exec { quote! { #[track_caller] - fn #name (&self, #(#params,)* ) -> #sylvia ::multitest::ExecProxy::<#error_type, #module ExecMsg, #mt_app, #custom_msg> { - let msg = #module ExecMsg:: #name ( #(#arguments),* ); + fn #name (&self, #(#params,)* ) -> #sylvia ::multitest::ExecProxy::<#error_type, #interface_enum :: #type_name, #mt_app, #custom_msg> { + let msg = #interface_enum :: #type_name :: #name ( #(#arguments),* ); #sylvia ::multitest::ExecProxy::new(&self.contract_addr, msg, &self.app) } @@ -451,7 +441,7 @@ impl<'a> MultitestHelpers<'a> { } else { quote! { fn #name (&self, #(#params,)* ) -> Result<#return_type, #error_type> { - let msg = #module QueryMsg:: #name ( #(#arguments),* ); + let msg = #interface_enum :: #type_name :: #name ( #(#arguments),* ); (*self.app) .app() @@ -472,9 +462,10 @@ impl<'a> MultitestHelpers<'a> { return_type, .. } = msg; + let type_name = msg_ty.as_accessor_name(); if msg_ty == &MsgType::Exec { quote! { - fn #name (&self, #(#params,)* ) -> #sylvia ::multitest::ExecProxy::<#error_type, #module ExecMsg, MtApp, #custom_msg>; + fn #name (&self, #(#params,)* ) -> #sylvia ::multitest::ExecProxy::<#error_type, #interface_enum :: #type_name, MtApp, #custom_msg>; } } else { quote! { diff --git a/sylvia-derive/src/parser.rs b/sylvia-derive/src/parser.rs index 5a3b48be..0993ca5b 100644 --- a/sylvia-derive/src/parser.rs +++ b/sylvia-derive/src/parser.rs @@ -152,6 +152,14 @@ impl MsgType { MsgType::Sudo => todo!(), } } + + pub fn as_accessor_name(&self) -> Option { + match self { + MsgType::Exec => Some(parse_quote! { Exec }), + MsgType::Query => Some(parse_quote! { Query }), + _ => None, + } + } } impl PartialEq for MsgAttr { @@ -238,36 +246,10 @@ pub struct Customs { #[derive(Debug)] pub struct ContractMessageAttr { pub module: Path, - pub exec_generic_params: Vec, - pub query_generic_params: Vec, pub variant: Ident, pub customs: Customs, } -#[cfg(not(tarpaulin_include))] -// False negative. Called in function below -fn parse_generics(content: &ParseBuffer) -> Result> { - let _: Token![<] = content.parse()?; - let mut params = vec![]; - - loop { - let param: Path = content.parse()?; - params.push(param); - - let generics_close: Option]> = content.parse()?; - if generics_close.is_some() { - break; - } - - let comma: Option = content.parse()?; - if comma.is_none() { - return Err(Error::new(content.span(), "Expected comma or `>`")); - } - } - - Ok(params) -} - fn interface_has_custom(content: ParseStream) -> Result { let mut customs = Customs { has_msg: false, @@ -312,31 +294,6 @@ impl Parse for ContractMessageAttr { let module = content.parse()?; - let generics_open: Option = content.parse()?; - let mut exec_generic_params = vec![]; - let mut query_generic_params = vec![]; - - if generics_open.is_some() { - loop { - let ty: Ident = content.parse()?; - let params = if ty == "exec" { - &mut exec_generic_params - } else if ty == "query" { - &mut query_generic_params - } else { - return Err(Error::new(ty.span(), "Invalid message type")); - }; - - *params = parse_generics(&content)?; - - if content.peek(Token![as]) { - break; - } - - let _: Token![,] = content.parse()?; - } - } - let _: Token![as] = content.parse()?; let variant = content.parse()?; @@ -351,8 +308,6 @@ impl Parse for ContractMessageAttr { Ok(Self { module, - exec_generic_params, - query_generic_params, variant, customs, }) diff --git a/sylvia-derive/src/utils.rs b/sylvia-derive/src/utils.rs index 16857d4d..276deb3d 100644 --- a/sylvia-derive/src/utils.rs +++ b/sylvia-derive/src/utils.rs @@ -1,21 +1,21 @@ use proc_macro2::TokenStream; use proc_macro_error::emit_error; -use quote::quote; +use quote::{quote, ToTokens}; use syn::spanned::Spanned; use syn::visit::Visit; use syn::{ - parse_quote, FnArg, GenericArgument, GenericParam, Path, PathArguments, ReturnType, Signature, - Type, WhereClause, WherePredicate, + parse_quote, FnArg, GenericArgument, Path, PathArguments, ReturnType, Signature, Type, + WhereClause, WherePredicate, }; -use crate::check_generics::CheckGenerics; +use crate::check_generics::{CheckGenerics, GetPath}; use crate::message::MsgField; #[cfg(not(tarpaulin_include))] -pub fn filter_wheres<'a>( +pub fn filter_wheres<'a, Generic: GetPath + PartialEq>( clause: &'a Option, - generics: &[&GenericParam], - used_generics: &[&GenericParam], + generics: &[&Generic], + used_generics: &[&Generic], ) -> Vec<&'a WherePredicate> { clause .as_ref() @@ -36,10 +36,13 @@ pub fn filter_wheres<'a>( .unwrap_or_default() } -pub fn process_fields<'s>( +pub fn process_fields<'s, Generic>( sig: &'s Signature, - generics_checker: &mut CheckGenerics, -) -> Vec> { + generics_checker: &mut CheckGenerics, +) -> Vec> +where + Generic: GetPath + PartialEq, +{ sig.inputs .iter() .skip(2) @@ -94,7 +97,7 @@ pub fn as_where_clause(where_predicates: &[&WherePredicate]) -> Option TokenStream { +pub fn emit_bracketed_generics(unbonded_generics: &[&Generic]) -> TokenStream { match unbonded_generics.is_empty() { true => quote! {}, false => quote! { < #(#unbonded_generics,)* > }, diff --git a/sylvia/tests/generics.rs b/sylvia/tests/generics.rs index f6721c67..bc9e982d 100644 --- a/sylvia/tests/generics.rs +++ b/sylvia/tests/generics.rs @@ -29,6 +29,58 @@ pub mod cw1 { } } +pub mod cw1_contract { + use cosmwasm_std::{Response, StdResult}; + use sylvia::types::InstantiateCtx; + use sylvia_derive::contract; + + pub struct Cw1Contract; + + #[contract] + impl Cw1Contract { + pub const fn new() -> Self { + Self + } + + #[msg(instantiate)] + pub fn instantiate(&self, _ctx: InstantiateCtx) -> StdResult { + Ok(Response::new()) + } + } +} + +pub mod impl_cw1 { + use cosmwasm_std::{CosmosMsg, Response, StdError}; + use sylvia::types::{ExecCtx, QueryCtx}; + use sylvia_derive::contract; + + use crate::{cw1::Cw1, cw1_contract::Cw1Contract, ExternalMsg}; + + #[contract(module = crate::cw1_contract)] + #[messages(crate::cw1 as Cw1)] + impl Cw1 for Cw1Contract { + type Error = StdError; + + #[msg(exec)] + fn execute( + &self, + _ctx: ExecCtx, + _msgs: Vec>, + ) -> Result, Self::Error> { + Ok(Response::new()) + } + + #[msg(query)] + fn some_query( + &self, + _ctx: QueryCtx, + _param: crate::ExternalMsg, + ) -> Result { + Ok(crate::ExternalQuery {}) + } + } +} + #[cw_serde] pub struct ExternalMsg; impl cosmwasm_std::CustomMsg for ExternalMsg {} @@ -37,7 +89,7 @@ impl cosmwasm_std::CustomMsg for ExternalMsg {} pub struct ExternalQuery; impl cosmwasm_std::CustomQuery for ExternalQuery {} -#[cfg(test)] +#[cfg(all(test, feature = "mt"))] mod tests { use cosmwasm_std::{testing::mock_dependencies, Addr, CosmosMsg, Empty, QuerierWrapper};