Skip to content

Commit

Permalink
Handle events occuring at fixed timepoints without root-finding
Browse files Browse the repository at this point in the history
A first attempt towards AMICI-dev#2185

For events that occur at known timepoints, we don't need sundials'
root-finding. We can just stop the solver at the respective timepoints
and handle the events.

To be extended to parameterized but state-independent trigger functions.
  • Loading branch information
dweindl committed Dec 4, 2023
1 parent 81872cc commit baa13af
Show file tree
Hide file tree
Showing 26 changed files with 234 additions and 43 deletions.
4 changes: 3 additions & 1 deletion include/amici/forwardproblem.h
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,9 @@ class ForwardProblem {
SimulationState const& getSimulationStateTimepoint(int it) const {
if (model->getTimepoint(it) == initial_state_.t)
return getInitialSimulationState();
return timepoint_states_.find(model->getTimepoint(it))->second;
auto map_iter = timepoint_states_.find(model->getTimepoint(it));
assert(map_iter != timepoint_states_.end());
return map_iter->second;
};

/**
Expand Down
10 changes: 8 additions & 2 deletions include/amici/model.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
#include "amici/vector.h"

#include <map>
#include <memory>
#include <vector>

namespace amici {
Expand Down Expand Up @@ -117,14 +116,17 @@ class Model : public AbstractModel, public ModelDimensions {
* @param ndxdotdp_explicit Number of nonzero elements in `dxdotdp_explicit`
* @param ndxdotdx_explicit Number of nonzero elements in `dxdotdx_explicit`
* @param w_recursion_depth Recursion depth of fw
* @param state_independent_events Map of events with state-independent
* triggers functions, mapping trigger timepoints to event indices.
*/
Model(
ModelDimensions const& model_dimensions,
SimulationParameters simulation_parameters,
amici::SecondOrderMode o2mode, std::vector<amici::realtype> idlist,
std::vector<int> z2event, bool pythonGenerated = false,
int ndxdotdp_explicit = 0, int ndxdotdx_explicit = 0,
int w_recursion_depth = 0
int w_recursion_depth = 0,
std::map<realtype, std::vector<int>> state_independent_events = {}
);

/** Destructor. */
Expand Down Expand Up @@ -1449,6 +1451,8 @@ class Model : public AbstractModel, public ModelDimensions {
*/
SUNMatrixWrapper const& get_dxdotdp_full() const;

virtual std::vector<double> get_trigger_timepoints() const;

/**
* Flag indicating whether for
* `amici::Solver::sensi_` == `amici::SensitivityOrder::second`
Expand All @@ -1462,6 +1466,8 @@ class Model : public AbstractModel, public ModelDimensions {
/** Logger */
Logger* logger = nullptr;

std::map<realtype, std::vector<int>> state_independent_events_ = {};

protected:
/**
* @brief Write part of a slice to a buffer according to indices specified
Expand Down
20 changes: 14 additions & 6 deletions include/amici/model_dimensions.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ struct ModelDimensions {
* @param nz Number of event observables
* @param nztrue Number of event observables of the non-augmented model
* @param ne Number of events
* @param ne_solver Number of events that require root-finding
* @param nspl Number of splines
* @param nJ Number of objective functions
* @param nw Number of repeating elements
Expand Down Expand Up @@ -58,11 +59,12 @@ struct ModelDimensions {
int const nx_rdata, int const nxtrue_rdata, int const nx_solver,
int const nxtrue_solver, int const nx_solver_reinit, int const np,
int const nk, int const ny, int const nytrue, int const nz,
int const nztrue, int const ne, int const nspl, int const nJ,
int const nw, int const ndwdx, int const ndwdp, int const ndwdw,
int const ndxdotdw, std::vector<int> ndJydy, int const ndxrdatadxsolver,
int const ndxrdatadtcl, int const ndtotal_cldx_rdata, int const nnz,
int const ubw, int const lbw
int const nztrue, int const ne, int const ne_solver, int const nspl,
int const nJ, int const nw, int const ndwdx, int const ndwdp,
int const ndwdw, int const ndxdotdw, std::vector<int> ndJydy,
int const ndxrdatadxsolver, int const ndxrdatadtcl,
int const ndtotal_cldx_rdata, int const nnz, int const ubw,
int const lbw
)
: nx_rdata(nx_rdata)
, nxtrue_rdata(nxtrue_rdata)
Expand All @@ -76,6 +78,7 @@ struct ModelDimensions {
, nz(nz)
, nztrue(nztrue)
, ne(ne)
, ne_solver(ne_solver)
, nspl(nspl)
, nw(nw)
, ndwdx(ndwdx)
Expand Down Expand Up @@ -104,6 +107,8 @@ struct ModelDimensions {
Expects(nztrue >= 0);
Expects(nztrue <= nz);
Expects(ne >= 0);
Expects(ne_solver >= 0);
Expects(ne >= ne_solver);
Expects(nspl >= 0);
Expects(nw >= 0);
Expects(ndwdx >= 0);
Expand Down Expand Up @@ -164,7 +169,10 @@ struct ModelDimensions {
/** Number of events */
int ne{0};

/** numer of spline functions in the model */
/** Number of events that require root-finding */
int ne_solver{0};

/** Number of spline functions in the model */
int nspl{0};

/** Number of common expressions */
Expand Down
8 changes: 6 additions & 2 deletions include/amici/model_ode.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,19 +39,23 @@ class Model_ODE : public Model {
* @param ndxdotdp_explicit number of nonzero elements dxdotdp_explicit
* @param ndxdotdx_explicit number of nonzero elements dxdotdx_explicit
* @param w_recursion_depth Recursion depth of fw
* @param state_independent_events Map of events with state-independent
* triggers functions, mapping trigger timepoints to event indices.
*/
Model_ODE(
ModelDimensions const& model_dimensions,
SimulationParameters simulation_parameters,
const SecondOrderMode o2mode, std::vector<realtype> const& idlist,
std::vector<int> const& z2event, bool const pythonGenerated = false,
int const ndxdotdp_explicit = 0, int const ndxdotdx_explicit = 0,
int const w_recursion_depth = 0
int const w_recursion_depth = 0,
std::map<realtype, std::vector<int>> state_independent_events
= {}
)
: Model(
model_dimensions, simulation_parameters, o2mode, idlist, z2event,
pythonGenerated, ndxdotdp_explicit, ndxdotdx_explicit,
w_recursion_depth
w_recursion_depth, state_independent_events
) {}

void
Expand Down
1 change: 1 addition & 0 deletions include/amici/serialization.h
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,7 @@ void serialize(
ar& m.nz;
ar& m.nztrue;
ar& m.ne;
ar& m.ne_solver;
ar& m.nspl;
ar& m.nw;
ar& m.ndwdx;
Expand Down
3 changes: 2 additions & 1 deletion models/model_calvetti/model_calvetti.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ class Model_model_calvetti : public amici::Model_DAE {
0,
4,
0,
0,
1,
38,
53,
Expand Down Expand Up @@ -207,6 +208,6 @@ class Model_model_calvetti : public amici::Model_DAE {

} // namespace model_model_calvetti

} // namespace amici
} // namespace amici

#endif /* _amici_model_calvetti_h */
3 changes: 2 additions & 1 deletion models/model_dirac/model_dirac.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ class Model_model_dirac : public amici::Model_ODE {
0,
2,
0,
0,
1,
0,
0,
Expand Down Expand Up @@ -204,6 +205,6 @@ class Model_model_dirac : public amici::Model_ODE {

} // namespace model_model_dirac

} // namespace amici
} // namespace amici

#endif /* _amici_model_dirac_h */
3 changes: 2 additions & 1 deletion models/model_events/model_events.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ class Model_model_events : public amici::Model_ODE {
2,
6,
0,
0,
1,
0,
0,
Expand Down Expand Up @@ -232,6 +233,6 @@ class Model_model_events : public amici::Model_ODE {

} // namespace model_model_events

} // namespace amici
} // namespace amici

#endif /* _amici_model_events_h */
3 changes: 2 additions & 1 deletion models/model_jakstat_adjoint/model_jakstat_adjoint.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ class Model_model_jakstat_adjoint : public amici::Model_ODE {
0,
0,
0,
0,
1,
2,
1,
Expand Down Expand Up @@ -210,6 +211,6 @@ class Model_model_jakstat_adjoint : public amici::Model_ODE {

} // namespace model_model_jakstat_adjoint

} // namespace amici
} // namespace amici

