diff --git a/Cargo.toml b/Cargo.toml index ad15f01..0e0816f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -4,3 +4,7 @@ members = [ "tree_hash_derive", ] resolver = "2" + +[patch."https://github.com/macladson/tree_hash"] +tree_hash = { path = "./tree_hash" } +tree_hash_derive = { path = "./tree_hash_derive" } diff --git a/tree_hash/Cargo.toml b/tree_hash/Cargo.toml index e3a8ca9..abf1724 100644 --- a/tree_hash/Cargo.toml +++ b/tree_hash/Cargo.toml @@ -18,8 +18,10 @@ smallvec = "1.6.1" [dev-dependencies] rand = "0.8.5" tree_hash_derive = { path = "../tree_hash_derive", version = "0.6.0" } -ethereum_ssz = "0.5" -ethereum_ssz_derive = "0.5" +ethereum_ssz = { git = "https://github.com/macladson/ethereum_ssz", branch = "stable-container" } +ethereum_ssz_derive = { git = "https://github.com/macladson/ethereum_ssz", branch = "stable-container" } +ssz_types = { git = "https://github.com/macladson/ssz_types", branch = "stable-container" } +typenum = "1.12.0" [features] arbitrary = ["ethereum-types/arbitrary"] diff --git a/tree_hash/src/impls.rs b/tree_hash/src/impls.rs index 277aedf..ffb5e24 100644 --- a/tree_hash/src/impls.rs +++ b/tree_hash/src/impls.rs @@ -205,6 +205,30 @@ impl TreeHash for Arc { } } +impl TreeHash for Option { + fn tree_hash_type() -> TreeHashType { + T::tree_hash_type() + } + + fn tree_hash_packed_encoding(&self) -> PackedEncoding { + match self { + Some(inner) => inner.tree_hash_packed_encoding(), + None => unreachable!(), + } + } + + fn tree_hash_packing_factor() -> usize { + T::tree_hash_packing_factor() + } + + fn tree_hash_root(&self) -> Hash256 { + match self { + Some(inner) => inner.tree_hash_root(), + None => unreachable!(), + } + } +} + #[cfg(test)] mod test { use super::*; diff --git a/tree_hash/src/lib.rs b/tree_hash/src/lib.rs index 2c0a5c5..ea91f1a 100644 --- a/tree_hash/src/lib.rs +++ b/tree_hash/src/lib.rs @@ -92,6 +92,13 @@ pub fn mix_in_selector(root: &Hash256, selector: u8) -> Option { Some(Hash256::from_slice(&root)) } +pub fn mix_in_aux(root: &Hash256, aux: &Hash256) -> Hash256 { + Hash256::from_slice(ðereum_hashing::hash32_concat( + root.as_bytes(), + aux.as_bytes(), + )) +} + /// Returns a cached padding node for a given height. fn get_zero_hash(height: usize) -> &'static [u8] { if height <= ZERO_HASHES_MAX_INDEX { @@ -107,6 +114,7 @@ pub enum TreeHashType { Vector, List, Container, + StableContainer, } pub trait TreeHash { diff --git a/tree_hash/tests/tests.rs b/tree_hash/tests/tests.rs index b831614..ec53bce 100644 --- a/tree_hash/tests/tests.rs +++ b/tree_hash/tests/tests.rs @@ -1,6 +1,8 @@ use ssz_derive::Encode; -use tree_hash::{Hash256, MerkleHasher, PackedEncoding, TreeHash, BYTES_PER_CHUNK}; +use ssz_types::BitVector; +use tree_hash::{self, Hash256, MerkleHasher, PackedEncoding, TreeHash, BYTES_PER_CHUNK}; use tree_hash_derive::TreeHash; +use typenum::Unsigned; #[derive(Encode)] struct HashVec { @@ -126,3 +128,88 @@ fn variable_union() { mix_in_selector(u8_hash_concat(2, 1), 1) ); } + +#[derive(TreeHash)] +#[tree_hash(struct_behaviour = "stable_container")] +#[tree_hash(max_fields = "typenum::U8")] +struct Shape { + side: Option, + color: Option, + radius: Option, +} + +#[derive(TreeHash, Clone)] +#[tree_hash(struct_behaviour = "profile")] +#[tree_hash(max_fields = "typenum::U8")] +struct Square { + // We always start with a stable_index of 0. + side: u16, + color: u8, +} + +#[derive(TreeHash, Clone)] +#[tree_hash(struct_behaviour = "profile")] +#[tree_hash(max_fields = "typenum::U8")] +struct Circle { + #[tree_hash(stable_index = 1)] + color: u8, + #[tree_hash(skip_hashing)] + phantom: u8, + // Note that we do not need to specify `stable_index = 2` here since + // we always increment by 1 from the previous index. + radius: u16, +} + +#[derive(TreeHash)] +#[tree_hash(enum_behaviour = "transparent_stable")] +enum ShapeEnum { + SquareVariant(Square), + CircleVariant(Circle), +} + +#[test] +fn shape_1() { + let shape_1 = Shape { + side: Some(16), + color: Some(2), + radius: None, + }; + + let square = Square { side: 16, color: 2 }; + + assert_eq!(shape_1.tree_hash_root(), square.tree_hash_root()); +} + +#[test] +fn shape_2() { + let shape_2 = Shape { + side: None, + color: Some(1), + radius: Some(42), + }; + + let circle = Circle { + color: 1, + phantom: 6, + radius: 42, + }; + + assert_eq!(shape_2.tree_hash_root(), circle.tree_hash_root()); +} + +#[test] +fn shape_enum() { + let square = Square { side: 16, color: 2 }; + + let circle = Circle { + color: 1, + phantom: 6, + radius: 14, + }; + + let enum_square = ShapeEnum::SquareVariant(square.clone()); + let enum_circle = ShapeEnum::CircleVariant(circle.clone()); + + assert_eq!(square.tree_hash_root(), enum_square.tree_hash_root()); + assert_eq!(circle.tree_hash_root(), enum_circle.tree_hash_root()); +} diff --git a/tree_hash_derive/Cargo.toml b/tree_hash_derive/Cargo.toml index 7d4fce3..1b9f48f 100644 --- a/tree_hash_derive/Cargo.toml +++ b/tree_hash_derive/Cargo.toml @@ -17,3 +17,5 @@ proc-macro = true syn = "1.0.42" quote = "1.0.7" darling = "0.13.0" +proc-macro2 = "1.0.23" +ssz_types = { git = "https://github.com/macladson/ssz_types", branch = "stable-container" } diff --git a/tree_hash_derive/src/lib.rs b/tree_hash_derive/src/lib.rs index 21ff324..1dc59de 100644 --- a/tree_hash_derive/src/lib.rs +++ b/tree_hash_derive/src/lib.rs @@ -1,9 +1,9 @@ #![recursion_limit = "256"] -use darling::FromDeriveInput; +use darling::{FromDeriveInput, FromMeta}; use proc_macro::TokenStream; use quote::quote; use std::convert::TryInto; -use syn::{parse_macro_input, Attribute, DataEnum, DataStruct, DeriveInput, Meta}; +use syn::{parse_macro_input, Attribute, DataEnum, DataStruct, DeriveInput, Expr, Meta}; /// The highest possible union selector value (higher values are reserved for backwards compatible /// extensions). @@ -12,18 +12,58 @@ const MAX_UNION_SELECTOR: u8 = 127; #[derive(Debug, FromDeriveInput)] #[darling(attributes(tree_hash))] struct StructOpts { + #[darling(default)] + struct_behaviour: Option, #[darling(default)] enum_behaviour: Option, + #[darling(default)] + max_fields: Option, } +/// Field-level configuration. +#[derive(Debug, Default, FromMeta)] +struct FieldOpts { + #[darling(default)] + skip_hashing: bool, + #[darling(default)] + stable_index: Option, +} + +const STRUCT_CONTAINER: &str = "container"; +const STRUCT_STABLE_CONTAINER: &str = "stable_container"; +const STRUCT_PROFILE: &str = "profile"; +const STRUCT_VARIANTS: &[&str] = &[STRUCT_CONTAINER, STRUCT_STABLE_CONTAINER, STRUCT_PROFILE]; + const ENUM_TRANSPARENT: &str = "transparent"; +const ENUM_TRANSPARENT_STABLE: &str = "transparent_stable"; const ENUM_UNION: &str = "union"; const ENUM_VARIANTS: &[&str] = &[ENUM_TRANSPARENT, ENUM_UNION]; const NO_ENUM_BEHAVIOUR_ERROR: &str = "enums require an \"enum_behaviour\" attribute, \ e.g., #[tree_hash(enum_behaviour = \"transparent\")]"; +enum StructBehaviour { + Container, + StableContainer, + Profile, +} + +impl StructBehaviour { + pub fn new(s: Option) -> Option { + s.map(|s| match s.as_ref() { + STRUCT_CONTAINER => StructBehaviour::Container, + STRUCT_STABLE_CONTAINER => StructBehaviour::StableContainer, + STRUCT_PROFILE => StructBehaviour::Profile, + other => panic!( + "{} is an invalid struct_behaviour, use one of: {:?}", + other, STRUCT_VARIANTS + ), + }) + } +} + enum EnumBehaviour { Transparent, + TransparentStable, Union, } @@ -31,6 +71,7 @@ impl EnumBehaviour { pub fn new(s: Option) -> Option { s.map(|s| match s.as_ref() { ENUM_TRANSPARENT => EnumBehaviour::Transparent, + ENUM_TRANSPARENT_STABLE => EnumBehaviour::TransparentStable, ENUM_UNION => EnumBehaviour::Union, other => panic!( "{} is an invalid enum_behaviour, use either {:?}", @@ -113,6 +154,43 @@ fn should_skip_hashing(field: &syn::Field) -> bool { }) } +fn parse_tree_hash_fields( + struct_data: &syn::DataStruct, +) -> Vec<(&syn::Type, Option<&syn::Ident>, FieldOpts)> { + struct_data + .fields + .iter() + .map(|field| { + let ty = &field.ty; + let ident = field.ident.as_ref(); + + let field_opts_candidates = field + .attrs + .iter() + .filter(|attr| { + attr.path + .get_ident() + .map_or(false, |ident| *ident == "tree_hash") + }) + .collect::>(); + + if field_opts_candidates.len() > 1 { + panic!("more than one field-level \"tree_hash\" attribute provided") + } + + let field_opts = field_opts_candidates + .first() + .map(|attr| { + let meta = attr.parse_meta().unwrap(); + FieldOpts::from_meta(&meta).unwrap() + }) + .unwrap_or_default(); + + (ty, ident, field_opts) + }) + .collect() +} + /// Implements `tree_hash::TreeHash` for some `struct`. /// /// Fields are hashed in the order they are defined. @@ -121,23 +199,66 @@ pub fn tree_hash_derive(input: TokenStream) -> TokenStream { let item = parse_macro_input!(input as DeriveInput); let opts = StructOpts::from_derive_input(&item).unwrap(); let enum_opt = EnumBehaviour::new(opts.enum_behaviour); + let struct_opt = StructBehaviour::new(opts.struct_behaviour); match &item.data { syn::Data::Struct(s) => { if enum_opt.is_some() { - panic!("enum_behaviour is invalid for structs"); + panic!("cannot use \"enum_behaviour\" for a struct"); + } + match struct_opt { + Some(StructBehaviour::Container) => tree_hash_derive_struct_container(&item, s), + Some(StructBehaviour::StableContainer) => { + if let Some(max_fields_string) = opts.max_fields { + let max_fields_ref = max_fields_string.as_ref(); + let max_fields_ty: Expr = syn::parse_str(max_fields_ref) + .expect("\"max_fields\" is not a valid type."); + let max_fields: proc_macro2::TokenStream = quote! { #max_fields_ty }; + + tree_hash_derive_struct_stable_container(&item, s, max_fields) + } else { + panic!("stable_container requires \"max_fields\"") + } + } + Some(StructBehaviour::Profile) => { + if let Some(max_fields_string) = opts.max_fields { + let max_fields_ref = max_fields_string.as_ref(); + let max_fields_ty: Expr = syn::parse_str(max_fields_ref) + .expect("\"max_fields\" is not a valid type."); + let max_fields: proc_macro2::TokenStream = quote! { #max_fields_ty }; + + tree_hash_derive_struct_profile(&item, s, max_fields) + } else { + panic!("profile requires \"max_fields\"") + } + } + // Default to container. + None => tree_hash_derive_struct_container(&item, s), + } + } + syn::Data::Enum(s) => { + if struct_opt.is_some() { + panic!("cannot use \"struct_behaviour\" for an enum"); + } + match enum_opt.expect(NO_ENUM_BEHAVIOUR_ERROR) { + EnumBehaviour::Transparent => tree_hash_derive_enum_transparent( + &item, + s, + syn::parse_str("Container").unwrap(), + ), + EnumBehaviour::TransparentStable => tree_hash_derive_enum_transparent( + &item, + s, + syn::parse_str("StableContainer").unwrap(), + ), + EnumBehaviour::Union => tree_hash_derive_enum_union(&item, s), } - tree_hash_derive_struct(&item, s) } - syn::Data::Enum(s) => match enum_opt.expect(NO_ENUM_BEHAVIOUR_ERROR) { - EnumBehaviour::Transparent => tree_hash_derive_enum_transparent(&item, s), - EnumBehaviour::Union => tree_hash_derive_enum_union(&item, s), - }, _ => panic!("tree_hash_derive only supports structs and enums."), } } -fn tree_hash_derive_struct(item: &DeriveInput, struct_data: &DataStruct) -> TokenStream { +fn tree_hash_derive_struct_container(item: &DeriveInput, struct_data: &DataStruct) -> TokenStream { let name = &item.ident; let (impl_generics, ty_generics, where_clause) = &item.generics.split_for_impl(); @@ -173,6 +294,162 @@ fn tree_hash_derive_struct(item: &DeriveInput, struct_data: &DataStruct) -> Toke output.into() } +fn tree_hash_derive_struct_stable_container( + item: &DeriveInput, + struct_data: &DataStruct, + max_fields: proc_macro2::TokenStream, +) -> TokenStream { + let name = &item.ident; + let (impl_generics, ty_generics, where_clause) = &item.generics.split_for_impl(); + + let idents = get_hashable_fields(struct_data); + + let output = quote! { + impl #impl_generics tree_hash::TreeHash for #name #ty_generics #where_clause { + fn tree_hash_type() -> tree_hash::TreeHashType { + tree_hash::TreeHashType::StableContainer + } + + fn tree_hash_packed_encoding(&self) -> tree_hash::PackedEncoding { + unreachable!("Struct should never be packed.") + } + + fn tree_hash_packing_factor() -> usize { + unreachable!("Struct should never be packed.") + } + + fn tree_hash_root(&self) -> tree_hash::Hash256 { + // Construct BitVector + let mut active_fields = BitVector::<#max_fields>::new(); + + let mut working_field: usize = 0; + + #( + if self.#idents.is_some() { + active_fields.set(working_field, true).expect("Should not be out of bounds"); + } + working_field += 1; + )* + + // Hash according to `max_fields` regardless of the actual number of fields on the struct. + let mut hasher = tree_hash::MerkleHasher::with_leaves(#max_fields::to_usize()); + + #( + if self.#idents.is_some() { + hasher.write(self.#idents.tree_hash_root().as_bytes()) + .expect("tree hash derive should not apply too many leaves"); + } + )* + + let hash = hasher.finish().expect("tree hash derive should not have a remaining buffer"); + + tree_hash::mix_in_aux(&hash, &active_fields.tree_hash_root()) + } + } + }; + output.into() +} + +fn tree_hash_derive_struct_profile( + item: &DeriveInput, + struct_data: &DataStruct, + max_fields: proc_macro2::TokenStream, +) -> TokenStream { + let name = &item.ident; + let (impl_generics, ty_generics, where_clause) = &item.generics.split_for_impl(); + + let set_active_fields = &mut vec![]; + let hashes = &mut vec![]; + + // Assume a starting index of 0. + let mut index = 0; + + for (ty, ident, field_opt) in parse_tree_hash_fields(struct_data) { + let mut is_optional = false; + if field_opt.skip_hashing { + continue; + } + + let ident = match ident { + Some(ref ident) => ident, + _ => { + panic!("#[tree_hash(struct_behaviour = \"profile\")] only supports named struct fields.") + } + }; + + if let Some(new_index) = field_opt.stable_index { + index = new_index; + } + + if ty_inner_type("Option", ty).is_some() { + is_optional = true; + } + + if is_optional { + set_active_fields.push(quote! { + if self.#ident.is_some() { + active_fields.set(#index, true).expect("Should not be out of bounds"); + } + }); + + hashes.push(quote! { + if active_fields.get(index) { + hasher.write(self.#ident.tree_hash_root().as_bytes()) + .expect("tree hash derive should not apply too many leaves"); + } + }); + } else { + set_active_fields.push(quote! { + active_fields.set(#index, true).expect("Should not be out of bounds"); + }); + hashes.push(quote! { + hasher.write(self.#ident.tree_hash_root().as_bytes()) + .expect("tree hash derive should not apply too many leaves"); + }); + } + + // Increment the index. + index += 1; + } + + let output = quote! { + impl #impl_generics tree_hash::TreeHash for #name #ty_generics #where_clause { + fn tree_hash_type() -> tree_hash::TreeHashType { + tree_hash::TreeHashType::StableContainer + } + + fn tree_hash_packed_encoding(&self) -> tree_hash::PackedEncoding { + unreachable!("Struct should never be packed.") + } + + fn tree_hash_packing_factor() -> usize { + unreachable!("Struct should never be packed.") + } + + fn tree_hash_root(&self) -> tree_hash::Hash256 { + // Construct BitVector + let mut active_fields = BitVector::<#max_fields>::new(); + + #( + #set_active_fields + )* + + // Hash according to `max_fields` regardless of the actual number of fields on the struct. + let mut hasher = tree_hash::MerkleHasher::with_leaves(#max_fields::to_usize()); + + #( + #hashes + )* + + let hash = hasher.finish().expect("tree hash derive should not have a remaining buffer"); + + tree_hash::mix_in_aux(&hash, &active_fields.tree_hash_root()) + } + } + }; + output.into() +} + /// Derive `TreeHash` for an enum in the "transparent" method. /// /// The "transparent" method is distinct from the "union" method specified in the SSZ specification. @@ -192,6 +469,7 @@ fn tree_hash_derive_struct(item: &DeriveInput, struct_data: &DataStruct) -> Toke fn tree_hash_derive_enum_transparent( derive_input: &DeriveInput, enum_data: &DataEnum, + inner_container_type: Expr, ) -> TokenStream { let name = &derive_input.ident; let (impl_generics, ty_generics, where_clause) = &derive_input.generics.split_for_impl(); @@ -224,11 +502,11 @@ fn tree_hash_derive_enum_transparent( #( assert_eq!( #type_exprs, - tree_hash::TreeHashType::Container, + tree_hash::TreeHashType::#inner_container_type, "all variants must be of container type" ); )* - tree_hash::TreeHashType::Container + tree_hash::TreeHashType::#inner_container_type } fn tree_hash_packed_encoding(&self) -> tree_hash::PackedEncoding { @@ -335,3 +613,23 @@ fn compute_union_selectors(num_variants: usize) -> Vec { union_selectors } + +fn ty_inner_type<'a>(wrapper: &str, ty: &'a syn::Type) -> Option<&'a syn::Type> { + if let syn::Type::Path(ref p) = ty { + if p.path.segments.len() != 1 || p.path.segments[0].ident != wrapper { + return None; + } + + if let syn::PathArguments::AngleBracketed(ref inner_ty) = p.path.segments[0].arguments { + if inner_ty.args.len() != 1 { + return None; + } + + let inner_ty = inner_ty.args.first().unwrap(); + if let syn::GenericArgument::Type(ref t) = inner_ty { + return Some(t); + } + } + } + None +}