diff --git a/Cargo.toml b/Cargo.toml index 72e7dc4b078..4e083b60e3a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -102,6 +102,7 @@ ahash = "0.8" # Support conversion to/from arrow-rs arrow-buffer = { version = "35.0.0", optional = true } +arrow-schema = { version = "35.0.0", optional = true } [target.wasm32-unknown-unknown.dependencies] getrandom = { version = "0.2", features = ["js"] } @@ -158,7 +159,7 @@ full = [ # parses timezones used in timestamp conversions "chrono-tz", ] -arrow = ["arrow-buffer"] +arrow = ["arrow-buffer", "arrow-schema"] io_odbc = ["odbc-api"] io_csv = ["io_csv_read", "io_csv_write"] io_csv_async = ["io_csv_read_async"] diff --git a/src/datatypes/field.rs b/src/datatypes/field.rs index 07d1b760211..6807f2a54c2 100644 --- a/src/datatypes/field.rs +++ b/src/datatypes/field.rs @@ -52,3 +52,24 @@ impl Field { &self.data_type } } + +#[cfg(feature = "arrow")] +impl From for arrow_schema::Field { + fn from(value: Field) -> Self { + Self::new(value.name, value.data_type.into(), value.is_nullable) + .with_metadata(value.metadata.into_iter().collect()) + } +} + +#[cfg(feature = "arrow")] +impl From for Field { + fn from(value: arrow_schema::Field) -> Self { + let data_type = value.data_type().clone().into(); + let metadata = value + .metadata() + .iter() + .map(|(k, v)| (k.clone(), v.clone())) + .collect(); + Self::new(value.name(), data_type, value.is_nullable()).with_metadata(metadata) + } +} diff --git a/src/datatypes/mod.rs b/src/datatypes/mod.rs index dafbb57848e..2582bb7a6cd 100644 --- a/src/datatypes/mod.rs +++ b/src/datatypes/mod.rs @@ -160,6 +160,126 @@ pub enum DataType { Extension(String, Box, Option), } +#[cfg(feature = "arrow")] +impl From for arrow_schema::DataType { + fn from(value: DataType) -> Self { + match value { + DataType::Null => Self::Null, + DataType::Boolean => Self::Boolean, + DataType::Int8 => Self::Int8, + DataType::Int16 => Self::Int16, + DataType::Int32 => Self::Int32, + DataType::Int64 => Self::Int64, + DataType::UInt8 => Self::UInt8, + DataType::UInt16 => Self::UInt16, + DataType::UInt32 => Self::UInt32, + DataType::UInt64 => Self::UInt64, + DataType::Float16 => Self::Float16, + DataType::Float32 => Self::Float32, + DataType::Float64 => Self::Float64, + DataType::Timestamp(unit, tz) => Self::Timestamp(unit.into(), tz), + DataType::Date32 => Self::Date32, + DataType::Date64 => Self::Date64, + DataType::Time32(unit) => Self::Time32(unit.into()), + DataType::Time64(unit) => Self::Time64(unit.into()), + DataType::Duration(unit) => Self::Duration(unit.into()), + DataType::Interval(unit) => Self::Interval(unit.into()), + DataType::Binary => Self::Binary, + DataType::FixedSizeBinary(size) => Self::FixedSizeBinary(size as _), + DataType::LargeBinary => Self::LargeBinary, + DataType::Utf8 => Self::Utf8, + DataType::LargeUtf8 => Self::LargeUtf8, + DataType::List(f) => Self::List(Box::new((*f).into())), + DataType::FixedSizeList(f, size) => { + Self::FixedSizeList(Box::new((*f).into()), size as _) + } + DataType::LargeList(f) => Self::LargeList(Box::new((*f).into())), + DataType::Struct(f) => Self::Struct(f.into_iter().map(Into::into).collect()), + DataType::Union(fields, Some(ids), mode) => { + let ids = ids.into_iter().map(|x| x as _).collect(); + let fields = fields.into_iter().map(Into::into).collect(); + Self::Union(fields, ids, mode.into()) + } + DataType::Union(fields, None, mode) => { + let ids = (0..fields.len() as i8).collect(); + let fields = fields.into_iter().map(Into::into).collect(); + Self::Union(fields, ids, mode.into()) + } + DataType::Map(f, ordered) => Self::Map(Box::new((*f).into()), ordered), + DataType::Dictionary(key, value, _) => Self::Dictionary( + Box::new(DataType::from(key).into()), + Box::new((*value).into()), + ), + DataType::Decimal(precision, scale) => Self::Decimal128(precision as _, scale as _), + DataType::Decimal256(precision, scale) => Self::Decimal256(precision as _, scale as _), + DataType::Extension(_, d, _) => (*d).into(), + } + } +} + +#[cfg(feature = "arrow")] +impl From for DataType { + fn from(value: arrow_schema::DataType) -> Self { + use arrow_schema::DataType; + match value { + DataType::Null => Self::Null, + DataType::Boolean => Self::Boolean, + DataType::Int8 => Self::Int8, + DataType::Int16 => Self::Int16, + DataType::Int32 => Self::Int32, + DataType::Int64 => Self::Int64, + DataType::UInt8 => Self::UInt8, + DataType::UInt16 => Self::UInt16, + DataType::UInt32 => Self::UInt32, + DataType::UInt64 => Self::UInt64, + DataType::Float16 => Self::Float16, + DataType::Float32 => Self::Float32, + DataType::Float64 => Self::Float64, + DataType::Timestamp(unit, tz) => Self::Timestamp(unit.into(), tz), + DataType::Date32 => Self::Date32, + DataType::Date64 => Self::Date64, + DataType::Time32(unit) => Self::Time32(unit.into()), + DataType::Time64(unit) => Self::Time64(unit.into()), + DataType::Duration(unit) => Self::Duration(unit.into()), + DataType::Interval(unit) => Self::Interval(unit.into()), + DataType::Binary => Self::Binary, + DataType::FixedSizeBinary(size) => Self::FixedSizeBinary(size as _), + DataType::LargeBinary => Self::LargeBinary, + DataType::Utf8 => Self::Utf8, + DataType::LargeUtf8 => Self::LargeUtf8, + DataType::List(f) => Self::List(Box::new((*f).into())), + DataType::FixedSizeList(f, size) => { + Self::FixedSizeList(Box::new((*f).into()), size as _) + } + DataType::LargeList(f) => Self::LargeList(Box::new((*f).into())), + DataType::Struct(f) => Self::Struct(f.into_iter().map(Into::into).collect()), + DataType::Union(fields, ids, mode) => { + let ids = ids.into_iter().map(|x| x as _).collect(); + let fields = fields.into_iter().map(Into::into).collect(); + Self::Union(fields, Some(ids), mode.into()) + } + DataType::Map(f, ordered) => Self::Map(Box::new((*f).into()), ordered), + DataType::Dictionary(key, value) => { + let key = match *key { + DataType::Int8 => IntegerType::Int8, + DataType::Int16 => IntegerType::Int16, + DataType::Int32 => IntegerType::Int32, + DataType::Int64 => IntegerType::Int64, + DataType::UInt8 => IntegerType::UInt8, + DataType::UInt16 => IntegerType::UInt16, + DataType::UInt32 => IntegerType::UInt32, + DataType::UInt64 => IntegerType::UInt64, + d => panic!("illegal dictionary key type: {d}"), + }; + Self::Dictionary(key, Box::new((*value).into()), false) + } + DataType::Decimal128(precision, scale) => Self::Decimal(precision as _, scale as _), + DataType::Decimal256(precision, scale) => Self::Decimal256(precision as _, scale as _), + DataType::RunEndEncoded(_, _) => panic!("Run-end encoding not supported by arrow2"), + } + } +} + /// Mode of [`DataType::Union`] #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] #[cfg_attr(feature = "serde_types", derive(Serialize, Deserialize))] @@ -170,6 +290,26 @@ pub enum UnionMode { Sparse, } +#[cfg(feature = "arrow")] +impl From for arrow_schema::UnionMode { + fn from(value: UnionMode) -> Self { + match value { + UnionMode::Dense => Self::Dense, + UnionMode::Sparse => Self::Sparse, + } + } +} + +#[cfg(feature = "arrow")] +impl From for UnionMode { + fn from(value: arrow_schema::UnionMode) -> Self { + match value { + arrow_schema::UnionMode::Dense => Self::Dense, + arrow_schema::UnionMode::Sparse => Self::Sparse, + } + } +} + impl UnionMode { /// Constructs a [`UnionMode::Sparse`] if the input bool is true, /// or otherwise constructs a [`UnionMode::Dense`] @@ -206,6 +346,30 @@ pub enum TimeUnit { Nanosecond, } +#[cfg(feature = "arrow")] +impl From for arrow_schema::TimeUnit { + fn from(value: TimeUnit) -> Self { + match value { + TimeUnit::Nanosecond => Self::Nanosecond, + TimeUnit::Millisecond => Self::Millisecond, + TimeUnit::Microsecond => Self::Microsecond, + TimeUnit::Second => Self::Second, + } + } +} + +#[cfg(feature = "arrow")] +impl From for TimeUnit { + fn from(value: arrow_schema::TimeUnit) -> Self { + match value { + arrow_schema::TimeUnit::Nanosecond => Self::Nanosecond, + arrow_schema::TimeUnit::Millisecond => Self::Millisecond, + arrow_schema::TimeUnit::Microsecond => Self::Microsecond, + arrow_schema::TimeUnit::Second => Self::Second, + } + } +} + /// Interval units defined in Arrow #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] #[cfg_attr(feature = "serde_types", derive(Serialize, Deserialize))] @@ -219,6 +383,28 @@ pub enum IntervalUnit { MonthDayNano, } +#[cfg(feature = "arrow")] +impl From for arrow_schema::IntervalUnit { + fn from(value: IntervalUnit) -> Self { + match value { + IntervalUnit::YearMonth => Self::YearMonth, + IntervalUnit::DayTime => Self::DayTime, + IntervalUnit::MonthDayNano => Self::MonthDayNano, + } + } +} + +#[cfg(feature = "arrow")] +impl From for IntervalUnit { + fn from(value: arrow_schema::IntervalUnit) -> Self { + match value { + arrow_schema::IntervalUnit::YearMonth => Self::YearMonth, + arrow_schema::IntervalUnit::DayTime => Self::DayTime, + arrow_schema::IntervalUnit::MonthDayNano => Self::MonthDayNano, + } + } +} + impl DataType { /// the [`PhysicalType`] of this [`DataType`]. pub fn to_physical_type(&self) -> PhysicalType {