Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor: smoother conversion from SUNMatrixWrapper to SUNMatrix #2317

Merged
merged 1 commit into from
Feb 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion include/amici/rdata.h
Original file line number Diff line number Diff line change
Expand Up @@ -576,7 +576,7 @@ class ReturnData : public ModelDimensions {

if (!this->J.empty()) {
SUNMatrixWrapper J(nx_solver, nx_solver);
model.fJ(t_, 0.0, x_solver_, dx_solver_, xdot, J.get());
model.fJ(t_, 0.0, x_solver_, dx_solver_, xdot, J);
// CVODES uses colmajor, so we need to transform to rowmajor
for (int ix = 0; ix < model.nx_solver; ix++)
for (int jx = 0; jx < model.nx_solver; jx++)
Expand Down
5 changes: 5 additions & 0 deletions include/amici/sundials_matrix_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,11 @@ class SUNMatrixWrapper {

~SUNMatrixWrapper();

/**
* @brief Conversion function.
*/
operator SUNMatrix() { return get(); };

/**
* @brief Copy constructor
* @param other
Expand Down
8 changes: 4 additions & 4 deletions src/model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2249,7 +2249,7 @@ void Model::fdJydy(int const it, AmiVector const& x, ExpData const& edata) {

auto tmp_sparse = SUNMatrixWrapper(tmp_dense, 0.0, CSC_MAT);
auto ret = SUNMatScaleAdd(
1.0, derived_state_.dJydy_.at(iyt).get(), tmp_sparse.get()
1.0, derived_state_.dJydy_.at(iyt), tmp_sparse
);
if (ret != SUNMAT_SUCCESS) {
throw AmiException(
Expand Down Expand Up @@ -2897,7 +2897,7 @@ void Model::fdwdp(realtype const t, realtype const* x) {
}

if (always_check_finite_) {
checkFinite(derived_state_.dwdp_.get(), ModelQuantity::dwdp, t);
checkFinite(derived_state_.dwdp_, ModelQuantity::dwdp, t);
}
}

Expand Down Expand Up @@ -2943,7 +2943,7 @@ void Model::fdwdx(realtype const t, realtype const* x) {
}

if (always_check_finite_) {
checkFinite(derived_state_.dwdx_.get(), ModelQuantity::dwdx, t);
checkFinite(derived_state_.dwdx_, ModelQuantity::dwdx, t);
}
}

Expand All @@ -2960,7 +2960,7 @@ void Model::fdwdw(realtype const t, realtype const* x) {
);

if (always_check_finite_) {
checkFinite(dwdw_.get(), ModelQuantity::dwdw, t);
checkFinite(dwdw_, ModelQuantity::dwdw, t);
}
}

Expand Down
20 changes: 10 additions & 10 deletions src/model_dae.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
realtype t, realtype cj, const_N_Vector x, const_N_Vector dx,
const_N_Vector /*xdot*/, SUNMatrix J
) {
fJSparse(t, cj, x, dx, derived_state_.J_.get());
fJSparse(t, cj, x, dx, derived_state_.J_);
derived_state_.J_.refresh();
auto JDense = SUNMatrixWrapper(J);
derived_state_.J_.to_dense(JDense);
Expand Down Expand Up @@ -88,7 +88,7 @@
N_Vector Jv, realtype cj
) {
N_VConst(0.0, Jv);
fJSparse(t, cj, x, dx, derived_state_.J_.get());
fJSparse(t, cj, x, dx, derived_state_.J_);
derived_state_.J_.refresh();
derived_state_.J_.multiply(Jv, v);
}
Expand Down Expand Up @@ -135,7 +135,7 @@
realtype const t, AmiVector& JDiag, realtype const /*cj*/,
AmiVector const& x, AmiVector const& dx
) {
fJSparse(t, 0.0, x.getNVector(), dx.getNVector(), derived_state_.J_.get());
fJSparse(t, 0.0, x.getNVector(), dx.getNVector(), derived_state_.J_);

Check warning on line 138 in src/model_dae.cpp

View check run for this annotation

Codecov / codecov/patch

src/model_dae.cpp#L138

Added line #L138 was not covered by tests
derived_state_.J_.refresh();
derived_state_.J_.to_diag(JDiag.getNVector());
if (checkFinite(JDiag.getVector(), ModelQuantity::JDiag) != AMICI_SUCCESS)
Expand Down Expand Up @@ -355,7 +355,7 @@
realtype t, realtype cj, const_N_Vector x, const_N_Vector dx,
const_N_Vector /*xB*/, const_N_Vector /*dxB*/, SUNMatrix JB
) {
fJSparse(t, cj, x, dx, derived_state_.J_.get());
fJSparse(t, cj, x, dx, derived_state_.J_);

Check warning on line 358 in src/model_dae.cpp

View check run for this annotation

Codecov / codecov/patch

src/model_dae.cpp#L358

Added line #L358 was not covered by tests
derived_state_.J_.refresh();
auto JBDense = SUNMatrixWrapper(JB);
derived_state_.J_.transpose(JBDense, -1.0, nxtrue_solver);
Expand All @@ -376,7 +376,7 @@
realtype t, realtype cj, const_N_Vector x, const_N_Vector dx,
const_N_Vector /*xB*/, const_N_Vector /*dxB*/, SUNMatrix JB
) {
fJSparse(t, cj, x, dx, derived_state_.J_.get());
fJSparse(t, cj, x, dx, derived_state_.J_);

Check warning on line 379 in src/model_dae.cpp

View check run for this annotation

Codecov / codecov/patch

src/model_dae.cpp#L379

Added line #L379 was not covered by tests
derived_state_.J_.refresh();
auto JSparseB = SUNMatrixWrapper(JB);
derived_state_.J_.transpose(JSparseB, -1.0, nxtrue_solver);
Expand All @@ -387,7 +387,7 @@
const_N_Vector dxB, const_N_Vector vB, N_Vector JvB, realtype cj
) {
N_VConst(0.0, JvB);
fJSparseB(t, cj, x, dx, xB, dxB, derived_state_.JB_.get());
fJSparseB(t, cj, x, dx, xB, dxB, derived_state_.JB_);

Check warning on line 390 in src/model_dae.cpp

View check run for this annotation

Codecov / codecov/patch

src/model_dae.cpp#L390

Added line #L390 was not covered by tests
derived_state_.JB_.refresh();
derived_state_.JB_.multiply(JvB, vB);
}
Expand All @@ -397,7 +397,7 @@
const_N_Vector dxB, N_Vector xBdot
) {
N_VConst(0.0, xBdot);
fJSparseB(t, 1.0, x, dx, xB, dxB, derived_state_.JB_.get());
fJSparseB(t, 1.0, x, dx, xB, dxB, derived_state_.JB_);

Check warning on line 400 in src/model_dae.cpp

View check run for this annotation

Codecov / codecov/patch

src/model_dae.cpp#L400

Added line #L400 was not covered by tests
derived_state_.JB_.refresh();
fM(t, x);
derived_state_.JB_.multiply(xBdot, xB);
Expand Down Expand Up @@ -454,7 +454,7 @@

void Model_DAE::fJSparseB_ss(SUNMatrix JB) {
/* Just pass the model Jacobian on to JB */
SUNMatCopy(derived_state_.JB_.get(), JB);
SUNMatCopy(derived_state_.JB_, JB);

Check warning on line 457 in src/model_dae.cpp

View check run for this annotation

Codecov / codecov/patch

src/model_dae.cpp#L457

Added line #L457 was not covered by tests
derived_state_.JB_.refresh();
}

Expand All @@ -465,7 +465,7 @@
/* Get backward Jacobian */
fJSparseB(
t, cj, x.getNVector(), dx.getNVector(), xB.getNVector(),
dxB.getNVector(), derived_state_.JB_.get()
dxB.getNVector(), derived_state_.JB_
);
derived_state_.JB_.refresh();
/* Switch sign, as we integrate forward in time, not backward */
Expand All @@ -491,7 +491,7 @@
// the same for all remaining
fM(t, x);
fdxdotdp(t, x, dx);
fJSparse(t, 0.0, x, dx, derived_state_.J_.get());
fJSparse(t, 0.0, x, dx, derived_state_.J_);
derived_state_.J_.refresh();
}

