From c86b41c95d96076e4017e39efa6bad19b8f4820a Mon Sep 17 00:00:00 2001 From: Dillon Date: Thu, 30 Jan 2025 20:28:45 -0800 Subject: [PATCH 1/9] Add `uninitialized_allocator` --- base/BUILD | 1 + base/allocator.h | 44 ++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 45 insertions(+) create mode 100644 base/allocator.h diff --git a/base/BUILD b/base/BUILD index 0a3705cd..ca42f07a 100644 --- a/base/BUILD +++ b/base/BUILD @@ -6,6 +6,7 @@ package( cc_library( name = "base", hdrs = [ + "allocator.h", "arithmetic.h", "ref_count.h", "modulus_remainder.h", diff --git a/base/allocator.h b/base/allocator.h new file mode 100644 index 00000000..ec73f59b --- /dev/null +++ b/base/allocator.h @@ -0,0 +1,44 @@ +#ifndef SLINKY_BASE_ALLOCATOR_H +#define SLINKY_BASE_ALLOCATOR_H + +#include +#include +#include + +namespace slinky { + +// https://howardhinnant.github.io/allocator_boilerplate.html, modified to not default construct. +template +class uninitialized_allocator { +public: + using value_type = T; + + uninitialized_allocator() noexcept {} + template + uninitialized_allocator(uninitialized_allocator const&) noexcept {} + + value_type* allocate(std::size_t n) { return static_cast(::operator new(n * sizeof(value_type))); } + + void deallocate(value_type* p, std::size_t) noexcept { ::operator delete(p); } + + template + void construct(U* p, Args&&... args) { + if (sizeof...(args) > 0) { + ::new (p) U(std::forward(args)...); + } + } +}; + +template +bool operator==(uninitialized_allocator const&, uninitialized_allocator const&) noexcept { + return true; +} + +template +bool operator!=(uninitialized_allocator const& x, uninitialized_allocator const& y) noexcept { + return !(x == y); +} + +} // namespace slinky + +#endif // SLINKY_BASE_ARITHMETIC_H From 2d1f77f1166880b562dfa98594982df5e9557f8e Mon Sep 17 00:00:00 2001 From: Dillon Date: Thu, 30 Jan 2025 20:29:00 -0800 Subject: [PATCH 2/9] optimize symbols after tracing --- builder/pipeline.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/builder/pipeline.cc b/builder/pipeline.cc index 772fd772..b0008a3b 100644 --- a/builder/pipeline.cc +++ b/builder/pipeline.cc @@ -1375,14 +1375,14 @@ stmt build_pipeline(node_context& ctx, const std::vector& input result = simplify(result); } - result = optimize_symbols(result, ctx); - result = insert_early_free(result); if (options.trace) { result = inject_traces(result, ctx); } + result = optimize_symbols(result, ctx); + result = canonicalize_nodes(result); if (is_verbose()) { From 17d19e4017117d54eb0aba5bd4ef868df9667fa7 Mon Sep 17 00:00:00 2001 From: Dillon Date: Thu, 30 Jan 2025 20:30:44 -0800 Subject: [PATCH 3/9] Leave eval context uninitialized --- runtime/evaluate.h | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/runtime/evaluate.h b/runtime/evaluate.h index 3549b978..df97a9e2 100644 --- a/runtime/evaluate.h +++ b/runtime/evaluate.h @@ -1,6 +1,7 @@ #ifndef SLINKY_RUNTIME_EVALUATE_H #define SLINKY_RUNTIME_EVALUATE_H +#include "base/allocator.h" #include "runtime/expr.h" #include "runtime/stmt.h" @@ -9,9 +10,8 @@ namespace slinky { class thread_pool; class eval_context { - // TODO: This should be uninitialized memory, not just for performance, but so we can detect uninitialized memory - // usage when evaluating. - std::vector values_; + // Leave uninitialized to avoid overhead and to detect uninitialized memory access via msan. + std::vector> values_; public: void reserve(std::size_t size) { From 6814f464466554319c109406faab1e2ae10a1da4 Mon Sep 17 00:00:00 2001 From: Dillon Date: Thu, 30 Jan 2025 20:57:18 -0800 Subject: [PATCH 4/9] Unify `find_dependencies` implementation with `depends_on` --- runtime/depends_on.cc | 179 +++++++++++++++---------------------- runtime/depends_on.h | 8 +- runtime/test/depends_on.cc | 20 ++++- 3 files changed, 97 insertions(+), 110 deletions(-) diff --git a/runtime/depends_on.cc b/runtime/depends_on.cc index 8dd7be0e..17e14b43 100644 --- a/runtime/depends_on.cc +++ b/runtime/depends_on.cc @@ -21,7 +21,11 @@ class dependencies : public recursive_node_visitor { std::vector> var_deps; depends_on_result dummy_deps; + // If non-null, we'll also add dependencies of variables as we find them. + std::map* unknown_deps = nullptr; + dependencies() {} + dependencies(std::map& unknown_deps) : unknown_deps(&unknown_deps) {} dependencies(std::vector> var_deps) : var_deps(var_deps) {} dependencies(span> deps) { var_deps.reserve(deps.size()); @@ -35,9 +39,23 @@ class dependencies : public recursive_node_visitor { for (auto i = var_deps.rbegin(); i != var_deps.rend(); ++i) { if (i->first == s) return i->second; } + + if (unknown_deps) { + return &(*unknown_deps)[s]; + } return nullptr; } + // Check if we need to shadow a declaration. + bool decl_needs_shadow(var s) const { + // Go in reverse order to handle shadowed declarations properly. + for (auto i = var_deps.rbegin(); i != var_deps.rend(); ++i) { + if (i->first == s) return i->second != &dummy_deps; + } + + return unknown_deps != nullptr; + } + depends_on_result* no_dummy(depends_on_result* deps) const { return deps != &dummy_deps ? deps : nullptr; } void visit(const variable* op) override { @@ -86,7 +104,7 @@ class dependencies : public recursive_node_visitor { size_t var_deps_count = var_deps.size(); for (const auto& p : op->lets) { p.second.accept(this); - if (no_dummy(find_deps(p.first))) { + if (decl_needs_shadow(p.first)) { var_deps.push_back({p.first, &dummy_deps}); } } @@ -94,13 +112,16 @@ class dependencies : public recursive_node_visitor { var_deps.resize(var_deps_count); } + void visit(const let* op) override { visit_let(op); } + void visit(const let_stmt* op) override { visit_let(op); } + void visit_sym_body(var sym, depends_on_result* src_deps, const stmt& body) { if (!body.defined()) return; size_t var_deps_count = var_deps.size(); if (no_dummy(src_deps)) { // We have src_deps we want to transitively add to via this declaration. var_deps.push_back({sym, src_deps}); - } else if (no_dummy(find_deps(sym))) { + } else if (decl_needs_shadow(sym)) { // We are shadowing something we are finding the dependencies of. Point at the dummy instead to avoid // contaminating the dependencies. var_deps.push_back({sym, &dummy_deps}); @@ -163,7 +184,7 @@ class dependencies : public recursive_node_visitor { // copy_stmt is effectively a declaration of the dst_x symbols for the src_x expressions. size_t var_deps_count = var_deps.size(); for (std::size_t i = 0; i < op->dst_x.size(); ++i) { - if (no_dummy(find_deps(op->dst_x[i]))) { + if (decl_needs_shadow(op->dst_x[i])) { var_deps.push_back({op->dst_x[i], &dummy_deps}); } } @@ -322,122 +343,64 @@ bool is_pure(expr_ref x) { namespace { -// Find the buffers accessed in a stmt. This is trickier than it seems, we need to track the lineage of buffers and -// report the buffer as it is visible to the caller. So something like: -// -// x = crop_dim(y, ...) { -// call(f, {x}, {}) -// } -// -// Will report that y is consumed in the stmt, not x. -class dependency_finder : public recursive_node_visitor { - bool input = true; - bool output = true; - bool data_only = false; - - symbol_map aliases; - -public: - std::vector result; - - dependency_finder() {} - dependency_finder(bool input, bool output) : input(input), output(output) {} - dependency_finder(bool data_only) : data_only(data_only) {} - - std::optional lookup_alias(var x) { - if (aliases.contains(x)) { - return aliases.lookup(x); - } else { - return x; - } +std::vector keys(const std::map& m) { std::vector result; + result.reserve(m.size()); + for (const auto& i : m) { + result.push_back(i.first); } - - void visit_buffer(var x) { - if (aliases.contains(x)) { - if (aliases.lookup(x)->defined()) { - x = *aliases.lookup(x); - } else { - return; - } - } - // Maintain result as a sorted unique list. - auto i = std::lower_bound(result.begin(), result.end(), x); - if (i == result.end() || *i != x) result.insert(i, x); - } - - void visit(const variable* op) override { - if (op->field != buffer_field::none && !data_only) { - visit_buffer(op->sym); - } - } - - void visit(const call* op) override { - if (op->intrinsic == intrinsic::buffer_at) { - auto buf = as_variable(op->args[0]); - assert(buf); - visit_buffer(*buf); - } - recursive_node_visitor::visit(op); - } - - void visit(const call_stmt* op) override { - if (input) { - for (const var& i : op->inputs) { - visit_buffer(i); - } - } - if (output) { - for (const var& i : op->outputs) { - visit_buffer(i); - } - } - } - void visit(const copy_stmt* op) override { - if (input) visit_buffer(op->src); - if (output) visit_buffer(op->dst); - } - - void visit(const allocate* op) override { - if (!op->body.defined()) return; - auto s = set_value_in_scope(aliases, op->sym, var()); - op->body.accept(this); - } - void visit(const make_buffer* op) override { - if (!op->body.defined()) return; - auto s = set_value_in_scope(aliases, op->sym, lookup_alias(find_buffer_data_dependency(op->base))); - op->body.accept(this); - } - - template - void visit_buffer_alias(const T* op) { - if (!op->body.defined()) return; - auto s = set_value_in_scope(aliases, op->sym, lookup_alias(op->src)); - op->body.accept(this); - } - void visit(const crop_buffer* op) override { visit_buffer_alias(op); } - void visit(const crop_dim* op) override { visit_buffer_alias(op); } - void visit(const slice_buffer* op) override { visit_buffer_alias(op); } - void visit(const slice_dim* op) override { visit_buffer_alias(op); } - void visit(const transpose* op) override { visit_buffer_alias(op); } - void visit(const clone_buffer* op) override { visit_buffer_alias(op); } -}; + return result; +} } // namespace -var find_buffer_data_dependency(expr_ref e) { - dependency_finder accessed(true); - if (e.defined()) e.accept(&accessed); - return accessed.result.size() == 1 ? accessed.result.front() : var(); +std::vector find_dependencies(expr_ref e) { + std::map deps; + dependencies v(deps); + if (e.defined()) e.accept(&v); + return keys(deps); +} +std::vector find_dependencies(stmt_ref s) { + std::map deps; + dependencies v(deps); + if (s.defined()) s.accept(&v); + return keys(deps); } std::vector find_buffer_dependencies(stmt_ref s) { return find_buffer_dependencies(s, true, true); } std::vector find_buffer_dependencies(stmt_ref s, bool input, bool output) { - dependency_finder accessed(input, output); - if (s.defined()) s.accept(&accessed); - return accessed.result; + std::map deps; + dependencies v(deps); + if (s.defined()) s.accept(&v); + + std::vector result; + result.reserve(deps.size()); + for (const auto& i : deps) { + if ((input && (i.second.buffer_input || i.second.buffer_src)) || + (output && (i.second.buffer_output || i.second.buffer_dst))) { + result.push_back(i.first); + } + } + return result; } +var find_buffer_data_dependency(expr_ref e) { + std::map deps; + dependencies v(deps); + if (e.defined()) e.accept(&v); + var result; + for (const auto& i : deps) { + if (i.second.buffer_base) { + if (result.defined()) { + // More than one data dependency + return var(); + } else { + result = i.first; + } + } + } + return result; +} } // namespace slinky diff --git a/runtime/depends_on.h b/runtime/depends_on.h index 8193a638..5824a133 100644 --- a/runtime/depends_on.h +++ b/runtime/depends_on.h @@ -47,8 +47,14 @@ bool can_substitute_buffer(const depends_on_result& r); // Check if the node depends on anything that may change value. bool is_pure(expr_ref x); -// Find the buffers used by a stmt or expr. Returns the vars accessed in sorted order. +// Find all the variables used by a node. +std::vector find_dependencies(expr_ref e); +std::vector find_dependencies(stmt_ref s); + +// Find a single buffer data dependency in e. var find_buffer_data_dependency(expr_ref e); + +// Find the buffers used by a stmt or expr. Returns the vars accessed in sorted order. std::vector find_buffer_dependencies(stmt_ref s); std::vector find_buffer_dependencies(stmt_ref s, bool input, bool output); diff --git a/runtime/test/depends_on.cc b/runtime/test/depends_on.cc index 6d857a60..07dbeade 100644 --- a/runtime/test/depends_on.cc +++ b/runtime/test/depends_on.cc @@ -107,12 +107,21 @@ TEST(depends_on, is_pure) { } TEST(find_buffer_dependencies, basic) { + ASSERT_THAT(find_buffer_dependencies(crop_buffer::make(z, y, {}, call_stmt::make(nullptr, {x}, {z}, {})), + /*input=*/false, /*output=*/true), + testing::ElementsAre(y)); ASSERT_EQ(find_buffer_data_dependency(buffer_at(x)), x); ASSERT_EQ(find_buffer_data_dependency(buffer_at(x, buffer_min(y, 0))), x); ASSERT_EQ(find_buffer_data_dependency(buffer_at(x) + buffer_at(y)), var()); ASSERT_THAT(find_buffer_dependencies(crop_buffer::make(x, y, {}, call_stmt::make(nullptr, {y}, {x}, {}))), testing::ElementsAre(y)); + ASSERT_THAT(find_buffer_dependencies(crop_buffer::make(z, y, {}, call_stmt::make(nullptr, {x}, {z}, {})), + /*input=*/true, /*output=*/false), + testing::ElementsAre(x)); + ASSERT_THAT(find_buffer_dependencies(crop_buffer::make(z, y, {}, call_stmt::make(nullptr, {x}, {z}, {})), + /*input=*/false, /*output=*/true), + testing::ElementsAre(y)); stmt test = block::make({ crop_buffer::make(z, x, {}, call_stmt::make(nullptr, {y}, {z}, {})), @@ -121,7 +130,16 @@ TEST(find_buffer_dependencies, basic) { }); ASSERT_THAT(find_buffer_dependencies(test, /*input=*/true, /*output=*/false), testing::ElementsAre(x, y)); - ASSERT_THAT(find_buffer_dependencies(test, /*input=*/false, /*output=*/true), testing::ElementsAre(x, w, u)); + ASSERT_THAT(find_buffer_dependencies(test, /*input=*/false, /*output=*/true), testing::ElementsAre(x, w)); +} + +TEST(find_dependencies, basic) { + ASSERT_THAT(find_dependencies(buffer_at(x)), testing::ElementsAre(x)); + ASSERT_THAT(find_dependencies(x + y), testing::ElementsAre(x, y)); + ASSERT_THAT(find_dependencies(let::make(x, y, x + z)), testing::ElementsAre(y, z)); + ASSERT_THAT(find_dependencies(crop_dim::make(x, y, 0, {z, z}, call_stmt::make(nullptr, {w}, {u}, {}))), + testing::ElementsAre(y, z, w, u)); + ASSERT_THAT(find_dependencies(block::make({check::make(x), check::make(y)})), testing::ElementsAre(x, y)); } } // namespace slinky From ad47d08a38abfcb93e4b78db2c0c2c689749ba20 Mon Sep 17 00:00:00 2001 From: Dillon Date: Thu, 30 Jan 2025 20:58:47 -0800 Subject: [PATCH 5/9] Move context config to a separate object --- builder/optimizations.cc | 2 +- builder/test/checks.cc | 4 +- builder/test/context.cc | 18 +++++---- builder/test/context.h | 4 +- builder/test/pipeline.cc | 4 +- runtime/evaluate.cc | 33 ++++++++------- runtime/evaluate.h | 64 ++++++++++++++++-------------- runtime/test/evaluate.cc | 8 +++- runtime/test/evaluate_benchmark.cc | 4 +- 9 files changed, 82 insertions(+), 59 deletions(-) diff --git a/builder/optimizations.cc b/builder/optimizations.cc index 38e2212a..4b707c8f 100644 --- a/builder/optimizations.cc +++ b/builder/optimizations.cc @@ -1045,7 +1045,7 @@ stmt implement_copy(const copy_stmt* op, node_context& ctx) { const raw_buffer* src_buf = ctx.lookup_buffer(op->outputs[0]); const raw_buffer* dst_buf = ctx.lookup_buffer(op->outputs[1]); const void* pad_value = (!padding || padding->empty()) ? nullptr : padding->data(); - ctx.copy(*src_buf, *dst_buf, pad_value); + ctx.config->copy(*src_buf, *dst_buf, pad_value); return 0; }, {}, {op->src, dst}, std::move(copy_attrs)); diff --git a/builder/test/checks.cc b/builder/test/checks.cc index 1e6477e7..c9b3a0d6 100644 --- a/builder/test/checks.cc +++ b/builder/test/checks.cc @@ -30,7 +30,9 @@ TEST(pipeline, checks) { int checks_failed = 0; eval_context eval_ctx; - eval_ctx.check_failed = [&](const expr& c) { checks_failed++; }; + eval_config eval_cfg; + eval_cfg.check_failed = [&](const expr& c) { checks_failed++; }; + eval_ctx.config = &eval_cfg; buffer in_buf({N}); buffer out_buf({N}); diff --git a/builder/test/context.cc b/builder/test/context.cc index faa40649..f22bb07a 100644 --- a/builder/test/context.cc +++ b/builder/test/context.cc @@ -9,7 +9,7 @@ namespace slinky { -void setup_tracing(eval_context& ctx, const std::string& filename) { +void setup_tracing(eval_config& cfg, const std::string& filename) { struct tracer { std::string trace_file; // Store the trace in a stringstream and write it at the end, to avoid overhead influencing the trace. @@ -25,40 +25,42 @@ void setup_tracing(eval_context& ctx, const std::string& filename) { auto t = std::make_shared(filename); - ctx.trace_begin = [t](const char* op) -> index_t { + cfg.trace_begin = [t](const char* op) -> index_t { t->trace.begin(op); // chrome_trace expects trace_begin and trace_end to pass the string, while slinky's API expects to pass a token to // trace_end returned by trace_begin. Because `index_t` must be able to hold a pointer, we'll just use the token to // store the pointer. return reinterpret_cast(op); }; - ctx.trace_end = [t](index_t token) { t->trace.end(reinterpret_cast(token)); }; + cfg.trace_end = [t](index_t token) { t->trace.end(reinterpret_cast(token)); }; } test_context::test_context() { static thread_pool_impl threads; - allocate = [this](var, raw_buffer* b) { + config.allocate = [this](var, raw_buffer* b) { void* allocation = b->allocate(); heap.track_allocate(b->size_bytes()); return allocation; }; - free = [this](var, raw_buffer* b, void* allocation) { + config.free = [this](var, raw_buffer* b, void* allocation) { ::free(allocation); heap.track_free(b->size_bytes()); }; - copy = [this](const raw_buffer& src, const raw_buffer& dst, const void* padding) { + config.copy = [this](const raw_buffer& src, const raw_buffer& dst, const void* padding) { ++copy_calls; copy_elements += dst.elem_count(); slinky::copy(src, dst, padding); }; - pad = [this](const dim* in_bounds, const raw_buffer& dst, const void* padding) { + config.pad = [this](const dim* in_bounds, const raw_buffer& dst, const void* padding) { ++pad_calls; slinky::pad(in_bounds, dst, padding); }; - thread_pool = &threads; + config.thread_pool = &threads; + + eval_context::config = &config; } } // namespace slinky diff --git a/builder/test/context.h b/builder/test/context.h index 8d4c641d..479f9263 100644 --- a/builder/test/context.h +++ b/builder/test/context.h @@ -9,7 +9,7 @@ namespace slinky { -void setup_tracing(eval_context& ctx, const std::string& filename); +void setup_tracing(eval_config& config, const std::string& filename); struct memory_info { std::atomic live_count = 0; @@ -37,6 +37,8 @@ class test_context : public eval_context { int copy_elements = 0; int pad_calls = 0; + eval_config config; + test_context(); }; diff --git a/builder/test/pipeline.cc b/builder/test/pipeline.cc index cfd636ed..8cbbc5ee 100644 --- a/builder/test/pipeline.cc +++ b/builder/test/pipeline.cc @@ -599,9 +599,9 @@ TEST_P(stencil_chain, pipeline) { std::string test_name = "stencil_chain_split_" + std::string(max_workers == loop::serial ? "serial" : "parallel") + "_split_" + std::to_string(split); - setup_tracing(eval_ctx, test_name + ".json"); + setup_tracing(eval_ctx.config, test_name + ".json"); - p.evaluate(inputs, outputs, eval_ctx); + p.evaluate(inputs, outputs, eval_ctx); // Run the pipeline stages manually to get the reference result. buffer ref_intm({W + 4, H + 4}); diff --git a/runtime/evaluate.cc b/runtime/evaluate.cc index 61343733..765e4c08 100644 --- a/runtime/evaluate.cc +++ b/runtime/evaluate.cc @@ -43,6 +43,11 @@ void dump_context_for_expr( } } +eval_context::eval_context() { + static eval_config default_config; + config = &default_config; +} + namespace { struct allocated_buffer : public raw_buffer { @@ -240,7 +245,7 @@ class evaluator { assert(op->args.size() == 2); index_t* sem = reinterpret_cast(eval(op->args[0])); index_t count = eval(op->args[1], 0); - context.thread_pool->atomic_call([=]() { *sem = count; }); + context.config->thread_pool->atomic_call([=]() { *sem = count; }); return 1; } @@ -253,7 +258,7 @@ class evaluator { sems[i] = reinterpret_cast(eval(op->args[i * 2 + 0])); counts[i] = eval(op->args[i * 2 + 1], 1); } - context.thread_pool->atomic_call([=]() { + context.config->thread_pool->atomic_call([=]() { for (std::size_t i = 0; i < sem_count; ++i) { *sems[i] += counts[i]; } @@ -270,7 +275,7 @@ class evaluator { sems[i] = reinterpret_cast(eval(op->args[i * 2 + 0])); counts[i] = eval(op->args[i * 2 + 1], 1); } - context.thread_pool->wait_for([=]() { + context.config->thread_pool->wait_for([=]() { // Check we can acquire all of the semaphores before acquiring any of them. for (std::size_t i = 0; i < sem_count; ++i) { if (*sems[i] < counts[i]) return false; @@ -287,13 +292,13 @@ class evaluator { index_t eval_trace_begin(const call* op) { assert(op->args.size() == 1); const char* name = reinterpret_cast(eval(op->args[0])); - return context.trace_begin ? context.trace_begin(name) : 0; + return context.config->trace_begin ? context.config->trace_begin(name) : 0; } index_t eval_trace_end(const call* op) { assert(op->args.size() == 1); - if (context.trace_end) { - context.trace_end(eval(op->args[0])); + if (context.config->trace_end) { + context.config->trace_end(eval(op->args[0])); } return 1; } @@ -302,7 +307,7 @@ class evaluator { assert(op->args.size() == 1); var sym = *as_variable(op->args[0]); allocated_buffer* buf = reinterpret_cast(context.lookup(sym)); - context.free(sym, buf, buf->allocation); + context.config->free(sym, buf, buf->allocation); buf->allocation = nullptr; return 1; } @@ -413,7 +418,7 @@ class evaluator { std::size_t n = ceil_div(bounds.max - bounds.min + 1, step); context.reserve(op->sym.id + 1); index_t old_value = context.set(op->sym, 0); - context.thread_pool->parallel_for( + context.config->thread_pool->parallel_for( n, [context = this->context, step, min = bounds.min, op, &result](index_t i) mutable { context.set(op->sym, i * step + min); @@ -457,8 +462,8 @@ class evaluator { } SLINKY_NO_INLINE void call_failed(index_t result, const call_stmt* op) { - if (context.call_failed) { - context.call_failed(op); + if (context.config->call_failed) { + context.config->call_failed(op); } else { std::cerr << "call_stmt failed: " << stmt(op) << "->" << result << std::endl; std::abort(); @@ -495,7 +500,7 @@ class evaluator { } if (op->storage == memory_type::heap) { - buffer.allocation = context.allocate(op->sym, &buffer); + buffer.allocation = context.config->allocate(op->sym, &buffer); } else { assert(op->storage == memory_type::stack); std::size_t size = buffer.init_strides(); @@ -506,7 +511,7 @@ class evaluator { index_t result = eval_with_value(op->body, op->sym, reinterpret_cast(&buffer)); if (op->storage == memory_type::heap) { - context.free(op->sym, &buffer, buffer.allocation); + context.config->free(op->sym, &buffer, buffer.allocation); } return result; @@ -735,8 +740,8 @@ class evaluator { } SLINKY_NO_INLINE index_t check_failed(const check* op) { - if (context.check_failed) { - context.check_failed(op->condition); + if (context.config->check_failed) { + context.config->check_failed(op->condition); } else { std::cerr << "Check failed: " << op->condition << std::endl; std::cerr << "Context: " << std::endl; diff --git a/runtime/evaluate.h b/runtime/evaluate.h index df97a9e2..ddb714ed 100644 --- a/runtime/evaluate.h +++ b/runtime/evaluate.h @@ -9,11 +9,45 @@ namespace slinky { class thread_pool; +struct eval_config { + // These two functions implement allocation. `allocate` is called before + // running the body, and should assign `base` of the buffer to the address + // of the min in each dimension. `free` is called after running the body, + // passing the result of `allocate` in addition to the buffer. + // If these functions are not defined, the default handler will call + // `raw_buffer::allocate` and `::free`. + std::function allocate = [](var, raw_buffer* buf) { return buf->allocate(); }; + std::function free = [](var, raw_buffer*, void* allocation) { ::free(allocation); }; + + // Functions called when there is a failure in the pipeline. + // If these functions are not defined, the default handler will write a + // message to cerr and abort. + std::function check_failed; + std::function call_failed; + + // A pointer to a thread pool, required for parallel + slinky::thread_pool* thread_pool = nullptr; + + // Functions implementing buffer data movement: + // - `copy` should copy from `src` to `dst`, filling `dst` with `padding` when out of bounds of `src`. + // - `pad` should fill the area out of bounds of `src_dims` with `padding` in `dst`. + std::function copy = + static_cast(slinky::copy); + std::function pad = + static_cast(slinky::pad); + + // Functions implementing the `trace_begin` and `trace_end` intrinsics. + std::function trace_begin; + std::function trace_end; +}; + class eval_context { // Leave uninitialized to avoid overhead and to detect uninitialized memory access via msan. std::vector> values_; public: + eval_context(); + void reserve(std::size_t size) { if (size > values_.size()) { values_.resize(std::max(values_.size() * 2, size)); @@ -46,35 +80,7 @@ class eval_context { std::size_t size() const { return values_.size(); } - // These two functions implement allocation. `allocate` is called before - // running the body, and should assign `base` of the buffer to the address - // of the min in each dimension. `free` is called after running the body, - // passing the result of `allocate` in addition to the buffer. - // If these functions are not defined, the default handler will call - // `raw_buffer::allocate` and `::free`. - std::function allocate = [](var, raw_buffer* buf) { return buf->allocate(); }; - std::function free = [](var, raw_buffer*, void* allocation) { ::free(allocation); }; - - // Functions called when there is a failure in the pipeline. - // If these functions are not defined, the default handler will write a - // message to cerr and abort. - std::function check_failed; - std::function call_failed; - - // A pointer to a thread pool, required for parallel - slinky::thread_pool* thread_pool = nullptr; - - // Functions implementing buffer data movement: - // - `copy` should copy from `src` to `dst`, filling `dst` with `padding` when out of bounds of `src`. - // - `pad` should fill the area out of bounds of `src_dims` with `padding` in `dst`. - std::function copy = - static_cast(slinky::copy); - std::function pad = - static_cast(slinky::pad); - - // Functions called every time a stmt begins or ends evaluation. - std::function trace_begin; - std::function trace_end; + const eval_config* config; }; index_t evaluate(const expr& e, eval_context& context); diff --git a/runtime/test/evaluate.cc b/runtime/test/evaluate.cc index eb2d5b4d..8234990b 100644 --- a/runtime/test/evaluate.cc +++ b/runtime/test/evaluate.cc @@ -98,7 +98,9 @@ TEST(evaluate, call) { TEST(evaluate, loop) { eval_context ctx; thread_pool_impl t; - ctx.thread_pool = &t; + eval_config cfg; + cfg.thread_pool = &t; + ctx.config = &cfg; for (int max_workers : {loop::serial, 2, 3, loop::parallel}) { std::atomic sum_x = 0; @@ -229,7 +231,9 @@ TEST(evaluate, clone_buffer) { TEST(evaluate, semaphore) { eval_context ctx; thread_pool_impl t; - ctx.thread_pool = &t; + eval_config cfg; + cfg.thread_pool = &t; + ctx.config = &cfg; index_t sem1 = 0; index_t sem2 = 0; diff --git a/runtime/test/evaluate_benchmark.cc b/runtime/test/evaluate_benchmark.cc index 29c66f6a..4f5148bf 100644 --- a/runtime/test/evaluate_benchmark.cc +++ b/runtime/test/evaluate_benchmark.cc @@ -216,8 +216,10 @@ void benchmark_parallel_loop(benchmark::State& state, bool synchronize) { body = loop::make(x, workers, range(0, iterations), 1, body); eval_context eval_ctx; + eval_config config; thread_pool_impl t(workers); - eval_ctx.thread_pool = &t; + config.thread_pool = &t; + eval_ctx.config = &config; for (auto _ : state) { evaluate(body, eval_ctx); From 9452c5a47e9e3a1d2de14ddaaaa5027dbade7dd5 Mon Sep 17 00:00:00 2001 From: Dillon Date: Thu, 30 Jan 2025 21:18:54 -0800 Subject: [PATCH 6/9] Add `is_closure` flag to `let_stmt` --- builder/node_mutator.cc | 10 +++++----- builder/optimizations.cc | 11 ++++++++++- runtime/evaluate.cc | 19 ++++++++++++++++--- runtime/expr_stmt.cc | 12 +++++++----- runtime/print.cc | 5 +++-- runtime/stmt.h | 10 +++++++++- 6 files changed, 50 insertions(+), 17 deletions(-) diff --git a/builder/node_mutator.cc b/builder/node_mutator.cc index 31a4b341..b7c05b0f 100644 --- a/builder/node_mutator.cc +++ b/builder/node_mutator.cc @@ -9,8 +9,8 @@ namespace slinky { namespace { -template -auto mutate_let(node_mutator* this_, const T* op) { +template +auto mutate_let(node_mutator* this_, const T* op, Args... args) { std::vector> lets; lets.reserve(op->lets.size()); bool changed = false; @@ -23,7 +23,7 @@ auto mutate_let(node_mutator* this_, const T* op) { if (!changed) { return decltype(body){op}; } else { - return T::make(std::move(lets), std::move(body)); + return T::make(std::move(lets), std::move(body), args...); } } @@ -71,7 +71,7 @@ stmt clone_with(const transpose* op, var sym, stmt new_body) { return transpose::make(sym, op->src, op->dims, std::move(new_body)); } -stmt clone_with(const let_stmt* op, stmt new_body) { return let_stmt::make(op->lets, std::move(new_body)); } +stmt clone_with(const let_stmt* op, stmt new_body) { return let_stmt::make(op->lets, std::move(new_body), op->is_closure); } stmt clone_with(const loop* op, stmt new_body) { return clone_with(op, op->sym, std::move(new_body)); } stmt clone_with(const allocate* op, stmt new_body) { return clone_with(op, op->sym, std::move(new_body)); } @@ -132,7 +132,7 @@ void stmt_mutator::visit(const transpose* op) { set_result(mutate_decl(this, op) void node_mutator::visit(const variable* op) { set_result(op); } void node_mutator::visit(const constant* op) { set_result(op); } void node_mutator::visit(const let* op) { set_result(mutate_let(this, op)); } -void node_mutator::visit(const let_stmt* op) { set_result(mutate_let(this, op)); } +void node_mutator::visit(const let_stmt* op) { set_result(mutate_let(this, op, op->is_closure)); } void node_mutator::visit(const add* op) { set_result(mutate_binary(this, op)); } void node_mutator::visit(const sub* op) { set_result(mutate_binary(this, op)); } void node_mutator::visit(const mul* op) { set_result(mutate_binary(this, op)); } diff --git a/builder/optimizations.cc b/builder/optimizations.cc index 4b707c8f..d0b554f7 100644 --- a/builder/optimizations.cc +++ b/builder/optimizations.cc @@ -1348,7 +1348,16 @@ class reuse_shadows : public stmt_mutator { // We're entering a parallel loop. All the buffers in scope cannot be mutated in this scope. symbol_map old_can_mutate; std::swap(can_mutate, old_can_mutate); - stmt_mutator::visit(op); + + stmt body = mutate(op->body); + std::vector referenced = find_dependencies(body); + std::vector> lets; + for (var i : referenced) { + lets.push_back({i, expr(i)}); + } + body = let_stmt::make(std::move(lets), std::move(body), /*is_closure=*/true); + set_result(loop::make(op->sym, op->max_workers, op->bounds, op->step, std::move(body))); + can_mutate = std::move(old_can_mutate); } else { stmt_mutator::visit(op); diff --git a/runtime/evaluate.cc b/runtime/evaluate.cc index 765e4c08..41f9f91a 100644 --- a/runtime/evaluate.cc +++ b/runtime/evaluate.cc @@ -417,10 +417,24 @@ class evaluator { std::atomic result = 0; std::size_t n = ceil_div(bounds.max - bounds.min + 1, step); context.reserve(op->sym.id + 1); - index_t old_value = context.set(op->sym, 0); context.config->thread_pool->parallel_for( n, - [context = this->context, step, min = bounds.min, op, &result](index_t i) mutable { + [parent_context = &context, step, min = bounds.min, op, &result](index_t i) mutable { + eval_context context; + if (const let_stmt* closure = is_closure(op->body)) { + // The body is a closure, so we know exactly which symbols we need to copy to the new local context. + context.reserve(parent_context->size()); + context.config = parent_context->config; + + // Assume that this let_stmt is a closure for this loop. We'll evaluate the values using the parent context, + // but assign them to our local context. + for (const std::pair& i : closure->lets) { + context[i.first] = evaluate(i.second, *parent_context); + } + } else { + // We don't have a closure, just copy the whole context. + context = *parent_context; + } context.set(op->sym, i * step + min); // Evaluate the parallel loop body with our copy of the context. index_t result_i = evaluate(op->body, context); @@ -430,7 +444,6 @@ class evaluator { } }, op->max_workers); - context.set(op->sym, old_value); return result; } diff --git a/runtime/expr_stmt.cc b/runtime/expr_stmt.cc index 63394c6d..3c28aa53 100644 --- a/runtime/expr_stmt.cc +++ b/runtime/expr_stmt.cc @@ -62,7 +62,7 @@ expr make_bin_op(expr a, expr b) { } template -Body make_let(std::vector> lets, Body body) { +T* make_let(std::vector> lets, Body body) { auto n = new T(); n->lets = std::move(lets); if (const T* l = body.template as()) { @@ -71,17 +71,19 @@ Body make_let(std::vector> lets, Body body) { } else { n->body = std::move(body); } - return Body(n); + return n; } expr let::make(std::vector> lets, expr body) { - return make_let(std::move(lets), std::move(body)); + return expr(make_let(std::move(lets), std::move(body))); } expr let::make(var sym, expr value, expr body) { return make({{sym, std::move(value)}}, std::move(body)); } -stmt let_stmt::make(std::vector> lets, stmt body) { - return make_let(std::move(lets), std::move(body)); +stmt let_stmt::make(std::vector> lets, stmt body, bool is_closure) { + let_stmt* n = make_let(std::move(lets), std::move(body)); + n->is_closure = is_closure; + return stmt(n); } stmt let_stmt::make(var sym, expr value, stmt body) { return make({{sym, std::move(value)}}, std::move(body)); } diff --git a/runtime/print.cc b/runtime/print.cc index e9a4559d..860c92d4 100644 --- a/runtime/print.cc +++ b/runtime/print.cc @@ -233,10 +233,11 @@ class printer : public expr_visitor, public stmt_visitor { } void visit(const let_stmt* l) override { + const char* tag = l->is_closure ? "closure" : "let"; if (l->lets.size() == 1) { - *this << indent() << "let " << l->lets.front().first << " = " << l->lets.front().second << " in {\n"; + *this << indent() << tag << " " << l->lets.front().first << " = " << l->lets.front().second << " in {\n"; } else { - *this << indent() << "let {\n"; + *this << indent() << tag << " {\n"; *this << indent(2); print_vector(l->lets, ",\n" + indent(2)); *this << "\n" << indent() << "} in {\n"; diff --git a/runtime/stmt.h b/runtime/stmt.h index 6b4c60ba..cb6bb28e 100644 --- a/runtime/stmt.h +++ b/runtime/stmt.h @@ -183,9 +183,12 @@ class let_stmt : public stmt_node { std::vector> lets; stmt body; + // If this is true, then the body does not access any symbols outside of those defined by `lets`. + bool is_closure; + void accept(stmt_visitor* v) const override; - static stmt make(std::vector> lets, stmt body); + static stmt make(std::vector> lets, stmt body, bool is_closure = false); static stmt make(var sym, expr value, stmt body); @@ -396,6 +399,11 @@ class check : public stmt_node { static constexpr stmt_node_type static_type = stmt_node_type::check; }; +SLINKY_ALWAYS_INLINE inline const let_stmt* is_closure(const stmt& s) { + const let_stmt* let = s.as(); + return let && let->is_closure ? let : nullptr; +} + class stmt_visitor { public: virtual ~stmt_visitor() = default; From 999b24df56c9ffc6dacf03677a13aa16c44f8486 Mon Sep 17 00:00:00 2001 From: Dillon Date: Thu, 30 Jan 2025 21:53:57 -0800 Subject: [PATCH 7/9] Print closures on a single line --- runtime/print.cc | 26 +++++++++++++++++--------- 1 file changed, 17 insertions(+), 9 deletions(-) diff --git a/runtime/print.cc b/runtime/print.cc index 860c92d4..efd97570 100644 --- a/runtime/print.cc +++ b/runtime/print.cc @@ -1,8 +1,8 @@ #include "runtime/print.h" #include -#include #include +#include #include "runtime/expr.h" #include "runtime/stmt.h" @@ -162,7 +162,13 @@ class printer : public expr_visitor, public stmt_visitor { return *this << "{" << d.bounds << ", " << d.stride << ", " << d.fold_factor << "}"; } - printer& operator<<(const std::pair& let) { return *this << let.first << " = " << let.second; } + printer& operator<<(const std::pair& let) { + if (is_variable(let.second, let.first)) { + return *this << let.first; + } else { + return *this << let.first << " = " << let.second; + } + } template void print_vector(const std::vector& v, const std::string& sep = ", ") { @@ -212,7 +218,7 @@ class printer : public expr_visitor, public stmt_visitor { std::string indent(int extra = 0) const { return std::string(depth + extra, ' '); } - void visit(const variable* v) override { + void visit(const variable* v) override { switch (v->field) { case buffer_field::none: *this << v->sym; return; case buffer_field::rank: *this << "buffer_rank(" << v->sym << ")"; return; @@ -234,7 +240,11 @@ class printer : public expr_visitor, public stmt_visitor { void visit(const let_stmt* l) override { const char* tag = l->is_closure ? "closure" : "let"; - if (l->lets.size() == 1) { + if (std::all_of(l->lets.begin(), l->lets.end(), [&](const auto& i) { return as_variable(i.second); })) { + *this << indent() << tag << " {"; + print_vector(l->lets, ", "); + *this << "} in {\n"; + } else if (l->lets.size() == 1) { *this << indent() << tag << " " << l->lets.front().first << " = " << l->lets.front().second << " in {\n"; } else { *this << indent() << tag << " {\n"; @@ -246,9 +256,7 @@ class printer : public expr_visitor, public stmt_visitor { *this << indent() << "}\n"; } - void visit_bin_op(const expr& a, const char* s, const expr& b) { - *this << "(" << a << s << b << ")"; - } + void visit_bin_op(const expr& a, const char* s, const expr& b) { *this << "(" << a << s << b << ")"; } void visit(const add* op) override { visit_bin_op(op->a, " + ", op->b); } void visit(const sub* op) override { visit_bin_op(op->a, " - ", op->b); } @@ -425,14 +433,14 @@ void print(std::ostream& os, const stmt& s, const node_context* ctx) { p << s; } -std::string to_string(var x) { +std::string to_string(var x) { std::stringstream ss; printer p(ss, default_context); p << x; return ss.str(); } -std::ostream& operator<<(std::ostream& os, var sym) { +std::ostream& operator<<(std::ostream& os, var sym) { print(os, sym); return os; } From 708508dfa7b9a1ded5a810079510a3461c1363ac Mon Sep 17 00:00:00 2001 From: Dillon Date: Thu, 30 Jan 2025 23:27:39 -0800 Subject: [PATCH 8/9] Don't flatten lets at construction --- builder/simplify.cc | 7 +++++++ builder/test/simplify/simplify.cc | 4 ++++ runtime/expr_stmt.cc | 7 +------ 3 files changed, 12 insertions(+), 6 deletions(-) diff --git a/builder/simplify.cc b/builder/simplify.cc index aeaa156a..5cbfd3db 100644 --- a/builder/simplify.cc +++ b/builder/simplify.cc @@ -1112,6 +1112,13 @@ class simplifier : public node_mutator { } } + while (const T* let_body = body.template as()) { + // Flatten nested lets + lets.insert(lets.end(), let_body->lets.begin(), let_body->lets.end()); + body = let_body->body; + values_changed = true; + } + if (lets.empty()) { // All lets were removed. set_result(std::move(body), std::move(body_info)); diff --git a/builder/test/simplify/simplify.cc b/builder/test/simplify/simplify.cc index 4bc596ea..a55790da 100644 --- a/builder/test/simplify/simplify.cc +++ b/builder/test/simplify/simplify.cc @@ -240,6 +240,10 @@ TEST(simplify, let) { // Duplicate lets ASSERT_THAT(simplify(let::make({{x, y * 2}, {z, y * 2}}, x + z)), matches(let::make(x, y * 2, x * 2))); + + // Nested lets + ASSERT_THAT( + simplify(let::make(x, y * 2, let::make(z, w + 2, x + z))), matches(let::make({{x, y * 2}, {z, w + 2}}, x + z))); } TEST(simplify, loop) { diff --git a/runtime/expr_stmt.cc b/runtime/expr_stmt.cc index 3c28aa53..4710c8c1 100644 --- a/runtime/expr_stmt.cc +++ b/runtime/expr_stmt.cc @@ -65,12 +65,7 @@ template T* make_let(std::vector> lets, Body body) { auto n = new T(); n->lets = std::move(lets); - if (const T* l = body.template as()) { - n->lets.insert(n->lets.end(), l->lets.begin(), l->lets.end()); - n->body = l->body; - } else { - n->body = std::move(body); - } + n->body = std::move(body); return n; } From b6ca2dd6ac20d419b6584f7d9cda62d0229b71e9 Mon Sep 17 00:00:00 2001 From: Dillon Date: Thu, 30 Jan 2025 23:31:30 -0800 Subject: [PATCH 9/9] Add comments --- base/allocator.h | 3 ++- builder/optimizations.h | 3 ++- builder/pipeline.cc | 1 + 3 files changed, 5 insertions(+), 2 deletions(-) diff --git a/base/allocator.h b/base/allocator.h index ec73f59b..ad9a453a 100644 --- a/base/allocator.h +++ b/base/allocator.h @@ -7,7 +7,8 @@ namespace slinky { -// https://howardhinnant.github.io/allocator_boilerplate.html, modified to not default construct. +// This is an STL allocator that doesn't default construct, enabling an STL container to manage uninitialized memory. +// https://howardhinnant.github.io/allocator_boilerplate.html template class uninitialized_allocator { public: diff --git a/builder/optimizations.h b/builder/optimizations.h index c69a1c08..e8487887 100644 --- a/builder/optimizations.h +++ b/builder/optimizations.h @@ -31,7 +31,8 @@ stmt deshadow(const stmt& s, span external_symbols, node_context& ctx); // We can improve `evaluate`'s performance and memory usage if: // - Buffer mutators are self-shadowing, so they can be performed in-place on existing buffers. -// - Symbols are indexed such that there are no unused symbol indices. +// - Make closures for parallel loop bodies, so evaluate doesn't need to copy the entire context. +// - (TODO) Symbols are indexed such that there are no unused symbol indices. stmt optimize_symbols(const stmt& s, node_context& ctx); // Guarantees that if match(a, b) is true, then a.same_as(b) is true, i.e. it rewrites matching nodes to be the same diff --git a/builder/pipeline.cc b/builder/pipeline.cc index b0008a3b..bc477514 100644 --- a/builder/pipeline.cc +++ b/builder/pipeline.cc @@ -1381,6 +1381,7 @@ stmt build_pipeline(node_context& ctx, const std::vector& input result = inject_traces(result, ctx); } + // This pass adds closures around parallel loop bodies, any following passes need to maintain this closure. result = optimize_symbols(result, ctx); result = canonicalize_nodes(result);