Skip to content

Commit

Permalink
Sliding window of a loop with a step > 1 works
Browse files Browse the repository at this point in the history
  • Loading branch information
dsharlet committed Jan 6, 2024
1 parent b08f8a4 commit 4eb4484
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 26 deletions.
63 changes: 41 additions & 22 deletions src/infer_bounds.cc
Original file line number Diff line number Diff line change
Expand Up @@ -248,13 +248,16 @@ class slider : public node_mutator {
symbol_map<std::pair<int, expr>> fold_factors;
struct loop_info {
symbol_id sym;
expr og_min;
expr min;
expr orig_min;
interval_expr bounds;
expr step;
};
std::vector<loop_info> loops;

slider(node_context& ctx) : ctx(ctx) {}
// We need an unknown to make equations of.
var x;

slider(node_context& ctx) : ctx(ctx), x(ctx.insert_unique("_x")) {}

void visit(const allocate* alloc) override {
box_expr bounds;
Expand Down Expand Up @@ -312,43 +315,59 @@ class slider : public node_mutator {
for (size_t l = 0; l < loops.size(); ++l) {
symbol_id loop_sym = loops[l].sym;
expr loop_var = variable::make(loop_sym);
const expr& loop_max = loops[l].bounds.max;

for (int d = 0; d < static_cast<int>(bounds->size()); ++d) {
interval_expr cur_bounds_d = (*bounds)[d];
interval_expr prev_bounds_d{
interval_expr prev_bounds_d = {
substitute(cur_bounds_d.min, loop_sym, loop_var - loops[l].step),
substitute(cur_bounds_d.max, loop_sym, loop_var - loops[l].step),
};

if (prove_true(prev_bounds_d.min <= cur_bounds_d.min) && prove_true(prev_bounds_d.max < cur_bounds_d.max)) {
// A few things here struggle to simplify when there is a min(loop_max, x) expression involved, where x is
// some expression that is bounded by the loop bounds. This min simplifies away if we know that x <= loop_max,
// but the simplifier can't figure that out. As a hopefully temporary workaround, we can just substitute
// infinity for the loop max.
auto ignore_loop_max = [=](const expr& e) { return substitute(e, loop_max, positive_infinity()); };

expr is_monotonic_increasing = prev_bounds_d.min <= cur_bounds_d.min && prev_bounds_d.max < cur_bounds_d.max;
expr is_monotonic_decreasing = prev_bounds_d.min > cur_bounds_d.min && prev_bounds_d.max >= cur_bounds_d.max;
is_monotonic_increasing = ignore_loop_max(is_monotonic_increasing);
is_monotonic_decreasing = ignore_loop_max(is_monotonic_decreasing);

if (prove_true(is_monotonic_increasing)) {
// The bounds for each loop iteration are monotonically increasing,
// so we can incrementally compute only the newly required bounds.
expr old_min = cur_bounds_d.min;
expr new_min = simplify(simplify(prev_bounds_d.max + 1));

expr fold_factor = simplify(bounds_of(cur_bounds_d.extent()).max);
fold_factors[output] = {d, fold_factor};
expr fold_factor = simplify(bounds_of(ignore_loop_max(cur_bounds_d.extent())).max);
if (!depends_on(fold_factor, loop_sym)) {
fold_factors[output] = {d, fold_factor};
} else {
// The fold factor didn't simplify to something that doesn't depend on the loop variable.
}

// Now that we're only computing the newly required parts of the domain, we need
// to move the loop min back so we compute the whole required region. We'll insert
// ifs around the other parts of the loop to avoid expanding the bounds that those
// run on.
symbol_id new_loop_min_sym = ctx.insert_unique();
expr new_loop_min_var = variable::make(new_loop_min_sym);
expr new_min_at_new_loop_min = substitute(new_min, loop_sym, new_loop_min_var);
expr old_min_at_loop_min = substitute(old_min, loop_sym, loops[l].og_min);
expr new_loop_min = where_true(new_min_at_new_loop_min <= old_min_at_loop_min, new_loop_min_sym).max;
expr new_min_at_new_loop_min = substitute(new_min, loop_sym, x);
expr old_min_at_loop_min = substitute(old_min, loop_sym, loops[l].orig_min);
expr new_loop_min =
where_true(ignore_loop_max(new_min_at_new_loop_min <= old_min_at_loop_min), x.sym()).max;
if (!is_negative_infinity(new_loop_min)) {
loops[l].min = simplify(min(loops[l].min, new_loop_min));
loops[l].bounds.min = simplify(min(loops[l].bounds.min, new_loop_min));

(*bounds)[d].min = new_min;
} else {
// We couldn't find the new loop min. We need to warm up the loop on the first iteration.
(*bounds)[d].min = select(loop_var == loops[l].og_min, old_min, new_min);
// TODO: If another loop or func adjusts the loop min, we're going to run before the original min... that
// seems like it might be fine anyways here, but pretty janky.
(*bounds)[d].min = select(loop_var == loops[l].orig_min, old_min, new_min);
}
break;
} else if (prove_true(prev_bounds_d.min > cur_bounds_d.min) &&
prove_true(prev_bounds_d.max >= cur_bounds_d.max)) {
} else if (prove_true(is_monotonic_decreasing)) {
// TODO: We could also try to slide when the bounds are monotonically
// decreasing, but this is an unusual case.
}
Expand All @@ -358,7 +377,7 @@ class slider : public node_mutator {

// Insert ifs around these calls, in case the loop min shifts later.
for (const auto& l : loops) {
result = if_then_else::make(variable::make(l.sym) >= l.min, result, stmt());
result = if_then_else::make(variable::make(l.sym) >= l.bounds.min, result, stmt());
}
set_result(result);
}
Expand Down Expand Up @@ -393,19 +412,19 @@ class slider : public node_mutator {
void visit(const truncate_rank*) override { std::abort(); }

void visit(const loop* l) override {
var orig_loop_min(ctx, ctx.name(l->sym) + "_min.orig");
var orig_min(ctx, ctx.name(l->sym) + "_min.orig");

loops.emplace_back(l->sym, orig_loop_min, orig_loop_min, l->step);
loops.emplace_back(l->sym, orig_min, l->bounds, l->step);
stmt body = mutate(l->body);
expr loop_min = loops.back().min;
expr loop_min = loops.back().bounds.min;
loops.pop_back();

if (loop_min.same_as(orig_loop_min) && body.same_as(l->body)) {
if (loop_min.same_as(orig_min) && body.same_as(l->body)) {
set_result(l);
} else {
// We rewrote the loop min.
stmt result = loop::make(l->sym, {loop_min, l->bounds.max}, l->step, std::move(body));
set_result(let_stmt::make(orig_loop_min.sym(), l->bounds.min, result));
set_result(let_stmt::make(orig_min.sym(), l->bounds.min, result));
}
}

Expand Down
19 changes: 17 additions & 2 deletions src/simplify.cc
Original file line number Diff line number Diff line change
Expand Up @@ -258,19 +258,21 @@ expr simplify(const class min* op, expr a, expr b) {
{min(min(x, c0), c1), min(x, min(c0, c1))},
{min(x, x + c0), x, c0 > 0},
{min(x, x + c0), x + c0, c0 < 0},
{min(x + c0, y + c1), min(x, y + (c1 - c0)) + c0},
{min(x + c0, c1), min(x, c1 - c0) + c0},
{min(c0 - x, c0 - y), c0 - max(x, y)},

// Algebraic simplifications
{min(x, x), x},
{min(x, max(x, y)), x},
{min(x, min(x, y)), min(x, y)},
{min(min(x, y), y + c0), min(x, min(y, y + c0))},
{min(min(x, y + c0), y), min(x, min(y, y + c0))},
{min(max(x, y), min(x, z)), min(x, z)},
{min(min(x, y), min(x, z)), min(x, min(y, z))},
{min(max(x, y), max(x, z)), max(x, min(y, z))},
{min(x, min(y, x + z)), min(y, min(x, x + z))},
{min(x, min(y, x - z)), min(y, min(x, x - z))},
{min(min(x, (y + z)), (y + w)), min(x, min(y + z, y + w))},
{min(x / z, y / z), min(x, y) / z, z > 0},
{min(x / z, y / z), max(x, y) / z, z < 0},
{min(x * z, y * z), z * min(x, y), z > 0},
Expand All @@ -285,6 +287,7 @@ expr simplify(const class min* op, expr a, expr b) {
{min(buffer_min(x, y), buffer_max(x, y)), buffer_min(x, y)},
{min(buffer_min(x, y), buffer_max(x, y) + c0), buffer_min(x, y), c0 > 0},
{min(buffer_max(x, y), buffer_min(x, y) + c0), buffer_min(x, y) + c0, c0 < 0},
{min(buffer_max(x, y) + c0, buffer_min(x, y) + c1), buffer_min(x, y) + c1, c0 > c1},
};
return rules.apply(e);
}
Expand Down Expand Up @@ -315,14 +318,15 @@ expr simplify(const class max* op, expr a, expr b) {
{max(max(x, c0), c1), max(x, max(c0, c1))},
{max(x, x + c0), x + c0, c0 > 0},
{max(x, x + c0), x, c0 < 0},
{max(x + c0, y + c1), max(x, y + (c1 - c0)) + c0},
{max(x + c0, c1), max(x, c1 - c0) + c0},
{max(c0 - x, c0 - y), c0 - min(x, y)},

// Algebraic simplifications
{max(x, x), x},
{max(x, min(x, y)), x},
{max(x, max(x, y)), max(x, y)},
{max(max(x, y), y + c0), max(x, max(y, y + c0))},
{max(max(x, y + c0), y), max(x, max(y, y + c0))},
{max(min(x, y), max(x, z)), max(x, z)},
{max(max(x, y), max(x, z)), max(x, max(y, z))},
{max(min(x, y), min(x, z)), min(x, max(y, z))},
Expand All @@ -340,6 +344,7 @@ expr simplify(const class max* op, expr a, expr b) {
{max(buffer_min(x, y), buffer_max(x, y)), buffer_max(x, y)},
{max(buffer_min(x, y), buffer_max(x, y) + c0), buffer_max(x, y) + c0, c0 > 0},
{max(buffer_max(x, y), buffer_min(x, y) + c0), buffer_max(x, y), c0 < 0},
{max(buffer_max(x, y) + c0, buffer_min(x, y) + c1), buffer_max(x, y) + c0, c0 > c1},
};
return rules.apply(e);
}
Expand Down Expand Up @@ -381,12 +386,17 @@ expr simplify(const add* op, expr a, expr b) {
{(x + c0) - y, (x - y) + c0},
{(x + c0) + (y + c1), (x + y) + (c0 + c1)},

{min(x, y - z) + z, min(y, x + z)},
{max(x, y - z) + z, max(y, x + z)},

{min(x + c0, y + c1) + c2, min(x + (c0 + c2), y + (c1 + c2))},
{max(x + c0, y + c1) + c2, max(x + (c0 + c2), y + (c1 + c2))},
{min(c0 - x, y + c1) + c2, min((c0 + c2) - x, y + (c1 + c2))},
{max(c0 - x, y + c1) + c2, max((c0 + c2) - x, y + (c1 + c2))},
{min(c0 - x, c1 - y) + c2, min((c0 + c2) - x, (c1 + c2) - y)},
{max(c0 - x, c1 - y) + c2, max((c0 + c2) - x, (c1 + c2) - y)},
{min(x, y + c0) + c1, min(x + c1, y + (c0 + c1))},
{max(x, y + c0) + c1, max(x + c1, y + (c0 + c1))},

{select(x, c0, c1) + c2, select(x, c0 + c2, c1 + c2)},
{select(x, y + c0, c1) + c2, select(x, y + (c0 + c2), c1 + c2)},
Expand Down Expand Up @@ -445,6 +455,9 @@ expr simplify(const sub* op, expr a, expr b) {
{(c0 - x) - (y - z), ((z - x) - y) + c0},
{(x + c0) - (y + c1), (x - y) + (c0 - c1)},

{min(x, y + z) - z, min(y, x - z)},
{max(x, y + z) - z, max(y, x - z)},

{c2 - select(x, c0, c1), select(x, c2 - c0, c2 - c1)},
{c2 - select(x, y + c0, c1), select(x, (c2 - c0) - y, c2 - c1)},
{c2 - select(x, c0 - y, c1), select(x, y + (c2 - c0), c2 - c1)},
Expand Down Expand Up @@ -589,6 +602,7 @@ expr simplify(const less* op, expr a, expr b) {
{x - y < z - y, x < z},

{min(x, y) < x, y < x},
{min(x, min(y, z)) < y, min(x, z) < y},
{max(x, y) < x, false},
{x < max(x, y), x < y},
{x < min(x, y), false},
Expand Down Expand Up @@ -639,6 +653,7 @@ expr simplify(const less_equal* op, expr a, expr b) {
{x - y <= z - y, x <= z},

{min(x, y) <= x, true},
{min(x, min(y, z)) <= y, true},
{max(x, y) <= x, y <= x},
{x <= max(x, y), true},
{x <= min(x, y), x <= y},
Expand Down
4 changes: 2 additions & 2 deletions test/pipeline.cc
Original file line number Diff line number Diff line change
Expand Up @@ -378,7 +378,7 @@ TEST(pipeline_stencil) {
func stencil =
func::make<const short, short>(sum3x3<short>, {intm, {bounds(-1, 1) + x, bounds(-1, 1) + y}}, {out, {x, y}});

stencil.loops({y});
stencil.loops({{y, 2}});
add.compute_at({&stencil, y});

pipeline p(ctx, {in}, {out});
Expand All @@ -399,7 +399,7 @@ TEST(pipeline_stencil) {
const raw_buffer* outputs[] = {&out_buf};
debug_context eval_ctx;
p.evaluate(inputs, outputs, eval_ctx);
ASSERT_EQ(eval_ctx.heap.total_size, (W + 2) * 3 * sizeof(short));
ASSERT_EQ(eval_ctx.heap.total_size, (W + 2) * 4 * sizeof(short));
ASSERT_EQ(eval_ctx.heap.total_count, 1);

for (int y = 0; y < H; ++y) {
Expand Down

0 comments on commit 4eb4484

Please sign in to comment.