Skip to content

Commit

Permalink
add fused kernel for chebyshev
Browse files Browse the repository at this point in the history
  • Loading branch information
yhmtsai committed Jan 3, 2025
1 parent 51513a4 commit d96108d
Show file tree
Hide file tree
Showing 5 changed files with 216 additions and 8 deletions.
67 changes: 67 additions & 0 deletions common/unified/solver/chebyshev_kernels.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
// SPDX-FileCopyrightText: 2017 - 2025 The Ginkgo authors
//
// SPDX-License-Identifier: BSD-3-Clause

#include "core/solver/chebyshev_kernels.hpp"

#include <ginkgo/core/matrix/dense.hpp>

#include "common/unified/base/kernel_launch.hpp"


namespace gko {
namespace kernels {
namespace GKO_DEVICE_NAMESPACE {
namespace chebyshev {


template <typename ValueType, typename ScalarType>
void init_update(std::shared_ptr<const DefaultExecutor> exec,
const ScalarType* alpha,
const matrix::Dense<ValueType>* inner_sol,
matrix::Dense<ValueType>* update_sol,
matrix::Dense<ValueType>* output)
{
run_kernel(
exec,
[] GKO_KERNEL(auto row, auto col, auto alpha, auto inner_sol,
auto update_sol, auto output) {
const auto inner_val = inner_sol(row, col);
update_sol(row, col) = val;
output(row, col) += alpha_val * inner_val;
},
output->get_size(), alpha, inner_sol, update_sol, output);
}

GKO_INSTANTIATE_FOR_EACH_VALUE_AND_SCALAR_TYPE(
GKO_DECLARE_CHEBYSHEV_INIT_UPDATE_KERNEL);


template <typename ValueType, typename ScalarType>
void update(std::shared_ptr<const DefaultExecutor> exec,
const ScalarType* alpha, const ScalarType* beta,
matrix::Dense<ValueType>* inner_sol,
matrix::Dense<ValueType>* update_sol,
matrix::Dense<ValueType>* output)
{
run_kernel(
exec,
[] GKO_KERNEL(auto row, auto col, auto alpha, auto beta, auto inner_sol,
auto update_sol, auto output) {
const auto val =
inner_sol(row, col) + beta[0] * update_sol(row, col);
inner_sol(row, col) = val;
update_sol(row, col) = val;
output(row, col) += alpha[0] * val;
},
output->get_size(), alpha, beta, inner_sol, update_sol, output);
}

GKO_INSTANTIATE_FOR_EACH_VALUE_AND_SCALAR_TYPE(
GKO_DECLARE_CHEBYSHEV_UPDATE_KERNEL);


} // namespace chebyshev
} // namespace GKO_DEVICE_NAMESPACE
} // namespace kernels
} // namespace gko
13 changes: 12 additions & 1 deletion core/device_hooks/common_kernels.inc.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// SPDX-FileCopyrightText: 2017 - 2024 The Ginkgo authors
// SPDX-FileCopyrightText: 2017 - 2025 The Ginkgo authors
//
// SPDX-License-Identifier: BSD-3-Clause

Expand Down Expand Up @@ -59,6 +59,7 @@
#include "core/solver/cb_gmres_kernels.hpp"
#include "core/solver/cg_kernels.hpp"
#include "core/solver/cgs_kernels.hpp"
#include "core/solver/chebyshev_kernels.hpp"
#include "core/solver/common_gmres_kernels.hpp"
#include "core/solver/fcg_kernels.hpp"
#include "core/solver/gcr_kernels.hpp"
Expand Down Expand Up @@ -653,6 +654,16 @@ GKO_STUB_CB_GMRES_CONST(GKO_DECLARE_CB_GMRES_SOLVE_KRYLOV_KERNEL);
} // namespace cb_gmres


namespace chebyshev {


GKO_STUB_VALUE_AND_SCALAR_TYPE(GKO_DECLARE_CHEBYSHEV_INIT_UPDATE_KERNEL);
GKO_STUB_VALUE_AND_SCALAR_TYPE(GKO_DECLARE_CHEBYSHEV_UPDATE_KERNEL);


} // namespace chebyshev


namespace ir {


Expand Down
23 changes: 16 additions & 7 deletions core/solver/chebyshev.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// SPDX-FileCopyrightText: 2017 - 2024 The Ginkgo authors
// SPDX-FileCopyrightText: 2017 - 2025 The Ginkgo authors
//
// SPDX-License-Identifier: BSD-3-Clause

Expand All @@ -10,6 +10,7 @@

#include "core/config/solver_config.hpp"
#include "core/distributed/helpers.hpp"
#include "core/solver/chebyshev_kernels.hpp"
#include "core/solver/ir_kernels.hpp"
#include "core/solver/solver_base.hpp"
#include "core/solver/solver_boilerplate.hpp"
Expand All @@ -23,6 +24,8 @@ namespace {


GKO_REGISTER_OPERATION(initialize, ir::initialize);
GKO_REGISTER_OPERATION(init_update, chebyshev::init_update);
GKO_REGISTER_OPERATION(update, chebyshev::update);


} // anonymous namespace
Expand Down Expand Up @@ -274,8 +277,12 @@ void Chebyshev<ValueType>::apply_dense_impl(const VectorType* dense_b,
num_generated_scalar_++;
}
// x = x + alpha * inner_solution
dense_x->add_scaled(alpha_scalar.get(), inner_solution);
update_solution->copy_from(inner_solution);
// update_solultion = inner_solution
exec->run(chebyshev::make_update(
alpha_scalar->get_const_values(),
gko::detail::get_local(inner_solution.get()),
gko::detail::get_local(update_solution.get()),
gko::detail::get_local(dense_x)));
continue;
}
// beta_ref for iter == 1 is initialized in the beginning
Expand All @@ -295,10 +302,12 @@ void Chebyshev<ValueType>::apply_dense_impl(const VectorType* dense_b,
}
// z = z + beta * p
// p = z
inner_solution->add_scaled(beta_scalar.get(), update_solution);
update_solution->copy_from(inner_solution);
// x + alpha * p
dense_x->add_scaled(alpha_scalar.get(), update_solution);
// x += alpha * p
exec->run(chebyshev::make_update(
alpha_scalar->get_const_values(), beta_scalar->get_const_values(),
gko::detail::get_local(inner_solution.get()),
gko::detail::get_local(update_solution.get()),
gko::detail::get_local(dense_x)));
}
}

