From 31f6472d954c50e9bbcaa19cb6b2dc7e1831561b Mon Sep 17 00:00:00 2001 From: Satyam Singh Date: Fri, 15 Dec 2023 08:30:54 +0530 Subject: [PATCH] ensure statistics are type compliant with schema (#571) This PR adds ensure all the types in table statistics are compatible with table schema types. --- server/src/catalog/column.rs | 36 ++++++++++++ server/src/query/stream_schema_provider.rs | 68 ++++++---------------- 2 files changed, 53 insertions(+), 51 deletions(-) diff --git a/server/src/catalog/column.rs b/server/src/catalog/column.rs index 17361609a..59074b611 100644 --- a/server/src/catalog/column.rs +++ b/server/src/catalog/column.rs @@ -18,6 +18,8 @@ use std::cmp::{max, min}; +use arrow_schema::DataType; +use datafusion::scalar::ScalarValue; use parquet::file::statistics::Statistics; #[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] @@ -85,6 +87,40 @@ impl TypedStatistics { _ => panic!("Cannot update wrong types"), } } + + pub fn min_max_as_scalar(self, datatype: &DataType) -> Option<(ScalarValue, ScalarValue)> { + let (min, max) = match (self, datatype) { + (TypedStatistics::Bool(stats), DataType::Boolean) => ( + ScalarValue::Boolean(Some(stats.min)), + ScalarValue::Boolean(Some(stats.max)), + ), + (TypedStatistics::Int(stats), DataType::Int32) => ( + ScalarValue::Int32(Some(stats.min as i32)), + ScalarValue::Int32(Some(stats.max as i32)), + ), + (TypedStatistics::Int(stats), DataType::Int64) => ( + ScalarValue::Int64(Some(stats.min)), + ScalarValue::Int64(Some(stats.max)), + ), + (TypedStatistics::Float(stats), DataType::Float32) => ( + ScalarValue::Float32(Some(stats.min as f32)), + ScalarValue::Float32(Some(stats.max as f32)), + ), + (TypedStatistics::Float(stats), DataType::Float64) => ( + ScalarValue::Float64(Some(stats.min)), + ScalarValue::Float64(Some(stats.max)), + ), + (TypedStatistics::String(stats), DataType::Utf8) => ( + ScalarValue::Utf8(Some(stats.min)), + ScalarValue::Utf8(Some(stats.max)), + ), + _ => { + return None; + } + }; + + Some((min, max)) + } } /// Column statistics are used to track statistics for a column in a given file. diff --git a/server/src/query/stream_schema_provider.rs b/server/src/query/stream_schema_provider.rs index 637972130..cbdf0acbd 100644 --- a/server/src/query/stream_schema_provider.rs +++ b/server/src/query/stream_schema_provider.rs @@ -18,7 +18,7 @@ use std::{any::Any, collections::HashMap, ops::Bound, sync::Arc}; -use arrow_schema::{DataType, Schema, SchemaRef, SortOptions}; +use arrow_schema::{Schema, SchemaRef, SortOptions}; use bytes::Bytes; use chrono::{NaiveDateTime, Timelike, Utc}; use datafusion::{ @@ -236,57 +236,23 @@ fn partitioned_files( count += num_rows; } - let mut statistics = vec![]; - - for field in table_schema.fields() { - let Some(stats) = column_statistics - .get(field.name()) - .and_then(|stats| stats.as_ref()) - else { - statistics.push(datafusion::common::ColumnStatistics::default()); - break; - }; - - let datatype = field.data_type(); - - let (min, max) = match (stats, datatype) { - (TypedStatistics::Bool(stats), DataType::Boolean) => ( - ScalarValue::Boolean(Some(stats.min)), - ScalarValue::Boolean(Some(stats.max)), - ), - (TypedStatistics::Int(stats), DataType::Int32) => ( - ScalarValue::Int32(Some(stats.min as i32)), - ScalarValue::Int32(Some(stats.max as i32)), - ), - (TypedStatistics::Int(stats), DataType::Int64) => ( - ScalarValue::Int64(Some(stats.min)), - ScalarValue::Int64(Some(stats.max)), - ), - (TypedStatistics::Float(stats), DataType::Float32) => ( - ScalarValue::Float32(Some(stats.min as f32)), - ScalarValue::Float32(Some(stats.max as f32)), - ), - (TypedStatistics::Float(stats), DataType::Float64) => ( - ScalarValue::Float64(Some(stats.min)), - ScalarValue::Float64(Some(stats.max)), - ), - (TypedStatistics::String(stats), DataType::Utf8) => ( - ScalarValue::Utf8(Some(stats.min.clone())), - ScalarValue::Utf8(Some(stats.max.clone())), - ), - _ => { - statistics.push(datafusion::common::ColumnStatistics::default()); - break; - } - }; - - statistics.push(datafusion::common::ColumnStatistics { - null_count: None, - max_value: Some(max), - min_value: Some(min), - distinct_count: None, + let statistics = table_schema + .fields() + .iter() + .map(|field| { + column_statistics + .get(field.name()) + .and_then(|stats| stats.as_ref()) + .and_then(|stats| stats.clone().min_max_as_scalar(field.data_type())) + .map(|(min, max)| datafusion::common::ColumnStatistics { + null_count: None, + max_value: Some(max), + min_value: Some(min), + distinct_count: None, + }) + .unwrap_or_default() }) - } + .collect(); let statistics = datafusion::common::Statistics { num_rows: Some(count as usize),