Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added MinMagnitude and MaxMagnitude ops #2353

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions g3doc/quick_reference.md
Original file line number Diff line number Diff line change
Expand Up @@ -678,6 +678,24 @@ is qNaN, and NaN if both are.

* <code>V **Max**(V a, V b)</code>: returns `max(a[i], b[i])`.

* <code>V **MinMagnitude**(V a, V b)</code>: returns the number with the
smaller magnitude if `a[i]` and `b[i]` are both non-NaN values.

If `a[i]` and `b[i]` are both non-NaN, `MinMagnitude(a, b)` returns
`(|a[i]| < |b[i]| || (|a[i]| == |b[i]| && a[i] < b[i])) ? a[i] : b[i]`.

Otherwise, the results of `MinMagnitude(a, b)` are implementation-defined
if `a[i]` is NaN or `b[i]` is NaN.

* <code>V **MaxMagnitude**(V a, V b)</code>: returns the number with the
larger magnitude if `a[i]` and `b[i]` are both non-NaN values.

If `a[i]` and `b[i]` are both non-NaN, `MaxMagnitude(a, b)` returns
`(|a[i]| < |b[i]| || (|a[i]| == |b[i]| && a[i] < b[i])) ? b[i] : a[i]`.

Otherwise, the results of `MaxMagnitude(a, b)` are implementation-defined
if `a[i]` is NaN or `b[i]` is NaN.

All other ops in this section are only available if `HWY_TARGET != HWY_SCALAR`:

* `V`: `u64` \
Expand Down
57 changes: 57 additions & 0 deletions hwy/ops/generic_ops-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -488,6 +488,63 @@ HWY_API V InterleaveEven(V a, V b) {
}
#endif

// ------------------------------ MinMagnitude/MaxMagnitude

#if (defined(HWY_NATIVE_FLOAT_MIN_MAX_MAGNITUDE) == defined(HWY_TARGET_TOGGLE))
#ifdef HWY_NATIVE_FLOAT_MIN_MAX_MAGNITUDE
#undef HWY_NATIVE_FLOAT_MIN_MAX_MAGNITUDE
#else
#define HWY_NATIVE_FLOAT_MIN_MAX_MAGNITUDE
#endif

template <class V, HWY_IF_FLOAT_V(V)>
HWY_API V MinMagnitude(V a, V b) {
const auto abs_a = Abs(a);
const auto abs_b = Abs(b);
return IfThenElse(Lt(abs_a, abs_b), a,
Min(IfThenElse(Eq(abs_a, abs_b), a, b), b));
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Neat trick to match the x86 |a| < |b|, |b| < |a|, min(a, b) cases :)

}

template <class V, HWY_IF_FLOAT_V(V)>
HWY_API V MaxMagnitude(V a, V b) {
const auto abs_a = Abs(a);
const auto abs_b = Abs(b);
return IfThenElse(Lt(abs_a, abs_b), b,
Max(IfThenElse(Eq(abs_a, abs_b), b, a), a));
}

#endif // HWY_NATIVE_FLOAT_MIN_MAX_MAGNITUDE

template <class V, HWY_IF_SIGNED_V(V)>
HWY_API V MinMagnitude(V a, V b) {
const DFromV<V> d;
const RebindToUnsigned<decltype(d)> du;
const auto abs_a = BitCast(du, Abs(a));
const auto abs_b = BitCast(du, Abs(b));
return IfThenElse(RebindMask(d, Lt(abs_a, abs_b)), a,
Min(IfThenElse(RebindMask(d, Eq(abs_a, abs_b)), a, b), b));
}

template <class V, HWY_IF_SIGNED_V(V)>
HWY_API V MaxMagnitude(V a, V b) {
const DFromV<V> d;
const RebindToUnsigned<decltype(d)> du;
const auto abs_a = BitCast(du, Abs(a));
const auto abs_b = BitCast(du, Abs(b));
return IfThenElse(RebindMask(d, Lt(abs_a, abs_b)), b,
Max(IfThenElse(RebindMask(d, Eq(abs_a, abs_b)), b, a), a));
}

template <class V, HWY_IF_UNSIGNED_V(V)>
HWY_API V MinMagnitude(V a, V b) {
return Min(a, b);
}

template <class V, HWY_IF_UNSIGNED_V(V)>
HWY_API V MaxMagnitude(V a, V b) {
return Max(a, b);
}

// ------------------------------ AddSub

template <class V, HWY_IF_LANES_D(DFromV<V>, 1)>
Expand Down
101 changes: 101 additions & 0 deletions hwy/tests/minmax_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,106 @@ HWY_NOINLINE void TestAllMinMax128Upper() {
ForGEVectors<128, TestMinMax128Upper>()(uint64_t());
}

