From 8afdaca8689d688518451f9db0fd0b515592f427 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jan=20Wo=C5=BAniak?= Date: Mon, 25 Sep 2023 12:36:35 +0200 Subject: [PATCH] feat: Emit InterfaceTypes --- sylvia-derive/src/input.rs | 8 ++- sylvia-derive/src/message.rs | 111 +++++++++++++++++++++++++++-------- sylvia-derive/src/utils.rs | 2 +- sylvia/src/types.rs | 5 ++ sylvia/tests/generics.rs | 12 ++++ 5 files changed, 113 insertions(+), 25 deletions(-) diff --git a/sylvia-derive/src/input.rs b/sylvia-derive/src/input.rs index 83e05237..718cb043 100644 --- a/sylvia-derive/src/input.rs +++ b/sylvia-derive/src/input.rs @@ -7,7 +7,9 @@ use syn::{parse_quote, GenericParam, Ident, ItemImpl, ItemTrait, TraitItem, Type use crate::crate_module; use crate::interfaces::Interfaces; -use crate::message::{ContractEnumMessage, EnumMessage, GlueMessage, MsgVariants, StructMessage}; +use crate::message::{ + ContractEnumMessage, EnumMessage, GlueMessage, InterfaceMessages, MsgVariants, StructMessage, +}; use crate::multitest::{MultitestHelpers, TraitMultitestHelpers}; use crate::parser::{ContractArgs, ContractErrorAttr, Custom, MsgType, OverrideEntryPoints}; use crate::remote::Remote; @@ -71,6 +73,8 @@ impl<'a> TraitInput<'a> { ) .emit_querier(); + let interface_messages = InterfaceMessages::new(self.item, &self.generics).emit(); + #[cfg(not(tarpaulin_include))] { quote! { @@ -81,6 +85,8 @@ impl<'a> TraitInput<'a> { #remote #querier + + #interface_messages } } } diff --git a/sylvia-derive/src/message.rs b/sylvia-derive/src/message.rs index 67bbbe3f..9825f8e8 100644 --- a/sylvia-derive/src/message.rs +++ b/sylvia-derive/src/message.rs @@ -7,7 +7,7 @@ use crate::parser::{ }; use crate::strip_generics::StripGenerics; use crate::utils::{ - as_where_clause, brace_generics, extract_return_type, filter_wheres, process_fields, + as_where_clause, emit_bracketed_generics, extract_return_type, filter_wheres, process_fields, }; use crate::variant_descs::{AsVariantDescs, VariantDescs}; use convert_case::{Case, Casing}; @@ -114,8 +114,8 @@ impl<'a> StructMessage<'a> { let fields = fields.iter().map(MsgField::emit); let where_clause = as_where_clause(wheres); - let generics = brace_generics(generics); - let unused_generics = brace_generics(unused_generics); + let generics = emit_bracketed_generics(generics); + let unused_generics = emit_bracketed_generics(unused_generics); #[cfg(not(tarpaulin_include))] { @@ -264,7 +264,7 @@ impl<'a> EnumMessage<'a> { let ctx_type = msg_ty.emit_ctx_type(query_type); let dispatch_type = msg_ty.emit_result_type(resp_type, &parse_quote!(C::Error)); - let all_generics = brace_generics(all_generics); + let all_generics = emit_bracketed_generics(all_generics); let phantom = if generics.is_empty() { quote! {} } else if MsgType::Query == *msg_ty { @@ -289,7 +289,7 @@ impl<'a> EnumMessage<'a> { } }; - let generics = brace_generics(generics); + let generics = emit_bracketed_generics(generics); let unique_enum_name = Ident::new(&format!("{}{}", trait_name, name), name.span()); @@ -683,14 +683,14 @@ impl<'a> MsgVariant<'a> { pub struct MsgVariants<'a> { variants: Vec>, unbonded_generics: Vec<&'a GenericParam>, - where_clause: Option, + where_predicates: Vec<&'a WherePredicate>, } impl<'a> MsgVariants<'a> { pub fn new( source: VariantDescs<'a>, msg_type: MsgType, - generics: &'a Vec<&'a GenericParam>, + generics: &'a [&'a GenericParam], unfiltered_where_clause: &'a Option, ) -> Self { let mut generics_checker = CheckGenerics::new(generics); @@ -718,21 +718,21 @@ impl<'a> MsgVariants<'a> { .collect(); let (unbonded_generics, _) = generics_checker.used_unused(); - let wheres = filter_wheres( - unfiltered_where_clause, - generics.as_slice(), - &unbonded_generics, - ); - let where_clause = if !wheres.is_empty() { - Some(parse_quote! { where #(#wheres),* }) - } else { - None - }; + let where_predicates = filter_wheres(unfiltered_where_clause, generics, &unbonded_generics); Self { variants, unbonded_generics, - where_clause, + where_predicates, + } + } + + pub fn where_clause(&self) -> Option { + let where_predicates = &self.where_predicates; + if !where_predicates.is_empty() { + Some(parse_quote! { where #(#where_predicates),* }) + } else { + None } } @@ -745,9 +745,9 @@ impl<'a> MsgVariants<'a> { let Self { variants, unbonded_generics, - where_clause, .. } = self; + let where_clause = self.where_clause(); let methods_impl = variants .iter() @@ -759,7 +759,7 @@ impl<'a> MsgVariants<'a> { .filter(|variant| variant.msg_type == MsgType::Query) .map(MsgVariant::emit_querier_declaration); - let braced_generics = brace_generics(unbonded_generics); + let braced_generics = emit_bracketed_generics(unbonded_generics); let querier = quote! { Querier #braced_generics }; #[cfg(not(tarpaulin_include))] @@ -804,9 +804,9 @@ impl<'a> MsgVariants<'a> { let Self { variants, unbonded_generics, - where_clause, .. } = self; + let where_clause = self.where_clause(); let methods_impl = variants .iter() @@ -1100,6 +1100,71 @@ impl<'a> GlueMessage<'a> { } } +pub struct InterfaceMessages<'a> { + exec_variants: MsgVariants<'a>, + query_variants: MsgVariants<'a>, + generics: &'a [&'a GenericParam], +} + +impl<'a> InterfaceMessages<'a> { + pub fn new(source: &'a ItemTrait, generics: &'a [&'a GenericParam]) -> Self { + let exec_variants = MsgVariants::new( + source.as_variants(), + MsgType::Exec, + generics, + &source.generics.where_clause, + ); + + let query_variants = MsgVariants::new( + source.as_variants(), + MsgType::Query, + generics, + &source.generics.where_clause, + ); + + Self { + exec_variants, + query_variants, + generics, + } + } + + pub fn emit(&self) -> TokenStream { + let sylvia = crate_module(); + let Self { + exec_variants, + query_variants, + generics, + } = self; + + let exec_generics = &exec_variants.unbonded_generics; + let query_generics = &query_variants.unbonded_generics; + + let bracket_generics = emit_bracketed_generics(generics); + let exec_bracketed_generics = emit_bracketed_generics(exec_generics); + let query_bracketed_generics = emit_bracketed_generics(query_generics); + + let phantom = if !generics.is_empty() { + quote! { + _phantom: std::marker::PhantomData<( #(#generics,)* )>, + } + } else { + quote! {} + }; + + quote! { + pub struct InterfaceTypes #bracket_generics { + #phantom + } + + impl #bracket_generics #sylvia ::types::InterfaceMessages for InterfaceTypes #bracket_generics { + type Exec = ExecMsg #exec_bracketed_generics; + type Query = QueryMsg #query_bracketed_generics; + } + } + } +} + pub struct EntryPoints<'a> { name: Type, error: Type, @@ -1130,11 +1195,11 @@ impl<'a> EntryPoints<'a> { ) .unwrap_or_else(|| parse_quote! { #sylvia ::cw_std::StdError }); - let has_migrate = !MsgVariants::new(source.as_variants(), MsgType::Migrate, &vec![], &None) + let has_migrate = !MsgVariants::new(source.as_variants(), MsgType::Migrate, &[], &None) .variants() .is_empty(); - let reply = MsgVariants::new(source.as_variants(), MsgType::Reply, &vec![], &None) + let reply = MsgVariants::new(source.as_variants(), MsgType::Reply, &[], &None) .variants() .iter() .map(|variant| variant.function_name.clone()) diff --git a/sylvia-derive/src/utils.rs b/sylvia-derive/src/utils.rs index 90ced524..16857d4d 100644 --- a/sylvia-derive/src/utils.rs +++ b/sylvia-derive/src/utils.rs @@ -94,7 +94,7 @@ pub fn as_where_clause(where_predicates: &[&WherePredicate]) -> Option TokenStream { +pub fn emit_bracketed_generics(unbonded_generics: &[&GenericParam]) -> TokenStream { match unbonded_generics.is_empty() { true => quote! {}, false => quote! { < #(#unbonded_generics,)* > }, diff --git a/sylvia/src/types.rs b/sylvia/src/types.rs index 71495870..d1bb8535 100644 --- a/sylvia/src/types.rs +++ b/sylvia/src/types.rs @@ -96,3 +96,8 @@ impl<'a, C: CustomQuery> From<(Deps<'a, C>, Env)> for QueryCtx<'a, C> { } pub trait CustomMsg: cosmwasm_std::CustomMsg + DeserializeOwned {} + +pub trait InterfaceMessages { + type Exec; + type Query; +} diff --git a/sylvia/tests/generics.rs b/sylvia/tests/generics.rs index a10f743b..4f4213dd 100644 --- a/sylvia/tests/generics.rs +++ b/sylvia/tests/generics.rs @@ -44,6 +44,8 @@ mod tests { use crate::{cw1::Querier, ExternalMsg, ExternalQuery}; + use crate::cw1::InterfaceTypes; + use sylvia::types::InterfaceMessages; #[test] fn construct_messages() { let contract = Addr::unchecked("contract"); @@ -59,5 +61,15 @@ mod tests { let cw1_querier = crate::cw1::BoundQuerier::borrowed(&contract, &querier); let _: Result = Querier::some_query(&cw1_querier, ExternalMsg {}); let _: Result = cw1_querier.some_query(ExternalMsg {}); + + // Construct messages with Interface extension + let _ = + as InterfaceMessages>::Query::some_query( + ExternalMsg {}, + ); + let _= + as InterfaceMessages>::Exec::execute(vec![ + CosmosMsg::Custom(ExternalMsg {}), + ]); } }