Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
..
Browse files Browse the repository at this point in the history
dweindl committed Apr 16, 2024
1 parent bbab4f2 commit eb4d7e7
Showing 1 changed file with 32 additions and 34 deletions.
66 changes: 32 additions & 34 deletions python/sdist/amici/swig.py
Original file line number Diff line number Diff line change
@@ -128,45 +128,17 @@ def _annotation_from_docstring(self, node: ast.FunctionDef):
# skip overloaded methods
return

def extract_type(line: str) -> tuple[str, str] | tuple[None, None]:
"""Extract argument name and type string from :type: docstring line."""
match = re.match(r"\s*:type\s+(\w+):\s+(.+)\s*$", line)
if not match:
return None, None

arg_name = match.group(1)

if not match.group(2).startswith(":py:"):
return arg_name, match.group(2)

match = re.match(r":py:\w+:`(.+)`", match.group(2))
assert match
return arg_name, match.group(1)

def extract_rtype(line: str) -> str | None:
"""Extract type string from :rtype: docstring line."""
match = re.match(r"\s*:rtype:\s+(.+)\s*$", line)
if not match:
return None

if not match.group(1).startswith(":py:"):
return match.group(1)

match = re.match(r":py:\w+:`(.+)`", match.group(1))
assert match
return match.group(1)

docstring = docstring.split("\n")
lines_to_remove = set()

for line_no, line in enumerate(docstring):
if type_str := extract_rtype(line):
if type_str := self.extract_rtype(line):
# handle `:rtype:`
node.returns = ast.Constant(type_str)
lines_to_remove.add(line_no)
continue

arg_name, type_str = extract_type(line)
arg_name, type_str = self.extract_type(line)
if arg_name is not None:
# handle `:type ...:`
for arg in node.args.args:
@@ -184,13 +156,39 @@ def extract_rtype(line: str) -> str | None:
)
node.body[0].value = ast.Str(new_docstring)

@staticmethod
def extract_type(line: str) -> tuple[str, str] | tuple[None, None]:
"""Extract argument name and type string from ``:type:`` docstring line."""
match = re.match(r"\s*:type\s+(\w+):\s+(.+)\s*$", line)
if not match:
return None, None

arg_name = match.group(1)

if not match.group(2).startswith(":py:"):
return arg_name, match.group(2)

match = re.match(r":py:\w+:`(.+)`", match.group(2))
assert match
return arg_name, match.group(1)

@staticmethod
def extract_rtype(line: str) -> str | None:
"""Extract type string from ``:rtype:`` docstring line."""
match = re.match(r"\s*:rtype:\s+(.+)\s*$", line)
if not match:
return None

if not match.group(1).startswith(":py:"):
return match.group(1)

match = re.match(r":py:\w+:`(.+)`", match.group(1))
assert match
return match.group(1)


def fix_typehints(infilename, outfilename):
"""Change SWIG-generated C++ typehints to Python typehints"""
# Only available from Python3.9
if not getattr(ast, "unparse", None):
return

# file -> AST
with open(infilename) as f:
source = f.read()

0 comments on commit eb4d7e7

Please sign in to comment.