From 94d90a7f00b65cf43051a352fef2774ee4a685cd Mon Sep 17 00:00:00 2001 From: rhdong Date: Thu, 19 Dec 2024 19:44:33 -0800 Subject: [PATCH] optimization for review comments --- cpp/bench/prims/linalg/masked_matmul.cu | 2 +- cpp/include/raft/sparse/convert/csr.cuh | 47 +++++++ .../sparse/convert/detail/bitset_to_csr.cuh | 46 +++++-- .../raft/sparse/linalg/masked_matmul.cuh | 117 ++++++++++++++++++ .../raft/sparse/linalg/masked_matmul.hpp | 104 ++-------------- cpp/test/sparse/masked_matmul.cu | 30 +++-- 6 files changed, 230 insertions(+), 116 deletions(-) create mode 100644 cpp/include/raft/sparse/linalg/masked_matmul.cuh diff --git a/cpp/bench/prims/linalg/masked_matmul.cu b/cpp/bench/prims/linalg/masked_matmul.cu index b831528b81..b96e14a25d 100644 --- a/cpp/bench/prims/linalg/masked_matmul.cu +++ b/cpp/bench/prims/linalg/masked_matmul.cu @@ -22,7 +22,7 @@ #include #include #include -#include +#include #include #include diff --git a/cpp/include/raft/sparse/convert/csr.cuh b/cpp/include/raft/sparse/convert/csr.cuh index 818b572a23..5237edd383 100644 --- a/cpp/include/raft/sparse/convert/csr.cuh +++ b/cpp/include/raft/sparse/convert/csr.cuh @@ -136,6 +136,53 @@ void bitmap_to_csr(raft::resources const& handle, * The bitset format inherently supports only a single-row matrix (rows=1). If the CSR matrix * requires multiple rows, the data from the bitset will be repeated for each row in the output. * + * Example usage: + * + * @code{.cpp} + * #include + * #include + * #include + * + * #include + * + * using bitset_t = uint32_t; + * using index_t = int; + * using value_t = float; + * using nnz_t = index_t; + * + * raft::resources handle; + * auto stream = resource::get_cuda_stream(handle); + * index_t n_rows = 3; + * index_t n_cols = 30; + * + * nnz_t nnz_for_bitset = 4; + * nnz_t nnz_for_csr = nnz_for_bitset * n_rows; + * + * index_t bitset_size = (n_cols + sizeof(bitset_t) * 8 - 1) / (sizeof(bitset_t) * 8); // = 1 + * + * rmm::device_uvector bitset_d(bitset_size, stream); + * std::vector bitset_h = { + * bitset_t(0b11001010), + * }; // nnz_for_bitset = 4; + * + * raft::copy(bitset_d.data(), bitset_h.data(), bitset_h.size(), stream); + * + * auto bitset_view = raft::core::bitset_view(bitset_d.data(), n_cols); + * auto csr = raft::make_device_csr_matrix(handle, n_rows, n_cols, nnz_for_csr); + * + * raft::sparse::convert::bitset_to_csr(handle, bitset_view, csr); + * resource::sync_stream(handle); + * + * // Results: + * // csr.indptr = [0, 4, 8, 12]; + * // csr.indices = [1, 3, 6, 7, + * // 1, 3, 6, 7, + * // 1, 3, 6, 7]; + * // csr.values = [1, 1, 1, 1, + * // 1, 1, 1, 1, + * // 1, 1, 1, 1]; + * @endcode + * * @tparam bitset_t The data type of the elements in the bitset matrix. * @tparam index_t The data type used for indexing the elements in the matrices. * @tparam csr_matrix_t Specifies the CSR matrix type, constrained to diff --git a/cpp/include/raft/sparse/convert/detail/bitset_to_csr.cuh b/cpp/include/raft/sparse/convert/detail/bitset_to_csr.cuh index f4660f4ecf..72abd02f7e 100644 --- a/cpp/include/raft/sparse/convert/detail/bitset_to_csr.cuh +++ b/cpp/include/raft/sparse/convert/detail/bitset_to_csr.cuh @@ -118,24 +118,52 @@ void bitset_to_csr(raft::resources const& handle, RAFT_CUDA_TRY(cudaMemsetAsync(indptr, 0, (csr_view.get_n_rows() + 1) * sizeof(index_t), stream)); - calc_nnz_by_rows(handle, bitset.data(), row_t(1), csr_view.get_n_cols(), indptr); - thrust::exclusive_scan(thrust_policy, indptr, indptr + 2, indptr); + size_t sub_nnz_size = 0; + index_t bits_per_sub_col = 0; + + // Get buffer size and number of bits per each sub-columns + calc_nnz_by_rows(handle, + bitset.data(), + row_t(1), + csr_view.get_n_cols(), + static_cast(nullptr), + sub_nnz_size, + bits_per_sub_col); + + rmm::device_async_resource_ref device_memory = resource::get_workspace_resource(handle); + rmm::device_uvector sub_nnz(sub_nnz_size + 1, stream, device_memory); + + calc_nnz_by_rows(handle, + bitset.data(), + row_t(1), + csr_view.get_n_cols(), + sub_nnz.data(), + sub_nnz_size, + bits_per_sub_col); + + thrust::exclusive_scan( + thrust_policy, sub_nnz.data(), sub_nnz.data() + sub_nnz_size + 1, sub_nnz.data()); index_t bitset_nnz = 0; - if constexpr (is_device_csr_sparsity_owning_v) { - RAFT_CUDA_TRY( - cudaMemcpyAsync(&bitset_nnz, indptr + 1, sizeof(index_t), cudaMemcpyDeviceToHost, stream)); + RAFT_CUDA_TRY(cudaMemcpyAsync( + &bitset_nnz, sub_nnz.data() + sub_nnz_size, sizeof(index_t), cudaMemcpyDeviceToHost, stream)); resource::sync_stream(handle); csr.initialize_sparsity(bitset_nnz * csr_view.get_n_rows()); } else { bitset_nnz = csr_view.get_nnz() / csr_view.get_n_rows(); } - constexpr bool check_nnz = is_device_csr_sparsity_preserving_v; - fill_indices_by_rows( - handle, bitset.data(), indptr, 1, csr_view.get_n_cols(), bitset_nnz, indices); - + fill_indices_by_rows(handle, + bitset.data(), + indptr, + 1, + csr_view.get_n_cols(), + csr_view.get_nnz(), + indices, + sub_nnz.data(), + bits_per_sub_col, + sub_nnz_size); if (csr_view.get_n_rows() > 1) { gpu_repeat_csr(handle, indptr, diff --git a/cpp/include/raft/sparse/linalg/masked_matmul.cuh b/cpp/include/raft/sparse/linalg/masked_matmul.cuh new file mode 100644 index 0000000000..288068dae2 --- /dev/null +++ b/cpp/include/raft/sparse/linalg/masked_matmul.cuh @@ -0,0 +1,117 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain A copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include + +namespace raft { +namespace sparse { +namespace linalg { + +/** + * @defgroup masked_matmul Masked Matrix Multiplication + * @{ + */ + +/** + * @brief Performs a masked multiplication of dense matrices A and B, followed by an element-wise + * multiplication with the sparsity pattern defined by the mask, resulting in the computation + * C = alpha * ((A * B) ∘ spy(mask)) + beta * C. + * + * This function multiplies two dense matrices A and B, and then applies an element-wise + * multiplication using the sparsity pattern provided by the mask. The result is scaled by alpha + * and added to beta times the original matrix C. + * + * @tparam value_t Data type of elements in the input matrices (e.g., half, float, double) + * @tparam output_t Data type of elements in the output matrices (e.g., float, double) + * @tparam index_t Type used for matrix indices + * @tparam nnz_t Type used for the number of non-zero entries in CSR format + * @tparam bitmap_t Type of the bitmap used for the mask + * + * @param[in] handle RAFT handle for resource management + * @param[in] A Input dense matrix (device_matrix_view) with shape [m, k] + * @param[in] B Input dense matrix (device_matrix_view) with shape [n, k] + * @param[in] mask Bitmap view representing the sparsity pattern (bitmap_view) with logical shape + * [m, n]. Each bit in the mask indicates whether the corresponding element pair in A and B is + * included (1) or masked out (0). + * @param[inout] C Output sparse matrix in CSR format (device_csr_matrix_view) with dense shape [m, + * n] + * @param[in] alpha Optional scalar multiplier for the product of A and B (default: 1.0 if + * std::nullopt) + * @param[in] beta Optional scalar multiplier for the original matrix C (default: 0 if std::nullopt) + */ +template +void masked_matmul(raft::resources const& handle, + raft::device_matrix_view A, + raft::device_matrix_view B, + raft::core::bitmap_view mask, + raft::device_csr_matrix_view C, + std::optional> alpha = std::nullopt, + std::optional> beta = std::nullopt) +{ + detail::masked_matmul(handle, A, B, mask, C, alpha, beta); +} + +/** + * @brief Computes a sparse matrix product with a masked sparsity pattern and scaling. + * + * This function computes the result of: + * C = alpha * ((A * B) ∘ spy(mask)) + beta * C + * where: + * - A and B are dense input matrices. + * - "mask" defines the sparsity pattern for element-wise multiplication. + * - The result is scaled by alpha and added to beta times the original C. + * + * **Special behavior of the mask**: + * - The `bitset` mask represents a single row of data, with its bits indicating whether + * each corresponding element in (A * B) is included (1) or masked out (0). + * - If the output CSR matrix `C` has multiple rows, the `bitset` is logically repeated + * across all rows of `C`. For example, if `C` has `n_rows` rows, the same `bitset` + * pattern is applied to all rows. + * + * @tparam value_t Data type of input matrix elements (e.g., half, float, double). + * @tparam output_t Data type of output matrix elements (e.g., float, double). + * @tparam index_t Type for matrix indices. + * @tparam nnz_t Type for non-zero entries in CSR format. + * @tparam bitmap_t Type for the bitmap mask. + * + * @param[in] handle RAFT handle for managing resources. + * @param[in] A Dense input matrix [m, k] (row-major). + * @param[in] B Dense input matrix [n, k] (row-major). + * @param[in] mask Bitmap view representing a single row [1, n], where each bit + * indicates if the corresponding element in (A * B) is included (1) + * or masked out (0). The pattern is repeated for all rows of `C`. + * @param[inout] C Output sparse matrix in CSR format [m, n]. + * @param[in] alpha Scalar multiplier for (A * B) (default: 1.0 if std::nullopt). + * @param[in] beta Scalar multiplier for the initial C (default: 0 if std::nullopt). + */ +template +void masked_matmul(raft::resources const& handle, + raft::device_matrix_view A, + raft::device_matrix_view B, + raft::core::bitset_view mask, + raft::device_csr_matrix_view C, + std::optional> alpha = std::nullopt, + std::optional> beta = std::nullopt) +{ + detail::masked_matmul(handle, A, B, mask, C, alpha, beta); +} + +/** @} */ // end of masked_matmul + +} // end namespace linalg +} // end namespace sparse +} // end namespace raft diff --git a/cpp/include/raft/sparse/linalg/masked_matmul.hpp b/cpp/include/raft/sparse/linalg/masked_matmul.hpp index 288068dae2..32322b90f6 100644 --- a/cpp/include/raft/sparse/linalg/masked_matmul.hpp +++ b/cpp/include/raft/sparse/linalg/masked_matmul.hpp @@ -13,105 +13,21 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#pragma once - -#include - -namespace raft { -namespace sparse { -namespace linalg { - /** - * @defgroup masked_matmul Masked Matrix Multiplication - * @{ + * This file is deprecated and will be removed in future release. + * Please use the cuh version instead. */ /** - * @brief Performs a masked multiplication of dense matrices A and B, followed by an element-wise - * multiplication with the sparsity pattern defined by the mask, resulting in the computation - * C = alpha * ((A * B) ∘ spy(mask)) + beta * C. - * - * This function multiplies two dense matrices A and B, and then applies an element-wise - * multiplication using the sparsity pattern provided by the mask. The result is scaled by alpha - * and added to beta times the original matrix C. - * - * @tparam value_t Data type of elements in the input matrices (e.g., half, float, double) - * @tparam output_t Data type of elements in the output matrices (e.g., float, double) - * @tparam index_t Type used for matrix indices - * @tparam nnz_t Type used for the number of non-zero entries in CSR format - * @tparam bitmap_t Type of the bitmap used for the mask - * - * @param[in] handle RAFT handle for resource management - * @param[in] A Input dense matrix (device_matrix_view) with shape [m, k] - * @param[in] B Input dense matrix (device_matrix_view) with shape [n, k] - * @param[in] mask Bitmap view representing the sparsity pattern (bitmap_view) with logical shape - * [m, n]. Each bit in the mask indicates whether the corresponding element pair in A and B is - * included (1) or masked out (0). - * @param[inout] C Output sparse matrix in CSR format (device_csr_matrix_view) with dense shape [m, - * n] - * @param[in] alpha Optional scalar multiplier for the product of A and B (default: 1.0 if - * std::nullopt) - * @param[in] beta Optional scalar multiplier for the original matrix C (default: 0 if std::nullopt) + * DISCLAIMER: this file is deprecated: use masked_matmul.cuh instead */ -template -void masked_matmul(raft::resources const& handle, - raft::device_matrix_view A, - raft::device_matrix_view B, - raft::core::bitmap_view mask, - raft::device_csr_matrix_view C, - std::optional> alpha = std::nullopt, - std::optional> beta = std::nullopt) -{ - detail::masked_matmul(handle, A, B, mask, C, alpha, beta); -} -/** - * @brief Computes a sparse matrix product with a masked sparsity pattern and scaling. - * - * This function computes the result of: - * C = alpha * ((A * B) ∘ spy(mask)) + beta * C - * where: - * - A and B are dense input matrices. - * - "mask" defines the sparsity pattern for element-wise multiplication. - * - The result is scaled by alpha and added to beta times the original C. - * - * **Special behavior of the mask**: - * - The `bitset` mask represents a single row of data, with its bits indicating whether - * each corresponding element in (A * B) is included (1) or masked out (0). - * - If the output CSR matrix `C` has multiple rows, the `bitset` is logically repeated - * across all rows of `C`. For example, if `C` has `n_rows` rows, the same `bitset` - * pattern is applied to all rows. - * - * @tparam value_t Data type of input matrix elements (e.g., half, float, double). - * @tparam output_t Data type of output matrix elements (e.g., float, double). - * @tparam index_t Type for matrix indices. - * @tparam nnz_t Type for non-zero entries in CSR format. - * @tparam bitmap_t Type for the bitmap mask. - * - * @param[in] handle RAFT handle for managing resources. - * @param[in] A Dense input matrix [m, k] (row-major). - * @param[in] B Dense input matrix [n, k] (row-major). - * @param[in] mask Bitmap view representing a single row [1, n], where each bit - * indicates if the corresponding element in (A * B) is included (1) - * or masked out (0). The pattern is repeated for all rows of `C`. - * @param[inout] C Output sparse matrix in CSR format [m, n]. - * @param[in] alpha Scalar multiplier for (A * B) (default: 1.0 if std::nullopt). - * @param[in] beta Scalar multiplier for the initial C (default: 0 if std::nullopt). - */ -template -void masked_matmul(raft::resources const& handle, - raft::device_matrix_view A, - raft::device_matrix_view B, - raft::core::bitset_view mask, - raft::device_csr_matrix_view C, - std::optional> alpha = std::nullopt, - std::optional> beta = std::nullopt) -{ - detail::masked_matmul(handle, A, B, mask, C, alpha, beta); -} +#pragma once -/** @} */ // end of masked_matmul +#ifndef RAFT_HIDE_DEPRECATION_WARNINGS +#pragma message(__FILE__ \ + " is deprecated and will be removed in a future release." \ + " Please use the cuh version instead.") +#endif -} // end namespace linalg -} // end namespace sparse -} // end namespace raft +#include diff --git a/cpp/test/sparse/masked_matmul.cu b/cpp/test/sparse/masked_matmul.cu index a235038018..5ee1677015 100644 --- a/cpp/test/sparse/masked_matmul.cu +++ b/cpp/test/sparse/masked_matmul.cu @@ -19,7 +19,7 @@ #include #include #include -#include +#include #include #include @@ -46,6 +46,8 @@ struct MaskedMatmulInputs { unsigned long long int seed; }; +enum class BitsLayout { Bitset, Bitmap }; + template struct sum_abs_op { __host__ __device__ value_t operator()(const value_t& x, const value_t& y) const @@ -87,7 +89,7 @@ bool isCuSparseVersionGreaterThan_12_0_1() template @@ -286,12 +288,14 @@ class MaskedMatmulTest resource::sync_stream(handle); index_t c_true_nnz = 0; - if constexpr (bitmap_or_bitset) { + if constexpr (bits_layout == BitsLayout::Bitmap) { c_true_nnz = create_sparse_matrix(params.m, params.n, params.sparsity, bits_h); - } else { + } else if constexpr (bits_layout == BitsLayout::Bitset) { c_true_nnz = create_sparse_matrix(1, params.n, params.sparsity, bits_h); repeat_cpu_bitset_inplace(bits_h, params.n, params.m - 1); c_true_nnz *= params.m; + } else { + GTEST_SKIP() << "Unsupported BitsLayout!"; } std::vector c_indptr_h(params.m + 1); @@ -343,12 +347,14 @@ class MaskedMatmulTest auto C = raft::make_device_csr_matrix_view(c_data_d.data(), c_structure); - if constexpr (bitmap_or_bitset) { + if constexpr (bits_layout == BitsLayout::Bitmap) { auto mask = raft::core::bitmap_view(bits_d.data(), params.m, params.n); raft::sparse::linalg::masked_matmul(handle, A, B, mask, C); - } else { + } else if constexpr (bits_layout == BitsLayout::Bitset) { auto mask = raft::core::bitset_view(bits_d.data(), params.n); raft::sparse::linalg::masked_matmul(handle, A, B, mask, C); + } else { + GTEST_SKIP() << "Unsupported BitsLayout!"; } resource::sync_stream(handle); @@ -386,22 +392,22 @@ class MaskedMatmulTest rmm::device_uvector c_expected_data_d; }; -using MaskedMatmulOnBitmapTestF = MaskedMatmulTest; +using MaskedMatmulOnBitmapTestF = MaskedMatmulTest; TEST_P(MaskedMatmulOnBitmapTestF, Result) { Run(); } -using MaskedMatmulOnBitmapTestD = MaskedMatmulTest; +using MaskedMatmulOnBitmapTestD = MaskedMatmulTest; TEST_P(MaskedMatmulOnBitmapTestD, Result) { Run(); } -using MaskedMatmulOnBitmapTestH = MaskedMatmulTest; +using MaskedMatmulOnBitmapTestH = MaskedMatmulTest; TEST_P(MaskedMatmulOnBitmapTestH, Result) { Run(); } -using MaskedMatmulOnBitsetTestF = MaskedMatmulTest; +using MaskedMatmulOnBitsetTestF = MaskedMatmulTest; TEST_P(MaskedMatmulOnBitsetTestF, Result) { Run(); } -using MaskedMatmulOnBitsetTestD = MaskedMatmulTest; +using MaskedMatmulOnBitsetTestD = MaskedMatmulTest; TEST_P(MaskedMatmulOnBitsetTestD, Result) { Run(); } -using MaskedMatmulOnBitsetTestH = MaskedMatmulTest; +using MaskedMatmulOnBitsetTestH = MaskedMatmulTest; TEST_P(MaskedMatmulOnBitsetTestH, Result) { Run(); } const std::vector> sddmm_inputs_f = {