Skip to content

Commit

Permalink
Allow subselection of state variables for steady-state simulations
Browse files Browse the repository at this point in the history
  • Loading branch information
dweindl committed Mar 26, 2024
1 parent 790ab44 commit c6528c4
Show file tree
Hide file tree
Showing 5 changed files with 143 additions and 9 deletions.
43 changes: 43 additions & 0 deletions include/amici/model.h
Original file line number Diff line number Diff line change
Expand Up @@ -1481,6 +1481,40 @@ class Model : public AbstractModel, public ModelDimensions {
*/
virtual std::vector<double> get_trigger_timepoints() const;

/**
* @brief Get steady-state mask as std::vector.
*
* See `set_steadystate_mask` for details.
*
* @return Steady-state mask
*/
std::vector<double> get_steadystate_mask() const {
return steadystate_mask_.getVector();
};

/**
* @brief Get steady-state mask as AmiVector.
*
* See `set_steadystate_mask` for details.
* @return Steady-state mask
*/
AmiVector const& get_steadystate_mask_av() const {
return steadystate_mask_;
};

/**
* @brief Set steady-state mask.
*
* The mask is used to exclude certain state variables from the steady-state
* convergence check. Positive values indicate that the corresponding state
* variable should be included in the convergence check, while non-positive
* values indicate that the corresponding state variable should be excluded.
* An empty mask is interpreted as including all state variables.
*
* @param mask Mask of length `nx_solver`.
*/
void set_steadystate_mask(std::vector<double> const& mask);

/**
* Flag indicating whether for
* `amici::Solver::sensi_` == `amici::SensitivityOrder::second`
Expand Down Expand Up @@ -2087,6 +2121,15 @@ class Model : public AbstractModel, public ModelDimensions {

/** Simulation parameters, initial state, etc. */
SimulationParameters simulation_parameters_;

/**
* Mask for state variables that should be checked for steady state
* during pre-/post-equilibration. Positive values indicate that the
* corresponding state variable should be checked for steady state.
* Negative values indicate that the corresponding state variable should
* be ignored.
*/
AmiVector steadystate_mask_;
};

