Skip to content

Commit

Permalink
[wip] add conversion to sparse
Browse files Browse the repository at this point in the history
  • Loading branch information
MarcelKoch committed Feb 5, 2025
1 parent a9b37ae commit 32eaf97
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 15 deletions.
18 changes: 13 additions & 5 deletions examples/batched-matrix-free-templated/benchmark/tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,10 @@
#include <core/test/utils/batch_helpers.hpp>
#include <examples/batched-matrix-free-templated/tensor.hpp>

DEFINE_string(apply, "matrix-free",
"The apply implementation: either >matrix-free<, or "
">matrix-based<, or a >,< separated list.");
DEFINE_string(
apply, "matrix-free",
"The apply implementation: either >matrix-free<, >matrix-dense<, or "
">matrix-sparse<, or a >,< separated list.");

using vtype = tensor::ValueType;

Expand Down Expand Up @@ -110,12 +111,19 @@ struct TensorBenchmark : public Benchmark<TensorState> {
operation_case["repetitions"] = ic.get_num_repetitions();
};

using Dense = gko::batch::matrix::Dense<vtype>;
using Csr = gko::batch::matrix::Csr<vtype>;

auto tensor =
std::make_shared<tensor::TensorLeft>(gko::clone(state.data_1d));
if (operation == "matrix-free") {
run_impl(tensor);
} else if (operation == "matrix-based") {
run_impl(tensor::convert(tensor));
} else if (operation == "matrix-dense") {
run_impl(tensor::convert<Dense>(tensor));
} else if (operation == "matrix-sparse") {
auto size_1d = state.data_1d->get_common_size()[0];
auto nnz = size_1d * size_1d * size_1d * size_1d;
run_impl(tensor::convert<Csr>(tensor, nnz));
} else {
throw std::runtime_error("Unsupported operation: " + operation);
}
Expand Down
14 changes: 9 additions & 5 deletions examples/batched-matrix-free-templated/tensor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -240,22 +240,26 @@ __device__ void simple_apply(
#endif


std::unique_ptr<gko::batch::matrix::Dense<ValueType>> convert(
gko::ptr_param<const TensorLeft> tensor)
template <typename MatrixType, typename... ExtraArgs>
std::unique_ptr<MatrixType> convert(gko::ptr_param<const TensorLeft> tensor,
ExtraArgs&&... args)
{
auto result = gko::batch::matrix::Dense<ValueType>::create(
tensor->get_executor(), tensor->get_size());
auto result = MatrixType::create(tensor->get_executor(), tensor->get_size(),
std::forward<ExtraArgs>(args)...);

auto size_1d = tensor->get_data()->get_common_size()[0];
auto id = gko::matrix::Identity<ValueType>::create(tensor->get_executor(),
size_1d);
auto intermediate = gko::matrix::Dense<ValueType>::create(
tensor->get_executor(), gko::dim<2>{size_1d * size_1d});
auto intermediate2 = gko::matrix::Dense<ValueType>::create(
tensor->get_executor(), gko::dim<2>{size_1d * size_1d * size_1d});
for (gko::size_type batch = 0; batch < tensor->get_num_batch_items();
++batch) {
convert_tensor(tensor->get_data()->create_const_view_for_item(batch),
id, intermediate);
convert_tensor(id, intermediate, result->create_view_for_item(batch));
convert_tensor(id, intermediate, intermediate2);
intermediate2->convert_to(result->create_view_for_item(batch));
}

return result;
Expand Down
25 changes: 20 additions & 5 deletions examples/batched-matrix-free-templated/test/tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -149,21 +149,36 @@ class TensorApply : public CommonTestFixture {
}
}

using Vector = gko::batch::MultiVector<tensor::ValueType>;
using Dense = gko::batch::matrix::Dense<tensor::ValueType>;
using Csr = gko::batch::matrix::Csr<tensor::ValueType>;

std::default_random_engine engine{42};

std::unique_ptr<tensor::TensorLeft> tensor;
std::unique_ptr<gko::batch::MultiVector<tensor::ValueType>> x;
std::unique_ptr<Vector> x;
std::unique_ptr<gko::batch::MultiVector<tensor::ValueType>> b;
};

TEST_F(TensorApply, CanConvert)
TEST_F(TensorApply, CanConvertDense)
{
auto mat = convert(tensor);
auto mat = convert<Dense>(tensor);

ASSERT_EQ(mat->get_size(), tensor->get_size());
gko::write(std::ofstream("batch.mtx"), mat->create_view_for_item(1));
}

TEST_F(TensorApply, CanConvertCsr)
{
auto size_1d = tensor->get_data()->get_common_size()[0];
auto nnz = size_1d * size_1d * size_1d * size_1d;

auto csr = convert<Csr>(tensor, nnz);

auto dense = convert<Dense>(tensor);
GKO_ASSERT_BATCH_MTX_NEAR(csr, dense, 0.0);
}

#if defined(GKO_COMPILING_HIP) || defined(GKO_COMPILING_CUDA)

__global__ void call_simple_apply_kernel(
Expand Down Expand Up @@ -207,7 +222,7 @@ TEST_F(TensorApply, CanApplySingleBatch)
gko::batch::extract_batch_item(b_view, batch_id));
exec->synchronize();

auto dense = convert(tensor);
auto dense = convert<Dense>(tensor);
auto expected_b = gko::clone(b);
dense->apply(x, expected_b);
GKO_ASSERT_MTX_NEAR(b->create_view_for_item(batch_id).get(),
Expand All @@ -219,7 +234,7 @@ TEST_F(TensorApply, CanApply)
{
tensor->apply(x, b);

auto dense = convert(tensor);
auto dense = convert<Dense>(tensor);
auto expected_b = gko::clone(b);
dense->apply(x, expected_b);
GKO_ASSERT_BATCH_MTX_NEAR(b, expected_b, r<tensor::ValueType>::value);
Expand Down

0 comments on commit 32eaf97

Please sign in to comment.