diff --git a/python/sdist/amici/swig.py b/python/sdist/amici/swig.py index 49d1ada4a2..5ba8017005 100644 --- a/python/sdist/amici/swig.py +++ b/python/sdist/amici/swig.py @@ -6,7 +6,7 @@ class TypeHintFixer(ast.NodeTransformer): - """Replaces SWIG-generated C++ typehints by corresponding Python types""" + """Replaces SWIG-generated C++ typehints by corresponding Python types.""" mapping = { "void": None, @@ -58,6 +58,9 @@ class TypeHintFixer(ast.NodeTransformer): } def visit_FunctionDef(self, node): + # convert type/rtype from docstring to annotation, if possible. + # those may be c++ types, not valid in python, that need to be + # converted to python types below. self._annotation_from_docstring(node) # Has a return type annotation? @@ -122,6 +125,8 @@ def _annotation_from_docstring(self, node: ast.FunctionDef): Swig sometimes generates ``:type solver: :py:class:`Solver`` instead of ``:type solver: Solver``. Those need special treatment. + + Overloaded functions are skipped. """ docstring = ast.get_docstring(node, clean=False) if not docstring or "*Overload 1:*" in docstring: @@ -158,13 +163,15 @@ def _annotation_from_docstring(self, node: ast.FunctionDef): @staticmethod def extract_type(line: str) -> tuple[str, str] | tuple[None, None]: - """Extract argument name and type string from ``:type:`` docstring line.""" + """Extract argument name and type string from ``:type:`` docstring + line.""" match = re.match(r"\s*:type\s+(\w+):\s+(.+?)(?:, optional)?\s*$", line) if not match: return None, None arg_name = match.group(1) + # get rid of any :py:class`...` in the type string if necessary if not match.group(2).startswith(":py:"): return arg_name, match.group(2) @@ -179,6 +186,7 @@ def extract_rtype(line: str) -> str | None: if not match: return None + # get rid of any :py:class`...` in the type string if necessary if not match.group(1).startswith(":py:"): return match.group(1)