Skip to content

Commit

Permalink
Clean up infer_bounds
Browse files Browse the repository at this point in the history
  • Loading branch information
dsharlet committed Jan 3, 2024
1 parent be456e8 commit 3beb6b8
Showing 1 changed file with 57 additions and 86 deletions.
143 changes: 57 additions & 86 deletions src/infer_bounds.cc
Original file line number Diff line number Diff line change
Expand Up @@ -42,16 +42,16 @@ void merge_crop(std::optional<box_expr>& bounds, const box_expr& new_bounds) {
class bounds_inferrer : public node_mutator {
public:
node_context& ctx;
symbol_map<box_expr> buffers;
symbol_map<box_expr> infer;
symbol_map<box_expr> crops;

bounds_inferrer(node_context& ctx) : ctx(ctx) {}

void visit(const allocate* alloc) override {
{
std::optional<box_expr>& info = buffers[alloc->sym];
assert(!info);
info = box_expr();
std::optional<box_expr>& bounds = infer[alloc->sym];
assert(!bounds);
bounds = box_expr();
}

stmt body = mutate(alloc->body);
Expand All @@ -64,10 +64,10 @@ class bounds_inferrer : public node_mutator {
// TODO: Is this actually a good design...?
std::vector<std::pair<expr, expr>> replacements;

box_expr& info = *buffers[alloc->sym];
box_expr& bounds = *infer[alloc->sym];
expr stride = static_cast<index_t>(alloc->elem_size);
for (index_t d = 0; d < static_cast<index_t>(info.size()); ++d) {
interval_expr& i = info[d];
for (index_t d = 0; d < static_cast<index_t>(bounds.size()); ++d) {
interval_expr& i = bounds[d];

i.min = simplify(i.min);
i.max = simplify(i.max);
Expand Down Expand Up @@ -106,21 +106,21 @@ class bounds_inferrer : public node_mutator {
// user set the bounds to something too small).
std::vector<stmt> checks;
for (index_t d = 0; d < static_cast<index_t>(dims.size()); ++d) {
checks.push_back(check::make(dims[d].min() <= info[d].min));
checks.push_back(check::make(dims[d].max() >= info[d].max));
checks.push_back(check::make(dims[d].min() <= bounds[d].min));
checks.push_back(check::make(dims[d].max() >= bounds[d].max));
}

stmt s = allocate::make(alloc->storage, alloc->sym, alloc->elem_size, std::move(dims), body);
set_result(block::make(block::make(checks), s));
}

expr buffer_intrinsic(symbol_id buffer, intrinsic fn, index_t d) {
std::optional<box_expr>& info = buffers[buffer];
if (info && d < static_cast<index_t>(info->size())) {
std::optional<box_expr>& bounds = infer[buffer];
if (bounds && d < static_cast<index_t>(bounds->size())) {
switch (fn) {
case intrinsic::buffer_min: return (*info)[d].min;
case intrinsic::buffer_max: return (*info)[d].max;
case intrinsic::buffer_extent: return (*info)[d].extent();
case intrinsic::buffer_min: return (*bounds)[d].min;
case intrinsic::buffer_max: return (*bounds)[d].max;
case intrinsic::buffer_extent: return (*bounds)[d].extent();
default: break;
}
}
Expand Down Expand Up @@ -148,28 +148,28 @@ class bounds_inferrer : public node_mutator {
}
}

std::optional<box_expr>& info = buffers[input.buffer->sym()];
assert(info);
info->reserve(input.bounds.size());
while (info->size() < input.bounds.size()) {
info->push_back(interval_expr::union_identity());
std::optional<box_expr>& bounds = infer[input.buffer->sym()];
assert(bounds);
bounds->reserve(input.bounds.size());
while (bounds->size() < input.bounds.size()) {
bounds->push_back(interval_expr::union_identity());
}
for (std::size_t d = 0; d < input.bounds.size(); ++d) {
expr min = substitute(input.bounds[d].min, mins);
expr max = substitute(input.bounds[d].max, maxs);
// We need to be careful of the case where min > max, such as when a pipeline
// flips a dimension.
// TODO: This seems janky/possibly not right.
(*info)[d] |= slinky::bounds(min, max) | slinky::bounds(max, min);
(*bounds)[d] |= slinky::bounds(min, max) | slinky::bounds(max, min);
}
}

// Add any crops necessary.
stmt result = c;
for (const func::output& output : c->fn->outputs()) {
std::optional<box_expr>& info = buffers[output.buffer->sym()];
if (info) {
result = crop_buffer::make(output.buffer->sym(), *info, result);
std::optional<box_expr>& bounds = infer[output.buffer->sym()];
if (bounds) {
result = crop_buffer::make(output.buffer->sym(), *bounds, result);
}
}

Expand All @@ -179,14 +179,14 @@ class bounds_inferrer : public node_mutator {
void visit(const crop_buffer* c) override {
std::optional<box_expr> crop = crops[c->sym];
merge_crop(crop, c->bounds);
auto new_crop = set_value_in_scope(crops, c->sym, crop);
auto set_crop = set_value_in_scope(crops, c->sym, crop);
node_mutator::visit(c);
}

void visit(const crop_dim* c) override {
std::optional<box_expr> crop = crops[c->sym];
merge_crop(crop, c->dim, c->bounds);
auto new_crop = set_value_in_scope(crops, c->sym, crop);
auto set_crop = set_value_in_scope(crops, c->sym, crop);
node_mutator::visit(c);
}

Expand All @@ -203,7 +203,7 @@ class bounds_inferrer : public node_mutator {

// We're leaving the body of l. If any of the bounds used that loop variable, we need
// to replace those uses with the bounds of the loop.
for (std::optional<box_expr>& i : buffers) {
for (std::optional<box_expr>& i : infer) {
if (!i) continue;

for (interval_expr& j : *i) {
Expand Down Expand Up @@ -236,19 +236,19 @@ class bounds_inferrer : public node_mutator {
class slider : public node_mutator {
public:
node_context& ctx;
symbol_map<box_expr> buffers;
symbol_map<box_expr> buffer_bounds;
symbol_map<std::pair<int, expr>> fold_factors;
std::vector<std::pair<symbol_id, interval_expr>> loop_bounds;

slider(node_context& ctx) : ctx(ctx) {}

void visit(const allocate* alloc) override {
box_expr info;
info.reserve(alloc->dims.size());
box_expr bounds;
bounds.reserve(alloc->dims.size());
for (const dim_expr& d : alloc->dims) {
info.push_back(d.bounds);
bounds.push_back(d.bounds);
}
auto set_buffers = set_value_in_scope(buffers, alloc->sym, info);
auto set_buffer_bounds = set_value_in_scope(buffer_bounds, alloc->sym, bounds);
stmt body = mutate(alloc->body);

// When we constructed the pipeline, the buffer dimensions were set to buffer_* calls.
Expand Down Expand Up @@ -294,28 +294,26 @@ class slider : public node_mutator {

stmt result = c;
for (const func::output& output : c->fn->outputs()) {
std::optional<box_expr>& info = buffers[output.buffer->sym()];
if (!info) continue;

box_expr& bounds = *info;
std::optional<box_expr>& bounds = buffer_bounds[output.buffer->sym()];
if (!bounds) continue;

for (size_t l = 0; l < loop_bounds.size(); ++l) {
symbol_id loop_sym = loop_bounds[l].first;
expr loop_var = variable::make(loop_sym);
expr loop_min = loop_bounds[l].second.min;
expr loop_max = loop_bounds[l].second.max;

box_expr prev_bounds(bounds.size());
for (int d = 0; d < static_cast<int>(bounds.size()); ++d) {
prev_bounds[d].min = substitute(bounds[d].min, loop_sym, loop_var - 1);
prev_bounds[d].max = substitute(bounds[d].max, loop_sym, loop_var - 1);
if (prove_true(prev_bounds[d].min <= bounds[d].min) && prove_true(prev_bounds[d].max < bounds[d].max)) {
box_expr prev_bounds(bounds->size());
for (int d = 0; d < static_cast<int>(bounds->size()); ++d) {
prev_bounds[d].min = substitute((*bounds)[d].min, loop_sym, loop_var - 1);
prev_bounds[d].max = substitute((*bounds)[d].max, loop_sym, loop_var - 1);
if (prove_true(prev_bounds[d].min <= (*bounds)[d].min) && prove_true(prev_bounds[d].max < (*bounds)[d].max)) {
// The bounds for each loop iteration are monotonically increasing,
// so we can incrementally compute only the newly required bounds.
expr& old_min = bounds[d].min;
expr& old_min = (*bounds)[d].min;
expr new_min = prev_bounds[d].max + 1;

fold_factors[output.buffer->sym()] = {d, simplify(bounds_of(bounds[d].extent()).max)};
fold_factors[output.buffer->sym()] = {d, simplify(bounds_of((*bounds)[d].extent()).max)};

// Now that we're only computing the newly required parts of the domain, we need
// to move the loop min back so we compute the whole required region. We'll insert
Expand All @@ -335,8 +333,8 @@ class slider : public node_mutator {
old_min = select(loop_var == loop_min, old_min, new_min);
}
break;
} else if (prove_true(prev_bounds[d].min > bounds[d].min) &&
prove_true(prev_bounds[d].max >= bounds[d].max)) {
} else if (prove_true(prev_bounds[d].min > (*bounds)[d].min) &&
prove_true(prev_bounds[d].max >= (*bounds)[d].max)) {
// TODO: We could also try to slide when the bounds are monotonically
// decreasing, but this is an unusual case.
}
Expand All @@ -356,11 +354,11 @@ class slider : public node_mutator {
}

void visit(const crop_buffer* c) override {
std::optional<box_expr> info = buffers[c->sym];
merge_crop(info, c->bounds);
auto new_crop = set_value_in_scope(buffers, c->sym, info);
std::optional<box_expr> bounds = buffer_bounds[c->sym];
merge_crop(bounds, c->bounds);
auto set_bounds = set_value_in_scope(buffer_bounds, c->sym, bounds);
stmt body = mutate(c->body);
box_expr new_bounds = *buffers[c->sym];
box_expr new_bounds = *buffer_bounds[c->sym];

if (const if_then_else* body_if = body.as<if_then_else>()) {
// TODO: HORRIBLE HACK: crop_dim modifies the buffer meta, which this if we inserted
Expand All @@ -370,19 +368,19 @@ class slider : public node_mutator {
// be to substitute a clamp on the loop variable for when the if is true. It should
// simplify away later anyways, and make it easier to track bounds. This isn't easily
// doable due to this hack.
set_result(
if_then_else::make(body_if->condition, crop_buffer::make(c->sym, std::move(new_bounds), body_if->true_body), stmt()));
set_result(if_then_else::make(
body_if->condition, crop_buffer::make(c->sym, std::move(new_bounds), body_if->true_body), stmt()));
} else {
set_result(crop_buffer::make(c->sym, std::move(new_bounds), std::move(body)));
}
}

void visit(const crop_dim* c) override {
std::optional<box_expr> info = buffers[c->sym];
merge_crop(info, c->dim, c->bounds);
auto set_crop = set_value_in_scope(buffers, c->sym, info);
std::optional<box_expr> bounds = buffer_bounds[c->sym];
merge_crop(bounds, c->dim, c->bounds);
auto set_bounds = set_value_in_scope(buffer_bounds, c->sym, bounds);
stmt body = mutate(c->body);
interval_expr new_bounds = (*buffers[c->sym])[c->dim];
interval_expr new_bounds = (*buffer_bounds[c->sym])[c->dim];

if (const if_then_else* body_if = body.as<if_then_else>()) {
// TODO: HORRIBLE HACK: crop_dim modifies the buffer meta, which this if we inserted
Expand All @@ -409,39 +407,12 @@ class slider : public node_mutator {
assert(l->bounds.max.same_as(loop_bounds.back().second.max));
loop_bounds.pop_back();

stmt result;
if (loop_min.same_as(l->bounds.min) && body.same_as(l->body)) {
result = l;
set_result(l);
} else {
// We rewrote the loop min.
result = loop::make(l->sym, {loop_min, l->bounds.max}, l->step, std::move(body));
}

// We're leaving the body of l. If any of the bounds used that loop variable, we need
// to replace those uses with the bounds of the loop.
// TODO: This ignores ifs inserted around parts of the body of this loop, which limit the
// range of the loop. I was debugging a failure regarding this when I made an unrelated
// change, and it magically started working. It *shouldn't* work, I expect this bug will
// appear again. See the TODO: HORRIBLE HACK: above for more.
// Use the original loop min. Hack?
loop_min = l->bounds.min;
expr loop_max = l->bounds.max;
for (std::optional<box_expr>& i : buffers) {
if (!i) continue;

for (interval_expr& j : *i) {
// We need to be careful of the case where min > max, such as when a pipeline
// flips a dimension.
// TODO: This seems janky/possibly not right.
if (depends_on(j.min, l->sym)) {
j.min = min(substitute(j.min, l->sym, loop_min), substitute(j.min, l->sym, loop_max));
}
if (depends_on(j.max, l->sym)) {
j.max = max(substitute(j.max, l->sym, loop_min), substitute(j.max, l->sym, loop_max));
}
}
set_result(loop::make(l->sym, {loop_min, l->bounds.max}, l->step, std::move(body)));
}
set_result(result);
}

void visit(const block* x) override {
Expand All @@ -463,7 +434,7 @@ stmt infer_bounds(const stmt& s, node_context& ctx, const std::vector<symbol_id>

// Tell the bounds inferrer that we are buffers the bounds of the inputs too.
for (symbol_id i : inputs) {
infer.buffers[i] = box_expr();
infer.infer[i] = box_expr();
}

// Run it.
Expand All @@ -473,7 +444,7 @@ stmt infer_bounds(const stmt& s, node_context& ctx, const std::vector<symbol_id>
std::vector<stmt> checks;
for (symbol_id i : inputs) {
expr buf_var = variable::make(i);
const box_expr& bounds = *infer.buffers[i];
const box_expr& bounds = *infer.infer[i];
for (int d = 0; d < static_cast<int>(bounds.size()); ++d) {
checks.push_back(check::make(buffer_min(buf_var, d) <= bounds[d].min));
checks.push_back(check::make(buffer_max(buf_var, d) >= bounds[d].max));
Expand All @@ -485,7 +456,7 @@ stmt infer_bounds(const stmt& s, node_context& ctx, const std::vector<symbol_id>

slider slide(ctx);
for (symbol_id i : inputs) {
slide.buffers[i] = box_expr();
slide.buffer_bounds[i] = box_expr();
}
result = slide.mutate(result);

Expand Down

0 comments on commit 3beb6b8

Please sign in to comment.