Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: Refactor struct message generation #412

Merged
merged 2 commits into from
Aug 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion sylvia-derive/src/contract.rs
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ impl<'a> ContractInput<'a> {
}

fn emit_struct_msg(&self, msg_ty: MsgType) -> TokenStream {
StructMessage::new(self.item, msg_ty, &self.generics, &self.custom)
StructMessage::new(self.item, msg_ty, &self.generics, &self.error, &self.custom)
.map_or(quote! {}, |msg| msg.emit())
}

Expand Down
173 changes: 69 additions & 104 deletions sylvia-derive/src/contract/communication/struct_msg.rs
Original file line number Diff line number Diff line change
@@ -1,29 +1,23 @@
use crate::crate_module;
use crate::parser::attributes::MsgAttrForwarding;
use crate::parser::check_generics::CheckGenerics;
use crate::parser::{Custom, MsgAttr, MsgType, ParsedSylviaAttributes};
use crate::parser::variant_descs::AsVariantDescs;
use crate::parser::{ContractErrorAttr, Custom, MsgType, ParsedSylviaAttributes};
use crate::types::msg_field::MsgField;
use crate::utils::{as_where_clause, emit_bracketed_generics, filter_wheres, process_fields};
use proc_macro2::{Span, TokenStream};
use crate::types::msg_variant::MsgVariants;
use crate::utils::{as_where_clause, emit_bracketed_generics, filter_wheres};
use proc_macro2::TokenStream;
use proc_macro_error::emit_error;
use quote::quote;
use syn::spanned::Spanned;
use syn::{
GenericParam, Ident, ImplItem, ImplItemFn, ItemImpl, ReturnType, Type, WhereClause,
WherePredicate,
};
use syn::{GenericParam, ItemImpl, Type};

/// Representation of single struct message
pub struct StructMessage<'a> {
source: &'a ItemImpl,
contract_type: &'a Type,
fields: Vec<MsgField<'a>>,
function_name: &'a Ident,
generics: Vec<&'a GenericParam>,
unused_generics: Vec<&'a GenericParam>,
wheres: Vec<&'a WherePredicate>,
full_where: Option<&'a WhereClause>,
result: &'a ReturnType,
msg_attr: MsgAttr,
variants: MsgVariants<'a, GenericParam>,
generics: &'a [&'a GenericParam],
error: &'a ContractErrorAttr,
custom: &'a Custom,
msg_attrs_to_forward: Vec<MsgAttrForwarding>,
}
Expand All @@ -32,138 +26,109 @@
/// Creates new struct message of given type from impl block
pub fn new(
source: &'a ItemImpl,
ty: MsgType,
msg_ty: MsgType,
generics: &'a [&'a GenericParam],
error: &'a ContractErrorAttr,
custom: &'a Custom,
) -> Option<StructMessage<'a>> {
let mut generics_checker = CheckGenerics::new(generics);

let contract_type = &source.self_ty;

let parsed = Self::parse_struct_message(source, ty);
let (method, msg_attr) = parsed?;
let variants = MsgVariants::new(
source.as_variants(),
msg_ty,
generics,
&source.generics.where_clause,
);

let function_name = &method.sig.ident;
let fields = process_fields(&method.sig, &mut generics_checker);
let (used_generics, unused_generics) = generics_checker.used_unused();
let wheres = filter_wheres(&source.generics.where_clause, generics, &used_generics);
if variants.variants().count() == 0 && variants.msg_ty() == MsgType::Instantiate {
emit_error!(
source.span(), "Missing instantiation message.";
note = source.span() => "`sylvia::contract` requires exactly one method marked with `#[sv::msg(instantiation)]` attribute."

Check warning on line 46 in sylvia-derive/src/contract/communication/struct_msg.rs

View check run for this annotation

Codecov / codecov/patch

sylvia-derive/src/contract/communication/struct_msg.rs#L44-L46

Added lines #L44 - L46 were not covered by tests
);
return None;

Check warning on line 48 in sylvia-derive/src/contract/communication/struct_msg.rs

View check run for this annotation

Codecov / codecov/patch

sylvia-derive/src/contract/communication/struct_msg.rs#L48

Added line #L48 was not covered by tests
} else if variants.variants().count() > 1 {
let mut variants = variants.variants();
let first_method = variants.next().map(|v| v.function_name());
let obsolete = variants.next().map(|v| v.function_name());
emit_error!(
first_method.span(), "More than one instantiation or migration message";
note = obsolete.span() => "Instantiation/Migration message previously defined here"

Check warning on line 55 in sylvia-derive/src/contract/communication/struct_msg.rs

View check run for this annotation

Codecov / codecov/patch

sylvia-derive/src/contract/communication/struct_msg.rs#L50-L55

Added lines #L50 - L55 were not covered by tests
);
return None;

Check warning on line 57 in sylvia-derive/src/contract/communication/struct_msg.rs

View check run for this annotation

Codecov / codecov/patch

sylvia-derive/src/contract/communication/struct_msg.rs#L57

Added line #L57 was not covered by tests
}

