Skip to content

Commit

Permalink
[wip] add reference apply
Browse files Browse the repository at this point in the history
  • Loading branch information
MarcelKoch committed Jan 31, 2025
1 parent 967edc6 commit b3b430f
Show file tree
Hide file tree
Showing 2 changed files with 86 additions and 8 deletions.
49 changes: 41 additions & 8 deletions examples/batched-matrix-free-templated/tensor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,14 @@

#pragma once

#include <variant>

#include <ginkgo/core/base/executor.hpp>
#include <ginkgo/core/base/polymorphic_object.hpp>
#include <ginkgo/core/matrix/batch_dense.hpp>
#include <ginkgo/extensions/kokkos.hpp>

#include "examples/batched-matrix-free-templated/batched/kernel_tags.hpp"
#include "ginkgo/core/matrix/identity.hpp"

namespace tensor {
Expand Down Expand Up @@ -95,24 +98,21 @@ void convert_tensor(gko::ptr_param<const gko::matrix::Identity<value_type>> A,

struct tensor_left_view {
gko::size_type num_batch_items;
gko::int32 num_rows;
gko::int32 num_cols;
gko::int32 stride;
gko::int32 size_1d;
const value_type* data;
};

struct tensor_left_item {
gko::int32 num_rows;
gko::int32 num_cols;
gko::int32 stride;
gko::int32 size_1d;
const value_type* data;
};

constexpr tensor_left_item extract_batch_item(tensor_left_view op,
gko::size_type batch_id)
{
return {op.num_rows, op.num_cols, op.stride,
op.data + batch_id * op.num_rows * op.stride};
return {op.stride, op.size_1d, op.data + batch_id * op.size_1d * op.stride};
}

class TensorLeft : public gko::EnablePolymorphicObject<TensorLeft> {
Expand Down Expand Up @@ -143,8 +143,7 @@ class TensorLeft : public gko::EnablePolymorphicObject<TensorLeft> {
[[nodiscard]] tensor_left_view create_view() const
{
return {this->get_num_batch_items(),
static_cast<gko::int32>(this->get_common_size()[0]),
static_cast<gko::int32>(this->get_common_size()[1]),
static_cast<gko::int32>(data_->get_common_size()[0]),
static_cast<gko::int32>(data_->get_common_size()[1]),
data_->get_const_values()};
}
Expand Down Expand Up @@ -172,6 +171,40 @@ class TensorLeft : public gko::EnablePolymorphicObject<TensorLeft> {
};


constexpr void advanced_apply(
double alpha, tensor_left_item a,
gko::batch::multi_vector::batch_item<const double> b, double beta,
gko::batch::multi_vector::batch_item<double> x,
[[maybe_unused]] std::variant<gko::reference_kernel, gko::omp_kernel>)
{
for (gko::int32 k = 0; k < a.size_1d; ++k) {
for (gko::int32 j = 0; j < a.size_1d; ++j) {
for (gko::int32 i = 0; i < a.size_1d; ++i) {
auto vector_start = k * a.size_1d * a.size_1d + i;

value_type acc = 0;
for (gko::size_type q = 0; q < a.size_1d; q++) {
auto vector_index = vector_start + q * a.size_1d;
acc = a.data[j * a.size_1d + q] * b.values[vector_index] +
acc;
}
auto row = k * a.size_1d * a.size_1d + j * a.size_1d + i;
x.values[row] = alpha * acc + beta * x.values[row];
}
}
}
}

constexpr void simple_apply(
const tensor_left_item& a,
const gko::batch::multi_vector::batch_item<const double>& b,
const gko::batch::multi_vector::batch_item<double>& x,
std::variant<gko::reference_kernel, gko::omp_kernel> tag)
{
advanced_apply(1.0, a, b, 0.0, x, tag);
}


std::unique_ptr<gko::batch::matrix::Dense<value_type>> convert(
gko::ptr_param<const TensorLeft> tensor)
{
Expand Down
45 changes: 45 additions & 0 deletions examples/batched-matrix-free-templated/test/tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

#include <gtest/gtest.h>

#include "core/matrix/batch_struct.hpp"
#include "core/test/utils.hpp"

auto exec = gko::ReferenceExecutor::create();
Expand Down Expand Up @@ -104,9 +105,32 @@ class Tensor2 : public testing::Test {
}

tensor = std::make_unique<tensor::TensorLeft>(std::move(data));

auto num_rows = tensor->get_common_size()[0];
auto vector_size = gko::batch_dim<2>{tensor->get_num_batch_items(),
gko::dim<2>{num_rows, 1}};
x = gko::batch::MultiVector<tensor::value_type>::create(exec,
vector_size);
b = gko::batch::MultiVector<tensor::value_type>::create(exec,
vector_size);
for (gko::size_type batch = 0; batch < x->get_num_batch_items();
++batch) {
x->create_view_for_item(batch)->read(
gko::test::generate_random_matrix_data<tensor::value_type,
gko::int32>(
vector_size.get_common_size()[0],
vector_size.get_common_size()[1],
std::uniform_int_distribution<>(1, 1),
std::uniform_real_distribution<>(), engine));
b->create_view_for_item(batch)->fill(0.0);
}
}

std::default_random_engine engine{42};

std::unique_ptr<tensor::TensorLeft> tensor;
std::unique_ptr<gko::batch::MultiVector<tensor::value_type>> x;
std::unique_ptr<gko::batch::MultiVector<tensor::value_type>> b;
};

TEST_F(Tensor2, CanConvert)
Expand All @@ -116,3 +140,24 @@ TEST_F(Tensor2, CanConvert)
ASSERT_EQ(mat->get_size(), tensor->get_size());
gko::write(std::ofstream("batch.mtx"), mat->create_view_for_item(1));
}


TEST_F(Tensor2, CanApply)
{
gko::size_type batch_id = 1;
auto view = tensor->create_view();
auto item = tensor::extract_batch_item(view, batch_id);
auto x_view = gko::batch::to_const(x->create_view());
auto b_view = b->create_view();

tensor::simple_apply(item, gko::batch::extract_batch_item(x_view, batch_id),
gko::batch::extract_batch_item(b_view, batch_id),
gko::reference_kernel{});

auto dense = convert(tensor);
auto expected_b = gko::clone(b);
dense->apply(x, expected_b);
GKO_ASSERT_MTX_NEAR(b->create_view_for_item(batch_id).get(),
expected_b->create_view_for_item(batch_id).get(),
r<tensor::value_type>::value);
}

0 comments on commit b3b430f

Please sign in to comment.