Skip to content

Commit

Permalink
fix: support min/max for Float16 type (#12050)
Browse files Browse the repository at this point in the history
* fix: support min/max for Float16 type

* minor: uncomment arrow_typeof float16 in sqllocigtests
  • Loading branch information
korowa authored Aug 19, 2024
1 parent 159ab17 commit 7c5a8eb
Show file tree
Hide file tree
Showing 5 changed files with 62 additions and 18 deletions.
1 change: 1 addition & 0 deletions datafusion-cli/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions datafusion/functions-aggregate/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ datafusion-expr = { workspace = true }
datafusion-functions-aggregate-common = { workspace = true }
datafusion-physical-expr = { workspace = true }
datafusion-physical-expr-common = { workspace = true }
half = { workspace = true }
log = { workspace = true }
paste = "1.0.14"
sqlparser = { workspace = true }
Expand Down
34 changes: 25 additions & 9 deletions datafusion/functions-aggregate/src/min_max.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,18 +34,19 @@

use arrow::array::{
ArrayRef, BinaryArray, BinaryViewArray, BooleanArray, Date32Array, Date64Array,
Decimal128Array, Decimal256Array, Float32Array, Float64Array, Int16Array, Int32Array,
Int64Array, Int8Array, IntervalDayTimeArray, IntervalMonthDayNanoArray,
IntervalYearMonthArray, LargeBinaryArray, LargeStringArray, StringArray,
StringViewArray, Time32MillisecondArray, Time32SecondArray, Time64MicrosecondArray,
Time64NanosecondArray, TimestampMicrosecondArray, TimestampMillisecondArray,
TimestampNanosecondArray, TimestampSecondArray, UInt16Array, UInt32Array,
UInt64Array, UInt8Array,
Decimal128Array, Decimal256Array, Float16Array, Float32Array, Float64Array,
Int16Array, Int32Array, Int64Array, Int8Array, IntervalDayTimeArray,
IntervalMonthDayNanoArray, IntervalYearMonthArray, LargeBinaryArray,
LargeStringArray, StringArray, StringViewArray, Time32MillisecondArray,
Time32SecondArray, Time64MicrosecondArray, Time64NanosecondArray,
TimestampMicrosecondArray, TimestampMillisecondArray, TimestampNanosecondArray,
TimestampSecondArray, UInt16Array, UInt32Array, UInt64Array, UInt8Array,
};
use arrow::compute;
use arrow::datatypes::{
DataType, Decimal128Type, Decimal256Type, Float32Type, Float64Type, Int16Type,
Int32Type, Int64Type, Int8Type, UInt16Type, UInt32Type, UInt64Type, UInt8Type,
DataType, Decimal128Type, Decimal256Type, Float16Type, Float32Type, Float64Type,
Int16Type, Int32Type, Int64Type, Int8Type, UInt16Type, UInt32Type, UInt64Type,
UInt8Type,
};
use arrow_schema::IntervalUnit;
use datafusion_common::{
Expand All @@ -66,6 +67,7 @@ use datafusion_expr::GroupsAccumulator;
use datafusion_expr::{
function::AccumulatorArgs, Accumulator, AggregateUDFImpl, Signature, Volatility,
};
use half::f16;
use std::ops::Deref;

fn get_min_max_result_type(input_types: &[DataType]) -> Result<Vec<DataType>> {
Expand Down Expand Up @@ -181,6 +183,7 @@ impl AggregateUDFImpl for Max {
| UInt16
| UInt32
| UInt64
| Float16
| Float32
| Float64
| Decimal128(_, _)
Expand Down Expand Up @@ -209,6 +212,9 @@ impl AggregateUDFImpl for Max {
UInt16 => instantiate_max_accumulator!(data_type, u16, UInt16Type),
UInt32 => instantiate_max_accumulator!(data_type, u32, UInt32Type),
UInt64 => instantiate_max_accumulator!(data_type, u64, UInt64Type),
Float16 => {
instantiate_max_accumulator!(data_type, f16, Float16Type)
}
Float32 => {
instantiate_max_accumulator!(data_type, f32, Float32Type)
}
Expand Down Expand Up @@ -339,6 +345,9 @@ macro_rules! min_max_batch {
DataType::Float32 => {
typed_min_max_batch!($VALUES, Float32Array, Float32, $OP)
}
DataType::Float16 => {
typed_min_max_batch!($VALUES, Float16Array, Float16, $OP)
}
DataType::Int64 => typed_min_max_batch!($VALUES, Int64Array, Int64, $OP),
DataType::Int32 => typed_min_max_batch!($VALUES, Int32Array, Int32, $OP),
DataType::Int16 => typed_min_max_batch!($VALUES, Int16Array, Int16, $OP),
Expand Down Expand Up @@ -623,6 +632,9 @@ macro_rules! min_max {
(ScalarValue::Float32(lhs), ScalarValue::Float32(rhs)) => {
typed_min_max_float!(lhs, rhs, Float32, $OP)
}
(ScalarValue::Float16(lhs), ScalarValue::Float16(rhs)) => {
typed_min_max_float!(lhs, rhs, Float16, $OP)
}
(ScalarValue::UInt64(lhs), ScalarValue::UInt64(rhs)) => {
typed_min_max!(lhs, rhs, UInt64, $OP)
}
Expand Down Expand Up @@ -950,6 +962,7 @@ impl AggregateUDFImpl for Min {
| UInt16
| UInt32
| UInt64
| Float16
| Float32
| Float64
| Decimal128(_, _)
Expand Down Expand Up @@ -978,6 +991,9 @@ impl AggregateUDFImpl for Min {
UInt16 => instantiate_min_accumulator!(data_type, u16, UInt16Type),
UInt32 => instantiate_min_accumulator!(data_type, u32, UInt32Type),
UInt64 => instantiate_min_accumulator!(data_type, u64, UInt64Type),
Float16 => {
instantiate_min_accumulator!(data_type, f16, Float16Type)
}
Float32 => {
instantiate_min_accumulator!(data_type, f32, Float32Type)
}
Expand Down
28 changes: 28 additions & 0 deletions datafusion/sqllogictest/test_files/aggregate.slt
Original file line number Diff line number Diff line change
Expand Up @@ -5642,3 +5642,31 @@ query I??III?T
select count(null), min(null), max(null), bit_and(NULL), bit_or(NULL), bit_xor(NULL), nth_value(NULL, 1), string_agg(NULL, ',');
----
0 NULL NULL NULL NULL NULL NULL NULL

# test min/max Float16 without group expression
query RRTT
WITH data AS (
SELECT arrow_cast(1, 'Float16') AS f
UNION ALL
SELECT arrow_cast(6, 'Float16') AS f
)
SELECT MIN(f), MAX(f), arrow_typeof(MIN(f)), arrow_typeof(MAX(f)) FROM data;
----
1 6 Float16 Float16

# test min/max Float16 with group expression
query IRRTT
WITH data AS (
SELECT 1 as k, arrow_cast(1.8125, 'Float16') AS f
UNION ALL
SELECT 1 as k, arrow_cast(6.8007813, 'Float16') AS f
UNION ALL
SELECT 2 AS k, arrow_cast(8.5, 'Float16') AS f
)
SELECT k, MIN(f), MAX(f), arrow_typeof(MIN(f)), arrow_typeof(MAX(f))
FROM data
GROUP BY k
ORDER BY k;
----
1 1.8125 6.8007813 Float16 Float16
2 8.5 8.5 Float16 Float16
16 changes: 7 additions & 9 deletions datafusion/sqllogictest/test_files/arrow_typeof.slt
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ query error Error unrecognized word: unknown
SELECT arrow_cast('1', 'unknown')

# Round Trip tests:
query TTTTTTTTTTTTTTTTTTTTTTT
query TTTTTTTTTTTTTTTTTTTTTTTT
SELECT
arrow_typeof(arrow_cast(1, 'Int8')) as col_i8,
arrow_typeof(arrow_cast(1, 'Int16')) as col_i16,
Expand All @@ -112,8 +112,7 @@ SELECT
arrow_typeof(arrow_cast(1, 'UInt16')) as col_u16,
arrow_typeof(arrow_cast(1, 'UInt32')) as col_u32,
arrow_typeof(arrow_cast(1, 'UInt64')) as col_u64,
-- can't seem to cast to Float16 for some reason
-- arrow_typeof(arrow_cast(1, 'Float16')) as col_f16,
arrow_typeof(arrow_cast(1, 'Float16')) as col_f16,
arrow_typeof(arrow_cast(1, 'Float32')) as col_f32,
arrow_typeof(arrow_cast(1, 'Float64')) as col_f64,
arrow_typeof(arrow_cast('foo', 'Utf8')) as col_utf8,
Expand All @@ -130,7 +129,7 @@ SELECT
arrow_typeof(arrow_cast(to_timestamp('2020-01-02 01:01:11.1234567890Z'), 'Timestamp(Nanosecond, Some("+08:00"))')) as col_tstz_ns,
arrow_typeof(arrow_cast('foo', 'Dictionary(Int32, Utf8)')) as col_dict
----
Int8 Int16 Int32 Int64 UInt8 UInt16 UInt32 UInt64 Float32 Float64 Utf8 LargeUtf8 Binary LargeBinary Timestamp(Second, None) Timestamp(Millisecond, None) Timestamp(Microsecond, None) Timestamp(Nanosecond, None) Timestamp(Second, Some("+08:00")) Timestamp(Millisecond, Some("+08:00")) Timestamp(Microsecond, Some("+08:00")) Timestamp(Nanosecond, Some("+08:00")) Dictionary(Int32, Utf8)
Int8 Int16 Int32 Int64 UInt8 UInt16 UInt32 UInt64 Float16 Float32 Float64 Utf8 LargeUtf8 Binary LargeBinary Timestamp(Second, None) Timestamp(Millisecond, None) Timestamp(Microsecond, None) Timestamp(Nanosecond, None) Timestamp(Second, Some("+08:00")) Timestamp(Millisecond, Some("+08:00")) Timestamp(Microsecond, Some("+08:00")) Timestamp(Nanosecond, Some("+08:00")) Dictionary(Int32, Utf8)



Expand All @@ -147,15 +146,14 @@ create table foo as select
arrow_cast(1, 'UInt16') as col_u16,
arrow_cast(1, 'UInt32') as col_u32,
arrow_cast(1, 'UInt64') as col_u64,
-- can't seem to cast to Float16 for some reason
-- arrow_cast(1.0, 'Float16') as col_f16,
arrow_cast(1.0, 'Float16') as col_f16,
arrow_cast(1.0, 'Float32') as col_f32,
arrow_cast(1.0, 'Float64') as col_f64
;

## Ensure each column in the table has the expected type

query TTTTTTTTTT
query TTTTTTTTTTT
SELECT
arrow_typeof(col_i8),
arrow_typeof(col_i16),
Expand All @@ -165,12 +163,12 @@ SELECT
arrow_typeof(col_u16),
arrow_typeof(col_u32),
arrow_typeof(col_u64),
-- arrow_typeof(col_f16),
arrow_typeof(col_f16),
arrow_typeof(col_f32),
arrow_typeof(col_f64)
FROM foo;
----
Int8 Int16 Int32 Int64 UInt8 UInt16 UInt32 UInt64 Float32 Float64
Int8 Int16 Int32 Int64 UInt8 UInt16 UInt32 UInt64 Float16 Float32 Float64


statement ok
Expand Down

0 comments on commit 7c5a8eb

Please sign in to comment.