diff --git a/derive/src/lib.rs b/derive/src/lib.rs index ef7606f60f..3835701f1b 100644 --- a/derive/src/lib.rs +++ b/derive/src/lib.rs @@ -29,36 +29,55 @@ use quote::quote; use syn::parse_macro_input; use syn::parse_quote; use syn::Data; +use syn::DeriveInput; +use syn::GenericParam; use syn::ItemFn; #[proc_macro_derive(AlignedBorrow)] pub fn aligned_borrow_derive(input: TokenStream) -> TokenStream { - let ast: syn::DeriveInput = syn::parse(input).unwrap(); - - // Get struct name from ast + let ast = parse_macro_input!(input as DeriveInput); let name = &ast.ident; - let methods = quote! { - impl core::borrow::Borrow<#name> for [T] { - fn borrow(&self) -> &#name { - debug_assert_eq!(self.len(), core::mem::size_of::<#name>()); - let (prefix, shorts, _suffix) = unsafe { self.align_to::<#name>() }; - debug_assert!(prefix.is_empty(), "Alignment should match"); - debug_assert_eq!(shorts.len(), 1); - &shorts[0] + + // Ensure the first generic parameter is the type generic, and rest all are const generics. + let mut generics_iter = ast.generics.params.iter(); + + // Extract the first generic parameter and ensure it's a type. + let type_generic = match generics_iter.next().expect("No generic parameters found") { + GenericParam::Type(type_param) => &type_param.ident, + _ => panic!("The first generic parameter must be a type."), + }; + + // Collect the remaining generic parameters, ensuring they are all const generics. + let const_generics: Vec<_> = generics_iter.map(|param| match param { + GenericParam::Const(const_param) => &const_param.ident, + _ => panic!("`AlignedBorrow` supports only a type as the first generic parameter and const generics after that"), + }).collect(); + + let methods = { + quote! { + impl<#type_generic: Copy #(, const #const_generics: usize)*> core::borrow::Borrow<#name<#type_generic #(, #const_generics)*>> for [#type_generic] { + fn borrow(&self) -> &#name<#type_generic #(, #const_generics)*> { + debug_assert_eq!(self.len(), std::mem::size_of::<#name>()); + let (prefix, shorts, _suffix) = unsafe { self.align_to::<#name<#type_generic #(, #const_generics)*>>() }; + debug_assert!(prefix.is_empty(), "Alignment should match"); + debug_assert_eq!(shorts.len(), 1); + &shorts[0] + } } - } - impl core::borrow::BorrowMut<#name> for [T] { - fn borrow_mut(&mut self) -> &mut #name { - debug_assert_eq!(self.len(), core::mem::size_of::<#name>()); - let (prefix, shorts, _suffix) = unsafe { self.align_to_mut::<#name>() }; - debug_assert!(prefix.is_empty(), "Alignment should match"); - debug_assert_eq!(shorts.len(), 1); - &mut shorts[0] + impl<#type_generic: Copy #(, const #const_generics: usize)*> core::borrow::BorrowMut<#name<#type_generic #(, #const_generics)*>> for [#type_generic] { + fn borrow_mut(&mut self) -> &mut #name<#type_generic #(, #const_generics)*> { + debug_assert_eq!(self.len(), std::mem::size_of::<#name>()); + let (prefix, shorts, _suffix) = unsafe { self.align_to_mut::<#name<#type_generic #(, #const_generics)*>>() }; + debug_assert!(prefix.is_empty(), "Alignment should match"); + debug_assert_eq!(shorts.len(), 1); + &mut shorts[0] + } } } }; - methods.into() + + TokenStream::from(methods) } #[proc_macro_derive(MachineAir, attributes(sp1_core_path, execution_record_path))]