Skip to content

Commit

Permalink
..
Browse files Browse the repository at this point in the history
  • Loading branch information
dweindl committed Dec 18, 2023
1 parent 4aafd45 commit 3f7f5bc
Showing 1 changed file with 63 additions and 41 deletions.
104 changes: 63 additions & 41 deletions python/sdist/amici/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from typing import Iterable, Optional, Sequence, Union

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from matplotlib.axes import Axes
Expand All @@ -16,28 +17,29 @@

def plot_state_trajectories(
rdata: ReturnDataView,
state_indices: Optional[Iterable[int]] = None,
state_indices: Optional[Sequence[int]] = None,
ax: Optional[Axes] = None,
model: Model = None,
prefer_names: bool = True,
marker=None,
) -> None:
"""
Plot state trajectories
Plot state trajectories.
:param rdata:
AMICI simulation results as returned by
:func:`amici.amici.runAmiciSimulation`
:func:`amici.amici.runAmiciSimulation`.
:param state_indices:
Indices of states for which trajectories are to be plotted
Indices of state variables for which trajectories are to be plotted.
:param ax:
matplotlib Axes instance to plot into
:class:`matplotlib.pyplot.Axes` instance to plot into.
:param model:
amici model instance
The model *rdata* was generated from.
:param prefer_names:
Whether state names should be preferred over IDs, if available.
:param marker:
Point marker for plotting.
Point marker for plotting (see
`matplotlib documentation <https://matplotlib.org/stable/api/markers_api.html>`_).
"""
if not ax:
fig, ax = plt.subplots()
Expand All @@ -49,16 +51,20 @@ def plot_state_trajectories(
# otherwise nothing will be shown
marker = "o" if len(rdata.t) == 1 else None

for ix in state_indices:
if model is None and rdata._swigptr.state_ids is None:
label = f"$x_{{{ix}}}$"
elif model is not None and prefer_names and model.getStateNames()[ix]:
label = model.getStateNames()[ix]
elif model is not None:
label = model.getStateIds()[ix]
else:
label = rdata._swigptr.state_ids[ix]

if model is None and rdata.ptr.state_ids is None:
labels = [f"$x_{{{ix}}}$" for ix in state_indices]
elif model is not None and prefer_names:
labels = np.asarray(model.getStateNames())[list(state_indices)]
labels = [
l if l else model.getStateNames()[ix]
for ix, l in enumerate(labels)
]
elif model is not None:
labels = np.asarray(model.getStateIds())[list(state_indices)]
else:
labels = np.asarray(rdata.ptr.state_ids)[list(state_indices)]

for ix, label in zip(state_indices, labels):
ax.plot(rdata["t"], rdata["x"][:, ix], marker=marker, label=label)
ax.set_xlabel("$t$")
ax.set_ylabel("$x(t)$")
Expand All @@ -72,38 +78,54 @@ def plot_observable_trajectories(
ax: Optional[Axes] = None,
model: Model = None,
prefer_names: bool = True,
marker=None,
) -> None:
"""
Plot observable trajectories
Plot observable trajectories.
:param rdata:
AMICI simulation results as returned by
:func:`amici.amici.runAmiciSimulation`
:func:`amici.amici.runAmiciSimulation`.
:param observable_indices:
Indices of observables for which trajectories are to be plotted
Indices of observables for which trajectories are to be plotted.
:param ax:
matplotlib Axes instance to plot into
:class:`matplotlib.pyplot.Axes` instance to plot into.
:param model:
amici model instance
The model *rdata* was generated from.
:param prefer_names:
Whether observables names should be preferred over IDs, if available.
Whether observable names should be preferred over IDs, if available.
:param marker:
Point marker for plotting (see
`matplotlib documentation <https://matplotlib.org/stable/api/markers_api.html>`_).
"""
if not ax:
fig, ax = plt.subplots()
if not observable_indices:
observable_indices = range(rdata["y"].shape[1])
for iy in observable_indices:
if model is None:
label = f"$y_{{{iy}}}$"
elif prefer_names and model.getObservableNames()[iy]:
label = model.getObservableNames()[iy]
else:
label = model.getObservableIds()[iy]
ax.plot(rdata["t"], rdata["y"][:, iy], label=label)

if marker is None:
# Show marker if only one time point is available,
# otherwise nothing will be shown
marker = "o" if len(rdata.t) == 1 else None

if model is None and rdata.ptr.observable_ids is None:
labels = [f"$y_{{{iy}}}$" for iy in observable_indices]
elif model is not None and prefer_names:
labels = np.asarray(model.getObservableNames())[
list(observable_indices)
]
labels = [
l if l else model.getObservableNames()[ix]
for ix, l in enumerate(labels)
]
elif model is not None:
labels = np.asarray(model.getObservableIds())[list(observable_indices)]
else:
labels = np.asarray(rdata.ptr.observable_ids)[list(observable_indices)]

for iy, label in zip(observable_indices, labels):
ax.plot(rdata["t"], rdata["y"][:, iy], marker=marker, label=label)
ax.set_xlabel("$t$")
ax.set_ylabel("$y(t)$")
ax.legend()
Expand All @@ -114,8 +136,8 @@ def plot_jacobian(rdata: ReturnDataView):
"""Plot Jacobian as heatmap."""
df = pd.DataFrame(
data=rdata.J,
index=rdata._swigptr.state_ids_solver,
columns=rdata._swigptr.state_ids_solver,
index=rdata.ptr.state_ids_solver,
columns=rdata.ptr.state_ids_solver,
)
sns.heatmap(df, center=0.0)
plt.title("Jacobian")
Expand All @@ -132,10 +154,10 @@ def plot_expressions(
"""Plot the given expressions evaluated on the given simulation outputs.
:param exprs:
A symbolic expression, e.g. a sympy expression or a string that can be sympified.
Can include state variable, expression, and observable IDs, depending on whether
the respective data is available in the simulation results.
Parameters are not yet supported.
A symbolic expression, e.g., a sympy expression or a string that can be
sympified. It Can include state variable, expression, and
observable IDs, depending on whether the respective data is available
in the simulation results. Parameters are not yet supported.
:param rdata:
The simulation results.
"""
Expand Down

0 comments on commit 3f7f5bc

Please sign in to comment.