Skip to content

Commit

Permalink
fix[next][dace]: Fix translation of if statement from tasklet to inte…
Browse files Browse the repository at this point in the history
…r-state condition (#1469)

The bug addressed by this PR is that if-nodes were translated to tasklets. Tasklets assume that all inputs are evaluated. For if-nodes, we need to enforce exclusive execution of one of the two branches. That means that only one of the two arguments will be evaluated at runtime. We achieve this by implementing the true/false branches as separate states and checking the if-statement as condition on the inter-state edge.
  • Loading branch information
edopao authored Feb 26, 2024
1 parent 4228624 commit b86a347
Showing 1 changed file with 114 additions and 30 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# distribution for a copy of the license or check <https://www.gnu.org/licenses/>.
#
# SPDX-License-Identifier: GPL-3.0-or-later
import copy
import dataclasses
import itertools
from collections.abc import Sequence
Expand Down Expand Up @@ -566,37 +567,120 @@ def builtin_can_deref(
def builtin_if(
transformer: "PythonTaskletCodegen", node: itir.Expr, node_args: list[itir.Expr]
) -> list[ValueExpr]:
di = dace_debuginfo(node, transformer.context.body.debuginfo)
args = transformer.visit(node_args)
assert len(args) == 3
if_node = args[0][0] if isinstance(args[0], list) else args[0]

# the argument could be a list of elements on each branch representing the result of `make_tuple`
# however, the normal case is to find one value expression
assert len(args[1]) == len(args[2])
if_expr_args = [
(a[0] if isinstance(a, list) else a, b[0] if isinstance(b, list) else b)
for a, b in zip(args[1], args[2])
]

# in case of tuple arguments, generate one if-tasklet for each element of the output tuple
if_expr_values = []
for a, b in if_expr_args:
assert a.dtype == b.dtype
expr_args = [
(arg, f"{arg.value.data}_v")
for arg in (if_node, a, b)
if not isinstance(arg, SymbolExpr)
]
internals = [
arg.value if isinstance(arg, SymbolExpr) else f"{arg.value.data}_v"
for arg in (if_node, a, b)
]
expr = "({1} if {0} else {2})".format(*internals)
if_expr = transformer.add_expr_tasklet(expr_args, expr, a.dtype, "if", dace_debuginfo=di)
if_expr_values.append(if_expr[0])
assert len(node_args) == 3
sdfg = transformer.context.body
current_state = transformer.context.state
is_start_state = sdfg.start_block == current_state

# build an empty state to join true and false branches
join_state = sdfg.add_state_before(current_state, "join")

def build_if_state(arg, state):
symbol_map = copy.deepcopy(transformer.context.symbol_map)
node_context = Context(sdfg, state, symbol_map)
node_taskgen = PythonTaskletCodegen(
transformer.offset_provider, node_context, transformer.node_types
)
return node_taskgen.visit(arg)

# represent the if-statement condition as a tasklet inside an `if_statement` state preceding `join` state
stmt_state = sdfg.add_state_before(join_state, "if_statement", is_start_state)
stmt_node = build_if_state(node_args[0], stmt_state)[0]
assert isinstance(stmt_node, ValueExpr)
assert stmt_node.dtype == dace.dtypes.bool
assert sdfg.arrays[stmt_node.value.data].shape == (1,)

# visit true and false branches (here called `tbr` and `fbr`) as separate states, following `if_statement` state
tbr_state = sdfg.add_state("true_branch")
sdfg.add_edge(
stmt_state, tbr_state, dace.InterstateEdge(condition=f"{stmt_node.value.data} == True")
)
sdfg.add_edge(tbr_state, join_state, dace.InterstateEdge())
tbr_values = build_if_state(node_args[1], tbr_state)
#
fbr_state = sdfg.add_state("false_branch")
sdfg.add_edge(
stmt_state, fbr_state, dace.InterstateEdge(condition=f"{stmt_node.value.data} == False")
)
sdfg.add_edge(fbr_state, join_state, dace.InterstateEdge())
fbr_values = build_if_state(node_args[2], fbr_state)

assert isinstance(stmt_node, ValueExpr)
assert stmt_node.dtype == dace.dtypes.bool
# make the result of the if-statement evaluation available inside current state
ctx_stmt_node = ValueExpr(current_state.add_access(stmt_node.value.data), stmt_node.dtype)

# we distinguish between select if-statements, where both true and false branches are symbolic expressions,
# and therefore do not require exclusive branch execution, and regular if-statements where at least one branch
# is a value expression, which has to be evaluated at runtime with conditional state transition
result_values = []
assert len(tbr_values) == len(fbr_values)
for tbr_value, fbr_value in zip(tbr_values, fbr_values):
assert isinstance(tbr_value, (SymbolExpr, ValueExpr))
assert isinstance(fbr_value, (SymbolExpr, ValueExpr))
assert tbr_value.dtype == fbr_value.dtype

if all(isinstance(x, SymbolExpr) for x in (tbr_value, fbr_value)):
# both branches return symbolic expressions, therefore the if-node can be translated
# to a select-tasklet inside current state
# TODO: use select-memlet when it becomes available in dace
code = f"{tbr_value.value} if _cond else {fbr_value.value}"
if_expr = transformer.add_expr_tasklet(
[(ctx_stmt_node, "_cond")], code, tbr_value.dtype, "if_select"
)[0]
result_values.append(if_expr)
else:
# at least one of the two branches contains a value expression, which should be evaluated
# only if the corresponding true/false condition is satisfied
desc = sdfg.arrays[
tbr_value.value.data if isinstance(tbr_value, ValueExpr) else fbr_value.value.data
]
var = unique_var_name()
if isinstance(desc, dace.data.Scalar):
sdfg.add_scalar(var, desc.dtype, transient=True)
else:
sdfg.add_array(var, desc.shape, desc.dtype, transient=True)

# write result to transient data container and access it in the original state
for state, expr in [(tbr_state, tbr_value), (fbr_state, fbr_value)]:
val_node = state.add_access(var)
if isinstance(expr, ValueExpr):
state.add_nedge(
expr.value, val_node, dace.Memlet.from_array(expr.value.data, desc)
)
else:
assert desc.shape == (1,)
state.add_edge(
state.add_tasklet("write_symbol", {}, {"_out"}, f"_out = {expr.value}"),
"_out",
val_node,
None,
dace.Memlet(var, "0"),
)
result_values.append(ValueExpr(current_state.add_access(var), desc.dtype))

if tbr_state.is_empty() and fbr_state.is_empty():
# if all branches are symbolic expressions, the true/false and join states can be removed
# as well as the conditional state transition
sdfg.remove_nodes_from([join_state, tbr_state, fbr_state])
sdfg.add_edge(stmt_state, current_state, dace.InterstateEdge())
elif tbr_state.is_empty():
# use direct edge from if-statement to join state for true branch
tbr_condition = sdfg.edges_between(stmt_state, tbr_state)[0].condition
sdfg.edges_between(stmt_state, join_state)[0].contition = tbr_condition
sdfg.remove_node(tbr_state)
elif fbr_state.is_empty():
# use direct edge from if-statement to join state for false branch
fbr_condition = sdfg.edges_between(stmt_state, fbr_state)[0].condition
sdfg.edges_between(stmt_state, join_state)[0].contition = fbr_condition
sdfg.remove_node(fbr_state)
else:
# remove direct edge from if-statement to join state
sdfg.remove_edge(sdfg.edges_between(stmt_state, join_state)[0])
# the if-statement condition is not used in current state
current_state.remove_node(ctx_stmt_node.value)

return if_expr_values
return result_values


def builtin_list_get(
Expand Down

0 comments on commit b86a347

Please sign in to comment.