Skip to content

Commit

Permalink
Merge branch 'develop' into jax_reinitialisation
Browse files Browse the repository at this point in the history
  • Loading branch information
FFroehlich authored Dec 5, 2024
2 parents 757ffa1 + 449041d commit 01e12fc
Show file tree
Hide file tree
Showing 5 changed files with 23 additions and 3 deletions.
1 change: 1 addition & 0 deletions pytest.ini
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ filterwarnings =
ignore:Conservation laws for non-constant species in models with Species-AssignmentRules are currently not supported and will be turned off.:UserWarning
ignore:Conservation laws for non-constant species in combination with parameterized stoichiometric coefficients are not currently supported and will be turned off.:UserWarning
ignore:Support for PEtab2.0 is experimental!:UserWarning
ignore:The JAX module is experimental and the API may change in the future.:ImportWarning
# hundreds of SBML <=5.17 warnings
ignore:.*inspect.getargspec\(\) is deprecated.*:DeprecationWarning
# pysb warnings
Expand Down
2 changes: 1 addition & 1 deletion python/examples/example_jax_petab/ExampleJaxPEtab.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@
" results (dict): Simulation results from run_simulations.\n",
" \"\"\"\n",
" # Extract the simulation results for the specific condition\n",
" sim_results = results[simulation_condition][1]\n",
" sim_results = results[simulation_condition]\n",
"\n",
" # Create a new figure for the state trajectories\n",
" plt.figure(figsize=(8, 6))\n",
Expand Down
17 changes: 16 additions & 1 deletion python/sdist/amici/jax/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,21 @@
"""Interface to facilitate AMICI generated models using JAX"""
"""
JAX
---
This module provides an interface to generate and use AMICI models with JAX. Please note that this module is
experimental, the API may substantially change in the future. Use at your own risk and do not expect backward
compatibility.
"""

from warnings import warn

from amici.jax.petab import JAXProblem, run_simulations
from amici.jax.model import JAXModel

warn(
"The JAX module is experimental and the API may change in the future.",
ImportWarning,
stacklevel=2,
)

__all__ = ["JAXModel", "JAXProblem", "run_simulations"]
1 change: 1 addition & 0 deletions python/sdist/amici/jax/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -432,6 +432,7 @@ def simulate_condition(
ts_posteq: jt.Float[jt.Array, "nt_posteq"],
my: jt.Float[jt.Array, "nt"],
iys: jt.Int[jt.Array, "nt"],
x_preeq: jt.Float[jt.Array, "nx"],
solver: diffrax.AbstractSolver,
controller: diffrax.AbstractStepSizeController,
adjoint: diffrax.AbstractAdjoint,
Expand Down
5 changes: 4 additions & 1 deletion python/sdist/amici/jax/petab.py
Original file line number Diff line number Diff line change
Expand Up @@ -490,4 +490,7 @@ def run_simulations(
)
for sc in simulation_conditions
}
return sum(llh for llh, _ in results.values()), results | preeqs
return sum(llh for llh, _ in results.values()), {
sc: res[1] | preeqs[sc[1]][1] if len(sc) > 1 else res[1]
for sc, res in results.items()
}

0 comments on commit 01e12fc

Please sign in to comment.