diff --git a/Cargo.toml b/Cargo.toml index 785aeb6..dcbf517 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -18,4 +18,4 @@ quote = "1.0" synstructure = "0.12" [dev-dependencies] -abomonation = "0.7" +abomonation = { git = "https://github.com/HadrienG2/abomonation.git", branch = "alignment" } diff --git a/src/lib.rs b/src/lib.rs index 53aea19..074c49a 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,7 +1,8 @@ #![recursion_limit="128"] -use synstructure::decl_derive; use quote::quote; +use std::collections::HashSet; +use synstructure::decl_derive; decl_derive!([Abomonation, attributes(unsafe_abomonate_ignore)] => derive_abomonation); @@ -14,38 +15,77 @@ fn derive_abomonation(mut s: synstructure::Structure) -> proc_macro2::TokenStrea }); let entomb = s.each(|bi| quote! { - ::abomonation::Abomonation::entomb(#bi, _write)?; + ::abomonation::Entomb::entomb(#bi, _write)?; }); - let extent = s.each(|bi| quote! { - sum += ::abomonation::Abomonation::extent(#bi); - }); + // T::alignment() is the max of mem::align_of and the U::alignment()s of + // every U type used as a struct member or inside of an enum variant (which + // includes the alignment of recursively abomonated data) + // + // Unfortunately, we cannot use synstructure's nice `fold()` convenience + // here because it's based on generating match arms for a `self` value, + // whereas here we're trying to implement an inherent type method without + // having such a `self` handy. + // + // We can, however, use Structure::variants() and VariantInfo::bindings() + // to enumerate all the types which _would appear_ in such match arms' + // inner bindings. + // + let mut alignment = vec![ + quote!(let mut align = ::std::mem::align_of::();) + ]; + let mut probed_types = HashSet::new(); + for variant_info in s.variants() { + for binding_info in variant_info.bindings() { + let binding_type = &binding_info.ast().ty; + // Do not query a type's alignment() multiple times + if probed_types.insert(binding_type) { + alignment.push( + quote!(align = align.max(<#binding_type as ::abomonation::Entomb>::alignment());) + ); + } + } + } + alignment.push(quote!(align)); s.bind_with(|_| synstructure::BindStyle::RefMut); let exhume = s.each(|bi| quote! { - let temp = bytes; - bytes = ::abomonation::Abomonation::exhume(#bi, temp)?; + ::abomonation::Exhume::exhume(From::from(#bi), reader)?; }); - s.bound_impl(quote!(abomonation::Abomonation), quote! { - #[inline] unsafe fn entomb(&self, _write: &mut W) -> ::std::io::Result<()> { - match *self { #entomb } - Ok(()) - } - #[allow(unused_mut)] - #[inline] fn extent(&self) -> usize { - let mut sum = 0; - match *self { #extent } - sum + s.gen_impl(quote! { + extern crate abomonation; + extern crate std; + + gen unsafe impl abomonation::Entomb for @Self { + unsafe fn entomb( + &self, + _write: &mut ::abomonation::align::AlignedWriter + ) -> ::std::io::Result<()> { + match *self { #entomb } + Ok(()) + } + + fn alignment() -> usize { + #(#alignment)* + } } - #[allow(unused_mut)] - #[inline] unsafe fn exhume<'a,'b>( - &'a mut self, - mut bytes: &'b mut [u8] - ) -> Option<&'b mut [u8]> { - match *self { #exhume } - Some(bytes) + + gen unsafe impl<'de> abomonation::Exhume<'de> for @Self + where Self: 'de, + { + #[allow(unused_mut)] + unsafe fn exhume( + self_: std::ptr::NonNull, + reader: &mut ::abomonation::align::AlignedReader<'de> + ) -> Option<&'de mut Self> { + // FIXME: This (briefly) constructs an &mut _ to invalid data + // (via "ref mut"), which is UB. The proposed &raw mut + // operator would allow avoiding this. + match *self_.as_ptr() { #exhume } + Some(&mut *self_.as_ptr()) + } } }) } diff --git a/tests/test.rs b/tests/test.rs index 7cc4a54..1368a2c 100644 --- a/tests/test.rs +++ b/tests/test.rs @@ -6,6 +6,7 @@ extern crate abomonation_derive; #[cfg(test)] mod tests { use abomonation::*; + use abomonation::align::AlignedBytes; #[derive(Eq, PartialEq, Abomonation)] pub struct Struct { @@ -26,6 +27,7 @@ mod tests { assert_eq!(bytes.len(), measure(&record)); // decode from binary data + let mut bytes = AlignedBytes::::new(&mut bytes); if let Some((result, rest)) = unsafe { decode::(&mut bytes) } { assert!(result == &record); assert!(rest.len() == 0); @@ -47,6 +49,7 @@ mod tests { assert_eq!(bytes.len(), measure(&record)); // decode from binary data + let mut bytes = AlignedBytes::::new(&mut bytes); if let Some((result, rest)) = unsafe { decode::(&mut bytes) } { assert!(result == &record); assert!(rest.len() == 0); @@ -68,6 +71,7 @@ mod tests { assert_eq!(bytes.len(), measure(&record)); // decode from binary data + let mut bytes = AlignedBytes::::new(&mut bytes); if let Some((result, rest)) = unsafe { decode::(&mut bytes) } { assert!(result == &record); assert!(rest.len() == 0); @@ -89,6 +93,7 @@ mod tests { assert_eq!(bytes.len(), measure(&record)); // decode from binary data + let mut bytes = AlignedBytes::>>::new(&mut bytes); if let Some((result, rest)) = unsafe { decode::>>(&mut bytes) } { assert!(result == &record); assert!(rest.len() == 0); @@ -115,6 +120,7 @@ mod tests { assert_eq!(bytes.len(), measure(&record)); // decode from binary data + let mut bytes = AlignedBytes::::new(&mut bytes); if let Some((result, rest)) = unsafe { decode::(&mut bytes) } { assert!(result == &record); assert!(rest.len() == 0); @@ -126,7 +132,7 @@ mod tests { pub enum DataEnum { A(String, u64, Vec), B, - C(String, String, String) + C(String, String, u16) } #[test] @@ -141,6 +147,7 @@ mod tests { assert_eq!(bytes.len(), measure(&record)); // decode from binary data + let mut bytes = AlignedBytes::::new(&mut bytes); if let Some((result, rest)) = unsafe { decode::(&mut bytes) } { assert!(result == &record); assert!(rest.len() == 0); @@ -197,9 +204,32 @@ mod tests { assert_eq!(bytes.len(), measure(&record)); // decode from binary data + let mut bytes = AlignedBytes::::new(&mut bytes); if let Some((result, rest)) = unsafe { decode::(&mut bytes) } { assert!(result == &record); assert!(rest.len() == 0); } } + + #[derive(Eq, PartialEq, Abomonation)] + pub struct StructWithRef<'a> { + a: &'a str, + b: bool, + } + + #[test] + fn test_struct_with_ref() { + let record = StructWithRef { a: &"test", b: true }; + + let mut bytes = Vec::new(); + unsafe { encode(&record, &mut bytes).unwrap(); } + + assert_eq!(bytes.len(), measure(&record)); + + let mut bytes = AlignedBytes::::new(&mut bytes); + if let Some((result, rest)) = unsafe { decode::(&mut bytes) } { + assert!(result == &record); + assert!(rest.len() == 0); + } + } }