Skip to content

Commit

Permalink
Move the amici.swig_wrappers.ExpData logic into the swig-generated …
Browse files Browse the repository at this point in the history
…`ExpData`.

This avoids shading the `ExpData` class by  `amici.swig_wrappers.ExpData`.
The old implementation prevented using e.g. `isinstance(x, amici.ExpData)`.

This required also:
* Moving some annotation types and `_get_ptr` from `amici.swig_wrappers` to the swig-generated `amici.py`
* Some smaller changes to allow for using `from __future__ import annotations` in `amici.py`

Closes AMICI-dev#2380
  • Loading branch information
dweindl committed Apr 20, 2024
1 parent e8493cf commit 7e7c3d7
Show file tree
Hide file tree
Showing 5 changed files with 65 additions and 78 deletions.
8 changes: 4 additions & 4 deletions documentation/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -558,15 +558,15 @@ def fix_typehints(sig: str) -> str:
sig = sig.replace("sunindextype", "int")
sig = sig.replace("H5::H5File", "object")

# remove const
sig = sig.replace(" const ", r" ")
sig = re.sub(r" const$", r"", sig)
# remove const / const&
sig = sig.replace(" const&? ", r" ")
sig = re.sub(r" const&?$", r"", sig)

# remove pass by reference
sig = re.sub(r" &(,|\))", r"\1", sig)
sig = re.sub(r" &$", r"", sig)

# turn gsl_spans and pointers int Iterables
# turn gsl_spans and pointers into Iterables
sig = re.sub(r"([\w.]+) \*", r"Iterable[\1]", sig)
sig = re.sub(r"gsl::span< ([\w.]+) >", r"Iterable[\1]", sig)

Expand Down
2 changes: 0 additions & 2 deletions python/sdist/amici/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,8 +134,6 @@ def get_model(self) -> amici.Model:
"""Create a model instance."""
...

AmiciModel = Union[amici.Model, amici.ModelPtr]


class add_path:
"""Context manager for temporarily changing PYTHONPATH"""
Expand Down
78 changes: 7 additions & 71 deletions python/sdist/amici/swig_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,16 @@
import warnings
from contextlib import contextmanager, suppress
from typing import Any, Optional, Union
from collections.abc import Sequence

import amici
import amici.amici as amici_swig

from amici.amici import (
_get_ptr,
AmiciExpData,
AmiciExpDataVector,
AmiciModel,
AmiciSolver,
)
from . import numpy
from .logging import get_logger

Expand All @@ -18,25 +23,12 @@
__all__ = [
"runAmiciSimulation",
"runAmiciSimulations",
"ExpData",
"readSolverSettingsFromHDF5",
"writeSolverSettingsToHDF5",
"set_model_settings",
"get_model_settings",
"AmiciModel",
"AmiciSolver",
"AmiciExpData",
"AmiciReturnData",
"AmiciExpDataVector",
]

AmiciModel = Union["amici.Model", "amici.ModelPtr"]
AmiciSolver = Union["amici.Solver", "amici.SolverPtr"]
AmiciExpData = Union["amici.ExpData", "amici.ExpDataPtr"]
AmiciReturnData = Union["amici.ReturnData", "amici.ReturnDataPtr"]
AmiciExpDataVector = Union["amici.ExpDataPtrVector", Sequence[AmiciExpData]]


try:
from wurlitzer import sys_pipes
except ModuleNotFoundError:
Expand All @@ -54,36 +46,6 @@ def _capture_cstdout():
yield


def _get_ptr(
obj: Union[AmiciModel, AmiciExpData, AmiciSolver, AmiciReturnData],
) -> Union[
"amici_swig.Model",
"amici_swig.ExpData",
"amici_swig.Solver",
"amici_swig.ReturnData",
]:
"""
Convenience wrapper that returns the smart pointer pointee, if applicable
:param obj:
Potential smart pointer
:returns:
Non-smart pointer
"""
if isinstance(
obj,
(
amici_swig.ModelPtr,
amici_swig.ExpDataPtr,
amici_swig.SolverPtr,
amici_swig.ReturnDataPtr,
),
):
return obj.get()
return obj