#endif /* _amici_model_jakstat_adjoint_h */
3 changes: 2 additions & 1 deletion models/model_jakstat_adjoint_o2/model_jakstat_adjoint_o2.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ class Model_model_jakstat_adjoint_o2 : public amici::Model_ODE {
0,
0,
0,
0,
18,
10,
2,
Expand Down Expand Up @@ -210,6 +211,6 @@ class Model_model_jakstat_adjoint_o2 : public amici::Model_ODE {

} // namespace model_model_jakstat_adjoint_o2

} // namespace amici
} // namespace amici

#endif /* _amici_model_jakstat_adjoint_o2_h */
3 changes: 2 additions & 1 deletion models/model_nested_events/model_nested_events.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ class Model_model_nested_events : public amici::Model_ODE {
0,
4,
0,
0,
1,
0,
0,
Expand Down Expand Up @@ -210,6 +211,6 @@ class Model_model_nested_events : public amici::Model_ODE {

} // namespace model_model_nested_events

} // namespace amici
} // namespace amici

#endif /* _amici_model_nested_events_h */
3 changes: 2 additions & 1 deletion models/model_neuron/model_neuron.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ class Model_model_neuron : public amici::Model_ODE {
1,
1,
0,
0,
1,
0,
0,
Expand Down Expand Up @@ -238,6 +239,6 @@ class Model_model_neuron : public amici::Model_ODE {

} // namespace model_model_neuron

} // namespace amici
} // namespace amici

