From 77311a5896272c7ed252d8cd53d48ec6ea7c0ccf Mon Sep 17 00:00:00 2001 From: Leonardo Yvens Date: Tue, 23 Jul 2024 11:18:00 +0100 Subject: [PATCH] support Decimal256 type in datafusion-proto (#11606) --- .../proto/datafusion_common.proto | 7 + datafusion/proto-common/src/from_proto/mod.rs | 4 + .../proto-common/src/generated/pbjson.rs | 125 ++++++++++++++++++ .../proto-common/src/generated/prost.rs | 12 +- datafusion/proto-common/src/to_proto/mod.rs | 7 +- .../src/generated/datafusion_proto_common.rs | 12 +- .../tests/cases/roundtrip_logical_plan.rs | 2 + 7 files changed, 164 insertions(+), 5 deletions(-) diff --git a/datafusion/proto-common/proto/datafusion_common.proto b/datafusion/proto-common/proto/datafusion_common.proto index ca95136dadd9..8e8fd2352c6c 100644 --- a/datafusion/proto-common/proto/datafusion_common.proto +++ b/datafusion/proto-common/proto/datafusion_common.proto @@ -130,6 +130,12 @@ message Decimal{ int32 scale = 4; } +message Decimal256Type{ + reserved 1, 2; + uint32 precision = 3; + int32 scale = 4; +} + message List{ Field field_type = 1; } @@ -335,6 +341,7 @@ message ArrowType{ TimeUnit TIME64 = 22 ; IntervalUnit INTERVAL = 23 ; Decimal DECIMAL = 24 ; + Decimal256Type DECIMAL256 = 36; List LIST = 25; List LARGE_LIST = 26; FixedSizeList FIXED_SIZE_LIST = 27; diff --git a/datafusion/proto-common/src/from_proto/mod.rs b/datafusion/proto-common/src/from_proto/mod.rs index 9191ff185a04..5fe9d937f7c4 100644 --- a/datafusion/proto-common/src/from_proto/mod.rs +++ b/datafusion/proto-common/src/from_proto/mod.rs @@ -260,6 +260,10 @@ impl TryFrom<&protobuf::arrow_type::ArrowTypeEnum> for DataType { precision, scale, }) => DataType::Decimal128(*precision as u8, *scale as i8), + arrow_type::ArrowTypeEnum::Decimal256(protobuf::Decimal256Type { + precision, + scale, + }) => DataType::Decimal256(*precision as u8, *scale as i8), arrow_type::ArrowTypeEnum::List(list) => { let list_type = list.as_ref().field_type.as_deref().required("field_type")?; diff --git a/datafusion/proto-common/src/generated/pbjson.rs b/datafusion/proto-common/src/generated/pbjson.rs index 4b34660ae2ef..511072f3cb55 100644 --- a/datafusion/proto-common/src/generated/pbjson.rs +++ b/datafusion/proto-common/src/generated/pbjson.rs @@ -175,6 +175,9 @@ impl serde::Serialize for ArrowType { arrow_type::ArrowTypeEnum::Decimal(v) => { struct_ser.serialize_field("DECIMAL", v)?; } + arrow_type::ArrowTypeEnum::Decimal256(v) => { + struct_ser.serialize_field("DECIMAL256", v)?; + } arrow_type::ArrowTypeEnum::List(v) => { struct_ser.serialize_field("LIST", v)?; } @@ -241,6 +244,7 @@ impl<'de> serde::Deserialize<'de> for ArrowType { "TIME64", "INTERVAL", "DECIMAL", + "DECIMAL256", "LIST", "LARGE_LIST", "LARGELIST", @@ -282,6 +286,7 @@ impl<'de> serde::Deserialize<'de> for ArrowType { Time64, Interval, Decimal, + Decimal256, List, LargeList, FixedSizeList, @@ -338,6 +343,7 @@ impl<'de> serde::Deserialize<'de> for ArrowType { "TIME64" => Ok(GeneratedField::Time64), "INTERVAL" => Ok(GeneratedField::Interval), "DECIMAL" => Ok(GeneratedField::Decimal), + "DECIMAL256" => Ok(GeneratedField::Decimal256), "LIST" => Ok(GeneratedField::List), "LARGELIST" | "LARGE_LIST" => Ok(GeneratedField::LargeList), "FIXEDSIZELIST" | "FIXED_SIZE_LIST" => Ok(GeneratedField::FixedSizeList), @@ -556,6 +562,13 @@ impl<'de> serde::Deserialize<'de> for ArrowType { return Err(serde::de::Error::duplicate_field("DECIMAL")); } arrow_type_enum__ = map_.next_value::<::std::option::Option<_>>()?.map(arrow_type::ArrowTypeEnum::Decimal) +; + } + GeneratedField::Decimal256 => { + if arrow_type_enum__.is_some() { + return Err(serde::de::Error::duplicate_field("DECIMAL256")); + } + arrow_type_enum__ = map_.next_value::<::std::option::Option<_>>()?.map(arrow_type::ArrowTypeEnum::Decimal256) ; } GeneratedField::List => { @@ -2849,6 +2862,118 @@ impl<'de> serde::Deserialize<'de> for Decimal256 { deserializer.deserialize_struct("datafusion_common.Decimal256", FIELDS, GeneratedVisitor) } } +impl serde::Serialize for Decimal256Type { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if self.precision != 0 { + len += 1; + } + if self.scale != 0 { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion_common.Decimal256Type", len)?; + if self.precision != 0 { + struct_ser.serialize_field("precision", &self.precision)?; + } + if self.scale != 0 { + struct_ser.serialize_field("scale", &self.scale)?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for Decimal256Type { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "precision", + "scale", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + Precision, + Scale, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "precision" => Ok(GeneratedField::Precision), + "scale" => Ok(GeneratedField::Scale), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = Decimal256Type; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion_common.Decimal256Type") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut precision__ = None; + let mut scale__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::Precision => { + if precision__.is_some() { + return Err(serde::de::Error::duplicate_field("precision")); + } + precision__ = + Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) + ; + } + GeneratedField::Scale => { + if scale__.is_some() { + return Err(serde::de::Error::duplicate_field("scale")); + } + scale__ = + Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) + ; + } + } + } + Ok(Decimal256Type { + precision: precision__.unwrap_or_default(), + scale: scale__.unwrap_or_default(), + }) + } + } + deserializer.deserialize_struct("datafusion_common.Decimal256Type", FIELDS, GeneratedVisitor) + } +} impl serde::Serialize for DfField { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result diff --git a/datafusion/proto-common/src/generated/prost.rs b/datafusion/proto-common/src/generated/prost.rs index 9a2770997f15..62919e218b13 100644 --- a/datafusion/proto-common/src/generated/prost.rs +++ b/datafusion/proto-common/src/generated/prost.rs @@ -140,6 +140,14 @@ pub struct Decimal { } #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] +pub struct Decimal256Type { + #[prost(uint32, tag = "3")] + pub precision: u32, + #[prost(int32, tag = "4")] + pub scale: i32, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] pub struct List { #[prost(message, optional, boxed, tag = "1")] pub field_type: ::core::option::Option<::prost::alloc::boxed::Box>, @@ -446,7 +454,7 @@ pub struct Decimal256 { pub struct ArrowType { #[prost( oneof = "arrow_type::ArrowTypeEnum", - tags = "1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 35, 32, 15, 34, 16, 31, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 33" + tags = "1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 35, 32, 15, 34, 16, 31, 17, 18, 19, 20, 21, 22, 23, 24, 36, 25, 26, 27, 28, 29, 30, 33" )] pub arrow_type_enum: ::core::option::Option, } @@ -516,6 +524,8 @@ pub mod arrow_type { Interval(i32), #[prost(message, tag = "24")] Decimal(super::Decimal), + #[prost(message, tag = "36")] + Decimal256(super::Decimal256Type), #[prost(message, tag = "25")] List(::prost::alloc::boxed::Box), #[prost(message, tag = "26")] diff --git a/datafusion/proto-common/src/to_proto/mod.rs b/datafusion/proto-common/src/to_proto/mod.rs index 9dcb65444a47..c15da2895b7c 100644 --- a/datafusion/proto-common/src/to_proto/mod.rs +++ b/datafusion/proto-common/src/to_proto/mod.rs @@ -191,9 +191,10 @@ impl TryFrom<&DataType> for protobuf::arrow_type::ArrowTypeEnum { precision: *precision as u32, scale: *scale as i32, }), - DataType::Decimal256(_, _) => { - return Err(Error::General("Proto serialization error: The Decimal256 data type is not yet supported".to_owned())) - } + DataType::Decimal256(precision, scale) => Self::Decimal256(protobuf::Decimal256Type { + precision: *precision as u32, + scale: *scale as i32, + }), DataType::Map(field, sorted) => { Self::Map(Box::new( protobuf::Map { diff --git a/datafusion/proto/src/generated/datafusion_proto_common.rs b/datafusion/proto/src/generated/datafusion_proto_common.rs index 9a2770997f15..62919e218b13 100644 --- a/datafusion/proto/src/generated/datafusion_proto_common.rs +++ b/datafusion/proto/src/generated/datafusion_proto_common.rs @@ -140,6 +140,14 @@ pub struct Decimal { } #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] +pub struct Decimal256Type { + #[prost(uint32, tag = "3")] + pub precision: u32, + #[prost(int32, tag = "4")] + pub scale: i32, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] pub struct List { #[prost(message, optional, boxed, tag = "1")] pub field_type: ::core::option::Option<::prost::alloc::boxed::Box>, @@ -446,7 +454,7 @@ pub struct Decimal256 { pub struct ArrowType { #[prost( oneof = "arrow_type::ArrowTypeEnum", - tags = "1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 35, 32, 15, 34, 16, 31, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 33" + tags = "1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 35, 32, 15, 34, 16, 31, 17, 18, 19, 20, 21, 22, 23, 24, 36, 25, 26, 27, 28, 29, 30, 33" )] pub arrow_type_enum: ::core::option::Option, } @@ -516,6 +524,8 @@ pub mod arrow_type { Interval(i32), #[prost(message, tag = "24")] Decimal(super::Decimal), + #[prost(message, tag = "36")] + Decimal256(super::Decimal256Type), #[prost(message, tag = "25")] List(::prost::alloc::boxed::Box), #[prost(message, tag = "26")] diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs index 3476d5d042cc..f6557c7b2d8f 100644 --- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs @@ -27,6 +27,7 @@ use arrow::array::{ use arrow::datatypes::{ DataType, Field, Fields, Int32Type, IntervalDayTimeType, IntervalMonthDayNanoType, IntervalUnit, Schema, SchemaRef, TimeUnit, UnionFields, UnionMode, + DECIMAL256_MAX_PRECISION, }; use prost::Message; @@ -1379,6 +1380,7 @@ fn round_trip_datatype() { DataType::Utf8, DataType::LargeUtf8, DataType::Decimal128(7, 12), + DataType::Decimal256(DECIMAL256_MAX_PRECISION, 0), // Recursive list tests DataType::List(new_arc_field("Level1", DataType::Binary, true)), DataType::List(new_arc_field(