diff --git a/sylvia-derive/src/contract/communication/reply.rs b/sylvia-derive/src/contract/communication/reply.rs index 5864b039..7790dfac 100644 --- a/sylvia-derive/src/contract/communication/reply.rs +++ b/sylvia-derive/src/contract/communication/reply.rs @@ -6,11 +6,13 @@ use syn::{parse_quote, GenericParam, Ident, ItemImpl, Type}; use crate::crate_module; use crate::parser::attributes::msg::ReplyOn; -use crate::parser::{MsgType, ParsedSylviaAttributes, SylviaAttribute}; +use crate::parser::{MsgType, ParsedSylviaAttributes}; use crate::types::msg_field::MsgField; use crate::types::msg_variant::{MsgVariant, MsgVariants}; use crate::utils::emit_turbofish; +const NUMBER_OF_DATA_FIELDS: usize = 1; + pub struct Reply<'a> { source: &'a ItemImpl, generics: &'a [&'a GenericParam], @@ -173,7 +175,7 @@ impl<'a> ReplyVariants<'a> for MsgVariants<'a, GenericParam> { }, ) } - Some(existing_data) => existing_data.add_second_handler(handler), + Some(existing_data) => existing_data.merge(handler), None => reply_data.push(ReplyData::new(reply_id, handler, handler_id)), } }); @@ -198,9 +200,13 @@ struct ReplyData<'a> { impl<'a> ReplyData<'a> { pub fn new(reply_id: Ident, variant: &'a MsgVariant<'a>, handler_id: &'a Ident) -> Self { - let data = variant.fields().first(); - // Skip the first field reserved for the `data`. - let payload = variant.fields().iter().skip(1).collect::>(); + let data = variant.as_data_field(); + let payload = variant.fields().iter(); + let payload = if data.is_some() || variant.msg_attr().reply_on() != ReplyOn::Success { + payload.skip(NUMBER_OF_DATA_FIELDS).collect::>() + } else { + payload.collect::>() + }; let method_name = variant.function_name(); let reply_on = variant.msg_attr().reply_on(); @@ -214,13 +220,15 @@ impl<'a> ReplyData<'a> { } /// Adds second handler to the reply data provdided their payload signature match. - pub fn add_second_handler(&mut self, new_handler: &'a MsgVariant<'a>) { + pub fn merge(&mut self, new_handler: &'a MsgVariant<'a>) { let (current_method_name, _) = match self.handlers.first() { Some(handler) => handler, _ => return, }; - if self.payload.len() != new_handler.fields().len() - 1 { + let new_reply_data = ReplyData::new(self.reply_id.clone(), new_handler, self.handler_id); + + if self.payload.len() != new_reply_data.payload.len() { emit_error!(current_method_name.span(), "Mismatched quantity of method parameters."; note = self.handler_id.span() => format!("Both `{}` handlers should have the same number of parameters.", self.handler_id); note = new_handler.function_name().span() => format!("Previous definition of {} handler.", self.handler_id) @@ -229,7 +237,7 @@ impl<'a> ReplyData<'a> { self.payload .iter() - .zip(new_handler.fields().iter().skip(1)) + .zip(new_reply_data.payload.iter()) .for_each(|(current_field, new_field)| { if current_field.ty() != new_field.ty() { @@ -377,6 +385,7 @@ impl<'a> ReplyData<'a> { let payload_values = self.payload.iter().map(|field| field.name()); let payload_deserialization = self.payload.emit_payload_deserialization(); let data_deserialization = self.data.map(DataField::emit_data_deserialization); + let data = self.data.map(|_| quote! { data, }); quote! { #sylvia ::cw_std::SubMsgResult::Ok(sub_msg_resp) => { @@ -385,7 +394,7 @@ impl<'a> ReplyData<'a> { #payload_deserialization #data_deserialization - #contract_turbofish ::new(). #method_name ((deps, env, gas_used, events, msg_responses).into(), data, #(#payload_values),* ) + #contract_turbofish ::new(). #method_name ((deps, env, gas_used, events, msg_responses).into(), #data #(#payload_values),* ) } } } @@ -462,6 +471,7 @@ impl<'a> ReplyData<'a> { trait ReplyVariant<'a> { fn as_variant_handlers_pair(&'a self) -> Vec<(&'a MsgVariant<'a>, &'a Ident)>; + fn as_data_field(&'a self) -> Option<&'a MsgField<'a>>; } impl<'a> ReplyVariant<'a> for MsgVariant<'a> { @@ -479,6 +489,22 @@ impl<'a> ReplyVariant<'a> for MsgVariant<'a> { variant_handler_id_pair } + + /// Returns `Some(MsgField)` if a field marked with `sv::data` attribute is present + /// and the `reply_on` attribute is set to `ReplyOn::Success`. + fn as_data_field(&'a self) -> Option<&'a MsgField<'a>> { + let data_attrs = self.fields().first().map(|field| { + ParsedSylviaAttributes::new(field.attrs().iter()) + .data + .is_some() + }); + match data_attrs { + Some(attrs) if attrs && self.msg_attr().reply_on() == ReplyOn::Success => { + self.fields().first() + } + _ => None, + } + } } pub trait DataField { @@ -489,10 +515,6 @@ impl DataField for MsgField<'_> { fn emit_data_deserialization(&self) -> TokenStream { let sylvia = crate_module(); let data = ParsedSylviaAttributes::new(self.attrs().iter()).data; - let is_data_attr = self - .attrs() - .iter() - .any(|attr| SylviaAttribute::new(attr) == Some(SylviaAttribute::Data)); let missing_data_err = "Missing reply data field."; let invalid_reply_data_err = quote! { format! {"Invalid reply data: {}\nSerde error while deserializing {}", data, err} @@ -555,7 +577,7 @@ impl DataField for MsgField<'_> { None => None, }; }, - None if is_data_attr => quote! { + Some(_) => quote! { let data = match data { Some(data) => { #execute_data_deserialization @@ -616,8 +638,11 @@ impl PayloadFields for Vec<&MsgField<'_>> { } fn is_payload_marked(&self) -> bool { - self.iter() - .any(|field| field.contains_attribute(SylviaAttribute::Payload)) + self.iter().any(|field| { + ParsedSylviaAttributes::new(field.attrs().iter()) + .payload + .is_some() + }) } } diff --git a/sylvia-derive/src/parser/attributes/mod.rs b/sylvia-derive/src/parser/attributes/mod.rs index 215bda73..9d3bedf4 100644 --- a/sylvia-derive/src/parser/attributes/mod.rs +++ b/sylvia-derive/src/parser/attributes/mod.rs @@ -3,6 +3,7 @@ use data::DataFieldParams; use features::SylviaFeatures; +use payload::PayloadFieldParam; use proc_macro_error::emit_error; use syn::spanned::Spanned; use syn::{Attribute, MetaList, PathSegment}; @@ -15,6 +16,7 @@ pub mod features; pub mod messages; pub mod msg; pub mod override_entry_point; +pub mod payload; pub use attr::{MsgAttrForwarding, VariantAttrForwarding}; pub use custom::Custom; @@ -79,6 +81,7 @@ pub struct ParsedSylviaAttributes { pub msg_attrs_forward: Vec, pub sv_features: SylviaFeatures, pub data: Option, + pub payload: Option, } impl ParsedSylviaAttributes { @@ -90,6 +93,14 @@ impl ParsedSylviaAttributes { if let (Some(sylvia_attr), Ok(attr)) = (sylvia_attr, &attr_content) { result.match_attribute(&sylvia_attr, attr); + } else if sylvia_attr == Some(SylviaAttribute::Data) { + // The `sv::data` attribute can be used without parameters. + result.data = Some(DataFieldParams::default()); + } else if sylvia_attr == Some(SylviaAttribute::Payload) { + emit_error!( + attr.span(), "Missing parameters for `sv::payload`"; + note = "Expected `#[sv::payload(raw)]`" + ); } } @@ -172,10 +183,9 @@ impl ParsedSylviaAttributes { } } SylviaAttribute::Payload => { - emit_error!( - attr, "The attribute `sv::payload` used in wrong context"; - note = attr.span() => "The `sv::payload` should be used as a prefix for `Binary` payload."; - ); + if let Ok(payload) = PayloadFieldParam::new(attr) { + self.payload = Some(payload); + } } SylviaAttribute::Data => { if let Ok(data) = DataFieldParams::new(attr) { diff --git a/sylvia-derive/src/parser/attributes/payload.rs b/sylvia-derive/src/parser/attributes/payload.rs new file mode 100644 index 00000000..1f47951d --- /dev/null +++ b/sylvia-derive/src/parser/attributes/payload.rs @@ -0,0 +1,44 @@ +use proc_macro_error::emit_error; +use syn::parse::{Parse, ParseStream, Parser}; +use syn::{Error, Ident, MetaList, Result}; + +/// Type wrapping data parsed from `sv::payload` attribute. +#[derive(Default, Debug)] +pub struct PayloadFieldParam; + +impl PayloadFieldParam { + pub fn new(attr: &MetaList) -> Result { + let data = PayloadFieldParam::parse + .parse2(attr.tokens.clone()) + .map_err(|err| { + emit_error!(err.span(), err); + err + })?; + + Ok(data) + } +} + +impl Parse for PayloadFieldParam { + fn parse(input: ParseStream) -> Result { + let option: Ident = input.parse()?; + match option.to_string().as_str() { + "raw" => (), + _ => { + return Err(Error::new( + option.span(), + "Invalid payload parameter.\n= note: Expected [`raw`].\n", + )) + } + }; + + if !input.is_empty() { + return Err(Error::new( + input.span(), + "Unexpected tokens inside `sv::payload` attribute.\n= note: Expected parameters: [`raw`] `.\n", + )); + } + + Ok(Self) + } +} diff --git a/sylvia-derive/src/types/msg_field.rs b/sylvia-derive/src/types/msg_field.rs index 46bfffc4..ecaf9901 100644 --- a/sylvia-derive/src/types/msg_field.rs +++ b/sylvia-derive/src/types/msg_field.rs @@ -1,6 +1,5 @@ use crate::fold::StripSelfPath; use crate::parser::check_generics::{CheckGenerics, GetPath}; -use crate::parser::SylviaAttribute; use proc_macro2::TokenStream; use proc_macro_error::emit_error; use quote::quote; @@ -124,10 +123,4 @@ impl<'a> MsgField<'a> { pub fn attrs(&self) -> &'a Vec { self.attrs } - - pub fn contains_attribute(&self, sv_attr: SylviaAttribute) -> bool { - self.attrs - .iter() - .any(|attr| SylviaAttribute::new(attr) == Some(sv_attr)) - } } diff --git a/sylvia/tests/reply_data.rs b/sylvia/tests/reply_data.rs index f336720c..faea8699 100644 --- a/sylvia/tests/reply_data.rs +++ b/sylvia/tests/reply_data.rs @@ -122,7 +122,7 @@ impl Contract { // MultiTest version is released. // Payload is not currently forwarded in the MultiTest. // _instantiate_payload: InstantiatePayload, - #[sv::payload] _payload: Binary, + #[sv::payload(raw)] _payload: Binary, ) -> Result { let remote_addr = Addr::unchecked(data.contract_address); @@ -137,7 +137,7 @@ impl Contract { &self, _ctx: ReplyCtx, #[sv::data(instantiate, opt)] _data: Option, - #[sv::payload] _payload: Binary, + #[sv::payload(raw)] _payload: Binary, ) -> Result { Ok(Response::new()) } @@ -147,7 +147,7 @@ impl Contract { &self, _ctx: ReplyCtx, #[sv::data(raw, opt)] _data: Option, - #[sv::payload] _payload: Binary, + #[sv::payload(raw)] _payload: Binary, ) -> Result { Ok(Response::new()) } @@ -157,7 +157,7 @@ impl Contract { &self, _ctx: ReplyCtx, #[sv::data(raw)] _data: Binary, - #[sv::payload] _payload: Binary, + #[sv::payload(raw)] _payload: Binary, ) -> Result { Ok(Response::new()) } @@ -167,7 +167,7 @@ impl Contract { &self, _ctx: ReplyCtx, #[sv::data(opt)] _data: Option, - #[sv::payload] _payload: Binary, + #[sv::payload(raw)] _payload: Binary, ) -> Result { Ok(Response::new()) } @@ -177,13 +177,23 @@ impl Contract { &self, _ctx: ReplyCtx, #[sv::data] _data: String, - #[sv::payload] _payload: Binary, + #[sv::payload(raw)] _payload: Binary, + ) -> Result { + Ok(Response::new()) + } + + #[sv::msg(reply, reply_on=success)] + fn no_data( + &self, + _ctx: ReplyCtx, + #[sv::payload(raw)] _payload: Binary, ) -> Result { Ok(Response::new()) } } -mod tests { +#[test] +fn dispatch_replies() { use crate::noop_contract::sv::mt::CodeId as NoopCodeId; use crate::sv::mt::{CodeId, ContractProxy}; use crate::sv::{DATA_OPT_REPLY_ID, DATA_RAW_OPT_REPLY_ID, DATA_RAW_REPLY_ID, DATA_REPLY_ID}; @@ -192,88 +202,85 @@ mod tests { use sylvia::cw_multi_test::IntoBech32; use sylvia::multitest::App; - #[test] - fn dispatch_replies() { - let app = App::default(); - let code_id = CodeId::store_code(&app); - let noop_code_id = NoopCodeId::store_code(&app); - - let owner = "owner".into_bech32(); - let data = Some(to_json_binary(&String::from("some_data")).unwrap()); - let invalid_data = Some(Binary::from("InvalidData".as_bytes())); - - // Trigger remote instantiation reply - let contract = code_id - .instantiate(noop_code_id.code_id()) - .with_label("Contract") - .call(&owner) - .unwrap(); - - // Should forward `data` in every case - contract - .send_message_expecting_data(None, DATA_RAW_OPT_REPLY_ID) - .call(&owner) - .unwrap(); - - contract - .send_message_expecting_data(data.clone(), DATA_RAW_OPT_REPLY_ID) - .call(&owner) - .unwrap(); - - // Should forward `data` if `Some` and return error if `None` - let err = contract - .send_message_expecting_data(None, DATA_RAW_REPLY_ID) - .call(&owner) - .unwrap_err(); - assert_eq!( - err, - StdError::generic_err("Missing reply data field.").into() - ); - - contract - .send_message_expecting_data(data.clone(), DATA_RAW_REPLY_ID) - .call(&owner) - .unwrap(); - - // Should forward deserialized `data` if `Some` or None and return error if deserialization fails - contract - .send_message_expecting_data(None, DATA_OPT_REPLY_ID) - .call(&owner) - .unwrap(); - - let err = contract - .send_message_expecting_data(invalid_data.clone(), DATA_OPT_REPLY_ID) - .call(&owner) - .unwrap_err(); - assert_eq!( + let app = App::default(); + let code_id = CodeId::store_code(&app); + let noop_code_id = NoopCodeId::store_code(&app); + + let owner = "owner".into_bech32(); + let data = Some(to_json_binary(&String::from("some_data")).unwrap()); + let invalid_data = Some(Binary::from("InvalidData".as_bytes())); + + // Trigger remote instantiation reply + let contract = code_id + .instantiate(noop_code_id.code_id()) + .with_label("Contract") + .call(&owner) + .unwrap(); + + // Should forward `data` in every case + contract + .send_message_expecting_data(None, DATA_RAW_OPT_REPLY_ID) + .call(&owner) + .unwrap(); + + contract + .send_message_expecting_data(data.clone(), DATA_RAW_OPT_REPLY_ID) + .call(&owner) + .unwrap(); + + // Should forward `data` if `Some` and return error if `None` + let err = contract + .send_message_expecting_data(None, DATA_RAW_REPLY_ID) + .call(&owner) + .unwrap_err(); + assert_eq!( + err, + StdError::generic_err("Missing reply data field.").into() + ); + + contract + .send_message_expecting_data(data.clone(), DATA_RAW_REPLY_ID) + .call(&owner) + .unwrap(); + + // Should forward deserialized `data` if `Some` or None and return error if deserialization fails + contract + .send_message_expecting_data(None, DATA_OPT_REPLY_ID) + .call(&owner) + .unwrap(); + + let err = contract + .send_message_expecting_data(invalid_data.clone(), DATA_OPT_REPLY_ID) + .call(&owner) + .unwrap_err(); + assert_eq!( err, StdError::generic_err("Invalid reply data: SW52YWxpZERhdGE=\nSerde error while deserializing Error parsing into type alloc::string::String: Invalid type").into() ); - contract - .send_message_expecting_data(data.clone(), DATA_OPT_REPLY_ID) - .call(&owner) - .unwrap(); - - // Should forward deserialized `data` if `Some` and return error if `None` or if deserialization fails - let err = contract - .send_message_expecting_data(None, DATA_REPLY_ID) - .call(&owner) - .unwrap_err(); - assert_eq!( - err, - StdError::generic_err("Missing reply data field.").into() - ); - - let err = contract - .send_message_expecting_data(invalid_data, DATA_REPLY_ID) - .call(&owner) - .unwrap_err(); - assert_eq!(err, StdError::generic_err("Invalid reply data: SW52YWxpZERhdGE=\nSerde error while deserializing Error parsing into type alloc::string::String: Invalid type").into()); - - contract - .send_message_expecting_data(data, DATA_REPLY_ID) - .call(&owner) - .unwrap(); - } + contract + .send_message_expecting_data(data.clone(), DATA_OPT_REPLY_ID) + .call(&owner) + .unwrap(); + + // Should forward deserialized `data` if `Some` and return error if `None` or if deserialization fails + let err = contract + .send_message_expecting_data(None, DATA_REPLY_ID) + .call(&owner) + .unwrap_err(); + assert_eq!( + err, + StdError::generic_err("Missing reply data field.").into() + ); + + let err = contract + .send_message_expecting_data(invalid_data, DATA_REPLY_ID) + .call(&owner) + .unwrap_err(); + assert_eq!(err, StdError::generic_err("Invalid reply data: SW52YWxpZERhdGE=\nSerde error while deserializing Error parsing into type alloc::string::String: Invalid type").into()); + + contract + .send_message_expecting_data(data, DATA_REPLY_ID) + .call(&owner) + .unwrap(); } diff --git a/sylvia/tests/reply_dispatch.rs b/sylvia/tests/reply_dispatch.rs index 23569c5c..f9583aaf 100644 --- a/sylvia/tests/reply_dispatch.rs +++ b/sylvia/tests/reply_dispatch.rs @@ -219,7 +219,7 @@ where // MultiTest version is released. // Payload is not currently forwarded in the MultiTest. // _instantiate_payload: InstantiatePayload, - #[sv::payload] _payload: Binary, + #[sv::payload(raw)] _payload: Binary, ) -> Result, ContractError> { self.last_reply .save(ctx.deps.storage, &REMOTE_INSTANTIATED_REPLY_ID)?; @@ -237,7 +237,7 @@ where &self, ctx: ReplyCtx, #[sv::data(raw, opt)] _data: Option, - #[sv::payload] _payload: Binary, + #[sv::payload(raw)] _payload: Binary, ) -> Result, ContractError> { self.last_reply.save(ctx.deps.storage, &SUCCESS_REPLY_ID)?; @@ -249,7 +249,7 @@ where &self, ctx: ReplyCtx, _error: String, - #[sv::payload] _payload: Binary, + #[sv::payload(raw)] _payload: Binary, ) -> Result, ContractError> { self.last_reply.save(ctx.deps.storage, &FAILURE_REPLY_ID)?; @@ -261,7 +261,7 @@ where &self, ctx: ReplyCtx, _result: SubMsgResult, - #[sv::payload] _payload: Binary, + #[sv::payload(raw)] _payload: Binary, // _first_part_payload: u32, // _second_part_payload: String, ) -> Result, ContractError> { diff --git a/sylvia/tests/reply_generation.rs b/sylvia/tests/reply_generation.rs index afd65ad2..23a8ae1a 100644 --- a/sylvia/tests/reply_generation.rs +++ b/sylvia/tests/reply_generation.rs @@ -24,7 +24,7 @@ impl Contract { &self, _ctx: ReplyCtx, _result: SubMsgResult, - #[sv::payload] _payload: Binary, + #[sv::payload(raw)] _payload: Binary, ) -> StdResult { Ok(Response::new()) } @@ -35,7 +35,7 @@ impl Contract { &self, _ctx: ReplyCtx, _result: SubMsgResult, - #[sv::payload] _payload: Binary, + #[sv::payload(raw)] _payload: Binary, ) -> StdResult { Ok(Response::new()) } @@ -46,7 +46,7 @@ impl Contract { &self, _ctx: ReplyCtx, #[sv::data(raw, opt)] _data: Option, - #[sv::payload] _payload: Binary, + #[sv::payload(raw)] _payload: Binary, ) -> StdResult { Ok(Response::new()) } @@ -57,7 +57,7 @@ impl Contract { &self, _ctx: ReplyCtx, _result: SubMsgResult, - #[sv::payload] _payload: Binary, + #[sv::payload(raw)] _payload: Binary, ) -> StdResult { Ok(Response::new()) } @@ -68,7 +68,7 @@ impl Contract { &self, _ctx: ReplyCtx, _error: String, - #[sv::payload] _payload: Binary, + #[sv::payload()] _payload: Binary, ) -> StdResult { Ok(Response::new()) }