Skip to content

Commit

Permalink
feat: Check interfaces return type for used generics
Browse files Browse the repository at this point in the history
  • Loading branch information
jawoznia committed Sep 22, 2023
1 parent f01f822 commit 5eb4c3f
Show file tree
Hide file tree
Showing 3 changed files with 76 additions and 47 deletions.
84 changes: 46 additions & 38 deletions sylvia-derive/src/message.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@ use crate::parser::{
Custom, MsgAttr, MsgType, OverrideEntryPoint, OverrideEntryPoints,
};
use crate::strip_generics::StripGenerics;
use crate::utils::{extract_return_type, filter_wheres, process_fields};
use crate::utils::{
as_where_clause, brace_generics, extract_return_type, filter_wheres, process_fields,
};
use crate::variant_descs::{AsVariantDescs, VariantDescs};
use convert_case::{Case, Casing};
use proc_macro2::{Span, TokenStream};
Expand Down Expand Up @@ -100,14 +102,6 @@ impl<'a> StructMessage<'a> {
custom,
} = self;

let where_clause = if !wheres.is_empty() {
quote! {
where #(#wheres,)*
}
} else {
quote! {}
};

let ctx_type = msg_attr
.msg_type()
.emit_ctx_type(&custom.query_or_default());
Expand All @@ -119,21 +113,9 @@ impl<'a> StructMessage<'a> {
});
let fields = fields.iter().map(MsgField::emit);

let generics = if generics.is_empty() {
quote! {}
} else {
quote! {
<#(#generics,)*>
}
};

let unused_generics = if unused_generics.is_empty() {
quote! {}
} else {
quote! {
<#(#unused_generics,)*>
}
};
let where_clause = as_where_clause(wheres);
let generics = brace_generics(generics);
let unused_generics = brace_generics(unused_generics);

