Skip to content

Commit

Permalink
Fuse sibling ops when possible (#559)
Browse files Browse the repository at this point in the history
  • Loading branch information
dsharlet authored Jan 21, 2025
1 parent 4c6388f commit 426e998
Show file tree
Hide file tree
Showing 6 changed files with 143 additions and 10 deletions.
86 changes: 86 additions & 0 deletions builder/optimizations.cc
Original file line number Diff line number Diff line change
Expand Up @@ -882,6 +882,92 @@ stmt alias_in_place(const stmt& s, const std::vector<buffer_expr_ptr>& outputs)
return in_place_aliaser(outputs).mutate(s);
}

namespace {

template <typename T>
bool match(span<const T> a, span<const T> b) {
if (a.size() != b.size()) return false;
for (std::size_t i = 0; i < a.size(); ++i) {
if (!match(a[i], b[i])) return false;
}
return true;
}

class sibling_fuser : public stmt_mutator {
// Sibling buffer declarations can be fused if they produce the same buffer (same parameters).
static bool can_fuse(const allocate* a, const allocate* b) {
return a->storage == b->storage && match(a->elem_size, b->elem_size) && match<dim_expr>(a->dims, b->dims);
}
static bool can_fuse(const make_buffer* a, const make_buffer* b) {
return match(a->base, b->base) && match(a->elem_size, b->elem_size) && match<dim_expr>(a->dims, b->dims);
}
static bool can_fuse(const crop_dim* a, const crop_dim* b) {
return a->src == b->src && a->dim == b->dim && match(a->bounds, b->bounds);
}
static bool can_fuse(const crop_buffer* a, const crop_buffer* b) {
return a->src == b->src && match<interval_expr>(a->bounds, b->bounds);
}
static bool can_fuse(const slice_dim* a, const slice_dim* b) {
return a->src == b->src && a->dim == b->dim && match(a->at, b->at);
}
static bool can_fuse(const slice_buffer* a, const slice_buffer* b) {
return a->src == b->src && match<expr>(a->at, b->at);
}
static bool can_fuse(const transpose* a, const transpose* b) { return a->src == b->src && a->dims == b->dims; }

template <typename T>
static bool fuse(const T* a, const T* b, stmt& result) {
if (!a || !b || !can_fuse(a, b)) return false;

stmt body = block::make({a->body, substitute(b->body, b->sym, a->sym)});
result = clone_with(a, std::move(body));
return true;
}

static bool fuse(stmt& a, const stmt& b) {
return fuse(a.as<allocate>(), b.as<allocate>(), a) || fuse(a.as<make_buffer>(), b.as<make_buffer>(), a) ||
fuse(a.as<crop_dim>(), b.as<crop_dim>(), a) || fuse(a.as<crop_buffer>(), b.as<crop_buffer>(), a) ||
fuse(a.as<slice_dim>(), b.as<slice_dim>(), a) || fuse(a.as<slice_buffer>(), b.as<slice_buffer>(), a) ||
fuse(a.as<transpose>(), b.as<transpose>(), a);
}

public:
void visit(const block* op) override {
std::vector<stmt> result;
result.reserve(op->stmts.size());
bool changed = false;
for (const stmt& s : op->stmts) {
result.push_back(mutate(s));
changed = changed || !result.back().same_as(s);
}

// TODO: This currently only looks for immediately adjacent nodes that can be fused. We can also try to fuse
// ops with intervening ops, but this isn't obviously a simplification, and in the case of allocations, may
// increase peak memory usage.
for (std::size_t i = 0; i + 1 < result.size();) {
if (fuse(result[i], result[i + 1])) {
result.erase(result.begin() + i + 1);
changed = true;
} else {
++i;
}
}

if (changed) {
set_result(block::make(std::move(result)));
} else {
set_result(op);
}
}
};

} // namespace

stmt fuse_siblings(const stmt& s) {
scoped_trace trace("fuse_siblings");
return sibling_fuser().mutate(s);
}

