From e462a2ec0e72e3d7079fb4fdd909160448044b4d Mon Sep 17 00:00:00 2001 From: edopao Date: Fri, 2 Feb 2024 16:35:41 +0100 Subject: [PATCH] feat[next][dace]: Add support for lift expressions in neighbor reductions (no unrolling) (#1431) Baseline dace backend forced unroll of neighbor reductions, in the ITIR pass, in order to eliminate all lift expressions. This PR adds support for lowering of lift expressions in neighbor reductions, thus avoiding the need to unroll reduce expressions. The result is a more compact SDFG, which leaves to the optimization backend the option of unrolling neighbor reductions. --- .../runners/dace_iterator/__init__.py | 20 +- .../runners/dace_iterator/itir_to_sdfg.py | 29 ++- .../runners/dace_iterator/itir_to_tasklet.py | 224 ++++++++++++++---- 3 files changed, 205 insertions(+), 68 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py b/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py index 2e9a66c435..fa28793187 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py @@ -69,28 +69,16 @@ def preprocess_program( program: itir.FencilDefinition, offset_provider: Mapping[str, Any], lift_mode: itir_transforms.LiftMode, + unroll_reduce: bool = False, ): - node = itir_transforms.apply_common_transforms( + return itir_transforms.apply_common_transforms( program, common_subexpression_elimination=False, + force_inline_lambda_args=True, lift_mode=lift_mode, offset_provider=offset_provider, - unroll_reduce=False, + unroll_reduce=unroll_reduce, ) - # If we don't unroll, there may be lifts left in the itir which can't be lowered to SDFG. - # In this case, just retry with unrolled reductions. - if all([ItirToSDFG._check_no_lifts(closure) for closure in node.closures]): - fencil_definition = node - else: - fencil_definition = itir_transforms.apply_common_transforms( - program, - common_subexpression_elimination=False, - force_inline_lambda_args=True, - lift_mode=lift_mode, - offset_provider=offset_provider, - unroll_reduce=True, - ) - return fencil_definition def get_args(sdfg: dace.SDFG, args: Sequence[Any]) -> dict[str, Any]: diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py index 073c856d86..eaff9f467e 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py @@ -124,6 +124,24 @@ def _make_array_shape_and_strides( return shape, strides +def _check_no_lifts(node: itir.StencilClosure): + """ + Parse stencil closure ITIR to check that lift expressions only appear as child nodes in neighbor reductions. + + Returns + ------- + True if lifts do not appear in the ITIR exception lift expressions in neighbor reductions. False otherwise. + """ + neighbors_call_count = 0 + for fun in eve.walk_values(node).if_isinstance(itir.FunCall).getattr("fun"): + if getattr(fun, "id", "") == "neighbors": + neighbors_call_count = 3 + elif getattr(fun, "id", "") == "lift" and neighbors_call_count != 1: + return False + neighbors_call_count = max(0, neighbors_call_count - 1) + return True + + class ItirToSDFG(eve.NodeVisitor): param_types: list[ts.TypeSpec] storage_types: dict[str, ts.TypeSpec] @@ -262,7 +280,7 @@ def visit_FencilDefinition(self, node: itir.FencilDefinition): def visit_StencilClosure( self, node: itir.StencilClosure, array_table: dict[str, dace.data.Array] ) -> tuple[dace.SDFG, list[str], list[str]]: - assert ItirToSDFG._check_no_lifts(node) + assert _check_no_lifts(node) # Create the closure's nested SDFG and single state. closure_sdfg = dace.SDFG(name="closure") @@ -681,15 +699,6 @@ def _visit_domain( return tuple(sorted(bounds, key=lambda item: item[0])) - @staticmethod - def _check_no_lifts(node: itir.StencilClosure): - if any( - getattr(fun, "id", "") == "lift" - for fun in eve.walk_values(node).if_isinstance(itir.FunCall).getattr("fun") - ): - return False - return True - @staticmethod def _check_shift_offsets_are_literals(node: itir.StencilClosure): fun_calls = eve.walk_values(node).if_isinstance(itir.FunCall) diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py index 56ffe7e104..773a3a61f7 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py @@ -181,6 +181,126 @@ def __init__( self.reduce_identity = reduce_identity +def _visit_lift_in_neighbors_reduction( + transformer: "PythonTaskletCodegen", + node: itir.FunCall, + node_args: Sequence[IteratorExpr | list[ValueExpr]], + offset_provider: NeighborTableOffsetProvider, + map_entry: dace.nodes.MapEntry, + map_exit: dace.nodes.MapExit, + neighbor_index_node: dace.nodes.AccessNode, + neighbor_value_node: dace.nodes.AccessNode, +) -> list[ValueExpr]: + neighbor_dim = offset_provider.neighbor_axis.value + origin_dim = offset_provider.origin_axis.value + + lifted_args: list[IteratorExpr | ValueExpr] = [] + for arg in node_args: + if isinstance(arg, IteratorExpr): + if origin_dim in arg.indices: + lifted_indices = arg.indices.copy() + lifted_indices.pop(origin_dim) + lifted_indices[neighbor_dim] = neighbor_index_node + lifted_args.append( + IteratorExpr( + arg.field, + lifted_indices, + arg.dtype, + arg.dimensions, + ) + ) + else: + lifted_args.append(arg) + else: + lifted_args.append(arg[0]) + + lift_context, inner_inputs, inner_outputs = transformer.visit(node.args[0], args=lifted_args) + assert len(inner_outputs) == 1 + inner_out_connector = inner_outputs[0].value.data + + input_nodes = {} + iterator_index_nodes = {} + lifted_index_connectors = set() + + for x, y in inner_inputs: + if isinstance(y, IteratorExpr): + field_connector, inner_index_table = x + input_nodes[field_connector] = y.field + for dim, connector in inner_index_table.items(): + if dim == neighbor_dim: + lifted_index_connectors.add(connector) + iterator_index_nodes[connector] = y.indices[dim] + else: + assert isinstance(y, ValueExpr) + input_nodes[x] = y.value + + neighbor_tables = filter_neighbor_tables(transformer.offset_provider) + connectivity_names = [connectivity_identifier(offset) for offset in neighbor_tables.keys()] + + parent_sdfg = transformer.context.body + parent_state = transformer.context.state + + input_mapping = { + connector: create_memlet_full(node.data, node.desc(parent_sdfg)) + for connector, node in input_nodes.items() + } + connectivity_mapping = { + name: create_memlet_full(name, parent_sdfg.arrays[name]) for name in connectivity_names + } + array_mapping = {**input_mapping, **connectivity_mapping} + symbol_mapping = map_nested_sdfg_symbols(parent_sdfg, lift_context.body, array_mapping) + + nested_sdfg_node = parent_state.add_nested_sdfg( + lift_context.body, + parent_sdfg, + inputs={*array_mapping.keys(), *iterator_index_nodes.keys()}, + outputs={inner_out_connector}, + symbol_mapping=symbol_mapping, + debuginfo=lift_context.body.debuginfo, + ) + + for connectivity_connector, memlet in connectivity_mapping.items(): + parent_state.add_memlet_path( + parent_state.add_access(memlet.data, debuginfo=lift_context.body.debuginfo), + map_entry, + nested_sdfg_node, + dst_conn=connectivity_connector, + memlet=memlet, + ) + + for inner_connector, access_node in input_nodes.items(): + parent_state.add_memlet_path( + access_node, + map_entry, + nested_sdfg_node, + dst_conn=inner_connector, + memlet=input_mapping[inner_connector], + ) + + for inner_connector, access_node in iterator_index_nodes.items(): + memlet = dace.Memlet(data=access_node.data, subset="0") + if inner_connector in lifted_index_connectors: + parent_state.add_edge(access_node, None, nested_sdfg_node, inner_connector, memlet) + else: + parent_state.add_memlet_path( + access_node, + map_entry, + nested_sdfg_node, + dst_conn=inner_connector, + memlet=memlet, + ) + + parent_state.add_memlet_path( + nested_sdfg_node, + map_exit, + neighbor_value_node, + src_conn=inner_out_connector, + memlet=dace.Memlet(data=neighbor_value_node.data, subset=",".join(map_entry.params)), + ) + + return [ValueExpr(neighbor_value_node, inner_outputs[0].dtype)] + + def builtin_neighbors( transformer: "PythonTaskletCodegen", node: itir.Expr, node_args: list[itir.Expr] ) -> list[ValueExpr]: @@ -198,7 +318,16 @@ def builtin_neighbors( "Neighbor reduction only implemented for connectivity based on neighbor tables." ) - iterator = transformer.visit(data) + lift_node = None + if isinstance(data, FunCall): + assert isinstance(data.fun, itir.FunCall) + fun_node = data.fun + if isinstance(fun_node.fun, itir.SymRef) and fun_node.fun.id == "lift": + lift_node = fun_node + lift_args = transformer.visit(data.args) + iterator = next(filter(lambda x: isinstance(x, IteratorExpr), lift_args), None) + if lift_node is None: + iterator = transformer.visit(data) assert isinstance(iterator, IteratorExpr) field_desc = iterator.field.desc(transformer.context.body) origin_index_node = iterator.indices[offset_provider.origin_axis.value] @@ -259,44 +388,56 @@ def builtin_neighbors( dace.Memlet(data=neighbor_index_var, subset="0"), ) - data_access_tasklet = state.add_tasklet( - "data_access", - code="__data = __field[__idx]" - + ( - f" if __idx != {neighbor_skip_value} else {transformer.context.reduce_identity.value}" - if offset_provider.has_skip_values - else "" - ), - inputs={"__field", "__idx"}, - outputs={"__data"}, - debuginfo=di, - ) - # select full shape only in the neighbor-axis dimension - field_subset = tuple( - f"0:{shape}" if dim == offset_provider.neighbor_axis.value else f"i_{dim}" - for dim, shape in zip(sorted(iterator.dimensions), field_desc.shape) - ) - state.add_memlet_path( - iterator.field, - me, - data_access_tasklet, - memlet=create_memlet_at(iterator.field.data, field_subset), - dst_conn="__field", - ) - state.add_edge( - neighbor_index_node, - None, - data_access_tasklet, - "__idx", - dace.Memlet(data=neighbor_index_var, subset="0"), - ) - state.add_memlet_path( - data_access_tasklet, - mx, - neighbor_value_node, - memlet=dace.Memlet(data=neighbor_value_var, subset=neighbor_map_index, debuginfo=di), - src_conn="__data", - ) + if lift_node is not None: + _visit_lift_in_neighbors_reduction( + transformer, + lift_node, + lift_args, + offset_provider, + me, + mx, + neighbor_index_node, + neighbor_value_node, + ) + else: + data_access_tasklet = state.add_tasklet( + "data_access", + code="__data = __field[__idx]" + + ( + f" if __idx != {neighbor_skip_value} else {transformer.context.reduce_identity.value}" + if offset_provider.has_skip_values + else "" + ), + inputs={"__field", "__idx"}, + outputs={"__data"}, + debuginfo=di, + ) + # select full shape only in the neighbor-axis dimension + field_subset = tuple( + f"0:{shape}" if dim == offset_provider.neighbor_axis.value else f"i_{dim}" + for dim, shape in zip(sorted(iterator.dimensions), field_desc.shape) + ) + state.add_memlet_path( + iterator.field, + me, + data_access_tasklet, + memlet=create_memlet_at(iterator.field.data, field_subset), + dst_conn="__field", + ) + state.add_edge( + neighbor_index_node, + None, + data_access_tasklet, + "__idx", + dace.Memlet(data=neighbor_index_var, subset="0"), + ) + state.add_memlet_path( + data_access_tasklet, + mx, + neighbor_value_node, + memlet=dace.Memlet(data=neighbor_value_var, subset=neighbor_map_index, debuginfo=di), + src_conn="__data", + ) if not offset_provider.has_skip_values: return [ValueExpr(neighbor_value_node, iterator.dtype)] @@ -377,9 +518,8 @@ def builtin_can_deref( # create tasklet to check that field indices are non-negative (-1 is invalid) args = [ValueExpr(access_node, _INDEX_DTYPE) for access_node in iterator.indices.values()] internals = [f"{arg.value.data}_v" for arg in args] - expr_code = " and ".join([f"{v} >= 0" for v in internals]) + expr_code = " and ".join(f"{v} != {neighbor_skip_value}" for v in internals) - # TODO(edopao): select-memlet could maybe allow to efficiently translate can_deref to predicative execution return transformer.add_expr_tasklet( list(zip(args, internals)), expr_code, @@ -946,7 +1086,7 @@ def _visit_shift(self, node: itir.FunCall) -> IteratorExpr | list[ValueExpr]: iterator = self.visit(node.args[0]) if not isinstance(iterator, IteratorExpr): # shift cannot be applied because the argument is not iterable - # TODO: remove this special case when ITIR reduce-unroll pass is able to catch it + # TODO: remove this special case when ITIR pass is able to catch it assert isinstance(iterator, list) and len(iterator) == 1 assert isinstance(iterator[0], ValueExpr) return iterator