diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py index ab012b2a2a..d9007440d9 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py @@ -792,11 +792,11 @@ def translate_scalar_expr( visit_expr = True if isinstance(arg_expr, gtir.SymRef): try: - # check if symbol is defined in the GT4Py program, returns `None` if undefined + # check if symbol is defined in the GT4Py program, throws `KeyError` exception if undefined sdfg_builder.get_symbol_type(arg_expr.id) except KeyError: - # this is the case of non-variable argument, e.g. target type such as `float64`, - # used in a casting expression like `cast_(variable, float64)` + # all `SymRef` should refer to symbols defined in the program, except in case of non-variable argument, + # e.g. the type name `float64` used in casting expressions like `cast_(variable, float64)` visit_expr = False if visit_expr: diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow.py index fc9a6217c4..6048ee5671 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow.py @@ -144,9 +144,11 @@ class DataflowInputEdge(Protocol): This protocol represents an open connection into the dataflow. It provides the `connect` method to setup an input edge from an external data source. - Since the dataflow represents a stencil, we instantiate the dataflow inside a map scope - and connect its inputs and outputs to external data nodes by means of memlets that - traverse the map entry and exit nodes. + The most common case is that the dataflow represents a stencil, which is instantied + inside a map scope and whose inputs and outputs are connected to external data nodes + by means of memlets that traverse the map entry and exit nodes. + The dataflow can also be instatiated without a map, in which case the `map_entry` + argument is set to `None`. """ @abc.abstractmethod @@ -195,7 +197,7 @@ class EmptyInputEdge(DataflowInputEdge): node: dace.nodes.Tasklet def connect(self, map_entry: Optional[dace.nodes.MapEntry]) -> None: - # cannot create empty edge from MapEntry node, if this is not present + # the empty edge is not created if the dataflow is instatiated without a map if map_entry is not None: self.state.add_nedge(map_entry, self.node, dace.Memlet()) @@ -206,10 +208,12 @@ class DataflowOutputEdge: Allows to setup an output memlet through a map exit node. The result of a dataflow subgraph needs to be written to an external data node. - Since the dataflow represents a stencil and the dataflow is computed over - a field domain, the dataflow is instatiated inside a map scope. The `connect` - method creates a memlet that writes the dataflow result to the external array - passing through the map exit node. + The most common case is that the dataflow represents a stencil and the dataflow + is computed over a field domain, therefore the dataflow is instatiated inside + a map scope. The `connect` method creates a memlet that writes the dataflow + result to the external array passing through the `map_exit` node. + The dataflow can also be instatiated without a map, in which case the `map_exit` + argument is set to `None`. """ state: dace.SDFGState @@ -575,7 +579,7 @@ def _visit_if_branch( DataflowOutputEdge | tuple[DataflowOutputEdge | tuple[Any, ...], ...], ]: """ - Helper method to visit an if branch expression and lower it to a dtaflow inside the given nested SDFG and state. + Helper method to visit an if-branch expression and lower it to a dtaflow inside the given nested SDFG and state. Args: if_sdfg: The nested SDFG where the if expression is lowered. @@ -646,16 +650,22 @@ def visit_arg(arg: IteratorExpr | DataExpr) -> IteratorExpr | ValueExpr: lambda_args.append(inner_arg) lambda_params.append(im.sym(p)) - # visit each branch of the if-statement as it was a Lambda node + # visit each branch of the if-statement as if it was a Lambda node lambda_node = gtir.Lambda(params=lambda_params, expr=expr) return apply(if_sdfg, if_branch_state, self.subgraph_builder, lambda_node, lambda_args) def _visit_if(self, node: gtir.FunCall) -> ValueExpr | tuple[ValueExpr | tuple[Any, ...], ...]: + """ + Lowers an if-expression with exclusive branch execution into a nested SDFG, in which + each branch is lowered into a dataflow in a separate state and the if-condition is represented + as the inter-state edge condtion. + """ assert len(node.args) == 3 # TODO(edopao): enable once supported in next DaCe release use_conditional_block: Final[bool] = False + # evaluate the if-condition that will write to a boolean scalar node condition_value = self.visit(node.args[0]) assert ( ( @@ -669,6 +679,7 @@ def _visit_if(self, node: gtir.FunCall) -> ValueExpr | tuple[ValueExpr | tuple[A nsdfg = dace.SDFG(self.unique_nsdfg_name(prefix="if_stmt")) nsdfg.debuginfo = dace_utils.debug_info(node, default=self.sdfg.debuginfo) + # create states inside the nested SDFG for the if-branches if use_conditional_block: if_region = dace.sdfg.state.ConditionalBlock("if") nsdfg.add_node(if_region) @@ -693,6 +704,7 @@ def _visit_if(self, node: gtir.FunCall) -> ValueExpr | tuple[ValueExpr | tuple[A nsdfg_symbol_mapping = {} input_memlets: dict[str, MemletExpr | ValueExpr] = {} + # define scalar or symbol for the condition value inside the nested SDFG if isinstance(condition_value, SymbolExpr): nsdfg.add_symbol("__cond", dace.dtypes.bool) nsdfg_symbol_mapping["__cond"] = condition_value.value @@ -707,18 +719,23 @@ def _visit_if(self, node: gtir.FunCall) -> ValueExpr | tuple[ValueExpr | tuple[A input_memlets["__cond"] = condition_value for if_branch_state, arg in zip([tstate, fstate], node.args[1:3]): + # visit each if-branch in the corresponding state of the nested SDFG in_edges, out_edge = self._visit_if_branch(nsdfg, if_branch_state, arg, input_memlets) for edge in in_edges: edge.connect(map_entry=None) + # the result of each branch needs to be moved to the parent SDFG def construct_output( output_state: dace.SDFGState, edge: DataflowOutputEdge, sym: gtir.Sym ) -> ValueExpr: + # the output data node has the same name as the nested SDFG output connector output_data = str(sym.id) try: output_desc = nsdfg.data(output_data) assert not output_desc.transient except KeyError: + # if the result is currently written to a transient node, inside the nested SDFG, + # we need to allocate a non-transient data node result_desc = edge.result.dc_node.desc(nsdfg) output_desc = result_desc.clone() output_desc.transient = False @@ -781,6 +798,7 @@ def construct_output( ) def connect_output(inner_value: ValueExpr) -> ValueExpr: + # each output connector of the nested SDFG writes to a transient node in the parent SDFG inner_data = inner_value.dc_node.data inner_desc = inner_value.dc_node.desc(nsdfg) assert not inner_desc.transient diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_python_codegen.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_python_codegen.py index 4bdb602f5f..956a5c6435 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_python_codegen.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_python_codegen.py @@ -73,18 +73,12 @@ } -def builtin_cast(*args: Any) -> str: - val, target_type = args +def builtin_cast(val: str, target_type: str) -> str: assert target_type in gtir.TYPEBUILTINS return MATH_BUILTINS_MAPPING[target_type].format(val) -def builtin_if(*args: Any) -> str: - cond, true_val, false_val = args - return f"{true_val} if {cond} else {false_val}" - - -def make_const_list(arg: str) -> str: +def builtin_const_list(arg: str) -> str: """ Takes a single scalar argument and broadcasts this value on the local dimension of map expression. In a dataflow, we represent it as a tasklet that writes @@ -93,10 +87,19 @@ def make_const_list(arg: str) -> str: return arg -GENERAL_BUILTIN_MAPPING: dict[str, Callable[[Any], str]] = { +def builtin_if(cond: str, true_val: str, false_val: str) -> str: + return f"{true_val} if {cond} else {false_val}" + + +def builtin_tuple_get(index: str, tuple_name: str) -> str: + return f"{tuple_name}_{index}" + + +GENERAL_BUILTIN_MAPPING: dict[str, Callable[..., str]] = { "cast_": builtin_cast, "if_": builtin_if, - "make_const_list": make_const_list, + "make_const_list": builtin_const_list, + "tuple_get": builtin_tuple_get, }