Expand Down
20 changes: 10 additions & 10 deletions src/model_ode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
void Model_ODE::fJ(
realtype t, const_N_Vector x, const_N_Vector /*xdot*/, SUNMatrix J
) {
fJSparse(t, x, derived_state_.J_.get());
fJSparse(t, x, derived_state_.J_);
derived_state_.J_.refresh();
auto JDense = SUNMatrixWrapper(J);
derived_state_.J_.to_dense(JDense);
Expand Down Expand Up @@ -77,7 +77,7 @@
const_N_Vector v, N_Vector Jv, realtype t, const_N_Vector x
) {
N_VConst(0.0, Jv);
fJSparse(t, x, derived_state_.J_.get());
fJSparse(t, x, derived_state_.J_);

Check warning on line 80 in src/model_ode.cpp

View check run for this annotation

Codecov / codecov/patch

src/model_ode.cpp#L80

Added line #L80 was not covered by tests
derived_state_.J_.refresh();
derived_state_.J_.multiply(Jv, v);
}
Expand Down Expand Up @@ -355,7 +355,7 @@
realtype t, const_N_Vector x, const_N_Vector /*xB*/,
const_N_Vector /*xBdot*/, SUNMatrix JB
) {
fJSparse(t, x, derived_state_.J_.get());
fJSparse(t, x, derived_state_.J_);

Check warning on line 358 in src/model_ode.cpp

View check run for this annotation

Codecov / codecov/patch

src/model_ode.cpp#L358

Added line #L358 was not covered by tests
derived_state_.J_.refresh();
auto JDenseB = SUNMatrixWrapper(JB);
derived_state_.J_.transpose(JDenseB, -1.0, nxtrue_solver);
Expand All @@ -373,14 +373,14 @@
realtype t, const_N_Vector x, const_N_Vector /*xB*/,
const_N_Vector /*xBdot*/, SUNMatrix JB
) {
fJSparse(t, x, derived_state_.J_.get());
fJSparse(t, x, derived_state_.J_);
derived_state_.J_.refresh();
auto JSparseB = SUNMatrixWrapper(JB);
derived_state_.J_.transpose(JSparseB, -1.0, nxtrue_solver);
}

