diff --git a/groovy/transpiler.py b/groovy/transpiler.py index 8de7146..c5fcf29 100644 --- a/groovy/transpiler.py +++ b/groovy/transpiler.py @@ -319,6 +319,10 @@ def visit_Call(self, node: ast.Call): # noqa: N802 # Try to resolve if this is a Gradio component. if has_gradio: try: + # Handle direct update() call + if node.func.id == "update": + return self._handle_gradio_component_updates(node) + component_class = getattr(gradio, node.func.id, None) if component_class and issubclass( component_class, gradio.blocks.Block @@ -342,6 +346,10 @@ def visit_Call(self, node: ast.Call): # noqa: N802 "gradio", "gr", }: + # Handle gr.update() call + if node.func.attr == "update": + return self._handle_gradio_component_updates(node) + component_class = getattr(gradio, node.func.attr, None) if component_class and issubclass( component_class, gradio.blocks.Block @@ -626,6 +634,9 @@ def _is_valid_gradio_return(node: ast.AST) -> bool: try: import gradio + if node.func.attr == "update": + return True + component_class = getattr(gradio, node.func.attr, None) if component_class and issubclass( component_class, gradio.blocks.Block @@ -643,6 +654,9 @@ def _is_valid_gradio_return(node: ast.AST) -> bool: try: import gradio + if node.func.id == "update": + return True + component_class = getattr(gradio, node.func.id, None) if component_class and issubclass(component_class, gradio.blocks.Block): if node.args: @@ -669,7 +683,7 @@ def _is_valid_gradio_return(node: ast.AST) -> bool: import gradio as gr def filter_rows_by_term(): - return gr.Tabs(selected=2, visible=True, info=None) + return gr.update(selected=2, visible=True, info=None) - js_code = transpile(filter_rows_by_term) + js_code = transpile(filter_rows_by_term, validate=True) print(js_code) diff --git a/tests/test_transpiler.py b/tests/test_transpiler.py index 5227252..db84e24 100644 --- a/tests/test_transpiler.py +++ b/tests/test_transpiler.py @@ -328,3 +328,51 @@ def component_with_none(): return {"visible": true, "info": null, "__type__": "update"}; }""" assert transpile(component_with_none).strip() == expected.strip() + + +def test_gradio_update_function(): + def update_component(): + return gradio.update(visible=False, interactive=True) + + expected = """function update_component() { + return {"visible": false, "interactive": true, "__type__": "update"}; +}""" + assert transpile(update_component).strip() == expected.strip() + + +def test_update_with_none_values(): + def update_with_none(): + return gradio.update(info=None, label="Updated") + + expected = """function update_with_none() { + return {"info": null, "label": "Updated", "__type__": "update"}; +}""" + assert transpile(update_with_none).strip() == expected.strip() + + +def test_mixed_update_and_components(): + def mixed_updates(): + return gradio.update(visible=True), gradio.Textbox(placeholder="Test") + + expected = """function mixed_updates() { + return [{"visible": true, "__type__": "update"}, {"placeholder": "Test", "__type__": "update"}]; +}""" + assert transpile(mixed_updates).strip() == expected.strip() + + +def test_conditional_update(): + def conditional_update(x: int): + if x > 10: + return gradio.update(visible=True) + else: + return gradio.update(visible=False) + + expected = """function conditional_update(x) { + if ((x > 10)) { + return {"visible": true, "__type__": "update"}; + } + else { + return {"visible": false, "__type__": "update"}; + } +}""" + assert transpile(conditional_update).strip() == expected.strip()