Skip to content

Commit

Permalink
Implement row an column getter
Browse files Browse the repository at this point in the history
  • Loading branch information
niermann999 committed Dec 4, 2024
1 parent 524ee52 commit b87866c
Show file tree
Hide file tree
Showing 18 changed files with 213 additions and 76 deletions.
3 changes: 2 additions & 1 deletion frontend/array_cmath/include/algebra/array_cmath.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,10 @@ namespace getter {
/// @{

using cmath::storage::block;
using cmath::storage::column;
using cmath::storage::element;
using cmath::storage::row;
using cmath::storage::set_block;
using cmath::storage::vector;

/// @}

Expand Down
3 changes: 2 additions & 1 deletion frontend/eigen_eigen/include/algebra/eigen_eigen.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,10 @@ namespace getter {
/// @{

using eigen::storage::block;
using eigen::storage::column;
using eigen::storage::element;
using eigen::storage::row;
using eigen::storage::set_block;
using eigen::storage::vector;

/// @}

Expand Down
3 changes: 2 additions & 1 deletion frontend/eigen_generic/include/algebra/eigen_generic.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,10 @@ namespace getter {
/// @{

using eigen::storage::block;
using eigen::storage::column;
using eigen::storage::element;
using eigen::storage::row;
using eigen::storage::set_block;
using eigen::storage::vector;

/// @}

Expand Down
3 changes: 2 additions & 1 deletion frontend/fastor_fastor/include/algebra/fastor_fastor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,10 @@ namespace getter {
/// @{

using fastor::storage::block;
using fastor::storage::column;
using fastor::storage::element;
using fastor::storage::row;
using fastor::storage::set_block;
using fastor::storage::vector;

/// @}

Expand Down
3 changes: 2 additions & 1 deletion frontend/smatrix_generic/include/algebra/smatrix_generic.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,10 @@ namespace getter {
/// @{

using smatrix::storage::block;
using smatrix::storage::column;
using smatrix::storage::element;
using smatrix::storage::row;
using smatrix::storage::set_block;
using smatrix::storage::vector;

/// @}

Expand Down
3 changes: 2 additions & 1 deletion frontend/smatrix_smatrix/include/algebra/smatrix_smatrix.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,10 @@ namespace getter {
/// @{

using smatrix::storage::block;
using smatrix::storage::column;
using smatrix::storage::element;
using smatrix::storage::row;
using smatrix::storage::set_block;
using smatrix::storage::vector;

/// @}

Expand Down
3 changes: 2 additions & 1 deletion frontend/vc_aos/include/algebra/vc_aos.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,10 @@ namespace getter {
/// @{

using vc_aos::storage::block;
using vc_aos::storage::column;
using vc_aos::storage::element;
using vc_aos::storage::row;
using vc_aos::storage::set_block;
using vc_aos::storage::vector;

/// @}

Expand Down
3 changes: 2 additions & 1 deletion frontend/vc_aos_generic/include/algebra/vc_aos_generic.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,10 @@ namespace getter {
/// @{

using vc_aos::storage::block;
using vc_aos::storage::column;
using vc_aos::storage::element;
using vc_aos::storage::row;
using vc_aos::storage::set_block;
using vc_aos::storage::vector;

/// @}

Expand Down
3 changes: 2 additions & 1 deletion frontend/vc_soa/include/algebra/vc_soa.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,10 @@ namespace getter {
/// @{

using vc_soa::storage::block;
using vc_soa::storage::column;
using vc_soa::storage::element;
using vc_soa::storage::row;
using vc_soa::storage::set_block;
using vc_soa::storage::vector;

/// @}

Expand Down
3 changes: 2 additions & 1 deletion frontend/vecmem_cmath/include/algebra/vecmem_cmath.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,10 @@ namespace getter {
/// @{

using cmath::storage::block;
using cmath::storage::column;
using cmath::storage::element;
using cmath::storage::row;
using cmath::storage::set_block;
using cmath::storage::vector;

/// @}

Expand Down
41 changes: 36 additions & 5 deletions storage/cmath/include/algebra/storage/impl/cmath_getter.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -165,11 +165,31 @@ struct block_getter {
return submatrix;
}

/// Operator producing a vector out of a const matrix
/// Operator producing a row vector out of a const matrix
template <std::size_t SIZE, std::size_t ROWS, std::size_t COLS,
concepts::scalar scalar_t,
template <typename, std::size_t> class array_t>
ALGEBRA_HOST_DEVICE inline array_t<scalar_t, SIZE> vector(
ALGEBRA_HOST_DEVICE inline array_t<scalar_t, SIZE> row(
const array_t<array_t<scalar_t, ROWS>, COLS> &m, std::size_t row,
std::size_t col) {

assert(col + SIZE <= COLS);
assert(row < ROWS);

array_t<scalar_t, SIZE> subvector{};

for (std::size_t icol = col; icol < col + SIZE; ++icol) {
subvector[icol - col] = m[icol][row];
}

return subvector;
}

/// Operator producing a column vector out of a const matrix
template <std::size_t SIZE, std::size_t ROWS, std::size_t COLS,
concepts::scalar scalar_t,
template <typename, std::size_t> class array_t>
ALGEBRA_HOST_DEVICE inline array_t<scalar_t, SIZE> column(
const array_t<array_t<scalar_t, ROWS>, COLS> &m, std::size_t row,
std::size_t col) {

Expand All @@ -196,15 +216,26 @@ ALGEBRA_HOST_DEVICE decltype(auto) block(const input_matrix_type &m,
return block_getter().template operator()<ROWS, COLS>(m, row, col);
}

/// Function extracting a vector from a matrix
/// Function extracting a row vector from a matrix
template <std::size_t SIZE, std::size_t ROWS, std::size_t COLS,
concepts::scalar scalar_t,
template <typename, std::size_t> class array_t>
ALGEBRA_HOST_DEVICE inline array_t<scalar_t, SIZE> row(
const array_t<array_t<scalar_t, ROWS>, COLS> &m, std::size_t row,
std::size_t col) {

return block_getter().template row<SIZE>(m, row, col);
}

/// Function extracting a column vector from a matrix
template <std::size_t SIZE, std::size_t ROWS, std::size_t COLS,
concepts::scalar scalar_t,
template <typename, std::size_t> class array_t>
ALGEBRA_HOST_DEVICE inline array_t<scalar_t, SIZE> vector(
ALGEBRA_HOST_DEVICE inline array_t<scalar_t, SIZE> column(
const array_t<array_t<scalar_t, ROWS>, COLS> &m, std::size_t row,
std::size_t col) {

return block_getter().template vector<SIZE>(m, row, col);
return block_getter().template column<SIZE>(m, row, col);
}

/// Sets a matrix of dimension @tparam ROW and @tparam COL as submatrix of
Expand Down
48 changes: 45 additions & 3 deletions storage/common/include/algebra/storage/matrix_getter.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -178,18 +178,42 @@ struct block_getter {
return res_m;
}

/// Get a vector of a const matrix
/// Get a row vector of a const matrix
template <std::size_t SIZE, std::size_t ROWS, std::size_t COLS,
concepts::scalar scalar_t,
template <typename, std::size_t> class array_t>
ALGEBRA_HOST_DEVICE constexpr auto vector(
ALGEBRA_HOST_DEVICE constexpr auto row(
const matrix<array_t, scalar_t, ROWS, COLS> &m, const std::size_t row,
const std::size_t col) noexcept {

static_assert(SIZE <= ROWS);
static_assert(SIZE <= COLS);
assert(row < ROWS);
assert(col + SIZE <= COLS);

using vector_t = algebra::storage::vector<SIZE, scalar_t, array_t>;

vector_t res_v{};

for (std::size_t i = col; i < col + SIZE; ++i) {
res_v[i - col] = m[i][row];
}

return res_v;
}

/// Get a column vector of a const matrix
template <std::size_t SIZE, std::size_t ROWS, std::size_t COLS,
concepts::scalar scalar_t,
template <typename, std::size_t> class array_t>
ALGEBRA_HOST_DEVICE constexpr auto column(
const matrix<array_t, scalar_t, ROWS, COLS> &m, const std::size_t row,
const std::size_t col) noexcept {

static_assert(SIZE <= ROWS);
static_assert(SIZE <= COLS);
assert(row + SIZE <= ROWS);
assert(col <= COLS);
assert(col < COLS);

using input_matrix_t = matrix<array_t, scalar_t, ROWS, COLS>;
using vector_t = algebra::storage::vector<SIZE, scalar_t, array_t>;
Expand Down Expand Up @@ -283,4 +307,22 @@ ALGEBRA_HOST_DEVICE constexpr void set_block(
}
}

template <std::size_t SIZE, std::size_t ROW, std::size_t COL,
concepts::scalar scalar_t,
template <typename, std::size_t> class array_t>
ALGEBRA_HOST_DEVICE constexpr auto row(
const algebra::storage::matrix<array_t, scalar_t, ROW, COL> &m,
const std::size_t row, const std::size_t col) noexcept {
return algebra::storage::block_getter{}.template row<SIZE>(m, row, col);
}

template <std::size_t SIZE, std::size_t ROW, std::size_t COL,
concepts::scalar scalar_t,
template <typename, std::size_t> class array_t>
ALGEBRA_HOST_DEVICE constexpr auto column(
const algebra::storage::matrix<array_t, scalar_t, ROW, COL> &m,
const std::size_t row, const std::size_t col) noexcept {
return algebra::storage::block_getter{}.template column<SIZE>(m, row, col);
}

} // namespace algebra::storage
78 changes: 53 additions & 25 deletions storage/eigen/include/algebra/storage/impl/eigen_getter.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,12 @@ struct element_getter {
/// Get non-const access to a matrix element
template <typename derived_type, concepts::index size_type_1,
concepts::index size_type_2>
requires std::is_base_of_v<
Eigen::DenseCoeffsBase<derived_type, Eigen::WriteAccessors>,
Eigen::MatrixBase<derived_type> >
ALGEBRA_HOST_DEVICE inline auto &operator()(
Eigen::MatrixBase<derived_type> &m, size_type_1 row,
size_type_2 col) const {
requires std::is_base_of_v<
Eigen::DenseCoeffsBase<derived_type, Eigen::WriteAccessors>,
Eigen::MatrixBase<derived_type> >
ALGEBRA_HOST_DEVICE inline auto &operator()(
Eigen::MatrixBase<derived_type> &m, size_type_1 row,
size_type_2 col) const {

return m(static_cast<Eigen::Index>(row), static_cast<Eigen::Index>(col));
}
Expand All @@ -58,11 +58,11 @@ struct element_getter {
}
/// Get non-const access to a matrix element
template <typename derived_type, concepts::index size_type>
requires std::is_base_of_v<
Eigen::DenseCoeffsBase<derived_type, Eigen::WriteAccessors>,
Eigen::MatrixBase<derived_type> >
ALGEBRA_HOST_DEVICE inline auto &operator()(
Eigen::MatrixBase<derived_type> &m, size_type row) const {
requires std::is_base_of_v<
Eigen::DenseCoeffsBase<derived_type, Eigen::WriteAccessors>,
Eigen::MatrixBase<derived_type> >
ALGEBRA_HOST_DEVICE inline auto &operator()(
Eigen::MatrixBase<derived_type> &m, size_type row) const {

return m(static_cast<Eigen::Index>(row));
}
Expand All @@ -87,11 +87,11 @@ ALGEBRA_HOST_DEVICE inline decltype(auto) element(

/// Function extracting an element from a matrix (non-const)
template <typename derived_type>
requires std::is_base_of_v<
Eigen::DenseCoeffsBase<derived_type, Eigen::WriteAccessors>,
Eigen::MatrixBase<derived_type> >
ALGEBRA_HOST_DEVICE inline decltype(auto) element(
Eigen::MatrixBase<derived_type> &m, std::size_t row, std::size_t col) {
requires std::is_base_of_v<
Eigen::DenseCoeffsBase<derived_type, Eigen::WriteAccessors>,
Eigen::MatrixBase<derived_type> >
ALGEBRA_HOST_DEVICE inline decltype(auto) element(
Eigen::MatrixBase<derived_type> &m, std::size_t row, std::size_t col) {

return element_getter()(m, static_cast<Eigen::Index>(row),
static_cast<Eigen::Index>(col));
Expand All @@ -107,11 +107,11 @@ ALGEBRA_HOST_DEVICE inline decltype(auto) element(

/// Function extracting an element from a matrix (non-const)
template <typename derived_type>
requires std::is_base_of_v<
Eigen::DenseCoeffsBase<derived_type, Eigen::WriteAccessors>,
Eigen::MatrixBase<derived_type> >
ALGEBRA_HOST_DEVICE inline decltype(auto) element(
Eigen::MatrixBase<derived_type> &m, std::size_t row) {
requires std::is_base_of_v<
Eigen::DenseCoeffsBase<derived_type, Eigen::WriteAccessors>,
Eigen::MatrixBase<derived_type> >
ALGEBRA_HOST_DEVICE inline decltype(auto) element(
Eigen::MatrixBase<derived_type> &m, std::size_t row) {

return element_getter()(m, static_cast<Eigen::Index>(row));
}
Expand Down Expand Up @@ -139,7 +139,25 @@ struct block_getter {

template <int SIZE, typename derived_type, concepts::index size_type_1,
concepts::index size_type_2>
ALGEBRA_HOST_DEVICE decltype(auto) vector(Eigen::MatrixBase<derived_type> &m,
ALGEBRA_HOST_DEVICE decltype(auto) row(Eigen::MatrixBase<derived_type> &m,
size_type_1 row,
size_type_2 col) const {

return m.template block<1, SIZE>(row, col);
}

template <int SIZE, typename derived_type, concepts::index size_type_1,
concepts::index size_type_2>
ALGEBRA_HOST_DEVICE decltype(auto) row(
const Eigen::MatrixBase<derived_type> &m, size_type_1 row,
size_type_2 col) const {

return m.template block<1, SIZE>(row, col);
}

template <int SIZE, typename derived_type, concepts::index size_type_1,
concepts::index size_type_2>
ALGEBRA_HOST_DEVICE decltype(auto) column(Eigen::MatrixBase<derived_type> &m,
size_type_1 row,
size_type_2 col) const {

Expand All @@ -148,7 +166,7 @@ struct block_getter {

template <int SIZE, typename derived_type, concepts::index size_type_1,
concepts::index size_type_2>
ALGEBRA_HOST_DEVICE decltype(auto) vector(
ALGEBRA_HOST_DEVICE decltype(auto) column(
const Eigen::MatrixBase<derived_type> &m, size_type_1 row,
size_type_2 col) const {

Expand Down Expand Up @@ -176,11 +194,21 @@ ALGEBRA_HOST_DEVICE decltype(auto) block(Eigen::MatrixBase<derived_type> &m,

/// Function extracting a slice from the matrix
template <int SIZE, typename derived_type>
ALGEBRA_HOST_DEVICE inline decltype(auto) vector(
ALGEBRA_HOST_DEVICE inline decltype(auto) row(
const Eigen::MatrixBase<derived_type> &m, std::size_t row,
std::size_t col) {

return block_getter{}.template row<SIZE>(m, static_cast<Eigen::Index>(row),
static_cast<Eigen::Index>(col));
}

/// Function extracting a slice from the matrix
template <int SIZE, typename derived_type>
ALGEBRA_HOST_DEVICE inline decltype(auto) column(
const Eigen::MatrixBase<derived_type> &m, std::size_t row,
std::size_t col) {

return block_getter{}.template vector<SIZE>(m, static_cast<Eigen::Index>(row),
return block_getter{}.template column<SIZE>(m, static_cast<Eigen::Index>(row),
static_cast<Eigen::Index>(col));
}

Expand Down
Loading

0 comments on commit b87866c

Please sign in to comment.