bool operator==(Model const& a, Model const& b);
Expand Down
6 changes: 4 additions & 2 deletions include/amici/steadystateproblem.h
Original file line number Diff line number Diff line change
Expand Up @@ -246,14 +246,16 @@ class SteadystateProblem {
* w_i = 1 / ( rtol * x_i + atol )
* @param x current state (sx[ip] for sensitivities)
* @param xdot current rhs (sxdot[ip] for sensitivities)
* @param mask mask for state variables to include in WRMS norm.
* Positive value: include; non-positive value: exclude; empty: include all.
* @param atol absolute tolerance
* @param rtol relative tolerance
* @param ewt error weight vector
* @return root-mean-square norm
*/
realtype getWrmsNorm(
AmiVector const& x, AmiVector const& xdot, realtype atol, realtype rtol,
AmiVector& ewt
AmiVector const& x, AmiVector const& xdot, AmiVector const& mask,
realtype atol, realtype rtol, AmiVector& ewt
) const;

/**
Expand Down
59 changes: 59 additions & 0 deletions python/tests/test_preequilibration.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from amici.debugging import get_model_for_preeq
from numpy.testing import assert_allclose, assert_equal
from test_pysb import get_data
from amici.testing import TemporaryDirectoryWinSafe as TemporaryDirectory


@pytest.fixture
Expand Down Expand Up @@ -658,3 +659,61 @@ def test_get_model_for_preeq(preeq_fixture):
rdata1.sx,
rdata2.sx,
)


def test_partial_eq():
"""Check that partial equilibration is possible."""
from amici.antimony_import import antimony2amici

ant_str = """
model test_partial_eq
explodes = 1
explodes' = explodes
A = 1
B = 0
R: A -> B; k*A - k*B
k = 1
end
"""
module_name = "test_partial_eq"
with TemporaryDirectory(prefix=module_name) as outdir:
antimony2amici(
ant_str,
model_name=module_name,
output_dir=outdir,
)
model_module = amici.import_model_module(
module_name=module_name, module_path=outdir
)
amici_model = model_module.getModel()
amici_model.setTimepoints([np.inf])
amici_solver = amici_model.getSolver()
amici_solver.setRelativeToleranceSteadyState(1e-12)

# equilibration of `explodes` will fail
rdata = amici.runAmiciSimulation(amici_model, amici_solver)
assert rdata.status == amici.AMICI_ERROR
assert rdata.messages[0].identifier == "EQUILIBRATION_FAILURE"

# excluding `explodes` should enable equilibration
amici_model.set_steadystate_mask(
[
0 if state_id == "explodes" else 1
for state_id in amici_model.getStateIdsSolver()
]
)
rdata = amici.runAmiciSimulation(amici_model, amici_solver)
assert rdata.status == amici.AMICI_SUCCESS
assert_allclose(
rdata.by_id("A"),
0.5,
atol=amici_solver.getAbsoluteToleranceSteadyState(),
rtol=amici_solver.getRelativeToleranceSteadyState(),
)
assert_allclose(
rdata.by_id("B"),
0.5,
atol=amici_solver.getAbsoluteToleranceSteadyState(),
rtol=amici_solver.getRelativeToleranceSteadyState(),
)
assert rdata.t_last < 100
17 changes: 17 additions & 0 deletions src/model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3138,6 +3138,23 @@ std::vector<double> Model::get_trigger_timepoints() const {
return trigger_timepoints;
}

void Model::set_steadystate_mask(std::vector<double> const& mask) {
if (mask.size() == 0) {
if (steadystate_mask_.getLength() != 0) {
steadystate_mask_ = AmiVector();
}
return;
}

if (gsl::narrow<int>(mask.size()) != nx_solver)
throw AmiException(
"Steadystate mask has wrong size: %d, expected %d",
gsl::narrow<int>(mask.size()), nx_solver
);

steadystate_mask_ = AmiVector(mask);
}

const_N_Vector Model::computeX_pos(const_N_Vector x) {
if (any_state_non_negative_) {
for (int ix = 0; ix < derived_state_.x_pos_tmp_.getLength(); ++ix) {
Expand Down
27 changes: 20 additions & 7 deletions src/steadystateproblem.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -509,8 +509,8 @@ bool SteadystateProblem::getSensitivityFlag(
}

realtype SteadystateProblem::getWrmsNorm(
AmiVector const& x, AmiVector const& xdot, realtype atol, realtype rtol,
AmiVector& ewt
AmiVector const& x, AmiVector const& xdot, AmiVector const& mask,
realtype atol, realtype rtol, AmiVector& ewt
) const {
/* Depending on what convergence we want to check (xdot, sxdot, xQBdot)
we need to pass ewt[QB], as xdot and xQBdot have different sizes */
Expand All @@ -522,7 +522,14 @@ realtype SteadystateProblem::getWrmsNorm(
N_VAddConst(ewt.getNVector(), atol, ewt.getNVector());
/* ewt = 1/ewt (ewt = 1/(rtol*x+atol)) */
N_VInv(ewt.getNVector(), ewt.getNVector());
/* wrms = sqrt(sum((xdot/ewt)**2)/n) where n = size of state vector */

// wrms = sqrt(sum((xdot/ewt)**2)/n) where n = size of state vector
if (mask.getLength()) {
return N_VWrmsNormMask(
const_cast<N_Vector>(xdot.getNVector()), ewt.getNVector(),
const_cast<N_Vector>(mask.getNVector())
);
}
return N_VWrmsNorm(
const_cast<N_Vector>(xdot.getNVector()), ewt.getNVector()
);
Expand All @@ -543,7 +550,10 @@ SteadystateProblem::getWrms(Model& model, SensitivityMethod sensi_method) {
"Newton type convergence check is not implemented for adjoint "
"steady state computations. Stopping."
);
wrms = getWrmsNorm(xQB_, xQBdot_, atol_quad_, rtol_quad_, ewtQB_);
wrms = getWrmsNorm(
xQB_, xQBdot_, model.get_steadystate_mask_av(), atol_quad_,
rtol_quad_, ewtQB_
);
} else {
/* If we're doing a forward simulation (with or without sensitivities:
Get RHS and compute weighted error norm */
Expand All @@ -552,7 +562,8 @@ SteadystateProblem::getWrms(Model& model, SensitivityMethod sensi_method) {
else
updateRightHandSide(model);
wrms = getWrmsNorm(
state_.x, newton_step_conv_ ? delta_ : xdot_, atol_, rtol_, ewt_
state_.x, newton_step_conv_ ? delta_ : xdot_,
model.get_steadystate_mask_av(), atol_, rtol_, ewt_
);
}
return wrms;
Expand All @@ -573,8 +584,10 @@ realtype SteadystateProblem::getWrmsFSA(Model& model) {
);
if (newton_step_conv_)
newton_solver_->solveLinearSystem(xdot_);
wrms
= getWrmsNorm(state_.sx[ip], xdot_, atol_sensi_, rtol_sensi_, ewt_);
wrms = getWrmsNorm(
state_.sx[ip], xdot_, model.get_steadystate_mask_av(), atol_sensi_,
rtol_sensi_, ewt_
);
/* ideally this function would report the maximum of all wrms over
all ip, but for practical purposes we can just report the wrms for
the first ip where we know that the convergence threshold is not
Expand Down

0 comments on commit c6528c4

Please sign in to comment.