#[cfg(not(tarpaulin_include))]
{
Expand Down Expand Up @@ -282,18 +264,33 @@ impl<'a> EnumMessage<'a> {
let ctx_type = msg_ty.emit_ctx_type(query_type);
let dispatch_type = msg_ty.emit_result_type(resp_type, &parse_quote!(C::Error));

let all_generics = if all_generics.is_empty() {
let all_generics = brace_generics(all_generics);
let phantom = if generics.is_empty() {
quote! {}
} else if MsgType::Query == *msg_ty {
quote! {
#[returns((#(#generics,)*))]
_Phantom(std::marker::PhantomData<( #(#generics,)* )>),
}
} else {
quote! { <#(#all_generics,)*> }
quote! {
_Phantom(std::marker::PhantomData<( #(#generics,)* )>),
}
};

let generics = if generics.is_empty() {
quote! {}
let match_arms = if !generics.is_empty() {
quote! {
#(#match_arms,)*
_Phantom(_) => unreachable!(),
}
} else {
quote! { <#(#generics,)*> }
quote! {
#(#match_arms,)*
}
};

let generics = brace_generics(generics);

let unique_enum_name = Ident::new(&format!("{}{}", trait_name, name), name.span());

#[cfg(not(tarpaulin_include))]
Expand All @@ -305,6 +302,7 @@ impl<'a> EnumMessage<'a> {
#[serde(rename_all="snake_case")]
pub enum #unique_enum_name #generics {
#(#variants,)*
#phantom
}
pub type #name #generics = #unique_enum_name #generics;
}
Expand All @@ -316,6 +314,7 @@ impl<'a> EnumMessage<'a> {
#[serde(rename_all="snake_case")]
pub enum #unique_enum_name #generics {
#(#variants,)*
#phantom
}
pub type #name #generics = #unique_enum_name #generics;
}
Expand All @@ -334,7 +333,7 @@ impl<'a> EnumMessage<'a> {
use #unique_enum_name::*;

match self {
#(#match_arms,)*
#match_arms
}
}
pub const fn messages() -> [&'static str; #msgs_cnt] {
Expand Down Expand Up @@ -507,10 +506,12 @@ impl<'a> MsgVariant<'a> {
let return_type = if let MsgAttr::Query { resp_type } = msg_attr {
match resp_type {
Some(resp_type) => {
generics_checker.visit_path(&parse_quote! { #resp_type });
quote! {#resp_type}
}
None => {
let return_type = extract_return_type(&sig.output);
generics_checker.visit_path(return_type);
quote! {#return_type}
}
}
Expand Down Expand Up @@ -621,7 +622,11 @@ impl<'a> MsgVariant<'a> {
}
}

pub fn emit_querier_impl(&self, trait_module: Option<&Path>) -> TokenStream {
pub fn emit_querier_impl(
&self,
trait_module: Option<&Path>,
unbonded_generics: &Vec<&GenericParam>,
) -> TokenStream {
let sylvia = crate_module();
let Self {
name,
Expand All @@ -637,6 +642,12 @@ impl<'a> MsgVariant<'a> {
.map(|module| quote! { #module ::QueryMsg })
.unwrap_or_else(|| quote! { QueryMsg });

let msg = if !unbonded_generics.is_empty() {
quote! { #msg ::< #(#unbonded_generics,)* > }
} else {
quote! { #msg }
};

#[cfg(not(tarpaulin_include))]
{
quote! {
Expand Down Expand Up @@ -741,18 +752,15 @@ impl<'a> MsgVariants<'a> {
let methods_impl = variants
.iter()
.filter(|variant| variant.msg_type == MsgType::Query)
.map(|variant| variant.emit_querier_impl(None));
.map(|variant| variant.emit_querier_impl(None, unbonded_generics));

let methods_declaration = variants
.iter()
.filter(|variant| variant.msg_type == MsgType::Query)
.map(MsgVariant::emit_querier_declaration);

let querier = if !unbonded_generics.is_empty() {
quote! { Querier < #(#unbonded_generics,)* > }
} else {
quote! { Querier }
};
let braced_generics = brace_generics(unbonded_generics);
let querier = quote! { Querier #braced_generics };

#[cfg(not(tarpaulin_include))]
{
Expand Down Expand Up @@ -803,7 +811,7 @@ impl<'a> MsgVariants<'a> {
let methods_impl = variants
.iter()
.filter(|variant| variant.msg_type == MsgType::Query)
.map(|variant| variant.emit_querier_impl(trait_module));
.map(|variant| variant.emit_querier_impl(trait_module, unbonded_generics));

let mut querier = trait_module
.map(|module| quote! { #module ::Querier })
Expand Down
20 changes: 18 additions & 2 deletions sylvia-derive/src/utils.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
use proc_macro2::TokenStream;
use proc_macro_error::emit_error;
use quote::quote;
use syn::spanned::Spanned;
use syn::visit::Visit;
use syn::{
FnArg, GenericArgument, GenericParam, Path, PathArguments, ReturnType, Signature, Type,
WhereClause, WherePredicate,
parse_quote, FnArg, GenericArgument, GenericParam, Path, PathArguments, ReturnType, Signature,
Type, WhereClause, WherePredicate,
};

use crate::check_generics::CheckGenerics;
Expand Down Expand Up @@ -84,3 +86,17 @@ pub fn extract_return_type(ret_type: &ReturnType) -> &Path {

&type_path.path
}

pub fn as_where_clause(where_predicates: &[&WherePredicate]) -> Option<WhereClause> {
match where_predicates.is_empty() {
true => None,
false => Some(parse_quote! { where #(#where_predicates),* }),
}
}

pub fn brace_generics(unbonded_generics: &[&GenericParam]) -> TokenStream {
match unbonded_generics.is_empty() {
true => quote! {},
false => quote! { < #(#unbonded_generics,)* > },
}
}
19 changes: 12 additions & 7 deletions sylvia/tests/generics.rs
Original file line number Diff line number Diff line change
@@ -1,26 +1,31 @@
use cosmwasm_schema::cw_serde;

pub mod cw1 {
use cosmwasm_std::{CosmosMsg, CustomMsg, Response, StdError};
use cosmwasm_std::{CosmosMsg, CustomMsg, CustomQuery, Response, StdError};

use serde::Deserialize;
use serde::{de::DeserializeOwned, Deserialize};
use sylvia::types::{ExecCtx, QueryCtx};
use sylvia_derive::interface;

#[interface(module=msg)]
pub trait Cw1<Msg, Param>
#[sv::custom(msg=Msg)]
pub trait Cw1<Msg, Param, QueryRet>
where
for<'msg_de> Msg: CustomMsg + Deserialize<'msg_de>,
Param: sylvia::types::CustomMsg,
for<'msg_de> QueryRet: CustomQuery + DeserializeOwned,
{
type Error: From<StdError>;

#[msg(exec)]
fn execute(&self, ctx: ExecCtx, msgs: Vec<CosmosMsg<Msg>>)
-> Result<Response, Self::Error>;
fn execute(
&self,
ctx: ExecCtx,
msgs: Vec<CosmosMsg<Msg>>,
) -> Result<Response<Msg>, Self::Error>;

#[msg(query)]
fn query(&self, ctx: QueryCtx, param: Param) -> Result<String, Self::Error>;
fn some_query(&self, ctx: QueryCtx, param: Param) -> Result<QueryRet, Self::Error>;
}
}

Expand All @@ -37,7 +42,7 @@ mod tests {

#[test]
fn construct_messages() {
let _ = crate::cw1::QueryMsg::query(ExternalMsg {});
let _ = crate::cw1::QueryMsg::<_, Empty>::some_query(ExternalMsg {});
let _ = crate::cw1::ExecMsg::execute(vec![CosmosMsg::Custom(ExternalMsg {})]);
let _ = crate::cw1::ExecMsg::execute(vec![CosmosMsg::Custom(Empty {})]);
}
Expand Down

0 comments on commit 5eb4c3f

Please sign in to comment.