diff --git a/dace/transformation/passes/dead_dataflow_elimination.py b/dace/transformation/passes/dead_dataflow_elimination.py index d9131385d6..017fe2cc86 100644 --- a/dace/transformation/passes/dead_dataflow_elimination.py +++ b/dace/transformation/passes/dead_dataflow_elimination.py @@ -1,10 +1,12 @@ # Copyright 2019-2022 ETH Zurich and the DaCe authors. All rights reserved. +import ast from collections import defaultdict from dataclasses import dataclass from typing import Any, Dict, List, Optional, Set, Tuple, Type from dace import SDFG, Memlet, SDFGState, data, dtypes, properties +from dace.frontend.python import astutils from dace.sdfg import nodes from dace.sdfg import utils as sdutil from dace.sdfg.analysis import cfg @@ -137,16 +139,27 @@ def apply_pass(self, sdfg: SDFG, pipeline_results: Dict[str, Any]) -> Optional[D predecessor_nsdfgs[leaf.src].add(leaf.src_conn) # Pruning connectors on tasklets sometimes needs to change their code - elif (isinstance(leaf.src, nodes.Tasklet) - and leaf.src.code.language != dtypes.Language.Python): + elif isinstance(leaf.src, nodes.Tasklet): + ctype = infer_types.infer_out_connector_type(sdfg, state, leaf.src, leaf.src_conn) + # Add definition if leaf.src.code.language == dtypes.Language.CPP: - ctype = infer_types.infer_out_connector_type(sdfg, state, leaf.src, leaf.src_conn) if ctype is None: raise NotImplementedError( f'Cannot eliminate dead connector "{leaf.src_conn}" on ' 'tasklet due to connector type inference failure.') - # Add definition leaf.src.code.code = f'{ctype.as_arg(leaf.src_conn)};\n' + leaf.src.code.code + elif leaf.src.code.language == dtypes.Language.Python: + if ctype is not None: + # ASTFindReplace won't do any replacement (note that repldict is empty), it is + # used only to check if leaf.src_conn is used in tasklet's code. + ast_find = astutils.ASTFindReplace(repldict={}, trigger_names={leaf.src_conn}) + # if leaf.src_conn is found in leaf.src.code.code + try: + for code in leaf.src.code.code: + ast_find.generic_visit(code) + except astutils.NameFound: + # then add the hint expression + leaf.src.code.code = ast.parse(f'{leaf.src_conn}: dace.{ctype.to_string()}\n').body + leaf.src.code.code else: raise NotImplementedError(f'Cannot eliminate dead connector "{leaf.src_conn}" on ' 'tasklet due to its code language.') diff --git a/tests/passes/dead_code_elimination_test.py b/tests/passes/dead_code_elimination_test.py index 2f84d333e0..f8920b0538 100644 --- a/tests/passes/dead_code_elimination_test.py +++ b/tests/passes/dead_code_elimination_test.py @@ -1,6 +1,7 @@ # Copyright 2019-2022 ETH Zurich and the DaCe authors. All rights reserved. """ Various tests for dead code elimination passes. """ +import numpy as np import pytest import dace from dace.transformation.pass_pipeline import Pipeline @@ -253,6 +254,43 @@ def test_dce_callback_manual(): sdfg.validate() +def test_dce_add_type_hint_of_variable(): + """ + The code of this test comes from this issue: https://github.com/spcl/dace/issues/1150#issue-1445418361 + """ + sdfg = dace.SDFG("test") + state = sdfg.add_state() + sdfg.add_array("out", dtype=dace.float64, shape=(10,)) + sdfg.add_array("cond", dtype=dace.bool, shape=(10,)) + sdfg.add_array("tmp", dtype=dace.float64, shape=(10,), transient=True) + tasklet, *_ = state.add_mapped_tasklet( + code=""" +if _cond: + _tmp = 3.0 +else: + _tmp = 7.0 +_out = _tmp + """, + inputs={"_cond": dace.Memlet(subset="k", data="cond")}, + outputs={ + "_out": dace.Memlet(subset="k", data="out"), + "_tmp": dace.Memlet(subset="k", data="tmp"), + }, + map_ranges={"k": "0:10"}, + name="test_tasklet", + external_edges=True, + ) + sdfg.simplify() + assert tasklet.code.as_string.startswith("_tmp: dace.float64") + + compiledsdfg = sdfg.compile() + cond = np.random.choice(a=[True, False], size=(10,)) + out = np.zeros((10,)) + compiledsdfg(cond=cond, out=out) + assert np.all(out == np.where(cond, 3.0, 7.0)) + + + if __name__ == '__main__': test_dse_simple() test_dse_unconditional() @@ -267,3 +305,4 @@ def test_dce_callback_manual(): test_dce() test_dce_callback() test_dce_callback_manual() + test_dce_add_type_hint_of_variable()