diff --git a/Cargo.lock b/Cargo.lock index 5ddf2429..9c44f981 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1115,6 +1115,7 @@ dependencies = [ "cw-storage-plus", "cw-utils", "derivative", + "itertools 0.13.0", "konst", "schemars", "serde", diff --git a/sylvia-derive/src/contract.rs b/sylvia-derive/src/contract.rs index ac428777..36bfc6d5 100644 --- a/sylvia-derive/src/contract.rs +++ b/sylvia-derive/src/contract.rs @@ -2,6 +2,7 @@ use communication::api::Api; use communication::enum_msg::EnumMessage; use communication::executor::Executor; use communication::querier::Querier; +use communication::reply::Reply; use communication::struct_msg::StructMessage; use communication::wrapper_msg::GlueMessage; use mt::MtHelpers; @@ -76,23 +77,11 @@ impl<'a> ContractInput<'a> { .. } = self; let multitest_helpers = self.emit_multitest_helpers(); - - let executor_variants = MsgVariants::new(item.as_variants(), MsgType::Exec, &[], &None); - let querier_variants = MsgVariants::new(item.as_variants(), MsgType::Query, &[], &None); - let executor = Executor::new( - item.generics.clone(), - *item.self_ty.clone(), - executor_variants, - ) - .emit(); - let querier = Querier::new( - item.generics.clone(), - *item.self_ty.clone(), - querier_variants, - ) - .emit(); let messages = self.emit_messages(); let contract_api = Api::new(item, generics, custom).emit(); + let querier = self.emit_querier(); + let executor = self.emit_executor(); + let reply = self.emit_reply(); quote! { pub mod sv { @@ -106,6 +95,8 @@ impl<'a> ContractInput<'a> { #executor + #reply + #contract_api } } @@ -175,4 +166,23 @@ impl<'a> ContractInput<'a> { let generic_params = &self.generics; MtHelpers::new(item, generic_params, custom, override_entry_points.clone()).emit() } + + fn emit_executor(&self) -> TokenStream { + let item = self.item; + let variants = MsgVariants::new(item.as_variants(), MsgType::Exec, &[], &None); + + Executor::new(item.generics.clone(), *item.self_ty.clone(), variants).emit() + } + fn emit_querier(&self) -> TokenStream { + let item = self.item; + let variants = MsgVariants::new(item.as_variants(), MsgType::Query, &[], &None); + + Querier::new(item.generics.clone(), *item.self_ty.clone(), variants).emit() + } + + fn emit_reply(&self) -> TokenStream { + let variants = MsgVariants::new(self.item.as_variants(), MsgType::Reply, &[], &None); + + Reply::new(&variants).emit() + } } diff --git a/sylvia-derive/src/contract/communication/mod.rs b/sylvia-derive/src/contract/communication/mod.rs index 012695c5..a6618959 100644 --- a/sylvia-derive/src/contract/communication/mod.rs +++ b/sylvia-derive/src/contract/communication/mod.rs @@ -2,5 +2,6 @@ pub mod api; pub mod enum_msg; pub mod executor; pub mod querier; +pub mod reply; pub mod struct_msg; pub mod wrapper_msg; diff --git a/sylvia-derive/src/contract/communication/reply.rs b/sylvia-derive/src/contract/communication/reply.rs new file mode 100644 index 00000000..256a8b7c --- /dev/null +++ b/sylvia-derive/src/contract/communication/reply.rs @@ -0,0 +1,60 @@ +use convert_case::{Case, Casing}; +use itertools::Itertools; +use proc_macro2::TokenStream; +use quote::quote; +use syn::{GenericParam, Ident}; + +use crate::types::msg_variant::MsgVariants; + +pub struct Reply<'a> { + variants: &'a MsgVariants<'a, GenericParam>, +} + +impl<'a> Reply<'a> { + pub fn new(variants: &'a MsgVariants<'a, GenericParam>) -> Self { + Self { variants } + } + + pub fn emit(&self) -> TokenStream { + let unique_handlers = self.emit_reply_ids(); + + quote! { + #(#unique_handlers)* + } + } + + fn emit_reply_ids(&self) -> impl Iterator + 'a { + self.variants + .as_reply_ids() + .enumerate() + .map(|(id, reply_id)| { + let id = id as u64; + + quote! { + pub const #reply_id : u64 = #id ; + } + }) + } +} + +trait ReplyVariants<'a> { + fn as_reply_ids(&'a self) -> impl Iterator + 'a; +} + +impl<'a> ReplyVariants<'a> for MsgVariants<'a, GenericParam> { + fn as_reply_ids(&'a self) -> impl Iterator + 'a { + self.variants() + .flat_map(|variant| { + if variant.msg_attr().handlers().is_empty() { + return vec![variant.function_name()]; + } + variant.msg_attr().handlers().iter().collect() + }) + .unique() + .map(|handler| { + let reply_id = + format! {"{}_REPLY_ID", handler.to_string().to_case(Case::UpperSnake)}; + Ident::new(&reply_id, handler.span()) + }) + } +} diff --git a/sylvia-derive/src/parser/attributes/msg.rs b/sylvia-derive/src/parser/attributes/msg.rs index 5c946301..288ccfa7 100644 --- a/sylvia-derive/src/parser/attributes/msg.rs +++ b/sylvia-derive/src/parser/attributes/msg.rs @@ -94,7 +94,7 @@ impl ReplyOn { pub struct MsgAttr { msg_type: MsgType, query_resp_type: Option, - _reply_handlers: Vec, + reply_handlers: Vec, _reply_on: ReplyOn, } @@ -113,6 +113,10 @@ impl MsgAttr { pub fn resp_type(&self) -> &Option { &self.query_resp_type } + + pub fn handlers(&self) -> &[Ident] { + &self.reply_handlers + } } impl PartialEq for MsgAttr { @@ -134,7 +138,7 @@ impl Parse for MsgAttr { Ok(Self { msg_type, query_resp_type, - _reply_handlers: reply_handlers, + reply_handlers, _reply_on: reply_on.unwrap_or_default(), }) } diff --git a/sylvia-derive/src/types/msg_variant.rs b/sylvia-derive/src/types/msg_variant.rs index 0432f7a2..807f63ba 100644 --- a/sylvia-derive/src/types/msg_variant.rs +++ b/sylvia-derive/src/types/msg_variant.rs @@ -170,6 +170,10 @@ impl<'a> MsgVariant<'a> { self.msg_attr.msg_type() } + pub fn msg_attr(&self) -> &MsgAttr { + &self.msg_attr + } + pub fn return_type(&self) -> &Option { &self.return_type } diff --git a/sylvia/Cargo.toml b/sylvia/Cargo.toml index 65068448..92df9d7a 100644 --- a/sylvia/Cargo.toml +++ b/sylvia/Cargo.toml @@ -60,6 +60,7 @@ cw-storage-plus = { workspace = true } cw-utils = { workspace = true } thiserror = { workspace = true } trybuild = "1.0.91" +itertools = "0.13.0" [package.metadata.docs.rs] all-features = true diff --git a/sylvia/tests/reply.rs b/sylvia/tests/reply.rs index 574c80de..e8599da8 100644 --- a/sylvia/tests/reply.rs +++ b/sylvia/tests/reply.rs @@ -44,3 +44,31 @@ impl Contract { Ok(Response::new()) } } + +#[cfg(test)] +mod tests { + use itertools::Itertools; + + #[test] + fn reply_id_generation() { + // Assert IDs uniqueness + let unique_ids_count = [ + super::sv::CLEAN_REPLY_ID, + super::sv::HANDLER_ONE_REPLY_ID, + super::sv::HANDLER_TWO_REPLY_ID, + super::sv::REPLY_ON_REPLY_ID, + super::sv::REPLY_ON_ALWAYS_REPLY_ID, + ] + .iter() + .unique() + .count(); + + assert_eq!(unique_ids_count, 5); + + assert_eq!(super::sv::CLEAN_REPLY_ID, 0); + assert_eq!(super::sv::HANDLER_ONE_REPLY_ID, 1); + assert_eq!(super::sv::HANDLER_TWO_REPLY_ID, 2); + assert_eq!(super::sv::REPLY_ON_REPLY_ID, 3); + assert_eq!(super::sv::REPLY_ON_ALWAYS_REPLY_ID, 4); + } +}