Skip to content

Commit

Permalink
Fix type annotations in swig-wrappers
Browse files Browse the repository at this point in the history
Depending on swig options/versions, swig may or may not generate type annotations. With these changes, type annotations will be extracted from the docstrings where available, independent of the swig-generated type annotations.

Closes #2336.
  • Loading branch information
dweindl committed Mar 4, 2024
1 parent 3a6b0df commit 42fb197
Showing 1 changed file with 42 additions and 1 deletion.
43 changes: 42 additions & 1 deletion python/sdist/amici/swig.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ class TypeHintFixer(ast.NodeTransformer):
"ptrdiff_t": ast.Name("int"),
"size_t": ast.Name("int"),
"bool": ast.Name("bool"),
"boolean": ast.Name("bool"),
"std::unique_ptr< amici::Solver >": ast.Constant("Solver"),
"amici::InternalSensitivityMethod": ast.Constant(
"InternalSensitivityMethod"
Expand All @@ -40,8 +41,10 @@ class TypeHintFixer(ast.NodeTransformer):
"SteadyStateSensitivityMode"
),
"amici::realtype": ast.Name("float"),
"DoubleVector": ast.Constant("Sequence[float]"),
"DoubleVector": ast.Name("Sequence[float]"),
"BoolVector": ast.Name("Sequence[bool]"),
"IntVector": ast.Name("Sequence[int]"),
"StringVector": ast.Name("Sequence[str]"),
"std::string": ast.Name("str"),
"std::string const &": ast.Name("str"),
"std::unique_ptr< amici::ExpData >": ast.Constant("ExpData"),
Expand All @@ -53,6 +56,8 @@ class TypeHintFixer(ast.NodeTransformer):
}

def visit_FunctionDef(self, node):
self._annotation_from_docstring(node)

Check warning on line 59 in python/sdist/amici/swig.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/swig.py#L59

Added line #L59 was not covered by tests

# Has a return type annotation?
if node.returns:
node.returns = self._new_annot(node.returns.value)
Expand Down Expand Up @@ -103,6 +108,42 @@ def _new_annot(self, old_annot: str):

return ast.Constant(old_annot)

def _annotation_from_docstring(self, node: ast.FunctionDef):
"""Add annotations based on docstring.
If any argument or return type of the function is not annotated, but
the corresponding docstring contains a type hint, the type hint is used
as the annotation.
"""
docstring = ast.get_docstring(node, clean=False)
if not docstring or "*Overload 1:*" in docstring:

Check warning on line 119 in python/sdist/amici/swig.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/swig.py#L118-L119

Added lines #L118 - L119 were not covered by tests
# skip overloaded methods
return

Check warning on line 121 in python/sdist/amici/swig.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/swig.py#L121

Added line #L121 was not covered by tests

docstring = docstring.split("\n")
lines_to_remove = set()

Check warning on line 124 in python/sdist/amici/swig.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/swig.py#L123-L124

Added lines #L123 - L124 were not covered by tests

for line_no, line in enumerate(docstring):
if match := re.match(r"\W*:rtype:\W*(.+)", line):
node.returns = ast.Constant(match.group(1))
lines_to_remove.add(line_no)

Check warning on line 129 in python/sdist/amici/swig.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/swig.py#L126-L129

Added lines #L126 - L129 were not covered by tests

if match := re.match(r"\W*:type:\W*(\w+):\W*(.+)", line):
for arg in node.args.args:
if arg.arg == match.group(1):
arg.annotation = ast.Constant(match.group(2))
lines_to_remove.add(line_no)

Check warning on line 135 in python/sdist/amici/swig.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/swig.py#L131-L135

Added lines #L131 - L135 were not covered by tests

if lines_to_remove:

Check warning on line 137 in python/sdist/amici/swig.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/swig.py#L137

Added line #L137 was not covered by tests
# Update docstring with type annotations removed
assert isinstance(node.body[0].value, ast.Constant)
new_docstring = "\n".join(

Check warning on line 140 in python/sdist/amici/swig.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/swig.py#L139-L140

Added lines #L139 - L140 were not covered by tests
line
for line_no, line in enumerate(docstring)
if line_no not in lines_to_remove
)
node.body[0].value = ast.Str(new_docstring)

Check warning on line 145 in python/sdist/amici/swig.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/swig.py#L145

Added line #L145 was not covered by tests


def fix_typehints(infilename, outfilename):
"""Change SWIG-generated C++ typehints to Python typehints"""
Expand Down

0 comments on commit 42fb197

Please sign in to comment.