From 440a474a35cd0c948565620cb6b34c2f747f9081 Mon Sep 17 00:00:00 2001 From: Edoardo Paone Date: Wed, 11 Dec 2024 11:59:20 +0100 Subject: [PATCH 1/4] Split handling of let-statement lambdas from stencil body --- .../gtir_builtin_translators.py | 26 +++--- .../runners/dace_fieldview/gtir_dataflow.py | 84 ++++++++++++++----- .../runners/dace_fieldview/gtir_sdfg.py | 5 +- 3 files changed, 81 insertions(+), 34 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py index ff011c4193..580ba64881 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py @@ -30,7 +30,7 @@ gtir_python_codegen, utility as dace_gtir_utils, ) -from gt4py.next.type_system import type_specifications as ts +from gt4py.next.type_system import type_info as ti, type_specifications as ts if TYPE_CHECKING: @@ -366,36 +366,36 @@ def translate_as_fieldop( """ assert isinstance(node, gtir.FunCall) assert cpm.is_call_to(node.fun, "as_fieldop") - assert isinstance(node.type, ts.FieldType) fun_node = node.fun assert len(fun_node.args) == 2 - stencil_expr, domain_expr = fun_node.args + fieldop_expr, domain_expr = fun_node.args - if isinstance(stencil_expr, gtir.Lambda): - # Default case, handled below: the argument expression is a lambda function - # representing the stencil operation to be computed over the field domain. - pass - elif cpm.is_ref_to(stencil_expr, "deref"): + assert isinstance(node.type, ts.FieldType) + if cpm.is_ref_to(fieldop_expr, "deref"): # Special usage of 'deref' as argument to fieldop expression, to pass a scalar # value to 'as_fieldop' function. It results in broadcasting the scalar value # over the field domain. stencil_expr = im.lambda_("a")(im.deref("a")) - stencil_expr.expr.type = node.type.dtype # type: ignore[attr-defined] + stencil_expr.expr.type = node.type.dtype + elif isinstance(fieldop_expr, gtir.Lambda): + # Default case, handled below: the argument expression is a lambda function + # representing the stencil operation to be computed over the field domain. + stencil_expr = fieldop_expr else: raise NotImplementedError( - f"Expression type '{type(stencil_expr)}' not supported as argument to 'as_fieldop' node." + f"Expression type '{type(fieldop_expr)}' not supported as argument to 'as_fieldop' node." ) # parse the domain of the field operator domain = extract_domain(domain_expr) # visit the list of arguments to be passed to the lambda expression - stencil_args = [_parse_fieldop_arg(arg, sdfg, state, sdfg_builder, domain) for arg in node.args] + fieldop_args = [_parse_fieldop_arg(arg, sdfg, state, sdfg_builder, domain) for arg in node.args] # represent the field operator as a mapped tasklet graph, which will range over the field domain taskgen = gtir_dataflow.LambdaToDataflow(sdfg, state, sdfg_builder) - input_edges, output_edge = taskgen.visit(stencil_expr, args=stencil_args) + input_edges, output_edge = taskgen.apply(stencil_expr, args=fieldop_args) return _create_field_operator( sdfg, state, domain, node.type, sdfg_builder, input_edges, output_edge @@ -654,7 +654,7 @@ def translate_tuple_get( if not isinstance(node.args[0], gtir.Literal): raise ValueError("Tuple can only be subscripted with compile-time constants.") - assert node.args[0].type == dace_utils.as_itir_type(INDEX_DTYPE) + assert ti.is_integral(node.args[0].type) index = int(node.args[0].value) data_nodes = sdfg_builder.visit( diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow.py index cfba4d61e5..c1b64e31da 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow.py @@ -10,10 +10,22 @@ import abc import dataclasses -from typing import Any, Dict, Final, List, Optional, Protocol, Set, Tuple, TypeAlias, Union +from typing import ( + Any, + Dict, + Final, + List, + Optional, + Protocol, + Sequence, + Set, + Tuple, + TypeAlias, + Union, +) import dace -import dace.subsets as sbs +from dace import subsets as sbs from gt4py import eve from gt4py.next import common as gtx_common @@ -68,7 +80,7 @@ class MemletExpr: dc_node: dace.nodes.AccessNode gt_dtype: itir_ts.ListType | ts.ScalarType - subset: sbs.Indices | sbs.Range + subset: sbs.Range @dataclasses.dataclass(frozen=True) @@ -1264,39 +1276,38 @@ def visit_FunCall(self, node: gtir.FunCall) -> IteratorExpr | DataExpr: elif cpm.is_applied_shift(node): return self._visit_shift(node) + elif isinstance(node.fun, gtir.Lambda): + raise AssertionError("Lambda node should be visited with 'apply()' method.") + elif isinstance(node.fun, gtir.SymRef): return self._visit_generic_builtin(node) else: raise NotImplementedError(f"Invalid 'FunCall' node: {node}.") - def visit_Lambda( - self, node: gtir.Lambda, args: list[IteratorExpr | MemletExpr | SymbolExpr] - ) -> tuple[list[DataflowInputEdge], DataflowOutputEdge]: - for p, arg in zip(node.params, args, strict=True): - self.symbol_map[str(p.id)] = arg - output_expr: DataExpr = self.visit(node.expr) - if isinstance(output_expr, ValueExpr): - return self.input_edges, DataflowOutputEdge(self.state, output_expr) + def visit_Lambda(self, node: gtir.Lambda) -> DataflowOutputEdge: + result: DataExpr = self.visit(node.expr) + + if isinstance(result, ValueExpr): + return DataflowOutputEdge(self.state, result) - if isinstance(output_expr, MemletExpr): + if isinstance(result, MemletExpr): # special case where the field operator is simply copying data from source to destination node - output_dtype = output_expr.dc_node.desc(self.sdfg).dtype + output_dtype = result.dc_node.desc(self.sdfg).dtype tasklet_node = self._add_tasklet("copy", {"__inp"}, {"__out"}, "__out = __inp") self._add_input_data_edge( - output_expr.dc_node, - output_expr.subset, + result.dc_node, + result.subset, tasklet_node, "__inp", ) else: - assert isinstance(output_expr, SymbolExpr) # even simpler case, where a constant value is written to destination node - output_dtype = output_expr.dc_dtype - tasklet_node = self._add_tasklet("write", {}, {"__out"}, f"__out = {output_expr.value}") + output_dtype = result.dc_dtype + tasklet_node = self._add_tasklet("write", {}, {"__out"}, f"__out = {result.value}") output_expr = self._construct_tasklet_result(output_dtype, tasklet_node, "__out") - return self.input_edges, DataflowOutputEdge(self.state, output_expr) + return DataflowOutputEdge(self.state, output_expr) def visit_Literal(self, node: gtir.Literal) -> SymbolExpr: dc_dtype = dace_utils.as_dace_type(node.type) @@ -1309,3 +1320,38 @@ def visit_SymRef(self, node: gtir.SymRef) -> IteratorExpr | MemletExpr | SymbolE # if not in the lambda symbol map, this must be a symref to a builtin function assert param in gtir_python_codegen.MATH_BUILTINS_MAPPING return SymbolExpr(param, dace.string) + + def apply( + self, + node: gtir.Lambda, + args: Sequence[IteratorExpr | MemletExpr | SymbolExpr], + ) -> tuple[list[DataflowInputEdge], DataflowOutputEdge]: + # lambda arguments are mapped to symbols defined in lambda scope + prev_symbols: dict[ + str, + Optional[IteratorExpr | MemletExpr | SymbolExpr], + ] = {} + for p, arg in zip(node.params, args, strict=True): + symbol_name = str(p.id) + prev_symbols[symbol_name] = self.symbol_map.get(symbol_name, None) + self.symbol_map[symbol_name] = arg + + if cpm.is_let(node.expr): + let_node = node.expr + let_args = [self.visit(arg) for arg in let_node.args] + assert isinstance(let_node.fun, gtir.Lambda) + input_edges, output_edge = self.apply(let_node.fun, args=let_args) + + else: + # this lambda is not a let-statement, but a stencil expression + output_edge = self.visit_Lambda(node) + input_edges = self.input_edges + + # remove locally defined lambda symbols and restore previous symbols + for symbol_name, prev_value in prev_symbols.items(): + if prev_value is None: + self.symbol_map.pop(symbol_name) + else: + self.symbol_map[symbol_name] = prev_value + + return input_edges, output_edge diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_sdfg.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_sdfg.py index 6b5e164458..9bd40f75f8 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_sdfg.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_sdfg.py @@ -602,7 +602,7 @@ def visit_Lambda( node: gtir.Lambda, sdfg: dace.SDFG, head_state: dace.SDFGState, - args: list[gtir_builtin_translators.FieldopResult], + args: Sequence[gtir_builtin_translators.FieldopResult], ) -> gtir_builtin_translators.FieldopResult: """ Translates a `Lambda` node to a nested SDFG in the current state. @@ -679,7 +679,7 @@ def get_field_domain_offset( self.offset_provider_type, lambda_symbols, lambda_field_offsets ) nsdfg = dace.SDFG(name=self.unique_nsdfg_name(sdfg, "lambda")) - nstate = nsdfg.add_state("lambda") + nsdfg.debuginfo = dace_utils.debug_info(node, default=sdfg.debuginfo) # add sdfg storage for the symbols that need to be passed as input parameters lambda_params = [ @@ -690,6 +690,7 @@ def get_field_domain_offset( nsdfg, node_params=lambda_params, symbolic_arguments=lambda_domain_symbols ) + nstate = nsdfg.add_state("lambda") lambda_result = lambda_translator.visit( node.expr, sdfg=nsdfg, From 55811dcfa187854bf98e0a18472bcf602a1dada1 Mon Sep 17 00:00:00 2001 From: Edoardo Paone Date: Fri, 13 Dec 2024 12:24:41 +0100 Subject: [PATCH 2/4] review comments --- .../gtir_builtin_translators.py | 14 ++-- .../runners/dace_fieldview/gtir_dataflow.py | 79 +++++++++++-------- 2 files changed, 52 insertions(+), 41 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py index 580ba64881..3d78cbbafe 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py @@ -13,7 +13,7 @@ from typing import TYPE_CHECKING, Final, Iterable, Optional, Protocol, Sequence, TypeAlias import dace -import dace.subsets as sbs +from dace import subsets as dace_subsets from gt4py.next import common as gtx_common, utils as gtx_utils from gt4py.next.ffront import fbuiltins as gtx_fbuiltins @@ -39,7 +39,7 @@ def _get_domain_indices( dims: Sequence[gtx_common.Dimension], offsets: Optional[Sequence[dace.symbolic.SymExpr]] = None -) -> sbs.Indices: +) -> dace_subsets.Indices: """ Helper function to construct the list of indices for a field domain, applying an optional offset in each dimension as start index. @@ -55,9 +55,9 @@ def _get_domain_indices( """ index_variables = [dace.symbolic.SymExpr(dace_gtir_utils.get_map_variable(dim)) for dim in dims] if offsets is None: - return sbs.Indices(index_variables) + return dace_subsets.Indices(index_variables) else: - return sbs.Indices( + return dace_subsets.Indices( [ index - offset if offset != 0 else index for index, offset in zip(index_variables, offsets, strict=True) @@ -96,7 +96,7 @@ def get_local_view( """Helper method to access a field in local view, given the compute domain of a field operator.""" if isinstance(self.gt_type, ts.ScalarType): return gtir_dataflow.MemletExpr( - dc_node=self.dc_node, gt_dtype=self.gt_type, subset=sbs.Indices([0]) + dc_node=self.dc_node, gt_dtype=self.gt_type, subset=dace_subsets.Indices([0]) ) if isinstance(self.gt_type, ts.FieldType): @@ -263,7 +263,7 @@ def _create_field_operator( dataflow_output_desc = output_edge.result.dc_node.desc(sdfg) - field_subset = sbs.Range.from_indices(field_indices) + field_subset = dace_subsets.Range.from_indices(field_indices) if isinstance(output_edge.result.gt_dtype, ts.ScalarType): assert output_edge.result.gt_dtype == node_type.dtype assert isinstance(dataflow_output_desc, dace.data.Scalar) @@ -280,7 +280,7 @@ def _create_field_operator( field_dims.append(output_edge.result.gt_dtype.offset_type) field_shape.extend(dataflow_output_desc.shape) field_offset.extend(dataflow_output_desc.offset) - field_subset = field_subset + sbs.Range.from_array(dataflow_output_desc) + field_subset = field_subset + dace_subsets.Range.from_array(dataflow_output_desc) # allocate local temporary storage field_name, _ = sdfg.add_temp_transient(field_shape, dataflow_output_desc.dtype) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow.py index c1b64e31da..b102e9638f 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow.py @@ -25,7 +25,7 @@ ) import dace -from dace import subsets as sbs +from dace import subsets as dace_subsets from gt4py import eve from gt4py.next import common as gtx_common @@ -80,7 +80,7 @@ class MemletExpr: dc_node: dace.nodes.AccessNode gt_dtype: itir_ts.ListType | ts.ScalarType - subset: sbs.Range + subset: dace_subsets.Range @dataclasses.dataclass(frozen=True) @@ -116,7 +116,7 @@ class IteratorExpr: field_domain: list[tuple[gtx_common.Dimension, dace.symbolic.SymExpr]] indices: dict[gtx_common.Dimension, DataExpr] - def get_memlet_subset(self, sdfg: dace.SDFG) -> sbs.Range: + def get_memlet_subset(self, sdfg: dace.SDFG) -> dace_subsets.Range: if not all(isinstance(self.indices[dim], SymbolExpr) for dim, _ in self.field_domain): raise ValueError(f"Cannot deref iterator {self}.") @@ -129,7 +129,7 @@ def get_memlet_subset(self, sdfg: dace.SDFG) -> sbs.Range: assert len(field_desc.shape) == len(self.field_domain) field_domain = self.field_domain - return sbs.Range.from_string( + return dace_subsets.Range.from_string( ",".join( str(self.indices[dim].value - offset) # type: ignore[union-attr] if dim in self.indices @@ -164,7 +164,7 @@ class MemletInputEdge(DataflowInputEdge): state: dace.SDFGState source: dace.nodes.AccessNode - subset: sbs.Range + subset: dace_subsets.Range dest: dace.nodes.AccessNode | dace.nodes.Tasklet dest_conn: Optional[str] @@ -214,7 +214,7 @@ def connect( self, mx: dace.nodes.MapExit, dest: dace.nodes.AccessNode, - subset: sbs.Range, + subset: dace_subsets.Range, ) -> None: # retrieve the node which writes the result last_node = self.state.in_edges(self.result.dc_node)[0].src @@ -270,8 +270,9 @@ def get_reduce_params(node: gtir.FunCall) -> tuple[str, SymbolExpr, SymbolExpr]: class LambdaToDataflow(eve.NodeVisitor): """ - Translates an `ir.Lambda` expression to a dataflow graph. + Visitor class to translate a `Lambda` expression to a dataflow graph. + This visitor should be applied by calling `apply()` method on a `Lambda` IR. The dataflow graph generated here typically represents the stencil function of a field operator. It only computes single elements or pure local fields, in case of neighbor values. In case of local fields, the dataflow contains @@ -305,7 +306,7 @@ def __init__( def _add_input_data_edge( self, src: dace.nodes.AccessNode, - src_subset: sbs.Range, + src_subset: dace_subsets.Range, dst_node: dace.nodes.Node, dst_conn: Optional[str] = None, src_offset: Optional[list[dace.symbolic.SymExpr]] = None, @@ -313,7 +314,7 @@ def _add_input_data_edge( input_subset = ( src_subset if src_offset is None - else sbs.Range( + else dace_subsets.Range( (start - off, stop - off, step) for (start, stop, step), off in zip(src_subset, src_offset, strict=True) ) @@ -524,7 +525,7 @@ def _visit_deref(self, node: gtir.FunCall) -> DataExpr: # add new termination point for the field parameter self._add_input_data_edge( arg_expr.field, - sbs.Range.from_array(field_desc), + dace_subsets.Range.from_array(field_desc), deref_node, "field", src_offset=[offset for (_, offset) in arg_expr.field_domain], @@ -592,7 +593,7 @@ def _visit_neighbors(self, node: gtir.FunCall) -> ValueExpr: MemletExpr( dc_node=it.field, gt_dtype=node.type, - subset=sbs.Range.from_string( + subset=dace_subsets.Range.from_string( ",".join( str(it.indices[dim].value - offset) # type: ignore[union-attr] if dim != offset_provider.codomain @@ -608,7 +609,7 @@ def _visit_neighbors(self, node: gtir.FunCall) -> ValueExpr: MemletExpr( dc_node=self.state.add_access(connectivity), gt_dtype=node.type, - subset=sbs.Range.from_string( + subset=dace_subsets.Range.from_string( f"{origin_index.value}, 0:{offset_provider.max_neighbors}" ), ) @@ -770,7 +771,7 @@ def _visit_map(self, node: gtir.FunCall) -> ValueExpr: gt_dtype=itir_ts.ListType( element_type=node.type.element_type, offset_type=offset_type ), - subset=sbs.Range.from_string( + subset=dace_subsets.Range.from_string( f"{origin_map_index}, 0:{offset_provider_type.max_neighbors}" ), ) @@ -920,7 +921,9 @@ def _make_reduce_with_skip_values( ) self._add_input_data_edge( connectivity_node, - sbs.Range.from_string(f"{origin_map_index}, 0:{offset_provider_type.max_neighbors}"), + dace_subsets.Range.from_string( + f"{origin_map_index}, 0:{offset_provider_type.max_neighbors}" + ), nsdfg_node, "neighbor_indices", ) @@ -1093,7 +1096,7 @@ def _make_dynamic_neighbor_offset( ) self._add_input_data_edge( offset_table_node, - sbs.Range.from_array(offset_table_node.desc(self.sdfg)), + dace_subsets.Range.from_array(offset_table_node.desc(self.sdfg)), tasklet_node, "table", ) @@ -1139,7 +1142,7 @@ def _make_unstructured_shift( shifted_indices[neighbor_dim] = MemletExpr( dc_node=offset_table_node, gt_dtype=it.gt_dtype, - subset=sbs.Indices([origin_index.value, offset_expr.value]), + subset=dace_subsets.Indices([origin_index.value, offset_expr.value]), ) else: # dynamic offset: we cannot use a memlet to retrieve the offset value, use a tasklet node @@ -1277,7 +1280,8 @@ def visit_FunCall(self, node: gtir.FunCall) -> IteratorExpr | DataExpr: return self._visit_shift(node) elif isinstance(node.fun, gtir.Lambda): - raise AssertionError("Lambda node should be visited with 'apply()' method.") + # Lambda node should be visited with 'apply()' method. + raise ValueError(f"Unexpected lambda in 'FunCall' node: {node}.") elif isinstance(node.fun, gtir.SymRef): return self._visit_generic_builtin(node) @@ -1326,32 +1330,39 @@ def apply( node: gtir.Lambda, args: Sequence[IteratorExpr | MemletExpr | SymbolExpr], ) -> tuple[list[DataflowInputEdge], DataflowOutputEdge]: - # lambda arguments are mapped to symbols defined in lambda scope - prev_symbols: dict[ - str, - Optional[IteratorExpr | MemletExpr | SymbolExpr], - ] = {} - for p, arg in zip(node.params, args, strict=True): - symbol_name = str(p.id) - prev_symbols[symbol_name] = self.symbol_map.get(symbol_name, None) - self.symbol_map[symbol_name] = arg + """ + Entry point for this visitor class. + + This visitor will translate a `Lambda` node into a dataflow graph to be + instantiated inside a map scope implementing the field operator. + However, this `apply()` method is responsible to recognize the usage of + the `Lambda` node, which can be either a let-statement or the stencil expression + in local view. The usage of a `Lambda` as let-statement corresponds to computing + some results and making them available inside the lambda scope, represented + as a nested SDFG. All let-statements, if any, are supposed to be encountered + before the stencil expression. In other words, the `Lambda` node representing + the stencil expression is always the innermost node. + Therefore, the lowering of let-statements results in recursive calls to + `apply()` until the stencil expression is found. At that point, it falls + back to the `visit()` function. + """ + + # lambda arguments are mapped to symbols defined in lambda scope. + prev_symbol_map = self.symbol_map + self.symbol_map = self.symbol_map.copy() + self.symbol_map |= {str(p.id): arg for p, arg in zip(node.params, args, strict=True)} if cpm.is_let(node.expr): let_node = node.expr let_args = [self.visit(arg) for arg in let_node.args] assert isinstance(let_node.fun, gtir.Lambda) input_edges, output_edge = self.apply(let_node.fun, args=let_args) - else: - # this lambda is not a let-statement, but a stencil expression - output_edge = self.visit_Lambda(node) + # this lambda node is not a let-statement, but a stencil expression + output_edge = self.visit(node) input_edges = self.input_edges # remove locally defined lambda symbols and restore previous symbols - for symbol_name, prev_value in prev_symbols.items(): - if prev_value is None: - self.symbol_map.pop(symbol_name) - else: - self.symbol_map[symbol_name] = prev_value + self.symbol_map = prev_symbol_map return input_edges, output_edge From 62e1648dd0d9f28fdd2404ca5f92db9209c14f7a Mon Sep 17 00:00:00 2001 From: Edoardo Paone Date: Mon, 16 Dec 2024 11:39:41 +0100 Subject: [PATCH 3/4] review comments (1) --- .../runners/dace_fieldview/gtir_dataflow.py | 65 +++++++++++++------ 1 file changed, 45 insertions(+), 20 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow.py index b102e9638f..8083dc6b09 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow.py @@ -1325,44 +1325,69 @@ def visit_SymRef(self, node: gtir.SymRef) -> IteratorExpr | MemletExpr | SymbolE assert param in gtir_python_codegen.MATH_BUILTINS_MAPPING return SymbolExpr(param, dace.string) - def apply( + def _visit_let( self, node: gtir.Lambda, args: Sequence[IteratorExpr | MemletExpr | SymbolExpr], - ) -> tuple[list[DataflowInputEdge], DataflowOutputEdge]: + ) -> DataflowOutputEdge: """ - Entry point for this visitor class. - - This visitor will translate a `Lambda` node into a dataflow graph to be - instantiated inside a map scope implementing the field operator. - However, this `apply()` method is responsible to recognize the usage of - the `Lambda` node, which can be either a let-statement or the stencil expression - in local view. The usage of a `Lambda` as let-statement corresponds to computing - some results and making them available inside the lambda scope, represented - as a nested SDFG. All let-statements, if any, are supposed to be encountered - before the stencil expression. In other words, the `Lambda` node representing - the stencil expression is always the innermost node. + Maps lambda arguments to internal parameters. + + This method is responsible to recognize the usage of the `Lambda` node, + which can be either a let-statement or the stencil expression in local view. + The usage of a `Lambda` as let-statement corresponds to computing some results + and making them available inside the lambda scope, represented as a nested SDFG. + All let-statements, if any, are supposed to be encountered before the stencil + expression. In other words, the `Lambda` node representing the stencil expression + is always the innermost node. Therefore, the lowering of let-statements results in recursive calls to - `apply()` until the stencil expression is found. At that point, it falls + `_visit_let()` until the stencil expression is found. At that point, it falls back to the `visit()` function. """ # lambda arguments are mapped to symbols defined in lambda scope. prev_symbol_map = self.symbol_map - self.symbol_map = self.symbol_map.copy() - self.symbol_map |= {str(p.id): arg for p, arg in zip(node.params, args, strict=True)} + self.symbol_map = self.symbol_map | { + str(p.id): arg for p, arg in zip(node.params, args, strict=True) + } if cpm.is_let(node.expr): let_node = node.expr let_args = [self.visit(arg) for arg in let_node.args] assert isinstance(let_node.fun, gtir.Lambda) - input_edges, output_edge = self.apply(let_node.fun, args=let_args) + output_edge = self._visit_let(let_node.fun, args=let_args) else: # this lambda node is not a let-statement, but a stencil expression output_edge = self.visit(node) - input_edges = self.input_edges - # remove locally defined lambda symbols and restore previous symbols + # restore previous symbols, thus removing locally defined lambda symbols in the let-scope self.symbol_map = prev_symbol_map - return input_edges, output_edge + return output_edge + + def apply( + self, + node: gtir.Lambda, + args: Sequence[IteratorExpr | MemletExpr | SymbolExpr], + ) -> tuple[list[DataflowInputEdge], DataflowOutputEdge]: + """ + Entry point for this visitor class, that will translate a `Lambda` node + into a dataflow graph to be instantiated inside a map scope implementing + the field operator. + + It calls `_visit_let()` to maps lambda arguments to internal parameters and + visit let-statements (if any), which always appear as outermost nodes. + The visitor will return the output edge of the dataflow. + + Args: + node: Lambda node to visit. + args: Arguments passed to lambda node. + + Returns: + A tuple of two elements: + - List of connections for data inputs to the dataflow. + - Output connection edge. + """ + + output_edge = self._visit_let(node, args) + return self.input_edges, output_edge From de4a80e06e699bbf7779fe1dce38475c4054ae78 Mon Sep 17 00:00:00 2001 From: Edoardo Paone Date: Mon, 16 Dec 2024 13:53:48 +0100 Subject: [PATCH 4/4] review comments (2) --- .../gtir_builtin_translators.py | 5 +- .../runners/dace_fieldview/gtir_dataflow.py | 91 ++++++++----------- 2 files changed, 43 insertions(+), 53 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py index 3d78cbbafe..cffbd74c90 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py @@ -394,8 +394,9 @@ def translate_as_fieldop( fieldop_args = [_parse_fieldop_arg(arg, sdfg, state, sdfg_builder, domain) for arg in node.args] # represent the field operator as a mapped tasklet graph, which will range over the field domain - taskgen = gtir_dataflow.LambdaToDataflow(sdfg, state, sdfg_builder) - input_edges, output_edge = taskgen.apply(stencil_expr, args=fieldop_args) + input_edges, output_edge = gtir_dataflow.visit_lambda( + sdfg, state, sdfg_builder, stencil_expr, fieldop_args + ) return _create_field_operator( sdfg, state, domain, node.type, sdfg_builder, input_edges, output_edge diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow.py index 8083dc6b09..a3653fb519 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow.py @@ -268,6 +268,7 @@ def get_reduce_params(node: gtir.FunCall) -> tuple[str, SymbolExpr, SymbolExpr]: return op_name, reduce_init, reduce_identity +@dataclasses.dataclass(frozen=True) class LambdaToDataflow(eve.NodeVisitor): """ Visitor class to translate a `Lambda` expression to a dataflow graph. @@ -288,20 +289,10 @@ class LambdaToDataflow(eve.NodeVisitor): sdfg: dace.SDFG state: dace.SDFGState subgraph_builder: gtir_sdfg.DataflowBuilder - input_edges: list[DataflowInputEdge] - symbol_map: dict[str, IteratorExpr | MemletExpr | SymbolExpr] - - def __init__( - self, - sdfg: dace.SDFG, - state: dace.SDFGState, - subgraph_builder: gtir_sdfg.DataflowBuilder, - ): - self.sdfg = sdfg - self.state = state - self.subgraph_builder = subgraph_builder - self.input_edges = [] - self.symbol_map = {} + input_edges: list[DataflowInputEdge] = dataclasses.field(default_factory=lambda: []) + symbol_map: dict[str, IteratorExpr | MemletExpr | SymbolExpr] = dataclasses.field( + default_factory=lambda: {} + ) def _add_input_data_edge( self, @@ -1280,7 +1271,7 @@ def visit_FunCall(self, node: gtir.FunCall) -> IteratorExpr | DataExpr: return self._visit_shift(node) elif isinstance(node.fun, gtir.Lambda): - # Lambda node should be visited with 'apply()' method. + # Lambda node should be visited with 'visit_let()' method. raise ValueError(f"Unexpected lambda in 'FunCall' node: {node}.") elif isinstance(node.fun, gtir.SymRef): @@ -1325,7 +1316,7 @@ def visit_SymRef(self, node: gtir.SymRef) -> IteratorExpr | MemletExpr | SymbolE assert param in gtir_python_codegen.MATH_BUILTINS_MAPPING return SymbolExpr(param, dace.string) - def _visit_let( + def visit_let( self, node: gtir.Lambda, args: Sequence[IteratorExpr | MemletExpr | SymbolExpr], @@ -1341,53 +1332,51 @@ def _visit_let( expression. In other words, the `Lambda` node representing the stencil expression is always the innermost node. Therefore, the lowering of let-statements results in recursive calls to - `_visit_let()` until the stencil expression is found. At that point, it falls + `visit_let()` until the stencil expression is found. At that point, it falls back to the `visit()` function. """ # lambda arguments are mapped to symbols defined in lambda scope. - prev_symbol_map = self.symbol_map - self.symbol_map = self.symbol_map | { - str(p.id): arg for p, arg in zip(node.params, args, strict=True) - } + for p, arg in zip(node.params, args, strict=True): + self.symbol_map[str(p.id)] = arg if cpm.is_let(node.expr): let_node = node.expr let_args = [self.visit(arg) for arg in let_node.args] assert isinstance(let_node.fun, gtir.Lambda) - output_edge = self._visit_let(let_node.fun, args=let_args) + return self.visit_let(let_node.fun, args=let_args) else: # this lambda node is not a let-statement, but a stencil expression - output_edge = self.visit(node) + return self.visit(node) - # restore previous symbols, thus removing locally defined lambda symbols in the let-scope - self.symbol_map = prev_symbol_map - return output_edge +def visit_lambda( + sdfg: dace.SDFG, + state: dace.SDFGState, + sdfg_builder: gtir_sdfg.SDFGBuilder, + node: gtir.Lambda, + args: Sequence[IteratorExpr | MemletExpr | SymbolExpr], +) -> tuple[list[DataflowInputEdge], DataflowOutputEdge]: + """ + Entry point to visit a `Lambda` node and lower it to a dataflow graph, + that can be instantiated inside a map scope implementing the field operator. - def apply( - self, - node: gtir.Lambda, - args: Sequence[IteratorExpr | MemletExpr | SymbolExpr], - ) -> tuple[list[DataflowInputEdge], DataflowOutputEdge]: - """ - Entry point for this visitor class, that will translate a `Lambda` node - into a dataflow graph to be instantiated inside a map scope implementing - the field operator. - - It calls `_visit_let()` to maps lambda arguments to internal parameters and - visit let-statements (if any), which always appear as outermost nodes. - The visitor will return the output edge of the dataflow. - - Args: - node: Lambda node to visit. - args: Arguments passed to lambda node. - - Returns: - A tuple of two elements: - - List of connections for data inputs to the dataflow. - - Output connection edge. - """ + It calls `LambdaToDataflow.visit_let()` to map the lambda arguments to internal + parameters and visit the let-statements (if any), which always appear as outermost + nodes. Finally, the visitor returns the output edge of the dataflow. - output_edge = self._visit_let(node, args) - return self.input_edges, output_edge + Args: + sdfg: The SDFG where the dataflow graph will be instantiated. + state: The SDFG state where the dataflow graph will be instantiated. + sdfg_builder: Helper class to build the SDFG. + node: Lambda node to visit. + args: Arguments passed to lambda node. + + Returns: + A tuple of two elements: + - List of connections for data inputs to the dataflow. + - Output data connection. + """ + taskgen = LambdaToDataflow(sdfg, state, sdfg_builder) + output_edge = taskgen.visit_let(node, args) + return taskgen.input_edges, output_edge