Skip to content

Commit

Permalink
DeadDataFlowElimination will add type hint when removing a connector (
Browse files Browse the repository at this point in the history
#1499)

Issue #1150 that `DeadDataflowElimination` removes a connector from a
Tasklet which leaves a variable without type hint.
This PR tries to fix this bug by adding a type hint expression for a
variable which is used in the tasklet. It adds the type hint only if the
variable is used inside the tasklet code (I checked using
`ASTFindReplace`).
The PR also adds a test which is literaly the code presented in #1150
and asserts the presence of the type hint and checks if it compiles.

### May need confirmation
- [ ] Did I use `ASTFindReplace` correctly?
- [ ] If the type inference fails no type hint is added. Is it the right
solution?
- [ ] Does the test even make sense? (I don't have much experience in
unit testing).

This is my first PR for this project so be patient with me.

---------

Co-authored-by: alexnick83 <[email protected]>
  • Loading branch information
luca-patrignani and alexnick83 authored Feb 22, 2024
1 parent ab6647b commit bc08e9a
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 4 deletions.
21 changes: 17 additions & 4 deletions dace/transformation/passes/dead_dataflow_elimination.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
# Copyright 2019-2022 ETH Zurich and the DaCe authors. All rights reserved.

import ast
from collections import defaultdict
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Set, Tuple, Type

from dace import SDFG, Memlet, SDFGState, data, dtypes, properties
from dace.frontend.python import astutils
from dace.sdfg import nodes
from dace.sdfg import utils as sdutil
from dace.sdfg.analysis import cfg
Expand Down Expand Up @@ -137,16 +139,27 @@ def apply_pass(self, sdfg: SDFG, pipeline_results: Dict[str, Any]) -> Optional[D
predecessor_nsdfgs[leaf.src].add(leaf.src_conn)

# Pruning connectors on tasklets sometimes needs to change their code
elif (isinstance(leaf.src, nodes.Tasklet)
and leaf.src.code.language != dtypes.Language.Python):
elif isinstance(leaf.src, nodes.Tasklet):
ctype = infer_types.infer_out_connector_type(sdfg, state, leaf.src, leaf.src_conn)
# Add definition
if leaf.src.code.language == dtypes.Language.CPP:
ctype = infer_types.infer_out_connector_type(sdfg, state, leaf.src, leaf.src_conn)
if ctype is None:
raise NotImplementedError(
f'Cannot eliminate dead connector "{leaf.src_conn}" on '
'tasklet due to connector type inference failure.')
# Add definition
leaf.src.code.code = f'{ctype.as_arg(leaf.src_conn)};\n' + leaf.src.code.code
elif leaf.src.code.language == dtypes.Language.Python:
if ctype is not None:
# ASTFindReplace won't do any replacement (note that repldict is empty), it is
# used only to check if leaf.src_conn is used in tasklet's code.
ast_find = astutils.ASTFindReplace(repldict={}, trigger_names={leaf.src_conn})
# if leaf.src_conn is found in leaf.src.code.code
try:
for code in leaf.src.code.code:
ast_find.generic_visit(code)
except astutils.NameFound:
# then add the hint expression
leaf.src.code.code = ast.parse(f'{leaf.src_conn}: dace.{ctype.to_string()}\n').body + leaf.src.code.code
else:
raise NotImplementedError(f'Cannot eliminate dead connector "{leaf.src_conn}" on '
'tasklet due to its code language.')
Expand Down
39 changes: 39 additions & 0 deletions tests/passes/dead_code_elimination_test.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright 2019-2022 ETH Zurich and the DaCe authors. All rights reserved.
""" Various tests for dead code elimination passes. """

import numpy as np
import pytest
import dace
from dace.transformation.pass_pipeline import Pipeline
Expand Down Expand Up @@ -253,6 +254,43 @@ def test_dce_callback_manual():
sdfg.validate()


def test_dce_add_type_hint_of_variable():
"""
The code of this test comes from this issue: https://github.com/spcl/dace/issues/1150#issue-1445418361
"""
sdfg = dace.SDFG("test")
state = sdfg.add_state()
sdfg.add_array("out", dtype=dace.float64, shape=(10,))
sdfg.add_array("cond", dtype=dace.bool, shape=(10,))
sdfg.add_array("tmp", dtype=dace.float64, shape=(10,), transient=True)
tasklet, *_ = state.add_mapped_tasklet(
code="""
if _cond:
_tmp = 3.0
else:
_tmp = 7.0
_out = _tmp
""",
inputs={"_cond": dace.Memlet(subset="k", data="cond")},
outputs={
"_out": dace.Memlet(subset="k", data="out"),
"_tmp": dace.Memlet(subset="k", data="tmp"),
},
map_ranges={"k": "0:10"},
name="test_tasklet",
external_edges=True,
)
sdfg.simplify()
assert tasklet.code.as_string.startswith("_tmp: dace.float64")

compiledsdfg = sdfg.compile()
cond = np.random.choice(a=[True, False], size=(10,))
out = np.zeros((10,))
compiledsdfg(cond=cond, out=out)
assert np.all(out == np.where(cond, 3.0, 7.0))



if __name__ == '__main__':
test_dse_simple()
test_dse_unconditional()
Expand All @@ -267,3 +305,4 @@ def test_dce_callback_manual():
test_dce()
test_dce_callback()
test_dce_callback_manual()
test_dce_add_type_hint_of_variable()

0 comments on commit bc08e9a

Please sign in to comment.