diff --git a/arrow-schema/src/fields.rs b/arrow-schema/src/fields.rs index f90632455fd9..b2ecbb796312 100644 --- a/arrow-schema/src/fields.rs +++ b/arrow-schema/src/fields.rs @@ -15,10 +15,11 @@ // specific language governing permissions and limitations // under the License. -use crate::{ArrowError, Field, FieldRef, SchemaBuilder}; use std::ops::Deref; use std::sync::Arc; +use crate::{ArrowError, DataType, Field, FieldRef, SchemaBuilder}; + /// A cheaply cloneable, owned slice of [`FieldRef`] /// /// Similar to `Arc>` or `Arc<[FieldRef]>` @@ -99,6 +100,90 @@ impl Fields { .all(|(a, b)| Arc::ptr_eq(a, b) || a.contains(b)) } + /// Performs a depth-first scan of [`Fields`] filtering the [`FieldRef`] with no children + /// + /// Returns a new [`Fields`] comprising the [`FieldRef`] for which `filter` returned `true` + /// + /// ``` + /// # use arrow_schema::{DataType, Field, Fields}; + /// let fields = Fields::from(vec![ + /// Field::new("a", DataType::Int32, true), + /// Field::new("b", DataType::Struct(Fields::from(vec![ + /// Field::new("c", DataType::Float32, false), + /// Field::new("d", DataType::Float64, false), + /// ])), false) + /// ]); + /// let filtered = fields.filter_leaves(|idx, _| idx == 0 || idx == 2); + /// let expected = Fields::from(vec![ + /// Field::new("a", DataType::Int32, true), + /// Field::new("b", DataType::Struct(Fields::from(vec![ + /// Field::new("d", DataType::Float64, false), + /// ])), false) + /// ]); + /// assert_eq!(filtered, expected); + /// ``` + pub fn filter_leaves bool>(&self, mut filter: F) -> Self { + fn filter_field bool>( + f: &FieldRef, + filter: &mut F, + ) -> Option { + use DataType::*; + + let (k, v) = match f.data_type() { + Dictionary(k, v) => (Some(k.clone()), v.as_ref()), + d => (None, d), + }; + let d = match v { + List(child) => List(filter_field(child, filter)?), + LargeList(child) => LargeList(filter_field(child, filter)?), + Map(child, ordered) => Map(filter_field(child, filter)?, *ordered), + FixedSizeList(child, size) => FixedSizeList(filter_field(child, filter)?, *size), + Struct(fields) => { + let filtered: Fields = fields + .iter() + .filter_map(|f| filter_field(f, filter)) + .collect(); + + if filtered.is_empty() { + return None; + } + + Struct(filtered) + } + Union(fields, mode) => { + let filtered: UnionFields = fields + .iter() + .filter_map(|(id, f)| Some((id, filter_field(f, filter)?))) + .collect(); + + if filtered.is_empty() { + return None; + } + + Union(filtered, *mode) + } + _ => return filter(f).then(|| f.clone()), + }; + let d = match k { + Some(k) => Dictionary(k, Box::new(d)), + None => d, + }; + Some(Arc::new(f.as_ref().clone().with_data_type(d))) + } + + let mut leaf_idx = 0; + let mut filter = |f: &FieldRef| { + let t = filter(leaf_idx, f); + leaf_idx += 1; + t + }; + + self.0 + .iter() + .filter_map(|f| filter_field(f, &mut filter)) + .collect() + } + /// Remove a field by index and return it. /// /// # Panic @@ -307,3 +392,107 @@ impl FromIterator<(i8, FieldRef)> for UnionFields { Self(iter.into_iter().collect()) } } + +#[cfg(test)] +mod tests { + use super::*; + use crate::UnionMode; + + #[test] + fn test_filter() { + let floats = Fields::from(vec![ + Field::new("a", DataType::Float32, false), + Field::new("b", DataType::Float32, false), + ]); + let fields = Fields::from(vec![ + Field::new("a", DataType::Int32, true), + Field::new("floats", DataType::Struct(floats.clone()), true), + Field::new("b", DataType::Int16, true), + Field::new( + "c", + DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)), + false, + ), + Field::new( + "d", + DataType::Dictionary( + Box::new(DataType::Int32), + Box::new(DataType::Struct(floats.clone())), + ), + false, + ), + Field::new_list( + "e", + Field::new("floats", DataType::Struct(floats.clone()), true), + true, + ), + Field::new( + "f", + DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Int32, false)), 3), + false, + ), + Field::new_map( + "g", + "entries", + Field::new("keys", DataType::LargeUtf8, false), + Field::new("values", DataType::Int32, true), + false, + false, + ), + Field::new( + "h", + DataType::Union( + UnionFields::new( + vec![1, 3], + vec![ + Field::new("field1", DataType::UInt8, false), + Field::new("field3", DataType::Utf8, false), + ], + ), + UnionMode::Dense, + ), + true, + ), + ]); + + let floats_a = DataType::Struct(vec![floats[0].clone()].into()); + + let r = fields.filter_leaves(|idx, _| idx == 0 || idx == 1); + assert_eq!(r.len(), 2); + assert_eq!(r[0], fields[0]); + assert_eq!(r[1].data_type(), &floats_a); + + let r = fields.filter_leaves(|_, f| f.name() == "a"); + assert_eq!(r.len(), 4); + assert_eq!(r[0], fields[0]); + assert_eq!(r[1].data_type(), &floats_a); + assert_eq!( + r[2].data_type(), + &DataType::Dictionary(Box::new(DataType::Int32), Box::new(floats_a.clone())) + ); + assert_eq!( + r[3].as_ref(), + &Field::new_list("e", Field::new("floats", floats_a.clone(), true), true) + ); + + let r = fields.filter_leaves(|_, f| f.name() == "floats"); + assert_eq!(r.len(), 0); + + let r = fields.filter_leaves(|idx, _| idx == 9); + assert_eq!(r.len(), 1); + assert_eq!(r[0], fields[6]); + + let r = fields.filter_leaves(|idx, _| idx == 10 || idx == 11); + assert_eq!(r.len(), 1); + assert_eq!(r[0], fields[7]); + + let union = DataType::Union( + UnionFields::new(vec![1], vec![Field::new("field1", DataType::UInt8, false)]), + UnionMode::Dense, + ); + + let r = fields.filter_leaves(|idx, _| idx == 12); + assert_eq!(r.len(), 1); + assert_eq!(r[0].data_type(), &union); + } +}