Skip to content

Commit

Permalink
Get rid of amici.swig_wrappers.ExpData
Browse files Browse the repository at this point in the history
The old implementation prevented using e.g. `isinstance(x, amici.ExpData)`.

Closes AMICI-dev#2380
  • Loading branch information
dweindl committed Apr 16, 2024
1 parent d993362 commit 924d161
Show file tree
Hide file tree
Showing 5 changed files with 58 additions and 59 deletions.
1 change: 1 addition & 0 deletions documentation/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -446,6 +446,7 @@ def install_doxygen():
"std::unique_ptr< amici::ExpData >": "ExpData",
"std::unique_ptr< amici::ReturnData >": "ReturnData",
"std::unique_ptr< amici::Solver >": "Solver",
"amici::realtype const": "float",
"amici::realtype": "float",
}

Expand Down
6 changes: 5 additions & 1 deletion python/sdist/amici/swig.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Functions related to SWIG or SWIG-generated code"""
from __future__ import annotations
import ast
import contextlib
import re
Expand Down Expand Up @@ -74,7 +75,10 @@ def visit_FunctionDef(self, node):
arg.annotation = self._new_annot(arg.annotation.value)
return node

def _new_annot(self, old_annot: str):
def _new_annot(self, old_annot: str | ast.Name):
if isinstance(old_annot, ast.Name):
old_annot = old_annot.id

with contextlib.suppress(KeyError):
return self.mapping[old_annot]

Expand Down
59 changes: 1 addition & 58 deletions python/sdist/amici/swig_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

import amici
import amici.amici as amici_swig

from amici.amici import _get_ptr
from . import numpy
from .logging import get_logger

Expand All @@ -18,7 +18,6 @@
__all__ = [
"runAmiciSimulation",
"runAmiciSimulations",
"ExpData",
"readSolverSettingsFromHDF5",
"writeSolverSettingsToHDF5",
"set_model_settings",
Expand Down Expand Up @@ -54,36 +53,6 @@ def _capture_cstdout():
yield


def _get_ptr(
obj: Union[AmiciModel, AmiciExpData, AmiciSolver, AmiciReturnData],
) -> Union[
"amici_swig.Model",
"amici_swig.ExpData",
"amici_swig.Solver",
"amici_swig.ReturnData",
]:
"""
Convenience wrapper that returns the smart pointer pointee, if applicable
:param obj:
Potential smart pointer
:returns:
Non-smart pointer
"""
if isinstance(
obj,
(
amici_swig.ModelPtr,
amici_swig.ExpDataPtr,
amici_swig.SolverPtr,
amici_swig.ReturnDataPtr,
),
):
return obj.get()
return obj


def runAmiciSimulation(
model: AmiciModel,
solver: AmiciSolver,
Expand Down Expand Up @@ -128,32 +97,6 @@ def runAmiciSimulation(
return numpy.ReturnDataView(rdata)


def ExpData(*args) -> "amici_swig.ExpData":
"""
Convenience wrapper for :py:class:`amici.amici.ExpData` constructors
:param args: arguments
:returns: ExpData Instance
"""
if not args:
return amici_swig.ExpData()

if isinstance(args[0], numpy.ReturnDataView):
return amici_swig.ExpData(_get_ptr(args[0]["ptr"]), *args[1:])

if isinstance(args[0], (amici_swig.ExpData, amici_swig.ExpDataPtr)):
# the *args[:1] should be empty, but by the time you read this,
# the constructor signature may have changed, and you are glad this
# wrapper did not break.
return amici_swig.ExpData(_get_ptr(args[0]), *args[1:])

if isinstance(args[0], (amici_swig.Model, amici_swig.ModelPtr)):
return amici_swig.ExpData(_get_ptr(args[0]))

return amici_swig.ExpData(*args)


def runAmiciSimulations(
model: AmiciModel,
solver: AmiciSolver,
Expand Down
33 changes: 33 additions & 0 deletions swig/amici.i
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,8 @@ def __repr__(self):

// Handle AMICI_DLL_DIRS environment variable
%pythonbegin %{
from __future__ import annotations

import sys
import os

Expand Down Expand Up @@ -368,4 +370,35 @@ __all__ = [
if not x.startswith('_')
and x not in {"np", "sys", "os", "numpy", "IntEnum", "enum", "pi", "TYPE_CHECKING", "Iterable", "Sequence"}
]


def _get_ptr(
obj: Union[AmiciModel, AmiciExpData, AmiciSolver, AmiciReturnData],
) -> Union[
Model,
ExpData,
Solver,
ReturnData,
]:
"""
Convenience wrapper that returns the smart pointer pointee, if applicable
:param obj:
Potential smart pointer
:returns:
Non-smart pointer
"""
if isinstance(
obj,
(
ModelPtr,
ExpDataPtr,
SolverPtr,
ReturnDataPtr,
),
):
return obj.get()
return obj

%}
18 changes: 18 additions & 0 deletions swig/edata.i
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,24 @@ using namespace amici;

%ignore ConditionContext;

%feature("pythonprepend") amici::ExpData::ExpData %{
"""
Convenience wrapper for :py:class:`amici.amici.ExpData` constructors
:param args: arguments
:returns: ExpData Instance
"""
if args:
from amici.numpy import ReturnDataView

# Get the raw pointer if necessary
if isinstance(args[0], (ExpData, ExpDataPtr, Model, ModelPtr)):
args = (_get_ptr(args[0]), *args[1:])
elif isinstance(args[0], ReturnDataView):
args = (_get_ptr(args[0]["ptr"]), *args[1:])
%}

// ExpData.__repr__
%pythoncode %{
def _edata_repr(self: "ExpData"):
Expand Down

0 comments on commit 924d161

Please sign in to comment.