Skip to content

Commit

Permalink
Refactor event handling (#2216)
Browse files Browse the repository at this point in the history
Just extracting 2 functions. No change in functionality.
  • Loading branch information
dweindl authored Dec 1, 2023
1 parent c174c40 commit 06f6217
Show file tree
Hide file tree
Showing 2 changed files with 88 additions and 65 deletions.
15 changes: 15 additions & 0 deletions include/amici/forwardproblem.h
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,21 @@ class ForwardProblem {

void handleEvent(realtype* tlastroot, bool seflag, bool initial_event);

/**
* @brief Store pre-event model state
*
* @param seflag Secondary event flag
* @param initial_event initial event flag
*/
void store_pre_event_state(bool seflag, bool initial_event);

/**
* @brief Check for, and if applicable, handle any secondary events
*
* @param tlastroot pointer to the timepoint of the last event
*/
void handle_secondary_event(realtype* tlastroot);

/**
* @brief Extract output information for events
*/
Expand Down
138 changes: 73 additions & 65 deletions src/forwardproblem.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,77 @@ void ForwardProblem::handleEvent(
if (model->nz > 0)
storeEvent();

store_pre_event_state(seflag, initial_event);

if (!initial_event)
model->updateHeaviside(roots_found_);

applyEventBolus();

if (solver->computingFSA()) {
/* compute the new xdot */
model->fxdot(t_, x_, dx_, xdot_);
applyEventSensiBolusFSA();
}

handle_secondary_event(tlastroot);

/* only reinitialise in the first event fired */
if (!seflag) {
solver->reInit(t_, x_, dx_);
if (solver->computingFSA()) {
solver->sensReInit(sx_, sdx_);
}
}
}

void ForwardProblem::storeEvent() {
if (t_ == model->getTimepoint(model->nt() - 1)) {
// call from fillEvent at last timepoint
model->froot(t_, x_, dx_, rootvals_);
for (int ie = 0; ie < model->ne; ie++) {
roots_found_.at(ie) = (nroots_.at(ie) < model->nMaxEvent()) ? 1 : 0;
}
root_idx_.push_back(roots_found_);
}

if (getRootCounter() < getEventCounter()) {
/* update stored state (sensi) */
event_states_.at(getRootCounter()) = getSimulationState();
} else {
/* add stored state (sensi) */
event_states_.push_back(getSimulationState());
}

/* EVENT OUTPUT */
for (int ie = 0; ie < model->ne; ie++) {
/* only look for roots of the rootfunction not discontinuities */
if (nroots_.at(ie) >= model->nMaxEvent())
continue;

/* only consider transitions false -> true or event filling */
if (roots_found_.at(ie) != 1
&& t_ != model->getTimepoint(model->nt() - 1)) {
continue;
}

if (edata && solver->computingASA())
model->getAdjointStateEventUpdate(
slice(dJzdx_, nroots_.at(ie), model->nx_solver * model->nJ), ie,
nroots_.at(ie), t_, x_, *edata
);

nroots_.at(ie)++;
}

if (t_ == model->getTimepoint(model->nt() - 1)) {
// call from fillEvent at last timepoint
// loop until all events are filled
fillEvents(model->nMaxEvent());
}
}

void ForwardProblem::store_pre_event_state(bool seflag, bool initial_event) {
/* if we need to do forward sensitivities later on we need to store the old
* x and the old xdot */
if (solver->getSensitivityOrder() >= SensitivityOrder::first) {
Expand Down Expand Up @@ -212,18 +283,9 @@ void ForwardProblem::handleEvent(
xdot_disc_.push_back(xdot_);
xdot_old_disc_.push_back(xdot_old_);
}
}

if (!initial_event)
model->updateHeaviside(roots_found_);

applyEventBolus();

if (solver->computingFSA()) {
/* compute the new xdot */
model->fxdot(t_, x_, dx_, xdot_);
applyEventSensiBolusFSA();
}

void ForwardProblem::handle_secondary_event(realtype* tlastroot) {
int secondevent = 0;

/* check whether we need to fire a secondary event */
Expand Down Expand Up @@ -260,60 +322,6 @@ void ForwardProblem::handleEvent(
);
handleEvent(tlastroot, true, false);
}

/* only reinitialise in the first event fired */
if (!seflag) {
solver->reInit(t_, x_, dx_);
if (solver->computingFSA()) {
solver->sensReInit(sx_, sdx_);
}
}
}

void ForwardProblem::storeEvent() {
if (t_ == model->getTimepoint(model->nt() - 1)) {
// call from fillEvent at last timepoint
model->froot(t_, x_, dx_, rootvals_);
for (int ie = 0; ie < model->ne; ie++) {
roots_found_.at(ie) = (nroots_.at(ie) < model->nMaxEvent()) ? 1 : 0;
}
root_idx_.push_back(roots_found_);
}

if (getRootCounter() < getEventCounter()) {
/* update stored state (sensi) */
event_states_.at(getRootCounter()) = getSimulationState();
} else {
/* add stored state (sensi) */
event_states_.push_back(getSimulationState());
}

/* EVENT OUTPUT */
for (int ie = 0; ie < model->ne; ie++) {
/* only look for roots of the rootfunction not discontinuities */
if (nroots_.at(ie) >= model->nMaxEvent())
continue;

/* only consider transitions false -> true or event filling */
if (roots_found_.at(ie) != 1
&& t_ != model->getTimepoint(model->nt() - 1)) {
continue;
}

if (edata && solver->computingASA())
model->getAdjointStateEventUpdate(
slice(dJzdx_, nroots_.at(ie), model->nx_solver * model->nJ), ie,
nroots_.at(ie), t_, x_, *edata
);

nroots_.at(ie)++;
}

if (t_ == model->getTimepoint(model->nt() - 1)) {
// call from fillEvent at last timepoint
// loop until all events are filled
fillEvents(model->nMaxEvent());
}
}

void ForwardProblem::handleDataPoint(int /*it*/) {
Expand Down

0 comments on commit 06f6217

Please sign in to comment.