void Model_ODE::fJDiag(realtype t, N_Vector JDiag, const_N_Vector x) {
fJSparse(t, x, derived_state_.J_.get());
fJSparse(t, x, derived_state_.J_);

Check warning on line 383 in src/model_ode.cpp

View check run for this annotation

Codecov / codecov/patch

src/model_ode.cpp#L383

Added line #L383 was not covered by tests
derived_state_.J_.refresh();
derived_state_.J_.to_diag(JDiag);
}
Expand All @@ -390,14 +390,14 @@
const_N_Vector xB
) {
N_VConst(0.0, JvB);
fJSparseB(t, x, xB, nullptr, derived_state_.JB_.get());
fJSparseB(t, x, xB, nullptr, derived_state_.JB_);

Check warning on line 393 in src/model_ode.cpp

View check run for this annotation

Codecov / codecov/patch

src/model_ode.cpp#L393

Added line #L393 was not covered by tests
derived_state_.JB_.refresh();
derived_state_.JB_.multiply(JvB, vB);
}

void Model_ODE::fxBdot(realtype t, N_Vector x, N_Vector xB, N_Vector xBdot) {
N_VConst(0.0, xBdot);
fJSparseB(t, x, xB, nullptr, derived_state_.JB_.get());
fJSparseB(t, x, xB, nullptr, derived_state_.JB_);
derived_state_.JB_.refresh();
derived_state_.JB_.multiply(xBdot, xB);
}
Expand Down Expand Up @@ -456,7 +456,7 @@

void Model_ODE::fJSparseB_ss(SUNMatrix JB) {
/* Just copy the model Jacobian */
SUNMatCopy(derived_state_.JB_.get(), JB);
SUNMatCopy(derived_state_.JB_, JB);
derived_state_.JB_.refresh();
}

