Skip to content

Commit

Permalink
Minor: Add more support for ScalarValue::Float16 (#11156)
Browse files Browse the repository at this point in the history
  • Loading branch information
Lordworms authored Jun 30, 2024
1 parent e52b5e5 commit d19487c
Showing 1 changed file with 35 additions and 1 deletion.
36 changes: 35 additions & 1 deletion datafusion/common/src/scalar/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -982,6 +982,7 @@ impl ScalarValue {
DataType::UInt16 => ScalarValue::UInt16(Some(0)),
DataType::UInt32 => ScalarValue::UInt32(Some(0)),
DataType::UInt64 => ScalarValue::UInt64(Some(0)),
DataType::Float16 => ScalarValue::Float16(Some(f16::from_f32(0.0))),
DataType::Float32 => ScalarValue::Float32(Some(0.0)),
DataType::Float64 => ScalarValue::Float64(Some(0.0)),
DataType::Timestamp(TimeUnit::Second, tz) => {
Expand Down Expand Up @@ -1035,6 +1036,7 @@ impl ScalarValue {
DataType::UInt16 => ScalarValue::UInt16(Some(1)),
DataType::UInt32 => ScalarValue::UInt32(Some(1)),
DataType::UInt64 => ScalarValue::UInt64(Some(1)),
DataType::Float16 => ScalarValue::Float16(Some(f16::from_f32(1.0))),
DataType::Float32 => ScalarValue::Float32(Some(1.0)),
DataType::Float64 => ScalarValue::Float64(Some(1.0)),
_ => {
Expand All @@ -1053,6 +1055,7 @@ impl ScalarValue {
DataType::Int16 | DataType::UInt16 => ScalarValue::Int16(Some(-1)),
DataType::Int32 | DataType::UInt32 => ScalarValue::Int32(Some(-1)),
DataType::Int64 | DataType::UInt64 => ScalarValue::Int64(Some(-1)),
DataType::Float16 => ScalarValue::Float16(Some(f16::from_f32(-1.0))),
DataType::Float32 => ScalarValue::Float32(Some(-1.0)),
DataType::Float64 => ScalarValue::Float64(Some(-1.0)),
_ => {
Expand All @@ -1074,6 +1077,7 @@ impl ScalarValue {
DataType::UInt16 => ScalarValue::UInt16(Some(10)),
DataType::UInt32 => ScalarValue::UInt32(Some(10)),
DataType::UInt64 => ScalarValue::UInt64(Some(10)),
DataType::Float16 => ScalarValue::Float16(Some(f16::from_f32(10.0))),
DataType::Float32 => ScalarValue::Float32(Some(10.0)),
DataType::Float64 => ScalarValue::Float64(Some(10.0)),
_ => {
Expand Down Expand Up @@ -1181,8 +1185,12 @@ impl ScalarValue {
| ScalarValue::Int16(None)
| ScalarValue::Int32(None)
| ScalarValue::Int64(None)
| ScalarValue::Float16(None)
| ScalarValue::Float32(None)
| ScalarValue::Float64(None) => Ok(self.clone()),
ScalarValue::Float16(Some(v)) => {
Ok(ScalarValue::Float16(Some(f16::from_f32(-v.to_f32()))))
}
ScalarValue::Float64(Some(v)) => Ok(ScalarValue::Float64(Some(-v))),
ScalarValue::Float32(Some(v)) => Ok(ScalarValue::Float32(Some(-v))),
ScalarValue::Int8(Some(v)) => Ok(ScalarValue::Int8(Some(v.neg_checked()?))),
Expand Down Expand Up @@ -1435,6 +1443,9 @@ impl ScalarValue {
(Self::UInt32(Some(l)), Self::UInt32(Some(r))) => Some(l.abs_diff(*r) as _),
(Self::UInt64(Some(l)), Self::UInt64(Some(r))) => Some(l.abs_diff(*r) as _),
// TODO: we might want to look into supporting ceil/floor here for floats.
(Self::Float16(Some(l)), Self::Float16(Some(r))) => {
Some((f16::to_f32(*l) - f16::to_f32(*r)).abs().round() as _)
}
(Self::Float32(Some(l)), Self::Float32(Some(r))) => {
Some((l - r).abs().round() as _)
}
Expand Down Expand Up @@ -2452,6 +2463,7 @@ impl ScalarValue {
DataType::Boolean => typed_cast!(array, index, BooleanArray, Boolean)?,
DataType::Float64 => typed_cast!(array, index, Float64Array, Float64)?,
DataType::Float32 => typed_cast!(array, index, Float32Array, Float32)?,
DataType::Float16 => typed_cast!(array, index, Float16Array, Float16)?,
DataType::UInt64 => typed_cast!(array, index, UInt64Array, UInt64)?,
DataType::UInt32 => typed_cast!(array, index, UInt32Array, UInt32)?,
DataType::UInt16 => typed_cast!(array, index, UInt16Array, UInt16)?,
Expand Down Expand Up @@ -5635,7 +5647,6 @@ mod tests {
}

#[test]
#[should_panic(expected = "Can not run arithmetic negative on scalar value Float16")]
fn f16_test_overflow() {
// TODO: if negate supports f16, add these cases to `test_scalar_negative_overflows` test case
let cases = [
Expand Down Expand Up @@ -5805,6 +5816,21 @@ mod tests {
ScalarValue::UInt64(Some(10)),
5,
),
(
ScalarValue::Float16(Some(f16::from_f32(1.1))),
ScalarValue::Float16(Some(f16::from_f32(1.9))),
1,
),
(
ScalarValue::Float16(Some(f16::from_f32(-5.3))),
ScalarValue::Float16(Some(f16::from_f32(-9.2))),
4,
),
(
ScalarValue::Float16(Some(f16::from_f32(-5.3))),
ScalarValue::Float16(Some(f16::from_f32(-9.7))),
4,
),
(
ScalarValue::Float32(Some(1.0)),
ScalarValue::Float32(Some(2.0)),
Expand Down Expand Up @@ -5877,6 +5903,14 @@ mod tests {
// Different type
(ScalarValue::Int8(Some(1)), ScalarValue::Int16(Some(1))),
(ScalarValue::Int8(Some(1)), ScalarValue::Float32(Some(1.0))),
(
ScalarValue::Float16(Some(f16::from_f32(1.0))),
ScalarValue::Float32(Some(1.0)),
),
(
ScalarValue::Float16(Some(f16::from_f32(1.0))),
ScalarValue::Int32(Some(1)),
),
(
ScalarValue::Float64(Some(1.1)),
ScalarValue::Float32(Some(2.2)),
Expand Down

0 comments on commit d19487c

Please sign in to comment.