From 3d9e67b55415e3cf23cda947d7ef419243e336be Mon Sep 17 00:00:00 2001 From: Dillon Sharlet Date: Wed, 23 Oct 2024 11:51:44 -0700 Subject: [PATCH] Fix likely bugs in unary operator tester - The order in the declaration of this is `min, stride, max`, but the uses look like `min, max, stride`. - Loops use "less", but a max suggests it should be "less equal". After these fixes, the leaky-relu-nc test required some tweaks to the range used to avoid asserts that the scale was out of range. PiperOrigin-RevId: 689049238 --- test/unary-operator-tester.h | 20 +++++++++----------- 1 file changed, 9 insertions(+), 11 deletions(-) diff --git a/test/unary-operator-tester.h b/test/unary-operator-tester.h index 869827e7020..4b8715a2434 100644 --- a/test/unary-operator-tester.h +++ b/test/unary-operator-tester.h @@ -444,8 +444,8 @@ class UnaryOperatorTester { template struct LoopLimits { T min; - T stride; T max; + T stride; std::string ToString() const { return "[" + std::to_string(min) + ":" + std::to_string(stride) + ":" + std::to_string(max) + "]"; @@ -605,14 +605,14 @@ inline std::ostream& operator<<(std::ostream& os, UnaryOpTestParams params) { TEST_P(Tester##datatype, Test##datatype) { \ const UnaryOpTestParams& test_case = GetParam(); \ for (size_t channels = test_case.channels.min; \ - channels < test_case.channels.max; \ + channels <= test_case.channels.max; \ channels += test_case.channels.stride) { \ LoopLimits input_scale_limits{1, 2, 2}; \ if (test_case.input_scale) { \ input_scale_limits = *test_case.input_scale; \ } \ for (float input_scale = input_scale_limits.min; \ - input_scale < input_scale_limits.max; \ + input_scale <= input_scale_limits.max; \ input_scale *= input_scale_limits.stride) { \ LoopLimits input_zero_point_limits{0, 1, 1}; \ if (test_case.input_zero_point) { \ @@ -672,21 +672,20 @@ inline std::ostream& operator<<(std::ostream& os, UnaryOpTestParams params) { UnaryOpTestParams::UnitBatch(), \ UnaryOpTestParams::UnitBatch().Qmin(128), \ UnaryOpTestParams::UnitBatch().Qmax(128), \ - UnaryOpTestParams::UnitBatch().InputScale({1.0e-2f, 1.0e2f, 10.0f}), \ + UnaryOpTestParams::UnitBatch().InputScale({1.0e-2f, 50.0f, 10.0f}), \ UnaryOpTestParams::UnitBatch().InputZeroPoint({0, 255, 51}), \ UnaryOpTestParams::SmallBatch(), \ UnaryOpTestParams::SmallBatch().InputStride(129), \ UnaryOpTestParams::SmallBatch().OutputStride(117), \ UnaryOpTestParams::SmallBatch().Qmin(128), \ UnaryOpTestParams::SmallBatch().Qmax(128), \ - UnaryOpTestParams::SmallBatch().InputScale( \ - {1.0e-2f, 1.0e2f, 10.0f}), \ + UnaryOpTestParams::SmallBatch().InputScale({1.0e-2f, 50.0f, 10.0f}), \ UnaryOpTestParams::SmallBatch().InputZeroPoint({0, 255, 51}), \ UnaryOpTestParams::StridedBatch(), \ UnaryOpTestParams::StridedBatch().Qmin(128), \ UnaryOpTestParams::StridedBatch().Qmax(128), \ UnaryOpTestParams::StridedBatch().InputScale( \ - {1.0e-2f, 1.0e2f, 10.0f}), \ + {1.0e-2f, 50.0f, 10.0f}), \ UnaryOpTestParams::StridedBatch().InputZeroPoint({0, 255, 51}), \ }), \ [](const testing::TestParamInfo& info) { \ @@ -699,17 +698,16 @@ inline std::ostream& operator<<(std::ostream& os, UnaryOpTestParams params) { datatype, Tester##datatype, \ testing::ValuesIn({ \ UnaryOpTestParams::UnitBatch(), \ - UnaryOpTestParams::UnitBatch().InputScale({1.0e-2f, 1.0e2f, 10.0f}), \ + UnaryOpTestParams::UnitBatch().InputScale({1.0e-2f, 50.0f, 10.0f}), \ UnaryOpTestParams::UnitBatch().InputZeroPoint({0, 255, 51}), \ UnaryOpTestParams::SmallBatch(), \ UnaryOpTestParams::SmallBatch().InputStride(129), \ UnaryOpTestParams::SmallBatch().OutputStride(117), \ - UnaryOpTestParams::SmallBatch().InputScale( \ - {1.0e-2f, 1.0e2f, 10.0f}), \ + UnaryOpTestParams::SmallBatch().InputScale({1.0e-2f, 50.0f, 10.0f}), \ UnaryOpTestParams::SmallBatch().InputZeroPoint({0, 255, 51}), \ UnaryOpTestParams::StridedBatch(), \ UnaryOpTestParams::StridedBatch().InputScale( \ - {1.0e-2f, 1.0e2f, 10.0f}), \ + {1.0e-2f, 50.0f, 10.0f}), \ UnaryOpTestParams::StridedBatch().InputZeroPoint({0, 255, 51}), \ }), \ [](const testing::TestParamInfo& info) { \