From c6528c4b69a8e557500793cdc7817f60ffd591b1 Mon Sep 17 00:00:00 2001 From: Daniel Weindl Date: Tue, 26 Mar 2024 07:34:14 +0100 Subject: [PATCH] Allow subselection of state variables for steady-state simulations Closes #2368 --- include/amici/model.h | 43 +++++++++++++++++++ include/amici/steadystateproblem.h | 6 ++- python/tests/test_preequilibration.py | 59 +++++++++++++++++++++++++++ src/model.cpp | 17 ++++++++ src/steadystateproblem.cpp | 27 ++++++++---- 5 files changed, 143 insertions(+), 9 deletions(-) diff --git a/include/amici/model.h b/include/amici/model.h index 29b98aa913..1c020400ac 100644 --- a/include/amici/model.h +++ b/include/amici/model.h @@ -1481,6 +1481,40 @@ class Model : public AbstractModel, public ModelDimensions { */ virtual std::vector get_trigger_timepoints() const; + /** + * @brief Get steady-state mask as std::vector. + * + * See `set_steadystate_mask` for details. + * + * @return Steady-state mask + */ + std::vector get_steadystate_mask() const { + return steadystate_mask_.getVector(); + }; + + /** + * @brief Get steady-state mask as AmiVector. + * + * See `set_steadystate_mask` for details. + * @return Steady-state mask + */ + AmiVector const& get_steadystate_mask_av() const { + return steadystate_mask_; + }; + + /** + * @brief Set steady-state mask. + * + * The mask is used to exclude certain state variables from the steady-state + * convergence check. Positive values indicate that the corresponding state + * variable should be included in the convergence check, while non-positive + * values indicate that the corresponding state variable should be excluded. + * An empty mask is interpreted as including all state variables. + * + * @param mask Mask of length `nx_solver`. + */ + void set_steadystate_mask(std::vector const& mask); + /** * Flag indicating whether for * `amici::Solver::sensi_` == `amici::SensitivityOrder::second` @@ -2087,6 +2121,15 @@ class Model : public AbstractModel, public ModelDimensions { /** Simulation parameters, initial state, etc. */ SimulationParameters simulation_parameters_; + + /** + * Mask for state variables that should be checked for steady state + * during pre-/post-equilibration. Positive values indicate that the + * corresponding state variable should be checked for steady state. + * Negative values indicate that the corresponding state variable should + * be ignored. + */ + AmiVector steadystate_mask_; }; bool operator==(Model const& a, Model const& b); diff --git a/include/amici/steadystateproblem.h b/include/amici/steadystateproblem.h index 55c9aaca77..72d248f8bf 100644 --- a/include/amici/steadystateproblem.h +++ b/include/amici/steadystateproblem.h @@ -246,14 +246,16 @@ class SteadystateProblem { * w_i = 1 / ( rtol * x_i + atol ) * @param x current state (sx[ip] for sensitivities) * @param xdot current rhs (sxdot[ip] for sensitivities) + * @param mask mask for state variables to include in WRMS norm. + * Positive value: include; non-positive value: exclude; empty: include all. * @param atol absolute tolerance * @param rtol relative tolerance * @param ewt error weight vector * @return root-mean-square norm */ realtype getWrmsNorm( - AmiVector const& x, AmiVector const& xdot, realtype atol, realtype rtol, - AmiVector& ewt + AmiVector const& x, AmiVector const& xdot, AmiVector const& mask, + realtype atol, realtype rtol, AmiVector& ewt ) const; /** diff --git a/python/tests/test_preequilibration.py b/python/tests/test_preequilibration.py index d003507199..8f9756937f 100644 --- a/python/tests/test_preequilibration.py +++ b/python/tests/test_preequilibration.py @@ -8,6 +8,7 @@ from amici.debugging import get_model_for_preeq from numpy.testing import assert_allclose, assert_equal from test_pysb import get_data +from amici.testing import TemporaryDirectoryWinSafe as TemporaryDirectory @pytest.fixture @@ -658,3 +659,61 @@ def test_get_model_for_preeq(preeq_fixture): rdata1.sx, rdata2.sx, ) + + +def test_partial_eq(): + """Check that partial equilibration is possible.""" + from amici.antimony_import import antimony2amici + + ant_str = """ + model test_partial_eq + explodes = 1 + explodes' = explodes + A = 1 + B = 0 + R: A -> B; k*A - k*B + k = 1 + end + """ + module_name = "test_partial_eq" + with TemporaryDirectory(prefix=module_name) as outdir: + antimony2amici( + ant_str, + model_name=module_name, + output_dir=outdir, + ) + model_module = amici.import_model_module( + module_name=module_name, module_path=outdir + ) + amici_model = model_module.getModel() + amici_model.setTimepoints([np.inf]) + amici_solver = amici_model.getSolver() + amici_solver.setRelativeToleranceSteadyState(1e-12) + + # equilibration of `explodes` will fail + rdata = amici.runAmiciSimulation(amici_model, amici_solver) + assert rdata.status == amici.AMICI_ERROR + assert rdata.messages[0].identifier == "EQUILIBRATION_FAILURE" + + # excluding `explodes` should enable equilibration + amici_model.set_steadystate_mask( + [ + 0 if state_id == "explodes" else 1 + for state_id in amici_model.getStateIdsSolver() + ] + ) + rdata = amici.runAmiciSimulation(amici_model, amici_solver) + assert rdata.status == amici.AMICI_SUCCESS + assert_allclose( + rdata.by_id("A"), + 0.5, + atol=amici_solver.getAbsoluteToleranceSteadyState(), + rtol=amici_solver.getRelativeToleranceSteadyState(), + ) + assert_allclose( + rdata.by_id("B"), + 0.5, + atol=amici_solver.getAbsoluteToleranceSteadyState(), + rtol=amici_solver.getRelativeToleranceSteadyState(), + ) + assert rdata.t_last < 100 diff --git a/src/model.cpp b/src/model.cpp index cefdf1ac97..9ba435b338 100644 --- a/src/model.cpp +++ b/src/model.cpp @@ -3138,6 +3138,23 @@ std::vector Model::get_trigger_timepoints() const { return trigger_timepoints; } +void Model::set_steadystate_mask(std::vector const& mask) { + if (mask.size() == 0) { + if (steadystate_mask_.getLength() != 0) { + steadystate_mask_ = AmiVector(); + } + return; + } + + if (gsl::narrow(mask.size()) != nx_solver) + throw AmiException( + "Steadystate mask has wrong size: %d, expected %d", + gsl::narrow(mask.size()), nx_solver + ); + + steadystate_mask_ = AmiVector(mask); +} + const_N_Vector Model::computeX_pos(const_N_Vector x) { if (any_state_non_negative_) { for (int ix = 0; ix < derived_state_.x_pos_tmp_.getLength(); ++ix) { diff --git a/src/steadystateproblem.cpp b/src/steadystateproblem.cpp index 98c36589f7..d78f9a8705 100644 --- a/src/steadystateproblem.cpp +++ b/src/steadystateproblem.cpp @@ -509,8 +509,8 @@ bool SteadystateProblem::getSensitivityFlag( } realtype SteadystateProblem::getWrmsNorm( - AmiVector const& x, AmiVector const& xdot, realtype atol, realtype rtol, - AmiVector& ewt + AmiVector const& x, AmiVector const& xdot, AmiVector const& mask, + realtype atol, realtype rtol, AmiVector& ewt ) const { /* Depending on what convergence we want to check (xdot, sxdot, xQBdot) we need to pass ewt[QB], as xdot and xQBdot have different sizes */ @@ -522,7 +522,14 @@ realtype SteadystateProblem::getWrmsNorm( N_VAddConst(ewt.getNVector(), atol, ewt.getNVector()); /* ewt = 1/ewt (ewt = 1/(rtol*x+atol)) */ N_VInv(ewt.getNVector(), ewt.getNVector()); - /* wrms = sqrt(sum((xdot/ewt)**2)/n) where n = size of state vector */ + + // wrms = sqrt(sum((xdot/ewt)**2)/n) where n = size of state vector + if (mask.getLength()) { + return N_VWrmsNormMask( + const_cast(xdot.getNVector()), ewt.getNVector(), + const_cast(mask.getNVector()) + ); + } return N_VWrmsNorm( const_cast(xdot.getNVector()), ewt.getNVector() ); @@ -543,7 +550,10 @@ SteadystateProblem::getWrms(Model& model, SensitivityMethod sensi_method) { "Newton type convergence check is not implemented for adjoint " "steady state computations. Stopping." ); - wrms = getWrmsNorm(xQB_, xQBdot_, atol_quad_, rtol_quad_, ewtQB_); + wrms = getWrmsNorm( + xQB_, xQBdot_, model.get_steadystate_mask_av(), atol_quad_, + rtol_quad_, ewtQB_ + ); } else { /* If we're doing a forward simulation (with or without sensitivities: Get RHS and compute weighted error norm */ @@ -552,7 +562,8 @@ SteadystateProblem::getWrms(Model& model, SensitivityMethod sensi_method) { else updateRightHandSide(model); wrms = getWrmsNorm( - state_.x, newton_step_conv_ ? delta_ : xdot_, atol_, rtol_, ewt_ + state_.x, newton_step_conv_ ? delta_ : xdot_, + model.get_steadystate_mask_av(), atol_, rtol_, ewt_ ); } return wrms; @@ -573,8 +584,10 @@ realtype SteadystateProblem::getWrmsFSA(Model& model) { ); if (newton_step_conv_) newton_solver_->solveLinearSystem(xdot_); - wrms - = getWrmsNorm(state_.sx[ip], xdot_, atol_sensi_, rtol_sensi_, ewt_); + wrms = getWrmsNorm( + state_.sx[ip], xdot_, model.get_steadystate_mask_av(), atol_sensi_, + rtol_sensi_, ewt_ + ); /* ideally this function would report the maximum of all wrms over all ip, but for practical purposes we can just report the wrms for the first ip where we know that the convergence threshold is not