Skip to content

Commit

Permalink
feat: Support generics on messages attribute in main contract macro
Browse files Browse the repository at this point in the history
call
  • Loading branch information
jawoznia committed Oct 18, 2023
1 parent f046d9e commit fd40f3d
Show file tree
Hide file tree
Showing 4 changed files with 101 additions and 33 deletions.
74 changes: 53 additions & 21 deletions sylvia-derive/src/interfaces.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,6 @@ pub struct Interfaces {
}

impl Interfaces {
fn merge_module_with_name(message_attr: &ContractMessageAttr, name: &syn::Ident) -> syn::Ident {
// ContractMessageAttr will fail to parse empty `#[messsages()]` attribute so we can safely unwrap here
let syn::PathSegment { ident, .. } = &message_attr.module.segments.last().unwrap();
let module_name = ident.to_string().to_case(Case::UpperCamel);
syn::Ident::new(&format!("{}{}", module_name, name), name.span())
}

pub fn new(source: &ItemImpl) -> Self {
let interfaces: Vec<_> = source
.attrs
Expand Down Expand Up @@ -90,11 +83,19 @@ impl Interfaces {
.iter()
.map(|interface| {
let ContractMessageAttr {
module, variant, ..
module,
variant,
generics,
..
} = interface;
let generics = if !generics.is_empty() {
quote! { < #generics > }
} else {
quote! {}
};

let interface_enum =
quote! { <#module ::InterfaceTypes as #sylvia ::types::InterfaceMessages> };
quote! { <#module ::InterfaceTypes #generics as #sylvia ::types::InterfaceMessages> };
if msg_ty == &MsgType::Query {
quote! { #variant ( #interface_enum :: Query) }
} else {
Expand All @@ -104,28 +105,46 @@ impl Interfaces {
.collect()
}

pub fn emit_messages_call(&self, msg_name: &Ident) -> Vec<TokenStream> {
pub fn emit_messages_call(&self, msg_ty: &MsgType) -> Vec<TokenStream> {
let sylvia = crate_module();

self.interfaces
.iter()
.map(|interface| {
let enum_name = Self::merge_module_with_name(interface, msg_name);
let module = &interface.module;
quote! { &#module :: #enum_name :: messages()}
let ContractMessageAttr {
module, generics, ..
} = interface;
let generics = if !generics.is_empty() {
quote! { < #generics > }
} else {
quote! {}
};
let type_name = msg_ty.as_accessor_name();
quote! {
&<#module :: InterfaceTypes #generics as #sylvia ::types::InterfaceMessages> :: #type_name :: messages()
}
})
.collect()
}

pub fn emit_deserialization_attempts(&self, msg_name: &Ident) -> Vec<TokenStream> {
pub fn emit_deserialization_attempts(&self, msg_ty: &MsgType) -> Vec<TokenStream> {
let sylvia = crate_module();

self.interfaces
.iter()
.map(|interface| {
let ContractMessageAttr {
module, variant, ..
module, variant, generics, ..
} = interface;
let enum_name = Self::merge_module_with_name(interface, msg_name);
let generics = if !generics.is_empty() {
quote! { < #generics > }
} else {
quote! {}
};

let type_name = msg_ty.as_accessor_name();
quote! {
let msgs = &#module :: #enum_name ::messages();
let msgs = &<#module :: InterfaceTypes #generics as #sylvia ::types::InterfaceMessages> :: #type_name :: messages();
if msgs.into_iter().any(|msg| msg == &recv_msg_name) {
match val.deserialize_into() {
Ok(msg) => return Ok(Self:: #variant (msg)),
Expand All @@ -137,13 +156,26 @@ impl Interfaces {
.collect()
}

pub fn emit_response_schemas_calls(&self, msg_name: &Ident) -> Vec<TokenStream> {
pub fn emit_response_schemas_calls(&self, msg_ty: &MsgType) -> Vec<TokenStream> {
let sylvia = crate_module();

self.interfaces
.iter()
.map(|interface| {
let enum_name = Self::merge_module_with_name(interface, msg_name);
let module = &interface.module;
quote! { #module :: #enum_name :: response_schemas_impl()}
let ContractMessageAttr {
module, generics, ..
} = interface;

let generics = if !generics.is_empty() {
quote! { < #generics > }
} else {
quote! {}
};

let type_name = msg_ty.as_accessor_name();
quote! {
<#module :: InterfaceTypes #generics as #sylvia ::types::InterfaceMessages> :: #type_name :: response_schemas_impl()
}
})
.collect()
}
Expand Down
6 changes: 3 additions & 3 deletions sylvia-derive/src/message.rs
Original file line number Diff line number Diff line change
Expand Up @@ -970,7 +970,7 @@ impl<'a> GlueMessage<'a> {

let msg_name = quote! {#contract ( #name)};
let mut messages_call_on_all_variants: Vec<TokenStream> =
interfaces.emit_messages_call(name);
interfaces.emit_messages_call(msg_ty);
messages_call_on_all_variants.push(quote! {&#name :: messages()});

let variants_cnt = messages_call_on_all_variants.len();
Expand Down Expand Up @@ -1004,7 +1004,7 @@ impl<'a> GlueMessage<'a> {

let dispatch_arm = quote! {#enum_name :: #contract (msg) => msg.dispatch(contract, ctx)};

let interfaces_deserialization_attempts = interfaces.emit_deserialization_attempts(name);
let interfaces_deserialization_attempts = interfaces.emit_deserialization_attempts(msg_ty);

#[cfg(not(tarpaulin_include))]
let contract_deserialization_attempt = quote! {
Expand All @@ -1020,7 +1020,7 @@ impl<'a> GlueMessage<'a> {
let ctx_type = msg_ty.emit_ctx_type(&custom.query_or_default());
let ret_type = msg_ty.emit_result_type(&custom.msg_or_default(), error);

let mut response_schemas_calls = interfaces.emit_response_schemas_calls(name);
let mut response_schemas_calls = interfaces.emit_response_schemas_calls(msg_ty);
response_schemas_calls.push(quote! {#name :: response_schemas_impl()});

let response_schemas = match name.to_string().as_str() {
Expand Down
34 changes: 31 additions & 3 deletions sylvia-derive/src/parser.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
use proc_macro2::{Punct, TokenStream};
use proc_macro_error::emit_error;
use quote::quote;
use syn::fold::Fold;
use syn::parse::{Error, Nothing, Parse, ParseBuffer, ParseStream, Parser};
use syn::punctuated::Punctuated;
use syn::spanned::Spanned;
use syn::{
parenthesized, parse_quote, Attribute, Ident, ImplItem, ImplItemMethod, ItemImpl, ItemTrait,
Path, Result, Token, TraitItem, Type,
parenthesized, parse_quote, Attribute, GenericArgument, Ident, ImplItem, ImplItemMethod,
ItemImpl, ItemTrait, Path, PathArguments, Result, Token, TraitItem, Type,
};

use crate::crate_module;
use crate::strip_generics::StripGenerics;

/// Parser arguments for `contract` macro
pub struct ContractArgs {
Expand Down Expand Up @@ -248,6 +251,7 @@ pub struct ContractMessageAttr {
pub module: Path,
pub variant: Ident,
pub customs: Customs,
pub generics: Punctuated<GenericArgument, Token![,]>,
}

fn interface_has_custom(content: ParseStream) -> Result<Customs> {
Expand Down Expand Up @@ -285,14 +289,36 @@ fn interface_has_custom(content: ParseStream) -> Result<Customs> {
Ok(customs)
}

fn extract_generics_from_path(module: &mut Path) -> Punctuated<GenericArgument, Token![,]> {
let generics = module.segments.last().map(|segment| {
match segment.arguments.clone(){
PathArguments::AngleBracketed(generics) => {
generics.args
},
PathArguments::None => Default::default(),
PathArguments::Parenthesized(generics) => {
emit_error!(
generics.span(), "Found paranthesis wrapping generics in `messages` attribute.";
note = "Expected `messages` attribute to be in form `#[messages(Path<generics> as Type)]`"

Check warning on line 302 in sylvia-derive/src/parser.rs

View check run for this annotation

Codecov / codecov/patch

sylvia-derive/src/parser.rs#L299-L302

Added lines #L299 - L302 were not covered by tests
);
Default::default()

Check warning on line 304 in sylvia-derive/src/parser.rs

View check run for this annotation

Codecov / codecov/patch

sylvia-derive/src/parser.rs#L304

Added line #L304 was not covered by tests
}
}
}).unwrap_or_default();

generics
}

#[cfg(not(tarpaulin_include))]
// False negative. It is being called in closure
impl Parse for ContractMessageAttr {
fn parse(input: ParseStream) -> Result<Self> {
let content;
parenthesized!(content in input);

let module = content.parse()?;
let mut module = content.parse()?;
let generics = extract_generics_from_path(&mut module);
let module = StripGenerics.fold_path(module);

let _: Token![as] = content.parse()?;
let variant = content.parse()?;
Expand All @@ -310,6 +336,7 @@ impl Parse for ContractMessageAttr {
module,
variant,
customs,
generics,
})
}
}
Expand Down Expand Up @@ -474,6 +501,7 @@ impl OverrideEntryPoint {
entry_point,
msg_name,
msg_type,
..
} = self;

let sylvia = crate_module();
Expand Down
20 changes: 14 additions & 6 deletions sylvia/tests/generics.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,16 +34,21 @@ pub mod cw1_contract {
use sylvia::types::InstantiateCtx;
use sylvia_derive::contract;

use crate::{ExternalMsg, ExternalQuery};

pub struct Cw1Contract;

#[contract]
#[messages(crate::cw1<ExternalMsg, ExternalMsg, ExternalQuery> as Cw1)]
/// Required if interface returns generic `Response`
#[sv::custom(msg=ExternalMsg)]
impl Cw1Contract {
pub const fn new() -> Self {
Self
}

#[msg(instantiate)]
pub fn instantiate(&self, _ctx: InstantiateCtx) -> StdResult<Response> {
pub fn instantiate(&self, _ctx: InstantiateCtx) -> StdResult<Response<ExternalMsg>> {

Check warning on line 51 in sylvia/tests/generics.rs

View check run for this annotation

Codecov / codecov/patch

sylvia/tests/generics.rs#L51

Added line #L51 was not covered by tests
Ok(Response::new())
}
}
Expand Down Expand Up @@ -91,12 +96,11 @@ impl cosmwasm_std::CustomQuery for ExternalQuery {}

#[cfg(all(test, feature = "mt"))]
mod tests {
use crate::cw1::{InterfaceTypes, Querier as Cw1Querier};
use crate::{ExternalMsg, ExternalQuery};
use cosmwasm_std::{testing::mock_dependencies, Addr, CosmosMsg, Empty, QuerierWrapper};

use crate::{cw1::Querier, ExternalMsg, ExternalQuery};

use crate::cw1::InterfaceTypes;
use sylvia::types::InterfaceMessages;

#[test]
fn construct_messages() {
let contract = Addr::unchecked("contract");
Expand All @@ -110,9 +114,13 @@ mod tests {
let querier: QuerierWrapper<ExternalQuery> = QuerierWrapper::new(&deps.querier);

let cw1_querier = crate::cw1::BoundQuerier::borrowed(&contract, &querier);
let _: Result<ExternalQuery, _> = Querier::some_query(&cw1_querier, ExternalMsg {});
let _: Result<ExternalQuery, _> =
crate::cw1::Querier::some_query(&cw1_querier, ExternalMsg {});
let _: Result<ExternalQuery, _> = cw1_querier.some_query(ExternalMsg {});

let contract_querier = crate::cw1_contract::BoundQuerier::borrowed(&contract, &querier);
let _: Result<ExternalQuery, _> = contract_querier.some_query(ExternalMsg {});

// Construct messages with Interface extension
let _ =
<InterfaceTypes<ExternalMsg, _, ExternalQuery> as InterfaceMessages>::Query::some_query(
Expand Down

0 comments on commit fd40f3d

Please sign in to comment.