From 42fb197ebd8092b4a11d766ffe6b59e4ffab50da Mon Sep 17 00:00:00 2001 From: Daniel Weindl Date: Mon, 4 Mar 2024 15:34:02 +0100 Subject: [PATCH] Fix type annotations in swig-wrappers Depending on swig options/versions, swig may or may not generate type annotations. With these changes, type annotations will be extracted from the docstrings where available, independent of the swig-generated type annotations. Closes #2336. --- python/sdist/amici/swig.py | 43 +++++++++++++++++++++++++++++++++++++- 1 file changed, 42 insertions(+), 1 deletion(-) diff --git a/python/sdist/amici/swig.py b/python/sdist/amici/swig.py index 902145ff3e..81a030ba3d 100644 --- a/python/sdist/amici/swig.py +++ b/python/sdist/amici/swig.py @@ -15,6 +15,7 @@ class TypeHintFixer(ast.NodeTransformer): "ptrdiff_t": ast.Name("int"), "size_t": ast.Name("int"), "bool": ast.Name("bool"), + "boolean": ast.Name("bool"), "std::unique_ptr< amici::Solver >": ast.Constant("Solver"), "amici::InternalSensitivityMethod": ast.Constant( "InternalSensitivityMethod" @@ -40,8 +41,10 @@ class TypeHintFixer(ast.NodeTransformer): "SteadyStateSensitivityMode" ), "amici::realtype": ast.Name("float"), - "DoubleVector": ast.Constant("Sequence[float]"), + "DoubleVector": ast.Name("Sequence[float]"), + "BoolVector": ast.Name("Sequence[bool]"), "IntVector": ast.Name("Sequence[int]"), + "StringVector": ast.Name("Sequence[str]"), "std::string": ast.Name("str"), "std::string const &": ast.Name("str"), "std::unique_ptr< amici::ExpData >": ast.Constant("ExpData"), @@ -53,6 +56,8 @@ class TypeHintFixer(ast.NodeTransformer): } def visit_FunctionDef(self, node): + self._annotation_from_docstring(node) + # Has a return type annotation? if node.returns: node.returns = self._new_annot(node.returns.value) @@ -103,6 +108,42 @@ def _new_annot(self, old_annot: str): return ast.Constant(old_annot) + def _annotation_from_docstring(self, node: ast.FunctionDef): + """Add annotations based on docstring. + + If any argument or return type of the function is not annotated, but + the corresponding docstring contains a type hint, the type hint is used + as the annotation. + """ + docstring = ast.get_docstring(node, clean=False) + if not docstring or "*Overload 1:*" in docstring: + # skip overloaded methods + return + + docstring = docstring.split("\n") + lines_to_remove = set() + + for line_no, line in enumerate(docstring): + if match := re.match(r"\W*:rtype:\W*(.+)", line): + node.returns = ast.Constant(match.group(1)) + lines_to_remove.add(line_no) + + if match := re.match(r"\W*:type:\W*(\w+):\W*(.+)", line): + for arg in node.args.args: + if arg.arg == match.group(1): + arg.annotation = ast.Constant(match.group(2)) + lines_to_remove.add(line_no) + + if lines_to_remove: + # Update docstring with type annotations removed + assert isinstance(node.body[0].value, ast.Constant) + new_docstring = "\n".join( + line + for line_no, line in enumerate(docstring) + if line_no not in lines_to_remove + ) + node.body[0].value = ast.Str(new_docstring) + def fix_typehints(infilename, outfilename): """Change SWIG-generated C++ typehints to Python typehints"""