diff --git a/ensemble/src/value/de.rs b/ensemble/src/value/de.rs new file mode 100644 index 0000000..845a0ad --- /dev/null +++ b/ensemble/src/value/de.rs @@ -0,0 +1,667 @@ +use std::{ + fmt::{self, Debug}, + vec::IntoIter, +}; + +use rbs::{value::map::ValueMap, Value}; +use serde::{ + de::{self, IntoDeserializer, Unexpected, Visitor}, + forward_to_deserialize_any, Deserialize, Deserializer, +}; + +#[inline] +pub fn deserialize_value<'de, T: Deserialize<'de>>(val: rbs::Value) -> Result { + Deserialize::deserialize(ValueDeserializer(val)) +} + +#[repr(transparent)] +struct ValueDeserializer(rbs::Value); + +trait ValueBase<'de>: Deserializer<'de, Error = rbs::Error> { + type Item: ValueBase<'de>; + type MapDeserializer: Deserializer<'de>; + type Iter: ExactSizeIterator; + type MapIter: Iterator; + + fn is_null(&self) -> bool; + fn unexpected(&self) -> Unexpected<'_>; + + fn into_iter(self) -> Result; + fn into_map_iter(self) -> Result; +} + +impl<'de> ValueBase<'de> for Value { + type Item = ValueDeserializer; + type Iter = IntoIter; + type MapIter = IntoIter<(Self::Item, Self::Item)>; + type MapDeserializer = MapDeserializer; + + #[inline] + fn is_null(&self) -> bool { + matches!(self, Self::Null) + } + + #[inline] + fn into_iter(self) -> Result { + match self { + Self::Array(v) => Ok(v + .into_iter() + .map(ValueDeserializer) + .collect::>() + .into_iter()), + other => Err(other.into()), + } + } + + #[inline] + fn into_map_iter(self) -> Result { + match self { + Self::Map(v) => Ok(v + .0 + .into_iter() + .map(|(k, v)| (ValueDeserializer(k), ValueDeserializer(v))) + .collect::>() + .into_iter()), + other => Err(other.into()), + } + } + + #[cold] + fn unexpected(&self) -> Unexpected<'_> { + match *self { + Self::Null => Unexpected::Unit, + Self::Map(..) => Unexpected::Map, + Self::F64(v) => Unexpected::Float(v), + Self::Bool(v) => Unexpected::Bool(v), + Self::I64(v) => Unexpected::Signed(v), + Self::U64(v) => Unexpected::Unsigned(v), + Self::Ext(..) | Self::Array(..) => Unexpected::Seq, + Self::F32(v) => Unexpected::Float(f64::from(v)), + Self::I32(v) => Unexpected::Signed(i64::from(v)), + Self::Binary(ref v) => Unexpected::Bytes(v), + Self::U32(v) => Unexpected::Unsigned(u64::from(v)), + Self::String(ref v) => Unexpected::Bytes(v.as_bytes()), + } + } +} + +impl<'de> ValueBase<'de> for ValueDeserializer { + type Item = Self; + type Iter = IntoIter; + type MapIter = IntoIter<(Self::Item, Self::Item)>; + type MapDeserializer = MapDeserializer; + + #[inline] + fn is_null(&self) -> bool { + self.0.is_null() + } + + #[inline] + fn into_iter(self) -> Result { + match self.0 { + Value::Array(v) => Ok(v + .into_iter() + .map(ValueDeserializer) + .collect::>() + .into_iter()), + other => Err(other.into()), + } + } + + #[inline] + fn into_map_iter(self) -> Result { + match self.0 { + Value::Map(v) => Ok(v + .0 + .into_iter() + .map(|(k, v)| (Self(k), Self(v))) + .collect::>() + .into_iter()), + other => Err(other.into()), + } + } + + #[cold] + fn unexpected(&self) -> Unexpected<'_> { + match self.0 { + Value::Null => Unexpected::Unit, + Value::Map(..) => Unexpected::Map, + Value::F64(v) => Unexpected::Float(v), + Value::I64(v) => Unexpected::Signed(v), + Value::Bool(v) => Unexpected::Bool(v), + Value::U64(v) => Unexpected::Unsigned(v), + Value::Ext(..) | Value::Array(..) => Unexpected::Seq, + Value::F32(v) => Unexpected::Float(f64::from(v)), + Value::I32(v) => Unexpected::Signed(i64::from(v)), + Value::Binary(ref v) => Unexpected::Bytes(v), + Value::U32(v) => Unexpected::Unsigned(u64::from(v)), + Value::String(ref v) => Unexpected::Bytes(v.as_bytes()), + } + } +} + +impl From for ValueDeserializer { + #[inline] + fn from(value: Value) -> Self { + Self(value) + } +} + +impl From for Value { + #[inline] + fn from(value: ValueDeserializer) -> Self { + value.0 + } +} + +impl<'de> Deserialize<'de> for ValueDeserializer { + #[inline] + #[allow(clippy::too_many_lines)] + fn deserialize(de: D) -> Result + where + D: de::Deserializer<'de>, + { + struct ValueVisitor; + + impl<'de> serde::de::Visitor<'de> for ValueVisitor { + type Value = ValueDeserializer; + + #[cold] + fn expecting(&self, fmt: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> { + "any valid MessagePack value".fmt(fmt) + } + + #[inline] + fn visit_some(self, de: D) -> Result + where + D: de::Deserializer<'de>, + { + Deserialize::deserialize(de) + } + + #[inline] + fn visit_none(self) -> Result { + Ok(Value::Null.into()) + } + + #[inline] + fn visit_unit(self) -> Result { + Ok(Value::Null.into()) + } + + #[inline] + fn visit_bool(self, value: bool) -> Result { + Ok(Value::Bool(value).into()) + } + + fn visit_u32(self, v: u32) -> Result { + Ok(Value::U32(v).into()) + } + + #[inline] + fn visit_u64(self, value: u64) -> Result { + Ok(Value::U64(value).into()) + } + + fn visit_i32(self, v: i32) -> Result { + Ok(Value::I32(v).into()) + } + + #[inline] + fn visit_i64(self, value: i64) -> Result { + Ok(Value::I64(value).into()) + } + + #[inline] + fn visit_f32(self, value: f32) -> Result { + Ok(Value::F32(value).into()) + } + + #[inline] + fn visit_f64(self, value: f64) -> Result { + Ok(Value::F64(value).into()) + } + + #[inline] + fn visit_string(self, value: String) -> Result { + Ok(Value::String(value).into()) + } + + #[inline] + fn visit_str(self, value: &str) -> Result { + self.visit_string(String::from(value)) + } + + #[inline] + fn visit_seq>( + self, + mut visitor: V, + ) -> Result { + let mut vec = { + visitor + .size_hint() + .map_or_else(Vec::new, Vec::with_capacity) + }; + while let Some(elem) = visitor.next_element::()? { + vec.push(elem.into()); + } + Ok(Value::Array(vec).into()) + } + + #[inline] + fn visit_bytes(self, v: &[u8]) -> Result { + Ok(Value::Binary(v.to_owned()).into()) + } + + #[inline] + fn visit_byte_buf(self, v: Vec) -> Result { + Ok(Value::Binary(v).into()) + } + + #[inline] + fn visit_map>( + self, + mut visitor: V, + ) -> Result { + let mut pairs = { + visitor + .size_hint() + .map_or_else(Vec::new, Vec::with_capacity) + }; + while let Some(key) = visitor.next_key::()? { + let val = visitor.next_value::()?; + pairs.push((key.into(), val.into())); + } + + Ok(Value::Map(ValueMap(pairs)).into()) + } + + fn visit_newtype_struct>( + self, + deserializer: D, + ) -> Result { + deserializer.deserialize_newtype_struct("", self) + } + } + + de.deserialize_any(ValueVisitor) + } +} + +impl<'de> Deserializer<'de> for ValueDeserializer { + type Error = rbs::Error; + + fn deserialize_any(self, visitor: V) -> Result + where + V: Visitor<'de>, + { + match self.into() { + Value::Null => visitor.visit_unit(), + Value::I32(v) => visitor.visit_i32(v), + Value::I64(v) => visitor.visit_i64(v), + Value::U32(v) => visitor.visit_u32(v), + Value::U64(v) => visitor.visit_u64(v), + Value::F32(v) => visitor.visit_f32(v), + Value::F64(v) => visitor.visit_f64(v), + Value::Bool(v) => visitor.visit_bool(v), + Value::String(v) => visitor.visit_string(v), + Value::Binary(v) => visitor.visit_byte_buf(v), + Value::Array(v) => { + let len = v.len(); + let mut de = SeqDeserializer { + iter: v.into_iter().map(ValueDeserializer), + }; + let seq = visitor.visit_seq(&mut de)?; + if de.iter.len() == 0 { + Ok(seq) + } else { + Err(de::Error::invalid_length(len, &"fewer elements in array")) + } + } + Value::Map(v) => { + let len = v.len(); + let mut de = MapDeserializer { + val: None, + iter: v.0.into_iter().map(|(k, v)| (Self(k), Self(v))), + }; + let map = visitor.visit_map(&mut de)?; + if de.iter.len() == 0 { + Ok(map) + } else { + Err(de::Error::invalid_length(len, &"fewer elements in map")) + } + } + Value::Ext(_tag, data) => Deserializer::deserialize_any(Self(*data), visitor), + } + } + + #[inline] + fn deserialize_option(self, visitor: V) -> Result + where + V: Visitor<'de>, + { + if self.0.is_null() { + visitor.visit_none() + } else { + visitor.visit_some(self) + } + } + + #[inline] + fn deserialize_enum( + self, + _name: &str, + _variants: &'static [&'static str], + visitor: V, + ) -> Result + where + V: Visitor<'de>, + { + match self.0 { + Value::String(variant) => visitor.visit_enum(variant.into_deserializer()), + Value::Array(iter) => { + let mut iter = iter.into_iter(); + if !(iter.len() == 1 || iter.len() == 2) { + return Err(de::Error::invalid_length( + iter.len(), + &"array with one or two elements", + )); + } + + let id = match iter.next() { + Some(id) => deserialize_value(id)?, + None => { + return Err(de::Error::invalid_value( + Unexpected::Seq, + &"array with one or two elements", + )); + } + }; + + visitor.visit_enum(EnumDeserializer { + id, + value: iter.next(), + }) + } + other => Err(de::Error::invalid_type( + other.unexpected(), + &"string, array, map or int", + )), + } + } + + #[inline] + fn deserialize_newtype_struct( + self, + _name: &'static str, + visitor: V, + ) -> Result + where + V: Visitor<'de>, + { + visitor.visit_newtype_struct(self) + } + + #[inline] + fn deserialize_unit_struct( + self, + _name: &'static str, + visitor: V, + ) -> Result + where + V: Visitor<'de>, + { + match self.0 { + Value::Array(iter) => { + let iter = iter.into_iter(); + + if iter.len() == 0 { + visitor.visit_unit() + } else { + Err(de::Error::invalid_type(Unexpected::Seq, &"empty array")) + } + } + other => Err(de::Error::invalid_type(other.unexpected(), &"empty array")), + } + } + + forward_to_deserialize_any! { + bool u8 u16 u32 u64 i8 i16 i32 i64 f32 f64 char str string unit seq + bytes byte_buf map tuple_struct struct + identifier tuple ignored_any + } +} + +struct SeqDeserializer { + iter: I, +} + +impl<'de, I, U> de::SeqAccess<'de> for SeqDeserializer +where + I: Iterator, + U: Deserializer<'de, Error = rbs::Error>, +{ + type Error = rbs::Error; + + fn next_element_seed(&mut self, seed: T) -> Result, Self::Error> + where + T: de::DeserializeSeed<'de>, + { + self.iter + .next() + .map_or_else(|| Ok(None), |val| seed.deserialize(val).map(Some)) + } +} + +impl<'de, I, U> Deserializer<'de> for SeqDeserializer +where + I: ExactSizeIterator, + U: Deserializer<'de, Error = rbs::Error>, +{ + type Error = rbs::Error; + + #[inline] + fn deserialize_any(mut self, visitor: V) -> Result + where + V: Visitor<'de>, + { + let len = self.iter.len(); + if len == 0 { + visitor.visit_unit() + } else { + let value = visitor.visit_seq(&mut self)?; + + if self.iter.len() == 0 { + Ok(value) + } else { + Err(de::Error::invalid_length(len, &"fewer elements in array")) + } + } + } + + forward_to_deserialize_any! { + bool u8 u16 u32 u64 i8 i16 i32 i64 f32 f64 char str string unit option + seq bytes byte_buf map unit_struct newtype_struct + tuple_struct struct identifier tuple enum ignored_any + } +} + +struct MapDeserializer { + iter: I, + val: Option, +} + +impl<'de, I, U> de::MapAccess<'de> for MapDeserializer +where + I: Iterator, + U: ValueBase<'de>, +{ + type Error = rbs::Error; + + fn next_key_seed(&mut self, seed: T) -> Result, Self::Error> + where + T: de::DeserializeSeed<'de>, + { + match self.iter.next() { + Some((key, val)) => { + self.val = Some(val); + seed.deserialize(key).map(Some) + } + None => Ok(None), + } + } + + fn next_value_seed(&mut self, seed: T) -> Result + where + T: de::DeserializeSeed<'de>, + { + Option::take(&mut self.val).map_or_else( + || Err(de::Error::custom("value is missing")), + |val| seed.deserialize(val), + ) + } +} + +impl<'de, I, U> Deserializer<'de> for MapDeserializer +where + U: ValueBase<'de>, + I: Iterator, +{ + type Error = rbs::Error; + + #[inline] + fn deserialize_any(self, visitor: V) -> Result + where + V: Visitor<'de>, + { + visitor.visit_map(self) + } + + forward_to_deserialize_any! { + bool u8 u16 u32 u64 i8 i16 i32 i64 f32 f64 char str string unit option + seq bytes byte_buf map unit_struct newtype_struct + tuple_struct struct identifier tuple enum ignored_any + } +} + +struct EnumDeserializer { + id: u32, + value: Option, +} + +impl<'de, U: ValueBase<'de>> de::EnumAccess<'de> for EnumDeserializer { + type Error = rbs::Error; + type Variant = VariantDeserializer; + + fn variant_seed>( + self, + seed: V, + ) -> Result<(V::Value, Self::Variant), Self::Error> { + let variant = self.id.into_deserializer(); + let visitor = VariantDeserializer { value: self.value }; + seed.deserialize(variant).map(|v| (v, visitor)) + } +} + +struct VariantDeserializer { + value: Option, +} + +impl<'de, U: ValueBase<'de>> de::VariantAccess<'de> for VariantDeserializer { + type Error = rbs::Error; + + fn unit_variant(self) -> Result<(), Self::Error> { + // Can accept only [u32]. + self.value.map_or(Ok(()), |v| match v.into_iter() { + Ok(ref v) if v.len() == 0 => Ok(()), + Ok(..) => Err(de::Error::invalid_value(Unexpected::Seq, &"empty array")), + Err(v) => Err(de::Error::invalid_value(v.unexpected(), &"empty array")), + }) + } + + fn newtype_variant_seed(self, seed: T) -> Result + where + T: de::DeserializeSeed<'de>, + { + // Can accept both [u32, T...] and [u32, [T]] cases. + match self.value { + Some(v) => match v.into_iter() { + Ok(mut iter) => { + if iter.len() > 1 { + seed.deserialize(SeqDeserializer { iter }) + } else { + let val = match iter.next() { + Some(val) => seed.deserialize(val), + None => { + return Err(de::Error::invalid_value( + Unexpected::Seq, + &"array with one element", + )); + } + }; + + if iter.next().is_some() { + Err(de::Error::invalid_value( + Unexpected::Seq, + &"array with one element", + )) + } else { + val + } + } + } + Err(v) => seed.deserialize(v), + }, + None => Err(de::Error::invalid_type( + Unexpected::UnitVariant, + &"newtype variant", + )), + } + } + + fn tuple_variant(self, _len: usize, visitor: V) -> Result + where + V: Visitor<'de>, + { + // Can accept [u32, [T...]]. + self.value.map_or_else( + || { + Err(de::Error::invalid_type( + Unexpected::UnitVariant, + &"tuple variant", + )) + }, + |v| match v.into_iter() { + Ok(v) => Deserializer::deserialize_any(SeqDeserializer { iter: v }, visitor), + Err(v) => Err(de::Error::invalid_type(v.unexpected(), &"tuple variant")), + }, + ) + } + + fn struct_variant( + self, + _fields: &'static [&'static str], + visitor: V, + ) -> Result + where + V: Visitor<'de>, + { + self.value.map_or_else( + || { + Err(de::Error::invalid_type( + Unexpected::UnitVariant, + &"struct variant", + )) + }, + |v| match v.into_iter() { + Ok(iter) => Deserializer::deserialize_any(SeqDeserializer { iter }, visitor), + Err(v) => match v.into_map_iter() { + Ok(iter) => { + Deserializer::deserialize_any(MapDeserializer { iter, val: None }, visitor) + } + Err(v) => Err(de::Error::invalid_type(v.unexpected(), &"struct variant")), + }, + }, + ) + } +} diff --git a/ensemble/src/value/mod.rs b/ensemble/src/value/mod.rs new file mode 100644 index 0000000..be3477e --- /dev/null +++ b/ensemble/src/value/mod.rs @@ -0,0 +1,25 @@ +use serde::Serialize; + +use self::{de::deserialize_value, ser::fast_serialize}; +use crate::Model; + +mod de; +mod ser; + +/// Serialize a model for the database. +/// +/// # Errors +/// +/// Returns an error if serialization fails. +pub fn for_db(value: T) -> Result { + fast_serialize(value) +} + +/// Deserialize a model from the database. +/// +/// # Errors +/// +/// Returns an error if deserialization fails. +pub(crate) fn from(value: rbs::Value) -> Result { + deserialize_value::(value) +} diff --git a/ensemble/src/value.rs b/ensemble/src/value/ser.rs similarity index 92% rename from ensemble/src/value.rs rename to ensemble/src/value/ser.rs index ba9e581..329deca 100644 --- a/ensemble/src/value.rs +++ b/ensemble/src/value/ser.rs @@ -1,22 +1,7 @@ use rbs::{value::map::ValueMap, Value}; use serde::{ser, Serialize}; -use crate::Model; - -pub(crate) fn from(value: Value) -> Result { - rbs::from_value::(value) -} - -/// Serialize a model for the database. -/// -/// # Errors -/// -/// Returns an error if serialization fails. -pub fn for_db(value: T) -> Result { - value.serialize(Serializer) -} - -fn fast_serialize(mut value: T) -> Result { +pub fn fast_serialize(mut value: T) -> Result { let type_name = std::any::type_name::(); if type_name == std::any::type_name::() { let addr = std::ptr::addr_of_mut!(value); @@ -454,7 +439,7 @@ mod tests { }; assert_eq!( - for_db(test).unwrap(), + fast_serialize(test).unwrap(), rbs::to_value! { "a" : 1, "b" : "test", @@ -472,10 +457,13 @@ mod tests { #[test] fn test_serialize_enum() { - assert_eq!(for_db(Status::Ok).unwrap(), rbs::to_value!("Ok")); - assert_eq!(for_db(Status::Error).unwrap(), rbs::to_value!("Error")); + assert_eq!(fast_serialize(Status::Ok).unwrap(), rbs::to_value!("Ok")); assert_eq!( - for_db(Status::ThirdThing).unwrap(), + fast_serialize(Status::Error).unwrap(), + rbs::to_value!("Error") + ); + assert_eq!( + fast_serialize(Status::ThirdThing).unwrap(), rbs::to_value!("ThirdThing") ); } @@ -490,10 +478,13 @@ mod tests { #[test] fn test_serialize_enum_with_custom_config() { - assert_eq!(for_db(StatusV2::Ok).unwrap(), rbs::to_value!("ok")); - assert_eq!(for_db(StatusV2::Error).unwrap(), rbs::to_value!("error")); + assert_eq!(fast_serialize(StatusV2::Ok).unwrap(), rbs::to_value!("ok")); + assert_eq!( + fast_serialize(StatusV2::Error).unwrap(), + rbs::to_value!("error") + ); assert_eq!( - for_db(StatusV2::ThirdThing).unwrap(), + fast_serialize(StatusV2::ThirdThing).unwrap(), rbs::to_value!("third_thing") ); } @@ -503,7 +494,7 @@ mod tests { let datetime = DateTime::now(); assert_eq!( - for_db(&datetime).unwrap(), + fast_serialize(&datetime).unwrap(), Value::Ext("DateTime", Box::new(rbs::to_value!(datetime.0))) ); } @@ -513,7 +504,7 @@ mod tests { let uuid = Uuid::new(); assert_eq!( - for_db(&uuid).unwrap(), + fast_serialize(&uuid).unwrap(), Value::Ext("Uuid", Box::new(Value::String(uuid.to_string()))) ); } @@ -522,7 +513,10 @@ mod tests { fn properly_serializes_hashed() { let hashed = Hashed::new("hello-world"); - assert_eq!(for_db(&hashed).unwrap(), Value::String(hashed.to_string())); + assert_eq!( + fast_serialize(&hashed).unwrap(), + Value::String(hashed.to_string()) + ); } #[test] @@ -533,7 +527,7 @@ mod tests { })); assert_eq!( - for_db(&json).unwrap(), + fast_serialize(&json).unwrap(), Value::Ext("Json", Box::new(Value::String(json.to_string()))) ); }