From 7b6d2ced10f0d91ebf72de45d6d0d017845c2ec9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jan=20Wo=C5=BAniak?= Date: Mon, 29 Jul 2024 14:48:28 +0200 Subject: [PATCH] chore: Cleanup in entry_points macro --- sylvia-derive/src/lib.rs | 3 +- sylvia-derive/src/message.rs | 235 +++++++++------------ sylvia-derive/src/parser/entry_point.rs | 27 ++- sylvia/tests/ui/macros/entry_points.rs | 68 ++++-- sylvia/tests/ui/macros/entry_points.stderr | 19 ++ 5 files changed, 200 insertions(+), 152 deletions(-) create mode 100644 sylvia/tests/ui/macros/entry_points.stderr diff --git a/sylvia-derive/src/lib.rs b/sylvia-derive/src/lib.rs index 81709030..b51149fc 100644 --- a/sylvia-derive/src/lib.rs +++ b/sylvia-derive/src/lib.rs @@ -29,6 +29,7 @@ mod variant_descs; use strip_input::StripInput; use crate::message::EntryPoints; +use crate::parser::EntryPointArgs; #[cfg(not(test))] pub(crate) fn crate_module() -> Path { @@ -757,8 +758,8 @@ pub fn entry_points(attr: TokenStream, item: TokenStream) -> TokenStream { fn entry_points_impl(attr: TokenStream2, item: TokenStream2) -> TokenStream2 { fn inner(attr: TokenStream2, item: TokenStream2) -> syn::Result { - let attrs: parser::EntryPointArgs = parse2(attr)?; let input: ItemImpl = parse2(item)?; + let attrs = EntryPointArgs::new(&attr, &input)?; let expanded = EntryPoints::new(&input, attrs).emit(); Ok(quote! { diff --git a/sylvia-derive/src/message.rs b/sylvia-derive/src/message.rs index 1259887e..a3a69aea 100644 --- a/sylvia-derive/src/message.rs +++ b/sylvia-derive/src/message.rs @@ -19,12 +19,11 @@ use proc_macro2::{Span, TokenStream}; use proc_macro_error::emit_error; use quote::{quote, ToTokens}; use syn::fold::Fold; -use syn::punctuated::Punctuated; use syn::spanned::Spanned; use syn::visit::Visit; use syn::{ - parse_quote, Attribute, GenericArgument, GenericParam, Ident, ImplItem, ImplItemFn, ItemImpl, - ItemTrait, Pat, PatType, ReturnType, Signature, Token, Type, WhereClause, WherePredicate, + parse_quote, Attribute, GenericParam, Ident, ImplItem, ImplItemFn, ItemImpl, ItemTrait, Pat, + PatType, ReturnType, Signature, Type, WhereClause, WherePredicate, }; /// Representation of single struct message @@ -660,41 +659,6 @@ where &self.unused_generics } - pub fn emit_default_entry_point( - &self, - custom_msg: &Type, - custom_query: &Type, - name: &Type, - error: &Type, - contract_generics: &Option>, - ) -> TokenStream { - let Self { msg_ty, .. } = self; - let sylvia = crate_module(); - - let resp_type = match msg_ty { - MsgType::Query => quote! { #sylvia ::cw_std::Binary }, - _ => quote! { #sylvia ::cw_std::Response < #custom_msg > }, - }; - let params = msg_ty.emit_ctx_params(custom_query); - let values = msg_ty.emit_ctx_values(); - let ep_name = msg_ty.emit_ep_name(); - let bracketed_generics = match &contract_generics { - Some(generics) => quote! { ::< #generics > }, - None => quote! {}, - }; - let associated_name = msg_ty.as_accessor_wrapper_name(); - - quote! { - #[#sylvia ::cw_std::entry_point] - pub fn #ep_name ( - #params , - msg: < #name < #contract_generics > as #sylvia ::types::ContractApi> :: #associated_name, - ) -> Result<#resp_type, #error> { - msg.dispatch(&#name #bracketed_generics ::new() , ( #values )).map_err(Into::into) - } - } - } - pub fn emit_phantom_match_arm(&self) -> TokenStream { let sylvia = crate_module(); let Self { used_generics, .. } = self; @@ -1296,6 +1260,7 @@ pub struct EntryPoints<'a> { source: &'a ItemImpl, name: Type, error: Type, + reply: Option, override_entry_points: Vec, generics: Vec<&'a GenericParam>, where_clause: &'a Option, @@ -1313,10 +1278,17 @@ impl<'a> EntryPoints<'a> { let generics: Vec<_> = source.generics.params.iter().collect(); let where_clause = &source.generics.where_clause; + let reply = + MsgVariants::::new(source.as_variants(), MsgType::Reply, &[], &None) + .variants() + .map(|variant| variant.function_name.clone()) + .next(); + Self { source, name, error, + reply, override_entry_points, generics, where_clause, @@ -1327,119 +1299,120 @@ impl<'a> EntryPoints<'a> { pub fn emit(&self) -> TokenStream { let Self { source, - name, - error, + reply, override_entry_points, generics, where_clause, - attrs, + .. } = self; - let sylvia = crate_module(); - let bracketed_generics = attrs - .generics - .as_ref() - .map(|generics| match generics.is_empty() { - true => quote! {}, - false => quote! { < #generics > }, - }) - .unwrap_or(quote! {}); - - let custom_msg = parse_quote! { < #name #bracketed_generics as #sylvia ::types::ContractApi > :: CustomMsg }; - let custom_query = parse_quote! { < #name #bracketed_generics as #sylvia ::types::ContractApi > :: CustomQuery }; - - let instantiate_variants = MsgVariants::new( - source.as_variants(), + let entry_points = [ MsgType::Instantiate, - generics, - where_clause, + MsgType::Exec, + MsgType::Query, + MsgType::Sudo, + ] + .into_iter() + .map( + |msg_ty| match override_entry_points.get_entry_point(msg_ty) { + Some(_) => quote! {}, + None => self.emit_default_entry_point(msg_ty), + }, ); - let exec_variants = - MsgVariants::new(source.as_variants(), MsgType::Exec, generics, where_clause); - let query_variants = - MsgVariants::new(source.as_variants(), MsgType::Query, generics, where_clause); - let migrate_variants = MsgVariants::new( + + let is_migrate = MsgVariants::new( source.as_variants(), MsgType::Migrate, generics, where_clause, - ); - let reply = - MsgVariants::::new(source.as_variants(), MsgType::Reply, &[], &None) - .variants() - .map(|variant| variant.function_name.clone()) - .next(); - let sudo_variants = - MsgVariants::new(source.as_variants(), MsgType::Sudo, generics, where_clause); - let contract_generics = match &attrs.generics { - Some(generics) => quote! { ::< #generics > }, - None => quote! {}, + ) + .get_only_variant() + .is_some(); + + let migrate_not_overridden = override_entry_points + .get_entry_point(MsgType::Migrate) + .is_none(); + + let migrate = if migrate_not_overridden && is_migrate { + self.emit_default_entry_point(MsgType::Migrate) + } else { + quote! {} }; - { - let entry_points = [ - instantiate_variants, - exec_variants, - query_variants, - sudo_variants, - ] - .into_iter() - .map(|variants| { - match override_entry_points.get_entry_point(variants.msg_ty) { - Some(_) => quote! {}, - None => variants.emit_default_entry_point( - &custom_msg, - &custom_query, - name, - error, - &attrs.generics, - ), + let reply_ep = override_entry_points + .get_entry_point(MsgType::Reply) + .map(|_| quote! {}) + .unwrap_or_else(|| { + if reply.is_some() { + self.emit_default_entry_point(MsgType::Reply) + } else { + quote! {} } }); - let migrate_not_overridden = override_entry_points - .get_entry_point(MsgType::Migrate) - .is_none(); - - let migrate = if migrate_not_overridden && migrate_variants.get_only_variant().is_some() - { - migrate_variants.emit_default_entry_point( - &custom_msg, - &custom_query, - name, - error, - &attrs.generics, - ) - } else { - quote! {} - }; - - let reply_ep = override_entry_points.get_entry_point(MsgType::Reply) - .map(|_| quote! {}) - .unwrap_or_else(|| match reply { - Some(reply) => quote! { - #[#sylvia ::cw_std::entry_point] - pub fn reply( - deps: #sylvia ::cw_std::DepsMut< #custom_query >, - env: #sylvia ::cw_std::Env, - msg: #sylvia ::cw_std::Reply, - ) -> Result<#sylvia ::cw_std::Response < #custom_msg >, #error> { - #name #contract_generics ::new(). #reply((deps, env).into(), msg).map_err(Into::into) - } - }, - _ => quote! {}, - }); + quote! { + pub mod entry_points { + use super::*; - quote! { - pub mod entry_points { - use super::*; + #(#entry_points)* - #(#entry_points)* + #migrate - #migrate + #reply_ep + } + } + } - #reply_ep - } + pub fn emit_default_entry_point(&self, msg_ty: MsgType) -> TokenStream { + let Self { + name, + error, + attrs, + reply, + .. + } = self; + let sylvia = crate_module(); + + let attr_generics = &attrs.generics; + let (contract, contract_turbo) = if attr_generics.is_empty() { + (quote! { #name }, quote! { #name }) + } else { + ( + quote! { #name < #attr_generics > }, + quote! { #name :: < #attr_generics > }, + ) + }; + + let custom_msg: Type = + parse_quote! { < #contract as #sylvia ::types::ContractApi > :: CustomMsg }; + let custom_query: Type = + parse_quote! { < #contract as #sylvia ::types::ContractApi > :: CustomQuery }; + + let result = msg_ty.emit_result_type(&custom_msg, error); + let params = msg_ty.emit_ctx_params(&custom_query); + let values = msg_ty.emit_ctx_values(); + let ep_name = msg_ty.emit_ep_name(); + let associated_name = msg_ty.as_accessor_wrapper_name(); + let msg = match msg_ty { + MsgType::Reply => quote! { msg: #sylvia ::cw_std::Reply }, + _ => quote! { msg: < #contract as #sylvia ::types::ContractApi> :: #associated_name }, + }; + let dispatch = match msg_ty { + MsgType::Reply => quote! { + #contract_turbo ::new(). #reply((deps, env).into(), msg).map_err(Into::into) + }, + _ => quote! { + msg.dispatch(& #contract_turbo ::new() , ( #values )).map_err(Into::into) + }, + }; + + quote! { + #[#sylvia ::cw_std::entry_point] + pub fn #ep_name ( + #params , + #msg + ) -> #result { + #dispatch } } } diff --git a/sylvia-derive/src/parser/entry_point.rs b/sylvia-derive/src/parser/entry_point.rs index 231bb0b5..87bceff3 100644 --- a/sylvia-derive/src/parser/entry_point.rs +++ b/sylvia-derive/src/parser/entry_point.rs @@ -1,7 +1,9 @@ +use proc_macro2::TokenStream as TokenStream2; +use proc_macro_error::emit_error; use syn::parse::{Error, Nothing, Parse, ParseStream}; use syn::punctuated::Punctuated; use syn::spanned::Spanned; -use syn::{GenericArgument, Path, Result, Token}; +use syn::{parse2, GenericArgument, ItemImpl, Path, Result, Token}; use super::extract_generics_from_path; @@ -9,7 +11,26 @@ use super::extract_generics_from_path; #[derive(Default)] pub struct EntryPointArgs { /// Types used in place of contracts generics. - pub generics: Option>, + pub generics: Punctuated, +} + +impl EntryPointArgs { + pub fn new(attr: &TokenStream2, source: &ItemImpl) -> Result { + let args: Self = parse2(attr.clone()).map_err(|err| { + emit_error!(attr, err); + err + })?; + + if args.generics.len() != source.generics.params.len() { + emit_error!( + attr.span(), + "Missing concrete types."; + note = "For every generic type in the contract, a concrete type must be provided in `#[entry_points(generics)]`."; + ); + } + + Ok(args) + } } impl Parse for EntryPointArgs { @@ -22,7 +43,7 @@ impl Parse for EntryPointArgs { let generics: Path = input.parse()?; match generics.segments.last() { Some(segment) if segment.ident == "generics" => { - entry_points_args.generics = Some(extract_generics_from_path(&generics)) + entry_points_args.generics = extract_generics_from_path(&generics) } _ => return Err(Error::new(generics.span(), "Expected `generics`.")), }; diff --git a/sylvia/tests/ui/macros/entry_points.rs b/sylvia/tests/ui/macros/entry_points.rs index a131966c..36ceabb6 100644 --- a/sylvia/tests/ui/macros/entry_points.rs +++ b/sylvia/tests/ui/macros/entry_points.rs @@ -1,28 +1,62 @@ +#![allow(unused_imports)] + use sylvia::cw_std::{Response, StdResult}; use sylvia::types::{CustomMsg, CustomQuery, InstantiateCtx}; use sylvia::{contract, entry_points}; -pub struct Contract { - _phantom: std::marker::PhantomData<(E, Q)>, -} +pub mod no_generics { + use super::*; + + pub struct Contract { + _phantom: std::marker::PhantomData<(E, Q)>, + } -#[entry_points] -#[contract] -#[sv::custom(msg = E, query = Q)] -impl Contract -where - E: CustomMsg + 'static, - Q: CustomQuery + 'static, -{ - pub fn new() -> Self { - Self { - _phantom: std::marker::PhantomData, + #[entry_points] + #[contract] + #[sv::custom(msg = E, query = Q)] + impl Contract + where + E: CustomMsg + 'static, + Q: CustomQuery + 'static, + { + pub fn new() -> Self { + Self { + _phantom: std::marker::PhantomData, + } + } + + #[sv::msg(instantiate)] + fn instantiate(&self, _ctx: InstantiateCtx) -> StdResult> { + Ok(Response::new()) } } +} - #[sv::msg(instantiate)] - fn instantiate(&self, _ctx: InstantiateCtx) -> StdResult> { - Ok(Response::new()) +pub mod missing_generics { + use super::*; + + pub struct Contract { + _phantom: std::marker::PhantomData<(E, Q)>, + } + + #[entry_points(generics)] + #[contract] + #[sv::custom(msg = E, query = Q)] + impl Contract + where + E: CustomMsg + 'static, + Q: CustomQuery + 'static, + { + pub fn new() -> Self { + Self { + _phantom: std::marker::PhantomData, + } + } + + #[sv::msg(instantiate)] + fn instantiate(&self, _ctx: InstantiateCtx) -> StdResult> { + Ok(Response::new()) + } } } diff --git a/sylvia/tests/ui/macros/entry_points.stderr b/sylvia/tests/ui/macros/entry_points.stderr new file mode 100644 index 00000000..8bf7e60f --- /dev/null +++ b/sylvia/tests/ui/macros/entry_points.stderr @@ -0,0 +1,19 @@ +error: Missing concrete types. + + = note: For every generic type in the contract, a concrete type must be provided in `#[entry_points(generics)]`. + + --> tests/ui/macros/entry_points.rs:14:5 + | +14 | #[entry_points] + | ^^^^^^^^^^^^^^^ + | + = note: this error originates in the attribute macro `entry_points` (in Nightly builds, run with -Z macro-backtrace for more info) + +error: Missing concrete types. + + = note: For every generic type in the contract, a concrete type must be provided in `#[entry_points(generics)]`. + + --> tests/ui/macros/entry_points.rs:42:20 + | +42 | #[entry_points(generics)] + | ^^^^^^^^