From bc08e9a18151986f65f3a7bf9d568b0e406c6f9f Mon Sep 17 00:00:00 2001 From: luca-patrignani <92518571+luca-patrignani@users.noreply.github.com> Date: Thu, 22 Feb 2024 17:10:09 +0100 Subject: [PATCH] `DeadDataFlowElimination` will add type hint when removing a connector (#1499) Issue #1150 that `DeadDataflowElimination` removes a connector from a Tasklet which leaves a variable without type hint. This PR tries to fix this bug by adding a type hint expression for a variable which is used in the tasklet. It adds the type hint only if the variable is used inside the tasklet code (I checked using `ASTFindReplace`). The PR also adds a test which is literaly the code presented in #1150 and asserts the presence of the type hint and checks if it compiles. ### May need confirmation - [ ] Did I use `ASTFindReplace` correctly? - [ ] If the type inference fails no type hint is added. Is it the right solution? - [ ] Does the test even make sense? (I don't have much experience in unit testing). This is my first PR for this project so be patient with me. --------- Co-authored-by: alexnick83 <31545860+alexnick83@users.noreply.github.com> --- .../passes/dead_dataflow_elimination.py | 21 ++++++++-- tests/passes/dead_code_elimination_test.py | 39 +++++++++++++++++++ 2 files changed, 56 insertions(+), 4 deletions(-) diff --git a/dace/transformation/passes/dead_dataflow_elimination.py b/dace/transformation/passes/dead_dataflow_elimination.py index d9131385d6..017fe2cc86 100644 --- a/dace/transformation/passes/dead_dataflow_elimination.py +++ b/dace/transformation/passes/dead_dataflow_elimination.py @@ -1,10 +1,12 @@ # Copyright 2019-2022 ETH Zurich and the DaCe authors. All rights reserved. +import ast from collections import defaultdict from dataclasses import dataclass from typing import Any, Dict, List, Optional, Set, Tuple, Type from dace import SDFG, Memlet, SDFGState, data, dtypes, properties +from dace.frontend.python import astutils from dace.sdfg import nodes from dace.sdfg import utils as sdutil from dace.sdfg.analysis import cfg @@ -137,16 +139,27 @@ def apply_pass(self, sdfg: SDFG, pipeline_results: Dict[str, Any]) -> Optional[D predecessor_nsdfgs[leaf.src].add(leaf.src_conn) # Pruning connectors on tasklets sometimes needs to change their code - elif (isinstance(leaf.src, nodes.Tasklet) - and leaf.src.code.language != dtypes.Language.Python): + elif isinstance(leaf.src, nodes.Tasklet): + ctype = infer_types.infer_out_connector_type(sdfg, state, leaf.src, leaf.src_conn) + # Add definition if leaf.src.code.language == dtypes.Language.CPP: - ctype = infer_types.infer_out_connector_type(sdfg, state, leaf.src, leaf.src_conn) if ctype is None: raise NotImplementedError( f'Cannot eliminate dead connector "{leaf.src_conn}" on ' 'tasklet due to connector type inference failure.') - # Add definition leaf.src.code.code = f'{ctype.as_arg(leaf.src_conn)};\n' + leaf.src.code.code + elif leaf.src.code.language == dtypes.Language.Python: + if ctype is not None: + # ASTFindReplace won't do any replacement (note that repldict is empty), it is + # used only to check if leaf.src_conn is used in tasklet's code. + ast_find = astutils.ASTFindReplace(repldict={}, trigger_names={leaf.src_conn}) + # if leaf.src_conn is found in leaf.src.code.code + try: + for code in leaf.src.code.code: + ast_find.generic_visit(code) + except astutils.NameFound: + # then add the hint expression + leaf.src.code.code = ast.parse(f'{leaf.src_conn}: dace.{ctype.to_string()}\n').body + leaf.src.code.code else: raise NotImplementedError(f'Cannot eliminate dead connector "{leaf.src_conn}" on ' 'tasklet due to its code language.') diff --git a/tests/passes/dead_code_elimination_test.py b/tests/passes/dead_code_elimination_test.py index 2f84d333e0..f8920b0538 100644 --- a/tests/passes/dead_code_elimination_test.py +++ b/tests/passes/dead_code_elimination_test.py @@ -1,6 +1,7 @@ # Copyright 2019-2022 ETH Zurich and the DaCe authors. All rights reserved. """ Various tests for dead code elimination passes. """ +import numpy as np import pytest import dace from dace.transformation.pass_pipeline import Pipeline @@ -253,6 +254,43 @@ def test_dce_callback_manual(): sdfg.validate() +def test_dce_add_type_hint_of_variable(): + """ + The code of this test comes from this issue: https://github.com/spcl/dace/issues/1150#issue-1445418361 + """ + sdfg = dace.SDFG("test") + state = sdfg.add_state() + sdfg.add_array("out", dtype=dace.float64, shape=(10,)) + sdfg.add_array("cond", dtype=dace.bool, shape=(10,)) + sdfg.add_array("tmp", dtype=dace.float64, shape=(10,), transient=True) + tasklet, *_ = state.add_mapped_tasklet( + code=""" +if _cond: + _tmp = 3.0 +else: + _tmp = 7.0 +_out = _tmp + """, + inputs={"_cond": dace.Memlet(subset="k", data="cond")}, + outputs={ + "_out": dace.Memlet(subset="k", data="out"), + "_tmp": dace.Memlet(subset="k", data="tmp"), + }, + map_ranges={"k": "0:10"}, + name="test_tasklet", + external_edges=True, + ) + sdfg.simplify() + assert tasklet.code.as_string.startswith("_tmp: dace.float64") + + compiledsdfg = sdfg.compile() + cond = np.random.choice(a=[True, False], size=(10,)) + out = np.zeros((10,)) + compiledsdfg(cond=cond, out=out) + assert np.all(out == np.where(cond, 3.0, 7.0)) + + + if __name__ == '__main__': test_dse_simple() test_dse_unconditional() @@ -267,3 +305,4 @@ def test_dce_callback_manual(): test_dce() test_dce_callback() test_dce_callback_manual() + test_dce_add_type_hint_of_variable()