diff --git a/sylvia-derive/src/message.rs b/sylvia-derive/src/message.rs index 67633bf2..6be7c033 100644 --- a/sylvia-derive/src/message.rs +++ b/sylvia-derive/src/message.rs @@ -648,6 +648,138 @@ impl<'a> MsgVariant<'a> { } } + pub fn emit_multitest_proxy_methods( + &self, + msg_ty: &MsgType, + custom_msg: &Type, + mt_app: &Type, + error_type: &Type, + ) -> TokenStream { + let sylvia = crate_module(); + let Self { + name, + fields, + return_type, + .. + } = self; + + let params = fields.iter().map(|field| field.emit_method_field()); + let arguments = fields.iter().map(MsgField::name); + let name = Ident::new(&name.to_string().to_case(Case::Snake), name.span()); + + match msg_ty { + MsgType::Exec => quote! { + #[track_caller] + pub fn #name (&self, #(#params,)* ) -> #sylvia ::multitest::ExecProxy::<#error_type, ExecMsg, #mt_app, #custom_msg> { + let msg = ExecMsg:: #name ( #(#arguments),* ); + + #sylvia ::multitest::ExecProxy::new(&self.contract_addr, msg, &self.app) + } + }, + MsgType::Migrate => quote! { + #[track_caller] + pub fn #name (&self, #(#params,)* ) -> #sylvia ::multitest::MigrateProxy::<#error_type, MigrateMsg, #mt_app, #custom_msg> { + let msg = MigrateMsg::new( #(#arguments),* ); + + #sylvia ::multitest::MigrateProxy::new(&self.contract_addr, msg, &self.app) + } + }, + MsgType::Query => quote! { + pub fn #name (&self, #(#params,)* ) -> Result<#return_type, #error_type> { + let msg = QueryMsg:: #name ( #(#arguments),* ); + + (*self.app) + .app() + .wrap() + .query_wasm_smart(self.contract_addr.clone(), &msg) + .map_err(Into::into) + } + }, + _ => quote! {}, + } + } + + pub fn emit_interface_multitest_proxy_methods( + &self, + msg_ty: &MsgType, + custom_msg: &Type, + mt_app: &Type, + error_type: &Type, + generics: &[&Generics], + module: &TokenStream, + ) -> TokenStream + where + Generics: ToTokens, + { + let sylvia = crate_module(); + let Self { + name, + fields, + return_type, + .. + } = self; + + let params = fields.iter().map(|field| field.emit_method_field()); + let arguments = fields.iter().map(MsgField::name); + let bracketed_generics = emit_bracketed_generics(generics); + let interface_enum = quote! { < #module InterfaceTypes #bracketed_generics as #sylvia ::types::InterfaceMessages> }; + let type_name = msg_ty.as_accessor_name(); + let name = Ident::new(&name.to_string().to_case(Case::Snake), name.span()); + + match msg_ty { + MsgType::Exec => quote! { + #[track_caller] + fn #name (&self, #(#params,)* ) -> #sylvia ::multitest::ExecProxy::<#error_type, #interface_enum :: #type_name, #mt_app, #custom_msg> { + let msg = #interface_enum :: #type_name :: #name ( #(#arguments),* ); + + #sylvia ::multitest::ExecProxy::new(&self.contract_addr, msg, &self.app) + } + }, + MsgType::Query => quote! { + fn #name (&self, #(#params,)* ) -> Result<#return_type, #error_type> { + let msg = #interface_enum :: #type_name :: #name ( #(#arguments),* ); + + (*self.app) + .app() + .wrap() + .query_wasm_smart(self.contract_addr.clone(), &msg) + .map_err(Into::into) + } + }, + _ => quote! {}, + } + } + + pub fn emit_proxy_methods_declarations( + &self, + msg_ty: &MsgType, + custom_msg: &Type, + error_type: &Type, + interface_enum: &TokenStream, + ) -> TokenStream { + let sylvia = crate_module(); + let Self { + name, + fields, + return_type, + .. + } = self; + + let params = fields.iter().map(|field| field.emit_method_field()); + let type_name = msg_ty.as_accessor_name(); + let name = Ident::new(&name.to_string().to_case(Case::Snake), name.span()); + + match msg_ty { + MsgType::Exec => quote! { + fn #name (&self, #(#params,)* ) -> #sylvia ::multitest::ExecProxy::<#error_type, #interface_enum :: #type_name, MtApp, #custom_msg>; + }, + MsgType::Query => quote! { + fn #name (&self, #(#params,)* ) -> Result<#return_type, #error_type>; + }, + _ => quote! {}, + } + } + pub fn as_fields_names(&self) -> Vec<&Ident> { self.fields.iter().map(MsgField::name).collect() } @@ -655,6 +787,10 @@ impl<'a> MsgVariant<'a> { pub fn emit_fields(&self) -> Vec { self.fields.iter().map(MsgField::emit).collect() } + + pub fn name(&self) -> &Ident { + &self.name + } } #[derive(Debug)] @@ -884,6 +1020,64 @@ where } } } + pub fn emit_multitest_proxy_methods( + &self, + custom_msg: &Type, + mt_app: &Type, + error_type: &Type, + ) -> Vec { + self.variants + .iter() + .map(|variant| { + variant.emit_multitest_proxy_methods(&self.msg_ty, custom_msg, mt_app, error_type) + }) + .collect() + } + + pub fn emit_interface_multitest_proxy_methods( + &self, + custom_msg: &Type, + mt_app: &Type, + error_type: &Type, + generics: &[&Generics], + module: &TokenStream, + ) -> Vec + where + Generics: ToTokens, + { + self.variants + .iter() + .map(|variant| { + variant.emit_interface_multitest_proxy_methods( + &self.msg_ty, + custom_msg, + mt_app, + error_type, + generics, + module, + ) + }) + .collect() + } + + pub fn emit_proxy_methods_declarations( + &self, + custom_msg: &Type, + error_type: &Type, + interface_enum: &TokenStream, + ) -> Vec { + self.variants + .iter() + .map(|variant| { + variant.emit_proxy_methods_declarations( + &self.msg_ty, + custom_msg, + error_type, + interface_enum, + ) + }) + .collect() + } pub fn emit_dispatch_legs(&self) -> impl Iterator + '_ { self.variants diff --git a/sylvia-derive/src/multitest.rs b/sylvia-derive/src/multitest.rs index 4b1e40ab..ad5ebe63 100644 --- a/sylvia-derive/src/multitest.rs +++ b/sylvia-derive/src/multitest.rs @@ -1,16 +1,14 @@ +use convert_case::{Casing, Case}; use proc_macro2::{Ident, TokenStream}; -use proc_macro_error::emit_error; use quote::{quote, ToTokens}; -use syn::parse::{Parse, Parser}; -use syn::spanned::Spanned; -use syn::{parse_quote, FnArg, ImplItem, ItemImpl, ItemTrait, Pat, PatType, Path, Type}; +use syn::{parse_quote, ImplItem, ItemImpl, ItemTrait, Path, Type}; use crate::check_generics::AsIdent; use crate::crate_module; use crate::interfaces::Interfaces; use crate::message::{MsgVariant, MsgVariants}; -use crate::parser::{Custom, MsgAttr, MsgType, OverrideEntryPoint, OverrideEntryPoints}; -use crate::utils::{emit_bracketed_generics, extract_return_type}; +use crate::parser::{Custom, MsgType, OverrideEntryPoint, OverrideEntryPoints}; +use crate::utils::emit_bracketed_generics; use crate::variant_descs::AsVariantDescs; fn interface_name(source: &ItemImpl) -> &Ident { @@ -34,20 +32,10 @@ fn extract_contract_name(contract: &Type) -> &Ident { &segment.ident } -struct MessageSignature<'a> { - pub name: &'a Ident, - pub params: Vec, - pub arguments: Vec<&'a Ident>, - pub msg_ty: MsgType, - pub return_type: TokenStream, -} - pub struct MultitestHelpers<'a, Generics> { - messages: Vec>, - error_type: TokenStream, + error_type: Type, contract: &'a Type, is_trait: bool, - reply: Option, source: &'a ItemImpl, generics: &'a [&'a Generics], contract_name: &'a Ident, @@ -59,6 +47,7 @@ pub struct MultitestHelpers<'a, Generics> { exec_variants: MsgVariants<'a, Generics>, query_variants: MsgVariants<'a, Generics>, migrate_variants: MsgVariants<'a, Generics>, + reply_variants: MsgVariants<'a, Generics>, } impl<'a, Generics> MultitestHelpers<'a, Generics> @@ -74,9 +63,6 @@ where override_entry_points: &'a OverrideEntryPoints, interfaces: &'a Interfaces, ) -> Self { - let mut reply = None; - let sylvia = crate_module(); - let where_clause = &source.generics.where_clause; let instantiate_variants = MsgVariants::new( source.as_variants(), @@ -94,88 +80,10 @@ where generics, where_clause, ); + let reply_variants = + MsgVariants::new(source.as_variants(), MsgType::Reply, generics, where_clause); - let messages: Vec<_> = source - .items - .iter() - .filter_map(|item| match item { - ImplItem::Method(method) => { - let msg_attr = method.attrs.iter().find(|attr| attr.path.is_ident("msg"))?; - let attr = match MsgAttr::parse.parse2(msg_attr.tokens.clone()) { - Ok(attr) => attr, - Err(err) => { - emit_error!(method.span(), err); - return None; - } - }; - let msg_ty = attr.msg_type(); - - if msg_ty == MsgType::Reply { - reply = Some(method.sig.ident.clone()); - return None; - } else if ![MsgType::Query, MsgType::Exec, MsgType::Migrate].contains(&msg_ty) { - return None; - } - - let sig = &method.sig; - let return_type = if let MsgAttr::Query { resp_type } = attr { - match resp_type { - Some(resp_type) => quote! {#resp_type}, - None => { - let return_type = extract_return_type(&sig.output); - quote! {#return_type} - } - } - } else { - quote! { #sylvia ::cw_multi_test::AppResponse } - }; - - let name = &sig.ident; - let params: Vec<_> = sig - .inputs - .iter() - .skip(2) - .filter_map(|arg| match arg { - FnArg::Typed(ty) => { - let name = match ty.pat.as_ref() { - Pat::Ident(ident) => &ident.ident, - _ => return None, - }; - let ty = &ty.ty; - Some(quote! {#name : #ty}) - } - _ => None, - }) - .collect(); - let arguments: Vec<_> = sig - .inputs - .iter() - .skip(2) - .filter_map(|arg| match arg { - FnArg::Typed(item) => { - let PatType { pat, .. } = item; - let Pat::Ident(ident) = pat.as_ref() else { - unreachable!() - }; - Some(&ident.ident) - } - _ => None, - }) - .collect(); - - Some(MessageSignature { - name, - params, - arguments, - msg_ty, - return_type, - }) - } - _ => None, - }) - .collect(); - - let error_type = if is_trait { + let error_type: Type = if is_trait { let error_type: Vec<_> = source .items .iter() @@ -198,9 +106,9 @@ where assert!(!error_type.is_empty()); let error_type = error_type[0]; - quote! {#error_type} + parse_quote! {#error_type} } else { - quote! {#contract_error} + parse_quote! {#contract_error} }; let contract = &source.self_ty; @@ -214,11 +122,9 @@ where }; Self { - messages, error_type, contract, is_trait, - reply, source, generics, contract_name, @@ -230,17 +136,20 @@ where exec_variants, query_variants, migrate_variants, + reply_variants, } } pub fn emit(&self) -> TokenStream { let Self { - messages, error_type, proxy_name, is_trait, custom, interfaces, + exec_variants, + query_variants, + migrate_variants, .. } = self; let sylvia = crate_module(); @@ -265,49 +174,12 @@ where > }; - #[cfg(not(tarpaulin_include))] - let messages = messages.iter().map(|msg| { - let MessageSignature { - name, - params, - arguments, - msg_ty, - return_type, - } = msg; - if msg_ty == &MsgType::Exec { - quote! { - #[track_caller] - pub fn #name (&self, #(#params,)* ) -> #sylvia ::multitest::ExecProxy::<#error_type, ExecMsg, #mt_app, #custom_msg> { - let msg = ExecMsg:: #name ( #(#arguments),* ); - - #sylvia ::multitest::ExecProxy::new(&self.contract_addr, msg, &self.app) - } - } - } else if msg_ty == &MsgType::Migrate { - quote! { - #[track_caller] - pub fn #name (&self, #(#params,)* ) -> #sylvia ::multitest::MigrateProxy::<#error_type, MigrateMsg, #mt_app, #custom_msg> { - let msg = MigrateMsg::new( #(#arguments),* ); - - #sylvia ::multitest::MigrateProxy::new(&self.contract_addr, msg, &self.app) - } - } - } else if msg_ty == &MsgType::Query { - quote! { - pub fn #name (&self, #(#params,)* ) -> Result<#return_type, #error_type> { - let msg = QueryMsg:: #name ( #(#arguments),* ); - - (*self.app) - .app() - .wrap() - .query_wasm_smart(self.contract_addr.clone(), &msg) - .map_err(Into::into) - } - } - } else { - quote! {} - } - }); + let exec_methods = + exec_variants.emit_multitest_proxy_methods(&custom_msg, &mt_app, error_type); + let query_methods = + query_variants.emit_multitest_proxy_methods(&custom_msg, &mt_app, error_type); + let migrate_methods = + migrate_variants.emit_multitest_proxy_methods(&custom_msg, &mt_app, error_type); let contract_block = self.generate_contract_helpers(); @@ -353,9 +225,10 @@ where #proxy_name{ contract_addr, app } } - #(#messages)* - - #(#proxy_accessors)* + #( #exec_methods )* + #( #migrate_methods )* + #( #query_methods )* + #( #proxy_accessors )* } impl<'app, BankT, ApiT, StorageT, CustomT, WasmT, StakingT, DistrT, IbcT, GovT> @@ -397,11 +270,12 @@ where fn impl_trait_on_proxy(&self) -> TokenStream { let Self { - messages, error_type, custom, interfaces, generics, + exec_variants, + query_variants, .. } = self; @@ -422,7 +296,7 @@ where let custom_msg = custom.msg_or_default(); #[cfg(not(tarpaulin_include))] - let mt_app = quote! { + let mt_app = parse_quote! { #sylvia ::cw_multi_test::App< BankT, ApiT, @@ -439,60 +313,27 @@ where let bracketed_generics = emit_bracketed_generics(generics); let interface_enum = quote! { < #module InterfaceTypes #bracketed_generics as #sylvia ::types::InterfaceMessages> }; - #[cfg(not(tarpaulin_include))] - let methods_definitions = messages.iter().map(|msg| { - let MessageSignature { - name, - params, - arguments, - msg_ty, - return_type, - } = msg; - let type_name = msg_ty.as_accessor_name(); - if msg_ty == &MsgType::Exec { - quote! { - #[track_caller] - fn #name (&self, #(#params,)* ) -> #sylvia ::multitest::ExecProxy::<#error_type, #interface_enum :: #type_name, #mt_app, #custom_msg> { - let msg = #interface_enum :: #type_name :: #name ( #(#arguments),* ); - - #sylvia ::multitest::ExecProxy::new(&self.contract_addr, msg, &self.app) - } - } - } else { - quote! { - fn #name (&self, #(#params,)* ) -> Result<#return_type, #error_type> { - let msg = #interface_enum :: #type_name :: #name ( #(#arguments),* ); - - (*self.app) - .app() - .wrap() - .query_wasm_smart(self.contract_addr.clone(), &msg) - .map_err(Into::into) - } - } - } - }); - - #[cfg(not(tarpaulin_include))] - let methods_declarations = messages.iter().map(|msg| { - let MessageSignature { - name, - params, - msg_ty, - return_type, - .. - } = msg; - let type_name = msg_ty.as_accessor_name(); - if msg_ty == &MsgType::Exec { - quote! { - fn #name (&self, #(#params,)* ) -> #sylvia ::multitest::ExecProxy::<#error_type, #interface_enum :: #type_name, MtApp, #custom_msg>; - } - } else { - quote! { - fn #name (&self, #(#params,)* ) -> Result<#return_type, #error_type>; - } - } - }); + let exec_methods = exec_variants.emit_interface_multitest_proxy_methods( + &custom_msg, + &mt_app, + error_type, + generics, + &module, + ); + let query_methods = query_variants.emit_interface_multitest_proxy_methods( + &custom_msg, + &mt_app, + error_type, + generics, + &module, + ); + let exec_methods_declarations = + exec_variants.emit_proxy_methods_declarations(&custom_msg, error_type, &interface_enum); + let query_methods_declarations = query_variants.emit_proxy_methods_declarations( + &custom_msg, + error_type, + &interface_enum, + ); #[cfg(not(tarpaulin_include))] { @@ -501,7 +342,8 @@ where use super::*; pub trait #trait_name { - #(#methods_declarations)* + #(#query_methods_declarations)* + #(#exec_methods_declarations)* } impl #trait_name< #mt_app > for #module trait_utils:: #proxy_name<'_, #mt_app > @@ -525,8 +367,8 @@ where CustomT::QueryT: #sylvia:: cw_std::CustomQuery + #sylvia ::serde::de::DeserializeOwned + 'static, #mt_app : #sylvia ::cw_multi_test::Executor< #custom_msg > { - - #(#methods_definitions)* + #(#query_methods)* + #(#exec_methods)* } } } @@ -705,6 +547,7 @@ where exec_variants, query_variants, migrate_variants, + reply_variants, .. } = self; let sylvia = crate_module(); @@ -747,12 +590,14 @@ where let reply_body = match override_entry_points.get_entry_point(MsgType::Reply) { Some(entry_point) => entry_point.emit_multitest_dispatch(), - None => self - .reply + None => reply_variants + .get_only_variant() .as_ref() .map(|reply| { + let reply_name = reply.name(); + let reply_name = Ident::new(&reply_name.to_string().to_case(Case::Snake), reply_name.span()); quote! { - self. #reply((deps, env).into(), msg).map_err(Into::into) + self. #reply_name ((deps, env).into(), msg).map_err(Into::into) } }) .unwrap_or_else(|| {