def runAmiciSimulation(
model: AmiciModel,
solver: AmiciSolver,
Expand Down Expand Up @@ -128,32 +90,6 @@ def runAmiciSimulation(
return numpy.ReturnDataView(rdata)


def ExpData(*args) -> "amici_swig.ExpData":
"""
Convenience wrapper for :py:class:`amici.amici.ExpData` constructors
:param args: arguments
:returns: ExpData Instance
"""
if not args:
return amici_swig.ExpData()

if isinstance(args[0], numpy.ReturnDataView):
return amici_swig.ExpData(_get_ptr(args[0]["ptr"]), *args[1:])

if isinstance(args[0], (amici_swig.ExpData, amici_swig.ExpDataPtr)):
# the *args[:1] should be empty, but by the time you read this,
# the constructor signature may have changed, and you are glad this
# wrapper did not break.
return amici_swig.ExpData(_get_ptr(args[0]), *args[1:])

if isinstance(args[0], (amici_swig.Model, amici_swig.ModelPtr)):
return amici_swig.ExpData(_get_ptr(args[0]))

return amici_swig.ExpData(*args)


def runAmiciSimulations(
model: AmiciModel,
solver: AmiciSolver,
Expand Down
37 changes: 36 additions & 1 deletion swig/amici.i
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,8 @@ def __repr__(self):

// Handle AMICI_DLL_DIRS environment variable
%pythonbegin %{
from __future__ import annotations

import sys
import os

Expand All @@ -353,19 +355,52 @@ if sys.platform == 'win32' and (dll_dirs := os.environ.get('AMICI_DLL_DIRS')):
// import additional types for typehints
// also import np for use in __repr__ functions
%pythonbegin %{
from typing import TYPE_CHECKING, Iterable, Sequence
from typing import TYPE_CHECKING, Iterable, Union
from collections.abc import Sequence
import numpy as np
if TYPE_CHECKING:
import numpy
%}

%pythoncode %{

AmiciModel = Union[Model, ModelPtr]
AmiciSolver = Union[Solver, SolverPtr]
AmiciExpData = Union[ExpData, ExpDataPtr]
AmiciReturnData = Union[ReturnData, ReturnDataPtr]
AmiciExpDataVector = Union[ExpDataPtrVector, Sequence[AmiciExpData]]


def _get_ptr(
obj: AmiciModel | AmiciExpData | AmiciSolver | AmiciReturnData,
) -> Model | ExpData | Solver | ReturnData:
"""
Convenience wrapper that returns the smart pointer pointee, if applicable
:param obj:
Potential smart pointer
:returns:
Non-smart pointer
"""
if isinstance(
obj,
(
ModelPtr,
ExpDataPtr,
SolverPtr,
ReturnDataPtr,
),
):
return obj.get()
return obj


__all__ = [
x
for x in dir(sys.modules[__name__])
if not x.startswith('_')
and x not in {"np", "sys", "os", "numpy", "IntEnum", "enum", "pi", "TYPE_CHECKING", "Iterable", "Sequence"}
]

%}
18 changes: 18 additions & 0 deletions swig/edata.i
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,24 @@ using namespace amici;

%ignore ConditionContext;

%feature("pythonprepend") amici::ExpData::ExpData %{
"""
Convenience wrapper for :py:class:`amici.amici.ExpData` constructors
:param args: arguments
:returns: ExpData Instance
"""
if args:
from amici.numpy import ReturnDataView

# Get the raw pointer if necessary
if isinstance(args[0], (ExpData, ExpDataPtr, Model, ModelPtr)):
args = (_get_ptr(args[0]), *args[1:])
elif isinstance(args[0], ReturnDataView):
args = (_get_ptr(args[0]["ptr"]), *args[1:])
%}

// ExpData.__repr__
%pythoncode %{
def _edata_repr(self: "ExpData"):
Expand Down

0 comments on commit 7e7c3d7

Please sign in to comment.