Skip to content

Commit

Permalink
Refactor parsing of function CodeInput using ast
Browse files Browse the repository at this point in the history
With `ast` we have to use less regex tricks and can rely that the
function is properly parsed.

Docstring is now also considering a linebreak at the beginning.

Annotations of input arguments are now supported.

Default arguments are now supported. We even support expressions as
default arguments e.g lambda functions.

Arbitrary keyword arguments now supported.
  • Loading branch information
agoscinski committed Dec 11, 2024
1 parent 934a46a commit c38e445
Show file tree
Hide file tree
Showing 2 changed files with 123 additions and 17 deletions.
88 changes: 82 additions & 6 deletions src/scwidgets/code/_widget_code_input.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import ast
import inspect
import re
import sys
import textwrap
import traceback
import types
import warnings
from functools import wraps
from typing import List, Optional
from typing import List, Optional, Tuple

from widget_code_input import WidgetCodeInput
from widget_code_input.utils import (
Expand All @@ -20,6 +22,18 @@
class CodeInput(WidgetCodeInput):
"""
Small wrapper around WidgetCodeInput that controls the output
:param function: We can automatically parse the function. Note that during
parsing the source code might be differently formatted and certain
python functionalities are not formatted. If you notice undesired
changes by the parsing, please directly specify the function as string
using the other parameters.
:param function_name: The name of the function
:param function_paramaters: The parameters as continuous string as specified in
the signature of the function. e.g for `foo(x, y = 5)` it should be
`"x, y = 5"`
:param docstring: The docstring of the function
:param function_body: The function definition without indentation
"""

valid_code_themes = ["nord", "solarizedLight", "basicLight"]
Expand All @@ -38,13 +52,15 @@ def __init__(
function.__name__ if function_name is None else function_name
)
function_parameters = (
", ".join(inspect.getfullargspec(function).args)
self.get_function_parameters(function)
if function_parameters is None
else function_parameters
)
docstring = inspect.getdoc(function) if docstring is None else docstring
docstring = self.get_docstring(function) if docstring is None else docstring
function_body = (
self.get_code(function) if function_body is None else function_body
self.get_function_body(function)
if function_body is None
else function_body
)

# default parameters from WidgetCodeInput
Expand Down Expand Up @@ -105,8 +121,68 @@ def function_parameters_name(self) -> List[str]:
return self.function_parameters.replace(",", "").split(" ")

@staticmethod
def get_code(func: types.FunctionType) -> str:
source_lines, _ = inspect.getsourcelines(func)
def get_docstring(function: types.FunctionType) -> str:
docstring = function.__doc__
return "" if docstring is None else textwrap.dedent(docstring)

@staticmethod
def _get_function_source_and_def(
function: types.FunctionType,
) -> Tuple[str, ast.FunctionDef]:
function_source = inspect.getsource(function)
function_source = textwrap.dedent(function_source)
module = ast.parse(function_source)
if len(module.body) != 1:
raise ValueError(
f"Expected code with one function definition but found {module.body}"
)
function_definition = module.body[0]
if not isinstance(function_definition, ast.FunctionDef):
raise ValueError(
f"While parsing code found {module.body[0]}"
" but only ast.FunctionDef is supported."
)
return function_source, function_definition

@staticmethod
def get_function_parameters(function: types.FunctionType) -> str:
function_parameters = []
function_source, function_definition = CodeInput._get_function_source_and_def(
function
)
idx_start_defaults = len(function_definition.args.args) - len(
function_definition.args.defaults
)
for i, arg in enumerate(function_definition.args.args):
function_parameter = ast.get_source_segment(function_source, arg)
# Following PEP 8 in formatting
if arg.annotation:
annotation = function_parameter = ast.get_source_segment(
function_source, arg.annotation
)
function_parameter = f"{arg.arg}: {annotation}"
else:
function_parameter = f"{arg.arg}"
if i >= idx_start_defaults:
default_val = ast.get_source_segment(
function_source,
function_definition.args.defaults[i - idx_start_defaults],
)
# Following PEP 8 in formatting
if arg.annotation:
function_parameter = f"{function_parameter} = {default_val}"
else:
function_parameter = f"{function_parameter}={default_val}"
function_parameters.append(function_parameter)

if function_definition.args.kwarg is not None:
function_parameters.append(f"**{function_definition.args.kwarg.arg}")

return ", ".join(function_parameters)

@staticmethod
def get_function_body(function: types.FunctionType) -> str:
source_lines, _ = inspect.getsourcelines(function)

found_def = False
def_index = 0
Expand Down
52 changes: 41 additions & 11 deletions tests/test_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,15 @@ def mock_function_0():
return 0

@staticmethod
def mock_function_1(x, y):
def mock_function_1(x: int, y: int = 5, z=lambda: 0):
"""
This is an example function.
It adds two numbers.
"""
if x > 0:
return x + y
else:
return y
return y + z()

@staticmethod
def mock_function_2(x):
Expand All @@ -53,26 +53,56 @@ def x():
@staticmethod
def mock_function_6(x: List[int]) -> List[int]:
return x

@staticmethod
def mock_function_7(x, **kwargs):
return kwargs
# fmt: on

def test_get_code(self):
def test_get_function_paramaters(self):
assert (
CodeInput.get_function_parameters(self.mock_function_1)
== "x: int, y: int = 5, z=lambda: 0"
)
assert CodeInput.get_function_parameters(self.mock_function_2) == "x"
assert CodeInput.get_function_parameters(self.mock_function_6) == "x: List[int]"
assert CodeInput.get_function_parameters(self.mock_function_7) == "x, **kwargs"

def test_get_docstring(self):
assert (
CodeInput.get_docstring(self.mock_function_1)
== "\nThis is an example function.\nIt adds two numbers.\n"
)
assert (
CodeInput.get_docstring(self.mock_function_2)
== "This is an example function. It adds two numbers."
)
assert (
CodeInput.get_docstring(self.mock_function_2)
== "This is an example function. It adds two numbers."
)

def test_get_function_body(self):
assert (
CodeInput.get_function_body(self.mock_function_1)
== "if x > 0:\n return x + y\nelse:\n return y + z()\n"
)
assert CodeInput.get_function_body(self.mock_function_2) == "return x\n"
assert CodeInput.get_function_body(self.mock_function_3) == "return x\n"
assert (
CodeInput.get_code(self.mock_function_1)
== "if x > 0:\n return x + y\nelse:\n return y\n"
CodeInput.get_function_body(self.mock_function_4)
== "return x # noqa: E702\n"
)
assert CodeInput.get_code(self.mock_function_2) == "return x\n"
assert CodeInput.get_code(self.mock_function_3) == "return x\n"
assert CodeInput.get_code(self.mock_function_4) == "return x # noqa: E702\n"
assert (
CodeInput.get_code(self.mock_function_5)
CodeInput.get_function_body(self.mock_function_5)
== "def x():\n return 5\nreturn x()\n"
)
assert CodeInput.get_code(self.mock_function_6) == "return x\n"
assert CodeInput.get_function_body(self.mock_function_6) == "return x\n"
with pytest.raises(
ValueError,
match=r"Did not find any def definition. .*",
):
CodeInput.get_code(lambda x: x)
CodeInput.get_function_body(lambda x: x)

def test_invalid_code_theme_raises_error(self):
with pytest.raises(
Expand Down

0 comments on commit c38e445

Please sign in to comment.