Skip to content

Commit

Permalink
chore: Refactor struct message generation
Browse files Browse the repository at this point in the history
  • Loading branch information
jawoznia committed Aug 9, 2024
1 parent 4c5a678 commit 99e236c
Show file tree
Hide file tree
Showing 5 changed files with 101 additions and 133 deletions.
2 changes: 1 addition & 1 deletion sylvia-derive/src/contract.rs
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ impl<'a> ContractInput<'a> {
}

fn emit_struct_msg(&self, msg_ty: MsgType) -> TokenStream {
StructMessage::new(self.item, msg_ty, &self.generics, &self.custom)
StructMessage::new(self.item, msg_ty, &self.generics, &self.error, &self.custom)
.map_or(quote! {}, |msg| msg.emit())
}

Expand Down
170 changes: 66 additions & 104 deletions sylvia-derive/src/contract/communication/struct_msg.rs
Original file line number Diff line number Diff line change
@@ -1,29 +1,23 @@
use crate::crate_module;
use crate::parser::attributes::MsgAttrForwarding;
use crate::parser::check_generics::CheckGenerics;
use crate::parser::{Custom, MsgAttr, MsgType, ParsedSylviaAttributes};
use crate::parser::variant_descs::AsVariantDescs;
use crate::parser::{ContractErrorAttr, Custom, MsgType, ParsedSylviaAttributes};
use crate::types::msg_field::MsgField;
use crate::utils::{as_where_clause, emit_bracketed_generics, filter_wheres, process_fields};
use proc_macro2::{Span, TokenStream};
use crate::types::msg_variant::MsgVariants;
use crate::utils::{as_where_clause, emit_bracketed_generics, filter_wheres};
use proc_macro2::TokenStream;
use proc_macro_error::emit_error;
use quote::quote;
use syn::spanned::Spanned;
use syn::{
GenericParam, Ident, ImplItem, ImplItemFn, ItemImpl, ReturnType, Type, WhereClause,
WherePredicate,
};
use syn::{GenericParam, ItemImpl, Type};

/// Representation of single struct message
pub struct StructMessage<'a> {
source: &'a ItemImpl,
contract_type: &'a Type,
fields: Vec<MsgField<'a>>,
function_name: &'a Ident,
generics: Vec<&'a GenericParam>,
unused_generics: Vec<&'a GenericParam>,
wheres: Vec<&'a WherePredicate>,
full_where: Option<&'a WhereClause>,
result: &'a ReturnType,
msg_attr: MsgAttr,
variants: MsgVariants<'a, GenericParam>,
generics: &'a [&'a GenericParam],
error: &'a ContractErrorAttr,
custom: &'a Custom,
msg_attrs_to_forward: Vec<MsgAttrForwarding>,
}
Expand All @@ -32,138 +26,106 @@ impl<'a> StructMessage<'a> {
/// Creates new struct message of given type from impl block
pub fn new(
source: &'a ItemImpl,
ty: MsgType,
msg_ty: MsgType,
generics: &'a [&'a GenericParam],
error: &'a ContractErrorAttr,
custom: &'a Custom,
) -> Option<StructMessage<'a>> {
let mut generics_checker = CheckGenerics::new(generics);

let contract_type = &source.self_ty;

let parsed = Self::parse_struct_message(source, ty);
let (method, msg_attr) = parsed?;
let variants = MsgVariants::new(
source.as_variants(),
msg_ty,
generics,
&source.generics.where_clause,
);

let function_name = &method.sig.ident;
let fields = process_fields(&method.sig, &mut generics_checker);
let (used_generics, unused_generics) = generics_checker.used_unused();
let wheres = filter_wheres(&source.generics.where_clause, generics, &used_generics);
if variants.variants().count() == 0 && variants.msg_ty() == MsgType::Instantiate {
emit_error!(
source.span(), "Missing instantiation message.";
note = source.span() => "`sylvia::contract` requires exactly one method marked with `#[sv::msg(instantiation)]` attribute."
);
return None;
} else if variants.variants().count() > 1 {
emit_error!(
source.span(), "More than one instantiation or migration message";
note = source.span() => "Instantiation/Migration message previously defined here"
);
return None;
}

let msg_attrs_to_forward = ParsedSylviaAttributes::new(source.attrs.iter())
.msg_attrs_forward
.into_iter()
.filter(|attr| attr.msg_type == ty)
.filter(|attr| attr.msg_type == msg_ty)
.collect();

