From 32fbea536e841ffdaf5f81e46a468e0f16367db6 Mon Sep 17 00:00:00 2001 From: Dillon Date: Fri, 10 Jan 2025 09:09:07 -0800 Subject: [PATCH] Avoid extra `std::function` layer in `func::make` (#546) For the lambda version of `func::make`, we currently get two layers of `std::function`: one for `call_stmt::callable`, which is unavoidable, but the other is just a wrapper to change the arguments, and doesn't need to be an `std::function`, which adds some extra overhead. This PR avoids the inner `std::function`, calling the lambda directly instead. --- builder/pipeline.h | 45 +++++++++++++++++++++++++++++++++++----- builder/test/pipeline.cc | 16 +++++++------- runtime/buffer.h | 2 ++ 3 files changed, 51 insertions(+), 12 deletions(-) diff --git a/builder/pipeline.h b/builder/pipeline.h index f39f7c96..7f5f3c95 100644 --- a/builder/pipeline.h +++ b/builder/pipeline.h @@ -1,6 +1,8 @@ #ifndef SLINKY_BUILDER_PIPELINE_H #define SLINKY_BUILDER_PIPELINE_H +#include + #include "base/ref_count.h" #include "runtime/evaluate.h" #include "runtime/expr.h" @@ -87,6 +89,29 @@ class buffer_expr : public ref_counted { static void destroy(buffer_expr* p) { delete p; } }; +namespace internal { + +template +struct buffer_converter { + static SLINKY_ALWAYS_INLINE const auto& convert(const raw_buffer* buffer) { + return buffer->cast::type>::type::element>(); + } +}; +template <> +struct buffer_converter { + static SLINKY_ALWAYS_INLINE const raw_buffer& convert(const raw_buffer* buffer) { return *buffer; } +}; +template <> +struct buffer_converter { + static SLINKY_ALWAYS_INLINE const raw_buffer& convert(const raw_buffer* buffer) { return *buffer; } +}; +template <> +struct buffer_converter { + static SLINKY_ALWAYS_INLINE const raw_buffer* convert(const raw_buffer* buffer) { return buffer; } +}; + +} // namespace internal + // Represents a node of computation in a pipeline. class func { public: @@ -192,13 +217,20 @@ class func { private: template - static inline index_t call_impl( + static SLINKY_ALWAYS_INLINE index_t call_impl( const func::callable& impl, eval_context& ctx, const call_stmt* op, std::index_sequence) { return impl( ctx.lookup_buffer(Indices < op->inputs.size() ? op->inputs[Indices] : op->outputs[Indices - op->inputs.size()]) ->template cast()...); } + template + static SLINKY_ALWAYS_INLINE index_t call_impl_tuple( + const Fn& impl, eval_context& ctx, const call_stmt* op, std::index_sequence) { + return impl(internal::buffer_converter::type>::convert(ctx.lookup_buffer( + Indices < op->inputs.size() ? op->inputs[Indices] : op->outputs[Indices - op->inputs.size()]))...); + } + template struct lambda_call_signature : lambda_call_signature {}; @@ -234,14 +266,17 @@ class func { template static func make( Lambda&& lambda, std::vector inputs, std::vector outputs, call_stmt::attributes attrs = {}) { + using sig = lambda_call_signature; // Verify that the lambda returns an index_t; a different return type will fail to match // the std::function call and just call this same function in an endless death spiral. - using sig = lambda_call_signature; static_assert(std::is_same_v); - using std_function_type = typename sig::std_function_type; - std_function_type impl = std::move(lambda); - return make_impl(std::move(impl), std::move(inputs), std::move(outputs), std::move(attrs)); + auto wrapper = [lambda = std::move(lambda)](const call_stmt* op, eval_context& ctx) -> index_t { + return call_impl_tuple( + lambda, ctx, op, std::make_index_sequence::value>()); + }; + + return func(std::move(wrapper), std::move(inputs), std::move(outputs), std::move(attrs)); } // Version for plain old function ptrs diff --git a/builder/test/pipeline.cc b/builder/test/pipeline.cc index cfa765d7..bf20dadf 100644 --- a/builder/test/pipeline.cc +++ b/builder/test/pipeline.cc @@ -190,8 +190,10 @@ TEST_P(elementwise, pipeline_2d) { // Here we explicitly use lambdas to wrap the local calls, // purely to verify that the relevant func::make calls work correctly. - auto m2 = [](const buffer& a, const buffer& b) -> index_t { return multiply_2(a, b); }; - auto a1 = [](const buffer& a, const buffer& b) -> index_t { return add_1(a, b); }; + auto m2 = [](const buffer& a, raw_buffer b) -> index_t { return multiply_2(a, b.cast()); }; + auto a1 = [](const raw_buffer& a, const raw_buffer* b) -> index_t { + return add_1(a.cast(), b->cast()); + }; func mul = func::make( std::move(m2), {{in, {point(x), point(y)}}}, {{intm, {x, y}}}, call_stmt::attributes{.allow_in_place = true}); @@ -638,7 +640,8 @@ TEST_P(stencil_chain, pipeline) { class multiple_outputs : public testing::TestWithParam> {}; -INSTANTIATE_TEST_SUITE_P(split_mode, multiple_outputs, testing::Combine(loop_modes, testing::Range(0, 4), testing::Bool()), +INSTANTIATE_TEST_SUITE_P(split_mode, multiple_outputs, + testing::Combine(loop_modes, testing::Range(0, 4), testing::Bool()), test_params_to_string); TEST_P(multiple_outputs, pipeline) { @@ -663,7 +666,8 @@ TEST_P(multiple_outputs, pipeline) { // For a 3D input in(x, y, z), compute sum_x = sum(input(:, y, z)) and sum_xy = sum(input(:, :, z)) in one stage. func::callable sum_x_xy = [](const buffer& in, const buffer& sum_x, const buffer& sum_xy) -> index_t { - for (index_t z = std::min(sum_xy.dim(0).min(), sum_x.dim(1).min()); z <= std::max(sum_xy.dim(0).max(), sum_x.dim(1).max()); ++z) { + for (index_t z = std::min(sum_xy.dim(0).min(), sum_x.dim(1).min()); + z <= std::max(sum_xy.dim(0).max(), sum_x.dim(1).max()); ++z) { if (sum_xy.contains(z)) sum_xy(z) = 0; for (index_t y = sum_x.dim(0).min(); y <= sum_x.dim(0).max(); ++y) { if (sum_x.contains(y, z)) sum_x(y, z) = 0; @@ -1461,8 +1465,7 @@ TEST(split, pipeline) { class upsample : public testing::TestWithParam> {}; -INSTANTIATE_TEST_SUITE_P(split_mode, upsample, - testing::Combine(loop_modes, testing::Range(0, 2)), +INSTANTIATE_TEST_SUITE_P(split_mode, upsample, testing::Combine(loop_modes, testing::Range(0, 2)), test_params_to_string); TEST_P(upsample, pipeline) { @@ -1480,7 +1483,6 @@ TEST_P(upsample, pipeline) { var x(ctx, "x"); var y(ctx, "y"); - func add = func::make(add_1, {{in, {point(x), point(x)}}}, {{intm, {x, y}}}); func upsample = func::make(upsample_nn_2x, {{intm, {point(x) / 2, point(y) / 2}}}, {{out, {x, y}}}); diff --git a/runtime/buffer.h b/runtime/buffer.h index 21d6d9c6..100e5c01 100644 --- a/runtime/buffer.h +++ b/runtime/buffer.h @@ -174,6 +174,7 @@ class raw_buffer { } public: + using element = void; using pointer = void*; void* base; @@ -414,6 +415,7 @@ class buffer : public raw_buffer { } public: + using element = T; using pointer = T*; using raw_buffer::cast;