Skip to content

Commit

Permalink
use lower tolerance for zero detection in simplify_algebra (#2696)
Browse files Browse the repository at this point in the history
  • Loading branch information
umangyadav authored Feb 1, 2024
1 parent 93fa001 commit 96f85a0
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 15 deletions.
18 changes: 8 additions & 10 deletions src/include/migraphx/matcher.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#ifndef MIGRAPHX_GUARD_RTGLIB_MATCHER_HPP
#define MIGRAPHX_GUARD_RTGLIB_MATCHER_HPP

#include <migraphx/float_equal.hpp>
#include <migraphx/functional.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
Expand Down Expand Up @@ -872,7 +873,7 @@ inline auto literal_value_checker(F f)

/**
* Uses integer multiples of the corresponding floating point epsilon and
* compares with abs(y - x) < eps * (atol_mult + rtol_mult * abs(y)).
* compares with abs(y - x) < eps * (atol_mult + rtol_mult * abs(x)).
* atol_mult controls the absolute tolerance.
* rtol_mult controls the relative tolerance.
* Uses no tolerance for integral types.
Expand All @@ -885,24 +886,21 @@ inline auto has_value(T x, std::size_t atol_mult = 10, std::size_t rtol_mult = 1
l.visit([&](auto v) {
// cast to the literal's data type before comparing
using type = typename decltype(v)::value_type;
if constexpr(std::is_integral<type>{})
auto tolerance = atol_mult + rtol_mult * std::fabs(x);
if(migraphx::float_equal(tolerance, 0) or std::is_integral<type>{})
{
if(std::all_of(
v.begin(), v.end(), [&](auto val) { return val == static_cast<type>(x); }))
{
if(std::all_of(v.begin(), v.end(), [&](auto val) {
return migraphx::float_equal(val, static_cast<type>(x));
}))
b = true;
}
}
else
{
auto eps = std::numeric_limits<type>::epsilon();
if(std::all_of(v.begin(), v.end(), [&](auto val) {
return std::fabs(val - static_cast<type>(x)) <
eps * (atol_mult + rtol_mult * std::fabs(val));
return std::fabs(val - static_cast<type>(x)) < (eps * tolerance);
}))
{
b = true;
}
}
});
return b;
Expand Down
10 changes: 5 additions & 5 deletions src/simplify_algebra.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1217,9 +1217,9 @@ struct find_unit_ops
auto div_1 =
match::name("div")(match::args(match::any().bind("x"), match::has_value(1.0f)));
auto add_0 = match::name("add")(
match::either_arg(0, 1)(match::has_value(0.0f, 1, 0), match::any().bind("x")));
match::either_arg(0, 1)(match::has_value(0.0f, 0, 0), match::any().bind("x")));
auto sub_0 =
match::name("sub")(match::args(match::any().bind("x"), match::has_value(0.0f, 1, 0)));
match::name("sub")(match::args(match::any().bind("x"), match::has_value(0.0f, 0, 0)));
return match::any_of(mul_1, div_1, add_0, sub_0);
}

Expand All @@ -1241,7 +1241,7 @@ struct find_neg_unit_ops
auto div_neg_1 =
match::name("div")(match::args(match::any().bind("x"), match::has_value(-1.0f)));
auto sub_0 =
match::name("sub")(match::args(match::has_value(0.0f), match::any().bind("x")));
match::name("sub")(match::args(match::has_value(0.0f, 0, 0), match::any().bind("x")));
return match::any_of(mul_neg_1, div_neg_1, sub_0);
}

Expand All @@ -1260,9 +1260,9 @@ struct find_zero_ops
auto matcher() const
{
auto mul_zero = match::name("mul")(
match::either_arg(0, 1)(match::has_value(0.0f).bind("x"), match::any()));
match::either_arg(0, 1)(match::has_value(0.0f, 0, 0).bind("x"), match::any()));
auto div_zero =
match::name("div")(match::args(match::has_value(0.0f).bind("x"), match::any()));
match::name("div")(match::args(match::has_value(0.0f, 0, 0).bind("x"), match::any()));
return match::any_of(mul_zero, div_zero);
}

Expand Down
58 changes: 58 additions & 0 deletions test/matcher.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -903,6 +903,64 @@ TEST_CASE(match_has_value6)
EXPECT(bool{r.result == mm.end()});
}

TEST_CASE(match_has_value7)
{
// zero detection
migraphx::module mm;
auto s = migraphx::shape{migraphx::shape::half_type, {1}, {0}};
auto zero = mm.add_literal(migraphx::literal{s, {0.00239754}});
auto one = mm.add_literal(migraphx::literal{s, {1.0}});
auto sum1 = mm.add_instruction(sum_op{}, one, zero);
mm.add_instruction(pass_op{}, sum1);
auto m1 = match::has_value(0.0f, 0, 0);
auto r1 = find_match(mm, m1);
EXPECT(bool{r1.result == mm.end()});
// increase tolerance
auto m2 = match::has_value(0.0f);
auto r2 = find_match(mm, m2);
EXPECT(bool{r2.result == zero});
}

TEST_CASE(match_has_value8)
{
// zero detection
migraphx::module mm;
auto s = migraphx::shape{migraphx::shape::half_type, {1}, {0}};
auto zero = mm.add_literal(migraphx::literal{s, {9.07183e-05}});
auto one = mm.add_literal(migraphx::literal{s, {1.0}});
auto sum1 = mm.add_instruction(sum_op{}, one, zero);
mm.add_instruction(pass_op{}, sum1);
auto m1 = match::has_value(0.0f, 0, 0);
auto r1 = find_match(mm, m1);
EXPECT(bool{r1.result == mm.end()});
// increase tolerance
auto m2 = match::has_value(0.0f);
auto r2 = find_match(mm, m2);
EXPECT(bool{r2.result == zero});
}

TEST_CASE(match_has_value9)
{
migraphx::module mm;
auto s = migraphx::shape{migraphx::shape::half_type, {1}, {0}};
auto n_five = mm.add_literal(migraphx::literal{s, {-5.0}});
mm.add_instruction(pass_op{}, n_five);
auto m1 = match::has_value(5.0f);
auto r1 = find_match(mm, m1);
EXPECT(bool{r1.result == mm.end()});
// increase tolerance
auto m2 = match::has_value(-5.0f);
auto r2 = find_match(mm, m2);
EXPECT(bool{r2.result == n_five});
// do exact match
auto m3 = match::has_value(-5.0f, 0, 0);
auto r3 = find_match(mm, m3);
EXPECT(bool{r3.result == n_five});
// do exact match
auto m4 = match::has_value(5.0f, 0, 0);
auto r4 = find_match(mm, m4);
EXPECT(bool{r4.result == mm.end()});
}
TEST_CASE(match_has_value_eps1)
{
migraphx::module mm;
Expand Down

0 comments on commit 96f85a0

Please sign in to comment.