Some(Self {
source,
contract_type,
fields,
function_name,
generics: used_generics,
unused_generics,
wheres,
full_where: source.generics.where_clause.as_ref(),
result: &method.sig.output,
msg_attr,
variants,
generics,
error,
custom,
msg_attrs_to_forward,
})
}

fn parse_struct_message(source: &ItemImpl, ty: MsgType) -> Option<(&ImplItemFn, MsgAttr)> {
let mut methods = source.items.iter().filter_map(|item| match item {
ImplItem::Fn(method) => {
let attr = ParsedSylviaAttributes::new(method.attrs.iter()).msg_attr?;
if attr == ty {
Some((method, attr))
} else {
None
}
}
_ => None,
});

let (method, msg_attr) = if let Some(method) = methods.next() {
method
} else {
if ty == MsgType::Instantiate {
emit_error!(
source.span(), "Missing instantiation message.";
note = source.span() => "`sylvia::contract` requires exactly one method marked with `#[sv::msg(instantiation)]` attribute."
);
}
return None;
};

if let Some((obsolete, _)) = methods.next() {
emit_error!(
obsolete.span(), "More than one instantiation or migration message";
note = method.span() => "Instantiation/Migration message previously defined here"
);
}
Some((method, msg_attr))
}

pub fn emit(&self) -> TokenStream {
use MsgAttr::*;

let instantiate_msg = Ident::new("InstantiateMsg", self.function_name.span());
let migrate_msg = Ident::new("MigrateMsg", self.function_name.span());

match &self.msg_attr {
Instantiate { .. } => self.emit_struct(&instantiate_msg),
Migrate { .. } => self.emit_struct(&migrate_msg),
_ => {
emit_error!(Span::mixed_site(), "Invalid message type");
quote! {}
}
}
}