let msg_attrs_to_forward = ParsedSylviaAttributes::new(source.attrs.iter())
.msg_attrs_forward
.into_iter()
.filter(|attr| attr.msg_type == ty)
.filter(|attr| attr.msg_type == msg_ty)
.collect();

Some(Self {
source,
contract_type,
fields,
function_name,
generics: used_generics,
unused_generics,
wheres,
full_where: source.generics.where_clause.as_ref(),
result: &method.sig.output,
msg_attr,
variants,
generics,
error,
custom,
msg_attrs_to_forward,
})
}

fn parse_struct_message(source: &ItemImpl, ty: MsgType) -> Option<(&ImplItemFn, MsgAttr)> {
let mut methods = source.items.iter().filter_map(|item| match item {
ImplItem::Fn(method) => {
let attr = ParsedSylviaAttributes::new(method.attrs.iter()).msg_attr?;
if attr == ty {
Some((method, attr))
} else {
None
}
}
_ => None,
});

let (method, msg_attr) = if let Some(method) = methods.next() {
method
} else {
if ty == MsgType::Instantiate {
emit_error!(
source.span(), "Missing instantiation message.";
note = source.span() => "`sylvia::contract` requires exactly one method marked with `#[sv::msg(instantiation)]` attribute."
);
}
return None;
};

if let Some((obsolete, _)) = methods.next() {
emit_error!(
obsolete.span(), "More than one instantiation or migration message";
note = method.span() => "Instantiation/Migration message previously defined here"
);
}
Some((method, msg_attr))
}

pub fn emit(&self) -> TokenStream {
use MsgAttr::*;

let instantiate_msg = Ident::new("InstantiateMsg", self.function_name.span());
let migrate_msg = Ident::new("MigrateMsg", self.function_name.span());

match &self.msg_attr {
Instantiate { .. } => self.emit_struct(&instantiate_msg),
Migrate { .. } => self.emit_struct(&migrate_msg),
_ => {
emit_error!(Span::mixed_site(), "Invalid message type");
quote! {}
}
}
}

pub fn emit_struct(&self, name: &Ident) -> TokenStream {
let sylvia = crate_module();

let Self {
source,
contract_type,
fields,
function_name,
variants,
generics,
unused_generics,
wheres,
full_where,
result,
msg_attr,
error,
custom,
msg_attrs_to_forward,
} = self;

let ctx_type = msg_attr
let Some(variant) = variants.get_only_variant() else {
return quote! {};
};

let used_generics = variants.used_generics();
let unused_generics = variants.unused_generics();
let full_where = &source.generics.where_clause;
let wheres = filter_wheres(full_where, generics, used_generics);
let where_clause = as_where_clause(&wheres);
let bracketed_used_generics = emit_bracketed_generics(used_generics);
let bracketed_unused_generics = emit_bracketed_generics(unused_generics);

let ret_type = variant
.msg_type()
.emit_ctx_type(&custom.query_or_default());
let fields_names: Vec<_> = fields.iter().map(MsgField::name).collect();
let parameters = fields.iter().map(MsgField::emit_method_field);
let fields = fields.iter().map(MsgField::emit_pub);

let where_clause = as_where_clause(wheres);
let generics = emit_bracketed_generics(generics);
let unused_generics = emit_bracketed_generics(unused_generics);
.emit_result_type(&custom.msg_or_default(), &error.error);
let name = variant.msg_type().emit_msg_name();
let function_name = variant.function_name();
let mut msg_name = variant.msg_type().emit_msg_name();
msg_name.set_span(function_name.span());

let ctx_type = variant.msg_type().emit_ctx_type(&custom.query_or_default());
let fields_names: Vec<_> = variant.fields().iter().map(MsgField::name).collect();
let parameters = variant.fields().iter().map(MsgField::emit_method_field);
let fields = variant.fields().iter().map(MsgField::emit_pub);

let msg_attrs_to_forward = msg_attrs_to_forward.iter().map(|attr| &attr.attrs);

quote! {
#[allow(clippy::derive_partial_eq_without_eq)]
#[derive(#sylvia ::serde::Serialize, #sylvia ::serde::Deserialize, Clone, Debug, PartialEq, #sylvia ::schemars::JsonSchema)]
#( #[ #msg_attrs_to_forward ] )*
#[serde(rename_all="snake_case")]
pub struct #name #generics {
pub struct #name #bracketed_used_generics {
#(#fields,)*
}

impl #generics #name #generics #where_clause {
impl #bracketed_used_generics #name #bracketed_used_generics #where_clause {
kulikthebird marked this conversation as resolved.
Show resolved Hide resolved
pub fn new(#(#parameters,)*) -> Self {
Self { #(#fields_names,)* }
}

pub fn dispatch #unused_generics(self, contract: &#contract_type, ctx: #ctx_type)
#result #full_where
pub fn dispatch #bracketed_unused_generics (self, contract: &#contract_type, ctx: #ctx_type) -> #ret_type #full_where
{
let Self { #(#fields_names,)* } = self;
contract.#function_name(Into::into(ctx), #(#fields_names,)*).map_err(Into::into)
Expand Down
28 changes: 26 additions & 2 deletions sylvia-derive/src/parser/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,19 @@
ContractErrorAttr, ContractMessageAttr, Custom, Customs, FilteredOverrideEntryPoints, MsgAttr,
MsgType, OverrideEntryPoint, ParsedSylviaAttributes, SylviaAttribute,
};
use check_generics::{CheckGenerics, GetPath};
pub use entry_point::EntryPointArgs;

use proc_macro_error::emit_error;
use syn::punctuated::Punctuated;
use syn::spanned::Spanned;
use syn::{
parse_quote, GenericArgument, Ident, ImplItem, ItemImpl, ItemTrait, Path, PathArguments, Token,
TraitItem, Type,
parse_quote, FnArg, GenericArgument, Ident, ImplItem, ItemImpl, ItemTrait, Path, PathArguments,
Signature, Token, TraitItem, Type,
};

use crate::types::msg_field::MsgField;

fn extract_generics_from_path(module: &Path) -> Punctuated<GenericArgument, Token![,]> {
let generics = module
.segments
Expand Down Expand Up @@ -64,3 +67,24 @@
_ => (),
}
}

pub fn process_fields<'s, Generic>(
sig: &'s Signature,
generics_checker: &mut CheckGenerics<Generic>,
) -> Vec<MsgField<'s>>
where
Generic: GetPath + PartialEq,
{
sig.inputs
.iter()
.skip(2)
.filter_map(|arg| match arg {
FnArg::Receiver(item) => {
emit_error!(item.span(), "Unexpected `self` argument");
None

Check warning on line 84 in sylvia-derive/src/parser/mod.rs

View check run for this annotation

Codecov / codecov/patch

sylvia-derive/src/parser/mod.rs#L82-L84

Added lines #L82 - L84 were not covered by tests
}

FnArg::Typed(item) => MsgField::new(item, generics_checker),
})
.collect()
}
8 changes: 6 additions & 2 deletions sylvia-derive/src/types/msg_variant.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@ use crate::fold::StripSelfPath;
use crate::parser::attributes::VariantAttrForwarding;
use crate::parser::check_generics::{CheckGenerics, GetPath};
use crate::parser::variant_descs::VariantDescs;
use crate::parser::{MsgAttr, MsgType};
use crate::utils::{extract_return_type, filter_wheres, process_fields, SvCasing};
use crate::parser::{process_fields, MsgAttr, MsgType};
use crate::utils::{extract_return_type, filter_wheres, SvCasing};
use convert_case::{Case, Casing};
use proc_macro2::TokenStream;
use quote::{quote, ToTokens};
Expand Down Expand Up @@ -246,6 +246,10 @@ where
&self.unused_generics
}

