From dbc7747bf9189f860915d504351cc9034f183e83 Mon Sep 17 00:00:00 2001 From: Tal Ben-Nun Date: Mon, 23 Dec 2024 06:52:03 -0800 Subject: [PATCH] Python frontend fixes for return values and arguments (#1834) * Properly parse `return` statements without a return value * Fix missing argument detection when no arguments are given --------- Co-authored-by: Philipp Schaad --- dace/codegen/compiled_sdfg.py | 2 + dace/frontend/python/newast.py | 32 ++++--- tests/codegen/external_memory_test.py | 2 +- tests/python_frontend/argument_test.py | 13 ++- tests/python_frontend/return_value_test.py | 99 ++++++++++++++++++++++ 5 files changed, 132 insertions(+), 16 deletions(-) create mode 100644 tests/python_frontend/return_value_test.py diff --git a/dace/codegen/compiled_sdfg.py b/dace/codegen/compiled_sdfg.py index bae8531e62..24b52aa02b 100644 --- a/dace/codegen/compiled_sdfg.py +++ b/dace/codegen/compiled_sdfg.py @@ -498,6 +498,8 @@ def _construct_args(self, kwargs) -> Tuple[Tuple[Any], Tuple[Any]]: raise KeyError("Missing program argument \"{}\"".format(a)) else: + if len(sig) > 0: + raise KeyError(f"Missing program arguments: {', '.join(sig)}") arglist = [] argtypes = [] argnames = [] diff --git a/dace/frontend/python/newast.py b/dace/frontend/python/newast.py index 57db9d7089..e625a004a9 100644 --- a/dace/frontend/python/newast.py +++ b/dace/frontend/python/newast.py @@ -4735,22 +4735,26 @@ def visit_TopLevelExpr(self, node: ast.Expr): self.generic_visit(node) def visit_Return(self, node: ast.Return): - # Modify node value to become an expression - new_node = ast.copy_location(ast.Expr(value=node.value), node) - - # Return values can either be tuples or a single object - if isinstance(node.value, (ast.Tuple, ast.List)): - ast_tuple = ast.copy_location( - ast.parse('(%s,)' % ','.join('__return_%d' % i for i in range(len(node.value.elts)))).body[0].value, - node) - self._visit_assign(new_node, ast_tuple, None, is_return=True) + if node.value is None: + # If there's no value on the return node or it is None, insert just a return block. + self._on_block_added(self.cfg_target.add_return(f'return_{self.cfg_target.label}_{node.lineno}')) else: - ast_name = ast.copy_location(ast.Name(id='__return'), node) - self._visit_assign(new_node, ast_name, None, is_return=True) + # Modify node value to become an expression + new_node = ast.copy_location(ast.Expr(value=node.value), node) + + # Return values can either be tuples or a single object + if isinstance(node.value, (ast.Tuple, ast.List)): + ast_tuple = ast.copy_location( + ast.parse('(%s,)' % ','.join('__return_%d' % i for i in range(len(node.value.elts)))).body[0].value, + node) + self._visit_assign(new_node, ast_tuple, None, is_return=True) + else: + ast_name = ast.copy_location(ast.Name(id='__return'), node) + self._visit_assign(new_node, ast_name, None, is_return=True) - if not isinstance(self.cfg_target, SDFG): - # In a nested control flow region, a return needs to be explicitly marked with a return block. - self._on_block_added(self.cfg_target.add_return(f'return_{self.cfg_target.label}_{node.lineno}')) + if not isinstance(self.cfg_target, SDFG): + # In a nested control flow region, a return needs to be explicitly marked with a return block. + self._on_block_added(self.cfg_target.add_return(f'return_{self.cfg_target.label}_{node.lineno}')) def visit_With(self, node: ast.With, is_async=False): # "with dace.tasklet" syntax diff --git a/tests/codegen/external_memory_test.py b/tests/codegen/external_memory_test.py index c72c574806..169e050914 100644 --- a/tests/codegen/external_memory_test.py +++ b/tests/codegen/external_memory_test.py @@ -36,7 +36,7 @@ def tester(a: dace.float64[N]): # Test workspace size csdfg = sdfg.compile() - csdfg.initialize(**extra_args) + csdfg.initialize(a, **extra_args) sizes = csdfg.get_workspace_sizes() assert sizes == {dace.StorageType.CPU_Heap: 20 * 8} diff --git a/tests/python_frontend/argument_test.py b/tests/python_frontend/argument_test.py index cb47188029..ab6666c687 100644 --- a/tests/python_frontend/argument_test.py +++ b/tests/python_frontend/argument_test.py @@ -1,4 +1,4 @@ -# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. +# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. import dace import pytest @@ -40,6 +40,17 @@ def tester(x: dace.float64[20, 20]): tester.to_sdfg().compile() +def test_missing_arguments_2_regression(): + + @dace.program + def tester(x: dace.float64[20]): + x[:] = 0 + + with pytest.raises(KeyError): + tester() + + if __name__ == '__main__': test_extra_args() test_missing_arguments_regression() + test_missing_arguments_2_regression() diff --git a/tests/python_frontend/return_value_test.py b/tests/python_frontend/return_value_test.py new file mode 100644 index 0000000000..93870c41ce --- /dev/null +++ b/tests/python_frontend/return_value_test.py @@ -0,0 +1,99 @@ +# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. +import dace +import numpy as np + + +def test_return_scalar(): + + @dace.program + def return_scalar(): + return 5 + + assert return_scalar() == 5 + + +def test_return_array(): + + @dace.program + def return_array(): + return 5 * np.ones(5) + + res = return_array() + assert np.allclose(res, 5 * np.ones(5)) + + +def test_return_tuple(): + + @dace.program + def return_tuple(): + return 5, 6 + + res = return_tuple() + assert res == (5, 6) + + +def test_return_array_tuple(): + + @dace.program + def return_array_tuple(): + return 5 * np.ones(5), 6 * np.ones(6) + + res = return_array_tuple() + assert np.allclose(res[0], 5 * np.ones(5)) + assert np.allclose(res[1], 6 * np.ones(6)) + + +def test_return_void(): + + @dace.program + def return_void(a: dace.float64[20]): + a[:] += 1 + return + a[:] = 5 + + a = np.random.rand(20) + ref = a + 1 + return_void(a) + assert np.allclose(a, ref) + + +def test_return_void_in_if(): + + @dace.program + def return_void(a: dace.float64[20]): + if a[0] < 0: + return + a[:] = 5 + + a = np.random.rand(20) + return_void(a) + assert np.allclose(a, 5) + a[:] = np.random.rand(20) + a[0] = -1 + ref = a.copy() + return_void(a) + assert np.allclose(a, ref) + + +def test_return_void_in_for(): + + @dace.program + def return_void(a: dace.float64[20]): + for _ in range(20): + return + a[:] = 5 + + a = np.random.rand(20) + ref = a.copy() + return_void(a) + assert np.allclose(a, ref) + + +if __name__ == '__main__': + test_return_scalar() + test_return_array() + test_return_tuple() + test_return_array_tuple() + test_return_void() + test_return_void_in_if() + test_return_void_in_for()