pub fn emit_struct(&self, name: &Ident) -> TokenStream {
let sylvia = crate_module();

let Self {
source,
contract_type,
fields,
function_name,
variants,
generics,
unused_generics,
wheres,
full_where,
result,
msg_attr,
error,
custom,
msg_attrs_to_forward,
} = self;

let ctx_type = msg_attr
let Some(variant) = variants.get_only_variant() else {
return quote! {};
};

let used_generics = variants.used_generics();
let unused_generics = variants.unused_generics();
let full_where = &source.generics.where_clause;
let wheres = filter_wheres(full_where, generics, used_generics);
let where_clause = as_where_clause(&wheres);
let bracketed_used_generics = emit_bracketed_generics(used_generics);
let bracketed_unused_generics = emit_bracketed_generics(unused_generics);

let ret_type = variant
.msg_type()
.emit_ctx_type(&custom.query_or_default());
let fields_names: Vec<_> = fields.iter().map(MsgField::name).collect();
let parameters = fields.iter().map(MsgField::emit_method_field);
let fields = fields.iter().map(MsgField::emit_pub);

let where_clause = as_where_clause(wheres);
let generics = emit_bracketed_generics(generics);
let unused_generics = emit_bracketed_generics(unused_generics);
.emit_result_type(&custom.msg_or_default(), &error.error);
let name = variant.msg_type().emit_msg_name();
let function_name = variant.function_name();
let mut msg_name = variant.msg_type().emit_msg_name();
msg_name.set_span(function_name.span());

let ctx_type = variant.msg_type().emit_ctx_type(&custom.query_or_default());
let fields_names: Vec<_> = variant.fields().iter().map(MsgField::name).collect();
let parameters = variant.fields().iter().map(MsgField::emit_method_field);
let fields = variant.fields().iter().map(MsgField::emit_pub);

let msg_attrs_to_forward = msg_attrs_to_forward.iter().map(|attr| &attr.attrs);

quote! {
#[allow(clippy::derive_partial_eq_without_eq)]
#[derive(#sylvia ::serde::Serialize, #sylvia ::serde::Deserialize, Clone, Debug, PartialEq, #sylvia ::schemars::JsonSchema)]
#( #[ #msg_attrs_to_forward ] )*
#[serde(rename_all="snake_case")]
pub struct #name #generics {
pub struct #name #bracketed_used_generics {
#(#fields,)*
}

impl #generics #name #generics #where_clause {
impl #bracketed_used_generics #name #bracketed_used_generics #where_clause {
pub fn new(#(#parameters,)*) -> Self {
Self { #(#fields_names,)* }
}

pub fn dispatch #unused_generics(self, contract: &#contract_type, ctx: #ctx_type)
#result #full_where
pub fn dispatch #bracketed_unused_generics (self, contract: &#contract_type, ctx: #ctx_type) -> #ret_type #full_where
{
let Self { #(#fields_names,)* } = self;
contract.#function_name(Into::into(ctx), #(#fields_names,)*).map_err(Into::into)
Expand Down
28 changes: 26 additions & 2 deletions sylvia-derive/src/parser/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,19 @@ pub use attributes::{
ContractErrorAttr, ContractMessageAttr, Custom, Customs, FilteredOverrideEntryPoints, MsgAttr,
MsgType, OverrideEntryPoint, ParsedSylviaAttributes, SylviaAttribute,
};
use check_generics::{CheckGenerics, GetPath};
pub use entry_point::EntryPointArgs;

use proc_macro_error::emit_error;
use syn::punctuated::Punctuated;
use syn::spanned::Spanned;
use syn::{
parse_quote, GenericArgument, Ident, ImplItem, ItemImpl, ItemTrait, Path, PathArguments, Token,
TraitItem, Type,
parse_quote, FnArg, GenericArgument, Ident, ImplItem, ItemImpl, ItemTrait, Path, PathArguments,
Signature, Token, TraitItem, Type,
};

use crate::types::msg_field::MsgField;

fn extract_generics_from_path(module: &Path) -> Punctuated<GenericArgument, Token![,]> {
let generics = module
.segments
Expand Down Expand Up @@ -64,3 +67,24 @@ pub fn assert_new_method_defined(item: &ItemImpl) {
_ => (),
}
}

pub fn process_fields<'s, Generic>(
sig: &'s Signature,
generics_checker: &mut CheckGenerics<Generic>,
) -> Vec<MsgField<'s>>
where
Generic: GetPath + PartialEq,
{
sig.inputs
.iter()
.skip(2)
.filter_map(|arg| match arg {
FnArg::Receiver(item) => {
emit_error!(item.span(), "Unexpected `self` argument");
None
}

FnArg::Typed(item) => MsgField::new(item, generics_checker),
})
.collect()
}
8 changes: 6 additions & 2 deletions sylvia-derive/src/types/msg_variant.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@ use crate::fold::StripSelfPath;
use crate::parser::attributes::VariantAttrForwarding;
use crate::parser::check_generics::{CheckGenerics, GetPath};
use crate::parser::variant_descs::VariantDescs;
use crate::parser::{MsgAttr, MsgType};
use crate::utils::{extract_return_type, filter_wheres, process_fields, SvCasing};
use crate::parser::{process_fields, MsgAttr, MsgType};
use crate::utils::{extract_return_type, filter_wheres, SvCasing};
use convert_case::{Case, Casing};
use proc_macro2::TokenStream;
use quote::{quote, ToTokens};
Expand Down Expand Up @@ -246,6 +246,10 @@ where
&self.unused_generics
}

pub fn msg_ty(&self) -> MsgType {
self.msg_ty
}

pub fn emit_phantom_match_arm(&self) -> TokenStream {
let sylvia = crate_module();
let Self { used_generics, .. } = self;
Expand Down
26 changes: 2 additions & 24 deletions sylvia-derive/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,11 @@ use quote::{quote, ToTokens};
use syn::spanned::Spanned;
use syn::visit::Visit;
use syn::{
parse_quote, FnArg, GenericArgument, Ident, Path, PathArguments, ReturnType, Signature, Type,
WhereClause, WherePredicate,
parse_quote, GenericArgument, Ident, Path, PathArguments, ReturnType, Type, WhereClause,
WherePredicate,
};

use crate::parser::check_generics::{CheckGenerics, GetPath};
use crate::types::msg_field::MsgField;

pub fn filter_wheres<'a, Generic: GetPath + PartialEq>(
clause: &'a Option<WhereClause>,
Expand All @@ -36,27 +35,6 @@ pub fn filter_wheres<'a, Generic: GetPath + PartialEq>(
.unwrap_or_default()
}

pub fn process_fields<'s, Generic>(
sig: &'s Signature,
generics_checker: &mut CheckGenerics<Generic>,
) -> Vec<MsgField<'s>>
where
Generic: GetPath + PartialEq,
{
sig.inputs
.iter()
.skip(2)
.filter_map(|arg| match arg {
FnArg::Receiver(item) => {
emit_error!(item.span(), "Unexpected `self` argument");
None
}

FnArg::Typed(item) => MsgField::new(item, generics_checker),
})
.collect()
}

pub fn extract_return_type(ret_type: &ReturnType) -> &Path {
let ReturnType::Type(_, ty) = ret_type else {
unreachable!()
Expand Down

0 comments on commit 99e236c

Please sign in to comment.