Skip to content

Commit

Permalink
feat: Support generic interface implemented on contract
Browse files Browse the repository at this point in the history
  • Loading branch information
jawoznia committed Oct 6, 2023
1 parent eb685d4 commit 57f4389
Show file tree
Hide file tree
Showing 7 changed files with 251 additions and 155 deletions.
56 changes: 42 additions & 14 deletions sylvia-derive/src/check_generics.rs
Original file line number Diff line number Diff line change
@@ -1,26 +1,55 @@
use syn::visit::Visit;
use syn::GenericParam;
use syn::{GenericArgument, GenericParam, Type};

pub trait AsIdent {
fn as_ident(&self) -> Option<&syn::Ident>;
}

impl AsIdent for GenericParam {
fn as_ident(&self) -> Option<&syn::Ident> {
match self {
GenericParam::Type(ty) => Some(&ty.ident),
GenericParam::Lifetime(lt) => Some(&lt.lifetime.ident),
GenericParam::Const(c) => Some(&c.ident),
}
}
}

impl AsIdent for GenericArgument {
fn as_ident(&self) -> Option<&syn::Ident> {
match self {
GenericArgument::Type(Type::Path(path)) => path.path.get_ident(),
GenericArgument::Lifetime(lt) => Some(&lt.ident),
GenericArgument::Binding(b) => Some(&b.ident),
GenericArgument::Constraint(c) => Some(&c.ident),
_ => None,
}
}
}

#[derive(Debug)]
pub struct CheckGenerics<'g> {
generics: &'g [&'g GenericParam],
used: Vec<&'g GenericParam>,
pub struct CheckGenerics<'g, Generic> {
generics: &'g [&'g Generic],
used: Vec<&'g Generic>,
}

impl<'g> CheckGenerics<'g> {
pub fn new(generics: &'g [&'g GenericParam]) -> Self {
impl<'g, Generic> CheckGenerics<'g, Generic>
where
Generic: AsIdent + PartialEq,
{
pub fn new(generics: &'g [&'g Generic]) -> Self {
Self {
generics,
used: vec![],
}
}

pub fn used(self) -> Vec<&'g GenericParam> {
pub fn used(self) -> Vec<&'g Generic> {
self.used
}

