From ec0dba4d439eb6898816fef51206843ec236d547 Mon Sep 17 00:00:00 2001 From: Jason Kaye Date: Wed, 24 Jan 2024 12:00:45 -0500 Subject: [PATCH] implemented vector rhs for gelss (#56) Co-authored-by: Nils Wentzell --- c++/nda/lapack/gelss.hpp | 15 ++++++++++----- c++/nda/lapack/gelss_worker.hpp | 8 ++++++++ test/c++/nda_lapack.cpp | 11 +++++++++-- 3 files changed, 27 insertions(+), 7 deletions(-) diff --git a/c++/nda/lapack/gelss.hpp b/c++/nda/lapack/gelss.hpp index a0f5c2913..24528cf55 100644 --- a/c++/nda/lapack/gelss.hpp +++ b/c++/nda/lapack/gelss.hpp @@ -68,10 +68,11 @@ namespace nda::lapack { * if INFO = i, i off-diagonal elements of an intermediate * bidiagonal form did not converge to zero. */ - template + template requires(have_same_value_type_v and mem::on_host and is_blas_lapack_v>) int gelss(A &&a, B &&b, S &&s, double rcond, int &rank) { static_assert(has_F_layout and has_F_layout, "C order not implemented"); + static_assert(MemoryVector or MemoryMatrix, "B must be vector or matrix"); using T = get_value_t; auto dm = std::min(a.extent(0), a.extent(1)); @@ -86,14 +87,18 @@ namespace nda::lapack { // First call to get the optimal bufferSize T bufferSize_T{}; int info = 0; - f77::gelss(a.extent(0), a.extent(1), b.extent(1), a.data(), get_ld(a), b.data(), get_ld(b), s.data(), rcond, rank, &bufferSize_T, -1, - rwork.data(), info); + int nrhs = 1, ldb = b.size(); // Defaults for B MemoryVector + if constexpr (MemoryMatrix) { + nrhs = b.extent(1); + ldb = get_ld(b); + } + f77::gelss(a.extent(0), a.extent(1), nrhs, a.data(), get_ld(a), b.data(), ldb, s.data(), rcond, rank, &bufferSize_T, -1, rwork.data(), info); int bufferSize = static_cast(std::ceil(std::real(bufferSize_T))); // Allocate work buffer and perform actual library call array work(bufferSize); - f77::gelss(a.extent(0), a.extent(1), b.extent(1), a.data(), get_ld(a), b.data(), get_ld(b), s.data(), rcond, rank, work.data(), bufferSize, - rwork.data(), info); + f77::gelss(a.extent(0), a.extent(1), nrhs, a.data(), get_ld(a), b.data(), ldb, s.data(), rcond, rank, work.data(), bufferSize, rwork.data(), + info); if (info) NDA_RUNTIME_ERROR << "Error in gesvd : info = " << info; return info; diff --git a/c++/nda/lapack/gelss_worker.hpp b/c++/nda/lapack/gelss_worker.hpp index 473b2db81..b1293a069 100644 --- a/c++/nda/lapack/gelss_worker.hpp +++ b/c++/nda/lapack/gelss_worker.hpp @@ -19,6 +19,7 @@ #include #include "./gesvd.hpp" +#include "../linalg.hpp" namespace nda::lapack { @@ -85,6 +86,13 @@ namespace nda::lapack { } return std::make_pair(V_x_InvS_x_UT * B, err); } + + std::pair, double> operator()(vector_const_view B, std::optional /*inner_matrix_dim*/ = {}) const { + using std::sqrt; + double err = 0.0; + if (M != N) { err = norm(UT_NULL * B) / sqrt(B.size()); } + return std::make_pair(V_x_InvS_x_UT * B, err); + } }; // Least square solver version specific for hermitian tail-fitting. diff --git a/test/c++/nda_lapack.cpp b/test/c++/nda_lapack.cpp index af733e645..b00efbb89 100644 --- a/test/c++/nda_lapack.cpp +++ b/test/c++/nda_lapack.cpp @@ -145,8 +145,9 @@ TEST(lapack, zgesvd) { test_gesvd(); } //NOLINT template void test_gelss() { // Cf. http://www.netlib.org/lapack/explore-html/d3/d77/example___d_g_e_l_s__colmajor_8c_source.html - auto A = matrix{{1, 1, 1}, {2, 3, 4}, {3, 5, 2}, {4, 2, 5}, {5, 4, 3}}; - auto B = matrix{{-10, -3}, {12, 14}, {14, 12}, {16, 16}, {18, 16}}; + auto A = matrix{{1, 1, 1}, {2, 3, 4}, {3, 5, 2}, {4, 2, 5}, {5, 4, 3}}; + auto B = matrix{{-10, -3}, {12, 14}, {14, 12}, {16, 16}, {18, 16}}; + auto Bvec = vector{-10, 12, 14, 16, 18}; // For testing vector right hand side auto [M, N] = A.shape(); auto x_exact = matrix{{2, 1}, {1, 1}, {1, 2}}; @@ -155,11 +156,17 @@ void test_gelss() { auto gelss_new = lapack::gelss_worker{A}; auto [x_1, eps_1] = gelss_new(B); EXPECT_ARRAY_NEAR(x_exact, x_1, 1e-14); + auto [x_2, eps_2] = gelss_new(Bvec); + EXPECT_ARRAY_NEAR(x_exact(_, 0), x_2, 1e-14); int rank{}; matrix AF{A}, BF{B}; lapack::gelss(AF, BF, S, 1e-18, rank); EXPECT_ARRAY_NEAR(x_exact, BF(range(N), _), 1e-14); + + AF = A; + lapack::gelss(AF, Bvec, S, 1e-18, rank); + EXPECT_ARRAY_NEAR(x_exact(_, 0), Bvec(range(N)), 1e-14); } TEST(lapack, gelss) { test_gelss(); } //NOLINT TEST(lapack, zgelss) { test_gelss(); } //NOLINT