diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py b/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py index fa28793187..432bf3e1bf 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py @@ -22,6 +22,7 @@ from dace.codegen.compiled_sdfg import CompiledSDFG from dace.sdfg import utils as sdutils from dace.transformation.auto import auto_optimize as autoopt +from dace.transformation.interstate import RefineNestedAccess import gt4py.next.allocators as next_allocators import gt4py.next.iterator.ir as itir @@ -71,7 +72,7 @@ def preprocess_program( lift_mode: itir_transforms.LiftMode, unroll_reduce: bool = False, ): - return itir_transforms.apply_common_transforms( + node = itir_transforms.apply_common_transforms( program, common_subexpression_elimination=False, force_inline_lambda_args=True, @@ -80,6 +81,21 @@ def preprocess_program( unroll_reduce=unroll_reduce, ) + if isinstance(node, itir_transforms.global_tmps.FencilWithTemporaries): + fencil_definition = node.fencil + tmps = node.tmps + + elif isinstance(node, itir.FencilDefinition): + fencil_definition = node + tmps = [] + + else: + raise TypeError( + f"Expected 'FencilDefinition' or 'FencilWithTemporaries', got '{type(program).__name__}'." + ) + + return fencil_definition, tmps + def get_args(sdfg: dace.SDFG, args: Sequence[Any]) -> dict[str, Any]: sdfg_params: Sequence[str] = sdfg.arg_names @@ -160,6 +176,7 @@ def get_stride_args( def get_cache_id( build_type: str, build_for_gpu: bool, + lift_mode: itir_transforms.LiftMode, program: itir.FencilDefinition, arg_types: Sequence[ts.TypeSpec], column_axis: Optional[common.Dimension], @@ -185,6 +202,7 @@ def offset_invariants(offset): for arg in ( build_type, build_for_gpu, + lift_mode, program, *arg_types, column_axis, @@ -272,17 +290,17 @@ def build_sdfg_from_itir( sdfg.validate() return sdfg - # TODO(edopao): As temporary fix until temporaries are supported in the DaCe Backend force - # `lift_more` to `FORCE_INLINE` mode. - lift_mode = itir_transforms.LiftMode.FORCE_INLINE arg_types = [type_translation.from_value(arg) for arg in args] # visit ITIR and generate SDFG - program = preprocess_program(program, offset_provider, lift_mode) - sdfg_genenerator = ItirToSDFG(arg_types, offset_provider, column_axis) + program, tmps = preprocess_program(program, offset_provider, lift_mode) + sdfg_genenerator = ItirToSDFG(arg_types, offset_provider, tmps, column_axis) sdfg = sdfg_genenerator.visit(program) if sdfg is None: raise RuntimeError(f"Visit failed for program {program.id}.") + elif tmps: + # This pass is needed to avoid transformation errors in SDFG inlining, because temporaries are using offsets + sdfg.apply_transformations_repeated(RefineNestedAccess) for nested_sdfg in sdfg.all_sdfgs_recursive(): if not nested_sdfg.debuginfo: @@ -338,7 +356,9 @@ def run_dace_iterator(program: itir.FencilDefinition, *args, **kwargs): arg_types = [type_translation.from_value(arg) for arg in args] - cache_id = get_cache_id(build_type, on_gpu, program, arg_types, column_axis, offset_provider) + cache_id = get_cache_id( + build_type, on_gpu, lift_mode, program, arg_types, column_axis, offset_provider + ) if build_cache is not None and cache_id in build_cache: # retrieve SDFG program from build cache sdfg_program = build_cache[cache_id] diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py index eaff9f467e..a578e9c19b 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py @@ -19,8 +19,12 @@ import gt4py.eve as eve from gt4py.next import Dimension, DimensionKind, type_inference as next_typing from gt4py.next.common import NeighborTable -from gt4py.next.iterator import ir as itir, type_inference as itir_typing -from gt4py.next.iterator.ir import Expr, FunCall, Literal, SymRef +from gt4py.next.iterator import ( + ir as itir, + transforms as itir_transforms, + type_inference as itir_typing, +) +from gt4py.next.iterator.ir import Expr, FunCall, Literal, Sym, SymRef from gt4py.next.type_system import type_specifications as ts, type_translation from .itir_to_tasklet import ( @@ -36,6 +40,7 @@ from .utility import ( add_mapped_nested_sdfg, as_dace_type, + as_scalar_type, connectivity_identifier, create_memlet_at, create_memlet_full, @@ -44,6 +49,7 @@ flatten_list, get_sorted_dims, map_nested_sdfg_symbols, + new_array_symbols, unique_name, unique_var_name, ) @@ -154,12 +160,14 @@ def __init__( self, param_types: list[ts.TypeSpec], offset_provider: dict[str, NeighborTable], + tmps: list[itir_transforms.global_tmps.Temporary], column_axis: Optional[Dimension] = None, ): self.param_types = param_types self.column_axis = column_axis self.offset_provider = offset_provider self.storage_types = {} + self.tmps = tmps def add_storage( self, @@ -189,6 +197,70 @@ def add_storage( raise NotImplementedError() self.storage_types[name] = type_ + def add_storage_for_temporaries( + self, node_params: list[Sym], defs_state: dace.SDFGState, program_sdfg: dace.SDFG + ) -> dict[str, str]: + symbol_map: dict[str, TaskletExpr] = {} + # The shape of temporary arrays might be defined based on scalar values passed as program arguments. + # Here we collect these values in a symbol map. + tmp_ids = set(tmp.id for tmp in self.tmps) + for sym in node_params: + if sym.id not in tmp_ids and sym.kind != "Iterator": + name_ = str(sym.id) + type_ = self.storage_types[name_] + assert isinstance(type_, ts.ScalarType) + symbol_map[name_] = SymbolExpr(name_, as_dace_type(type_)) + + tmp_symbols: dict[str, str] = {} + for tmp in self.tmps: + tmp_name = str(tmp.id) + + # We visit the domain of the temporary field, passing the set of available symbols. + assert isinstance(tmp.domain, itir.FunCall) + self.node_types.update(itir_typing.infer_all(tmp.domain)) + domain_ctx = Context(program_sdfg, defs_state, symbol_map) + tmp_domain = self._visit_domain(tmp.domain, domain_ctx) + + # We build the FieldType for this temporary array. + dims: list[Dimension] = [] + for dim, _ in tmp_domain: + dims.append( + Dimension( + value=dim, + kind=( + DimensionKind.VERTICAL + if self.column_axis is not None and self.column_axis.value == dim + else DimensionKind.HORIZONTAL + ), + ) + ) + assert isinstance(tmp.dtype, str) + type_ = ts.FieldType(dims=dims, dtype=as_scalar_type(tmp.dtype)) + self.storage_types[tmp_name] = type_ + + # N.B.: skip generation of symbolic strides and just let dace assign default strides, for now. + # Another option, in the future, is to use symbolic strides and apply auto-tuning or some heuristics + # to assign optimal stride values. + tmp_shape, _ = new_array_symbols(tmp_name, len(dims)) + tmp_offset = [ + dace.symbol(unique_name(f"{tmp_name}_offset{i}")) for i in range(len(dims)) + ] + _, tmp_array = program_sdfg.add_array( + tmp_name, tmp_shape, as_dace_type(type_.dtype), offset=tmp_offset, transient=True + ) + + # Loop through all dimensions to visit the symbolic expressions for array shape and offset. + # These expressions are later mapped to interstate symbols. + for (_, (begin, end)), offset_sym, shape_sym in zip( + tmp_domain, + tmp_array.offset, + tmp_array.shape, + ): + tmp_symbols[str(offset_sym)] = f"0 - {begin.value}" + tmp_symbols[str(shape_sym)] = f"{end.value} - {begin.value}" + + return tmp_symbols + def get_output_nodes( self, closure: itir.StencilClosure, sdfg: dace.SDFG, state: dace.SDFGState ) -> dict[str, dace.nodes.AccessNode]: @@ -204,7 +276,7 @@ def get_output_nodes( def visit_FencilDefinition(self, node: itir.FencilDefinition): program_sdfg = dace.SDFG(name=node.id) program_sdfg.debuginfo = dace_debuginfo(node) - last_state = program_sdfg.add_state("program_entry", True) + entry_state = program_sdfg.add_state("program_entry", is_start_block=True) self.node_types = itir_typing.infer_all(node) # Filter neighbor tables from offset providers. @@ -214,6 +286,20 @@ def visit_FencilDefinition(self, node: itir.FencilDefinition): for param, type_ in zip(node.params, self.param_types): self.add_storage(program_sdfg, str(param.id), type_, neighbor_tables) + if self.tmps: + tmp_symbols = self.add_storage_for_temporaries(node.params, entry_state, program_sdfg) + # on the first interstate edge define symbols for shape and offsets of temporary arrays + last_state = program_sdfg.add_state("init_symbols_for_temporaries") + program_sdfg.add_edge( + entry_state, + last_state, + dace.InterstateEdge( + assignments=tmp_symbols, + ), + ) + else: + last_state = entry_state + # Add connectivities as SDFG storages. for offset, offset_provider in neighbor_tables.items(): scalar_kind = type_translation.get_scalar_kind(offset_provider.table.dtype) diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/utility.py b/src/gt4py/next/program_processors/runners/dace_iterator/utility.py index 49dd2472c5..0c3fd741d5 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/utility.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/utility.py @@ -51,6 +51,14 @@ def as_dace_type(type_: ts.ScalarType): raise ValueError(f"Scalar type '{type_}' not supported.") +def as_scalar_type(typestr: str) -> ts.ScalarType: + try: + kind = getattr(ts.ScalarKind, typestr.upper()) + except AttributeError: + raise ValueError(f"Data type {typestr} not supported.") + return ts.ScalarType(kind) + + def filter_neighbor_tables(offset_provider: dict[str, Any]): return { offset: table