/// Returns split between used and unused generics
pub fn used_unused(self) -> (Vec<&'g GenericParam>, Vec<&'g GenericParam>) {
pub fn used_unused(self) -> (Vec<&'g Generic>, Vec<&'g Generic>) {
let unused = self
.generics
.iter()
Expand All @@ -32,14 +61,13 @@ impl<'g> CheckGenerics<'g> {
}
}

impl<'ast, 'g> Visit<'ast> for CheckGenerics<'g> {
impl<'ast, 'g, Generic> Visit<'ast> for CheckGenerics<'g, Generic>
where
Generic: AsIdent + PartialEq,
{
fn visit_path(&mut self, p: &'ast syn::Path) {
if let Some(p) = p.get_ident() {
if let Some(gen) = self
.generics
.iter()
.find(|gen| matches!(gen, GenericParam::Type(ty) if ty.ident == *p))
{
if let Some(gen) = self.generics.iter().find(|gen| gen.as_ident() == Some(p)) {
if !self.used.contains(gen) {
self.used.push(gen);
}
Expand Down
65 changes: 45 additions & 20 deletions sylvia-derive/src/input.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use proc_macro_error::emit_error;
use quote::quote;
use syn::parse::{Parse, Parser};
use syn::spanned::Spanned;
use syn::{parse_quote, GenericParam, Ident, ItemImpl, ItemTrait, TraitItem, Type};
use syn::{parse_quote, GenericParam, Ident, ItemImpl, ItemTrait, PathArguments, TraitItem, Type};

use crate::crate_module;
use crate::interfaces::Interfaces;
Expand Down Expand Up @@ -156,42 +156,47 @@ impl<'a> ImplInput<'a> {
}

pub fn process(&self) -> TokenStream {
let is_trait = self.item.trait_.is_some();
let Self {
item,
generics,
error,
custom,
override_entry_points,
interfaces,
..
} = self;
let is_trait = item.trait_.is_some();
let multitest_helpers = if cfg!(feature = "mt") {
MultitestHelpers::new(
self.item,
item,
is_trait,
&self.error,
&self.generics,
&self.custom,
&self.override_entry_points,
&self.interfaces,
error,
generics,
custom,
override_entry_points,
interfaces,
)
.emit()
} else {
quote! {}
};

let unbonded_generics = &vec![];
let where_clause = &item.generics.where_clause;
let variants = MsgVariants::new(
self.item.as_variants(),
MsgType::Query,
unbonded_generics,
&None,
generics,
where_clause,
);

match is_trait {
true => self.process_interface(variants, multitest_helpers),
true => self.process_interface(multitest_helpers),
false => self.process_contract(variants, multitest_helpers),
}
}

fn process_interface(
&self,
variants: MsgVariants<'a>,
multitest_helpers: TokenStream,
) -> TokenStream {
let querier_bound_for_impl = self.emit_querier_for_bound_impl(variants);
fn process_interface(&self, multitest_helpers: TokenStream) -> TokenStream {
let querier_bound_for_impl = self.emit_querier_for_bound_impl();

#[cfg(not(tarpaulin_include))]
quote! {
Expand All @@ -203,7 +208,7 @@ impl<'a> ImplInput<'a> {

fn process_contract(
&self,
variants: MsgVariants<'a>,
variants: MsgVariants<'a, GenericParam>,
multitest_helpers: TokenStream,
) -> TokenStream {
let messages = self.emit_messages();
Expand Down Expand Up @@ -285,14 +290,34 @@ impl<'a> ImplInput<'a> {
.emit()
}

fn emit_querier_for_bound_impl(&self, variants: MsgVariants<'a>) -> TokenStream {
fn emit_querier_for_bound_impl(&self) -> TokenStream {
let trait_module = self
.interfaces
.interfaces()
.first()
.map(|interface| &interface.module);
let contract_module = self.attributes.module.as_ref();

// TODO: Try to not unwrap here
let interface_generics = &self
.item
.trait_
.as_ref()
.unwrap()
.1
.segments
.last()
.unwrap()
.arguments;

let generics = match interface_generics {
PathArguments::AngleBracketed(args) => {
args.args.pairs().map(|pair| *pair.value()).collect()
}
_ => vec![],
};
let variants = MsgVariants::new(self.item.as_variants(), MsgType::Query, &generics, &None);

variants.emit_querier_for_bound_impl(trait_module, contract_module)
}
}
26 changes: 11 additions & 15 deletions sylvia-derive/src/interfaces.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,30 +83,26 @@ impl Interfaces {
.collect()
}

pub fn emit_glue_message_variants(
&self,
msg_ty: &MsgType,
msg_name: &Ident,
) -> Vec<TokenStream> {
pub fn emit_glue_message_variants(&self, msg_ty: &MsgType) -> Vec<TokenStream> {
let sylvia = crate_module();

self.interfaces
.iter()
.map(|interface| {
let ContractMessageAttr {
module,
exec_generic_params,
query_generic_params,
variant,
generics,
..
} = interface;

let generics = match msg_ty {
MsgType::Exec => exec_generic_params.as_slice(),
MsgType::Query => query_generic_params.as_slice(),
_ => &[],
};

let enum_name = Self::merge_module_with_name(interface, msg_name);
quote! { #variant(#module :: #enum_name<#(#generics,)*>) }
let interface_enum =
quote! { <#module ::InterfaceTypes < #(#generics,)* > as #sylvia ::types::InterfaceMessages> };
if msg_ty == &MsgType::Query {
quote! { #variant ( #interface_enum :: Query) }
} else {
quote! { #variant ( #interface_enum :: Exec)}
}
})
.collect()
}
Expand Down
Loading

0 comments on commit 57f4389

Please sign in to comment.