Skip to content

Commit

Permalink
Refactor SUNLinSolWrapper
Browse files Browse the repository at this point in the history
* Make it consistent that SUNLinSolWrapper always holds the associated matrix.
* Always use SUNMatrixWrapper instead of raw SUNMatrix objects

This makes it a bit easier to finally address #1164.
  • Loading branch information
dweindl committed Sep 30, 2024
1 parent 99a1636 commit ea082ee
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 66 deletions.
62 changes: 22 additions & 40 deletions include/amici/sundials_linsol_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,20 @@ class SUNLinSolWrapper {

/**
* @brief Wrap existing SUNLinearSolver
* @param linsol
*
* @param linsol SUNLinSolWrapper takes ownership of `linsol`.
*/
explicit SUNLinSolWrapper(SUNLinearSolver linsol);

/**
* @brief Wrap existing SUNLinearSolver
*
* @param linsol SUNLinSolWrapper takes ownership of `linsol`.
* @param A Matrix
*/
explicit SUNLinSolWrapper(SUNLinearSolver linsol, SUNMatrixWrapper const& A);


virtual ~SUNLinSolWrapper();

/**
Expand Down Expand Up @@ -80,26 +90,17 @@ class SUNLinSolWrapper {
/**
* @brief Performs any linear solver setup needed, based on an updated
* system matrix A.
* @param A
*/
void setup(SUNMatrix A) const;

/**
* @brief Performs any linear solver setup needed, based on an updated
* system matrix A.
* @param A
*/
void setup(SUNMatrixWrapper const& A) const;
void setup() const;

/**
* @brief Solves a linear system A*x = b
* @param A
* @param x A template for cloning vectors needed within the solver.
* @param b
* @param tol Tolerance (weighted 2-norm), iterative solvers only
* @return error flag
*/
int Solve(SUNMatrix A, N_Vector x, N_Vector b, realtype tol) const;
int solve(N_Vector x, N_Vector b, realtype tol) const;

/**
* @brief Returns the last error flag encountered within the linear solver
Expand Down Expand Up @@ -131,6 +132,10 @@ class SUNLinSolWrapper {

/** Wrapped solver */
SUNLinearSolver solver_{nullptr};

/** Matrix A for solver. */
SUNMatrixWrapper A_;

};

/**
Expand All @@ -139,12 +144,12 @@ class SUNLinSolWrapper {
class SUNLinSolBand : public SUNLinSolWrapper {
public:
/**
* @brief Create solver using existing matrix A without taking ownership of
* A.
* @brief Create solver using existing matrix A
*
* @param x A template for cloning vectors needed within the solver.
* @param A square matrix
*/
SUNLinSolBand(N_Vector x, SUNMatrix A);
SUNLinSolBand(N_Vector x, SUNMatrixWrapper A);

/**
* @brief Create new band solver and matrix A.
Expand All @@ -154,11 +159,6 @@ class SUNLinSolBand : public SUNLinSolWrapper {
*/
SUNLinSolBand(AmiVector const& x, int ubw, int lbw);

SUNMatrix getMatrix() const override;

private:
/** Matrix A for solver, only if created by here. */
SUNMatrixWrapper A_;
};

/**
Expand All @@ -171,12 +171,6 @@ class SUNLinSolDense : public SUNLinSolWrapper {
* @param x A template for cloning vectors needed within the solver.
*/
explicit SUNLinSolDense(AmiVector const& x);

SUNMatrix getMatrix() const override;

private:
/** Matrix A for solver, only if created by here. */
SUNMatrixWrapper A_;
};

/**
Expand All @@ -192,7 +186,7 @@ class SUNLinSolKLU : public SUNLinSolWrapper {
* @param x A template for cloning vectors needed within the solver.
* @param A sparse matrix
*/
SUNLinSolKLU(N_Vector x, SUNMatrix A);
SUNLinSolKLU(N_Vector x, SUNMatrixWrapper A);

/**
* @brief Create KLU solver and matrix to operate on
Expand All @@ -205,8 +199,6 @@ class SUNLinSolKLU : public SUNLinSolWrapper {
AmiVector const& x, int nnz, int sparsetype, StateOrdering ordering
);

SUNMatrix getMatrix() const override;

/**
* @brief Reinitializes memory and flags for a new factorization
* (symbolic and numeric) to be conducted at the next solver setup call.
Expand All @@ -223,10 +215,6 @@ class SUNLinSolKLU : public SUNLinSolWrapper {
* @param ordering
*/
void setOrdering(StateOrdering ordering);

