Skip to content

Commit

Permalink
Use proper labels for plotting if IDs are available in ReturnData
Browse files Browse the repository at this point in the history
  • Loading branch information
dweindl committed Dec 18, 2023
1 parent a3b1c2b commit 4aafd45
Showing 1 changed file with 16 additions and 8 deletions.
24 changes: 16 additions & 8 deletions python/sdist/amici/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,38 +20,46 @@ def plot_state_trajectories(
ax: Optional[Axes] = None,
model: Model = None,
prefer_names: bool = True,
marker=None,
) -> None:
"""
Plot state trajectories
:param rdata:
AMICI simulation results as returned by
:func:`amici.amici.runAmiciSimulation`
:param state_indices:
Indices of states for which trajectories are to be plotted
:param ax:
matplotlib Axes instance to plot into
:param model:
amici model instance
:param prefer_names:
Whether state names should be preferred over IDs, if available.
:param marker:
Point marker for plotting.
"""
if not ax:
fig, ax = plt.subplots()
if not state_indices:
state_indices = range(rdata["x"].shape[1])

if marker is None:
# Show marker if only one time point is available,
# otherwise nothing will be shown
marker = "o" if len(rdata.t) == 1 else None

for ix in state_indices:
if model is None:
if model is None and rdata._swigptr.state_ids is None:
label = f"$x_{{{ix}}}$"
elif prefer_names and model.getStateNames()[ix]:
elif model is not None and prefer_names and model.getStateNames()[ix]:
label = model.getStateNames()[ix]
else:
elif model is not None:
label = model.getStateIds()[ix]
ax.plot(rdata["t"], rdata["x"][:, ix], label=label)
else:
label = rdata._swigptr.state_ids[ix]

ax.plot(rdata["t"], rdata["x"][:, ix], marker=marker, label=label)
ax.set_xlabel("$t$")
ax.set_ylabel("$x(t)$")
ax.legend()
Expand Down

0 comments on commit 4aafd45

Please sign in to comment.