Expand Down
58 changes: 58 additions & 0 deletions core/solver/chebyshev_kernels.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
// SPDX-FileCopyrightText: 2017 - 2025 The Ginkgo authors
//
// SPDX-License-Identifier: BSD-3-Clause

#ifndef GKO_CORE_SOLVER_CHEBYSHEV_KERNELS_HPP_
#define GKO_CORE_SOLVER_CHEBYSHEV_KERNELS_HPP_


#include <memory>

#include <ginkgo/core/base/executor.hpp>
#include <ginkgo/core/base/types.hpp>
#include <ginkgo/core/matrix/dense.hpp>

#include "core/base/kernel_declaration.hpp"


namespace gko {
namespace kernels {
namespace chebyshev {


#define GKO_DECLARE_CHEBYSHEV_INIT_UPDATE_KERNEL(ValueType, ScalarType) \
void init_update(std::shared_ptr<const DefaultExecutor> exec, \
const ScalarType* alpha, \
const matrix::Dense<ValueType>* inner_sol, \
matrix::Dense<ValueType>* update_sol, \
matrix::Dense<ValueType>* output)

#define GKO_DECLARE_CHEBYSHEV_UPDATE_KERNEL(ValueType, ScalarType) \
void update(std::shared_ptr<const DefaultExecutor> exec,\
const ScalarType* alpha, \
const ScalarType* beta, matrix::Dense<ValueType>* inner_sol, \
matrix::Dense<ValueType>* update_sol, \
matrix::Dense<ValueType>* output)

#define GKO_DECLARE_ALL_AS_TEMPLATES \
template <typename ValueType, typename ScalarType> \
GKO_DECLARE_CHEBYSHEV_INIT_UPDATE_KERNEL(ValueType, ScalarType); \
template <typename ValueType, typename ScalarType> \
GKO_DECLARE_CHEBYSHEV_UPDATE_KERNEL(ValueType, ScalarType)


} // namespace chebyshev


GKO_DECLARE_FOR_ALL_EXECUTOR_NAMESPACES(chebyshev,
GKO_DECLARE_ALL_AS_TEMPLATES);


#undef GKO_DECLARE_ALL_AS_TEMPLATES


} // namespace kernels
} // namespace gko


#endif // GKO_CORE_SOLVER_CHEBYSHEV_KERNELS_HPP_
63 changes: 63 additions & 0 deletions reference/solver/chebyshev_kernels.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
// SPDX-FileCopyrightText: 2017 - 2025 The Ginkgo authors
//
// SPDX-License-Identifier: BSD-3-Clause

#include "core/solver/chebyshev_kernels.hpp"

#include <ginkgo/core/matrix/dense.hpp>

namespace gko {
namespace kernels {
namespace reference {
namespace chebyshev {


template <typename ValueType, typename ScalarType>
void init_update(std::shared_ptr<const DefaultExecutor> exec,
const ScalarType* alpha,
const matrix::Dense<ValueType>* inner_sol,
matrix::Dense<ValueType>* update_sol,
matrix::Dense<ValueType>* output)
{
const auto alpha_val = alpha[0];
for (size_t row = 0; row < output->get_size()[0]; row++) {
for (size_t col = 0; col < output->get_size()[1]; col++) {
const auto inner_val = inner_sol->at(row, col);
update_sol->at(row, col) = inner_val;
output->at(row, col) += alpha_val * inner_val;
}
}
}

GKO_INSTANTIATE_FOR_EACH_VALUE_AND_SCALAR_TYPE(
GKO_DECLARE_CHEBYSHEV_INIT_UPDATE_KERNEL);


template <typename ValueType, typename ScalarType>
void update(std::shared_ptr<const DefaultExecutor> exec,
const ScalarType* alpha, const ScalarType* beta,
const matrix::Dense<ValueType>* inner_sol,
matrix::Dense<ValueType>* update_sol,
matrix::Dense<ValueType>* output)
{
const auto alpha_val = alpha[0];
const auto beta_val = beta[0];
for (size_t row = 0; row < output->get_size()[0]; row++) {
for (size_t col = 0; col < output->get_size()[1]; col++) {
const auto val =
inner_sol->at(row, col) + beta[0] * update_sol->at(row, col);
inner_sol->at(row, col) = inner_val;
update_sol->at(row, col) = inner_val;
output->at(row, col) += alpha_val * inner_val;
}
}
}

GKO_INSTANTIATE_FOR_EACH_VALUE_AND_SCALAR_TYPE(
GKO_DECLARE_CHEBYSHEV_UPDATE_KERNEL);


} // namespace chebyshev
} // namespace reference
} // namespace kernels
} // namespace gko

0 comments on commit d96108d

Please sign in to comment.