From 6ae83c167aed04422a33165dd2d075da376d61cc Mon Sep 17 00:00:00 2001 From: Juniper Tyree Date: Fri, 8 Dec 2023 16:09:21 +0000 Subject: [PATCH] Improve enum discriminant descriptor --- const-type-layout-derive/src/lib.rs | 37 +---- src/impls/core/option.rs | 17 +-- src/impls/core/result.rs | 14 +- src/lib.rs | 223 ++++++++++++---------------- src/ser.rs | 78 +++++++++- 5 files changed, 189 insertions(+), 180 deletions(-) diff --git a/const-type-layout-derive/src/lib.rs b/const-type-layout-derive/src/lib.rs index bb7d4ba..8ec27c0 100644 --- a/const-type-layout-derive/src/lib.rs +++ b/const-type-layout-derive/src/lib.rs @@ -42,18 +42,16 @@ pub fn derive_type_layout(input: TokenStream) -> TokenStream { let inner_types = extract_inner_types(&input.data); + let discriminant_ty = if let syn::Data::Enum(_) = input.data { + Some(quote! { ::Ty, }) + } else { + None + }; + let Generics { type_layout_input_generics, type_set_input_generics, - } = generate_generics( - &crate_path, - &ty_name, - &ty_generics, - &input.generics, - matches!(input.data, syn::Data::Enum(_)), - &extra_bounds, - &type_params, - ); + } = generate_generics(&crate_path, &input.generics, &extra_bounds, &type_params); let (type_layout_impl_generics, type_layout_ty_generics, type_layout_where_clause) = type_layout_input_generics.split_for_impl(); let (type_set_impl_generics, type_set_ty_generics, type_set_where_clause) = @@ -84,7 +82,7 @@ pub fn derive_type_layout(input: TokenStream) -> TokenStream { { type Output<__TypeSetRest: #crate_path::typeset::ExpandTypeSet> = #crate_path::typeset::Set; } } @@ -383,32 +381,13 @@ struct Generics { fn generate_generics( crate_path: &syn::Path, - ty_name: &syn::Ident, - ty_generics: &syn::TypeGenerics, generics: &syn::Generics, - is_enum: bool, extra_bounds: &[syn::WherePredicate], type_params: &[&syn::Ident], ) -> Generics { let mut type_layout_input_generics = generics.clone(); let mut type_set_input_generics = generics.clone(); - if is_enum { - type_layout_input_generics - .make_where_clause() - .predicates - .push(syn::parse_quote! { - [u8; ::core::mem::size_of::<::core::mem::Discriminant<#ty_name #ty_generics>>()]: - }); - - type_set_input_generics - .make_where_clause() - .predicates - .push(syn::parse_quote! { - [u8; ::core::mem::size_of::<::core::mem::Discriminant<#ty_name #ty_generics>>()]: - }); - } - for ty in type_params { type_layout_input_generics .make_where_clause() diff --git a/src/impls/core/option.rs b/src/impls/core/option.rs index 33d6e89..d80b2eb 100644 --- a/src/impls/core/option.rs +++ b/src/impls/core/option.rs @@ -3,10 +3,7 @@ use crate::{ Field, MaybeUninhabited, TypeLayout, TypeLayoutInfo, TypeStructure, Variant, }; -unsafe impl const TypeLayout for core::option::Option -where - [u8; core::mem::size_of::>()]:, -{ +unsafe impl const TypeLayout for core::option::Option { const TYPE_LAYOUT: TypeLayoutInfo<'static> = TypeLayoutInfo { name: ::core::any::type_name::(), size: ::core::mem::size_of::(), @@ -41,9 +38,11 @@ where } } -unsafe impl ComputeTypeSet for core::option::Option -where - [u8; core::mem::size_of::>()]:, -{ - type Output = Set; +unsafe impl ComputeTypeSet for core::option::Option { + type Output = Set< + Self, + tset![ + T, ::Ty, .. @ R + ], + >; } diff --git a/src/impls/core/result.rs b/src/impls/core/result.rs index 92d0aba..c3348a6 100644 --- a/src/impls/core/result.rs +++ b/src/impls/core/result.rs @@ -5,8 +5,6 @@ use crate::{ unsafe impl const TypeLayout for core::result::Result -where - [u8; core::mem::size_of::>()]:, { const TYPE_LAYOUT: TypeLayoutInfo<'static> = TypeLayoutInfo { name: ::core::any::type_name::(), @@ -58,9 +56,11 @@ where } } -unsafe impl ComputeTypeSet for core::result::Result -where - [u8; core::mem::size_of::>()]:, -{ - type Output = Set; +unsafe impl ComputeTypeSet for core::result::Result { + type Output = Set< + Self, + tset![ + T, E, ::Ty, .. @ R + ], + >; } diff --git a/src/lib.rs b/src/lib.rs index 6335d6a..43f45a9 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -165,6 +165,7 @@ #![feature(const_maybe_uninit_array_assume_init)] #![feature(c_variadic)] #![feature(ptr_from_ref)] +#![feature(discriminant_kind)] #![allow(incomplete_features)] #![feature(generic_const_exprs)] #![feature(specialization)] @@ -355,15 +356,77 @@ pub enum TypeStructure< #[cfg_attr(feature = "serde", derive(::serde::Serialize, ::serde::Deserialize))] pub struct Variant<'a, F: Deref]> = &'a [Field<'a>]> { pub name: &'a str, - pub discriminant: MaybeUninhabited>, + pub discriminant: MaybeUninhabited, pub fields: F, } -#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] +#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)] #[cfg_attr(feature = "serde", derive(::serde::Serialize, ::serde::Deserialize))] -#[repr(transparent)] -pub struct Discriminant<'a> { - pub big_endian_bytes: &'a [u8], +pub enum Discriminant { + I8(i8), + I16(i16), + I32(i32), + I64(i64), + I128(i128), + Isize(isize), + U8(u8), + U16(u16), + U32(u32), + U64(u64), + U128(u128), + Usize(usize), +} + +#[const_trait] +pub trait ExtractDiscriminant { + type Ty: typeset::ComputeTypeSet; + + fn discriminant(&self) -> Discriminant; +} + +impl const ExtractDiscriminant for T { + type Ty = + ::Discriminant>>::Ty; + + fn discriminant(&self) -> Discriminant { + ::Discriminant>>::discriminant(self) + } +} + +#[doc(hidden)] +#[const_trait] +pub trait ExtractDiscriminantSpec { + type Ty: typeset::ComputeTypeSet; + + fn discriminant(&self) -> Discriminant; +} + +impl const ExtractDiscriminantSpec<::Discriminant> for T { + default type Ty = !; + + default fn discriminant(&self) -> Discriminant { + panic!("bug: unknown discriminant kind") + } +} + +macro_rules! impl_extract_discriminant { + ($variant:ident($ty:ty)) => { + impl> const ExtractDiscriminantSpec<$ty> for T { + type Ty = $ty; + + fn discriminant(&self) -> Discriminant { + Discriminant::$variant(core::intrinsics::discriminant_value(self)) + } + } + }; + ($($variant:ident($ty:ty)),*) => { + $(impl_extract_discriminant! { $variant($ty) })* + }; +} + +impl_extract_discriminant! { + I8(i8), I16(i16), I32(i32), I64(i64), I128(i128), Isize(isize), + U8(u8), U16(u16), U32(u32), U64(u64), U128(u128), Usize(usize) } #[derive(Clone, Copy, Debug, Hash)] @@ -468,28 +531,6 @@ impl<'a, F: Deref]> + PartialOrd> PartialOrd for Variant<'a, } } -impl<'a> fmt::Debug for Discriminant<'a> { - fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { - write!(fmt, "0x")?; - - let mut is_zero = true; - - for byte in self.big_endian_bytes.iter().copied() { - if byte != 0_u8 { - is_zero = false; - - write!(fmt, "{byte:x}")?; - } - } - - if is_zero { - write!(fmt, "0")?; - } - - Ok(()) - } -} - impl<'a> PartialEq for Field<'a> { fn eq(&self, other: &Self) -> bool { self.name == other.name && self.offset == other.offset && core::ptr::eq(self.ty, other.ty) @@ -541,38 +582,16 @@ pub macro struct_field_offset($ty_name:ident => $ty:ty => (*$base:ident).$field: #[doc(hidden)] #[allow_internal_unstable(const_discriminant)] pub macro struct_variant_discriminant { - ($ty_name:ident => $ty:ty => $variant_name:ident) => { - $crate::MaybeUninhabited::Inhabited { - 0: $crate::Discriminant { - big_endian_bytes: &{ - let uninit: $ty = $ty_name::$variant_name; - - let system_endian_bytes: [u8; ::core::mem::size_of::<::core::mem::Discriminant<$ty>>()] = unsafe { - ::core::mem::transmute(::core::mem::discriminant(&uninit)) - }; + ($ty_name:ident => $ty:ty => $variant_name:ident) => {{ + let uninit: $ty = $ty_name::$variant_name; - #[allow(clippy::forget_non_drop, clippy::forget_copy)] - ::core::mem::forget(uninit); + let discriminant = <$ty as $crate::ExtractDiscriminant>::discriminant(&uninit); - let mut big_endian_bytes = [0_u8; ::core::mem::size_of::<::core::mem::Discriminant<$ty>>()]; + #[allow(clippy::forget_non_drop, clippy::forget_copy)] + ::core::mem::forget(uninit); - let mut i = 0; - - while i < system_endian_bytes.len() { - big_endian_bytes[i] = system_endian_bytes[if cfg!(target_endian = "big") { - i - } else { - system_endian_bytes.len() - i - 1 - }]; - - i += 1; - } - - big_endian_bytes - }, - }, - } - }, + $crate::MaybeUninhabited::Inhabited(discriminant) + }}, ($ty_name:ident => $ty:ty => $variant_name:ident($($field_name:ident: $field_ty:ty),* $(,)?)) => {{ #[allow(unused_parens)] if let ( @@ -580,42 +599,16 @@ pub macro struct_variant_discriminant { ) = ( $(unsafe { <$field_ty as $crate::TypeLayout>::uninit() }),* ) { - $crate::MaybeUninhabited::Inhabited { - 0: $crate::Discriminant { - big_endian_bytes: { - let uninit: $ty = $ty_name::$variant_name( - $(unsafe { $field_name.assume_init() }),* - ); - - let system_endian_bytes: [u8; ::core::mem::size_of::<::core::mem::Discriminant<$ty>>()] = unsafe { - ::core::mem::transmute(::core::mem::discriminant(&uninit)) - }; - - #[allow(clippy::forget_non_drop, clippy::forget_copy)] - ::core::mem::forget(uninit); - - let big_endian_bytes = unsafe { - &mut *$crate::impls::leak_uninit_ptr::< - [u8; ::core::mem::size_of::<::core::mem::Discriminant<$ty>>()] - >() - }; - - let mut i = 0; - - while i < system_endian_bytes.len() { - (*big_endian_bytes)[i] = system_endian_bytes[if cfg!(target_endian = "big") { - i - } else { - system_endian_bytes.len() - i - 1 - }]; - - i += 1; - } - - big_endian_bytes - } - }, - } + let uninit: $ty = $ty_name::$variant_name( + $(unsafe { $field_name.assume_init() }),* + ); + + let discriminant = <$ty as $crate::ExtractDiscriminant>::discriminant(&uninit); + + #[allow(clippy::forget_non_drop, clippy::forget_copy)] + ::core::mem::forget(uninit); + + $crate::MaybeUninhabited::Inhabited(discriminant) } else { $crate::MaybeUninhabited::Uninhabited } @@ -627,42 +620,16 @@ pub macro struct_variant_discriminant { ) = ( $(unsafe { <$field_ty as $crate::TypeLayout>::uninit() }),* ) { - $crate::MaybeUninhabited::Inhabited { - 0: $crate::Discriminant { - big_endian_bytes: { - let uninit: $ty = $ty_name::$variant_name { - $($field_name: unsafe { $field_name.assume_init() }),* - }; - - let system_endian_bytes: [u8; ::core::mem::size_of::<::core::mem::Discriminant<$ty>>()] = unsafe { - ::core::mem::transmute(::core::mem::discriminant(&uninit)) - }; - - #[allow(clippy::forget_non_drop, clippy::forget_copy)] - ::core::mem::forget(uninit); - - let big_endian_bytes = unsafe { - &mut *$crate::impls::leak_uninit_ptr::< - [u8; ::core::mem::size_of::<::core::mem::Discriminant<$ty>>()] - >() - }; - - let mut i = 0; - - while i < system_endian_bytes.len() { - (*big_endian_bytes)[i] = system_endian_bytes[if cfg!(target_endian = "big") { - i - } else { - system_endian_bytes.len() - i - 1 - }]; - - i += 1; - } - - big_endian_bytes - } - }, - } + let uninit: $ty = $ty_name::$variant_name { + $($field_name: unsafe { $field_name.assume_init() }),* + }; + + let discriminant = <$ty as $crate::ExtractDiscriminant>::discriminant(&uninit); + + #[allow(clippy::forget_non_drop, clippy::forget_copy)] + ::core::mem::forget(uninit); + + $crate::MaybeUninhabited::Inhabited(discriminant) } else { $crate::MaybeUninhabited::Uninhabited } diff --git a/src/ser.rs b/src/ser.rs index 9695857..b288328 100644 --- a/src/ser.rs +++ b/src/ser.rs @@ -1,8 +1,8 @@ use core::ops::Deref; use crate::{ - Asyncness, Constness, Discriminant, Field, MaybeUninhabited, Safety, TypeLayoutGraph, - TypeLayoutInfo, TypeStructure, Variant, + Asyncness, Constness, Discriminant, Field, MaybeUninhabited, Safety, TypeLayout, + TypeLayoutGraph, TypeLayoutInfo, TypeStructure, Variant, }; pub const fn serialise_str(bytes: &mut [u8], from: usize, value: &str) -> usize { @@ -159,9 +159,7 @@ pub const fn serialised_maybe_uninhabited_len(from: usize, _value: MaybeUninhabi from + 1 } -pub const fn serialise_discriminant(bytes: &mut [u8], from: usize, value: &Discriminant) -> usize { - let value_bytes = value.big_endian_bytes; - +const fn serialise_discriminant_bytes(bytes: &mut [u8], from: usize, value_bytes: &[u8]) -> usize { let mut leading_zeroes = 0; while leading_zeroes < value_bytes.len() { @@ -190,9 +188,43 @@ pub const fn serialise_discriminant(bytes: &mut [u8], from: usize, value: &Discr from + i - leading_zeroes } -pub const fn serialised_discriminant_len(from: usize, value: &Discriminant) -> usize { - let value_bytes = value.big_endian_bytes; +pub const fn serialise_discriminant(bytes: &mut [u8], from: usize, value: &Discriminant) -> usize { + let from = match value { + Discriminant::I8(_) => serialise_str(bytes, from, ::TYPE_LAYOUT.name), + Discriminant::I16(_) => serialise_str(bytes, from, ::TYPE_LAYOUT.name), + Discriminant::I32(_) => serialise_str(bytes, from, ::TYPE_LAYOUT.name), + Discriminant::I64(_) => serialise_str(bytes, from, ::TYPE_LAYOUT.name), + Discriminant::I128(_) => serialise_str(bytes, from, ::TYPE_LAYOUT.name), + Discriminant::Isize(_) => { + serialise_str(bytes, from, ::TYPE_LAYOUT.name) + }, + Discriminant::U8(_) => serialise_str(bytes, from, ::TYPE_LAYOUT.name), + Discriminant::U16(_) => serialise_str(bytes, from, ::TYPE_LAYOUT.name), + Discriminant::U32(_) => serialise_str(bytes, from, ::TYPE_LAYOUT.name), + Discriminant::U64(_) => serialise_str(bytes, from, ::TYPE_LAYOUT.name), + Discriminant::U128(_) => serialise_str(bytes, from, ::TYPE_LAYOUT.name), + Discriminant::Usize(_) => { + serialise_str(bytes, from, ::TYPE_LAYOUT.name) + }, + }; + + match value { + Discriminant::I8(v) => serialise_discriminant_bytes(bytes, from, &v.to_be_bytes()), + Discriminant::I16(v) => serialise_discriminant_bytes(bytes, from, &v.to_be_bytes()), + Discriminant::I32(v) => serialise_discriminant_bytes(bytes, from, &v.to_be_bytes()), + Discriminant::I64(v) => serialise_discriminant_bytes(bytes, from, &v.to_be_bytes()), + Discriminant::I128(v) => serialise_discriminant_bytes(bytes, from, &v.to_be_bytes()), + Discriminant::Isize(v) => serialise_discriminant_bytes(bytes, from, &v.to_be_bytes()), + Discriminant::U8(v) => serialise_discriminant_bytes(bytes, from, &v.to_be_bytes()), + Discriminant::U16(v) => serialise_discriminant_bytes(bytes, from, &v.to_be_bytes()), + Discriminant::U32(v) => serialise_discriminant_bytes(bytes, from, &v.to_be_bytes()), + Discriminant::U64(v) => serialise_discriminant_bytes(bytes, from, &v.to_be_bytes()), + Discriminant::U128(v) => serialise_discriminant_bytes(bytes, from, &v.to_be_bytes()), + Discriminant::Usize(v) => serialise_discriminant_bytes(bytes, from, &v.to_be_bytes()), + } +} +const fn serialised_discriminant_bytes_len(from: usize, value_bytes: &[u8]) -> usize { let mut leading_zeroes = 0; while leading_zeroes < value_bytes.len() { @@ -208,6 +240,38 @@ pub const fn serialised_discriminant_len(from: usize, value: &Discriminant) -> u from + value_bytes.len() - leading_zeroes } +pub const fn serialised_discriminant_len(from: usize, value: &Discriminant) -> usize { + let from = match value { + Discriminant::I8(_) => serialised_str_len(from, ::TYPE_LAYOUT.name), + Discriminant::I16(_) => serialised_str_len(from, ::TYPE_LAYOUT.name), + Discriminant::I32(_) => serialised_str_len(from, ::TYPE_LAYOUT.name), + Discriminant::I64(_) => serialised_str_len(from, ::TYPE_LAYOUT.name), + Discriminant::I128(_) => serialised_str_len(from, ::TYPE_LAYOUT.name), + Discriminant::Isize(_) => serialised_str_len(from, ::TYPE_LAYOUT.name), + Discriminant::U8(_) => serialised_str_len(from, ::TYPE_LAYOUT.name), + Discriminant::U16(_) => serialised_str_len(from, ::TYPE_LAYOUT.name), + Discriminant::U32(_) => serialised_str_len(from, ::TYPE_LAYOUT.name), + Discriminant::U64(_) => serialised_str_len(from, ::TYPE_LAYOUT.name), + Discriminant::U128(_) => serialised_str_len(from, ::TYPE_LAYOUT.name), + Discriminant::Usize(_) => serialised_str_len(from, ::TYPE_LAYOUT.name), + }; + + match value { + Discriminant::I8(v) => serialised_discriminant_bytes_len(from, &v.to_be_bytes()), + Discriminant::I16(v) => serialised_discriminant_bytes_len(from, &v.to_be_bytes()), + Discriminant::I32(v) => serialised_discriminant_bytes_len(from, &v.to_be_bytes()), + Discriminant::I64(v) => serialised_discriminant_bytes_len(from, &v.to_be_bytes()), + Discriminant::I128(v) => serialised_discriminant_bytes_len(from, &v.to_be_bytes()), + Discriminant::Isize(v) => serialised_discriminant_bytes_len(from, &v.to_be_bytes()), + Discriminant::U8(v) => serialised_discriminant_bytes_len(from, &v.to_be_bytes()), + Discriminant::U16(v) => serialised_discriminant_bytes_len(from, &v.to_be_bytes()), + Discriminant::U32(v) => serialised_discriminant_bytes_len(from, &v.to_be_bytes()), + Discriminant::U64(v) => serialised_discriminant_bytes_len(from, &v.to_be_bytes()), + Discriminant::U128(v) => serialised_discriminant_bytes_len(from, &v.to_be_bytes()), + Discriminant::Usize(v) => serialised_discriminant_bytes_len(from, &v.to_be_bytes()), + } +} + pub const fn serialise_field(bytes: &mut [u8], from: usize, value: &Field) -> usize { let from = serialise_str(bytes, from, value.name); let from = serialise_maybe_uninhabited(bytes, from, value.offset.map(()));