diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py b/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py index 2e9a66c435..fa28793187 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py @@ -69,28 +69,16 @@ def preprocess_program( program: itir.FencilDefinition, offset_provider: Mapping[str, Any], lift_mode: itir_transforms.LiftMode, + unroll_reduce: bool = False, ): - node = itir_transforms.apply_common_transforms( + return itir_transforms.apply_common_transforms( program, common_subexpression_elimination=False, + force_inline_lambda_args=True, lift_mode=lift_mode, offset_provider=offset_provider, - unroll_reduce=False, + unroll_reduce=unroll_reduce, ) - # If we don't unroll, there may be lifts left in the itir which can't be lowered to SDFG. - # In this case, just retry with unrolled reductions. - if all([ItirToSDFG._check_no_lifts(closure) for closure in node.closures]): - fencil_definition = node - else: - fencil_definition = itir_transforms.apply_common_transforms( - program, - common_subexpression_elimination=False, - force_inline_lambda_args=True, - lift_mode=lift_mode, - offset_provider=offset_provider, - unroll_reduce=True, - ) - return fencil_definition def get_args(sdfg: dace.SDFG, args: Sequence[Any]) -> dict[str, Any]: diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py index 073c856d86..eaff9f467e 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py @@ -124,6 +124,24 @@ def _make_array_shape_and_strides( return shape, strides +def _check_no_lifts(node: itir.StencilClosure): + """ + Parse stencil closure ITIR to check that lift expressions only appear as child nodes in neighbor reductions. + + Returns + ------- + True if lifts do not appear in the ITIR exception lift expressions in neighbor reductions. False otherwise. + """ + neighbors_call_count = 0 + for fun in eve.walk_values(node).if_isinstance(itir.FunCall).getattr("fun"): + if getattr(fun, "id", "") == "neighbors": + neighbors_call_count = 3 + elif getattr(fun, "id", "") == "lift" and neighbors_call_count != 1: + return False + neighbors_call_count = max(0, neighbors_call_count - 1) + return True + + class ItirToSDFG(eve.NodeVisitor): param_types: list[ts.TypeSpec] storage_types: dict[str, ts.TypeSpec] @@ -262,7 +280,7 @@ def visit_FencilDefinition(self, node: itir.FencilDefinition): def visit_StencilClosure( self, node: itir.StencilClosure, array_table: dict[str, dace.data.Array] ) -> tuple[dace.SDFG, list[str], list[str]]: - assert ItirToSDFG._check_no_lifts(node) + assert _check_no_lifts(node) # Create the closure's nested SDFG and single state. closure_sdfg = dace.SDFG(name="closure") @@ -681,15 +699,6 @@ def _visit_domain( return tuple(sorted(bounds, key=lambda item: item[0])) - @staticmethod - def _check_no_lifts(node: itir.StencilClosure): - if any( - getattr(fun, "id", "") == "lift" - for fun in eve.walk_values(node).if_isinstance(itir.FunCall).getattr("fun") - ): - return False - return True - @staticmethod def _check_shift_offsets_are_literals(node: itir.StencilClosure): fun_calls = eve.walk_values(node).if_isinstance(itir.FunCall) diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py index 56ffe7e104..773a3a61f7 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py @@ -181,6 +181,126 @@ def __init__( self.reduce_identity = reduce_identity +def _visit_lift_in_neighbors_reduction( + transformer: "PythonTaskletCodegen", + node: itir.FunCall, + node_args: Sequence[IteratorExpr | list[ValueExpr]], + offset_provider: NeighborTableOffsetProvider, + map_entry: dace.nodes.MapEntry, + map_exit: dace.nodes.MapExit, + neighbor_index_node: dace.nodes.AccessNode, + neighbor_value_node: dace.nodes.AccessNode, +) -> list[ValueExpr]: + neighbor_dim = offset_provider.neighbor_axis.value + origin_dim = offset_provider.origin_axis.value + + lifted_args: list[IteratorExpr | ValueExpr] = [] + for arg in node_args: + if isinstance(arg, IteratorExpr): + if origin_dim in arg.indices: + lifted_indices = arg.indices.copy() + lifted_indices.pop(origin_dim) + lifted_indices[neighbor_dim] = neighbor_index_node + lifted_args.append( + IteratorExpr( + arg.field, + lifted_indices, + arg.dtype, + arg.dimensions, + ) + ) + else: + lifted_args.append(arg) + else: + lifted_args.append(arg[0]) + + lift_context, inner_inputs, inner_outputs = transformer.visit(node.args[0], args=lifted_args) + assert len(inner_outputs) == 1 + inner_out_connector = inner_outputs[0].value.data + + input_nodes = {} + iterator_index_nodes = {} + lifted_index_connectors = set() + + for x, y in inner_inputs: + if isinstance(y, IteratorExpr): + field_connector, inner_index_table = x + input_nodes[field_connector] = y.field + for dim, connector in inner_index_table.items(): + if dim == neighbor_dim: + lifted_index_connectors.add(connector) + iterator_index_nodes[connector] = y.indices[dim] + else: + assert isinstance(y, ValueExpr) + input_nodes[x] = y.value + + neighbor_tables = filter_neighbor_tables(transformer.offset_provider) + connectivity_names = [connectivity_identifier(offset) for offset in neighbor_tables.keys()] + + parent_sdfg = transformer.context.body + parent_state = transformer.context.state + + input_mapping = { + connector: create_memlet_full(node.data, node.desc(parent_sdfg)) + for connector, node in input_nodes.items() + } + connectivity_mapping = { + name: create_memlet_full(name, parent_sdfg.arrays[name]) for name in connectivity_names + } + array_mapping = {**input_mapping, **connectivity_mapping} + symbol_mapping = map_nested_sdfg_symbols(parent_sdfg, lift_context.body, array_mapping) + + nested_sdfg_node = parent_state.add_nested_sdfg( + lift_context.body, + parent_sdfg, + inputs={*array_mapping.keys(), *iterator_index_nodes.keys()}, + outputs={inner_out_connector}, + symbol_mapping=symbol_mapping, + debuginfo=lift_context.body.debuginfo, + ) + + for connectivity_connector, memlet in connectivity_mapping.items(): + parent_state.add_memlet_path( + parent_state.add_access(memlet.data, debuginfo=lift_context.body.debuginfo), + map_entry, + nested_sdfg_node, + dst_conn=connectivity_connector, + memlet=memlet, + ) + + for inner_connector, access_node in input_nodes.items(): + parent_state.add_memlet_path( + access_node, + map_entry, + nested_sdfg_node, + dst_conn=inner_connector, + memlet=input_mapping[inner_connector], + ) + + for inner_connector, access_node in iterator_index_nodes.items(): + memlet = dace.Memlet(data=access_node.data, subset="0") + if inner_connector in lifted_index_connectors: + parent_state.add_edge(access_node, None, nested_sdfg_node, inner_connector, memlet) + else: + parent_state.add_memlet_path( + access_node, + map_entry, + nested_sdfg_node, + dst_conn=inner_connector, + memlet=memlet, + ) + + parent_state.add_memlet_path( + nested_sdfg_node, + map_exit, + neighbor_value_node, + src_conn=inner_out_connector, + memlet=dace.Memlet(data=neighbor_value_node.data, subset=",".join(map_entry.params)), + ) + + return [ValueExpr(neighbor_value_node, inner_outputs[0].dtype)] + + def builtin_neighbors( transformer: "PythonTaskletCodegen", node: itir.Expr, node_args: list[itir.Expr] ) -> list[ValueExpr]: @@ -198,7 +318,16 @@ def builtin_neighbors( "Neighbor reduction only implemented for connectivity based on neighbor tables." ) - iterator = transformer.visit(data) + lift_node = None + if isinstance(data, FunCall): + assert isinstance(data.fun, itir.FunCall) + fun_node = data.fun + if isinstance(fun_node.fun, itir.SymRef) and fun_node.fun.id == "lift": + lift_node = fun_node + lift_args = transformer.visit(data.args) + iterator = next(filter(lambda x: isinstance(x, IteratorExpr), lift_args), None) + if lift_node is None: + iterator = transformer.visit(data) assert isinstance(iterator, IteratorExpr) field_desc = iterator.field.desc(transformer.context.body) origin_index_node = iterator.indices[offset_provider.origin_axis.value] @@ -259,44 +388,56 @@ def builtin_neighbors( dace.Memlet(data=neighbor_index_var, subset="0"), ) - data_access_tasklet = state.add_tasklet( - "data_access", - code="__data = __field[__idx]" - + ( - f" if __idx != {neighbor_skip_value} else {transformer.context.reduce_identity.value}" - if offset_provider.has_skip_values - else "" - ), - inputs={"__field", "__idx"}, - outputs={"__data"}, - debuginfo=di, - ) - # select full shape only in the neighbor-axis dimension - field_subset = tuple( - f"0:{shape}" if dim == offset_provider.neighbor_axis.value else f"i_{dim}" - for dim, shape in zip(sorted(iterator.dimensions), field_desc.shape) - ) - state.add_memlet_path( - iterator.field, - me, - data_access_tasklet, - memlet=create_memlet_at(iterator.field.data, field_subset), - dst_conn="__field", - ) - state.add_edge( - neighbor_index_node, - None, - data_access_tasklet, - "__idx", - dace.Memlet(data=neighbor_index_var, subset="0"), - ) - state.add_memlet_path( - data_access_tasklet, - mx, - neighbor_value_node, - memlet=dace.Memlet(data=neighbor_value_var, subset=neighbor_map_index, debuginfo=di), - src_conn="__data", - ) + if lift_node is not None: + _visit_lift_in_neighbors_reduction( + transformer, + lift_node, + lift_args, + offset_provider, + me, + mx, + neighbor_index_node, + neighbor_value_node, + ) + else: + data_access_tasklet = state.add_tasklet( + "data_access", + code="__data = __field[__idx]" + + ( + f" if __idx != {neighbor_skip_value} else {transformer.context.reduce_identity.value}" + if offset_provider.has_skip_values + else "" + ), + inputs={"__field", "__idx"}, + outputs={"__data"}, + debuginfo=di, + ) + # select full shape only in the neighbor-axis dimension + field_subset = tuple( + f"0:{shape}" if dim == offset_provider.neighbor_axis.value else f"i_{dim}" + for dim, shape in zip(sorted(iterator.dimensions), field_desc.shape) + ) + state.add_memlet_path( + iterator.field, + me, + data_access_tasklet, + memlet=create_memlet_at(iterator.field.data, field_subset), + dst_conn="__field", + ) + state.add_edge( + neighbor_index_node, + None, + data_access_tasklet, + "__idx", + dace.Memlet(data=neighbor_index_var, subset="0"), + ) + state.add_memlet_path( + data_access_tasklet, + mx, + neighbor_value_node, + memlet=dace.Memlet(data=neighbor_value_var, subset=neighbor_map_index, debuginfo=di), + src_conn="__data", + ) if not offset_provider.has_skip_values: return [ValueExpr(neighbor_value_node, iterator.dtype)] @@ -377,9 +518,8 @@ def builtin_can_deref( # create tasklet to check that field indices are non-negative (-1 is invalid) args = [ValueExpr(access_node, _INDEX_DTYPE) for access_node in iterator.indices.values()] internals = [f"{arg.value.data}_v" for arg in args] - expr_code = " and ".join([f"{v} >= 0" for v in internals]) + expr_code = " and ".join(f"{v} != {neighbor_skip_value}" for v in internals) - # TODO(edopao): select-memlet could maybe allow to efficiently translate can_deref to predicative execution return transformer.add_expr_tasklet( list(zip(args, internals)), expr_code, @@ -946,7 +1086,7 @@ def _visit_shift(self, node: itir.FunCall) -> IteratorExpr | list[ValueExpr]: iterator = self.visit(node.args[0]) if not isinstance(iterator, IteratorExpr): # shift cannot be applied because the argument is not iterable - # TODO: remove this special case when ITIR reduce-unroll pass is able to catch it + # TODO: remove this special case when ITIR pass is able to catch it assert isinstance(iterator, list) and len(iterator) == 1 assert isinstance(iterator[0], ValueExpr) return iterator