Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change IR generation to minimize the depth of the allocation nesting #543

Merged
merged 55 commits into from
Jan 16, 2025
Merged
Show file tree
Hide file tree
Changes from 35 commits
Commits
Show all changes
55 commits
Select commit Hold shift + click to select a range
89491e3
First cut of reduced nesting algorithm
vksnk Jan 6, 2025
a1c6ee3
Formatting
vksnk Jan 6, 2025
5f684cb
Re-enable checks and old function for the reference
vksnk Jan 6, 2025
922b5d9
Debugging
vksnk Jan 6, 2025
842208e
Add assert
vksnk Jan 6, 2025
5e2c5a6
Remove old stuff
vksnk Jan 7, 2025
6636f72
Merge main
vksnk Jan 7, 2025
18d50b6
Fix merge bugs
vksnk Jan 7, 2025
0856e7e
put back old stuff for comparison
vksnk Jan 7, 2025
82fed14
debug
vksnk Jan 7, 2025
a286c15
Merge branch 'main' into vksnk/lessnest
vksnk Jan 7, 2025
5b897f6
Fix aliasing
dsharlet Jan 7, 2025
f70efd9
Merge branch 'main' into vksnk/lessnest
vksnk Jan 8, 2025
d7a0420
Update tests
dsharlet Jan 8, 2025
698f29c
alias_copy_src should set the permutation too (#542)
dsharlet Jan 8, 2025
62ae4f6
Remove unneeded code
vksnk Jan 8, 2025
b66c271
Remove unused structu
vksnk Jan 8, 2025
469ad38
Renaming
vksnk Jan 8, 2025
70f65af
use sym as a key to store candidates
vksnk Jan 8, 2025
972f6f9
Update viz
vksnk Jan 8, 2025
ce64574
Disable some of the checks
vksnk Jan 9, 2025
1970854
Remove old stuff
vksnk Jan 9, 2025
2332781
Reenable checks
vksnk Jan 9, 2025
68f7efb
Remove debugging code
vksnk Jan 9, 2025
345e608
Move deps counting into separate function
vksnk Jan 9, 2025
42d67e3
More cleanup
vksnk Jan 9, 2025
5822084
Fix accidental replacement
vksnk Jan 9, 2025
b9b8b26
Comments
vksnk Jan 9, 2025
cfe8cf5
Fix comment
vksnk Jan 9, 2025
5c364e7
Fix comment
vksnk Jan 9, 2025
a014fe1
Formatting
vksnk Jan 9, 2025
82c205c
Merge branch 'main' into vksnk/lessnest
vksnk Jan 9, 2025
5f55ec8
Merge branch 'main' into vksnk/lessnest
vksnk Jan 9, 2025
80a1724
Combine various structures into one
vksnk Jan 9, 2025
2b7c46d
Speedup check for which of the buffers are produced/consumed inside o…
vksnk Jan 9, 2025
1978659
Address feedback
vksnk Jan 10, 2025
4a2cdb9
Remove unneeded check
vksnk Jan 10, 2025
ebdc830
Replace tuple with the existing structu
vksnk Jan 10, 2025
41cbadd
Replace tuple with the new struct
vksnk Jan 10, 2025
2873db4
Merge branch 'main' into vksnk/lessnest
vksnk Jan 10, 2025
dacceee
Put allocations of the copy's inputs outside
vksnk Jan 10, 2025
604f55a
A more efficient way to make inputs of the copy to wrap outputs
vksnk Jan 11, 2025
d28b284
Remove old stuff
vksnk Jan 13, 2025
6843ad4
Merge branch 'main' into vksnk/lessnest
vksnk Jan 13, 2025
5b90206
Fix warning
vksnk Jan 13, 2025
864e413
More accurately track allocations within the range
vksnk Jan 13, 2025
6ee00d5
Update viz
vksnk Jan 13, 2025
9a3c2b3
Merge branch 'main' into vksnk/lessnest
vksnk Jan 13, 2025
19b5031
Reverse constraint logic for the padded copies
vksnk Jan 14, 2025
ce26da1
Merge branch 'main' into vksnk/lessnest
vksnk Jan 16, 2025
cebfa14
Revert test/copy_pipeline.cc changes
vksnk Jan 16, 2025
191f071
Update test
vksnk Jan 16, 2025
2cc4c34
Revert optimizations.cc changes
vksnk Jan 16, 2025
74e99bd
Remove commented assert
vksnk Jan 16, 2025
3420c29
Reserve vectors
vksnk Jan 16, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions builder/optimizations.cc
Original file line number Diff line number Diff line change
Expand Up @@ -452,6 +452,7 @@ class buffer_aliaser : public stmt_mutator {
i.may_mutate = i.may_mutate || alias.may_mutate;
i.assume_in_bounds = i.assume_in_bounds && alias.assume_in_bounds;
}
target_info->uses += info.uses;

if (elem_size.defined()) {
result = block::make({check::make(elem_size == op->elem_size), result});
Expand Down Expand Up @@ -783,6 +784,23 @@ class buffer_aliaser : public stmt_mutator {
// TODO: We should be able to handle this.
std::abort();
}

void visit(const block* op) override {
// Visit blocks in reverse order so we see uses of buffers before they are produced.
std::vector<stmt> stmts;
stmts.reserve(op->stmts.size());
bool changed = false;
for (auto i = op->stmts.rbegin(); i != op->stmts.rend(); ++i) {
stmts.push_back(mutate(*i));
changed = changed || !stmts.back().same_as(*i);
}
if (!changed) {
set_result(op);
} else {
std::reverse(stmts.begin(), stmts.end());
set_result(block::make(std::move(stmts)));
}
}
};

} // namespace
Expand Down
271 changes: 241 additions & 30 deletions builder/pipeline.cc
Original file line number Diff line number Diff line change
Expand Up @@ -552,6 +552,31 @@ stmt substitute_inputs(const stmt& s, const symbol_map<var>& subs) {
class pipeline_builder {
node_context& ctx;

struct allocation_candidate {
buffer_expr_ptr buffer;
int deps_count = 0;
int consumers_produced = 0;
int lifetime_start = -1;
int lifetime_end = -1;

explicit allocation_candidate(buffer_expr_ptr b) : buffer(b) {}
};

struct loop_id_less {
bool operator()(const loop_id& a, const loop_id& b) const {
if (a.root() && b.root()) return false;
if (a.root()) return true;
if (b.root()) return false;
if (a.func == b.func) return a.var < b.var;
return a.func < b.func;
}
};
std::map<loop_id, std::set<var>, loop_id_less> candidates_for_allocation_;
// Information tracking the lifetimes of the buffers.
symbol_map<allocation_candidate> allocation_info_;

int functions_produced_ = 0;

// Topologically sorted functions.
std::vector<const func*> order_;
// A mapping between func's and their compute_at locations.
Expand Down Expand Up @@ -702,12 +727,16 @@ class pipeline_builder {
}

// Generate the loops that we want to be explicit.
stmt make_loops(const func* f) {
// Returns generated statement as well as the lifetime range covered by it.
std::tuple<stmt, int, int> make_loops(const func* f) {
int old_function_produced = functions_produced_;

stmt result;
for (const auto& loop : f->loops()) {
result = make_loop(result, f, loop);
}
return result;

return std::make_tuple(result, old_function_produced, functions_produced_);
}

void compute_allocation_bounds() {
Expand Down Expand Up @@ -741,12 +770,98 @@ class pipeline_builder {
}
}

stmt produce(const func* f) {
// Returns generated statement for this function, as well as the
// lifetime range covered by it.
std::tuple<stmt, int, int> produce(const func* f) {
stmt result = sanitizer_.mutate(f->make_call());

for (const func::output& o : f->outputs()) {
const buffer_expr_ptr& b = o.buffer;
if (output_syms_.count(b->sym())) continue;

if (b->store_at()) {
candidates_for_allocation_[*b->store_at()].insert(b->sym());
} else {
candidates_for_allocation_[loop_id()].insert(b->sym());
}

allocation_info_[b->sym()]->buffer = b;
vksnk marked this conversation as resolved.
Show resolved Hide resolved
allocation_info_[b->sym()]->lifetime_start = functions_produced_;

if (allocation_info_[b->sym()]->consumers_produced == allocation_info_[b->sym()]->deps_count) {
allocation_info_[b->sym()]->lifetime_end = functions_produced_;
}
}

for (const auto& i : f->inputs()) {
const auto& input = i.buffer;
if (input->constant()) {
vksnk marked this conversation as resolved.
Show resolved Hide resolved
continue;
}
if (!input->producer()) {
continue;
}

allocation_info_[input->sym()]->consumers_produced++;
vksnk marked this conversation as resolved.
Show resolved Hide resolved

if (allocation_info_[input->sym()]->consumers_produced == allocation_info_[input->sym()]->deps_count) {
allocation_info_[input->sym()]->lifetime_end = functions_produced_;
}
}

functions_produced_++;

return std::make_tuple(result, functions_produced_ - 1, functions_produced_ - 1);
}

// Wraps provided body statement with the allocation node for a given buffer.
stmt produce_allocation(const buffer_expr_ptr& b, stmt body, symbol_map<var>& uncropped_subs) {
var uncropped = ctx.insert_unique(ctx.name(b->sym()) + ".uncropped");
uncropped_subs[b->sym()] = uncropped;
stmt result = clone_buffer::make(uncropped, b->sym(), body);

const std::vector<dim_expr>& dims = *inferred_dims_[b->sym()];
assert(allocation_bounds_[b->sym()]);
const box_expr& bounds = *allocation_bounds_[b->sym()];
result = allocate::make(b->sym(), b->storage(), b->elem_size(), dims, result);

std::vector<stmt> checks;
for (std::size_t d = 0; d < std::min(dims.size(), bounds.size()); ++d) {
checks.push_back(check::make(dims[d].min() <= bounds[d].min));
checks.push_back(check::make(dims[d].max() >= bounds[d].max));
}

result = block::make(std::move(checks), result);
return result;
}

// Computes number of consumers for each of the buffers.
void compute_deps_count() {
for (const func* f : order_) {
for (const func::output& o : f->outputs()) {
const buffer_expr_ptr& b = o.buffer;
if (output_syms_.count(b->sym())) continue;

if (!allocation_info_[b->sym()]) {
allocation_info_[b->sym()].emplace(b);
}
}
}

for (const func* f : order_) {
for (const auto& i : f->inputs()) {
const auto& input = i.buffer;
if (input->constant()) {
vksnk marked this conversation as resolved.
Show resolved Hide resolved
continue;
}
if (!input->producer()) {
continue;
}
allocation_info_[input->sym()]->deps_count++;
}
}
}

public:
pipeline_builder(node_context& ctx, const std::vector<buffer_expr_ptr>& inputs,
const std::vector<buffer_expr_ptr>& outputs, std::set<buffer_expr_ptr>& constants)
Expand Down Expand Up @@ -778,6 +893,9 @@ class pipeline_builder {

// Substitute inferred bounds into user provided dims.
substitute_buffer_dims();

// Compute number of consumers for each of the buffers.
compute_deps_count();
}

const std::vector<var>& external_symbols() const { return sanitizer_.external; }
Expand All @@ -795,55 +913,148 @@ class pipeline_builder {
// For each of the new loops, the `build()` is called for the case when there
// are func which need to be produced in that new loop.
stmt build(const stmt& body, const func* base_f, const loop_id& at) {
std::vector<stmt> results;

symbol_map<var> uncropped_subs;
std::vector<std::tuple<stmt, int, int>> results;
// Build the functions computed at this loop level.
for (auto i = order_.rbegin(); i != order_.rend(); ++i) {
const func* f = *i;
const auto& compute_at = compute_at_levels_.find(f);
assert(compute_at != compute_at_levels_.end());
std::set<var> old_candidates = candidates_for_allocation_[at];

const auto& realize_at = realization_levels_.find(f);
assert(realize_at != realization_levels_.end());

if (compute_at->second == at && !f->loops().empty()) {
results.push_back(make_loops(f));
std::tuple<stmt, int, int> f_body = make_loops(f);
// This is a special case for the buffers which are produced and consumed inside
// of this loop. In this case we simply wrap loop body with corresponding allocations.
if (candidates_for_allocation_[at].size() > old_candidates.size() + 1) {
std::set<var> to_remove;
vksnk marked this conversation as resolved.
Show resolved Hide resolved
for (auto b : candidates_for_allocation_[at]) {
if (old_candidates.count(b) > 0) continue;
vksnk marked this conversation as resolved.
Show resolved Hide resolved
if (allocation_info_[b]->consumers_produced != allocation_info_[b]->deps_count) continue;
vksnk marked this conversation as resolved.
Show resolved Hide resolved
if ((allocation_info_[b]->buffer->store_at() && *allocation_info_[b]->buffer->store_at() == at) ||
(!allocation_info_[b]->buffer->store_at() && at.root())) {
std::get<0>(f_body) =
produce_allocation(allocation_info_[b]->buffer, std::get<0>(f_body), uncropped_subs);
to_remove.insert(b);
}
}
for (auto b : to_remove) {
candidates_for_allocation_[at].erase(b);
}
}

results.push_back(f_body);
} else if (realize_at->second == at) {
results.push_back(produce(f));
vksnk marked this conversation as resolved.
Show resolved Hide resolved
std::tuple<stmt, int, int> f_body = produce(f);

results.push_back(f_body);
}
}

stmt result = block::make(std::move(results), body);
// This attempts to lay out allocation nodes such that the nesting
// is minimized. The general idea is to iteratively build up a tree of
// allocations starting from the allocations with the allocations with the
// shortest life time as the lowest level of the tree. This is not always possible
// in general to do and there are corner cases where nesting of allocations will
// have depth N regardless of the approach, but in most practical situations this
// will produce a structure close to the tree (for example, for the linear pipeline it
// should build a perfect tree of depth ~log(N)). Similarly, the complexity of this
// algorithm is O(N^2) for the worst case, but for the most practical pipelines it's
// should be O(N*log(N)).

// Combine buffer sym, start and end of the lifetime into a vector of tuples.
std::vector<std::tuple<buffer_expr_ptr, int, int>> lifetimes;
for (const auto& b : candidates_for_allocation_[at]) {
if (output_syms_.count(b)) continue;
vksnk marked this conversation as resolved.
Show resolved Hide resolved
if ((allocation_info_[b]->buffer->store_at() && *(allocation_info_[b]->buffer->store_at()) == at) ||
(!allocation_info_[b]->buffer->store_at() && at.root())) {
lifetimes.push_back(std::make_tuple(
allocation_info_[b]->buffer, allocation_info_[b]->lifetime_start, allocation_info_[b]->lifetime_end));
}
}

symbol_map<var> uncropped_subs;
// Add all allocations at this loop level. The allocations can be added in any order. This order enables aliasing
// copy dsts to srcs, which is more flexible than aliasing srcs to dsts.
for (const func* f : order_) {
for (const func::output& o : f->outputs()) {
const buffer_expr_ptr& b = o.buffer;
if (output_syms_.count(b->sym())) continue;
// Sort vector by (end - start) and then sym.
std::sort(lifetimes.begin(), lifetimes.end(),
[](std::tuple<buffer_expr_ptr, int, int> a, std::tuple<buffer_expr_ptr, int, int> b) {
vksnk marked this conversation as resolved.
Show resolved Hide resolved
if (std::get<2>(a) - std::get<1>(a) == std::get<2>(b) - std::get<1>(b)) {
return std::get<1>(a) < std::get<1>(b);
}
return std::get<2>(a) - std::get<1>(a) < std::get<2>(b) - std::get<1>(b);
});

int iteration_count = 0;
while (true) {
std::vector<std::tuple<buffer_expr_ptr, int, int>> new_lifetimes;
std::vector<std::tuple<stmt, int, int>> new_results;

std::size_t result_index = 0;
for (std::size_t ix = 0; ix < lifetimes.size() && result_index < results.size();) {
// Skip function bodies which go before the current buffer.
while (result_index < results.size() && std::get<2>(results[result_index]) < std::get<1>(lifetimes[ix])) {
new_results.push_back(results[result_index]);
result_index++;
}

if ((b->store_at() && *b->store_at() == at) || (!b->store_at() && at.root())) {
var uncropped = ctx.insert_unique(ctx.name(b->sym()) + ".uncropped");
uncropped_subs[b->sym()] = uncropped;
result = clone_buffer::make(uncropped, b->sym(), result);
int new_min = std::numeric_limits<int>::max();
int new_max = std::numeric_limits<int>::min();

// Find which function bodies overlap with the lifetime of the buffer.
std::vector<stmt> new_block;
while (result_index < results.size() && std::get<1>(results[result_index]) <= std::get<2>(lifetimes[ix]) &&
std::get<1>(lifetimes[ix]) <= std::get<2>(results[result_index])) {
new_min = std::min(new_min, std::get<1>(results[result_index]));
new_max = std::max(new_max, std::get<2>(results[result_index]));
new_block.push_back(std::get<0>(results[result_index]));
result_index++;
}

const std::vector<dim_expr>& dims = *inferred_dims_[b->sym()];
assert(allocation_bounds_[b->sym()]);
const box_expr& bounds = *allocation_bounds_[b->sym()];
result = allocate::make(b->sym(), b->storage(), b->elem_size(), dims, result);
// Combine overlapping function bodies and wrap them into current buffer allocation.
if (!new_block.empty()) {
stmt new_body = block::make(new_block);

std::vector<stmt> checks;
for (std::size_t d = 0; d < std::min(dims.size(), bounds.size()); ++d) {
checks.push_back(check::make(dims[d].min() <= bounds[d].min));
checks.push_back(check::make(dims[d].max() >= bounds[d].max));
}
buffer_expr_ptr b = std::get<0>(lifetimes[ix]);
// assert(candidates_for_allocation_[b->sym()]->consumers_produced ==
// candidates_for_allocation_[b->sym()]->deps_count);

new_body = produce_allocation(b, new_body, uncropped_subs);
candidates_for_allocation_[at].erase(b->sym());

new_results.push_back(std::make_tuple(new_body, new_min, new_max));
}

// Move to the next buffer.
ix++;

result = block::make(std::move(checks), result);
// Skip buffers which go before the next statement range/.
while (ix < lifetimes.size() && std::get<1>(lifetimes[ix]) <= new_max) {
new_lifetimes.push_back(lifetimes[ix]);
ix++;
}
}

for (std::size_t ix = result_index; ix < results.size(); ix++) {
vksnk marked this conversation as resolved.
Show resolved Hide resolved
new_results.push_back(results[ix]);
}

// No changes, so go for the next iteration.
if (lifetimes.size() == new_lifetimes.size()) break;

lifetimes = new_lifetimes;
vksnk marked this conversation as resolved.
Show resolved Hide resolved
results = new_results;
iteration_count++;
}

// Combine into one statement.
std::vector<stmt> results_stmt;
for (const auto& rs : results) {
results_stmt.push_back(std::get<0>(rs));
}

stmt result = block::make(std::move(results_stmt), body);

// Substitute references to the intermediate buffers with the 'name.uncropped' when they
// are used as an input arguments. This does a batch substitution by replacing multiple
// buffer names at once and relies on the fact that the same var can't be written
Expand Down Expand Up @@ -976,6 +1187,7 @@ stmt build_pipeline(node_context& ctx, const std::vector<buffer_expr_ptr>& input

stmt result;
result = builder.build(result, nullptr, loop_id());
// std::cout << "Initial IR:\n" << result << std::endl;
result = builder.add_input_checks(result);
result = builder.make_buffers(result);
result = builder.define_sanitized_replacements(result);
Expand Down Expand Up @@ -1003,7 +1215,6 @@ stmt build_pipeline(node_context& ctx, const std::vector<buffer_expr_ptr>& input
result = block::make(std::move(buffer_checks), std::move(result));

result = slide_and_fold_storage(result, ctx);

result = deshadow(result, builder.external_symbols(), ctx);
result = simplify(result);

Expand Down
Loading
Loading