From 40ed43812a10f3a622572bd8c82baa68d15053a0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philip=20M=C3=BCller?= <147368808+philip-paul-mueller@users.noreply.github.com> Date: Fri, 17 Nov 2023 17:21:55 +0100 Subject: [PATCH] Numpy fill accepts also variables (#1420) This PR is for addressing issue [#1389](https://github.com/spcl/dace/issues/1389). --------- Co-authored-by: acalotoiu <61420859+acalotoiu@users.noreply.github.com> Co-authored-by: BenWeber42 --- dace/frontend/common/op_repository.py | 7 +-- dace/frontend/python/astutils.py | 25 ++++++----- dace/frontend/python/replacements.py | 45 +++++++++++++++---- .../numpy/ndarray_attributes_methods_test.py | 14 ++++++ 4 files changed, 66 insertions(+), 25 deletions(-) diff --git a/dace/frontend/common/op_repository.py b/dace/frontend/common/op_repository.py index 32e10417dc..067c19ac57 100644 --- a/dace/frontend/common/op_repository.py +++ b/dace/frontend/common/op_repository.py @@ -17,12 +17,7 @@ def _get_all_bases(class_or_name: Union[str, Type]) -> List[str]: """ if isinstance(class_or_name, str): return [class_or_name] - - classes = [class_or_name.__name__] - for base in class_or_name.__bases__: - classes.extend(_get_all_bases(base)) - - return deduplicate(classes) + return [base.__name__ for base in class_or_name.__mro__] class Replacements(object): diff --git a/dace/frontend/python/astutils.py b/dace/frontend/python/astutils.py index 67d8b6aded..c9a400e5f1 100644 --- a/dace/frontend/python/astutils.py +++ b/dace/frontend/python/astutils.py @@ -442,9 +442,10 @@ class ExtNodeTransformer(ast.NodeTransformer): bodies in order to discern DaCe statements from others. """ def visit_TopLevel(self, node): - clsname = type(node).__name__ - if getattr(self, "visit_TopLevel" + clsname, False): - return getattr(self, "visit_TopLevel" + clsname)(node) + visitor_name = "visit_TopLevel" + type(node).__name__ + if hasattr(self, visitor_name): + visitor = getattr(self, visitor_name) + return visitor(node) else: return self.visit(node) @@ -480,21 +481,23 @@ class ExtNodeVisitor(ast.NodeVisitor): top-level expressions in bodies in order to discern DaCe statements from others. """ def visit_TopLevel(self, node): - clsname = type(node).__name__ - if getattr(self, "visit_TopLevel" + clsname, False): - getattr(self, "visit_TopLevel" + clsname)(node) + visitor_name = "visit_TopLevel" + type(node).__name__ + if hasattr(self, visitor_name): + visitor = getattr(self, visitor_name) + return visitor(node) else: - self.visit(node) + return self.visit(node) def generic_visit(self, node): for field, old_value in ast.iter_fields(node): if isinstance(old_value, list): for value in old_value: if isinstance(value, ast.AST): - if (field == 'body' or field == 'orelse'): - clsname = type(value).__name__ - if getattr(self, "visit_TopLevel" + clsname, False): - getattr(self, "visit_TopLevel" + clsname)(value) + if field == 'body' or field == 'orelse': + visitor_name = "visit_TopLevel" + type(value).__name__ + if hasattr(self, visitor_name): + visitor = getattr(self, visitor_name) + visitor(value) else: self.visit(value) else: diff --git a/dace/frontend/python/replacements.py b/dace/frontend/python/replacements.py index eace0c8336..f55a65eabb 100644 --- a/dace/frontend/python/replacements.py +++ b/dace/frontend/python/replacements.py @@ -605,11 +605,10 @@ def _elementwise(pv: 'ProgramVisitor', else: state.add_mapped_tasklet( name="_elementwise_", - map_ranges={'__i%d' % i: '0:%s' % n - for i, n in enumerate(inparr.shape)}, - inputs={'__inp': Memlet.simple(in_array, ','.join(['__i%d' % i for i in range(len(inparr.shape))]))}, + map_ranges={f'__i{dim}': f'0:{N}' for dim, N in enumerate(inparr.shape)}, + inputs={'__inp': Memlet.simple(in_array, ','.join([f'__i{dim}' for dim in range(len(inparr.shape))]))}, code=code, - outputs={'__out': Memlet.simple(out_array, ','.join(['__i%d' % i for i in range(len(inparr.shape))]))}, + outputs={'__out': Memlet.simple(out_array, ','.join([f'__i{dim}' for dim in range(len(inparr.shape))]))}, external_edges=True) return out_array @@ -4232,10 +4231,40 @@ def _ndarray_copy(pv: ProgramVisitor, sdfg: SDFG, state: SDFGState, arr: str) -> @oprepo.replaces_method('Array', 'fill') @oprepo.replaces_method('Scalar', 'fill') @oprepo.replaces_method('View', 'fill') -def _ndarray_fill(pv: ProgramVisitor, sdfg: SDFG, state: SDFGState, arr: str, value: Number) -> str: - if not isinstance(value, (Number, np.bool_)): - raise mem_parser.DaceSyntaxError(pv, None, "Fill value {f} must be a number!".format(f=value)) - return _elementwise(pv, sdfg, state, "lambda x: {}".format(value), arr, arr) +def _ndarray_fill(pv: ProgramVisitor, sdfg: SDFG, state: SDFGState, arr: str, value: Union[str, Number, + sp.Expr]) -> str: + assert arr in sdfg.arrays + + if isinstance(value, sp.Expr): + raise NotImplementedError( + f"{arr}.fill is not implemented for symbolic expressions ({value}).") # Look at `full`. + + if isinstance(value, (Number, np.bool_)): + body = value + inputs = {} + elif isinstance(value, str) and value in sdfg.arrays: + value_array = sdfg.arrays[value] + if not isinstance(value_array, data.Scalar): + raise mem_parser.DaceSyntaxError( + pv, None, f"{arr}.fill requires a scalar argument, but {type(value_array)} was given.") + body = '__inp' + inputs = {'__inp': dace.Memlet(data=value, subset='0')} + else: + raise mem_parser.DaceSyntaxError(pv, None, f"Unsupported argument '{value}' for {arr}.fill.") + + shape = sdfg.arrays[arr].shape + state.add_mapped_tasklet( + '_numpy_fill_', + map_ranges={ + f"__i{dim}": f"0:{s}" + for dim, s in enumerate(shape) + }, + inputs=inputs, + code=f"__out = {body}", + outputs={'__out': dace.Memlet.simple(arr, ",".join([f"__i{dim}" for dim in range(len(shape))]))}, + external_edges=True) + + return arr @oprepo.replaces_method('Array', 'reshape') diff --git a/tests/numpy/ndarray_attributes_methods_test.py b/tests/numpy/ndarray_attributes_methods_test.py index 40a6db7a6c..c9c38e245c 100644 --- a/tests/numpy/ndarray_attributes_methods_test.py +++ b/tests/numpy/ndarray_attributes_methods_test.py @@ -38,6 +38,18 @@ def test_fill(A: dace.int32[M, N]): return A # return A.fill(5) doesn't work because A is not copied +@compare_numpy_output() +def test_fill2(A: dace.int32[M, N], a: dace.int32): + A.fill(a) + return A # return A.fill(5) doesn't work because A is not copied + + +@compare_numpy_output() +def test_fill3(A: dace.int32[M, N], a: dace.int32): + A.fill(a + 1) + return A + + @compare_numpy_output() def test_reshape(A: dace.float32[N, N]): return A.reshape([1, N * N]) @@ -124,6 +136,8 @@ def test_any(): test_copy() test_astype() test_fill() + test_fill2() + test_fill3() test_reshape() test_transpose1() test_transpose2()