Skip to content

Commit

Permalink
Handle split loops
Browse files Browse the repository at this point in the history
  • Loading branch information
dsharlet committed Jan 6, 2024
1 parent ea1f932 commit b08f8a4
Show file tree
Hide file tree
Showing 4 changed files with 7 additions and 15 deletions.
7 changes: 4 additions & 3 deletions src/infer_bounds.cc
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,7 @@ class slider : public node_mutator {
symbol_id sym;
expr og_min;
expr min;
expr step;
};
std::vector<loop_info> loops;

Expand Down Expand Up @@ -315,8 +316,8 @@ class slider : public node_mutator {
for (int d = 0; d < static_cast<int>(bounds->size()); ++d) {
interval_expr cur_bounds_d = (*bounds)[d];
interval_expr prev_bounds_d{
substitute(cur_bounds_d.min, loop_sym, loop_var - 1),
substitute(cur_bounds_d.max, loop_sym, loop_var - 1),
substitute(cur_bounds_d.min, loop_sym, loop_var - loops[l].step),
substitute(cur_bounds_d.max, loop_sym, loop_var - loops[l].step),
};

if (prove_true(prev_bounds_d.min <= cur_bounds_d.min) && prove_true(prev_bounds_d.max < cur_bounds_d.max)) {
Expand Down Expand Up @@ -394,7 +395,7 @@ class slider : public node_mutator {
void visit(const loop* l) override {
var orig_loop_min(ctx, ctx.name(l->sym) + "_min.orig");

loops.emplace_back(l->sym, orig_loop_min, orig_loop_min);
loops.emplace_back(l->sym, orig_loop_min, orig_loop_min, l->step);
stmt body = mutate(l->body);
expr loop_min = loops.back().min;
loops.pop_back();
Expand Down
5 changes: 2 additions & 3 deletions src/pipeline.cc
Original file line number Diff line number Diff line change
Expand Up @@ -252,9 +252,8 @@ class pipeline_builder {
for (const func::output& o : f->outputs()) {
for (int d = 0; d < static_cast<int>(o.dims.size()); ++d) {
if (o.dims[d].sym() == loop.sym()) {
// TODO: Clamp at buffer max here to handle loop extents not a multiple of the step.
// expr loop_max = buffer_max(var(o.sym()), d);
interval_expr bounds = slinky::bounds(loop.var, simplify(loop.var + loop.step - 1));
expr loop_max = buffer_max(var(o.sym()), d);
interval_expr bounds = slinky::bounds(loop.var, min(loop.var + loop.step - 1, loop_max));
body = crop_dim::make(o.sym(), d, bounds, body);
}
}
Expand Down
8 changes: 0 additions & 8 deletions src/simplify.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1158,10 +1158,6 @@ class simplifier : public node_mutator {
interval_expr bounds = mutate(op->bounds, &min_bounds, &max_bounds);
expr step = mutate(op->step);

if (!step.defined()) {
step = 1;
}

if (prove_true(min_bounds.min > max_bounds.max)) {
// This loop is dead.
set_result(stmt());
Expand All @@ -1175,10 +1171,6 @@ class simplifier : public node_mutator {
auto set_bounds = set_value_in_scope(expr_bounds, op->sym, bounds);
stmt body = mutate(op->body);

if (is_constant(step, 1)) {
step = expr();
}

if (bounds.same_as(op->bounds) && step.same_as(op->step) && body.same_as(op->body)) {
set_result(op);
} else {
Expand Down
2 changes: 1 addition & 1 deletion test/pipeline.cc
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ TEST(pipeline_elementwise_1d_explicit) {
func mul = func::make<const int, int>(multiply_2<int>, {in, {point(x)}}, {intm, {x}});
func add = func::make<const int, int>(add_1<int>, {intm, {point(x)}}, {out, {x}});

add.loops({x});
add.loops({{x, 3}}); // Doesn't divide the extent of the buffer below.
mul.compute_at({&add, x});

intm->store_at({&add, x});
Expand Down

0 comments on commit b08f8a4

Please sign in to comment.