stmt implement_copy(const copy_stmt* op, node_context& ctx) {
scoped_trace trace("implement_copy");
// Start by making a call to copy.
Expand Down
3 changes: 3 additions & 0 deletions builder/optimizations.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@ stmt alias_copies(const stmt& s, node_context& ctx, const std::vector<buffer_exp
// Replace allocations of input buffers to calls that can be computed in place with crops of the output buffer.
stmt alias_in_place(const stmt& s, const std::vector<buffer_expr_ptr>& outputs);

// Replace sibling stmts with a single stmt of a block where possible.
stmt fuse_siblings(const stmt& s);

// Given a copy_stmt, produce an implementation that calls `slinky::copy`, possibly inside loops that implement copy
// operations that `slinky::copy` cannot express.
stmt implement_copy(const copy_stmt* c, node_context& ctx);
Expand Down
2 changes: 2 additions & 0 deletions builder/pipeline.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1339,6 +1339,8 @@ stmt build_pipeline(node_context& ctx, const std::vector<buffer_expr_ptr>& input
result = deshadow(result, builder.external_symbols(), ctx);
result = simplify(result);

result = fuse_siblings(result);

if (options.no_checks) {
result = recursive_mutate<check>(result, [](const check* op) { return stmt(); });
// Simplify again, in case there are lets that the checks used that are now dead.
Expand Down
48 changes: 48 additions & 0 deletions builder/test/optimizations.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,54 @@ MATCHER_P(matches, expected, "") { return match(arg, expected); }

} // namespace

TEST(optimizations, fuse_siblings) {
auto use_buffer = [](var x) { return call_stmt::make(nullptr, {}, {x}, {}); };

ASSERT_THAT(fuse_siblings(block::make({
allocate::make(x, memory_type::heap, 1, {}, use_buffer(x)),
allocate::make(y, memory_type::heap, 1, {}, use_buffer(y)),
})),
matches(allocate::make(x, memory_type::heap, 1, {}, block::make({use_buffer(x), use_buffer(x)}))));

ASSERT_THAT(fuse_siblings(block::make({
allocate::make(x, memory_type::heap, 1, {}, use_buffer(x)),
allocate::make(y, memory_type::heap, 2, {}, use_buffer(y)),
})),
matches(fuse_siblings(block::make({
allocate::make(x, memory_type::heap, 1, {}, use_buffer(x)),
allocate::make(y, memory_type::heap, 2, {}, use_buffer(y)),
}))));

ASSERT_THAT(fuse_siblings(block::make({
allocate::make(x, memory_type::heap, 1, {}, use_buffer(x)),
allocate::make(y, memory_type::stack, 1, {}, use_buffer(y)),
})),
matches(fuse_siblings(block::make({
allocate::make(x, memory_type::heap, 1, {}, use_buffer(x)),
allocate::make(y, memory_type::stack, 1, {}, use_buffer(y)),
}))));

ASSERT_THAT(fuse_siblings(block::make({
allocate::make(x, memory_type::heap, 1, {{}}, use_buffer(x)),
allocate::make(y, memory_type::heap, 1, {}, use_buffer(y)),
})),
matches(fuse_siblings(block::make({
allocate::make(x, memory_type::heap, 1, {{}}, use_buffer(x)),
allocate::make(y, memory_type::heap, 1, {}, use_buffer(y)),
}))));

ASSERT_THAT(fuse_siblings(block::make({
allocate::make(x, memory_type::heap, 1, {}, use_buffer(x)),
use_buffer(z),
allocate::make(y, memory_type::heap, 1, {}, use_buffer(y)),
})),
matches(block::make({
allocate::make(x, memory_type::heap, 1, {}, use_buffer(x)),
use_buffer(z),
allocate::make(y, memory_type::heap, 1, {}, use_buffer(y)),
})));
}

TEST(optimizations, optimize_symbols) {
auto make_dummy_decl = [](var x, stmt body) { return allocate::make(x, memory_type::heap, 1, {}, body); };

Expand Down
2 changes: 1 addition & 1 deletion builder/test/softmax.cc
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ TEST_P(softmax, pipeline) {
if (copy_at_the_end == 2) {
ASSERT_EQ(eval_ctx.heap.allocs.size(), 6);
} else {
ASSERT_EQ(eval_ctx.heap.allocs.size(), 5);
ASSERT_EQ(eval_ctx.heap.allocs.size(), split_c ? 5 : 4);
}
}

Expand Down
12 changes: 3 additions & 9 deletions builder/test/visualize/softmax_split_0.html
Original file line number Diff line number Diff line change
Expand Up @@ -379,23 +379,17 @@
consume(softmax_in);
produce(max_in);
__event_t++;
free(softmax_in);
}
{ let exp_in = allocate('exp_in', 4, [
{bounds:[min(g, g_0), max(g, g_0)], stride:NaN, fold_factor:NaN},
{bounds:[buffer_min(out, 1), buffer_max(out, 1)], stride:NaN, fold_factor:NaN}
]);
consume(__in);
consume(max_in);
produce(exp_in);
produce(softmax_in);
produce(sum_exp_in);
__event_t++;
check(free(max_in));
consume(exp_in);
consume(softmax_in);
consume(sum_exp_in);
produce(softmax_out);
__event_t++;
free(exp_in);
free(softmax_in);
}
free(max_in);
}
Expand Down

0 comments on commit 426e998

Please sign in to comment.