From 252758dfd3639ec08cd244996f04601392407fbc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jan=20Wo=C5=BAniak?= Date: Thu, 18 Apr 2024 14:03:43 +0200 Subject: [PATCH] chore: Move `Multitest` related `MsgVariant` logic to trait in multitest module --- sylvia-derive/src/message.rs | 162 +++----------------------- sylvia-derive/src/multitest.rs | 204 ++++++++++++++++++++++++++------- sylvia-derive/src/querier.rs | 4 - 3 files changed, 181 insertions(+), 189 deletions(-) diff --git a/sylvia-derive/src/message.rs b/sylvia-derive/src/message.rs index 42746e48..a2dd3825 100644 --- a/sylvia-derive/src/message.rs +++ b/sylvia-derive/src/message.rs @@ -607,115 +607,6 @@ impl<'a> MsgVariant<'a> { } } - pub fn emit_mt_method_definition( - &self, - msg_ty: &MsgType, - custom_msg: &Type, - mt_app: &Type, - error_type: &Type, - api: &TokenStream, - ) -> TokenStream { - let sylvia = crate_module(); - let Self { - name, - fields, - return_type, - .. - } = self; - - let params: Vec<_> = fields - .iter() - .map(|field| field.emit_method_field_folded()) - .collect(); - let arguments = fields.iter().map(MsgField::name); - let type_name = msg_ty.as_accessor_name(); - let name = Ident::new(&name.to_string().to_case(Case::Snake), name.span()); - - match msg_ty { - MsgType::Exec => quote! { - #[track_caller] - fn #name (&self, #(#params,)* ) -> #sylvia ::multitest::ExecProxy::< #error_type, #api :: #type_name, #mt_app, #custom_msg> { - let msg = #api :: #type_name :: #name ( #(#arguments),* ); - - #sylvia ::multitest::ExecProxy::new(&self.contract_addr, msg, &self.app) - } - }, - MsgType::Query => { - quote! { - fn #name (&self, #(#params,)* ) -> Result<#return_type, #error_type> { - let msg = #api :: #type_name :: #name ( #(#arguments),* ); - - (*self.app) - .querier() - .query_wasm_smart(self.contract_addr.clone(), &msg) - .map_err(Into::into) - } - } - } - MsgType::Sudo => quote! { - fn #name (&self, #(#params,)* ) -> Result< #sylvia ::cw_multi_test::AppResponse, #error_type> { - let msg = #api :: #type_name :: #name ( #(#arguments),* ); - - (*self.app) - .app_mut() - .wasm_sudo(self.contract_addr.clone(), &msg) - .map_err(|err| err.downcast().unwrap()) - } - }, - MsgType::Migrate => quote! { - #[track_caller] - fn #name (&self, #(#params,)* ) -> #sylvia ::multitest::MigrateProxy::< #error_type, #api :: #type_name , #mt_app, #custom_msg> { - let msg = #api :: #type_name ::new( #(#arguments),* ); - - #sylvia ::multitest::MigrateProxy::new(&self.contract_addr, msg, &self.app) - } - }, - _ => quote! {}, - } - } - - pub fn emit_mt_method_declaration( - &self, - msg_ty: &MsgType, - custom_msg: &Type, - error_type: &Type, - api: &TokenStream, - ) -> TokenStream { - let sylvia = crate_module(); - let Self { - name, - fields, - return_type, - .. - } = self; - - let params: Vec<_> = fields - .iter() - .map(|field| field.emit_method_field_folded()) - .collect(); - let type_name = msg_ty.as_accessor_name(); - let name = Ident::new(&name.to_string().to_case(Case::Snake), name.span()); - - match msg_ty { - MsgType::Exec => quote! { - fn #name (&self, #(#params,)* ) -> #sylvia ::multitest::ExecProxy::< #error_type, #api:: #type_name, MtApp, #custom_msg>; - }, - MsgType::Query => { - quote! { - fn #name (&self, #(#params,)* ) -> Result<#return_type, #error_type>; - } - } - MsgType::Sudo => quote! { - fn #name (&self, #(#params,)* ) -> Result< #sylvia ::cw_multi_test::AppResponse, #error_type>; - }, - MsgType::Migrate => quote! { - #[track_caller] - fn #name (&self, #(#params,)* ) -> #sylvia ::multitest::MigrateProxy::< #error_type, #api :: #type_name, MtApp, #custom_msg>; - }, - _ => quote! {}, - } - } - pub fn as_fields_names(&self) -> Vec<&Ident> { self.fields.iter().map(MsgField::name).collect() } @@ -727,6 +618,18 @@ impl<'a> MsgVariant<'a> { pub fn name(&self) -> &Ident { &self.name } + + pub fn fields(&self) -> &Vec { + &self.fields + } + + pub fn msg_type(&self) -> &MsgType { + &self.msg_type + } + + pub fn return_type(&self) -> &TokenStream { + &self.return_type + } } #[derive(Debug)] @@ -786,8 +689,8 @@ where } } - pub fn variants(&self) -> &Vec> { - &self.variants + pub fn variants(&self) -> impl Iterator { + self.variants.iter() } pub fn used_generics(&self) -> &Vec<&'a Generic> { @@ -833,35 +736,6 @@ where } } - pub fn emit_mt_method_definitions( - &self, - custom_msg: &Type, - mt_app: &Type, - error_type: &Type, - api: &TokenStream, - ) -> Vec { - self.variants - .iter() - .map(|variant| { - variant.emit_mt_method_definition(&self.msg_ty, custom_msg, mt_app, error_type, api) - }) - .collect() - } - - pub fn emit_mt_method_declarations( - &self, - custom_msg: &Type, - error_type: &Type, - api: &TokenStream, - ) -> Vec { - self.variants - .iter() - .map(|variant| { - variant.emit_mt_method_declaration(&self.msg_ty, custom_msg, error_type, api) - }) - .collect() - } - pub fn emit_phantom_match_arm(&self) -> TokenStream { let sylvia = crate_module(); let Self { used_generics, .. } = self; @@ -1349,9 +1223,10 @@ impl<'a> ContractApi<'a> { let contract_query_bracketed_generics = emit_bracketed_generics(&contract_query_generics); let contract_sudo_bracketed_generics = emit_bracketed_generics(&contract_sudo_generics); - let migrate_type = match !migrate_variants.variants().is_empty() { - true => quote! { type Migrate = MigrateMsg #migrate_bracketed_generics; }, - false => quote! { type Migrate = #sylvia ::cw_std::Empty; }, + let migrate_type = if migrate_variants.variants().count() != 0 { + quote! { type Migrate = MigrateMsg #migrate_bracketed_generics; } + } else { + quote! { type Migrate = #sylvia ::cw_std::Empty; } }; let custom_query = custom.query_or_default(); @@ -1531,7 +1406,6 @@ impl<'a> EntryPoints<'a> { let reply = MsgVariants::::new(source.as_variants(), MsgType::Reply, &[], &None) .variants() - .iter() .map(|variant| variant.function_name.clone()) .next(); let sudo_variants = diff --git a/sylvia-derive/src/multitest.rs b/sylvia-derive/src/multitest.rs index 135287b6..67b5abf2 100644 --- a/sylvia-derive/src/multitest.rs +++ b/sylvia-derive/src/multitest.rs @@ -140,23 +140,31 @@ impl<'a> ContractMtHelpers<'a> { }; let api = quote! { < #contract_name as #sylvia ::types::ContractApi> }; - let exec_methods = - exec_variants.emit_mt_method_definitions(&custom_msg, &mt_app, error_type, &api); - let query_methods = - query_variants.emit_mt_method_definitions(&custom_msg, &mt_app, error_type, &api); - let sudo_methods = - sudo_variants.emit_mt_method_definitions(&custom_msg, &mt_app, error_type, &api); - let migrate_methods = - migrate_variants.emit_mt_method_definitions(&custom_msg, &mt_app, error_type, &api); - - let exec_methods_declarations = - exec_variants.emit_mt_method_declarations(&custom_msg, error_type, &api); - let query_methods_declarations = - query_variants.emit_mt_method_declarations(&custom_msg, error_type, &api); - let sudo_methods_declarations = - sudo_variants.emit_mt_method_declarations(&custom_msg, error_type, &api); - let migrate_methods_declarations = - migrate_variants.emit_mt_method_declarations(&custom_msg, error_type, &api); + let exec_methods = exec_variants.variants().map(|variant| { + variant.emit_mt_method_definition(&custom_msg, &mt_app, error_type, &api) + }); + let query_methods = query_variants.variants().map(|variant| { + variant.emit_mt_method_definition(&custom_msg, &mt_app, error_type, &api) + }); + let sudo_methods = sudo_variants.variants().map(|variant| { + variant.emit_mt_method_definition(&custom_msg, &mt_app, error_type, &api) + }); + let migrate_methods = migrate_variants.variants().map(|variant| { + variant.emit_mt_method_definition(&custom_msg, &mt_app, error_type, &api) + }); + + let exec_methods_declarations = exec_variants + .variants() + .map(|variant| variant.emit_mt_method_declaration(&custom_msg, error_type, &api)); + let query_methods_declarations = query_variants + .variants() + .map(|variant| variant.emit_mt_method_declaration(&custom_msg, error_type, &api)); + let sudo_methods_declarations = sudo_variants + .variants() + .map(|variant| variant.emit_mt_method_declaration(&custom_msg, error_type, &api)); + let migrate_methods_declarations = migrate_variants + .variants() + .map(|variant| variant.emit_mt_method_declaration(&custom_msg, error_type, &api)); let where_predicates = where_clause .as_ref() @@ -650,31 +658,25 @@ impl<'a> TraitMtHelpers<'a> { let associated_types_declaration = associated_types.without_error(); - let exec_methods = exec_variants.emit_mt_method_definitions( - &custom_msg, - &mt_app, - &prefixed_error_type, - &api, - ); - let query_methods = query_variants.emit_mt_method_definitions( - &custom_msg, - &mt_app, - &prefixed_error_type, - &api, - ); - let sudo_methods = sudo_variants.emit_mt_method_definitions( - &custom_msg, - &mt_app, - &prefixed_error_type, - &api, - ); + let exec_methods = exec_variants.variants().map(|variant| { + variant.emit_mt_method_definition(&custom_msg, &mt_app, &prefixed_error_type, &api) + }); + let query_methods = query_variants.variants().map(|variant| { + variant.emit_mt_method_definition(&custom_msg, &mt_app, &prefixed_error_type, &api) + }); + let sudo_methods = sudo_variants.variants().map(|variant| { + variant.emit_mt_method_definition(&custom_msg, &mt_app, &prefixed_error_type, &api) + }); - let exec_methods_declarations = - exec_variants.emit_mt_method_declarations(&custom_msg, &prefixed_error_type, &api); - let query_methods_declarations = - query_variants.emit_mt_method_declarations(&custom_msg, &prefixed_error_type, &api); - let sudo_methods_declarations = - sudo_variants.emit_mt_method_declarations(&custom_msg, &prefixed_error_type, &api); + let exec_methods_declarations = exec_variants.variants().map(|variant| { + variant.emit_mt_method_declaration(&custom_msg, &prefixed_error_type, &api) + }); + let query_methods_declarations = query_variants.variants().map(|variant| { + variant.emit_mt_method_declaration(&custom_msg, &prefixed_error_type, &api) + }); + let sudo_methods_declarations = sudo_variants.variants().map(|variant| { + variant.emit_mt_method_declaration(&custom_msg, &prefixed_error_type, &api) + }); let where_predicates = where_clause .as_ref() @@ -737,3 +739,123 @@ fn emit_default_dispatch(msg_ty: &MsgType, contract_name: &Type) -> TokenStream .map_err(Into::into) } } + +trait EmitMethods { + fn emit_mt_method_definition( + &self, + custom_msg: &Type, + mt_app: &Type, + error_type: &Type, + api: &TokenStream, + ) -> TokenStream; + + fn emit_mt_method_declaration( + &self, + custom_msg: &Type, + error_type: &Type, + api: &TokenStream, + ) -> TokenStream; +} + +impl EmitMethods for MsgVariant<'_> { + fn emit_mt_method_definition( + &self, + custom_msg: &Type, + mt_app: &Type, + error_type: &Type, + api: &TokenStream, + ) -> TokenStream { + let sylvia = crate_module(); + + let name = self.name(); + let return_type = self.return_type(); + + let params: Vec<_> = self + .fields() + .iter() + .map(|field| field.emit_method_field_folded()) + .collect(); + let arguments = self.as_fields_names(); + let type_name = self.msg_type().as_accessor_name(); + let name = Ident::new(&name.to_string().to_case(Case::Snake), name.span()); + + match self.msg_type() { + MsgType::Exec => quote! { + #[track_caller] + fn #name (&self, #(#params,)* ) -> #sylvia ::multitest::ExecProxy::< #error_type, #api :: #type_name, #mt_app, #custom_msg> { + let msg = #api :: #type_name :: #name ( #(#arguments),* ); + + #sylvia ::multitest::ExecProxy::new(&self.contract_addr, msg, &self.app) + } + }, + MsgType::Query => { + quote! { + fn #name (&self, #(#params,)* ) -> Result<#return_type, #error_type> { + let msg = #api :: #type_name :: #name ( #(#arguments),* ); + + (*self.app) + .querier() + .query_wasm_smart(self.contract_addr.clone(), &msg) + .map_err(Into::into) + } + } + } + MsgType::Sudo => quote! { + fn #name (&self, #(#params,)* ) -> Result< #sylvia ::cw_multi_test::AppResponse, #error_type> { + let msg = #api :: #type_name :: #name ( #(#arguments),* ); + + (*self.app) + .app_mut() + .wasm_sudo(self.contract_addr.clone(), &msg) + .map_err(|err| err.downcast().unwrap()) + } + }, + MsgType::Migrate => quote! { + #[track_caller] + fn #name (&self, #(#params,)* ) -> #sylvia ::multitest::MigrateProxy::< #error_type, #api :: #type_name , #mt_app, #custom_msg> { + let msg = #api :: #type_name ::new( #(#arguments),* ); + + #sylvia ::multitest::MigrateProxy::new(&self.contract_addr, msg, &self.app) + } + }, + _ => quote! {}, + } + } + + fn emit_mt_method_declaration( + &self, + custom_msg: &Type, + error_type: &Type, + api: &TokenStream, + ) -> TokenStream { + let sylvia = crate_module(); + + let name = self.name(); + let return_type = self.return_type(); + + let params: Vec<_> = self + .fields() + .iter() + .map(|field| field.emit_method_field_folded()) + .collect(); + let type_name = self.msg_type().as_accessor_name(); + let name = Ident::new(&name.to_string().to_case(Case::Snake), name.span()); + + match self.msg_type() { + MsgType::Exec => quote! { + fn #name (&self, #(#params,)* ) -> #sylvia ::multitest::ExecProxy::< #error_type, #api:: #type_name, MtApp, #custom_msg>; + }, + MsgType::Query => quote! { + fn #name (&self, #(#params,)* ) -> Result<#return_type, #error_type>; + }, + MsgType::Sudo => quote! { + fn #name (&self, #(#params,)* ) -> Result< #sylvia ::cw_multi_test::AppResponse, #error_type>; + }, + MsgType::Migrate => quote! { + #[track_caller] + fn #name (&self, #(#params,)* ) -> #sylvia ::multitest::MigrateProxy::< #error_type, #api :: #type_name, MtApp, #custom_msg>; + }, + _ => quote! {}, + } + } +} diff --git a/sylvia-derive/src/querier.rs b/sylvia-derive/src/querier.rs index 3e1f61d3..8558d0cb 100644 --- a/sylvia-derive/src/querier.rs +++ b/sylvia-derive/src/querier.rs @@ -53,13 +53,11 @@ where .collect(); let methods_trait_impl = variants .variants() - .iter() .map(|variant| variant.emit_trait_querier_impl(&assoc_types)) .collect::>(); let methods_declaration = variants .variants() - .iter() .map(|variant| variant.emit_querier_declaration()); let types_declaration = associated_types.filtered(); @@ -112,12 +110,10 @@ impl<'a> ContractQuerier<'a> { let api_path = quote! { < #contract as #sylvia ::types::ContractApi>:: #accessor }; let methods_impl = variants .variants() - .iter() .map(|variant| variant.emit_querier_impl::(&api_path)); let methods_declaration = variants .variants() - .iter() .map(|variant| variant.emit_querier_declaration()); let types_declaration = where_clause