diff --git a/sylvia-derive/src/input.rs b/sylvia-derive/src/input.rs index 80b8929a..41c3d458 100644 --- a/sylvia-derive/src/input.rs +++ b/sylvia-derive/src/input.rs @@ -1,6 +1,6 @@ use proc_macro2::{Span, TokenStream}; use proc_macro_error::emit_error; -use quote::quote; +use quote::{quote, ToTokens}; use syn::parse::{Parse, Parser}; use syn::spanned::Spanned; use syn::{ @@ -8,6 +8,7 @@ use syn::{ TraitItem, Type, }; +use crate::check_generics::AsIdent; use crate::crate_module; use crate::interfaces::Interfaces; use crate::message::{ @@ -159,47 +160,17 @@ impl<'a> ImplInput<'a> { } pub fn process(&self) -> TokenStream { - let Self { - item, - generics, - error, - custom, - override_entry_points, - interfaces, - .. - } = self; - let is_trait = item.trait_.is_some(); - let multitest_helpers = if cfg!(feature = "mt") { - let interface_generics = self.extract_generic_argument(); - MultitestHelpers::new( - item, - is_trait, - error, - &interface_generics, - custom, - override_entry_points, - interfaces, - ) - .emit() - } else { - quote! {} - }; - - let where_clause = &item.generics.where_clause; - let variants = MsgVariants::new( - self.item.as_variants(), - MsgType::Query, - generics, - where_clause, - ); + let is_trait = self.item.trait_.is_some(); match is_trait { - true => self.process_interface(multitest_helpers), - false => self.process_contract(variants, multitest_helpers), + true => self.process_interface(), + false => self.process_contract(), } } - fn process_interface(&self, multitest_helpers: TokenStream) -> TokenStream { + fn process_interface(&self) -> TokenStream { + let interface_generics = self.extract_generic_argument(); + let multitest_helpers = self.emit_multitest_helpers(&interface_generics); let querier_bound_for_impl = self.emit_querier_for_bound_impl(); #[cfg(not(tarpaulin_include))] @@ -210,14 +181,19 @@ impl<'a> ImplInput<'a> { } } - fn process_contract( - &self, - variants: MsgVariants<'a, GenericParam>, - multitest_helpers: TokenStream, - ) -> TokenStream { - let messages = self.emit_messages(); - let remote = Remote::new(&self.interfaces).emit(); + fn process_contract(&self) -> TokenStream { + let Self { item, generics, .. } = self; + let multitest_helpers = self.emit_multitest_helpers(generics); + let where_clause = &item.generics.where_clause; + let variants = MsgVariants::new( + self.item.as_variants(), + MsgType::Query, + generics, + where_clause, + ); + let messages = self.emit_messages(&variants); + let remote = Remote::new(&self.interfaces).emit(); let querier = variants.emit_querier(); let querier_from_impl = self.interfaces.emit_querier_from_impl(); @@ -237,15 +213,23 @@ impl<'a> ImplInput<'a> { } } - fn emit_messages(&self) -> TokenStream { + fn emit_messages(&self, variants: &MsgVariants) -> TokenStream { let instantiate = self.emit_struct_msg(MsgType::Instantiate); let migrate = self.emit_struct_msg(MsgType::Migrate); let exec_impl = self.emit_enum_msg(&Ident::new("ExecMsg", Span::mixed_site()), MsgType::Exec); let query_impl = self.emit_enum_msg(&Ident::new("QueryMsg", Span::mixed_site()), MsgType::Query); - let exec = self.emit_glue_msg(&Ident::new("ExecMsg", Span::mixed_site()), MsgType::Exec); - let query = self.emit_glue_msg(&Ident::new("QueryMsg", Span::mixed_site()), MsgType::Query); + let exec = self.emit_glue_msg( + &Ident::new("ExecMsg", Span::mixed_site()), + MsgType::Exec, + variants, + ); + let query = self.emit_glue_msg( + &Ident::new("QueryMsg", Span::mixed_site()), + MsgType::Query, + variants, + ); #[cfg(not(tarpaulin_include))] { @@ -282,7 +266,12 @@ impl<'a> ImplInput<'a> { .emit() } - fn emit_glue_msg(&self, name: &Ident, msg_ty: MsgType) -> TokenStream { + fn emit_glue_msg( + &self, + name: &Ident, + msg_ty: MsgType, + variants: &MsgVariants, + ) -> TokenStream { GlueMessage::new( name, self.item, @@ -290,6 +279,7 @@ impl<'a> ImplInput<'a> { &self.error, &self.custom, &self.interfaces, + variants, ) .emit() } @@ -322,4 +312,35 @@ impl<'a> ImplInput<'a> { variants.emit_querier_for_bound_impl(trait_module, contract_module) } + + fn emit_multitest_helpers(&self, generics: &[&Generic]) -> TokenStream + where + Generic: ToTokens + PartialEq + AsIdent, + { + let Self { + item, + error, + custom, + override_entry_points, + interfaces, + .. + } = self; + + let is_trait = self.item.trait_.is_some(); + + if cfg!(feature = "mt") { + MultitestHelpers::new( + item, + is_trait, + error, + generics, + custom, + override_entry_points, + interfaces, + ) + .emit() + } else { + quote! {} + } + } } diff --git a/sylvia-derive/src/message.rs b/sylvia-derive/src/message.rs index 926e9b74..67633bf2 100644 --- a/sylvia-derive/src/message.rs +++ b/sylvia-derive/src/message.rs @@ -3,7 +3,7 @@ use crate::crate_module; use crate::interfaces::Interfaces; use crate::parser::{ parse_associated_custom_type, parse_struct_message, ContractErrorAttr, ContractMessageAttr, - Custom, MsgAttr, MsgType, OverrideEntryPoint, OverrideEntryPoints, + Custom, MsgAttr, MsgType, OverrideEntryPoints, }; use crate::strip_generics::StripGenerics; use crate::utils::{ @@ -19,7 +19,7 @@ use syn::parse::{Parse, Parser}; use syn::spanned::Spanned; use syn::visit::Visit; use syn::{ - parse_quote, Attribute, GenericParam, Ident, ImplItem, ItemImpl, ItemTrait, Pat, PatType, Path, + parse_quote, Attribute, GenericParam, Ident, ItemImpl, ItemTrait, Pat, PatType, Path, ReturnType, Signature, TraitItem, Type, WhereClause, WherePredicate, }; @@ -123,7 +123,7 @@ impl<'a> StructMessage<'a> { #[allow(clippy::derive_partial_eq_without_eq)] #[derive(#sylvia ::serde::Serialize, #sylvia ::serde::Deserialize, Clone, Debug, PartialEq, #sylvia ::schemars::JsonSchema)] #[serde(rename_all="snake_case")] - pub struct #name #generics #where_clause { + pub struct #name #generics { #(pub #fields,)* } @@ -242,9 +242,7 @@ impl<'a> EnumMessage<'a> { query_type, } = self; - let match_arms = variants - .iter() - .map(|variant| variant.emit_dispatch_leg(*msg_ty)); + let match_arms = variants.iter().map(|variant| variant.emit_dispatch_leg()); let mut msgs: Vec = variants .iter() .map(|var| var.name.to_string().to_case(Case::Snake)) @@ -349,7 +347,7 @@ impl<'a> EnumMessage<'a> { /// Representation of single enum message pub struct ContractEnumMessage<'a> { name: &'a Ident, - variants: Vec>, + variants: MsgVariants<'a, GenericParam>, msg_ty: MsgType, contract: &'a Type, error: &'a Type, @@ -360,40 +358,22 @@ impl<'a> ContractEnumMessage<'a> { pub fn new( name: &'a Ident, source: &'a ItemImpl, - ty: MsgType, + msg_ty: MsgType, generics: &'a [&'a GenericParam], error: &'a Type, custom: &'a Custom, ) -> Self { - let mut generics_checker = CheckGenerics::new(generics); - let variants: Vec<_> = source - .items - .iter() - .filter_map(|item| match item { - ImplItem::Method(method) => { - let msg_attr = method.attrs.iter().find(|attr| attr.path.is_ident("msg"))?; - let attr = match MsgAttr::parse.parse2(msg_attr.tokens.clone()) { - Ok(attr) => attr, - Err(err) => { - emit_error!(method.span(), err); - return None; - } - }; - - if attr == ty { - Some(MsgVariant::new(&method.sig, &mut generics_checker, attr)) - } else { - None - } - } - _ => None, - }) - .collect(); + let variants = MsgVariants::new( + source.as_variants(), + msg_ty, + generics, + &source.generics.where_clause, + ); Self { name, variants, - msg_ty: ty, + msg_ty, contract: &source.self_ty, error, custom, @@ -412,61 +392,45 @@ impl<'a> ContractEnumMessage<'a> { custom, } = self; - let match_arms = variants - .iter() - .map(|variant| variant.emit_dispatch_leg(*msg_ty)); - let mut msgs: Vec = variants - .iter() - .map(|var| var.name.to_string().to_case(Case::Snake)) - .collect(); - msgs.sort(); - let msgs_cnt = msgs.len(); - let variants_constructors = variants.iter().map(MsgVariant::emit_variants_constructors); - let variants = variants.iter().map(MsgVariant::emit); + let match_arms = variants.emit_dispatch_legs(); + let generic_name = variants.emit_generic_name(name); + let unused_generics = variants.unused_generics(); + let unused_generics = emit_bracketed_generics(unused_generics); + + let mut variant_names = variants.as_names_snake_cased(); + variant_names.sort(); + let variants_cnt = variant_names.len(); + let variants_constructors = variants.emit_constructors(); + let variants = variants.emit(); let ctx_type = msg_ty.emit_ctx_type(&custom.query_or_default()); - let contract = StripGenerics.fold_type((*contract).clone()); let ret_type = msg_ty.emit_result_type(&custom.msg_or_default(), error); - #[cfg(not(tarpaulin_include))] - let enum_declaration = match name.to_string().as_str() { - "QueryMsg" => { - quote! { - #[allow(clippy::derive_partial_eq_without_eq)] - #[derive(#sylvia ::serde::Serialize, #sylvia ::serde::Deserialize, Clone, Debug, PartialEq, #sylvia ::schemars::JsonSchema, cosmwasm_schema::QueryResponses)] - #[serde(rename_all="snake_case")] - pub enum #name { - #(#variants,)* - } - } - } - _ => { - quote! { - #[allow(clippy::derive_partial_eq_without_eq)] - #[derive(#sylvia ::serde::Serialize, #sylvia ::serde::Deserialize, Clone, Debug, PartialEq, #sylvia ::schemars::JsonSchema)] - #[serde(rename_all="snake_case")] - pub enum #name { - #(#variants,)* - } - } - } + let derive_query = match msg_ty { + MsgType::Query => quote! { #sylvia ::cw_schema::QueryResponses }, + _ => quote! {}, }; #[cfg(not(tarpaulin_include))] { quote! { - #enum_declaration + #[allow(clippy::derive_partial_eq_without_eq)] + #[derive(#sylvia ::serde::Serialize, #sylvia ::serde::Deserialize, Clone, Debug, PartialEq, #sylvia ::schemars::JsonSchema, #derive_query )] + #[serde(rename_all="snake_case")] + pub enum #generic_name { + #(#variants,)* + } - impl #name { - pub fn dispatch(self, contract: &#contract, ctx: #ctx_type) -> #ret_type { + impl #generic_name { + pub fn dispatch #unused_generics (self, contract: &#contract, ctx: #ctx_type) -> #ret_type { use #name::*; match self { #(#match_arms,)* } } - pub const fn messages() -> [&'static str; #msgs_cnt] { - [#(#msgs,)*] + pub const fn messages() -> [&'static str; #variants_cnt] { + [#(#variant_names,)*] } #(#variants_constructors)* @@ -477,6 +441,7 @@ impl<'a> ContractEnumMessage<'a> { } /// Representation of whole message variant +#[derive(Debug)] pub struct MsgVariant<'a> { name: Ident, function_name: &'a Ident, @@ -562,13 +527,14 @@ impl<'a> MsgVariant<'a> { /// Emits match leg dispatching against this variant. Assumes enum variants are imported into the /// scope. Dispatching is performed by calling the function this variant is build from on the /// `contract` variable, with `ctx` as its first argument - both of them should be in scope. - pub fn emit_dispatch_leg(&self, msg_type: MsgType) -> TokenStream { + pub fn emit_dispatch_leg(&self) -> TokenStream { use MsgType::*; let Self { name, fields, function_name, + msg_type, .. } = self; @@ -681,12 +647,23 @@ impl<'a> MsgVariant<'a> { } } } + + pub fn as_fields_names(&self) -> Vec<&Ident> { + self.fields.iter().map(MsgField::name).collect() + } + + pub fn emit_fields(&self) -> Vec { + self.fields.iter().map(MsgField::emit).collect() + } } +#[derive(Debug)] pub struct MsgVariants<'a, Generic> { variants: Vec>, used_generics: Vec<&'a Generic>, + unused_generics: Vec<&'a Generic>, where_predicates: Vec<&'a WherePredicate>, + msg_ty: MsgType, } impl<'a, Generic> MsgVariants<'a, Generic> @@ -695,11 +672,11 @@ where { pub fn new( source: VariantDescs<'a>, - msg_type: MsgType, - generics: &'a [&'a Generic], + msg_ty: MsgType, + all_generics: &'a [&'a Generic], unfiltered_where_clause: &'a Option, ) -> Self { - let mut generics_checker = CheckGenerics::new(generics); + let mut generics_checker = CheckGenerics::new(all_generics); let variants: Vec<_> = source .filter_map(|variant_desc| { let msg_attr = variant_desc.attr_msg()?; @@ -711,7 +688,7 @@ where } }; - if attr.msg_type() != msg_type { + if attr.msg_type() != msg_ty { return None; } @@ -723,13 +700,15 @@ where }) .collect(); - let (used_generics, _) = generics_checker.used_unused(); - let where_predicates = filter_wheres(unfiltered_where_clause, generics, &used_generics); + let (used_generics, unused_generics) = generics_checker.used_unused(); + let where_predicates = filter_wheres(unfiltered_where_clause, all_generics, &used_generics); Self { variants, used_generics, + unused_generics, where_predicates, + msg_ty, } } @@ -746,6 +725,18 @@ where &self.variants } + pub fn used_generics(&self) -> &Vec<&'a Generic> { + &self.used_generics + } + + pub fn unused_generics(&self) -> &Vec<&'a Generic> { + &self.unused_generics + } + + pub fn where_predicates(&'a self) -> &'a [&'a WherePredicate] { + &self.where_predicates + } + pub fn emit_querier(&self) -> TokenStream { let sylvia = crate_module(); let Self { @@ -841,9 +832,100 @@ where } } } + + pub fn emit_multitest_default_dispatch(&self) -> TokenStream { + let sylvia = crate_module(); + let Self { + msg_ty, + used_generics, + .. + } = self; + + let values = msg_ty.emit_ctx_values(); + let msg_name = msg_ty.emit_msg_name(used_generics.as_slice()); + + quote! { + #sylvia ::cw_std::from_slice::< #msg_name >(&msg)? + .dispatch(self, ( #values )) + .map_err(Into::into) + } + } + + pub fn emit_default_entry_point( + &self, + custom_msg: &Type, + custom_query: &Type, + name: &Type, + error: &Type, + ) -> TokenStream { + let Self { + used_generics, + msg_ty, + .. + } = self; + let sylvia = crate_module(); + + let resp_type = match msg_ty { + MsgType::Query => quote! { #sylvia ::cw_std::Binary }, + _ => quote! { #sylvia ::cw_std::Response < #custom_msg > }, + }; + let params = msg_ty.emit_ctx_params(custom_query); + let values = msg_ty.emit_ctx_values(); + let ep_name = msg_ty.emit_ep_name(); + let msg_name = msg_ty.emit_msg_name(used_generics); + + quote! { + #[#sylvia ::cw_std::entry_point] + pub fn #ep_name ( + #params , + msg: #msg_name, + ) -> Result<#resp_type, #error> { + msg.dispatch(&#name ::new() , ( #values )).map_err(Into::into) + } + } + } + + pub fn emit_dispatch_legs(&self) -> impl Iterator + '_ { + self.variants + .iter() + .map(|variant| variant.emit_dispatch_leg()) + } + + pub fn as_names_snake_cased(&self) -> Vec { + self.variants + .iter() + .map(|variant| variant.name.to_string().to_case(Case::Snake)) + .collect() + } + + pub fn emit_constructors(&self) -> impl Iterator + '_ { + self.variants + .iter() + .map(MsgVariant::emit_variants_constructors) + } + + pub fn emit_generic_name(&self, name: &Ident) -> TokenStream { + let generics = emit_bracketed_generics(&self.used_generics); + + #[cfg(not(tarpaulin_include))] + { + quote! { + #name #generics + } + } + } + + pub fn emit(&self) -> impl Iterator + '_ { + self.variants.iter().map(MsgVariant::emit) + } + + pub fn get_only_variant(&self) -> Option<&MsgVariant> { + self.variants.first() + } } /// Representation of single message variant field +#[derive(Debug)] pub struct MsgField<'a> { name: &'a Ident, ty: &'a Type, @@ -929,6 +1011,7 @@ pub struct GlueMessage<'a> { error: &'a Type, custom: &'a Custom<'a>, interfaces: &'a Interfaces, + variants: &'a MsgVariants<'a, GenericParam>, } impl<'a> GlueMessage<'a> { @@ -939,6 +1022,7 @@ impl<'a> GlueMessage<'a> { error: &'a Type, custom: &'a Custom, interfaces: &'a Interfaces, + variants: &'a MsgVariants<'a, GenericParam>, ) -> Self { GlueMessage { name, @@ -947,12 +1031,12 @@ impl<'a> GlueMessage<'a> { error, custom, interfaces, + variants, } } pub fn emit(&self) -> TokenStream { let sylvia = crate_module(); - let Self { name, contract, @@ -960,18 +1044,23 @@ impl<'a> GlueMessage<'a> { error, custom, interfaces, + variants, } = self; - let contract = StripGenerics.fold_type((*contract).clone()); + let contract_name = StripGenerics.fold_type((*contract).clone()); let enum_name = Ident::new(&format!("Contract{}", name), name.span()); + let used_generics = variants.used_generics(); + let used_generics = emit_bracketed_generics(used_generics); + let unused_generics = variants.unused_generics(); + let unused_generics = emit_bracketed_generics(unused_generics); + let where_clause = variants.where_clause(); let variants = interfaces.emit_glue_message_variants(msg_ty); - let msg_name = quote! {#contract ( #name)}; - let mut messages_call_on_all_variants: Vec = - interfaces.emit_messages_call(msg_ty); - messages_call_on_all_variants.push(quote! {&#name :: messages()}); + let contract_variant = quote! { #contract_name ( #name ) }; + let mut messages_call = interfaces.emit_messages_call(msg_ty); + messages_call.push(quote! { &#name :: messages() }); - let variants_cnt = messages_call_on_all_variants.len(); + let variants_cnt = messages_call.len(); let dispatch_arms = interfaces.interfaces().iter().map(|interface| { let ContractMessageAttr { @@ -1000,7 +1089,8 @@ impl<'a> GlueMessage<'a> { } }); - let dispatch_arm = quote! {#enum_name :: #contract (msg) => msg.dispatch(contract, ctx)}; + let dispatch_arm = + quote! {#enum_name :: #contract_name (msg) => msg.dispatch(contract, ctx)}; let interfaces_deserialization_attempts = interfaces.emit_deserialization_attempts(msg_ty); @@ -1009,8 +1099,8 @@ impl<'a> GlueMessage<'a> { let msgs = &#name :: messages(); if msgs.into_iter().any(|msg| msg == &recv_msg_name) { match val.deserialize_into() { - Ok(msg) => return Ok(Self:: #contract (msg)), - Err(err) => return Err(D::Error::custom(err)).map(Self:: #contract) + Ok(msg) => return Ok(Self:: #contract_name (msg)), + Err(err) => return Err(D::Error::custom(err)).map(Self:: #contract_name ) }; } }; @@ -1047,19 +1137,19 @@ impl<'a> GlueMessage<'a> { #[allow(clippy::derive_partial_eq_without_eq)] #[derive(#sylvia ::serde::Serialize, Clone, Debug, PartialEq, #sylvia ::schemars::JsonSchema)] #[serde(rename_all="snake_case", untagged)] - pub enum #enum_name { + pub enum #enum_name #used_generics { #(#variants,)* - #msg_name + #contract_variant } - impl #enum_name { - pub fn dispatch( + impl #used_generics #enum_name #used_generics { + pub fn dispatch #unused_generics #where_clause ( self, contract: &#contract, ctx: #ctx_type, ) -> #ret_type { const _: () = { - let msgs: [&[&str]; #variants_cnt] = [#(#messages_call_on_all_variants),*]; + let msgs: [&[&str]; #variants_cnt] = [#(#messages_call),*]; #sylvia ::utils::assert_no_intersection(msgs); }; @@ -1095,7 +1185,7 @@ impl<'a> GlueMessage<'a> { #contract_deserialization_attempt } - let msgs: [&[&str]; #variants_cnt] = [#(#messages_call_on_all_variants),*]; + let msgs: [&[&str]; #variants_cnt] = [#(#messages_call),*]; let mut err_msg = msgs.into_iter().flatten().fold( // It might be better to forward the error or serialization, but we just // deserialized it from JSON, not reason to expect failure here. @@ -1180,12 +1270,13 @@ impl<'a> InterfaceMessages<'a> { } pub struct EntryPoints<'a> { + source: &'a ItemImpl, name: Type, error: Type, custom: Custom<'a>, override_entry_points: OverrideEntryPoints, - has_migrate: bool, - reply: Option, + generics: Vec<&'a GenericParam>, + where_clause: &'a Option, } impl<'a> EntryPoints<'a> { @@ -1209,56 +1300,71 @@ impl<'a> EntryPoints<'a> { ) .unwrap_or_else(|| parse_quote! { #sylvia ::cw_std::StdError }); - let has_migrate = - !MsgVariants::::new(source.as_variants(), MsgType::Migrate, &[], &None) - .variants() - .is_empty(); - - let reply = - MsgVariants::::new(source.as_variants(), MsgType::Reply, &[], &None) - .variants() - .iter() - .map(|variant| variant.function_name.clone()) - .next(); + let generics: Vec<_> = source.generics.params.iter().collect(); + let where_clause = &source.generics.where_clause; let custom = Custom::new(&source.attrs); Self { + source, name, error, custom, override_entry_points, - has_migrate, - reply, + generics, + where_clause, } } pub fn emit(&self) -> TokenStream { let Self { + source, name, error, custom, override_entry_points, - has_migrate, - reply, + generics, + where_clause, } = self; let sylvia = crate_module(); let custom_msg = custom.msg_or_default(); let custom_query = custom.query_or_default(); + let instantiate_variants = MsgVariants::new( + source.as_variants(), + MsgType::Instantiate, + generics, + where_clause, + ); + let exec_variants = + MsgVariants::new(source.as_variants(), MsgType::Exec, generics, where_clause); + let query_variants = + MsgVariants::new(source.as_variants(), MsgType::Query, generics, where_clause); + let migrate_variants = MsgVariants::new( + source.as_variants(), + MsgType::Migrate, + generics, + where_clause, + ); + let reply = + MsgVariants::::new(source.as_variants(), MsgType::Reply, &[], &None) + .variants() + .iter() + .map(|variant| variant.function_name.clone()) + .next(); + #[cfg(not(tarpaulin_include))] { - let entry_points = [MsgType::Instantiate, MsgType::Exec, MsgType::Query] + let entry_points = [instantiate_variants, exec_variants, query_variants] .into_iter() .map( - |msg_type| match override_entry_points.get_entry_point(msg_type) { + |variants| match override_entry_points.get_entry_point(variants.msg_ty) { Some(_) => quote! {}, - None => OverrideEntryPoint::emit_default_entry_point( + None => variants.emit_default_entry_point( &custom_msg, &custom_query, name, error, - msg_type, ), }, ); @@ -1267,14 +1373,9 @@ impl<'a> EntryPoints<'a> { .get_entry_point(MsgType::Migrate) .is_none(); - let migrate = if migrate_not_overridden && *has_migrate { - OverrideEntryPoint::emit_default_entry_point( - &custom_msg, - &custom_query, - name, - error, - MsgType::Migrate, - ) + let migrate = if migrate_not_overridden && migrate_variants.get_only_variant().is_some() + { + migrate_variants.emit_default_entry_point(&custom_msg, &custom_query, name, error) } else { quote! {} }; diff --git a/sylvia-derive/src/multitest.rs b/sylvia-derive/src/multitest.rs index f468789a..4b1e40ab 100644 --- a/sylvia-derive/src/multitest.rs +++ b/sylvia-derive/src/multitest.rs @@ -5,14 +5,34 @@ use syn::parse::{Parse, Parser}; use syn::spanned::Spanned; use syn::{parse_quote, FnArg, ImplItem, ItemImpl, ItemTrait, Pat, PatType, Path, Type}; -use crate::check_generics::{AsIdent, CheckGenerics}; +use crate::check_generics::AsIdent; use crate::crate_module; use crate::interfaces::Interfaces; -use crate::message::MsgField; -use crate::parser::{ - parse_struct_message, Custom, MsgAttr, MsgType, OverrideEntryPoint, OverrideEntryPoints, -}; -use crate::utils::{emit_bracketed_generics, extract_return_type, process_fields}; +use crate::message::{MsgVariant, MsgVariants}; +use crate::parser::{Custom, MsgAttr, MsgType, OverrideEntryPoint, OverrideEntryPoints}; +use crate::utils::{emit_bracketed_generics, extract_return_type}; +use crate::variant_descs::AsVariantDescs; + +fn interface_name(source: &ItemImpl) -> &Ident { + let trait_name = &source.trait_; + let Some(trait_name) = trait_name else { + unreachable!() + }; + let (_, Path { segments, .. }, _) = &trait_name; + assert!(!segments.is_empty()); + + &segments[0].ident +} + +fn extract_contract_name(contract: &Type) -> &Ident { + let Type::Path(type_path) = contract else { + unreachable!() + }; + let segments = &type_path.path.segments; + assert!(!segments.is_empty()); + let segment = &segments[0]; + &segment.ident +} struct MessageSignature<'a> { pub name: &'a Ident, @@ -27,7 +47,6 @@ pub struct MultitestHelpers<'a, Generics> { error_type: TokenStream, contract: &'a Type, is_trait: bool, - is_migrate: bool, reply: Option, source: &'a ItemImpl, generics: &'a [&'a Generics], @@ -36,27 +55,10 @@ pub struct MultitestHelpers<'a, Generics> { custom: &'a Custom<'a>, override_entry_points: &'a OverrideEntryPoints, interfaces: &'a Interfaces, -} - -fn interface_name(source: &ItemImpl) -> &Ident { - let trait_name = &source.trait_; - let Some(trait_name) = trait_name else { - unreachable!() - }; - let (_, Path { segments, .. }, _) = &trait_name; - assert!(!segments.is_empty()); - - &segments[0].ident -} - -fn extract_contract_name(contract: &Type) -> &Ident { - let Type::Path(type_path) = contract else { - unreachable!() - }; - let segments = &type_path.path.segments; - assert!(!segments.is_empty()); - let segment = &segments[0]; - &segment.ident + instantiate_variants: MsgVariants<'a, Generics>, + exec_variants: MsgVariants<'a, Generics>, + query_variants: MsgVariants<'a, Generics>, + migrate_variants: MsgVariants<'a, Generics>, } impl<'a, Generics> MultitestHelpers<'a, Generics> @@ -72,10 +74,27 @@ where override_entry_points: &'a OverrideEntryPoints, interfaces: &'a Interfaces, ) -> Self { - let mut is_migrate = false; let mut reply = None; let sylvia = crate_module(); + let where_clause = &source.generics.where_clause; + let instantiate_variants = MsgVariants::new( + source.as_variants(), + MsgType::Instantiate, + generics, + where_clause, + ); + let exec_variants = + MsgVariants::new(source.as_variants(), MsgType::Exec, generics, where_clause); + let query_variants = + MsgVariants::new(source.as_variants(), MsgType::Query, generics, where_clause); + let migrate_variants = MsgVariants::new( + source.as_variants(), + MsgType::Migrate, + generics, + where_clause, + ); + let messages: Vec<_> = source .items .iter() @@ -91,12 +110,10 @@ where }; let msg_ty = attr.msg_type(); - if msg_ty == MsgType::Migrate { - is_migrate = true; - } else if msg_ty == MsgType::Reply { + if msg_ty == MsgType::Reply { reply = Some(method.sig.ident.clone()); return None; - } else if msg_ty != MsgType::Query && msg_ty != MsgType::Exec { + } else if ![MsgType::Query, MsgType::Exec, MsgType::Migrate].contains(&msg_ty) { return None; } @@ -201,7 +218,6 @@ where error_type, contract, is_trait, - is_migrate, reply, source, generics, @@ -210,6 +226,10 @@ where custom, override_entry_points, interfaces, + instantiate_variants, + exec_variants, + query_variants, + migrate_variants, } } @@ -514,13 +534,15 @@ where } fn generate_contract_helpers(&self) -> TokenStream { + let sylvia = crate_module(); let Self { + source, error_type, is_trait, - source, generics, contract_name, proxy_name, + instantiate_variants, .. } = self; @@ -528,18 +550,34 @@ where return quote! {}; } - let sylvia = crate_module(); + let fields_names = instantiate_variants + .get_only_variant() + .map(MsgVariant::as_fields_names) + .unwrap_or(vec![]); - let mut generics_checker = CheckGenerics::new(generics); + let fields = instantiate_variants + .get_only_variant() + .map(MsgVariant::emit_fields) + .unwrap_or(vec![]); - let parsed = parse_struct_message(source, MsgType::Instantiate); - let Some((method, _)) = parsed else { - return quote! {}; + let used_generics = instantiate_variants.used_generics(); + let bracketed_used_generics = emit_bracketed_generics(used_generics); + let bracketed_generics = emit_bracketed_generics(generics); + let full_where_clause = &source.generics.where_clause; + + let where_predicates = instantiate_variants.where_predicates(); + let where_clause = instantiate_variants.where_clause(); + let contract = if !generics.is_empty() { + quote! { #contract_name ::< #(#generics,)* > } + } else { + quote! { #contract_name } }; - let instantiate_fields = process_fields(&method.sig, &mut generics_checker); - let fields_names: Vec<_> = instantiate_fields.iter().map(MsgField::name).collect(); - let fields = instantiate_fields.iter().map(MsgField::emit); + let instantiate_msg = if !used_generics.is_empty() { + quote! { InstantiateMsg::< #(#used_generics,)* > } + } else { + quote! { InstantiateMsg } + }; let impl_contract = self.generate_impl_contract(); @@ -582,10 +620,10 @@ where IbcT: #sylvia ::cw_multi_test::Ibc, GovT: #sylvia ::cw_multi_test::Gov, { - pub fn store_code(app: &'app #sylvia ::multitest::App< #mt_app >) -> Self { + pub fn store_code #bracketed_generics (app: &'app #sylvia ::multitest::App< #mt_app >) -> Self #full_where_clause { let code_id = app .app_mut() - .store_code(Box::new(#contract_name ::new())); + .store_code(Box::new(#contract ::new())); Self { code_id, app } } @@ -593,11 +631,11 @@ where self.code_id } - pub fn instantiate( + pub fn instantiate #bracketed_used_generics ( &self,#(#fields,)* - ) -> InstantiateProxy<'_, 'app, #mt_app > { - let msg = InstantiateMsg {#(#fields_names,)*}; - InstantiateProxy { + ) -> InstantiateProxy<'_, 'app, #mt_app, #(#used_generics,)* > #where_clause { + let msg = #instantiate_msg {#(#fields_names,)*}; + InstantiateProxy::<_, #(#used_generics,)* > { code_id: self, funds: &[], label: "Contract", @@ -607,17 +645,18 @@ where } } - pub struct InstantiateProxy<'a, 'app, MtApp> { + pub struct InstantiateProxy<'a, 'app, MtApp, #(#used_generics,)* > { code_id: &'a CodeId <'app, MtApp>, funds: &'a [#sylvia ::cw_std::Coin], label: &'a str, admin: Option, - msg: InstantiateMsg, + msg: InstantiateMsg #bracketed_used_generics, } - impl<'a, 'app, MtApp> InstantiateProxy<'a, 'app, MtApp> + impl<'a, 'app, MtApp, #(#used_generics,)* > InstantiateProxy<'a, 'app, MtApp, #(#used_generics,)* > where MtApp: Executor< #custom_msg >, + #(#where_predicates,)* { pub fn with_funds(self, funds: &'a [#sylvia ::cw_std::Coin]) -> Self { Self { funds, ..self } @@ -657,29 +696,35 @@ where fn generate_impl_contract(&self) -> TokenStream { let Self { + source, contract, custom, override_entry_points, + generics, + instantiate_variants, + exec_variants, + query_variants, + migrate_variants, .. } = self; let sylvia = crate_module(); + let bracketed_generics = emit_bracketed_generics(generics); + let full_where_clause = &source.generics.where_clause; let instantiate_body = override_entry_points .get_entry_point(MsgType::Instantiate) .map(OverrideEntryPoint::emit_multitest_dispatch) - .unwrap_or_else(|| { - OverrideEntryPoint::emit_multitest_default_dispatch(MsgType::Instantiate) - }); + .unwrap_or_else(|| instantiate_variants.emit_multitest_default_dispatch()); let exec_body = override_entry_points .get_entry_point(MsgType::Exec) .map(OverrideEntryPoint::emit_multitest_dispatch) - .unwrap_or_else(|| OverrideEntryPoint::emit_multitest_default_dispatch(MsgType::Exec)); + .unwrap_or_else(|| exec_variants.emit_multitest_default_dispatch()); let query_body = override_entry_points .get_entry_point(MsgType::Query) .map(OverrideEntryPoint::emit_multitest_dispatch) - .unwrap_or_else(|| OverrideEntryPoint::emit_multitest_default_dispatch(MsgType::Query)); + .unwrap_or_else(|| query_variants.emit_multitest_default_dispatch()); let sudo_body = override_entry_points .get_entry_point(MsgType::Sudo) @@ -692,8 +737,8 @@ where let migrate_body = match override_entry_points.get_entry_point(MsgType::Migrate) { Some(entry_point) => entry_point.emit_multitest_dispatch(), - None if self.is_migrate => { - OverrideEntryPoint::emit_multitest_default_dispatch(MsgType::Migrate) + None if migrate_variants.get_only_variant().is_some() => { + migrate_variants.emit_multitest_default_dispatch() } None => quote! { #sylvia ::anyhow::bail!("migrate not implemented for contract") @@ -723,7 +768,7 @@ where #[cfg(not(tarpaulin_include))] { quote! { - impl #sylvia ::cw_multi_test::Contract<#custom_msg, #custom_query> for #contract { + impl #bracketed_generics #sylvia ::cw_multi_test::Contract<#custom_msg, #custom_query> for #contract #full_where_clause { fn execute( &self, deps: #sylvia ::cw_std::DepsMut< #custom_query >, diff --git a/sylvia-derive/src/parser.rs b/sylvia-derive/src/parser.rs index fbbf67dc..51ca393c 100644 --- a/sylvia-derive/src/parser.rs +++ b/sylvia-derive/src/parser.rs @@ -1,6 +1,6 @@ use proc_macro2::{Punct, TokenStream}; use proc_macro_error::emit_error; -use quote::quote; +use quote::{quote, ToTokens}; use syn::fold::Fold; use syn::parse::{Error, Nothing, Parse, ParseBuffer, ParseStream, Parser}; use syn::punctuated::Punctuated; @@ -145,13 +145,21 @@ impl MsgType { } } - pub fn emit_msg_name(&self) -> Type { + pub fn emit_msg_name(&self, generics: &[&Generic]) -> Type + where + Generic: ToTokens, + { + let generics = if !generics.is_empty() { + quote! { ::< #(#generics,)* > } + } else { + quote! {} + }; match self { - MsgType::Exec => parse_quote! { ContractExecMsg }, - MsgType::Query => parse_quote! { ContractQueryMsg }, - MsgType::Instantiate => parse_quote! { InstantiateMsg }, - MsgType::Migrate => parse_quote! { MigrateMsg }, - MsgType::Reply => parse_quote! { ReplyMsg }, + MsgType::Exec => parse_quote! { ContractExecMsg #generics }, + MsgType::Query => parse_quote! { ContractQueryMsg #generics }, + MsgType::Instantiate => parse_quote! { InstantiateMsg #generics }, + MsgType::Migrate => parse_quote! { MigrateMsg #generics }, + MsgType::Reply => parse_quote! { ReplyMsg #generics }, MsgType::Sudo => todo!(), } } @@ -512,49 +520,6 @@ impl OverrideEntryPoint { .map_err(Into::into) } } - - pub fn emit_multitest_default_dispatch(ty: MsgType) -> TokenStream { - let sylvia = crate_module(); - - let values = ty.emit_ctx_values(); - let msg_name = ty.emit_msg_name(); - - quote! { - #sylvia ::cw_std::from_slice::< #msg_name >(&msg)? - .dispatch(self, ( #values )) - .map_err(Into::into) - } - } - - #[cfg(not(tarpaulin_include))] - pub fn emit_default_entry_point( - custom_msg: &Type, - custom_query: &Type, - name: &Type, - error: &Type, - msg_type: MsgType, - ) -> TokenStream { - let sylvia = crate_module(); - - let resp_type = match msg_type { - MsgType::Query => quote! { #sylvia ::cw_std::Binary }, - _ => quote! { #sylvia ::cw_std::Response < #custom_msg > }, - }; - let params = msg_type.emit_ctx_params(custom_query); - let values = msg_type.emit_ctx_values(); - let ep_name = msg_type.emit_ep_name(); - let msg_name = msg_type.emit_msg_name(); - - quote! { - #[#sylvia ::cw_std::entry_point] - pub fn #ep_name ( - #params , - msg: #msg_name, - ) -> Result<#resp_type, #error> { - msg.dispatch(&#name ::new() , ( #values )).map_err(Into::into) - } - } - } } #[derive(Debug)] diff --git a/sylvia/tests/generics.rs b/sylvia/tests/generics.rs index 2e6648dc..40b666dd 100644 --- a/sylvia/tests/generics.rs +++ b/sylvia/tests/generics.rs @@ -79,6 +79,32 @@ pub mod non_generic { } } +pub mod generic_contract { + use cosmwasm_std::{CustomQuery, Response, StdResult}; + use serde::de::DeserializeOwned; + use serde::Deserialize; + use sylvia::types::{CustomMsg, InstantiateCtx}; + use sylvia_derive::contract; + + pub struct GenericContract(std::marker::PhantomData<(Msg, QueryRet)>); + + #[contract] + impl GenericContract + where + for<'msg_de> Msg: CustomMsg + Deserialize<'msg_de> + 'msg_de, + for<'a> QueryRet: CustomQuery + DeserializeOwned + 'a, + { + pub const fn new() -> Self { + Self(std::marker::PhantomData) + } + + #[msg(instantiate)] + pub fn instantiate(&self, _ctx: InstantiateCtx, _msg: Msg) -> StdResult { + Ok(Response::new()) + } + } +} + pub mod cw1_contract { use cosmwasm_std::{Response, StdResult}; use sylvia::types::InstantiateCtx; @@ -206,6 +232,7 @@ impl cosmwasm_std::CustomMsg for ExternalMsg {} #[cw_serde] pub struct ExternalQuery; impl cosmwasm_std::CustomQuery for ExternalQuery {} + #[cfg(all(test, feature = "mt"))] mod tests { use crate::cw1::{InterfaceTypes, Querier as Cw1Querier}; @@ -297,4 +324,21 @@ mod tests { .call(owner) .unwrap(); } + + #[test] + fn generic_contract() { + let app = App::default(); + let code_id = crate::generic_contract::multitest_utils::CodeId::store_code::< + ExternalMsg, + ExternalQuery, + >(&app); + + let owner = "owner"; + + code_id + .instantiate(ExternalMsg {}) + .with_label("GenericContract") + .call(owner) + .unwrap(); + } }