Skip to content

Commit

Permalink
..
Browse files Browse the repository at this point in the history
  • Loading branch information
dweindl committed Apr 16, 2024
1 parent 3990e71 commit ea40015
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 31 deletions.
32 changes: 1 addition & 31 deletions python/sdist/amici/swig_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

import amici
import amici.amici as amici_swig

from amici.amici import _get_ptr
from . import numpy
from .logging import get_logger

Expand Down Expand Up @@ -53,36 +53,6 @@ def _capture_cstdout():
yield


def _get_ptr(
obj: Union[AmiciModel, AmiciExpData, AmiciSolver, AmiciReturnData],
) -> Union[
"amici_swig.Model",
"amici_swig.ExpData",
"amici_swig.Solver",
"amici_swig.ReturnData",
]:
"""
Convenience wrapper that returns the smart pointer pointee, if applicable
:param obj:
Potential smart pointer
:returns:
Non-smart pointer
"""
if isinstance(
obj,
(
amici_swig.ModelPtr,
amici_swig.ExpDataPtr,
amici_swig.SolverPtr,
amici_swig.ReturnDataPtr,
),
):
return obj.get()
return obj


def runAmiciSimulation(
model: AmiciModel,
solver: AmiciSolver,
Expand Down
33 changes: 33 additions & 0 deletions swig/amici.i
Original file line number Diff line number Diff line change
Expand Up @@ -353,6 +353,8 @@ if sys.platform == 'win32' and (dll_dirs := os.environ.get('AMICI_DLL_DIRS')):
// import additional types for typehints
// also import np for use in __repr__ functions
%pythonbegin %{
from __future__ import annotations

from typing import TYPE_CHECKING, Iterable, Sequence
import numpy as np
if TYPE_CHECKING:
Expand All @@ -368,4 +370,35 @@ __all__ = [
if not x.startswith('_')
and x not in {"np", "sys", "os", "numpy", "IntEnum", "enum", "pi", "TYPE_CHECKING", "Iterable", "Sequence"}
]


def _get_ptr(
obj: Union[AmiciModel, AmiciExpData, AmiciSolver, AmiciReturnData],
) -> Union[
Model,
ExpData,
Solver,
ReturnData,
]:
"""
Convenience wrapper that returns the smart pointer pointee, if applicable
:param obj:
Potential smart pointer
:returns:
Non-smart pointer
"""
if isinstance(
obj,
(
ModelPtr,
ExpDataPtr,
SolverPtr,
ReturnDataPtr,
),
):
return obj.get()
return obj

%}

0 comments on commit ea40015

Please sign in to comment.