Skip to content

Commit

Permalink
fix jax imports
Browse files Browse the repository at this point in the history
  • Loading branch information
FFroehlich committed Apr 12, 2024
1 parent e09bb2f commit d8d1900
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 12 deletions.
2 changes: 2 additions & 0 deletions documentation/rtd_requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ sphinx
mock>=5.0.2
setuptools>=67.7.2
pysb>=1.11.0
jax>=0.4.26
diffrax>=0.5.0
matplotlib==3.7.1
nbsphinx==0.9.1
nbformat==5.8.0
Expand Down
17 changes: 5 additions & 12 deletions python/sdist/amici/__init__.template.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
"""AMICI-generated module for model TPL_MODELNAME"""

from pathlib import Path

from typing import TYPE_CHECKING
import amici

try:
if TYPE_CHECKING:
from amici.jax import JAXModel
except (ModuleNotFoundError, ImportError):
JAXModel = object

# Ensure we are binary-compatible, see #556
if "TPL_AMICI_VERSION" != amici.__version__:
Expand All @@ -23,16 +21,11 @@
from .TPL_MODELNAME import * # noqa: F403, F401
from .TPL_MODELNAME import getModel as get_model # noqa: F401

try:
from .jax import JAXModel_TPL_MODELNAME

def get_jax_model() -> JAXModel:
return JAXModel_TPL_MODELNAME()
except (ModuleNotFoundError, ImportError) as exc:
error = str(exc)
def get_jax_model() -> "JAXModel":
from .jax import JAXModel_TPL_MODELNAME

def get_jax_model() -> JAXModel:
raise NotImplementedError(error)
return JAXModel_TPL_MODELNAME()


__version__ = "TPL_PACKAGE_VERSION"

0 comments on commit d8d1900

Please sign in to comment.