diff --git a/python/examples/example_steadystate/ExampleSteadystate.ipynb b/python/examples/example_steadystate/ExampleSteadystate.ipynb index 09590a3b1a..b57ed522aa 100644 --- a/python/examples/example_steadystate/ExampleSteadystate.ipynb +++ b/python/examples/example_steadystate/ExampleSteadystate.ipynb @@ -354,7 +354,7 @@ "source": [ "### Importing the module and loading the model\n", "\n", - "If everything went well, we need to add the previously selected model output directory to our PYTHON_PATH and are then ready to load newly generated model:" + "If everything went well, we can now import the newly generated Python module containing our model:" ] }, { @@ -392,7 +392,7 @@ "source": [ "model = model_module.getModel()\n", "\n", - "print(\"Model name:\", model.getName())\n", + "print(\"Model name: \", model.getName())\n", "print(\"Model parameters:\", model.getParameterIds())\n", "print(\"Model outputs: \", model.getObservableIds())\n", "print(\"Model states: \", model.getStateIds())" @@ -917,10 +917,32 @@ "source": [ "import amici.plotting\n", "\n", - "amici.plotting.plotStateTrajectories(rdata, model=None)\n", - "amici.plotting.plotObservableTrajectories(rdata, model=None)" + "amici.plotting.plot_state_trajectories(rdata, model=None)\n", + "amici.plotting.plot_observable_trajectories(rdata, model=None)" ] }, + { + "cell_type": "markdown", + "source": [ + "We can also evaluate symbolic expressions of model quantities using `amici.numpy.evaluate`, or directly plot the results using `amici.plotting.plot_expressions`:" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": null, + "outputs": [], + "source": [ + "amici.plotting.plot_expressions(\n", + " \"observable_x1 + observable_x2 + observable_x3\", rdata=rdata\n", + ")" + ], + "metadata": { + "collapsed": false + } + }, { "cell_type": "markdown", "metadata": {}, diff --git a/python/sdist/amici/numpy.py b/python/sdist/amici/numpy.py index 23ebfdbbc4..d9b34b6447 100644 --- a/python/sdist/amici/numpy.py +++ b/python/sdist/amici/numpy.py @@ -10,9 +10,12 @@ import amici import numpy as np +import sympy as sp from . import ExpData, ExpDataPtr, Model, ReturnData, ReturnDataPtr +StrOrExpr = Union[str, sp.Expr] + class SwigPtrView(collections.abc.Mapping): """ @@ -429,3 +432,28 @@ def _entity_type_from_id( return symbol raise KeyError(f"Unknown symbol {entity_id}.") + + +def evaluate(expr: StrOrExpr, rdata: ReturnDataView) -> np.array: + """Evaluate a symbolic expression based on the given simulation outputs. + + :param expr: + A symbolic expression, e.g. a sympy expression or a string that can be sympified. + Can include state variable, expression, and observable IDs, depending on whether + the respective data is available in the simulation results. + Parameters are not yet supported. + :param rdata: + The simulation results. + + :return: + The evaluated expression for the simulation output timepoints. + """ + from sympy.utilities.lambdify import lambdify + + if isinstance(expr, str): + expr = sp.sympify(expr) + + arg_names = list(sorted(expr.free_symbols, key=lambda x: x.name)) + func = lambdify(arg_names, expr, "numpy") + args = [rdata.by_id(arg.name) for arg in arg_names] + return func(*args) diff --git a/python/sdist/amici/plotting.py b/python/sdist/amici/plotting.py index da718c1ec7..bd1f3a8ba1 100644 --- a/python/sdist/amici/plotting.py +++ b/python/sdist/amici/plotting.py @@ -3,7 +3,7 @@ -------- Plotting related functions """ -from typing import Iterable, Optional +from typing import Iterable, Optional, Sequence, Union import matplotlib.pyplot as plt import pandas as pd @@ -11,6 +11,7 @@ from matplotlib.axes import Axes from . import Model, ReturnDataView +from .numpy import StrOrExpr, evaluate def plot_state_trajectories( @@ -115,3 +116,26 @@ def plot_jacobian(rdata: ReturnDataView): # backwards compatibility plotStateTrajectories = plot_state_trajectories plotObservableTrajectories = plot_observable_trajectories + + +def plot_expressions( + exprs: Union[Sequence[StrOrExpr], StrOrExpr], rdata: ReturnDataView +) -> None: + """Plot the given expressions evaluated on the given simulation outputs. + + :param exprs: + A symbolic expression, e.g. a sympy expression or a string that can be sympified. + Can include state variable, expression, and observable IDs, depending on whether + the respective data is available in the simulation results. + Parameters are not yet supported. + :param rdata: + The simulation results. + """ + if not isinstance(exprs, Sequence) or isinstance(exprs, str): + exprs = [exprs] + + for expr in exprs: + plt.plot(rdata.t, evaluate(expr, rdata), label=str(expr)) + + plt.legend() + plt.gca().set_xlabel("$t$") diff --git a/python/tests/test_rdata.py b/python/tests/test_rdata.py index 29ea401932..ac7659f363 100644 --- a/python/tests/test_rdata.py +++ b/python/tests/test_rdata.py @@ -2,7 +2,8 @@ import amici import numpy as np import pytest -from numpy.testing import assert_array_equal +from amici.numpy import evaluate +from numpy.testing import assert_almost_equal, assert_array_equal @pytest.fixture(scope="session") @@ -39,3 +40,27 @@ def test_rdata_by_id(rdata_by_id_fixture): assert_array_equal( rdata.by_id(model.getStateIds()[1], "sx", model), rdata.sx[:, :, 1] ) + + +def test_evaluate(rdata_by_id_fixture): + # get IDs of model components + model, rdata = rdata_by_id_fixture + expr0_id = model.getExpressionIds()[0] + state1_id = model.getStateIds()[1] + observable0_id = model.getObservableIds()[0] + + # ensure `evaluate` works for atoms + expr0 = rdata.by_id(expr0_id) + assert_array_equal(expr0, evaluate(expr0_id, rdata=rdata)) + + state1 = rdata.by_id(state1_id) + assert_array_equal(state1, evaluate(state1_id, rdata=rdata)) + + observable0 = rdata.by_id(observable0_id) + assert_array_equal(observable0, evaluate(observable0_id, rdata=rdata)) + + # ensure `evaluate` works for expressions + assert_almost_equal( + expr0 + state1 * observable0, + evaluate(f"{expr0_id} + {state1_id} * {observable0_id}", rdata=rdata), + )