Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize context copies for parallel loops #579

Merged
merged 9 commits into from
Jan 31, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions base/BUILD
Original file line number Diff line number Diff line change
@@ -6,6 +6,7 @@ package(
cc_library(
name = "base",
hdrs = [
"allocator.h",
"arithmetic.h",
"ref_count.h",
"modulus_remainder.h",
45 changes: 45 additions & 0 deletions base/allocator.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
#ifndef SLINKY_BASE_ALLOCATOR_H
#define SLINKY_BASE_ALLOCATOR_H

#include <cstddef>
#include <memory>
#include <type_traits>

namespace slinky {

// This is an STL allocator that doesn't default construct, enabling an STL container to manage uninitialized memory.
// https://howardhinnant.github.io/allocator_boilerplate.html
template <class T>
class uninitialized_allocator {
public:
using value_type = T;

uninitialized_allocator() noexcept {}
template <class U>
uninitialized_allocator(uninitialized_allocator<U> const&) noexcept {}

value_type* allocate(std::size_t n) { return static_cast<value_type*>(::operator new(n * sizeof(value_type))); }

void deallocate(value_type* p, std::size_t) noexcept { ::operator delete(p); }

template <class U, class... Args>
void construct(U* p, Args&&... args) {
if (sizeof...(args) > 0) {
::new (p) U(std::forward<Args>(args)...);
}
}
};

template <class T, class U>
bool operator==(uninitialized_allocator<T> const&, uninitialized_allocator<U> const&) noexcept {
return true;
}

template <class T, class U>
bool operator!=(uninitialized_allocator<T> const& x, uninitialized_allocator<U> const& y) noexcept {
return !(x == y);
}

} // namespace slinky

#endif // SLINKY_BASE_ARITHMETIC_H
10 changes: 5 additions & 5 deletions builder/node_mutator.cc
Original file line number Diff line number Diff line change
@@ -9,8 +9,8 @@ namespace slinky {

namespace {

template <typename T>
auto mutate_let(node_mutator* this_, const T* op) {
template <typename T, typename... Args>
auto mutate_let(node_mutator* this_, const T* op, Args... args) {
std::vector<std::pair<var, expr>> lets;
lets.reserve(op->lets.size());
bool changed = false;
@@ -23,7 +23,7 @@ auto mutate_let(node_mutator* this_, const T* op) {
if (!changed) {
return decltype(body){op};
} else {
return T::make(std::move(lets), std::move(body));
return T::make(std::move(lets), std::move(body), args...);
}
}

@@ -71,7 +71,7 @@ stmt clone_with(const transpose* op, var sym, stmt new_body) {
return transpose::make(sym, op->src, op->dims, std::move(new_body));
}

stmt clone_with(const let_stmt* op, stmt new_body) { return let_stmt::make(op->lets, std::move(new_body)); }
stmt clone_with(const let_stmt* op, stmt new_body) { return let_stmt::make(op->lets, std::move(new_body), op->is_closure); }

stmt clone_with(const loop* op, stmt new_body) { return clone_with(op, op->sym, std::move(new_body)); }
stmt clone_with(const allocate* op, stmt new_body) { return clone_with(op, op->sym, std::move(new_body)); }
@@ -132,7 +132,7 @@ void stmt_mutator::visit(const transpose* op) { set_result(mutate_decl(this, op)
void node_mutator::visit(const variable* op) { set_result(op); }
void node_mutator::visit(const constant* op) { set_result(op); }
void node_mutator::visit(const let* op) { set_result(mutate_let(this, op)); }
void node_mutator::visit(const let_stmt* op) { set_result(mutate_let(this, op)); }
void node_mutator::visit(const let_stmt* op) { set_result(mutate_let(this, op, op->is_closure)); }
void node_mutator::visit(const add* op) { set_result(mutate_binary(this, op)); }
void node_mutator::visit(const sub* op) { set_result(mutate_binary(this, op)); }
void node_mutator::visit(const mul* op) { set_result(mutate_binary(this, op)); }
13 changes: 11 additions & 2 deletions builder/optimizations.cc
Original file line number Diff line number Diff line change
@@ -1045,7 +1045,7 @@ stmt implement_copy(const copy_stmt* op, node_context& ctx) {
const raw_buffer* src_buf = ctx.lookup_buffer(op->outputs[0]);
const raw_buffer* dst_buf = ctx.lookup_buffer(op->outputs[1]);
const void* pad_value = (!padding || padding->empty()) ? nullptr : padding->data();
ctx.copy(*src_buf, *dst_buf, pad_value);
ctx.config->copy(*src_buf, *dst_buf, pad_value);
return 0;
},
{}, {op->src, dst}, std::move(copy_attrs));
@@ -1348,7 +1348,16 @@ class reuse_shadows : public stmt_mutator {
// We're entering a parallel loop. All the buffers in scope cannot be mutated in this scope.
symbol_map<bool> old_can_mutate;
std::swap(can_mutate, old_can_mutate);
stmt_mutator::visit(op);

stmt body = mutate(op->body);
std::vector<var> referenced = find_dependencies(body);
std::vector<std::pair<var, expr>> lets;
for (var i : referenced) {
lets.push_back({i, expr(i)});
}
body = let_stmt::make(std::move(lets), std::move(body), /*is_closure=*/true);
set_result(loop::make(op->sym, op->max_workers, op->bounds, op->step, std::move(body)));

can_mutate = std::move(old_can_mutate);
} else {
stmt_mutator::visit(op);
3 changes: 2 additions & 1 deletion builder/optimizations.h
Original file line number Diff line number Diff line change
@@ -31,7 +31,8 @@ stmt deshadow(const stmt& s, span<var> external_symbols, node_context& ctx);

// We can improve `evaluate`'s performance and memory usage if:
// - Buffer mutators are self-shadowing, so they can be performed in-place on existing buffers.
// - Symbols are indexed such that there are no unused symbol indices.
// - Make closures for parallel loop bodies, so evaluate doesn't need to copy the entire context.
// - (TODO) Symbols are indexed such that there are no unused symbol indices.
stmt optimize_symbols(const stmt& s, node_context& ctx);

// Guarantees that if match(a, b) is true, then a.same_as(b) is true, i.e. it rewrites matching nodes to be the same
5 changes: 3 additions & 2 deletions builder/pipeline.cc
Original file line number Diff line number Diff line change
@@ -1375,14 +1375,15 @@ stmt build_pipeline(node_context& ctx, const std::vector<buffer_expr_ptr>& input
result = simplify(result);
}

result = optimize_symbols(result, ctx);

result = insert_early_free(result);

if (options.trace) {
result = inject_traces(result, ctx);
}

// This pass adds closures around parallel loop bodies, any following passes need to maintain this closure.
result = optimize_symbols(result, ctx);

result = canonicalize_nodes(result);

if (is_verbose()) {
7 changes: 7 additions & 0 deletions builder/simplify.cc
Original file line number Diff line number Diff line change
@@ -1112,6 +1112,13 @@ class simplifier : public node_mutator {
}
}

while (const T* let_body = body.template as<T>()) {
// Flatten nested lets
lets.insert(lets.end(), let_body->lets.begin(), let_body->lets.end());
body = let_body->body;
values_changed = true;
}

if (lets.empty()) {
// All lets were removed.
set_result(std::move(body), std::move(body_info));
4 changes: 3 additions & 1 deletion builder/test/checks.cc
Original file line number Diff line number Diff line change
@@ -30,7 +30,9 @@ TEST(pipeline, checks) {
int checks_failed = 0;

eval_context eval_ctx;
eval_ctx.check_failed = [&](const expr& c) { checks_failed++; };
eval_config eval_cfg;
eval_cfg.check_failed = [&](const expr& c) { checks_failed++; };
eval_ctx.config = &eval_cfg;

buffer<int, 1> in_buf({N});
buffer<int, 1> out_buf({N});
18 changes: 10 additions & 8 deletions builder/test/context.cc
Original file line number Diff line number Diff line change
@@ -9,7 +9,7 @@

namespace slinky {

void setup_tracing(eval_context& ctx, const std::string& filename) {
void setup_tracing(eval_config& cfg, const std::string& filename) {
struct tracer {
std::string trace_file;
// Store the trace in a stringstream and write it at the end, to avoid overhead influencing the trace.
@@ -25,40 +25,42 @@ void setup_tracing(eval_context& ctx, const std::string& filename) {

auto t = std::make_shared<tracer>(filename);

ctx.trace_begin = [t](const char* op) -> index_t {
cfg.trace_begin = [t](const char* op) -> index_t {
t->trace.begin(op);
// chrome_trace expects trace_begin and trace_end to pass the string, while slinky's API expects to pass a token to
// trace_end returned by trace_begin. Because `index_t` must be able to hold a pointer, we'll just use the token to
// store the pointer.
return reinterpret_cast<index_t>(op);
};
ctx.trace_end = [t](index_t token) { t->trace.end(reinterpret_cast<const char*>(token)); };
cfg.trace_end = [t](index_t token) { t->trace.end(reinterpret_cast<const char*>(token)); };
}

test_context::test_context() {
static thread_pool_impl threads;

allocate = [this](var, raw_buffer* b) {
config.allocate = [this](var, raw_buffer* b) {
void* allocation = b->allocate();
heap.track_allocate(b->size_bytes());
return allocation;
};
free = [this](var, raw_buffer* b, void* allocation) {
config.free = [this](var, raw_buffer* b, void* allocation) {
::free(allocation);
heap.track_free(b->size_bytes());
};

copy = [this](const raw_buffer& src, const raw_buffer& dst, const void* padding) {
config.copy = [this](const raw_buffer& src, const raw_buffer& dst, const void* padding) {
++copy_calls;
copy_elements += dst.elem_count();
slinky::copy(src, dst, padding);
};
pad = [this](const dim* in_bounds, const raw_buffer& dst, const void* padding) {
config.pad = [this](const dim* in_bounds, const raw_buffer& dst, const void* padding) {
++pad_calls;
slinky::pad(in_bounds, dst, padding);
};

thread_pool = &threads;
config.thread_pool = &threads;

eval_context::config = &config;
}

} // namespace slinky
4 changes: 3 additions & 1 deletion builder/test/context.h
Original file line number Diff line number Diff line change
@@ -9,7 +9,7 @@

namespace slinky {

void setup_tracing(eval_context& ctx, const std::string& filename);
void setup_tracing(eval_config& config, const std::string& filename);

struct memory_info {
std::atomic<index_t> live_count = 0;
@@ -37,6 +37,8 @@ class test_context : public eval_context {
int copy_elements = 0;
int pad_calls = 0;

eval_config config;

test_context();
};

4 changes: 2 additions & 2 deletions builder/test/pipeline.cc
Original file line number Diff line number Diff line change
@@ -599,9 +599,9 @@ TEST_P(stencil_chain, pipeline) {

std::string test_name = "stencil_chain_split_" + std::string(max_workers == loop::serial ? "serial" : "parallel") +
"_split_" + std::to_string(split);
setup_tracing(eval_ctx, test_name + ".json");
setup_tracing(eval_ctx.config, test_name + ".json");

p.evaluate(inputs, outputs, eval_ctx);
p.evaluate(inputs, outputs, eval_ctx);

// Run the pipeline stages manually to get the reference result.
buffer<short, 2> ref_intm({W + 4, H + 4});
4 changes: 4 additions & 0 deletions builder/test/simplify/simplify.cc
Original file line number Diff line number Diff line change
@@ -240,6 +240,10 @@ TEST(simplify, let) {

// Duplicate lets
ASSERT_THAT(simplify(let::make({{x, y * 2}, {z, y * 2}}, x + z)), matches(let::make(x, y * 2, x * 2)));

// Nested lets
ASSERT_THAT(
simplify(let::make(x, y * 2, let::make(z, w + 2, x + z))), matches(let::make({{x, y * 2}, {z, w + 2}}, x + z)));
}

TEST(simplify, loop) {
Loading