Skip to content

Commit

Permalink
Fix type annotations in swig-wrappers (again)
Browse files Browse the repository at this point in the history
Some things were missing in #2344
  • Loading branch information
dweindl committed Mar 7, 2024
1 parent 8d65524 commit 964bc22
Showing 1 changed file with 15 additions and 4 deletions.
19 changes: 15 additions & 4 deletions python/sdist/amici/swig.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,8 +112,11 @@ def _annotation_from_docstring(self, node: ast.FunctionDef):
"""Add annotations based on docstring.
If any argument or return type of the function is not annotated, but
the corresponding docstring contains a type hint, the type hint is used
as the annotation.
the corresponding docstring contains a type hint (``:rtype:`` or
``:type:``), the type hint is used as the annotation.
Swig sometimes generates ``:type solver: :py:class:`Solver`` instead of
``:type solver: Solver``. Those need special treatment.
"""
docstring = ast.get_docstring(node, clean=False)
if not docstring or "*Overload 1:*" in docstring:
Expand All @@ -124,11 +127,19 @@ def _annotation_from_docstring(self, node: ast.FunctionDef):
lines_to_remove = set()

for line_no, line in enumerate(docstring):
if match := re.match(r"\W*:rtype:\W*(.+)", line):
if (

Check warning on line 130 in python/sdist/amici/swig.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/swig.py#L130

Added line #L130 was not covered by tests
match := re.match(
r"\s*:rtype:\s*(?::py:class:`)?(\w+)`?\s+$", line
)
) and not match.group(1).startswith(":"):
node.returns = ast.Constant(match.group(1))
lines_to_remove.add(line_no)

if match := re.match(r"\W*:type:\W*(\w+):\W*(.+)", line):
if (

Check warning on line 138 in python/sdist/amici/swig.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/swig.py#L138

Added line #L138 was not covered by tests
match := re.match(
r"\s*:type\s*(\w+):\W*(?::py:class:`)?(\w+)`?\s+$", line
)
) and not match.group(1).startswith(":"):
for arg in node.args.args:
if arg.arg == match.group(1):
arg.annotation = ast.Constant(match.group(2))
Expand Down

0 comments on commit 964bc22

Please sign in to comment.