diff --git a/Cargo.lock b/Cargo.lock index 5ddf2429..9c44f981 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1115,6 +1115,7 @@ dependencies = [ "cw-storage-plus", "cw-utils", "derivative", + "itertools 0.13.0", "konst", "schemars", "serde", diff --git a/sylvia-derive/src/contract.rs b/sylvia-derive/src/contract.rs index ac428777..36bfc6d5 100644 --- a/sylvia-derive/src/contract.rs +++ b/sylvia-derive/src/contract.rs @@ -2,6 +2,7 @@ use communication::api::Api; use communication::enum_msg::EnumMessage; use communication::executor::Executor; use communication::querier::Querier; +use communication::reply::Reply; use communication::struct_msg::StructMessage; use communication::wrapper_msg::GlueMessage; use mt::MtHelpers; @@ -76,23 +77,11 @@ impl<'a> ContractInput<'a> { .. } = self; let multitest_helpers = self.emit_multitest_helpers(); - - let executor_variants = MsgVariants::new(item.as_variants(), MsgType::Exec, &[], &None); - let querier_variants = MsgVariants::new(item.as_variants(), MsgType::Query, &[], &None); - let executor = Executor::new( - item.generics.clone(), - *item.self_ty.clone(), - executor_variants, - ) - .emit(); - let querier = Querier::new( - item.generics.clone(), - *item.self_ty.clone(), - querier_variants, - ) - .emit(); let messages = self.emit_messages(); let contract_api = Api::new(item, generics, custom).emit(); + let querier = self.emit_querier(); + let executor = self.emit_executor(); + let reply = self.emit_reply(); quote! { pub mod sv { @@ -106,6 +95,8 @@ impl<'a> ContractInput<'a> { #executor + #reply + #contract_api } } @@ -175,4 +166,23 @@ impl<'a> ContractInput<'a> { let generic_params = &self.generics; MtHelpers::new(item, generic_params, custom, override_entry_points.clone()).emit() } + + fn emit_executor(&self) -> TokenStream { + let item = self.item; + let variants = MsgVariants::new(item.as_variants(), MsgType::Exec, &[], &None); + + Executor::new(item.generics.clone(), *item.self_ty.clone(), variants).emit() + } + fn emit_querier(&self) -> TokenStream { + let item = self.item; + let variants = MsgVariants::new(item.as_variants(), MsgType::Query, &[], &None); + + Querier::new(item.generics.clone(), *item.self_ty.clone(), variants).emit() + } + + fn emit_reply(&self) -> TokenStream { + let variants = MsgVariants::new(self.item.as_variants(), MsgType::Reply, &[], &None); + + Reply::new(&variants).emit() + } } diff --git a/sylvia-derive/src/contract/communication/mod.rs b/sylvia-derive/src/contract/communication/mod.rs index 012695c5..a6618959 100644 --- a/sylvia-derive/src/contract/communication/mod.rs +++ b/sylvia-derive/src/contract/communication/mod.rs @@ -2,5 +2,6 @@ pub mod api; pub mod enum_msg; pub mod executor; pub mod querier; +pub mod reply; pub mod struct_msg; pub mod wrapper_msg; diff --git a/sylvia-derive/src/contract/communication/reply.rs b/sylvia-derive/src/contract/communication/reply.rs new file mode 100644 index 00000000..256a8b7c --- /dev/null +++ b/sylvia-derive/src/contract/communication/reply.rs @@ -0,0 +1,60 @@ +use convert_case::{Case, Casing}; +use itertools::Itertools; +use proc_macro2::TokenStream; +use quote::quote; +use syn::{GenericParam, Ident}; + +use crate::types::msg_variant::MsgVariants; + +pub struct Reply<'a> { + variants: &'a MsgVariants<'a, GenericParam>, +} + +impl<'a> Reply<'a> { + pub fn new(variants: &'a MsgVariants<'a, GenericParam>) -> Self { + Self { variants } + } + + pub fn emit(&self) -> TokenStream { + let unique_handlers = self.emit_reply_ids(); + + quote! { + #(#unique_handlers)* + } + } + + fn emit_reply_ids(&self) -> impl Iterator + 'a { + self.variants + .as_reply_ids() + .enumerate() + .map(|(id, reply_id)| { + let id = id as u64; + + quote! { + pub const #reply_id : u64 = #id ; + } + }) + } +} + +trait ReplyVariants<'a> { + fn as_reply_ids(&'a self) -> impl Iterator + 'a; +} + +impl<'a> ReplyVariants<'a> for MsgVariants<'a, GenericParam> { + fn as_reply_ids(&'a self) -> impl Iterator + 'a { + self.variants() + .flat_map(|variant| { + if variant.msg_attr().handlers().is_empty() { + return vec![variant.function_name()]; + } + variant.msg_attr().handlers().iter().collect() + }) + .unique() + .map(|handler| { + let reply_id = + format! {"{}_REPLY_ID", handler.to_string().to_case(Case::UpperSnake)}; + Ident::new(&reply_id, handler.span()) + }) + } +} diff --git a/sylvia-derive/src/contract/communication/struct_msg.rs b/sylvia-derive/src/contract/communication/struct_msg.rs index d28834cd..b8122b02 100644 --- a/sylvia-derive/src/contract/communication/struct_msg.rs +++ b/sylvia-derive/src/contract/communication/struct_msg.rs @@ -96,20 +96,24 @@ impl<'a> StructMessage<'a> { let bracketed_unused_generics = emit_bracketed_generics(unused_generics); let ret_type = variant + .msg_attr() .msg_type() .emit_result_type(&custom.msg_or_default(), &error.error); - let name = variant.msg_type().emit_msg_name(); + let name = variant.msg_attr().msg_type().emit_msg_name(); let function_name = variant.function_name(); - let mut msg_name = variant.msg_type().emit_msg_name(); + let mut msg_name = variant.msg_attr().msg_type().emit_msg_name(); msg_name.set_span(function_name.span()); - let ctx_type = variant.msg_type().emit_ctx_type(&custom.query_or_default()); + let ctx_type = variant + .msg_attr() + .msg_type() + .emit_ctx_type(&custom.query_or_default()); let fields_names: Vec<_> = variant.fields().iter().map(MsgField::name).collect(); let parameters = variant.fields().iter().map(MsgField::emit_method_field); let fields = variant.fields().iter().map(MsgField::emit_pub); let msg_attrs_to_forward = msg_attrs_to_forward.iter().map(|attr| &attr.attrs); - let derive_call = variant.msg_type().emit_derive_call(); + let derive_call = variant.msg_attr().msg_type().emit_derive_call(); quote! { #[allow(clippy::derive_partial_eq_without_eq)] diff --git a/sylvia-derive/src/contract/mt.rs b/sylvia-derive/src/contract/mt.rs index 46161460..3a9ecb19 100644 --- a/sylvia-derive/src/contract/mt.rs +++ b/sylvia-derive/src/contract/mt.rs @@ -659,10 +659,10 @@ impl EmitMethods for MsgVariant<'_> { .map(|field| field.emit_method_field_folded()) .collect(); let arguments = self.as_fields_names(); - let type_name = self.msg_type().as_accessor_name(); + let type_name = self.msg_attr().msg_type().as_accessor_name(); let name = name.to_case(Case::Snake); - match self.msg_type() { + match self.msg_attr().msg_type() { MsgType::Exec => quote! { #[track_caller] fn #name (&self, #(#params,)* ) -> #sylvia ::multitest::ExecProxy::< #error_type, #api :: #type_name, #mt_app, #custom_msg> { @@ -721,10 +721,10 @@ impl EmitMethods for MsgVariant<'_> { .iter() .map(|field| field.emit_method_field_folded()) .collect(); - let type_name = self.msg_type().as_accessor_name(); + let type_name = self.msg_attr().msg_type().as_accessor_name(); let name = name.to_case(Case::Snake); - match self.msg_type() { + match self.msg_attr().msg_type() { MsgType::Exec => quote! { fn #name (&self, #(#params,)* ) -> #sylvia ::multitest::ExecProxy::< #error_type, #api:: #type_name, MtApp, #custom_msg>; }, diff --git a/sylvia-derive/src/interface/mt.rs b/sylvia-derive/src/interface/mt.rs index ad5aaf4a..0fb4f348 100644 --- a/sylvia-derive/src/interface/mt.rs +++ b/sylvia-derive/src/interface/mt.rs @@ -199,10 +199,10 @@ impl EmitMethods for MsgVariant<'_> { .map(|field| field.emit_method_field_folded()) .collect(); let arguments = self.as_fields_names(); - let type_name = self.msg_type().as_accessor_name(); + let type_name = self.msg_attr().msg_type().as_accessor_name(); let name = name.to_case(Case::Snake); - match self.msg_type() { + match self.msg_attr().msg_type() { MsgType::Exec => quote! { #[track_caller] fn #name (&self, #(#params,)* ) -> #sylvia ::multitest::ExecProxy::< #error_type, #api :: #type_name, #mt_app, #custom_msg> { @@ -261,10 +261,10 @@ impl EmitMethods for MsgVariant<'_> { .iter() .map(|field| field.emit_method_field_folded()) .collect(); - let type_name = self.msg_type().as_accessor_name(); + let type_name = self.msg_attr().msg_type().as_accessor_name(); let name = name.to_case(Case::Snake); - match self.msg_type() { + match self.msg_attr().msg_type() { MsgType::Exec => quote! { fn #name (&self, #(#params,)* ) -> #sylvia ::multitest::ExecProxy::< #error_type, #api:: #type_name, MtApp, #custom_msg>; }, diff --git a/sylvia-derive/src/parser/attributes/msg.rs b/sylvia-derive/src/parser/attributes/msg.rs index faf22436..16e1f037 100644 --- a/sylvia-derive/src/parser/attributes/msg.rs +++ b/sylvia-derive/src/parser/attributes/msg.rs @@ -95,7 +95,7 @@ impl ReplyOn { pub struct MsgAttr { msg_type: MsgType, query_resp_type: Option, - _reply_handlers: Vec, + reply_handlers: Vec, _reply_on: ReplyOn, } @@ -114,6 +114,10 @@ impl MsgAttr { pub fn resp_type(&self) -> &Option { &self.query_resp_type } + + pub fn handlers(&self) -> &[Ident] { + &self.reply_handlers + } } impl PartialEq for MsgAttr { @@ -135,7 +139,7 @@ impl Parse for MsgAttr { Ok(Self { msg_type, query_resp_type, - _reply_handlers: reply_handlers, + reply_handlers, _reply_on: reply_on.unwrap_or_default(), }) } diff --git a/sylvia-derive/src/types/msg_variant.rs b/sylvia-derive/src/types/msg_variant.rs index 0432f7a2..7edb1a02 100644 --- a/sylvia-derive/src/types/msg_variant.rs +++ b/sylvia-derive/src/types/msg_variant.rs @@ -166,8 +166,8 @@ impl<'a> MsgVariant<'a> { &self.fields } - pub fn msg_type(&self) -> MsgType { - self.msg_attr.msg_type() + pub fn msg_attr(&self) -> &MsgAttr { + &self.msg_attr } pub fn return_type(&self) -> &Option { diff --git a/sylvia/Cargo.toml b/sylvia/Cargo.toml index 65068448..92df9d7a 100644 --- a/sylvia/Cargo.toml +++ b/sylvia/Cargo.toml @@ -60,6 +60,7 @@ cw-storage-plus = { workspace = true } cw-utils = { workspace = true } thiserror = { workspace = true } trybuild = "1.0.91" +itertools = "0.13.0" [package.metadata.docs.rs] all-features = true diff --git a/sylvia/tests/reply.rs b/sylvia/tests/reply.rs index 64593f29..e0b57634 100644 --- a/sylvia/tests/reply.rs +++ b/sylvia/tests/reply.rs @@ -44,3 +44,31 @@ impl Contract { Ok(Response::new()) } } + +#[cfg(test)] +mod tests { + use itertools::Itertools; + + #[test] + fn reply_id_generation() { + // Assert IDs uniqueness + let unique_ids_count = [ + super::sv::CLEAN_REPLY_ID, + super::sv::HANDLER_ONE_REPLY_ID, + super::sv::HANDLER_TWO_REPLY_ID, + super::sv::REPLY_ON_REPLY_ID, + super::sv::REPLY_ON_ALWAYS_REPLY_ID, + ] + .iter() + .unique() + .count(); + + assert_eq!(unique_ids_count, 5); + + assert_eq!(super::sv::CLEAN_REPLY_ID, 0); + assert_eq!(super::sv::HANDLER_ONE_REPLY_ID, 1); + assert_eq!(super::sv::HANDLER_TWO_REPLY_ID, 2); + assert_eq!(super::sv::REPLY_ON_REPLY_ID, 3); + assert_eq!(super::sv::REPLY_ON_ALWAYS_REPLY_ID, 4); + } +}