diff --git a/datafusion/common/src/types/logical.rs b/datafusion/common/src/types/logical.rs index fe609fd4873bb..88c5eb646e13a 100644 --- a/datafusion/common/src/types/logical.rs +++ b/datafusion/common/src/types/logical.rs @@ -15,11 +15,12 @@ // specific language governing permissions and limitations // under the License. +use super::NativeType; +use crate::Result; +use arrow_schema::DataType; use core::fmt; use std::{cmp::Ordering, hash::Hash, sync::Arc}; -use super::NativeType; - /// Signature that uniquely identifies a type among other types. #[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)] pub enum TypeSignature<'a> { @@ -75,8 +76,17 @@ pub type LogicalTypeRef = Arc; /// } /// ``` pub trait LogicalType: Sync + Send { + /// Get the native backing type of this logical type. fn native(&self) -> &NativeType; + /// Get the unique type signature for this logical type. Logical types with identical + /// signatures are considered equal. fn signature(&self) -> TypeSignature<'_>; + + /// Get the default physical type to cast `origin` to in order to obtain a physical type + /// that is logically compatible with this logical type. + fn default_cast_for(&self, origin: &DataType) -> Result { + self.native().default_cast_for(origin) + } } impl fmt::Debug for dyn LogicalType { @@ -90,7 +100,7 @@ impl fmt::Debug for dyn LogicalType { impl PartialEq for dyn LogicalType { fn eq(&self, other: &Self) -> bool { - self.native().eq(other.native()) && self.signature().eq(&other.signature()) + self.signature().eq(&other.signature()) } } diff --git a/datafusion/common/src/types/native.rs b/datafusion/common/src/types/native.rs index 66e2e6feae6b0..8d67b24e4268b 100644 --- a/datafusion/common/src/types/native.rs +++ b/datafusion/common/src/types/native.rs @@ -15,13 +15,13 @@ // specific language governing permissions and limitations // under the License. -use std::sync::Arc; - -use arrow_schema::{DataType, IntervalUnit, TimeUnit}; - use super::{ - LogicalFieldRef, LogicalFields, LogicalType, LogicalUnionFields, TypeSignature, + LogicalField, LogicalFieldRef, LogicalFields, LogicalType, LogicalUnionFields, + TypeSignature, }; +use crate::{internal_err, Result}; +use arrow_schema::{DataType, Field, FieldRef, IntervalUnit, TimeUnit}; +use std::sync::Arc; /// Representation of a type that DataFusion can handle natively. It is a subset /// of the physical variants in Arrow's native [`DataType`]. @@ -188,6 +188,145 @@ impl LogicalType for NativeType { fn signature(&self) -> TypeSignature<'_> { TypeSignature::Native(self) } + + fn default_cast_for(&self, origin: &DataType) -> Result { + use DataType::*; + + fn default_field_cast(to: &LogicalField, from: &Field) -> Result { + Ok(Arc::new(Field::new( + to.name.clone(), + to.logical_type.default_cast_for(from.data_type())?, + to.nullable, + ))) + } + + Ok(match (self, origin) { + (Self::Null, _) => Null, + (Self::Boolean, _) => Boolean, + (Self::Int8, _) => Int8, + (Self::Int16, _) => Int16, + (Self::Int32, _) => Int32, + (Self::Int64, _) => Int64, + (Self::UInt8, _) => UInt8, + (Self::UInt16, _) => UInt16, + (Self::UInt32, _) => UInt32, + (Self::UInt64, _) => UInt64, + (Self::Float16, _) => Float16, + (Self::Float32, _) => Float32, + (Self::Float64, _) => Float64, + (Self::Decimal(p, s), _) if p <= &38 => Decimal128(p.clone(), s.clone()), + (Self::Decimal(p, s), _) => Decimal256(p.clone(), s.clone()), + (Self::Timestamp(tu, tz), _) => Timestamp(tu.clone(), tz.clone()), + (Self::Date, _) => Date32, + (Self::Time(tu), _) => match tu { + TimeUnit::Second | TimeUnit::Millisecond => Time32(tu.clone()), + TimeUnit::Microsecond | TimeUnit::Nanosecond => Time64(tu.clone()), + }, + (Self::Duration(tu), _) => Duration(tu.clone()), + (Self::Interval(iu), _) => Interval(iu.clone()), + (Self::Binary, LargeUtf8) => LargeBinary, + (Self::Binary, Utf8View) => BinaryView, + (Self::Binary, _) => Binary, + (Self::FixedSizeBinary(size), _) => FixedSizeBinary(size.clone()), + (Self::Utf8, LargeBinary) => LargeUtf8, + (Self::Utf8, BinaryView) => Utf8View, + (Self::Utf8, _) => Utf8, + (Self::List(to_field), List(from_field) | FixedSizeList(from_field, _)) => { + List(default_field_cast(to_field, from_field)?) + } + (Self::List(to_field), LargeList(from_field)) => { + LargeList(default_field_cast(to_field, from_field)?) + } + (Self::List(to_field), ListView(from_field)) => { + ListView(default_field_cast(to_field, from_field)?) + } + (Self::List(to_field), LargeListView(from_field)) => { + LargeListView(default_field_cast(to_field, from_field)?) + } + // List array where each element is a len 1 list of the origin type + (Self::List(field), _) => List(Arc::new(Field::new( + field.name.clone(), + field.logical_type.default_cast_for(origin)?, + field.nullable, + ))), + ( + Self::FixedSizeList(to_field, to_size), + FixedSizeList(from_field, from_size), + ) if from_size == to_size => { + FixedSizeList(default_field_cast(to_field, from_field)?, to_size.clone()) + } + ( + Self::FixedSizeList(to_field, size), + List(from_field) + | LargeList(from_field) + | ListView(from_field) + | LargeListView(from_field), + ) => FixedSizeList(default_field_cast(to_field, from_field)?, size.clone()), + // FixedSizeList array where each element is a len 1 list of the origin type + (Self::FixedSizeList(field, size), _) => FixedSizeList( + Arc::new(Field::new( + field.name.clone(), + field.logical_type.default_cast_for(origin)?, + field.nullable, + )), + size.clone(), + ), + // From https://github.com/apache/arrow-rs/blob/56525efbd5f37b89d1b56aa51709cab9f81bc89e/arrow-cast/src/cast/mod.rs#L189-L196 + (Self::Struct(to_fields), Struct(from_fields)) + if from_fields.len() == to_fields.len() => + { + Struct( + from_fields + .iter() + .zip(to_fields.iter()) + .map(|(from, to)| default_field_cast(to, from)) + .collect()?, + ) + } + (Self::Struct(to_fields), Null) => Struct( + to_fields + .iter() + .map(|field| { + Ok(Arc::new(Field::new( + field.name.clone(), + field.logical_type.default_cast_for(&Null)?, + field.nullable, + ))) + }) + .collect()?, + ), + (Self::Map(to_field), Map(from_field, sorted)) => { + Map(default_field_cast(to_field, from_field)?, sorted.clone()) + } + (Self::Map(field), Null) => Map( + Arc::new(Field::new( + field.name.clone(), + field.logical_type.default_cast_for(&Null)?, + field.nullable, + )), + false, + ), + (Self::Union(to_fields), Union(from_fields, mode)) + if from_fields.len() == to_fields.len() => + { + Union( + from_fields + .iter() + .zip(to_fields.iter()) + .map(|((_, from), (i, to))| { + (i.clone(), default_field_cast(to, from)) + }) + .collect()?, + mode.clone(), + ) + } + _ => return internal_err!( + "Unavailable default cast for native type {:?} from physical type {:?}", + self, + origin + ), + }) + } } // The following From, From, ... implementations are temporary @@ -230,9 +369,9 @@ impl From for NativeType { DataType::Union(union_fields, _) => { Union(LogicalUnionFields::from(&union_fields)) } - DataType::Dictionary(_, data_type) => data_type.as_ref().clone().into(), DataType::Decimal128(p, s) | DataType::Decimal256(p, s) => Decimal(p, s), DataType::Map(field, _) => Map(Arc::new(field.as_ref().into())), + DataType::Dictionary(_, data_type) => data_type.as_ref().clone().into(), DataType::RunEndEncoded(_, field) => field.data_type().clone().into(), } }