struct TestMinMaxMagnitude {
template <class T>
static constexpr MakeSigned<T> MaxPosIotaVal(hwy::FloatTag /*type_tag*/) {
return static_cast<MakeSigned<T>>(MantissaMask<T>() + 1);
}
template <class T>
static constexpr MakeSigned<T> MaxPosIotaVal(hwy::NonFloatTag /*type_tag*/) {
return static_cast<MakeSigned<T>>(((LimitsMax<MakeSigned<T>>()) >> 1) + 1);
}

template <class D>
HWY_NOINLINE static void VerifyMinMaxMagnitude(
D d, const TFromD<D>* HWY_RESTRICT in1_lanes,
const TFromD<D>* HWY_RESTRICT in2_lanes, const int line) {
using T = TFromD<D>;
using TAbs = If<IsFloat<T>() || IsSpecialFloat<T>(), T, MakeUnsigned<T>>;

const char* file = __FILE__;
const size_t N = Lanes(d);
auto expected_min_mag = AllocateAligned<T>(N);
auto expected_max_mag = AllocateAligned<T>(N);
HWY_ASSERT(expected_min_mag && expected_max_mag);

for (size_t i = 0; i < N; i++) {
const T val1 = in1_lanes[i];
const T val2 = in2_lanes[i];
const TAbs abs_val1 = static_cast<TAbs>(ScalarAbs(val1));
const TAbs abs_val2 = static_cast<TAbs>(ScalarAbs(val2));
if (abs_val1 < abs_val2 || (abs_val1 == abs_val2 && val1 < val2)) {
expected_min_mag[i] = val1;
expected_max_mag[i] = val2;
} else {
expected_min_mag[i] = val2;
expected_max_mag[i] = val1;
}
}

const auto in1 = Load(d, in1_lanes);
const auto in2 = Load(d, in2_lanes);
AssertVecEqual(d, expected_min_mag.get(), MinMagnitude(in1, in2), file,
line);
AssertVecEqual(d, expected_min_mag.get(), MinMagnitude(in2, in1), file,
line);
AssertVecEqual(d, expected_max_mag.get(), MaxMagnitude(in1, in2), file,
line);
AssertVecEqual(d, expected_max_mag.get(), MaxMagnitude(in2, in1), file,
line);
}

template <class T, class D>
HWY_NOINLINE void operator()(T /*unused*/, D d) {
using TI = MakeSigned<T>;
using TU = MakeUnsigned<T>;
constexpr TI kMaxPosIotaVal = MaxPosIotaVal<T>(hwy::IsFloatTag<T>());
static_assert(kMaxPosIotaVal > 0, "kMaxPosIotaVal > 0 must be true");

constexpr size_t kPositiveIotaMask = static_cast<size_t>(
static_cast<TU>(kMaxPosIotaVal - 1) & (HWY_MAX_LANES_D(D) - 1));

const size_t N = Lanes(d);
auto in1_lanes = AllocateAligned<T>(N);
auto in2_lanes = AllocateAligned<T>(N);
auto in3_lanes = AllocateAligned<T>(N);
auto in4_lanes = AllocateAligned<T>(N);
HWY_ASSERT(in1_lanes && in2_lanes && in3_lanes && in4_lanes);

for (size_t i = 0; i < N; i++) {
const TI x1 = static_cast<TI>((i & kPositiveIotaMask) + 1);
const TI x2 = static_cast<TI>(kMaxPosIotaVal - x1);
const TI x3 = static_cast<TI>(-x1);
const TI x4 = static_cast<TI>(-x2);

in1_lanes[i] = ConvertScalarTo<T>(x1);
in2_lanes[i] = ConvertScalarTo<T>(x2);
in3_lanes[i] = ConvertScalarTo<T>(x3);
in4_lanes[i] = ConvertScalarTo<T>(x4);
}

VerifyMinMaxMagnitude(d, in1_lanes.get(), in2_lanes.get(), __LINE__);
VerifyMinMaxMagnitude(d, in1_lanes.get(), in3_lanes.get(), __LINE__);
VerifyMinMaxMagnitude(d, in1_lanes.get(), in4_lanes.get(), __LINE__);
VerifyMinMaxMagnitude(d, in2_lanes.get(), in3_lanes.get(), __LINE__);
VerifyMinMaxMagnitude(d, in2_lanes.get(), in4_lanes.get(), __LINE__);
VerifyMinMaxMagnitude(d, in3_lanes.get(), in4_lanes.get(), __LINE__);

in2_lanes[0] = HighestValue<T>();
in4_lanes[0] = LowestValue<T>();

VerifyMinMaxMagnitude(d, in1_lanes.get(), in2_lanes.get(), __LINE__);
VerifyMinMaxMagnitude(d, in1_lanes.get(), in4_lanes.get(), __LINE__);
VerifyMinMaxMagnitude(d, in2_lanes.get(), in3_lanes.get(), __LINE__);
VerifyMinMaxMagnitude(d, in2_lanes.get(), in4_lanes.get(), __LINE__);
VerifyMinMaxMagnitude(d, in3_lanes.get(), in4_lanes.get(), __LINE__);
}
};

HWY_NOINLINE void TestAllMinMaxMagnitude() {
ForAllTypes(ForPartialVectors<TestMinMaxMagnitude>());
}

// NOLINTNEXTLINE(google-readability-namespace-comments)
} // namespace HWY_NAMESPACE
} // namespace hwy
Expand All @@ -269,6 +369,7 @@ HWY_BEFORE_TEST(HwyMinMaxTest);
HWY_EXPORT_AND_TEST_P(HwyMinMaxTest, TestAllMinMax);
HWY_EXPORT_AND_TEST_P(HwyMinMaxTest, TestAllMinMax128);
HWY_EXPORT_AND_TEST_P(HwyMinMaxTest, TestAllMinMax128Upper);
HWY_EXPORT_AND_TEST_P(HwyMinMaxTest, TestAllMinMaxMagnitude);
HWY_AFTER_TEST();
} // namespace hwy

Expand Down
Loading