Skip to content

Commit

Permalink
changes
Browse files Browse the repository at this point in the history
  • Loading branch information
abidlabs committed Feb 26, 2025
1 parent 0e5fb9b commit 8c97900
Show file tree
Hide file tree
Showing 2 changed files with 104 additions and 4 deletions.
72 changes: 70 additions & 2 deletions groovy/transpiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -527,9 +527,9 @@ def transpile(fn: Callable, validate: bool = False) -> str:
if sig.parameters:
param_names = list(sig.parameters.keys())
raise TranspilerError(
message=f"Function must take no arguments, but got: {param_names}"
message=f"Function must take no arguments for client-side use, but got: {param_names}"
)

try:
source = inspect.getsource(fn)
source = textwrap.dedent(source)
Expand All @@ -543,6 +543,37 @@ def transpile(fn: Callable, validate: bool = False) -> str:
except SyntaxError as e:
raise TranspilerError(message="Could not parse function source.") from e

if validate:
try:
import gradio
has_gradio = True
except ImportError:
has_gradio = False
raise TranspilerError(message="Gradio must be installed for validation.")

func_node = None
for node in ast.walk(tree):
if isinstance(node, ast.FunctionDef) and node.name == fn.__name__:
func_node = node
break

if func_node:
return_nodes = []
for node in ast.walk(func_node):
if isinstance(node, ast.Return) and node.value is not None:
return_nodes.append(node)

if not return_nodes:
raise TranspilerError(message="Function must return Gradio component updates, but no return statement found.")

for return_node in return_nodes:
if not _is_valid_gradio_return(return_node.value):
line_no = return_node.lineno
line_text = source.splitlines()[line_no - 1].strip()
raise TranspilerError(
message=f"Function must only return Gradio component updates. Invalid return at line {line_no}: {line_text}"
)

func_node = None
for node in ast.walk(tree):
if isinstance(node, (ast.FunctionDef, ast.Lambda)):
Expand Down Expand Up @@ -575,6 +606,43 @@ def transpile(fn: Callable, validate: bool = False) -> str:
return "\n".join(visitor.js_lines)


def _is_valid_gradio_return(node: ast.AST) -> bool:
"""
Check if a return value is a valid Gradio component or collection of components.
Args:
node: The AST node representing the return value
Returns:
bool: True if the return value is valid, False otherwise
"""
# Check for direct Gradio component call
if isinstance(node, ast.Call):
if isinstance(node.func, ast.Attribute) and isinstance(node.func.value, ast.Name):
# Check for gr.Component or gradio.Component
if node.func.value.id in {"gr", "gradio"}:
return True
elif isinstance(node.func, ast.Name):
# Check for direct Component call if imported
return True

# Check for tuple or list of Gradio components
elif isinstance(node, (ast.Tuple, ast.List)):
# Empty tuple/list is not valid
if not node.elts:
return False
# All elements must be valid Gradio returns
return all(_is_valid_gradio_return(elt) for elt in node.elts)

# Check for variable that might be a Gradio component
elif isinstance(node, ast.Name):
# We can't easily determine if a variable is a Gradio component
# without executing the code, so we'll allow it
return True

return False


# === Example Usage ===

if __name__ == "__main__":
Expand Down
36 changes: 34 additions & 2 deletions tests/test_transpiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,6 @@ def test_validate_no_arguments():
def no_args_function():
return gradio.Textbox(placeholder="This is valid")

# This should pass validation
result = transpile(no_args_function, validate=True)
expected = """function no_args_function() {
return {"placeholder": "This is valid", "__type__": "update"};
Expand All @@ -237,8 +236,41 @@ def test_validate_with_arguments():
def function_with_args(text_input):
return gradio.Textbox(placeholder=f"You entered: {text_input}")

# This should fail validation
with pytest.raises(TranspilerError) as e:
transpile(function_with_args, validate=True)

assert "text_input" in str(e.value)


def test_validate_non_gradio_return():
def invalid_return_function():
return "This is not a Gradio component"

with pytest.raises(TranspilerError) as e:
transpile(invalid_return_function, validate=True)

assert "Function must only return Gradio component updates" in str(e.value)


def test_validate_mixed_return_paths():
def mixed_return_function():
if 5:
return gradio.Textbox(placeholder="Valid path")
else:
return "Invalid path"

with pytest.raises(TranspilerError) as e:
transpile(mixed_return_function, validate=True)

assert "Function must only return Gradio component updates" in str(e.value)


def test_validate_multiple_gradio_returns():
def multiple_components():
return gradio.Textbox(placeholder="Component 1"), gradio.Button(variant="primary")

result = transpile(multiple_components, validate=True)
expected = """function multiple_components() {
return [{"placeholder": "Component 1", "__type__": "update"}, {"variant": "primary", "__type__": "update"}];
}"""
assert result.strip() == expected.strip()

0 comments on commit 8c97900

Please sign in to comment.