Skip to content

Commit

Permalink
Handle non-buffer<> arguments
Browse files Browse the repository at this point in the history
  • Loading branch information
dsharlet committed Jan 10, 2025
1 parent 05c6900 commit 7b930bc
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 11 deletions.
29 changes: 25 additions & 4 deletions builder/pipeline.h
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,29 @@ class buffer_expr : public ref_counted<buffer_expr> {
static void destroy(buffer_expr* p) { delete p; }
};

namespace internal {

template <typename T>
struct buffer_converter {
static SLINKY_ALWAYS_INLINE const auto& convert(const raw_buffer* buffer) {
return buffer->cast<typename std::remove_cv<typename std::remove_reference<T>::type>::type::element>();
}
};
template <>
struct buffer_converter<raw_buffer> {
static SLINKY_ALWAYS_INLINE const raw_buffer& convert(const raw_buffer* buffer) { return *buffer; }
};
template <>
struct buffer_converter<const raw_buffer&> {
static SLINKY_ALWAYS_INLINE const raw_buffer& convert(const raw_buffer* buffer) { return *buffer; }
};
template <>
struct buffer_converter<const raw_buffer*> {
static SLINKY_ALWAYS_INLINE const raw_buffer* convert(const raw_buffer* buffer) { return buffer; }
};

} // namespace internal

// Represents a node of computation in a pipeline.
class func {
public:
Expand Down Expand Up @@ -204,10 +227,8 @@ class func {
template <typename ArgTypes, typename Fn, std::size_t... Indices>
static SLINKY_ALWAYS_INLINE index_t call_impl_tuple(
const Fn& impl, eval_context& ctx, const call_stmt* op, std::index_sequence<Indices...>) {
return impl(
ctx.lookup_buffer(Indices < op->inputs.size() ? op->inputs[Indices] : op->outputs[Indices - op->inputs.size()])
->template cast<typename std::remove_cv<typename std::remove_reference<
typename std::tuple_element<Indices, ArgTypes>::type>::type>::type::element>()...);
return impl(internal::buffer_converter<typename std::tuple_element<Indices, ArgTypes>::type>::convert(ctx.lookup_buffer(
Indices < op->inputs.size() ? op->inputs[Indices] : op->outputs[Indices - op->inputs.size()]))...);
}

template <typename Lambda>
Expand Down
16 changes: 9 additions & 7 deletions builder/test/pipeline.cc
Original file line number Diff line number Diff line change
Expand Up @@ -190,8 +190,10 @@ TEST_P(elementwise, pipeline_2d) {

// Here we explicitly use lambdas to wrap the local calls,
// purely to verify that the relevant func::make calls work correctly.
auto m2 = [](const buffer<const int>& a, const buffer<int>& b) -> index_t { return multiply_2<int>(a, b); };
auto a1 = [](const buffer<const int>& a, const buffer<int>& b) -> index_t { return add_1<int>(a, b); };
auto m2 = [](const buffer<const int>& a, raw_buffer b) -> index_t { return multiply_2<int>(a, b.cast<int>()); };
auto a1 = [](const raw_buffer& a, const raw_buffer* b) -> index_t {
return add_1<int>(a.cast<const int>(), b->cast<int>());
};

func mul = func::make(
std::move(m2), {{in, {point(x), point(y)}}}, {{intm, {x, y}}}, call_stmt::attributes{.allow_in_place = true});
Expand Down Expand Up @@ -638,7 +640,8 @@ TEST_P(stencil_chain, pipeline) {

class multiple_outputs : public testing::TestWithParam<std::tuple<int, int, bool>> {};

INSTANTIATE_TEST_SUITE_P(split_mode, multiple_outputs, testing::Combine(loop_modes, testing::Range(0, 4), testing::Bool()),
INSTANTIATE_TEST_SUITE_P(split_mode, multiple_outputs,
testing::Combine(loop_modes, testing::Range(0, 4), testing::Bool()),
test_params_to_string<multiple_outputs::ParamType>);

TEST_P(multiple_outputs, pipeline) {
Expand All @@ -663,7 +666,8 @@ TEST_P(multiple_outputs, pipeline) {
// For a 3D input in(x, y, z), compute sum_x = sum(input(:, y, z)) and sum_xy = sum(input(:, :, z)) in one stage.
func::callable<const int, int, int> sum_x_xy = [](const buffer<const int>& in, const buffer<int>& sum_x,
const buffer<int>& sum_xy) -> index_t {
for (index_t z = std::min(sum_xy.dim(0).min(), sum_x.dim(1).min()); z <= std::max(sum_xy.dim(0).max(), sum_x.dim(1).max()); ++z) {
for (index_t z = std::min(sum_xy.dim(0).min(), sum_x.dim(1).min());
z <= std::max(sum_xy.dim(0).max(), sum_x.dim(1).max()); ++z) {
if (sum_xy.contains(z)) sum_xy(z) = 0;
for (index_t y = sum_x.dim(0).min(); y <= sum_x.dim(0).max(); ++y) {
if (sum_x.contains(y, z)) sum_x(y, z) = 0;
Expand Down Expand Up @@ -1461,8 +1465,7 @@ TEST(split, pipeline) {

class upsample : public testing::TestWithParam<std::tuple<int, int>> {};

INSTANTIATE_TEST_SUITE_P(split_mode, upsample,
testing::Combine(loop_modes, testing::Range(0, 2)),
INSTANTIATE_TEST_SUITE_P(split_mode, upsample, testing::Combine(loop_modes, testing::Range(0, 2)),
test_params_to_string<upsample::ParamType>);

TEST_P(upsample, pipeline) {
Expand All @@ -1480,7 +1483,6 @@ TEST_P(upsample, pipeline) {
var x(ctx, "x");
var y(ctx, "y");


func add = func::make(add_1<short>, {{in, {point(x), point(x)}}}, {{intm, {x, y}}});
func upsample = func::make(upsample_nn_2x<short>, {{intm, {point(x) / 2, point(y) / 2}}}, {{out, {x, y}}});

Expand Down

0 comments on commit 7b930bc

Please sign in to comment.