Expand All @@ -468,7 +468,7 @@
/* Get backward Jacobian */
fJSparseB(
t, x.getNVector(), xB.getNVector(), xBdot.getNVector(),
derived_state_.JB_.get()
derived_state_.JB_
);
derived_state_.JB_.refresh();
/* Switch sign, as we integrate forward in time, not backward */
Expand All @@ -492,7 +492,7 @@
// we only need to call this for the first parameter index will be
// the same for all remaining
fdxdotdp(t, x);
fJSparse(t, x, derived_state_.J_.get());
fJSparse(t, x, derived_state_.J_);
derived_state_.J_.refresh();
}
if (pythonGenerated) {
Expand Down
26 changes: 13 additions & 13 deletions src/newton_solver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@
NewtonSolverDense::NewtonSolverDense(Model const& model)
: NewtonSolver(model)
, Jtmp_(model.nx_solver, model.nx_solver)
, linsol_(SUNLinSol_Dense(x_.getNVector(), Jtmp_.get())) {
, linsol_(SUNLinSol_Dense(x_.getNVector(), Jtmp_)) {
auto status = SUNLinSolInitialize_Dense(linsol_);
if (status != SUNLS_SUCCESS)
throw NewtonFailure(status, "SUNLinSolInitialize_Dense");
Expand All @@ -117,26 +117,26 @@
void NewtonSolverDense::prepareLinearSystem(
Model& model, SimulationState const& state
) {
model.fJ(state.t, 0.0, state.x, state.dx, xdot_, Jtmp_.get());
model.fJ(state.t, 0.0, state.x, state.dx, xdot_, Jtmp_);
Jtmp_.refresh();
auto status = SUNLinSolSetup_Dense(linsol_, Jtmp_.get());
auto status = SUNLinSolSetup_Dense(linsol_, Jtmp_);
if (status != SUNLS_SUCCESS)
throw NewtonFailure(status, "SUNLinSolSetup_Dense");
}

void NewtonSolverDense::prepareLinearSystemB(
Model& model, SimulationState const& state
) {
model.fJB(state.t, 0.0, state.x, state.dx, xB_, dxB_, xdot_, Jtmp_.get());
model.fJB(state.t, 0.0, state.x, state.dx, xB_, dxB_, xdot_, Jtmp_);

Check warning on line 130 in src/newton_solver.cpp

View check run for this annotation

Codecov / codecov/patch

src/newton_solver.cpp#L130

Added line #L130 was not covered by tests
Jtmp_.refresh();
auto status = SUNLinSolSetup_Dense(linsol_, Jtmp_.get());
auto status = SUNLinSolSetup_Dense(linsol_, Jtmp_);

Check warning on line 132 in src/newton_solver.cpp

View check run for this annotation

Codecov / codecov/patch

src/newton_solver.cpp#L132

Added line #L132 was not covered by tests
if (status != SUNLS_SUCCESS)
throw NewtonFailure(status, "SUNLinSolSetup_Dense");
}

void NewtonSolverDense::solveLinearSystem(AmiVector& rhs) {
auto status = SUNLinSolSolve_Dense(
linsol_, Jtmp_.get(), rhs.getNVector(), rhs.getNVector(), 0.0
linsol_, Jtmp_, rhs.getNVector(), rhs.getNVector(), 0.0
);
Jtmp_.refresh();
// last argument is tolerance and does not have any influence on result
Expand Down Expand Up @@ -167,7 +167,7 @@
NewtonSolverSparse::NewtonSolverSparse(Model const& model)
: NewtonSolver(model)
, Jtmp_(model.nx_solver, model.nx_solver, model.nnz, CSC_MAT)
, linsol_(SUNKLU(x_.getNVector(), Jtmp_.get())) {
, linsol_(SUNKLU(x_.getNVector(), Jtmp_)) {
auto status = SUNLinSolInitialize_KLU(linsol_);
if (status != SUNLS_SUCCESS)
throw NewtonFailure(status, "SUNLinSolInitialize_KLU");
Expand All @@ -177,9 +177,9 @@
Model& model, SimulationState const& state
) {
/* Get sparse Jacobian */
model.fJSparse(state.t, 0.0, state.x, state.dx, xdot_, Jtmp_.get());
model.fJSparse(state.t, 0.0, state.x, state.dx, xdot_, Jtmp_);
Jtmp_.refresh();
auto status = SUNLinSolSetup_KLU(linsol_, Jtmp_.get());
auto status = SUNLinSolSetup_KLU(linsol_, Jtmp_);
if (status != SUNLS_SUCCESS)
throw NewtonFailure(status, "SUNLinSolSetup_KLU");
}
Expand All @@ -189,18 +189,18 @@
) {
/* Get sparse Jacobian */
model.fJSparseB(
state.t, 0.0, state.x, state.dx, xB_, dxB_, xdot_, Jtmp_.get()
state.t, 0.0, state.x, state.dx, xB_, dxB_, xdot_, Jtmp_
);
Jtmp_.refresh();
auto status = SUNLinSolSetup_KLU(linsol_, Jtmp_.get());
auto status = SUNLinSolSetup_KLU(linsol_, Jtmp_);
if (status != SUNLS_SUCCESS)
throw NewtonFailure(status, "SUNLinSolSetup_KLU");
}

void NewtonSolverSparse::solveLinearSystem(AmiVector& rhs) {
/* Pass pointer to the linear solver */
auto status = SUNLinSolSolve_KLU(
linsol_, Jtmp_.get(), rhs.getNVector(), rhs.getNVector(), 0.0
linsol_, Jtmp_, rhs.getNVector(), rhs.getNVector(), 0.0
);
// last argument is tolerance and does not have any influence on result

Expand All @@ -211,7 +211,7 @@
void NewtonSolverSparse::reinitialize() {
/* partial reinitialization, don't need to reallocate Jtmp_ */
auto status = SUNLinSol_KLUReInit(
linsol_, Jtmp_.get(), Jtmp_.capacity(), SUNKLU_REINIT_PARTIAL
linsol_, Jtmp_, Jtmp_.capacity(), SUNKLU_REINIT_PARTIAL
);
if (status != SUNLS_SUCCESS)
throw NewtonFailure(status, "SUNLinSol_KLUReInit");
Expand Down
8 changes: 4 additions & 4 deletions src/sundials_linsol_wrapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@

SUNLinSolBand::SUNLinSolBand(AmiVector const& x, int ubw, int lbw)
: A_(SUNMatrixWrapper(x.getLength(), ubw, lbw)) {
solver_ = SUNLinSol_Band(const_cast<N_Vector>(x.getNVector()), A_.get());
solver_ = SUNLinSol_Band(const_cast<N_Vector>(x.getNVector()), A_);

Check warning on line 164 in src/sundials_linsol_wrapper.cpp

View check run for this annotation

Codecov / codecov/patch

src/sundials_linsol_wrapper.cpp#L164

Added line #L164 was not covered by tests
if (!solver_)
throw AmiException("Failed to create solver.");
}
Expand All @@ -170,7 +170,7 @@

SUNLinSolDense::SUNLinSolDense(AmiVector const& x)
: A_(SUNMatrixWrapper(x.getLength(), x.getLength())) {
solver_ = SUNLinSol_Dense(const_cast<N_Vector>(x.getNVector()), A_.get());
solver_ = SUNLinSol_Dense(const_cast<N_Vector>(x.getNVector()), A_);
if (!solver_)
throw AmiException("Failed to create solver.");
}
Expand All @@ -187,7 +187,7 @@
AmiVector const& x, int nnz, int sparsetype, StateOrdering ordering
)
: A_(SUNMatrixWrapper(x.getLength(), x.getLength(), nnz, sparsetype)) {
solver_ = SUNLinSol_KLU(const_cast<N_Vector>(x.getNVector()), A_.get());
solver_ = SUNLinSol_KLU(const_cast<N_Vector>(x.getNVector()), A_);
if (!solver_)
throw AmiException("Failed to create solver.");

Expand All @@ -197,7 +197,7 @@
SUNMatrix SUNLinSolKLU::getMatrix() const { return A_.get(); }

void SUNLinSolKLU::reInit(int nnz, int reinit_type) {
int status = SUNLinSol_KLUReInit(solver_, A_.get(), nnz, reinit_type);
int status = SUNLinSol_KLUReInit(solver_, A_, nnz, reinit_type);

Check warning on line 200 in src/sundials_linsol_wrapper.cpp

View check run for this annotation

Codecov / codecov/patch

src/sundials_linsol_wrapper.cpp#L200

Added line #L200 was not covered by tests
if (status != SUNLS_SUCCESS)
throw AmiException("SUNLinSol_KLUReInit failed with %d", status);
}
Expand Down
8 changes: 4 additions & 4 deletions tests/cpp/unittests/testMisc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -681,7 +681,7 @@ TEST(UnravelIndex, UnravelIndexSunMatDense)
A.set_data(2, 1, 5);

for(int i = 0; i < 6; ++i) {
auto idx = unravel_index(i, A.get());
auto idx = unravel_index(i, A);
EXPECT_EQ(A.get_data(idx.first, idx.second), i);
}
}
Expand All @@ -706,7 +706,7 @@ TEST(UnravelIndex, UnravelIndexSunMatSparse)
D.set_data(2, 1, 0);
D.set_data(3, 1, 0);

auto S = SUNSparseFromDenseMatrix(D.get(), 1e-15, CSC_MAT);
auto S = SUNSparseFromDenseMatrix(D, 1e-15, CSC_MAT);

EXPECT_EQ(unravel_index(0, S), std::make_pair((sunindextype) 2, (sunindextype) 0));
EXPECT_EQ(unravel_index(1, S), std::make_pair((sunindextype) 3, (sunindextype) 0));
Expand All @@ -720,8 +720,8 @@ TEST(UnravelIndex, UnravelIndexSunMatSparseMissingIndices)
{
// Sparse matrix without any indices set
SUNMatrixWrapper mat = SUNMatrixWrapper(2, 3, 2, CSC_MAT);
EXPECT_EQ(unravel_index(0, mat.get()), std::make_pair((sunindextype) -1, (sunindextype) -1));
EXPECT_EQ(unravel_index(1, mat.get()), std::make_pair((sunindextype) -1, (sunindextype) -1));
EXPECT_EQ(unravel_index(0, mat), std::make_pair((sunindextype) -1, (sunindextype) -1));
EXPECT_EQ(unravel_index(1, mat), std::make_pair((sunindextype) -1, (sunindextype) -1));
}


Expand Down