#endif /* _amici_model_neuron_h */
3 changes: 2 additions & 1 deletion models/model_neuron_o2/model_neuron_o2.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ class Model_model_neuron_o2 : public amici::Model_ODE {
1,
1,
0,
0,
5,
2,
2,
Expand Down Expand Up @@ -242,6 +243,6 @@ class Model_model_neuron_o2 : public amici::Model_ODE {

} // namespace model_model_neuron_o2

} // namespace amici
} // namespace amici

#endif /* _amici_model_neuron_o2_h */
3 changes: 2 additions & 1 deletion models/model_robertson/model_robertson.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ class Model_model_robertson : public amici::Model_DAE {
0,
0,
0,
0,
1,
1,
2,
Expand Down Expand Up @@ -209,6 +210,6 @@ class Model_model_robertson : public amici::Model_DAE {

} // namespace model_model_robertson

} // namespace amici
} // namespace amici

#endif /* _amici_model_robertson_h */
3 changes: 2 additions & 1 deletion models/model_steadystate/model_steadystate.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ class Model_model_steadystate : public amici::Model_ODE {
0,
0,
0,
0,
1,
2,
2,
Expand Down Expand Up @@ -204,6 +205,6 @@ class Model_model_steadystate : public amici::Model_ODE {

} // namespace model_model_steadystate

} // namespace amici
} // namespace amici

#endif /* _amici_model_steadystate_h */
44 changes: 43 additions & 1 deletion python/sdist/amici/de_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -1425,13 +1425,24 @@ def num_expr(self) -> int:
return len(self.sym("w"))

def num_events(self) -> int:
"""
Total number of Events (those for which root-functions are added and those without).
:return:
number of events
"""
return len(self.sym("h"))

def num_events_solver(self) -> int:
"""
Number of Events.
:return:
number of event symbols (length of the root vector in AMICI)
"""
return len(self.sym("h"))
return sum(
not event.triggers_at_fixed_timepoint() for event in self.events()
)

def sym(self, name: str) -> sp.Matrix:
"""
Expand Down Expand Up @@ -1750,6 +1761,16 @@ def parse_events(self) -> None:
# add roots of heaviside functions
self.add_component(root)

# re-order events - first those that require root tracking, then the others
self._events = list(
chain(
itertools.filterfalse(
Event.triggers_at_fixed_timepoint, self._events
),
filter(Event.triggers_at_fixed_timepoint, self._events),
)
)

def get_appearance_counts(self, idxs: List[int]) -> List[int]:
"""
Counts how often a state appears in the time derivative of
Expand Down Expand Up @@ -3642,6 +3663,7 @@ def _write_model_header_cpp(self) -> None:
"NZ": self.model.num_eventobs(),
"NZTRUE": self.model.num_eventobs(),
"NEVENT": self.model.num_events(),
"NEVENT_SOLVER": self.model.num_events_solver(),
"NOBJECTIVE": "1",
"NSPL": len(self.model.splines),
"NW": len(self.model.sym("w")),
Expand Down Expand Up @@ -3736,6 +3758,7 @@ def _write_model_header_cpp(self) -> None:
)
),
"Z2EVENT": ", ".join(map(str, self.model._z2event)),
"STATE_INDEPENDENT_EVENTS": self._get_state_independent_event_intializer(),
"ID": ", ".join(
(
str(float(isinstance(s, DifferentialState)))
Expand Down Expand Up @@ -3871,6 +3894,25 @@ def _get_symbol_id_initializer_list(self, name: str) -> str:
for idx, symbol in enumerate(self.model.sym(name))
)

def _get_state_independent_event_intializer(self) -> str:
tmp_map = {}
for event_idx, event in enumerate(self.model.events()):
if not event.triggers_at_fixed_timepoint():
continue
trigger_time = float(event.get_trigger_time())
try:
tmp_map[trigger_time].append(event_idx)
except KeyError:
tmp_map[trigger_time] = [event_idx]

def vector_initializer(v):
return f"{{{', '.join(map(str, v))}}}"

return ", ".join(
f"{{{trigger_time}, {vector_initializer(event_idxs)}}}"
for trigger_time, event_idxs in tmp_map.items()
)

def _write_c_make_file(self):
"""Write CMake ``CMakeLists.txt`` file for this model."""
sources = "\n".join(
Expand Down
Loading

0 comments on commit baa13af

Please sign in to comment.