Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize context copies for parallel loops #579

Merged
merged 9 commits into from
Jan 31, 2025
Merged
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions base/BUILD
Original file line number Diff line number Diff line change
@@ -6,6 +6,7 @@ package(
cc_library(
name = "base",
hdrs = [
"allocator.h",
"arithmetic.h",
"ref_count.h",
"modulus_remainder.h",
44 changes: 44 additions & 0 deletions base/allocator.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
#ifndef SLINKY_BASE_ALLOCATOR_H
#define SLINKY_BASE_ALLOCATOR_H

#include <cstddef>
#include <memory>
#include <type_traits>

namespace slinky {

// https://howardhinnant.github.io/allocator_boilerplate.html, modified to not default construct.
template <class T>
class uninitialized_allocator {
public:
using value_type = T;

uninitialized_allocator() noexcept {}
template <class U>
uninitialized_allocator(uninitialized_allocator<U> const&) noexcept {}

value_type* allocate(std::size_t n) { return static_cast<value_type*>(::operator new(n * sizeof(value_type))); }

void deallocate(value_type* p, std::size_t) noexcept { ::operator delete(p); }

template <class U, class... Args>
void construct(U* p, Args&&... args) {
if (sizeof...(args) > 0) {
::new (p) U(std::forward<Args>(args)...);
}
}
};

template <class T, class U>
bool operator==(uninitialized_allocator<T> const&, uninitialized_allocator<U> const&) noexcept {
return true;
}

template <class T, class U>
bool operator!=(uninitialized_allocator<T> const& x, uninitialized_allocator<U> const& y) noexcept {
return !(x == y);
}

} // namespace slinky

#endif // SLINKY_BASE_ARITHMETIC_H
10 changes: 5 additions & 5 deletions builder/node_mutator.cc
Original file line number Diff line number Diff line change
@@ -9,8 +9,8 @@ namespace slinky {

namespace {

template <typename T>
auto mutate_let(node_mutator* this_, const T* op) {
template <typename T, typename... Args>
auto mutate_let(node_mutator* this_, const T* op, Args... args) {
std::vector<std::pair<var, expr>> lets;
lets.reserve(op->lets.size());
bool changed = false;
@@ -23,7 +23,7 @@ auto mutate_let(node_mutator* this_, const T* op) {
if (!changed) {
return decltype(body){op};
} else {
return T::make(std::move(lets), std::move(body));
return T::make(std::move(lets), std::move(body), args...);
}
}

@@ -71,7 +71,7 @@ stmt clone_with(const transpose* op, var sym, stmt new_body) {
return transpose::make(sym, op->src, op->dims, std::move(new_body));
}

stmt clone_with(const let_stmt* op, stmt new_body) { return let_stmt::make(op->lets, std::move(new_body)); }
stmt clone_with(const let_stmt* op, stmt new_body) { return let_stmt::make(op->lets, std::move(new_body), op->is_closure); }

stmt clone_with(const loop* op, stmt new_body) { return clone_with(op, op->sym, std::move(new_body)); }
stmt clone_with(const allocate* op, stmt new_body) { return clone_with(op, op->sym, std::move(new_body)); }
@@ -132,7 +132,7 @@ void stmt_mutator::visit(const transpose* op) { set_result(mutate_decl(this, op)
void node_mutator::visit(const variable* op) { set_result(op); }
void node_mutator::visit(const constant* op) { set_result(op); }
void node_mutator::visit(const let* op) { set_result(mutate_let(this, op)); }
void node_mutator::visit(const let_stmt* op) { set_result(mutate_let(this, op)); }
void node_mutator::visit(const let_stmt* op) { set_result(mutate_let(this, op, op->is_closure)); }
void node_mutator::visit(const add* op) { set_result(mutate_binary(this, op)); }
void node_mutator::visit(const sub* op) { set_result(mutate_binary(this, op)); }
void node_mutator::visit(const mul* op) { set_result(mutate_binary(this, op)); }
13 changes: 11 additions & 2 deletions builder/optimizations.cc
Original file line number Diff line number Diff line change
@@ -1045,7 +1045,7 @@ stmt implement_copy(const copy_stmt* op, node_context& ctx) {
const raw_buffer* src_buf = ctx.lookup_buffer(op->outputs[0]);
const raw_buffer* dst_buf = ctx.lookup_buffer(op->outputs[1]);
const void* pad_value = (!padding || padding->empty()) ? nullptr : padding->data();
ctx.copy(*src_buf, *dst_buf, pad_value);
ctx.config->copy(*src_buf, *dst_buf, pad_value);
return 0;
},
{}, {op->src, dst}, std::move(copy_attrs));
@@ -1348,7 +1348,16 @@ class reuse_shadows : public stmt_mutator {
// We're entering a parallel loop. All the buffers in scope cannot be mutated in this scope.
symbol_map<bool> old_can_mutate;
std::swap(can_mutate, old_can_mutate);
stmt_mutator::visit(op);

stmt body = mutate(op->body);
std::vector<var> referenced = find_dependencies(body);
std::vector<std::pair<var, expr>> lets;
for (var i : referenced) {
lets.push_back({i, expr(i)});
}
body = let_stmt::make(std::move(lets), std::move(body), /*is_closure=*/true);
set_result(loop::make(op->sym, op->max_workers, op->bounds, op->step, std::move(body)));

can_mutate = std::move(old_can_mutate);
} else {
stmt_mutator::visit(op);
4 changes: 2 additions & 2 deletions builder/pipeline.cc
Original file line number Diff line number Diff line change
@@ -1375,14 +1375,14 @@ stmt build_pipeline(node_context& ctx, const std::vector<buffer_expr_ptr>& input
result = simplify(result);
}

result = optimize_symbols(result, ctx);

result = insert_early_free(result);

if (options.trace) {
result = inject_traces(result, ctx);
}

result = optimize_symbols(result, ctx);

result = canonicalize_nodes(result);

if (is_verbose()) {
4 changes: 3 additions & 1 deletion builder/test/checks.cc
Original file line number Diff line number Diff line change
@@ -30,7 +30,9 @@ TEST(pipeline, checks) {
int checks_failed = 0;

eval_context eval_ctx;
eval_ctx.check_failed = [&](const expr& c) { checks_failed++; };
eval_config eval_cfg;
eval_cfg.check_failed = [&](const expr& c) { checks_failed++; };
eval_ctx.config = &eval_cfg;

buffer<int, 1> in_buf({N});
buffer<int, 1> out_buf({N});
18 changes: 10 additions & 8 deletions builder/test/context.cc
Original file line number Diff line number Diff line change
@@ -9,7 +9,7 @@

namespace slinky {

void setup_tracing(eval_context& ctx, const std::string& filename) {
void setup_tracing(eval_config& cfg, const std::string& filename) {
struct tracer {
std::string trace_file;
// Store the trace in a stringstream and write it at the end, to avoid overhead influencing the trace.
@@ -25,40 +25,42 @@ void setup_tracing(eval_context& ctx, const std::string& filename) {

auto t = std::make_shared<tracer>(filename);

ctx.trace_begin = [t](const char* op) -> index_t {
cfg.trace_begin = [t](const char* op) -> index_t {
t->trace.begin(op);
// chrome_trace expects trace_begin and trace_end to pass the string, while slinky's API expects to pass a token to
// trace_end returned by trace_begin. Because `index_t` must be able to hold a pointer, we'll just use the token to
// store the pointer.
return reinterpret_cast<index_t>(op);
};
ctx.trace_end = [t](index_t token) { t->trace.end(reinterpret_cast<const char*>(token)); };
cfg.trace_end = [t](index_t token) { t->trace.end(reinterpret_cast<const char*>(token)); };
}

test_context::test_context() {
static thread_pool_impl threads;

allocate = [this](var, raw_buffer* b) {
config.allocate = [this](var, raw_buffer* b) {
void* allocation = b->allocate();
heap.track_allocate(b->size_bytes());
return allocation;
};
free = [this](var, raw_buffer* b, void* allocation) {
config.free = [this](var, raw_buffer* b, void* allocation) {
::free(allocation);
heap.track_free(b->size_bytes());
};

copy = [this](const raw_buffer& src, const raw_buffer& dst, const void* padding) {
config.copy = [this](const raw_buffer& src, const raw_buffer& dst, const void* padding) {
++copy_calls;
copy_elements += dst.elem_count();
slinky::copy(src, dst, padding);
};
pad = [this](const dim* in_bounds, const raw_buffer& dst, const void* padding) {
config.pad = [this](const dim* in_bounds, const raw_buffer& dst, const void* padding) {
++pad_calls;
slinky::pad(in_bounds, dst, padding);
};

thread_pool = &threads;
config.thread_pool = &threads;

eval_context::config = &config;
}

} // namespace slinky
4 changes: 3 additions & 1 deletion builder/test/context.h
Original file line number Diff line number Diff line change
@@ -9,7 +9,7 @@

namespace slinky {

void setup_tracing(eval_context& ctx, const std::string& filename);
void setup_tracing(eval_config& config, const std::string& filename);

struct memory_info {
std::atomic<index_t> live_count = 0;
@@ -37,6 +37,8 @@ class test_context : public eval_context {
int copy_elements = 0;
int pad_calls = 0;

eval_config config;

test_context();
};

4 changes: 2 additions & 2 deletions builder/test/pipeline.cc
Original file line number Diff line number Diff line change
@@ -599,9 +599,9 @@ TEST_P(stencil_chain, pipeline) {

std::string test_name = "stencil_chain_split_" + std::string(max_workers == loop::serial ? "serial" : "parallel") +
"_split_" + std::to_string(split);
setup_tracing(eval_ctx, test_name + ".json");
setup_tracing(eval_ctx.config, test_name + ".json");

p.evaluate(inputs, outputs, eval_ctx);
p.evaluate(inputs, outputs, eval_ctx);

// Run the pipeline stages manually to get the reference result.
buffer<short, 2> ref_intm({W + 4, H + 4});
Loading