Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/dace-gtir-iterator_view' into da…
Browse files Browse the repository at this point in the history
…ce-gtir-scan
  • Loading branch information
edopao committed Jan 10, 2025
2 parents 76168b3 + 87b5bd5 commit 4914de5
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 23 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -792,11 +792,11 @@ def translate_scalar_expr(
visit_expr = True
if isinstance(arg_expr, gtir.SymRef):
try:
# check if symbol is defined in the GT4Py program, returns `None` if undefined
# check if symbol is defined in the GT4Py program, throws `KeyError` exception if undefined
sdfg_builder.get_symbol_type(arg_expr.id)
except KeyError:
# this is the case of non-variable argument, e.g. target type such as `float64`,
# used in a casting expression like `cast_(variable, float64)`
# all `SymRef` should refer to symbols defined in the program, except in case of non-variable argument,
# e.g. the type name `float64` used in casting expressions like `cast_(variable, float64)`
visit_expr = False

if visit_expr:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -144,9 +144,11 @@ class DataflowInputEdge(Protocol):
This protocol represents an open connection into the dataflow.
It provides the `connect` method to setup an input edge from an external data source.
Since the dataflow represents a stencil, we instantiate the dataflow inside a map scope
and connect its inputs and outputs to external data nodes by means of memlets that
traverse the map entry and exit nodes.
The most common case is that the dataflow represents a stencil, which is instantied
inside a map scope and whose inputs and outputs are connected to external data nodes
by means of memlets that traverse the map entry and exit nodes.
The dataflow can also be instatiated without a map, in which case the `map_entry`
argument is set to `None`.
"""

@abc.abstractmethod
Expand Down Expand Up @@ -195,7 +197,7 @@ class EmptyInputEdge(DataflowInputEdge):
node: dace.nodes.Tasklet

def connect(self, map_entry: Optional[dace.nodes.MapEntry]) -> None:
# cannot create empty edge from MapEntry node, if this is not present
# the empty edge is not created if the dataflow is instatiated without a map
if map_entry is not None:
self.state.add_nedge(map_entry, self.node, dace.Memlet())

Expand All @@ -206,10 +208,12 @@ class DataflowOutputEdge:
Allows to setup an output memlet through a map exit node.
The result of a dataflow subgraph needs to be written to an external data node.
Since the dataflow represents a stencil and the dataflow is computed over
a field domain, the dataflow is instatiated inside a map scope. The `connect`
method creates a memlet that writes the dataflow result to the external array
passing through the map exit node.
The most common case is that the dataflow represents a stencil and the dataflow
is computed over a field domain, therefore the dataflow is instatiated inside
a map scope. The `connect` method creates a memlet that writes the dataflow
result to the external array passing through the `map_exit` node.
The dataflow can also be instatiated without a map, in which case the `map_exit`
argument is set to `None`.
"""

state: dace.SDFGState
Expand Down Expand Up @@ -575,7 +579,7 @@ def _visit_if_branch(
DataflowOutputEdge | tuple[DataflowOutputEdge | tuple[Any, ...], ...],
]:
"""
Helper method to visit an if branch expression and lower it to a dtaflow inside the given nested SDFG and state.
Helper method to visit an if-branch expression and lower it to a dtaflow inside the given nested SDFG and state.
Args:
if_sdfg: The nested SDFG where the if expression is lowered.
Expand Down Expand Up @@ -646,16 +650,22 @@ def visit_arg(arg: IteratorExpr | DataExpr) -> IteratorExpr | ValueExpr:
lambda_args.append(inner_arg)
lambda_params.append(im.sym(p))

# visit each branch of the if-statement as it was a Lambda node
# visit each branch of the if-statement as if it was a Lambda node
lambda_node = gtir.Lambda(params=lambda_params, expr=expr)
return apply(if_sdfg, if_branch_state, self.subgraph_builder, lambda_node, lambda_args)

def _visit_if(self, node: gtir.FunCall) -> ValueExpr | tuple[ValueExpr | tuple[Any, ...], ...]:
"""
Lowers an if-expression with exclusive branch execution into a nested SDFG, in which
each branch is lowered into a dataflow in a separate state and the if-condition is represented
as the inter-state edge condtion.
"""
assert len(node.args) == 3

# TODO(edopao): enable once supported in next DaCe release
use_conditional_block: Final[bool] = False

# evaluate the if-condition that will write to a boolean scalar node
condition_value = self.visit(node.args[0])
assert (
(
Expand All @@ -669,6 +679,7 @@ def _visit_if(self, node: gtir.FunCall) -> ValueExpr | tuple[ValueExpr | tuple[A
nsdfg = dace.SDFG(self.unique_nsdfg_name(prefix="if_stmt"))
nsdfg.debuginfo = dace_utils.debug_info(node, default=self.sdfg.debuginfo)

# create states inside the nested SDFG for the if-branches
if use_conditional_block:
if_region = dace.sdfg.state.ConditionalBlock("if")
nsdfg.add_node(if_region)
Expand All @@ -693,6 +704,7 @@ def _visit_if(self, node: gtir.FunCall) -> ValueExpr | tuple[ValueExpr | tuple[A
nsdfg_symbol_mapping = {}
input_memlets: dict[str, MemletExpr | ValueExpr] = {}

# define scalar or symbol for the condition value inside the nested SDFG
if isinstance(condition_value, SymbolExpr):
nsdfg.add_symbol("__cond", dace.dtypes.bool)
nsdfg_symbol_mapping["__cond"] = condition_value.value
Expand All @@ -707,18 +719,23 @@ def _visit_if(self, node: gtir.FunCall) -> ValueExpr | tuple[ValueExpr | tuple[A
input_memlets["__cond"] = condition_value

for if_branch_state, arg in zip([tstate, fstate], node.args[1:3]):
# visit each if-branch in the corresponding state of the nested SDFG
in_edges, out_edge = self._visit_if_branch(nsdfg, if_branch_state, arg, input_memlets)
for edge in in_edges:
edge.connect(map_entry=None)

# the result of each branch needs to be moved to the parent SDFG
def construct_output(
output_state: dace.SDFGState, edge: DataflowOutputEdge, sym: gtir.Sym
) -> ValueExpr:
# the output data node has the same name as the nested SDFG output connector
output_data = str(sym.id)
try:
output_desc = nsdfg.data(output_data)
assert not output_desc.transient
except KeyError:
# if the result is currently written to a transient node, inside the nested SDFG,
# we need to allocate a non-transient data node
result_desc = edge.result.dc_node.desc(nsdfg)
output_desc = result_desc.clone()
output_desc.transient = False
Expand Down Expand Up @@ -781,6 +798,7 @@ def construct_output(
)

def connect_output(inner_value: ValueExpr) -> ValueExpr:
# each output connector of the nested SDFG writes to a transient node in the parent SDFG
inner_data = inner_value.dc_node.data
inner_desc = inner_value.dc_node.desc(nsdfg)
assert not inner_desc.transient
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,18 +73,12 @@
}


def builtin_cast(*args: Any) -> str:
val, target_type = args
def builtin_cast(val: str, target_type: str) -> str:
assert target_type in gtir.TYPEBUILTINS
return MATH_BUILTINS_MAPPING[target_type].format(val)


def builtin_if(*args: Any) -> str:
cond, true_val, false_val = args
return f"{true_val} if {cond} else {false_val}"


def make_const_list(arg: str) -> str:
def builtin_const_list(arg: str) -> str:
"""
Takes a single scalar argument and broadcasts this value on the local dimension
of map expression. In a dataflow, we represent it as a tasklet that writes
Expand All @@ -93,10 +87,19 @@ def make_const_list(arg: str) -> str:
return arg


GENERAL_BUILTIN_MAPPING: dict[str, Callable[[Any], str]] = {
def builtin_if(cond: str, true_val: str, false_val: str) -> str:
return f"{true_val} if {cond} else {false_val}"


def builtin_tuple_get(index: str, tuple_name: str) -> str:
return f"{tuple_name}_{index}"


GENERAL_BUILTIN_MAPPING: dict[str, Callable[..., str]] = {
"cast_": builtin_cast,
"if_": builtin_if,
"make_const_list": make_const_list,
"make_const_list": builtin_const_list,
"tuple_get": builtin_tuple_get,
}


Expand Down

0 comments on commit 4914de5

Please sign in to comment.