Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Small simplify improvements #560

Merged
merged 4 commits into from
Jan 22, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion builder/rewrite.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ namespace slinky {
namespace rewrite {

// The maximum number of values pattern_wildcard::idx and pattern_constant::idx can have, starting from 0.
constexpr int symbol_count = 6;
constexpr int symbol_count = 7;
constexpr int constant_count = 5;

template <int N>
Expand Down
38 changes: 28 additions & 10 deletions builder/simplify.cc
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ bool empty_intersection(const std::set<T>& a, const std::set<T>& b) {
class constant_adder : public node_mutator {
public:
index_t c;
bool cancelled = false;

constant_adder(index_t c) : c(c) {}

Expand All @@ -120,6 +121,11 @@ class constant_adder : public node_mutator {

template <typename T>
void visit_add_sub(const T* op, int sign_b) {
if (is_constant(op->b, c * -sign_b)) {
set_result(op->a);
cancelled = true;
return;
}
expr a = mutate(op->a);
if (a.defined()) {
set_result(T::make(std::move(a), op->b));
Expand All @@ -140,15 +146,22 @@ class constant_adder : public node_mutator {

template <typename T>
void visit_min_max(const T* op) {
// Here, we want to rewrite something like max(x + 2, y) - 2 to max(x, y - 2), but we don't want to rewrite
// something like max(x, 2) - 2 to max(x - 2, 0). To do this, we need to know if we cancelled something.
bool old_cancelled = cancelled;
cancelled = false;
expr a = mutate(op->a);
if (a.defined()) {
expr b = mutate(op->b);
if (b.defined()) {
set_result(T::make(std::move(a), std::move(b)));
return;
}
expr b = mutate(op->b);
cancelled = old_cancelled || cancelled;
if (a.defined() && b.defined()) {
set_result(T::make(std::move(a), std::move(b)));
} else if (a.defined() && cancelled) {
set_result(T::make(std::move(a), add::make(op->b, c)));
} else if (b.defined() && cancelled) {
set_result(T::make(add::make(op->a, c), std::move(b)));
} else {
set_result(expr());
}
set_result(expr());
}
void visit(const class min* op) override { visit_min_max(op); }
void visit(const class max* op) override { visit_min_max(op); }
Expand Down Expand Up @@ -1225,7 +1238,8 @@ class simplifier : public node_mutator {
alignment_type alignment;
if (auto cstep = as_constant(step)) alignment.modulus = *cstep;
if (auto cmin = as_constant(bounds.min)) alignment.remainder = *cmin;
stmt body = mutate_with_bounds(op->body, op->sym, bounds, alignment);
// If we're in the body of the loop, then we know that bounds.max >= bounds.min.
stmt body = mutate_with_bounds(op->body, op->sym, {bounds.min, mutate(max(bounds.min, bounds.max))}, alignment);
for (auto& i : buffers) {
if (i) --i->loop_depth;
}
Expand Down Expand Up @@ -2457,8 +2471,12 @@ bool can_evaluate(intrinsic fn) {
expr constant_lower_bound(const expr& x) { return constant_evaluator(false).mutate(x, -1); }
expr constant_upper_bound(const expr& x) { return constant_evaluator(false).mutate(x, 1); }
std::optional<index_t> evaluate_constant(const expr& x) { return as_constant(constant_evaluator().mutate(x, 0)); }
std::optional<index_t> evaluate_constant_lower_bound(const expr& x) { return as_constant(constant_evaluator().mutate(x, -1)); }
std::optional<index_t> evaluate_constant_upper_bound(const expr& x) { return as_constant(constant_evaluator().mutate(x, 1)); }
std::optional<index_t> evaluate_constant_lower_bound(const expr& x) {
return as_constant(constant_evaluator().mutate(x, -1));
}
std::optional<index_t> evaluate_constant_upper_bound(const expr& x) {
return as_constant(constant_evaluator().mutate(x, 1));
}

std::optional<bool> attempt_to_prove(
const expr& condition, const bounds_map& expr_bounds, const alignment_map& alignment) {
Expand Down
11 changes: 10 additions & 1 deletion builder/simplify_rules.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ rewrite::pattern_wildcard<2> z;
rewrite::pattern_wildcard<3> w;
rewrite::pattern_wildcard<4> u;
rewrite::pattern_wildcard<5> v;
rewrite::pattern_wildcard<6> t;

rewrite::pattern_constant<0> c0;
rewrite::pattern_constant<1> c1;
Expand All @@ -39,7 +40,9 @@ bool apply_min_rules(Fn&& apply) {
apply(min(x, y), x && y, is_boolean(x) && is_boolean(y)) ||

// This might be the only rule that doesn't have an analogous max rule.
apply(min(max(x, c0), c1), max(min(x, c1), c0), c0 <= c1) ||
apply(min(max(x, c0), c1),
c0, c0 == c1,
max(min(x, c1), c0), c0 < c1) ||

// Canonicalize trees and find duplicate terms.
apply(min(min(x, y), min(x, z)), min(x, min(y, z))) ||
Expand Down Expand Up @@ -104,6 +107,9 @@ bool apply_min_rules(Fn&& apply) {
apply(min(w - select(x, y, z), select(x, u, v)), select(x, min(u, w - y), min(v, w - z))) ||
apply(min(select(x, y, z) - w, select(x, u, v)), select(x, min(u, y - w), min(v, z - w))) ||

apply(min(select(x, y, select(z, w, u)), select(z, v, t)), select(z, min(v, select(x, y, w)), min(t, select(x, y, u)))) ||
apply(min(select(x, select(z, w, u), y), select(z, v, t)), select(z, min(v, select(x, w, y)), min(t, select(x, u, y)))) ||

apply(min(x + c2, select(c0 < x, y, c1)), select(c0 < x, min(x, y - c2), x) + c2, c1 >= c0 + c2) ||
apply(min(x + c2, select(c0 < x, c1, y)), select(c0 < x, c1, min(y, x + c2)), c1 <= c0 + c2) ||
apply(min(x + c2, select(x < c0, y, c1)), select(x < c0, min(y, x + c2), c1), c1 <= c0 + c2) ||
Expand Down Expand Up @@ -271,6 +277,9 @@ bool apply_max_rules(Fn&& apply) {
apply(max(w + select(x, y, z), select(x, u, v)), select(x, max(u, w + y), max(v, w + z))) ||
apply(max(w - select(x, y, z), select(x, u, v)), select(x, max(u, w - y), max(v, w - z))) ||
apply(max(select(x, y, z) - w, select(x, u, v)), select(x, max(u, y - w), max(v, z - w))) ||

apply(max(select(x, y, select(z, w, u)), select(z, v, t)), select(z, max(v, select(x, y, w)), max(t, select(x, y, u)))) ||
apply(max(select(x, select(z, w, u), y), select(z, v, t)), select(z, max(v, select(x, w, y)), max(t, select(x, u, y)))) ||

apply(max(x + c2, select(c0 < x, y, c1)), select(c0 < x, max(y, x + c2), c1), c1 >= c0 + c2) ||
apply(max(x + c2, select(c0 < x, c1, y)), select(c0 < x, x, max(x, y - c2)) + c2, c1 <= c0 + c2) ||
Expand Down
11 changes: 11 additions & 0 deletions builder/test/simplify/simplify.cc
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,10 @@ TEST(simplify, basic) {
matches(call_stmt::make(nullptr, {}, {x}, {})));
ASSERT_THAT(simplify(slice_buffer::make(y, x, {}, call_stmt::make(nullptr, {}, {y}, {}))),
matches(call_stmt::make(nullptr, {}, {x}, {})));

ASSERT_THAT(simplify(max(select(z <= 0, -1, select(1 <= y, min(x, z + -1), 0)) + 1, select((1 <= y), z, 0))),
matches(select((1 <= y), max(z, 0), (0 < z))));

}

TEST(simplify, let) {
Expand Down Expand Up @@ -520,6 +524,11 @@ TEST(simplify, buffer_bounds) {
ASSERT_THAT(simplify(decl_bounds(b0, {{0, select(1 < x, y, 1) + -1}},
crop_dim::make(b1, b0, 0, {0, select(1 < x, expr(), 0)}, use_buffer(b1)))),
matches(decl_bounds(b0, {{0, select(1 < x, y, 1) + -1}}, use_buffer(b0))));

ASSERT_THAT(simplify(loop::make(x, loop::serial, {0, y}, z,
crop_dim::make(b1, b0, 0, {select(x <= 0, x, expr()), y}, use_buffer(b1)))),
matches(loop::make(
x, loop::serial, {0, y}, z, crop_dim::make(b1, b0, 0, {select(x <= 0, 0, expr()), y}, use_buffer(b1)))));
}

TEST(simplify, crop_not_needed) {
Expand Down Expand Up @@ -795,6 +804,8 @@ TEST(simplify, knowledge) {
check::make(buffer_max(b0, 0) <= ((buffer_max(b0, 0) + 16) / 16) * 16 - 1)))),
matches(stmt()));

ASSERT_THAT(simplify(let::make(x, clamp(y, 0, 10), select(x <= 0, x, 0))), matches(0));

expr huge_select = 1;
for (int i = 0; i < 100; ++i) {
switch (i % 4) {
Expand Down
Loading