diff --git a/pyproject.toml b/pyproject.toml index d086363ec4..08203bf25a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -254,6 +254,7 @@ markers = [ 'uses_scan_without_field_args: tests that require calls to scan that do not have any fields as arguments', 'uses_scan_nested: tests that use nested scans', 'uses_scan_requiring_projector: tests need a projector implementation in gtfn', + 'uses_scan_1d_field: tests scan on a 1D vertical field', 'uses_sparse_fields: tests that require backend support for sparse fields', 'uses_sparse_fields_as_output: tests that require backend support for writing sparse fields', 'uses_strided_neighbor_offset: tests that require backend support for strided neighbor offset', diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py index cffbd74c90..131321f77e 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py @@ -10,7 +10,7 @@ import abc import dataclasses -from typing import TYPE_CHECKING, Final, Iterable, Optional, Protocol, Sequence, TypeAlias +from typing import TYPE_CHECKING, Any, Final, Iterable, Optional, Protocol, Sequence, TypeAlias import dace from dace import subsets as dace_subsets @@ -28,6 +28,7 @@ from gt4py.next.program_processors.runners.dace_fieldview import ( gtir_dataflow, gtir_python_codegen, + gtir_sdfg, utility as dace_gtir_utils, ) from gt4py.next.type_system import type_info as ti, type_specifications as ts @@ -158,6 +159,33 @@ def get_local_view( """Data type used for field indexing.""" +def get_tuple_type(data: tuple[FieldopResult, ...]) -> ts.TupleType: + """ + Compute the `ts.TupleType` corresponding to the structure of a tuple of `FieldopResult`. + """ + return ts.TupleType( + types=[get_tuple_type(d) if isinstance(d, tuple) else d.gt_type for d in data] + ) + + +def flatten_tuples(name: str, arg: FieldopResult) -> list[tuple[str, FieldopData]]: + """ + Visit a `FieldopResult`, potentially containing nested tuples, and construct a list + of pairs `(str, FieldopData)` containing the symbol name of each tuple field and + the corresponding `FieldopData`. + """ + if isinstance(arg, tuple): + tuple_type = get_tuple_type(arg) + tuple_symbols = dace_gtir_utils.flatten_tuple_fields(name, tuple_type) + tuple_data_fields = gtx_utils.flatten_nested_tuple(arg) + return [ + (str(tsym.id), tfield) + for tsym, tfield in zip(tuple_symbols, tuple_data_fields, strict=True) + ] + else: + return [(name, arg)] + + class PrimitiveTranslator(Protocol): @abc.abstractmethod def __call__( @@ -192,16 +220,39 @@ def _parse_fieldop_arg( state: dace.SDFGState, sdfg_builder: gtir_sdfg.SDFGBuilder, domain: FieldopDomain, -) -> gtir_dataflow.IteratorExpr | gtir_dataflow.MemletExpr: + by_value: bool = False, +) -> ( + gtir_dataflow.IteratorExpr + | gtir_dataflow.MemletExpr + | gtir_dataflow.ValueExpr + | tuple[ + gtir_dataflow.IteratorExpr + | gtir_dataflow.MemletExpr + | gtir_dataflow.ValueExpr + | tuple[Any, ...], + ..., + ] +): """Helper method to visit an expression passed as argument to a field operator.""" arg = sdfg_builder.visit(node, sdfg=sdfg, head_state=state) - # arguments passed to field operator should be plain fields, not tuples of fields - if not isinstance(arg, FieldopData): - raise ValueError(f"Received {node} as argument to field operator, expected a field.") + def get_arg_value( + arg: FieldopData, + ) -> gtir_dataflow.IteratorExpr | gtir_dataflow.MemletExpr: + arg_expr = arg.get_local_view(domain) + if not by_value or isinstance(arg_expr, gtir_dataflow.MemletExpr): + return arg_expr + # In case of scan field operator, the arguments to the vertical stencil are passed by value. + return gtir_dataflow.MemletExpr( + arg_expr.field, arg_expr.gt_dtype, arg_expr.get_memlet_subset(sdfg) + ) - return arg.get_local_view(domain) + if isinstance(arg, FieldopData): + return get_arg_value(arg) + else: + # handle tuples of fields + return gtx_utils.tree_map(lambda x: get_arg_value(x))(arg) def _get_field_layout( @@ -237,11 +288,13 @@ def _create_field_operator( sdfg: dace.SDFG, state: dace.SDFGState, domain: FieldopDomain, - node_type: ts.FieldType, + node_type: ts.FieldType | ts.TupleType, sdfg_builder: gtir_sdfg.SDFGBuilder, - input_edges: Sequence[gtir_dataflow.DataflowInputEdge], - output_edge: gtir_dataflow.DataflowOutputEdge, -) -> FieldopData: + input_edges: Iterable[gtir_dataflow.DataflowInputEdge], + output_edges: gtir_dataflow.DataflowOutputEdge + | tuple[gtir_dataflow.DataflowOutputEdge | tuple[Any, ...], ...], + scan_dim: Optional[gtx_common.Dimension] = None, +) -> FieldopResult: """ Helper method to allocate a temporary field to store the output of a field operator. @@ -252,62 +305,110 @@ def _create_field_operator( node_type: The GT4Py type of the IR node that produces this field. sdfg_builder: The object used to build the map scope in the provided SDFG. input_edges: List of edges to pass input data into the dataflow. - output_edge: Edge representing the dataflow output data. + output_edges: Single edge or tuple of edges representing the dataflow output data. + scan_dim: Column dimension used in scan field operators. Returns: The field data descriptor, which includes the field access node in the given `state` and the field domain offset. """ - field_dims, field_offset, field_shape = _get_field_layout(domain) - field_indices = _get_domain_indices(field_dims, field_offset) - - dataflow_output_desc = output_edge.result.dc_node.desc(sdfg) + domain_dims, domain_offset, domain_shape = _get_field_layout(domain) + domain_indices = _get_domain_indices(domain_dims, domain_offset) + domain_subset = dace_subsets.Range.from_indices(domain_indices) + + scan_dim_index: Optional[int] = None + if scan_dim is not None: + scan_dim_index = domain_dims.index(scan_dim) + # we construct the field operator only on the horizontal domain + domain_subset = dace_subsets.Range( + domain_subset[:scan_dim_index] + domain_subset[scan_dim_index + 1 :] + ) - field_subset = dace_subsets.Range.from_indices(field_indices) - if isinstance(output_edge.result.gt_dtype, ts.ScalarType): - assert output_edge.result.gt_dtype == node_type.dtype - assert isinstance(dataflow_output_desc, dace.data.Scalar) - assert dataflow_output_desc.dtype == dace_utils.as_dace_type(node_type.dtype) - field_dtype = output_edge.result.gt_dtype + # now check, after removal of the vertical dimension, whether the domain is empty + if len(domain_subset) == 0: + # no need to create a map scope, the field operator domain is empty + me, mx = (None, None) else: - assert isinstance(node_type.dtype, itir_ts.ListType) - assert output_edge.result.gt_dtype.element_type == node_type.dtype.element_type - assert isinstance(dataflow_output_desc, dace.data.Array) - assert isinstance(output_edge.result.gt_dtype.element_type, ts.ScalarType) - field_dtype = output_edge.result.gt_dtype.element_type - # extend the array with the local dimensions added by the field operator (e.g. `neighbors`) - assert output_edge.result.gt_dtype.offset_type is not None - field_dims.append(output_edge.result.gt_dtype.offset_type) - field_shape.extend(dataflow_output_desc.shape) - field_offset.extend(dataflow_output_desc.offset) - field_subset = field_subset + dace_subsets.Range.from_array(dataflow_output_desc) - - # allocate local temporary storage - field_name, _ = sdfg.add_temp_transient(field_shape, dataflow_output_desc.dtype) - field_node = state.add_access(field_name) - - # create map range corresponding to the field operator domain - me, mx = sdfg_builder.add_map( - "fieldop", - state, - ndrange={ - dace_gtir_utils.get_map_variable(dim): f"{lower_bound}:{upper_bound}" - for dim, lower_bound, upper_bound in domain - }, - ) + # create map range corresponding to the field operator domain + me, mx = sdfg_builder.add_map( + "fieldop", + state, + ndrange={ + dace_gtir_utils.get_map_variable(dim): f"{lower_bound}:{upper_bound}" + for dim, lower_bound, upper_bound in domain + if dim != scan_dim + }, + ) # here we setup the edges passing through the map entry node for edge in input_edges: edge.connect(me) - # and here the edge writing the dataflow result data through the map exit node - output_edge.connect(mx, field_node, field_subset) + def create_field(output_edge: gtir_dataflow.DataflowOutputEdge, sym: gtir.Sym) -> FieldopData: + assert isinstance(sym.type, ts.FieldType) + dataflow_output_desc = output_edge.result.dc_node.desc(sdfg) + if isinstance(output_edge.result.gt_dtype, ts.ScalarType): + assert output_edge.result.gt_dtype == sym.type.dtype + assert dataflow_output_desc.dtype == dace_utils.as_dace_type(sym.type.dtype) + field_dtype = output_edge.result.gt_dtype + field_dims, field_shape, field_offset = (domain_dims, domain_shape, domain_offset) + if scan_dim is not None: + # the scan field operator produces a 1D vertical field + assert isinstance(dataflow_output_desc, dace.data.Array) + assert len(dataflow_output_desc.shape) == 1 + # the vertical dimension should not belong to the field operator domain + # but we need to write it to the output field + field_subset = ( + dace_subsets.Range(domain_subset[:scan_dim_index]) + + dace_subsets.Range.from_array(dataflow_output_desc) + + dace_subsets.Range(domain_subset[scan_dim_index:]) + ) + else: + assert isinstance(dataflow_output_desc, dace.data.Scalar) + field_subset = domain_subset + else: + assert isinstance(sym.type.dtype, itir_ts.ListType) + assert output_edge.result.gt_dtype.element_type == sym.type.dtype.element_type + assert isinstance(output_edge.result.gt_dtype.element_type, ts.ScalarType) + field_dtype = output_edge.result.gt_dtype.element_type + assert isinstance(dataflow_output_desc, dace.data.Array) + assert len(dataflow_output_desc.shape) == 1 + # extend the array with the local dimensions added by the field operator (e.g. `neighbors`) + assert output_edge.result.gt_dtype.offset_type is not None + field_dims = [*domain_dims, output_edge.result.gt_dtype.offset_type] + field_shape = [*domain_shape, dataflow_output_desc.shape[0]] + field_offset = [*domain_offset, dataflow_output_desc.offset[0]] + field_subset = domain_subset + dace_subsets.Range.from_array(dataflow_output_desc) + + # allocate local temporary storage + field_name, field_desc = sdfg.add_temp_transient(field_shape, dataflow_output_desc.dtype) + field_node = state.add_access(field_name) + + if scan_dim is not None: + # By default, we leave `strides=None` which corresponds to use DaCe default memory layout + # for transient arrays. However, for scan field operators we need to ensure that the same + # stride is used for the vertical dimension in inner and outer array. + scan_output_stride = field_desc.strides[scan_dim_index] + dataflow_output_desc.strides = (scan_output_stride,) + + # and here the edge writing the dataflow result data through the map exit node + output_edge.connect(mx, field_node, field_subset) + + return FieldopData( + field_node, + ts.FieldType(field_dims, field_dtype), + offset=(field_offset if set(field_offset) != {0} else None), + ) - return FieldopData( - field_node, - ts.FieldType(field_dims, field_dtype), - offset=(field_offset if set(field_offset) != {0} else None), - ) + if isinstance(output_edges, gtir_dataflow.DataflowOutputEdge): + assert isinstance(node_type, ts.FieldType) + return create_field(output_edges, im.sym("x", node_type)) + else: + # handle tuples of fields + assert isinstance(node_type, ts.TupleType) + return gtx_utils.tree_map(create_field)( + output_edges, dace_gtir_utils.make_symbol_tuple("x", node_type) + ) def extract_domain(node: gtir.Node) -> FieldopDomain: @@ -371,6 +472,9 @@ def translate_as_fieldop( assert len(fun_node.args) == 2 fieldop_expr, domain_expr = fun_node.args + if cpm.is_call_to(fieldop_expr, "scan"): + return translate_scan(node, sdfg, state, sdfg_builder) + assert isinstance(node.type, ts.FieldType) if cpm.is_ref_to(fieldop_expr, "deref"): # Special usage of 'deref' as argument to fieldop expression, to pass a scalar @@ -397,6 +501,7 @@ def translate_as_fieldop( input_edges, output_edge = gtir_dataflow.visit_lambda( sdfg, state, sdfg_builder, stencil_expr, fieldop_args ) + assert isinstance(output_edge, gtir_dataflow.DataflowOutputEdge) return _create_field_operator( sdfg, state, domain, node.type, sdfg_builder, input_edges, output_edge @@ -570,11 +675,10 @@ def _get_data_nodes( return sdfg_builder.make_field(data_node, data_type) elif isinstance(data_type, ts.TupleType): - tuple_fields = dace_gtir_utils.get_tuple_fields(data_name, data_type) - return tuple( - _get_data_nodes(sdfg, state, sdfg_builder, fname, ftype) - for fname, ftype in tuple_fields - ) + tuple_syms = dace_gtir_utils.make_symbol_tuple(data_name, data_type) + return gtx_utils.tree_map( + lambda sym: _get_data_nodes(sdfg, state, sdfg_builder, sym.id, sym.type) + )(tuple_syms) else: raise NotImplementedError(f"Symbol type {type(data_type)} not supported.") @@ -691,10 +795,8 @@ def translate_scalar_expr( visit_expr = True if isinstance(arg_expr, gtir.SymRef): try: - # `gt_symbol` refers to symbols defined in the GT4Py program - gt_symbol_type = sdfg_builder.get_symbol_type(arg_expr.id) - if not isinstance(gt_symbol_type, ts.ScalarType): - raise ValueError(f"Invalid argument to scalar expression {arg_expr}.") + # check if symbol is defined in the GT4Py program, returns `None` if undefined + sdfg_builder.get_symbol_type(arg_expr.id) except KeyError: # this is the case of non-variable argument, e.g. target type such as `float64`, # used in a casting expression like `cast_(variable, float64)` @@ -708,7 +810,7 @@ def translate_scalar_expr( sdfg=sdfg, head_state=state, ) - if not (isinstance(arg, FieldopData) and isinstance(arg.gt_type, ts.ScalarType)): + if not (isinstance(arg, FieldopData) and isinstance(node.type, ts.ScalarType)): raise ValueError(f"Invalid argument to scalar expression {arg_expr}.") param = f"__arg{i}" args.append(arg.dc_node) @@ -756,6 +858,298 @@ def translate_scalar_expr( return FieldopData(temp_node, node.type, offset=None) +def translate_scan( + node: gtir.Node, + sdfg: dace.SDFG, + state: dace.SDFGState, + sdfg_builder: gtir_sdfg.SDFGBuilder, +) -> FieldopResult: + assert isinstance(node, gtir.FunCall) + assert cpm.is_call_to(node.fun, "as_fieldop") + assert isinstance(node.type, (ts.FieldType, ts.TupleType)) + + fun_node = node.fun + assert len(fun_node.args) == 2 + scan_expr, domain_expr = fun_node.args + assert cpm.is_call_to(scan_expr, "scan") + + # parse the domain of the scan field operator + domain = extract_domain(domain_expr) + + # use the vertical dimension in the domain as scan dimension + scan_domain = [ + (dim, lower_bound, upper_bound) + for dim, lower_bound, upper_bound in domain + if dim.kind == gtx_common.DimensionKind.VERTICAL + ] + assert len(scan_domain) == 1 + scan_dim, scan_lower_bound, scan_upper_bound = scan_domain[0] + assert sdfg_builder.is_column_dimension(scan_dim) + + # parse scan parameters + assert len(scan_expr.args) == 3 + stencil_expr = scan_expr.args[0] + assert isinstance(stencil_expr, gtir.Lambda) + + # params[0]: the lambda parameter to propagate the scan carry on the vertical dimension + scan_carry = str(stencil_expr.params[0].id) + + # params[1]: boolean flag for forward/backward scan + assert isinstance(scan_expr.args[1], gtir.Literal) and ti.is_logical(scan_expr.args[1].type) + scan_forward = scan_expr.args[1].value == "True" + + # params[2]: the value for scan initialization + init_value = scan_expr.args[2] + + # make naming consistent throughut this function scope + def scan_input_name(input_name: str) -> str: + return f"__input_{input_name}" + + def scan_output_name(input_name: str) -> str: + return f"__output_{input_name}" + + # visit the initialization value of the scan expression + init_data = sdfg_builder.visit(init_value, sdfg=sdfg, head_state=state) + + # extract type definition of the scan carry + scan_carry_type = ( + init_data.gt_type if isinstance(init_data, FieldopData) else get_tuple_type(init_data) + ) + + # create list of params to the lambda function with associated node type + lambda_symbols = {scan_carry: scan_carry_type} | { + str(p.id): arg.type + for p, arg in zip(stencil_expr.params[1:], node.args, strict=True) + if isinstance(arg.type, ts.DataType) + } + + # visit the arguments to be passed to the lambda expression + # obs. this must be executed before visiting the lambda expression, in order to populate + # the data descriptor with the correct field domain offsets for field arguments + lambda_args = [sdfg_builder.visit(arg, sdfg=sdfg, head_state=state) for arg in node.args] + lambda_args_mapping = { + scan_input_name(scan_carry): init_data, + } | { + str(param.id): arg for param, arg in zip(stencil_expr.params[1:], lambda_args, strict=True) + } + + # parse the dataflow input and output symbols + lambda_flat_args: dict[str, FieldopData] = {} + lambda_field_offsets: dict[str, Optional[list[dace.symbolic.SymExpr]]] = {} + for param, arg in lambda_args_mapping.items(): + tuple_fields = flatten_tuples(param, arg) + lambda_field_offsets |= {tsym: tfield.offset for tsym, tfield in tuple_fields} + lambda_flat_args |= dict(tuple_fields) + lambda_flat_outs = ( + { + str(sym.id): sym.type + for sym in dace_gtir_utils.flatten_tuple_fields( + scan_output_name(scan_carry), scan_carry_type + ) + } + if isinstance(scan_carry_type, ts.TupleType) + else {scan_output_name(scan_carry): scan_carry_type} + ) + + # the scan operator is implemented as an nested SDFG implementing the lambda expression + nsdfg = dace.SDFG(sdfg_builder.unique_nsdfg_name(sdfg, "scan")) + nsdfg.debuginfo = dace_utils.debug_info(node, default=sdfg.debuginfo) + + # extract the scan loop range + scan_loop_var = dace_gtir_utils.get_map_variable(scan_dim) + _, scan_output_offset, scan_output_shape = _get_field_layout(scan_domain) + + # create a loop region for lambda call over the scan dimension + if scan_forward: + scan_loop = dace.sdfg.state.LoopRegion( + label="scan", + condition_expr=f"{scan_loop_var} < {scan_upper_bound}", + loop_var=scan_loop_var, + initialize_expr=f"{scan_loop_var} = {scan_lower_bound}", + update_expr=f"{scan_loop_var} = {scan_loop_var} + 1", + inverted=False, + ) + else: + scan_loop = dace.sdfg.state.LoopRegion( + label="scan", + condition_expr=f"{scan_loop_var} >= {scan_lower_bound}", + loop_var=scan_loop_var, + initialize_expr=f"{scan_loop_var} = {scan_upper_bound} - 1", + update_expr=f"{scan_loop_var} = {scan_loop_var} - 1", + inverted=False, + ) + + nsdfg.add_node(scan_loop) + compute_state = scan_loop.add_state("scan_compute", is_start_block=True) + update_state = scan_loop.add_state("scan_update") + scan_loop.add_edge(compute_state, update_state, dace.InterstateEdge()) + + init_state = nsdfg.add_state("scan_init", is_start_block=True) + nsdfg.add_edge(init_state, scan_loop, dace.InterstateEdge()) + + # visit the list of arguments to be passed to the scan expression + stencil_builder = sdfg_builder.nested_context(nsdfg, lambda_symbols, lambda_field_offsets) + stencil_args = [ + _parse_fieldop_arg( + im.ref(p.id), nsdfg, compute_state, stencil_builder, domain, by_value=True + ) + for p in stencil_expr.params + ] + + # generate the dataflow representing the scan field operator + input_edges, result = gtir_dataflow.visit_lambda( + nsdfg, compute_state, stencil_builder, stencil_expr, args=stencil_args + ) + + # now initialize the scan carry + scan_carry_input = ( + dace_gtir_utils.make_symbol_tuple(scan_carry, scan_carry_type) + if isinstance(scan_carry_type, ts.TupleType) + else im.sym(scan_carry, scan_carry_type) + ) + + def init_scan_carry(sym: gtir.Sym) -> None: + scan_state = str(sym.id) + scan_state_desc = nsdfg.data(scan_state) + input_state = scan_input_name(scan_state) + input_state_desc = scan_state_desc.clone() + nsdfg.add_datadesc(input_state, input_state_desc) + scan_state_desc.transient = True + init_state.add_nedge( + init_state.add_access(input_state), + init_state.add_access(scan_state), + nsdfg.make_array_memlet(input_state), + ) + + if isinstance(scan_carry_input, tuple): + gtx_utils.tree_map(init_scan_carry)(scan_carry_input) + else: + init_scan_carry(scan_carry_input) + + # connect the dataflow input directly to the source data nodes, without passing through a map node; + # the reason is that the map for horizontal domain is outside the scan loop region + for edge in input_edges: + edge.connect(map_entry=None) + + # connect the dataflow result nodes to the carry variables + def connect_scan_output( + scan_output_edge: gtir_dataflow.DataflowOutputEdge, sym: gtir.Sym + ) -> FieldopData: + scan_result = scan_output_edge.result + assert isinstance(scan_result.gt_dtype, ts.ScalarType) + assert scan_result.gt_dtype == sym.type + scan_result_data = scan_result.dc_node.data + scan_result_desc = scan_result.dc_node.desc(nsdfg) + + output, _ = nsdfg.add_array( + scan_output_name(sym.id), scan_output_shape, scan_result_desc.dtype, find_new_name=True + ) + output_node = compute_state.add_access(output) + output_subset = str(dace.symbolic.SymExpr(scan_loop_var) - scan_lower_bound) + compute_state.add_nedge( + scan_result.dc_node, output_node, dace.Memlet(data=output, subset=output_subset) + ) + + update_state.add_nedge( + update_state.add_access(scan_result_data), + update_state.add_access(sym.id), + dace.Memlet(data=sym.id, subset="0"), + ) + + output_type = ts.FieldType(dims=[scan_dim], dtype=scan_result.gt_dtype) + return FieldopData(output_node, output_type, scan_output_offset) + + lambda_output = ( + gtx_utils.tree_map(connect_scan_output)(result, scan_carry_input) + if (isinstance(result, tuple) and isinstance(scan_carry_input, tuple)) + else connect_scan_output(result, scan_carry_input) + if ( + isinstance(result, gtir_dataflow.DataflowOutputEdge) + and isinstance(scan_carry_input, gtir.Sym) + ) + else None + ) + assert lambda_output + + # in case tuples are passed as argument, isolated non-transient nodes might be left in the state, + # because not all tuple fields are necessarily accessed in the lambda scope + for data_node in compute_state.data_nodes(): + data_desc = data_node.desc(nsdfg) + if (compute_state.degree(data_node) == 0) and ( + (not data_desc.transient) + or data_node.data.startswith( + scan_carry + ) # exceptional case where the carry variable is not used, not a scan indeed + ): + # isolated node, connect it to a transient to avoid SDFG validation errors + temp, temp_desc = nsdfg.add_temp_transient_like(data_desc) + temp_node = compute_state.add_access(temp) + compute_state.add_nedge(data_node, temp_node, dace.Memlet.from_array(temp, temp_desc)) + + # build the mapping of symbols from nested SDFG to parent SDFG + nsdfg_symbols_mapping: dict[str, dace.symbolic.SymExpr] = {} + for dim, _, _ in domain: + if dim != scan_dim: + dim_map_variable = dace_gtir_utils.get_map_variable(dim) + nsdfg_symbols_mapping[dim_map_variable] = dim_map_variable + for inner, arg in lambda_flat_args.items(): + inner_desc = nsdfg.data(inner) + outer_desc = arg.dc_node.desc(sdfg) + nsdfg_symbols_mapping |= { + str(nested_symbol): parent_symbol + for nested_symbol, parent_symbol in zip( + [*inner_desc.shape, *inner_desc.strides], + [*outer_desc.shape, *outer_desc.strides], + strict=True, + ) + if isinstance(nested_symbol, dace.symbol) + } + + # the scan nested SDFG is ready, now we need to instantiate it inside the map implementing the field operator + nsdfg_node = state.add_nested_sdfg( + nsdfg, + sdfg, + inputs=set(lambda_flat_args.keys()), + outputs=set(lambda_flat_outs.keys()), + symbol_mapping=nsdfg_symbols_mapping, + ) + + input_edges = [] + for input_connector, arg in lambda_flat_args.items(): + arg_desc = arg.dc_node.desc(sdfg) + input_subset = dace_subsets.Range.from_array(arg_desc) + input_edge = gtir_dataflow.MemletInputEdge( + state, arg.dc_node, input_subset, nsdfg_node, input_connector + ) + input_edges.append(input_edge) + + def construct_output_edge(scan_data: FieldopData) -> gtir_dataflow.DataflowOutputEdge: + assert isinstance(scan_data.gt_type, ts.FieldType) + inner_data = scan_data.dc_node.data + inner_desc = nsdfg.data(inner_data) + output_data, output_desc = sdfg.add_temp_transient_like(inner_desc) + output_node = state.add_access(output_data) + state.add_edge( + nsdfg_node, + inner_data, + output_node, + None, + dace.Memlet.from_array(output_data, output_desc), + ) + output_expr = gtir_dataflow.ValueExpr(output_node, scan_data.gt_type.dtype) + return gtir_dataflow.DataflowOutputEdge(state, output_expr) + + output_edges = ( + construct_output_edge(lambda_output) + if isinstance(lambda_output, FieldopData) + else gtx_utils.tree_map(construct_output_edge)(lambda_output) + ) + + return _create_field_operator( + sdfg, state, domain, node.type, sdfg_builder, input_edges, output_edges, scan_dim + ) + + def translate_symbol_ref( node: gtir.Node, sdfg: dace.SDFG, @@ -785,5 +1179,6 @@ def translate_symbol_ref( translate_make_tuple, translate_tuple_get, translate_scalar_expr, + translate_scan, translate_symbol_ref, ] diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow.py index a3653fb519..4f53a9dcad 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow.py @@ -28,9 +28,10 @@ from dace import subsets as dace_subsets from gt4py import eve -from gt4py.next import common as gtx_common +from gt4py.next import common as gtx_common, utils as gtx_utils from gt4py.next.iterator import ir as gtir from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm, ir_makers as im +from gt4py.next.iterator.transforms import symbol_ref_utils from gt4py.next.iterator.type_system import type_specifications as itir_ts from gt4py.next.program_processors.runners.dace_common import utility as dace_utils from gt4py.next.program_processors.runners.dace_fieldview import ( @@ -150,7 +151,7 @@ class DataflowInputEdge(Protocol): """ @abc.abstractmethod - def connect(self, me: dace.nodes.MapEntry) -> None: ... + def connect(self, map_entry: Optional[dace.nodes.MapEntry]) -> None: ... @dataclasses.dataclass(frozen=True) @@ -168,15 +169,18 @@ class MemletInputEdge(DataflowInputEdge): dest: dace.nodes.AccessNode | dace.nodes.Tasklet dest_conn: Optional[str] - def connect(self, me: dace.nodes.MapEntry) -> None: + def connect(self, map_entry: Optional[dace.nodes.MapEntry]) -> None: memlet = dace.Memlet(data=self.source.data, subset=self.subset) - self.state.add_memlet_path( - self.source, - me, - self.dest, - dst_conn=self.dest_conn, - memlet=memlet, - ) + if map_entry is None: + self.state.add_edge(self.source, None, self.dest, self.dest_conn, memlet) + else: + self.state.add_memlet_path( + self.source, + map_entry, + self.dest, + dst_conn=self.dest_conn, + memlet=memlet, + ) @dataclasses.dataclass(frozen=True) @@ -191,8 +195,10 @@ class EmptyInputEdge(DataflowInputEdge): state: dace.SDFGState node: dace.nodes.Tasklet - def connect(self, me: dace.nodes.MapEntry) -> None: - self.state.add_nedge(me, self.node, dace.Memlet()) + def connect(self, map_entry: Optional[dace.nodes.MapEntry]) -> None: + # cannot create empty edge from MapEntry node, if this is not present + if map_entry is not None: + self.state.add_nedge(map_entry, self.node, dace.Memlet()) @dataclasses.dataclass(frozen=True) @@ -212,13 +218,13 @@ class DataflowOutputEdge: def connect( self, - mx: dace.nodes.MapExit, + map_exit: Optional[dace.nodes.MapExit], dest: dace.nodes.AccessNode, subset: dace_subsets.Range, ) -> None: # retrieve the node which writes the result last_node = self.state.in_edges(self.result.dc_node)[0].src - if isinstance(last_node, dace.nodes.Tasklet): + if isinstance(last_node, (dace.nodes.Tasklet, dace.nodes.NestedSDFG)): # the last transient node can be deleted last_node_connector = self.state.in_edges(self.result.dc_node)[0].src_conn self.state.remove_node(self.result.dc_node) @@ -226,13 +232,22 @@ def connect( last_node = self.result.dc_node last_node_connector = None - self.state.add_memlet_path( - last_node, - mx, - dest, - src_conn=last_node_connector, - memlet=dace.Memlet(data=dest.data, subset=subset), - ) + if map_exit is None: + self.state.add_edge( + last_node, + last_node_connector, + dest, + None, + dace.Memlet(data=dest.data, subset=subset), + ) + else: + self.state.add_memlet_path( + last_node, + map_exit, + dest, + src_conn=last_node_connector, + memlet=dace.Memlet(data=dest.data, subset=subset), + ) DACE_REDUCTION_MAPPING: dict[str, dace.dtypes.ReductionType] = { @@ -290,9 +305,10 @@ class LambdaToDataflow(eve.NodeVisitor): state: dace.SDFGState subgraph_builder: gtir_sdfg.DataflowBuilder input_edges: list[DataflowInputEdge] = dataclasses.field(default_factory=lambda: []) - symbol_map: dict[str, IteratorExpr | MemletExpr | SymbolExpr] = dataclasses.field( - default_factory=lambda: {} - ) + symbol_map: dict[ + str, + IteratorExpr | DataExpr | tuple[IteratorExpr | DataExpr | tuple[Any, ...], ...], + ] = dataclasses.field(default_factory=lambda: {}) def _add_input_data_edge( self, @@ -546,6 +562,222 @@ def _visit_deref(self, node: gtir.FunCall) -> DataExpr: return self._construct_tasklet_result(field_desc.dtype, deref_node, "val") + def _visit_if(self, node: gtir.FunCall) -> ValueExpr | tuple[ValueExpr | tuple[Any, ...], ...]: + assert len(node.args) == 3 + + # TODO(edopao): enable once supported in next DaCe release + use_conditional_block: Final[bool] = False + + condition_value = self.visit(node.args[0]) + assert ( + ( + isinstance(condition_value.gt_dtype, ts.ScalarType) + and condition_value.gt_dtype.kind == ts.ScalarKind.BOOL + ) + if isinstance(condition_value, (MemletExpr, ValueExpr)) + else (condition_value.dc_dtype == dace.dtypes.bool) + ) + + nsdfg = dace.SDFG(self.unique_nsdfg_name(prefix="if_stmt")) + nsdfg.debuginfo = dace_utils.debug_info(node, default=self.sdfg.debuginfo) + + if use_conditional_block: + if_region = dace.sdfg.state.ConditionalBlock("if") + nsdfg.add_node(if_region) + entry_state = nsdfg.add_state("entry", is_start_block=True) + nsdfg.add_edge(entry_state, if_region, dace.InterstateEdge()) + + if_body = dace.sdfg.state.ControlFlowRegion("if_body", sdfg=nsdfg) + tstate = if_body.add_state("true_branch", is_start_block=True) + if_region.add_branch(dace.sdfg.state.CodeBlock("__cond"), if_body) + + else_body = dace.sdfg.state.ControlFlowRegion("else_body", sdfg=nsdfg) + fstate = else_body.add_state("false_branch", is_start_block=True) + if_region.add_branch(dace.sdfg.state.CodeBlock("not (__cond)"), else_body) + + else: + entry_state = nsdfg.add_state("entry", is_start_block=True) + tstate = nsdfg.add_state("true_branch") + nsdfg.add_edge(entry_state, tstate, dace.InterstateEdge(condition="__cond")) + fstate = nsdfg.add_state("false_branch") + nsdfg.add_edge(entry_state, fstate, dace.InterstateEdge(condition="not (__cond)")) + + nsdfg_symbol_mapping = {} + input_memlets: dict[str, tuple[dace.nodes.AccessNode, Optional[dace_subsets.Range]]] = {} + + if isinstance(condition_value, SymbolExpr): + nsdfg.add_symbol("__cond", dace.dtypes.bool) + nsdfg_symbol_mapping["__cond"] = condition_value.value + else: + nsdfg.add_scalar("__cond", dace.dtypes.bool) + if isinstance(condition_value, ValueExpr): + input_memlets["__cond"] = ( + condition_value.dc_node, + dace_subsets.Range.from_string("0"), + ) + else: + assert isinstance(condition_value, MemletExpr) + input_memlets["__cond"] = (condition_value.dc_node, condition_value.subset) + nsdfg_symbol_mapping.update( + {sym: sym for sym in condition_value.subset.free_symbols} + ) + + def visit_branch( + state: dace.SDFGState, expr: gtir.Expr + ) -> tuple[ + list[DataflowInputEdge], + DataflowOutputEdge | tuple[DataflowOutputEdge | tuple[Any, ...], ...], + ]: + assert state in nsdfg.states() + + def visit_arg(arg: IteratorExpr | DataExpr) -> IteratorExpr | ValueExpr: + if isinstance(arg, IteratorExpr): + arg_node = arg.field + arg_desc = arg_node.desc(self.sdfg) + arg_subset = dace_subsets.Range.from_array(arg_desc) + + else: + assert isinstance(arg, (MemletExpr, ValueExpr)) + arg_node = arg.dc_node + if isinstance(arg, MemletExpr): + assert set(arg.subset.size()) == {1} + arg_desc = dace.data.Scalar(arg_node.desc(self.sdfg).dtype) + arg_subset = arg.subset + else: + arg_desc = arg_node.desc(self.sdfg) + arg_subset = None + + arg_data = arg_node.data + # SDFG data containers with name prefix '__tmp' are expected to be transients + inner_data = ( + arg_data.replace("__tmp", "__input") + if arg_data.startswith("__tmp") + else arg_data + ) + + try: + inner_desc = nsdfg.data(inner_data) + assert not inner_desc.transient + except KeyError: + inner_desc = arg_desc.clone() + inner_desc.transient = False + nsdfg.add_datadesc(inner_data, inner_desc) + input_memlets[inner_data] = (arg_node, arg_subset) + + inner_node = state.add_access(inner_data) + if isinstance(arg, IteratorExpr): + return IteratorExpr(inner_node, arg.gt_dtype, arg.field_domain, arg.indices) + else: + assert isinstance(inner_desc, dace.data.Scalar) + return ValueExpr(inner_node, arg.gt_dtype) + + lambda_params = [] + lambda_args: list[ + IteratorExpr + | MemletExpr + | ValueExpr + | tuple[IteratorExpr | MemletExpr | ValueExpr | tuple[Any, ...], ...] + ] = [] + for p in symbol_ref_utils.collect_symbol_refs(expr, self.symbol_map.keys()): + arg = self.symbol_map[p] + if isinstance(arg, tuple): + inner_arg = gtx_utils.tree_map(visit_arg)(arg) + else: + inner_arg = visit_arg(arg) + lambda_args.append(inner_arg) + lambda_params.append(im.sym(p)) + + # visit each branch of the if-statement as it was a Lambda node + lambda_node = gtir.Lambda(params=lambda_params, expr=expr) + return visit_lambda(nsdfg, state, self.subgraph_builder, lambda_node, lambda_args) + + for state, arg in zip([tstate, fstate], node.args[1:3]): + in_edges, out_edge = visit_branch(state, arg) + for edge in in_edges: + edge.connect(map_entry=None) + + def construct_output( + output_state: dace.SDFGState, edge: DataflowOutputEdge, sym: gtir.Sym + ) -> ValueExpr: + output_data = str(sym.id) + try: + output_desc = nsdfg.data(output_data) + assert not output_desc.transient + except KeyError: + result_desc = edge.result.dc_node.desc(nsdfg) + output_desc = result_desc.clone() + output_desc.transient = False + output_data = nsdfg.add_datadesc(output_data, output_desc, find_new_name=True) + output_node = output_state.add_access(output_data) + output_state.add_nedge( + edge.result.dc_node, + output_node, + dace.Memlet.from_array(output_data, output_desc), + ) + return ValueExpr(output_node, edge.result.gt_dtype) + + if isinstance(out_edge, tuple): + assert isinstance(node.type, ts.TupleType) + out_symbol = dace_gtir_utils.make_symbol_tuple("__output", node.type) + outer_value = gtx_utils.tree_map( + lambda x, y, output_state=state: construct_output(output_state, x, y) + )(out_edge, out_symbol) + else: + assert isinstance(node.type, ts.FieldType | ts.ScalarType) + outer_value = construct_output(state, out_edge, im.sym("__output", node.type)) + + # in case tuples are passed as argument, isolated non-transient nodes might be left in the state, + # because not all tuple fields are necessarily used in the lambda scope + for data_node in state.data_nodes(): + data_desc = data_node.desc(nsdfg) + if (not data_desc.transient) and (state.degree(data_node) == 0): + # isolated node, connect it to a transient to avoid SDFG validation errors + temp, temp_desc = nsdfg.add_temp_transient_like(data_desc) + temp_node = state.add_access(temp) + state.add_nedge(data_node, temp_node, dace.Memlet.from_array(temp, temp_desc)) + + else: + result = outer_value + + outputs = {outval.dc_node.data for outval in gtx_utils.flatten_nested_tuple((result,))} + + nsdfg_node = self.state.add_nested_sdfg( + nsdfg, + self.sdfg, + inputs=set(input_memlets.keys()), + outputs=outputs, + symbol_mapping=nsdfg_symbol_mapping | {str(sym): sym for sym in nsdfg.free_symbols}, + ) + + for inner, (src_node, src_subset) in input_memlets.items(): + if src_subset is None: + self._add_edge( + src_node, None, nsdfg_node, inner, self.sdfg.make_array_memlet(src_node.data) + ) + else: + self._add_input_data_edge(src_node, src_subset, nsdfg_node, inner) + + def connect_output(inner_value: ValueExpr) -> ValueExpr: + inner_data = inner_value.dc_node.data + inner_desc = inner_value.dc_node.desc(nsdfg) + assert not inner_desc.transient + output, output_desc = self.sdfg.add_temp_transient_like(inner_desc) + output_node = self.state.add_access(output) + self.state.add_edge( + nsdfg_node, + inner_data, + output_node, + None, + dace.Memlet.from_array(output, output_desc), + ) + return ValueExpr(output_node, inner_value.gt_dtype) + + return ( + gtx_utils.tree_map(connect_output)(result) + if isinstance(result, tuple) + else connect_output(result) + ) + def _visit_neighbors(self, node: gtir.FunCall) -> ValueExpr: assert isinstance(node.type, itir_ts.ListType) assert len(node.args) == 2 @@ -1254,13 +1486,42 @@ def _visit_generic_builtin(self, node: gtir.FunCall) -> ValueExpr: return self._construct_tasklet_result(dc_dtype, tasklet_node, "result", use_array=use_array) - def visit_FunCall(self, node: gtir.FunCall) -> IteratorExpr | DataExpr: + def _visit_make_tuple(self, node: gtir.FunCall) -> tuple[IteratorExpr | DataExpr]: + assert cpm.is_call_to(node, "make_tuple") + return tuple(self.visit(arg) for arg in node.args) + + def _visit_tuple_get( + self, node: gtir.FunCall + ) -> IteratorExpr | DataExpr | tuple[IteratorExpr | DataExpr]: + assert cpm.is_call_to(node, "tuple_get") + assert len(node.args) == 2 + + if not isinstance(node.args[0], gtir.Literal): + raise ValueError("Tuple can only be subscripted with compile-time constants.") + assert ti.is_integral(node.args[0].type) + index = int(node.args[0].value) + + tuple_fields = self.visit(node.args[1]) + return tuple_fields[index] + + def visit_FunCall( + self, node: gtir.FunCall + ) -> IteratorExpr | DataExpr | tuple[IteratorExpr | DataExpr | tuple[Any, ...], ...]: if cpm.is_call_to(node, "deref"): return self._visit_deref(node) + elif cpm.is_call_to(node, "if_"): + return self._visit_if(node) + elif cpm.is_call_to(node, "neighbors"): return self._visit_neighbors(node) + elif cpm.is_call_to(node, "make_tuple"): + return self._visit_make_tuple(node) + + elif cpm.is_call_to(node, "tuple_get"): + return self._visit_tuple_get(node) + elif cpm.is_applied_map(node): return self._visit_map(node) @@ -1280,35 +1541,57 @@ def visit_FunCall(self, node: gtir.FunCall) -> IteratorExpr | DataExpr: else: raise NotImplementedError(f"Invalid 'FunCall' node: {node}.") - def visit_Lambda(self, node: gtir.Lambda) -> DataflowOutputEdge: - result: DataExpr = self.visit(node.expr) + def visit_Lambda( + self, node: gtir.Lambda + ) -> DataflowOutputEdge | tuple[DataflowOutputEdge | tuple[Any, ...], ...]: + result = self.visit(node.expr) + + def make_output_edge( + output_expr: ValueExpr | MemletExpr | SymbolExpr, + ) -> DataflowOutputEdge: + if isinstance(output_expr, ValueExpr): + return DataflowOutputEdge(self.state, output_expr) + + if isinstance(output_expr, MemletExpr): + # special case where the field operator is simply copying data from source to destination node + output_dtype = output_expr.dc_node.desc(self.sdfg).dtype + tasklet_node = self._add_tasklet("copy", {"__inp"}, {"__out"}, "__out = __inp") + self._add_input_data_edge( + output_expr.dc_node, + output_expr.subset, + tasklet_node, + "__inp", + ) + else: + # even simpler case, where a constant value is written to destination node + output_dtype = output_expr.dc_dtype + tasklet_node = self._add_tasklet( + "write", {}, {"__out"}, f"__out = {output_expr.value}" + ) - if isinstance(result, ValueExpr): - return DataflowOutputEdge(self.state, result) + output_expr = self._construct_tasklet_result(output_dtype, tasklet_node, "__out") + return DataflowOutputEdge(self.state, output_expr) - if isinstance(result, MemletExpr): - # special case where the field operator is simply copying data from source to destination node - output_dtype = result.dc_node.desc(self.sdfg).dtype - tasklet_node = self._add_tasklet("copy", {"__inp"}, {"__out"}, "__out = __inp") - self._add_input_data_edge( - result.dc_node, - result.subset, - tasklet_node, - "__inp", - ) - else: - # even simpler case, where a constant value is written to destination node - output_dtype = result.dc_dtype - tasklet_node = self._add_tasklet("write", {}, {"__out"}, f"__out = {result.value}") + def parse_result( + r: DataflowOutputEdge | ValueExpr | MemletExpr | SymbolExpr, + ) -> DataflowOutputEdge: + if isinstance(r, DataflowOutputEdge): + return r + return make_output_edge(r) - output_expr = self._construct_tasklet_result(output_dtype, tasklet_node, "__out") - return DataflowOutputEdge(self.state, output_expr) + return ( + gtx_utils.tree_map(parse_result)(result) + if isinstance(result, tuple) + else parse_result(result) + ) def visit_Literal(self, node: gtir.Literal) -> SymbolExpr: dc_dtype = dace_utils.as_dace_type(node.type) return SymbolExpr(node.value, dc_dtype) - def visit_SymRef(self, node: gtir.SymRef) -> IteratorExpr | MemletExpr | SymbolExpr: + def visit_SymRef( + self, node: gtir.SymRef + ) -> IteratorExpr | DataExpr | tuple[IteratorExpr | DataExpr | tuple[Any, ...], ...]: param = str(node.id) if param in self.symbol_map: return self.symbol_map[param] @@ -1319,8 +1602,13 @@ def visit_SymRef(self, node: gtir.SymRef) -> IteratorExpr | MemletExpr | SymbolE def visit_let( self, node: gtir.Lambda, - args: Sequence[IteratorExpr | MemletExpr | SymbolExpr], - ) -> DataflowOutputEdge: + args: Sequence[ + IteratorExpr + | MemletExpr + | ValueExpr + | tuple[IteratorExpr | MemletExpr | ValueExpr | tuple[Any, ...], ...] + ], + ) -> DataflowOutputEdge | tuple[DataflowOutputEdge | tuple[Any, ...], ...]: """ Maps lambda arguments to internal parameters. @@ -1353,10 +1641,18 @@ def visit_let( def visit_lambda( sdfg: dace.SDFG, state: dace.SDFGState, - sdfg_builder: gtir_sdfg.SDFGBuilder, + sdfg_builder: gtir_sdfg.DataflowBuilder, node: gtir.Lambda, - args: Sequence[IteratorExpr | MemletExpr | SymbolExpr], -) -> tuple[list[DataflowInputEdge], DataflowOutputEdge]: + args: Sequence[ + IteratorExpr + | MemletExpr + | ValueExpr + | tuple[IteratorExpr | MemletExpr | ValueExpr | tuple[Any, ...], ...] + ], +) -> tuple[ + list[DataflowInputEdge], + DataflowOutputEdge | tuple[DataflowOutputEdge | tuple[Any, ...], ...], +]: """ Entry point to visit a `Lambda` node and lower it to a dataflow graph, that can be instantiated inside a map scope implementing the field operator. @@ -1368,7 +1664,7 @@ def visit_lambda( Args: sdfg: The SDFG where the dataflow graph will be instantiated. state: The SDFG state where the dataflow graph will be instantiated. - sdfg_builder: Helper class to build the SDFG. + sdfg_builder: Helper class to build the dataflow inside the given SDFG. node: Lambda node to visit. args: Arguments passed to lambda node. @@ -1378,5 +1674,5 @@ def visit_lambda( - Output data connection. """ taskgen = LambdaToDataflow(sdfg, state, sdfg_builder) - output_edge = taskgen.visit_let(node, args) - return taskgen.input_edges, output_edge + output_edges = taskgen.visit_let(node, args) + return taskgen.input_edges, output_edges diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_sdfg.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_sdfg.py index 9bd40f75f8..30185697a5 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_sdfg.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_sdfg.py @@ -22,6 +22,7 @@ from typing import Any, Dict, Iterable, List, Optional, Protocol, Sequence, Set, Tuple, Union import dace +from dace.sdfg import utils as dace_sdfg_utils from gt4py import eve from gt4py.eve import concepts @@ -111,6 +112,21 @@ def get_symbol_type(self, symbol_name: str) -> ts.DataType: """Retrieve the GT4Py type of a symbol used in the SDFG.""" ... + @abc.abstractmethod + def is_column_dimension(self, dim: gtx_common.Dimension) -> bool: + """Check if the given dimension is the column dimension.""" + ... + + @abc.abstractmethod + def nested_context( + self, + sdfg: dace.SDFG, + global_symbols: dict[str, ts.DataType], + field_offsets: dict[str, Optional[list[dace.symbolic.SymExpr]]], + ) -> SDFGBuilder: + """Create a new empty context, useful to build a nested SDFG.""" + ... + @abc.abstractmethod def visit(self, node: concepts.RootNode, **kwargs: Any) -> Any: """Visit a node of the GT4Py IR.""" @@ -149,15 +165,6 @@ def _collect_symbols_in_domain_expressions( ) -def _get_tuple_type(data: tuple[gtir_builtin_translators.FieldopResult, ...]) -> ts.TupleType: - """ - Compute the `ts.TupleType` corresponding to the structure of a tuple of data nodes. - """ - return ts.TupleType( - types=[_get_tuple_type(d) if isinstance(d, tuple) else d.gt_type for d in data] - ) - - @dataclasses.dataclass(frozen=True) class GTIRToSDFG(eve.NodeVisitor, SDFGBuilder): """Provides translation capability from a GTIR program to a DaCe SDFG. @@ -173,6 +180,7 @@ class GTIRToSDFG(eve.NodeVisitor, SDFGBuilder): """ offset_provider_type: gtx_common.OffsetProviderType + column_dim: Optional[gtx_common.Dimension] global_symbols: dict[str, ts.DataType] = dataclasses.field(default_factory=lambda: {}) field_offsets: dict[str, Optional[list[dace.symbolic.SymExpr]]] = dataclasses.field( default_factory=lambda: {} @@ -199,6 +207,25 @@ def make_field( def get_symbol_type(self, symbol_name: str) -> ts.DataType: return self.global_symbols[symbol_name] + def is_column_dimension(self, dim: gtx_common.Dimension) -> bool: + assert self.column_dim + return dim == self.column_dim + + def nested_context( + self, + sdfg: dace.SDFG, + global_symbols: dict[str, ts.DataType], + field_offsets: dict[str, Optional[list[dace.symbolic.SymExpr]]], + ) -> SDFGBuilder: + nsdfg_builder = GTIRToSDFG( + self.offset_provider_type, self.column_dim, global_symbols, field_offsets + ) + nsdfg_params = [ + gtir.Sym(id=p_name, type=p_type) for p_name, p_type in global_symbols.items() + ] + nsdfg_builder._add_sdfg_params(sdfg, node_params=nsdfg_params, symbolic_arguments=None) + return nsdfg_builder + def unique_nsdfg_name(self, sdfg: dace.SDFG, prefix: str) -> str: nsdfg_list = [ nsdfg.label for nsdfg in sdfg.all_sdfgs_recursive() if nsdfg.label.startswith(prefix) @@ -277,10 +304,11 @@ def _add_storage( """ if isinstance(gt_type, ts.TupleType): tuple_fields = [] - for tname, ttype in dace_gtir_utils.get_tuple_fields(name, gt_type, flatten=True): + for sym in dace_gtir_utils.flatten_tuple_fields(name, gt_type): + assert isinstance(sym.type, ts.DataType) tuple_fields.extend( self._add_storage( - sdfg, symbolic_arguments, tname, ttype, transient, tuple_name=name + sdfg, symbolic_arguments, sym.id, sym.type, transient, tuple_name=name ) ) return tuple_fields @@ -379,7 +407,7 @@ def _add_sdfg_params( self, sdfg: dace.SDFG, node_params: Sequence[gtir.Sym], - symbolic_arguments: set[str], + symbolic_arguments: Optional[set[str]], ) -> list[str]: """ Helper function to add storage for node parameters and connectivity tables. @@ -389,6 +417,9 @@ def _add_sdfg_params( except when they are listed in 'symbolic_arguments', in which case they will be represented in the SDFG as DaCe symbols. """ + if symbolic_arguments is None: + symbolic_arguments = set() + # add non-transient arrays and/or SDFG symbols for the program arguments sdfg_args = [] for param in node_params: @@ -436,7 +467,7 @@ def visit_Program(self, node: gtir.Program) -> dace.SDFG: assert len(self.field_offsets) == 0 sdfg = dace.SDFG(node.id) - sdfg.debuginfo = dace_utils.debug_info(node, default=sdfg.debuginfo) + sdfg.debuginfo = dace_utils.debug_info(node) # DaCe requires C-compatible strings for the names of data containers, # such as arrays and scalars. GT4Py uses a unicode symbols ('ᐞ') as name @@ -619,24 +650,13 @@ def visit_Lambda( (str(param.id), arg) for param, arg in zip(node.params, args, strict=True) ] - def flatten_tuples( - name: str, - arg: gtir_builtin_translators.FieldopResult, - ) -> list[tuple[str, gtir_builtin_translators.FieldopData]]: - if isinstance(arg, tuple): - tuple_type = _get_tuple_type(arg) - tuple_field_names = [ - arg_name for arg_name, _ in dace_gtir_utils.get_tuple_fields(name, tuple_type) - ] - tuple_args = zip(tuple_field_names, arg, strict=True) - return list( - itertools.chain(*[flatten_tuples(fname, farg) for fname, farg in tuple_args]) - ) - else: - return [(name, arg)] - lambda_arg_nodes = dict( - itertools.chain(*[flatten_tuples(pname, arg) for pname, arg in lambda_args_mapping]) + itertools.chain( + *[ + gtir_builtin_translators.flatten_tuples(pname, arg) + for pname, arg in lambda_args_mapping + ] + ) ) # inherit symbols from parent scope but eventually override with local symbols @@ -644,7 +664,9 @@ def flatten_tuples( sym: self.global_symbols[sym] for sym in symbol_ref_utils.collect_symbol_refs(node.expr, self.global_symbols.keys()) } | { - pname: _get_tuple_type(arg) if isinstance(arg, tuple) else arg.gt_type + pname: gtir_builtin_translators.get_tuple_type(arg) + if isinstance(arg, tuple) + else arg.gt_type for pname, arg in lambda_args_mapping } @@ -659,12 +681,12 @@ def get_field_domain_offset( elif field_domain_offset := self.field_offsets.get(p_name, None): return {p_name: field_domain_offset} elif isinstance(p_type, ts.TupleType): - p_fields = dace_gtir_utils.get_tuple_fields(p_name, p_type, flatten=True) + tsyms = dace_gtir_utils.flatten_tuple_fields(p_name, p_type) return functools.reduce( - lambda field_offsets, field: ( - field_offsets | get_field_domain_offset(field[0], field[1]) + lambda field_offsets, sym: ( + field_offsets | get_field_domain_offset(sym.id, sym.type) # type: ignore[arg-type] ), - p_fields, + tsyms, {}, ) return {} @@ -676,7 +698,7 @@ def get_field_domain_offset( # lower let-statement lambda node as a nested SDFG lambda_translator = GTIRToSDFG( - self.offset_provider_type, lambda_symbols, lambda_field_offsets + self.offset_provider_type, self.column_dim, lambda_symbols, lambda_field_offsets ) nsdfg = dace.SDFG(name=self.unique_nsdfg_name(sdfg, "lambda")) nsdfg.debuginfo = dace_utils.debug_info(node, default=sdfg.debuginfo) @@ -853,6 +875,7 @@ def visit_SymRef( def build_sdfg_from_gtir( ir: gtir.Program, offset_provider_type: gtx_common.OffsetProviderType, + column_dim: Optional[gtx_common.Dimension] = None, ) -> dace.SDFG: """ Receives a GTIR program and lowers it to a DaCe SDFG. @@ -863,6 +886,7 @@ def build_sdfg_from_gtir( Args: ir: The GTIR program node to be lowered to SDFG offset_provider_type: The definitions of offset providers used by the program node + column_dim: Vertical dimension used for scan expressions. Returns: An SDFG in the DaCe canonical form (simplified) @@ -870,8 +894,11 @@ def build_sdfg_from_gtir( ir = gtir_type_inference.infer(ir, offset_provider_type=offset_provider_type) ir = ir_prune_casts.PruneCasts().visit(ir) - sdfg_genenerator = GTIRToSDFG(offset_provider_type) + sdfg_genenerator = GTIRToSDFG(offset_provider_type, column_dim) sdfg = sdfg_genenerator.visit(ir) assert isinstance(sdfg, dace.SDFG) + # TODO(edopao): remove inlining when DaCe transformations support LoopRegion construct + dace_sdfg_utils.inline_loop_blocks(sdfg) + return sdfg diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/utility.py b/src/gt4py/next/program_processors/runners/dace_fieldview/utility.py index 118f0449c8..ad120e2502 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/utility.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/utility.py @@ -8,14 +8,14 @@ from __future__ import annotations -import itertools from typing import Dict, TypeVar import dace from gt4py import eve -from gt4py.next import common as gtx_common +from gt4py.next import common as gtx_common, utils as gtx_utils from gt4py.next.iterator import ir as gtir +from gt4py.next.iterator.ir_utils import ir_makers as im from gt4py.next.type_system import type_specifications as ts @@ -27,35 +27,47 @@ def get_map_variable(dim: gtx_common.Dimension) -> str: return f"i_{dim.value}_gtx_{dim.kind}{suffix}" -def get_tuple_fields( - tuple_name: str, tuple_type: ts.TupleType, flatten: bool = False -) -> list[tuple[str, ts.DataType]]: +def make_symbol_tuple(tuple_name: str, tuple_type: ts.TupleType) -> tuple[gtir.Sym, ...]: """ - Creates a list of names with the corresponding data type for all elements of the given tuple. + Creates a tuple representation of the symbols corresponding to the tuple fields. + The constructed tuple preserves the nested nature of the type, if any. Examples -------- >>> sty = ts.ScalarType(kind=ts.ScalarKind.INT32) >>> fty = ts.FieldType(dims=[], dtype=ts.ScalarType(kind=ts.ScalarKind.FLOAT32)) >>> t = ts.TupleType(types=[sty, ts.TupleType(types=[fty, sty])]) - >>> assert get_tuple_fields("a", t) == [("a_0", sty), ("a_1", ts.TupleType(types=[fty, sty]))] - >>> assert get_tuple_fields("a", t, flatten=True) == [ - ... ("a_0", sty), - ... ("a_1_0", fty), - ... ("a_1_1", sty), - ... ] + >>> assert make_symbol_tuple("a", t) == ( + ... im.sym("a_0", sty), + ... (im.sym("a_1_0", fty), im.sym("a_1_1", sty)), + ... ) """ fields = [(f"{tuple_name}_{i}", field_type) for i, field_type in enumerate(tuple_type.types)] - if flatten: - expanded_fields = [ - get_tuple_fields(field_name, field_type) - if isinstance(field_type, ts.TupleType) - else [(field_name, field_type)] - for field_name, field_type in fields - ] - return list(itertools.chain(*expanded_fields)) - else: - return fields + return tuple( + make_symbol_tuple(field_name, field_type) # type: ignore[misc] + if isinstance(field_type, ts.TupleType) + else im.sym(field_name, field_type) + for field_name, field_type in fields + ) + + +def flatten_tuple_fields(tuple_name: str, tuple_type: ts.TupleType) -> list[gtir.Sym]: + """ + Creates a list of symbols, annotated with the data type, for all elements of the given tuple. + + Examples + -------- + >>> sty = ts.ScalarType(kind=ts.ScalarKind.INT32) + >>> fty = ts.FieldType(dims=[], dtype=ts.ScalarType(kind=ts.ScalarKind.FLOAT32)) + >>> t = ts.TupleType(types=[sty, ts.TupleType(types=[fty, sty])]) + >>> assert flatten_tuple_fields("a", t) == [ + ... im.sym("a_0", sty), + ... im.sym("a_1_0", fty), + ... im.sym("a_1_1", sty), + ... ] + """ + symbol_tuple = make_symbol_tuple(tuple_name, tuple_type) + return list(gtx_utils.flatten_nested_tuple(symbol_tuple)) def replace_invalid_symbols(sdfg: dace.SDFG, ir: gtir.Program) -> gtir.Program: diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/workflow.py b/src/gt4py/next/program_processors/runners/dace_fieldview/workflow.py index a38a50d886..407faf7ec1 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/workflow.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/workflow.py @@ -47,13 +47,13 @@ def generate_sdfg( self, ir: itir.Program, offset_provider: common.OffsetProvider, - column_axis: Optional[common.Dimension], + column_dim: Optional[common.Dimension], auto_opt: bool, on_gpu: bool, ) -> dace.SDFG: ir = itir_transforms.apply_fieldview_transforms(ir, offset_provider=offset_provider) sdfg = gtir_sdfg.build_sdfg_from_gtir( - ir, offset_provider_type=common.offset_provider_to_type(offset_provider) + ir, common.offset_provider_to_type(offset_provider), column_dim ) if auto_opt: diff --git a/tests/next_tests/definitions.py b/tests/next_tests/definitions.py index bed6e89a52..67b566d35b 100644 --- a/tests/next_tests/definitions.py +++ b/tests/next_tests/definitions.py @@ -100,6 +100,7 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum): USES_SCAN_WITHOUT_FIELD_ARGS = "uses_scan_without_field_args" USES_SCAN_NESTED = "uses_scan_nested" USES_SCAN_REQUIRING_PROJECTOR = "uses_scan_requiring_projector" +USES_SCAN_1D_FIELD = "uses_scan_1d_field" USES_SPARSE_FIELDS = "uses_sparse_fields" USES_SPARSE_FIELDS_AS_OUTPUT = "uses_sparse_fields_as_output" USES_REDUCTION_WITH_ONLY_SPARSE_FIELDS = "uses_reduction_with_only_sparse_fields" @@ -134,7 +135,6 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum): ] DACE_SKIP_TEST_LIST = DOMAIN_INFERENCE_SKIP_LIST + [ (USES_NEGATIVE_MODULO, XFAIL, UNSUPPORTED_MESSAGE), - (USES_SCAN, XFAIL, UNSUPPORTED_MESSAGE), (USES_SPARSE_FIELDS_AS_OUTPUT, XFAIL, UNSUPPORTED_MESSAGE), ] EMBEDDED_SKIP_LIST = [ @@ -169,9 +169,17 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum): EmbeddedIds.NUMPY_EXECUTION: EMBEDDED_SKIP_LIST, EmbeddedIds.CUPY_EXECUTION: EMBEDDED_SKIP_LIST, OptionalProgramBackendId.DACE_CPU: DACE_SKIP_TEST_LIST, - OptionalProgramBackendId.DACE_GPU: DACE_SKIP_TEST_LIST, + OptionalProgramBackendId.DACE_GPU: DACE_SKIP_TEST_LIST + + [ + # dace issue https://github.com/spcl/dace/issues/1773 + (USES_SCAN_1D_FIELD, XFAIL, UNSUPPORTED_MESSAGE), + ], OptionalProgramBackendId.DACE_CPU_NO_OPT: DACE_SKIP_TEST_LIST, - OptionalProgramBackendId.DACE_GPU_NO_OPT: DACE_SKIP_TEST_LIST, + OptionalProgramBackendId.DACE_GPU_NO_OPT: DACE_SKIP_TEST_LIST + + [ + # dace issue https://github.com/spcl/dace/issues/1773 + (USES_SCAN_1D_FIELD, XFAIL, UNSUPPORTED_MESSAGE), + ], ProgramBackendId.GTFN_CPU: GTFN_SKIP_TEST_LIST + [(USES_SCAN_NESTED, XFAIL, UNSUPPORTED_MESSAGE)], ProgramBackendId.GTFN_CPU_IMPERATIVE: GTFN_SKIP_TEST_LIST diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py index 9de4449ac2..caef13df3d 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py @@ -818,6 +818,7 @@ def testee(a: cases.EField, b: cases.EField) -> cases.VField: @pytest.mark.uses_scan +@pytest.mark.uses_scan_1d_field def test_ternary_scan(cartesian_case): @gtx.scan_operator(axis=KDim, forward=True, init=0.0) def simple_scan_operator(carry: float, a: float) -> float: