Skip to content

Commit

Permalink
add support for .update()
Browse files Browse the repository at this point in the history
  • Loading branch information
abidlabs committed Feb 28, 2025
1 parent 02df14e commit 4b5b42d
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 2 deletions.
18 changes: 16 additions & 2 deletions groovy/transpiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,10 @@ def visit_Call(self, node: ast.Call): # noqa: N802
# Try to resolve if this is a Gradio component.
if has_gradio:
try:
# Handle direct update() call
if node.func.id == "update":
return self._handle_gradio_component_updates(node)

component_class = getattr(gradio, node.func.id, None)
if component_class and issubclass(
component_class, gradio.blocks.Block
Expand All @@ -342,6 +346,10 @@ def visit_Call(self, node: ast.Call): # noqa: N802
"gradio",
"gr",
}:
# Handle gr.update() call
if node.func.attr == "update":
return self._handle_gradio_component_updates(node)

component_class = getattr(gradio, node.func.attr, None)
if component_class and issubclass(
component_class, gradio.blocks.Block
Expand Down Expand Up @@ -626,6 +634,9 @@ def _is_valid_gradio_return(node: ast.AST) -> bool:
try:
import gradio

if node.func.attr == "update":
return True

component_class = getattr(gradio, node.func.attr, None)
if component_class and issubclass(
component_class, gradio.blocks.Block
Expand All @@ -643,6 +654,9 @@ def _is_valid_gradio_return(node: ast.AST) -> bool:
try:
import gradio

if node.func.id == "update":
return True

component_class = getattr(gradio, node.func.id, None)
if component_class and issubclass(component_class, gradio.blocks.Block):
if node.args:
Expand All @@ -669,7 +683,7 @@ def _is_valid_gradio_return(node: ast.AST) -> bool:
import gradio as gr

def filter_rows_by_term():
return gr.Tabs(selected=2, visible=True, info=None)
return gr.update(selected=2, visible=True, info=None)

js_code = transpile(filter_rows_by_term)
js_code = transpile(filter_rows_by_term, validate=True)
print(js_code)
48 changes: 48 additions & 0 deletions tests/test_transpiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,3 +328,51 @@ def component_with_none():
return {"visible": true, "info": null, "__type__": "update"};
}"""
assert transpile(component_with_none).strip() == expected.strip()


def test_gradio_update_function():
def update_component():
return gradio.update(visible=False, interactive=True)

expected = """function update_component() {
return {"visible": false, "interactive": true, "__type__": "update"};
}"""
assert transpile(update_component).strip() == expected.strip()


def test_update_with_none_values():
def update_with_none():
return gradio.update(info=None, label="Updated")

expected = """function update_with_none() {
return {"info": null, "label": "Updated", "__type__": "update"};
}"""
assert transpile(update_with_none).strip() == expected.strip()


def test_mixed_update_and_components():
def mixed_updates():
return gradio.update(visible=True), gradio.Textbox(placeholder="Test")

expected = """function mixed_updates() {
return [{"visible": true, "__type__": "update"}, {"placeholder": "Test", "__type__": "update"}];
}"""
assert transpile(mixed_updates).strip() == expected.strip()


def test_conditional_update():
def conditional_update(x: int):
if x > 10:
return gradio.update(visible=True)
else:
return gradio.update(visible=False)

expected = """function conditional_update(x) {
if ((x > 10)) {
return {"visible": true, "__type__": "update"};
}
else {
return {"visible": false, "__type__": "update"};
}
}"""
assert transpile(conditional_update).strip() == expected.strip()

0 comments on commit 4b5b42d

Please sign in to comment.