pub fn msg_ty(&self) -> MsgType {
self.msg_ty
}

pub fn emit_phantom_match_arm(&self) -> TokenStream {
let sylvia = crate_module();
let Self { used_generics, .. } = self;
Expand Down
26 changes: 2 additions & 24 deletions sylvia-derive/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,11 @@ use quote::{quote, ToTokens};
use syn::spanned::Spanned;
use syn::visit::Visit;
use syn::{
parse_quote, FnArg, GenericArgument, Ident, Path, PathArguments, ReturnType, Signature, Type,
WhereClause, WherePredicate,
parse_quote, GenericArgument, Ident, Path, PathArguments, ReturnType, Type, WhereClause,
WherePredicate,
};

use crate::parser::check_generics::{CheckGenerics, GetPath};
use crate::types::msg_field::MsgField;

pub fn filter_wheres<'a, Generic: GetPath + PartialEq>(
clause: &'a Option<WhereClause>,
Expand All @@ -36,27 +35,6 @@ pub fn filter_wheres<'a, Generic: GetPath + PartialEq>(
.unwrap_or_default()
}

pub fn process_fields<'s, Generic>(
sig: &'s Signature,
generics_checker: &mut CheckGenerics<Generic>,
) -> Vec<MsgField<'s>>
where
Generic: GetPath + PartialEq,
{
sig.inputs
.iter()
.skip(2)
.filter_map(|arg| match arg {
FnArg::Receiver(item) => {
emit_error!(item.span(), "Unexpected `self` argument");
None
}

FnArg::Typed(item) => MsgField::new(item, generics_checker),
})
.collect()
}

pub fn extract_return_type(ret_type: &ReturnType) -> &Path {
let ReturnType::Type(_, ty) = ret_type else {
unreachable!()
Expand Down
6 changes: 3 additions & 3 deletions sylvia/tests/ui/missing_method/msgs_misused.stderr
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ error: More than one instantiation or migration message

= note: Instantiation/Migration message previously defined here

--> tests/ui/missing_method/msgs_misused.rs:35:5
--> tests/ui/missing_method/msgs_misused.rs:31:12
|
35 | #[sv::msg(instantiate)]
| ^
31 | pub fn instantiate(&self, ctx: InstantiateCtx) -> StdResult<Response> {
| ^^^^^^^^^^^
Loading