Skip to content

Commit

Permalink
..
Browse files Browse the repository at this point in the history
  • Loading branch information
dweindl committed Apr 16, 2024
1 parent ccecfe4 commit bbab4f2
Showing 1 changed file with 38 additions and 13 deletions.
51 changes: 38 additions & 13 deletions python/sdist/amici/swig.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ class TypeHintFixer(ast.NodeTransformer):
"std::allocator< amici::ParameterScaling > > const &": ast.Constant(
"ParameterScalingVector"
),
"H5::H5File": None,
}

def visit_FunctionDef(self, node):
Expand Down Expand Up @@ -127,26 +128,50 @@ def _annotation_from_docstring(self, node: ast.FunctionDef):
# skip overloaded methods
return

def extract_type(line: str) -> tuple[str, str] | tuple[None, None]:
"""Extract argument name and type string from :type: docstring line."""
match = re.match(r"\s*:type\s+(\w+):\s+(.+)\s*$", line)
if not match:
return None, None

arg_name = match.group(1)

if not match.group(2).startswith(":py:"):
return arg_name, match.group(2)

match = re.match(r":py:\w+:`(.+)`", match.group(2))
assert match
return arg_name, match.group(1)

def extract_rtype(line: str) -> str | None:
"""Extract type string from :rtype: docstring line."""
match = re.match(r"\s*:rtype:\s+(.+)\s*$", line)
if not match:
return None

if not match.group(1).startswith(":py:"):
return match.group(1)

match = re.match(r":py:\w+:`(.+)`", match.group(1))
assert match
return match.group(1)

docstring = docstring.split("\n")
lines_to_remove = set()

for line_no, line in enumerate(docstring):
if (
match := re.match(
r"\s*:rtype:\s*(?::py:class:`)?(.+)`?\s*$", line
)
) and not match.group(1).startswith(":"):
node.returns = ast.Constant(match.group(1))
if type_str := extract_rtype(line):
# handle `:rtype:`
node.returns = ast.Constant(type_str)
lines_to_remove.add(line_no)
continue

if (
match := re.match(
r"\s*:type\s+(\w+):\s+(?::py:class:`)?(.+)`?\s*$", line
)
) and not match.group(1).startswith(":"):
arg_name, type_str = extract_type(line)
if arg_name is not None:
# handle `:type ...:`
for arg in node.args.args:
if arg.arg == match.group(1):
arg.annotation = ast.Constant(match.group(2))
if arg.arg == arg_name:
arg.annotation = ast.Constant(type_str)
lines_to_remove.add(line_no)

if lines_to_remove:
Expand Down

0 comments on commit bbab4f2

Please sign in to comment.