Skip to content

Commit

Permalink
Python frontend fixes for return values and arguments (#1834)
Browse files Browse the repository at this point in the history
* Properly parse `return` statements without a return value
* Fix missing argument detection when no arguments are given

---------

Co-authored-by: Philipp Schaad <[email protected]>
  • Loading branch information
tbennun and phschaad authored Dec 23, 2024
1 parent 65ca11f commit dbc7747
Show file tree
Hide file tree
Showing 5 changed files with 132 additions and 16 deletions.
2 changes: 2 additions & 0 deletions dace/codegen/compiled_sdfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -498,6 +498,8 @@ def _construct_args(self, kwargs) -> Tuple[Tuple[Any], Tuple[Any]]:
raise KeyError("Missing program argument \"{}\"".format(a))

else:
if len(sig) > 0:
raise KeyError(f"Missing program arguments: {', '.join(sig)}")
arglist = []
argtypes = []
argnames = []
Expand Down
32 changes: 18 additions & 14 deletions dace/frontend/python/newast.py
Original file line number Diff line number Diff line change
Expand Up @@ -4735,22 +4735,26 @@ def visit_TopLevelExpr(self, node: ast.Expr):
self.generic_visit(node)

def visit_Return(self, node: ast.Return):
# Modify node value to become an expression
new_node = ast.copy_location(ast.Expr(value=node.value), node)

# Return values can either be tuples or a single object
if isinstance(node.value, (ast.Tuple, ast.List)):
ast_tuple = ast.copy_location(
ast.parse('(%s,)' % ','.join('__return_%d' % i for i in range(len(node.value.elts)))).body[0].value,
node)
self._visit_assign(new_node, ast_tuple, None, is_return=True)
if node.value is None:
# If there's no value on the return node or it is None, insert just a return block.
self._on_block_added(self.cfg_target.add_return(f'return_{self.cfg_target.label}_{node.lineno}'))
else:
ast_name = ast.copy_location(ast.Name(id='__return'), node)
self._visit_assign(new_node, ast_name, None, is_return=True)
# Modify node value to become an expression
new_node = ast.copy_location(ast.Expr(value=node.value), node)

# Return values can either be tuples or a single object
if isinstance(node.value, (ast.Tuple, ast.List)):
ast_tuple = ast.copy_location(
ast.parse('(%s,)' % ','.join('__return_%d' % i for i in range(len(node.value.elts)))).body[0].value,
node)
self._visit_assign(new_node, ast_tuple, None, is_return=True)
else:
ast_name = ast.copy_location(ast.Name(id='__return'), node)
self._visit_assign(new_node, ast_name, None, is_return=True)

if not isinstance(self.cfg_target, SDFG):
# In a nested control flow region, a return needs to be explicitly marked with a return block.
self._on_block_added(self.cfg_target.add_return(f'return_{self.cfg_target.label}_{node.lineno}'))
if not isinstance(self.cfg_target, SDFG):
# In a nested control flow region, a return needs to be explicitly marked with a return block.
self._on_block_added(self.cfg_target.add_return(f'return_{self.cfg_target.label}_{node.lineno}'))

def visit_With(self, node: ast.With, is_async=False):
# "with dace.tasklet" syntax
Expand Down
2 changes: 1 addition & 1 deletion tests/codegen/external_memory_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def tester(a: dace.float64[N]):

# Test workspace size
csdfg = sdfg.compile()
csdfg.initialize(**extra_args)
csdfg.initialize(a, **extra_args)
sizes = csdfg.get_workspace_sizes()
assert sizes == {dace.StorageType.CPU_Heap: 20 * 8}

Expand Down
13 changes: 12 additions & 1 deletion tests/python_frontend/argument_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved.
# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved.

import dace
import pytest
Expand Down Expand Up @@ -40,6 +40,17 @@ def tester(x: dace.float64[20, 20]):
tester.to_sdfg().compile()


def test_missing_arguments_2_regression():

@dace.program
def tester(x: dace.float64[20]):
x[:] = 0

with pytest.raises(KeyError):
tester()


if __name__ == '__main__':
test_extra_args()
test_missing_arguments_regression()
test_missing_arguments_2_regression()
99 changes: 99 additions & 0 deletions tests/python_frontend/return_value_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved.
import dace
import numpy as np


def test_return_scalar():

@dace.program
def return_scalar():
return 5

assert return_scalar() == 5


def test_return_array():

@dace.program
def return_array():
return 5 * np.ones(5)

res = return_array()
assert np.allclose(res, 5 * np.ones(5))


def test_return_tuple():

@dace.program
def return_tuple():
return 5, 6

res = return_tuple()
assert res == (5, 6)


def test_return_array_tuple():

@dace.program
def return_array_tuple():
return 5 * np.ones(5), 6 * np.ones(6)

res = return_array_tuple()
assert np.allclose(res[0], 5 * np.ones(5))
assert np.allclose(res[1], 6 * np.ones(6))


def test_return_void():

@dace.program
def return_void(a: dace.float64[20]):
a[:] += 1
return
a[:] = 5

a = np.random.rand(20)
ref = a + 1
return_void(a)
assert np.allclose(a, ref)


def test_return_void_in_if():

@dace.program
def return_void(a: dace.float64[20]):
if a[0] < 0:
return
a[:] = 5

a = np.random.rand(20)
return_void(a)
assert np.allclose(a, 5)
a[:] = np.random.rand(20)
a[0] = -1
ref = a.copy()
return_void(a)
assert np.allclose(a, ref)


def test_return_void_in_for():

@dace.program
def return_void(a: dace.float64[20]):
for _ in range(20):
return
a[:] = 5

a = np.random.rand(20)
ref = a.copy()
return_void(a)
assert np.allclose(a, ref)


if __name__ == '__main__':
test_return_scalar()
test_return_array()
test_return_tuple()
test_return_array_tuple()
test_return_void()
test_return_void_in_if()
test_return_void_in_for()

0 comments on commit dbc7747

Please sign in to comment.