Skip to content

Commit

Permalink
implemented vector rhs for gelss (#56)
Browse files Browse the repository at this point in the history
Co-authored-by: Nils Wentzell <[email protected]>
  • Loading branch information
jasonkaye and Wentzell authored Jan 24, 2024
1 parent 967f4a1 commit ec0dba4
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 7 deletions.
15 changes: 10 additions & 5 deletions c++/nda/lapack/gelss.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,10 +68,11 @@ namespace nda::lapack {
* if INFO = i, i off-diagonal elements of an intermediate
* bidiagonal form did not converge to zero.
*/
template <MemoryMatrix A, MemoryMatrix B, MemoryVector S>
template <MemoryMatrix A, MemoryArray B, MemoryVector S>
requires(have_same_value_type_v<A, B> and mem::on_host<A, B, S> and is_blas_lapack_v<get_value_t<A>>)
int gelss(A &&a, B &&b, S &&s, double rcond, int &rank) {
static_assert(has_F_layout<A> and has_F_layout<B>, "C order not implemented");
static_assert(MemoryVector<B> or MemoryMatrix<B>, "B must be vector or matrix");

using T = get_value_t<A>;
auto dm = std::min(a.extent(0), a.extent(1));
Expand All @@ -86,14 +87,18 @@ namespace nda::lapack {
// First call to get the optimal bufferSize
T bufferSize_T{};
int info = 0;
f77::gelss(a.extent(0), a.extent(1), b.extent(1), a.data(), get_ld(a), b.data(), get_ld(b), s.data(), rcond, rank, &bufferSize_T, -1,
rwork.data(), info);
int nrhs = 1, ldb = b.size(); // Defaults for B MemoryVector
if constexpr (MemoryMatrix<B>) {
nrhs = b.extent(1);
ldb = get_ld(b);
}
f77::gelss(a.extent(0), a.extent(1), nrhs, a.data(), get_ld(a), b.data(), ldb, s.data(), rcond, rank, &bufferSize_T, -1, rwork.data(), info);
int bufferSize = static_cast<int>(std::ceil(std::real(bufferSize_T)));

// Allocate work buffer and perform actual library call
array<T, 1> work(bufferSize);
f77::gelss(a.extent(0), a.extent(1), b.extent(1), a.data(), get_ld(a), b.data(), get_ld(b), s.data(), rcond, rank, work.data(), bufferSize,
rwork.data(), info);
f77::gelss(a.extent(0), a.extent(1), nrhs, a.data(), get_ld(a), b.data(), ldb, s.data(), rcond, rank, work.data(), bufferSize, rwork.data(),
info);

if (info) NDA_RUNTIME_ERROR << "Error in gesvd : info = " << info;
return info;
Expand Down
8 changes: 8 additions & 0 deletions c++/nda/lapack/gelss_worker.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include <optional>

#include "./gesvd.hpp"
#include "../linalg.hpp"

namespace nda::lapack {

Expand Down Expand Up @@ -85,6 +86,13 @@ namespace nda::lapack {
}
return std::make_pair(V_x_InvS_x_UT * B, err);
}

std::pair<vector<T>, double> operator()(vector_const_view<T> B, std::optional<long> /*inner_matrix_dim*/ = {}) const {
using std::sqrt;
double err = 0.0;
if (M != N) { err = norm(UT_NULL * B) / sqrt(B.size()); }
return std::make_pair(V_x_InvS_x_UT * B, err);
}
};

// Least square solver version specific for hermitian tail-fitting.
Expand Down
11 changes: 9 additions & 2 deletions test/c++/nda_lapack.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -145,8 +145,9 @@ TEST(lapack, zgesvd) { test_gesvd<dcomplex>(); } //NOLINT
template <typename value_t>
void test_gelss() {
// Cf. http://www.netlib.org/lapack/explore-html/d3/d77/example___d_g_e_l_s__colmajor_8c_source.html
auto A = matrix<value_t>{{1, 1, 1}, {2, 3, 4}, {3, 5, 2}, {4, 2, 5}, {5, 4, 3}};
auto B = matrix<value_t>{{-10, -3}, {12, 14}, {14, 12}, {16, 16}, {18, 16}};
auto A = matrix<value_t>{{1, 1, 1}, {2, 3, 4}, {3, 5, 2}, {4, 2, 5}, {5, 4, 3}};
auto B = matrix<value_t>{{-10, -3}, {12, 14}, {14, 12}, {16, 16}, {18, 16}};
auto Bvec = vector<value_t>{-10, 12, 14, 16, 18}; // For testing vector right hand side

auto [M, N] = A.shape();
auto x_exact = matrix<value_t>{{2, 1}, {1, 1}, {1, 2}};
Expand All @@ -155,11 +156,17 @@ void test_gelss() {
auto gelss_new = lapack::gelss_worker<value_t>{A};
auto [x_1, eps_1] = gelss_new(B);
EXPECT_ARRAY_NEAR(x_exact, x_1, 1e-14);
auto [x_2, eps_2] = gelss_new(Bvec);
EXPECT_ARRAY_NEAR(x_exact(_, 0), x_2, 1e-14);

int rank{};
matrix<value_t, F_layout> AF{A}, BF{B};
lapack::gelss(AF, BF, S, 1e-18, rank);
EXPECT_ARRAY_NEAR(x_exact, BF(range(N), _), 1e-14);

AF = A;
lapack::gelss(AF, Bvec, S, 1e-18, rank);
EXPECT_ARRAY_NEAR(x_exact(_, 0), Bvec(range(N)), 1e-14);
}
TEST(lapack, gelss) { test_gelss<double>(); } //NOLINT
TEST(lapack, zgelss) { test_gelss<dcomplex>(); } //NOLINT
Expand Down

0 comments on commit ec0dba4

Please sign in to comment.