From 91e66528e0c9cd2e35bd6979505df839f88dd30e Mon Sep 17 00:00:00 2001 From: Daniel Weindl Date: Sun, 25 Feb 2024 18:53:23 +0100 Subject: [PATCH] Refactor: smoother conversion from SUNMatrixWrapper to SUNMatrix Adds an implicit conversion function to SUNMatrixWrapper make things more readable. --- include/amici/rdata.h | 2 +- include/amici/sundials_matrix_wrapper.h | 5 +++++ src/model.cpp | 8 ++++---- src/model_dae.cpp | 20 +++++++++---------- src/model_ode.cpp | 20 +++++++++---------- src/newton_solver.cpp | 26 ++++++++++++------------- src/sundials_linsol_wrapper.cpp | 8 ++++---- tests/cpp/unittests/testMisc.cpp | 8 ++++---- 8 files changed, 51 insertions(+), 46 deletions(-) diff --git a/include/amici/rdata.h b/include/amici/rdata.h index 1de02c99db..6357d24748 100644 --- a/include/amici/rdata.h +++ b/include/amici/rdata.h @@ -576,7 +576,7 @@ class ReturnData : public ModelDimensions { if (!this->J.empty()) { SUNMatrixWrapper J(nx_solver, nx_solver); - model.fJ(t_, 0.0, x_solver_, dx_solver_, xdot, J.get()); + model.fJ(t_, 0.0, x_solver_, dx_solver_, xdot, J); // CVODES uses colmajor, so we need to transform to rowmajor for (int ix = 0; ix < model.nx_solver; ix++) for (int jx = 0; jx < model.nx_solver; jx++) diff --git a/include/amici/sundials_matrix_wrapper.h b/include/amici/sundials_matrix_wrapper.h index ee2516f78d..0bb9b9215f 100644 --- a/include/amici/sundials_matrix_wrapper.h +++ b/include/amici/sundials_matrix_wrapper.h @@ -72,6 +72,11 @@ class SUNMatrixWrapper { ~SUNMatrixWrapper(); + /** + * @brief Conversion function. + */ + operator SUNMatrix() { return get(); }; + /** * @brief Copy constructor * @param other diff --git a/src/model.cpp b/src/model.cpp index 2485867a4c..3478610bbe 100644 --- a/src/model.cpp +++ b/src/model.cpp @@ -2249,7 +2249,7 @@ void Model::fdJydy(int const it, AmiVector const& x, ExpData const& edata) { auto tmp_sparse = SUNMatrixWrapper(tmp_dense, 0.0, CSC_MAT); auto ret = SUNMatScaleAdd( - 1.0, derived_state_.dJydy_.at(iyt).get(), tmp_sparse.get() + 1.0, derived_state_.dJydy_.at(iyt), tmp_sparse ); if (ret != SUNMAT_SUCCESS) { throw AmiException( @@ -2897,7 +2897,7 @@ void Model::fdwdp(realtype const t, realtype const* x) { } if (always_check_finite_) { - checkFinite(derived_state_.dwdp_.get(), ModelQuantity::dwdp, t); + checkFinite(derived_state_.dwdp_, ModelQuantity::dwdp, t); } } @@ -2943,7 +2943,7 @@ void Model::fdwdx(realtype const t, realtype const* x) { } if (always_check_finite_) { - checkFinite(derived_state_.dwdx_.get(), ModelQuantity::dwdx, t); + checkFinite(derived_state_.dwdx_, ModelQuantity::dwdx, t); } } @@ -2960,7 +2960,7 @@ void Model::fdwdw(realtype const t, realtype const* x) { ); if (always_check_finite_) { - checkFinite(dwdw_.get(), ModelQuantity::dwdw, t); + checkFinite(dwdw_, ModelQuantity::dwdw, t); } } diff --git a/src/model_dae.cpp b/src/model_dae.cpp index 3b2e74e0e1..a3a8aaea26 100644 --- a/src/model_dae.cpp +++ b/src/model_dae.cpp @@ -14,7 +14,7 @@ void Model_DAE::fJ( realtype t, realtype cj, const_N_Vector x, const_N_Vector dx, const_N_Vector /*xdot*/, SUNMatrix J ) { - fJSparse(t, cj, x, dx, derived_state_.J_.get()); + fJSparse(t, cj, x, dx, derived_state_.J_); derived_state_.J_.refresh(); auto JDense = SUNMatrixWrapper(J); derived_state_.J_.to_dense(JDense); @@ -88,7 +88,7 @@ void Model_DAE::fJv( N_Vector Jv, realtype cj ) { N_VConst(0.0, Jv); - fJSparse(t, cj, x, dx, derived_state_.J_.get()); + fJSparse(t, cj, x, dx, derived_state_.J_); derived_state_.J_.refresh(); derived_state_.J_.multiply(Jv, v); } @@ -135,7 +135,7 @@ void Model_DAE::fJDiag( realtype const t, AmiVector& JDiag, realtype const /*cj*/, AmiVector const& x, AmiVector const& dx ) { - fJSparse(t, 0.0, x.getNVector(), dx.getNVector(), derived_state_.J_.get()); + fJSparse(t, 0.0, x.getNVector(), dx.getNVector(), derived_state_.J_); derived_state_.J_.refresh(); derived_state_.J_.to_diag(JDiag.getNVector()); if (checkFinite(JDiag.getVector(), ModelQuantity::JDiag) != AMICI_SUCCESS) @@ -355,7 +355,7 @@ void Model_DAE::fJB( realtype t, realtype cj, const_N_Vector x, const_N_Vector dx, const_N_Vector /*xB*/, const_N_Vector /*dxB*/, SUNMatrix JB ) { - fJSparse(t, cj, x, dx, derived_state_.J_.get()); + fJSparse(t, cj, x, dx, derived_state_.J_); derived_state_.J_.refresh(); auto JBDense = SUNMatrixWrapper(JB); derived_state_.J_.transpose(JBDense, -1.0, nxtrue_solver); @@ -376,7 +376,7 @@ void Model_DAE::fJSparseB( realtype t, realtype cj, const_N_Vector x, const_N_Vector dx, const_N_Vector /*xB*/, const_N_Vector /*dxB*/, SUNMatrix JB ) { - fJSparse(t, cj, x, dx, derived_state_.J_.get()); + fJSparse(t, cj, x, dx, derived_state_.J_); derived_state_.J_.refresh(); auto JSparseB = SUNMatrixWrapper(JB); derived_state_.J_.transpose(JSparseB, -1.0, nxtrue_solver); @@ -387,7 +387,7 @@ void Model_DAE::fJvB( const_N_Vector dxB, const_N_Vector vB, N_Vector JvB, realtype cj ) { N_VConst(0.0, JvB); - fJSparseB(t, cj, x, dx, xB, dxB, derived_state_.JB_.get()); + fJSparseB(t, cj, x, dx, xB, dxB, derived_state_.JB_); derived_state_.JB_.refresh(); derived_state_.JB_.multiply(JvB, vB); } @@ -397,7 +397,7 @@ void Model_DAE::fxBdot( const_N_Vector dxB, N_Vector xBdot ) { N_VConst(0.0, xBdot); - fJSparseB(t, 1.0, x, dx, xB, dxB, derived_state_.JB_.get()); + fJSparseB(t, 1.0, x, dx, xB, dxB, derived_state_.JB_); derived_state_.JB_.refresh(); fM(t, x); derived_state_.JB_.multiply(xBdot, xB); @@ -454,7 +454,7 @@ void Model_DAE::fqBdot_ss( void Model_DAE::fJSparseB_ss(SUNMatrix JB) { /* Just pass the model Jacobian on to JB */ - SUNMatCopy(derived_state_.JB_.get(), JB); + SUNMatCopy(derived_state_.JB_, JB); derived_state_.JB_.refresh(); } @@ -465,7 +465,7 @@ void Model_DAE::writeSteadystateJB( /* Get backward Jacobian */ fJSparseB( t, cj, x.getNVector(), dx.getNVector(), xB.getNVector(), - dxB.getNVector(), derived_state_.JB_.get() + dxB.getNVector(), derived_state_.JB_ ); derived_state_.JB_.refresh(); /* Switch sign, as we integrate forward in time, not backward */ @@ -491,7 +491,7 @@ void Model_DAE::fsxdot( // the same for all remaining fM(t, x); fdxdotdp(t, x, dx); - fJSparse(t, 0.0, x, dx, derived_state_.J_.get()); + fJSparse(t, 0.0, x, dx, derived_state_.J_); derived_state_.J_.refresh(); } diff --git a/src/model_ode.cpp b/src/model_ode.cpp index 52275008e3..257ef45289 100644 --- a/src/model_ode.cpp +++ b/src/model_ode.cpp @@ -14,7 +14,7 @@ void Model_ODE::fJ( void Model_ODE::fJ( realtype t, const_N_Vector x, const_N_Vector /*xdot*/, SUNMatrix J ) { - fJSparse(t, x, derived_state_.J_.get()); + fJSparse(t, x, derived_state_.J_); derived_state_.J_.refresh(); auto JDense = SUNMatrixWrapper(J); derived_state_.J_.to_dense(JDense); @@ -77,7 +77,7 @@ void Model_ODE::fJv( const_N_Vector v, N_Vector Jv, realtype t, const_N_Vector x ) { N_VConst(0.0, Jv); - fJSparse(t, x, derived_state_.J_.get()); + fJSparse(t, x, derived_state_.J_); derived_state_.J_.refresh(); derived_state_.J_.multiply(Jv, v); } @@ -355,7 +355,7 @@ void Model_ODE::fJB( realtype t, const_N_Vector x, const_N_Vector /*xB*/, const_N_Vector /*xBdot*/, SUNMatrix JB ) { - fJSparse(t, x, derived_state_.J_.get()); + fJSparse(t, x, derived_state_.J_); derived_state_.J_.refresh(); auto JDenseB = SUNMatrixWrapper(JB); derived_state_.J_.transpose(JDenseB, -1.0, nxtrue_solver); @@ -373,14 +373,14 @@ void Model_ODE::fJSparseB( realtype t, const_N_Vector x, const_N_Vector /*xB*/, const_N_Vector /*xBdot*/, SUNMatrix JB ) { - fJSparse(t, x, derived_state_.J_.get()); + fJSparse(t, x, derived_state_.J_); derived_state_.J_.refresh(); auto JSparseB = SUNMatrixWrapper(JB); derived_state_.J_.transpose(JSparseB, -1.0, nxtrue_solver); } void Model_ODE::fJDiag(realtype t, N_Vector JDiag, const_N_Vector x) { - fJSparse(t, x, derived_state_.J_.get()); + fJSparse(t, x, derived_state_.J_); derived_state_.J_.refresh(); derived_state_.J_.to_diag(JDiag); } @@ -390,14 +390,14 @@ void Model_ODE::fJvB( const_N_Vector xB ) { N_VConst(0.0, JvB); - fJSparseB(t, x, xB, nullptr, derived_state_.JB_.get()); + fJSparseB(t, x, xB, nullptr, derived_state_.JB_); derived_state_.JB_.refresh(); derived_state_.JB_.multiply(JvB, vB); } void Model_ODE::fxBdot(realtype t, N_Vector x, N_Vector xB, N_Vector xBdot) { N_VConst(0.0, xBdot); - fJSparseB(t, x, xB, nullptr, derived_state_.JB_.get()); + fJSparseB(t, x, xB, nullptr, derived_state_.JB_); derived_state_.JB_.refresh(); derived_state_.JB_.multiply(xBdot, xB); } @@ -456,7 +456,7 @@ void Model_ODE::fqBdot_ss(realtype /*t*/, N_Vector xB, N_Vector qBdot) const { void Model_ODE::fJSparseB_ss(SUNMatrix JB) { /* Just copy the model Jacobian */ - SUNMatCopy(derived_state_.JB_.get(), JB); + SUNMatCopy(derived_state_.JB_, JB); derived_state_.JB_.refresh(); } @@ -468,7 +468,7 @@ void Model_ODE::writeSteadystateJB( /* Get backward Jacobian */ fJSparseB( t, x.getNVector(), xB.getNVector(), xBdot.getNVector(), - derived_state_.JB_.get() + derived_state_.JB_ ); derived_state_.JB_.refresh(); /* Switch sign, as we integrate forward in time, not backward */ @@ -492,7 +492,7 @@ void Model_ODE::fsxdot( // we only need to call this for the first parameter index will be // the same for all remaining fdxdotdp(t, x); - fJSparse(t, x, derived_state_.J_.get()); + fJSparse(t, x, derived_state_.J_); derived_state_.J_.refresh(); } if (pythonGenerated) { diff --git a/src/newton_solver.cpp b/src/newton_solver.cpp index 8c3fcca5f4..9f011bac1e 100644 --- a/src/newton_solver.cpp +++ b/src/newton_solver.cpp @@ -108,7 +108,7 @@ void NewtonSolver::computeNewtonSensis( NewtonSolverDense::NewtonSolverDense(Model const& model) : NewtonSolver(model) , Jtmp_(model.nx_solver, model.nx_solver) - , linsol_(SUNLinSol_Dense(x_.getNVector(), Jtmp_.get())) { + , linsol_(SUNLinSol_Dense(x_.getNVector(), Jtmp_)) { auto status = SUNLinSolInitialize_Dense(linsol_); if (status != SUNLS_SUCCESS) throw NewtonFailure(status, "SUNLinSolInitialize_Dense"); @@ -117,9 +117,9 @@ NewtonSolverDense::NewtonSolverDense(Model const& model) void NewtonSolverDense::prepareLinearSystem( Model& model, SimulationState const& state ) { - model.fJ(state.t, 0.0, state.x, state.dx, xdot_, Jtmp_.get()); + model.fJ(state.t, 0.0, state.x, state.dx, xdot_, Jtmp_); Jtmp_.refresh(); - auto status = SUNLinSolSetup_Dense(linsol_, Jtmp_.get()); + auto status = SUNLinSolSetup_Dense(linsol_, Jtmp_); if (status != SUNLS_SUCCESS) throw NewtonFailure(status, "SUNLinSolSetup_Dense"); } @@ -127,16 +127,16 @@ void NewtonSolverDense::prepareLinearSystem( void NewtonSolverDense::prepareLinearSystemB( Model& model, SimulationState const& state ) { - model.fJB(state.t, 0.0, state.x, state.dx, xB_, dxB_, xdot_, Jtmp_.get()); + model.fJB(state.t, 0.0, state.x, state.dx, xB_, dxB_, xdot_, Jtmp_); Jtmp_.refresh(); - auto status = SUNLinSolSetup_Dense(linsol_, Jtmp_.get()); + auto status = SUNLinSolSetup_Dense(linsol_, Jtmp_); if (status != SUNLS_SUCCESS) throw NewtonFailure(status, "SUNLinSolSetup_Dense"); } void NewtonSolverDense::solveLinearSystem(AmiVector& rhs) { auto status = SUNLinSolSolve_Dense( - linsol_, Jtmp_.get(), rhs.getNVector(), rhs.getNVector(), 0.0 + linsol_, Jtmp_, rhs.getNVector(), rhs.getNVector(), 0.0 ); Jtmp_.refresh(); // last argument is tolerance and does not have any influence on result @@ -167,7 +167,7 @@ NewtonSolverDense::~NewtonSolverDense() { NewtonSolverSparse::NewtonSolverSparse(Model const& model) : NewtonSolver(model) , Jtmp_(model.nx_solver, model.nx_solver, model.nnz, CSC_MAT) - , linsol_(SUNKLU(x_.getNVector(), Jtmp_.get())) { + , linsol_(SUNKLU(x_.getNVector(), Jtmp_)) { auto status = SUNLinSolInitialize_KLU(linsol_); if (status != SUNLS_SUCCESS) throw NewtonFailure(status, "SUNLinSolInitialize_KLU"); @@ -177,9 +177,9 @@ void NewtonSolverSparse::prepareLinearSystem( Model& model, SimulationState const& state ) { /* Get sparse Jacobian */ - model.fJSparse(state.t, 0.0, state.x, state.dx, xdot_, Jtmp_.get()); + model.fJSparse(state.t, 0.0, state.x, state.dx, xdot_, Jtmp_); Jtmp_.refresh(); - auto status = SUNLinSolSetup_KLU(linsol_, Jtmp_.get()); + auto status = SUNLinSolSetup_KLU(linsol_, Jtmp_); if (status != SUNLS_SUCCESS) throw NewtonFailure(status, "SUNLinSolSetup_KLU"); } @@ -189,10 +189,10 @@ void NewtonSolverSparse::prepareLinearSystemB( ) { /* Get sparse Jacobian */ model.fJSparseB( - state.t, 0.0, state.x, state.dx, xB_, dxB_, xdot_, Jtmp_.get() + state.t, 0.0, state.x, state.dx, xB_, dxB_, xdot_, Jtmp_ ); Jtmp_.refresh(); - auto status = SUNLinSolSetup_KLU(linsol_, Jtmp_.get()); + auto status = SUNLinSolSetup_KLU(linsol_, Jtmp_); if (status != SUNLS_SUCCESS) throw NewtonFailure(status, "SUNLinSolSetup_KLU"); } @@ -200,7 +200,7 @@ void NewtonSolverSparse::prepareLinearSystemB( void NewtonSolverSparse::solveLinearSystem(AmiVector& rhs) { /* Pass pointer to the linear solver */ auto status = SUNLinSolSolve_KLU( - linsol_, Jtmp_.get(), rhs.getNVector(), rhs.getNVector(), 0.0 + linsol_, Jtmp_, rhs.getNVector(), rhs.getNVector(), 0.0 ); // last argument is tolerance and does not have any influence on result @@ -211,7 +211,7 @@ void NewtonSolverSparse::solveLinearSystem(AmiVector& rhs) { void NewtonSolverSparse::reinitialize() { /* partial reinitialization, don't need to reallocate Jtmp_ */ auto status = SUNLinSol_KLUReInit( - linsol_, Jtmp_.get(), Jtmp_.capacity(), SUNKLU_REINIT_PARTIAL + linsol_, Jtmp_, Jtmp_.capacity(), SUNKLU_REINIT_PARTIAL ); if (status != SUNLS_SUCCESS) throw NewtonFailure(status, "SUNLinSol_KLUReInit"); diff --git a/src/sundials_linsol_wrapper.cpp b/src/sundials_linsol_wrapper.cpp index 765f2a1f91..5752ea03c3 100644 --- a/src/sundials_linsol_wrapper.cpp +++ b/src/sundials_linsol_wrapper.cpp @@ -161,7 +161,7 @@ SUNLinSolBand::SUNLinSolBand(N_Vector x, SUNMatrix A) SUNLinSolBand::SUNLinSolBand(AmiVector const& x, int ubw, int lbw) : A_(SUNMatrixWrapper(x.getLength(), ubw, lbw)) { - solver_ = SUNLinSol_Band(const_cast(x.getNVector()), A_.get()); + solver_ = SUNLinSol_Band(const_cast(x.getNVector()), A_); if (!solver_) throw AmiException("Failed to create solver."); } @@ -170,7 +170,7 @@ SUNMatrix SUNLinSolBand::getMatrix() const { return A_.get(); } SUNLinSolDense::SUNLinSolDense(AmiVector const& x) : A_(SUNMatrixWrapper(x.getLength(), x.getLength())) { - solver_ = SUNLinSol_Dense(const_cast(x.getNVector()), A_.get()); + solver_ = SUNLinSol_Dense(const_cast(x.getNVector()), A_); if (!solver_) throw AmiException("Failed to create solver."); } @@ -187,7 +187,7 @@ SUNLinSolKLU::SUNLinSolKLU( AmiVector const& x, int nnz, int sparsetype, StateOrdering ordering ) : A_(SUNMatrixWrapper(x.getLength(), x.getLength(), nnz, sparsetype)) { - solver_ = SUNLinSol_KLU(const_cast(x.getNVector()), A_.get()); + solver_ = SUNLinSol_KLU(const_cast(x.getNVector()), A_); if (!solver_) throw AmiException("Failed to create solver."); @@ -197,7 +197,7 @@ SUNLinSolKLU::SUNLinSolKLU( SUNMatrix SUNLinSolKLU::getMatrix() const { return A_.get(); } void SUNLinSolKLU::reInit(int nnz, int reinit_type) { - int status = SUNLinSol_KLUReInit(solver_, A_.get(), nnz, reinit_type); + int status = SUNLinSol_KLUReInit(solver_, A_, nnz, reinit_type); if (status != SUNLS_SUCCESS) throw AmiException("SUNLinSol_KLUReInit failed with %d", status); } diff --git a/tests/cpp/unittests/testMisc.cpp b/tests/cpp/unittests/testMisc.cpp index 1f464b3433..a722b567a7 100644 --- a/tests/cpp/unittests/testMisc.cpp +++ b/tests/cpp/unittests/testMisc.cpp @@ -681,7 +681,7 @@ TEST(UnravelIndex, UnravelIndexSunMatDense) A.set_data(2, 1, 5); for(int i = 0; i < 6; ++i) { - auto idx = unravel_index(i, A.get()); + auto idx = unravel_index(i, A); EXPECT_EQ(A.get_data(idx.first, idx.second), i); } } @@ -706,7 +706,7 @@ TEST(UnravelIndex, UnravelIndexSunMatSparse) D.set_data(2, 1, 0); D.set_data(3, 1, 0); - auto S = SUNSparseFromDenseMatrix(D.get(), 1e-15, CSC_MAT); + auto S = SUNSparseFromDenseMatrix(D, 1e-15, CSC_MAT); EXPECT_EQ(unravel_index(0, S), std::make_pair((sunindextype) 2, (sunindextype) 0)); EXPECT_EQ(unravel_index(1, S), std::make_pair((sunindextype) 3, (sunindextype) 0)); @@ -720,8 +720,8 @@ TEST(UnravelIndex, UnravelIndexSunMatSparseMissingIndices) { // Sparse matrix without any indices set SUNMatrixWrapper mat = SUNMatrixWrapper(2, 3, 2, CSC_MAT); - EXPECT_EQ(unravel_index(0, mat.get()), std::make_pair((sunindextype) -1, (sunindextype) -1)); - EXPECT_EQ(unravel_index(1, mat.get()), std::make_pair((sunindextype) -1, (sunindextype) -1)); + EXPECT_EQ(unravel_index(0, mat), std::make_pair((sunindextype) -1, (sunindextype) -1)); + EXPECT_EQ(unravel_index(1, mat), std::make_pair((sunindextype) -1, (sunindextype) -1)); }