Skip to content

Commit

Permalink
support Decimal256 type in datafusion-proto (#11606)
Browse files Browse the repository at this point in the history
  • Loading branch information
leoyvens authored Jul 23, 2024
1 parent deef834 commit 77311a5
Show file tree
Hide file tree
Showing 7 changed files with 164 additions and 5 deletions.
7 changes: 7 additions & 0 deletions datafusion/proto-common/proto/datafusion_common.proto
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,12 @@ message Decimal{
int32 scale = 4;
}

message Decimal256Type{
reserved 1, 2;
uint32 precision = 3;
int32 scale = 4;
}

message List{
Field field_type = 1;
}
Expand Down Expand Up @@ -335,6 +341,7 @@ message ArrowType{
TimeUnit TIME64 = 22 ;
IntervalUnit INTERVAL = 23 ;
Decimal DECIMAL = 24 ;
Decimal256Type DECIMAL256 = 36;
List LIST = 25;
List LARGE_LIST = 26;
FixedSizeList FIXED_SIZE_LIST = 27;
Expand Down
4 changes: 4 additions & 0 deletions datafusion/proto-common/src/from_proto/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,10 @@ impl TryFrom<&protobuf::arrow_type::ArrowTypeEnum> for DataType {
precision,
scale,
}) => DataType::Decimal128(*precision as u8, *scale as i8),
arrow_type::ArrowTypeEnum::Decimal256(protobuf::Decimal256Type {
precision,
scale,
}) => DataType::Decimal256(*precision as u8, *scale as i8),
arrow_type::ArrowTypeEnum::List(list) => {
let list_type =
list.as_ref().field_type.as_deref().required("field_type")?;
Expand Down
125 changes: 125 additions & 0 deletions datafusion/proto-common/src/generated/pbjson.rs
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,9 @@ impl serde::Serialize for ArrowType {
arrow_type::ArrowTypeEnum::Decimal(v) => {
struct_ser.serialize_field("DECIMAL", v)?;
}
arrow_type::ArrowTypeEnum::Decimal256(v) => {
struct_ser.serialize_field("DECIMAL256", v)?;
}
arrow_type::ArrowTypeEnum::List(v) => {
struct_ser.serialize_field("LIST", v)?;
}
Expand Down Expand Up @@ -241,6 +244,7 @@ impl<'de> serde::Deserialize<'de> for ArrowType {
"TIME64",
"INTERVAL",
"DECIMAL",
"DECIMAL256",
"LIST",
"LARGE_LIST",
"LARGELIST",
Expand Down Expand Up @@ -282,6 +286,7 @@ impl<'de> serde::Deserialize<'de> for ArrowType {
Time64,
Interval,
Decimal,
Decimal256,
List,
LargeList,
FixedSizeList,
Expand Down Expand Up @@ -338,6 +343,7 @@ impl<'de> serde::Deserialize<'de> for ArrowType {
"TIME64" => Ok(GeneratedField::Time64),
"INTERVAL" => Ok(GeneratedField::Interval),
"DECIMAL" => Ok(GeneratedField::Decimal),
"DECIMAL256" => Ok(GeneratedField::Decimal256),
"LIST" => Ok(GeneratedField::List),
"LARGELIST" | "LARGE_LIST" => Ok(GeneratedField::LargeList),
"FIXEDSIZELIST" | "FIXED_SIZE_LIST" => Ok(GeneratedField::FixedSizeList),
Expand Down Expand Up @@ -556,6 +562,13 @@ impl<'de> serde::Deserialize<'de> for ArrowType {
return Err(serde::de::Error::duplicate_field("DECIMAL"));
}
arrow_type_enum__ = map_.next_value::<::std::option::Option<_>>()?.map(arrow_type::ArrowTypeEnum::Decimal)
;
}
GeneratedField::Decimal256 => {
if arrow_type_enum__.is_some() {
return Err(serde::de::Error::duplicate_field("DECIMAL256"));
}
arrow_type_enum__ = map_.next_value::<::std::option::Option<_>>()?.map(arrow_type::ArrowTypeEnum::Decimal256)
;
}
GeneratedField::List => {
Expand Down Expand Up @@ -2849,6 +2862,118 @@ impl<'de> serde::Deserialize<'de> for Decimal256 {
deserializer.deserialize_struct("datafusion_common.Decimal256", FIELDS, GeneratedVisitor)
}
}
impl serde::Serialize for Decimal256Type {
#[allow(deprecated)]
fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
use serde::ser::SerializeStruct;
let mut len = 0;
if self.precision != 0 {
len += 1;
}
if self.scale != 0 {
len += 1;
}
let mut struct_ser = serializer.serialize_struct("datafusion_common.Decimal256Type", len)?;
if self.precision != 0 {
struct_ser.serialize_field("precision", &self.precision)?;
}
if self.scale != 0 {
struct_ser.serialize_field("scale", &self.scale)?;
}
struct_ser.end()
}
}
impl<'de> serde::Deserialize<'de> for Decimal256Type {
#[allow(deprecated)]
fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
const FIELDS: &[&str] = &[
"precision",
"scale",
];

#[allow(clippy::enum_variant_names)]
enum GeneratedField {
Precision,
Scale,
}
impl<'de> serde::Deserialize<'de> for GeneratedField {
fn deserialize<D>(deserializer: D) -> std::result::Result<GeneratedField, D::Error>
where
D: serde::Deserializer<'de>,
{
struct GeneratedVisitor;

impl<'de> serde::de::Visitor<'de> for GeneratedVisitor {
type Value = GeneratedField;

fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(formatter, "expected one of: {:?}", &FIELDS)
}

#[allow(unused_variables)]
fn visit_str<E>(self, value: &str) -> std::result::Result<GeneratedField, E>
where
E: serde::de::Error,
{
match value {
"precision" => Ok(GeneratedField::Precision),
"scale" => Ok(GeneratedField::Scale),
_ => Err(serde::de::Error::unknown_field(value, FIELDS)),
}
}
}
deserializer.deserialize_identifier(GeneratedVisitor)
}
}
struct GeneratedVisitor;
impl<'de> serde::de::Visitor<'de> for GeneratedVisitor {
type Value = Decimal256Type;

fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
formatter.write_str("struct datafusion_common.Decimal256Type")
}

fn visit_map<V>(self, mut map_: V) -> std::result::Result<Decimal256Type, V::Error>
where
V: serde::de::MapAccess<'de>,
{
let mut precision__ = None;
let mut scale__ = None;
while let Some(k) = map_.next_key()? {
match k {
GeneratedField::Precision => {
if precision__.is_some() {
return Err(serde::de::Error::duplicate_field("precision"));
}
precision__ =
Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0)
;
}
GeneratedField::Scale => {
if scale__.is_some() {
return Err(serde::de::Error::duplicate_field("scale"));
}
scale__ =
Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0)
;
}
}
}
Ok(Decimal256Type {
precision: precision__.unwrap_or_default(),
scale: scale__.unwrap_or_default(),
})
}
}
deserializer.deserialize_struct("datafusion_common.Decimal256Type", FIELDS, GeneratedVisitor)
}
}
impl serde::Serialize for DfField {
#[allow(deprecated)]
fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
Expand Down
12 changes: 11 additions & 1 deletion datafusion/proto-common/src/generated/prost.rs
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,14 @@ pub struct Decimal {
}
#[allow(clippy::derive_partial_eq_without_eq)]
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct Decimal256Type {
#[prost(uint32, tag = "3")]
pub precision: u32,
#[prost(int32, tag = "4")]
pub scale: i32,
}
#[allow(clippy::derive_partial_eq_without_eq)]
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct List {
#[prost(message, optional, boxed, tag = "1")]
pub field_type: ::core::option::Option<::prost::alloc::boxed::Box<Field>>,
Expand Down Expand Up @@ -446,7 +454,7 @@ pub struct Decimal256 {
pub struct ArrowType {
#[prost(
oneof = "arrow_type::ArrowTypeEnum",
tags = "1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 35, 32, 15, 34, 16, 31, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 33"
tags = "1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 35, 32, 15, 34, 16, 31, 17, 18, 19, 20, 21, 22, 23, 24, 36, 25, 26, 27, 28, 29, 30, 33"
)]
pub arrow_type_enum: ::core::option::Option<arrow_type::ArrowTypeEnum>,
}
Expand Down Expand Up @@ -516,6 +524,8 @@ pub mod arrow_type {
Interval(i32),
#[prost(message, tag = "24")]
Decimal(super::Decimal),
#[prost(message, tag = "36")]
Decimal256(super::Decimal256Type),
#[prost(message, tag = "25")]
List(::prost::alloc::boxed::Box<super::List>),
#[prost(message, tag = "26")]
Expand Down
7 changes: 4 additions & 3 deletions datafusion/proto-common/src/to_proto/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -191,9 +191,10 @@ impl TryFrom<&DataType> for protobuf::arrow_type::ArrowTypeEnum {
precision: *precision as u32,
scale: *scale as i32,
}),
DataType::Decimal256(_, _) => {
return Err(Error::General("Proto serialization error: The Decimal256 data type is not yet supported".to_owned()))
}
DataType::Decimal256(precision, scale) => Self::Decimal256(protobuf::Decimal256Type {
precision: *precision as u32,
scale: *scale as i32,
}),
DataType::Map(field, sorted) => {
Self::Map(Box::new(
protobuf::Map {
Expand Down
12 changes: 11 additions & 1 deletion datafusion/proto/src/generated/datafusion_proto_common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,14 @@ pub struct Decimal {
}
#[allow(clippy::derive_partial_eq_without_eq)]
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct Decimal256Type {
#[prost(uint32, tag = "3")]
pub precision: u32,
#[prost(int32, tag = "4")]
pub scale: i32,
}
#[allow(clippy::derive_partial_eq_without_eq)]
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct List {
#[prost(message, optional, boxed, tag = "1")]
pub field_type: ::core::option::Option<::prost::alloc::boxed::Box<Field>>,
Expand Down Expand Up @@ -446,7 +454,7 @@ pub struct Decimal256 {
pub struct ArrowType {
#[prost(
oneof = "arrow_type::ArrowTypeEnum",
tags = "1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 35, 32, 15, 34, 16, 31, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 33"
tags = "1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 35, 32, 15, 34, 16, 31, 17, 18, 19, 20, 21, 22, 23, 24, 36, 25, 26, 27, 28, 29, 30, 33"
)]
pub arrow_type_enum: ::core::option::Option<arrow_type::ArrowTypeEnum>,
}
Expand Down Expand Up @@ -516,6 +524,8 @@ pub mod arrow_type {
Interval(i32),
#[prost(message, tag = "24")]
Decimal(super::Decimal),
#[prost(message, tag = "36")]
Decimal256(super::Decimal256Type),
#[prost(message, tag = "25")]
List(::prost::alloc::boxed::Box<super::List>),
#[prost(message, tag = "26")]
Expand Down
2 changes: 2 additions & 0 deletions datafusion/proto/tests/cases/roundtrip_logical_plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ use arrow::array::{
use arrow::datatypes::{
DataType, Field, Fields, Int32Type, IntervalDayTimeType, IntervalMonthDayNanoType,
IntervalUnit, Schema, SchemaRef, TimeUnit, UnionFields, UnionMode,
DECIMAL256_MAX_PRECISION,
};
use prost::Message;

Expand Down Expand Up @@ -1379,6 +1380,7 @@ fn round_trip_datatype() {
DataType::Utf8,
DataType::LargeUtf8,
DataType::Decimal128(7, 12),
DataType::Decimal256(DECIMAL256_MAX_PRECISION, 0),
// Recursive list tests
DataType::List(new_arc_field("Level1", DataType::Binary, true)),
DataType::List(new_arc_field(
Expand Down

0 comments on commit 77311a5

Please sign in to comment.