Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix type annotations in swig-wrappers #2344

Merged
merged 4 commits into from
Mar 6, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 42 additions & 1 deletion python/sdist/amici/swig.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
"ptrdiff_t": ast.Name("int"),
"size_t": ast.Name("int"),
"bool": ast.Name("bool"),
"boolean": ast.Name("bool"),
"std::unique_ptr< amici::Solver >": ast.Constant("Solver"),
"amici::InternalSensitivityMethod": ast.Constant(
"InternalSensitivityMethod"
Expand All @@ -40,8 +41,10 @@
"SteadyStateSensitivityMode"
),
"amici::realtype": ast.Name("float"),
"DoubleVector": ast.Constant("Sequence[float]"),
"DoubleVector": ast.Name("Sequence[float]"),
"BoolVector": ast.Name("Sequence[bool]"),
"IntVector": ast.Name("Sequence[int]"),
"StringVector": ast.Name("Sequence[str]"),
"std::string": ast.Name("str"),
"std::string const &": ast.Name("str"),
"std::unique_ptr< amici::ExpData >": ast.Constant("ExpData"),
Expand All @@ -53,6 +56,8 @@
}

def visit_FunctionDef(self, node):
self._annotation_from_docstring(node)

Check warning on line 59 in python/sdist/amici/swig.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/swig.py#L59

Added line #L59 was not covered by tests

# Has a return type annotation?
if node.returns:
node.returns = self._new_annot(node.returns.value)
Expand Down Expand Up @@ -103,6 +108,42 @@

return ast.Constant(old_annot)

def _annotation_from_docstring(self, node: ast.FunctionDef):
"""Add annotations based on docstring.

If any argument or return type of the function is not annotated, but
the corresponding docstring contains a type hint, the type hint is used
as the annotation.
"""
docstring = ast.get_docstring(node, clean=False)
if not docstring or "*Overload 1:*" in docstring:

Check warning on line 119 in python/sdist/amici/swig.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/swig.py#L118-L119

Added lines #L118 - L119 were not covered by tests
# skip overloaded methods
return

Check warning on line 121 in python/sdist/amici/swig.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/swig.py#L121

Added line #L121 was not covered by tests

docstring = docstring.split("\n")
lines_to_remove = set()

Check warning on line 124 in python/sdist/amici/swig.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/swig.py#L123-L124

Added lines #L123 - L124 were not covered by tests

for line_no, line in enumerate(docstring):
if match := re.match(r"\W*:rtype:\W*(.+)", line):
node.returns = ast.Constant(match.group(1))
lines_to_remove.add(line_no)

Check warning on line 129 in python/sdist/amici/swig.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/swig.py#L126-L129

Added lines #L126 - L129 were not covered by tests

if match := re.match(r"\W*:type:\W*(\w+):\W*(.+)", line):
for arg in node.args.args:
if arg.arg == match.group(1):
arg.annotation = ast.Constant(match.group(2))
lines_to_remove.add(line_no)

Check warning on line 135 in python/sdist/amici/swig.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/swig.py#L131-L135

Added lines #L131 - L135 were not covered by tests

if lines_to_remove:

Check warning on line 137 in python/sdist/amici/swig.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/swig.py#L137

Added line #L137 was not covered by tests
# Update docstring with type annotations removed
assert isinstance(node.body[0].value, ast.Constant)
new_docstring = "\n".join(

Check warning on line 140 in python/sdist/amici/swig.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/swig.py#L139-L140

Added lines #L139 - L140 were not covered by tests
line
for line_no, line in enumerate(docstring)
if line_no not in lines_to_remove
)
node.body[0].value = ast.Str(new_docstring)

Check warning on line 145 in python/sdist/amici/swig.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/swig.py#L145

Added line #L145 was not covered by tests


def fix_typehints(infilename, outfilename):
"""Change SWIG-generated C++ typehints to Python typehints"""
Expand Down
Loading