Skip to content

Commit

Permalink
..
Browse files Browse the repository at this point in the history
  • Loading branch information
dweindl committed Dec 18, 2023
1 parent 4aafd45 commit 3f7f5bc
Showing 1 changed file with 63 additions and 41 deletions.
104 changes: 63 additions & 41 deletions python/sdist/amici/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from typing import Iterable, Optional, Sequence, Union

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from matplotlib.axes import Axes
Expand All @@ -16,28 +17,29 @@

def plot_state_trajectories(
rdata: ReturnDataView,
state_indices: Optional[Iterable[int]] = None,
state_indices: Optional[Sequence[int]] = None,
ax: Optional[Axes] = None,
model: Model = None,
prefer_names: bool = True,
marker=None,
) -> None:
"""
Plot state trajectories
Plot state trajectories.
:param rdata:
AMICI simulation results as returned by
:func:`amici.amici.runAmiciSimulation`
:func:`amici.amici.runAmiciSimulation`.
:param state_indices:
Indices of states for which trajectories are to be plotted
Indices of state variables for which trajectories are to be plotted.
:param ax:
matplotlib Axes instance to plot into
:class:`matplotlib.pyplot.Axes` instance to plot into.
:param model:
amici model instance
The model *rdata* was generated from.
:param prefer_names:
Whether state names should be preferred over IDs, if available.
:param marker:
Point marker for plotting.
Point marker for plotting (see
`matplotlib documentation <https://matplotlib.org/stable/api/markers_api.html>`_).
"""
if not ax:
fig, ax = plt.subplots()
Expand All @@ -49,16 +51,20 @@ def plot_state_trajectories(
# otherwise nothing will be shown
marker = "o" if len(rdata.t) == 1 else None

Check warning on line 52 in python/sdist/amici/plotting.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/plotting.py#L52

Added line #L52 was not covered by tests

for ix in state_indices:
if model is None and rdata._swigptr.state_ids is None:
label = f"$x_{{{ix}}}$"
elif model is not None and prefer_names and model.getStateNames()[ix]:
label = model.getStateNames()[ix]
elif model is not None:
label = model.getStateIds()[ix]
else:
label = rdata._swigptr.state_ids[ix]

if model is None and rdata.ptr.state_ids is None:
labels = [f"$x_{{{ix}}}$" for ix in state_indices]
elif model is not None and prefer_names:
labels = np.asarray(model.getStateNames())[list(state_indices)]
labels = [

Check warning on line 58 in python/sdist/amici/plotting.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/plotting.py#L54-L58

Added lines #L54 - L58 were not covered by tests
l if l else model.getStateNames()[ix]
for ix, l in enumerate(labels)
]
elif model is not None:
labels = np.asarray(model.getStateIds())[list(state_indices)]

Check warning on line 63 in python/sdist/amici/plotting.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/plotting.py#L62-L63

Added lines #L62 - L63 were not covered by tests
else:
labels = np.asarray(rdata.ptr.state_ids)[list(state_indices)]

Check warning on line 65 in python/sdist/amici/plotting.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/plotting.py#L65

Added line #L65 was not covered by tests

for ix, label in zip(state_indices, labels):
ax.plot(rdata["t"], rdata["x"][:, ix], marker=marker, label=label)

Check warning on line 68 in python/sdist/amici/plotting.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/plotting.py#L67-L68

Added lines #L67 - L68 were not covered by tests
ax.set_xlabel("$t$")
ax.set_ylabel("$x(t)$")
Expand All @@ -72,38 +78,54 @@ def plot_observable_trajectories(
ax: Optional[Axes] = None,
model: Model = None,
prefer_names: bool = True,
marker=None,
) -> None:
"""
Plot observable trajectories
Plot observable trajectories.
:param rdata:
AMICI simulation results as returned by
:func:`amici.amici.runAmiciSimulation`
:func:`amici.amici.runAmiciSimulation`.
:param observable_indices:
Indices of observables for which trajectories are to be plotted
Indices of observables for which trajectories are to be plotted.
:param ax:
matplotlib Axes instance to plot into
:class:`matplotlib.pyplot.Axes` instance to plot into.
:param model:
amici model instance
The model *rdata* was generated from.
:param prefer_names:
Whether observables names should be preferred over IDs, if available.
Whether observable names should be preferred over IDs, if available.
:param marker:
Point marker for plotting (see
`matplotlib documentation <https://matplotlib.org/stable/api/markers_api.html>`_).
"""
if not ax:
fig, ax = plt.subplots()
if not observable_indices:
observable_indices = range(rdata["y"].shape[1])
for iy in observable_indices:
if model is None:
label = f"$y_{{{iy}}}$"
elif prefer_names and model.getObservableNames()[iy]:
label = model.getObservableNames()[iy]
else:
label = model.getObservableIds()[iy]
ax.plot(rdata["t"], rdata["y"][:, iy], label=label)

if marker is None:

Check warning on line 107 in python/sdist/amici/plotting.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/plotting.py#L107

Added line #L107 was not covered by tests
# Show marker if only one time point is available,
# otherwise nothing will be shown
marker = "o" if len(rdata.t) == 1 else None

Check warning on line 110 in python/sdist/amici/plotting.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/plotting.py#L110

Added line #L110 was not covered by tests

if model is None and rdata.ptr.observable_ids is None:
labels = [f"$y_{{{iy}}}$" for iy in observable_indices]
elif model is not None and prefer_names:
labels = np.asarray(model.getObservableNames())[

Check warning on line 115 in python/sdist/amici/plotting.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/plotting.py#L112-L115

Added lines #L112 - L115 were not covered by tests
list(observable_indices)
]
labels = [

Check warning on line 118 in python/sdist/amici/plotting.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/plotting.py#L118

Added line #L118 was not covered by tests
l if l else model.getObservableNames()[ix]
for ix, l in enumerate(labels)
]
elif model is not None:
labels = np.asarray(model.getObservableIds())[list(observable_indices)]

Check warning on line 123 in python/sdist/amici/plotting.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/plotting.py#L122-L123

Added lines #L122 - L123 were not covered by tests
else:
labels = np.asarray(rdata.ptr.observable_ids)[list(observable_indices)]

Check warning on line 125 in python/sdist/amici/plotting.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/plotting.py#L125

Added line #L125 was not covered by tests

for iy, label in zip(observable_indices, labels):
ax.plot(rdata["t"], rdata["y"][:, iy], marker=marker, label=label)

Check warning on line 128 in python/sdist/amici/plotting.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/plotting.py#L127-L128

Added lines #L127 - L128 were not covered by tests
ax.set_xlabel("$t$")
ax.set_ylabel("$y(t)$")
ax.legend()
Expand All @@ -114,8 +136,8 @@ def plot_jacobian(rdata: ReturnDataView):
"""Plot Jacobian as heatmap."""
df = pd.DataFrame(
data=rdata.J,
index=rdata._swigptr.state_ids_solver,
columns=rdata._swigptr.state_ids_solver,
index=rdata.ptr.state_ids_solver,
columns=rdata.ptr.state_ids_solver,
)
sns.heatmap(df, center=0.0)
plt.title("Jacobian")
Expand All @@ -132,10 +154,10 @@ def plot_expressions(
"""Plot the given expressions evaluated on the given simulation outputs.
:param exprs:
A symbolic expression, e.g. a sympy expression or a string that can be sympified.
Can include state variable, expression, and observable IDs, depending on whether
the respective data is available in the simulation results.
Parameters are not yet supported.
A symbolic expression, e.g., a sympy expression or a string that can be
sympified. It Can include state variable, expression, and
observable IDs, depending on whether the respective data is available
in the simulation results. Parameters are not yet supported.
:param rdata:
The simulation results.
"""
Expand Down

0 comments on commit 3f7f5bc

Please sign in to comment.