Skip to content

Commit

Permalink
Merge (ginkgo-project#1651): Unify batch functionality: Multivector
Browse files Browse the repository at this point in the history
Unify and simplify batch functionality: Multivector

Related PR: ginkgo-project#1651
  • Loading branch information
pratikvn authored and MarcelKoch committed Dec 2, 2024
2 parents 35195b5 + fd24ac5 commit 4e9a501
Show file tree
Hide file tree
Showing 28 changed files with 341 additions and 311 deletions.
1 change: 1 addition & 0 deletions common/cuda_hip/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
include(${PROJECT_SOURCE_DIR}/cmake/template_instantiation.cmake)
set(CUDA_HIP_SOURCES
base/batch_multi_vector_kernels.cpp
base/device_matrix_data_kernels.cpp
base/index_set_kernels.cpp
components/prefix_sum_kernels.cpp
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,32 @@
//
// SPDX-License-Identifier: BSD-3-Clause

#include "common/cuda_hip/base/batch_multi_vector_kernels.hpp"

#include <thrust/functional.h>
#include <thrust/transform.h>

#include <ginkgo/core/base/batch_multi_vector.hpp>
#include <ginkgo/core/base/exception_helpers.hpp>
#include <ginkgo/core/base/math.hpp>
#include <ginkgo/core/base/types.hpp>

#include "common/cuda_hip/base/config.hpp"
#include "common/cuda_hip/base/math.hpp"
#include "common/cuda_hip/base/runtime.hpp"
#include "core/base/batch_multi_vector_kernels.hpp"
#include "core/base/batch_struct.hpp"


namespace gko {
namespace kernels {
namespace GKO_DEVICE_NAMESPACE {
namespace batch_multi_vector {


constexpr auto default_block_size = 256;


template <typename ValueType>
void scale(std::shared_ptr<const DefaultExecutor> exec,
const batch::MultiVector<ValueType>* const alpha,
Expand All @@ -11,16 +37,19 @@ void scale(std::shared_ptr<const DefaultExecutor> exec,
const auto alpha_ub = get_batch_struct(alpha);
const auto x_ub = get_batch_struct(x);
if (alpha->get_common_size()[1] == 1) {
scale_kernel<<<num_blocks, default_block_size, 0, exec->get_stream()>>>(
batch_single_kernels::scale_kernel<<<num_blocks, default_block_size, 0,
exec->get_stream()>>>(
alpha_ub, x_ub,
[] __device__(int row, int col, int stride) { return 0; });
} else if (alpha->get_common_size() == x->get_common_size()) {
scale_kernel<<<num_blocks, default_block_size, 0, exec->get_stream()>>>(
batch_single_kernels::scale_kernel<<<num_blocks, default_block_size, 0,
exec->get_stream()>>>(
alpha_ub, x_ub, [] __device__(int row, int col, int stride) {
return row * stride + col;
});
} else {
scale_kernel<<<num_blocks, default_block_size, 0, exec->get_stream()>>>(
batch_single_kernels::scale_kernel<<<num_blocks, default_block_size, 0,
exec->get_stream()>>>(
alpha_ub, x_ub,
[] __device__(int row, int col, int stride) { return col; });
}
Expand All @@ -42,12 +71,12 @@ void add_scaled(std::shared_ptr<const DefaultExecutor> exec,
const auto x_ub = get_batch_struct(x);
const auto y_ub = get_batch_struct(y);
if (alpha->get_common_size()[1] == 1) {
add_scaled_kernel<<<num_blocks, default_block_size, 0,
exec->get_stream()>>>(
batch_single_kernels::add_scaled_kernel<<<
num_blocks, default_block_size, 0, exec->get_stream()>>>(
alpha_ub, x_ub, y_ub, [] __device__(int col) { return 0; });
} else {
add_scaled_kernel<<<num_blocks, default_block_size, 0,
exec->get_stream()>>>(
batch_single_kernels::add_scaled_kernel<<<
num_blocks, default_block_size, 0, exec->get_stream()>>>(
alpha_ub, x_ub, y_ub, [] __device__(int col) { return col; });
}
}
Expand All @@ -67,8 +96,8 @@ void compute_dot(std::shared_ptr<const DefaultExecutor> exec,
const auto x_ub = get_batch_struct(x);
const auto y_ub = get_batch_struct(y);
const auto res_ub = get_batch_struct(result);
compute_gen_dot_product_kernel<<<num_blocks, default_block_size, 0,
exec->get_stream()>>>(
batch_single_kernels::compute_gen_dot_product_kernel<<<
num_blocks, default_block_size, 0, exec->get_stream()>>>(
x_ub, y_ub, res_ub, [] __device__(auto val) { return val; });
}

Expand All @@ -87,8 +116,8 @@ void compute_conj_dot(std::shared_ptr<const DefaultExecutor> exec,
const auto x_ub = get_batch_struct(x);
const auto y_ub = get_batch_struct(y);
const auto res_ub = get_batch_struct(result);
compute_gen_dot_product_kernel<<<num_blocks, default_block_size, 0,
exec->get_stream()>>>(
batch_single_kernels::compute_gen_dot_product_kernel<<<
num_blocks, default_block_size, 0, exec->get_stream()>>>(
x_ub, y_ub, res_ub, [] __device__(auto val) { return conj(val); });
}

Expand All @@ -105,8 +134,9 @@ void compute_norm2(std::shared_ptr<const DefaultExecutor> exec,
const auto num_rhs = x->get_common_size()[1];
const auto x_ub = get_batch_struct(x);
const auto res_ub = get_batch_struct(result);
compute_norm2_kernel<<<num_blocks, default_block_size, 0,
exec->get_stream()>>>(x_ub, res_ub);
batch_single_kernels::compute_norm2_kernel<<<num_blocks, default_block_size,
0, exec->get_stream()>>>(
x_ub, res_ub);
}

GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(
Expand All @@ -121,8 +151,15 @@ void copy(std::shared_ptr<const DefaultExecutor> exec,
const auto num_blocks = x->get_num_batch_items();
const auto result_ub = get_batch_struct(result);
const auto x_ub = get_batch_struct(x);
copy_kernel<<<num_blocks, default_block_size, 0, exec->get_stream()>>>(
x_ub, result_ub);
batch_single_kernels::
copy_kernel<<<num_blocks, default_block_size, 0, exec->get_stream()>>>(
x_ub, result_ub);
}

GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(GKO_DECLARE_BATCH_MULTI_VECTOR_COPY_KERNEL);


} // namespace batch_multi_vector
} // namespace GKO_DEVICE_NAMESPACE
} // namespace kernels
} // namespace gko
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,44 @@
//
// SPDX-License-Identifier: BSD-3-Clause

#include <thrust/functional.h>
#include <thrust/transform.h>

#include <ginkgo/core/base/batch_multi_vector.hpp>
#include <ginkgo/core/base/exception_helpers.hpp>
#include <ginkgo/core/base/math.hpp>
#include <ginkgo/core/base/types.hpp>

#include "common/cuda_hip/base/config.hpp"
#include "common/cuda_hip/base/math.hpp"
#include "common/cuda_hip/base/runtime.hpp"
#include "common/cuda_hip/base/thrust.hpp"
#include "common/cuda_hip/base/types.hpp"
#include "common/cuda_hip/components/cooperative_groups.hpp"
#include "common/cuda_hip/components/format_conversion.hpp"
#include "common/cuda_hip/components/reduction.hpp"
#include "common/cuda_hip/components/segment_scan.hpp"
#include "common/cuda_hip/components/thread_ids.hpp"
#include "common/cuda_hip/components/warp_blas.hpp"

#if defined(GKO_COMPILING_CUDA)
#include "cuda/base/batch_struct.hpp"
#elif defined(GKO_COMPILING_HIP)
#include "hip/base/batch_struct.hip.hpp"
#else
#error "batch struct def missing"
#endif


namespace gko {
namespace kernels {
namespace GKO_DEVICE_NAMESPACE {
namespace batch_single_kernels {


constexpr auto default_block_size = 256;


template <typename ValueType, typename Mapping>
__device__ __forceinline__ void scale(
const gko::batch::multi_vector::batch_item<const ValueType>& alpha,
Expand All @@ -20,8 +58,7 @@ __device__ __forceinline__ void scale(


template <typename ValueType, typename Mapping>
__global__
__launch_bounds__(default_block_size, sm_oversubscription) void scale_kernel(
__global__ __launch_bounds__(default_block_size) void scale_kernel(
const gko::batch::multi_vector::uniform_batch<const ValueType> alpha,
const gko::batch::multi_vector::uniform_batch<ValueType> x, Mapping map)
{
Expand Down Expand Up @@ -52,20 +89,10 @@ __device__ __forceinline__ void add_scaled(


template <typename ValueType, typename Mapping>
__global__ __launch_bounds__(
default_block_size,
sm_oversubscription) void add_scaled_kernel(const gko::batch::multi_vector::
uniform_batch<
const ValueType>
alpha,
const gko::batch::multi_vector::
uniform_batch<
const ValueType>
x,
const gko::batch::multi_vector::
uniform_batch<ValueType>
y,
Mapping map)
__global__ __launch_bounds__(default_block_size) void add_scaled_kernel(
const gko::batch::multi_vector::uniform_batch<const ValueType> alpha,
const gko::batch::multi_vector::uniform_batch<const ValueType> x,
const gko::batch::multi_vector::uniform_batch<ValueType> y, Mapping map)
{
for (size_type batch_id = blockIdx.x; batch_id < x.num_batch_items;
batch_id += gridDim.x) {
Expand Down Expand Up @@ -145,7 +172,7 @@ __device__ __forceinline__ void compute_gen_dot_product(

template <typename ValueType, typename Mapping>
__global__
__launch_bounds__(default_block_size, sm_oversubscription) void compute_gen_dot_product_kernel(
__launch_bounds__(default_block_size) void compute_gen_dot_product_kernel(
const gko::batch::multi_vector::uniform_batch<const ValueType> x,
const gko::batch::multi_vector::uniform_batch<const ValueType> y,
const gko::batch::multi_vector::uniform_batch<ValueType> result,
Expand Down Expand Up @@ -232,19 +259,10 @@ __device__ __forceinline__ void compute_norm2(


template <typename ValueType>
__global__ __launch_bounds__(
default_block_size,
sm_oversubscription) void compute_norm2_kernel(const gko::batch::
multi_vector::
uniform_batch<
const ValueType>
x,
const gko::batch::
multi_vector::
uniform_batch<
remove_complex<
ValueType>>
result)
__global__ __launch_bounds__(default_block_size) void compute_norm2_kernel(
const gko::batch::multi_vector::uniform_batch<const ValueType> x,
const gko::batch::multi_vector::uniform_batch<remove_complex<ValueType>>
result)
{
for (size_type batch_id = blockIdx.x; batch_id < x.num_batch_items;
batch_id += gridDim.x) {
Expand Down Expand Up @@ -287,8 +305,7 @@ __device__ __forceinline__ void copy(


template <typename ValueType>
__global__
__launch_bounds__(default_block_size, sm_oversubscription) void copy_kernel(
__global__ __launch_bounds__(default_block_size) void copy_kernel(
const gko::batch::multi_vector::uniform_batch<const ValueType> src,
const gko::batch::multi_vector::uniform_batch<ValueType> dst)
{
Expand All @@ -299,3 +316,9 @@ __launch_bounds__(default_block_size, sm_oversubscription) void copy_kernel(
copy(src_b, dst_b);
}
}


} // namespace batch_single_kernels
} // namespace GKO_DEVICE_NAMESPACE
} // namespace kernels
} // namespace gko
41 changes: 26 additions & 15 deletions common/cuda_hip/solver/batch_bicgstab_kernels.hpp.inc
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,14 @@ __device__ __forceinline__ void initialize(
__syncthreads();

if (threadIdx.x / config::warp_size == 0) {
single_rhs_compute_norm2(subgroup, num_rows, r_shared_entry, res_norm);
gko::kernels::GKO_DEVICE_NAMESPACE::batch_single_kernels::
single_rhs_compute_norm2(subgroup, num_rows, r_shared_entry,
res_norm);
} else if (threadIdx.x / config::warp_size == 1) {
// Compute norms of rhs
single_rhs_compute_norm2(subgroup, num_rows, b_global_entry, rhs_norm);
gko::kernels::GKO_DEVICE_NAMESPACE::batch_single_kernels::
single_rhs_compute_norm2(subgroup, num_rows, b_global_entry,
rhs_norm);
}
__syncthreads();

Expand Down Expand Up @@ -70,8 +74,9 @@ __device__ __forceinline__ void compute_alpha(
const ValueType* const v_shared_entry, ValueType& alpha)
{
if (threadIdx.x / config::warp_size == 0) {
single_rhs_compute_conj_dot(subgroup, num_rows, r_hat_shared_entry,
v_shared_entry, alpha);
gko::kernels::GKO_DEVICE_NAMESPACE::batch_single_kernels::
single_rhs_compute_conj_dot(subgroup, num_rows, r_hat_shared_entry,
v_shared_entry, alpha);
}
__syncthreads();
if (threadIdx.x == 0) {
Expand Down Expand Up @@ -99,11 +104,13 @@ __device__ __forceinline__ void compute_omega(
const ValueType* const s_shared_entry, ValueType& temp, ValueType& omega)
{
if (threadIdx.x / config::warp_size == 0) {
single_rhs_compute_conj_dot(subgroup, num_rows, t_shared_entry,
s_shared_entry, omega);
gko::kernels::GKO_DEVICE_NAMESPACE::batch_single_kernels::
single_rhs_compute_conj_dot(subgroup, num_rows, t_shared_entry,
s_shared_entry, omega);
} else if (threadIdx.x / config::warp_size == 1) {
single_rhs_compute_conj_dot(subgroup, num_rows, t_shared_entry,
t_shared_entry, temp);
gko::kernels::GKO_DEVICE_NAMESPACE::batch_single_kernels::
single_rhs_compute_conj_dot(subgroup, num_rows, t_shared_entry,
t_shared_entry, temp);
}

__syncthreads();
Expand Down Expand Up @@ -271,8 +278,9 @@ __global__ void apply_kernel(

// rho_new = < r_hat , r > = (r_hat)' * (r)
if (threadIdx.x / config::warp_size == 0) {
single_rhs_compute_conj_dot(subgroup, num_rows, r_hat_sh, r_sh,
rho_new_sh[0]);
gko::kernels::GKO_DEVICE_NAMESPACE::batch_single_kernels::
single_rhs_compute_conj_dot(subgroup, num_rows, r_hat_sh,
r_sh, rho_new_sh[0]);
}
__syncthreads();

Expand Down Expand Up @@ -301,8 +309,9 @@ __global__ void apply_kernel(

// an estimate of residual norms
if (threadIdx.x / config::warp_size == 0) {
single_rhs_compute_norm2(subgroup, num_rows, s_sh,
norms_res_sh[0]);
gko::kernels::GKO_DEVICE_NAMESPACE::batch_single_kernels::
single_rhs_compute_norm2(subgroup, num_rows, s_sh,
norms_res_sh[0]);
}
__syncthreads();

Expand Down Expand Up @@ -333,8 +342,9 @@ __global__ void apply_kernel(
__syncthreads();

if (threadIdx.x / config::warp_size == 0) {
single_rhs_compute_norm2(subgroup, num_rows, r_sh,
norms_res_sh[0]);
gko::kernels::GKO_DEVICE_NAMESPACE::batch_single_kernels::
single_rhs_compute_norm2(subgroup, num_rows, r_sh,
norms_res_sh[0]);
}
//__syncthreads();

Expand All @@ -347,7 +357,8 @@ __global__ void apply_kernel(
logger.log_iteration(batch_id, iter, norms_res_sh[0]);

// copy x back to global memory
single_rhs_copy(num_rows, x_sh, x_gl_entry_ptr);
gko::kernels::GKO_DEVICE_NAMESPACE::batch_single_kernels::
single_rhs_copy(num_rows, x_sh, x_gl_entry_ptr);
__syncthreads();
}
}
Loading

0 comments on commit 4e9a501

Please sign in to comment.