Skip to content

Commit

Permalink
Refactor: smoother conversion from SUNMatrixWrapper to SUNMatrix
Browse files Browse the repository at this point in the history
Adds an implicit conversion function to SUNMatrixWrapper make things more readable.
  • Loading branch information
dweindl committed Feb 25, 2024
1 parent 8b324bc commit 91e6652
Show file tree
Hide file tree
Showing 8 changed files with 51 additions and 46 deletions.
2 changes: 1 addition & 1 deletion include/amici/rdata.h
Original file line number Diff line number Diff line change
Expand Up @@ -576,7 +576,7 @@ class ReturnData : public ModelDimensions {

if (!this->J.empty()) {
SUNMatrixWrapper J(nx_solver, nx_solver);
model.fJ(t_, 0.0, x_solver_, dx_solver_, xdot, J.get());
model.fJ(t_, 0.0, x_solver_, dx_solver_, xdot, J);
// CVODES uses colmajor, so we need to transform to rowmajor
for (int ix = 0; ix < model.nx_solver; ix++)
for (int jx = 0; jx < model.nx_solver; jx++)
Expand Down
5 changes: 5 additions & 0 deletions include/amici/sundials_matrix_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,11 @@ class SUNMatrixWrapper {

~SUNMatrixWrapper();

/**
* @brief Conversion function.
*/
operator SUNMatrix() { return get(); };

/**
* @brief Copy constructor
* @param other
Expand Down
8 changes: 4 additions & 4 deletions src/model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2249,7 +2249,7 @@ void Model::fdJydy(int const it, AmiVector const& x, ExpData const& edata) {

auto tmp_sparse = SUNMatrixWrapper(tmp_dense, 0.0, CSC_MAT);
auto ret = SUNMatScaleAdd(
1.0, derived_state_.dJydy_.at(iyt).get(), tmp_sparse.get()
1.0, derived_state_.dJydy_.at(iyt), tmp_sparse
);
if (ret != SUNMAT_SUCCESS) {
throw AmiException(
Expand Down Expand Up @@ -2897,7 +2897,7 @@ void Model::fdwdp(realtype const t, realtype const* x) {
}

if (always_check_finite_) {
checkFinite(derived_state_.dwdp_.get(), ModelQuantity::dwdp, t);
checkFinite(derived_state_.dwdp_, ModelQuantity::dwdp, t);
}
}

Expand Down Expand Up @@ -2943,7 +2943,7 @@ void Model::fdwdx(realtype const t, realtype const* x) {
}

if (always_check_finite_) {
checkFinite(derived_state_.dwdx_.get(), ModelQuantity::dwdx, t);
checkFinite(derived_state_.dwdx_, ModelQuantity::dwdx, t);
}
}

Expand All @@ -2960,7 +2960,7 @@ void Model::fdwdw(realtype const t, realtype const* x) {
);

if (always_check_finite_) {
checkFinite(dwdw_.get(), ModelQuantity::dwdw, t);
checkFinite(dwdw_, ModelQuantity::dwdw, t);
}
}

Expand Down
20 changes: 10 additions & 10 deletions src/model_dae.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ void Model_DAE::fJ(
realtype t, realtype cj, const_N_Vector x, const_N_Vector dx,
const_N_Vector /*xdot*/, SUNMatrix J
) {
fJSparse(t, cj, x, dx, derived_state_.J_.get());
fJSparse(t, cj, x, dx, derived_state_.J_);
derived_state_.J_.refresh();
auto JDense = SUNMatrixWrapper(J);
derived_state_.J_.to_dense(JDense);
Expand Down Expand Up @@ -88,7 +88,7 @@ void Model_DAE::fJv(
N_Vector Jv, realtype cj
) {
N_VConst(0.0, Jv);
fJSparse(t, cj, x, dx, derived_state_.J_.get());
fJSparse(t, cj, x, dx, derived_state_.J_);
derived_state_.J_.refresh();
derived_state_.J_.multiply(Jv, v);
}
Expand Down Expand Up @@ -135,7 +135,7 @@ void Model_DAE::fJDiag(
realtype const t, AmiVector& JDiag, realtype const /*cj*/,
AmiVector const& x, AmiVector const& dx
) {
fJSparse(t, 0.0, x.getNVector(), dx.getNVector(), derived_state_.J_.get());
fJSparse(t, 0.0, x.getNVector(), dx.getNVector(), derived_state_.J_);

Check warning on line 138 in src/model_dae.cpp

View check run for this annotation

Codecov / codecov/patch

src/model_dae.cpp#L138

Added line #L138 was not covered by tests
derived_state_.J_.refresh();
derived_state_.J_.to_diag(JDiag.getNVector());
if (checkFinite(JDiag.getVector(), ModelQuantity::JDiag) != AMICI_SUCCESS)
Expand Down Expand Up @@ -355,7 +355,7 @@ void Model_DAE::fJB(
realtype t, realtype cj, const_N_Vector x, const_N_Vector dx,
const_N_Vector /*xB*/, const_N_Vector /*dxB*/, SUNMatrix JB
) {
fJSparse(t, cj, x, dx, derived_state_.J_.get());
fJSparse(t, cj, x, dx, derived_state_.J_);

Check warning on line 358 in src/model_dae.cpp

View check run for this annotation

Codecov / codecov/patch

src/model_dae.cpp#L358

Added line #L358 was not covered by tests
derived_state_.J_.refresh();
auto JBDense = SUNMatrixWrapper(JB);
derived_state_.J_.transpose(JBDense, -1.0, nxtrue_solver);
Expand All @@ -376,7 +376,7 @@ void Model_DAE::fJSparseB(
realtype t, realtype cj, const_N_Vector x, const_N_Vector dx,
const_N_Vector /*xB*/, const_N_Vector /*dxB*/, SUNMatrix JB
) {
fJSparse(t, cj, x, dx, derived_state_.J_.get());
fJSparse(t, cj, x, dx, derived_state_.J_);

Check warning on line 379 in src/model_dae.cpp

View check run for this annotation

Codecov / codecov/patch

src/model_dae.cpp#L379

Added line #L379 was not covered by tests
derived_state_.J_.refresh();
auto JSparseB = SUNMatrixWrapper(JB);
derived_state_.J_.transpose(JSparseB, -1.0, nxtrue_solver);
Expand All @@ -387,7 +387,7 @@ void Model_DAE::fJvB(
const_N_Vector dxB, const_N_Vector vB, N_Vector JvB, realtype cj
) {
N_VConst(0.0, JvB);
fJSparseB(t, cj, x, dx, xB, dxB, derived_state_.JB_.get());
fJSparseB(t, cj, x, dx, xB, dxB, derived_state_.JB_);

Check warning on line 390 in src/model_dae.cpp

View check run for this annotation

Codecov / codecov/patch

src/model_dae.cpp#L390

Added line #L390 was not covered by tests
derived_state_.JB_.refresh();
derived_state_.JB_.multiply(JvB, vB);
}
Expand All @@ -397,7 +397,7 @@ void Model_DAE::fxBdot(
const_N_Vector dxB, N_Vector xBdot
) {
N_VConst(0.0, xBdot);
fJSparseB(t, 1.0, x, dx, xB, dxB, derived_state_.JB_.get());
fJSparseB(t, 1.0, x, dx, xB, dxB, derived_state_.JB_);

Check warning on line 400 in src/model_dae.cpp

View check run for this annotation

Codecov / codecov/patch

src/model_dae.cpp#L400

Added line #L400 was not covered by tests
derived_state_.JB_.refresh();
fM(t, x);
derived_state_.JB_.multiply(xBdot, xB);
Expand Down Expand Up @@ -454,7 +454,7 @@ void Model_DAE::fqBdot_ss(

void Model_DAE::fJSparseB_ss(SUNMatrix JB) {
/* Just pass the model Jacobian on to JB */
SUNMatCopy(derived_state_.JB_.get(), JB);
SUNMatCopy(derived_state_.JB_, JB);

Check warning on line 457 in src/model_dae.cpp

View check run for this annotation

Codecov / codecov/patch

src/model_dae.cpp#L457

Added line #L457 was not covered by tests
derived_state_.JB_.refresh();
}

Expand All @@ -465,7 +465,7 @@ void Model_DAE::writeSteadystateJB(
/* Get backward Jacobian */
fJSparseB(
t, cj, x.getNVector(), dx.getNVector(), xB.getNVector(),
dxB.getNVector(), derived_state_.JB_.get()
dxB.getNVector(), derived_state_.JB_
);
derived_state_.JB_.refresh();
/* Switch sign, as we integrate forward in time, not backward */
Expand All @@ -491,7 +491,7 @@ void Model_DAE::fsxdot(
// the same for all remaining
fM(t, x);
fdxdotdp(t, x, dx);
fJSparse(t, 0.0, x, dx, derived_state_.J_.get());
fJSparse(t, 0.0, x, dx, derived_state_.J_);
derived_state_.J_.refresh();
}

Expand Down
20 changes: 10 additions & 10 deletions src/model_ode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ void Model_ODE::fJ(
void Model_ODE::fJ(
realtype t, const_N_Vector x, const_N_Vector /*xdot*/, SUNMatrix J
) {
fJSparse(t, x, derived_state_.J_.get());
fJSparse(t, x, derived_state_.J_);
derived_state_.J_.refresh();
auto JDense = SUNMatrixWrapper(J);
derived_state_.J_.to_dense(JDense);
Expand Down Expand Up @@ -77,7 +77,7 @@ void Model_ODE::fJv(
const_N_Vector v, N_Vector Jv, realtype t, const_N_Vector x
) {
N_VConst(0.0, Jv);
fJSparse(t, x, derived_state_.J_.get());
fJSparse(t, x, derived_state_.J_);

Check warning on line 80 in src/model_ode.cpp

View check run for this annotation

Codecov / codecov/patch

src/model_ode.cpp#L80

Added line #L80 was not covered by tests
derived_state_.J_.refresh();
derived_state_.J_.multiply(Jv, v);
}
Expand Down Expand Up @@ -355,7 +355,7 @@ void Model_ODE::fJB(
realtype t, const_N_Vector x, const_N_Vector /*xB*/,
const_N_Vector /*xBdot*/, SUNMatrix JB
) {
fJSparse(t, x, derived_state_.J_.get());
fJSparse(t, x, derived_state_.J_);

Check warning on line 358 in src/model_ode.cpp

View check run for this annotation

Codecov / codecov/patch

src/model_ode.cpp#L358

Added line #L358 was not covered by tests
derived_state_.J_.refresh();
auto JDenseB = SUNMatrixWrapper(JB);
derived_state_.J_.transpose(JDenseB, -1.0, nxtrue_solver);
Expand All @@ -373,14 +373,14 @@ void Model_ODE::fJSparseB(
realtype t, const_N_Vector x, const_N_Vector /*xB*/,
const_N_Vector /*xBdot*/, SUNMatrix JB
) {
fJSparse(t, x, derived_state_.J_.get());
fJSparse(t, x, derived_state_.J_);
derived_state_.J_.refresh();
auto JSparseB = SUNMatrixWrapper(JB);
derived_state_.J_.transpose(JSparseB, -1.0, nxtrue_solver);
}

void Model_ODE::fJDiag(realtype t, N_Vector JDiag, const_N_Vector x) {
fJSparse(t, x, derived_state_.J_.get());
fJSparse(t, x, derived_state_.J_);

Check warning on line 383 in src/model_ode.cpp

View check run for this annotation

Codecov / codecov/patch

src/model_ode.cpp#L383

Added line #L383 was not covered by tests
derived_state_.J_.refresh();
derived_state_.J_.to_diag(JDiag);
}
Expand All @@ -390,14 +390,14 @@ void Model_ODE::fJvB(
const_N_Vector xB
) {
N_VConst(0.0, JvB);
fJSparseB(t, x, xB, nullptr, derived_state_.JB_.get());
fJSparseB(t, x, xB, nullptr, derived_state_.JB_);

Check warning on line 393 in src/model_ode.cpp

View check run for this annotation

Codecov / codecov/patch

src/model_ode.cpp#L393

Added line #L393 was not covered by tests
derived_state_.JB_.refresh();
derived_state_.JB_.multiply(JvB, vB);
}

void Model_ODE::fxBdot(realtype t, N_Vector x, N_Vector xB, N_Vector xBdot) {
N_VConst(0.0, xBdot);
fJSparseB(t, x, xB, nullptr, derived_state_.JB_.get());
fJSparseB(t, x, xB, nullptr, derived_state_.JB_);
derived_state_.JB_.refresh();
derived_state_.JB_.multiply(xBdot, xB);
}
Expand Down Expand Up @@ -456,7 +456,7 @@ void Model_ODE::fqBdot_ss(realtype /*t*/, N_Vector xB, N_Vector qBdot) const {

void Model_ODE::fJSparseB_ss(SUNMatrix JB) {
/* Just copy the model Jacobian */
SUNMatCopy(derived_state_.JB_.get(), JB);
SUNMatCopy(derived_state_.JB_, JB);
derived_state_.JB_.refresh();
}

Expand All @@ -468,7 +468,7 @@ void Model_ODE::writeSteadystateJB(
/* Get backward Jacobian */
fJSparseB(
t, x.getNVector(), xB.getNVector(), xBdot.getNVector(),
derived_state_.JB_.get()
derived_state_.JB_
);
derived_state_.JB_.refresh();
/* Switch sign, as we integrate forward in time, not backward */
Expand All @@ -492,7 +492,7 @@ void Model_ODE::fsxdot(
// we only need to call this for the first parameter index will be
// the same for all remaining
fdxdotdp(t, x);
fJSparse(t, x, derived_state_.J_.get());
fJSparse(t, x, derived_state_.J_);
derived_state_.J_.refresh();
}
if (pythonGenerated) {
Expand Down
26 changes: 13 additions & 13 deletions src/newton_solver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ void NewtonSolver::computeNewtonSensis(
NewtonSolverDense::NewtonSolverDense(Model const& model)
: NewtonSolver(model)
, Jtmp_(model.nx_solver, model.nx_solver)
, linsol_(SUNLinSol_Dense(x_.getNVector(), Jtmp_.get())) {
, linsol_(SUNLinSol_Dense(x_.getNVector(), Jtmp_)) {
auto status = SUNLinSolInitialize_Dense(linsol_);
if (status != SUNLS_SUCCESS)
throw NewtonFailure(status, "SUNLinSolInitialize_Dense");
Expand All @@ -117,26 +117,26 @@ NewtonSolverDense::NewtonSolverDense(Model const& model)
void NewtonSolverDense::prepareLinearSystem(
Model& model, SimulationState const& state
) {
model.fJ(state.t, 0.0, state.x, state.dx, xdot_, Jtmp_.get());
model.fJ(state.t, 0.0, state.x, state.dx, xdot_, Jtmp_);
Jtmp_.refresh();
auto status = SUNLinSolSetup_Dense(linsol_, Jtmp_.get());
auto status = SUNLinSolSetup_Dense(linsol_, Jtmp_);
if (status != SUNLS_SUCCESS)
throw NewtonFailure(status, "SUNLinSolSetup_Dense");
}

void NewtonSolverDense::prepareLinearSystemB(
Model& model, SimulationState const& state
) {
model.fJB(state.t, 0.0, state.x, state.dx, xB_, dxB_, xdot_, Jtmp_.get());
model.fJB(state.t, 0.0, state.x, state.dx, xB_, dxB_, xdot_, Jtmp_);

Check warning on line 130 in src/newton_solver.cpp

View check run for this annotation

Codecov / codecov/patch

src/newton_solver.cpp#L130

Added line #L130 was not covered by tests
Jtmp_.refresh();
auto status = SUNLinSolSetup_Dense(linsol_, Jtmp_.get());
auto status = SUNLinSolSetup_Dense(linsol_, Jtmp_);

Check warning on line 132 in src/newton_solver.cpp

View check run for this annotation

Codecov / codecov/patch

src/newton_solver.cpp#L132

Added line #L132 was not covered by tests
if (status != SUNLS_SUCCESS)
throw NewtonFailure(status, "SUNLinSolSetup_Dense");
}

void NewtonSolverDense::solveLinearSystem(AmiVector& rhs) {
auto status = SUNLinSolSolve_Dense(
linsol_, Jtmp_.get(), rhs.getNVector(), rhs.getNVector(), 0.0
linsol_, Jtmp_, rhs.getNVector(), rhs.getNVector(), 0.0
);
Jtmp_.refresh();
// last argument is tolerance and does not have any influence on result
Expand Down Expand Up @@ -167,7 +167,7 @@ NewtonSolverDense::~NewtonSolverDense() {
NewtonSolverSparse::NewtonSolverSparse(Model const& model)
: NewtonSolver(model)
, Jtmp_(model.nx_solver, model.nx_solver, model.nnz, CSC_MAT)
, linsol_(SUNKLU(x_.getNVector(), Jtmp_.get())) {
, linsol_(SUNKLU(x_.getNVector(), Jtmp_)) {
auto status = SUNLinSolInitialize_KLU(linsol_);
if (status != SUNLS_SUCCESS)
throw NewtonFailure(status, "SUNLinSolInitialize_KLU");
Expand All @@ -177,9 +177,9 @@ void NewtonSolverSparse::prepareLinearSystem(
Model& model, SimulationState const& state
) {
/* Get sparse Jacobian */
model.fJSparse(state.t, 0.0, state.x, state.dx, xdot_, Jtmp_.get());
model.fJSparse(state.t, 0.0, state.x, state.dx, xdot_, Jtmp_);
Jtmp_.refresh();
auto status = SUNLinSolSetup_KLU(linsol_, Jtmp_.get());
auto status = SUNLinSolSetup_KLU(linsol_, Jtmp_);
if (status != SUNLS_SUCCESS)
throw NewtonFailure(status, "SUNLinSolSetup_KLU");
}
Expand All @@ -189,18 +189,18 @@ void NewtonSolverSparse::prepareLinearSystemB(
) {
/* Get sparse Jacobian */
model.fJSparseB(
state.t, 0.0, state.x, state.dx, xB_, dxB_, xdot_, Jtmp_.get()
state.t, 0.0, state.x, state.dx, xB_, dxB_, xdot_, Jtmp_
);
Jtmp_.refresh();
auto status = SUNLinSolSetup_KLU(linsol_, Jtmp_.get());
auto status = SUNLinSolSetup_KLU(linsol_, Jtmp_);
if (status != SUNLS_SUCCESS)
throw NewtonFailure(status, "SUNLinSolSetup_KLU");
}

void NewtonSolverSparse::solveLinearSystem(AmiVector& rhs) {
/* Pass pointer to the linear solver */
auto status = SUNLinSolSolve_KLU(
linsol_, Jtmp_.get(), rhs.getNVector(), rhs.getNVector(), 0.0
linsol_, Jtmp_, rhs.getNVector(), rhs.getNVector(), 0.0
);
// last argument is tolerance and does not have any influence on result

Expand All @@ -211,7 +211,7 @@ void NewtonSolverSparse::solveLinearSystem(AmiVector& rhs) {
void NewtonSolverSparse::reinitialize() {
/* partial reinitialization, don't need to reallocate Jtmp_ */
auto status = SUNLinSol_KLUReInit(
linsol_, Jtmp_.get(), Jtmp_.capacity(), SUNKLU_REINIT_PARTIAL
linsol_, Jtmp_, Jtmp_.capacity(), SUNKLU_REINIT_PARTIAL
);
if (status != SUNLS_SUCCESS)
throw NewtonFailure(status, "SUNLinSol_KLUReInit");
Expand Down
8 changes: 4 additions & 4 deletions src/sundials_linsol_wrapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ SUNLinSolBand::SUNLinSolBand(N_Vector x, SUNMatrix A)

SUNLinSolBand::SUNLinSolBand(AmiVector const& x, int ubw, int lbw)
: A_(SUNMatrixWrapper(x.getLength(), ubw, lbw)) {
solver_ = SUNLinSol_Band(const_cast<N_Vector>(x.getNVector()), A_.get());
solver_ = SUNLinSol_Band(const_cast<N_Vector>(x.getNVector()), A_);

Check warning on line 164 in src/sundials_linsol_wrapper.cpp

View check run for this annotation

Codecov / codecov/patch

src/sundials_linsol_wrapper.cpp#L164

Added line #L164 was not covered by tests
if (!solver_)
throw AmiException("Failed to create solver.");
}
Expand All @@ -170,7 +170,7 @@ SUNMatrix SUNLinSolBand::getMatrix() const { return A_.get(); }

SUNLinSolDense::SUNLinSolDense(AmiVector const& x)
: A_(SUNMatrixWrapper(x.getLength(), x.getLength())) {
solver_ = SUNLinSol_Dense(const_cast<N_Vector>(x.getNVector()), A_.get());
solver_ = SUNLinSol_Dense(const_cast<N_Vector>(x.getNVector()), A_);
if (!solver_)
throw AmiException("Failed to create solver.");
}
Expand All @@ -187,7 +187,7 @@ SUNLinSolKLU::SUNLinSolKLU(
AmiVector const& x, int nnz, int sparsetype, StateOrdering ordering
)
: A_(SUNMatrixWrapper(x.getLength(), x.getLength(), nnz, sparsetype)) {
solver_ = SUNLinSol_KLU(const_cast<N_Vector>(x.getNVector()), A_.get());
solver_ = SUNLinSol_KLU(const_cast<N_Vector>(x.getNVector()), A_);
if (!solver_)
throw AmiException("Failed to create solver.");

Expand All @@ -197,7 +197,7 @@ SUNLinSolKLU::SUNLinSolKLU(
SUNMatrix SUNLinSolKLU::getMatrix() const { return A_.get(); }

void SUNLinSolKLU::reInit(int nnz, int reinit_type) {
int status = SUNLinSol_KLUReInit(solver_, A_.get(), nnz, reinit_type);
int status = SUNLinSol_KLUReInit(solver_, A_, nnz, reinit_type);

Check warning on line 200 in src/sundials_linsol_wrapper.cpp

View check run for this annotation

Codecov / codecov/patch

src/sundials_linsol_wrapper.cpp#L200

Added line #L200 was not covered by tests
if (status != SUNLS_SUCCESS)
throw AmiException("SUNLinSol_KLUReInit failed with %d", status);
}
Expand Down
8 changes: 4 additions & 4 deletions tests/cpp/unittests/testMisc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -681,7 +681,7 @@ TEST(UnravelIndex, UnravelIndexSunMatDense)
A.set_data(2, 1, 5);

for(int i = 0; i < 6; ++i) {
auto idx = unravel_index(i, A.get());
auto idx = unravel_index(i, A);
EXPECT_EQ(A.get_data(idx.first, idx.second), i);
}
}
Expand All @@ -706,7 +706,7 @@ TEST(UnravelIndex, UnravelIndexSunMatSparse)
D.set_data(2, 1, 0);
D.set_data(3, 1, 0);

auto S = SUNSparseFromDenseMatrix(D.get(), 1e-15, CSC_MAT);
auto S = SUNSparseFromDenseMatrix(D, 1e-15, CSC_MAT);

EXPECT_EQ(unravel_index(0, S), std::make_pair((sunindextype) 2, (sunindextype) 0));
EXPECT_EQ(unravel_index(1, S), std::make_pair((sunindextype) 3, (sunindextype) 0));
Expand All @@ -720,8 +720,8 @@ TEST(UnravelIndex, UnravelIndexSunMatSparseMissingIndices)
{
// Sparse matrix without any indices set
SUNMatrixWrapper mat = SUNMatrixWrapper(2, 3, 2, CSC_MAT);
EXPECT_EQ(unravel_index(0, mat.get()), std::make_pair((sunindextype) -1, (sunindextype) -1));
EXPECT_EQ(unravel_index(1, mat.get()), std::make_pair((sunindextype) -1, (sunindextype) -1));
EXPECT_EQ(unravel_index(0, mat), std::make_pair((sunindextype) -1, (sunindextype) -1));
EXPECT_EQ(unravel_index(1, mat), std::make_pair((sunindextype) -1, (sunindextype) -1));
}


Expand Down

0 comments on commit 91e6652

Please sign in to comment.