From a20e8aa296a4c6b290d4b99ec63a1284157dfcf4 Mon Sep 17 00:00:00 2001 From: Jonas Irgens Kylling Date: Wed, 23 Oct 2024 13:43:52 +0200 Subject: [PATCH] Handle more types and fix offset bug Ideally, I'd like to rewrite this to have a very different interface. Something like ```rust trait TransformArray<'a> { fn transform(&'a Array) -> impl Lookup<'a> } trait Lookup<'a> { type Item; fn get(&self, index: usize) -> Option; } ``` Then all the iterators converting from RecordBatch or StructArray would be extension traits on top of this trait. --- Cargo.lock | 102 +++++++++++++ Cargo.toml | 5 +- arrow_struct/src/lib.rs | 253 +++++++++++++++++++++++++++---- arrow_struct_derive/Cargo.toml | 3 +- arrow_struct_derive/src/lib.rs | 111 ++++++++++++-- benchmarks/benches/benchmarks.rs | 14 +- examples/src/lib.rs | 119 ++++++++++++++- 7 files changed, 554 insertions(+), 53 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 895cb6f..e1d8551 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -52,6 +52,12 @@ version = "1.0.8" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1bec1de6f59aedf83baf9ff929c98f2ad654b97c9510f4e70cf6f661d49fd5b1" +[[package]] +name = "arrayvec" +version = "0.7.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7c02d123df017efcdfbd739ef81735b36c5ba83ec3c59c80a9d7ecc718f92e50" + [[package]] name = "arrow" version = "52.2.0" @@ -280,6 +286,7 @@ name = "arrow_struct_derive" version = "0.1.0" dependencies = [ "convert_case", + "deluxe", "proc-macro2", "quote", "syn", @@ -563,6 +570,47 @@ dependencies = [ "memchr", ] +[[package]] +name = "deluxe" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8ed332aaf752b459088acf3dd4eca323e3ef4b83c70a84ca48fb0ec5305f1488" +dependencies = [ + "deluxe-core", + "deluxe-macros", + "once_cell", + "proc-macro2", + "syn", +] + +[[package]] +name = "deluxe-core" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "eddada51c8576df9d6a8450c351ff63042b092c9458b8ac7d20f89cbd0ffd313" +dependencies = [ + "arrayvec", + "proc-macro2", + "quote", + "strsim", + "syn", +] + +[[package]] +name = "deluxe-macros" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f87546d9c837f0b7557e47b8bd6eae52c3c223141b76aa233c345c9ab41d9117" +dependencies = [ + "deluxe-core", + "heck", + "if_chain", + "proc-macro-crate", + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "either" version = "1.13.0" @@ -625,6 +673,12 @@ version = "0.14.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e5274423e17b7c9fc20b6e7e208532f9b19825d82dfd615708b70edd83df41f1" +[[package]] +name = "heck" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "95505c38b4572b2d910cecb0281560f54b440a19336cbbcb27bf6ce6adc6f5a8" + [[package]] name = "hermit-abi" version = "0.4.0" @@ -654,6 +708,12 @@ dependencies = [ "cc", ] +[[package]] +name = "if_chain" +version = "1.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cb56e1aa765b4b4f3aadfab769793b7087bb03a4ea4920644a6d238e2df5b9ed" + [[package]] name = "indexmap" version = "2.5.0" @@ -907,6 +967,16 @@ dependencies = [ "plotters-backend", ] +[[package]] +name = "proc-macro-crate" +version = "1.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7f4c021e1093a56626774e81216a4ce732a735e5bad4868a03f3ed65ca0c3919" +dependencies = [ + "once_cell", + "toml_edit", +] + [[package]] name = "proc-macro2" version = "1.0.86" @@ -1064,6 +1134,12 @@ version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a2eb9349b6444b326872e140eb1cf5e7c522154d69e7a0ffb0fb81c06b37543f" +[[package]] +name = "strsim" +version = "0.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "73473c0e59e6d5812c5dfe2a064a6444949f089e20eec9a2e5506596494e4623" + [[package]] name = "syn" version = "2.0.76" @@ -1094,6 +1170,23 @@ dependencies = [ "serde_json", ] +[[package]] +name = "toml_datetime" +version = "0.6.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0dd7358ecb8fc2f8d014bf86f6f638ce72ba252a2c3a2572f2a795f1d23efb41" + +[[package]] +name = "toml_edit" +version = "0.19.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1b5bb770da30e5cbfde35a2d7b9b8a2c4b8ef89548a7a6aeab5c9a576e3e7421" +dependencies = [ + "indexmap", + "toml_datetime", + "winnow", +] + [[package]] name = "unicode-ident" version = "1.0.12" @@ -1293,6 +1386,15 @@ version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec" +[[package]] +name = "winnow" +version = "0.5.40" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f593a95398737aeed53e489c785df13f3618e41dbcd6718c6addbf1395aa6876" +dependencies = [ + "memchr", +] + [[package]] name = "zerocopy" version = "0.7.35" diff --git a/Cargo.toml b/Cargo.toml index a490c2f..a711c46 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -4,6 +4,8 @@ members = [ "arrow_struct", "examples", "benchmarks", ] +resolver = "2" + [workspace.package] @@ -22,4 +24,5 @@ quote = "1.0.37" syn = "2.0.76" serde_arrow = { version = "0.11.7", features = ["arrow-52"] } serde = { version = "1.0.210" } -criterion = { version = "0.5.1" } \ No newline at end of file +criterion = { version = "0.5.1" } +deluxe = "0.5.0" \ No newline at end of file diff --git a/arrow_struct/src/lib.rs b/arrow_struct/src/lib.rs index ba9e964..88629e7 100644 --- a/arrow_struct/src/lib.rs +++ b/arrow_struct/src/lib.rs @@ -11,6 +11,7 @@ use arrow::datatypes::{ pub use arrow::record_batch::RecordBatch; pub use bytes::Bytes; use std::fmt::Debug; +pub use std::option::Option; pub use arrow_struct_derive::Deserialize; @@ -18,10 +19,52 @@ pub trait FromArrayRef<'a>: Sized { fn from_array_ref(array: &'a ArrayRef) -> impl Iterator; } +pub trait FromArrayRefOpt<'a>: Sized { + type Item; + fn from_array_ref_opt(array: &'a ArrayRef) -> impl Iterator>; +} + +impl<'a, T: FromArrayRefOpt<'a, Item = T>> FromArrayRef<'a> for Option { + fn from_array_ref(array: &'a ArrayRef) -> impl Iterator { + T::from_array_ref_opt(array) + } +} + +impl<'a, T: FromArrayRefOpt<'a, Item = T>> FromArrayRefOpt<'a> for Option { + type Item = T; + + fn from_array_ref_opt(array: &'a ArrayRef) -> impl Iterator> { + T::from_array_ref_opt(array) + } +} + +// Effectively a marker trait, since stable Rust does not have specialization or negative trait bounds +pub trait NullConversion: Sized { + type Item; + fn convert(item: Option) -> Self; +} + +impl NullConversion for Option { + type Item = T; + + fn convert(item: Option) -> Self { + item + } +} + +impl NullConversion for Vec { + type Item = Vec; + + fn convert(item: Option) -> Self { + item.unwrap() + } +} + macro_rules! impl_from_array_ref_primitive { ($native_ty:ty, $data_ty:ty) => { - impl<'a> FromArrayRef<'a> for Option<$native_ty> { - fn from_array_ref(array: &'a ArrayRef) -> impl Iterator { + impl<'a> FromArrayRefOpt<'a> for $native_ty { + type Item = $native_ty; + fn from_array_ref_opt(array: &'a ArrayRef) -> impl Iterator> { let array = array .as_primitive_opt::<$data_ty>() .expect(&format!(concat!(stringify!(Expected #data_ty), ", was {:?}"), array.data_type())); @@ -38,6 +81,22 @@ macro_rules! impl_from_array_ref_primitive { array.iter().map(Option::unwrap) } } + + impl_null_conversion_simple_type!($native_ty); + }; +} + +macro_rules! impl_null_conversion_simple_type { + ($native_ty:ty) => { + impl NullConversion for $native_ty { + type Item = $native_ty; + + fn convert(item: Option) -> Self { + item.unwrap() + } + } + + impl NotNull for $native_ty {} }; } @@ -52,16 +111,40 @@ impl_from_array_ref_primitive!(u64, UInt64Type); impl_from_array_ref_primitive!(f32, Float32Type); impl_from_array_ref_primitive!(f64, Float64Type); -impl<'a> FromArrayRef<'a> for Option { - fn from_array_ref(array: &'a ArrayRef) -> impl Iterator { +impl_null_conversion_simple_type!(String); +impl_null_conversion_simple_type!(Bytes); +impl_null_conversion_simple_type!(bool); + +impl<'a> FromArrayRefOpt<'a> for bool { + type Item = bool; + + fn from_array_ref_opt(array: &'a ArrayRef) -> impl Iterator> { let array = array.as_boolean(); - array.iter() + let nulls = array.nulls(); + let mut iter = array.iter(); + let mut position = 0; + std::iter::from_fn(move || { + if let Some(next) = iter.next() { + position += 1; + if nulls + .map(|nulls| nulls.is_null(position)) + .unwrap_or_default() + { + Some(None) + } else { + Some(next) + } + } else { + None + } + }) } } -impl<'a> FromArrayRef<'a> for Option { - fn from_array_ref(array: &'a ArrayRef) -> impl Iterator { - let res: Box> = match array.data_type() { +impl<'a> FromArrayRefOpt<'a> for String { + type Item = String; + fn from_array_ref_opt(array: &'a ArrayRef) -> impl Iterator> { + let res: Box>> = match array.data_type() { DataType::Utf8 => { let array = array.as_string::(); Box::new(array.iter().map(|s| s.map(|s| s.to_string()))) @@ -78,9 +161,17 @@ impl<'a> FromArrayRef<'a> for Option { } } -impl<'a> FromArrayRef<'a> for Option<&'a str> { +impl<'a> FromArrayRef<'a> for String { fn from_array_ref(array: &'a ArrayRef) -> impl Iterator { - let res: Box> = match array.data_type() { + Option::::from_array_ref(array).map(|x| x.expect("unwrap String")) + } +} + +impl<'a> FromArrayRefOpt<'a> for &'a str { + type Item = Self; + + fn from_array_ref_opt(array: &'a ArrayRef) -> impl Iterator> { + let res: Box>> = match array.data_type() { DataType::Utf8 => { let array = array.as_string::(); Box::new(array.iter()) @@ -97,33 +188,54 @@ impl<'a> FromArrayRef<'a> for Option<&'a str> { } } -impl<'a> FromArrayRef<'a> for Option { +impl<'a> FromArrayRef<'a> for &'a str { fn from_array_ref(array: &'a ArrayRef) -> impl Iterator { - let res: Box> = match array.data_type() { + Option::<&'a str>::from_array_ref(array).map(|x| x.expect("unwrap str")) + } +} + +impl<'a> FromArrayRefOpt<'a> for Bytes { + type Item = Bytes; + + fn from_array_ref_opt(array: &'a ArrayRef) -> impl Iterator> { + let res: Box>> = match array.data_type() { DataType::Binary => { let array = array.as_binary::(); - Box::new(array.iter() - .map(|bytes| bytes.map(|bytes| Bytes::from(bytes.to_vec())))) + Box::new( + array + .iter() + .map(|bytes| bytes.map(|bytes| Bytes::from(bytes.to_vec()))), + ) } DataType::LargeBinary => { let array = array.as_binary::(); - Box::new(array.iter() - .map(|bytes| bytes.map(|bytes| Bytes::from(bytes.to_vec())))) + Box::new( + array + .iter() + .map(|bytes| bytes.map(|bytes| Bytes::from(bytes.to_vec()))), + ) } _ => { - panic!("Expected String, was {:?}", array.data_type()) + panic!("Expected Binary, was {:?}", array.data_type()) } }; res } } -impl<'a, 'c> FromArrayRef<'a> for Option<&'c [u8]> +impl<'a> FromArrayRef<'a> for Bytes { + fn from_array_ref(array: &'a ArrayRef) -> impl Iterator { + Option::::from_array_ref(array).map(|x| x.expect("unwrap bytes")) + } +} + +impl<'a, 'c> FromArrayRefOpt<'a> for &'c [u8] where 'a: 'c, { - fn from_array_ref(array: &'a ArrayRef) -> impl Iterator> { - let res: Box> = match array.data_type() { + type Item = &'c [u8]; + fn from_array_ref_opt(array: &'a ArrayRef) -> impl Iterator> { + let res: Box>> = match array.data_type() { DataType::Binary => { let array = array.as_binary::(); Box::new(array.iter()) @@ -140,7 +252,9 @@ where } } -impl<'a, T: FromArrayRef<'a> + Debug + 'a> FromArrayRef<'a> for Option> { +impl<'a, T: FromArrayRefOpt<'a> + Debug + 'a> FromArrayRefOpt<'a> for Vec> { + type Item = Vec>; + // TODO: Needs extensive testing. // This is a bit verbose, but the naive implementation below is too slow: // array.iter() @@ -149,14 +263,22 @@ impl<'a, T: FromArrayRef<'a> + Debug + 'a> FromArrayRef<'a> for Option> { // We must use array.values() directly and handle the offsets, as we cannot call // T::from_array_ref in any kind of loop. // Could be room for more optimization by not using iterators? - fn from_array_ref(array: &'a ArrayRef) -> impl Iterator { - fn helper<'a, O: OffsetSizeTrait + Into, T: FromArrayRef<'a> + Debug + 'a>( + fn from_array_ref_opt(array: &'a ArrayRef) -> impl Iterator> { + fn helper<'a, O: OffsetSizeTrait + Into, T: FromArrayRefOpt<'a> + Debug + 'a>( array: &'a GenericListArray, - ) -> impl Iterator>> + 'a { + ) -> impl Iterator>::Item>>>> + 'a + { let nulls = array.logical_nulls(); - let mut inner = T::from_array_ref(array.values()); + let offsets = array.offsets(); + let mut inner = T::from_array_ref_opt(array.values()); let mut current_position = 0; + if let Some(first_offset) = offsets.first() { + for _ in 0..first_offset.as_usize() { + let _ = inner.next(); + } + } + std::iter::from_fn(move || { if current_position >= array.len() { return None; @@ -185,17 +307,90 @@ impl<'a, T: FromArrayRef<'a> + Debug + 'a> FromArrayRef<'a> for Option> { }) } - let res: Box> = match array.data_type() { + let res: Box>> = match array.data_type() { DataType::List(_) => { let array = array.as_list::(); - Box::new(helper(array)) + Box::new(helper::<_, T>(array)) } DataType::LargeList(_) => { let array = array.as_list::(); - Box::new(helper(array)) + Box::new(helper::<_, T>(array)) } _ => { - panic!("Expected Binary, was {:?}", array.data_type()) + panic!("Expected List, was {:?}", array.data_type()) + } + }; + res + } +} + +pub trait NotNull {} + +impl<'a, T: FromArrayRefOpt<'a> + Debug + NotNull + 'a> FromArrayRefOpt<'a> for Vec { + type Item = Vec; + + // TODO: Needs extensive testing. + // This is a bit verbose, but the naive implementation below is too slow: + // array.iter() + // .map(|element| + // element.as_ref().map(|element| T::from_array_ref(element).collect::>())) + // We must use array.values() directly and handle the offsets, as we cannot call + // T::from_array_ref in any kind of loop. + // Could be room for more optimization by not using iterators? + fn from_array_ref_opt(array: &'a ArrayRef) -> impl Iterator> { + fn helper<'a, O: OffsetSizeTrait + Into, T: FromArrayRefOpt<'a> + Debug + 'a>( + array: &'a GenericListArray, + ) -> impl Iterator>::Item>>> + 'a { + let nulls = array.logical_nulls(); + let offsets = array.offsets(); + let mut inner = T::from_array_ref_opt(array.values()); + let mut current_position = 0; + + if let Some(first_offset) = offsets.first() { + for _ in 0..first_offset.as_usize() { + let _ = inner.next(); + } + } + + std::iter::from_fn(move || { + if current_position >= array.len() { + return None; + } + + let len = array.value_length(current_position).into(); + let is_null = nulls + .as_ref() + .map(|buffer| buffer.is_null(current_position)) + .unwrap_or_default(); + let res = if is_null { + for _ in 0..len { + // This can happen if record batch has values which are nulled. It's weird to construct RecordBatches this way, but it's possible + let _ = inner.next().unwrap(); + } + None + } else { + let mut out = Vec::with_capacity(len as usize); + for _ in 0..len { + out.push(inner.next().unwrap().expect("unwrap in vec")); + } + Some(out) + }; + current_position += 1; + Some(res) + }) + } + + let res: Box>> = match array.data_type() { + DataType::List(_) => { + let array = array.as_list::(); + Box::new(helper::<_, T>(array)) + } + DataType::LargeList(_) => { + let array = array.as_list::(); + Box::new(helper::<_, T>(array)) + } + _ => { + panic!("Expected List, was {:?}", array.data_type()) } }; res diff --git a/arrow_struct_derive/Cargo.toml b/arrow_struct_derive/Cargo.toml index a07aa01..4ed341d 100644 --- a/arrow_struct_derive/Cargo.toml +++ b/arrow_struct_derive/Cargo.toml @@ -10,4 +10,5 @@ proc-macro = true quote = { workspace = true } syn = { workspace = true } proc-macro2 = { workspace = true } -convert_case = { workspace = true } \ No newline at end of file +convert_case = { workspace = true } +deluxe = { workspace = true } \ No newline at end of file diff --git a/arrow_struct_derive/src/lib.rs b/arrow_struct_derive/src/lib.rs index 04aaf76..aca1e0c 100644 --- a/arrow_struct_derive/src/lib.rs +++ b/arrow_struct_derive/src/lib.rs @@ -3,15 +3,59 @@ extern crate proc_macro; use convert_case::{Case, Casing}; use proc_macro2::{Span, TokenStream}; use quote::{format_ident, quote, quote_spanned}; +use std::str::FromStr; use syn::spanned::Spanned; use syn::{parse_macro_input, Data, DeriveInput, Fields, GenericParam, LifetimeParam}; -#[proc_macro_derive(Deserialize)] +#[derive(deluxe::ExtractAttributes, Debug, Default)] +#[deluxe(attributes(arrow_struct))] +struct Attributes { + #[deluxe(default = String::from("none"))] + rename_all: String, +} + +#[derive(Default, Debug)] +enum RenameAll { + #[default] + None, + SnakeCase, + CamelCase, +} + +impl FromStr for RenameAll { + type Err = String; + + fn from_str(s: &str) -> Result { + match s { + "snake_case" => Ok(Self::SnakeCase), + "none" => Ok(Self::None), + "camelCase" => Ok(Self::CamelCase), + _ => Err(format!("Unknown case: {}", s)), + } + } +} + +impl From for Option { + fn from(value: RenameAll) -> Self { + match value { + RenameAll::None => None, + RenameAll::SnakeCase => Some(Case::Snake), + RenameAll::CamelCase => Some(Case::Camel), + } + } +} + +#[proc_macro_derive(Deserialize, attributes(arrow_struct))] pub fn derive_deserialize(input: proc_macro::TokenStream) -> proc_macro::TokenStream { let input = parse_macro_input!(input as DeriveInput); + let mut attrs = input.attrs; + let res: Attributes = deluxe::extract_attributes(&mut attrs).unwrap(); + + let rename_all: RenameAll = RenameAll::from_str(&res.rename_all).unwrap(); + let case = rename_all.into(); let name = input.ident; - let (_, ty_generics, where_clause) = input.generics.split_for_impl(); + let (plain_impl_generics, ty_generics, where_clause) = input.generics.split_for_impl(); // We add our reserved lifetime parameter 'ar (like 'de of serde::Deserialize) and add all existing lifetimes as bounds let lt = syn::Lifetime::new("'ar", Span::call_site()); @@ -23,22 +67,41 @@ pub fn derive_deserialize(input: proc_macro::TokenStream) -> proc_macro::TokenSt new_generics.params.push(GenericParam::Lifetime(ltp)); let (impl_generics, _, _) = new_generics.split_for_impl(); - let inner = inner_implementation(&input.data); + let inner = inner_implementation(&input.data, case); let expanded = quote! { - impl #impl_generics arrow_struct::FromArrayRef<'ar> for #name #ty_generics #where_clause { - fn from_array_ref(array: &'ar arrow_struct::ArrayRef) -> impl Iterator { + impl #impl_generics arrow_struct::FromArrayRefOpt<'ar> for #name #ty_generics #where_clause { + type Item=Self; + fn from_array_ref_opt(array: &'ar arrow_struct::ArrayRef) -> impl Iterator>{ let array = arrow_struct::AsArray::as_struct(array); #inner } } + + impl #impl_generics arrow_struct::FromArrayRef<'ar> for #name #ty_generics #where_clause { + fn from_array_ref(array: &'ar arrow_struct::ArrayRef) -> impl Iterator { + <#name #ty_generics as arrow_struct::FromArrayRefOpt<'ar>>::from_array_ref_opt(array) + .map(|x| arrow_struct::Option::expect(x, stringify!(unwrap on #name))) + } + } + + impl #plain_impl_generics arrow_struct::NullConversion for #name #ty_generics #where_clause { + type Item=Self; + fn convert(item: Option) -> Self { + item.unwrap() + } + } + + impl #plain_impl_generics arrow_struct::NotNull for #name #ty_generics #where_clause { + + } }; proc_macro::TokenStream::from(expanded) } -fn inner_implementation(data: &Data) -> TokenStream { +fn inner_implementation(data: &Data, case: Option) -> TokenStream { match *data { Data::Struct(ref data) => match data.fields { Fields::Named(ref fields) => { @@ -46,33 +109,53 @@ fn inner_implementation(data: &Data) -> TokenStream { let idents_clone = idents.clone(); let iterators = fields.named.iter().map(|field| { + let ident = field.ident.clone(); let name = field.ident.as_ref().unwrap().to_string(); let field_type = field.ty.clone(); - let column_name = name.clone(); //.to_case(Case::Camel); + let column_name = if let Some(case) = case { name.clone().to_case(case) } else { name.clone() }; let iterator_name = format_ident!("__arrow_struct_derive_{}", name); let iterator_declaration = quote_spanned! {field.span()=> let mut #iterator_name = { let array = array.column_by_name(#column_name) .expect(stringify!(no column named #column_name)); - <#field_type as arrow_struct::FromArrayRef>::from_array_ref(array) + <#field_type as arrow_struct::FromArrayRefOpt>::from_array_ref_opt(array) }; }; - (iterator_name, iterator_declaration) + let conversion = quote_spanned! {field.span()=> + let #ident = <#field_type as arrow_struct::NullConversion>::convert(#ident); + }; + (iterator_name, iterator_declaration, conversion) }); - let iterator_declarations = iterators.clone().map(|(_, declaration)| declaration); - let iterator_next = iterators.clone().map(|(name, _)| quote! { #name.next() }); + let iterator_declarations = + iterators.clone().map(|(_, declaration, _)| declaration); + let iterator_next = iterators + .clone() + .map(|(name, _, _)| quote! { #name.next() }); + let conversions = iterators.clone().map(|(_, _, conversion)| conversion); quote! { + // TODO: See if nulls can be used instead + let is_null = arrow_struct::Array::logical_nulls(array); + #(#iterator_declarations)* + let mut pos = 0; std::iter::from_fn(move || { - if let (#(Some(#idents)),*) = (#(#iterator_next),*) { - Some(Self { #(#idents_clone),* }) + let res = if let (#(Some(#idents)),*) = (#(#iterator_next),*) { + let is_null = is_null.as_ref().map(|x| x.is_null(pos)).unwrap_or(false); + if !is_null { + #(#conversions)* + Some(Some(Self { #(#idents_clone),* })) + } else { + Some(None) + } } else { None - } + }; + pos += 1; + res }) } } diff --git a/benchmarks/benches/benchmarks.rs b/benchmarks/benches/benchmarks.rs index b6b268f..94973c6 100644 --- a/benchmarks/benches/benchmarks.rs +++ b/benchmarks/benches/benchmarks.rs @@ -30,12 +30,14 @@ fn benchmark< let batch = setup_record_batch::(size); let struct_array: StructArray = batch.clone().into(); let array: ArrayRef = Arc::new(struct_array); - c.bench_function(&format!("serde_arrow {} {}", std::any::type_name::(), size), |b| { - b.iter_with_large_drop(|| serde_arrow_convert::(black_box(&batch))) - }); - c.bench_function(&format!("arrow_struct {} {}", std::any::type_name::(), size), |b| { - b.iter_with_large_drop(|| arrow_struct_convert::(black_box(&array))) - }); + c.bench_function( + &format!("serde_arrow {} {}", std::any::type_name::(), size), + |b| b.iter_with_large_drop(|| serde_arrow_convert::(black_box(&batch))), + ); + c.bench_function( + &format!("arrow_struct {} {}", std::any::type_name::(), size), + |b| b.iter_with_large_drop(|| arrow_struct_convert::(black_box(&array))), + ); } fn benchmark_small(c: &mut Criterion) { diff --git a/examples/src/lib.rs b/examples/src/lib.rs index b2c96cf..93a34f1 100644 --- a/examples/src/lib.rs +++ b/examples/src/lib.rs @@ -2,11 +2,12 @@ mod tests { use arrow::array::{ Array, BinaryArray, GenericListBuilder, Int32Builder, LargeBinaryArray, LargeStringArray, - RecordBatch, StructArray, + LargeStringBuilder, RecordBatch, StructArray, }; - use arrow::datatypes::{DataType, Field, FieldRef, Schema}; + use arrow::datatypes::{DataType, Field, FieldRef, Fields, Schema}; use arrow_struct::Deserialize; use arrow_struct::FromArrayRef; + use serde_arrow::_impl::arrow::_raw::buffer::NullBufferBuilder; use serde_arrow::_impl::arrow::array::StringArray; use serde_arrow::schema::{SchemaLike, TracingOptions}; use std::sync::Arc; @@ -30,6 +31,7 @@ mod tests { #[test] fn all_primitive_types() { + println!("here"); let some_string = "0123456789"; let data = (0u8..10) .map(|i| AllPrimitiveTypes { @@ -227,4 +229,117 @@ mod tests { SmallAndLargeArrays::from_array_ref(&array).collect::>() ); } + + #[test] + fn null_object() { + #[allow(dead_code)] + #[derive(serde::Deserialize, serde::Serialize, Deserialize, Debug, PartialEq)] + #[arrow_struct(rename_all = "camelCase")] + struct NullOuter { + inner1: Option, + } + #[allow(dead_code)] + #[derive(serde::Deserialize, serde::Serialize, Deserialize, Debug, PartialEq)] + struct NullInner { + string1: String, + } + + let data = [ + NullOuter { inner1: None }, + NullOuter { + inner1: Some(NullInner { + string1: "hello".to_string(), + }), + }, + NullOuter { inner1: None }, + NullOuter { + inner1: Some(NullInner { + string1: "world".to_string(), + }), + }, + ]; + + let mut string_array_builder = LargeStringBuilder::new(); + string_array_builder.append_null(); + string_array_builder.append_value("hello"); + string_array_builder.append_null(); + string_array_builder.append_value("world"); + + let mut null_buffer_builder = NullBufferBuilder::new(10); + null_buffer_builder.append_null(); + null_buffer_builder.append(true); + null_buffer_builder.append_null(); + null_buffer_builder.append(true); + + let fields_inner = Vec::::from_type::( + TracingOptions::default().allow_null_fields(true), + ) + .unwrap(); + let struct_array_inner = StructArray::new( + Fields::from(fields_inner), + vec![Arc::new(string_array_builder.finish())], + null_buffer_builder.finish(), + ); + + let fields = Vec::::from_type::( + TracingOptions::default().allow_null_fields(true), + ) + .unwrap(); + let struct_array = StructArray::new( + Fields::from(fields), + vec![Arc::new(struct_array_inner)], + None, + ); + let batch = RecordBatch::from(struct_array); + let i = 2; + let length = 2; + let struct_array: StructArray = batch.clone().into(); + let struct_array = struct_array.slice(i, length); + + let array = Arc::new(struct_array) as _; + assert_eq!( + &data[i..i + length], + NullOuter::from_array_ref(&array).collect::>() + ); + } + + #[test] + fn null_object_vec() { + #[allow(dead_code)] + #[derive(serde::Deserialize, serde::Serialize, Deserialize, Debug, PartialEq)] + struct Struct { + string: Vec, + } + let data = (1..=10) + .map(|x| Struct { + string: (1..=x).map(|y| y.to_string()).collect(), + }) + .collect::>(); + let fields = Vec::::from_type::(TracingOptions::default()).unwrap(); + let batch = serde_arrow::to_record_batch(&fields, &data).unwrap(); + let batch = batch.slice(5, 5); + + let struct_array: StructArray = batch.clone().into(); + let array = Arc::new(struct_array) as _; + let actual = Struct::from_array_ref(&array).collect::>(); + assert_eq!(data[5..10], actual); + } + + #[test] + fn camel_case() { + #[allow(dead_code)] + #[derive(serde::Deserialize, serde::Serialize, Deserialize, Debug, PartialEq)] + #[serde(rename_all = "camelCase")] + #[arrow_struct(rename_all = "camelCase")] + struct Struct { + camel_case: i64, + } + let data = vec![Struct { camel_case: 1 }]; + let fields = Vec::::from_type::(TracingOptions::default()).unwrap(); + let batch = serde_arrow::to_record_batch(&fields, &data).unwrap(); + + let struct_array: StructArray = batch.clone().into(); + let array = Arc::new(struct_array) as _; + println!("{:?}", Struct::from_array_ref(&array).collect::>()); + } }