Skip to content

Commit

Permalink
Fix type annotations in swig-wrappers (again)
Browse files Browse the repository at this point in the history
Some things were missing in AMICI-dev#2344
  • Loading branch information
dweindl committed Mar 7, 2024
1 parent 8d65524 commit 1388009
Showing 1 changed file with 16 additions and 5 deletions.
21 changes: 16 additions & 5 deletions python/sdist/amici/swig.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def visit_FunctionDef(self, node):
self._annotation_from_docstring(node)

# Has a return type annotation?
if node.returns:
if node.returns and isinstance(node.returns, ast.Constant):
node.returns = self._new_annot(node.returns.value)

# Has arguments?
Expand Down Expand Up @@ -112,8 +112,11 @@ def _annotation_from_docstring(self, node: ast.FunctionDef):
"""Add annotations based on docstring.
If any argument or return type of the function is not annotated, but
the corresponding docstring contains a type hint, the type hint is used
as the annotation.
the corresponding docstring contains a type hint (``:rtype:`` or
``:type:``), the type hint is used as the annotation.
Swig sometimes generates ``:type solver: :py:class:`Solver`` instead of
``:type solver: Solver``. Those need special treatment.
"""
docstring = ast.get_docstring(node, clean=False)
if not docstring or "*Overload 1:*" in docstring:
Expand All @@ -124,11 +127,19 @@ def _annotation_from_docstring(self, node: ast.FunctionDef):
lines_to_remove = set()

for line_no, line in enumerate(docstring):
if match := re.match(r"\W*:rtype:\W*(.+)", line):
if (
match := re.match(
r"\s*:rtype:\s*(?::py:class:`)?(\w+)`?\s+$", line
)
) and not match.group(1).startswith(":"):
node.returns = ast.Constant(match.group(1))
lines_to_remove.add(line_no)

if match := re.match(r"\W*:type:\W*(\w+):\W*(.+)", line):
if (
match := re.match(
r"\s*:type\s*(\w+):\W*(?::py:class:`)?(\w+)`?\s+$", line
)
) and not match.group(1).startswith(":"):
for arg in node.args.args:
if arg.arg == match.group(1):
arg.annotation = ast.Constant(match.group(2))
Expand Down

0 comments on commit 1388009

Please sign in to comment.