From 620fec7ff80a630119eaf5e6bb689aa5e873ea15 Mon Sep 17 00:00:00 2001 From: John Platts Date: Thu, 10 Oct 2024 21:18:25 -0500 Subject: [PATCH] Added MinMagnitude and MaxMagnitude ops --- g3doc/quick_reference.md | 18 +++++++ hwy/ops/generic_ops-inl.h | 57 +++++++++++++++++++++ hwy/tests/minmax_test.cc | 101 ++++++++++++++++++++++++++++++++++++++ 3 files changed, 176 insertions(+) diff --git a/g3doc/quick_reference.md b/g3doc/quick_reference.md index 8220e9b718..059acef4d8 100644 --- a/g3doc/quick_reference.md +++ b/g3doc/quick_reference.md @@ -678,6 +678,24 @@ is qNaN, and NaN if both are. * V **Max**(V a, V b): returns `max(a[i], b[i])`. +* V **MinMagnitude**(V a, V b): returns the number with the + smaller magnitude if `a[i]` and `b[i]` are both non-NaN values. + + If `a[i]` and `b[i]` are both non-NaN, `MinMagnitude(a, b)` returns + `(|a[i]| < |b[i]| || (|a[i]| == |b[i]| && a[i] < b[i])) ? a[i] : b[i]`. + + Otherwise, the results of `MinMagnitude(a, b)` are implementation-defined + if `a[i]` is NaN or `b[i]` is NaN. + +* V **MaxMagnitude**(V a, V b): returns the number with the + larger magnitude if `a[i]` and `b[i]` are both non-NaN values. + + If `a[i]` and `b[i]` are both non-NaN, `MaxMagnitude(a, b)` returns + `(|a[i]| < |b[i]| || (|a[i]| == |b[i]| && a[i] < b[i])) ? b[i] : a[i]`. + + Otherwise, the results of `MaxMagnitude(a, b)` are implementation-defined + if `a[i]` is NaN or `b[i]` is NaN. + All other ops in this section are only available if `HWY_TARGET != HWY_SCALAR`: * `V`: `u64` \ diff --git a/hwy/ops/generic_ops-inl.h b/hwy/ops/generic_ops-inl.h index 99b518d99c..7447fcb6cc 100644 --- a/hwy/ops/generic_ops-inl.h +++ b/hwy/ops/generic_ops-inl.h @@ -488,6 +488,63 @@ HWY_API V InterleaveEven(V a, V b) { } #endif +// ------------------------------ MinMagnitude/MaxMagnitude + +#if (defined(HWY_NATIVE_FLOAT_MIN_MAX_MAGNITUDE) == defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_FLOAT_MIN_MAX_MAGNITUDE +#undef HWY_NATIVE_FLOAT_MIN_MAX_MAGNITUDE +#else +#define HWY_NATIVE_FLOAT_MIN_MAX_MAGNITUDE +#endif + +template +HWY_API V MinMagnitude(V a, V b) { + const auto abs_a = Abs(a); + const auto abs_b = Abs(b); + return IfThenElse(Lt(abs_a, abs_b), a, + Min(IfThenElse(Eq(abs_a, abs_b), a, b), b)); +} + +template +HWY_API V MaxMagnitude(V a, V b) { + const auto abs_a = Abs(a); + const auto abs_b = Abs(b); + return IfThenElse(Lt(abs_a, abs_b), b, + Max(IfThenElse(Eq(abs_a, abs_b), b, a), a)); +} + +#endif // HWY_NATIVE_FLOAT_MIN_MAX_MAGNITUDE + +template +HWY_API V MinMagnitude(V a, V b) { + const DFromV d; + const RebindToUnsigned du; + const auto abs_a = BitCast(du, Abs(a)); + const auto abs_b = BitCast(du, Abs(b)); + return IfThenElse(RebindMask(d, Lt(abs_a, abs_b)), a, + Min(IfThenElse(RebindMask(d, Eq(abs_a, abs_b)), a, b), b)); +} + +template +HWY_API V MaxMagnitude(V a, V b) { + const DFromV d; + const RebindToUnsigned du; + const auto abs_a = BitCast(du, Abs(a)); + const auto abs_b = BitCast(du, Abs(b)); + return IfThenElse(RebindMask(d, Lt(abs_a, abs_b)), b, + Max(IfThenElse(RebindMask(d, Eq(abs_a, abs_b)), b, a), a)); +} + +template +HWY_API V MinMagnitude(V a, V b) { + return Min(a, b); +} + +template +HWY_API V MaxMagnitude(V a, V b) { + return Max(a, b); +} + // ------------------------------ AddSub template , 1)> diff --git a/hwy/tests/minmax_test.cc b/hwy/tests/minmax_test.cc index 3ef116d30d..1a08d56aee 100644 --- a/hwy/tests/minmax_test.cc +++ b/hwy/tests/minmax_test.cc @@ -257,6 +257,106 @@ HWY_NOINLINE void TestAllMinMax128Upper() { ForGEVectors<128, TestMinMax128Upper>()(uint64_t()); } +struct TestMinMaxMagnitude { + template + static constexpr MakeSigned MaxPosIotaVal(hwy::FloatTag /*type_tag*/) { + return static_cast>(MantissaMask() + 1); + } + template + static constexpr MakeSigned MaxPosIotaVal(hwy::NonFloatTag /*type_tag*/) { + return static_cast>(((LimitsMax>()) >> 1) + 1); + } + + template + HWY_NOINLINE static void VerifyMinMaxMagnitude( + D d, const TFromD* HWY_RESTRICT in1_lanes, + const TFromD* HWY_RESTRICT in2_lanes, const int line) { + using T = TFromD; + using TAbs = If() || IsSpecialFloat(), T, MakeUnsigned>; + + const char* file = __FILE__; + const size_t N = Lanes(d); + auto expected_min_mag = AllocateAligned(N); + auto expected_max_mag = AllocateAligned(N); + HWY_ASSERT(expected_min_mag && expected_max_mag); + + for (size_t i = 0; i < N; i++) { + const T val1 = in1_lanes[i]; + const T val2 = in2_lanes[i]; + const TAbs abs_val1 = static_cast(ScalarAbs(val1)); + const TAbs abs_val2 = static_cast(ScalarAbs(val2)); + if (abs_val1 < abs_val2 || (abs_val1 == abs_val2 && val1 < val2)) { + expected_min_mag[i] = val1; + expected_max_mag[i] = val2; + } else { + expected_min_mag[i] = val2; + expected_max_mag[i] = val1; + } + } + + const auto in1 = Load(d, in1_lanes); + const auto in2 = Load(d, in2_lanes); + AssertVecEqual(d, expected_min_mag.get(), MinMagnitude(in1, in2), file, + line); + AssertVecEqual(d, expected_min_mag.get(), MinMagnitude(in2, in1), file, + line); + AssertVecEqual(d, expected_max_mag.get(), MaxMagnitude(in1, in2), file, + line); + AssertVecEqual(d, expected_max_mag.get(), MaxMagnitude(in2, in1), file, + line); + } + + template + HWY_NOINLINE void operator()(T /*unused*/, D d) { + using TI = MakeSigned; + using TU = MakeUnsigned; + constexpr TI kMaxPosIotaVal = MaxPosIotaVal(hwy::IsFloatTag()); + static_assert(kMaxPosIotaVal > 0, "kMaxPosIotaVal > 0 must be true"); + + constexpr size_t kPositiveIotaMask = static_cast( + static_cast(kMaxPosIotaVal - 1) & (HWY_MAX_LANES_D(D) - 1)); + + const size_t N = Lanes(d); + auto in1_lanes = AllocateAligned(N); + auto in2_lanes = AllocateAligned(N); + auto in3_lanes = AllocateAligned(N); + auto in4_lanes = AllocateAligned(N); + HWY_ASSERT(in1_lanes && in2_lanes && in3_lanes && in4_lanes); + + for (size_t i = 0; i < N; i++) { + const TI x1 = static_cast((i & kPositiveIotaMask) + 1); + const TI x2 = static_cast(kMaxPosIotaVal - x1); + const TI x3 = static_cast(-x1); + const TI x4 = static_cast(-x2); + + in1_lanes[i] = ConvertScalarTo(x1); + in2_lanes[i] = ConvertScalarTo(x2); + in3_lanes[i] = ConvertScalarTo(x3); + in4_lanes[i] = ConvertScalarTo(x4); + } + + VerifyMinMaxMagnitude(d, in1_lanes.get(), in2_lanes.get(), __LINE__); + VerifyMinMaxMagnitude(d, in1_lanes.get(), in3_lanes.get(), __LINE__); + VerifyMinMaxMagnitude(d, in1_lanes.get(), in4_lanes.get(), __LINE__); + VerifyMinMaxMagnitude(d, in2_lanes.get(), in3_lanes.get(), __LINE__); + VerifyMinMaxMagnitude(d, in2_lanes.get(), in4_lanes.get(), __LINE__); + VerifyMinMaxMagnitude(d, in3_lanes.get(), in4_lanes.get(), __LINE__); + + in2_lanes[0] = HighestValue(); + in4_lanes[0] = LowestValue(); + + VerifyMinMaxMagnitude(d, in1_lanes.get(), in2_lanes.get(), __LINE__); + VerifyMinMaxMagnitude(d, in1_lanes.get(), in4_lanes.get(), __LINE__); + VerifyMinMaxMagnitude(d, in2_lanes.get(), in3_lanes.get(), __LINE__); + VerifyMinMaxMagnitude(d, in2_lanes.get(), in4_lanes.get(), __LINE__); + VerifyMinMaxMagnitude(d, in3_lanes.get(), in4_lanes.get(), __LINE__); + } +}; + +HWY_NOINLINE void TestAllMinMaxMagnitude() { + ForAllTypes(ForPartialVectors()); +} + // NOLINTNEXTLINE(google-readability-namespace-comments) } // namespace HWY_NAMESPACE } // namespace hwy @@ -269,6 +369,7 @@ HWY_BEFORE_TEST(HwyMinMaxTest); HWY_EXPORT_AND_TEST_P(HwyMinMaxTest, TestAllMinMax); HWY_EXPORT_AND_TEST_P(HwyMinMaxTest, TestAllMinMax128); HWY_EXPORT_AND_TEST_P(HwyMinMaxTest, TestAllMinMax128Upper); +HWY_EXPORT_AND_TEST_P(HwyMinMaxTest, TestAllMinMaxMagnitude); HWY_AFTER_TEST(); } // namespace hwy