Skip to content

Commit

Permalink
Avoid extra std::function layer in func::make (#546)
Browse files Browse the repository at this point in the history
For the lambda version of `func::make`, we currently get two layers of
`std::function`: one for `call_stmt::callable`, which is unavoidable,
but the other is just a wrapper to change the arguments, and doesn't
need to be an `std::function`, which adds some extra overhead.

This PR avoids the inner `std::function`, calling the lambda directly
instead.
  • Loading branch information
dsharlet authored Jan 10, 2025
1 parent 9de98b3 commit 32fbea5
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 12 deletions.
45 changes: 40 additions & 5 deletions builder/pipeline.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
#ifndef SLINKY_BUILDER_PIPELINE_H
#define SLINKY_BUILDER_PIPELINE_H

#include <type_traits>

#include "base/ref_count.h"
#include "runtime/evaluate.h"
#include "runtime/expr.h"
Expand Down Expand Up @@ -87,6 +89,29 @@ class buffer_expr : public ref_counted<buffer_expr> {
static void destroy(buffer_expr* p) { delete p; }
};

namespace internal {

template <typename T>
struct buffer_converter {
static SLINKY_ALWAYS_INLINE const auto& convert(const raw_buffer* buffer) {
return buffer->cast<typename std::remove_cv<typename std::remove_reference<T>::type>::type::element>();
}
};
template <>
struct buffer_converter<raw_buffer> {
static SLINKY_ALWAYS_INLINE const raw_buffer& convert(const raw_buffer* buffer) { return *buffer; }
};
template <>
struct buffer_converter<const raw_buffer&> {
static SLINKY_ALWAYS_INLINE const raw_buffer& convert(const raw_buffer* buffer) { return *buffer; }
};
template <>
struct buffer_converter<const raw_buffer*> {
static SLINKY_ALWAYS_INLINE const raw_buffer* convert(const raw_buffer* buffer) { return buffer; }
};

} // namespace internal

// Represents a node of computation in a pipeline.
class func {
public:
Expand Down Expand Up @@ -192,13 +217,20 @@ class func {

private:
template <typename... T, std::size_t... Indices>
static inline index_t call_impl(
static SLINKY_ALWAYS_INLINE index_t call_impl(
const func::callable<T...>& impl, eval_context& ctx, const call_stmt* op, std::index_sequence<Indices...>) {
return impl(
ctx.lookup_buffer(Indices < op->inputs.size() ? op->inputs[Indices] : op->outputs[Indices - op->inputs.size()])
->template cast<T>()...);
}

template <typename ArgTypes, typename Fn, std::size_t... Indices>
static SLINKY_ALWAYS_INLINE index_t call_impl_tuple(
const Fn& impl, eval_context& ctx, const call_stmt* op, std::index_sequence<Indices...>) {
return impl(internal::buffer_converter<typename std::tuple_element<Indices, ArgTypes>::type>::convert(ctx.lookup_buffer(
Indices < op->inputs.size() ? op->inputs[Indices] : op->outputs[Indices - op->inputs.size()]))...);
}

template <typename Lambda>
struct lambda_call_signature : lambda_call_signature<decltype(&Lambda::operator())> {};

Expand Down Expand Up @@ -234,14 +266,17 @@ class func {
template <typename Lambda>
static func make(
Lambda&& lambda, std::vector<input> inputs, std::vector<output> outputs, call_stmt::attributes attrs = {}) {
using sig = lambda_call_signature<Lambda>;
// Verify that the lambda returns an index_t; a different return type will fail to match
// the std::function call and just call this same function in an endless death spiral.
using sig = lambda_call_signature<Lambda>;
static_assert(std::is_same_v<typename sig::ret_type, index_t>);

using std_function_type = typename sig::std_function_type;
std_function_type impl = std::move(lambda);
return make_impl(std::move(impl), std::move(inputs), std::move(outputs), std::move(attrs));
auto wrapper = [lambda = std::move(lambda)](const call_stmt* op, eval_context& ctx) -> index_t {
return call_impl_tuple<typename sig::arg_types>(
lambda, ctx, op, std::make_index_sequence<std::tuple_size<typename sig::arg_types>::value>());
};

return func(std::move(wrapper), std::move(inputs), std::move(outputs), std::move(attrs));
}

// Version for plain old function ptrs
Expand Down
16 changes: 9 additions & 7 deletions builder/test/pipeline.cc
Original file line number Diff line number Diff line change
Expand Up @@ -190,8 +190,10 @@ TEST_P(elementwise, pipeline_2d) {

// Here we explicitly use lambdas to wrap the local calls,
// purely to verify that the relevant func::make calls work correctly.
auto m2 = [](const buffer<const int>& a, const buffer<int>& b) -> index_t { return multiply_2<int>(a, b); };
auto a1 = [](const buffer<const int>& a, const buffer<int>& b) -> index_t { return add_1<int>(a, b); };
auto m2 = [](const buffer<const int>& a, raw_buffer b) -> index_t { return multiply_2<int>(a, b.cast<int>()); };
auto a1 = [](const raw_buffer& a, const raw_buffer* b) -> index_t {
return add_1<int>(a.cast<const int>(), b->cast<int>());
};

func mul = func::make(
std::move(m2), {{in, {point(x), point(y)}}}, {{intm, {x, y}}}, call_stmt::attributes{.allow_in_place = true});
Expand Down Expand Up @@ -638,7 +640,8 @@ TEST_P(stencil_chain, pipeline) {

class multiple_outputs : public testing::TestWithParam<std::tuple<int, int, bool>> {};

INSTANTIATE_TEST_SUITE_P(split_mode, multiple_outputs, testing::Combine(loop_modes, testing::Range(0, 4), testing::Bool()),
INSTANTIATE_TEST_SUITE_P(split_mode, multiple_outputs,
testing::Combine(loop_modes, testing::Range(0, 4), testing::Bool()),
test_params_to_string<multiple_outputs::ParamType>);

TEST_P(multiple_outputs, pipeline) {
Expand All @@ -663,7 +666,8 @@ TEST_P(multiple_outputs, pipeline) {
// For a 3D input in(x, y, z), compute sum_x = sum(input(:, y, z)) and sum_xy = sum(input(:, :, z)) in one stage.
func::callable<const int, int, int> sum_x_xy = [](const buffer<const int>& in, const buffer<int>& sum_x,
const buffer<int>& sum_xy) -> index_t {
for (index_t z = std::min(sum_xy.dim(0).min(), sum_x.dim(1).min()); z <= std::max(sum_xy.dim(0).max(), sum_x.dim(1).max()); ++z) {
for (index_t z = std::min(sum_xy.dim(0).min(), sum_x.dim(1).min());
z <= std::max(sum_xy.dim(0).max(), sum_x.dim(1).max()); ++z) {
if (sum_xy.contains(z)) sum_xy(z) = 0;
for (index_t y = sum_x.dim(0).min(); y <= sum_x.dim(0).max(); ++y) {
if (sum_x.contains(y, z)) sum_x(y, z) = 0;
Expand Down Expand Up @@ -1461,8 +1465,7 @@ TEST(split, pipeline) {

class upsample : public testing::TestWithParam<std::tuple<int, int>> {};

INSTANTIATE_TEST_SUITE_P(split_mode, upsample,
testing::Combine(loop_modes, testing::Range(0, 2)),
INSTANTIATE_TEST_SUITE_P(split_mode, upsample, testing::Combine(loop_modes, testing::Range(0, 2)),
test_params_to_string<upsample::ParamType>);

TEST_P(upsample, pipeline) {
Expand All @@ -1480,7 +1483,6 @@ TEST_P(upsample, pipeline) {
var x(ctx, "x");
var y(ctx, "y");


func add = func::make(add_1<short>, {{in, {point(x), point(x)}}}, {{intm, {x, y}}});
func upsample = func::make(upsample_nn_2x<short>, {{intm, {point(x) / 2, point(y) / 2}}}, {{out, {x, y}}});

Expand Down
2 changes: 2 additions & 0 deletions runtime/buffer.h
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,7 @@ class raw_buffer {
}

public:
using element = void;
using pointer = void*;

void* base;
Expand Down Expand Up @@ -414,6 +415,7 @@ class buffer : public raw_buffer {
}

public:
using element = T;
using pointer = T*;

using raw_buffer::cast;
Expand Down

0 comments on commit 32fbea5

Please sign in to comment.