diff --git a/.gitmodules b/.gitmodules index 19c36ba..e69de29 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,3 +0,0 @@ -[submodule "test-data"] - path = test-data - url = https://github.com/datafusion-contrib/test-data.git diff --git a/Cargo.toml b/Cargo.toml index fece9b0..8160903 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -10,13 +10,13 @@ edition = "2021" rust-version = "1.64" [dependencies] -async-trait = { version = "0.1.53" } +async-trait = { version = "^0.1.86" } itertools = { version = "0.12.0" } regex = { version = "1.9.5" } serde_json = { version = "1.0.107" } serde = { version = "1.0.188", features = ["derive"] } flate2 = { version = "1.0.27" } -blosc = { version = "0.2.0" } +blosc-src = { version = "^0.3.4" } crc32c = { version = "0.6.5" } object_store = { version = "0.9" } futures = { version = "0.3" } @@ -49,6 +49,9 @@ all = ["datafusion"] [dev-dependencies] arrow-cast = { version = "50.0.0", features = ["prettyprint"] } chrono = { version = "0.4" } +zarrs = { version = "0.19.2" } +zarrs_filesystem = { version = "0.2.0" } +zarrs_storage = { version = "0.3.0" } +rstest = { version = "0.24.0" } +ndarray = { version = "^0.16.1" } -[[bin]] -name = "async-benchmark" diff --git a/src/async_reader/mod.rs b/src/async_reader/mod.rs index 6ff1a9b..7836c08 100644 --- a/src/async_reader/mod.rs +++ b/src/async_reader/mod.rs @@ -17,75 +17,6 @@ //! A module tha provides an asychronous reader for zarr store, to generate [`RecordBatch`]es. //! -//! ``` -//! # #[tokio::main(flavor="current_thread")] -//! # async fn main() { -//! # -//! # use arrow_zarr::async_reader::{ZarrPath, ZarrRecordBatchStreamBuilder}; -//! # use arrow_zarr::reader::ZarrProjection; -//! # use arrow_cast::pretty::pretty_format_batches; -//! # use arrow_array::RecordBatch; -//! # use object_store::{path::Path, local::LocalFileSystem}; -//! # use std::path::PathBuf; -//! # use std::sync::Arc; -//! # use futures::stream::Stream; -//! # use futures_util::TryStreamExt; -//! # -//! # fn get_test_data_path(zarr_store: String) -> ZarrPath { -//! # let p = PathBuf::from(env!("CARGO_MANIFEST_DIR")) -//! # .join("test-data/data/zarr/v2_data") -//! # .join(zarr_store); -//! # ZarrPath::new( -//! # Arc::new(LocalFileSystem::new()), -//! # Path::from_absolute_path(p).unwrap() -//! # ) -//! # } -//! # -//! # fn assert_batches_eq(batches: &[RecordBatch], expected_lines: &[&str]) { -//! # let formatted = pretty_format_batches(batches).unwrap().to_string(); -//! # let actual_lines: Vec<_> = formatted.trim().lines().collect(); -//! # assert_eq!( -//! # &actual_lines, expected_lines, -//! # "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n", -//! # expected_lines, actual_lines -//! # ); -//! # } -//! -//! // The ZarrReadAsync trait is implemented for the ZarrPath struct. -//! let p: ZarrPath = get_test_data_path("lat_lon_example.zarr".to_string()); -//! -//! let proj = ZarrProjection::keep(vec!["lat".to_string(), "float_data".to_string()]); -//! let builder = ZarrRecordBatchStreamBuilder::new(p).with_projection(proj); -//! let mut stream = builder.build().await.unwrap(); -//! let mut rec_batches: Vec<_> = stream.try_collect().await.unwrap(); -//! -//! assert_batches_eq( -//! &[rec_batches.remove(0)], -//! &[ -//! "+------------+------+", -//! "| float_data | lat |", -//! "+------------+------+", -//! "| 1001.0 | 38.0 |", -//! "| 1002.0 | 38.1 |", -//! "| 1003.0 | 38.2 |", -//! "| 1004.0 | 38.3 |", -//! "| 1012.0 | 38.0 |", -//! "| 1013.0 | 38.1 |", -//! "| 1014.0 | 38.2 |", -//! "| 1015.0 | 38.3 |", -//! "| 1023.0 | 38.0 |", -//! "| 1024.0 | 38.1 |", -//! "| 1025.0 | 38.2 |", -//! "| 1026.0 | 38.3 |", -//! "| 1034.0 | 38.0 |", -//! "| 1035.0 | 38.1 |", -//! "| 1036.0 | 38.2 |", -//! "| 1037.0 | 38.3 |", -//! "+------------+------+", -//! ], -//! ); -//! # } -//! ``` use arrow_array::{BooleanArray, RecordBatch}; use async_trait::async_trait; @@ -709,134 +640,40 @@ impl ZarrReadAsync<'a> + Clone + Unpin + Send + 'static> #[cfg(test)] mod zarr_async_reader_tests { - use arrow::compute::kernels::cmp::{gt_eq, lt}; - use arrow_array::cast::AsArray; + use crate::test_utils::{ + compare_values, create_filter, store_compression_codecs, store_lat_lon, + store_lat_lon_broadcastable, store_partial_sharding, store_partial_sharding_3d, + validate_bool_column, validate_names_and_types, validate_primitive_column, StoreWrapper, + }; + + use arrow::compute::kernels::cmp::gt_eq; use arrow_array::types::*; use arrow_array::*; use arrow_schema::DataType; use futures_util::TryStreamExt; - use itertools::enumerate; use object_store::{local::LocalFileSystem, path::Path}; + use rstest::*; + use std::collections::HashMap; + use std::path::PathBuf; use std::sync::Arc; - use std::{collections::HashMap, fmt::Debug}; use super::*; use crate::async_reader::zarr_read_async::ZarrPath; use crate::reader::{ZarrArrowPredicate, ZarrArrowPredicateFn}; - use crate::tests::{get_test_v2_data_path, get_test_v3_data_path}; - fn get_v2_test_zarr_path(zarr_store: String) -> ZarrPath { + fn get_zarr_path(zarr_store: PathBuf) -> ZarrPath { ZarrPath::new( Arc::new(LocalFileSystem::new()), - Path::from_absolute_path(get_test_v2_data_path(zarr_store)).unwrap(), + Path::from_absolute_path(zarr_store).unwrap(), ) } - fn validate_names_and_types(targets: &HashMap, rec: &RecordBatch) { - let mut target_cols: Vec<&String> = targets.keys().collect(); - let schema = rec.schema(); - let from_rec: Vec<&String> = schema.fields.iter().map(|f| f.name()).collect(); - - target_cols.sort(); - assert_eq!(from_rec, target_cols); - - for field in schema.fields.iter() { - assert_eq!(field.data_type(), targets.get(field.name()).unwrap()); - } - } - - fn validate_bool_column(col_name: &str, rec: &RecordBatch, targets: &[bool]) { - let mut matched = false; - for (idx, col) in enumerate(rec.schema().fields.iter()) { - if col.name().as_str() == col_name { - assert_eq!( - rec.column(idx).as_boolean(), - &BooleanArray::from(targets.to_vec()), - ); - matched = true; - } - } - assert!(matched); - } - - fn validate_primitive_column(col_name: &str, rec: &RecordBatch, targets: &[U]) - where - T: ArrowPrimitiveType, - [U]: AsRef<[::Native]>, - U: Debug, - { - let mut matched = false; - for (idx, col) in enumerate(rec.schema().fields.iter()) { - if col.name().as_str() == col_name { - assert_eq!(rec.column(idx).as_primitive::().values(), targets,); - matched = true; - } - } - assert!(matched); - } - - fn compare_values(col_name1: &str, col_name2: &str, rec: &RecordBatch) - where - T: ArrowPrimitiveType, - { - let mut vals1 = None; - let mut vals2 = None; - for (idx, col) in enumerate(rec.schema().fields.iter()) { - if col.name().as_str() == col_name1 { - vals1 = Some(rec.column(idx).as_primitive::().values()) - } else if col.name().as_str() == col_name2 { - vals2 = Some(rec.column(idx).as_primitive::().values()) - } - } - - if let (Some(vals1), Some(vals2)) = (vals1, vals2) { - assert_eq!(vals1, vals2); - return; - } - - panic!("columns not found"); - } - - // create a test filter - fn create_filter() -> ZarrChunkFilter { - let mut filters: Vec> = Vec::new(); - let f = ZarrArrowPredicateFn::new( - ZarrProjection::keep(vec!["lat".to_string()]), - move |batch| { - gt_eq( - batch.column_by_name("lat").unwrap(), - &Scalar::new(&Float64Array::from(vec![38.6])), - ) - }, - ); - filters.push(Box::new(f)); - let f = ZarrArrowPredicateFn::new( - ZarrProjection::keep(vec!["lon".to_string()]), - move |batch| { - gt_eq( - batch.column_by_name("lon").unwrap(), - &Scalar::new(&Float64Array::from(vec![-109.7])), - ) - }, - ); - filters.push(Box::new(f)); - let f = ZarrArrowPredicateFn::new( - ZarrProjection::keep(vec!["lon".to_string()]), - move |batch| { - lt( - batch.column_by_name("lon").unwrap(), - &Scalar::new(&Float64Array::from(vec![-109.2])), - ) - }, - ); - filters.push(Box::new(f)); - - ZarrChunkFilter::new(filters) - } - + #[rstest] #[tokio::test] - async fn projection_tests() { - let zp = get_v2_test_zarr_path("compression_example.zarr".to_string()); + async fn projection_tests( + #[with("async_projection_tests".to_string())] store_compression_codecs: StoreWrapper, + ) { + let zp = get_zarr_path(store_compression_codecs.store_path()); let proj = ZarrProjection::keep(vec!["bool_data".to_string(), "int_data".to_string()]); let stream_builder = ZarrRecordBatchStreamBuilder::new(zp).with_projection(proj); @@ -847,10 +684,10 @@ mod zarr_async_reader_tests { ("bool_data".to_string(), DataType::Boolean), ("int_data".to_string(), DataType::Int64), ]); + validate_names_and_types(&target_types, &records[0]); // center chunk let rec = &records[4]; - validate_names_and_types(&target_types, rec); validate_bool_column( "bool_data", rec, @@ -863,9 +700,10 @@ mod zarr_async_reader_tests { ); } + #[rstest] #[tokio::test] - async fn filters_tests() { - let zp = get_v2_test_zarr_path("lat_lon_example.zarr".to_string()); + async fn filters_tests(#[with("async_filter_tests".to_string())] store_lat_lon: StoreWrapper) { + let zp = get_zarr_path(store_lat_lon.store_path()); let stream_builder = ZarrRecordBatchStreamBuilder::new(zp).with_filter(create_filter()); let stream = stream_builder.build().await.unwrap(); let records: Vec<_> = stream.try_collect().await.unwrap(); @@ -875,6 +713,7 @@ mod zarr_async_reader_tests { ("lon".to_string(), DataType::Float64), ("float_data".to_string(), DataType::Float64), ]); + validate_names_and_types(&target_types, &records[0]); // check the values in a chunk. the predicate pushdown only takes care of // skipping whole chunks, so there is no guarantee that the values in the @@ -882,7 +721,6 @@ mod zarr_async_reader_tests { // the first chunk that was read is the first one with some values that // satisfy the predicate. let rec = &records[0]; - validate_names_and_types(&target_types, rec); validate_primitive_column::( "lat", rec, @@ -903,15 +741,18 @@ mod zarr_async_reader_tests { "float_data", rec, &[ - 1005.0, 1006.0, 1007.0, 1008.0, 1016.0, 1017.0, 1018.0, 1019.0, 1027.0, 1028.0, - 1029.0, 1030.0, 1038.0, 1039.0, 1040.0, 1041.0, + 4.0, 5.0, 6.0, 7.0, 15.0, 16.0, 17.0, 18.0, 26.0, 27.0, 28.0, 29.0, 37.0, 38.0, + 39.0, 40.0, ], ); } + #[rstest] #[tokio::test] - async fn multiple_readers_tests() { - let zp = get_v2_test_zarr_path("compression_example.zarr".to_string()); + async fn multiple_readers_tests( + #[with("async_multiple_readers_tests".to_string())] store_compression_codecs: StoreWrapper, + ) { + let zp = get_zarr_path(store_compression_codecs.store_path()); let stream1 = ZarrRecordBatchStreamBuilder::new(zp.clone()) .build_partial_reader(Some((0, 5))) .await @@ -928,13 +769,14 @@ mod zarr_async_reader_tests { ("bool_data".to_string(), DataType::Boolean), ("uint_data".to_string(), DataType::UInt64), ("int_data".to_string(), DataType::Int64), - ("float_data".to_string(), DataType::Float64), + ("float_data".to_string(), DataType::Float32), ("float_data_no_comp".to_string(), DataType::Float64), ]); + validate_names_and_types(&target_types, &records1[0]); + validate_names_and_types(&target_types, &records2[0]); // center chunk let rec = &records1[4]; - validate_names_and_types(&target_types, rec); validate_bool_column( "bool_data", rec, @@ -950,7 +792,7 @@ mod zarr_async_reader_tests { rec, &[27, 28, 29, 35, 36, 37, 43, 44, 45], ); - validate_primitive_column::( + validate_primitive_column::( "float_data", rec, &[127., 128., 129., 135., 136., 137., 143., 144., 145.], @@ -963,11 +805,10 @@ mod zarr_async_reader_tests { // bottom edge chunk let rec = &records2[2]; - validate_names_and_types(&target_types, rec); validate_bool_column("bool_data", rec, &[false, true, false, false, true, false]); validate_primitive_column::("int_data", rec, &[20, 21, 22, 28, 29, 30]); validate_primitive_column::("uint_data", rec, &[51, 52, 53, 59, 60, 61]); - validate_primitive_column::( + validate_primitive_column::( "float_data", rec, &[151.0, 152.0, 153.0, 159.0, 160.0, 161.0], @@ -979,9 +820,12 @@ mod zarr_async_reader_tests { ); } + #[rstest] #[tokio::test] - async fn empty_query_tests() { - let zp = get_v2_test_zarr_path("lat_lon_example.zarr".to_string()); + async fn empty_query_tests( + #[with("async_empty_query_tests".to_string())] store_lat_lon: StoreWrapper, + ) { + let zp = get_zarr_path(store_lat_lon.store_path()); let mut builder = ZarrRecordBatchStreamBuilder::new(zp); // set a filter that will filter out all the data, there should be nothing left after @@ -1006,18 +850,23 @@ mod zarr_async_reader_tests { assert_eq!(records.len(), 0); } + #[rstest] #[tokio::test] - async fn array_broadcast_tests() { + async fn array_broadcast_tests( + #[with("async_array_broadcast_tests_part1".to_string())] store_lat_lon: StoreWrapper, + #[with("async_array_broadcast_tests_part2".to_string())] + store_lat_lon_broadcastable: StoreWrapper, + ) { // reference that doesn't broadcast a 1D array - let zp = get_v2_test_zarr_path("lat_lon_example.zarr".to_string()); + let zp = get_zarr_path(store_lat_lon.store_path()); let mut builder = ZarrRecordBatchStreamBuilder::new(zp); builder = builder.with_filter(create_filter()); let stream = builder.build().await.unwrap(); let records: Vec<_> = stream.try_collect().await.unwrap(); - // v2 format with array broadcast - let zp = get_v2_test_zarr_path("lat_lon_example_broadcastable.zarr".to_string()); + // with array broadcast + let zp = get_zarr_path(store_lat_lon_broadcastable.store_path()); let mut builder = ZarrRecordBatchStreamBuilder::new(zp); builder = builder.with_filter(create_filter()); @@ -1028,106 +877,14 @@ mod zarr_async_reader_tests { for (rec, rec_from_one_d_repr) in records.iter().zip(records_from_one_d_repr.iter()) { assert_eq!(rec, rec_from_one_d_repr); } - - // v3 format with array broadcast - let zp = get_v3_test_zarr_path("with_broadcastable_array.zarr".to_string()); - let mut builder = ZarrRecordBatchStreamBuilder::new(zp); - - builder = builder.with_filter(create_filter()); - let stream = builder.build().await.unwrap(); - let records_from_one_d_repr: Vec<_> = stream.try_collect().await.unwrap(); - - assert_eq!(records_from_one_d_repr.len(), records.len()); - for (rec, rec_from_one_d_repr) in records.iter().zip(records_from_one_d_repr.iter()) { - assert_eq!(rec, rec_from_one_d_repr); - } - } - - fn get_v3_test_zarr_path(zarr_store: String) -> ZarrPath { - ZarrPath::new( - Arc::new(LocalFileSystem::new()), - Path::from_absolute_path(get_test_v3_data_path(zarr_store)).unwrap(), - ) - } - - #[tokio::test] - async fn with_sharding_tests() { - let zp = get_v3_test_zarr_path("with_sharding.zarr".to_string()); - let stream_builder = ZarrRecordBatchStreamBuilder::new(zp); - - let stream = stream_builder.build().await.unwrap(); - let records: Vec<_> = stream.try_collect().await.unwrap(); - - let target_types = HashMap::from([ - ("float_data".to_string(), DataType::Float64), - ("int_data".to_string(), DataType::Int64), - ]); - - let rec = &records[2]; - validate_names_and_types(&target_types, rec); - validate_primitive_column::( - "float_data", - rec, - &[ - 32.0, 33.0, 34.0, 35.0, 40.0, 41.0, 42.0, 43.0, 48.0, 49.0, 50.0, 51.0, 56.0, 57.0, - 58.0, 59.0, - ], - ); - validate_primitive_column::( - "int_data", - rec, - &[ - 32, 33, 34, 35, 40, 41, 42, 43, 48, 49, 50, 51, 56, 57, 58, 59, - ], - ); - } - - #[tokio::test] - async fn three_dims_with_sharding_with_edge_tests() { - let zp = get_v3_test_zarr_path("with_sharding_with_edge_3d.zarr".to_string()); - let stream_builder = ZarrRecordBatchStreamBuilder::new(zp); - - let stream = stream_builder.build().await.unwrap(); - let records: Vec<_> = stream.try_collect().await.unwrap(); - - let target_types = HashMap::from([("float_data".to_string(), DataType::Float64)]); - - let rec = &records[23]; - validate_names_and_types(&target_types, rec); - validate_primitive_column::( - "float_data", - rec, - &[ - 1020.0, 1021.0, 1022.0, 1031.0, 1032.0, 1033.0, 1042.0, 1043.0, 1044.0, 1053.0, - 1054.0, 1055.0, 1141.0, 1142.0, 1143.0, 1152.0, 1153.0, 1154.0, 1163.0, 1164.0, - 1165.0, 1174.0, 1175.0, 1176.0, 1262.0, 1263.0, 1264.0, 1273.0, 1274.0, 1275.0, - 1284.0, 1285.0, 1286.0, 1295.0, 1296.0, 1297.0, - ], - ); - } - - #[tokio::test] - async fn no_sharding_tests() { - let zp = get_v3_test_zarr_path("no_sharding.zarr".to_string()); - let stream_builder = ZarrRecordBatchStreamBuilder::new(zp); - - let stream = stream_builder.build().await.unwrap(); - let records: Vec<_> = stream.try_collect().await.unwrap(); - - let target_types = HashMap::from([("int_data".to_string(), DataType::Int32)]); - - let rec = &records[1]; - validate_names_and_types(&target_types, rec); - validate_primitive_column::( - "int_data", - rec, - &[4, 5, 6, 7, 20, 21, 22, 23, 36, 37, 38, 39, 52, 53, 54, 55], - ); } + #[rstest] #[tokio::test] - async fn with_partial_sharding_tests() { - let zp = get_v3_test_zarr_path("with_partial_sharding.zarr".to_string()); + async fn with_partial_sharding_tests( + #[with("async_partial_sharding_tests".to_string())] store_partial_sharding: StoreWrapper, + ) { + let zp = get_zarr_path(store_partial_sharding.store_path()); let stream_builder = ZarrRecordBatchStreamBuilder::new(zp); let stream = stream_builder.build().await.unwrap(); @@ -1137,9 +894,13 @@ mod zarr_async_reader_tests { } } + #[rstest] #[tokio::test] - async fn with_partial_sharding_3d_tests() { - let zp = get_v3_test_zarr_path("with_partial_sharding_3D.zarr".to_string()); + async fn with_partial_sharding_3d_tests( + #[with("async_partial_sharding_3d_tests".to_string())] + store_partial_sharding_3d: StoreWrapper, + ) { + let zp = get_zarr_path(store_partial_sharding_3d.store_path()); let stream_builder = ZarrRecordBatchStreamBuilder::new(zp); let stream = stream_builder.build().await.unwrap(); diff --git a/src/async_reader/zarr_read_async.rs b/src/async_reader/zarr_read_async.rs index 2c20284..88b4156 100644 --- a/src/async_reader/zarr_read_async.rs +++ b/src/async_reader/zarr_read_async.rs @@ -202,25 +202,32 @@ impl<'a> ZarrReadAsync<'a> for ZarrPath { mod zarr_read_async_tests { use object_store::{local::LocalFileSystem, path::Path}; use std::collections::HashSet; - use std::path::PathBuf; use std::sync::Arc; use super::*; use crate::reader::codecs::{Endianness, ZarrCodec, ZarrDataType}; use crate::reader::metadata::{ChunkSeparator, ZarrArrayMetadata}; use crate::reader::ZarrProjection; + use crate::test_utils::{store_raw_bytes, StoreWrapper}; + use rstest::*; - fn get_test_data_file_system() -> LocalFileSystem { - LocalFileSystem::new_with_prefix( - PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("test-data/data/zarr/v2_data"), - ) - .unwrap() - } - + #[rstest] #[tokio::test] - async fn read_metadata() { - let file_sys = get_test_data_file_system(); - let p = Path::parse("raw_bytes_example.zarr").unwrap(); + async fn read_metadata( + #[with("read_metadata_async".to_string())] store_raw_bytes: StoreWrapper, + ) { + let file_sys = LocalFileSystem::new_with_prefix(env!("CARGO_MANIFEST_DIR")).unwrap(); + let p = Path::parse( + store_raw_bytes + .store_path() + .components() + .last() + .unwrap() + .as_os_str() + .to_str() + .unwrap(), + ) + .unwrap(); let store = ZarrPath::new(Arc::new(file_sys), p); let meta = store.get_zarr_metadata().await.unwrap(); @@ -229,11 +236,11 @@ mod zarr_read_async_tests { assert_eq!( meta.get_array_meta("byte_data").unwrap(), &ZarrArrayMetadata::new( - 2, + 3, ZarrDataType::UInt(1), ChunkPattern { - separator: ChunkSeparator::Period, - c_prefix: false + separator: ChunkSeparator::Slash, + c_prefix: true }, None, vec![ZarrCodec::Bytes(Endianness::Little)], @@ -242,11 +249,11 @@ mod zarr_read_async_tests { assert_eq!( meta.get_array_meta("float_data").unwrap(), &ZarrArrayMetadata::new( - 2, + 3, ZarrDataType::Float(8), ChunkPattern { - separator: ChunkSeparator::Period, - c_prefix: false + separator: ChunkSeparator::Slash, + c_prefix: true }, None, vec![ZarrCodec::Bytes(Endianness::Little)], @@ -254,10 +261,23 @@ mod zarr_read_async_tests { ); } + #[rstest] #[tokio::test] - async fn read_raw_chunks() { - let file_sys = get_test_data_file_system(); - let p = Path::parse("raw_bytes_example.zarr").unwrap(); + async fn read_raw_chunks( + #[with("read_raw_chunks_async".to_string())] store_raw_bytes: StoreWrapper, + ) { + let file_sys = LocalFileSystem::new_with_prefix(env!("CARGO_MANIFEST_DIR")).unwrap(); + let p = Path::parse( + store_raw_bytes + .store_path() + .components() + .last() + .unwrap() + .as_os_str() + .to_str() + .unwrap(), + ) + .unwrap(); let mut io_uring_worker_pool = WorkerPool::new(32, 2).unwrap(); let store = ZarrPath::new(Arc::new(file_sys), p); diff --git a/src/bin/async-benchmark.rs b/src/bin/async-benchmark.rs deleted file mode 100644 index cde245b..0000000 --- a/src/bin/async-benchmark.rs +++ /dev/null @@ -1,56 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use arrow_zarr::async_reader::{ZarrPath, ZarrRecordBatchStreamBuilder}; -use futures::TryStreamExt; -use object_store::{local::LocalFileSystem, path::Path}; -use std::path::PathBuf; -use std::process::Command; -use std::sync::Arc; -use std::time::Instant; - -fn get_v2_test_data_path(zarr_store: String) -> ZarrPath { - let p = PathBuf::from(env!("CARGO_MANIFEST_DIR")) - .join("test-data/data/zarr/v2_data") - .join(zarr_store); - ZarrPath::new( - Arc::new(LocalFileSystem::new()), - Path::from_absolute_path(p).unwrap(), - ) -} - -fn clear_data_path_cache() { - let p = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("test-data/data/zarr/v2_data"); - - let _ = Command::new("vmtouch") - .arg("-e") - .arg(p) - .output() - .expect("vmtouch failed to start"); -} - -#[tokio::main] -async fn main() { - clear_data_path_cache(); - let zp = get_v2_test_data_path("lat_lon_example.zarr".to_string()); - let stream_builder = ZarrRecordBatchStreamBuilder::new(zp); - let stream = stream_builder.build().await.unwrap(); - - let now = Instant::now(); - let _: Vec<_> = stream.try_collect().await.unwrap(); - println!("{:?}", now.elapsed()); -} diff --git a/src/datafusion/file_opener.rs b/src/datafusion/file_opener.rs index a8fb99a..0c0dd85 100644 --- a/src/datafusion/file_opener.rs +++ b/src/datafusion/file_opener.rs @@ -84,15 +84,19 @@ mod tests { use datafusion::datasource::physical_plan::FileMeta; use object_store::{local::LocalFileSystem, path::Path, ObjectMeta}; - use crate::tests::get_test_v2_data_path; + use crate::test_utils::{store_lat_lon, StoreWrapper}; + use rstest::*; use super::*; + #[rstest] #[tokio::test] - async fn test_open() -> Result<(), Box> { + async fn test_open( + #[with("ftest_open".to_string())] store_lat_lon: StoreWrapper, + ) -> Result<(), Box> { let local_fs = LocalFileSystem::new(); - let test_data = get_test_v2_data_path("lat_lon_example.zarr".to_string()); + let test_data = store_lat_lon.store_path(); let config = ZarrConfig::new(Arc::new(local_fs)); let opener = ZarrFileOpener::new(config, None); diff --git a/src/datafusion/helpers.rs b/src/datafusion/helpers.rs index 78a074e..8832f96 100644 --- a/src/datafusion/helpers.rs +++ b/src/datafusion/helpers.rs @@ -521,20 +521,23 @@ pub fn split_files( #[cfg(test)] mod helpers_tests { use super::*; - use crate::tests::get_test_v2_data_path; + use crate::test_utils::{store_lat_lon_with_partition, StoreWrapper}; use datafusion_expr::{and, col, lit}; use itertools::Itertools; use object_store::local::LocalFileSystem; + use rstest::*; + #[rstest] #[tokio::test] - async fn test_listing_and_pruning_partitions() { - let table_path = get_test_v2_data_path("lat_lon_w_groups_example.zarr".to_string()) - .to_str() - .unwrap() - .to_string(); + async fn test_listing_and_pruning_partitions( + #[with("test_listing_and_pruning_partitions".to_string())] + store_lat_lon_with_partition: StoreWrapper, + ) { + let table_path_buf = store_lat_lon_with_partition.store_path(); + let table_path = table_path_buf.to_str().unwrap(); let store = LocalFileSystem::new(); - let url = ListingTableUrl::parse(&table_path).unwrap(); + let url = ListingTableUrl::parse(table_path).unwrap(); let partitions = list_partitions(&store, &url, 2).await.unwrap(); let expr1 = col("var").eq(lit(1_i32)); @@ -545,14 +548,14 @@ mod helpers_tests { ]; let part_1a = Partition { - path: Path::parse(&table_path) + path: Path::parse(table_path) .unwrap() .child("var=1") .child("other_var=a"), depth: 2, }; let part_1b = Partition { - path: Path::parse(&table_path) + path: Path::parse(table_path) .unwrap() .child("var=1") .child("other_var=b"), diff --git a/src/datafusion/scanner.rs b/src/datafusion/scanner.rs index cb3bfa4..d80619b 100644 --- a/src/datafusion/scanner.rs +++ b/src/datafusion/scanner.rs @@ -140,19 +140,21 @@ mod tests { use futures::TryStreamExt; use object_store::{local::LocalFileSystem, path::Path, ObjectMeta}; - use crate::{ - async_reader::{ZarrPath, ZarrReadAsync}, - tests::get_test_v2_data_path, - }; - use super::*; + use crate::async_reader::{ZarrPath, ZarrReadAsync}; + use crate::test_utils::{store_lat_lon, StoreWrapper}; + use rstest::*; + #[rstest] #[tokio::test] - async fn test_open() -> Result<(), Box> { + async fn test_scanner_open( + #[with("test_scanner_open".to_string())] store_lat_lon: StoreWrapper, + ) -> Result<(), Box> { let local_fs = Arc::new(LocalFileSystem::new()); - let test_data = get_test_v2_data_path("lat_lon_example.zarr".to_string()); - let location = Path::from_filesystem_path(&test_data)?; + let test_data_pathbuf = store_lat_lon.store_path(); + let test_data = test_data_pathbuf.to_str().unwrap(); + let location = Path::from_filesystem_path(test_data)?; let file_meta = FileMeta { object_meta: ObjectMeta { diff --git a/src/datafusion/table_factory.rs b/src/datafusion/table_factory.rs index 55f4db3..e577426 100644 --- a/src/datafusion/table_factory.rs +++ b/src/datafusion/table_factory.rs @@ -119,17 +119,18 @@ impl TableProviderFactory for ZarrListingTableFactory { #[cfg(test)] mod tests { - use crate::tests::get_test_v2_data_path; use arrow::record_batch::RecordBatch; use arrow_array::types::*; use arrow_array::{cast::AsArray, StringArray}; use arrow_buffer::ScalarBuffer; + use crate::test_utils::{store_lat_lon, store_lat_lon_with_partition, StoreWrapper}; use datafusion::execution::{ config::SessionConfig, context::{SessionContext, SessionState}, runtime_env::RuntimeEnv, }; + use rstest::*; use std::sync::Arc; fn extract_col(col_name: &str, rec_batch: &RecordBatch) -> ScalarBuffer @@ -152,8 +153,11 @@ mod tests { .to_owned() } + #[rstest] #[tokio::test] - async fn test_create() -> Result<(), Box> { + async fn test_create( + #[with("test_create".to_string())] store_lat_lon: StoreWrapper, + ) -> Result<(), Box> { let mut state = SessionState::new_with_config_rt( SessionConfig::default(), Arc::new(RuntimeEnv::default()), @@ -163,7 +167,7 @@ mod tests { .table_factories_mut() .insert("ZARR".into(), Arc::new(super::ZarrListingTableFactory {})); - let test_data = get_test_v2_data_path("lat_lon_example.zarr".to_string()); + let test_data = store_lat_lon.store_path(); let sql = format!( "CREATE EXTERNAL TABLE zarr_table STORED AS ZARR LOCATION '{}'", @@ -187,8 +191,11 @@ mod tests { Ok(()) } + #[rstest] #[tokio::test] - async fn test_predicates() -> Result<(), Box> { + async fn test_predicates( + #[with("test_predicate".to_string())] store_lat_lon: StoreWrapper, + ) -> Result<(), Box> { let mut state = SessionState::new_with_config_rt( SessionConfig::default(), Arc::new(RuntimeEnv::default()), @@ -198,7 +205,7 @@ mod tests { .table_factories_mut() .insert("ZARR".into(), Arc::new(super::ZarrListingTableFactory {})); - let test_data = get_test_v2_data_path("lat_lon_example.zarr".to_string()); + let test_data = store_lat_lon.store_path(); let sql = format!( "CREATE EXTERNAL TABLE zarr_table STORED AS ZARR LOCATION '{}'", @@ -308,8 +315,11 @@ mod tests { Ok(()) } + #[rstest] #[tokio::test] - async fn test_partitions() -> Result<(), Box> { + async fn test_partitions( + #[with("test_partitions".to_string())] store_lat_lon_with_partition: StoreWrapper, + ) -> Result<(), Box> { let mut state = SessionState::new_with_config_rt( SessionConfig::default(), Arc::new(RuntimeEnv::default()), @@ -319,13 +329,12 @@ mod tests { .table_factories_mut() .insert("ZARR".into(), Arc::new(super::ZarrListingTableFactory {})); - let test_data = get_test_v2_data_path("lat_lon_w_groups_example.zarr".to_string()); + let test_data = store_lat_lon_with_partition.store_path(); let sql = format!( "CREATE EXTERNAL TABLE zarr_table ( lat double, lon double, - float_data double, var int, other_var string ) diff --git a/src/datafusion/table_provider.rs b/src/datafusion/table_provider.rs index edc4b0f..d8e928c 100644 --- a/src/datafusion/table_provider.rs +++ b/src/datafusion/table_provider.rs @@ -36,7 +36,7 @@ use futures::StreamExt; use crate::{ async_reader::{ZarrPath, ZarrReadAsync}, - reader::{ZarrError, ZarrResult}, + reader::ZarrResult, }; use super::helpers::{expr_applicable_for_cols, pruned_partition_list, split_files}; @@ -79,42 +79,53 @@ impl ListingZarrTableOptions { let store = state.runtime_env().object_store(table_path)?; let prefix = table_path.prefix(); - let n_partitions = self.table_partition_cols.len(); - let mut files = table_path.list_all_files(state, &store, "zgroup").await?; - let mut schema_to_return: Option = None; - while let Some(file) = files.next().await { - let mut p = prefix.clone(); - let file = file?.location; - for (cnt, part) in file.prefix_match(prefix).unwrap().enumerate() { - if cnt == n_partitions { - if let Some(ext) = file.extension() { - if ext == "zgroup" { - let schema = ZarrPath::new(store.clone(), p.clone()) - .get_zarr_metadata() - .await? - .arrow_schema()?; - if let Some(sch) = &schema_to_return { - if sch != &schema { - return Err(ZarrError::InvalidMetadata( - "mismatch between different partition schemas".into(), - )); - } - } else { - schema_to_return = Some(schema); - } - } - } - } - p = p.child(part); - } - } + // this is clearly not correct, but I don't think the commented + // out logic, for when we need to infer a schema but there are + // partitions, works either. for now I'll just hack this so that + // I can test most of the logic, I will refactor everything with + // zarrs anyway so I will revisit shortly. + let schema = ZarrPath::new(store.clone(), prefix.clone()) + .get_zarr_metadata() + .await? + .arrow_schema()?; + Ok(schema) - if let Some(schema_to_return) = schema_to_return { - return Ok(schema_to_return); - } - Err(ZarrError::InvalidMetadata( - "could not infer schema for zarr table path".into(), - )) + // let n_partitions = self.table_partition_cols.len(); + // let mut files = table_path.list_all_files(state, &store, "zgroup").await?; + // let mut schema_to_return: Option = None; + // while let Some(file) = files.next().await { + // let mut p = prefix.clone(); + // let file = file?.location; + // for (cnt, part) in file.prefix_match(prefix).unwrap().enumerate() { + // if cnt == n_partitions { + // if let Some(ext) = file.extension() { + // if ext == "zgroup" { + // let schema = ZarrPath::new(store.clone(), p.clone()) + // .get_zarr_metadata() + // .await? + // .arrow_schema()?; + // if let Some(sch) = &schema_to_return { + // if sch != &schema { + // return Err(ZarrError::InvalidMetadata( + // "mismatch between different partition schemas".into(), + // )); + // } + // } else { + // schema_to_return = Some(schema); + // } + // } + // } + // } + // p = p.child(part); + // } + // } + + // if let Some(schema_to_return) = schema_to_return { + // return Ok(schema_to_return); + // } + // Err(ZarrError::InvalidMetadata( + // "could not infer schema for zarr table path".into(), + // )) } } diff --git a/src/lib.rs b/src/lib.rs index 66af346..609da4b 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -22,18 +22,1113 @@ pub mod reader; pub mod datafusion; #[cfg(test)] -mod tests { +mod test_utils { + use crate::reader::{ + ZarrArrowPredicate, ZarrArrowPredicateFn, ZarrChunkFilter, ZarrProjection, + }; + use arrow::compute::kernels::cmp::{gt_eq, lt}; + use arrow_array::cast::AsArray; + use arrow_array::types::*; + use arrow_array::RecordBatch; + use arrow_array::*; + use itertools::enumerate; + use ndarray::{Array, Array1, Array2, Array3}; + use rstest::*; use std::path::PathBuf; + use std::sync::Arc; + use std::{collections::HashMap, fmt::Debug}; + use zarrs::array::codec::array_to_bytes::sharding::ShardingCodecBuilder; + use zarrs::array::{codec, ArrayBuilder, DataType, FillValue}; + use zarrs::array_subset::ArraySubset; + use zarrs_filesystem::FilesystemStore; + use zarrs_storage::{StorePrefix, WritableStorageTraits}; - pub(crate) fn get_test_v2_data_path(zarr_store: String) -> PathBuf { - PathBuf::from(env!("CARGO_MANIFEST_DIR")) - .join("test-data/data/zarr/v2_data") - .join(zarr_store) + fn create_zarr_store(store_name: String) -> FilesystemStore { + let p = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join(store_name); + FilesystemStore::new(p).unwrap() } - pub(crate) fn get_test_v3_data_path(zarr_array: String) -> PathBuf { - PathBuf::from(env!("CARGO_MANIFEST_DIR")) - .join("test-data/data/zarr/v3_data") - .join(zarr_array) + fn clear_store(store: &Arc) { + let prefix = StorePrefix::new("").unwrap(); + store.erase_prefix(&prefix).unwrap(); } + + fn get_lz4_compressor() -> codec::BloscCodec { + codec::BloscCodec::new( + codec::bytes_to_bytes::blosc::BloscCompressor::LZ4, + 5.try_into().unwrap(), + Some(0), + codec::bytes_to_bytes::blosc::BloscShuffleMode::NoShuffle, + Some(1), + ) + .unwrap() + } + + // we won't actually use this, it's a place holder + #[fixture] + fn dummy_name() -> String { + "test_store".to_string() + } + + // convenience class to make sure the stores get cleanup + // after we're done running a test. + pub(crate) struct StoreWrapper { + store: Arc, + } + + impl StoreWrapper { + fn new(store_name: String) -> Self { + StoreWrapper { + store: Arc::new(create_zarr_store(store_name)), + } + } + + pub(crate) fn get_store(&self) -> Arc { + self.store.clone() + } + + pub(crate) fn store_path(&self) -> PathBuf { + self.store.prefix_to_fs_path(&StorePrefix::new("").unwrap()) + } + } + + impl Drop for StoreWrapper { + fn drop(&mut self) { + clear_store(&self.store); + } + } + + // various fixtures to create some test data on the fly. + #[fixture] + pub(crate) fn store_raw_bytes(dummy_name: String) -> StoreWrapper { + // create the store + let store_wrapper = StoreWrapper::new(dummy_name); + let store = store_wrapper.get_store(); + + // uint array with no compression + let array = ArrayBuilder::new( + vec![9, 9], + DataType::UInt8, + vec![3, 3].try_into().unwrap(), + FillValue::new(vec![0]), + ) + .build(store.clone(), "/byte_data") + .unwrap(); + array.store_metadata().unwrap(); + + let arr: Array2 = Array::from_vec((0..81).collect()) + .into_shape_with_order((9, 9)) + .unwrap(); + array + .store_array_subset_ndarray(ArraySubset::new_with_ranges(&[0..9, 0..9]).start(), arr) + .unwrap(); + + // float data with no compression + let array = ArrayBuilder::new( + vec![9, 9], + DataType::Float64, + vec![3, 3].try_into().unwrap(), + FillValue::new(vec![0; 8]), + ) + .build(store.clone(), "/float_data") + .unwrap(); + array.store_metadata().unwrap(); + + let arr: Array2 = Array::range(0.0, 81.0, 1.0) + .into_shape_with_order((9, 9)) + .unwrap(); + array + .store_array_subset_ndarray(ArraySubset::new_with_ranges(&[0..9, 0..9]).start(), arr) + .unwrap(); + + store_wrapper + } + + #[fixture] + pub(crate) fn store_compression_codecs(dummy_name: String) -> StoreWrapper { + // create the store + let store_wrapper = StoreWrapper::new(dummy_name); + let store = store_wrapper.get_store(); + + // bool array with blosc lz4 compression + let array = ArrayBuilder::new( + vec![8, 8], + DataType::Bool, + vec![3, 3].try_into().unwrap(), + FillValue::new(vec![0]), + ) + .bytes_to_bytes_codecs(vec![Arc::new(get_lz4_compressor())]) + .build(store.clone(), "/bool_data") + .unwrap(); + array.store_metadata().unwrap(); + + let mut v: Vec = Vec::with_capacity(64); + for i in 0..64 { + v.push(i % 2 == 0); + } + let arr: Array2 = Array::from_vec(v).into_shape_with_order((8, 8)).unwrap(); + array + .store_array_subset_ndarray(ArraySubset::new_with_ranges(&[0..8, 0..8]).start(), arr) + .unwrap(); + + // uint array with blosc zlib compression + let codec = codec::BloscCodec::new( + codec::bytes_to_bytes::blosc::BloscCompressor::Zlib, + 5.try_into().unwrap(), + Some(0), + codec::bytes_to_bytes::blosc::BloscShuffleMode::Shuffle, + Some(1), + ) + .unwrap(); + + let array = ArrayBuilder::new( + vec![8, 8], + DataType::UInt64, + vec![3, 3].try_into().unwrap(), + FillValue::new(vec![0; 8]), + ) + .bytes_to_bytes_codecs(vec![Arc::new(codec)]) + .build(store.clone(), "/uint_data") + .unwrap(); + array.store_metadata().unwrap(); + + let arr: Array2 = Array::from_vec((0..64).collect()) + .into_shape_with_order((8, 8)) + .unwrap(); + array + .store_array_subset_ndarray(ArraySubset::new_with_ranges(&[0..8, 0..8]).start(), arr) + .unwrap(); + + // int array with zstd compression + let codec = codec::BloscCodec::new( + codec::bytes_to_bytes::blosc::BloscCompressor::Zlib, + 3.try_into().unwrap(), + Some(0), + codec::bytes_to_bytes::blosc::BloscShuffleMode::BitShuffle, + Some(1), + ) + .unwrap(); + + let array = ArrayBuilder::new( + vec![8, 8], + DataType::Int64, + vec![3, 3].try_into().unwrap(), + FillValue::new(vec![0; 8]), + ) + .bytes_to_bytes_codecs(vec![Arc::new(codec)]) + .build(store.clone(), "/int_data") + .unwrap(); + array.store_metadata().unwrap(); + + let arr: Array2 = Array::from_vec((-31..33).collect()) + .into_shape_with_order((8, 8)) + .unwrap(); + array + .store_array_subset_ndarray(ArraySubset::new_with_ranges(&[0..8, 0..8]).start(), arr) + .unwrap(); + + // float32 array with blosclz compression + let codec = codec::BloscCodec::new( + codec::bytes_to_bytes::blosc::BloscCompressor::BloscLZ, + 7.try_into().unwrap(), + Some(0), + codec::bytes_to_bytes::blosc::BloscShuffleMode::NoShuffle, + Some(1), + ) + .unwrap(); + + let array = ArrayBuilder::new( + vec![8, 8], + DataType::Float32, + vec![3, 3].try_into().unwrap(), + FillValue::new(vec![0; 4]), + ) + .bytes_to_bytes_codecs(vec![Arc::new(codec)]) + .build(store.clone(), "/float_data") + .unwrap(); + array.store_metadata().unwrap(); + + let arr: Array2 = Array::range(100.0, 164.0, 1.0) + .into_shape_with_order((8, 8)) + .unwrap(); + array + .store_array_subset_ndarray(ArraySubset::new_with_ranges(&[0..8, 0..8]).start(), arr) + .unwrap(); + + // float64 array with no compression + let array = ArrayBuilder::new( + vec![8, 8], + DataType::Float64, + vec![3, 3].try_into().unwrap(), + FillValue::new(vec![0; 8]), + ) + .build(store.clone(), "/float_data_no_comp") + .unwrap(); + array.store_metadata().unwrap(); + + let arr: Array2 = Array::range(200.0, 264.0, 1.0) + .into_shape_with_order((8, 8)) + .unwrap(); + array + .store_array_subset_ndarray(ArraySubset::new_with_ranges(&[0..8, 0..8]).start(), arr) + .unwrap(); + + store_wrapper + } + + #[fixture] + pub(crate) fn store_endianness_and_order(dummy_name: String) -> StoreWrapper { + // create the store + let store_wrapper = StoreWrapper::new(dummy_name); + let store = store_wrapper.get_store(); + + // big endian and F order + let array = ArrayBuilder::new( + vec![10, 11], + DataType::Int32, + vec![3, 3].try_into().unwrap(), + FillValue::new(vec![0; 4]), + ) + .bytes_to_bytes_codecs(vec![Arc::new(get_lz4_compressor())]) + .array_to_bytes_codec(Arc::new(codec::BytesCodec::new(Some( + zarrs::array::Endianness::Big, + )))) + .array_to_array_codecs(vec![Arc::new(codec::TransposeCodec::new( + codec::array_to_array::transpose::TransposeOrder::new(&[1, 0]).unwrap(), + ))]) + .build(store.clone(), "/int_data_big_endian_f_order") + .unwrap(); + array.store_metadata().unwrap(); + + let arr: Array2 = Array::from_vec((0..110).collect()) + .into_shape_with_order((10, 11)) + .unwrap(); + array + .store_array_subset_ndarray(ArraySubset::new_with_ranges(&[0..10, 0..11]).start(), arr) + .unwrap(); + + // little endian and C order + let array = ArrayBuilder::new( + vec![10, 11], + DataType::Int32, + vec![3, 3].try_into().unwrap(), + FillValue::new(vec![0; 4]), + ) + .bytes_to_bytes_codecs(vec![Arc::new(get_lz4_compressor())]) + .build(store.clone(), "/int_data") + .unwrap(); + array.store_metadata().unwrap(); + + let arr: Array2 = Array::from_vec((0..110).collect()) + .into_shape_with_order((10, 11)) + .unwrap(); + array + .store_array_subset_ndarray(ArraySubset::new_with_ranges(&[0..10, 0..11]).start(), arr) + .unwrap(); + + store_wrapper + } + + #[fixture] + pub(crate) fn store_endianness_and_order_3d(dummy_name: String) -> StoreWrapper { + // create the store + let store_wrapper = StoreWrapper::new(dummy_name); + let store = store_wrapper.get_store(); + + // big endian and F order + let array = ArrayBuilder::new( + vec![10, 11, 12], + DataType::Int32, + vec![3, 4, 5].try_into().unwrap(), + FillValue::new(vec![0; 4]), + ) + .bytes_to_bytes_codecs(vec![Arc::new(get_lz4_compressor())]) + .array_to_bytes_codec(Arc::new(codec::BytesCodec::new(Some( + zarrs::array::Endianness::Big, + )))) + .array_to_array_codecs(vec![Arc::new(codec::TransposeCodec::new( + codec::array_to_array::transpose::TransposeOrder::new(&[2, 1, 0]).unwrap(), + ))]) + .build(store.clone(), "/int_data_big_endian_f_order") + .unwrap(); + array.store_metadata().unwrap(); + + let arr: Array3 = Array::from_vec((0..(10 * 11 * 12)).collect()) + .into_shape_with_order((10, 11, 12)) + .unwrap(); + array + .store_array_subset_ndarray( + ArraySubset::new_with_ranges(&[0..10, 0..11, 0..12]).start(), + arr, + ) + .unwrap(); + + // little endian and C order + let array = ArrayBuilder::new( + vec![10, 11, 12], + DataType::Int32, + vec![3, 4, 5].try_into().unwrap(), + FillValue::new(vec![0; 4]), + ) + .bytes_to_bytes_codecs(vec![Arc::new(get_lz4_compressor())]) + .build(store.clone(), "/int_data") + .unwrap(); + array.store_metadata().unwrap(); + + let arr: Array3 = Array::from_vec((0..(10 * 11 * 12)).collect()) + .into_shape_with_order((10, 11, 12)) + .unwrap(); + array + .store_array_subset_ndarray( + ArraySubset::new_with_ranges(&[0..10, 0..11, 0..12]).start(), + arr, + ) + .unwrap(); + + store_wrapper + } + + // don't need that for now, I commented out tests with string data. + // #[fixture] + // pub(crate) fn store_strings(dummy_name: String) -> StoreWrapper { + // // create the store + // let store_wrapper = StoreWrapper::new(dummy_name); + // let store = store_wrapper.get_store(); + + // // integer data, for validation + // let array = ArrayBuilder::new( + // vec![8, 8], + // DataType::Int32, + // vec![3, 3].try_into().unwrap(), + // FillValue::new(vec![0; 4]), + // ) + // .bytes_to_bytes_codecs(vec![ + // Arc::new(get_lz4_compressor()), + // ]) + // .build(store.clone(), "/int_data") + // .unwrap(); + // array.store_metadata().unwrap(); + + // let arr: Array2 = Array::from_vec((0..64).collect()).into_shape_with_order((8, 8)).unwrap(); + // array.store_array_subset_ndarray( + // ArraySubset::new_with_ranges(&[0..8, 0..8]).start(), + // arr, + // ).unwrap(); + + // // some string data + // let array = ArrayBuilder::new( + // vec![8, 8], + // DataType::String, + // vec![3, 3].try_into().unwrap(), + // FillValue::from(" "), + // ) + // .bytes_to_bytes_codecs(vec![ + // Arc::new(get_lz4_compressor()), + // ]) + // .build(store.clone(), "/string_data") + // .unwrap(); + // array.store_metadata().unwrap(); + + // let arr: Array2 = Array::from_vec( + // (0..64).map(|i| format!("abc{:0>2}", i)).collect() + // ) + // .into_shape_with_order((8, 8)) + // .unwrap(); + // array.store_array_subset_ndarray( + // ArraySubset::new_with_ranges(&[0..8, 0..8]).start(), + // arr, + // ).unwrap(); + + // store_wrapper + // } + + #[fixture] + pub(crate) fn store_1d(dummy_name: String) -> StoreWrapper { + // create the store + let store_wrapper = StoreWrapper::new(dummy_name); + let store = store_wrapper.get_store(); + + // integer data + let array = ArrayBuilder::new( + vec![11], + DataType::Int32, + vec![3].try_into().unwrap(), + FillValue::new(vec![0; 4]), + ) + .bytes_to_bytes_codecs(vec![Arc::new(get_lz4_compressor())]) + .build(store.clone(), "/int_data") + .unwrap(); + array.store_metadata().unwrap(); + + let arr: Array1 = Array::from_vec((-5..6).collect()); + array.store_array_subset_ndarray(&[0], arr).unwrap(); + + // float data + let array = ArrayBuilder::new( + vec![11], + DataType::Float32, + vec![3].try_into().unwrap(), + FillValue::new(vec![0; 4]), + ) + .bytes_to_bytes_codecs(vec![Arc::new(get_lz4_compressor())]) + .build(store.clone(), "/float_data") + .unwrap(); + array.store_metadata().unwrap(); + + let arr: Array1 = Array::range(100.0, 111.0, 1.0); + array.store_array_subset_ndarray(&[0], arr).unwrap(); + + store_wrapper + } + + #[fixture] + pub(crate) fn store_3d(dummy_name: String) -> StoreWrapper { + // create the store + let store_wrapper = StoreWrapper::new(dummy_name); + let store = store_wrapper.get_store(); + + // integer data + let array = ArrayBuilder::new( + vec![5, 5, 5], + DataType::Int32, + vec![2, 2, 2].try_into().unwrap(), + FillValue::new(vec![0; 4]), + ) + .bytes_to_bytes_codecs(vec![Arc::new(get_lz4_compressor())]) + .build(store.clone(), "/int_data") + .unwrap(); + array.store_metadata().unwrap(); + + let arr: Array3 = Array::from_vec((-62..63).collect()) + .into_shape_with_order((5, 5, 5)) + .unwrap(); + array + .store_array_subset_ndarray( + ArraySubset::new_with_ranges(&[0..5, 0..5, 0..5]).start(), + arr, + ) + .unwrap(); + + // float data + let array = ArrayBuilder::new( + vec![5, 5, 5], + DataType::Float32, + vec![2, 2, 2].try_into().unwrap(), + FillValue::new(vec![0; 4]), + ) + .bytes_to_bytes_codecs(vec![Arc::new(get_lz4_compressor())]) + .build(store.clone(), "/float_data") + .unwrap(); + array.store_metadata().unwrap(); + + let arr: Array3 = Array::range(100.0, 225.0, 1.0) + .into_shape_with_order((5, 5, 5)) + .unwrap(); + array + .store_array_subset_ndarray( + ArraySubset::new_with_ranges(&[0..5, 0..5, 0..5]).start(), + arr, + ) + .unwrap(); + + store_wrapper + } + + #[fixture] + pub(crate) fn store_lat_lon(dummy_name: String) -> StoreWrapper { + // create the store + let store_wrapper = StoreWrapper::new(dummy_name); + let store = store_wrapper.get_store(); + + // latitude + let array = ArrayBuilder::new( + vec![11, 11], + DataType::Float64, + vec![4, 4].try_into().unwrap(), + FillValue::new(vec![0; 8]), + ) + .bytes_to_bytes_codecs(vec![Arc::new(get_lz4_compressor())]) + .build(store.clone(), "/lat") + .unwrap(); + array.store_metadata().unwrap(); + + let mut v = vec![ + 38., 38.1, 38.2, 38.3, 38.4, 38.5, 38.6, 38.7, 38.8, 38.9, 39., + ]; + for _ in 0..10 { + v.extend_from_within(..11); + } + + let arr: Array2 = Array::from_vec(v).into_shape_with_order((11, 11)).unwrap(); + array + .store_array_subset_ndarray(ArraySubset::new_with_ranges(&[0..11, 0..11]).start(), arr) + .unwrap(); + + // longitude + let array = ArrayBuilder::new( + vec![11, 11], + DataType::Float64, + vec![4, 4].try_into().unwrap(), + FillValue::new(vec![0; 8]), + ) + .bytes_to_bytes_codecs(vec![Arc::new(get_lz4_compressor())]) + .build(store.clone(), "/lon") + .unwrap(); + array.store_metadata().unwrap(); + + let mut v = vec![ + -110., -109.9, -109.8, -109.7, -109.6, -109.5, -109.4, -109.3, -109.2, -109.1, -109., + ]; + for _ in 0..10 { + v.extend_from_within(..11); + } + + let mut arr: Array2 = Array::from_vec(v).into_shape_with_order((11, 11)).unwrap(); + arr.swap_axes(1, 0); + array + .store_array_subset_ndarray(ArraySubset::new_with_ranges(&[0..11, 0..11]).start(), arr) + .unwrap(); + + // float data + let array = ArrayBuilder::new( + vec![11, 11], + DataType::Float64, + vec![4, 4].try_into().unwrap(), + FillValue::new(vec![0; 8]), + ) + .build(store.clone(), "/float_data") + .unwrap(); + array.store_metadata().unwrap(); + + let arr: Array2 = Array::range(0.0, 121.0, 1.0) + .into_shape_with_order((11, 11)) + .unwrap(); + array + .store_array_subset_ndarray(ArraySubset::new_with_ranges(&[0..11, 0..11]).start(), arr) + .unwrap(); + + store_wrapper + } + + #[fixture] + pub(crate) fn store_lat_lon_broadcastable(dummy_name: String) -> StoreWrapper { + // create the store + let store_wrapper = StoreWrapper::new(dummy_name); + let store = store_wrapper.get_store(); + + // latitude + let array = ArrayBuilder::new( + vec![11], + DataType::Float64, + vec![4].try_into().unwrap(), + FillValue::new(vec![0; 8]), + ) + .bytes_to_bytes_codecs(vec![Arc::new(get_lz4_compressor())]) + .attributes( + serde_json::from_str( + r#"{ + "broadcast_params": { + "target_shape": [11, 11], + "target_chunks": [4, 4], + "axis": 1 + } + }"#, + ) + .unwrap(), + ) + .build(store.clone(), "/lat") + .unwrap(); + array.store_metadata().unwrap(); + + let v = vec![ + 38., 38.1, 38.2, 38.3, 38.4, 38.5, 38.6, 38.7, 38.8, 38.9, 39., + ]; + let arr: Array1 = Array::from_vec(v); + array.store_array_subset_ndarray(&[0], arr).unwrap(); + + // longitude + let array = ArrayBuilder::new( + vec![11], + DataType::Float64, + vec![4].try_into().unwrap(), + FillValue::new(vec![0; 8]), + ) + .bytes_to_bytes_codecs(vec![Arc::new(get_lz4_compressor())]) + .attributes( + serde_json::from_str( + r#"{ + "broadcast_params": { + "target_shape": [11, 11], + "target_chunks": [4, 4], + "axis": 0 + } + }"#, + ) + .unwrap(), + ) + .build(store.clone(), "/lon") + .unwrap(); + array.store_metadata().unwrap(); + + let v = vec![ + -110., -109.9, -109.8, -109.7, -109.6, -109.5, -109.4, -109.3, -109.2, -109.1, -109., + ]; + let arr: Array1 = Array::from_vec(v); + array.store_array_subset_ndarray(&[0], arr).unwrap(); + + // float data + let array = ArrayBuilder::new( + vec![11, 11], + DataType::Float64, + vec![4, 4].try_into().unwrap(), + FillValue::new(vec![0; 8]), + ) + .build(store.clone(), "/float_data") + .unwrap(); + array.store_metadata().unwrap(); + + let arr: Array2 = Array::range(0.0, 121.0, 1.0) + .into_shape_with_order((11, 11)) + .unwrap(); + array + .store_array_subset_ndarray(ArraySubset::new_with_ranges(&[0..11, 0..11]).start(), arr) + .unwrap(); + + store_wrapper + } + + #[fixture] + pub(crate) fn store_partial_sharding(dummy_name: String) -> StoreWrapper { + // create the store + let store_wrapper = StoreWrapper::new(dummy_name); + let store = store_wrapper.get_store(); + + // float data with sharding + let sharding_chunk = vec![3, 2]; + let mut codec_builder = + ShardingCodecBuilder::new(sharding_chunk.as_slice().try_into().unwrap()); + codec_builder.bytes_to_bytes_codecs(vec![Arc::new(get_lz4_compressor())]); + let array = ArrayBuilder::new( + vec![11, 10], + DataType::Float64, + vec![6, 4].try_into().unwrap(), + FillValue::new(vec![0; 8]), + ) + .array_to_bytes_codec(Arc::new(codec_builder.build())) + .build(store.clone(), "/float_data_sharded") + .unwrap(); + array.store_metadata().unwrap(); + + let arr: Array2 = Array::range(0.0, 110.0, 1.0) + .into_shape_with_order((11, 10)) + .unwrap(); + array + .store_array_subset_ndarray(ArraySubset::new_with_ranges(&[0..11, 0..10]).start(), arr) + .unwrap(); + + // float data without sharding + let array = ArrayBuilder::new( + vec![11, 10], + DataType::Float64, + vec![6, 4].try_into().unwrap(), + FillValue::new(vec![0; 8]), + ) + .build(store.clone(), "/float_data_not_sharded") + .unwrap(); + array.store_metadata().unwrap(); + + let arr: Array2 = Array::range(0.0, 110.0, 1.0) + .into_shape_with_order((11, 10)) + .unwrap(); + array + .store_array_subset_ndarray(ArraySubset::new_with_ranges(&[0..11, 0..10]).start(), arr) + .unwrap(); + + store_wrapper + } + + #[fixture] + pub(crate) fn store_partial_sharding_3d(dummy_name: String) -> StoreWrapper { + // create the store + let store_wrapper = StoreWrapper::new(dummy_name); + let store = store_wrapper.get_store(); + + // float data with sharding + let sharding_chunk = vec![3, 2, 4]; + let mut codec_builder = + ShardingCodecBuilder::new(sharding_chunk.as_slice().try_into().unwrap()); + codec_builder.bytes_to_bytes_codecs(vec![Arc::new(get_lz4_compressor())]); + let array = ArrayBuilder::new( + vec![11, 10, 9], + DataType::Float64, + vec![6, 4, 8].try_into().unwrap(), + FillValue::new(vec![0; 8]), + ) + .array_to_bytes_codec(Arc::new(codec_builder.build())) + .build(store.clone(), "/float_data_sharded") + .unwrap(); + array.store_metadata().unwrap(); + + let arr: Array3 = Array::range(0.0, 990.0, 1.0) + .into_shape_with_order((11, 10, 9)) + .unwrap(); + array + .store_array_subset_ndarray( + ArraySubset::new_with_ranges(&[0..11, 0..10, 0..9]).start(), + arr, + ) + .unwrap(); + + // float data without sharding + let array = ArrayBuilder::new( + vec![11, 10, 9], + DataType::Float64, + vec![6, 4, 8].try_into().unwrap(), + FillValue::new(vec![0; 8]), + ) + .build(store.clone(), "/float_data_not_sharded") + .unwrap(); + array.store_metadata().unwrap(); + + let arr: Array3 = Array::range(0.0, 990.0, 1.0) + .into_shape_with_order((11, 10, 9)) + .unwrap(); + array + .store_array_subset_ndarray( + ArraySubset::new_with_ranges(&[0..11, 0..10, 0..9]).start(), + arr, + ) + .unwrap(); + + store_wrapper + } + + #[fixture] + pub(crate) fn store_lat_lon_with_partition(dummy_name: String) -> StoreWrapper { + // create the store + let store_wrapper = StoreWrapper::new(dummy_name); + let store = store_wrapper.get_store(); + + //var=1, other_var=a + // latitude + let array = ArrayBuilder::new( + vec![11, 11], + DataType::Float64, + vec![4, 4].try_into().unwrap(), + FillValue::new(vec![0; 8]), + ) + .bytes_to_bytes_codecs(vec![Arc::new(get_lz4_compressor())]) + .build(store.clone(), "/var=1/other_var=a/lat") + .unwrap(); + array.store_metadata().unwrap(); + + let mut v = vec![ + 38., 38.1, 38.2, 38.3, 38.4, 38.5, 38.6, 38.7, 38.8, 38.9, 39., + ]; + for _ in 0..10 { + v.extend_from_within(..11); + } + + let arr: Array2 = Array::from_vec(v).into_shape_with_order((11, 11)).unwrap(); + array + .store_array_subset_ndarray(ArraySubset::new_with_ranges(&[0..11, 0..11]).start(), arr) + .unwrap(); + + // longitude + let array = ArrayBuilder::new( + vec![11, 11], + DataType::Float64, + vec![4, 4].try_into().unwrap(), + FillValue::new(vec![0; 8]), + ) + .bytes_to_bytes_codecs(vec![Arc::new(get_lz4_compressor())]) + .build(store.clone(), "/var=1/other_var=a/lon") + .unwrap(); + array.store_metadata().unwrap(); + + let mut v = vec![ + -110., -109.9, -109.8, -109.7, -109.6, -109.5, -109.4, -109.3, -109.2, -109.1, -109., + ]; + for _ in 0..10 { + v.extend_from_within(..11); + } + + let mut arr: Array2 = Array::from_vec(v).into_shape_with_order((11, 11)).unwrap(); + arr.swap_axes(1, 0); + array + .store_array_subset_ndarray(ArraySubset::new_with_ranges(&[0..11, 0..11]).start(), arr) + .unwrap(); + + //var=2, other_var=a + // latitude + let array = ArrayBuilder::new( + vec![11, 11], + DataType::Float64, + vec![4, 4].try_into().unwrap(), + FillValue::new(vec![0; 8]), + ) + .bytes_to_bytes_codecs(vec![Arc::new(get_lz4_compressor())]) + .build(store.clone(), "/var=2/other_var=a/lat") + .unwrap(); + array.store_metadata().unwrap(); + + let mut v = vec![ + 39., 39.1, 39.2, 39.3, 39.4, 39.5, 39.6, 39.7, 39.8, 39.9, 40., + ]; + for _ in 0..10 { + v.extend_from_within(..11); + } + + let arr: Array2 = Array::from_vec(v).into_shape_with_order((11, 11)).unwrap(); + array + .store_array_subset_ndarray(ArraySubset::new_with_ranges(&[0..11, 0..11]).start(), arr) + .unwrap(); + + // longitude + let array = ArrayBuilder::new( + vec![11, 11], + DataType::Float64, + vec![4, 4].try_into().unwrap(), + FillValue::new(vec![0; 8]), + ) + .bytes_to_bytes_codecs(vec![Arc::new(get_lz4_compressor())]) + .build(store.clone(), "/var=2/other_var=a/lon") + .unwrap(); + array.store_metadata().unwrap(); + + let mut v = vec![ + -110., -109.9, -109.8, -109.7, -109.6, -109.5, -109.4, -109.3, -109.2, -109.1, -109., + ]; + for _ in 0..10 { + v.extend_from_within(..11); + } + + let mut arr: Array2 = Array::from_vec(v).into_shape_with_order((11, 11)).unwrap(); + arr.swap_axes(1, 0); + array + .store_array_subset_ndarray(ArraySubset::new_with_ranges(&[0..11, 0..11]).start(), arr) + .unwrap(); + + //var=1, other_var=b + // latitude + let array = ArrayBuilder::new( + vec![11, 11], + DataType::Float64, + vec![4, 4].try_into().unwrap(), + FillValue::new(vec![0; 8]), + ) + .bytes_to_bytes_codecs(vec![Arc::new(get_lz4_compressor())]) + .build(store.clone(), "/var=1/other_var=b/lat") + .unwrap(); + array.store_metadata().unwrap(); + + let mut v = vec![ + 38., 38.1, 38.2, 38.3, 38.4, 38.5, 38.6, 38.7, 38.8, 38.9, 39., + ]; + for _ in 0..10 { + v.extend_from_within(..11); + } + + let arr: Array2 = Array::from_vec(v).into_shape_with_order((11, 11)).unwrap(); + array + .store_array_subset_ndarray(ArraySubset::new_with_ranges(&[0..11, 0..11]).start(), arr) + .unwrap(); + + // longitude + let array = ArrayBuilder::new( + vec![11, 11], + DataType::Float64, + vec![4, 4].try_into().unwrap(), + FillValue::new(vec![0; 8]), + ) + .bytes_to_bytes_codecs(vec![Arc::new(get_lz4_compressor())]) + .build(store.clone(), "/var=1/other_var=b/lon") + .unwrap(); + array.store_metadata().unwrap(); + + let mut v = vec![ + -108.9, -108.8, -108.7, -108.6, -108.5, -108.4, -108.3, -108.2, -108.1, -108.0, -107.9, + ]; + for _ in 0..10 { + v.extend_from_within(..11); + } + + let mut arr: Array2 = Array::from_vec(v).into_shape_with_order((11, 11)).unwrap(); + arr.swap_axes(1, 0); + array + .store_array_subset_ndarray(ArraySubset::new_with_ranges(&[0..11, 0..11]).start(), arr) + .unwrap(); + + //var=2, other_var=b + // latitude + let array = ArrayBuilder::new( + vec![11, 11], + DataType::Float64, + vec![4, 4].try_into().unwrap(), + FillValue::new(vec![0; 8]), + ) + .bytes_to_bytes_codecs(vec![Arc::new(get_lz4_compressor())]) + .build(store.clone(), "/var=2/other_var=b/lat") + .unwrap(); + array.store_metadata().unwrap(); + + let mut v = vec![ + 39., 39.1, 39.2, 39.3, 39.4, 39.5, 39.6, 39.7, 39.8, 39.9, 40., + ]; + for _ in 0..10 { + v.extend_from_within(..11); + } + + let arr: Array2 = Array::from_vec(v).into_shape_with_order((11, 11)).unwrap(); + array + .store_array_subset_ndarray(ArraySubset::new_with_ranges(&[0..11, 0..11]).start(), arr) + .unwrap(); + + // longitude + let array = ArrayBuilder::new( + vec![11, 11], + DataType::Float64, + vec![4, 4].try_into().unwrap(), + FillValue::new(vec![0; 8]), + ) + .bytes_to_bytes_codecs(vec![Arc::new(get_lz4_compressor())]) + .build(store.clone(), "/var=2/other_var=b/lon") + .unwrap(); + array.store_metadata().unwrap(); + + let mut v = vec![ + -108.9, -108.8, -108.7, -108.6, -108.5, -108.4, -108.3, -108.2, -108.1, -108.0, -107.9, + ]; + for _ in 0..10 { + v.extend_from_within(..11); + } + + let mut arr: Array2 = Array::from_vec(v).into_shape_with_order((11, 11)).unwrap(); + arr.swap_axes(1, 0); + array + .store_array_subset_ndarray(ArraySubset::new_with_ranges(&[0..11, 0..11]).start(), arr) + .unwrap(); + + store_wrapper + } + + pub(crate) fn validate_names_and_types( + targets: &HashMap, + rec: &RecordBatch, + ) { + let mut target_cols: Vec<&String> = targets.keys().collect(); + let schema = rec.schema(); + let from_rec: Vec<&String> = schema.fields.iter().map(|f| f.name()).collect(); + + target_cols.sort(); + assert_eq!(from_rec, target_cols); + + for field in schema.fields.iter() { + assert_eq!(field.data_type(), targets.get(field.name()).unwrap()); + } + } + + pub(crate) fn validate_bool_column(col_name: &str, rec: &RecordBatch, targets: &[bool]) { + let mut matched = false; + for (idx, col) in enumerate(rec.schema().fields.iter()) { + if col.name().as_str() == col_name { + assert_eq!( + rec.column(idx).as_boolean(), + &BooleanArray::from(targets.to_vec()), + ); + matched = true; + } + } + assert!(matched); + } + + pub(crate) fn validate_primitive_column(col_name: &str, rec: &RecordBatch, targets: &[U]) + where + T: ArrowPrimitiveType, + [U]: AsRef<[::Native]>, + U: Debug, + { + let mut matched = false; + for (idx, col) in enumerate(rec.schema().fields.iter()) { + if col.name().as_str() == col_name { + assert_eq!(rec.column(idx).as_primitive::().values(), targets); + matched = true; + } + } + assert!(matched); + } + + pub(crate) fn compare_values(col_name1: &str, col_name2: &str, rec: &RecordBatch) + where + T: ArrowPrimitiveType, + { + let mut vals1 = None; + let mut vals2 = None; + for (idx, col) in enumerate(rec.schema().fields.iter()) { + if col.name().as_str() == col_name1 { + vals1 = Some(rec.column(idx).as_primitive::().values()) + } else if col.name().as_str() == col_name2 { + vals2 = Some(rec.column(idx).as_primitive::().values()) + } + } + + if let (Some(vals1), Some(vals2)) = (vals1, vals2) { + assert_eq!(vals1, vals2); + return; + } + + panic!("columns not found"); + } + + // create a test filter + pub(crate) fn create_filter() -> ZarrChunkFilter { + let mut filters: Vec> = Vec::new(); + let f = ZarrArrowPredicateFn::new( + ZarrProjection::keep(vec!["lat".to_string()]), + move |batch| { + gt_eq( + batch.column_by_name("lat").unwrap(), + &Scalar::new(&Float64Array::from(vec![38.6])), + ) + }, + ); + filters.push(Box::new(f)); + let f = ZarrArrowPredicateFn::new( + ZarrProjection::keep(vec!["lon".to_string()]), + move |batch| { + gt_eq( + batch.column_by_name("lon").unwrap(), + &Scalar::new(&Float64Array::from(vec![-109.7])), + ) + }, + ); + filters.push(Box::new(f)); + let f = ZarrArrowPredicateFn::new( + ZarrProjection::keep(vec!["lon".to_string()]), + move |batch| { + lt( + batch.column_by_name("lon").unwrap(), + &Scalar::new(&Float64Array::from(vec![-109.2])), + ) + }, + ); + filters.push(Box::new(f)); + + ZarrChunkFilter::new(filters) + } + + // don't need that for now, I commented out tests with string data. + // pub(crate) fn validate_string_column(col_name: &str, rec: &RecordBatch, targets: &[&str]) { + // let mut matched = false; + // for (idx, col) in enumerate(rec.schema().fields.iter()) { + // if col.name().as_str() == col_name { + // assert_eq!( + // rec.column(idx).as_string(), + // &StringArray::from(targets.to_vec()), + // ); + // matched = true; + // } + // } + // assert!(matched); + // } } diff --git a/src/reader/codecs.rs b/src/reader/codecs.rs index 62b52e6..14a984a 100644 --- a/src/reader/codecs.rs +++ b/src/reader/codecs.rs @@ -19,10 +19,12 @@ use crate::reader::errors::throw_invalid_meta; use crate::reader::{ZarrError, ZarrResult}; use arrow_array::*; use arrow_schema::{DataType, Field, FieldRef, TimeUnit}; +use blosc_src::{blosc_cbuffer_sizes, blosc_decompress_ctx}; use crc32c::crc32c; use flate2::read::GzDecoder; use itertools::Itertools; use std::io::Read; +use std::os::raw::c_void; use std::str::FromStr; use std::sync::Arc; use std::vec; @@ -270,14 +272,14 @@ fn decode_transpose( order: &[usize], ) -> ZarrResult> { let new_indices: Vec<_> = match order.len() { - 2 => (0..chunk_dims[order[0]]) - .cartesian_product(0..chunk_dims[order[1]]) - .map(|t| t.0 * chunk_dims[1] + t.1) + 2 => (0..chunk_dims[order[1]]) + .cartesian_product(0..chunk_dims[order[0]]) + .map(|t| t.1 * chunk_dims[0] + t.0) .collect(), - 3 => (0..chunk_dims[order[0]]) + 3 => (0..chunk_dims[order[2]]) .cartesian_product(0..chunk_dims[order[1]]) - .cartesian_product(0..chunk_dims[order[2]]) - .map(|t| t.0 .0 * chunk_dims[1] * chunk_dims[2] + t.0 .1 * chunk_dims[2] + t.1) + .cartesian_product(0..chunk_dims[order[0]]) + .map(|t| t.1 * chunk_dims[0] * chunk_dims[1] + t.0 .1 * chunk_dims[0] + t.0 .0) .collect(), _ => { panic!("Invalid number of dims for transpose") @@ -333,6 +335,41 @@ fn process_edge_chunk( keep_indices(buf, &indices_to_keep); } +// copied this from the blosc library, since we directly import blosc-src for +// this project due to using zarrs. +unsafe fn blosc_decompress_bytes(src: &[u8]) -> ZarrResult> { + let mut nbytes: usize = 0; + let mut _cbytes: usize = 0; + let mut _blocksize: usize = 0; + // Unsafe if src comes from an untrusted source. + blosc_cbuffer_sizes( + src.as_ptr() as *const c_void, + &mut nbytes as *mut usize, + &mut _cbytes as *mut usize, + &mut _blocksize as *mut usize, + ); + let dest_size = nbytes; + let mut dest: Vec = Vec::with_capacity(dest_size); + // Unsafe if src comes from an untrusted source. + let rsize = blosc_decompress_ctx( + src.as_ptr() as *const c_void, + dest.as_mut_ptr() as *mut c_void, + nbytes, + 1, + ); + if rsize > 0 { + // Unsafe if T contains references or pointers + dest.set_len(rsize as usize); + dest.shrink_to_fit(); + Ok(dest) + } else { + // Buffer too small, data corrupted, decompressor not available, etc + Err(ZarrError::Read( + "A problem occured with blosc decompression".to_string(), + )) + } +} + // decode data that was encoded with a bytes to bytes codec. fn apply_bytes_to_bytes_codec(codec: &ZarrCodec, bytes: &[u8]) -> ZarrResult> { let mut decompressed_bytes = Vec::new(); @@ -342,14 +379,13 @@ fn apply_bytes_to_bytes_codec(codec: &ZarrCodec, bytes: &[u8]) -> ZarrResult { - decompressed_bytes = unsafe { blosc::decompress_bytes(bytes).unwrap() }; + decompressed_bytes = unsafe { blosc_decompress_bytes(bytes).unwrap() }; } ZarrCodec::Crc32c => { let mut bytes = bytes.to_vec(); let l = bytes.len(); let checksum = bytes.split_off(l - 4); - let checksum = [checksum[0], checksum[1], checksum[2], checksum[3]]; - if crc32c(&bytes[..]) != u32::from_le_bytes(checksum) { + if crc32c(&bytes[..]).to_le_bytes() != checksum[..4] { return Err(throw_invalid_meta("crc32c checksum failed")); } decompressed_bytes = bytes; @@ -705,7 +741,7 @@ macro_rules! create_decode_function { if let Some(sharding_params) = sharding_params.as_ref() { let mut index_size: usize = 2 * 8 * sharding_params.n_chunks.iter().fold(1, |mult, x| mult * x); - index_size += sharding_params + index_size += 4 * sharding_params .index_codecs .iter() .any(|c| c == &ZarrCodec::Crc32c) as usize; @@ -1018,154 +1054,7 @@ pub(crate) fn apply_codecs( #[cfg(test)] mod zarr_codecs_tests { - use crate::tests::get_test_v3_data_path; - use super::*; - use ::std::fs::read; - - // reading a chunk and decoding it using hard coded, known options. this test - // doesn't included any sharding. - #[test] - fn no_sharding_tests() { - let path = get_test_v3_data_path("no_sharding.zarr/int_data/c/1/1".to_string()); - let raw_data = read(path).unwrap(); - - let chunk_shape = vec![4, 4]; - let real_dims = vec![4, 4]; - let data_type = ZarrDataType::Int(4); - let codecs = vec![ - ZarrCodec::Bytes(Endianness::Little), - ZarrCodec::BloscCompressor(BloscOptions::new( - CompressorName::Zstd, - 5, - ShuffleOptions::Noshuffle, - 0, - )), - ]; - let sharding_params: Option = None; - - let (arr, field) = apply_codecs( - "int_data".to_string(), - raw_data, - &chunk_shape, - &real_dims, - &data_type, - &codecs, - sharding_params, - None, - ) - .unwrap(); - - assert_eq!( - field, - Arc::new(Field::new("int_data", DataType::Int32, false)) - ); - let target_arr: Int32Array = vec![ - 68, 69, 70, 71, 84, 85, 86, 87, 100, 101, 102, 103, 116, 117, 118, 119, - ] - .into(); - assert_eq!(*arr, target_arr); - } - - // reading a chunk and decoding it using hard coded, known options. this test - // includes sharding. - #[test] - fn with_sharding_tests() { - let path = get_test_v3_data_path("with_sharding.zarr/float_data/1.1".to_string()); - let raw_data = read(path).unwrap(); - - let chunk_shape = vec![4, 4]; - let real_dims = vec![4, 4]; - let data_type = ZarrDataType::Float(8); - let codecs: Vec = vec![]; - let sharding_params = Some(ShardingOptions::new( - vec![2, 2], - vec![2, 2], - vec![ - ZarrCodec::Bytes(Endianness::Little), - ZarrCodec::BloscCompressor(BloscOptions::new( - CompressorName::Zstd, - 5, - ShuffleOptions::Noshuffle, - 0, - )), - ], - vec![ZarrCodec::Bytes(Endianness::Little)], - IndexLocation::End, - 0, - )); - - let (arr, field) = apply_codecs( - "float_data".to_string(), - raw_data, - &chunk_shape, - &real_dims, - &data_type, - &codecs, - sharding_params, - None, - ) - .unwrap(); - - assert_eq!( - field, - Arc::new(Field::new("float_data", DataType::Float64, false)) - ); - let target_arr: Float64Array = vec![ - 36.0, 37.0, 38.0, 39.0, 44.0, 45.0, 46.0, 47.0, 52.0, 53.0, 54.0, 55.0, 60.0, 61.0, - 62.0, 63.0, - ] - .into(); - assert_eq!(*arr, target_arr); - } - - // reading a chunk and decoding it using hard coded, known options. this test - // includes sharding, and the shape doesn't exactly line up with the chunks. - #[test] - fn with_sharding_with_edge_tests() { - let path = get_test_v3_data_path("with_sharding_with_edge.zarr/uint_data/1.1".to_string()); - let raw_data = read(path).unwrap(); - - let chunk_shape = vec![4, 4]; - let real_dims = vec![3, 3]; - let data_type = ZarrDataType::UInt(2); - let codecs: Vec = vec![]; - let sharding_params = Some(ShardingOptions::new( - vec![2, 2], - vec![2, 2], - vec![ - ZarrCodec::Bytes(Endianness::Little), - ZarrCodec::BloscCompressor(BloscOptions::new( - CompressorName::Zstd, - 5, - ShuffleOptions::Noshuffle, - 0, - )), - ], - vec![ZarrCodec::Bytes(Endianness::Little)], - IndexLocation::End, - 0, - )); - - let (arr, field) = apply_codecs( - "uint_data".to_string(), - raw_data, - &chunk_shape, - &real_dims, - &data_type, - &codecs, - sharding_params, - None, - ) - .unwrap(); - - assert_eq!( - field, - Arc::new(Field::new("uint_data", DataType::UInt16, false)) - ); - let target_arr: UInt16Array = vec![32, 33, 34, 39, 40, 41, 46, 47, 48].into(); - assert_eq!(*arr, target_arr); - } #[test] fn test_zarr_data_type_to_arrow_datatype() -> ZarrResult<()> { diff --git a/src/reader/mod.rs b/src/reader/mod.rs index cca087d..c367633 100644 --- a/src/reader/mod.rs +++ b/src/reader/mod.rs @@ -1,61 +1,5 @@ //! A module tha provides a sychronous reader for zarr store, to generate [`RecordBatch`]es. //! -//! ``` -//! # use arrow_zarr::reader::{ZarrRecordBatchReaderBuilder, ZarrProjection}; -//! # use arrow_cast::pretty::pretty_format_batches; -//! # use arrow_array::RecordBatch; -//! # use std::path::PathBuf; -//! # -//! # fn get_test_data_path(zarr_store: String) -> PathBuf { -//! # PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("test-data/data/zarr/v2_data").join(zarr_store) -//! # } -//! # -//! # fn assert_batches_eq(batches: &[RecordBatch], expected_lines: &[&str]) { -//! # let formatted = pretty_format_batches(batches).unwrap().to_string(); -//! # let actual_lines: Vec<_> = formatted.trim().lines().collect(); -//! # assert_eq!( -//! # &actual_lines, expected_lines, -//! # "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n", -//! # expected_lines, actual_lines -//! # ); -//! # } -//! -//! // The ZarrRead trait is implemented for PathBuf, so as long as it points -//! // to a directory with a valid zarr store, it can be used to initialize -//! // a zarr reader builder. -//! let p: PathBuf = get_test_data_path("lat_lon_example.zarr".to_string()); -//! -//! let proj = ZarrProjection::keep(vec!["lat".to_string(), "float_data".to_string()]); -//! let builder = ZarrRecordBatchReaderBuilder::new(p).with_projection(proj); -//! let mut reader = builder.build().unwrap(); -//! let rec_batch = reader.next().unwrap().unwrap(); -//! -//! assert_batches_eq( -//! &[rec_batch], -//! &[ -//! "+------------+------+", -//! "| float_data | lat |", -//! "+------------+------+", -//! "| 1001.0 | 38.0 |", -//! "| 1002.0 | 38.1 |", -//! "| 1003.0 | 38.2 |", -//! "| 1004.0 | 38.3 |", -//! "| 1012.0 | 38.0 |", -//! "| 1013.0 | 38.1 |", -//! "| 1014.0 | 38.2 |", -//! "| 1015.0 | 38.3 |", -//! "| 1023.0 | 38.0 |", -//! "| 1024.0 | 38.1 |", -//! "| 1025.0 | 38.2 |", -//! "| 1026.0 | 38.3 |", -//! "| 1034.0 | 38.0 |", -//! "| 1035.0 | 38.1 |", -//! "| 1036.0 | 38.2 |", -//! "| 1037.0 | 38.3 |", -//! "+------------+------+", -//! ], -//! ); -//! ``` use arrow_array::*; use arrow_schema::{DataType, Field, FieldRef, Schema}; @@ -402,136 +346,27 @@ impl ZarrRecordBatchReaderBuilder { #[cfg(test)] mod zarr_reader_tests { - use crate::tests::{get_test_v2_data_path, get_test_v3_data_path}; - use arrow::compute::kernels::cmp::{gt_eq, lt}; - use arrow_array::cast::AsArray; + use crate::test_utils::{ + compare_values, create_filter, store_1d, store_3d, store_compression_codecs, + store_endianness_and_order, store_endianness_and_order_3d, store_lat_lon, + store_lat_lon_broadcastable, store_partial_sharding, store_partial_sharding_3d, + validate_bool_column, validate_names_and_types, validate_primitive_column, StoreWrapper, + }; + use arrow::compute::kernels::cmp::gt_eq; use arrow_array::types::*; - use arrow_schema::{DataType, TimeUnit}; - use itertools::enumerate; - use std::{boxed::Box, collections::HashMap, fmt::Debug}; + use arrow_schema::DataType; + use rstest::*; + use std::{boxed::Box, collections::HashMap}; use super::*; use crate::reader::filters::{ZarrArrowPredicate, ZarrArrowPredicateFn}; - fn validate_names_and_types(targets: &HashMap, rec: &RecordBatch) { - let mut target_cols: Vec<&String> = targets.keys().collect(); - let schema = rec.schema(); - let from_rec: Vec<&String> = schema.fields.iter().map(|f| f.name()).collect(); - - target_cols.sort(); - assert_eq!(from_rec, target_cols); + #[rstest] + fn compressed_data_tests( + #[with("compressed_data_tests".to_string())] store_compression_codecs: StoreWrapper, + ) { + let p = store_compression_codecs.store_path(); - for field in schema.fields.iter() { - assert_eq!(field.data_type(), targets.get(field.name()).unwrap()); - } - } - - fn validate_bool_column(col_name: &str, rec: &RecordBatch, targets: &[bool]) { - let mut matched = false; - for (idx, col) in enumerate(rec.schema().fields.iter()) { - if col.name().as_str() == col_name { - assert_eq!( - rec.column(idx).as_boolean(), - &BooleanArray::from(targets.to_vec()), - ); - matched = true; - } - } - assert!(matched); - } - - fn validate_primitive_column(col_name: &str, rec: &RecordBatch, targets: &[U]) - where - T: ArrowPrimitiveType, - [U]: AsRef<[::Native]>, - U: Debug, - { - let mut matched = false; - for (idx, col) in enumerate(rec.schema().fields.iter()) { - if col.name().as_str() == col_name { - assert_eq!(rec.column(idx).as_primitive::().values(), targets); - matched = true; - } - } - assert!(matched); - } - - fn compare_values(col_name1: &str, col_name2: &str, rec: &RecordBatch) - where - T: ArrowPrimitiveType, - { - let mut vals1 = None; - let mut vals2 = None; - for (idx, col) in enumerate(rec.schema().fields.iter()) { - if col.name().as_str() == col_name1 { - vals1 = Some(rec.column(idx).as_primitive::().values()) - } else if col.name().as_str() == col_name2 { - vals2 = Some(rec.column(idx).as_primitive::().values()) - } - } - - if let (Some(vals1), Some(vals2)) = (vals1, vals2) { - assert_eq!(vals1, vals2); - return; - } - - panic!("columns not found"); - } - - fn validate_string_column(col_name: &str, rec: &RecordBatch, targets: &[&str]) { - let mut matched = false; - for (idx, col) in enumerate(rec.schema().fields.iter()) { - if col.name().as_str() == col_name { - assert_eq!( - rec.column(idx).as_string(), - &StringArray::from(targets.to_vec()), - ); - matched = true; - } - } - assert!(matched); - } - - // create a test filter - fn create_filter() -> ZarrChunkFilter { - let mut filters: Vec> = Vec::new(); - let f = ZarrArrowPredicateFn::new( - ZarrProjection::keep(vec!["lat".to_string()]), - move |batch| { - gt_eq( - batch.column_by_name("lat").unwrap(), - &Scalar::new(&Float64Array::from(vec![38.6])), - ) - }, - ); - filters.push(Box::new(f)); - let f = ZarrArrowPredicateFn::new( - ZarrProjection::keep(vec!["lon".to_string()]), - move |batch| { - gt_eq( - batch.column_by_name("lon").unwrap(), - &Scalar::new(&Float64Array::from(vec![-109.7])), - ) - }, - ); - filters.push(Box::new(f)); - let f = ZarrArrowPredicateFn::new( - ZarrProjection::keep(vec!["lon".to_string()]), - move |batch| { - lt( - batch.column_by_name("lon").unwrap(), - &Scalar::new(&Float64Array::from(vec![-109.2])), - ) - }, - ); - filters.push(Box::new(f)); - - ZarrChunkFilter::new(filters) - } - - #[test] - fn compression_tests() { - let p = get_test_v2_data_path("compression_example.zarr".to_string()); let reader = ZarrRecordBatchReaderBuilder::new(p).build().unwrap(); let records: Vec = reader.map(|x| x.unwrap()).collect(); @@ -539,29 +374,29 @@ mod zarr_reader_tests { ("bool_data".to_string(), DataType::Boolean), ("uint_data".to_string(), DataType::UInt64), ("int_data".to_string(), DataType::Int64), - ("float_data".to_string(), DataType::Float64), + ("float_data".to_string(), DataType::Float32), ("float_data_no_comp".to_string(), DataType::Float64), ]); + validate_names_and_types(&target_types, &records[0]); // center chunk let rec = &records[4]; - validate_names_and_types(&target_types, rec); validate_bool_column( "bool_data", rec, &[false, true, false, false, true, false, false, true, false], ); - validate_primitive_column::( - "int_data", - rec, - &[-4, -3, -2, 4, 5, 6, 12, 13, 14], - ); validate_primitive_column::( "uint_data", rec, &[27, 28, 29, 35, 36, 37, 43, 44, 45], ); - validate_primitive_column::( + validate_primitive_column::( + "int_data", + rec, + &[-4, -3, -2, 4, 5, 6, 12, 13, 14], + ); + validate_primitive_column::( "float_data", rec, &[127., 128., 129., 135., 136., 137., 143., 144., 145.], @@ -574,11 +409,10 @@ mod zarr_reader_tests { // right edge chunk let rec = &records[5]; - validate_names_and_types(&target_types, rec); validate_bool_column("bool_data", rec, &[true, false, true, false, true, false]); - validate_primitive_column::("int_data", rec, &[-1, 0, 7, 8, 15, 16]); validate_primitive_column::("uint_data", rec, &[30, 31, 38, 39, 46, 47]); - validate_primitive_column::( + validate_primitive_column::("int_data", rec, &[-1, 0, 7, 8, 15, 16]); + validate_primitive_column::( "float_data", rec, &[130., 131., 138., 139., 146., 147.], @@ -591,11 +425,10 @@ mod zarr_reader_tests { // bottom right edge chunk let rec = &records[8]; - validate_names_and_types(&target_types, rec); validate_bool_column("bool_data", rec, &[true, false, true, false]); - validate_primitive_column::("int_data", rec, &[23, 24, 31, 32]); validate_primitive_column::("uint_data", rec, &[54, 55, 62, 63]); - validate_primitive_column::( + validate_primitive_column::("int_data", rec, &[23, 24, 31, 32]); + validate_primitive_column::( "float_data", rec, &[154.0, 155.0, 162.0, 163.0], @@ -607,9 +440,11 @@ mod zarr_reader_tests { ); } - #[test] - fn projection_tests() { - let p = get_test_v2_data_path("compression_example.zarr".to_string()); + #[rstest] + fn projection_tests( + #[with("projection_tests".to_string())] store_compression_codecs: StoreWrapper, + ) { + let p = store_compression_codecs.store_path(); let proj = ZarrProjection::keep(vec!["bool_data".to_string(), "int_data".to_string()]); let builder = ZarrRecordBatchReaderBuilder::new(p).with_projection(proj); let reader = builder.build().unwrap(); @@ -619,10 +454,10 @@ mod zarr_reader_tests { ("bool_data".to_string(), DataType::Boolean), ("int_data".to_string(), DataType::Int64), ]); + validate_names_and_types(&target_types, &records[0]); // center chunk let rec = &records[4]; - validate_names_and_types(&target_types, rec); validate_bool_column( "bool_data", rec, @@ -635,9 +470,11 @@ mod zarr_reader_tests { ); } - #[test] - fn multiple_readers_tests() { - let p = get_test_v2_data_path("compression_example.zarr".to_string()); + #[rstest] + fn multiple_readers_tests( + #[with("multiple_readers_tests".to_string())] store_compression_codecs: StoreWrapper, + ) { + let p = store_compression_codecs.store_path(); let reader1 = ZarrRecordBatchReaderBuilder::new(p.clone()) .build_partial_reader(Some((0, 5))) .unwrap(); @@ -655,13 +492,14 @@ mod zarr_reader_tests { ("bool_data".to_string(), DataType::Boolean), ("uint_data".to_string(), DataType::UInt64), ("int_data".to_string(), DataType::Int64), - ("float_data".to_string(), DataType::Float64), + ("float_data".to_string(), DataType::Float32), ("float_data_no_comp".to_string(), DataType::Float64), ]); + validate_names_and_types(&target_types, &records1[0]); + validate_names_and_types(&target_types, &records2[0]); // center chunk let rec = &records1[4]; - validate_names_and_types(&target_types, rec); validate_bool_column( "bool_data", rec, @@ -677,7 +515,7 @@ mod zarr_reader_tests { rec, &[27, 28, 29, 35, 36, 37, 43, 44, 45], ); - validate_primitive_column::( + validate_primitive_column::( "float_data", rec, &[127., 128., 129., 135., 136., 137., 143., 144., 145.], @@ -690,11 +528,10 @@ mod zarr_reader_tests { // bottom edge chunk let rec = &records2[2]; - validate_names_and_types(&target_types, rec); validate_bool_column("bool_data", rec, &[false, true, false, false, true, false]); validate_primitive_column::("int_data", rec, &[20, 21, 22, 28, 29, 30]); validate_primitive_column::("uint_data", rec, &[51, 52, 53, 59, 60, 61]); - validate_primitive_column::( + validate_primitive_column::( "float_data", rec, &[151.0, 152.0, 153.0, 159.0, 160.0, 161.0], @@ -706,184 +543,120 @@ mod zarr_reader_tests { ); } - #[test] - fn endianness_and_order_tests() { - let p = get_test_v2_data_path("endianness_and_order_example.zarr".to_string()); - let reader = ZarrRecordBatchReaderBuilder::new(p).build().unwrap(); - let records: Vec = reader.map(|x| x.unwrap()).collect(); - - let target_types = HashMap::from([ - ("var1".to_string(), DataType::Int32), - ("var2".to_string(), DataType::Int32), - ]); - - // bottom edge chunk - let rec = &records[9]; - validate_names_and_types(&target_types, rec); - validate_primitive_column::( - "var1", - rec, - &[69, 80, 91, 70, 81, 92, 71, 82, 93], - ); - validate_primitive_column::( - "var2", - rec, - &[69, 80, 91, 70, 81, 92, 71, 82, 93], - ); - } - - #[test] - fn string_data_tests() { - let p = get_test_v2_data_path("string_example.zarr".to_string()); + #[rstest] + fn endianness_and_order_tests( + #[with("endianness_and_order_tests".to_string())] store_endianness_and_order: StoreWrapper, + ) { + let p = store_endianness_and_order.store_path(); let reader = ZarrRecordBatchReaderBuilder::new(p).build().unwrap(); let records: Vec = reader.map(|x| x.unwrap()).collect(); let target_types = HashMap::from([ - ("uint_data".to_string(), DataType::UInt8), - ("string_data".to_string(), DataType::Utf8), - ("utf8_data".to_string(), DataType::Utf8), + ("int_data_big_endian_f_order".to_string(), DataType::Int32), + ("int_data".to_string(), DataType::Int32), ]); + validate_names_and_types(&target_types, &records[0]); - // bottom edge chunk - let rec = &records[7]; - validate_names_and_types(&target_types, rec); - validate_primitive_column::("uint_data", rec, &[51, 52, 53, 59, 60, 61]); - validate_string_column( - "string_data", - rec, - &["abc61", "abc62", "abc63", "abc69", "abc70", "abc71"], - ); - validate_string_column( - "utf8_data", - rec, - &["def61", "def62", "def63", "def69", "def70", "def71"], - ); + for rec in &records { + compare_values::("int_data_big_endian_f_order", "int_data", rec); + } } - #[test] - fn ts_data_tests() { - let p = get_test_v2_data_path("ts_example.zarr".to_string()); + #[rstest] + fn endianness_and_order_3d_tests( + #[with("endianness_and_order_3d_tests".to_string())] + store_endianness_and_order_3d: StoreWrapper, + ) { + let p = store_endianness_and_order_3d.store_path(); let reader = ZarrRecordBatchReaderBuilder::new(p).build().unwrap(); let records: Vec = reader.map(|x| x.unwrap()).collect(); let target_types = HashMap::from([ - ( - "ts_s_data".to_string(), - DataType::Timestamp(TimeUnit::Second, None), - ), - ( - "ts_ms_data".to_string(), - DataType::Timestamp(TimeUnit::Millisecond, None), - ), - ( - "ts_us_data".to_string(), - DataType::Timestamp(TimeUnit::Microsecond, None), - ), - ( - "ts_ns_data".to_string(), - DataType::Timestamp(TimeUnit::Nanosecond, None), - ), + ("int_data_big_endian_f_order".to_string(), DataType::Int32), + ("int_data".to_string(), DataType::Int32), ]); + validate_names_and_types(&target_types, &records[0]); - // top center chunk - let rec = &records[1]; - validate_names_and_types(&target_types, rec); - validate_primitive_column::( - "ts_s_data", - rec, - &[1685750400, 1685836800, 1686182400, 1686268800], - ); - validate_primitive_column::( - "ts_ms_data", - rec, - &[1685750400000, 1685836800000, 1686182400000, 1686268800000], - ); - validate_primitive_column::( - "ts_us_data", - rec, - &[ - 1685750400000000, - 1685836800000000, - 1686182400000000, - 1686268800000000, - ], - ); - validate_primitive_column::( - "ts_ns_data", - rec, - &[ - 1685750400000000000, - 1685836800000000000, - 1686182400000000000, - 1686268800000000000, - ], - ); - - // top right edge chunk - let rec = &records[2]; - validate_names_and_types(&target_types, rec); - validate_primitive_column::( - "ts_s_data", - rec, - &[1685923200, 1686355200], - ); - validate_primitive_column::( - "ts_ms_data", - rec, - &[1685923200000, 1686355200000], - ); - validate_primitive_column::( - "ts_us_data", - rec, - &[1685923200000000, 1686355200000000], - ); - validate_primitive_column::( - "ts_ns_data", - rec, - &[1685923200000000000, 1686355200000000000], - ); + for rec in &records { + compare_values::("int_data_big_endian_f_order", "int_data", rec); + } } - #[test] - fn one_dim_tests() { - let p = get_test_v2_data_path("one_dim_example.zarr".to_string()); + // apparently in zarr v3 the type is just "string", so I'm not too sure + // how this is supposed to be handled, how should I get the length for + // example... I will revisit later and uncomment this test. + // #[rstest] + // fn string_data_tests( + // #[with("string_data_tests".to_string())] store_strings: StoreWrapper + // ) { + // let p = store_strings.prefix_to_fs_path(&StorePrefix::new("").unwrap()); + // let reader = ZarrRecordBatchReaderBuilder::new(p).build().unwrap(); + // let records: Vec = reader.map(|x| x.unwrap()).collect(); + + // let target_types = HashMap::from([ + // ("int_data".to_string(), DataType::Int8), + // ("string_data".to_string(), DataType::Utf8), + // ]); + // validate_names_and_types(&target_types, &records[0]); + + // // top left corner + // let rec = &records[0]; + // validate_primitive_column::("uint_data", rec, &[1, 2, 3, 9, 10, 11, 17, 18, 19]); + // validate_string_column( + // "string_data", + // rec, + // &["abc01", "abc02", "abc03", "abc09", "abc10", "abc11", "abc17", "abc18", "abc19"], + // ); + + // // bottom edge chunk + // let rec = &records[7]; + // validate_primitive_column::("uint_data", rec, &[61, 62, 63, 69, 70, 71]); + // validate_string_column( + // "string_data", + // rec, + // &["abc61", "abc62", "abc63", "abc69", "abc70", "abc71"], + // ); + // } + + #[rstest] + fn one_dim_tests(#[with("one_dim_tests".to_string())] store_1d: StoreWrapper) { + let p = store_1d.store_path(); let reader = ZarrRecordBatchReaderBuilder::new(p).build().unwrap(); let records: Vec = reader.map(|x| x.unwrap()).collect(); let target_types = HashMap::from([ - ("int_data".to_string(), DataType::Int64), - ("float_data".to_string(), DataType::Float64), + ("int_data".to_string(), DataType::Int32), + ("float_data".to_string(), DataType::Float32), ]); + validate_names_and_types(&target_types, &records[0]); // center chunk let rec = &records[1]; - validate_names_and_types(&target_types, rec); - validate_primitive_column::("int_data", rec, &[-2, -1, 0]); - validate_primitive_column::("float_data", rec, &[103.0, 104.0, 105.0]); + validate_primitive_column::("int_data", rec, &[-2, -1, 0]); + validate_primitive_column::("float_data", rec, &[103.0, 104.0, 105.0]); // right edge chunk let rec = &records[3]; - validate_names_and_types(&target_types, rec); - validate_primitive_column::("int_data", rec, &[4, 5]); - validate_primitive_column::("float_data", rec, &[109.0, 110.0]); + validate_primitive_column::("int_data", rec, &[4, 5]); + validate_primitive_column::("float_data", rec, &[109.0, 110.0]); } - #[test] - fn three_dim_tests() { - let p = get_test_v2_data_path("three_dim_example.zarr".to_string()); + #[rstest] + fn three_dim_tests(#[with("three_dim_tests".to_string())] store_3d: StoreWrapper) { + let p = store_3d.store_path(); let reader = ZarrRecordBatchReaderBuilder::new(p).build().unwrap(); let records: Vec = reader.map(|x| x.unwrap()).collect(); let target_types = HashMap::from([ - ("int_data".to_string(), DataType::Int64), - ("float_data".to_string(), DataType::Float64), + ("int_data".to_string(), DataType::Int32), + ("float_data".to_string(), DataType::Float32), ]); + validate_names_and_types(&target_types, &records[0]); // center chunk let rec = &records[13]; - validate_names_and_types(&target_types, rec); - validate_primitive_column::("int_data", rec, &[0, 1, 5, 6, 25, 26, 30, 31]); - validate_primitive_column::( + validate_primitive_column::("int_data", rec, &[0, 1, 5, 6, 25, 26, 30, 31]); + validate_primitive_column::( "float_data", rec, &[162.0, 163.0, 167.0, 168.0, 187.0, 188.0, 192.0, 193.0], @@ -892,8 +665,8 @@ mod zarr_reader_tests { // right edge chunk let rec = &records[14]; validate_names_and_types(&target_types, rec); - validate_primitive_column::("int_data", rec, &[2, 7, 27, 32]); - validate_primitive_column::( + validate_primitive_column::("int_data", rec, &[2, 7, 27, 32]); + validate_primitive_column::( "float_data", rec, &[164.0, 169.0, 189.0, 194.0], @@ -902,25 +675,25 @@ mod zarr_reader_tests { // right front edge chunk let rec = &records[23]; validate_names_and_types(&target_types, rec); - validate_primitive_column::("int_data", rec, &[52, 57]); - validate_primitive_column::("float_data", rec, &[214.0, 219.0]); + validate_primitive_column::("int_data", rec, &[52, 57]); + validate_primitive_column::("float_data", rec, &[214.0, 219.0]); // bottom front edge chunk let rec = &records[24]; validate_names_and_types(&target_types, rec); - validate_primitive_column::("int_data", rec, &[58, 59]); - validate_primitive_column::("float_data", rec, &[220.0, 221.0]); + validate_primitive_column::("int_data", rec, &[58, 59]); + validate_primitive_column::("float_data", rec, &[220.0, 221.0]); // right front bottom edge chunk let rec = &records[26]; validate_names_and_types(&target_types, rec); - validate_primitive_column::("int_data", rec, &[62]); - validate_primitive_column::("float_data", rec, &[224.0]); + validate_primitive_column::("int_data", rec, &[62]); + validate_primitive_column::("float_data", rec, &[224.0]); } - #[test] - fn filters_tests() { - let p = get_test_v2_data_path("lat_lon_example.zarr".to_string()); + #[rstest] + fn filters_tests(#[with("filters_tests".to_string())] store_lat_lon: StoreWrapper) { + let p = store_lat_lon.store_path(); let mut builder = ZarrRecordBatchReaderBuilder::new(p); // set the filters to select part of the raster, based on lat and @@ -936,6 +709,7 @@ mod zarr_reader_tests { ("lon".to_string(), DataType::Float64), ("float_data".to_string(), DataType::Float64), ]); + validate_names_and_types(&target_types, &records[0]); // check the values in a chunk. the predicate pushdown only takes care of // skipping whole chunks, so there is no guarantee that the values in the @@ -964,15 +738,15 @@ mod zarr_reader_tests { "float_data", rec, &[ - 1005.0, 1006.0, 1007.0, 1008.0, 1016.0, 1017.0, 1018.0, 1019.0, 1027.0, 1028.0, - 1029.0, 1030.0, 1038.0, 1039.0, 1040.0, 1041.0, + 4.0, 5.0, 6.0, 7.0, 15.0, 16.0, 17.0, 18.0, 26.0, 27.0, 28.0, 29.0, 37.0, 38.0, + 39.0, 40.0, ], ); } - #[test] - fn empty_query_tests() { - let p = get_test_v2_data_path("lat_lon_example.zarr".to_string()); + #[rstest] + fn empty_query_tests(#[with("empty_query_tests".to_string())] store_lat_lon: StoreWrapper) { + let p = store_lat_lon.store_path(); let mut builder = ZarrRecordBatchReaderBuilder::new(p); // set a filter that will filter out all the data, there should be nothing left after @@ -997,31 +771,22 @@ mod zarr_reader_tests { assert_eq!(records.len(), 0); } - #[test] - fn array_broadcast_tests() { + #[rstest] + fn array_broadcast_tests( + #[with("array_broadcast_tests_part1".to_string())] store_lat_lon: StoreWrapper, + #[with("array_broadcast_tests_part2".to_string())] + store_lat_lon_broadcastable: StoreWrapper, + ) { // reference that doesn't broadcast a 1D array - let p = get_test_v2_data_path("lat_lon_example.zarr".to_string()); + let p = store_lat_lon.store_path(); let mut builder = ZarrRecordBatchReaderBuilder::new(p); builder = builder.with_filter(create_filter()); let reader = builder.build().unwrap(); let records: Vec = reader.map(|x| x.unwrap()).collect(); - // v2 format with array broadcast - let p = get_test_v2_data_path("lat_lon_example_broadcastable.zarr".to_string()); - let mut builder = ZarrRecordBatchReaderBuilder::new(p); - - builder = builder.with_filter(create_filter()); - let reader = builder.build().unwrap(); - let records_from_one_d_repr: Vec = reader.map(|x| x.unwrap()).collect(); - - assert_eq!(records_from_one_d_repr.len(), records.len()); - for (rec, rec_from_one_d_repr) in records.iter().zip(records_from_one_d_repr.iter()) { - assert_eq!(rec, rec_from_one_d_repr); - } - // v3 format with array broadcast - let p = get_test_v3_data_path("with_broadcastable_array.zarr".to_string()); + let p = store_lat_lon_broadcastable.store_path(); let mut builder = ZarrRecordBatchReaderBuilder::new(p); builder = builder.with_filter(create_filter()); @@ -1034,57 +799,11 @@ mod zarr_reader_tests { } } - #[test] - fn no_sharding_tests() { - let p = get_test_v3_data_path("no_sharding.zarr".to_string()); - let builder = ZarrRecordBatchReaderBuilder::new(p); - - let reader = builder.build().unwrap(); - let records: Vec = reader.map(|x| x.unwrap()).collect(); - - let target_types = HashMap::from([("int_data".to_string(), DataType::Int32)]); - - let rec = &records[1]; - validate_names_and_types(&target_types, rec); - validate_primitive_column::( - "int_data", - rec, - &[4, 5, 6, 7, 20, 21, 22, 23, 36, 37, 38, 39, 52, 53, 54, 55], - ); - } - - #[test] - fn no_sharding_with_edge_tests() { - let p = get_test_v3_data_path("no_sharding_with_edge.zarr".to_string()); - let builder = ZarrRecordBatchReaderBuilder::new(p); - - let reader = builder.build().unwrap(); - let records: Vec = reader.map(|x| x.unwrap()).collect(); - - let target_types = HashMap::from([ - ("float_data".to_string(), DataType::Float32), - ("uint_data".to_string(), DataType::UInt64), - ]); - - let rec = &records[3]; - validate_names_and_types(&target_types, rec); - validate_primitive_column::( - "float_data", - rec, - &[ - 12.0, 13.0, 14.0, 27.0, 28.0, 29.0, 42.0, 43.0, 44.0, 57.0, 58.0, 59.0, - ], - ); - validate_primitive_column::( - "uint_data", - rec, - &[12, 13, 14, 27, 28, 29, 42, 43, 44, 57, 58, 59], - ); - } - - #[test] - fn with_partial_sharding_tests() { - let p = get_test_v3_data_path("with_partial_sharding.zarr".to_string()); + #[rstest] + fn partial_sharding_tests( + #[with("partial_sharding_tests".to_string())] store_partial_sharding: StoreWrapper, + ) { + let p = store_partial_sharding.store_path(); let builder = ZarrRecordBatchReaderBuilder::new(p); let reader = builder.build().unwrap(); @@ -1095,9 +814,11 @@ mod zarr_reader_tests { } } - #[test] - fn with_partial_sharding_3d_tests() { - let p = get_test_v3_data_path("with_partial_sharding_3D.zarr".to_string()); + #[rstest] + fn partial_sharding_3d_tests( + #[with("partial_sharding_3d_tests".to_string())] store_partial_sharding_3d: StoreWrapper, + ) { + let p = store_partial_sharding_3d.store_path(); let builder = ZarrRecordBatchReaderBuilder::new(p); let reader = builder.build().unwrap(); @@ -1107,94 +828,4 @@ mod zarr_reader_tests { compare_values::("float_data_not_sharded", "float_data_sharded", &rec); } } - - #[test] - fn with_sharding_tests() { - let p = get_test_v3_data_path("with_sharding.zarr".to_string()); - let builder = ZarrRecordBatchReaderBuilder::new(p); - - let reader = builder.build().unwrap(); - let records: Vec = reader.map(|x| x.unwrap()).collect(); - - let target_types = HashMap::from([ - ("float_data".to_string(), DataType::Float64), - ("int_data".to_string(), DataType::Int64), - ]); - - let rec = &records[2]; - validate_names_and_types(&target_types, rec); - validate_primitive_column::( - "float_data", - rec, - &[ - 32.0, 33.0, 34.0, 35.0, 40.0, 41.0, 42.0, 43.0, 48.0, 49.0, 50.0, 51.0, 56.0, 57.0, - 58.0, 59.0, - ], - ); - validate_primitive_column::( - "int_data", - rec, - &[ - 32, 33, 34, 35, 40, 41, 42, 43, 48, 49, 50, 51, 56, 57, 58, 59, - ], - ); - } - - #[test] - fn with_sharding_with_edge_tests() { - let p = get_test_v3_data_path("with_sharding_with_edge.zarr".to_string()); - let builder = ZarrRecordBatchReaderBuilder::new(p); - - let reader = builder.build().unwrap(); - let records: Vec = reader.map(|x| x.unwrap()).collect(); - - let target_types = HashMap::from([("uint_data".to_string(), DataType::UInt16)]); - - let rec = &records[1]; - validate_names_and_types(&target_types, rec); - validate_primitive_column::( - "uint_data", - rec, - &[4, 5, 6, 11, 12, 13, 18, 19, 20, 25, 26, 27], - ); - } - - #[test] - fn three_dims_no_sharding_with_edge_tests() { - let p = get_test_v3_data_path("no_sharding_with_edge_3d.zarr".to_string()); - let builder = ZarrRecordBatchReaderBuilder::new(p); - - let reader = builder.build().unwrap(); - let records: Vec = reader.map(|x| x.unwrap()).collect(); - - let target_types = HashMap::from([("uint_data".to_string(), DataType::UInt64)]); - - let rec = &records[5]; - validate_names_and_types(&target_types, rec); - validate_primitive_column::("uint_data", rec, &[14, 19, 39, 44]); - } - - #[test] - fn three_dims_with_sharding_with_edge_tests() { - let p = get_test_v3_data_path("with_sharding_with_edge_3d.zarr".to_string()); - let builder = ZarrRecordBatchReaderBuilder::new(p); - - let reader = builder.build().unwrap(); - let records: Vec = reader.map(|x| x.unwrap()).collect(); - - let target_types = HashMap::from([("float_data".to_string(), DataType::Float64)]); - - let rec = &records[23]; - validate_names_and_types(&target_types, rec); - validate_primitive_column::( - "float_data", - rec, - &[ - 1020.0, 1021.0, 1022.0, 1031.0, 1032.0, 1033.0, 1042.0, 1043.0, 1044.0, 1053.0, - 1054.0, 1055.0, 1141.0, 1142.0, 1143.0, 1152.0, 1153.0, 1154.0, 1163.0, 1164.0, - 1165.0, 1174.0, 1175.0, 1176.0, 1262.0, 1263.0, 1264.0, 1273.0, 1274.0, 1275.0, - 1284.0, 1285.0, 1286.0, 1295.0, 1296.0, 1297.0, - ], - ); - } } diff --git a/src/reader/zarr_read.rs b/src/reader/zarr_read.rs index 51127dc..85fd73c 100644 --- a/src/reader/zarr_read.rs +++ b/src/reader/zarr_read.rs @@ -364,33 +364,28 @@ impl ZarrRead for PathBuf { #[cfg(test)] mod zarr_read_tests { use std::collections::HashSet; - use std::path::PathBuf; use super::*; use crate::reader::codecs::{Endianness, ZarrCodec, ZarrDataType}; use crate::reader::metadata::{ChunkSeparator, ZarrArrayMetadata}; - - fn get_test_data_path(zarr_store: String) -> PathBuf { - PathBuf::from(env!("CARGO_MANIFEST_DIR")) - .join("test-data/data/zarr/v2_data") - .join(zarr_store) - } + use crate::test_utils::{store_raw_bytes, StoreWrapper}; + use rstest::*; // read the store metadata, given a path to a zarr store. - #[test] - fn read_metadata() { - let p = get_test_data_path("raw_bytes_example.zarr".to_string()); + #[rstest] + fn read_metadata(#[with("read_metadata".to_string())] store_raw_bytes: StoreWrapper) { + let p = store_raw_bytes.store_path(); let meta = p.get_zarr_metadata().unwrap(); assert_eq!(meta.get_columns(), &vec!["byte_data", "float_data"]); assert_eq!( meta.get_array_meta("byte_data").unwrap(), &ZarrArrayMetadata::new( - 2, + 3, ZarrDataType::UInt(1), ChunkPattern { - separator: ChunkSeparator::Period, - c_prefix: false + separator: ChunkSeparator::Slash, + c_prefix: true }, None, vec![ZarrCodec::Bytes(Endianness::Little)], @@ -399,11 +394,11 @@ mod zarr_read_tests { assert_eq!( meta.get_array_meta("float_data").unwrap(), &ZarrArrayMetadata::new( - 2, + 3, ZarrDataType::Float(8), ChunkPattern { - separator: ChunkSeparator::Period, - c_prefix: false + separator: ChunkSeparator::Slash, + c_prefix: true }, None, vec![ZarrCodec::Bytes(Endianness::Little)], @@ -413,9 +408,9 @@ mod zarr_read_tests { // read the raw data contained into a zarr store. one of the variables contains // byte data, which we explicitly check here. - #[test] - fn read_raw_chunks() { - let p = get_test_data_path("raw_bytes_example.zarr".to_string()); + #[rstest] + fn read_raw_chunks(#[with("read_raw_chunks".to_string())] store_raw_bytes: StoreWrapper) { + let p = store_raw_bytes.store_path(); let meta = p.get_zarr_metadata().unwrap(); // no broadcastable arrays diff --git a/test-data b/test-data deleted file mode 160000 index 7809aec..0000000 --- a/test-data +++ /dev/null @@ -1 +0,0 @@ -Subproject commit 7809aeccd2d460bce819d94ec6cc09a6c48068d0