private:
/** Sparse matrix A for solver, only if created by here. */
SUNMatrixWrapper A_;
};

#ifdef SUNDIALS_SUPERLUMT
Expand All @@ -249,7 +237,7 @@ class SUNLinSolSuperLUMT : public SUNLinSolWrapper {
* @param A sparse matrix
* @param numThreads Number of threads to be used by SuperLUMT
*/
SUNLinSolSuperLUMT(N_Vector x, SUNMatrix A, int numThreads);
SUNLinSolSuperLUMT(N_Vector x, SUNMatrixWrapper A, int numThreads);

/**
* @brief Create SuperLUMT solver and matrix to operate on
Expand Down Expand Up @@ -279,18 +267,12 @@ class SUNLinSolSuperLUMT : public SUNLinSolWrapper {
int numThreads
);

SUNMatrix getMatrix() const override;

/**
* @brief Sets the ordering used by SuperLUMT for reducing fill in the
* linear solve.
* @param ordering
*/
void setOrdering(StateOrdering ordering);

private:
/** Sparse matrix A for solver, only if created by here. */
SUNMatrixWrapper A;
};

#endif
Expand Down
46 changes: 20 additions & 26 deletions src/sundials_linsol_wrapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,11 @@ namespace amici {
SUNLinSolWrapper::SUNLinSolWrapper(SUNLinearSolver linsol)
: solver_(linsol) {}

SUNLinSolWrapper::SUNLinSolWrapper(
SUNLinearSolver linsol, SUNMatrixWrapper const& A
) : solver_(linsol),
A_(A){}

SUNLinSolWrapper::~SUNLinSolWrapper() {
if (solver_)
SUNLinSolFree(solver_);
Expand All @@ -31,19 +36,15 @@ int SUNLinSolWrapper::initialize() {
return res;
}

void SUNLinSolWrapper::setup(SUNMatrix A) const {
auto res = SUNLinSolSetup(solver_, A);
void SUNLinSolWrapper::setup() const {
auto res = SUNLinSolSetup(solver_, A_.get());
if (res != SUNLS_SUCCESS)
throw AmiException("Solver setup failed with code %d", res);
}

void SUNLinSolWrapper::setup(SUNMatrixWrapper const& A) const {
return setup(A.get());
}

int SUNLinSolWrapper::Solve(SUNMatrix A, N_Vector x, N_Vector b, realtype tol)
int SUNLinSolWrapper::solve(N_Vector x, N_Vector b, realtype tol)
const {
return SUNLinSolSolve(solver_, A, x, b, tol);
return SUNLinSolSolve(solver_, A_.get(), x, b, tol);
}

long SUNLinSolWrapper::getLastFlag() const {
Expand All @@ -54,7 +55,7 @@ int SUNLinSolWrapper::space(long* lenrwLS, long* leniwLS) const {
return SUNLinSolSpace(solver_, lenrwLS, leniwLS);
}

SUNMatrix SUNLinSolWrapper::getMatrix() const { return nullptr; }
SUNMatrix SUNLinSolWrapper::getMatrix() const { return A_.get(); }

SUNNonLinSolWrapper::SUNNonLinSolWrapper(SUNNonlinearSolver sol)
: solver(sol) {}
Expand Down Expand Up @@ -153,31 +154,28 @@ void SUNNonLinSolWrapper::initialize() {
);
}

SUNLinSolBand::SUNLinSolBand(N_Vector x, SUNMatrix A)
SUNLinSolBand::SUNLinSolBand(N_Vector x, SUNMatrixWrapper A)
: SUNLinSolWrapper(SUNLinSol_Band(x, A)) {
if (!solver_)
throw AmiException("Failed to create solver.");
}

SUNLinSolBand::SUNLinSolBand(AmiVector const& x, int ubw, int lbw)
: A_(SUNMatrixWrapper(x.getLength(), ubw, lbw)) {
: SUNLinSolWrapper(nullptr, SUNMatrixWrapper(x.getLength(), ubw, lbw))
{
solver_ = SUNLinSol_Band(const_cast<N_Vector>(x.getNVector()), A_);
if (!solver_)
throw AmiException("Failed to create solver.");
}

SUNMatrix SUNLinSolBand::getMatrix() const { return A_.get(); }

SUNLinSolDense::SUNLinSolDense(AmiVector const& x)
: A_(SUNMatrixWrapper(x.getLength(), x.getLength())) {
: SUNLinSolWrapper(nullptr, SUNMatrixWrapper(x.getLength(), x.getLength())) {
solver_ = SUNLinSol_Dense(const_cast<N_Vector>(x.getNVector()), A_);
if (!solver_)
throw AmiException("Failed to create solver.");
}

SUNMatrix SUNLinSolDense::getMatrix() const { return A_.get(); }

SUNLinSolKLU::SUNLinSolKLU(N_Vector x, SUNMatrix A)
SUNLinSolKLU::SUNLinSolKLU(N_Vector x, SUNMatrixWrapper A)
: SUNLinSolWrapper(SUNLinSol_KLU(x, A)) {
if (!solver_)
throw AmiException("Failed to create solver.");
Expand All @@ -186,16 +184,14 @@ SUNLinSolKLU::SUNLinSolKLU(N_Vector x, SUNMatrix A)
SUNLinSolKLU::SUNLinSolKLU(
AmiVector const& x, int nnz, int sparsetype, StateOrdering ordering
)
: A_(SUNMatrixWrapper(x.getLength(), x.getLength(), nnz, sparsetype)) {
: SUNLinSolWrapper(nullptr, SUNMatrixWrapper(x.getLength(), x.getLength(), nnz, sparsetype)) {
solver_ = SUNLinSol_KLU(const_cast<N_Vector>(x.getNVector()), A_);
if (!solver_)
throw AmiException("Failed to create solver.");

setOrdering(ordering);
}

SUNMatrix SUNLinSolKLU::getMatrix() const { return A_.get(); }

void SUNLinSolKLU::reInit(int nnz, int reinit_type) {
int status = SUNLinSol_KLUReInit(solver_, A_, nnz, reinit_type);
if (status != SUNLS_SUCCESS)
Expand Down Expand Up @@ -413,8 +409,8 @@ int SUNNonLinSolFixedPoint::getSysFn(SUNNonlinSolSysFn* SysFn) const {

#ifdef SUNDIALS_SUPERLUMT

SUNLinSolSuperLUMT::SUNLinSolSuperLUMT(N_Vector x, SUNMatrix A, int numThreads)
: SUNLinSolWrapper(SUNLinSol_SuperLUMT(x, A, numThreads)) {
SUNLinSolSuperLUMT::SUNLinSolSuperLUMT(N_Vector x, SUNMatrixWrapper A, int numThreads)
: SUNLinSolWrapper(SUNLinSol_SuperLUMT(x, A, numThreads), A) {
if (!solver)
throw AmiException("Failed to create solver.");
}
Expand All @@ -423,7 +419,7 @@ SUNLinSolSuperLUMT::SUNLinSolSuperLUMT(
AmiVector const& x, int nnz, int sparsetype,
SUNLinSolSuperLUMT::StateOrdering ordering
)
: A(SUNMatrixWrapper(x.getLength(), x.getLength(), nnz, sparsetype)) {
: SUNLinSolWrapper(nullptr, SUNMatrixWrapper(x.getLength(), x.getLength(), nnz, sparsetype)) {
int numThreads = 1;
if (auto env = std::getenv("AMICI_SUPERLUMT_NUM_THREADS")) {
numThreads = std::max(1, std::stoi(env));
Expand All @@ -440,16 +436,14 @@ SUNLinSolSuperLUMT::SUNLinSolSuperLUMT(
AmiVector const& x, int nnz, int sparsetype, StateOrdering ordering,
int numThreads
)
: A(SUNMatrixWrapper(x.getLength(), x.getLength(), nnz, sparsetype)) {
: SUNLinSolWrapper(nullptr, SUNMatrixWrapper(x.getLength(), x.getLength(), nnz, sparsetype)) {
solver = SUNLinSol_SuperLUMT(x.getNVector(), A.get(), numThreads);
if (!solver)
throw AmiException("Failed to create solver.");

setOrdering(ordering);
}

SUNMatrix SUNLinSolSuperLUMT::getMatrix() const { return A.get(); }

void SUNLinSolSuperLUMT::setOrdering(StateOrdering ordering) {
auto status
= SUNLinSol_SuperLUMTSetOrdering(solver, static_cast<int>(ordering));
Expand Down

0 comments on commit ea082ee

Please sign in to comment.