From d19487c9682738066ff3394d7492943adaadb509 Mon Sep 17 00:00:00 2001 From: Lordworms <48054792+Lordworms@users.noreply.github.com> Date: Sun, 30 Jun 2024 03:40:09 -0700 Subject: [PATCH] Minor: Add more support for ScalarValue::Float16 (#11156) --- datafusion/common/src/scalar/mod.rs | 36 ++++++++++++++++++++++++++++- 1 file changed, 35 insertions(+), 1 deletion(-) diff --git a/datafusion/common/src/scalar/mod.rs b/datafusion/common/src/scalar/mod.rs index 5b9c4a223de6..bd2265c85003 100644 --- a/datafusion/common/src/scalar/mod.rs +++ b/datafusion/common/src/scalar/mod.rs @@ -982,6 +982,7 @@ impl ScalarValue { DataType::UInt16 => ScalarValue::UInt16(Some(0)), DataType::UInt32 => ScalarValue::UInt32(Some(0)), DataType::UInt64 => ScalarValue::UInt64(Some(0)), + DataType::Float16 => ScalarValue::Float16(Some(f16::from_f32(0.0))), DataType::Float32 => ScalarValue::Float32(Some(0.0)), DataType::Float64 => ScalarValue::Float64(Some(0.0)), DataType::Timestamp(TimeUnit::Second, tz) => { @@ -1035,6 +1036,7 @@ impl ScalarValue { DataType::UInt16 => ScalarValue::UInt16(Some(1)), DataType::UInt32 => ScalarValue::UInt32(Some(1)), DataType::UInt64 => ScalarValue::UInt64(Some(1)), + DataType::Float16 => ScalarValue::Float16(Some(f16::from_f32(1.0))), DataType::Float32 => ScalarValue::Float32(Some(1.0)), DataType::Float64 => ScalarValue::Float64(Some(1.0)), _ => { @@ -1053,6 +1055,7 @@ impl ScalarValue { DataType::Int16 | DataType::UInt16 => ScalarValue::Int16(Some(-1)), DataType::Int32 | DataType::UInt32 => ScalarValue::Int32(Some(-1)), DataType::Int64 | DataType::UInt64 => ScalarValue::Int64(Some(-1)), + DataType::Float16 => ScalarValue::Float16(Some(f16::from_f32(-1.0))), DataType::Float32 => ScalarValue::Float32(Some(-1.0)), DataType::Float64 => ScalarValue::Float64(Some(-1.0)), _ => { @@ -1074,6 +1077,7 @@ impl ScalarValue { DataType::UInt16 => ScalarValue::UInt16(Some(10)), DataType::UInt32 => ScalarValue::UInt32(Some(10)), DataType::UInt64 => ScalarValue::UInt64(Some(10)), + DataType::Float16 => ScalarValue::Float16(Some(f16::from_f32(10.0))), DataType::Float32 => ScalarValue::Float32(Some(10.0)), DataType::Float64 => ScalarValue::Float64(Some(10.0)), _ => { @@ -1181,8 +1185,12 @@ impl ScalarValue { | ScalarValue::Int16(None) | ScalarValue::Int32(None) | ScalarValue::Int64(None) + | ScalarValue::Float16(None) | ScalarValue::Float32(None) | ScalarValue::Float64(None) => Ok(self.clone()), + ScalarValue::Float16(Some(v)) => { + Ok(ScalarValue::Float16(Some(f16::from_f32(-v.to_f32())))) + } ScalarValue::Float64(Some(v)) => Ok(ScalarValue::Float64(Some(-v))), ScalarValue::Float32(Some(v)) => Ok(ScalarValue::Float32(Some(-v))), ScalarValue::Int8(Some(v)) => Ok(ScalarValue::Int8(Some(v.neg_checked()?))), @@ -1435,6 +1443,9 @@ impl ScalarValue { (Self::UInt32(Some(l)), Self::UInt32(Some(r))) => Some(l.abs_diff(*r) as _), (Self::UInt64(Some(l)), Self::UInt64(Some(r))) => Some(l.abs_diff(*r) as _), // TODO: we might want to look into supporting ceil/floor here for floats. + (Self::Float16(Some(l)), Self::Float16(Some(r))) => { + Some((f16::to_f32(*l) - f16::to_f32(*r)).abs().round() as _) + } (Self::Float32(Some(l)), Self::Float32(Some(r))) => { Some((l - r).abs().round() as _) } @@ -2452,6 +2463,7 @@ impl ScalarValue { DataType::Boolean => typed_cast!(array, index, BooleanArray, Boolean)?, DataType::Float64 => typed_cast!(array, index, Float64Array, Float64)?, DataType::Float32 => typed_cast!(array, index, Float32Array, Float32)?, + DataType::Float16 => typed_cast!(array, index, Float16Array, Float16)?, DataType::UInt64 => typed_cast!(array, index, UInt64Array, UInt64)?, DataType::UInt32 => typed_cast!(array, index, UInt32Array, UInt32)?, DataType::UInt16 => typed_cast!(array, index, UInt16Array, UInt16)?, @@ -5635,7 +5647,6 @@ mod tests { } #[test] - #[should_panic(expected = "Can not run arithmetic negative on scalar value Float16")] fn f16_test_overflow() { // TODO: if negate supports f16, add these cases to `test_scalar_negative_overflows` test case let cases = [ @@ -5805,6 +5816,21 @@ mod tests { ScalarValue::UInt64(Some(10)), 5, ), + ( + ScalarValue::Float16(Some(f16::from_f32(1.1))), + ScalarValue::Float16(Some(f16::from_f32(1.9))), + 1, + ), + ( + ScalarValue::Float16(Some(f16::from_f32(-5.3))), + ScalarValue::Float16(Some(f16::from_f32(-9.2))), + 4, + ), + ( + ScalarValue::Float16(Some(f16::from_f32(-5.3))), + ScalarValue::Float16(Some(f16::from_f32(-9.7))), + 4, + ), ( ScalarValue::Float32(Some(1.0)), ScalarValue::Float32(Some(2.0)), @@ -5877,6 +5903,14 @@ mod tests { // Different type (ScalarValue::Int8(Some(1)), ScalarValue::Int16(Some(1))), (ScalarValue::Int8(Some(1)), ScalarValue::Float32(Some(1.0))), + ( + ScalarValue::Float16(Some(f16::from_f32(1.0))), + ScalarValue::Float32(Some(1.0)), + ), + ( + ScalarValue::Float16(Some(f16::from_f32(1.0))), + ScalarValue::Int32(Some(1)), + ), ( ScalarValue::Float64(Some(1.1)), ScalarValue::Float32(Some(2.2)),