Skip to content

Commit

Permalink
refactor: make AlignedBorrow accept arbitrary no of const generics (#…
Browse files Browse the repository at this point in the history
…368)

Signed-off-by: 0xkanekiken <[email protected]>
  • Loading branch information
0xkanekiken authored Mar 12, 2024
1 parent 4528b78 commit 3e2b07f
Showing 1 changed file with 39 additions and 20 deletions.
59 changes: 39 additions & 20 deletions derive/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,36 +29,55 @@ use quote::quote;
use syn::parse_macro_input;
use syn::parse_quote;
use syn::Data;
use syn::DeriveInput;
use syn::GenericParam;
use syn::ItemFn;

#[proc_macro_derive(AlignedBorrow)]
pub fn aligned_borrow_derive(input: TokenStream) -> TokenStream {
let ast: syn::DeriveInput = syn::parse(input).unwrap();

// Get struct name from ast
let ast = parse_macro_input!(input as DeriveInput);
let name = &ast.ident;
let methods = quote! {
impl<T: Copy> core::borrow::Borrow<#name<T>> for [T] {
fn borrow(&self) -> &#name<T> {
debug_assert_eq!(self.len(), core::mem::size_of::<#name<u8>>());
let (prefix, shorts, _suffix) = unsafe { self.align_to::<#name<T>>() };
debug_assert!(prefix.is_empty(), "Alignment should match");
debug_assert_eq!(shorts.len(), 1);
&shorts[0]

// Ensure the first generic parameter is the type generic, and rest all are const generics.
let mut generics_iter = ast.generics.params.iter();

// Extract the first generic parameter and ensure it's a type.
let type_generic = match generics_iter.next().expect("No generic parameters found") {
GenericParam::Type(type_param) => &type_param.ident,
_ => panic!("The first generic parameter must be a type."),
};

// Collect the remaining generic parameters, ensuring they are all const generics.
let const_generics: Vec<_> = generics_iter.map(|param| match param {
GenericParam::Const(const_param) => &const_param.ident,
_ => panic!("`AlignedBorrow` supports only a type as the first generic parameter and const generics after that"),
}).collect();

let methods = {
quote! {
impl<#type_generic: Copy #(, const #const_generics: usize)*> core::borrow::Borrow<#name<#type_generic #(, #const_generics)*>> for [#type_generic] {
fn borrow(&self) -> &#name<#type_generic #(, #const_generics)*> {
debug_assert_eq!(self.len(), std::mem::size_of::<#name<u8 #(, #const_generics)*>>());
let (prefix, shorts, _suffix) = unsafe { self.align_to::<#name<#type_generic #(, #const_generics)*>>() };
debug_assert!(prefix.is_empty(), "Alignment should match");
debug_assert_eq!(shorts.len(), 1);
&shorts[0]
}
}
}

impl<T: Copy> core::borrow::BorrowMut<#name<T>> for [T] {
fn borrow_mut(&mut self) -> &mut #name<T> {
debug_assert_eq!(self.len(), core::mem::size_of::<#name<u8>>());
let (prefix, shorts, _suffix) = unsafe { self.align_to_mut::<#name<T>>() };
debug_assert!(prefix.is_empty(), "Alignment should match");
debug_assert_eq!(shorts.len(), 1);
&mut shorts[0]
impl<#type_generic: Copy #(, const #const_generics: usize)*> core::borrow::BorrowMut<#name<#type_generic #(, #const_generics)*>> for [#type_generic] {
fn borrow_mut(&mut self) -> &mut #name<#type_generic #(, #const_generics)*> {
debug_assert_eq!(self.len(), std::mem::size_of::<#name<u8 #(, #const_generics)*>>());
let (prefix, shorts, _suffix) = unsafe { self.align_to_mut::<#name<#type_generic #(, #const_generics)*>>() };
debug_assert!(prefix.is_empty(), "Alignment should match");
debug_assert_eq!(shorts.len(), 1);
&mut shorts[0]
}
}
}
};
methods.into()

TokenStream::from(methods)
}

#[proc_macro_derive(MachineAir, attributes(sp1_core_path, execution_record_path))]
Expand Down

0 comments on commit 3e2b07f

Please sign in to comment.