From df1847ad46dd861bd3222048142a5cc422f2787e Mon Sep 17 00:00:00 2001 From: Edoardo Paone Date: Fri, 29 Nov 2024 10:42:54 +0100 Subject: [PATCH 01/80] scan - working draft --- .../gtir_builtin_translators.py | 443 +++++++++++++++--- .../runners/dace_fieldview/gtir_dataflow.py | 147 ++++-- .../runners/dace_fieldview/gtir_sdfg.py | 89 ++-- .../runners/dace_fieldview/utility.py | 19 +- .../runners/dace_fieldview/workflow.py | 4 +- tests/next_tests/definitions.py | 1 - 6 files changed, 549 insertions(+), 154 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py index ff011c4193..702215f97d 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py @@ -10,6 +10,7 @@ import abc import dataclasses +import itertools from typing import TYPE_CHECKING, Final, Iterable, Optional, Protocol, Sequence, TypeAlias import dace @@ -28,9 +29,10 @@ from gt4py.next.program_processors.runners.dace_fieldview import ( gtir_dataflow, gtir_python_codegen, + gtir_sdfg, utility as dace_gtir_utils, ) -from gt4py.next.type_system import type_specifications as ts +from gt4py.next.type_system import type_info as ti, type_specifications as ts if TYPE_CHECKING: @@ -158,6 +160,27 @@ def get_local_view( """Data type used for field indexing.""" +def get_tuple_type(data: tuple[FieldopResult, ...]) -> ts.TupleType: + """ + Compute the `ts.TupleType` corresponding to the structure of a tuple of data nodes. + """ + return ts.TupleType( + types=[get_tuple_type(d) if isinstance(d, tuple) else d.gt_type for d in data] + ) + + +def flatten_tuples(name: str, arg: FieldopResult) -> list[tuple[str, FieldopData]]: + if isinstance(arg, tuple): + tuple_type = get_tuple_type(arg) + tuple_field_names = [ + str(sym.id) for sym in dace_gtir_utils.get_tuple_fields(name, tuple_type) + ] + tuple_args = zip(tuple_field_names, arg, strict=True) + return list(itertools.chain(*[flatten_tuples(fname, farg) for fname, farg in tuple_args])) + else: + return [(name, arg)] + + class PrimitiveTranslator(Protocol): @abc.abstractmethod def __call__( @@ -192,16 +215,20 @@ def _parse_fieldop_arg( state: dace.SDFGState, sdfg_builder: gtir_sdfg.SDFGBuilder, domain: FieldopDomain, -) -> gtir_dataflow.IteratorExpr | gtir_dataflow.MemletExpr: +) -> ( + gtir_dataflow.IteratorExpr + | gtir_dataflow.MemletExpr + | tuple[gtir_dataflow.IteratorExpr | gtir_dataflow.MemletExpr, ...] +): """Helper method to visit an expression passed as argument to a field operator.""" arg = sdfg_builder.visit(node, sdfg=sdfg, head_state=state) - # arguments passed to field operator should be plain fields, not tuples of fields - if not isinstance(arg, FieldopData): - raise ValueError(f"Received {node} as argument to field operator, expected a field.") - - return arg.get_local_view(domain) + if isinstance(arg, FieldopData): + return arg.get_local_view(domain) + else: + # handle tuples of fields + return gtx_utils.tree_map(lambda targ: targ.get_local_view(domain))(arg) def _get_field_layout( @@ -237,11 +264,12 @@ def _create_field_operator( sdfg: dace.SDFG, state: dace.SDFGState, domain: FieldopDomain, - node_type: ts.FieldType, + node_type: ts.FieldType | ts.TupleType, sdfg_builder: gtir_sdfg.SDFGBuilder, - input_edges: Sequence[gtir_dataflow.DataflowInputEdge], - output_edge: gtir_dataflow.DataflowOutputEdge, -) -> FieldopData: + input_edges: Iterable[gtir_dataflow.DataflowInputEdge], + output_edges: gtir_dataflow.DataflowOutputEdge | tuple[gtir_dataflow.DataflowOutputEdge, ...], + scan_dim: Optional[gtx_common.Dimension] = None, +) -> FieldopResult: """ Helper method to allocate a temporary field to store the output of a field operator. @@ -252,39 +280,16 @@ def _create_field_operator( node_type: The GT4Py type of the IR node that produces this field. sdfg_builder: The object used to build the map scope in the provided SDFG. input_edges: List of edges to pass input data into the dataflow. - output_edge: Edge representing the dataflow output data. + output_edges: Single edge or tuple of edges representing the dataflow output data. + scan_dim: Column dimension used in scan field operators. Returns: The field data descriptor, which includes the field access node in the given `state` and the field domain offset. """ - field_dims, field_offset, field_shape = _get_field_layout(domain) - field_indices = _get_domain_indices(field_dims, field_offset) - - dataflow_output_desc = output_edge.result.dc_node.desc(sdfg) - - field_subset = sbs.Range.from_indices(field_indices) - if isinstance(output_edge.result.gt_dtype, ts.ScalarType): - assert output_edge.result.gt_dtype == node_type.dtype - assert isinstance(dataflow_output_desc, dace.data.Scalar) - assert dataflow_output_desc.dtype == dace_utils.as_dace_type(node_type.dtype) - field_dtype = output_edge.result.gt_dtype - else: - assert isinstance(node_type.dtype, itir_ts.ListType) - assert output_edge.result.gt_dtype.element_type == node_type.dtype.element_type - assert isinstance(dataflow_output_desc, dace.data.Array) - assert isinstance(output_edge.result.gt_dtype.element_type, ts.ScalarType) - field_dtype = output_edge.result.gt_dtype.element_type - # extend the array with the local dimensions added by the field operator (e.g. `neighbors`) - assert output_edge.result.gt_dtype.offset_type is not None - field_dims.append(output_edge.result.gt_dtype.offset_type) - field_shape.extend(dataflow_output_desc.shape) - field_offset.extend(dataflow_output_desc.offset) - field_subset = field_subset + sbs.Range.from_array(dataflow_output_desc) - - # allocate local temporary storage - field_name, _ = sdfg.add_temp_transient(field_shape, dataflow_output_desc.dtype) - field_node = state.add_access(field_name) + domain_dims, domain_offset, domain_shape = _get_field_layout(domain) + domain_indices = _get_domain_indices(domain_dims, domain_offset) + domain_subset = sbs.Range.from_indices(domain_indices) # create map range corresponding to the field operator domain me, mx = sdfg_builder.add_map( @@ -293,6 +298,7 @@ def _create_field_operator( ndrange={ dace_gtir_utils.get_map_variable(dim): f"{lower_bound}:{upper_bound}" for dim, lower_bound, upper_bound in domain + if dim != scan_dim }, ) @@ -300,14 +306,60 @@ def _create_field_operator( for edge in input_edges: edge.connect(me) - # and here the edge writing the dataflow result data through the map exit node - output_edge.connect(mx, field_node, field_subset) + def create_field(output_edge: gtir_dataflow.DataflowOutputEdge, sym: gtir.Sym) -> FieldopData: + assert isinstance(sym.type, ts.FieldType) + dataflow_output_desc = output_edge.result.dc_node.desc(sdfg) + if isinstance(output_edge.result.gt_dtype, ts.ScalarType): + assert output_edge.result.gt_dtype == sym.type.dtype + assert dataflow_output_desc.dtype == dace_utils.as_dace_type(sym.type.dtype) + field_dtype = output_edge.result.gt_dtype + field_dims, field_shape, field_offset = (domain_dims, domain_shape, domain_offset) + if scan_dim is not None: + # this is the case of scan expressions, that produce a 1D vertical field + assert domain_dims.index(scan_dim) == (len(domain_dims) - 1) + assert isinstance(dataflow_output_desc, dace.data.Array) + assert len(dataflow_output_desc.shape) == 1 + # the vertical dimension should not belong to the field operator domain + field_subset = sbs.Range(domain_subset[:-1]) + sbs.Range.from_array( + dataflow_output_desc + ) + else: + assert isinstance(dataflow_output_desc, dace.data.Scalar) + field_subset = domain_subset + else: + assert isinstance(sym.type.dtype, itir_ts.ListType) + assert output_edge.result.gt_dtype.element_type == sym.type.dtype.element_type + assert isinstance(dataflow_output_desc, dace.data.Array) + assert isinstance(output_edge.result.gt_dtype.element_type, ts.ScalarType) + field_dtype = output_edge.result.gt_dtype.element_type + # extend the array with the local dimensions added by the field operator (e.g. `neighbors`) + assert output_edge.result.gt_dtype.offset_type is not None + field_dims = [*domain_dims, output_edge.result.gt_dtype.offset_type] + field_shape = [*domain_shape, dataflow_output_desc.shape] + field_offset = [*domain_offset, dataflow_output_desc.offset] + field_subset = [*domain_subset, sbs.Range.from_array(dataflow_output_desc)] + + # allocate local temporary storage + field_name, _ = sdfg.add_temp_transient(field_shape, dataflow_output_desc.dtype) + field_node = state.add_access(field_name) + + # and here the edge writing the dataflow result data through the map exit node + output_edge.connect(mx, field_node, field_subset) + + return FieldopData( + field_node, + ts.FieldType(field_dims, field_dtype), + offset=(field_offset if set(field_offset) != {0} else None), + ) - return FieldopData( - field_node, - ts.FieldType(field_dims, field_dtype), - offset=(field_offset if set(field_offset) != {0} else None), - ) + if isinstance(output_edges, gtir_dataflow.DataflowOutputEdge): + assert isinstance(node_type, ts.FieldType) + return create_field(output_edges, im.sym("x", node_type)) + else: + assert isinstance(node_type, ts.TupleType) + return gtx_utils.tree_map(create_field)( + output_edges, dace_gtir_utils.get_tuple_fields("x", node_type) + ) def extract_domain(node: gtir.Node) -> FieldopDomain: @@ -366,39 +418,41 @@ def translate_as_fieldop( """ assert isinstance(node, gtir.FunCall) assert cpm.is_call_to(node.fun, "as_fieldop") - assert isinstance(node.type, ts.FieldType) fun_node = node.fun assert len(fun_node.args) == 2 - stencil_expr, domain_expr = fun_node.args + fieldop_expr, domain_expr = fun_node.args - if isinstance(stencil_expr, gtir.Lambda): + if cpm.is_call_to(fieldop_expr, "scan"): + return translate_scan(node, sdfg, state, sdfg_builder) + elif isinstance(fieldop_expr, gtir.Lambda): # Default case, handled below: the argument expression is a lambda function # representing the stencil operation to be computed over the field domain. - pass - elif cpm.is_ref_to(stencil_expr, "deref"): + stencil_expr = fieldop_expr + elif cpm.is_ref_to(fieldop_expr, "deref"): # Special usage of 'deref' as argument to fieldop expression, to pass a scalar # value to 'as_fieldop' function. It results in broadcasting the scalar value # over the field domain. stencil_expr = im.lambda_("a")(im.deref("a")) - stencil_expr.expr.type = node.type.dtype # type: ignore[attr-defined] + stencil_expr.expr.type = node.type.dtype else: raise NotImplementedError( - f"Expression type '{type(stencil_expr)}' not supported as argument to 'as_fieldop' node." + f"Expression type '{type(fieldop_expr)}' not supported as argument to 'as_fieldop' node." ) # parse the domain of the field operator domain = extract_domain(domain_expr) # visit the list of arguments to be passed to the lambda expression - stencil_args = [_parse_fieldop_arg(arg, sdfg, state, sdfg_builder, domain) for arg in node.args] + fieldop_args = [_parse_fieldop_arg(arg, sdfg, state, sdfg_builder, domain) for arg in node.args] # represent the field operator as a mapped tasklet graph, which will range over the field domain taskgen = gtir_dataflow.LambdaToDataflow(sdfg, state, sdfg_builder) - input_edges, output_edge = taskgen.visit(stencil_expr, args=stencil_args) + input_edges, output_edges = taskgen.apply(stencil_expr, args=fieldop_args) + assert len(output_edges) == 1 return _create_field_operator( - sdfg, state, domain, node.type, sdfg_builder, input_edges, output_edge + sdfg, state, domain, node.type, sdfg_builder, input_edges, output_edges[0] ) @@ -569,11 +623,10 @@ def _get_data_nodes( return sdfg_builder.make_field(data_node, data_type) elif isinstance(data_type, ts.TupleType): - tuple_fields = dace_gtir_utils.get_tuple_fields(data_name, data_type) - return tuple( - _get_data_nodes(sdfg, state, sdfg_builder, fname, ftype) - for fname, ftype in tuple_fields - ) + tuple_syms = dace_gtir_utils.get_tuple_fields(data_name, data_type) + return gtx_utils.tree_map( + lambda sym: _get_data_nodes(sdfg, state, sdfg_builder, sym.id, sym.type) + )(tuple_syms) else: raise NotImplementedError(f"Symbol type {type(data_type)} not supported.") @@ -654,7 +707,7 @@ def translate_tuple_get( if not isinstance(node.args[0], gtir.Literal): raise ValueError("Tuple can only be subscripted with compile-time constants.") - assert node.args[0].type == dace_utils.as_itir_type(INDEX_DTYPE) + assert ti.is_integral(node.args[0].type) index = int(node.args[0].value) data_nodes = sdfg_builder.visit( @@ -690,10 +743,8 @@ def translate_scalar_expr( visit_expr = True if isinstance(arg_expr, gtir.SymRef): try: - # `gt_symbol` refers to symbols defined in the GT4Py program - gt_symbol_type = sdfg_builder.get_symbol_type(arg_expr.id) - if not isinstance(gt_symbol_type, ts.ScalarType): - raise ValueError(f"Invalid argument to scalar expression {arg_expr}.") + # check if symbol is defined in the GT4Py program, returns `None` if undefined + sdfg_builder.get_symbol_type(arg_expr.id) except KeyError: # this is the case of non-variable argument, e.g. target type such as `float64`, # used in a casting expression like `cast_(variable, float64)` @@ -707,7 +758,7 @@ def translate_scalar_expr( sdfg=sdfg, head_state=state, ) - if not (isinstance(arg, FieldopData) and isinstance(arg.gt_type, ts.ScalarType)): + if not (isinstance(arg, FieldopData) and isinstance(node.type, ts.ScalarType)): raise ValueError(f"Invalid argument to scalar expression {arg_expr}.") param = f"__arg{i}" args.append(arg.dc_node) @@ -755,6 +806,263 @@ def translate_scalar_expr( return FieldopData(temp_node, node.type, offset=None) +def translate_scan( + node: gtir.Node, + sdfg: dace.SDFG, + state: dace.SDFGState, + sdfg_builder: gtir_sdfg.SDFGBuilder, +) -> FieldopResult: + assert isinstance(node, gtir.FunCall) + assert cpm.is_call_to(node.fun, "as_fieldop") + + fun_node = node.fun + assert len(fun_node.args) == 2 + scan_expr, domain_expr = fun_node.args + assert cpm.is_call_to(scan_expr, "scan") + + # parse the domain of the scan field operator + domain = extract_domain(domain_expr) + + assert len(scan_expr.args) == 3 + stencil_expr = scan_expr.args[0] + assert isinstance(stencil_expr, gtir.Lambda) + + # params[0]: the lambda parameter to propagate the scan state on the vertical dimension + scan_state = str(stencil_expr.params[0].id) + + # params[1]: boolean flag for forward/backward scan + assert isinstance(scan_expr.args[1], gtir.Literal) and ti.is_logical(scan_expr.args[1].type) + scan_forward = scan_expr.args[1].value == "True" + + # params[2]: the value for scan initialization + init_value = scan_expr.args[2] + + # the scan operator is implemented as an nested SDFG implementing the lambda expression + nsdfg = dace.SDFG(name="scan") + nsdfg.debuginfo = dace_utils.debug_info(node) + + # use the vertical dimension in the domain as scan dimension + scan_domain = [ + (dim, lower_bound, upper_bound) + for dim, lower_bound, upper_bound in domain + if dim.kind == gtx_common.DimensionKind.VERTICAL + ] + assert len(scan_domain) == 1 + scan_dim, scan_lower_bound, scan_upper_bound = scan_domain[0] + assert sdfg_builder.is_column_dimension(scan_dim) + scan_loop_var = dace_gtir_utils.get_map_variable(scan_dim) + _, scan_output_offset, scan_output_shape = _get_field_layout(scan_domain) + + # create field operator on the horizontal domain + horizontal_domain = [ + (dim, lower_bound, upper_bound) + for dim, lower_bound, upper_bound in domain + if dim.kind == gtx_common.DimensionKind.HORIZONTAL + ] + + # create a loop region for lambda call over the scan dimension + if scan_forward: + scan_loop = dace.sdfg.state.LoopRegion( + label="scan", + condition_expr=f"{scan_loop_var} < {scan_upper_bound}", + loop_var=scan_loop_var, + initialize_expr=f"{scan_loop_var} = {scan_lower_bound}", + update_expr=f"{scan_loop_var} = {scan_loop_var} + 1", + inverted=False, + ) + else: + scan_loop = dace.sdfg.state.LoopRegion( + label="scan", + condition_expr=f"{scan_loop_var} >= {scan_lower_bound}", + loop_var=scan_loop_var, + initialize_expr=f"{scan_loop_var} = {scan_upper_bound} - 1", + update_expr=f"{scan_loop_var} = {scan_loop_var} - 1", + inverted=False, + ) + + nsdfg.add_node(scan_loop) + compute_state = scan_loop.add_state("scan_compute", is_start_block=True) + update_state = scan_loop.add_state("scan_update") + scan_loop.add_edge(compute_state, update_state, dace.InterstateEdge()) + + init_state = nsdfg.add_state("scan_init", is_start_block=True) + nsdfg.add_edge(init_state, scan_loop, dace.InterstateEdge()) + + def scan_input_name(input_name: str) -> str: + return f"__input_{input_name}" + + def scan_output_name(input_name: str) -> str: + return f"__output_{input_name}" + + # visit the initialization value of the scan expression + init_data = sdfg_builder.visit(init_value, sdfg=sdfg, head_state=state) + + # extract type definition of the scan state + scan_state_type = ( + init_data.gt_type if isinstance(init_data, FieldopData) else get_tuple_type(init_data) + ) + + # visit the list of arguments to be passed to the scan expression + nsdfg_symbols = {scan_state: scan_state_type} | { + str(p.id): arg.type for p, arg in zip(stencil_expr.params[1:], node.args, strict=True) + } + nsdfg_builder = sdfg_builder.nested_context(nsdfg, nsdfg_symbols) + fieldop_args = [ + _parse_fieldop_arg(im.ref(p.id), nsdfg, compute_state, nsdfg_builder, domain) + for p in stencil_expr.params + ] + + # generate the dataflow representing the scan field operator + taskgen = gtir_dataflow.LambdaToDataflow(nsdfg, compute_state, nsdfg_builder) + input_edges, result = taskgen.apply(stencil_expr, args=fieldop_args) + + # now initialize the scan state + scan_state_input = ( + dace_gtir_utils.get_tuple_fields(scan_state, scan_state_type) + if isinstance(scan_state_type, ts.TupleType) + else im.sym(scan_state, scan_state_type) + ) + + def init_scan_state(outer_data: FieldopData, sym: gtir.Sym) -> None: + scan_state = str(sym.id) + scan_state_desc = nsdfg.data(scan_state) + input_state = scan_input_name(scan_state) + input_state_desc = scan_state_desc.clone() + nsdfg.add_datadesc(input_state, input_state_desc) + scan_state_desc.transient = True + init_state.add_nedge( + init_state.add_access(input_state), + init_state.add_access(scan_state), + nsdfg.make_array_memlet(input_state), + ) + + init_scan_state(init_data, scan_state_input) if isinstance( + init_data, FieldopData + ) else gtx_utils.tree_map(init_scan_state)(init_data, scan_state_input) + + # connect the dataflow input directly to the source data nodes, without passing through a map node; + # the reason is that the map for horizontal domain is outside the scan loop region + for edge in input_edges: + edge.connect(map_entry=None) + + # connect the dataflow result nodes to the variables that carry the scan state along the column axis + def connect_scan_output( + scan_output_edge: gtir_dataflow.DataflowOutputEdge, sym: gtir.Sym + ) -> FieldopData: + scan_result = scan_output_edge.result + assert isinstance(scan_result, gtir_dataflow.ValueExpr) + assert isinstance(sym.type, ts.ScalarType) and scan_result.gt_dtype == sym.type + scan_result_data = scan_result.dc_node.data + scan_result_desc = scan_result.dc_node.desc(nsdfg) + + output, _ = nsdfg.add_array( + scan_output_name(sym.id), scan_output_shape, scan_result_desc.dtype, find_new_name=True + ) + output_node = compute_state.add_access(output) + compute_state.add_nedge( + scan_result.dc_node, output_node, dace.Memlet(data=output, subset=scan_loop_var) + ) + + update_state.add_nedge( + update_state.add_access(scan_result_data), + update_state.add_access(sym.id), + dace.Memlet(data=sym.id, subset="0"), + ) + + output_type = ts.FieldType(dims=[scan_dim], dtype=scan_result.gt_dtype) + return FieldopData(output_node, output_type, scan_output_offset) + + if isinstance(scan_state_input, gtir.Sym): + assert isinstance(result, gtir_dataflow.DataflowOutputEdge) + lambda_output = connect_scan_output(result, scan_state_input) + else: + assert isinstance(result, tuple) + lambda_output = gtx_utils.tree_map(connect_scan_output)(result, scan_state_input) + + # the scan nested SDFG is ready, now we need to instantiate it inside the map implementing the field operator + lambda_args = [sdfg_builder.visit(arg, sdfg=sdfg, head_state=state) for arg in node.args] + lambda_args_mapping = { + scan_input_name(scan_state): init_data, + } | { + str(param.id): arg for param, arg in zip(stencil_expr.params[1:], lambda_args, strict=True) + } + lambda_flat_args = dict( + itertools.chain(*[flatten_tuples(param, arg) for param, arg in lambda_args_mapping.items()]) + ) + lambda_flat_outs = ( + set( + str(sym.id) + for sym in dace_gtir_utils.get_tuple_fields( + scan_output_name(scan_state), scan_state_type, flatten=True + ) + ) + if isinstance(scan_state_type, ts.TupleType) + else {scan_output_name(scan_state)} + ) + + nsdfg_symbols_mapping: dict[str, dace.symbolic.SymExpr] = {} + for dim, _, _ in horizontal_domain: + if dim != scan_dim: + dim_map_variable = dace_gtir_utils.get_map_variable(dim) + nsdfg_symbols_mapping[dim_map_variable] = dim_map_variable + for inner, arg in lambda_flat_args.items(): + inner_desc = nsdfg.data(inner) + outer_desc = arg.dc_node.desc(sdfg) + nsdfg_symbols_mapping |= { + str(nested_symbol): parent_symbol + for nested_symbol, parent_symbol in zip( + [*inner_desc.shape, *inner_desc.strides], + [*outer_desc.shape, *outer_desc.strides], + strict=True, + ) + if isinstance(nested_symbol, dace.symbol) + } + + nsdfg_node = state.add_nested_sdfg( + nsdfg, + sdfg, + inputs=set(lambda_flat_args.keys()), + outputs=lambda_flat_outs, + symbol_mapping=nsdfg_symbols_mapping, + ) + + input_edges = [] + for input_connector, arg in lambda_flat_args.items(): + arg_desc = arg.dc_node.desc(sdfg) + input_subset = sbs.Range.from_array(arg_desc) + input_edge = gtir_dataflow.MemletInputEdge( + state, arg.dc_node, input_subset, nsdfg_node, input_connector + ) + input_edges.append(input_edge) + + def construct_output_edge(scan_data: FieldopData) -> gtir_dataflow.DataflowOutputEdge: + assert isinstance(scan_data.gt_type, ts.FieldType) + inner_data = scan_data.dc_node.data + inner_desc = nsdfg.data(inner_data) + output_data, output_desc = sdfg.add_temp_transient_like(inner_desc) + output_node = state.add_access(output_data) + state.add_edge( + nsdfg_node, + inner_data, + output_node, + None, + dace.Memlet.from_array(output_data, output_desc), + ) + output_expr = gtir_dataflow.MemletExpr( + output_node, scan_data.gt_type.dtype, sbs.Range.from_array(output_desc) + ) + return gtir_dataflow.DataflowOutputEdge(state, output_expr) + + if isinstance(lambda_output, FieldopData): + output_edges = construct_output_edge(lambda_output) + else: + output_edges = gtx_utils.tree_map(construct_output_edge)(lambda_output) + + return _create_field_operator( + sdfg, state, domain, node.type, sdfg_builder, input_edges, output_edges, scan_dim + ) + + def translate_symbol_ref( node: gtir.Node, sdfg: dace.SDFG, @@ -784,5 +1092,6 @@ def translate_symbol_ref( translate_make_tuple, translate_tuple_get, translate_scalar_expr, + translate_scan, translate_symbol_ref, ] diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow.py index cfba4d61e5..b45b03b0ce 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow.py @@ -16,7 +16,7 @@ import dace.subsets as sbs from gt4py import eve -from gt4py.next import common as gtx_common +from gt4py.next import common as gtx_common, utils as gtx_utils from gt4py.next.iterator import ir as gtir from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm, ir_makers as im from gt4py.next.iterator.type_system import type_specifications as itir_ts @@ -68,7 +68,7 @@ class MemletExpr: dc_node: dace.nodes.AccessNode gt_dtype: itir_ts.ListType | ts.ScalarType - subset: sbs.Indices | sbs.Range + subset: sbs.Range @dataclasses.dataclass(frozen=True) @@ -138,7 +138,7 @@ class DataflowInputEdge(Protocol): """ @abc.abstractmethod - def connect(self, me: dace.nodes.MapEntry) -> None: ... + def connect(self, map_entry: Optional[dace.nodes.MapEntry]) -> None: ... @dataclasses.dataclass(frozen=True) @@ -156,15 +156,18 @@ class MemletInputEdge(DataflowInputEdge): dest: dace.nodes.AccessNode | dace.nodes.Tasklet dest_conn: Optional[str] - def connect(self, me: dace.nodes.MapEntry) -> None: + def connect(self, map_entry: Optional[dace.nodes.MapEntry]) -> None: memlet = dace.Memlet(data=self.source.data, subset=self.subset) - self.state.add_memlet_path( - self.source, - me, - self.dest, - dst_conn=self.dest_conn, - memlet=memlet, - ) + if map_entry is None: + self.state.add_edge(self.source, None, self.dest, self.dest_conn, memlet) + else: + self.state.add_memlet_path( + self.source, + map_entry, + self.dest, + dst_conn=self.dest_conn, + memlet=memlet, + ) @dataclasses.dataclass(frozen=True) @@ -179,7 +182,8 @@ class EmptyInputEdge(DataflowInputEdge): state: dace.SDFGState node: dace.nodes.Tasklet - def connect(self, me: dace.nodes.MapEntry) -> None: + def connect(self, me: Optional[dace.nodes.MapEntry]) -> None: + assert me is not None self.state.add_nedge(me, self.node, dace.Memlet()) @@ -276,7 +280,7 @@ class LambdaToDataflow(eve.NodeVisitor): state: dace.SDFGState subgraph_builder: gtir_sdfg.DataflowBuilder input_edges: list[DataflowInputEdge] - symbol_map: dict[str, IteratorExpr | MemletExpr | SymbolExpr] + symbol_map: dict[str, tuple[IteratorExpr | MemletExpr | SymbolExpr, ...]] def __init__( self, @@ -1248,13 +1252,39 @@ def _visit_generic_builtin(self, node: gtir.FunCall) -> ValueExpr: return self._construct_tasklet_result(dc_dtype, tasklet_node, "result", use_array=use_array) - def visit_FunCall(self, node: gtir.FunCall) -> IteratorExpr | DataExpr: + def _visit_make_tuple(self, node: gtir.FunCall) -> tuple[IteratorExpr | DataExpr]: + assert cpm.is_call_to(node, "make_tuple") + return tuple(self.visit(arg) for arg in node.args) + + def _visit_tuple_get( + self, node: gtir.FunCall + ) -> IteratorExpr | DataExpr | tuple[IteratorExpr | DataExpr]: + assert cpm.is_call_to(node, "tuple_get") + assert len(node.args) == 2 + + if not isinstance(node.args[0], gtir.Literal): + raise ValueError("Tuple can only be subscripted with compile-time constants.") + assert ti.is_integral(node.args[0].type) + index = int(node.args[0].value) + + tuple_fields = self.visit(node.args[1]) + return tuple_fields[index] + + def visit_FunCall( + self, node: gtir.FunCall + ) -> IteratorExpr | DataExpr | tuple[DataflowOutputEdge, ...]: if cpm.is_call_to(node, "deref"): return self._visit_deref(node) elif cpm.is_call_to(node, "neighbors"): return self._visit_neighbors(node) + elif cpm.is_call_to(node, "make_tuple"): + return self._visit_make_tuple(node) + + elif cpm.is_call_to(node, "tuple_get"): + return self._visit_tuple_get(node) + elif cpm.is_applied_map(node): return self._visit_map(node) @@ -1264,6 +1294,10 @@ def visit_FunCall(self, node: gtir.FunCall) -> IteratorExpr | DataExpr: elif cpm.is_applied_shift(node): return self._visit_shift(node) + elif isinstance(node.fun, gtir.Lambda): + lambda_args = [self.visit(arg) for arg in node.args] + return self.visit_Lambda(node.fun, args=lambda_args) + elif isinstance(node.fun, gtir.SymRef): return self._visit_generic_builtin(node) @@ -1271,41 +1305,76 @@ def visit_FunCall(self, node: gtir.FunCall) -> IteratorExpr | DataExpr: raise NotImplementedError(f"Invalid 'FunCall' node: {node}.") def visit_Lambda( - self, node: gtir.Lambda, args: list[IteratorExpr | MemletExpr | SymbolExpr] - ) -> tuple[list[DataflowInputEdge], DataflowOutputEdge]: + self, node: gtir.Lambda, args: list[tuple[IteratorExpr | MemletExpr | SymbolExpr, ...]] + ) -> DataflowOutputEdge | tuple[DataflowOutputEdge, ...]: + # lambda arguments are mapped to symbols defined in lambda scope + prev_symbols: dict[str, Optional[tuple[IteratorExpr | MemletExpr | SymbolExpr, ...]]] = {} for p, arg in zip(node.params, args, strict=True): - self.symbol_map[str(p.id)] = arg - output_expr: DataExpr = self.visit(node.expr) - if isinstance(output_expr, ValueExpr): - return self.input_edges, DataflowOutputEdge(self.state, output_expr) - - if isinstance(output_expr, MemletExpr): - # special case where the field operator is simply copying data from source to destination node - output_dtype = output_expr.dc_node.desc(self.sdfg).dtype - tasklet_node = self._add_tasklet("copy", {"__inp"}, {"__out"}, "__out = __inp") - self._add_input_data_edge( - output_expr.dc_node, - output_expr.subset, - tasklet_node, - "__inp", - ) - else: - assert isinstance(output_expr, SymbolExpr) - # even simpler case, where a constant value is written to destination node - output_dtype = output_expr.dc_dtype - tasklet_node = self._add_tasklet("write", {}, {"__out"}, f"__out = {output_expr.value}") + symbol_name = str(p.id) + prev_symbols[symbol_name] = self.symbol_map.get(symbol_name, None) + self.symbol_map[symbol_name] = arg + + result = self.visit(node.expr) + + # remove locally defined lambda symbols and restore previous symbols + for symbol_name, arg in prev_symbols.items(): + if arg is None: + self.symbol_map.pop(symbol_name) + else: + self.symbol_map[symbol_name] = arg + + def make_output_edge( + output_expr: ValueExpr | MemletExpr | SymbolExpr, + ) -> DataflowOutputEdge: + if isinstance(output_expr, ValueExpr): + return DataflowOutputEdge(self.state, output_expr) + + if isinstance(output_expr, MemletExpr): + # special case where the field operator is simply copying data from source to destination node + output_dtype = output_expr.dc_node.desc(self.sdfg).dtype + tasklet_node = self._add_tasklet("copy", {"__inp"}, {"__out"}, "__out = __inp") + self._add_input_data_edge( + output_expr.dc_node, + output_expr.subset, + tasklet_node, + "__inp", + ) + else: + # even simpler case, where a constant value is written to destination node + output_dtype = output_expr.dc_dtype + tasklet_node = self._add_tasklet( + "write", {}, {"__out"}, f"__out = {output_expr.value}" + ) + + output_expr = self._construct_tasklet_result(output_dtype, tasklet_node, "__out") + return DataflowOutputEdge(self.state, output_expr) - output_expr = self._construct_tasklet_result(output_dtype, tasklet_node, "__out") - return self.input_edges, DataflowOutputEdge(self.state, output_expr) + def parse_result( + r: DataflowOutputEdge | ValueExpr | MemletExpr | SymbolExpr, + ) -> DataflowOutputEdge: + if isinstance(r, DataflowOutputEdge): + return r + return make_output_edge(r) + + if isinstance(result, tuple): + return gtx_utils.tree_map(parse_result)(result) + else: + return parse_result(result) def visit_Literal(self, node: gtir.Literal) -> SymbolExpr: dc_dtype = dace_utils.as_dace_type(node.type) return SymbolExpr(node.value, dc_dtype) - def visit_SymRef(self, node: gtir.SymRef) -> IteratorExpr | MemletExpr | SymbolExpr: + def visit_SymRef(self, node: gtir.SymRef) -> tuple[IteratorExpr | MemletExpr | SymbolExpr, ...]: param = str(node.id) if param in self.symbol_map: return self.symbol_map[param] # if not in the lambda symbol map, this must be a symref to a builtin function assert param in gtir_python_codegen.MATH_BUILTINS_MAPPING return SymbolExpr(param, dace.string) + + def apply( + self, node: gtir.Lambda, args: list[tuple[IteratorExpr | MemletExpr | SymbolExpr, ...]] + ) -> tuple[list[DataflowInputEdge], tuple[DataflowOutputEdge]]: + output_edges = self.visit_Lambda(node, args=args) + return self.input_edges, output_edges diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_sdfg.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_sdfg.py index 6b5e164458..304a5b47fe 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_sdfg.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_sdfg.py @@ -22,6 +22,7 @@ from typing import Any, Dict, Iterable, List, Optional, Protocol, Sequence, Set, Tuple, Union import dace +from dace.sdfg import utils as sdutils from gt4py import eve from gt4py.eve import concepts @@ -111,6 +112,18 @@ def get_symbol_type(self, symbol_name: str) -> ts.DataType: """Retrieve the GT4Py type of a symbol used in the SDFG.""" ... + @abc.abstractmethod + def is_column_dimension(self, dim: gtx_common.Dimension) -> bool: + """Check if the given dimension is the column dimension.""" + ... + + @abc.abstractmethod + def nested_context( + self, sdfg: dace.SDFG, global_symbols: dict[str, ts.DataType] + ) -> SDFGBuilder: + """Create a new empty context, useful to build a nested SDFG.""" + ... + @abc.abstractmethod def visit(self, node: concepts.RootNode, **kwargs: Any) -> Any: """Visit a node of the GT4Py IR.""" @@ -149,15 +162,6 @@ def _collect_symbols_in_domain_expressions( ) -def _get_tuple_type(data: tuple[gtir_builtin_translators.FieldopResult, ...]) -> ts.TupleType: - """ - Compute the `ts.TupleType` corresponding to the structure of a tuple of data nodes. - """ - return ts.TupleType( - types=[_get_tuple_type(d) if isinstance(d, tuple) else d.gt_type for d in data] - ) - - @dataclasses.dataclass(frozen=True) class GTIRToSDFG(eve.NodeVisitor, SDFGBuilder): """Provides translation capability from a GTIR program to a DaCe SDFG. @@ -173,6 +177,7 @@ class GTIRToSDFG(eve.NodeVisitor, SDFGBuilder): """ offset_provider_type: gtx_common.OffsetProviderType + column_dim: Optional[gtx_common.Dimension] global_symbols: dict[str, ts.DataType] = dataclasses.field(default_factory=lambda: {}) field_offsets: dict[str, Optional[list[dace.symbolic.SymExpr]]] = dataclasses.field( default_factory=lambda: {} @@ -199,6 +204,22 @@ def make_field( def get_symbol_type(self, symbol_name: str) -> ts.DataType: return self.global_symbols[symbol_name] + def is_column_dimension(self, dim: gtx_common.Dimension) -> bool: + assert self.column_dim + return dim == self.column_dim + + def nested_context( + self, sdfg: dace.SDFG, global_symbols: dict[str, ts.DataType] + ) -> SDFGBuilder: + nsdfg_builder = GTIRToSDFG( + self.offset_provider_type, self.column_dim, global_symbols, self.field_offsets + ) + nsdfg_params = [ + gtir.Sym(id=p_name, type=p_type) for p_name, p_type in global_symbols.items() + ] + nsdfg_builder._add_sdfg_params(sdfg, node_params=nsdfg_params, symbolic_arguments={}) + return nsdfg_builder + def unique_nsdfg_name(self, sdfg: dace.SDFG, prefix: str) -> str: nsdfg_list = [ nsdfg.label for nsdfg in sdfg.all_sdfgs_recursive() if nsdfg.label.startswith(prefix) @@ -277,10 +298,10 @@ def _add_storage( """ if isinstance(gt_type, ts.TupleType): tuple_fields = [] - for tname, ttype in dace_gtir_utils.get_tuple_fields(name, gt_type, flatten=True): + for sym in dace_gtir_utils.get_tuple_fields(name, gt_type, flatten=True): tuple_fields.extend( self._add_storage( - sdfg, symbolic_arguments, tname, ttype, transient, tuple_name=name + sdfg, symbolic_arguments, sym.id, sym.type, transient, tuple_name=name ) ) return tuple_fields @@ -602,7 +623,7 @@ def visit_Lambda( node: gtir.Lambda, sdfg: dace.SDFG, head_state: dace.SDFGState, - args: list[gtir_builtin_translators.FieldopResult], + args: Sequence[gtir_builtin_translators.FieldopResult], ) -> gtir_builtin_translators.FieldopResult: """ Translates a `Lambda` node to a nested SDFG in the current state. @@ -619,24 +640,13 @@ def visit_Lambda( (str(param.id), arg) for param, arg in zip(node.params, args, strict=True) ] - def flatten_tuples( - name: str, - arg: gtir_builtin_translators.FieldopResult, - ) -> list[tuple[str, gtir_builtin_translators.FieldopData]]: - if isinstance(arg, tuple): - tuple_type = _get_tuple_type(arg) - tuple_field_names = [ - arg_name for arg_name, _ in dace_gtir_utils.get_tuple_fields(name, tuple_type) - ] - tuple_args = zip(tuple_field_names, arg, strict=True) - return list( - itertools.chain(*[flatten_tuples(fname, farg) for fname, farg in tuple_args]) - ) - else: - return [(name, arg)] - lambda_arg_nodes = dict( - itertools.chain(*[flatten_tuples(pname, arg) for pname, arg in lambda_args_mapping]) + itertools.chain( + *[ + gtir_builtin_translators.flatten_tuples(pname, arg) + for pname, arg in lambda_args_mapping + ] + ) ) # inherit symbols from parent scope but eventually override with local symbols @@ -644,7 +654,9 @@ def flatten_tuples( sym: self.global_symbols[sym] for sym in symbol_ref_utils.collect_symbol_refs(node.expr, self.global_symbols.keys()) } | { - pname: _get_tuple_type(arg) if isinstance(arg, tuple) else arg.gt_type + pname: gtir_builtin_translators.get_tuple_type(arg) + if isinstance(arg, tuple) + else arg.gt_type for pname, arg in lambda_args_mapping } @@ -659,12 +671,12 @@ def get_field_domain_offset( elif field_domain_offset := self.field_offsets.get(p_name, None): return {p_name: field_domain_offset} elif isinstance(p_type, ts.TupleType): - p_fields = dace_gtir_utils.get_tuple_fields(p_name, p_type, flatten=True) + tsyms = dace_gtir_utils.get_tuple_fields(p_name, p_type, flatten=True) return functools.reduce( - lambda field_offsets, field: ( - field_offsets | get_field_domain_offset(field[0], field[1]) + lambda field_offsets, sym: ( + field_offsets | get_field_domain_offset(sym.id, sym.type) ), - p_fields, + tsyms, {}, ) return {} @@ -676,7 +688,7 @@ def get_field_domain_offset( # lower let-statement lambda node as a nested SDFG lambda_translator = GTIRToSDFG( - self.offset_provider_type, lambda_symbols, lambda_field_offsets + self.offset_provider_type, self.column_dim, lambda_symbols, lambda_field_offsets ) nsdfg = dace.SDFG(name=self.unique_nsdfg_name(sdfg, "lambda")) nstate = nsdfg.add_state("lambda") @@ -852,6 +864,7 @@ def visit_SymRef( def build_sdfg_from_gtir( ir: gtir.Program, offset_provider_type: gtx_common.OffsetProviderType, + column_dim: Optional[gtx_common.Dimension], ) -> dace.SDFG: """ Receives a GTIR program and lowers it to a DaCe SDFG. @@ -862,6 +875,7 @@ def build_sdfg_from_gtir( Args: ir: The GTIR program node to be lowered to SDFG offset_provider_type: The definitions of offset providers used by the program node + column_dim: Vertical dimension used for scan expressions. Returns: An SDFG in the DaCe canonical form (simplified) @@ -869,8 +883,11 @@ def build_sdfg_from_gtir( ir = gtir_type_inference.infer(ir, offset_provider_type=offset_provider_type) ir = ir_prune_casts.PruneCasts().visit(ir) - sdfg_genenerator = GTIRToSDFG(offset_provider_type) + sdfg_genenerator = GTIRToSDFG(offset_provider_type, column_dim) sdfg = sdfg_genenerator.visit(ir) assert isinstance(sdfg, dace.SDFG) + # TODO(edopao): remove `inline_loop_blocks` when DaCe transformations support LoopRegion construct + sdutils.inline_loop_blocks(sdfg) + return sdfg diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/utility.py b/src/gt4py/next/program_processors/runners/dace_fieldview/utility.py index 118f0449c8..bad8e2f585 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/utility.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/utility.py @@ -16,6 +16,7 @@ from gt4py import eve from gt4py.next import common as gtx_common from gt4py.next.iterator import ir as gtir +from gt4py.next.iterator.ir_utils import ir_makers as im from gt4py.next.type_system import type_specifications as ts @@ -29,7 +30,7 @@ def get_map_variable(dim: gtx_common.Dimension) -> str: def get_tuple_fields( tuple_name: str, tuple_type: ts.TupleType, flatten: bool = False -) -> list[tuple[str, ts.DataType]]: +) -> tuple[gtir.Sym, ...]: """ Creates a list of names with the corresponding data type for all elements of the given tuple. @@ -46,16 +47,16 @@ def get_tuple_fields( ... ] """ fields = [(f"{tuple_name}_{i}", field_type) for i, field_type in enumerate(tuple_type.types)] + expanded_fields = tuple( + get_tuple_fields(field_name, field_type, flatten) + if isinstance(field_type, ts.TupleType) + else im.sym(field_name, field_type) + for field_name, field_type in fields + ) if flatten: - expanded_fields = [ - get_tuple_fields(field_name, field_type) - if isinstance(field_type, ts.TupleType) - else [(field_name, field_type)] - for field_name, field_type in fields - ] - return list(itertools.chain(*expanded_fields)) + return tuple(itertools.chain(expanded_fields)) else: - return fields + return expanded_fields def replace_invalid_symbols(sdfg: dace.SDFG, ir: gtir.Program) -> gtir.Program: diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/workflow.py b/src/gt4py/next/program_processors/runners/dace_fieldview/workflow.py index 40d44f5ab0..1f61f159e9 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/workflow.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/workflow.py @@ -47,13 +47,13 @@ def generate_sdfg( self, ir: itir.Program, offset_provider: common.OffsetProvider, - column_axis: Optional[common.Dimension], + column_dim: Optional[common.Dimension], auto_opt: bool, on_gpu: bool, ) -> dace.SDFG: ir = itir_transforms.apply_fieldview_transforms(ir, offset_provider=offset_provider) sdfg = gtir_sdfg.build_sdfg_from_gtir( - ir, offset_provider_type=common.offset_provider_to_type(offset_provider) + ir, common.offset_provider_to_type(offset_provider), column_dim ) if auto_opt: diff --git a/tests/next_tests/definitions.py b/tests/next_tests/definitions.py index 349d3e9f70..26140afbc9 100644 --- a/tests/next_tests/definitions.py +++ b/tests/next_tests/definitions.py @@ -155,7 +155,6 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum): ] GTIR_DACE_SKIP_TEST_LIST = DOMAIN_INFERENCE_SKIP_LIST + [ (USES_NEGATIVE_MODULO, XFAIL, UNSUPPORTED_MESSAGE), - (USES_SCAN, XFAIL, UNSUPPORTED_MESSAGE), (USES_SPARSE_FIELDS_AS_OUTPUT, XFAIL, UNSUPPORTED_MESSAGE), ] EMBEDDED_SKIP_LIST = [ From c26d90656fea882d06ebb61941e69f5b6302857d Mon Sep 17 00:00:00 2001 From: Edoardo Paone Date: Wed, 4 Dec 2024 13:50:34 +0100 Subject: [PATCH 02/80] Improve utility functions for tuples --- .../gtir_builtin_translators.py | 34 +++++++-------- .../runners/dace_fieldview/gtir_sdfg.py | 6 +-- .../runners/dace_fieldview/utility.py | 41 +++++++++++-------- 3 files changed, 44 insertions(+), 37 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py index 702215f97d..ffe4020a02 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py @@ -172,11 +172,12 @@ def get_tuple_type(data: tuple[FieldopResult, ...]) -> ts.TupleType: def flatten_tuples(name: str, arg: FieldopResult) -> list[tuple[str, FieldopData]]: if isinstance(arg, tuple): tuple_type = get_tuple_type(arg) - tuple_field_names = [ - str(sym.id) for sym in dace_gtir_utils.get_tuple_fields(name, tuple_type) + tuple_symbols = dace_gtir_utils.flatten_tuple_fields(name, tuple_type) + tuple_data_fields = gtx_utils.flatten_nested_tuple(arg) + return [ + (str(tsym.id), tfield) + for tsym, tfield in zip(tuple_symbols, tuple_data_fields, strict=True) ] - tuple_args = zip(tuple_field_names, arg, strict=True) - return list(itertools.chain(*[flatten_tuples(fname, farg) for fname, farg in tuple_args])) else: return [(name, arg)] @@ -329,15 +330,16 @@ def create_field(output_edge: gtir_dataflow.DataflowOutputEdge, sym: gtir.Sym) - else: assert isinstance(sym.type.dtype, itir_ts.ListType) assert output_edge.result.gt_dtype.element_type == sym.type.dtype.element_type - assert isinstance(dataflow_output_desc, dace.data.Array) assert isinstance(output_edge.result.gt_dtype.element_type, ts.ScalarType) field_dtype = output_edge.result.gt_dtype.element_type + assert isinstance(dataflow_output_desc, dace.data.Array) + assert len(dataflow_output_desc.shape) == 1 # extend the array with the local dimensions added by the field operator (e.g. `neighbors`) assert output_edge.result.gt_dtype.offset_type is not None field_dims = [*domain_dims, output_edge.result.gt_dtype.offset_type] - field_shape = [*domain_shape, dataflow_output_desc.shape] - field_offset = [*domain_offset, dataflow_output_desc.offset] - field_subset = [*domain_subset, sbs.Range.from_array(dataflow_output_desc)] + field_shape = [*domain_shape, dataflow_output_desc.shape[0]] + field_offset = [*domain_offset, dataflow_output_desc.offset[0]] + field_subset = domain_subset + sbs.Range.from_array(dataflow_output_desc) # allocate local temporary storage field_name, _ = sdfg.add_temp_transient(field_shape, dataflow_output_desc.dtype) @@ -358,7 +360,7 @@ def create_field(output_edge: gtir_dataflow.DataflowOutputEdge, sym: gtir.Sym) - else: assert isinstance(node_type, ts.TupleType) return gtx_utils.tree_map(create_field)( - output_edges, dace_gtir_utils.get_tuple_fields("x", node_type) + output_edges, dace_gtir_utils.make_symbol_tuple("x", node_type) ) @@ -448,11 +450,11 @@ def translate_as_fieldop( # represent the field operator as a mapped tasklet graph, which will range over the field domain taskgen = gtir_dataflow.LambdaToDataflow(sdfg, state, sdfg_builder) - input_edges, output_edges = taskgen.apply(stencil_expr, args=fieldop_args) - assert len(output_edges) == 1 + input_edges, output_edge = taskgen.apply(stencil_expr, args=fieldop_args) + assert isinstance(output_edge, gtir_dataflow.DataflowOutputEdge) return _create_field_operator( - sdfg, state, domain, node.type, sdfg_builder, input_edges, output_edges[0] + sdfg, state, domain, node.type, sdfg_builder, input_edges, output_edge ) @@ -623,7 +625,7 @@ def _get_data_nodes( return sdfg_builder.make_field(data_node, data_type) elif isinstance(data_type, ts.TupleType): - tuple_syms = dace_gtir_utils.get_tuple_fields(data_name, data_type) + tuple_syms = dace_gtir_utils.make_symbol_tuple(data_name, data_type) return gtx_utils.tree_map( lambda sym: _get_data_nodes(sdfg, state, sdfg_builder, sym.id, sym.type) )(tuple_syms) @@ -918,7 +920,7 @@ def scan_output_name(input_name: str) -> str: # now initialize the scan state scan_state_input = ( - dace_gtir_utils.get_tuple_fields(scan_state, scan_state_type) + dace_gtir_utils.make_symbol_tuple(scan_state, scan_state_type) if isinstance(scan_state_type, ts.TupleType) else im.sym(scan_state, scan_state_type) ) @@ -992,8 +994,8 @@ def connect_scan_output( lambda_flat_outs = ( set( str(sym.id) - for sym in dace_gtir_utils.get_tuple_fields( - scan_output_name(scan_state), scan_state_type, flatten=True + for sym in dace_gtir_utils.flatten_tuple_fields( + scan_output_name(scan_state), scan_state_type ) ) if isinstance(scan_state_type, ts.TupleType) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_sdfg.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_sdfg.py index 304a5b47fe..7c97144ab8 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_sdfg.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_sdfg.py @@ -298,7 +298,7 @@ def _add_storage( """ if isinstance(gt_type, ts.TupleType): tuple_fields = [] - for sym in dace_gtir_utils.get_tuple_fields(name, gt_type, flatten=True): + for sym in dace_gtir_utils.flatten_tuple_fields(name, gt_type): tuple_fields.extend( self._add_storage( sdfg, symbolic_arguments, sym.id, sym.type, transient, tuple_name=name @@ -671,7 +671,7 @@ def get_field_domain_offset( elif field_domain_offset := self.field_offsets.get(p_name, None): return {p_name: field_domain_offset} elif isinstance(p_type, ts.TupleType): - tsyms = dace_gtir_utils.get_tuple_fields(p_name, p_type, flatten=True) + tsyms = dace_gtir_utils.flatten_tuple_fields(p_name, p_type) return functools.reduce( lambda field_offsets, sym: ( field_offsets | get_field_domain_offset(sym.id, sym.type) @@ -864,7 +864,7 @@ def visit_SymRef( def build_sdfg_from_gtir( ir: gtir.Program, offset_provider_type: gtx_common.OffsetProviderType, - column_dim: Optional[gtx_common.Dimension], + column_dim: Optional[gtx_common.Dimension] = None, ) -> dace.SDFG: """ Receives a GTIR program and lowers it to a DaCe SDFG. diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/utility.py b/src/gt4py/next/program_processors/runners/dace_fieldview/utility.py index bad8e2f585..842f06e899 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/utility.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/utility.py @@ -8,13 +8,12 @@ from __future__ import annotations -import itertools from typing import Dict, TypeVar import dace from gt4py import eve -from gt4py.next import common as gtx_common +from gt4py.next import common as gtx_common, utils as gtx_utils from gt4py.next.iterator import ir as gtir from gt4py.next.iterator.ir_utils import ir_makers as im from gt4py.next.type_system import type_specifications as ts @@ -28,35 +27,41 @@ def get_map_variable(dim: gtx_common.Dimension) -> str: return f"i_{dim.value}_gtx_{dim.kind}{suffix}" -def get_tuple_fields( - tuple_name: str, tuple_type: ts.TupleType, flatten: bool = False -) -> tuple[gtir.Sym, ...]: +def make_symbol_tuple(tuple_name: str, tuple_type: ts.TupleType) -> tuple[gtir.Sym, ...]: """ - Creates a list of names with the corresponding data type for all elements of the given tuple. + Creates a tuple representation of the symbols corresponding to the tuple fields. + The constructed tuple presrves the nested nature of the type, is any. Examples -------- >>> sty = ts.ScalarType(kind=ts.ScalarKind.INT32) >>> fty = ts.FieldType(dims=[], dtype=ts.ScalarType(kind=ts.ScalarKind.FLOAT32)) >>> t = ts.TupleType(types=[sty, ts.TupleType(types=[fty, sty])]) - >>> assert get_tuple_fields("a", t) == [("a_0", sty), ("a_1", ts.TupleType(types=[fty, sty]))] - >>> assert get_tuple_fields("a", t, flatten=True) == [ - ... ("a_0", sty), - ... ("a_1_0", fty), - ... ("a_1_1", sty), - ... ] + >>> assert get_tuple_fields("a", t) == [("a_0", sty), (("a_1_0", fty), ("a_1_1", sty))] """ fields = [(f"{tuple_name}_{i}", field_type) for i, field_type in enumerate(tuple_type.types)] - expanded_fields = tuple( - get_tuple_fields(field_name, field_type, flatten) + return tuple( + make_symbol_tuple(field_name, field_type) if isinstance(field_type, ts.TupleType) else im.sym(field_name, field_type) for field_name, field_type in fields ) - if flatten: - return tuple(itertools.chain(expanded_fields)) - else: - return expanded_fields + + +def flatten_tuple_fields(tuple_name: str, tuple_type: ts.TupleType) -> list[gtir.Sym]: + """ + Creates a list of names with the corresponding data type for all elements of the given tuple. + + Examples + -------- + >>> sty = ts.ScalarType(kind=ts.ScalarKind.INT32) + >>> fty = ts.FieldType(dims=[], dtype=ts.ScalarType(kind=ts.ScalarKind.FLOAT32)) + >>> t = ts.TupleType(types=[sty, ts.TupleType(types=[fty, sty])]) + >>> assert get_tuple_fields("a", t) == [("a_0", sty), ("a_1", ts.TupleType(types=[fty, sty]))] + >>> assert flatten_tuple_fields("a", t) == [("a_0", sty), ("a_1_0", fty), ("a_1_1", sty)] + """ + symbol_tuple = make_symbol_tuple(tuple_name, tuple_type) + return list(gtx_utils.flatten_nested_tuple(symbol_tuple)) def replace_invalid_symbols(sdfg: dace.SDFG, ir: gtir.Program) -> gtir.Program: From ba0a9bab8c88cce330890ecb5b105d84fc46caea Mon Sep 17 00:00:00 2001 From: Edoardo Paone Date: Wed, 4 Dec 2024 15:16:11 +0100 Subject: [PATCH 03/80] Fix for empty field domain --- .../gtir_builtin_translators.py | 39 +++++++++++-------- .../runners/dace_fieldview/gtir_dataflow.py | 19 +++++---- 2 files changed, 34 insertions(+), 24 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py index ffe4020a02..9c8c636140 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py @@ -292,16 +292,26 @@ def _create_field_operator( domain_indices = _get_domain_indices(domain_dims, domain_offset) domain_subset = sbs.Range.from_indices(domain_indices) - # create map range corresponding to the field operator domain - me, mx = sdfg_builder.add_map( - "fieldop", - state, - ndrange={ - dace_gtir_utils.get_map_variable(dim): f"{lower_bound}:{upper_bound}" - for dim, lower_bound, upper_bound in domain - if dim != scan_dim - }, - ) + if scan_dim is not None: + assert domain_dims.index(scan_dim) == (len(domain_dims) - 1) + # we construct the fieldo operator only on the horizontal domain + domain_subset = sbs.Range(domain_subset[:-1]) + + # now check, after removal of the vertical dimension, whether the domain is empty + if len(domain_subset) == 0: + # no need to create a map scope, the field operator domain is empty + me, mx = (None, None) + else: + # create map range corresponding to the field operator domain + me, mx = sdfg_builder.add_map( + "fieldop", + state, + ndrange={ + dace_gtir_utils.get_map_variable(dim): f"{lower_bound}:{upper_bound}" + for dim, lower_bound, upper_bound in domain + if dim != scan_dim + }, + ) # here we setup the edges passing through the map entry node for edge in input_edges: @@ -316,14 +326,11 @@ def create_field(output_edge: gtir_dataflow.DataflowOutputEdge, sym: gtir.Sym) - field_dtype = output_edge.result.gt_dtype field_dims, field_shape, field_offset = (domain_dims, domain_shape, domain_offset) if scan_dim is not None: - # this is the case of scan expressions, that produce a 1D vertical field - assert domain_dims.index(scan_dim) == (len(domain_dims) - 1) + # the scan field operator produces a 1D vertical field assert isinstance(dataflow_output_desc, dace.data.Array) assert len(dataflow_output_desc.shape) == 1 # the vertical dimension should not belong to the field operator domain - field_subset = sbs.Range(domain_subset[:-1]) + sbs.Range.from_array( - dataflow_output_desc - ) + field_subset = domain_subset + sbs.Range.from_array(dataflow_output_desc) else: assert isinstance(dataflow_output_desc, dace.data.Scalar) field_subset = domain_subset @@ -840,7 +847,7 @@ def translate_scan( init_value = scan_expr.args[2] # the scan operator is implemented as an nested SDFG implementing the lambda expression - nsdfg = dace.SDFG(name="scan") + nsdfg = dace.SDFG(sdfg_builder.unique_nsdfg_name(sdfg, "scan")) nsdfg.debuginfo = dace_utils.debug_info(node) # use the vertical dimension in the domain as scan dimension diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow.py index b45b03b0ce..649c2cad38 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow.py @@ -204,7 +204,7 @@ class DataflowOutputEdge: def connect( self, - mx: dace.nodes.MapExit, + map_exit: Optional[dace.nodes.MapExit], dest: dace.nodes.AccessNode, subset: sbs.Range, ) -> None: @@ -218,13 +218,16 @@ def connect( last_node = self.result.dc_node last_node_connector = None - self.state.add_memlet_path( - last_node, - mx, - dest, - src_conn=last_node_connector, - memlet=dace.Memlet(data=dest.data, subset=subset), - ) + if map_exit is None: + self.state.add_edge(last_node, last_node_connector, dest, None, dace.Memlet(data=dest.data, subset=subset)) + else: + self.state.add_memlet_path( + last_node, + map_exit, + dest, + src_conn=last_node_connector, + memlet=dace.Memlet(data=dest.data, subset=subset), + ) DACE_REDUCTION_MAPPING: dict[str, dace.dtypes.ReductionType] = { From 784b573a6c96ec179af3df348d6e207dbbb821fd Mon Sep 17 00:00:00 2001 From: Edoardo Paone Date: Thu, 5 Dec 2024 22:23:02 +0100 Subject: [PATCH 04/80] Add exclusive if_ in dataflow --- .../gtir_builtin_translators.py | 7 +- .../runners/dace_fieldview/gtir_dataflow.py | 232 +++++++++++++++++- .../runners/dace_fieldview/gtir_sdfg.py | 17 +- .../runners/dace_fieldview/utility.py | 2 +- 4 files changed, 242 insertions(+), 16 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py index 9c8c636140..2c744b9d69 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py @@ -14,7 +14,7 @@ from typing import TYPE_CHECKING, Final, Iterable, Optional, Protocol, Sequence, TypeAlias import dace -import dace.subsets as sbs +from dace import subsets as sbs from gt4py.next import common as gtx_common, utils as gtx_utils from gt4py.next.ffront import fbuiltins as gtx_fbuiltins @@ -315,6 +315,9 @@ def _create_field_operator( # here we setup the edges passing through the map entry node for edge in input_edges: + if isinstance(edge, gtir_dataflow.EmptyInputEdge) and me is None: + # cannot create empty edge from MapEntry node, if this is not present + continue edge.connect(me) def create_field(output_edge: gtir_dataflow.DataflowOutputEdge, sym: gtir.Sym) -> FieldopData: @@ -848,7 +851,7 @@ def translate_scan( # the scan operator is implemented as an nested SDFG implementing the lambda expression nsdfg = dace.SDFG(sdfg_builder.unique_nsdfg_name(sdfg, "scan")) - nsdfg.debuginfo = dace_utils.debug_info(node) + nsdfg.debuginfo = dace_utils.debug_info(node, default=sdfg.debuginfo) # use the vertical dimension in the domain as scan dimension scan_domain = [ diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow.py index 649c2cad38..3903f317b1 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow.py @@ -10,15 +10,16 @@ import abc import dataclasses -from typing import Any, Dict, Final, List, Optional, Protocol, Set, Tuple, TypeAlias, Union +from typing import Any, Dict, Final, List, Optional, Protocol, Set, Tuple, TypeAlias, TypeVar, Union import dace -import dace.subsets as sbs +from dace import subsets as sbs from gt4py import eve from gt4py.next import common as gtx_common, utils as gtx_utils from gt4py.next.iterator import ir as gtir from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm, ir_makers as im +from gt4py.next.iterator.transforms import symbol_ref_utils from gt4py.next.iterator.type_system import type_specifications as itir_ts from gt4py.next.program_processors.runners.dace_common import utility as dace_utils from gt4py.next.program_processors.runners.dace_fieldview import ( @@ -182,9 +183,9 @@ class EmptyInputEdge(DataflowInputEdge): state: dace.SDFGState node: dace.nodes.Tasklet - def connect(self, me: Optional[dace.nodes.MapEntry]) -> None: - assert me is not None - self.state.add_nedge(me, self.node, dace.Memlet()) + def connect(self, map_entry: Optional[dace.nodes.MapEntry]) -> None: + assert map_entry is not None + self.state.add_nedge(map_entry, self.node, dace.Memlet()) @dataclasses.dataclass(frozen=True) @@ -210,7 +211,7 @@ def connect( ) -> None: # retrieve the node which writes the result last_node = self.state.in_edges(self.result.dc_node)[0].src - if isinstance(last_node, dace.nodes.Tasklet): + if isinstance(last_node, (dace.nodes.Tasklet, dace.nodes.NestedSDFG)): # the last transient node can be deleted last_node_connector = self.state.in_edges(self.result.dc_node)[0].src_conn self.state.remove_node(self.result.dc_node) @@ -219,7 +220,13 @@ def connect( last_node_connector = None if map_exit is None: - self.state.add_edge(last_node, last_node_connector, dest, None, dace.Memlet(data=dest.data, subset=subset)) + self.state.add_edge( + last_node, + last_node_connector, + dest, + None, + dace.Memlet(data=dest.data, subset=subset), + ) else: self.state.add_memlet_path( last_node, @@ -549,6 +556,201 @@ def _visit_deref(self, node: gtir.FunCall) -> DataExpr: return self._construct_tasklet_result(field_desc.dtype, deref_node, "val") + def _visit_if(self, node: gtir.FunCall) -> ValueExpr | list[ValueExpr]: + assert len(node.args) == 3 + + # TODO(edopao): enable once DaCe supports it in next release + use_conditional_block: Final[bool] = False + + condition_value = self.visit(node.args[0]) + assert ( + isinstance(condition_value, DataExpr) + and isinstance(condition_value.gt_dtype, ts.ScalarType) + and condition_value.gt_dtype.kind == ts.ScalarKind.BOOL + ) + + nsdfg = dace.SDFG(self.unique_nsdfg_name(prefix="if_stmt")) + nsdfg.debuginfo = dace_utils.debug_info(node, default=self.sdfg.debuginfo) + + if use_conditional_block: + if_region = dace.sdfg.state.ConditionalBlock("if") + nsdfg.add_node(if_region) + entry_state = nsdfg.add_state("entry", is_start_block=True) + nsdfg.add_edge(entry_state, if_region, dace.InterstateEdge()) + + if_body = dace.sdfg.state.ControlFlowRegion("if_body", sdfg=nsdfg) + tstate = if_body.add_state("true_branch", is_start_block=True) + if_region.add_branch(dace.sdfg.state.CodeBlock("__cond"), if_body) + + else_body = dace.sdfg.state.ControlFlowRegion("else_body", sdfg=nsdfg) + fstate = else_body.add_state("false_branch", is_start_block=True) + if_region.add_branch(dace.sdfg.state.CodeBlock("not (__cond)"), else_body) + + else: + entry_state = nsdfg.add_state("entry", is_start_block=True) + tstate = nsdfg.add_state("true_branch") + nsdfg.add_edge(entry_state, tstate, dace.InterstateEdge(condition="__cond")) + fstate = nsdfg.add_state("false_branch") + nsdfg.add_edge(entry_state, fstate, dace.InterstateEdge(condition="not (__cond)")) + + nsdfg_symbol_mapping = {} + input_memlets: dict[str, tuple[dace.nodes.AccessNode, Optional[sbs.Range]]] = {} + + if isinstance(condition_value, SymbolExpr): + nsdfg.add_symbol("__cond", dace.dtypes.bool) + nsdfg_symbol_mapping["__cond"] = condition_value.value + else: + nsdfg.add_scalar("__cond", dace.dtypes.bool) + if isinstance(condition_value, ValueExpr): + input_memlets["__cond"] = (condition_value.dc_node, sbs.Range.from_string("0")) + else: + assert isinstance(condition_value, MemletExpr) + input_memlets["__cond"] = (condition_value.dc_node, condition_value.subset) + nsdfg_symbol_mapping.update( + {sym: sym for sym in condition_value.subset.free_symbols} + ) + + def visit_branch( + state: dace.SDFGState, expr: gtir.Expr + ) -> tuple[list[DataflowInputEdge], tuple[DataflowOutputEdge]]: + assert state in nsdfg.states() + + T = TypeVar("T", IteratorExpr, MemletExpr, ValueExpr) + + def visit_arg(arg: T) -> T: + if isinstance(arg, IteratorExpr): + arg_node = arg.field + arg_desc = arg_node.desc(self.sdfg) + arg_subset = sbs.Range.from_array(arg_desc) + + else: + assert isinstance(arg, (MemletExpr | ValueExpr)) + arg_node = arg.dc_node + if isinstance(arg, MemletExpr): + assert set(arg.subset.size()) == {1} + arg_desc = dace.data.Scalar(arg_node.desc(self.sdfg).dtype) + arg_subset = arg.subset + else: + arg_desc = arg_node.desc(self.sdfg) + arg_subset = None + + arg_data = arg_node.data + # SDFG data containers with name prefix '__tmp' are expected to be transients + inner_data = ( + arg_data.replace("__tmp", "__arg") if arg_data.startswith("__tmp") else arg_data + ) + + try: + inner_desc = nsdfg.data(inner_data) + assert not inner_desc.transient + except KeyError: + inner_desc = arg_desc.clone() + inner_desc.transient = False + nsdfg.add_datadesc(inner_data, inner_desc) + input_memlets[inner_data] = (arg_node, arg_subset) + + if arg_subset: + # symbols used in memlet subset are not automatically mapped to the parent SDFG + nsdfg_symbol_mapping.update({sym: sym for sym in arg_subset.free_symbols}) + + inner_node = state.add_access(inner_data) + if isinstance(arg, IteratorExpr): + return IteratorExpr(inner_node, arg.gt_dtype, arg.field_domain, arg.indices) + else: + return ValueExpr(inner_node, arg.gt_dtype) + + lambda_params = [] + lambda_args = [] + for p in symbol_ref_utils.collect_symbol_refs(expr, self.symbol_map.keys()): + arg = self.symbol_map[p] + if isinstance(arg, tuple): + inner_arg = gtx_utils.tree_map(visit_arg)(arg) + else: + inner_arg = visit_arg(arg) + lambda_args.append(inner_arg) + lambda_params.append(im.sym(p)) + + lambda_node = gtir.Lambda(params=lambda_params, expr=expr) + return LambdaToDataflow(nsdfg, state, self.subgraph_builder).apply( + lambda_node, lambda_args + ) + + for state, arg in zip([tstate, fstate], node.args[1:3]): + in_edges, out_edge = visit_branch(state, arg) + for edge in in_edges: + if isinstance(edge, EmptyInputEdge): + continue + edge.connect(map_entry=None) + + def construct_output( + output_state: dace.SDFGState, edge: DataflowOutputEdge, sym: gtir.Sym + ) -> ValueExpr: + output_data = str(sym.id) + try: + output_desc = nsdfg.data(output_data) + assert not output_desc.transient + except KeyError: + result_desc = edge.result.dc_node.desc(nsdfg) + output_desc = result_desc.clone() + output_desc.transient = False + output_data = nsdfg.add_datadesc(output_data, output_desc, find_new_name=True) + output_node = output_state.add_access(output_data) + output_state.add_nedge( + edge.result.dc_node, + output_node, + dace.Memlet.from_array(output_data, output_desc), + ) + return ValueExpr(output_node, edge.result.gt_dtype) + + if isinstance(out_edge, tuple): + assert isinstance(node.type, ts.TupleType) + out_symbol = dace_gtir_utils.make_symbol_tuple("__output", node.type) + outer_value = gtx_utils.tree_map(lambda x, y: construct_output(state, x, y))(out_edge, out_symbol) + else: + assert isinstance(node.type, ts.FieldType | ts.ScalarType) + outer_value = construct_output(state, out_edge, im.sym("__output", node.type)) + + else: + result = outer_value + + outputs = {outval.dc_node.data for outval in gtx_utils.flatten_nested_tuple(result)} + + nsdfg_node = self.state.add_nested_sdfg( + nsdfg, + self.sdfg, + inputs=set(input_memlets.keys()), + outputs=outputs, + symbol_mapping=nsdfg_symbol_mapping, + ) + + for inner, (src_node, src_subset) in input_memlets.items(): + if src_subset is None: + self._add_edge( + src_node, None, nsdfg_node, inner, self.sdfg.make_array_memlet(src_node.data) + ) + else: + self._add_input_data_edge(src_node, src_subset, nsdfg_node, inner) + + def connect_output(inner_value: ValueExpr) -> ValueExpr: + inner_data = inner_value.dc_node.data + inner_desc = inner_value.dc_node.desc(nsdfg) + assert not inner_desc.transient + output, output_desc = self.sdfg.add_temp_transient_like(inner_desc) + output_node = self.state.add_access(output) + self.state.add_edge( + nsdfg_node, + inner_data, + output_node, + None, + dace.Memlet.from_array(output, output_desc), + ) + return ValueExpr(output_node, inner_value.gt_dtype) + + if isinstance(result, tuple): + return gtx_utils.tree_map(connect_output)(result) + else: + return connect_output(result) + def _visit_neighbors(self, node: gtir.FunCall) -> ValueExpr: assert isinstance(node.type, itir_ts.ListType) assert len(node.args) == 2 @@ -1279,6 +1481,9 @@ def visit_FunCall( if cpm.is_call_to(node, "deref"): return self._visit_deref(node) + elif cpm.is_call_to(node, "if_"): + return self._visit_if(node) + elif cpm.is_call_to(node, "neighbors"): return self._visit_neighbors(node) @@ -1319,6 +1524,19 @@ def visit_Lambda( result = self.visit(node.expr) + # in case tuples are passed as argument, isolated non-transient nodes might be left in the state, + # because not all tuple fields are necessarily used in the lambda scope + used_data = set( + edge.source.data for edge in self.input_edges if isinstance(edge, MemletInputEdge) + ) + for data_node in self.state.data_nodes(): + data_desc = data_node.desc(self.sdfg) + if (not data_desc.transient) and (data_node.data not in used_data): + # isolated node, connect it to a transient to avoid SDFG validation errors + temp, temp_desc = self.sdfg.add_temp_transient_like(data_desc) + temp_node = self.state.add_access(temp) + self.state.add_nedge(data_node, temp_node, dace.Memlet.from_array(temp, temp_desc)) + # remove locally defined lambda symbols and restore previous symbols for symbol_name, arg in prev_symbols.items(): if arg is None: diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_sdfg.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_sdfg.py index 7c97144ab8..c2228c75b7 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_sdfg.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_sdfg.py @@ -217,7 +217,7 @@ def nested_context( nsdfg_params = [ gtir.Sym(id=p_name, type=p_type) for p_name, p_type in global_symbols.items() ] - nsdfg_builder._add_sdfg_params(sdfg, node_params=nsdfg_params, symbolic_arguments={}) + nsdfg_builder._add_sdfg_params(sdfg, node_params=nsdfg_params, symbolic_arguments=None) return nsdfg_builder def unique_nsdfg_name(self, sdfg: dace.SDFG, prefix: str) -> str: @@ -299,6 +299,7 @@ def _add_storage( if isinstance(gt_type, ts.TupleType): tuple_fields = [] for sym in dace_gtir_utils.flatten_tuple_fields(name, gt_type): + assert isinstance(sym.type, ts.DataType) tuple_fields.extend( self._add_storage( sdfg, symbolic_arguments, sym.id, sym.type, transient, tuple_name=name @@ -400,7 +401,7 @@ def _add_sdfg_params( self, sdfg: dace.SDFG, node_params: Sequence[gtir.Sym], - symbolic_arguments: set[str], + symbolic_arguments: Optional[set[str]], ) -> list[str]: """ Helper function to add storage for node parameters and connectivity tables. @@ -410,6 +411,9 @@ def _add_sdfg_params( except when they are listed in 'symbolic_arguments', in which case they will be represented in the SDFG as DaCe symbols. """ + if symbolic_arguments is None: + symbolic_arguments = set() + # add non-transient arrays and/or SDFG symbols for the program arguments sdfg_args = [] for param in node_params: @@ -457,7 +461,7 @@ def visit_Program(self, node: gtir.Program) -> dace.SDFG: assert len(self.field_offsets) == 0 sdfg = dace.SDFG(node.id) - sdfg.debuginfo = dace_utils.debug_info(node, default=sdfg.debuginfo) + sdfg.debuginfo = dace_utils.debug_info(node) # DaCe requires C-compatible strings for the names of data containers, # such as arrays and scalars. GT4Py uses a unicode symbols ('ᐞ') as name @@ -674,7 +678,7 @@ def get_field_domain_offset( tsyms = dace_gtir_utils.flatten_tuple_fields(p_name, p_type) return functools.reduce( lambda field_offsets, sym: ( - field_offsets | get_field_domain_offset(sym.id, sym.type) + field_offsets | get_field_domain_offset(sym.id, sym.type) # type: ignore[arg-type] ), tsyms, {}, @@ -691,7 +695,7 @@ def get_field_domain_offset( self.offset_provider_type, self.column_dim, lambda_symbols, lambda_field_offsets ) nsdfg = dace.SDFG(name=self.unique_nsdfg_name(sdfg, "lambda")) - nstate = nsdfg.add_state("lambda") + nsdfg.debuginfo = dace_utils.debug_info(node, default=sdfg.debuginfo) # add sdfg storage for the symbols that need to be passed as input parameters lambda_params = [ @@ -702,6 +706,7 @@ def get_field_domain_offset( nsdfg, node_params=lambda_params, symbolic_arguments=lambda_domain_symbols ) + nstate = nsdfg.add_state("lambda") lambda_result = lambda_translator.visit( node.expr, sdfg=nsdfg, @@ -887,7 +892,7 @@ def build_sdfg_from_gtir( sdfg = sdfg_genenerator.visit(ir) assert isinstance(sdfg, dace.SDFG) - # TODO(edopao): remove `inline_loop_blocks` when DaCe transformations support LoopRegion construct + # TODO(edopao): remove inlining when DaCe transformations support LoopRegion construct sdutils.inline_loop_blocks(sdfg) return sdfg diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/utility.py b/src/gt4py/next/program_processors/runners/dace_fieldview/utility.py index 842f06e899..64f3ca04e7 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/utility.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/utility.py @@ -41,7 +41,7 @@ def make_symbol_tuple(tuple_name: str, tuple_type: ts.TupleType) -> tuple[gtir.S """ fields = [(f"{tuple_name}_{i}", field_type) for i, field_type in enumerate(tuple_type.types)] return tuple( - make_symbol_tuple(field_name, field_type) + make_symbol_tuple(field_name, field_type) # type: ignore[misc] if isinstance(field_type, ts.TupleType) else im.sym(field_name, field_type) for field_name, field_type in fields From de9c9de57b94bf44902f764f0f3580251b3d1bfe Mon Sep 17 00:00:00 2001 From: Edoardo Paone Date: Thu, 5 Dec 2024 23:06:19 +0100 Subject: [PATCH 05/80] Better handling of isolated nodes --- .../gtir_builtin_translators.py | 21 +++++++++--- .../runners/dace_fieldview/gtir_dataflow.py | 33 +++++++++---------- 2 files changed, 32 insertions(+), 22 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py index 2c744b9d69..ebe7a3a05c 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py @@ -935,7 +935,7 @@ def scan_output_name(input_name: str) -> str: else im.sym(scan_state, scan_state_type) ) - def init_scan_state(outer_data: FieldopData, sym: gtir.Sym) -> None: + def init_scan_state(sym: gtir.Sym) -> None: scan_state = str(sym.id) scan_state_desc = nsdfg.data(scan_state) input_state = scan_input_name(scan_state) @@ -948,15 +948,28 @@ def init_scan_state(outer_data: FieldopData, sym: gtir.Sym) -> None: nsdfg.make_array_memlet(input_state), ) - init_scan_state(init_data, scan_state_input) if isinstance( - init_data, FieldopData - ) else gtx_utils.tree_map(init_scan_state)(init_data, scan_state_input) + init_scan_state(scan_state_input) if isinstance( + scan_state_input, FieldopData + ) else gtx_utils.tree_map(init_scan_state)(scan_state_input) # connect the dataflow input directly to the source data nodes, without passing through a map node; # the reason is that the map for horizontal domain is outside the scan loop region for edge in input_edges: edge.connect(map_entry=None) + # in case tuples are passed as argument, isolated non-transient nodes might be left in the state, + # because not all tuple fields are necessarily used in the lambda scope + for data_node in compute_state.data_nodes(): + data_desc = data_node.desc(nsdfg) + if (compute_state.degree(data_node) == 0) and ( + (not data_desc.transient) + or data_node.data.startswith(scan_state) # check for isolated scan state + ): + # isolated node, connect it to a transient to avoid SDFG validation errors + temp, temp_desc = nsdfg.add_temp_transient_like(data_desc) + temp_node = compute_state.add_access(temp) + compute_state.add_nedge(data_node, temp_node, dace.Memlet.from_array(temp, temp_desc)) + # connect the dataflow result nodes to the variables that carry the scan state along the column axis def connect_scan_output( scan_output_edge: gtir_dataflow.DataflowOutputEdge, sym: gtir.Sym diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow.py index 3903f317b1..952afda6ab 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow.py @@ -184,8 +184,8 @@ class EmptyInputEdge(DataflowInputEdge): node: dace.nodes.Tasklet def connect(self, map_entry: Optional[dace.nodes.MapEntry]) -> None: - assert map_entry is not None - self.state.add_nedge(map_entry, self.node, dace.Memlet()) + if map_entry is not None: + self.state.add_nedge(map_entry, self.node, dace.Memlet()) @dataclasses.dataclass(frozen=True) @@ -678,10 +678,18 @@ def visit_arg(arg: T) -> T: for state, arg in zip([tstate, fstate], node.args[1:3]): in_edges, out_edge = visit_branch(state, arg) for edge in in_edges: - if isinstance(edge, EmptyInputEdge): - continue edge.connect(map_entry=None) + # in case tuples are passed as argument, isolated non-transient nodes might be left in the state, + # because not all tuple fields are necessarily used in the lambda scope + for data_node in state.data_nodes(): + data_desc = data_node.desc(nsdfg) + if (not data_desc.transient) and (state.degree(data_node) == 0): + # isolated node, connect it to a transient to avoid SDFG validation errors + temp, temp_desc = nsdfg.add_temp_transient_like(data_desc) + temp_node = state.add_access(temp) + state.add_nedge(data_node, temp_node, dace.Memlet.from_array(temp, temp_desc)) + def construct_output( output_state: dace.SDFGState, edge: DataflowOutputEdge, sym: gtir.Sym ) -> ValueExpr: @@ -705,7 +713,9 @@ def construct_output( if isinstance(out_edge, tuple): assert isinstance(node.type, ts.TupleType) out_symbol = dace_gtir_utils.make_symbol_tuple("__output", node.type) - outer_value = gtx_utils.tree_map(lambda x, y: construct_output(state, x, y))(out_edge, out_symbol) + outer_value = gtx_utils.tree_map(lambda x, y: construct_output(state, x, y))( + out_edge, out_symbol + ) else: assert isinstance(node.type, ts.FieldType | ts.ScalarType) outer_value = construct_output(state, out_edge, im.sym("__output", node.type)) @@ -1524,19 +1534,6 @@ def visit_Lambda( result = self.visit(node.expr) - # in case tuples are passed as argument, isolated non-transient nodes might be left in the state, - # because not all tuple fields are necessarily used in the lambda scope - used_data = set( - edge.source.data for edge in self.input_edges if isinstance(edge, MemletInputEdge) - ) - for data_node in self.state.data_nodes(): - data_desc = data_node.desc(self.sdfg) - if (not data_desc.transient) and (data_node.data not in used_data): - # isolated node, connect it to a transient to avoid SDFG validation errors - temp, temp_desc = self.sdfg.add_temp_transient_like(data_desc) - temp_node = self.state.add_access(temp) - self.state.add_nedge(data_node, temp_node, dace.Memlet.from_array(temp, temp_desc)) - # remove locally defined lambda symbols and restore previous symbols for symbol_name, arg in prev_symbols.items(): if arg is None: From 14e66e80f162161d7b167f9f3a1c67aa05deab1e Mon Sep 17 00:00:00 2001 From: Edoardo Paone Date: Fri, 6 Dec 2024 11:23:53 +0100 Subject: [PATCH 06/80] Fix field offset in nested SDFG context --- .../gtir_builtin_translators.py | 101 ++++++++++-------- .../runners/dace_fieldview/gtir_dataflow.py | 4 +- .../runners/dace_fieldview/gtir_sdfg.py | 12 ++- 3 files changed, 68 insertions(+), 49 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py index ebe7a3a05c..e57478ad8c 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py @@ -10,7 +10,6 @@ import abc import dataclasses -import itertools from typing import TYPE_CHECKING, Final, Iterable, Optional, Protocol, Sequence, TypeAlias import dace @@ -849,6 +848,54 @@ def translate_scan( # params[2]: the value for scan initialization init_value = scan_expr.args[2] + # make naming consistent throughut this function scope + def scan_input_name(input_name: str) -> str: + return f"__input_{input_name}" + + def scan_output_name(input_name: str) -> str: + return f"__output_{input_name}" + + # visit the initialization value of the scan expression + init_data = sdfg_builder.visit(init_value, sdfg=sdfg, head_state=state) + + # extract type definition of the scan state + scan_state_type = ( + init_data.gt_type if isinstance(init_data, FieldopData) else get_tuple_type(init_data) + ) + + # create list of params to the lambda function with associated node type + lambda_symbols = {scan_state: scan_state_type} | { + str(p.id): arg.type for p, arg in zip(stencil_expr.params[1:], node.args, strict=True) + } + + # visit the arguments to be passed to the lambda expression + # obs. this must be executed before visiting the lambda expression, in order to populate + # the data descriptor with the correct field domain offsets for field arguments + lambda_args = [sdfg_builder.visit(arg, sdfg=sdfg, head_state=state) for arg in node.args] + lambda_args_mapping = { + scan_input_name(scan_state): init_data, + } | { + str(param.id): arg for param, arg in zip(stencil_expr.params[1:], lambda_args, strict=True) + } + + # parse the dataflow input and output symbols + lambda_flat_args = {} + lambda_field_offsets = {} + for param, arg in lambda_args_mapping.items(): + tuple_fields = flatten_tuples(param, arg) + lambda_field_offsets |= {tsym: tfield.offset for tsym, tfield in tuple_fields} + lambda_flat_args |= dict(tuple_fields) + lambda_flat_outs = ( + { + str(sym.id): sym.type + for sym in dace_gtir_utils.flatten_tuple_fields( + scan_output_name(scan_state), scan_state_type + ) + } + if isinstance(scan_state_type, ts.TupleType) + else {scan_output_name(scan_state): scan_state_type} + ) + # the scan operator is implemented as an nested SDFG implementing the lambda expression nsdfg = dace.SDFG(sdfg_builder.unique_nsdfg_name(sdfg, "scan")) nsdfg.debuginfo = dace_utils.debug_info(node, default=sdfg.debuginfo) @@ -900,33 +947,16 @@ def translate_scan( init_state = nsdfg.add_state("scan_init", is_start_block=True) nsdfg.add_edge(init_state, scan_loop, dace.InterstateEdge()) - def scan_input_name(input_name: str) -> str: - return f"__input_{input_name}" - - def scan_output_name(input_name: str) -> str: - return f"__output_{input_name}" - - # visit the initialization value of the scan expression - init_data = sdfg_builder.visit(init_value, sdfg=sdfg, head_state=state) - - # extract type definition of the scan state - scan_state_type = ( - init_data.gt_type if isinstance(init_data, FieldopData) else get_tuple_type(init_data) - ) - # visit the list of arguments to be passed to the scan expression - nsdfg_symbols = {scan_state: scan_state_type} | { - str(p.id): arg.type for p, arg in zip(stencil_expr.params[1:], node.args, strict=True) - } - nsdfg_builder = sdfg_builder.nested_context(nsdfg, nsdfg_symbols) - fieldop_args = [ - _parse_fieldop_arg(im.ref(p.id), nsdfg, compute_state, nsdfg_builder, domain) + stencil_builder = sdfg_builder.nested_context(nsdfg, lambda_symbols, lambda_field_offsets) + stencil_args = [ + _parse_fieldop_arg(im.ref(p.id), nsdfg, compute_state, stencil_builder, domain) for p in stencil_expr.params ] # generate the dataflow representing the scan field operator - taskgen = gtir_dataflow.LambdaToDataflow(nsdfg, compute_state, nsdfg_builder) - input_edges, result = taskgen.apply(stencil_expr, args=fieldop_args) + taskgen = gtir_dataflow.LambdaToDataflow(nsdfg, compute_state, stencil_builder) + input_edges, result = taskgen.apply(stencil_expr, args=stencil_args) # now initialize the scan state scan_state_input = ( @@ -1004,27 +1034,7 @@ def connect_scan_output( assert isinstance(result, tuple) lambda_output = gtx_utils.tree_map(connect_scan_output)(result, scan_state_input) - # the scan nested SDFG is ready, now we need to instantiate it inside the map implementing the field operator - lambda_args = [sdfg_builder.visit(arg, sdfg=sdfg, head_state=state) for arg in node.args] - lambda_args_mapping = { - scan_input_name(scan_state): init_data, - } | { - str(param.id): arg for param, arg in zip(stencil_expr.params[1:], lambda_args, strict=True) - } - lambda_flat_args = dict( - itertools.chain(*[flatten_tuples(param, arg) for param, arg in lambda_args_mapping.items()]) - ) - lambda_flat_outs = ( - set( - str(sym.id) - for sym in dace_gtir_utils.flatten_tuple_fields( - scan_output_name(scan_state), scan_state_type - ) - ) - if isinstance(scan_state_type, ts.TupleType) - else {scan_output_name(scan_state)} - ) - + # build the mapping of symbols from nested SDFG to parent SDFG nsdfg_symbols_mapping: dict[str, dace.symbolic.SymExpr] = {} for dim, _, _ in horizontal_domain: if dim != scan_dim: @@ -1043,11 +1053,12 @@ def connect_scan_output( if isinstance(nested_symbol, dace.symbol) } + # the scan nested SDFG is ready, now we need to instantiate it inside the map implementing the field operator nsdfg_node = state.add_nested_sdfg( nsdfg, sdfg, inputs=set(lambda_flat_args.keys()), - outputs=lambda_flat_outs, + outputs=set(lambda_flat_outs.keys()), symbol_mapping=nsdfg_symbols_mapping, ) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow.py index 952afda6ab..89de719d4c 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow.py @@ -637,7 +637,9 @@ def visit_arg(arg: T) -> T: arg_data = arg_node.data # SDFG data containers with name prefix '__tmp' are expected to be transients inner_data = ( - arg_data.replace("__tmp", "__arg") if arg_data.startswith("__tmp") else arg_data + arg_data.replace("__tmp", "__input") + if arg_data.startswith("__tmp") + else arg_data ) try: diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_sdfg.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_sdfg.py index c2228c75b7..ebf98e453e 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_sdfg.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_sdfg.py @@ -119,7 +119,10 @@ def is_column_dimension(self, dim: gtx_common.Dimension) -> bool: @abc.abstractmethod def nested_context( - self, sdfg: dace.SDFG, global_symbols: dict[str, ts.DataType] + self, + sdfg: dace.SDFG, + global_symbols: dict[str, ts.DataType], + field_offsets: dict[str, Optional[list[dace.symbolic.SymExpr]]], ) -> SDFGBuilder: """Create a new empty context, useful to build a nested SDFG.""" ... @@ -209,10 +212,13 @@ def is_column_dimension(self, dim: gtx_common.Dimension) -> bool: return dim == self.column_dim def nested_context( - self, sdfg: dace.SDFG, global_symbols: dict[str, ts.DataType] + self, + sdfg: dace.SDFG, + global_symbols: dict[str, ts.DataType], + field_offsets: dict[str, Optional[list[dace.symbolic.SymExpr]]], ) -> SDFGBuilder: nsdfg_builder = GTIRToSDFG( - self.offset_provider_type, self.column_dim, global_symbols, self.field_offsets + self.offset_provider_type, self.column_dim, global_symbols, field_offsets ) nsdfg_params = [ gtir.Sym(id=p_name, type=p_type) for p_name, p_type in global_symbols.items() From fcfaf7252e9ce533e8c0714bf55c2d1af6feb12f Mon Sep 17 00:00:00 2001 From: Edoardo Paone Date: Fri, 6 Dec 2024 13:46:59 +0100 Subject: [PATCH 07/80] fix problem with dereferencil of 1D vertical fields inside scan --- .../gtir_builtin_translators.py | 41 +++++++++++++------ 1 file changed, 29 insertions(+), 12 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py index e57478ad8c..32277f5f77 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py @@ -215,6 +215,7 @@ def _parse_fieldop_arg( state: dace.SDFGState, sdfg_builder: gtir_sdfg.SDFGBuilder, domain: FieldopDomain, + scan_dim: Optional[gtx_common.Dimension] = None, ) -> ( gtir_dataflow.IteratorExpr | gtir_dataflow.MemletExpr @@ -224,11 +225,24 @@ def _parse_fieldop_arg( arg = sdfg_builder.visit(node, sdfg=sdfg, head_state=state) + def scan_arg_wrapper(arg: FieldopData) -> gtir_dataflow.MemletExpr | gtir_dataflow.IteratorExpr: + # In case of scan field operator, input 1D fields with scan dimension do not need to be dereferenced. + # The iterator expression for such fields is converted into a memlet expression for data access. + arg_expr = arg.get_local_view(domain) + if scan_dim is None: + return arg_expr + if isinstance(arg_expr, gtir_dataflow.MemletExpr): + return arg_expr + head, *tail = arg_expr.field_domain + if tail or head[0] != scan_dim: + return arg_expr + return gtir_dataflow.MemletExpr(arg_expr.field, arg_expr.gt_dtype, arg_expr.get_memlet_subset(sdfg)) + if isinstance(arg, FieldopData): - return arg.get_local_view(domain) + return scan_arg_wrapper(arg) else: # handle tuples of fields - return gtx_utils.tree_map(lambda targ: targ.get_local_view(domain))(arg) + return gtx_utils.tree_map(lambda x: scan_arg_wrapper(x))(arg) def _get_field_layout( @@ -834,6 +848,17 @@ def translate_scan( # parse the domain of the scan field operator domain = extract_domain(domain_expr) + # use the vertical dimension in the domain as scan dimension + scan_domain = [ + (dim, lower_bound, upper_bound) + for dim, lower_bound, upper_bound in domain + if dim.kind == gtx_common.DimensionKind.VERTICAL + ] + assert len(scan_domain) == 1 + scan_dim, scan_lower_bound, scan_upper_bound = scan_domain[0] + assert sdfg_builder.is_column_dimension(scan_dim) + + # parse scan parameters assert len(scan_expr.args) == 3 stencil_expr = scan_expr.args[0] assert isinstance(stencil_expr, gtir.Lambda) @@ -900,15 +925,7 @@ def scan_output_name(input_name: str) -> str: nsdfg = dace.SDFG(sdfg_builder.unique_nsdfg_name(sdfg, "scan")) nsdfg.debuginfo = dace_utils.debug_info(node, default=sdfg.debuginfo) - # use the vertical dimension in the domain as scan dimension - scan_domain = [ - (dim, lower_bound, upper_bound) - for dim, lower_bound, upper_bound in domain - if dim.kind == gtx_common.DimensionKind.VERTICAL - ] - assert len(scan_domain) == 1 - scan_dim, scan_lower_bound, scan_upper_bound = scan_domain[0] - assert sdfg_builder.is_column_dimension(scan_dim) + # extract the scan loop range scan_loop_var = dace_gtir_utils.get_map_variable(scan_dim) _, scan_output_offset, scan_output_shape = _get_field_layout(scan_domain) @@ -950,7 +967,7 @@ def scan_output_name(input_name: str) -> str: # visit the list of arguments to be passed to the scan expression stencil_builder = sdfg_builder.nested_context(nsdfg, lambda_symbols, lambda_field_offsets) stencil_args = [ - _parse_fieldop_arg(im.ref(p.id), nsdfg, compute_state, stencil_builder, domain) + _parse_fieldop_arg(im.ref(p.id), nsdfg, compute_state, stencil_builder, domain, scan_dim) for p in stencil_expr.params ] From 79204ee419a77c9ac5e66027b6473fb8cd5712b0 Mon Sep 17 00:00:00 2001 From: Edoardo Paone Date: Fri, 6 Dec 2024 14:05:10 +0100 Subject: [PATCH 08/80] generalize previous fix to all scan input fields --- .../gtir_builtin_translators.py | 20 +++++++------------ 1 file changed, 7 insertions(+), 13 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py index 32277f5f77..b30af3aeac 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py @@ -215,7 +215,7 @@ def _parse_fieldop_arg( state: dace.SDFGState, sdfg_builder: gtir_sdfg.SDFGBuilder, domain: FieldopDomain, - scan_dim: Optional[gtx_common.Dimension] = None, + by_value: bool = False, ) -> ( gtir_dataflow.IteratorExpr | gtir_dataflow.MemletExpr @@ -225,24 +225,18 @@ def _parse_fieldop_arg( arg = sdfg_builder.visit(node, sdfg=sdfg, head_state=state) - def scan_arg_wrapper(arg: FieldopData) -> gtir_dataflow.MemletExpr | gtir_dataflow.IteratorExpr: - # In case of scan field operator, input 1D fields with scan dimension do not need to be dereferenced. - # The iterator expression for such fields is converted into a memlet expression for data access. + def get_arg_value(arg: FieldopData) -> gtir_dataflow.MemletExpr | gtir_dataflow.IteratorExpr: + # In case of scan field operator, the arguments to the vertical stencil are passed by value. arg_expr = arg.get_local_view(domain) - if scan_dim is None: - return arg_expr - if isinstance(arg_expr, gtir_dataflow.MemletExpr): - return arg_expr - head, *tail = arg_expr.field_domain - if tail or head[0] != scan_dim: + if not by_value or isinstance(arg_expr, gtir_dataflow.MemletExpr): return arg_expr return gtir_dataflow.MemletExpr(arg_expr.field, arg_expr.gt_dtype, arg_expr.get_memlet_subset(sdfg)) if isinstance(arg, FieldopData): - return scan_arg_wrapper(arg) + return get_arg_value(arg) else: # handle tuples of fields - return gtx_utils.tree_map(lambda x: scan_arg_wrapper(x))(arg) + return gtx_utils.tree_map(lambda x: get_arg_value(x))(arg) def _get_field_layout( @@ -967,7 +961,7 @@ def scan_output_name(input_name: str) -> str: # visit the list of arguments to be passed to the scan expression stencil_builder = sdfg_builder.nested_context(nsdfg, lambda_symbols, lambda_field_offsets) stencil_args = [ - _parse_fieldop_arg(im.ref(p.id), nsdfg, compute_state, stencil_builder, domain, scan_dim) + _parse_fieldop_arg(im.ref(p.id), nsdfg, compute_state, stencil_builder, domain, by_value=True) for p in stencil_expr.params ] From 5fe461a0e5ca2a607ba68a93227abf28d8047258 Mon Sep 17 00:00:00 2001 From: Edoardo Paone Date: Fri, 6 Dec 2024 14:07:58 +0100 Subject: [PATCH 09/80] minor edit --- .../dace_fieldview/gtir_builtin_translators.py | 17 +++++++---------- 1 file changed, 7 insertions(+), 10 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py index b30af3aeac..a0833e2d72 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py @@ -230,7 +230,9 @@ def get_arg_value(arg: FieldopData) -> gtir_dataflow.MemletExpr | gtir_dataflow. arg_expr = arg.get_local_view(domain) if not by_value or isinstance(arg_expr, gtir_dataflow.MemletExpr): return arg_expr - return gtir_dataflow.MemletExpr(arg_expr.field, arg_expr.gt_dtype, arg_expr.get_memlet_subset(sdfg)) + return gtir_dataflow.MemletExpr( + arg_expr.field, arg_expr.gt_dtype, arg_expr.get_memlet_subset(sdfg) + ) if isinstance(arg, FieldopData): return get_arg_value(arg) @@ -923,13 +925,6 @@ def scan_output_name(input_name: str) -> str: scan_loop_var = dace_gtir_utils.get_map_variable(scan_dim) _, scan_output_offset, scan_output_shape = _get_field_layout(scan_domain) - # create field operator on the horizontal domain - horizontal_domain = [ - (dim, lower_bound, upper_bound) - for dim, lower_bound, upper_bound in domain - if dim.kind == gtx_common.DimensionKind.HORIZONTAL - ] - # create a loop region for lambda call over the scan dimension if scan_forward: scan_loop = dace.sdfg.state.LoopRegion( @@ -961,7 +956,9 @@ def scan_output_name(input_name: str) -> str: # visit the list of arguments to be passed to the scan expression stencil_builder = sdfg_builder.nested_context(nsdfg, lambda_symbols, lambda_field_offsets) stencil_args = [ - _parse_fieldop_arg(im.ref(p.id), nsdfg, compute_state, stencil_builder, domain, by_value=True) + _parse_fieldop_arg( + im.ref(p.id), nsdfg, compute_state, stencil_builder, domain, by_value=True + ) for p in stencil_expr.params ] @@ -1047,7 +1044,7 @@ def connect_scan_output( # build the mapping of symbols from nested SDFG to parent SDFG nsdfg_symbols_mapping: dict[str, dace.symbolic.SymExpr] = {} - for dim, _, _ in horizontal_domain: + for dim, _, _ in domain: if dim != scan_dim: dim_map_variable = dace_gtir_utils.get_map_variable(dim) nsdfg_symbols_mapping[dim_map_variable] = dim_map_variable From a4bde3a5f43e843c48efb5076482c3cf03048d53 Mon Sep 17 00:00:00 2001 From: Edoardo Paone Date: Fri, 6 Dec 2024 14:24:54 +0100 Subject: [PATCH 10/80] fix out-of-bound access --- .../runners/dace_fieldview/gtir_builtin_translators.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py index a0833e2d72..d1c3c2a0e4 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py @@ -1022,8 +1022,9 @@ def connect_scan_output( scan_output_name(sym.id), scan_output_shape, scan_result_desc.dtype, find_new_name=True ) output_node = compute_state.add_access(output) + output_subset = str(dace.symbolic.SymExpr(scan_loop_var) - scan_lower_bound) compute_state.add_nedge( - scan_result.dc_node, output_node, dace.Memlet(data=output, subset=scan_loop_var) + scan_result.dc_node, output_node, dace.Memlet(data=output, subset=output_subset) ) update_state.add_nedge( From c75a8e47e94504afc844ff59a801e1e74a3e8b50 Mon Sep 17 00:00:00 2001 From: Edoardo Paone Date: Fri, 6 Dec 2024 14:34:57 +0100 Subject: [PATCH 11/80] Better handling of isolated nodes --- .../gtir_builtin_translators.py | 28 ++++++++++--------- .../runners/dace_fieldview/gtir_dataflow.py | 20 ++++++------- 2 files changed, 25 insertions(+), 23 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py index d1c3c2a0e4..aa83aed163 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py @@ -995,19 +995,6 @@ def init_scan_state(sym: gtir.Sym) -> None: for edge in input_edges: edge.connect(map_entry=None) - # in case tuples are passed as argument, isolated non-transient nodes might be left in the state, - # because not all tuple fields are necessarily used in the lambda scope - for data_node in compute_state.data_nodes(): - data_desc = data_node.desc(nsdfg) - if (compute_state.degree(data_node) == 0) and ( - (not data_desc.transient) - or data_node.data.startswith(scan_state) # check for isolated scan state - ): - # isolated node, connect it to a transient to avoid SDFG validation errors - temp, temp_desc = nsdfg.add_temp_transient_like(data_desc) - temp_node = compute_state.add_access(temp) - compute_state.add_nedge(data_node, temp_node, dace.Memlet.from_array(temp, temp_desc)) - # connect the dataflow result nodes to the variables that carry the scan state along the column axis def connect_scan_output( scan_output_edge: gtir_dataflow.DataflowOutputEdge, sym: gtir.Sym @@ -1043,6 +1030,21 @@ def connect_scan_output( assert isinstance(result, tuple) lambda_output = gtx_utils.tree_map(connect_scan_output)(result, scan_state_input) + # in case tuples are passed as argument, isolated non-transient nodes might be left in the state, + # because not all tuple fields are necessarily accessed in the lambda scope + for data_node in compute_state.data_nodes(): + data_desc = data_node.desc(nsdfg) + if (compute_state.degree(data_node) == 0) and ( + (not data_desc.transient) + or data_node.data.startswith( + scan_state + ) # exceptional case where the state is not used, not a scan indeed + ): + # isolated node, connect it to a transient to avoid SDFG validation errors + temp, temp_desc = nsdfg.add_temp_transient_like(data_desc) + temp_node = compute_state.add_access(temp) + compute_state.add_nedge(data_node, temp_node, dace.Memlet.from_array(temp, temp_desc)) + # build the mapping of symbols from nested SDFG to parent SDFG nsdfg_symbols_mapping: dict[str, dace.symbolic.SymExpr] = {} for dim, _, _ in domain: diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow.py index 89de719d4c..5767a86c42 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow.py @@ -682,16 +682,6 @@ def visit_arg(arg: T) -> T: for edge in in_edges: edge.connect(map_entry=None) - # in case tuples are passed as argument, isolated non-transient nodes might be left in the state, - # because not all tuple fields are necessarily used in the lambda scope - for data_node in state.data_nodes(): - data_desc = data_node.desc(nsdfg) - if (not data_desc.transient) and (state.degree(data_node) == 0): - # isolated node, connect it to a transient to avoid SDFG validation errors - temp, temp_desc = nsdfg.add_temp_transient_like(data_desc) - temp_node = state.add_access(temp) - state.add_nedge(data_node, temp_node, dace.Memlet.from_array(temp, temp_desc)) - def construct_output( output_state: dace.SDFGState, edge: DataflowOutputEdge, sym: gtir.Sym ) -> ValueExpr: @@ -722,6 +712,16 @@ def construct_output( assert isinstance(node.type, ts.FieldType | ts.ScalarType) outer_value = construct_output(state, out_edge, im.sym("__output", node.type)) + # in case tuples are passed as argument, isolated non-transient nodes might be left in the state, + # because not all tuple fields are necessarily used in the lambda scope + for data_node in state.data_nodes(): + data_desc = data_node.desc(nsdfg) + if (not data_desc.transient) and (state.degree(data_node) == 0): + # isolated node, connect it to a transient to avoid SDFG validation errors + temp, temp_desc = nsdfg.add_temp_transient_like(data_desc) + temp_node = state.add_access(temp) + state.add_nedge(data_node, temp_node, dace.Memlet.from_array(temp, temp_desc)) + else: result = outer_value From 397acae9a613021bd6217561644d01db40006ced Mon Sep 17 00:00:00 2001 From: Edoardo Paone Date: Fri, 6 Dec 2024 14:53:18 +0100 Subject: [PATCH 12/80] exclude scan tests on dace backend with optimizations --- tests/next_tests/definitions.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/next_tests/definitions.py b/tests/next_tests/definitions.py index b3064701b4..640fde6236 100644 --- a/tests/next_tests/definitions.py +++ b/tests/next_tests/definitions.py @@ -168,8 +168,10 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum): BACKEND_SKIP_TEST_MATRIX = { EmbeddedIds.NUMPY_EXECUTION: EMBEDDED_SKIP_LIST, EmbeddedIds.CUPY_EXECUTION: EMBEDDED_SKIP_LIST, - OptionalProgramBackendId.DACE_CPU: DACE_SKIP_TEST_LIST, - OptionalProgramBackendId.DACE_GPU: DACE_SKIP_TEST_LIST, + OptionalProgramBackendId.DACE_CPU: DACE_SKIP_TEST_LIST + + [(USES_SCAN, XFAIL, UNSUPPORTED_MESSAGE)], # TODO(edopao): result validation fails with dace optimization + OptionalProgramBackendId.DACE_GPU: DACE_SKIP_TEST_LIST + + [(USES_SCAN, XFAIL, UNSUPPORTED_MESSAGE)], # TODO(edopao): result validation fails with dace optimization OptionalProgramBackendId.DACE_CPU_NO_OPT: DACE_SKIP_TEST_LIST, OptionalProgramBackendId.DACE_GPU_NO_OPT: DACE_SKIP_TEST_LIST, ProgramBackendId.GTFN_CPU: GTFN_SKIP_TEST_LIST From a706b27e6283490baaa6bc663cafbfeb94ce56c8 Mon Sep 17 00:00:00 2001 From: Edoardo Paone Date: Fri, 6 Dec 2024 18:06:20 +0100 Subject: [PATCH 13/80] fix pre-commit --- .../gtir_builtin_translators.py | 82 +++++++---- .../runners/dace_fieldview/gtir_dataflow.py | 136 +++++++++++------- tests/next_tests/definitions.py | 8 +- 3 files changed, 145 insertions(+), 81 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py index aa83aed163..524091cbe3 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py @@ -10,7 +10,7 @@ import abc import dataclasses -from typing import TYPE_CHECKING, Final, Iterable, Optional, Protocol, Sequence, TypeAlias +from typing import TYPE_CHECKING, Any, Final, Iterable, Optional, Protocol, Sequence, TypeAlias import dace from dace import subsets as sbs @@ -219,17 +219,26 @@ def _parse_fieldop_arg( ) -> ( gtir_dataflow.IteratorExpr | gtir_dataflow.MemletExpr - | tuple[gtir_dataflow.IteratorExpr | gtir_dataflow.MemletExpr, ...] + | gtir_dataflow.ValueExpr + | tuple[ + gtir_dataflow.IteratorExpr + | gtir_dataflow.MemletExpr + | gtir_dataflow.ValueExpr + | tuple[Any, ...], + ..., + ] ): """Helper method to visit an expression passed as argument to a field operator.""" arg = sdfg_builder.visit(node, sdfg=sdfg, head_state=state) - def get_arg_value(arg: FieldopData) -> gtir_dataflow.MemletExpr | gtir_dataflow.IteratorExpr: - # In case of scan field operator, the arguments to the vertical stencil are passed by value. + def get_arg_value( + arg: FieldopData, + ) -> gtir_dataflow.IteratorExpr | gtir_dataflow.MemletExpr: arg_expr = arg.get_local_view(domain) if not by_value or isinstance(arg_expr, gtir_dataflow.MemletExpr): return arg_expr + # In case of scan field operator, the arguments to the vertical stencil are passed by value. return gtir_dataflow.MemletExpr( arg_expr.field, arg_expr.gt_dtype, arg_expr.get_memlet_subset(sdfg) ) @@ -277,7 +286,8 @@ def _create_field_operator( node_type: ts.FieldType | ts.TupleType, sdfg_builder: gtir_sdfg.SDFGBuilder, input_edges: Iterable[gtir_dataflow.DataflowInputEdge], - output_edges: gtir_dataflow.DataflowOutputEdge | tuple[gtir_dataflow.DataflowOutputEdge, ...], + output_edges: gtir_dataflow.DataflowOutputEdge + | tuple[gtir_dataflow.DataflowOutputEdge | tuple[Any, ...], ...], scan_dim: Optional[gtx_common.Dimension] = None, ) -> FieldopResult: """ @@ -446,16 +456,18 @@ def translate_as_fieldop( if cpm.is_call_to(fieldop_expr, "scan"): return translate_scan(node, sdfg, state, sdfg_builder) - elif isinstance(fieldop_expr, gtir.Lambda): - # Default case, handled below: the argument expression is a lambda function - # representing the stencil operation to be computed over the field domain. - stencil_expr = fieldop_expr - elif cpm.is_ref_to(fieldop_expr, "deref"): + + assert isinstance(node.type, ts.FieldType) + if cpm.is_ref_to(fieldop_expr, "deref"): # Special usage of 'deref' as argument to fieldop expression, to pass a scalar # value to 'as_fieldop' function. It results in broadcasting the scalar value # over the field domain. stencil_expr = im.lambda_("a")(im.deref("a")) stencil_expr.expr.type = node.type.dtype + elif isinstance(fieldop_expr, gtir.Lambda): + # Default case, handled below: the argument expression is a lambda function + # representing the stencil operation to be computed over the field domain. + stencil_expr = fieldop_expr else: raise NotImplementedError( f"Expression type '{type(fieldop_expr)}' not supported as argument to 'as_fieldop' node." @@ -835,6 +847,7 @@ def translate_scan( ) -> FieldopResult: assert isinstance(node, gtir.FunCall) assert cpm.is_call_to(node.fun, "as_fieldop") + assert isinstance(node.type, (ts.FieldType, ts.TupleType)) fun_node = node.fun assert len(fun_node.args) == 2 @@ -886,7 +899,9 @@ def scan_output_name(input_name: str) -> str: # create list of params to the lambda function with associated node type lambda_symbols = {scan_state: scan_state_type} | { - str(p.id): arg.type for p, arg in zip(stencil_expr.params[1:], node.args, strict=True) + str(p.id): arg.type + for p, arg in zip(stencil_expr.params[1:], node.args, strict=True) + if isinstance(arg.type, ts.DataType) } # visit the arguments to be passed to the lambda expression @@ -900,8 +915,8 @@ def scan_output_name(input_name: str) -> str: } # parse the dataflow input and output symbols - lambda_flat_args = {} - lambda_field_offsets = {} + lambda_flat_args: dict[str, FieldopData] = {} + lambda_field_offsets: dict[str, Optional[list[dace.symbolic.SymExpr]]] = {} for param, arg in lambda_args_mapping.items(): tuple_fields = flatten_tuples(param, arg) lambda_field_offsets |= {tsym: tfield.offset for tsym, tfield in tuple_fields} @@ -986,9 +1001,10 @@ def init_scan_state(sym: gtir.Sym) -> None: nsdfg.make_array_memlet(input_state), ) - init_scan_state(scan_state_input) if isinstance( - scan_state_input, FieldopData - ) else gtx_utils.tree_map(init_scan_state)(scan_state_input) + if isinstance(scan_state_input, tuple): + gtx_utils.tree_map(init_scan_state)(scan_state_input) + else: + init_scan_state(scan_state_input) # connect the dataflow input directly to the source data nodes, without passing through a map node; # the reason is that the map for horizontal domain is outside the scan loop region @@ -1000,8 +1016,8 @@ def connect_scan_output( scan_output_edge: gtir_dataflow.DataflowOutputEdge, sym: gtir.Sym ) -> FieldopData: scan_result = scan_output_edge.result - assert isinstance(scan_result, gtir_dataflow.ValueExpr) - assert isinstance(sym.type, ts.ScalarType) and scan_result.gt_dtype == sym.type + assert isinstance(scan_result.gt_dtype, ts.ScalarType) + assert scan_result.gt_dtype == sym.type scan_result_data = scan_result.dc_node.data scan_result_desc = scan_result.dc_node.desc(nsdfg) @@ -1023,12 +1039,17 @@ def connect_scan_output( output_type = ts.FieldType(dims=[scan_dim], dtype=scan_result.gt_dtype) return FieldopData(output_node, output_type, scan_output_offset) - if isinstance(scan_state_input, gtir.Sym): - assert isinstance(result, gtir_dataflow.DataflowOutputEdge) - lambda_output = connect_scan_output(result, scan_state_input) - else: - assert isinstance(result, tuple) - lambda_output = gtx_utils.tree_map(connect_scan_output)(result, scan_state_input) + lambda_output = ( + gtx_utils.tree_map(connect_scan_output)(result, scan_state_input) + if (isinstance(result, tuple) and isinstance(scan_state_input, tuple)) + else connect_scan_output(result, scan_state_input) + if ( + isinstance(result, gtir_dataflow.DataflowOutputEdge) + and isinstance(scan_state_input, gtir.Sym) + ) + else None + ) + assert lambda_output # in case tuples are passed as argument, isolated non-transient nodes might be left in the state, # because not all tuple fields are necessarily accessed in the lambda scope @@ -1095,15 +1116,14 @@ def construct_output_edge(scan_data: FieldopData) -> gtir_dataflow.DataflowOutpu None, dace.Memlet.from_array(output_data, output_desc), ) - output_expr = gtir_dataflow.MemletExpr( - output_node, scan_data.gt_type.dtype, sbs.Range.from_array(output_desc) - ) + output_expr = gtir_dataflow.ValueExpr(output_node, scan_data.gt_type.dtype) return gtir_dataflow.DataflowOutputEdge(state, output_expr) - if isinstance(lambda_output, FieldopData): - output_edges = construct_output_edge(lambda_output) - else: - output_edges = gtx_utils.tree_map(construct_output_edge)(lambda_output) + output_edges = ( + construct_output_edge(lambda_output) + if isinstance(lambda_output, FieldopData) + else gtx_utils.tree_map(construct_output_edge)(lambda_output) + ) return _create_field_operator( sdfg, state, domain, node.type, sdfg_builder, input_edges, output_edges, scan_dim diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow.py index 5767a86c42..fa54942049 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow.py @@ -10,7 +10,7 @@ import abc import dataclasses -from typing import Any, Dict, Final, List, Optional, Protocol, Set, Tuple, TypeAlias, TypeVar, Union +from typing import Any, Dict, Final, List, Optional, Protocol, Set, Tuple, TypeAlias, Union import dace from dace import subsets as sbs @@ -290,7 +290,10 @@ class LambdaToDataflow(eve.NodeVisitor): state: dace.SDFGState subgraph_builder: gtir_sdfg.DataflowBuilder input_edges: list[DataflowInputEdge] - symbol_map: dict[str, tuple[IteratorExpr | MemletExpr | SymbolExpr, ...]] + symbol_map: dict[ + str, + IteratorExpr | DataExpr | tuple[IteratorExpr | DataExpr | tuple[Any, ...], ...], + ] def __init__( self, @@ -556,7 +559,7 @@ def _visit_deref(self, node: gtir.FunCall) -> DataExpr: return self._construct_tasklet_result(field_desc.dtype, deref_node, "val") - def _visit_if(self, node: gtir.FunCall) -> ValueExpr | list[ValueExpr]: + def _visit_if(self, node: gtir.FunCall) -> ValueExpr | tuple[ValueExpr | tuple[Any, ...], ...]: assert len(node.args) == 3 # TODO(edopao): enable once DaCe supports it in next release @@ -564,9 +567,12 @@ def _visit_if(self, node: gtir.FunCall) -> ValueExpr | list[ValueExpr]: condition_value = self.visit(node.args[0]) assert ( - isinstance(condition_value, DataExpr) - and isinstance(condition_value.gt_dtype, ts.ScalarType) - and condition_value.gt_dtype.kind == ts.ScalarKind.BOOL + ( + isinstance(condition_value.gt_dtype, ts.ScalarType) + and condition_value.gt_dtype.kind == ts.ScalarKind.BOOL + ) + if isinstance(condition_value, (MemletExpr, ValueExpr)) + else (condition_value.dc_dtype == dace.dtypes.bool) ) nsdfg = dace.SDFG(self.unique_nsdfg_name(prefix="if_stmt")) @@ -612,19 +618,20 @@ def _visit_if(self, node: gtir.FunCall) -> ValueExpr | list[ValueExpr]: def visit_branch( state: dace.SDFGState, expr: gtir.Expr - ) -> tuple[list[DataflowInputEdge], tuple[DataflowOutputEdge]]: + ) -> tuple[ + list[DataflowInputEdge], + DataflowOutputEdge | tuple[DataflowOutputEdge | tuple[Any, ...], ...], + ]: assert state in nsdfg.states() - T = TypeVar("T", IteratorExpr, MemletExpr, ValueExpr) - - def visit_arg(arg: T) -> T: + def visit_arg(arg: IteratorExpr | DataExpr) -> IteratorExpr | ValueExpr: if isinstance(arg, IteratorExpr): arg_node = arg.field arg_desc = arg_node.desc(self.sdfg) arg_subset = sbs.Range.from_array(arg_desc) else: - assert isinstance(arg, (MemletExpr | ValueExpr)) + assert isinstance(arg, (MemletExpr, ValueExpr)) arg_node = arg.dc_node if isinstance(arg, MemletExpr): assert set(arg.subset.size()) == {1} @@ -659,10 +666,16 @@ def visit_arg(arg: T) -> T: if isinstance(arg, IteratorExpr): return IteratorExpr(inner_node, arg.gt_dtype, arg.field_domain, arg.indices) else: + assert isinstance(inner_desc, dace.data.Scalar) return ValueExpr(inner_node, arg.gt_dtype) lambda_params = [] - lambda_args = [] + lambda_args: list[ + IteratorExpr + | MemletExpr + | ValueExpr + | tuple[IteratorExpr | MemletExpr | ValueExpr | tuple[Any, ...], ...] + ] = [] for p in symbol_ref_utils.collect_symbol_refs(expr, self.symbol_map.keys()): arg = self.symbol_map[p] if isinstance(arg, tuple): @@ -705,9 +718,9 @@ def construct_output( if isinstance(out_edge, tuple): assert isinstance(node.type, ts.TupleType) out_symbol = dace_gtir_utils.make_symbol_tuple("__output", node.type) - outer_value = gtx_utils.tree_map(lambda x, y: construct_output(state, x, y))( - out_edge, out_symbol - ) + outer_value = gtx_utils.tree_map( + lambda x, y, output_state=state: construct_output(output_state, x, y) + )(out_edge, out_symbol) else: assert isinstance(node.type, ts.FieldType | ts.ScalarType) outer_value = construct_output(state, out_edge, im.sym("__output", node.type)) @@ -725,7 +738,7 @@ def construct_output( else: result = outer_value - outputs = {outval.dc_node.data for outval in gtx_utils.flatten_nested_tuple(result)} + outputs = {outval.dc_node.data for outval in gtx_utils.flatten_nested_tuple((result,))} nsdfg_node = self.state.add_nested_sdfg( nsdfg, @@ -758,10 +771,11 @@ def connect_output(inner_value: ValueExpr) -> ValueExpr: ) return ValueExpr(output_node, inner_value.gt_dtype) - if isinstance(result, tuple): - return gtx_utils.tree_map(connect_output)(result) - else: - return connect_output(result) + return ( + gtx_utils.tree_map(connect_output)(result) + if isinstance(result, tuple) + else connect_output(result) + ) def _visit_neighbors(self, node: gtir.FunCall) -> ValueExpr: assert isinstance(node.type, itir_ts.ListType) @@ -1489,7 +1503,7 @@ def _visit_tuple_get( def visit_FunCall( self, node: gtir.FunCall - ) -> IteratorExpr | DataExpr | tuple[DataflowOutputEdge, ...]: + ) -> IteratorExpr | DataExpr | tuple[IteratorExpr | DataExpr | tuple[Any, ...], ...]: if cpm.is_call_to(node, "deref"): return self._visit_deref(node) @@ -1515,8 +1529,7 @@ def visit_FunCall( return self._visit_shift(node) elif isinstance(node.fun, gtir.Lambda): - lambda_args = [self.visit(arg) for arg in node.args] - return self.visit_Lambda(node.fun, args=lambda_args) + raise AssertionError("Lambda node should be visited with 'apply()' method.") elif isinstance(node.fun, gtir.SymRef): return self._visit_generic_builtin(node) @@ -1525,24 +1538,10 @@ def visit_FunCall( raise NotImplementedError(f"Invalid 'FunCall' node: {node}.") def visit_Lambda( - self, node: gtir.Lambda, args: list[tuple[IteratorExpr | MemletExpr | SymbolExpr, ...]] - ) -> DataflowOutputEdge | tuple[DataflowOutputEdge, ...]: - # lambda arguments are mapped to symbols defined in lambda scope - prev_symbols: dict[str, Optional[tuple[IteratorExpr | MemletExpr | SymbolExpr, ...]]] = {} - for p, arg in zip(node.params, args, strict=True): - symbol_name = str(p.id) - prev_symbols[symbol_name] = self.symbol_map.get(symbol_name, None) - self.symbol_map[symbol_name] = arg - + self, node: gtir.Lambda + ) -> DataflowOutputEdge | tuple[DataflowOutputEdge | tuple[Any, ...], ...]: result = self.visit(node.expr) - # remove locally defined lambda symbols and restore previous symbols - for symbol_name, arg in prev_symbols.items(): - if arg is None: - self.symbol_map.pop(symbol_name) - else: - self.symbol_map[symbol_name] = arg - def make_output_edge( output_expr: ValueExpr | MemletExpr | SymbolExpr, ) -> DataflowOutputEdge: @@ -1576,16 +1575,19 @@ def parse_result( return r return make_output_edge(r) - if isinstance(result, tuple): - return gtx_utils.tree_map(parse_result)(result) - else: - return parse_result(result) + return ( + gtx_utils.tree_map(parse_result)(result) + if isinstance(result, tuple) + else parse_result(result) + ) def visit_Literal(self, node: gtir.Literal) -> SymbolExpr: dc_dtype = dace_utils.as_dace_type(node.type) return SymbolExpr(node.value, dc_dtype) - def visit_SymRef(self, node: gtir.SymRef) -> tuple[IteratorExpr | MemletExpr | SymbolExpr, ...]: + def visit_SymRef( + self, node: gtir.SymRef + ) -> IteratorExpr | DataExpr | tuple[IteratorExpr | DataExpr | tuple[Any, ...], ...]: param = str(node.id) if param in self.symbol_map: return self.symbol_map[param] @@ -1594,7 +1596,45 @@ def visit_SymRef(self, node: gtir.SymRef) -> tuple[IteratorExpr | MemletExpr | S return SymbolExpr(param, dace.string) def apply( - self, node: gtir.Lambda, args: list[tuple[IteratorExpr | MemletExpr | SymbolExpr, ...]] - ) -> tuple[list[DataflowInputEdge], tuple[DataflowOutputEdge]]: - output_edges = self.visit_Lambda(node, args=args) - return self.input_edges, output_edges + self, + node: gtir.Lambda, + args: list[ + IteratorExpr + | MemletExpr + | ValueExpr + | tuple[IteratorExpr | MemletExpr | ValueExpr | tuple[Any, ...], ...] + ], + ) -> tuple[ + list[DataflowInputEdge], + DataflowOutputEdge | tuple[DataflowOutputEdge | tuple[Any, ...], ...], + ]: + # lambda arguments are mapped to symbols defined in lambda scope + prev_symbols: dict[ + str, + Optional[ + IteratorExpr | DataExpr | tuple[IteratorExpr | DataExpr | tuple[Any, ...], ...] + ], + ] = {} + for p, arg in zip(node.params, args, strict=True): + symbol_name = str(p.id) + prev_symbols[symbol_name] = self.symbol_map.get(symbol_name, None) + self.symbol_map[symbol_name] = arg + + if cpm.is_let(node.expr): + let_node = node.expr + let_args = [self.visit(arg) for arg in let_node.args] + assert isinstance(let_node.fun, gtir.Lambda) + input_edges, output_edges = self.apply(let_node.fun, args=let_args) + + else: + output_edges = self.visit_Lambda(node) + input_edges = self.input_edges + + # remove locally defined lambda symbols and restore previous symbols + for symbol_name, prev_value in prev_symbols.items(): + if prev_value is None: + self.symbol_map.pop(symbol_name) + else: + self.symbol_map[symbol_name] = prev_value + + return input_edges, output_edges diff --git a/tests/next_tests/definitions.py b/tests/next_tests/definitions.py index 640fde6236..66058f5711 100644 --- a/tests/next_tests/definitions.py +++ b/tests/next_tests/definitions.py @@ -169,9 +169,13 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum): EmbeddedIds.NUMPY_EXECUTION: EMBEDDED_SKIP_LIST, EmbeddedIds.CUPY_EXECUTION: EMBEDDED_SKIP_LIST, OptionalProgramBackendId.DACE_CPU: DACE_SKIP_TEST_LIST - + [(USES_SCAN, XFAIL, UNSUPPORTED_MESSAGE)], # TODO(edopao): result validation fails with dace optimization + + [ + (USES_SCAN, XFAIL, UNSUPPORTED_MESSAGE) + ], # TODO(edopao): result validation fails with dace optimization OptionalProgramBackendId.DACE_GPU: DACE_SKIP_TEST_LIST - + [(USES_SCAN, XFAIL, UNSUPPORTED_MESSAGE)], # TODO(edopao): result validation fails with dace optimization + + [ + (USES_SCAN, XFAIL, UNSUPPORTED_MESSAGE) + ], # TODO(edopao): result validation fails with dace optimization OptionalProgramBackendId.DACE_CPU_NO_OPT: DACE_SKIP_TEST_LIST, OptionalProgramBackendId.DACE_GPU_NO_OPT: DACE_SKIP_TEST_LIST, ProgramBackendId.GTFN_CPU: GTFN_SKIP_TEST_LIST From c22cfc84fe6f7cd23943c57a27d492d12166e2ae Mon Sep 17 00:00:00 2001 From: Edoardo Paone Date: Fri, 6 Dec 2024 21:17:13 +0100 Subject: [PATCH 14/80] fix doctest --- .../runners/dace_fieldview/utility.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/utility.py b/src/gt4py/next/program_processors/runners/dace_fieldview/utility.py index 64f3ca04e7..33c333a9f3 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/utility.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/utility.py @@ -37,7 +37,10 @@ def make_symbol_tuple(tuple_name: str, tuple_type: ts.TupleType) -> tuple[gtir.S >>> sty = ts.ScalarType(kind=ts.ScalarKind.INT32) >>> fty = ts.FieldType(dims=[], dtype=ts.ScalarType(kind=ts.ScalarKind.FLOAT32)) >>> t = ts.TupleType(types=[sty, ts.TupleType(types=[fty, sty])]) - >>> assert get_tuple_fields("a", t) == [("a_0", sty), (("a_1_0", fty), ("a_1_1", sty))] + >>> assert make_symbol_tuple("a", t) == ( + ... im.sym("a_0", sty), + ... (im.sym("a_1_0", fty), im.sym("a_1_1", sty)), + ... ) """ fields = [(f"{tuple_name}_{i}", field_type) for i, field_type in enumerate(tuple_type.types)] return tuple( @@ -57,8 +60,11 @@ def flatten_tuple_fields(tuple_name: str, tuple_type: ts.TupleType) -> list[gtir >>> sty = ts.ScalarType(kind=ts.ScalarKind.INT32) >>> fty = ts.FieldType(dims=[], dtype=ts.ScalarType(kind=ts.ScalarKind.FLOAT32)) >>> t = ts.TupleType(types=[sty, ts.TupleType(types=[fty, sty])]) - >>> assert get_tuple_fields("a", t) == [("a_0", sty), ("a_1", ts.TupleType(types=[fty, sty]))] - >>> assert flatten_tuple_fields("a", t) == [("a_0", sty), ("a_1_0", fty), ("a_1_1", sty)] + >>> assert flatten_tuple_fields("a", t) == [ + ... im.sym("a_0", sty), + ... im.sym("a_1_0", fty), + ... im.sym("a_1_1", sty), + ... ] """ symbol_tuple = make_symbol_tuple(tuple_name, tuple_type) return list(gtx_utils.flatten_nested_tuple(symbol_tuple)) From 792a8ebef2f49677aab648d4ff6e0a5c8cdc4f7f Mon Sep 17 00:00:00 2001 From: Edoardo Paone Date: Mon, 9 Dec 2024 12:09:05 +0100 Subject: [PATCH 15/80] temporarily disable one optimize transformation --- .../dace_fieldview/transformations/auto_optimize.py | 3 ++- tests/next_tests/definitions.py | 10 ++-------- 2 files changed, 4 insertions(+), 9 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/auto_optimize.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/auto_optimize.py index 4a06d2f416..f50c70f3ce 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/auto_optimize.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/auto_optimize.py @@ -173,7 +173,8 @@ def gt_auto_optimize( sdfg.apply_transformations_repeated( [ gtx_transformations.GT4PyMoveTaskletIntoMap, - gtx_transformations.GT4PyMapBufferElimination(assume_pointwise=assume_pointwise), + # TODO(edopao): investigate correct mapping of stride symbols on scan output + # gtx_transformations.GT4PyMapBufferElimination(assume_pointwise=assume_pointwise), ], validate=validate, validate_all=validate_all, diff --git a/tests/next_tests/definitions.py b/tests/next_tests/definitions.py index bf045bc6b8..321ebb85c7 100644 --- a/tests/next_tests/definitions.py +++ b/tests/next_tests/definitions.py @@ -167,14 +167,8 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum): BACKEND_SKIP_TEST_MATRIX = { EmbeddedIds.NUMPY_EXECUTION: EMBEDDED_SKIP_LIST, EmbeddedIds.CUPY_EXECUTION: EMBEDDED_SKIP_LIST, - OptionalProgramBackendId.DACE_CPU: DACE_SKIP_TEST_LIST - + [ - (USES_SCAN, XFAIL, UNSUPPORTED_MESSAGE) - ], # TODO(edopao): result validation fails with dace optimization - OptionalProgramBackendId.DACE_GPU: DACE_SKIP_TEST_LIST - + [ - (USES_SCAN, XFAIL, UNSUPPORTED_MESSAGE) - ], # TODO(edopao): result validation fails with dace optimization + OptionalProgramBackendId.DACE_CPU: DACE_SKIP_TEST_LIST, + OptionalProgramBackendId.DACE_GPU: DACE_SKIP_TEST_LIST, OptionalProgramBackendId.DACE_CPU_NO_OPT: DACE_SKIP_TEST_LIST, OptionalProgramBackendId.DACE_GPU_NO_OPT: DACE_SKIP_TEST_LIST, ProgramBackendId.GTFN_CPU: GTFN_SKIP_TEST_LIST From 61985f71138a4867d166d7475fb71e7eb276bfa7 Mon Sep 17 00:00:00 2001 From: Edoardo Paone Date: Mon, 9 Dec 2024 13:24:32 +0100 Subject: [PATCH 16/80] Revert "temporarily disable one optimize transformation" This reverts commit 792a8ebef2f49677aab648d4ff6e0a5c8cdc4f7f. --- .../dace_fieldview/transformations/auto_optimize.py | 3 +-- tests/next_tests/definitions.py | 10 ++++++++-- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/auto_optimize.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/auto_optimize.py index f50c70f3ce..4a06d2f416 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/auto_optimize.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/auto_optimize.py @@ -173,8 +173,7 @@ def gt_auto_optimize( sdfg.apply_transformations_repeated( [ gtx_transformations.GT4PyMoveTaskletIntoMap, - # TODO(edopao): investigate correct mapping of stride symbols on scan output - # gtx_transformations.GT4PyMapBufferElimination(assume_pointwise=assume_pointwise), + gtx_transformations.GT4PyMapBufferElimination(assume_pointwise=assume_pointwise), ], validate=validate, validate_all=validate_all, diff --git a/tests/next_tests/definitions.py b/tests/next_tests/definitions.py index 321ebb85c7..bf045bc6b8 100644 --- a/tests/next_tests/definitions.py +++ b/tests/next_tests/definitions.py @@ -167,8 +167,14 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum): BACKEND_SKIP_TEST_MATRIX = { EmbeddedIds.NUMPY_EXECUTION: EMBEDDED_SKIP_LIST, EmbeddedIds.CUPY_EXECUTION: EMBEDDED_SKIP_LIST, - OptionalProgramBackendId.DACE_CPU: DACE_SKIP_TEST_LIST, - OptionalProgramBackendId.DACE_GPU: DACE_SKIP_TEST_LIST, + OptionalProgramBackendId.DACE_CPU: DACE_SKIP_TEST_LIST + + [ + (USES_SCAN, XFAIL, UNSUPPORTED_MESSAGE) + ], # TODO(edopao): result validation fails with dace optimization + OptionalProgramBackendId.DACE_GPU: DACE_SKIP_TEST_LIST + + [ + (USES_SCAN, XFAIL, UNSUPPORTED_MESSAGE) + ], # TODO(edopao): result validation fails with dace optimization OptionalProgramBackendId.DACE_CPU_NO_OPT: DACE_SKIP_TEST_LIST, OptionalProgramBackendId.DACE_GPU_NO_OPT: DACE_SKIP_TEST_LIST, ProgramBackendId.GTFN_CPU: GTFN_SKIP_TEST_LIST From aa236a2f2ad3daf0c457395304ead051489640d6 Mon Sep 17 00:00:00 2001 From: Edoardo Paone Date: Tue, 10 Dec 2024 18:37:55 +0100 Subject: [PATCH 17/80] fix for scan output stride --- .../gtir_builtin_translators.py | 25 +++++++++++++++---- .../transformations/simplify.py | 25 +++++++++++++++++++ tests/next_tests/definitions.py | 10 ++------ 3 files changed, 47 insertions(+), 13 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py index 524091cbe3..32bb3b218c 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py @@ -311,10 +311,13 @@ def _create_field_operator( domain_indices = _get_domain_indices(domain_dims, domain_offset) domain_subset = sbs.Range.from_indices(domain_indices) + scan_dim_index: Optional[int] = None if scan_dim is not None: - assert domain_dims.index(scan_dim) == (len(domain_dims) - 1) - # we construct the fieldo operator only on the horizontal domain - domain_subset = sbs.Range(domain_subset[:-1]) + scan_dim_index = domain_dims.index(scan_dim) + # we construct the field operator only on the horizontal domain + domain_subset = sbs.Range( + domain_subset[:scan_dim_index] + domain_subset[scan_dim_index + 1 :] + ) # now check, after removal of the vertical dimension, whether the domain is empty if len(domain_subset) == 0: @@ -352,7 +355,12 @@ def create_field(output_edge: gtir_dataflow.DataflowOutputEdge, sym: gtir.Sym) - assert isinstance(dataflow_output_desc, dace.data.Array) assert len(dataflow_output_desc.shape) == 1 # the vertical dimension should not belong to the field operator domain - field_subset = domain_subset + sbs.Range.from_array(dataflow_output_desc) + # but we need to write it to the output field + field_subset = ( + sbs.Range(domain_subset[:scan_dim_index]) + + sbs.Range.from_array(dataflow_output_desc) + + sbs.Range(domain_subset[scan_dim_index:]) + ) else: assert isinstance(dataflow_output_desc, dace.data.Scalar) field_subset = domain_subset @@ -371,9 +379,16 @@ def create_field(output_edge: gtir_dataflow.DataflowOutputEdge, sym: gtir.Sym) - field_subset = domain_subset + sbs.Range.from_array(dataflow_output_desc) # allocate local temporary storage - field_name, _ = sdfg.add_temp_transient(field_shape, dataflow_output_desc.dtype) + field_name, field_desc = sdfg.add_temp_transient(field_shape, dataflow_output_desc.dtype) field_node = state.add_access(field_name) + if scan_dim is not None: + # By default, we leave `strides=None` which corresponds to use DaCe default memory layout + # for transient arrays. However, for scan field operators we need to ensure that the same + # stride is used for the vertical dimension in inner and outer array. + scan_output_stride = field_desc.strides[scan_dim_index] + dataflow_output_desc.strides = (scan_output_stride,) + # and here the edge writing the dataflow result data through the map exit node output_edge.connect(mx, field_node, field_subset) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/simplify.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/simplify.py index 6b7bd1b6d5..958fe10a9b 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/simplify.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/simplify.py @@ -971,6 +971,31 @@ def apply( tmp_out_subset = dace_subsets.Range.from_array(tmp_desc) assert glob_in_subset is not None + map_exit_in_conn = map_to_tmp_edge.src_conn.replace("OUT_", "IN_") + src_to_map_exit_edge = next( + edge for edge in graph.in_edges(map_exit) if edge.dst_conn == map_exit_in_conn + ) + if isinstance(src_to_map_exit_edge.src, dace.nodes.NestedSDFG): + nsdfg_node = src_to_map_exit_edge.src + # We need to propagate the strides inside the nested SDFG + # TODO: the stride should be propagate recursively to nested SDFGs, if directly connected + new_strides = tuple( + stride + for stride, to_map_size, from_map_size in zip( + glob_ac.desc(sdfg).strides, + src_to_map_exit_edge.data.subset, + glob_in_subset, + strict=True, + ) + if to_map_size == from_map_size + ) + inner_data = src_to_map_exit_edge.src_conn + inner_desc = nsdfg_node.sdfg.arrays[inner_data] + inner_desc.set_shape(inner_desc.shape, new_strides) + for stride in new_strides: + for sym in stride.free_symbols: + nsdfg_node.sdfg.add_symbol(str(sym), sym.dtype) + nsdfg_node.symbol_mapping |= {str(sym): sym} # We now remove the `tmp` node, and create a new connection between # the global node and the map exit. new_map_to_glob_edge = graph.add_edge( diff --git a/tests/next_tests/definitions.py b/tests/next_tests/definitions.py index bf045bc6b8..321ebb85c7 100644 --- a/tests/next_tests/definitions.py +++ b/tests/next_tests/definitions.py @@ -167,14 +167,8 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum): BACKEND_SKIP_TEST_MATRIX = { EmbeddedIds.NUMPY_EXECUTION: EMBEDDED_SKIP_LIST, EmbeddedIds.CUPY_EXECUTION: EMBEDDED_SKIP_LIST, - OptionalProgramBackendId.DACE_CPU: DACE_SKIP_TEST_LIST - + [ - (USES_SCAN, XFAIL, UNSUPPORTED_MESSAGE) - ], # TODO(edopao): result validation fails with dace optimization - OptionalProgramBackendId.DACE_GPU: DACE_SKIP_TEST_LIST - + [ - (USES_SCAN, XFAIL, UNSUPPORTED_MESSAGE) - ], # TODO(edopao): result validation fails with dace optimization + OptionalProgramBackendId.DACE_CPU: DACE_SKIP_TEST_LIST, + OptionalProgramBackendId.DACE_GPU: DACE_SKIP_TEST_LIST, OptionalProgramBackendId.DACE_CPU_NO_OPT: DACE_SKIP_TEST_LIST, OptionalProgramBackendId.DACE_GPU_NO_OPT: DACE_SKIP_TEST_LIST, ProgramBackendId.GTFN_CPU: GTFN_SKIP_TEST_LIST From 9bdc75b5c34adab5fc0db29967b1f0c160c8d3d6 Mon Sep 17 00:00:00 2001 From: Edoardo Paone Date: Tue, 10 Dec 2024 19:55:20 +0100 Subject: [PATCH 18/80] fix previous commit --- .../transformations/simplify.py | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/simplify.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/simplify.py index 958fe10a9b..c4e9be3835 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/simplify.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/simplify.py @@ -971,31 +971,36 @@ def apply( tmp_out_subset = dace_subsets.Range.from_array(tmp_desc) assert glob_in_subset is not None + # Find the source of the edge entering the map exit node map_exit_in_conn = map_to_tmp_edge.src_conn.replace("OUT_", "IN_") src_to_map_exit_edge = next( edge for edge in graph.in_edges(map_exit) if edge.dst_conn == map_exit_in_conn ) if isinstance(src_to_map_exit_edge.src, dace.nodes.NestedSDFG): nsdfg_node = src_to_map_exit_edge.src - # We need to propagate the strides inside the nested SDFG - # TODO: the stride should be propagate recursively to nested SDFGs, if directly connected + # We need to propagate the strides inside the nested SDFG on the global arrays + # TODO: the stride should be propagated recursively to nested SDFGs, if directly connected new_strides = tuple( stride - for stride, to_map_size, from_map_size in zip( + for stride, to_map_size in zip( glob_ac.desc(sdfg).strides, - src_to_map_exit_edge.data.subset, - glob_in_subset, + src_to_map_exit_edge.data.subset.size(), strict=True, ) - if to_map_size == from_map_size + if to_map_size != 1 ) inner_data = src_to_map_exit_edge.src_conn inner_desc = nsdfg_node.sdfg.arrays[inner_data] - inner_desc.set_shape(inner_desc.shape, new_strides) + if isinstance(inner_desc, dace.data.Array): + inner_desc.set_shape(inner_desc.shape, new_strides) + else: + assert isinstance(inner_desc, dace.data.Scalar) + assert len(new_strides) == 0 for stride in new_strides: for sym in stride.free_symbols: nsdfg_node.sdfg.add_symbol(str(sym), sym.dtype) nsdfg_node.symbol_mapping |= {str(sym): sym} + # We now remove the `tmp` node, and create a new connection between # the global node and the map exit. new_map_to_glob_edge = graph.add_edge( From 746f9d8b9344fcbe899328c2f8d7aa925f7f1bab Mon Sep 17 00:00:00 2001 From: Edoardo Paone Date: Tue, 10 Dec 2024 20:15:54 +0100 Subject: [PATCH 19/80] converto scalar to array on nsdfg output --- .../runners/dace_fieldview/transformations/simplify.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/simplify.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/simplify.py index c4e9be3835..c8f80d9c97 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/simplify.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/simplify.py @@ -996,6 +996,10 @@ def apply( else: assert isinstance(inner_desc, dace.data.Scalar) assert len(new_strides) == 0 + # we convert the scalar data to array to avoid a gpu codegen error + nsdfg_node.sdfg.arrays[inner_data] = dace.data.Array( + inner_desc.dtype, (1,), inner_desc.transient + ) for stride in new_strides: for sym in stride.free_symbols: nsdfg_node.sdfg.add_symbol(str(sym), sym.dtype) From 0d894ffbe3016804b8405f166d74ba9e832fefc3 Mon Sep 17 00:00:00 2001 From: Edoardo Paone Date: Wed, 11 Dec 2024 10:08:49 +0100 Subject: [PATCH 20/80] Revert "converto scalar to array on nsdfg output" This reverts commit 746f9d8b9344fcbe899328c2f8d7aa925f7f1bab. --- .../runners/dace_fieldview/transformations/simplify.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/simplify.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/simplify.py index c8f80d9c97..c4e9be3835 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/simplify.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/simplify.py @@ -996,10 +996,6 @@ def apply( else: assert isinstance(inner_desc, dace.data.Scalar) assert len(new_strides) == 0 - # we convert the scalar data to array to avoid a gpu codegen error - nsdfg_node.sdfg.arrays[inner_data] = dace.data.Array( - inner_desc.dtype, (1,), inner_desc.transient - ) for stride in new_strides: for sym in stride.free_symbols: nsdfg_node.sdfg.add_symbol(str(sym), sym.dtype) From 440a474a35cd0c948565620cb6b34c2f747f9081 Mon Sep 17 00:00:00 2001 From: Edoardo Paone Date: Wed, 11 Dec 2024 11:59:20 +0100 Subject: [PATCH 21/80] Split handling of let-statement lambdas from stencil body --- .../gtir_builtin_translators.py | 26 +++--- .../runners/dace_fieldview/gtir_dataflow.py | 84 ++++++++++++++----- .../runners/dace_fieldview/gtir_sdfg.py | 5 +- 3 files changed, 81 insertions(+), 34 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py index ff011c4193..580ba64881 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py @@ -30,7 +30,7 @@ gtir_python_codegen, utility as dace_gtir_utils, ) -from gt4py.next.type_system import type_specifications as ts +from gt4py.next.type_system import type_info as ti, type_specifications as ts if TYPE_CHECKING: @@ -366,36 +366,36 @@ def translate_as_fieldop( """ assert isinstance(node, gtir.FunCall) assert cpm.is_call_to(node.fun, "as_fieldop") - assert isinstance(node.type, ts.FieldType) fun_node = node.fun assert len(fun_node.args) == 2 - stencil_expr, domain_expr = fun_node.args + fieldop_expr, domain_expr = fun_node.args - if isinstance(stencil_expr, gtir.Lambda): - # Default case, handled below: the argument expression is a lambda function - # representing the stencil operation to be computed over the field domain. - pass - elif cpm.is_ref_to(stencil_expr, "deref"): + assert isinstance(node.type, ts.FieldType) + if cpm.is_ref_to(fieldop_expr, "deref"): # Special usage of 'deref' as argument to fieldop expression, to pass a scalar # value to 'as_fieldop' function. It results in broadcasting the scalar value # over the field domain. stencil_expr = im.lambda_("a")(im.deref("a")) - stencil_expr.expr.type = node.type.dtype # type: ignore[attr-defined] + stencil_expr.expr.type = node.type.dtype + elif isinstance(fieldop_expr, gtir.Lambda): + # Default case, handled below: the argument expression is a lambda function + # representing the stencil operation to be computed over the field domain. + stencil_expr = fieldop_expr else: raise NotImplementedError( - f"Expression type '{type(stencil_expr)}' not supported as argument to 'as_fieldop' node." + f"Expression type '{type(fieldop_expr)}' not supported as argument to 'as_fieldop' node." ) # parse the domain of the field operator domain = extract_domain(domain_expr) # visit the list of arguments to be passed to the lambda expression - stencil_args = [_parse_fieldop_arg(arg, sdfg, state, sdfg_builder, domain) for arg in node.args] + fieldop_args = [_parse_fieldop_arg(arg, sdfg, state, sdfg_builder, domain) for arg in node.args] # represent the field operator as a mapped tasklet graph, which will range over the field domain taskgen = gtir_dataflow.LambdaToDataflow(sdfg, state, sdfg_builder) - input_edges, output_edge = taskgen.visit(stencil_expr, args=stencil_args) + input_edges, output_edge = taskgen.apply(stencil_expr, args=fieldop_args) return _create_field_operator( sdfg, state, domain, node.type, sdfg_builder, input_edges, output_edge @@ -654,7 +654,7 @@ def translate_tuple_get( if not isinstance(node.args[0], gtir.Literal): raise ValueError("Tuple can only be subscripted with compile-time constants.") - assert node.args[0].type == dace_utils.as_itir_type(INDEX_DTYPE) + assert ti.is_integral(node.args[0].type) index = int(node.args[0].value) data_nodes = sdfg_builder.visit( diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow.py index cfba4d61e5..c1b64e31da 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow.py @@ -10,10 +10,22 @@ import abc import dataclasses -from typing import Any, Dict, Final, List, Optional, Protocol, Set, Tuple, TypeAlias, Union +from typing import ( + Any, + Dict, + Final, + List, + Optional, + Protocol, + Sequence, + Set, + Tuple, + TypeAlias, + Union, +) import dace -import dace.subsets as sbs +from dace import subsets as sbs from gt4py import eve from gt4py.next import common as gtx_common @@ -68,7 +80,7 @@ class MemletExpr: dc_node: dace.nodes.AccessNode gt_dtype: itir_ts.ListType | ts.ScalarType - subset: sbs.Indices | sbs.Range + subset: sbs.Range @dataclasses.dataclass(frozen=True) @@ -1264,39 +1276,38 @@ def visit_FunCall(self, node: gtir.FunCall) -> IteratorExpr | DataExpr: elif cpm.is_applied_shift(node): return self._visit_shift(node) + elif isinstance(node.fun, gtir.Lambda): + raise AssertionError("Lambda node should be visited with 'apply()' method.") + elif isinstance(node.fun, gtir.SymRef): return self._visit_generic_builtin(node) else: raise NotImplementedError(f"Invalid 'FunCall' node: {node}.") - def visit_Lambda( - self, node: gtir.Lambda, args: list[IteratorExpr | MemletExpr | SymbolExpr] - ) -> tuple[list[DataflowInputEdge], DataflowOutputEdge]: - for p, arg in zip(node.params, args, strict=True): - self.symbol_map[str(p.id)] = arg - output_expr: DataExpr = self.visit(node.expr) - if isinstance(output_expr, ValueExpr): - return self.input_edges, DataflowOutputEdge(self.state, output_expr) + def visit_Lambda(self, node: gtir.Lambda) -> DataflowOutputEdge: + result: DataExpr = self.visit(node.expr) + + if isinstance(result, ValueExpr): + return DataflowOutputEdge(self.state, result) - if isinstance(output_expr, MemletExpr): + if isinstance(result, MemletExpr): # special case where the field operator is simply copying data from source to destination node - output_dtype = output_expr.dc_node.desc(self.sdfg).dtype + output_dtype = result.dc_node.desc(self.sdfg).dtype tasklet_node = self._add_tasklet("copy", {"__inp"}, {"__out"}, "__out = __inp") self._add_input_data_edge( - output_expr.dc_node, - output_expr.subset, + result.dc_node, + result.subset, tasklet_node, "__inp", ) else: - assert isinstance(output_expr, SymbolExpr) # even simpler case, where a constant value is written to destination node - output_dtype = output_expr.dc_dtype - tasklet_node = self._add_tasklet("write", {}, {"__out"}, f"__out = {output_expr.value}") + output_dtype = result.dc_dtype + tasklet_node = self._add_tasklet("write", {}, {"__out"}, f"__out = {result.value}") output_expr = self._construct_tasklet_result(output_dtype, tasklet_node, "__out") - return self.input_edges, DataflowOutputEdge(self.state, output_expr) + return DataflowOutputEdge(self.state, output_expr) def visit_Literal(self, node: gtir.Literal) -> SymbolExpr: dc_dtype = dace_utils.as_dace_type(node.type) @@ -1309,3 +1320,38 @@ def visit_SymRef(self, node: gtir.SymRef) -> IteratorExpr | MemletExpr | SymbolE # if not in the lambda symbol map, this must be a symref to a builtin function assert param in gtir_python_codegen.MATH_BUILTINS_MAPPING return SymbolExpr(param, dace.string) + + def apply( + self, + node: gtir.Lambda, + args: Sequence[IteratorExpr | MemletExpr | SymbolExpr], + ) -> tuple[list[DataflowInputEdge], DataflowOutputEdge]: + # lambda arguments are mapped to symbols defined in lambda scope + prev_symbols: dict[ + str, + Optional[IteratorExpr | MemletExpr | SymbolExpr], + ] = {} + for p, arg in zip(node.params, args, strict=True): + symbol_name = str(p.id) + prev_symbols[symbol_name] = self.symbol_map.get(symbol_name, None) + self.symbol_map[symbol_name] = arg + + if cpm.is_let(node.expr): + let_node = node.expr + let_args = [self.visit(arg) for arg in let_node.args] + assert isinstance(let_node.fun, gtir.Lambda) + input_edges, output_edge = self.apply(let_node.fun, args=let_args) + + else: + # this lambda is not a let-statement, but a stencil expression + output_edge = self.visit_Lambda(node) + input_edges = self.input_edges + + # remove locally defined lambda symbols and restore previous symbols + for symbol_name, prev_value in prev_symbols.items(): + if prev_value is None: + self.symbol_map.pop(symbol_name) + else: + self.symbol_map[symbol_name] = prev_value + + return input_edges, output_edge diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_sdfg.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_sdfg.py index 6b5e164458..9bd40f75f8 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_sdfg.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_sdfg.py @@ -602,7 +602,7 @@ def visit_Lambda( node: gtir.Lambda, sdfg: dace.SDFG, head_state: dace.SDFGState, - args: list[gtir_builtin_translators.FieldopResult], + args: Sequence[gtir_builtin_translators.FieldopResult], ) -> gtir_builtin_translators.FieldopResult: """ Translates a `Lambda` node to a nested SDFG in the current state. @@ -679,7 +679,7 @@ def get_field_domain_offset( self.offset_provider_type, lambda_symbols, lambda_field_offsets ) nsdfg = dace.SDFG(name=self.unique_nsdfg_name(sdfg, "lambda")) - nstate = nsdfg.add_state("lambda") + nsdfg.debuginfo = dace_utils.debug_info(node, default=sdfg.debuginfo) # add sdfg storage for the symbols that need to be passed as input parameters lambda_params = [ @@ -690,6 +690,7 @@ def get_field_domain_offset( nsdfg, node_params=lambda_params, symbolic_arguments=lambda_domain_symbols ) + nstate = nsdfg.add_state("lambda") lambda_result = lambda_translator.visit( node.expr, sdfg=nsdfg, From 500590be377588b3a4c0f8f66786d35c0bdc5622 Mon Sep 17 00:00:00 2001 From: Edoardo Paone Date: Wed, 11 Dec 2024 15:57:27 +0100 Subject: [PATCH 22/80] minor edit --- .../runners/dace_fieldview/transformations/gpu_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/gpu_utils.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/gpu_utils.py index 2cd3020180..7b14144ead 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/gpu_utils.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/gpu_utils.py @@ -95,7 +95,7 @@ def gt_gpu_transformation( if try_removing_trivial_maps: # In DaCe a Tasklet, outside of a Map, can not write into an _array_ that is on - # GPU. `sdfg.appyl_gpu_transformations()` will wrap such Tasklets in a Map. So + # GPU. `sdfg.apply_gpu_transformations()` will wrap such Tasklets in a Map. So # we might end up with lots of these trivial Maps, each requiring a separate # kernel launch. To prevent this we will combine these trivial maps, if # possible, with their downstream maps. From 5d5992a9a726d31f8f7c61254026cd81ccb658b4 Mon Sep 17 00:00:00 2001 From: Edoardo Paone Date: Thu, 12 Dec 2024 08:51:10 +0100 Subject: [PATCH 23/80] use dace auto-optimize on gpu --- .../runners/dace_fieldview/workflow.py | 33 +++++++++++++++---- 1 file changed, 26 insertions(+), 7 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/workflow.py b/src/gt4py/next/program_processors/runners/dace_fieldview/workflow.py index 407faf7ec1..f8d93bed9e 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/workflow.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/workflow.py @@ -13,6 +13,9 @@ from typing import Optional import dace +import dace.transformation +import dace.transformation.auto +import dace.transformation.auto.auto_optimize import factory from gt4py._core import definitions as core_defs @@ -56,15 +59,31 @@ def generate_sdfg( ir, common.offset_provider_to_type(offset_provider), column_dim ) + # TODO(phimuell): check auto-optimize pipeline on gpu for scan field operators + use_gtx_transformations = not on_gpu + if auto_opt: - gtx_transformations.gt_auto_optimize(sdfg, gpu=on_gpu) + if use_gtx_transformations: + gtx_transformations.gt_auto_optimize(sdfg, gpu=on_gpu) + else: + sdfg.simplify() + device_type = dace.dtypes.DeviceType.GPU if on_gpu else dace.dtypes.DeviceType.CPU + sdfg = dace.transformation.auto.auto_optimize.auto_optimize( + sdfg, device_type, use_gpu_storage=on_gpu + ) elif on_gpu: - # We run simplify to bring the SDFG into a canonical form that the gpu transformations - # can handle. This is a workaround for an issue with scalar expressions that are - # promoted to symbolic expressions and computed on the host (CPU), but the intermediate - # result is written to a GPU global variable (https://github.com/spcl/dace/issues/1773). - gtx_transformations.gt_simplify(sdfg) - gtx_transformations.gt_gpu_transformation(sdfg, try_removing_trivial_maps=True) + if use_gtx_transformations: + # We run simplify to bring the SDFG into a canonical form that the gpu transformations + # can handle. This is a workaround for an issue with scalar expressions that are + # promoted to symbolic expressions and computed on the host (CPU), but the intermediate + # result is written to a GPU global variable (https://github.com/spcl/dace/issues/1773). + gtx_transformations.gt_simplify(sdfg) + gtx_transformations.gt_gpu_transformation(sdfg, try_removing_trivial_maps=True) + else: + sdfg.simplify() + dace.transformation.auto.auto_optimize.apply_gpu_storage(sdfg) + sdfg.apply_gpu_transformations() + sdfg.simplify() return sdfg From eb173456935c58ee5677baefc531e5a0c974badc Mon Sep 17 00:00:00 2001 From: Edoardo Paone Date: Thu, 12 Dec 2024 09:22:02 +0100 Subject: [PATCH 24/80] Revert "use dace auto-optimize on gpu" This reverts commit 5d5992a9a726d31f8f7c61254026cd81ccb658b4. --- .../runners/dace_fieldview/workflow.py | 33 ++++--------------- 1 file changed, 7 insertions(+), 26 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/workflow.py b/src/gt4py/next/program_processors/runners/dace_fieldview/workflow.py index f8d93bed9e..407faf7ec1 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/workflow.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/workflow.py @@ -13,9 +13,6 @@ from typing import Optional import dace -import dace.transformation -import dace.transformation.auto -import dace.transformation.auto.auto_optimize import factory from gt4py._core import definitions as core_defs @@ -59,31 +56,15 @@ def generate_sdfg( ir, common.offset_provider_to_type(offset_provider), column_dim ) - # TODO(phimuell): check auto-optimize pipeline on gpu for scan field operators - use_gtx_transformations = not on_gpu - if auto_opt: - if use_gtx_transformations: - gtx_transformations.gt_auto_optimize(sdfg, gpu=on_gpu) - else: - sdfg.simplify() - device_type = dace.dtypes.DeviceType.GPU if on_gpu else dace.dtypes.DeviceType.CPU - sdfg = dace.transformation.auto.auto_optimize.auto_optimize( - sdfg, device_type, use_gpu_storage=on_gpu - ) + gtx_transformations.gt_auto_optimize(sdfg, gpu=on_gpu) elif on_gpu: - if use_gtx_transformations: - # We run simplify to bring the SDFG into a canonical form that the gpu transformations - # can handle. This is a workaround for an issue with scalar expressions that are - # promoted to symbolic expressions and computed on the host (CPU), but the intermediate - # result is written to a GPU global variable (https://github.com/spcl/dace/issues/1773). - gtx_transformations.gt_simplify(sdfg) - gtx_transformations.gt_gpu_transformation(sdfg, try_removing_trivial_maps=True) - else: - sdfg.simplify() - dace.transformation.auto.auto_optimize.apply_gpu_storage(sdfg) - sdfg.apply_gpu_transformations() - sdfg.simplify() + # We run simplify to bring the SDFG into a canonical form that the gpu transformations + # can handle. This is a workaround for an issue with scalar expressions that are + # promoted to symbolic expressions and computed on the host (CPU), but the intermediate + # result is written to a GPU global variable (https://github.com/spcl/dace/issues/1773). + gtx_transformations.gt_simplify(sdfg) + gtx_transformations.gt_gpu_transformation(sdfg, try_removing_trivial_maps=True) return sdfg From 8b163daa02ce69585c5424665197e2abb0512d76 Mon Sep 17 00:00:00 2001 From: Edoardo Paone Date: Thu, 12 Dec 2024 16:02:58 +0100 Subject: [PATCH 25/80] make map_strides recursive --- .../transformations/simplify.py | 49 +++++++++++++------ 1 file changed, 34 insertions(+), 15 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/simplify.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/simplify.py index c4e9be3835..3debb6a5eb 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/simplify.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/simplify.py @@ -971,36 +971,55 @@ def apply( tmp_out_subset = dace_subsets.Range.from_array(tmp_desc) assert glob_in_subset is not None - # Find the source of the edge entering the map exit node - map_exit_in_conn = map_to_tmp_edge.src_conn.replace("OUT_", "IN_") - src_to_map_exit_edge = next( - edge for edge in graph.in_edges(map_exit) if edge.dst_conn == map_exit_in_conn - ) - if isinstance(src_to_map_exit_edge.src, dace.nodes.NestedSDFG): - nsdfg_node = src_to_map_exit_edge.src + # Recursively visit the nested SDFGs for mapping from inner to outer strides on the vertical dimension + def map_strides(edge: dace.sdfg.graph.Edge, outer_node: dace.nodes.AccessNode) -> None: + if isinstance(edge.src, dace.nodes.MapExit): + # Find the source of the edge entering the map exit node + map_exit_in_conn = edge.src_conn.replace("OUT_", "IN_") + for edge_to_map_exit_edge in graph.in_edges_by_connector( + edge.src, map_exit_in_conn + ): + map_strides(edge_to_map_exit_edge, outer_node) + return + + if not isinstance(edge.src, dace.nodes.NestedSDFG): + return + # We need to propagate the strides inside the nested SDFG on the global arrays - # TODO: the stride should be propagated recursively to nested SDFGs, if directly connected + nsdfg_node = edge.src new_strides = tuple( stride for stride, to_map_size in zip( - glob_ac.desc(sdfg).strides, - src_to_map_exit_edge.data.subset.size(), + outer_node.desc(sdfg).strides, + edge.data.subset.size(), strict=True, ) if to_map_size != 1 ) - inner_data = src_to_map_exit_edge.src_conn + inner_data = edge.src_conn inner_desc = nsdfg_node.sdfg.arrays[inner_data] - if isinstance(inner_desc, dace.data.Array): - inner_desc.set_shape(inner_desc.shape, new_strides) - else: - assert isinstance(inner_desc, dace.data.Scalar) + assert not inner_desc.transient + + if isinstance(inner_desc, dace.data.Scalar): assert len(new_strides) == 0 + return + + assert isinstance(inner_desc, dace.data.Array) + inner_desc.set_shape(inner_desc.shape, new_strides) + for stride in new_strides: for sym in stride.free_symbols: nsdfg_node.sdfg.add_symbol(str(sym), sym.dtype) nsdfg_node.symbol_mapping |= {str(sym): sym} + for inner_state in nsdfg_node.sdfg.states(): + for inner_node in inner_state.data_nodes(): + if inner_node.data == inner_data: + for inner_edge in inner_state.in_edges(inner_node): + map_strides(inner_edge, inner_node) + + map_strides(map_to_tmp_edge, glob_ac) + # We now remove the `tmp` node, and create a new connection between # the global node and the map exit. new_map_to_glob_edge = graph.add_edge( From d15213a140bd173b3818dff963d1053ca699c8ac Mon Sep 17 00:00:00 2001 From: Edoardo Paone Date: Fri, 13 Dec 2024 11:30:17 +0100 Subject: [PATCH 26/80] rename module alias --- .../gtir_builtin_translators.py | 24 +++++----- .../runners/dace_fieldview/gtir_dataflow.py | 44 +++++++++++-------- .../runners/dace_fieldview/gtir_sdfg.py | 4 +- 3 files changed, 39 insertions(+), 33 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py index 32bb3b218c..47894f14c9 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py @@ -13,7 +13,7 @@ from typing import TYPE_CHECKING, Any, Final, Iterable, Optional, Protocol, Sequence, TypeAlias import dace -from dace import subsets as sbs +from dace import subsets as dace_subsets from gt4py.next import common as gtx_common, utils as gtx_utils from gt4py.next.ffront import fbuiltins as gtx_fbuiltins @@ -40,7 +40,7 @@ def _get_domain_indices( dims: Sequence[gtx_common.Dimension], offsets: Optional[Sequence[dace.symbolic.SymExpr]] = None -) -> sbs.Indices: +) -> dace_subsets.Indices: """ Helper function to construct the list of indices for a field domain, applying an optional offset in each dimension as start index. @@ -56,9 +56,9 @@ def _get_domain_indices( """ index_variables = [dace.symbolic.SymExpr(dace_gtir_utils.get_map_variable(dim)) for dim in dims] if offsets is None: - return sbs.Indices(index_variables) + return dace_subsets.Indices(index_variables) else: - return sbs.Indices( + return dace_subsets.Indices( [ index - offset if offset != 0 else index for index, offset in zip(index_variables, offsets, strict=True) @@ -97,7 +97,7 @@ def get_local_view( """Helper method to access a field in local view, given the compute domain of a field operator.""" if isinstance(self.gt_type, ts.ScalarType): return gtir_dataflow.MemletExpr( - dc_node=self.dc_node, gt_dtype=self.gt_type, subset=sbs.Indices([0]) + dc_node=self.dc_node, gt_dtype=self.gt_type, subset=dace_subsets.Indices([0]) ) if isinstance(self.gt_type, ts.FieldType): @@ -309,13 +309,13 @@ def _create_field_operator( """ domain_dims, domain_offset, domain_shape = _get_field_layout(domain) domain_indices = _get_domain_indices(domain_dims, domain_offset) - domain_subset = sbs.Range.from_indices(domain_indices) + domain_subset = dace_subsets.Range.from_indices(domain_indices) scan_dim_index: Optional[int] = None if scan_dim is not None: scan_dim_index = domain_dims.index(scan_dim) # we construct the field operator only on the horizontal domain - domain_subset = sbs.Range( + domain_subset = dace_subsets.Range( domain_subset[:scan_dim_index] + domain_subset[scan_dim_index + 1 :] ) @@ -357,9 +357,9 @@ def create_field(output_edge: gtir_dataflow.DataflowOutputEdge, sym: gtir.Sym) - # the vertical dimension should not belong to the field operator domain # but we need to write it to the output field field_subset = ( - sbs.Range(domain_subset[:scan_dim_index]) - + sbs.Range.from_array(dataflow_output_desc) - + sbs.Range(domain_subset[scan_dim_index:]) + dace_subsets.Range(domain_subset[:scan_dim_index]) + + dace_subsets.Range.from_array(dataflow_output_desc) + + dace_subsets.Range(domain_subset[scan_dim_index:]) ) else: assert isinstance(dataflow_output_desc, dace.data.Scalar) @@ -376,7 +376,7 @@ def create_field(output_edge: gtir_dataflow.DataflowOutputEdge, sym: gtir.Sym) - field_dims = [*domain_dims, output_edge.result.gt_dtype.offset_type] field_shape = [*domain_shape, dataflow_output_desc.shape[0]] field_offset = [*domain_offset, dataflow_output_desc.offset[0]] - field_subset = domain_subset + sbs.Range.from_array(dataflow_output_desc) + field_subset = domain_subset + dace_subsets.Range.from_array(dataflow_output_desc) # allocate local temporary storage field_name, field_desc = sdfg.add_temp_transient(field_shape, dataflow_output_desc.dtype) @@ -1112,7 +1112,7 @@ def connect_scan_output( input_edges = [] for input_connector, arg in lambda_flat_args.items(): arg_desc = arg.dc_node.desc(sdfg) - input_subset = sbs.Range.from_array(arg_desc) + input_subset = dace_subsets.Range.from_array(arg_desc) input_edge = gtir_dataflow.MemletInputEdge( state, arg.dc_node, input_subset, nsdfg_node, input_connector ) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow.py index fa54942049..1a394aec20 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow.py @@ -13,7 +13,7 @@ from typing import Any, Dict, Final, List, Optional, Protocol, Set, Tuple, TypeAlias, Union import dace -from dace import subsets as sbs +from dace import subsets as dace_subsets from gt4py import eve from gt4py.next import common as gtx_common, utils as gtx_utils @@ -69,7 +69,7 @@ class MemletExpr: dc_node: dace.nodes.AccessNode gt_dtype: itir_ts.ListType | ts.ScalarType - subset: sbs.Range + subset: dace_subsets.Range @dataclasses.dataclass(frozen=True) @@ -105,7 +105,7 @@ class IteratorExpr: field_domain: list[tuple[gtx_common.Dimension, dace.symbolic.SymExpr]] indices: dict[gtx_common.Dimension, DataExpr] - def get_memlet_subset(self, sdfg: dace.SDFG) -> sbs.Range: + def get_memlet_subset(self, sdfg: dace.SDFG) -> dace_subsets.Range: if not all(isinstance(self.indices[dim], SymbolExpr) for dim, _ in self.field_domain): raise ValueError(f"Cannot deref iterator {self}.") @@ -118,7 +118,7 @@ def get_memlet_subset(self, sdfg: dace.SDFG) -> sbs.Range: assert len(field_desc.shape) == len(self.field_domain) field_domain = self.field_domain - return sbs.Range.from_string( + return dace_subsets.Range.from_string( ",".join( str(self.indices[dim].value - offset) # type: ignore[union-attr] if dim in self.indices @@ -153,7 +153,7 @@ class MemletInputEdge(DataflowInputEdge): state: dace.SDFGState source: dace.nodes.AccessNode - subset: sbs.Range + subset: dace_subsets.Range dest: dace.nodes.AccessNode | dace.nodes.Tasklet dest_conn: Optional[str] @@ -207,7 +207,7 @@ def connect( self, map_exit: Optional[dace.nodes.MapExit], dest: dace.nodes.AccessNode, - subset: sbs.Range, + subset: dace_subsets.Range, ) -> None: # retrieve the node which writes the result last_node = self.state.in_edges(self.result.dc_node)[0].src @@ -310,7 +310,7 @@ def __init__( def _add_input_data_edge( self, src: dace.nodes.AccessNode, - src_subset: sbs.Range, + src_subset: dace_subsets.Range, dst_node: dace.nodes.Node, dst_conn: Optional[str] = None, src_offset: Optional[list[dace.symbolic.SymExpr]] = None, @@ -318,7 +318,7 @@ def _add_input_data_edge( input_subset = ( src_subset if src_offset is None - else sbs.Range( + else dace_subsets.Range( (start - off, stop - off, step) for (start, stop, step), off in zip(src_subset, src_offset, strict=True) ) @@ -529,7 +529,7 @@ def _visit_deref(self, node: gtir.FunCall) -> DataExpr: # add new termination point for the field parameter self._add_input_data_edge( arg_expr.field, - sbs.Range.from_array(field_desc), + dace_subsets.Range.from_array(field_desc), deref_node, "field", src_offset=[offset for (_, offset) in arg_expr.field_domain], @@ -600,7 +600,7 @@ def _visit_if(self, node: gtir.FunCall) -> ValueExpr | tuple[ValueExpr | tuple[A nsdfg.add_edge(entry_state, fstate, dace.InterstateEdge(condition="not (__cond)")) nsdfg_symbol_mapping = {} - input_memlets: dict[str, tuple[dace.nodes.AccessNode, Optional[sbs.Range]]] = {} + input_memlets: dict[str, tuple[dace.nodes.AccessNode, Optional[dace_subsets.Range]]] = {} if isinstance(condition_value, SymbolExpr): nsdfg.add_symbol("__cond", dace.dtypes.bool) @@ -608,7 +608,10 @@ def _visit_if(self, node: gtir.FunCall) -> ValueExpr | tuple[ValueExpr | tuple[A else: nsdfg.add_scalar("__cond", dace.dtypes.bool) if isinstance(condition_value, ValueExpr): - input_memlets["__cond"] = (condition_value.dc_node, sbs.Range.from_string("0")) + input_memlets["__cond"] = ( + condition_value.dc_node, + dace_subsets.Range.from_string("0"), + ) else: assert isinstance(condition_value, MemletExpr) input_memlets["__cond"] = (condition_value.dc_node, condition_value.subset) @@ -628,7 +631,7 @@ def visit_arg(arg: IteratorExpr | DataExpr) -> IteratorExpr | ValueExpr: if isinstance(arg, IteratorExpr): arg_node = arg.field arg_desc = arg_node.desc(self.sdfg) - arg_subset = sbs.Range.from_array(arg_desc) + arg_subset = dace_subsets.Range.from_array(arg_desc) else: assert isinstance(arg, (MemletExpr, ValueExpr)) @@ -815,7 +818,7 @@ def _visit_neighbors(self, node: gtir.FunCall) -> ValueExpr: MemletExpr( dc_node=it.field, gt_dtype=node.type, - subset=sbs.Range.from_string( + subset=dace_subsets.Range.from_string( ",".join( str(it.indices[dim].value - offset) # type: ignore[union-attr] if dim != offset_provider.codomain @@ -831,7 +834,7 @@ def _visit_neighbors(self, node: gtir.FunCall) -> ValueExpr: MemletExpr( dc_node=self.state.add_access(connectivity), gt_dtype=node.type, - subset=sbs.Range.from_string( + subset=dace_subsets.Range.from_string( f"{origin_index.value}, 0:{offset_provider.max_neighbors}" ), ) @@ -993,7 +996,7 @@ def _visit_map(self, node: gtir.FunCall) -> ValueExpr: gt_dtype=itir_ts.ListType( element_type=node.type.element_type, offset_type=offset_type ), - subset=sbs.Range.from_string( + subset=dace_subsets.Range.from_string( f"{origin_map_index}, 0:{offset_provider_type.max_neighbors}" ), ) @@ -1143,7 +1146,9 @@ def _make_reduce_with_skip_values( ) self._add_input_data_edge( connectivity_node, - sbs.Range.from_string(f"{origin_map_index}, 0:{offset_provider_type.max_neighbors}"), + dace_subsets.Range.from_string( + f"{origin_map_index}, 0:{offset_provider_type.max_neighbors}" + ), nsdfg_node, "neighbor_indices", ) @@ -1316,7 +1321,7 @@ def _make_dynamic_neighbor_offset( ) self._add_input_data_edge( offset_table_node, - sbs.Range.from_array(offset_table_node.desc(self.sdfg)), + dace_subsets.Range.from_array(offset_table_node.desc(self.sdfg)), tasklet_node, "table", ) @@ -1362,7 +1367,7 @@ def _make_unstructured_shift( shifted_indices[neighbor_dim] = MemletExpr( dc_node=offset_table_node, gt_dtype=it.gt_dtype, - subset=sbs.Indices([origin_index.value, offset_expr.value]), + subset=dace_subsets.Indices([origin_index.value, offset_expr.value]), ) else: # dynamic offset: we cannot use a memlet to retrieve the offset value, use a tasklet node @@ -1529,7 +1534,8 @@ def visit_FunCall( return self._visit_shift(node) elif isinstance(node.fun, gtir.Lambda): - raise AssertionError("Lambda node should be visited with 'apply()' method.") + # Lambda node should be visited with 'apply()' method. + raise ValueError(f"Unexpected lambda in 'FunCall' node: {node}.") elif isinstance(node.fun, gtir.SymRef): return self._visit_generic_builtin(node) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_sdfg.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_sdfg.py index ebf98e453e..30185697a5 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_sdfg.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_sdfg.py @@ -22,7 +22,7 @@ from typing import Any, Dict, Iterable, List, Optional, Protocol, Sequence, Set, Tuple, Union import dace -from dace.sdfg import utils as sdutils +from dace.sdfg import utils as dace_sdfg_utils from gt4py import eve from gt4py.eve import concepts @@ -899,6 +899,6 @@ def build_sdfg_from_gtir( assert isinstance(sdfg, dace.SDFG) # TODO(edopao): remove inlining when DaCe transformations support LoopRegion construct - sdutils.inline_loop_blocks(sdfg) + dace_sdfg_utils.inline_loop_blocks(sdfg) return sdfg From 55811dcfa187854bf98e0a18472bcf602a1dada1 Mon Sep 17 00:00:00 2001 From: Edoardo Paone Date: Fri, 13 Dec 2024 12:24:41 +0100 Subject: [PATCH 27/80] review comments --- .../gtir_builtin_translators.py | 14 ++-- .../runners/dace_fieldview/gtir_dataflow.py | 79 +++++++++++-------- 2 files changed, 52 insertions(+), 41 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py index 580ba64881..3d78cbbafe 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py @@ -13,7 +13,7 @@ from typing import TYPE_CHECKING, Final, Iterable, Optional, Protocol, Sequence, TypeAlias import dace -import dace.subsets as sbs +from dace import subsets as dace_subsets from gt4py.next import common as gtx_common, utils as gtx_utils from gt4py.next.ffront import fbuiltins as gtx_fbuiltins @@ -39,7 +39,7 @@ def _get_domain_indices( dims: Sequence[gtx_common.Dimension], offsets: Optional[Sequence[dace.symbolic.SymExpr]] = None -) -> sbs.Indices: +) -> dace_subsets.Indices: """ Helper function to construct the list of indices for a field domain, applying an optional offset in each dimension as start index. @@ -55,9 +55,9 @@ def _get_domain_indices( """ index_variables = [dace.symbolic.SymExpr(dace_gtir_utils.get_map_variable(dim)) for dim in dims] if offsets is None: - return sbs.Indices(index_variables) + return dace_subsets.Indices(index_variables) else: - return sbs.Indices( + return dace_subsets.Indices( [ index - offset if offset != 0 else index for index, offset in zip(index_variables, offsets, strict=True) @@ -96,7 +96,7 @@ def get_local_view( """Helper method to access a field in local view, given the compute domain of a field operator.""" if isinstance(self.gt_type, ts.ScalarType): return gtir_dataflow.MemletExpr( - dc_node=self.dc_node, gt_dtype=self.gt_type, subset=sbs.Indices([0]) + dc_node=self.dc_node, gt_dtype=self.gt_type, subset=dace_subsets.Indices([0]) ) if isinstance(self.gt_type, ts.FieldType): @@ -263,7 +263,7 @@ def _create_field_operator( dataflow_output_desc = output_edge.result.dc_node.desc(sdfg) - field_subset = sbs.Range.from_indices(field_indices) + field_subset = dace_subsets.Range.from_indices(field_indices) if isinstance(output_edge.result.gt_dtype, ts.ScalarType): assert output_edge.result.gt_dtype == node_type.dtype assert isinstance(dataflow_output_desc, dace.data.Scalar) @@ -280,7 +280,7 @@ def _create_field_operator( field_dims.append(output_edge.result.gt_dtype.offset_type) field_shape.extend(dataflow_output_desc.shape) field_offset.extend(dataflow_output_desc.offset) - field_subset = field_subset + sbs.Range.from_array(dataflow_output_desc) + field_subset = field_subset + dace_subsets.Range.from_array(dataflow_output_desc) # allocate local temporary storage field_name, _ = sdfg.add_temp_transient(field_shape, dataflow_output_desc.dtype) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow.py index c1b64e31da..b102e9638f 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow.py @@ -25,7 +25,7 @@ ) import dace -from dace import subsets as sbs +from dace import subsets as dace_subsets from gt4py import eve from gt4py.next import common as gtx_common @@ -80,7 +80,7 @@ class MemletExpr: dc_node: dace.nodes.AccessNode gt_dtype: itir_ts.ListType | ts.ScalarType - subset: sbs.Range + subset: dace_subsets.Range @dataclasses.dataclass(frozen=True) @@ -116,7 +116,7 @@ class IteratorExpr: field_domain: list[tuple[gtx_common.Dimension, dace.symbolic.SymExpr]] indices: dict[gtx_common.Dimension, DataExpr] - def get_memlet_subset(self, sdfg: dace.SDFG) -> sbs.Range: + def get_memlet_subset(self, sdfg: dace.SDFG) -> dace_subsets.Range: if not all(isinstance(self.indices[dim], SymbolExpr) for dim, _ in self.field_domain): raise ValueError(f"Cannot deref iterator {self}.") @@ -129,7 +129,7 @@ def get_memlet_subset(self, sdfg: dace.SDFG) -> sbs.Range: assert len(field_desc.shape) == len(self.field_domain) field_domain = self.field_domain - return sbs.Range.from_string( + return dace_subsets.Range.from_string( ",".join( str(self.indices[dim].value - offset) # type: ignore[union-attr] if dim in self.indices @@ -164,7 +164,7 @@ class MemletInputEdge(DataflowInputEdge): state: dace.SDFGState source: dace.nodes.AccessNode - subset: sbs.Range + subset: dace_subsets.Range dest: dace.nodes.AccessNode | dace.nodes.Tasklet dest_conn: Optional[str] @@ -214,7 +214,7 @@ def connect( self, mx: dace.nodes.MapExit, dest: dace.nodes.AccessNode, - subset: sbs.Range, + subset: dace_subsets.Range, ) -> None: # retrieve the node which writes the result last_node = self.state.in_edges(self.result.dc_node)[0].src @@ -270,8 +270,9 @@ def get_reduce_params(node: gtir.FunCall) -> tuple[str, SymbolExpr, SymbolExpr]: class LambdaToDataflow(eve.NodeVisitor): """ - Translates an `ir.Lambda` expression to a dataflow graph. + Visitor class to translate a `Lambda` expression to a dataflow graph. + This visitor should be applied by calling `apply()` method on a `Lambda` IR. The dataflow graph generated here typically represents the stencil function of a field operator. It only computes single elements or pure local fields, in case of neighbor values. In case of local fields, the dataflow contains @@ -305,7 +306,7 @@ def __init__( def _add_input_data_edge( self, src: dace.nodes.AccessNode, - src_subset: sbs.Range, + src_subset: dace_subsets.Range, dst_node: dace.nodes.Node, dst_conn: Optional[str] = None, src_offset: Optional[list[dace.symbolic.SymExpr]] = None, @@ -313,7 +314,7 @@ def _add_input_data_edge( input_subset = ( src_subset if src_offset is None - else sbs.Range( + else dace_subsets.Range( (start - off, stop - off, step) for (start, stop, step), off in zip(src_subset, src_offset, strict=True) ) @@ -524,7 +525,7 @@ def _visit_deref(self, node: gtir.FunCall) -> DataExpr: # add new termination point for the field parameter self._add_input_data_edge( arg_expr.field, - sbs.Range.from_array(field_desc), + dace_subsets.Range.from_array(field_desc), deref_node, "field", src_offset=[offset for (_, offset) in arg_expr.field_domain], @@ -592,7 +593,7 @@ def _visit_neighbors(self, node: gtir.FunCall) -> ValueExpr: MemletExpr( dc_node=it.field, gt_dtype=node.type, - subset=sbs.Range.from_string( + subset=dace_subsets.Range.from_string( ",".join( str(it.indices[dim].value - offset) # type: ignore[union-attr] if dim != offset_provider.codomain @@ -608,7 +609,7 @@ def _visit_neighbors(self, node: gtir.FunCall) -> ValueExpr: MemletExpr( dc_node=self.state.add_access(connectivity), gt_dtype=node.type, - subset=sbs.Range.from_string( + subset=dace_subsets.Range.from_string( f"{origin_index.value}, 0:{offset_provider.max_neighbors}" ), ) @@ -770,7 +771,7 @@ def _visit_map(self, node: gtir.FunCall) -> ValueExpr: gt_dtype=itir_ts.ListType( element_type=node.type.element_type, offset_type=offset_type ), - subset=sbs.Range.from_string( + subset=dace_subsets.Range.from_string( f"{origin_map_index}, 0:{offset_provider_type.max_neighbors}" ), ) @@ -920,7 +921,9 @@ def _make_reduce_with_skip_values( ) self._add_input_data_edge( connectivity_node, - sbs.Range.from_string(f"{origin_map_index}, 0:{offset_provider_type.max_neighbors}"), + dace_subsets.Range.from_string( + f"{origin_map_index}, 0:{offset_provider_type.max_neighbors}" + ), nsdfg_node, "neighbor_indices", ) @@ -1093,7 +1096,7 @@ def _make_dynamic_neighbor_offset( ) self._add_input_data_edge( offset_table_node, - sbs.Range.from_array(offset_table_node.desc(self.sdfg)), + dace_subsets.Range.from_array(offset_table_node.desc(self.sdfg)), tasklet_node, "table", ) @@ -1139,7 +1142,7 @@ def _make_unstructured_shift( shifted_indices[neighbor_dim] = MemletExpr( dc_node=offset_table_node, gt_dtype=it.gt_dtype, - subset=sbs.Indices([origin_index.value, offset_expr.value]), + subset=dace_subsets.Indices([origin_index.value, offset_expr.value]), ) else: # dynamic offset: we cannot use a memlet to retrieve the offset value, use a tasklet node @@ -1277,7 +1280,8 @@ def visit_FunCall(self, node: gtir.FunCall) -> IteratorExpr | DataExpr: return self._visit_shift(node) elif isinstance(node.fun, gtir.Lambda): - raise AssertionError("Lambda node should be visited with 'apply()' method.") + # Lambda node should be visited with 'apply()' method. + raise ValueError(f"Unexpected lambda in 'FunCall' node: {node}.") elif isinstance(node.fun, gtir.SymRef): return self._visit_generic_builtin(node) @@ -1326,32 +1330,39 @@ def apply( node: gtir.Lambda, args: Sequence[IteratorExpr | MemletExpr | SymbolExpr], ) -> tuple[list[DataflowInputEdge], DataflowOutputEdge]: - # lambda arguments are mapped to symbols defined in lambda scope - prev_symbols: dict[ - str, - Optional[IteratorExpr | MemletExpr | SymbolExpr], - ] = {} - for p, arg in zip(node.params, args, strict=True): - symbol_name = str(p.id) - prev_symbols[symbol_name] = self.symbol_map.get(symbol_name, None) - self.symbol_map[symbol_name] = arg + """ + Entry point for this visitor class. + + This visitor will translate a `Lambda` node into a dataflow graph to be + instantiated inside a map scope implementing the field operator. + However, this `apply()` method is responsible to recognize the usage of + the `Lambda` node, which can be either a let-statement or the stencil expression + in local view. The usage of a `Lambda` as let-statement corresponds to computing + some results and making them available inside the lambda scope, represented + as a nested SDFG. All let-statements, if any, are supposed to be encountered + before the stencil expression. In other words, the `Lambda` node representing + the stencil expression is always the innermost node. + Therefore, the lowering of let-statements results in recursive calls to + `apply()` until the stencil expression is found. At that point, it falls + back to the `visit()` function. + """ + + # lambda arguments are mapped to symbols defined in lambda scope. + prev_symbol_map = self.symbol_map + self.symbol_map = self.symbol_map.copy() + self.symbol_map |= {str(p.id): arg for p, arg in zip(node.params, args, strict=True)} if cpm.is_let(node.expr): let_node = node.expr let_args = [self.visit(arg) for arg in let_node.args] assert isinstance(let_node.fun, gtir.Lambda) input_edges, output_edge = self.apply(let_node.fun, args=let_args) - else: - # this lambda is not a let-statement, but a stencil expression - output_edge = self.visit_Lambda(node) + # this lambda node is not a let-statement, but a stencil expression + output_edge = self.visit(node) input_edges = self.input_edges # remove locally defined lambda symbols and restore previous symbols - for symbol_name, prev_value in prev_symbols.items(): - if prev_value is None: - self.symbol_map.pop(symbol_name) - else: - self.symbol_map[symbol_name] = prev_value + self.symbol_map = prev_symbol_map return input_edges, output_edge From f01d291e0130bc7b1224063a60763ce9c914254b Mon Sep 17 00:00:00 2001 From: Edoardo Paone Date: Fri, 13 Dec 2024 17:13:50 +0100 Subject: [PATCH 28/80] add test case for sdfg transformation --- .../transformations/simplify.py | 7 +- .../test_map_buffer_elimination.py | 93 ++++++++++++++++++- 2 files changed, 93 insertions(+), 7 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/simplify.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/simplify.py index 3debb6a5eb..f202f79ede 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/simplify.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/simplify.py @@ -1008,9 +1008,10 @@ def map_strides(edge: dace.sdfg.graph.Edge, outer_node: dace.nodes.AccessNode) - inner_desc.set_shape(inner_desc.shape, new_strides) for stride in new_strides: - for sym in stride.free_symbols: - nsdfg_node.sdfg.add_symbol(str(sym), sym.dtype) - nsdfg_node.symbol_mapping |= {str(sym): sym} + if isinstance(stride, dace.symbolic.symbol): + for sym in stride.free_symbols: + nsdfg_node.sdfg.add_symbol(str(sym), sym.dtype) + nsdfg_node.symbol_mapping |= {str(sym): sym} for inner_state in nsdfg_node.sdfg.states(): for inner_node in inner_state.data_nodes(): diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_map_buffer_elimination.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_map_buffer_elimination.py index 1a4ce6d047..a98eac3c2c 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_map_buffer_elimination.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_map_buffer_elimination.py @@ -22,10 +22,6 @@ import dace -def _make_test_data(names: list[str]) -> dict[str, np.ndarray]: - return {name: np.array(np.random.rand(10), dtype=np.float64, copy=True) for name in names} - - def _make_test_sdfg( output_name: str = "G", input_name: str = "G", @@ -262,3 +258,92 @@ def test_map_buffer_elimination_not_apply(): validate_all=True, ) assert count == 0 + + +def test_map_buffer_elimination_with_nested_sdfgs(): + """ + After removing a transient connected to a nested SDFG node, ensure that the strides + are propagated to the arrays in nested SDFG. + """ + + stride1, stride2, stride3 = [dace.symbol(f"stride{i}", dace.int32) for i in range(3)] + + # top-level sdfg + sdfg = dace.SDFG(util.unique_name("map_buffer")) + inp, inp_desc = sdfg.add_array("__inp", (10,), dace.float64) + out, out_desc = sdfg.add_array( + "__out", (10, 10, 10), dace.float64, strides=(stride1, stride2, stride3) + ) + tmp, _ = sdfg.add_temp_transient_like(out_desc) + state = sdfg.add_state() + tmp_node = state.add_access(tmp) + + nsdfg1 = dace.SDFG(util.unique_name("map_buffer")) + inp1, inp1_desc = nsdfg1.add_array("__inp", (10,), dace.float64) + out1, out1_desc = nsdfg1.add_array("__out", (10, 10), dace.float64) + tmp1, _ = nsdfg1.add_temp_transient_like(out1_desc) + state1 = nsdfg1.add_state() + tmp1_node = state1.add_access(tmp1) + + nsdfg2 = dace.SDFG(util.unique_name("map_buffer")) + inp2, _ = nsdfg2.add_array("__inp", (10,), dace.float64) + out2, out2_desc = nsdfg2.add_array("__out", (10,), dace.float64) + tmp2, _ = nsdfg2.add_temp_transient_like(out2_desc) + state2 = nsdfg2.add_state() + tmp2_node = state2.add_access(tmp2) + + state2.add_mapped_tasklet( + "broadcast2", + map_ranges={"__i": "0:10"}, + code="__oval = __ival + 1.0", + inputs={ + "__ival": dace.Memlet(f"{inp2}[__i]"), + }, + outputs={ + "__oval": dace.Memlet(f"{tmp2}[__i]"), + }, + output_nodes={tmp2_node}, + external_edges=True, + ) + state2.add_nedge(tmp2_node, state2.add_access(out2), dace.Memlet.from_array(out2, out2_desc)) + + nsdfg2_node = state1.add_nested_sdfg(nsdfg2, nsdfg1, inputs={"__inp"}, outputs={"__out"}) + me1, mx1 = state1.add_map("broadcast1", ndrange={"__i": "0:10"}) + state1.add_memlet_path( + state1.add_access(inp1), + me1, + nsdfg2_node, + dst_conn="__inp", + memlet=dace.Memlet.from_array(inp1, inp1_desc), + ) + state1.add_memlet_path( + nsdfg2_node, mx1, tmp1_node, src_conn="__out", memlet=dace.Memlet(f"{tmp1}[__i, 0:10]") + ) + state1.add_nedge(tmp1_node, state1.add_access(out1), dace.Memlet.from_array(out1, out1_desc)) + + nsdfg1_node = state.add_nested_sdfg(nsdfg1, sdfg, inputs={"__inp"}, outputs={"__out"}) + me, mx = state.add_map("broadcast", ndrange={"__i": "0:10"}) + state.add_memlet_path( + state.add_access(inp), + me, + nsdfg1_node, + dst_conn="__inp", + memlet=dace.Memlet.from_array(inp, inp_desc), + ) + state.add_memlet_path( + nsdfg1_node, mx, tmp_node, src_conn="__out", memlet=dace.Memlet(f"{tmp}[__i, 0:10, 0:10]") + ) + state.add_nedge(tmp_node, state.add_access(out), dace.Memlet.from_array(out, out_desc)) + + sdfg.validate() + + count = sdfg.apply_transformations_repeated( + gtx_transformations.GT4PyMapBufferElimination( + assume_pointwise=False, + ), + validate=True, + validate_all=True, + ) + assert count == 3 + assert out1_desc.strides == out_desc.strides[1:] + assert out2_desc.strides == out_desc.strides[2:] From 62e1648dd0d9f28fdd2404ca5f92db9209c14f7a Mon Sep 17 00:00:00 2001 From: Edoardo Paone Date: Mon, 16 Dec 2024 11:39:41 +0100 Subject: [PATCH 29/80] review comments (1) --- .../runners/dace_fieldview/gtir_dataflow.py | 65 +++++++++++++------ 1 file changed, 45 insertions(+), 20 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow.py index b102e9638f..8083dc6b09 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow.py @@ -1325,44 +1325,69 @@ def visit_SymRef(self, node: gtir.SymRef) -> IteratorExpr | MemletExpr | SymbolE assert param in gtir_python_codegen.MATH_BUILTINS_MAPPING return SymbolExpr(param, dace.string) - def apply( + def _visit_let( self, node: gtir.Lambda, args: Sequence[IteratorExpr | MemletExpr | SymbolExpr], - ) -> tuple[list[DataflowInputEdge], DataflowOutputEdge]: + ) -> DataflowOutputEdge: """ - Entry point for this visitor class. - - This visitor will translate a `Lambda` node into a dataflow graph to be - instantiated inside a map scope implementing the field operator. - However, this `apply()` method is responsible to recognize the usage of - the `Lambda` node, which can be either a let-statement or the stencil expression - in local view. The usage of a `Lambda` as let-statement corresponds to computing - some results and making them available inside the lambda scope, represented - as a nested SDFG. All let-statements, if any, are supposed to be encountered - before the stencil expression. In other words, the `Lambda` node representing - the stencil expression is always the innermost node. + Maps lambda arguments to internal parameters. + + This method is responsible to recognize the usage of the `Lambda` node, + which can be either a let-statement or the stencil expression in local view. + The usage of a `Lambda` as let-statement corresponds to computing some results + and making them available inside the lambda scope, represented as a nested SDFG. + All let-statements, if any, are supposed to be encountered before the stencil + expression. In other words, the `Lambda` node representing the stencil expression + is always the innermost node. Therefore, the lowering of let-statements results in recursive calls to - `apply()` until the stencil expression is found. At that point, it falls + `_visit_let()` until the stencil expression is found. At that point, it falls back to the `visit()` function. """ # lambda arguments are mapped to symbols defined in lambda scope. prev_symbol_map = self.symbol_map - self.symbol_map = self.symbol_map.copy() - self.symbol_map |= {str(p.id): arg for p, arg in zip(node.params, args, strict=True)} + self.symbol_map = self.symbol_map | { + str(p.id): arg for p, arg in zip(node.params, args, strict=True) + } if cpm.is_let(node.expr): let_node = node.expr let_args = [self.visit(arg) for arg in let_node.args] assert isinstance(let_node.fun, gtir.Lambda) - input_edges, output_edge = self.apply(let_node.fun, args=let_args) + output_edge = self._visit_let(let_node.fun, args=let_args) else: # this lambda node is not a let-statement, but a stencil expression output_edge = self.visit(node) - input_edges = self.input_edges - # remove locally defined lambda symbols and restore previous symbols + # restore previous symbols, thus removing locally defined lambda symbols in the let-scope self.symbol_map = prev_symbol_map - return input_edges, output_edge + return output_edge + + def apply( + self, + node: gtir.Lambda, + args: Sequence[IteratorExpr | MemletExpr | SymbolExpr], + ) -> tuple[list[DataflowInputEdge], DataflowOutputEdge]: + """ + Entry point for this visitor class, that will translate a `Lambda` node + into a dataflow graph to be instantiated inside a map scope implementing + the field operator. + + It calls `_visit_let()` to maps lambda arguments to internal parameters and + visit let-statements (if any), which always appear as outermost nodes. + The visitor will return the output edge of the dataflow. + + Args: + node: Lambda node to visit. + args: Arguments passed to lambda node. + + Returns: + A tuple of two elements: + - List of connections for data inputs to the dataflow. + - Output connection edge. + """ + + output_edge = self._visit_let(node, args) + return self.input_edges, output_edge From 72e8830b991745b2af33c0c452b3af9f90b829ec Mon Sep 17 00:00:00 2001 From: Edoardo Paone Date: Mon, 16 Dec 2024 13:53:48 +0100 Subject: [PATCH 30/80] review comments (2) --- .../gtir_builtin_translators.py | 5 +- .../runners/dace_fieldview/gtir_dataflow.py | 91 ++++++++----------- 2 files changed, 43 insertions(+), 53 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py index 3d78cbbafe..cffbd74c90 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py @@ -394,8 +394,9 @@ def translate_as_fieldop( fieldop_args = [_parse_fieldop_arg(arg, sdfg, state, sdfg_builder, domain) for arg in node.args] # represent the field operator as a mapped tasklet graph, which will range over the field domain - taskgen = gtir_dataflow.LambdaToDataflow(sdfg, state, sdfg_builder) - input_edges, output_edge = taskgen.apply(stencil_expr, args=fieldop_args) + input_edges, output_edge = gtir_dataflow.visit_lambda( + sdfg, state, sdfg_builder, stencil_expr, fieldop_args + ) return _create_field_operator( sdfg, state, domain, node.type, sdfg_builder, input_edges, output_edge diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow.py index 8083dc6b09..d0ebf9ee3b 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow.py @@ -268,6 +268,7 @@ def get_reduce_params(node: gtir.FunCall) -> tuple[str, SymbolExpr, SymbolExpr]: return op_name, reduce_init, reduce_identity +@dataclasses.dataclass(frozen=True) class LambdaToDataflow(eve.NodeVisitor): """ Visitor class to translate a `Lambda` expression to a dataflow graph. @@ -288,20 +289,10 @@ class LambdaToDataflow(eve.NodeVisitor): sdfg: dace.SDFG state: dace.SDFGState subgraph_builder: gtir_sdfg.DataflowBuilder - input_edges: list[DataflowInputEdge] - symbol_map: dict[str, IteratorExpr | MemletExpr | SymbolExpr] - - def __init__( - self, - sdfg: dace.SDFG, - state: dace.SDFGState, - subgraph_builder: gtir_sdfg.DataflowBuilder, - ): - self.sdfg = sdfg - self.state = state - self.subgraph_builder = subgraph_builder - self.input_edges = [] - self.symbol_map = {} + input_edges: list[DataflowInputEdge] = dataclasses.field(default_factory=lambda: []) + symbol_map: dict[str, IteratorExpr | MemletExpr | SymbolExpr] = dataclasses.field( + default_factory=lambda: {} + ) def _add_input_data_edge( self, @@ -1280,7 +1271,7 @@ def visit_FunCall(self, node: gtir.FunCall) -> IteratorExpr | DataExpr: return self._visit_shift(node) elif isinstance(node.fun, gtir.Lambda): - # Lambda node should be visited with 'apply()' method. + # Lambda node should be visited with 'visit_let()' method. raise ValueError(f"Unexpected lambda in 'FunCall' node: {node}.") elif isinstance(node.fun, gtir.SymRef): @@ -1325,7 +1316,7 @@ def visit_SymRef(self, node: gtir.SymRef) -> IteratorExpr | MemletExpr | SymbolE assert param in gtir_python_codegen.MATH_BUILTINS_MAPPING return SymbolExpr(param, dace.string) - def _visit_let( + def visit_let( self, node: gtir.Lambda, args: Sequence[IteratorExpr | MemletExpr | SymbolExpr], @@ -1341,53 +1332,51 @@ def _visit_let( expression. In other words, the `Lambda` node representing the stencil expression is always the innermost node. Therefore, the lowering of let-statements results in recursive calls to - `_visit_let()` until the stencil expression is found. At that point, it falls + `visit_let()` until the stencil expression is found. At that point, it falls back to the `visit()` function. """ # lambda arguments are mapped to symbols defined in lambda scope. - prev_symbol_map = self.symbol_map - self.symbol_map = self.symbol_map | { - str(p.id): arg for p, arg in zip(node.params, args, strict=True) - } + for p, arg in zip(node.params, args, strict=True): + self.symbol_map[str(p.id)] = arg if cpm.is_let(node.expr): let_node = node.expr let_args = [self.visit(arg) for arg in let_node.args] assert isinstance(let_node.fun, gtir.Lambda) - output_edge = self._visit_let(let_node.fun, args=let_args) + return self.visit_let(let_node.fun, args=let_args) else: # this lambda node is not a let-statement, but a stencil expression - output_edge = self.visit(node) + return self.visit(node) - # restore previous symbols, thus removing locally defined lambda symbols in the let-scope - self.symbol_map = prev_symbol_map - return output_edge +def visit_lambda( + sdfg: dace.SDFG, + state: dace.SDFGState, + sdfg_builder: gtir_sdfg.SDFGBuilder, + node: gtir.Lambda, + args: Sequence[IteratorExpr | MemletExpr | SymbolExpr], +) -> tuple[list[DataflowInputEdge], DataflowOutputEdge]: + """ + Entry point to visit a `Lambda` node and lower it to a dataflow graph, + that can be instantiated inside a map scope implementing the field operator. - def apply( - self, - node: gtir.Lambda, - args: Sequence[IteratorExpr | MemletExpr | SymbolExpr], - ) -> tuple[list[DataflowInputEdge], DataflowOutputEdge]: - """ - Entry point for this visitor class, that will translate a `Lambda` node - into a dataflow graph to be instantiated inside a map scope implementing - the field operator. - - It calls `_visit_let()` to maps lambda arguments to internal parameters and - visit let-statements (if any), which always appear as outermost nodes. - The visitor will return the output edge of the dataflow. - - Args: - node: Lambda node to visit. - args: Arguments passed to lambda node. - - Returns: - A tuple of two elements: - - List of connections for data inputs to the dataflow. - - Output connection edge. - """ + It calls `_visit_let()` to map the lambda arguments to internal parameters + and visit let-statements (if any), which always appear as outermost nodes. + The visitor will return the output edge of the dataflow. - output_edge = self._visit_let(node, args) - return self.input_edges, output_edge + Args: + sdfg: The SDFG where the dataflow graph will be instantiated. + state: The SDFG state where the dataflow graph will be instantiated. + sdfg_builder: Helper class to build the SDFG. + node: Lambda node to visit. + args: Arguments passed to lambda node. + + Returns: + A tuple of two elements: + - List of connections for data inputs to the dataflow. + - Output data connection. + """ + taskgen = LambdaToDataflow(sdfg, state, sdfg_builder) + output_edge = taskgen.visit_let(node, args) + return taskgen.input_edges, output_edge From de4a80e06e699bbf7779fe1dce38475c4054ae78 Mon Sep 17 00:00:00 2001 From: Edoardo Paone Date: Mon, 16 Dec 2024 13:53:48 +0100 Subject: [PATCH 31/80] review comments (2) --- .../gtir_builtin_translators.py | 5 +- .../runners/dace_fieldview/gtir_dataflow.py | 91 ++++++++----------- 2 files changed, 43 insertions(+), 53 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py index 3d78cbbafe..cffbd74c90 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py @@ -394,8 +394,9 @@ def translate_as_fieldop( fieldop_args = [_parse_fieldop_arg(arg, sdfg, state, sdfg_builder, domain) for arg in node.args] # represent the field operator as a mapped tasklet graph, which will range over the field domain - taskgen = gtir_dataflow.LambdaToDataflow(sdfg, state, sdfg_builder) - input_edges, output_edge = taskgen.apply(stencil_expr, args=fieldop_args) + input_edges, output_edge = gtir_dataflow.visit_lambda( + sdfg, state, sdfg_builder, stencil_expr, fieldop_args + ) return _create_field_operator( sdfg, state, domain, node.type, sdfg_builder, input_edges, output_edge diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow.py index 8083dc6b09..a3653fb519 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow.py @@ -268,6 +268,7 @@ def get_reduce_params(node: gtir.FunCall) -> tuple[str, SymbolExpr, SymbolExpr]: return op_name, reduce_init, reduce_identity +@dataclasses.dataclass(frozen=True) class LambdaToDataflow(eve.NodeVisitor): """ Visitor class to translate a `Lambda` expression to a dataflow graph. @@ -288,20 +289,10 @@ class LambdaToDataflow(eve.NodeVisitor): sdfg: dace.SDFG state: dace.SDFGState subgraph_builder: gtir_sdfg.DataflowBuilder - input_edges: list[DataflowInputEdge] - symbol_map: dict[str, IteratorExpr | MemletExpr | SymbolExpr] - - def __init__( - self, - sdfg: dace.SDFG, - state: dace.SDFGState, - subgraph_builder: gtir_sdfg.DataflowBuilder, - ): - self.sdfg = sdfg - self.state = state - self.subgraph_builder = subgraph_builder - self.input_edges = [] - self.symbol_map = {} + input_edges: list[DataflowInputEdge] = dataclasses.field(default_factory=lambda: []) + symbol_map: dict[str, IteratorExpr | MemletExpr | SymbolExpr] = dataclasses.field( + default_factory=lambda: {} + ) def _add_input_data_edge( self, @@ -1280,7 +1271,7 @@ def visit_FunCall(self, node: gtir.FunCall) -> IteratorExpr | DataExpr: return self._visit_shift(node) elif isinstance(node.fun, gtir.Lambda): - # Lambda node should be visited with 'apply()' method. + # Lambda node should be visited with 'visit_let()' method. raise ValueError(f"Unexpected lambda in 'FunCall' node: {node}.") elif isinstance(node.fun, gtir.SymRef): @@ -1325,7 +1316,7 @@ def visit_SymRef(self, node: gtir.SymRef) -> IteratorExpr | MemletExpr | SymbolE assert param in gtir_python_codegen.MATH_BUILTINS_MAPPING return SymbolExpr(param, dace.string) - def _visit_let( + def visit_let( self, node: gtir.Lambda, args: Sequence[IteratorExpr | MemletExpr | SymbolExpr], @@ -1341,53 +1332,51 @@ def _visit_let( expression. In other words, the `Lambda` node representing the stencil expression is always the innermost node. Therefore, the lowering of let-statements results in recursive calls to - `_visit_let()` until the stencil expression is found. At that point, it falls + `visit_let()` until the stencil expression is found. At that point, it falls back to the `visit()` function. """ # lambda arguments are mapped to symbols defined in lambda scope. - prev_symbol_map = self.symbol_map - self.symbol_map = self.symbol_map | { - str(p.id): arg for p, arg in zip(node.params, args, strict=True) - } + for p, arg in zip(node.params, args, strict=True): + self.symbol_map[str(p.id)] = arg if cpm.is_let(node.expr): let_node = node.expr let_args = [self.visit(arg) for arg in let_node.args] assert isinstance(let_node.fun, gtir.Lambda) - output_edge = self._visit_let(let_node.fun, args=let_args) + return self.visit_let(let_node.fun, args=let_args) else: # this lambda node is not a let-statement, but a stencil expression - output_edge = self.visit(node) + return self.visit(node) - # restore previous symbols, thus removing locally defined lambda symbols in the let-scope - self.symbol_map = prev_symbol_map - return output_edge +def visit_lambda( + sdfg: dace.SDFG, + state: dace.SDFGState, + sdfg_builder: gtir_sdfg.SDFGBuilder, + node: gtir.Lambda, + args: Sequence[IteratorExpr | MemletExpr | SymbolExpr], +) -> tuple[list[DataflowInputEdge], DataflowOutputEdge]: + """ + Entry point to visit a `Lambda` node and lower it to a dataflow graph, + that can be instantiated inside a map scope implementing the field operator. - def apply( - self, - node: gtir.Lambda, - args: Sequence[IteratorExpr | MemletExpr | SymbolExpr], - ) -> tuple[list[DataflowInputEdge], DataflowOutputEdge]: - """ - Entry point for this visitor class, that will translate a `Lambda` node - into a dataflow graph to be instantiated inside a map scope implementing - the field operator. - - It calls `_visit_let()` to maps lambda arguments to internal parameters and - visit let-statements (if any), which always appear as outermost nodes. - The visitor will return the output edge of the dataflow. - - Args: - node: Lambda node to visit. - args: Arguments passed to lambda node. - - Returns: - A tuple of two elements: - - List of connections for data inputs to the dataflow. - - Output connection edge. - """ + It calls `LambdaToDataflow.visit_let()` to map the lambda arguments to internal + parameters and visit the let-statements (if any), which always appear as outermost + nodes. Finally, the visitor returns the output edge of the dataflow. - output_edge = self._visit_let(node, args) - return self.input_edges, output_edge + Args: + sdfg: The SDFG where the dataflow graph will be instantiated. + state: The SDFG state where the dataflow graph will be instantiated. + sdfg_builder: Helper class to build the SDFG. + node: Lambda node to visit. + args: Arguments passed to lambda node. + + Returns: + A tuple of two elements: + - List of connections for data inputs to the dataflow. + - Output data connection. + """ + taskgen = LambdaToDataflow(sdfg, state, sdfg_builder) + output_edge = taskgen.visit_let(node, args) + return taskgen.input_edges, output_edge From 4b0ac602c247f5cb3a67bbe143b93a83878eb439 Mon Sep 17 00:00:00 2001 From: Edoardo Paone Date: Mon, 16 Dec 2024 15:48:27 +0100 Subject: [PATCH 32/80] Propagate strides to nested SDFG when changing transient strides --- .../transformations/__init__.py | 3 +- .../transformations/simplify.py | 71 ++++--------------- .../dace_fieldview/transformations/strides.py | 56 +++++++++++++++ 3 files changed, 70 insertions(+), 60 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/__init__.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/__init__.py index 4f3efb19b0..3f995c3db4 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/__init__.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/__init__.py @@ -35,7 +35,7 @@ gt_simplify, gt_substitute_compiletime_symbols, ) -from .strides import gt_change_transient_strides +from .strides import gt_change_transient_strides, gt_map_strides from .util import gt_find_constant_arguments, gt_make_transients_persistent @@ -59,6 +59,7 @@ "gt_gpu_transformation", "gt_inline_nested_sdfg", "gt_make_transients_persistent", + "gt_map_strides", "gt_reduce_distributed_buffering", "gt_set_gpu_blocksize", "gt_set_iteration_order", diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/simplify.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/simplify.py index f202f79ede..708834413e 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/simplify.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/simplify.py @@ -950,7 +950,7 @@ def _perform_pointwise_test( def apply( self, - graph: dace.SDFGState | dace.SDFG, + state: dace.SDFGState, sdfg: dace.SDFG, ) -> None: # Removal @@ -962,68 +962,21 @@ def apply( glob_ac: dace_nodes.AccessNode = self.glob_ac glob_data = glob_ac.data - map_to_tmp_edge = next(edge for edge in graph.in_edges(tmp_ac)) - tmp_to_glob_edge = next(edge for edge in graph.out_edges(tmp_ac)) + map_to_tmp_edge = next(edge for edge in state.in_edges(tmp_ac)) + tmp_to_glob_edge = next(edge for edge in state.out_edges(tmp_ac)) - glob_in_subset = tmp_to_glob_edge.data.get_dst_subset(tmp_to_glob_edge, graph) - tmp_out_subset = tmp_to_glob_edge.data.get_src_subset(tmp_to_glob_edge, graph) + glob_in_subset = tmp_to_glob_edge.data.get_dst_subset(tmp_to_glob_edge, state) + tmp_out_subset = tmp_to_glob_edge.data.get_src_subset(tmp_to_glob_edge, state) if tmp_out_subset is None: tmp_out_subset = dace_subsets.Range.from_array(tmp_desc) assert glob_in_subset is not None # Recursively visit the nested SDFGs for mapping from inner to outer strides on the vertical dimension - def map_strides(edge: dace.sdfg.graph.Edge, outer_node: dace.nodes.AccessNode) -> None: - if isinstance(edge.src, dace.nodes.MapExit): - # Find the source of the edge entering the map exit node - map_exit_in_conn = edge.src_conn.replace("OUT_", "IN_") - for edge_to_map_exit_edge in graph.in_edges_by_connector( - edge.src, map_exit_in_conn - ): - map_strides(edge_to_map_exit_edge, outer_node) - return - - if not isinstance(edge.src, dace.nodes.NestedSDFG): - return - - # We need to propagate the strides inside the nested SDFG on the global arrays - nsdfg_node = edge.src - new_strides = tuple( - stride - for stride, to_map_size in zip( - outer_node.desc(sdfg).strides, - edge.data.subset.size(), - strict=True, - ) - if to_map_size != 1 - ) - inner_data = edge.src_conn - inner_desc = nsdfg_node.sdfg.arrays[inner_data] - assert not inner_desc.transient - - if isinstance(inner_desc, dace.data.Scalar): - assert len(new_strides) == 0 - return - - assert isinstance(inner_desc, dace.data.Array) - inner_desc.set_shape(inner_desc.shape, new_strides) - - for stride in new_strides: - if isinstance(stride, dace.symbolic.symbol): - for sym in stride.free_symbols: - nsdfg_node.sdfg.add_symbol(str(sym), sym.dtype) - nsdfg_node.symbol_mapping |= {str(sym): sym} - - for inner_state in nsdfg_node.sdfg.states(): - for inner_node in inner_state.data_nodes(): - if inner_node.data == inner_data: - for inner_edge in inner_state.in_edges(inner_node): - map_strides(inner_edge, inner_node) - - map_strides(map_to_tmp_edge, glob_ac) + gtx_transformations.gt_map_strides(sdfg, state, map_to_tmp_edge, glob_ac) # We now remove the `tmp` node, and create a new connection between # the global node and the map exit. - new_map_to_glob_edge = graph.add_edge( + new_map_to_glob_edge = state.add_edge( map_exit, map_to_tmp_edge.src_conn, glob_ac, @@ -1033,9 +986,9 @@ def map_strides(edge: dace.sdfg.graph.Edge, outer_node: dace.nodes.AccessNode) - subset=copy.deepcopy(glob_in_subset), ), ) - graph.remove_edge(map_to_tmp_edge) - graph.remove_edge(tmp_to_glob_edge) - graph.remove_node(tmp_ac) + state.remove_edge(map_to_tmp_edge) + state.remove_edge(tmp_to_glob_edge) + state.remove_node(tmp_ac) # We can not unconditionally remove the data `tmp` refers to, because # it could be that in a parallel branch the `tmp` is also defined. @@ -1050,10 +1003,10 @@ def map_strides(edge: dace.sdfg.graph.Edge, outer_node: dace.nodes.AccessNode) - # offset. # NOTE: Assumes that `tmp_out_subset` and `tmp_in_subset` are the same. correcting_offset = glob_in_subset.offset_new(tmp_out_subset, negative=True) - mtree = graph.memlet_tree(new_map_to_glob_edge) + mtree = state.memlet_tree(new_map_to_glob_edge) for tree in mtree.traverse_children(include_self=False): curr_edge = tree.edge - curr_dst_subset = curr_edge.data.get_dst_subset(curr_edge, graph) + curr_dst_subset = curr_edge.data.get_dst_subset(curr_edge, state) if curr_edge.data.data == tmp_data: curr_edge.data.data = glob_data if curr_dst_subset is not None: diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py index 4e254f2880..40bc0fb984 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py @@ -64,6 +64,11 @@ def _gt_change_transient_strides_non_recursive_impl( # we simply have to reverse the order. new_stride_order = list(range(ndim)) desc.set_strides_from_layout(*new_stride_order) + for state in sdfg.states(): + for data_node in state.data_nodes(): + if data_node.data == top_level_transient: + for in_edge in state.in_edges(data_node): + gt_map_strides(sdfg, state, in_edge, data_node) def _find_toplevel_transients( @@ -97,3 +102,54 @@ def _find_toplevel_transients( continue top_level_transients.add(data) return top_level_transients + + +def gt_map_strides( + sdfg: dace.SDFG, + state: dace.SDFGState, + edge: dace.sdfg.graph.Edge, + outer_node: dace.nodes.AccessNode, +) -> None: + if isinstance(edge.src, dace.nodes.MapExit): + # Find the source of the edge entering the map exit node + map_exit_in_conn = edge.src_conn.replace("OUT_", "IN_") + for edge_to_map_exit_edge in state.in_edges_by_connector(edge.src, map_exit_in_conn): + gt_map_strides(sdfg, state, edge_to_map_exit_edge, outer_node) + return + + if not isinstance(edge.src, dace.nodes.NestedSDFG): + return + + # We need to propagate the strides inside the nested SDFG on the global arrays + nsdfg_node = edge.src + new_strides = tuple( + stride + for stride, to_map_size in zip( + outer_node.desc(sdfg).strides, + edge.data.subset.size(), + strict=True, + ) + if to_map_size != 1 + ) + inner_data = edge.src_conn + inner_desc = nsdfg_node.sdfg.arrays[inner_data] + assert not inner_desc.transient + + if isinstance(inner_desc, dace.data.Scalar): + assert len(new_strides) == 0 + return + + assert isinstance(inner_desc, dace.data.Array) + inner_desc.set_shape(inner_desc.shape, new_strides) + + for stride in new_strides: + if isinstance(stride, dace.symbolic.symbol): + for sym in stride.free_symbols: + nsdfg_node.sdfg.add_symbol(str(sym), sym.dtype) + nsdfg_node.symbol_mapping |= {str(sym): sym} + + for inner_state in nsdfg_node.sdfg.states(): + for inner_node in inner_state.data_nodes(): + if inner_node.data == inner_data: + for inner_edge in inner_state.in_edges(inner_node): + gt_map_strides(sdfg, state, inner_edge, inner_node) From f7016050b8dd4f5696eb452bd8a1e4fcb03b3ec8 Mon Sep 17 00:00:00 2001 From: Edoardo Paone Date: Mon, 16 Dec 2024 15:57:36 +0100 Subject: [PATCH 33/80] rename function --- .../transformations/__init__.py | 4 ++-- .../transformations/simplify.py | 2 +- .../dace_fieldview/transformations/strides.py | 19 +++++++++++++++---- 3 files changed, 18 insertions(+), 7 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/__init__.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/__init__.py index 3f995c3db4..7a4450aa4b 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/__init__.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/__init__.py @@ -35,7 +35,7 @@ gt_simplify, gt_substitute_compiletime_symbols, ) -from .strides import gt_change_transient_strides, gt_map_strides +from .strides import gt_change_transient_strides, gt_map_strides_to_nested_sdfg from .util import gt_find_constant_arguments, gt_make_transients_persistent @@ -59,7 +59,7 @@ "gt_gpu_transformation", "gt_inline_nested_sdfg", "gt_make_transients_persistent", - "gt_map_strides", + "gt_map_strides_to_nested_sdfg", "gt_reduce_distributed_buffering", "gt_set_gpu_blocksize", "gt_set_iteration_order", diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/simplify.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/simplify.py index 708834413e..c4b127ca67 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/simplify.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/simplify.py @@ -972,7 +972,7 @@ def apply( assert glob_in_subset is not None # Recursively visit the nested SDFGs for mapping from inner to outer strides on the vertical dimension - gtx_transformations.gt_map_strides(sdfg, state, map_to_tmp_edge, glob_ac) + gtx_transformations.gt_map_strides_to_nested_sdfg(sdfg, state, map_to_tmp_edge, glob_ac) # We now remove the `tmp` node, and create a new connection between # the global node and the map exit. diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py index 40bc0fb984..df909fb4cd 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py @@ -68,7 +68,7 @@ def _gt_change_transient_strides_non_recursive_impl( for data_node in state.data_nodes(): if data_node.data == top_level_transient: for in_edge in state.in_edges(data_node): - gt_map_strides(sdfg, state, in_edge, data_node) + gt_map_strides_to_nested_sdfg(sdfg, state, in_edge, data_node) def _find_toplevel_transients( @@ -104,17 +104,28 @@ def _find_toplevel_transients( return top_level_transients -def gt_map_strides( +def gt_map_strides_to_nested_sdfg( sdfg: dace.SDFG, state: dace.SDFGState, edge: dace.sdfg.graph.Edge, outer_node: dace.nodes.AccessNode, ) -> None: + """Propagates the strides of the given data node to the nested SDFGs. + + This function will recursively visit the nested SDFGs connected to the given + data node and apply mapping from inner to outer strides. + + Args: + sdfg: The SDFG to process. + state: The state where the data node is used. + edge: The edge that writes to the data node. + outer_node: The data node whose strides should be propagated. + """ if isinstance(edge.src, dace.nodes.MapExit): # Find the source of the edge entering the map exit node map_exit_in_conn = edge.src_conn.replace("OUT_", "IN_") for edge_to_map_exit_edge in state.in_edges_by_connector(edge.src, map_exit_in_conn): - gt_map_strides(sdfg, state, edge_to_map_exit_edge, outer_node) + gt_map_strides_to_nested_sdfg(sdfg, state, edge_to_map_exit_edge, outer_node) return if not isinstance(edge.src, dace.nodes.NestedSDFG): @@ -152,4 +163,4 @@ def gt_map_strides( for inner_node in inner_state.data_nodes(): if inner_node.data == inner_data: for inner_edge in inner_state.in_edges(inner_node): - gt_map_strides(sdfg, state, inner_edge, inner_node) + gt_map_strides_to_nested_sdfg(sdfg, state, inner_edge, inner_node) From a19019ffbcb909699d589b1e66ed7a38920b262d Mon Sep 17 00:00:00 2001 From: Edoardo Paone Date: Mon, 16 Dec 2024 16:23:10 +0100 Subject: [PATCH 34/80] fix bug --- .../runners/dace_fieldview/transformations/strides.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py index df909fb4cd..15c9e7085a 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py @@ -154,7 +154,7 @@ def gt_map_strides_to_nested_sdfg( inner_desc.set_shape(inner_desc.shape, new_strides) for stride in new_strides: - if isinstance(stride, dace.symbolic.symbol): + if dace.symbolic.issymbolic(stride): for sym in stride.free_symbols: nsdfg_node.sdfg.add_symbol(str(sym), sym.dtype) nsdfg_node.symbol_mapping |= {str(sym): sym} From c03492c84666d088a79a9c8ada6a65477d24c210 Mon Sep 17 00:00:00 2001 From: Edoardo Paone Date: Mon, 16 Dec 2024 16:46:04 +0100 Subject: [PATCH 35/80] fix previous commit --- .../dace_fieldview/transformations/strides.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py index 15c9e7085a..72dc2d4b6d 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py @@ -6,6 +6,8 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause +import functools + import dace from dace import data as dace_data @@ -153,11 +155,15 @@ def gt_map_strides_to_nested_sdfg( assert isinstance(inner_desc, dace.data.Array) inner_desc.set_shape(inner_desc.shape, new_strides) - for stride in new_strides: - if dace.symbolic.issymbolic(stride): - for sym in stride.free_symbols: - nsdfg_node.sdfg.add_symbol(str(sym), sym.dtype) - nsdfg_node.symbol_mapping |= {str(sym): sym} + new_strides_symbols: list[dace.symbol] = functools.reduce( + lambda acc, itm: acc + list(itm.free_symbols), new_strides, [] + ) + new_strides_free_symbols = { + sym for sym in new_strides_symbols if sym.name not in nsdfg_node.sdfg.symbols + } + for sym in new_strides_free_symbols: + nsdfg_node.sdfg.add_symbol(sym.name, sym.dtype) + nsdfg_node.symbol_mapping |= {sym.name: sym} for inner_state in nsdfg_node.sdfg.states(): for inner_node in inner_state.data_nodes(): From 310fcceb087a1a6e85e75ff08fa3e95ccfd84dc7 Mon Sep 17 00:00:00 2001 From: Edoardo Paone Date: Mon, 16 Dec 2024 17:04:09 +0100 Subject: [PATCH 36/80] Test commit --- .../runners/dace_fieldview/transformations/auto_optimize.py | 2 +- .../next/program_processors/runners/dace_fieldview/workflow.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/auto_optimize.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/auto_optimize.py index 4a06d2f416..e27e37499b 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/auto_optimize.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/auto_optimize.py @@ -269,7 +269,7 @@ def gt_auto_optimize( dace_aoptimize.move_small_arrays_to_stack(sdfg) # Now we modify the strides. - gtx_transformations.gt_change_transient_strides(sdfg, gpu=gpu) + # TODO: re-enable gtx_transformations.gt_change_transient_strides(sdfg, gpu=gpu) if make_persistent: gtx_transformations.gt_make_transients_persistent(sdfg=sdfg, device=device) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/workflow.py b/src/gt4py/next/program_processors/runners/dace_fieldview/workflow.py index 407faf7ec1..5f5a00712b 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/workflow.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/workflow.py @@ -63,7 +63,7 @@ def generate_sdfg( # can handle. This is a workaround for an issue with scalar expressions that are # promoted to symbolic expressions and computed on the host (CPU), but the intermediate # result is written to a GPU global variable (https://github.com/spcl/dace/issues/1773). - gtx_transformations.gt_simplify(sdfg) + sdfg.simplify() gtx_transformations.gt_gpu_transformation(sdfg, try_removing_trivial_maps=True) return sdfg From 4b487ea68cd6749b5242f3839551f1f2e8ffaba5 Mon Sep 17 00:00:00 2001 From: Edoardo Paone Date: Mon, 16 Dec 2024 17:55:21 +0100 Subject: [PATCH 37/80] propagate strides also to destination nested SDFG --- .../transformations/__init__.py | 9 +- .../transformations/auto_optimize.py | 2 +- .../transformations/simplify.py | 2 +- .../dace_fieldview/transformations/strides.py | 102 +++++++++++++----- .../runners/dace_fieldview/workflow.py | 2 +- 5 files changed, 86 insertions(+), 31 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/__init__.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/__init__.py index 7a4450aa4b..439084674e 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/__init__.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/__init__.py @@ -35,7 +35,11 @@ gt_simplify, gt_substitute_compiletime_symbols, ) -from .strides import gt_change_transient_strides, gt_map_strides_to_nested_sdfg +from .strides import ( + gt_change_transient_strides, + gt_map_strides_to_dst_nested_sdfg, + gt_map_strides_to_src_nested_sdfg, +) from .util import gt_find_constant_arguments, gt_make_transients_persistent @@ -59,7 +63,8 @@ "gt_gpu_transformation", "gt_inline_nested_sdfg", "gt_make_transients_persistent", - "gt_map_strides_to_nested_sdfg", + "gt_map_strides_to_dst_nested_sdfg", + "gt_map_strides_to_src_nested_sdfg", "gt_reduce_distributed_buffering", "gt_set_gpu_blocksize", "gt_set_iteration_order", diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/auto_optimize.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/auto_optimize.py index e27e37499b..4a06d2f416 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/auto_optimize.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/auto_optimize.py @@ -269,7 +269,7 @@ def gt_auto_optimize( dace_aoptimize.move_small_arrays_to_stack(sdfg) # Now we modify the strides. - # TODO: re-enable gtx_transformations.gt_change_transient_strides(sdfg, gpu=gpu) + gtx_transformations.gt_change_transient_strides(sdfg, gpu=gpu) if make_persistent: gtx_transformations.gt_make_transients_persistent(sdfg=sdfg, device=device) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/simplify.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/simplify.py index c4b127ca67..89aeda4740 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/simplify.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/simplify.py @@ -972,7 +972,7 @@ def apply( assert glob_in_subset is not None # Recursively visit the nested SDFGs for mapping from inner to outer strides on the vertical dimension - gtx_transformations.gt_map_strides_to_nested_sdfg(sdfg, state, map_to_tmp_edge, glob_ac) + gtx_transformations.gt_map_strides_to_src_nested_sdfg(sdfg, state, map_to_tmp_edge, glob_ac) # We now remove the `tmp` node, and create a new connection between # the global node and the map exit. diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py index 72dc2d4b6d..7c97e3fee9 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py @@ -70,7 +70,9 @@ def _gt_change_transient_strides_non_recursive_impl( for data_node in state.data_nodes(): if data_node.data == top_level_transient: for in_edge in state.in_edges(data_node): - gt_map_strides_to_nested_sdfg(sdfg, state, in_edge, data_node) + gt_map_strides_to_src_nested_sdfg(sdfg, state, in_edge, data_node) + for out_edge in state.out_edges(data_node): + gt_map_strides_to_dst_nested_sdfg(sdfg, state, out_edge, data_node) def _find_toplevel_transients( @@ -106,13 +108,13 @@ def _find_toplevel_transients( return top_level_transients -def gt_map_strides_to_nested_sdfg( +def gt_map_strides_to_dst_nested_sdfg( sdfg: dace.SDFG, state: dace.SDFGState, edge: dace.sdfg.graph.Edge, outer_node: dace.nodes.AccessNode, ) -> None: - """Propagates the strides of the given data node to the nested SDFGs. + """Propagates the strides of the given data node to the nested SDFGs on the edge destination. This function will recursively visit the nested SDFGs connected to the given data node and apply mapping from inner to outer strides. @@ -120,31 +122,81 @@ def gt_map_strides_to_nested_sdfg( Args: sdfg: The SDFG to process. state: The state where the data node is used. - edge: The edge that writes to the data node. + edge: The edge that reads from the data node, the nested SDFG is expected as the destination. + outer_node: The data node whose strides should be propagated. + """ + if isinstance(edge.dst, dace.nodes.MapEntry): + # Find the destinaion of the edge entering the map entry node + map_entry_out_conn = edge.dst_conn.replace("IN_", "OUT_") + for edge_from_map_entry in state.out_edges_by_connector(edge.dst, map_entry_out_conn): + gt_map_strides_to_dst_nested_sdfg(sdfg, state, edge_from_map_entry, outer_node) + return + + if not isinstance(edge.dst, dace.nodes.NestedSDFG): + return + + _gt_map_strides_to_nested_sdfg(sdfg, edge.dst, edge.dst_conn, edge.data, outer_node) + + for inner_state in edge.dst.sdfg.states(): + for inner_node in inner_state.data_nodes(): + if inner_node.data == edge.dst: + for inner_edge in inner_state.out_edges(inner_node): + gt_map_strides_to_dst_nested_sdfg(sdfg, state, inner_edge, inner_node) + + +def gt_map_strides_to_src_nested_sdfg( + sdfg: dace.SDFG, + state: dace.SDFGState, + edge: dace.sdfg.graph.Edge, + outer_node: dace.nodes.AccessNode, +) -> None: + """Propagates the strides of the given data node to the nested SDFGs on the edge source. + + This function will recursively visit the nested SDFGs connected to the given + data node and apply mapping from inner to outer strides. + + Args: + sdfg: The SDFG to process. + state: The state where the data node is used. + edge: The edge that writes to the data node, the nested SDFG is expected as the source. outer_node: The data node whose strides should be propagated. """ if isinstance(edge.src, dace.nodes.MapExit): # Find the source of the edge entering the map exit node map_exit_in_conn = edge.src_conn.replace("OUT_", "IN_") - for edge_to_map_exit_edge in state.in_edges_by_connector(edge.src, map_exit_in_conn): - gt_map_strides_to_nested_sdfg(sdfg, state, edge_to_map_exit_edge, outer_node) + for edge_to_map_exit in state.in_edges_by_connector(edge.src, map_exit_in_conn): + gt_map_strides_to_src_nested_sdfg(sdfg, state, edge_to_map_exit, outer_node) return if not isinstance(edge.src, dace.nodes.NestedSDFG): return + _gt_map_strides_to_nested_sdfg(sdfg, edge.src, edge.src_conn, edge.data, outer_node) + + for inner_state in edge.src.sdfg.states(): + for inner_node in inner_state.data_nodes(): + if inner_node.data == edge.src_conn: + for inner_edge in inner_state.in_edges(inner_node): + gt_map_strides_to_src_nested_sdfg(sdfg, state, inner_edge, inner_node) + + +def _gt_map_strides_to_nested_sdfg( + sdfg: dace.SDFG, + nsdfg_node: dace.nodes.NestedSDFG, + inner_data: str, + edge_data: dace.Memlet, + outer_node: dace.nodes.AccessNode, +) -> None: # We need to propagate the strides inside the nested SDFG on the global arrays - nsdfg_node = edge.src new_strides = tuple( stride for stride, to_map_size in zip( outer_node.desc(sdfg).strides, - edge.data.subset.size(), + edge_data.subset.size(), strict=True, ) if to_map_size != 1 ) - inner_data = edge.src_conn inner_desc = nsdfg_node.sdfg.arrays[inner_data] assert not inner_desc.transient @@ -153,20 +205,18 @@ def gt_map_strides_to_nested_sdfg( return assert isinstance(inner_desc, dace.data.Array) - inner_desc.set_shape(inner_desc.shape, new_strides) - - new_strides_symbols: list[dace.symbol] = functools.reduce( - lambda acc, itm: acc + list(itm.free_symbols), new_strides, [] - ) - new_strides_free_symbols = { - sym for sym in new_strides_symbols if sym.name not in nsdfg_node.sdfg.symbols - } - for sym in new_strides_free_symbols: - nsdfg_node.sdfg.add_symbol(sym.name, sym.dtype) - nsdfg_node.symbol_mapping |= {sym.name: sym} - - for inner_state in nsdfg_node.sdfg.states(): - for inner_node in inner_state.data_nodes(): - if inner_node.data == inner_data: - for inner_edge in inner_state.in_edges(inner_node): - gt_map_strides_to_nested_sdfg(sdfg, state, inner_edge, inner_node) + if all(isinstance(inner_stride, dace.symbol) for inner_stride in inner_desc.strides): + for inner_stride, outer_stride in zip(inner_desc.strides, new_strides, strict=True): + nsdfg_node.symbol_mapping[inner_stride.name] = outer_stride + else: + inner_desc.set_shape(inner_desc.shape, new_strides) + + new_strides_symbols: list[dace.symbol] = functools.reduce( + lambda acc, itm: acc + list(itm.free_symbols), new_strides, [] + ) + new_strides_free_symbols = { + sym for sym in new_strides_symbols if sym.name not in nsdfg_node.sdfg.symbols + } + for sym in new_strides_free_symbols: + nsdfg_node.sdfg.add_symbol(sym.name, sym.dtype) + nsdfg_node.symbol_mapping[sym.name] = sym diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/workflow.py b/src/gt4py/next/program_processors/runners/dace_fieldview/workflow.py index 5f5a00712b..407faf7ec1 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/workflow.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/workflow.py @@ -63,7 +63,7 @@ def generate_sdfg( # can handle. This is a workaround for an issue with scalar expressions that are # promoted to symbolic expressions and computed on the host (CPU), but the intermediate # result is written to a GPU global variable (https://github.com/spcl/dace/issues/1773). - sdfg.simplify() + gtx_transformations.gt_simplify(sdfg) gtx_transformations.gt_gpu_transformation(sdfg, try_removing_trivial_maps=True) return sdfg From 4cf66e728fc6a6048cd656079b571cadc8b61826 Mon Sep 17 00:00:00 2001 From: Edoardo Paone Date: Mon, 16 Dec 2024 18:24:14 +0100 Subject: [PATCH 38/80] fix previous commit (skip scalar inner nodes) --- .../runners/dace_fieldview/transformations/strides.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py index 7c97e3fee9..3f714e9d4f 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py @@ -171,6 +171,9 @@ def gt_map_strides_to_src_nested_sdfg( if not isinstance(edge.src, dace.nodes.NestedSDFG): return + if isinstance(edge.src.sdfg.data(edge.src_conn), dace.data.Scalar): + return # no strides to propagate + _gt_map_strides_to_nested_sdfg(sdfg, edge.src, edge.src_conn, edge.data, outer_node) for inner_state in edge.src.sdfg.states(): From ab7ee5f0d4c2ffe7c63d838bb6c72e6588c1c070 Mon Sep 17 00:00:00 2001 From: Edoardo Paone Date: Tue, 17 Dec 2024 09:09:31 +0100 Subject: [PATCH 39/80] fix - do not call free_symbols on int stride --- .../runners/dace_fieldview/transformations/strides.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py index 3f714e9d4f..70668c08d1 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py @@ -215,7 +215,9 @@ def _gt_map_strides_to_nested_sdfg( inner_desc.set_shape(inner_desc.shape, new_strides) new_strides_symbols: list[dace.symbol] = functools.reduce( - lambda acc, itm: acc + list(itm.free_symbols), new_strides, [] + lambda acc, itm: acc + list(itm.free_symbols) if dace.symbolic.issymbolic(itm) else acc, + new_strides, + [], ) new_strides_free_symbols = { sym for sym in new_strides_symbols if sym.name not in nsdfg_node.sdfg.symbols From 82cf4910b32c6c2a8e5de5f1b6f564e67915128a Mon Sep 17 00:00:00 2001 From: Edoardo Paone Date: Tue, 17 Dec 2024 09:38:31 +0100 Subject: [PATCH 40/80] run simplify before gpu transformations --- .../runners/dace_fieldview/transformations/gpu_utils.py | 8 ++++++++ .../program_processors/runners/dace_fieldview/workflow.py | 5 ----- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/gpu_utils.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/gpu_utils.py index 7b14144ead..b7d87d8217 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/gpu_utils.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/gpu_utils.py @@ -76,6 +76,14 @@ def gt_gpu_transformation( len(kwargs) == 0 ), f"gt_gpu_transformation(): found unknown arguments: {', '.join(arg for arg in kwargs.keys())}" + # We run simplify to bring the SDFG into a canonical form that the gpu transformations + # can handle. This is a workaround for an issue with scalar expressions that are + # promoted to symbolic expressions and computed on the host (CPU), but the intermediate + # result is written to a GPU global variable (https://github.com/spcl/dace/issues/1773). + dace_transformation.passes.SimplifyPass( + validate=validate, validate_all=validate_all, skip={"ConstantPropagation"} + ).apply_pass(sdfg, {}) + # Turn all global arrays (which we identify as input) into GPU memory. # This way the GPU transformation will not create this copying stuff. if use_gpu_storage: diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/workflow.py b/src/gt4py/next/program_processors/runners/dace_fieldview/workflow.py index 407faf7ec1..07fc6713f4 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/workflow.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/workflow.py @@ -59,11 +59,6 @@ def generate_sdfg( if auto_opt: gtx_transformations.gt_auto_optimize(sdfg, gpu=on_gpu) elif on_gpu: - # We run simplify to bring the SDFG into a canonical form that the gpu transformations - # can handle. This is a workaround for an issue with scalar expressions that are - # promoted to symbolic expressions and computed on the host (CPU), but the intermediate - # result is written to a GPU global variable (https://github.com/spcl/dace/issues/1773). - gtx_transformations.gt_simplify(sdfg) gtx_transformations.gt_gpu_transformation(sdfg, try_removing_trivial_maps=True) return sdfg From a0dbea54d1a9c288c4ec4e24b4c3e64ca097e6bb Mon Sep 17 00:00:00 2001 From: Edoardo Paone Date: Tue, 17 Dec 2024 10:01:13 +0100 Subject: [PATCH 41/80] undo renaming graph -> state --- .../transformations/gpu_utils.py | 5 ++-- .../transformations/simplify.py | 24 +++++++++---------- 2 files changed, 15 insertions(+), 14 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/gpu_utils.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/gpu_utils.py index b7d87d8217..80a1f77cb3 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/gpu_utils.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/gpu_utils.py @@ -78,8 +78,9 @@ def gt_gpu_transformation( # We run simplify to bring the SDFG into a canonical form that the gpu transformations # can handle. This is a workaround for an issue with scalar expressions that are - # promoted to symbolic expressions and computed on the host (CPU), but the intermediate - # result is written to a GPU global variable (https://github.com/spcl/dace/issues/1773). + # promoted to symbolic expressions and accessed on the host (CPU) as arguments of + # interstate edge conditions, but the scalar data is stored as a GPU global variable. + # For details, see the dace issue https://github.com/spcl/dace/issues/1773 dace_transformation.passes.SimplifyPass( validate=validate, validate_all=validate_all, skip={"ConstantPropagation"} ).apply_pass(sdfg, {}) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/simplify.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/simplify.py index 89aeda4740..1a132cacb2 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/simplify.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/simplify.py @@ -950,7 +950,7 @@ def _perform_pointwise_test( def apply( self, - state: dace.SDFGState, + graph: dace.SDFGState, sdfg: dace.SDFG, ) -> None: # Removal @@ -962,21 +962,21 @@ def apply( glob_ac: dace_nodes.AccessNode = self.glob_ac glob_data = glob_ac.data - map_to_tmp_edge = next(edge for edge in state.in_edges(tmp_ac)) - tmp_to_glob_edge = next(edge for edge in state.out_edges(tmp_ac)) + map_to_tmp_edge = next(edge for edge in graph.in_edges(tmp_ac)) + tmp_to_glob_edge = next(edge for edge in graph.out_edges(tmp_ac)) - glob_in_subset = tmp_to_glob_edge.data.get_dst_subset(tmp_to_glob_edge, state) - tmp_out_subset = tmp_to_glob_edge.data.get_src_subset(tmp_to_glob_edge, state) + glob_in_subset = tmp_to_glob_edge.data.get_dst_subset(tmp_to_glob_edge, graph) + tmp_out_subset = tmp_to_glob_edge.data.get_src_subset(tmp_to_glob_edge, graph) if tmp_out_subset is None: tmp_out_subset = dace_subsets.Range.from_array(tmp_desc) assert glob_in_subset is not None # Recursively visit the nested SDFGs for mapping from inner to outer strides on the vertical dimension - gtx_transformations.gt_map_strides_to_src_nested_sdfg(sdfg, state, map_to_tmp_edge, glob_ac) + gtx_transformations.gt_map_strides_to_src_nested_sdfg(sdfg, graph, map_to_tmp_edge, glob_ac) # We now remove the `tmp` node, and create a new connection between # the global node and the map exit. - new_map_to_glob_edge = state.add_edge( + new_map_to_glob_edge = graph.add_edge( map_exit, map_to_tmp_edge.src_conn, glob_ac, @@ -986,9 +986,9 @@ def apply( subset=copy.deepcopy(glob_in_subset), ), ) - state.remove_edge(map_to_tmp_edge) - state.remove_edge(tmp_to_glob_edge) - state.remove_node(tmp_ac) + graph.remove_edge(map_to_tmp_edge) + graph.remove_edge(tmp_to_glob_edge) + graph.remove_node(tmp_ac) # We can not unconditionally remove the data `tmp` refers to, because # it could be that in a parallel branch the `tmp` is also defined. @@ -1003,10 +1003,10 @@ def apply( # offset. # NOTE: Assumes that `tmp_out_subset` and `tmp_in_subset` are the same. correcting_offset = glob_in_subset.offset_new(tmp_out_subset, negative=True) - mtree = state.memlet_tree(new_map_to_glob_edge) + mtree = graph.memlet_tree(new_map_to_glob_edge) for tree in mtree.traverse_children(include_self=False): curr_edge = tree.edge - curr_dst_subset = curr_edge.data.get_dst_subset(curr_edge, state) + curr_dst_subset = curr_edge.data.get_dst_subset(curr_edge, graph) if curr_edge.data.data == tmp_data: curr_edge.data.data = glob_data if curr_dst_subset is not None: From 9128ffbf70179bde49c714267ec35849b4eb8fbf Mon Sep 17 00:00:00 2001 From: Edoardo Paone Date: Tue, 17 Dec 2024 10:01:34 +0100 Subject: [PATCH 42/80] increase slurm timeout to 20 minutes --- ci/cscs-ci.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ci/cscs-ci.yml b/ci/cscs-ci.yml index 7adb88459e..943b264bda 100644 --- a/ci/cscs-ci.yml +++ b/ci/cscs-ci.yml @@ -156,7 +156,7 @@ build_py38_image_x86_64: variables: CRAY_CUDA_MPS: 1 SLURM_JOB_NUM_NODES: 1 - SLURM_TIMELIMIT: 15 + SLURM_TIMELIMIT: 20 NUM_PROCESSES: auto VIRTUALENV_SYSTEM_SITE_PACKAGES: 1 .test_helper_x86_64: From f940c4ee0f903c33c7c736c0056810640ae03af9 Mon Sep 17 00:00:00 2001 From: Edoardo Paone Date: Tue, 17 Dec 2024 10:28:27 +0100 Subject: [PATCH 43/80] increase slurm timeout to 30 minutes --- ci/cscs-ci.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ci/cscs-ci.yml b/ci/cscs-ci.yml index 943b264bda..595bcac034 100644 --- a/ci/cscs-ci.yml +++ b/ci/cscs-ci.yml @@ -156,7 +156,7 @@ build_py38_image_x86_64: variables: CRAY_CUDA_MPS: 1 SLURM_JOB_NUM_NODES: 1 - SLURM_TIMELIMIT: 20 + SLURM_TIMELIMIT: 30 NUM_PROCESSES: auto VIRTUALENV_SYSTEM_SITE_PACKAGES: 1 .test_helper_x86_64: From cc0777b41153ade119d442cfb4505af9c8125127 Mon Sep 17 00:00:00 2001 From: Edoardo Paone Date: Tue, 17 Dec 2024 10:30:16 +0100 Subject: [PATCH 44/80] minor edit --- .../dace_fieldview/transformations/strides.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py index 70668c08d1..72a1916875 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py @@ -7,6 +7,7 @@ # SPDX-License-Identifier: BSD-3-Clause import functools +from typing import Iterable import dace from dace import data as dace_data @@ -135,7 +136,8 @@ def gt_map_strides_to_dst_nested_sdfg( if not isinstance(edge.dst, dace.nodes.NestedSDFG): return - _gt_map_strides_to_nested_sdfg(sdfg, edge.dst, edge.dst_conn, edge.data, outer_node) + outer_strides = outer_node.desc(sdfg).strides + _gt_map_strides_to_nested_sdfg(edge.dst, edge.dst_conn, edge.data, outer_strides) for inner_state in edge.dst.sdfg.states(): for inner_node in inner_state.data_nodes(): @@ -174,7 +176,8 @@ def gt_map_strides_to_src_nested_sdfg( if isinstance(edge.src.sdfg.data(edge.src_conn), dace.data.Scalar): return # no strides to propagate - _gt_map_strides_to_nested_sdfg(sdfg, edge.src, edge.src_conn, edge.data, outer_node) + outer_strides = outer_node.desc(sdfg).strides + _gt_map_strides_to_nested_sdfg(edge.src, edge.src_conn, edge.data, outer_strides) for inner_state in edge.src.sdfg.states(): for inner_node in inner_state.data_nodes(): @@ -184,17 +187,16 @@ def gt_map_strides_to_src_nested_sdfg( def _gt_map_strides_to_nested_sdfg( - sdfg: dace.SDFG, nsdfg_node: dace.nodes.NestedSDFG, inner_data: str, edge_data: dace.Memlet, - outer_node: dace.nodes.AccessNode, + outer_strides: Iterable[int | dace.symbolic.SymExpr], ) -> None: # We need to propagate the strides inside the nested SDFG on the global arrays new_strides = tuple( stride for stride, to_map_size in zip( - outer_node.desc(sdfg).strides, + outer_strides, edge_data.subset.size(), strict=True, ) @@ -215,7 +217,9 @@ def _gt_map_strides_to_nested_sdfg( inner_desc.set_shape(inner_desc.shape, new_strides) new_strides_symbols: list[dace.symbol] = functools.reduce( - lambda acc, itm: acc + list(itm.free_symbols) if dace.symbolic.issymbolic(itm) else acc, + lambda acc, itm: (acc + list(itm.free_symbols)) # type: ignore[union-attr] + if dace.symbolic.issymbolic(itm) + else acc, new_strides, [], ) From 462f3c5d6f83783556f2c11699147a6aa459f98f Mon Sep 17 00:00:00 2001 From: Edoardo Paone Date: Tue, 17 Dec 2024 11:08:01 +0100 Subject: [PATCH 45/80] exclude test_ternary_scan from gpu tests --- ci/cscs-ci.yml | 2 +- pyproject.toml | 1 + .../dace_fieldview/transformations/gpu_utils.py | 9 --------- .../runners/dace_fieldview/workflow.py | 5 +++++ tests/next_tests/definitions.py | 13 +++++++++++-- .../feature_tests/ffront_tests/test_execution.py | 1 + 6 files changed, 19 insertions(+), 12 deletions(-) diff --git a/ci/cscs-ci.yml b/ci/cscs-ci.yml index 595bcac034..7adb88459e 100644 --- a/ci/cscs-ci.yml +++ b/ci/cscs-ci.yml @@ -156,7 +156,7 @@ build_py38_image_x86_64: variables: CRAY_CUDA_MPS: 1 SLURM_JOB_NUM_NODES: 1 - SLURM_TIMELIMIT: 30 + SLURM_TIMELIMIT: 15 NUM_PROCESSES: auto VIRTUALENV_SYSTEM_SITE_PACKAGES: 1 .test_helper_x86_64: diff --git a/pyproject.toml b/pyproject.toml index d086363ec4..08203bf25a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -254,6 +254,7 @@ markers = [ 'uses_scan_without_field_args: tests that require calls to scan that do not have any fields as arguments', 'uses_scan_nested: tests that use nested scans', 'uses_scan_requiring_projector: tests need a projector implementation in gtfn', + 'uses_scan_1d_field: tests scan on a 1D vertical field', 'uses_sparse_fields: tests that require backend support for sparse fields', 'uses_sparse_fields_as_output: tests that require backend support for writing sparse fields', 'uses_strided_neighbor_offset: tests that require backend support for strided neighbor offset', diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/gpu_utils.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/gpu_utils.py index 80a1f77cb3..7b14144ead 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/gpu_utils.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/gpu_utils.py @@ -76,15 +76,6 @@ def gt_gpu_transformation( len(kwargs) == 0 ), f"gt_gpu_transformation(): found unknown arguments: {', '.join(arg for arg in kwargs.keys())}" - # We run simplify to bring the SDFG into a canonical form that the gpu transformations - # can handle. This is a workaround for an issue with scalar expressions that are - # promoted to symbolic expressions and accessed on the host (CPU) as arguments of - # interstate edge conditions, but the scalar data is stored as a GPU global variable. - # For details, see the dace issue https://github.com/spcl/dace/issues/1773 - dace_transformation.passes.SimplifyPass( - validate=validate, validate_all=validate_all, skip={"ConstantPropagation"} - ).apply_pass(sdfg, {}) - # Turn all global arrays (which we identify as input) into GPU memory. # This way the GPU transformation will not create this copying stuff. if use_gpu_storage: diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/workflow.py b/src/gt4py/next/program_processors/runners/dace_fieldview/workflow.py index 07fc6713f4..407faf7ec1 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/workflow.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/workflow.py @@ -59,6 +59,11 @@ def generate_sdfg( if auto_opt: gtx_transformations.gt_auto_optimize(sdfg, gpu=on_gpu) elif on_gpu: + # We run simplify to bring the SDFG into a canonical form that the gpu transformations + # can handle. This is a workaround for an issue with scalar expressions that are + # promoted to symbolic expressions and computed on the host (CPU), but the intermediate + # result is written to a GPU global variable (https://github.com/spcl/dace/issues/1773). + gtx_transformations.gt_simplify(sdfg) gtx_transformations.gt_gpu_transformation(sdfg, try_removing_trivial_maps=True) return sdfg diff --git a/tests/next_tests/definitions.py b/tests/next_tests/definitions.py index 321ebb85c7..67b566d35b 100644 --- a/tests/next_tests/definitions.py +++ b/tests/next_tests/definitions.py @@ -100,6 +100,7 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum): USES_SCAN_WITHOUT_FIELD_ARGS = "uses_scan_without_field_args" USES_SCAN_NESTED = "uses_scan_nested" USES_SCAN_REQUIRING_PROJECTOR = "uses_scan_requiring_projector" +USES_SCAN_1D_FIELD = "uses_scan_1d_field" USES_SPARSE_FIELDS = "uses_sparse_fields" USES_SPARSE_FIELDS_AS_OUTPUT = "uses_sparse_fields_as_output" USES_REDUCTION_WITH_ONLY_SPARSE_FIELDS = "uses_reduction_with_only_sparse_fields" @@ -168,9 +169,17 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum): EmbeddedIds.NUMPY_EXECUTION: EMBEDDED_SKIP_LIST, EmbeddedIds.CUPY_EXECUTION: EMBEDDED_SKIP_LIST, OptionalProgramBackendId.DACE_CPU: DACE_SKIP_TEST_LIST, - OptionalProgramBackendId.DACE_GPU: DACE_SKIP_TEST_LIST, + OptionalProgramBackendId.DACE_GPU: DACE_SKIP_TEST_LIST + + [ + # dace issue https://github.com/spcl/dace/issues/1773 + (USES_SCAN_1D_FIELD, XFAIL, UNSUPPORTED_MESSAGE), + ], OptionalProgramBackendId.DACE_CPU_NO_OPT: DACE_SKIP_TEST_LIST, - OptionalProgramBackendId.DACE_GPU_NO_OPT: DACE_SKIP_TEST_LIST, + OptionalProgramBackendId.DACE_GPU_NO_OPT: DACE_SKIP_TEST_LIST + + [ + # dace issue https://github.com/spcl/dace/issues/1773 + (USES_SCAN_1D_FIELD, XFAIL, UNSUPPORTED_MESSAGE), + ], ProgramBackendId.GTFN_CPU: GTFN_SKIP_TEST_LIST + [(USES_SCAN_NESTED, XFAIL, UNSUPPORTED_MESSAGE)], ProgramBackendId.GTFN_CPU_IMPERATIVE: GTFN_SKIP_TEST_LIST diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py index 9de4449ac2..caef13df3d 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py @@ -818,6 +818,7 @@ def testee(a: cases.EField, b: cases.EField) -> cases.VField: @pytest.mark.uses_scan +@pytest.mark.uses_scan_1d_field def test_ternary_scan(cartesian_case): @gtx.scan_operator(axis=KDim, forward=True, init=0.0) def simple_scan_operator(carry: float, a: float) -> float: From d9218b63c37b678cc13f35adb135a7c679978778 Mon Sep 17 00:00:00 2001 From: edopao Date: Tue, 17 Dec 2024 13:29:17 +0100 Subject: [PATCH 46/80] This are the changes Edoardo implemented to fix some issues in the optimization pipeline when confronted with scans. --- .../transformations/__init__.py | 8 +- .../transformations/gpu_utils.py | 2 +- .../transformations/simplify.py | 5 +- .../dace_fieldview/transformations/strides.py | 132 ++++++++++++++++++ .../test_map_buffer_elimination.py | 93 +++++++++++- 5 files changed, 233 insertions(+), 7 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/__init__.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/__init__.py index 4f3efb19b0..439084674e 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/__init__.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/__init__.py @@ -35,7 +35,11 @@ gt_simplify, gt_substitute_compiletime_symbols, ) -from .strides import gt_change_transient_strides +from .strides import ( + gt_change_transient_strides, + gt_map_strides_to_dst_nested_sdfg, + gt_map_strides_to_src_nested_sdfg, +) from .util import gt_find_constant_arguments, gt_make_transients_persistent @@ -59,6 +63,8 @@ "gt_gpu_transformation", "gt_inline_nested_sdfg", "gt_make_transients_persistent", + "gt_map_strides_to_dst_nested_sdfg", + "gt_map_strides_to_src_nested_sdfg", "gt_reduce_distributed_buffering", "gt_set_gpu_blocksize", "gt_set_iteration_order", diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/gpu_utils.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/gpu_utils.py index 2cd3020180..7b14144ead 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/gpu_utils.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/gpu_utils.py @@ -95,7 +95,7 @@ def gt_gpu_transformation( if try_removing_trivial_maps: # In DaCe a Tasklet, outside of a Map, can not write into an _array_ that is on - # GPU. `sdfg.appyl_gpu_transformations()` will wrap such Tasklets in a Map. So + # GPU. `sdfg.apply_gpu_transformations()` will wrap such Tasklets in a Map. So # we might end up with lots of these trivial Maps, each requiring a separate # kernel launch. To prevent this we will combine these trivial maps, if # possible, with their downstream maps. diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/simplify.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/simplify.py index 6b7bd1b6d5..1a132cacb2 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/simplify.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/simplify.py @@ -950,7 +950,7 @@ def _perform_pointwise_test( def apply( self, - graph: dace.SDFGState | dace.SDFG, + graph: dace.SDFGState, sdfg: dace.SDFG, ) -> None: # Removal @@ -971,6 +971,9 @@ def apply( tmp_out_subset = dace_subsets.Range.from_array(tmp_desc) assert glob_in_subset is not None + # Recursively visit the nested SDFGs for mapping from inner to outer strides on the vertical dimension + gtx_transformations.gt_map_strides_to_src_nested_sdfg(sdfg, graph, map_to_tmp_edge, glob_ac) + # We now remove the `tmp` node, and create a new connection between # the global node and the map exit. new_map_to_glob_edge = graph.add_edge( diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py index 4e254f2880..72a1916875 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py @@ -6,6 +6,9 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause +import functools +from typing import Iterable + import dace from dace import data as dace_data @@ -64,6 +67,13 @@ def _gt_change_transient_strides_non_recursive_impl( # we simply have to reverse the order. new_stride_order = list(range(ndim)) desc.set_strides_from_layout(*new_stride_order) + for state in sdfg.states(): + for data_node in state.data_nodes(): + if data_node.data == top_level_transient: + for in_edge in state.in_edges(data_node): + gt_map_strides_to_src_nested_sdfg(sdfg, state, in_edge, data_node) + for out_edge in state.out_edges(data_node): + gt_map_strides_to_dst_nested_sdfg(sdfg, state, out_edge, data_node) def _find_toplevel_transients( @@ -97,3 +107,125 @@ def _find_toplevel_transients( continue top_level_transients.add(data) return top_level_transients + + +def gt_map_strides_to_dst_nested_sdfg( + sdfg: dace.SDFG, + state: dace.SDFGState, + edge: dace.sdfg.graph.Edge, + outer_node: dace.nodes.AccessNode, +) -> None: + """Propagates the strides of the given data node to the nested SDFGs on the edge destination. + + This function will recursively visit the nested SDFGs connected to the given + data node and apply mapping from inner to outer strides. + + Args: + sdfg: The SDFG to process. + state: The state where the data node is used. + edge: The edge that reads from the data node, the nested SDFG is expected as the destination. + outer_node: The data node whose strides should be propagated. + """ + if isinstance(edge.dst, dace.nodes.MapEntry): + # Find the destinaion of the edge entering the map entry node + map_entry_out_conn = edge.dst_conn.replace("IN_", "OUT_") + for edge_from_map_entry in state.out_edges_by_connector(edge.dst, map_entry_out_conn): + gt_map_strides_to_dst_nested_sdfg(sdfg, state, edge_from_map_entry, outer_node) + return + + if not isinstance(edge.dst, dace.nodes.NestedSDFG): + return + + outer_strides = outer_node.desc(sdfg).strides + _gt_map_strides_to_nested_sdfg(edge.dst, edge.dst_conn, edge.data, outer_strides) + + for inner_state in edge.dst.sdfg.states(): + for inner_node in inner_state.data_nodes(): + if inner_node.data == edge.dst: + for inner_edge in inner_state.out_edges(inner_node): + gt_map_strides_to_dst_nested_sdfg(sdfg, state, inner_edge, inner_node) + + +def gt_map_strides_to_src_nested_sdfg( + sdfg: dace.SDFG, + state: dace.SDFGState, + edge: dace.sdfg.graph.Edge, + outer_node: dace.nodes.AccessNode, +) -> None: + """Propagates the strides of the given data node to the nested SDFGs on the edge source. + + This function will recursively visit the nested SDFGs connected to the given + data node and apply mapping from inner to outer strides. + + Args: + sdfg: The SDFG to process. + state: The state where the data node is used. + edge: The edge that writes to the data node, the nested SDFG is expected as the source. + outer_node: The data node whose strides should be propagated. + """ + if isinstance(edge.src, dace.nodes.MapExit): + # Find the source of the edge entering the map exit node + map_exit_in_conn = edge.src_conn.replace("OUT_", "IN_") + for edge_to_map_exit in state.in_edges_by_connector(edge.src, map_exit_in_conn): + gt_map_strides_to_src_nested_sdfg(sdfg, state, edge_to_map_exit, outer_node) + return + + if not isinstance(edge.src, dace.nodes.NestedSDFG): + return + + if isinstance(edge.src.sdfg.data(edge.src_conn), dace.data.Scalar): + return # no strides to propagate + + outer_strides = outer_node.desc(sdfg).strides + _gt_map_strides_to_nested_sdfg(edge.src, edge.src_conn, edge.data, outer_strides) + + for inner_state in edge.src.sdfg.states(): + for inner_node in inner_state.data_nodes(): + if inner_node.data == edge.src_conn: + for inner_edge in inner_state.in_edges(inner_node): + gt_map_strides_to_src_nested_sdfg(sdfg, state, inner_edge, inner_node) + + +def _gt_map_strides_to_nested_sdfg( + nsdfg_node: dace.nodes.NestedSDFG, + inner_data: str, + edge_data: dace.Memlet, + outer_strides: Iterable[int | dace.symbolic.SymExpr], +) -> None: + # We need to propagate the strides inside the nested SDFG on the global arrays + new_strides = tuple( + stride + for stride, to_map_size in zip( + outer_strides, + edge_data.subset.size(), + strict=True, + ) + if to_map_size != 1 + ) + inner_desc = nsdfg_node.sdfg.arrays[inner_data] + assert not inner_desc.transient + + if isinstance(inner_desc, dace.data.Scalar): + assert len(new_strides) == 0 + return + + assert isinstance(inner_desc, dace.data.Array) + if all(isinstance(inner_stride, dace.symbol) for inner_stride in inner_desc.strides): + for inner_stride, outer_stride in zip(inner_desc.strides, new_strides, strict=True): + nsdfg_node.symbol_mapping[inner_stride.name] = outer_stride + else: + inner_desc.set_shape(inner_desc.shape, new_strides) + + new_strides_symbols: list[dace.symbol] = functools.reduce( + lambda acc, itm: (acc + list(itm.free_symbols)) # type: ignore[union-attr] + if dace.symbolic.issymbolic(itm) + else acc, + new_strides, + [], + ) + new_strides_free_symbols = { + sym for sym in new_strides_symbols if sym.name not in nsdfg_node.sdfg.symbols + } + for sym in new_strides_free_symbols: + nsdfg_node.sdfg.add_symbol(sym.name, sym.dtype) + nsdfg_node.symbol_mapping[sym.name] = sym diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_map_buffer_elimination.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_map_buffer_elimination.py index 1a4ce6d047..a98eac3c2c 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_map_buffer_elimination.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_map_buffer_elimination.py @@ -22,10 +22,6 @@ import dace -def _make_test_data(names: list[str]) -> dict[str, np.ndarray]: - return {name: np.array(np.random.rand(10), dtype=np.float64, copy=True) for name in names} - - def _make_test_sdfg( output_name: str = "G", input_name: str = "G", @@ -262,3 +258,92 @@ def test_map_buffer_elimination_not_apply(): validate_all=True, ) assert count == 0 + + +def test_map_buffer_elimination_with_nested_sdfgs(): + """ + After removing a transient connected to a nested SDFG node, ensure that the strides + are propagated to the arrays in nested SDFG. + """ + + stride1, stride2, stride3 = [dace.symbol(f"stride{i}", dace.int32) for i in range(3)] + + # top-level sdfg + sdfg = dace.SDFG(util.unique_name("map_buffer")) + inp, inp_desc = sdfg.add_array("__inp", (10,), dace.float64) + out, out_desc = sdfg.add_array( + "__out", (10, 10, 10), dace.float64, strides=(stride1, stride2, stride3) + ) + tmp, _ = sdfg.add_temp_transient_like(out_desc) + state = sdfg.add_state() + tmp_node = state.add_access(tmp) + + nsdfg1 = dace.SDFG(util.unique_name("map_buffer")) + inp1, inp1_desc = nsdfg1.add_array("__inp", (10,), dace.float64) + out1, out1_desc = nsdfg1.add_array("__out", (10, 10), dace.float64) + tmp1, _ = nsdfg1.add_temp_transient_like(out1_desc) + state1 = nsdfg1.add_state() + tmp1_node = state1.add_access(tmp1) + + nsdfg2 = dace.SDFG(util.unique_name("map_buffer")) + inp2, _ = nsdfg2.add_array("__inp", (10,), dace.float64) + out2, out2_desc = nsdfg2.add_array("__out", (10,), dace.float64) + tmp2, _ = nsdfg2.add_temp_transient_like(out2_desc) + state2 = nsdfg2.add_state() + tmp2_node = state2.add_access(tmp2) + + state2.add_mapped_tasklet( + "broadcast2", + map_ranges={"__i": "0:10"}, + code="__oval = __ival + 1.0", + inputs={ + "__ival": dace.Memlet(f"{inp2}[__i]"), + }, + outputs={ + "__oval": dace.Memlet(f"{tmp2}[__i]"), + }, + output_nodes={tmp2_node}, + external_edges=True, + ) + state2.add_nedge(tmp2_node, state2.add_access(out2), dace.Memlet.from_array(out2, out2_desc)) + + nsdfg2_node = state1.add_nested_sdfg(nsdfg2, nsdfg1, inputs={"__inp"}, outputs={"__out"}) + me1, mx1 = state1.add_map("broadcast1", ndrange={"__i": "0:10"}) + state1.add_memlet_path( + state1.add_access(inp1), + me1, + nsdfg2_node, + dst_conn="__inp", + memlet=dace.Memlet.from_array(inp1, inp1_desc), + ) + state1.add_memlet_path( + nsdfg2_node, mx1, tmp1_node, src_conn="__out", memlet=dace.Memlet(f"{tmp1}[__i, 0:10]") + ) + state1.add_nedge(tmp1_node, state1.add_access(out1), dace.Memlet.from_array(out1, out1_desc)) + + nsdfg1_node = state.add_nested_sdfg(nsdfg1, sdfg, inputs={"__inp"}, outputs={"__out"}) + me, mx = state.add_map("broadcast", ndrange={"__i": "0:10"}) + state.add_memlet_path( + state.add_access(inp), + me, + nsdfg1_node, + dst_conn="__inp", + memlet=dace.Memlet.from_array(inp, inp_desc), + ) + state.add_memlet_path( + nsdfg1_node, mx, tmp_node, src_conn="__out", memlet=dace.Memlet(f"{tmp}[__i, 0:10, 0:10]") + ) + state.add_nedge(tmp_node, state.add_access(out), dace.Memlet.from_array(out, out_desc)) + + sdfg.validate() + + count = sdfg.apply_transformations_repeated( + gtx_transformations.GT4PyMapBufferElimination( + assume_pointwise=False, + ), + validate=True, + validate_all=True, + ) + assert count == 3 + assert out1_desc.strides == out_desc.strides[1:] + assert out2_desc.strides == out_desc.strides[2:] From 9d7e7225333a1ada28f0273a2495c88ed0fea6df Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Wed, 18 Dec 2024 11:08:22 +0100 Subject: [PATCH 47/80] First rework. However the actuall modifier function is not modified yet. --- .../dace_fieldview/transformations/strides.py | 431 ++++++++++++++---- 1 file changed, 354 insertions(+), 77 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py index 72a1916875..196f7b3e74 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py @@ -7,10 +7,11 @@ # SPDX-License-Identifier: BSD-3-Clause import functools -from typing import Iterable +from typing import Iterable, Optional import dace from dace import data as dace_data +from dace.sdfg import nodes as dace_nodes from gt4py.next.program_processors.runners.dace_fieldview import ( transformations as gtx_transformations, @@ -57,93 +58,160 @@ def gt_change_transient_strides( def _gt_change_transient_strides_non_recursive_impl( sdfg: dace.SDFG, ) -> None: - """Essentially this function just changes the stride to FORTRAN order.""" - for top_level_transient in _find_toplevel_transients(sdfg, only_arrays=True): + """Essentially this function just changes the stride to FORTRAN order. + + Todo: + Make this function more intelligent to analyse the access pattern and then + figuring out the best order. + """ + + # NOTE: processing the transient here is enough. If we are inside a + # NestedSDFG then they were handled before on the level above us. + top_level_transients_and_their_accesses = _gt_find_toplevel_data_accesses( + sdfg=sdfg, + only_transients=True, + only_arrays=True, + ) + for top_level_transient, accesses in top_level_transients_and_their_accesses.items(): desc: dace_data.Array = sdfg.arrays[top_level_transient] + + # Setting the strides only make sense if we have more than two dimensions ndim = len(desc.shape) if ndim <= 1: continue + # We assume that everything is in C order initially, to get FORTRAN order # we simply have to reverse the order. new_stride_order = list(range(ndim)) desc.set_strides_from_layout(*new_stride_order) - for state in sdfg.states(): - for data_node in state.data_nodes(): - if data_node.data == top_level_transient: - for in_edge in state.in_edges(data_node): - gt_map_strides_to_src_nested_sdfg(sdfg, state, in_edge, data_node) - for out_edge in state.out_edges(data_node): - gt_map_strides_to_dst_nested_sdfg(sdfg, state, out_edge, data_node) - -def _find_toplevel_transients( + # Now we have to propagate the changed strides. Because we already have + # collected all the AccessNodes we are using the + # `gt_propagate_strides_from_access_node()` function, but we have to + # create `processed_nsdfg` set already outside here. + # Furthermore, the same comment as above apply, we do not have to + # propagate the non-transients, because they either come from outside, + # or they were already handled in the levels above, where they were + # defined and then propagated down. + processed_nsdfgs: set[dace_nodes.NestedSDFG] = set() + for state, access_node in accesses: + gt_propagate_strides_from_access_node( + sdfg=sdfg, + state=state, + outer_node=access_node, + processed_nsdfgs=processed_nsdfgs, + ) + + +def gt_propagate_strides_of( sdfg: dace.SDFG, - only_arrays: bool = False, -) -> set[str]: - """Find all top level transients in the SDFG. + data_name: str, +) -> None: + """Propagates the strides of `data_name` within the whole SDFG. - The function will scan the SDFG, ignoring nested one, and return the - name of all transients that have an access node at the top level. - However, it will ignore access nodes that refers to registers. + This function will call `gt_propagate_strides_from_access_node()` for every + AccessNode that refers to `data_name`. It will also make sure that + a NestedSDFG is visited only once. + + Args: + sdfg: The SDFG on which we operate. + data_name: Name of the data descriptor that should be handled. """ - top_level_transients: set[str] = set() + + # Defining it here ensures that we will not enter an NestedSDFG multiple times. + processed_nsdfgs: set[dace_nodes.NestedSDFG] = set() + for state in sdfg.states(): - scope_dict = state.scope_dict() for dnode in state.data_nodes(): - data: str = dnode.data - if scope_dict[dnode] is not None: - if data in top_level_transients: - top_level_transients.remove(data) - continue - elif data in top_level_transients: - continue - elif gtx_transformations.util.is_view(dnode, sdfg): + if dnode.data != data_name: continue - desc: dace_data.Data = dnode.desc(sdfg) - - if not desc.transient: - continue - elif only_arrays and not isinstance(desc, dace_data.Array): - continue - top_level_transients.add(data) - return top_level_transients + gt_propagate_strides_from_access_node( + sdfg=sdfg, + state=state, + outer_node=dnode, + processed_nsdfgs=processed_nsdfgs, + ) -def gt_map_strides_to_dst_nested_sdfg( +def gt_propagate_strides_from_access_node( sdfg: dace.SDFG, state: dace.SDFGState, - edge: dace.sdfg.graph.Edge, - outer_node: dace.nodes.AccessNode, + outer_node: dace_nodes.AccessNode, + processed_nsdfgs: Optional[set[dace_nodes.NestedSDFG]] = None, ) -> None: - """Propagates the strides of the given data node to the nested SDFGs on the edge destination. + """Propagates the stride of `outer_node` along all adjacent edges of `outer_node`. + + The function will propagate the strides of the data descriptor `outer_node` + refers to along all adjacent edges of `outer_node`. If one of these edges + leads to a NestedSDFG then the function will modify the strides of data + descriptor within to match the strides on the outside. The function will then + recursively process NestedSDFG. - This function will recursively visit the nested SDFGs connected to the given - data node and apply mapping from inner to outer strides. + It is important that this function will only handle the NestedSDFGs that are + reachable from `outer_node`. To fully propagate the strides the + `gt_propagate_strides_of()` should be used. Args: sdfg: The SDFG to process. state: The state where the data node is used. edge: The edge that reads from the data node, the nested SDFG is expected as the destination. outer_node: The data node whose strides should be propagated. + processed_nsdfgs: Set of Nested SDFG that were already processed and will be ignored. + Only specify when you know what your are doing. + propagate_along_dataflow: Determine the direction of propagation. If `True` the + function follows the dataflow. """ - if isinstance(edge.dst, dace.nodes.MapEntry): - # Find the destinaion of the edge entering the map entry node - map_entry_out_conn = edge.dst_conn.replace("IN_", "OUT_") - for edge_from_map_entry in state.out_edges_by_connector(edge.dst, map_entry_out_conn): - gt_map_strides_to_dst_nested_sdfg(sdfg, state, edge_from_map_entry, outer_node) - return + if processed_nsdfgs is None: + # For preventing the case that nested SDFGs are handled multiple time. + # TODO: It certainly happens if a node is input and output, but are there other cases? + processed_nsdfgs = set() + + for in_edge in state.in_edges(outer_node): + gt_map_strides_to_src_nested_sdfg( + sdfg=sdfg, + state=state, + edge=in_edge, + outer_node=outer_node, + processed_nsdfgs=processed_nsdfgs, + ) + for out_edge in state.out_edges(outer_node): + gt_map_strides_to_dst_nested_sdfg( + sdfg=sdfg, + state=state, + edge=out_edge, + outer_node=outer_node, + processed_nsdfgs=processed_nsdfgs, + ) - if not isinstance(edge.dst, dace.nodes.NestedSDFG): - return - outer_strides = outer_node.desc(sdfg).strides - _gt_map_strides_to_nested_sdfg(edge.dst, edge.dst_conn, edge.data, outer_strides) +def gt_map_strides_to_dst_nested_sdfg( + sdfg: dace.SDFG, + state: dace.SDFGState, + edge: dace.sdfg.graph.Edge, + outer_node: dace.nodes.AccessNode, + processed_nsdfgs: Optional[set[dace_nodes.NestedSDFG]] = None, +) -> None: + """Propagates the strides of `outer_node` along `edge` along the dataflow. - for inner_state in edge.dst.sdfg.states(): - for inner_node in inner_state.data_nodes(): - if inner_node.data == edge.dst: - for inner_edge in inner_state.out_edges(inner_node): - gt_map_strides_to_dst_nested_sdfg(sdfg, state, inner_edge, inner_node) + For more information see the description of `_gt_map_strides_to_nested_sdfg_src_dst(). + However it is recommended to use `gt_propagate_strides_of()` directly. + + Args: + sdfg: The SDFG to process. + state: The state where the data node is used. + edge: The edge that writes to the data node, the nested SDFG is expected as the source. + outer_node: The data node whose strides should be propagated. + processed_nsdfgs: Set of Nested SDFG that were already processed. Only specify when + you know what your are doing. + """ + _gt_map_strides_to_nested_sdfg_src_dst( + sdfg=sdfg, + state=state, + edge=edge, + outer_node=outer_node, + processed_nsdfgs=processed_nsdfgs, + propagate_along_dataflow=True, + ) def gt_map_strides_to_src_nested_sdfg( @@ -151,39 +219,165 @@ def gt_map_strides_to_src_nested_sdfg( state: dace.SDFGState, edge: dace.sdfg.graph.Edge, outer_node: dace.nodes.AccessNode, + processed_nsdfgs: Optional[set[dace_nodes.NestedSDFG]] = None, ) -> None: - """Propagates the strides of the given data node to the nested SDFGs on the edge source. + """Propagates the strides of `outer_node` along `edge` against the dataflow. - This function will recursively visit the nested SDFGs connected to the given - data node and apply mapping from inner to outer strides. + For more information see the description of `_gt_map_strides_to_nested_sdfg_src_dst(). + However it is recommended to use `gt_propagate_strides_of()` directly. Args: sdfg: The SDFG to process. state: The state where the data node is used. edge: The edge that writes to the data node, the nested SDFG is expected as the source. outer_node: The data node whose strides should be propagated. + processed_nsdfgs: Set of Nested SDFG that were already processed. Only specify when + you know what your are doing. """ - if isinstance(edge.src, dace.nodes.MapExit): - # Find the source of the edge entering the map exit node - map_exit_in_conn = edge.src_conn.replace("OUT_", "IN_") - for edge_to_map_exit in state.in_edges_by_connector(edge.src, map_exit_in_conn): - gt_map_strides_to_src_nested_sdfg(sdfg, state, edge_to_map_exit, outer_node) - return + _gt_map_strides_to_nested_sdfg_src_dst( + sdfg=sdfg, + state=state, + edge=edge, + outer_node=outer_node, + processed_nsdfgs=processed_nsdfgs, + propagate_along_dataflow=False, + ) - if not isinstance(edge.src, dace.nodes.NestedSDFG): - return - if isinstance(edge.src.sdfg.data(edge.src_conn), dace.data.Scalar): - return # no strides to propagate +def _gt_map_strides_to_nested_sdfg_src_dst( + sdfg: dace.SDFG, + state: dace.SDFGState, + edge: dace.sdfg.graph.MultiConnectorEdge[dace.Memlet], + outer_node: dace.nodes.AccessNode, + processed_nsdfgs: Optional[set[dace_nodes.NestedSDFG]], + propagate_along_dataflow: bool, +) -> None: + """Propagates the stride of `outer_node` along `edge`. + + The function will follow `edge`, the direction depends on the value of + `propagate_along_dataflow` and propagate the strides of `outer_node` + into every NestedSDFG that is reachable by following `edge`. + + When the function encounters a NestedSDFG it will determine the the data + descriptor `outer_node` refers on the inside of the NestedSDFG. + It will then replace the stride of the inner descriptor with the ones + of the outside. Afterwards it will recursively propagates the + stride inside the NestedSDFG. + During this propagation the function will follow any edges. + + If the function reaches a NestedSDFG that is listed inside `processed_nsdfgs` + then it will be skipped. NestedSDFGs that have been processed will be added + to the `processed_nsdfgs`. + + Args: + sdfg: The SDFG to process. + state: The state where the data node is used. + edge: The edge that reads from the data node, the nested SDFG is expected as the destination. + outer_node: The data node whose strides should be propagated. + processed_nsdfgs: Set of Nested SDFG that were already processed and will be ignored. + Only specify when you know what your are doing. + propagate_along_dataflow: Determine the direction of propagation. If `True` the + function follows the dataflow. + + Note: + A user should not use this function directly, instead `gt_propagate_strides_of()`, + `gt_map_strides_to_src_nested_sdfg()` (`propagate_along_dataflow == `False`) + or `gt_map_strides_to_dst_nested_sdfg()` (`propagate_along_dataflow == `True`) + should be used. + + Todo: + Try using `MemletTree` for the propagation. + """ + # If `processed_nsdfg` is `None` then this is the first call. We will now + # allocate the `set` and pass it as argument to all recursive calls, this + # ensures that the `set` is the same everywhere. + if processed_nsdfgs is None: + processed_nsdfgs = set() - outer_strides = outer_node.desc(sdfg).strides - _gt_map_strides_to_nested_sdfg(edge.src, edge.src_conn, edge.data, outer_strides) + if propagate_along_dataflow: + # Propagate along the dataflow or forward, so we are interested at the `dst` of the edge. + ScopeNode = dace_nodes.MapEntry - for inner_state in edge.src.sdfg.states(): - for inner_node in inner_state.data_nodes(): - if inner_node.data == edge.src_conn: - for inner_edge in inner_state.in_edges(inner_node): - gt_map_strides_to_src_nested_sdfg(sdfg, state, inner_edge, inner_node) + def get_node(edge: dace.sdfg.graph.MultiConnectorEdge[dace.Memlet]) -> dace_nodes.Node: + return edge.dst + + def get_inner_data(edge: dace.sdfg.graph.MultiConnectorEdge[dace.Memlet]) -> str: + return edge.dst_conn + + def next_edges_by_connector( + state: dace.SDFGState, edge: dace.sdfg.graph.MultiConnectorEdge[dace.Memlet] + ) -> list[dace.sdfg.graph.MultiConnectorEdge[dace.Memlet]]: + if edge.dst_conn is None or not edge.dst_conn.startswith("IN_"): + return [] + return list(state.out_edges_by_connector(edge.dst, "OUT_" + edge.dst_conn[3:])) + + else: + # Propagate against the dataflow or backward, so we are interested at the `src` of the edge. + ScopeNode = dace_nodes.MapExit + + def get_node(edge: dace.sdfg.graph.MultiConnectorEdge[dace.Memlet]) -> dace_nodes.Node: + return edge.src + + def get_inner_data(edge: dace.sdfg.graph.MultiConnectorEdge[dace.Memlet]) -> str: + return edge.src_conn + + def next_edges_by_connector( + state: dace.SDFGState, edge: dace.sdfg.graph.MultiConnectorEdge[dace.Memlet] + ) -> list[dace.sdfg.graph.MultiConnectorEdge[dace.Memlet]]: + return list(state.in_edges_by_connector(edge.src, "IN_" + edge.src_conn[4:])) + + if isinstance(get_node(edge), ScopeNode): + for next_edge in next_edges_by_connector(state, edge): + _gt_map_strides_to_nested_sdfg_src_dst( + sdfg=sdfg, + state=state, + edge=next_edge, + outer_node=outer_node, + processed_nsdfgs=processed_nsdfgs, + propagate_along_dataflow=propagate_along_dataflow, + ) + + elif isinstance(get_node(edge), dace.nodes.NestedSDFG): + nsdfg_node = get_node(edge) + inner_data = get_inner_data(edge) + + if nsdfg_node in processed_nsdfgs: + # We have processed this nested SDFG already, so we have nothing to do. + return + + # Mark this nested SDFG as processed. + processed_nsdfgs.add(nsdfg_node) + + # Now set the stride of the data descriptor inside the nested SDFG to + # the ones it has outside. + _gt_map_strides_to_nested_sdfg( + nsdfg_node=nsdfg_node, + inner_data=inner_data, + edge_data=edge.data, + outer_strides=outer_node.desc(sdfg).strides, + ) + + # Because the function call above if not recursive we have now to scan the + # propagate the change into the nested SDFG. Using + # `_gt_find_toplevel_data_accesses()` is a bit overkill, but allows for a + # more uniform processing. + # TODO(phimuell): Instead of scanning every level for every data we modify + # we should scan the whole SDFG once and then reuse this information. + accesses_in_nested_sdfg = _gt_find_toplevel_data_accesses( + sdfg=nsdfg_node.sdfg, + only_transients=False, # Because on the nested levels they are globals. + only_arrays=True, + ) + for nested_state, nested_access in accesses_in_nested_sdfg.get(inner_data, list()): + # We have to use `gt_propagate_strides_of()` here because we have to + # handle its entirety. We could wait until the other branch processes + # the nested SDFG, but this might not work, so let's do it fully now. + gt_propagate_strides_from_access_node( + sdfg=nsdfg_node.sdfg, + state=nested_state, + outer_node=nested_access, + processed_nsdfgs=processed_nsdfgs, + ) def _gt_map_strides_to_nested_sdfg( @@ -192,6 +386,7 @@ def _gt_map_strides_to_nested_sdfg( edge_data: dace.Memlet, outer_strides: Iterable[int | dace.symbolic.SymExpr], ) -> None: + # TODO(phimuell/edopao): Refactor this function. # We need to propagate the strides inside the nested SDFG on the global arrays new_strides = tuple( stride @@ -214,6 +409,7 @@ def _gt_map_strides_to_nested_sdfg( for inner_stride, outer_stride in zip(inner_desc.strides, new_strides, strict=True): nsdfg_node.symbol_mapping[inner_stride.name] = outer_stride else: + assert len(inner_desc.shape) == len(new_strides) inner_desc.set_shape(inner_desc.shape, new_strides) new_strides_symbols: list[dace.symbol] = functools.reduce( @@ -229,3 +425,84 @@ def _gt_map_strides_to_nested_sdfg( for sym in new_strides_free_symbols: nsdfg_node.sdfg.add_symbol(sym.name, sym.dtype) nsdfg_node.symbol_mapping[sym.name] = sym + + +def _gt_find_toplevel_data_accesses( + sdfg: dace.SDFG, + only_transients: bool, + only_arrays: bool = False, +) -> dict[str, list[tuple[dace.SDFGState, dace_nodes.AccessNode]]]: + """Find all data that is accessed on the top level. + + The function will scan the SDFG, ignoring nested one, and return the + name of all data (global and transient) that only have AccessNodes on + the top level. In data is found that has an AccessNode on both the top + level and in a nested scope and error is generated. + The function will ignore an access in the following cases: + - The AccessNode refers to data that is a register. + - The AccessNode refers to a View. + + Args: + sdfg: The SDFG to process. + only_transients: If `True` all non transients will be filtered out. + only_arrays: If `True`, defaults to `False`, only arrays are returned. + + Returns: + A `dict` that maps the name of a data container, that should be processed + to a list of tuples containing the state where the AccessNode was found + and the node. + """ + # List of data that is accessed on the top level and all its access node. + top_level_data: dict[str, list[tuple[dace.SDFGState, dace_nodes.AccessNode]]] = dict() + + # List of all data that were found not on top level. + not_top_level_data: set[str] = set() + + for state in sdfg.states(): + scope_dict = state.scope_dict() + for dnode in state.data_nodes(): + data: str = dnode.data + if scope_dict[dnode] is not None: + # The node was not found on the top level. So we can ignore it. + # We also check if it was ever found on the top level, this should + # not happen, as everything should go through Maps. But some strange + # DaCe transformation might do it. + assert data in top_level_data, f"Found {data} on the top level and inside a scope." + not_top_level_data.add(data) + continue + + elif data in top_level_data: + # The data is already known to be in top level data, so we must add the + # AccessNode to the list of known nodes. But nothing else. + top_level_data[data].append((state, dnode)) + continue + + elif gtx_transformations.util.is_view(dnode, sdfg): + # The AccessNode refers to a View so we ignore it anyway + # TODO(phimuell/edopao): Should the function return them? + continue + + # We have found a new data node that is on the top node and is unknown. + assert ( + data not in not_top_level_data + ), f"Found {data} on the top level and inside a scope." + desc: dace_data.Data = dnode.desc(sdfg) + + # Check if we only accept arrays + if only_arrays and not isinstance(desc, dace_data.Array): + continue + + # For now we ignore registers. + # We do this because register are allocated on the stack, so the compiler + # has all information and should organize the best thing possible. + # TODO(phimuell): verify this. + elif desc.storage is not dace.StorageType.Register: + continue + + # We are only interested in transients + if only_transients and desc.transient: + continue + + # Now create the new entry in the list and record the AccessNode. + top_level_data[data] = [(state, dnode)] + return top_level_data From 1ddd6fee4a21b565e300a5a6ea3ad2161998a53f Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Wed, 18 Dec 2024 11:37:55 +0100 Subject: [PATCH 48/80] Updated some commenst. --- .../dace_fieldview/transformations/strides.py | 22 +++++++++++++------ 1 file changed, 15 insertions(+), 7 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py index 196f7b3e74..363ffd6a93 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py @@ -39,8 +39,6 @@ def gt_change_transient_strides( Todo: - Implement the estimation correctly. - - Handle the case of nested SDFGs correctly; on the outside a transient, - but on the inside a non transient. """ # TODO(phimeull): Implement this function correctly. @@ -50,22 +48,32 @@ def gt_change_transient_strides( return sdfg for nsdfg in sdfg.all_sdfgs_recursive(): - # TODO(phimuell): Handle the case when transient goes into nested SDFG - # on the inside it is a non transient, so it is ignored. _gt_change_transient_strides_non_recursive_impl(nsdfg) def _gt_change_transient_strides_non_recursive_impl( sdfg: dace.SDFG, ) -> None: - """Essentially this function just changes the stride to FORTRAN order. + """Set optimal strides of all transients in the SDFG. + + The function will look for all top level transients, see `_gt_find_toplevel_data_accesses()` + and set their strides such that the access is optimal, see Note. The function + will also run `gt_propagate_strides_of()` to propagate the strides into nested SDFGs. + + This function should never be called directly but always through + `gt_change_transient_strides()`! + + Note: + Currently the function just reverses the strides of the data descriptor + it processes. Since DaCe generates `C` order by default this lead to + FORTRAN order, which is (for now) sufficient to optimize the memory + layout to GPU. Todo: Make this function more intelligent to analyse the access pattern and then figuring out the best order. """ - - # NOTE: processing the transient here is enough. If we are inside a + # NOTE: Processing the transient here is enough. If we are inside a # NestedSDFG then they were handled before on the level above us. top_level_transients_and_their_accesses = _gt_find_toplevel_data_accesses( sdfg=sdfg, From 95e0007022dabe56762936451e583ef092625bb3 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Wed, 18 Dec 2024 11:44:49 +0100 Subject: [PATCH 49/80] I want to ignore register, not only consider them. --- .../runners/dace_fieldview/transformations/strides.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py index 363ffd6a93..48d5f7620d 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py @@ -504,7 +504,7 @@ def _gt_find_toplevel_data_accesses( # We do this because register are allocated on the stack, so the compiler # has all information and should organize the best thing possible. # TODO(phimuell): verify this. - elif desc.storage is not dace.StorageType.Register: + elif desc.storage is dace.StorageType.Register: continue # We are only interested in transients From f1b7a3ff851884cfdb154950c7ebc4fd8fc46d47 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Wed, 18 Dec 2024 12:56:06 +0100 Subject: [PATCH 50/80] There was a missing `not` in the check. Which is funny then if you look at the last commit, the number of `not`s in this function was correct. --- .../runners/dace_fieldview/transformations/strides.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py index 48d5f7620d..61471de74b 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py @@ -475,7 +475,9 @@ def _gt_find_toplevel_data_accesses( # We also check if it was ever found on the top level, this should # not happen, as everything should go through Maps. But some strange # DaCe transformation might do it. - assert data in top_level_data, f"Found {data} on the top level and inside a scope." + assert ( + data not in top_level_data + ), f"Found {data} on the top level and inside a scope." not_top_level_data.add(data) continue From 50ad620b97284fa3c78c53f92c5bd0e474308f98 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Wed, 18 Dec 2024 12:56:49 +0100 Subject: [PATCH 51/80] Had to update the propagation, to also handle aliasing. It seems that we alsohave to handle alias. It makes thing a bit handler, instead of only looking at the NestedSDFG, we now look at the `(NameOfDataDescriptorInside, NestedSDFG)` pair. However, it still has some errors. --- .../dace_fieldview/transformations/strides.py | 38 +++++++++++++------ 1 file changed, 26 insertions(+), 12 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py index 61471de74b..78a25c4407 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py @@ -7,7 +7,7 @@ # SPDX-License-Identifier: BSD-3-Clause import functools -from typing import Iterable, Optional +from typing import Iterable, Optional, TypeAlias import dace from dace import data as dace_data @@ -18,6 +18,19 @@ ) +PropagatedStrideRecord: TypeAlias = tuple[str, dace_nodes.NestedSDFG] +"""Record of a stride that has been propagated into a NestedSDFG. + +The type combines the NestedSDFG into which the strides were already propagated +and the data within that NestedSDFG to which we have propagated the data, +which is the connector name on the NestedSDFG. +We need the NestedSDFG because we have to know what was already processed, +however, we also need the name within because of aliasing, i.e. a data +descriptor on the outside could be mapped to multiple data descriptors +inside the NestedSDFG. +""" + + def gt_change_transient_strides( sdfg: dace.SDFG, gpu: bool, @@ -118,8 +131,8 @@ def gt_propagate_strides_of( """Propagates the strides of `data_name` within the whole SDFG. This function will call `gt_propagate_strides_from_access_node()` for every - AccessNode that refers to `data_name`. It will also make sure that - a NestedSDFG is visited only once. + AccessNode that refers to `data_name`. It will also make sure that a descriptor + inside a NestedSDFG is only processed once. Args: sdfg: The SDFG on which we operate. @@ -127,7 +140,7 @@ def gt_propagate_strides_of( """ # Defining it here ensures that we will not enter an NestedSDFG multiple times. - processed_nsdfgs: set[dace_nodes.NestedSDFG] = set() + processed_nsdfgs: set[PropagatedStrideRecord] = set() for state in sdfg.states(): for dnode in state.data_nodes(): @@ -145,7 +158,7 @@ def gt_propagate_strides_from_access_node( sdfg: dace.SDFG, state: dace.SDFGState, outer_node: dace_nodes.AccessNode, - processed_nsdfgs: Optional[set[dace_nodes.NestedSDFG]] = None, + processed_nsdfgs: Optional[set[PropagatedStrideRecord]] = None, ) -> None: """Propagates the stride of `outer_node` along all adjacent edges of `outer_node`. @@ -164,7 +177,7 @@ def gt_propagate_strides_from_access_node( state: The state where the data node is used. edge: The edge that reads from the data node, the nested SDFG is expected as the destination. outer_node: The data node whose strides should be propagated. - processed_nsdfgs: Set of Nested SDFG that were already processed and will be ignored. + processed_nsdfgs: Set of NestedSDFG that were already processed and will be ignored. Only specify when you know what your are doing. propagate_along_dataflow: Determine the direction of propagation. If `True` the function follows the dataflow. @@ -197,7 +210,7 @@ def gt_map_strides_to_dst_nested_sdfg( state: dace.SDFGState, edge: dace.sdfg.graph.Edge, outer_node: dace.nodes.AccessNode, - processed_nsdfgs: Optional[set[dace_nodes.NestedSDFG]] = None, + processed_nsdfgs: Optional[set[PropagatedStrideRecord]] = None, ) -> None: """Propagates the strides of `outer_node` along `edge` along the dataflow. @@ -227,7 +240,7 @@ def gt_map_strides_to_src_nested_sdfg( state: dace.SDFGState, edge: dace.sdfg.graph.Edge, outer_node: dace.nodes.AccessNode, - processed_nsdfgs: Optional[set[dace_nodes.NestedSDFG]] = None, + processed_nsdfgs: Optional[set[PropagatedStrideRecord]] = None, ) -> None: """Propagates the strides of `outer_node` along `edge` against the dataflow. @@ -257,7 +270,7 @@ def _gt_map_strides_to_nested_sdfg_src_dst( state: dace.SDFGState, edge: dace.sdfg.graph.MultiConnectorEdge[dace.Memlet], outer_node: dace.nodes.AccessNode, - processed_nsdfgs: Optional[set[dace_nodes.NestedSDFG]], + processed_nsdfgs: Optional[set[PropagatedStrideRecord]], propagate_along_dataflow: bool, ) -> None: """Propagates the stride of `outer_node` along `edge`. @@ -348,13 +361,14 @@ def next_edges_by_connector( elif isinstance(get_node(edge), dace.nodes.NestedSDFG): nsdfg_node = get_node(edge) inner_data = get_inner_data(edge) + process_record = (inner_data, nsdfg_node) - if nsdfg_node in processed_nsdfgs: - # We have processed this nested SDFG already, so we have nothing to do. + if process_record in processed_nsdfgs: + # We already handled this NestedSDFG and the inner data. return # Mark this nested SDFG as processed. - processed_nsdfgs.add(nsdfg_node) + processed_nsdfgs.add(process_record) # Now set the stride of the data descriptor inside the nested SDFG to # the ones it has outside. From 983022c3f80a6ccb2bd003dbc9c22bafaafc08b0 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Wed, 18 Dec 2024 13:21:00 +0100 Subject: [PATCH 52/80] In the function for looking for top level accesses the `only_transients` flag was not implemented properly. --- .../dace_fieldview/transformations/strides.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py index 78a25c4407..e808422765 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py @@ -457,16 +457,19 @@ def _gt_find_toplevel_data_accesses( """Find all data that is accessed on the top level. The function will scan the SDFG, ignoring nested one, and return the - name of all data (global and transient) that only have AccessNodes on - the top level. In data is found that has an AccessNode on both the top - level and in a nested scope and error is generated. - The function will ignore an access in the following cases: + name of all data that only have AccessNodes on the top level. In data + is found that has an AccessNode on both the top level and in a nested + scope and error is generated. + By default the function will return transient and non transient data, + however, if `only_transients` is `True` then only transient data will + be returned. + Furthermore, the function will ignore an access in the following cases: - The AccessNode refers to data that is a register. - The AccessNode refers to a View. Args: sdfg: The SDFG to process. - only_transients: If `True` all non transients will be filtered out. + only_transients: If `True` only include transients. only_arrays: If `True`, defaults to `False`, only arrays are returned. Returns: @@ -524,7 +527,7 @@ def _gt_find_toplevel_data_accesses( continue # We are only interested in transients - if only_transients and desc.transient: + if only_transients and (not desc.transient): continue # Now create the new entry in the list and record the AccessNode. From e7b1afbf127a7f4a38df9492b147a015501a8a47 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Wed, 18 Dec 2024 13:27:29 +0100 Subject: [PATCH 53/80] Small reminder of the future. --- .../runners/dace_fieldview/transformations/strides.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py index e808422765..08cd08120a 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py @@ -114,6 +114,7 @@ def _gt_change_transient_strides_non_recursive_impl( # propagate the non-transients, because they either come from outside, # or they were already handled in the levels above, where they were # defined and then propagated down. + # TODO(phimuell): Updated the functions such that only once scan is needed. processed_nsdfgs: set[dace_nodes.NestedSDFG] = set() for state, access_node in accesses: gt_propagate_strides_from_access_node( From df7bd0ca993b59a0fcbde67a4ed194c24ac9b3e4 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Wed, 18 Dec 2024 14:47:31 +0100 Subject: [PATCH 54/80] Forgot to export the new SDFG stuff. --- .../runners/dace_fieldview/transformations/__init__.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/__init__.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/__init__.py index 439084674e..0902bd665a 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/__init__.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/__init__.py @@ -39,6 +39,8 @@ gt_change_transient_strides, gt_map_strides_to_dst_nested_sdfg, gt_map_strides_to_src_nested_sdfg, + gt_propagate_strides_from_access_node, + gt_propagate_strides_of, ) from .util import gt_find_constant_arguments, gt_make_transients_persistent @@ -65,6 +67,8 @@ "gt_make_transients_persistent", "gt_map_strides_to_dst_nested_sdfg", "gt_map_strides_to_src_nested_sdfg", + "gt_propagate_strides_from_access_node", + "gt_propagate_strides_of", "gt_reduce_distributed_buffering", "gt_set_gpu_blocksize", "gt_set_iteration_order", From 363ab5942e4da737789554b57f5880cf80ef49ee Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Wed, 18 Dec 2024 15:02:43 +0100 Subject: [PATCH 55/80] Had to update function for actuall renaming of the strides. Before the function had a special mode in which it performed the renaming through the `symbol_mapping`. However, this made testing a bit harder and so I decided that there should be a flag to disable this. --- .../dace_fieldview/transformations/strides.py | 28 ++++++++++++++++++- 1 file changed, 27 insertions(+), 1 deletion(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py index 08cd08120a..e8eb25bd59 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py @@ -128,6 +128,7 @@ def _gt_change_transient_strides_non_recursive_impl( def gt_propagate_strides_of( sdfg: dace.SDFG, data_name: str, + ignore_symbol_mapping: bool = False, ) -> None: """Propagates the strides of `data_name` within the whole SDFG. @@ -138,6 +139,8 @@ def gt_propagate_strides_of( Args: sdfg: The SDFG on which we operate. data_name: Name of the data descriptor that should be handled. + ignore_symbol_mapping: If `False`, the default, try to modify the `symbol_mapping` + of NestedSDFGs instead of manipulating the data descriptor. """ # Defining it here ensures that we will not enter an NestedSDFG multiple times. @@ -152,6 +155,7 @@ def gt_propagate_strides_of( state=state, outer_node=dnode, processed_nsdfgs=processed_nsdfgs, + ignore_symbol_mapping=ignore_symbol_mapping, ) @@ -159,6 +163,7 @@ def gt_propagate_strides_from_access_node( sdfg: dace.SDFG, state: dace.SDFGState, outer_node: dace_nodes.AccessNode, + ignore_symbol_mapping: bool = False, processed_nsdfgs: Optional[set[PropagatedStrideRecord]] = None, ) -> None: """Propagates the stride of `outer_node` along all adjacent edges of `outer_node`. @@ -180,6 +185,8 @@ def gt_propagate_strides_from_access_node( outer_node: The data node whose strides should be propagated. processed_nsdfgs: Set of NestedSDFG that were already processed and will be ignored. Only specify when you know what your are doing. + ignore_symbol_mapping: If `False`, the default, try to modify the `symbol_mapping` + of NestedSDFGs instead of manipulating the data descriptor. propagate_along_dataflow: Determine the direction of propagation. If `True` the function follows the dataflow. """ @@ -195,6 +202,7 @@ def gt_propagate_strides_from_access_node( edge=in_edge, outer_node=outer_node, processed_nsdfgs=processed_nsdfgs, + ignore_symbol_mapping=ignore_symbol_mapping, ) for out_edge in state.out_edges(outer_node): gt_map_strides_to_dst_nested_sdfg( @@ -203,6 +211,7 @@ def gt_propagate_strides_from_access_node( edge=out_edge, outer_node=outer_node, processed_nsdfgs=processed_nsdfgs, + ignore_symbol_mapping=ignore_symbol_mapping, ) @@ -211,6 +220,7 @@ def gt_map_strides_to_dst_nested_sdfg( state: dace.SDFGState, edge: dace.sdfg.graph.Edge, outer_node: dace.nodes.AccessNode, + ignore_symbol_mapping: bool = False, processed_nsdfgs: Optional[set[PropagatedStrideRecord]] = None, ) -> None: """Propagates the strides of `outer_node` along `edge` along the dataflow. @@ -223,6 +233,8 @@ def gt_map_strides_to_dst_nested_sdfg( state: The state where the data node is used. edge: The edge that writes to the data node, the nested SDFG is expected as the source. outer_node: The data node whose strides should be propagated. + ignore_symbol_mapping: If `False`, the default, try to modify the `symbol_mapping` + of NestedSDFGs instead of manipulating the data descriptor. processed_nsdfgs: Set of Nested SDFG that were already processed. Only specify when you know what your are doing. """ @@ -233,6 +245,7 @@ def gt_map_strides_to_dst_nested_sdfg( outer_node=outer_node, processed_nsdfgs=processed_nsdfgs, propagate_along_dataflow=True, + ignore_symbol_mapping=ignore_symbol_mapping, ) @@ -241,6 +254,7 @@ def gt_map_strides_to_src_nested_sdfg( state: dace.SDFGState, edge: dace.sdfg.graph.Edge, outer_node: dace.nodes.AccessNode, + ignore_symbol_mapping: bool = False, processed_nsdfgs: Optional[set[PropagatedStrideRecord]] = None, ) -> None: """Propagates the strides of `outer_node` along `edge` against the dataflow. @@ -253,6 +267,8 @@ def gt_map_strides_to_src_nested_sdfg( state: The state where the data node is used. edge: The edge that writes to the data node, the nested SDFG is expected as the source. outer_node: The data node whose strides should be propagated. + ignore_symbol_mapping: If `False`, the default, try to modify the `symbol_mapping` + of NestedSDFGs instead of manipulating the data descriptor. processed_nsdfgs: Set of Nested SDFG that were already processed. Only specify when you know what your are doing. """ @@ -263,6 +279,7 @@ def gt_map_strides_to_src_nested_sdfg( outer_node=outer_node, processed_nsdfgs=processed_nsdfgs, propagate_along_dataflow=False, + ignore_symbol_mapping=ignore_symbol_mapping, ) @@ -273,6 +290,7 @@ def _gt_map_strides_to_nested_sdfg_src_dst( outer_node: dace.nodes.AccessNode, processed_nsdfgs: Optional[set[PropagatedStrideRecord]], propagate_along_dataflow: bool, + ignore_symbol_mapping: bool = False, ) -> None: """Propagates the stride of `outer_node` along `edge`. @@ -300,6 +318,8 @@ def _gt_map_strides_to_nested_sdfg_src_dst( Only specify when you know what your are doing. propagate_along_dataflow: Determine the direction of propagation. If `True` the function follows the dataflow. + ignore_symbol_mapping: If `False`, the default, try to modify the `symbol_mapping` + of NestedSDFGs instead of manipulating the data descriptor. Note: A user should not use this function directly, instead `gt_propagate_strides_of()`, @@ -357,6 +377,7 @@ def next_edges_by_connector( outer_node=outer_node, processed_nsdfgs=processed_nsdfgs, propagate_along_dataflow=propagate_along_dataflow, + ignore_symbol_mapping=ignore_symbol_mapping, ) elif isinstance(get_node(edge), dace.nodes.NestedSDFG): @@ -378,6 +399,7 @@ def next_edges_by_connector( inner_data=inner_data, edge_data=edge.data, outer_strides=outer_node.desc(sdfg).strides, + ignore_symbol_mapping=ignore_symbol_mapping, ) # Because the function call above if not recursive we have now to scan the @@ -400,6 +422,7 @@ def next_edges_by_connector( state=nested_state, outer_node=nested_access, processed_nsdfgs=processed_nsdfgs, + ignore_symbol_mapping=ignore_symbol_mapping, ) @@ -408,6 +431,7 @@ def _gt_map_strides_to_nested_sdfg( inner_data: str, edge_data: dace.Memlet, outer_strides: Iterable[int | dace.symbolic.SymExpr], + ignore_symbol_mapping: bool = False, ) -> None: # TODO(phimuell/edopao): Refactor this function. # We need to propagate the strides inside the nested SDFG on the global arrays @@ -428,7 +452,9 @@ def _gt_map_strides_to_nested_sdfg( return assert isinstance(inner_desc, dace.data.Array) - if all(isinstance(inner_stride, dace.symbol) for inner_stride in inner_desc.strides): + if (not ignore_symbol_mapping) and all( + isinstance(inner_stride, dace.symbol) for inner_stride in inner_desc.strides + ): for inner_stride, outer_stride in zip(inner_desc.strides, new_strides, strict=True): nsdfg_node.symbol_mapping[inner_stride.name] = outer_stride else: From 9c19d32438477a87e447e0ed510a58c6a85b5fb2 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Wed, 18 Dec 2024 15:06:00 +0100 Subject: [PATCH 56/80] Added a todo to the replacement function. --- .../runners/dace_fieldview/transformations/strides.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py index e8eb25bd59..ea14cf97fb 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py @@ -433,7 +433,12 @@ def _gt_map_strides_to_nested_sdfg( outer_strides: Iterable[int | dace.symbolic.SymExpr], ignore_symbol_mapping: bool = False, ) -> None: - # TODO(phimuell/edopao): Refactor this function. + """ + Todo: + - Refactor this function. + - Handle the case the stride is used somewhere else. + - Handle the case where we have an explicit size 1 dimension in slicing. + """ # We need to propagate the strides inside the nested SDFG on the global arrays new_strides = tuple( stride From 9cad1f7179b7bc8124d9248569c1ee2ccaf904e8 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Wed, 18 Dec 2024 15:10:07 +0100 Subject: [PATCH 57/80] Added a first test to the propagation function. There are some functioanlity missing, but it is looking good. --- .../transformation_tests/test_strides.py | 221 ++++++++++++++++++ 1 file changed, 221 insertions(+) create mode 100644 tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_strides.py diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_strides.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_strides.py new file mode 100644 index 0000000000..bb0af074c7 --- /dev/null +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_strides.py @@ -0,0 +1,221 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +import pytest + +dace = pytest.importorskip("dace") +from dace import symbolic as dace_symbolic +from dace.sdfg import nodes as dace_nodes + +from gt4py.next.program_processors.runners.dace_fieldview import ( + transformations as gtx_transformations, +) + +from . import util + +import dace + + +def _make_strides_propagation_level3_sdfg() -> dace.SDFG: + """Generates the level 3 SDFG (nested-nested) SDFG for `test_strides_propagation()`.""" + sdfg = dace.SDFG(util.unique_name("level3")) + state = sdfg.add_state(is_start_block=True) + names = ["a3", "c3"] + + for name in names: + stride_name = name + "_stride" + stride_sym = dace_symbolic.pystr_to_symbolic(stride_name) + sdfg.add_array( + name, + shape=(10,), + dtype=dace.float64, + transient=False, + strides=(stride_sym,), + ) + + state.add_mapped_tasklet( + "compL3", + map_ranges={"__i0": "0:10"}, + inputs={"__in1": dace.Memlet("a3[__i0]")}, + code="__out = __in1 + 10.", + outputs={"__out": dace.Memlet("c3[__i0]")}, + external_edges=True, + ) + sdfg.validate() + return sdfg + + +def _make_strides_propagation_level2_sdfg() -> tuple[dace.SDFG, dace_nodes.NestedSDFG]: + """Generates the level 2 SDFG (nested) SDFG for `test_strides_propagation()`. + + The function returns the level 2 SDFG and the NestedSDFG node that contains + the level 3 SDFG. + """ + sdfg = dace.SDFG(util.unique_name("level2")) + state = sdfg.add_state(is_start_block=True) + names = ["a2", "a2_alias", "b2", "c2"] + + for name in names: + stride_name = name + "_stride" + stride_sym = dace_symbolic.pystr_to_symbolic(stride_name) + sdfg.add_symbol(stride_name, dace.int64) + sdfg.add_array( + name, + shape=(10,), + dtype=dace.float64, + transient=False, + strides=(stride_sym,), + ) + + state.add_mapped_tasklet( + "compL2_1", + map_ranges={"__i0": "0:10"}, + inputs={"__in1": dace.Memlet("a2[__i0]")}, + code="__out = __in1 + 10", + outputs={"__out": dace.Memlet("b2[__i0]")}, + external_edges=True, + ) + + state.add_mapped_tasklet( + "compL2_2", + map_ranges={"__i0": "0:10"}, + inputs={"__in1": dace.Memlet("c2[__i0]")}, + code="__out = __in1", + outputs={"__out": dace.Memlet("a2_alias[__i0]")}, + external_edges=True, + ) + + # This is the nested SDFG we have here. + sdfg_level3 = _make_strides_propagation_level3_sdfg() + + nsdfg = state.add_nested_sdfg( + sdfg=sdfg_level3, + parent=sdfg, + inputs={"a3"}, + outputs={"c3"}, + symbol_mapping={s3: s3 for s3 in sdfg_level3.free_symbols}, + ) + + state.add_edge(state.add_access("a2"), None, nsdfg, "a3", dace.Memlet("a2[0:10]")) + state.add_edge(nsdfg, "c3", state.add_access("c2"), None, dace.Memlet("c2[0:10]")) + sdfg.validate() + + return sdfg, nsdfg + + +def _make_strides_propagation_level1_sdfg() -> ( + tuple[dace.SDFG, dace_nodes.NestedSDFG, dace_nodes.NestedSDFG] +): + """Generates the level 1 SDFG (top) SDFG for `test_strides_propagation()`. + + Note that the SDFG is valid, but will be indeterminate. The only point of + this SDFG is to have a lot of different situations that have to be handled + for renaming. + + Returns: + A tuple of length three, with the following members: + - The top level SDFG. + - The NestedSDFG node that contains the level 2 SDFG (member of the top level SDFG). + - The NestedSDFG node that contains the lebel 3 SDFG (member of the level 2 SDFG). + """ + + sdfg = dace.SDFG(util.unique_name("level1")) + state = sdfg.add_state(is_start_block=True) + names = ["a1", "b1", "c1"] + + for name in names: + stride_name = name + "_stride" + stride_sym = dace_symbolic.pystr_to_symbolic(stride_name) + sdfg.add_symbol(stride_name, dace.int64) + sdfg.add_array( + name, + shape=(10,), + dtype=dace.float64, + transient=False, + strides=(stride_sym,), + ) + + sdfg_level2, nsdfg_level3 = _make_strides_propagation_level2_sdfg() + + nsdfg_level2: dace_nodes.NestedSDFG = state.add_nested_sdfg( + sdfg=sdfg_level2, + parent=sdfg, + inputs={"a2", "c2"}, + outputs={"a2_alias", "b2", "c2"}, + symbol_mapping={s: s for s in sdfg_level2.free_symbols}, + ) + + for inner_name in nsdfg_level2.in_connectors: + outer_name = inner_name[0] + "1" + state.add_edge( + state.add_access(outer_name), + None, + nsdfg_level2, + inner_name, + dace.Memlet(f"{outer_name}[0:10]"), + ) + for inner_name in nsdfg_level2.out_connectors: + outer_name = inner_name[0] + "1" + state.add_edge( + nsdfg_level2, + inner_name, + state.add_access(outer_name), + None, + dace.Memlet(f"{outer_name}[0:10]"), + ) + + sdfg.validate() + + return sdfg, nsdfg_level2, nsdfg_level3 + + +def test_strides_propagation(): + """ + Todo: + - Add a case where `ignore_symbol_mapping=False` can be tested. + - What happens if the stride symbol is used somewhere else? + """ + # Note that the SDFG we are building here is not really meaningful. + sdfg_level1, nsdfg_level2, nsdfg_level3 = _make_strides_propagation_level1_sdfg() + + # Tests if all strides are distinct in the beginning and match what we expect. + for sdfg in [sdfg_level1, nsdfg_level2.sdfg, nsdfg_level3.sdfg]: + for aname, adesc in sdfg.arrays.items(): + exp_stride = f"{aname}_stride" + assert len(adesc.strides) == 1 + assert exp_stride == str( + adesc.strides[0] + ), f"Expected that '{aname}' has strides '{exp_stride}', but found '{adesc.strides}'." + + # Now we propagate `a` and `b`, but not `c`. + # TODO(phimuell): Create a version where we can set `ignore_symbol_mapping=False`. + gtx_transformations.gt_propagate_strides_of(sdfg_level1, "a1", ignore_symbol_mapping=True) + gtx_transformations.gt_propagate_strides_of(sdfg_level1, "b1", ignore_symbol_mapping=True) + + # After the propagation `a` and `b` should use the same stride (the one that + # it has on level 1, but `c` should still be level depending. + for sdfg in [sdfg_level1, nsdfg_level2.sdfg, nsdfg_level3.sdfg]: + for aname, adesc in sdfg.arrays.items(): + if aname.startswith("c"): + exp_stride = f"{aname}_stride" + else: + exp_stride = f"{aname[0]}1_stride" + assert len(adesc.strides) == 1 + assert exp_stride == str( + adesc.strides[0] + ), f"Expected that '{aname}' has strides '{exp_stride}', but found '{adesc.strides}'." + + # Now we also propagate `c` thus now all data descriptors have the same stride + gtx_transformations.gt_propagate_strides_of(sdfg_level1, "c1", ignore_symbol_mapping=True) + for sdfg in [sdfg_level1, nsdfg_level2.sdfg, nsdfg_level3.sdfg]: + for aname, adesc in sdfg.arrays.items(): + exp_stride = f"{aname[0]}1_stride" + assert len(adesc.strides) == 1 + assert exp_stride == str( + adesc.strides[0] + ), f"Expected that '{aname}' has strides '{exp_stride}', but found '{adesc.strides}'." From 2700f534142464daa38d1ce95edfea26ff3dafc1 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Thu, 19 Dec 2024 10:02:46 +0100 Subject: [PATCH 58/80] Modified the function that performs the actuall modification of the strides. However, it is not yet fully tested, tehy are on their wa. --- .../dace_fieldview/transformations/strides.py | 167 +++++++++++++----- 1 file changed, 127 insertions(+), 40 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py index ea14cf97fb..17bdbceeec 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py @@ -6,8 +6,7 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause -import functools -from typing import Iterable, Optional, TypeAlias +from typing import Optional, TypeAlias import dace from dace import data as dace_data @@ -346,6 +345,11 @@ def get_node(edge: dace.sdfg.graph.MultiConnectorEdge[dace.Memlet]) -> dace_node def get_inner_data(edge: dace.sdfg.graph.MultiConnectorEdge[dace.Memlet]) -> str: return edge.dst_conn + def get_subset( + edge: dace.sdfg.graph.MultiConnectorEdge[dace.Memlet], + ) -> dace.subsets.Subset: + return edge.data.src_subset + def next_edges_by_connector( state: dace.SDFGState, edge: dace.sdfg.graph.MultiConnectorEdge[dace.Memlet] ) -> list[dace.sdfg.graph.MultiConnectorEdge[dace.Memlet]]: @@ -363,6 +367,11 @@ def get_node(edge: dace.sdfg.graph.MultiConnectorEdge[dace.Memlet]) -> dace_node def get_inner_data(edge: dace.sdfg.graph.MultiConnectorEdge[dace.Memlet]) -> str: return edge.src_conn + def get_subset( + edge: dace.sdfg.graph.MultiConnectorEdge[dace.Memlet], + ) -> dace.subsets.Subset: + return edge.data.dst_subset + def next_edges_by_connector( state: dace.SDFGState, edge: dace.sdfg.graph.MultiConnectorEdge[dace.Memlet] ) -> list[dace.sdfg.graph.MultiConnectorEdge[dace.Memlet]]: @@ -394,11 +403,11 @@ def next_edges_by_connector( # Now set the stride of the data descriptor inside the nested SDFG to # the ones it has outside. - _gt_map_strides_to_nested_sdfg( + _gt_map_strides_into_nested_sdfg( nsdfg_node=nsdfg_node, inner_data=inner_data, - edge_data=edge.data, - outer_strides=outer_node.desc(sdfg).strides, + outer_subset=get_subset(edge), + outer_desc=outer_node.desc(sdfg), ignore_symbol_mapping=ignore_symbol_mapping, ) @@ -426,59 +435,137 @@ def next_edges_by_connector( ) -def _gt_map_strides_to_nested_sdfg( +def _gt_map_strides_into_nested_sdfg( nsdfg_node: dace.nodes.NestedSDFG, inner_data: str, - edge_data: dace.Memlet, - outer_strides: Iterable[int | dace.symbolic.SymExpr], + outer_subset: dace.subsets.Subset, + outer_desc: dace_data.Data, ignore_symbol_mapping: bool = False, ) -> None: - """ + """Modify the strides of `inner_data` inside `nsdfg_node` to match `outer_desc`. + + `inner_data` is the name of of a data descriptor inside the NestedSDFG. + The function will then modify the modify the strides of `inner_data` to + match the ones of `outer_desc`. + + Args: + nsdfg_node: The node in the parent SDFG that contains the NestedSDFG. + inner_data: The name of the data descriptor that should be processed + inside the NestedSDFG (by construction also a connector name). + outer_subset: The subset that describes what part of the outer data is + mapped into the NestedSDFG. + outer_desc: The data descriptor of the data on the outside. + ignore_symbol_mapping: If possible the function will perform the renaming + through the `symbol_mapping` of the nested SDFG. If `True` then + the function will always perform the renaming. + Todo: - - Refactor this function. - - Handle the case the stride is used somewhere else. - - Handle the case where we have an explicit size 1 dimension in slicing. + - Handle explicit dimensions of size 1. """ - # We need to propagate the strides inside the nested SDFG on the global arrays - new_strides = tuple( - stride - for stride, to_map_size in zip( - outer_strides, - edge_data.subset.size(), - strict=True, - ) - if to_map_size != 1 - ) - inner_desc = nsdfg_node.sdfg.arrays[inner_data] + # We need to compute the new strides. In the following we assume that the + # relative order of the dimension does not change, but some dimensions + # that are present on the outside are not present on the inside. For + # example this happens for the Memlet `a[__i0, 0:__a_size1]`. + # We detect this case by checking if that dimension has size 1. + # TODO(phimuell): Handle the case were some additional size 1 dimensions are added. + inner_desc: dace_data.Data = nsdfg_node.sdfg.arrays[inner_data] + inner_shape = inner_desc.shape + inner_strides_init = inner_desc.strides + + # TODO(phimuell): For now this is fine, but it should be possisble to allow it. assert not inner_desc.transient - if isinstance(inner_desc, dace.data.Scalar): - assert len(new_strides) == 0 + outer_strides = outer_desc.strides + outer_inflow = outer_subset.size() + + new_strides: list = [] + for dim_ostride, dim_oinflow in zip(outer_strides, outer_inflow, strict=True): + current_inner_dim = len(new_strides) + + if inner_shape[current_inner_dim] == 1 and dim_oinflow == 1: + # There is an explicit size 1 dimension. Because the only valid + # index for this dimension is `0` we can use any value here. + # To give the compiler more information we explicitly use `0`, + # instead of the outer value. + new_strides.append(0) + + elif dim_oinflow == 1: + # Only something flows in, thus there is no stride in this dimension. + pass + + else: + # There is inflow into the SDFG, so we need the stride. + assert dim_oinflow != 0 + new_strides.append(dim_ostride) + assert len(new_strides) <= len(inner_shape) + + if len(new_strides) != len(inner_shape): + raise ValueError("Failed to compute the inner strides.") + + # If we have a scalar on the inside, then there is nothing to adjust. + # We could have performed the test above, but doing it here, gives us + # the chance of validating it. + if isinstance(inner_desc, dace_data.Scalar): + if len(new_strides) != 0: + raise ValueError(f"Dimensional error for '{inner_data}' in '{nsdfg_node.label}'.") return - assert isinstance(inner_desc, dace.data.Array) + if not isinstance(inner_desc, dace_data.Array): + raise TypeError( + f"Expected that '{inner_data}' is an 'Array' but it is '{type(inner_desc).__name__}'." + ) + + # Now we actually replace the strides, there are two ways of doing it. + # The first is to create an alias in the `symbol_mapping`, however, + # this is only possible if the current strides are singular symbols, + # like `__a_strides_1`, but not expressions such as `horizontal_end - horizontal_start` + # or literal values. + # The second way would be to replace `strides` attributer of the + # inner data descriptor. In case the new stride consists of expressions + # such as `value1 - value2` we have to make them available inside the + # NestedSDFG. However, it could be that the strides is used somewhere else. + # We will do the following, if `ignore_symbol_mapping` is `False` and + # the strides of the inner descriptors are symbols, we will use the + # symbol mapping. Otherwise, we will replace the `strides` attribute + # of the inner descriptor, in addition we will install a remapping, + # for those values that were a symbol. if (not ignore_symbol_mapping) and all( - isinstance(inner_stride, dace.symbol) for inner_stride in inner_desc.strides + isinstance(inner_stride, dace.symbol) for inner_stride in inner_strides_init ): + # Use the symbol for inner_stride, outer_stride in zip(inner_desc.strides, new_strides, strict=True): nsdfg_node.symbol_mapping[inner_stride.name] = outer_stride else: - assert len(inner_desc.shape) == len(new_strides) + # We have to replace the `strides` attribute of the inner descriptor. inner_desc.set_shape(inner_desc.shape, new_strides) - new_strides_symbols: list[dace.symbol] = functools.reduce( - lambda acc, itm: (acc + list(itm.free_symbols)) # type: ignore[union-attr] - if dace.symbolic.issymbolic(itm) - else acc, - new_strides, - [], - ) - new_strides_free_symbols = { - sym for sym in new_strides_symbols if sym.name not in nsdfg_node.sdfg.symbols + # Now find the free symbols that the new strides need. + new_strides_symbols: list[str] = [] + for new_stride_dim in new_strides: + if dace.symbolic.issymbolic(new_stride_dim): + new_strides_symbols.append(str(new_stride_dim)) + else: + new_strides_symbols.extend(sym for sym in new_stride_dim.free_symbols) + + # Now we determine the set of symbols that should be mapped inside the NestedSDFG. + # We will exclude all that are already inside the `symbol_mapping` (we do not + # check if they map to the same value, we just hope it). Furthermore, + # we will exclude all symbols that are listed in the `symbols` property + # of the SDFG that is nested, and hope that it has the same meaning. + missing_symbol_mappings: set[str] = { + sym + for sym in new_strides_symbols + if not (sym in nsdfg_node.sdfg.symbols or sym in nsdfg_node.symbol_mapping) } - for sym in new_strides_free_symbols: - nsdfg_node.sdfg.add_symbol(sym.name, sym.dtype) - nsdfg_node.symbol_mapping[sym.name] = sym + for sym in missing_symbol_mappings: + # We can not create symbols in the nested SDFG, because we do not have + # the type of the symbols. + nsdfg_node.symbol_mapping[sym] = dace.symbolic.pystr_to_symbolic(sym) + + # Now create aliases for the old symbols that were used as strides. + for old_sym, new_sym in zip(inner_strides_init, new_strides): + if dace.symbolic.issymbolic(old_sym): + nsdfg_node.symbol_mapping[str(old_sym)] = dace.symbolic.pystr_to_symbolic(new_sym) def _gt_find_toplevel_data_accesses( From a20d3c00a202aea530dd66ee842b00bb550e7045 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Thu, 19 Dec 2024 10:07:24 +0100 Subject: [PATCH 59/80] Updated some tes, but more are missing. --- .../transformation_tests/test_strides.py | 22 +++++++++++++++++-- 1 file changed, 20 insertions(+), 2 deletions(-) diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_strides.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_strides.py index bb0af074c7..655e50fb23 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_strides.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_strides.py @@ -187,11 +187,17 @@ def test_strides_propagation(): for sdfg in [sdfg_level1, nsdfg_level2.sdfg, nsdfg_level3.sdfg]: for aname, adesc in sdfg.arrays.items(): exp_stride = f"{aname}_stride" + actual_stride = adesc.strides[0] assert len(adesc.strides) == 1 - assert exp_stride == str( - adesc.strides[0] + assert ( + str(actual_stride) == exp_stride ), f"Expected that '{aname}' has strides '{exp_stride}', but found '{adesc.strides}'." + nsdfg = sdfg.parent_nsdfg_node + if nsdfg is not None: + assert exp_stride in nsdfg.symbol_mapping + assert str(nsdfg.symbol_mapping[exp_stride]) == exp_stride + # Now we propagate `a` and `b`, but not `c`. # TODO(phimuell): Create a version where we can set `ignore_symbol_mapping=False`. gtx_transformations.gt_propagate_strides_of(sdfg_level1, "a1", ignore_symbol_mapping=True) @@ -201,6 +207,7 @@ def test_strides_propagation(): # it has on level 1, but `c` should still be level depending. for sdfg in [sdfg_level1, nsdfg_level2.sdfg, nsdfg_level3.sdfg]: for aname, adesc in sdfg.arrays.items(): + original_stride = f"{aname}_stride" if aname.startswith("c"): exp_stride = f"{aname}_stride" else: @@ -210,12 +217,23 @@ def test_strides_propagation(): adesc.strides[0] ), f"Expected that '{aname}' has strides '{exp_stride}', but found '{adesc.strides}'." + nsdfg = sdfg.parent_nsdfg_node + if nsdfg is not None: + assert original_stride in nsdfg.symbol_mapping + assert str(nsdfg.symbol_mapping[original_stride]) == exp_stride + # Now we also propagate `c` thus now all data descriptors have the same stride gtx_transformations.gt_propagate_strides_of(sdfg_level1, "c1", ignore_symbol_mapping=True) for sdfg in [sdfg_level1, nsdfg_level2.sdfg, nsdfg_level3.sdfg]: for aname, adesc in sdfg.arrays.items(): exp_stride = f"{aname[0]}1_stride" + original_stride = f"{aname}_stride" assert len(adesc.strides) == 1 assert exp_stride == str( adesc.strides[0] ), f"Expected that '{aname}' has strides '{exp_stride}', but found '{adesc.strides}'." + + nsdfg = sdfg.parent_nsdfg_node + if nsdfg is not None: + assert original_stride in nsdfg.symbol_mapping + assert str(nsdfg.symbol_mapping[original_stride]) == exp_stride From b5ff46270733b4edd6fc7d43fcfe13c55558dd84 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Thu, 19 Dec 2024 11:01:07 +0100 Subject: [PATCH 60/80] Subset caching strikes again. --- .../dace_fieldview/transformations/strides.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py index 17bdbceeec..8808248e40 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py @@ -346,12 +346,14 @@ def get_inner_data(edge: dace.sdfg.graph.MultiConnectorEdge[dace.Memlet]) -> str return edge.dst_conn def get_subset( + state: dace.SDFGState, edge: dace.sdfg.graph.MultiConnectorEdge[dace.Memlet], ) -> dace.subsets.Subset: - return edge.data.src_subset + return edge.data.get_src_subset(edge, state) def next_edges_by_connector( - state: dace.SDFGState, edge: dace.sdfg.graph.MultiConnectorEdge[dace.Memlet] + state: dace.SDFGState, + edge: dace.sdfg.graph.MultiConnectorEdge[dace.Memlet], ) -> list[dace.sdfg.graph.MultiConnectorEdge[dace.Memlet]]: if edge.dst_conn is None or not edge.dst_conn.startswith("IN_"): return [] @@ -368,12 +370,14 @@ def get_inner_data(edge: dace.sdfg.graph.MultiConnectorEdge[dace.Memlet]) -> str return edge.src_conn def get_subset( + state: dace.SDFGState, edge: dace.sdfg.graph.MultiConnectorEdge[dace.Memlet], ) -> dace.subsets.Subset: - return edge.data.dst_subset + return edge.data.get_dst_subset(edge, state) def next_edges_by_connector( - state: dace.SDFGState, edge: dace.sdfg.graph.MultiConnectorEdge[dace.Memlet] + state: dace.SDFGState, + edge: dace.sdfg.graph.MultiConnectorEdge[dace.Memlet], ) -> list[dace.sdfg.graph.MultiConnectorEdge[dace.Memlet]]: return list(state.in_edges_by_connector(edge.src, "IN_" + edge.src_conn[4:])) @@ -406,7 +410,7 @@ def next_edges_by_connector( _gt_map_strides_into_nested_sdfg( nsdfg_node=nsdfg_node, inner_data=inner_data, - outer_subset=get_subset(edge), + outer_subset=get_subset(state, edge), outer_desc=outer_node.desc(sdfg), ignore_symbol_mapping=ignore_symbol_mapping, ) From d326d3b7316f6561897f9b1e7117424b5911baf5 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Thu, 19 Dec 2024 11:01:36 +0100 Subject: [PATCH 61/80] It seems that the explicit handling of one dimensions is not working. It also seems that it inferes with something. --- .../runners/dace_fieldview/transformations/strides.py | 11 +---------- 1 file changed, 1 insertion(+), 10 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py index 8808248e40..e0f21e4163 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py @@ -484,16 +484,7 @@ def _gt_map_strides_into_nested_sdfg( new_strides: list = [] for dim_ostride, dim_oinflow in zip(outer_strides, outer_inflow, strict=True): - current_inner_dim = len(new_strides) - - if inner_shape[current_inner_dim] == 1 and dim_oinflow == 1: - # There is an explicit size 1 dimension. Because the only valid - # index for this dimension is `0` we can use any value here. - # To give the compiler more information we explicitly use `0`, - # instead of the outer value. - new_strides.append(0) - - elif dim_oinflow == 1: + if dim_oinflow == 1: # Only something flows in, thus there is no stride in this dimension. pass From 252f348e104cff7f75fccc5629044bcdb5347b33 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Thu, 19 Dec 2024 11:02:47 +0100 Subject: [PATCH 62/80] The test must be moved bellow. Because a scalar has a shape of `(1,)` but a stride of `()`. Thus we have first to handle this case. However, now we are back at the index stuff, let's fix it. --- .../runners/dace_fieldview/transformations/strides.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py index e0f21e4163..1ee0260310 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py @@ -487,16 +487,12 @@ def _gt_map_strides_into_nested_sdfg( if dim_oinflow == 1: # Only something flows in, thus there is no stride in this dimension. pass - else: # There is inflow into the SDFG, so we need the stride. assert dim_oinflow != 0 new_strides.append(dim_ostride) assert len(new_strides) <= len(inner_shape) - if len(new_strides) != len(inner_shape): - raise ValueError("Failed to compute the inner strides.") - # If we have a scalar on the inside, then there is nothing to adjust. # We could have performed the test above, but doing it here, gives us # the chance of validating it. @@ -510,6 +506,9 @@ def _gt_map_strides_into_nested_sdfg( f"Expected that '{inner_data}' is an 'Array' but it is '{type(inner_desc).__name__}'." ) + if len(new_strides) != len(inner_shape): + raise ValueError("Failed to compute the inner strides.") + # Now we actually replace the strides, there are two ways of doing it. # The first is to create an alias in the `symbol_mapping`, however, # this is only possible if the current strides are singular symbols, From 49f81721b27440a22b2cf3f8fcc14401bb1fbaf1 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Thu, 19 Dec 2024 11:38:44 +0100 Subject: [PATCH 63/80] The symbol is also needed to be present in the nested SDFG. However, it still seems to fail in some cases. --- .../dace_fieldview/transformations/strides.py | 22 ++++++++++++++++--- 1 file changed, 19 insertions(+), 3 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py index 1ee0260310..c03079037d 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py @@ -6,6 +6,7 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause +import warnings from typing import Optional, TypeAlias import dace @@ -408,6 +409,7 @@ def next_edges_by_connector( # Now set the stride of the data descriptor inside the nested SDFG to # the ones it has outside. _gt_map_strides_into_nested_sdfg( + sdfg=sdfg, nsdfg_node=nsdfg_node, inner_data=inner_data, outer_subset=get_subset(state, edge), @@ -440,6 +442,7 @@ def next_edges_by_connector( def _gt_map_strides_into_nested_sdfg( + sdfg: dace.SDFG, nsdfg_node: dace.nodes.NestedSDFG, inner_data: str, outer_subset: dace.subsets.Subset, @@ -453,6 +456,7 @@ def _gt_map_strides_into_nested_sdfg( match the ones of `outer_desc`. Args: + sdfg: The SDFG containing the NestedSDFG. nsdfg_node: The node in the parent SDFG that contains the NestedSDFG. inner_data: The name of the data descriptor that should be processed inside the NestedSDFG (by construction also a connector name). @@ -539,7 +543,9 @@ def _gt_map_strides_into_nested_sdfg( if dace.symbolic.issymbolic(new_stride_dim): new_strides_symbols.append(str(new_stride_dim)) else: - new_strides_symbols.extend(sym for sym in new_stride_dim.free_symbols) + # NOTE: In DaCe `free_symbols` is `set[str]` but in `sympy` it + # returns `set[symbol]`. We need `str` so we have to cast them. + new_strides_symbols.extend(str(sym) for sym in new_stride_dim.free_symbols) # Now we determine the set of symbols that should be mapped inside the NestedSDFG. # We will exclude all that are already inside the `symbol_mapping` (we do not @@ -551,9 +557,19 @@ def _gt_map_strides_into_nested_sdfg( for sym in new_strides_symbols if not (sym in nsdfg_node.sdfg.symbols or sym in nsdfg_node.symbol_mapping) } + # Now create the symbol we in the NestedSDFG. for sym in missing_symbol_mappings: - # We can not create symbols in the nested SDFG, because we do not have - # the type of the symbols. + if sym in sdfg.symbols: + # TODO(phimuell): Handle the case the symbol is already defined. + nsdfg_node.sdfg.add_symbol(sym, sdfg.symbols[sym]) + else: + # The symbol is not known in the parent SDFG, but we need a symbol + # for it. So we use the default. + nsdfg_node.sdfg.add_symbol(sym, dace.symbol("__INVALID_SYMBOL__").dtype) + warnings.warn( + f"Could not find the symbol '{sym}' in the parent SDFG while modifying the strides.", + stacklevel=1, + ) nsdfg_node.symbol_mapping[sym] = dace.symbolic.pystr_to_symbolic(sym) # Now create aliases for the old symbols that were used as strides. From 2d6dfc0e9e7497f5e31161c9baeec5f92c71921c Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Thu, 19 Dec 2024 11:59:11 +0100 Subject: [PATCH 64/80] Fixed a bug in determining the free symbols that we need. --- .../runners/dace_fieldview/transformations/strides.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py index c03079037d..1b8ebdfd41 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py @@ -541,11 +541,11 @@ def _gt_map_strides_into_nested_sdfg( new_strides_symbols: list[str] = [] for new_stride_dim in new_strides: if dace.symbolic.issymbolic(new_stride_dim): - new_strides_symbols.append(str(new_stride_dim)) - else: # NOTE: In DaCe `free_symbols` is `set[str]` but in `sympy` it # returns `set[symbol]`. We need `str` so we have to cast them. new_strides_symbols.extend(str(sym) for sym in new_stride_dim.free_symbols) + else: + new_strides_symbols.append(str(new_stride_dim)) # Now we determine the set of symbols that should be mapped inside the NestedSDFG. # We will exclude all that are already inside the `symbol_mapping` (we do not From 6124c6d7d461799963a8913a1629d5b74a2bee34 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Thu, 19 Dec 2024 11:59:47 +0100 Subject: [PATCH 65/80] Updated the propagation code for the symbols. The type is now a bit better estimated. --- .../dace_fieldview/transformations/strides.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py index 1b8ebdfd41..30864c4449 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py @@ -557,17 +557,20 @@ def _gt_map_strides_into_nested_sdfg( for sym in new_strides_symbols if not (sym in nsdfg_node.sdfg.symbols or sym in nsdfg_node.symbol_mapping) } - # Now create the symbol we in the NestedSDFG. + + # Now propagate the symbols from the parent SDFG to the NestedSDFG. for sym in missing_symbol_mappings: if sym in sdfg.symbols: # TODO(phimuell): Handle the case the symbol is already defined. nsdfg_node.sdfg.add_symbol(sym, sdfg.symbols[sym]) else: - # The symbol is not known in the parent SDFG, but we need a symbol - # for it. So we use the default. - nsdfg_node.sdfg.add_symbol(sym, dace.symbol("__INVALID_SYMBOL__").dtype) + # The symbol is not known in the parent SDFG, but we need to define a + # symbol and for that we need a `dtype`. Our solution (which is as + # wrong as any other) is to create a symbol with that name and then + # use the type that was deduced. + nsdfg_node.sdfg.add_symbol(sym, dace.symbol(sym).dtype) warnings.warn( - f"Could not find the symbol '{sym}' in the parent SDFG while modifying the strides.", + f"Could not find the symbol '{sym}' in the parent SDFG while modifying the strides, use '{nsdfg_node.sdfg.symbols[sym]}' as dtype.", stacklevel=1, ) nsdfg_node.symbol_mapping[sym] = dace.symbolic.pystr_to_symbolic(sym) From 45bcf9795496eb857f80bc594cd1fd37406a7ff4 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Thu, 19 Dec 2024 13:55:09 +0100 Subject: [PATCH 66/80] Addressed Edoardo's changes. --- .../transformations/simplify.py | 2 +- .../dace_fieldview/transformations/strides.py | 94 ++++++++++--------- 2 files changed, 51 insertions(+), 45 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/simplify.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/simplify.py index 1a132cacb2..4339a761fa 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/simplify.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/simplify.py @@ -971,7 +971,7 @@ def apply( tmp_out_subset = dace_subsets.Range.from_array(tmp_desc) assert glob_in_subset is not None - # Recursively visit the nested SDFGs for mapping from inner to outer strides on the vertical dimension + # Recursively visit the nested SDFGs for mapping of strides from inner to outer array gtx_transformations.gt_map_strides_to_src_nested_sdfg(sdfg, graph, map_to_tmp_edge, glob_ac) # We now remove the `tmp` node, and create a new connection between diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py index 30864c4449..e69d392770 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py @@ -22,10 +22,10 @@ """Record of a stride that has been propagated into a NestedSDFG. The type combines the NestedSDFG into which the strides were already propagated -and the data within that NestedSDFG to which we have propagated the data, +and the data within that NestedSDFG to which we have propagated the strides, which is the connector name on the NestedSDFG. We need the NestedSDFG because we have to know what was already processed, -however, we also need the name within because of aliasing, i.e. a data +however, we also need the inner array name because of aliasing, i.e. a data descriptor on the outside could be mapped to multiple data descriptors inside the NestedSDFG. """ @@ -96,13 +96,14 @@ def _gt_change_transient_strides_non_recursive_impl( for top_level_transient, accesses in top_level_transients_and_their_accesses.items(): desc: dace_data.Array = sdfg.arrays[top_level_transient] - # Setting the strides only make sense if we have more than two dimensions + # Setting the strides only make sense if we have more than one dimensions ndim = len(desc.shape) if ndim <= 1: continue # We assume that everything is in C order initially, to get FORTRAN order # we simply have to reverse the order. + # TODO(phimuell): Improve this. new_stride_order = list(range(ndim)) desc.set_strides_from_layout(*new_stride_order) @@ -110,11 +111,11 @@ def _gt_change_transient_strides_non_recursive_impl( # collected all the AccessNodes we are using the # `gt_propagate_strides_from_access_node()` function, but we have to # create `processed_nsdfg` set already outside here. - # Furthermore, the same comment as above apply, we do not have to + # Furthermore, the same comment as above applies here, we do not have to # propagate the non-transients, because they either come from outside, # or they were already handled in the levels above, where they were # defined and then propagated down. - # TODO(phimuell): Updated the functions such that only once scan is needed. + # TODO(phimuell): Updated the functions such that only one scan is needed. processed_nsdfgs: set[dace_nodes.NestedSDFG] = set() for state, access_node in accesses: gt_propagate_strides_from_access_node( @@ -166,7 +167,7 @@ def gt_propagate_strides_from_access_node( ignore_symbol_mapping: bool = False, processed_nsdfgs: Optional[set[PropagatedStrideRecord]] = None, ) -> None: - """Propagates the stride of `outer_node` along all adjacent edges of `outer_node`. + """Propagates the stride of `outer_node` to any adjacent reachable through its edges. The function will propagate the strides of the data descriptor `outer_node` refers to along all adjacent edges of `outer_node`. If one of these edges @@ -183,16 +184,13 @@ def gt_propagate_strides_from_access_node( state: The state where the data node is used. edge: The edge that reads from the data node, the nested SDFG is expected as the destination. outer_node: The data node whose strides should be propagated. - processed_nsdfgs: Set of NestedSDFG that were already processed and will be ignored. - Only specify when you know what your are doing. ignore_symbol_mapping: If `False`, the default, try to modify the `symbol_mapping` of NestedSDFGs instead of manipulating the data descriptor. - propagate_along_dataflow: Determine the direction of propagation. If `True` the - function follows the dataflow. + processed_nsdfgs: Set of NestedSDFG that were already processed and will be ignored. + Only specify when you know what your are doing. """ if processed_nsdfgs is None: # For preventing the case that nested SDFGs are handled multiple time. - # TODO: It certainly happens if a node is input and output, but are there other cases? processed_nsdfgs = set() for in_edge in state.in_edges(outer_node): @@ -225,8 +223,13 @@ def gt_map_strides_to_dst_nested_sdfg( ) -> None: """Propagates the strides of `outer_node` along `edge` along the dataflow. - For more information see the description of `_gt_map_strides_to_nested_sdfg_src_dst(). - However it is recommended to use `gt_propagate_strides_of()` directly. + In this context "along the dataflow" means that `edge` is an outgoing + edge of `outer_node` and the strides are into all NestedSDFGs that + are downstream of `outer_node`. + + Except in certain cases this function should not be used directly. It is + instead recommended to use `gt_propagate_strides_of()`, which propagates + all edges in the SDFG. Args: sdfg: The SDFG to process. @@ -235,9 +238,10 @@ def gt_map_strides_to_dst_nested_sdfg( outer_node: The data node whose strides should be propagated. ignore_symbol_mapping: If `False`, the default, try to modify the `symbol_mapping` of NestedSDFGs instead of manipulating the data descriptor. - processed_nsdfgs: Set of Nested SDFG that were already processed. Only specify when + processed_nsdfgs: Set of NestedSDFGs that were already processed. Only specify when you know what your are doing. """ + assert edge.src is outer_node _gt_map_strides_to_nested_sdfg_src_dst( sdfg=sdfg, state=state, @@ -259,8 +263,13 @@ def gt_map_strides_to_src_nested_sdfg( ) -> None: """Propagates the strides of `outer_node` along `edge` against the dataflow. - For more information see the description of `_gt_map_strides_to_nested_sdfg_src_dst(). - However it is recommended to use `gt_propagate_strides_of()` directly. + In this context "along the dataflow" means that `edge` is an incoming + edge of `outer_node` and the strides are into all NestedSDFGs that + are upstream of `outer_node`. + + Except in certain cases this function should not be used directly. It is + instead recommended to use `gt_propagate_strides_of()`, which propagates + all edges in the SDFG. Args: sdfg: The SDFG to process. @@ -269,7 +278,7 @@ def gt_map_strides_to_src_nested_sdfg( outer_node: The data node whose strides should be propagated. ignore_symbol_mapping: If `False`, the default, try to modify the `symbol_mapping` of NestedSDFGs instead of manipulating the data descriptor. - processed_nsdfgs: Set of Nested SDFG that were already processed. Only specify when + processed_nsdfgs: Set of NestedSDFGs that were already processed. Only specify when you know what your are doing. """ _gt_map_strides_to_nested_sdfg_src_dst( @@ -298,11 +307,11 @@ def _gt_map_strides_to_nested_sdfg_src_dst( `propagate_along_dataflow` and propagate the strides of `outer_node` into every NestedSDFG that is reachable by following `edge`. - When the function encounters a NestedSDFG it will determine the the data - descriptor `outer_node` refers on the inside of the NestedSDFG. + When the function encounters a NestedSDFG it will determine what data + the `outer_node` is mapped to on the inside of the NestedSDFG. It will then replace the stride of the inner descriptor with the ones - of the outside. Afterwards it will recursively propagates the - stride inside the NestedSDFG. + of the outside. Afterwards it will recursively propagate the strides + inside the NestedSDFG. During this propagation the function will follow any edges. If the function reaches a NestedSDFG that is listed inside `processed_nsdfgs` @@ -417,10 +426,9 @@ def next_edges_by_connector( ignore_symbol_mapping=ignore_symbol_mapping, ) - # Because the function call above if not recursive we have now to scan the - # propagate the change into the nested SDFG. Using - # `_gt_find_toplevel_data_accesses()` is a bit overkill, but allows for a - # more uniform processing. + # Since the function call above is not recursive we have now to propagate + # the change into the NestedSDFGs. Using `_gt_find_toplevel_data_accesses()` + # is a bit overkill, but allows for a more uniform processing. # TODO(phimuell): Instead of scanning every level for every data we modify # we should scan the whole SDFG once and then reuse this information. accesses_in_nested_sdfg = _gt_find_toplevel_data_accesses( @@ -429,8 +437,8 @@ def next_edges_by_connector( only_arrays=True, ) for nested_state, nested_access in accesses_in_nested_sdfg.get(inner_data, list()): - # We have to use `gt_propagate_strides_of()` here because we have to - # handle its entirety. We could wait until the other branch processes + # We have to use `gt_propagate_strides_from_access_node()` here because we + # have to handle its entirety. We could wait until the other branch processes # the nested SDFG, but this might not work, so let's do it fully now. gt_propagate_strides_from_access_node( sdfg=nsdfg_node.sdfg, @@ -451,9 +459,9 @@ def _gt_map_strides_into_nested_sdfg( ) -> None: """Modify the strides of `inner_data` inside `nsdfg_node` to match `outer_desc`. - `inner_data` is the name of of a data descriptor inside the NestedSDFG. - The function will then modify the modify the strides of `inner_data` to - match the ones of `outer_desc`. + `inner_data` is the name of a data descriptor inside the NestedSDFG. + The function will then modify the strides of `inner_data`, assuming this + is an array, to match the ones of `outer_desc`. Args: sdfg: The SDFG containing the NestedSDFG. @@ -471,25 +479,24 @@ def _gt_map_strides_into_nested_sdfg( - Handle explicit dimensions of size 1. """ # We need to compute the new strides. In the following we assume that the - # relative order of the dimension does not change, but some dimensions - # that are present on the outside are not present on the inside. For - # example this happens for the Memlet `a[__i0, 0:__a_size1]`. - # We detect this case by checking if that dimension has size 1. + # relative order of the dimensions does not change, but we support the case + # where some dimensions of the outer data descriptor are not present on the + # inside. For example this happens for the Memlet `a[__i0, 0:__a_size1]`. We + # detect this case by checking if the Memlet subset in that dimension has size 1. # TODO(phimuell): Handle the case were some additional size 1 dimensions are added. inner_desc: dace_data.Data = nsdfg_node.sdfg.arrays[inner_data] inner_shape = inner_desc.shape inner_strides_init = inner_desc.strides - # TODO(phimuell): For now this is fine, but it should be possisble to allow it. - assert not inner_desc.transient - outer_strides = outer_desc.strides outer_inflow = outer_subset.size() new_strides: list = [] for dim_ostride, dim_oinflow in zip(outer_strides, outer_inflow, strict=True): if dim_oinflow == 1: - # Only something flows in, thus there is no stride in this dimension. + # This is the case of implicit slicing along one dimension. The inner + # array descriptor has shape != 1 in `current_inner_dim`, which has + # to map to a subsequent dimension of `outer_inflow` pass else: # There is inflow into the SDFG, so we need the stride. @@ -518,7 +525,7 @@ def _gt_map_strides_into_nested_sdfg( # this is only possible if the current strides are singular symbols, # like `__a_strides_1`, but not expressions such as `horizontal_end - horizontal_start` # or literal values. - # The second way would be to replace `strides` attributer of the + # The second way would be to replace `strides` attribute of the # inner data descriptor. In case the new stride consists of expressions # such as `value1 - value2` we have to make them available inside the # NestedSDFG. However, it could be that the strides is used somewhere else. @@ -552,6 +559,7 @@ def _gt_map_strides_into_nested_sdfg( # check if they map to the same value, we just hope it). Furthermore, # we will exclude all symbols that are listed in the `symbols` property # of the SDFG that is nested, and hope that it has the same meaning. + # TODO(phimuell): Add better checks here. missing_symbol_mappings: set[str] = { sym for sym in new_strides_symbols @@ -605,9 +613,8 @@ def _gt_find_toplevel_data_accesses( only_arrays: If `True`, defaults to `False`, only arrays are returned. Returns: - A `dict` that maps the name of a data container, that should be processed - to a list of tuples containing the state where the AccessNode was found - and the node. + A `dict` that maps the name of a data container, to a list of tuples + containing the state where the AccessNode was found and the AccessNode. """ # List of data that is accessed on the top level and all its access node. top_level_data: dict[str, list[tuple[dace.SDFGState, dace_nodes.AccessNode]]] = dict() @@ -637,8 +644,7 @@ def _gt_find_toplevel_data_accesses( continue elif gtx_transformations.util.is_view(dnode, sdfg): - # The AccessNode refers to a View so we ignore it anyway - # TODO(phimuell/edopao): Should the function return them? + # The AccessNode refers to a View so we ignore it anyway. continue # We have found a new data node that is on the top node and is unknown. From 23b0baa530077be44981e800a5b622bc2ae872bc Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Thu, 19 Dec 2024 14:18:16 +0100 Subject: [PATCH 67/80] Updated how we get the type of symbols. The type are now extracted from the stuff we get from `free_symbols`. --- .../dace_fieldview/transformations/strides.py | 38 ++++++++++--------- 1 file changed, 21 insertions(+), 17 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py index e69d392770..5c501bca24 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py @@ -545,43 +545,47 @@ def _gt_map_strides_into_nested_sdfg( inner_desc.set_shape(inner_desc.shape, new_strides) # Now find the free symbols that the new strides need. - new_strides_symbols: list[str] = [] + # Note that usually `free_symbols` returns `set[str]`, but here, because + # we fall back on SymPy, we get back symbols. We will keep them, because + # then we can use them to extract the type form them, which we need later. + new_strides_symbols: list[dace.symbol] = [] for new_stride_dim in new_strides: if dace.symbolic.issymbolic(new_stride_dim): - # NOTE: In DaCe `free_symbols` is `set[str]` but in `sympy` it - # returns `set[symbol]`. We need `str` so we have to cast them. - new_strides_symbols.extend(str(sym) for sym in new_stride_dim.free_symbols) + new_strides_symbols.extend(sym for sym in new_stride_dim.free_symbols) else: - new_strides_symbols.append(str(new_stride_dim)) + # It is not already a symbol, so we turn it into a symbol. + # However, we only add it, if it is also a symbol, for example `1`. + # should not be added. + new_stride_symbol = dace.symbolic.pystr_to_symbolic(new_stride_dim) + if new_stride_symbol.is_symbol: + new_strides_symbols.append(new_stride_symbol) # Now we determine the set of symbols that should be mapped inside the NestedSDFG. # We will exclude all that are already inside the `symbol_mapping` (we do not # check if they map to the same value, we just hope it). Furthermore, # we will exclude all symbols that are listed in the `symbols` property # of the SDFG that is nested, and hope that it has the same meaning. - # TODO(phimuell): Add better checks here. - missing_symbol_mappings: set[str] = { + # TODO(phimuell): Add better checks to avoid overwriting. + missing_symbol_mappings: set[dace.symbol] = { sym for sym in new_strides_symbols - if not (sym in nsdfg_node.sdfg.symbols or sym in nsdfg_node.symbol_mapping) + if not (sym.name in nsdfg_node.sdfg.symbols or sym.name in nsdfg_node.symbol_mapping) } # Now propagate the symbols from the parent SDFG to the NestedSDFG. for sym in missing_symbol_mappings: if sym in sdfg.symbols: - # TODO(phimuell): Handle the case the symbol is already defined. - nsdfg_node.sdfg.add_symbol(sym, sdfg.symbols[sym]) + # TODO(phimuell): Handle the case the symbol is already defined in + # the nested SDFG. + nsdfg_node.sdfg.add_symbol(sym.name, sdfg.symbols[sym.name]) else: - # The symbol is not known in the parent SDFG, but we need to define a - # symbol and for that we need a `dtype`. Our solution (which is as - # wrong as any other) is to create a symbol with that name and then - # use the type that was deduced. - nsdfg_node.sdfg.add_symbol(sym, dace.symbol(sym).dtype) + # The symbol is not known in the parent SDFG, so we add it + nsdfg_node.sdfg.add_symbol(sym.name, sym.dtype) warnings.warn( - f"Could not find the symbol '{sym}' in the parent SDFG while modifying the strides, use '{nsdfg_node.sdfg.symbols[sym]}' as dtype.", + f"Could not find the symbol '{sym}' in the parent SDFG while modifying the strides, use '{nsdfg_node.sdfg.symbols[sym.name]}' as dtype.", stacklevel=1, ) - nsdfg_node.symbol_mapping[sym] = dace.symbolic.pystr_to_symbolic(sym) + nsdfg_node.symbol_mapping[sym.name] = sym # Now create aliases for the old symbols that were used as strides. for old_sym, new_sym in zip(inner_strides_init, new_strides): From ff058802b5128ddde404e01c909e0ff36f85fefb Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Thu, 19 Dec 2024 14:25:10 +0100 Subject: [PATCH 68/80] New restriction on the update of the symbol mapping. --- .../runners/dace_fieldview/transformations/strides.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py index 5c501bca24..f683737f23 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py @@ -575,8 +575,6 @@ def _gt_map_strides_into_nested_sdfg( # Now propagate the symbols from the parent SDFG to the NestedSDFG. for sym in missing_symbol_mappings: if sym in sdfg.symbols: - # TODO(phimuell): Handle the case the symbol is already defined in - # the nested SDFG. nsdfg_node.sdfg.add_symbol(sym.name, sdfg.symbols[sym.name]) else: # The symbol is not known in the parent SDFG, so we add it @@ -589,7 +587,7 @@ def _gt_map_strides_into_nested_sdfg( # Now create aliases for the old symbols that were used as strides. for old_sym, new_sym in zip(inner_strides_init, new_strides): - if dace.symbolic.issymbolic(old_sym): + if dace.symbolic.issymbolic(old_sym) and old_sym.is_symbol: nsdfg_node.symbol_mapping[str(old_sym)] = dace.symbolic.pystr_to_symbolic(new_sym) From 43ec33ccff098c7beacf4a9588120a047abd0e44 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Thu, 19 Dec 2024 14:47:35 +0100 Subject: [PATCH 69/80] Updated the tests, now also made one that has tests for the symbol mapping branch. --- .../transformation_tests/test_strides.py | 81 ++++++++++++++++--- 1 file changed, 71 insertions(+), 10 deletions(-) diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_strides.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_strides.py index 655e50fb23..45c3ebc739 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_strides.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_strides.py @@ -174,12 +174,70 @@ def _make_strides_propagation_level1_sdfg() -> ( return sdfg, nsdfg_level2, nsdfg_level3 -def test_strides_propagation(): - """ - Todo: - - Add a case where `ignore_symbol_mapping=False` can be tested. - - What happens if the stride symbol is used somewhere else? - """ +def test_strides_propagation_use_symbol_mapping(): + # Note that the SDFG we are building here is not really meaningful. + sdfg_level1, nsdfg_level2, nsdfg_level3 = _make_strides_propagation_level1_sdfg() + + # Tests if all strides are distinct in the beginning and match what we expect. + for sdfg in [sdfg_level1, nsdfg_level2.sdfg, nsdfg_level3.sdfg]: + for aname, adesc in sdfg.arrays.items(): + exp_stride = f"{aname}_stride" + actual_stride = adesc.strides[0] + assert len(adesc.strides) == 1 + assert ( + str(actual_stride) == exp_stride + ), f"Expected that '{aname}' has strides '{exp_stride}', but found '{adesc.strides}'." + + nsdfg = sdfg.parent_nsdfg_node + if nsdfg is not None: + assert exp_stride in nsdfg.symbol_mapping + assert str(nsdfg.symbol_mapping[exp_stride]) == exp_stride + + # Now we propagate `a` and `b`, but not `c`. + gtx_transformations.gt_propagate_strides_of(sdfg_level1, "a1", ignore_symbol_mapping=False) + sdfg_level1.validate() + gtx_transformations.gt_propagate_strides_of(sdfg_level1, "b1", ignore_symbol_mapping=False) + sdfg_level1.validate() + + # Because `ignore_symbol_mapping=False` the strides of the data descriptor should + # not have changed. But the `symbol_mapping` has been updated for `a` and `b`. + # However, the symbols will only point one level above. + for level, sdfg in enumerate([sdfg_level1, nsdfg_level2.sdfg, nsdfg_level3.sdfg], start=1): + for aname, adesc in sdfg.arrays.items(): + nsdfg = sdfg.parent_nsdfg_node + original_stride = f"{aname}_stride" + + if aname.startswith("c"): + target_symbol = f"{aname}_stride" + else: + target_symbol = f"{aname[0]}{level - 1}_stride" + + if nsdfg is not None: + assert original_stride in nsdfg.symbol_mapping + assert str(nsdfg.symbol_mapping[original_stride]) == target_symbol + assert len(adesc.strides) == 1 + assert ( + str(adesc.strides[0]) == original_stride + ), f"Expected that '{aname}' has strides '{exp_stride}', but found '{adesc.strides}'." + + # Now we also propagate `c` thus now all data descriptors have the same stride + gtx_transformations.gt_propagate_strides_of(sdfg_level1, "c1", ignore_symbol_mapping=False) + sdfg_level1.validate() + for level, sdfg in enumerate([sdfg_level1, nsdfg_level2.sdfg, nsdfg_level3.sdfg], start=1): + for aname, adesc in sdfg.arrays.items(): + nsdfg = sdfg.parent_nsdfg_node + original_stride = f"{aname}_stride" + target_symbol = f"{aname[0]}{level-1}_stride" + if nsdfg is not None: + assert original_stride in nsdfg.symbol_mapping + assert str(nsdfg.symbol_mapping[original_stride]) == target_symbol + assert len(adesc.strides) == 1 + assert ( + str(adesc.strides[0]) == original_stride + ), f"Expected that '{aname}' has strides '{exp_stride}', but found '{adesc.strides}'." + + +def test_strides_propagation_ignore_symbol_mapping(): # Note that the SDFG we are building here is not really meaningful. sdfg_level1, nsdfg_level2, nsdfg_level3 = _make_strides_propagation_level1_sdfg() @@ -201,7 +259,9 @@ def test_strides_propagation(): # Now we propagate `a` and `b`, but not `c`. # TODO(phimuell): Create a version where we can set `ignore_symbol_mapping=False`. gtx_transformations.gt_propagate_strides_of(sdfg_level1, "a1", ignore_symbol_mapping=True) + sdfg_level1.validate() gtx_transformations.gt_propagate_strides_of(sdfg_level1, "b1", ignore_symbol_mapping=True) + sdfg_level1.validate() # After the propagation `a` and `b` should use the same stride (the one that # it has on level 1, but `c` should still be level depending. @@ -213,8 +273,8 @@ def test_strides_propagation(): else: exp_stride = f"{aname[0]}1_stride" assert len(adesc.strides) == 1 - assert exp_stride == str( - adesc.strides[0] + assert ( + str(adesc.strides[0]) == exp_stride ), f"Expected that '{aname}' has strides '{exp_stride}', but found '{adesc.strides}'." nsdfg = sdfg.parent_nsdfg_node @@ -224,13 +284,14 @@ def test_strides_propagation(): # Now we also propagate `c` thus now all data descriptors have the same stride gtx_transformations.gt_propagate_strides_of(sdfg_level1, "c1", ignore_symbol_mapping=True) + sdfg_level1.validate() for sdfg in [sdfg_level1, nsdfg_level2.sdfg, nsdfg_level3.sdfg]: for aname, adesc in sdfg.arrays.items(): exp_stride = f"{aname[0]}1_stride" original_stride = f"{aname}_stride" assert len(adesc.strides) == 1 - assert exp_stride == str( - adesc.strides[0] + assert ( + str(adesc.strides[0]) == exp_stride ), f"Expected that '{aname}' has strides '{exp_stride}', but found '{adesc.strides}'." nsdfg = sdfg.parent_nsdfg_node From d43153a4165878a1cf91e033bf3cbdb5360babea Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Thu, 19 Dec 2024 15:15:14 +0100 Subject: [PATCH 70/80] Fixed two bug in the stride propagation function. --- .../runners/dace_fieldview/transformations/strides.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py index f683737f23..2cc75e195d 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py @@ -574,7 +574,7 @@ def _gt_map_strides_into_nested_sdfg( # Now propagate the symbols from the parent SDFG to the NestedSDFG. for sym in missing_symbol_mappings: - if sym in sdfg.symbols: + if str(sym) in sdfg.symbols: nsdfg_node.sdfg.add_symbol(sym.name, sdfg.symbols[sym.name]) else: # The symbol is not known in the parent SDFG, so we add it @@ -583,7 +583,7 @@ def _gt_map_strides_into_nested_sdfg( f"Could not find the symbol '{sym}' in the parent SDFG while modifying the strides, use '{nsdfg_node.sdfg.symbols[sym.name]}' as dtype.", stacklevel=1, ) - nsdfg_node.symbol_mapping[sym.name] = sym + nsdfg_node.symbol_mapping[sym.name] = sym # Now create aliases for the old symbols that were used as strides. for old_sym, new_sym in zip(inner_strides_init, new_strides): From 2e82bd5a90e0bb2d1abda35e428b337bf91a7efa Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Thu, 19 Dec 2024 15:19:42 +0100 Subject: [PATCH 71/80] Added a test that ensures that the dependent adding works. --- .../transformation_tests/test_strides.py | 105 ++++++++++++++++++ 1 file changed, 105 insertions(+) diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_strides.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_strides.py index 45c3ebc739..17874a3450 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_strides.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_strides.py @@ -298,3 +298,108 @@ def test_strides_propagation_ignore_symbol_mapping(): if nsdfg is not None: assert original_stride in nsdfg.symbol_mapping assert str(nsdfg.symbol_mapping[original_stride]) == exp_stride + + +def _make_strides_propagation_dependent_symbol_nsdfg() -> dace.SDFG: + sdfg = dace.SDFG(util.unique_name("nested_sdfg")) + state = sdfg.add_state(is_start_block=True) + + array_names = ["a2", "b2"] + for name in array_names: + stride_sym = dace.symbol(f"{name}_stride", dtype=dace.uint64) + sdfg.add_symbol(stride_sym.name, stride_sym.dtype) + sdfg.add_array( + name, + shape=(10,), + dtype=dace.float64, + strides=(stride_sym,), + transient=False, + ) + + state.add_mapped_tasklet( + "nested_comp", + map_ranges={"__i0": "0:10"}, + inputs={"__in1": dace.Memlet("a2[__i0]")}, + code="__out = __in1 + 10.", + outputs={"__out": dace.Memlet("b2[__i0]")}, + external_edges=True, + ) + sdfg.validate() + return sdfg + + +def _make_strides_propagation_dependent_symbol_sdfg() -> tuple[dace.SDFG, dace_nodes.NestedSDFG]: + sdfg_level1 = dace.SDFG(util.unique_name("nested_level")) + state = sdfg_level1.add_state(is_start_block=True) + + array_names = ["a1", "b1"] + for name in array_names: + stride_sym1 = dace.symbol(f"{name}_1stride", dtype=dace.uint64) + stride_sym2 = dace.symbol(f"{name}_2stride", dtype=dace.int64) + sdfg_level1.add_symbol(stride_sym1.name, stride_sym1.dtype) + sdfg_level1.add_symbol(stride_sym2.name, stride_sym2.dtype) + stride_sym = stride_sym1 * stride_sym2 + sdfg_level1.add_array( + name, + shape=(10,), + dtype=dace.float64, + strides=(stride_sym,), + transient=False, + ) + + sdfg_level2 = _make_strides_propagation_dependent_symbol_nsdfg() + + for sym, sym_dtype in sdfg_level2.symbols.items(): + sdfg_level1.add_symbol(sym, sym_dtype) + + nsdfg = state.add_nested_sdfg( + sdfg=sdfg_level2, + parent=sdfg_level1, + inputs={"a2"}, + outputs={"b2"}, + symbol_mapping={s: s for s in sdfg_level2.symbols}, + ) + + state.add_edge(state.add_access("a1"), None, nsdfg, "a2", dace.Memlet("a1[0:10]")) + state.add_edge(nsdfg, "b2", state.add_access("b1"), None, dace.Memlet("b1[0:10]")) + sdfg_level1.validate() + + return sdfg_level1, nsdfg + + +def test_strides_propagation_dependent_symbol(): + sdfg_level1, nsdfg_level2 = _make_strides_propagation_dependent_symbol_sdfg() + sym1_dtype = dace.uint64 + sym2_dtype = dace.int64 + + # Ensure that the special symbols are not already present inside the nested SDFG. + for aname, adesc in sdfg_level1.arrays.items(): + sym1 = f"{aname}_1stride" + sym2 = f"{aname}_2stride" + for sym, dtype in [(sym1, sym1_dtype), (sym2, sym2_dtype)]: + assert sym in {fs.name for fs in adesc.strides[0].free_symbols} + assert sym not in nsdfg_level2.symbol_mapping + assert sym not in nsdfg_level2.sdfg.symbols + assert sym in sdfg_level1.symbols + assert sdfg_level1.symbols[sym] == dtype + + # Now propagate `a1` and `b1`. + gtx_transformations.gt_propagate_strides_of(sdfg_level1, "a1", ignore_symbol_mapping=True) + sdfg_level1.validate() + gtx_transformations.gt_propagate_strides_of(sdfg_level1, "b1", ignore_symbol_mapping=True) + sdfg_level1.validate() + + # Now we check if the update has worked. + for aname, adesc in sdfg_level1.arrays.items(): + sym1 = f"{aname}_1stride" + sym2 = f"{aname}_2stride" + adesc2 = nsdfg_level2.sdfg.arrays[aname.replace("1", "2")] + assert adesc2.strides == adesc.strides + + for sym, dtype in [(sym1, sym1_dtype), (sym2, sym2_dtype)]: + assert sym in nsdfg_level2.symbol_mapping + assert nsdfg_level2.symbol_mapping[sym].name == sym + assert sym in sdfg_level1.symbols + assert sdfg_level1.symbols[sym] == dtype + assert sym in nsdfg_level2.sdfg.symbols + assert nsdfg_level2.sdfg.symbols[sym] == dtype From 07e6a5cd61c17b4039c0a6d3ce0d3003fbed8f9c Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Thu, 19 Dec 2024 15:23:13 +0100 Subject: [PATCH 72/80] Changed the default of `ignore_symbol_mapping` to `True`. --- .../dace_fieldview/transformations/strides.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py index 2cc75e195d..bf298c0164 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py @@ -41,6 +41,11 @@ def gt_change_transient_strides( transients in the optimal way. The function should run after all maps have been created. + After the strides have been adjusted the function will also propagate + the strides into nested SDFG. This propagation will happen with + `ignore_symbol_mapping` set to `True`, see `gt_propagate_strides_of()` + for more. + Args: sdfg: The SDFG to process. gpu: If the SDFG is supposed to run on the GPU. @@ -123,13 +128,14 @@ def _gt_change_transient_strides_non_recursive_impl( state=state, outer_node=access_node, processed_nsdfgs=processed_nsdfgs, + ignore_symbol_mapping=True, ) def gt_propagate_strides_of( sdfg: dace.SDFG, data_name: str, - ignore_symbol_mapping: bool = False, + ignore_symbol_mapping: bool = True, ) -> None: """Propagates the strides of `data_name` within the whole SDFG. @@ -140,7 +146,7 @@ def gt_propagate_strides_of( Args: sdfg: The SDFG on which we operate. data_name: Name of the data descriptor that should be handled. - ignore_symbol_mapping: If `False`, the default, try to modify the `symbol_mapping` + ignore_symbol_mapping: If `False` (default is `True`) try to modify the `symbol_mapping` of NestedSDFGs instead of manipulating the data descriptor. """ @@ -164,7 +170,7 @@ def gt_propagate_strides_from_access_node( sdfg: dace.SDFG, state: dace.SDFGState, outer_node: dace_nodes.AccessNode, - ignore_symbol_mapping: bool = False, + ignore_symbol_mapping: bool = True, processed_nsdfgs: Optional[set[PropagatedStrideRecord]] = None, ) -> None: """Propagates the stride of `outer_node` to any adjacent reachable through its edges. @@ -184,7 +190,7 @@ def gt_propagate_strides_from_access_node( state: The state where the data node is used. edge: The edge that reads from the data node, the nested SDFG is expected as the destination. outer_node: The data node whose strides should be propagated. - ignore_symbol_mapping: If `False`, the default, try to modify the `symbol_mapping` + ignore_symbol_mapping: If `False` (default is `True`), try to modify the `symbol_mapping` of NestedSDFGs instead of manipulating the data descriptor. processed_nsdfgs: Set of NestedSDFG that were already processed and will be ignored. Only specify when you know what your are doing. From 4bf145b7d63ca6c98e94cb0d02f6bebed4246690 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Thu, 19 Dec 2024 15:40:19 +0100 Subject: [PATCH 73/80] Added Edoardo's comments. --- .../dace_fieldview/transformations/strides.py | 33 +++++++------------ 1 file changed, 11 insertions(+), 22 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py index bf298c0164..7854cbea12 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py @@ -6,7 +6,6 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause -import warnings from typing import Optional, TypeAlias import dace @@ -173,7 +172,7 @@ def gt_propagate_strides_from_access_node( ignore_symbol_mapping: bool = True, processed_nsdfgs: Optional[set[PropagatedStrideRecord]] = None, ) -> None: - """Propagates the stride of `outer_node` to any adjacent reachable through its edges. + """Propagates the stride of `outer_node` to any adjacent NestedSDFG. The function will propagate the strides of the data descriptor `outer_node` refers to along all adjacent edges of `outer_node`. If one of these edges @@ -227,10 +226,10 @@ def gt_map_strides_to_dst_nested_sdfg( ignore_symbol_mapping: bool = False, processed_nsdfgs: Optional[set[PropagatedStrideRecord]] = None, ) -> None: - """Propagates the strides of `outer_node` along `edge` along the dataflow. + """Propagates the strides of `outer_node` along `edge` in the dataflow direction. - In this context "along the dataflow" means that `edge` is an outgoing - edge of `outer_node` and the strides are into all NestedSDFGs that + In this context "along the dataflow direction" means that `edge` is an outgoing + edge of `outer_node` and the strides are propagated into all NestedSDFGs that are downstream of `outer_node`. Except in certain cases this function should not be used directly. It is @@ -267,11 +266,11 @@ def gt_map_strides_to_src_nested_sdfg( ignore_symbol_mapping: bool = False, processed_nsdfgs: Optional[set[PropagatedStrideRecord]] = None, ) -> None: - """Propagates the strides of `outer_node` along `edge` against the dataflow. + """Propagates the strides of `outer_node` along `edge` in the opposite direction of the dataflow - In this context "along the dataflow" means that `edge` is an incoming - edge of `outer_node` and the strides are into all NestedSDFGs that - are upstream of `outer_node`. + In this context "in the opposite direction of the dataflow" means that `edge` + is an incoming edge of `outer_node` and the strides are propagated into all + NestedSDFGs that are upstream of `outer_node`. Except in certain cases this function should not be used directly. It is instead recommended to use `gt_propagate_strides_of()`, which propagates @@ -500,13 +499,10 @@ def _gt_map_strides_into_nested_sdfg( new_strides: list = [] for dim_ostride, dim_oinflow in zip(outer_strides, outer_inflow, strict=True): if dim_oinflow == 1: - # This is the case of implicit slicing along one dimension. The inner - # array descriptor has shape != 1 in `current_inner_dim`, which has - # to map to a subsequent dimension of `outer_inflow` + # This is the case of implicit slicing along one dimension. pass else: # There is inflow into the SDFG, so we need the stride. - assert dim_oinflow != 0 new_strides.append(dim_ostride) assert len(new_strides) <= len(inner_shape) @@ -580,15 +576,8 @@ def _gt_map_strides_into_nested_sdfg( # Now propagate the symbols from the parent SDFG to the NestedSDFG. for sym in missing_symbol_mappings: - if str(sym) in sdfg.symbols: - nsdfg_node.sdfg.add_symbol(sym.name, sdfg.symbols[sym.name]) - else: - # The symbol is not known in the parent SDFG, so we add it - nsdfg_node.sdfg.add_symbol(sym.name, sym.dtype) - warnings.warn( - f"Could not find the symbol '{sym}' in the parent SDFG while modifying the strides, use '{nsdfg_node.sdfg.symbols[sym.name]}' as dtype.", - stacklevel=1, - ) + assert sym.name in sdfg.symbols, f"Expected that '{sym}' is defined in the parent SDFG." + nsdfg_node.sdfg.add_symbol(sym.name, sdfg.symbols[sym.name]) nsdfg_node.symbol_mapping[sym.name] = sym # Now create aliases for the old symbols that were used as strides. From 2b03bb4799ff092a68aecea820074813779cfc17 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Fri, 20 Dec 2024 07:55:32 +0100 Subject: [PATCH 74/80] Removed the creation of aliasing if symbol tables are ignored. I realized that allowing this is not very safe. I also added a test to show that. --- .../dace_fieldview/transformations/strides.py | 16 ++++++++-------- .../transformation_tests/test_strides.py | 5 +++-- 2 files changed, 11 insertions(+), 10 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py index 7854cbea12..06dfe6626c 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py @@ -223,7 +223,7 @@ def gt_map_strides_to_dst_nested_sdfg( state: dace.SDFGState, edge: dace.sdfg.graph.Edge, outer_node: dace.nodes.AccessNode, - ignore_symbol_mapping: bool = False, + ignore_symbol_mapping: bool = True, processed_nsdfgs: Optional[set[PropagatedStrideRecord]] = None, ) -> None: """Propagates the strides of `outer_node` along `edge` in the dataflow direction. @@ -460,7 +460,7 @@ def _gt_map_strides_into_nested_sdfg( inner_data: str, outer_subset: dace.subsets.Subset, outer_desc: dace_data.Data, - ignore_symbol_mapping: bool = False, + ignore_symbol_mapping: bool = True, ) -> None: """Modify the strides of `inner_data` inside `nsdfg_node` to match `outer_desc`. @@ -479,9 +479,12 @@ def _gt_map_strides_into_nested_sdfg( ignore_symbol_mapping: If possible the function will perform the renaming through the `symbol_mapping` of the nested SDFG. If `True` then the function will always perform the renaming. + Note that setting this value to `False` might have negative side effects. Todo: - Handle explicit dimensions of size 1. + - What should we do if the stride symbol is used somewhere else, creating an + alias is probably not the right thing? """ # We need to compute the new strides. In the following we assume that the # relative order of the dimensions does not change, but we support the case @@ -526,7 +529,9 @@ def _gt_map_strides_into_nested_sdfg( # The first is to create an alias in the `symbol_mapping`, however, # this is only possible if the current strides are singular symbols, # like `__a_strides_1`, but not expressions such as `horizontal_end - horizontal_start` - # or literal values. + # or literal values. Furthermore, this would change the meaning of the + # old stride symbol in any context and not only in the one of the stride + # of a single and isolated data descriptor. # The second way would be to replace `strides` attribute of the # inner data descriptor. In case the new stride consists of expressions # such as `value1 - value2` we have to make them available inside the @@ -580,11 +585,6 @@ def _gt_map_strides_into_nested_sdfg( nsdfg_node.sdfg.add_symbol(sym.name, sdfg.symbols[sym.name]) nsdfg_node.symbol_mapping[sym.name] = sym - # Now create aliases for the old symbols that were used as strides. - for old_sym, new_sym in zip(inner_strides_init, new_strides): - if dace.symbolic.issymbolic(old_sym) and old_sym.is_symbol: - nsdfg_node.symbol_mapping[str(old_sym)] = dace.symbolic.pystr_to_symbolic(new_sym) - def _gt_find_toplevel_data_accesses( sdfg: dace.SDFG, diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_strides.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_strides.py index 17874a3450..22d1b16b39 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_strides.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_strides.py @@ -280,7 +280,7 @@ def test_strides_propagation_ignore_symbol_mapping(): nsdfg = sdfg.parent_nsdfg_node if nsdfg is not None: assert original_stride in nsdfg.symbol_mapping - assert str(nsdfg.symbol_mapping[original_stride]) == exp_stride + assert str(nsdfg.symbol_mapping[original_stride]) == original_stride # Now we also propagate `c` thus now all data descriptors have the same stride gtx_transformations.gt_propagate_strides_of(sdfg_level1, "c1", ignore_symbol_mapping=True) @@ -296,8 +296,9 @@ def test_strides_propagation_ignore_symbol_mapping(): nsdfg = sdfg.parent_nsdfg_node if nsdfg is not None: + # The symbol mapping must should not be updated. assert original_stride in nsdfg.symbol_mapping - assert str(nsdfg.symbol_mapping[original_stride]) == exp_stride + assert str(nsdfg.symbol_mapping[original_stride]) == original_stride def _make_strides_propagation_dependent_symbol_nsdfg() -> dace.SDFG: From 40c225d6e601c7fee7c612da620bc0b37d15895b Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Fri, 20 Dec 2024 08:21:24 +0100 Subject: [PATCH 75/80] Added a test that shows that `ignore_symbol_mapping=False` does produces errors in certain cases. --- .../transformation_tests/test_strides.py | 129 +++++++++++++++++- 1 file changed, 127 insertions(+), 2 deletions(-) diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_strides.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_strides.py index 22d1b16b39..6d6a36028a 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_strides.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_strides.py @@ -7,6 +7,8 @@ # SPDX-License-Identifier: BSD-3-Clause import pytest +import numpy as np +import copy dace = pytest.importorskip("dace") from dace import symbolic as dace_symbolic @@ -302,7 +304,7 @@ def test_strides_propagation_ignore_symbol_mapping(): def _make_strides_propagation_dependent_symbol_nsdfg() -> dace.SDFG: - sdfg = dace.SDFG(util.unique_name("nested_sdfg")) + sdfg = dace.SDFG(util.unique_name("strides_propagation_dependent_symbol_nsdfg")) state = sdfg.add_state(is_start_block=True) array_names = ["a2", "b2"] @@ -330,7 +332,7 @@ def _make_strides_propagation_dependent_symbol_nsdfg() -> dace.SDFG: def _make_strides_propagation_dependent_symbol_sdfg() -> tuple[dace.SDFG, dace_nodes.NestedSDFG]: - sdfg_level1 = dace.SDFG(util.unique_name("nested_level")) + sdfg_level1 = dace.SDFG(util.unique_name("strides_propagation_dependent_symbol_sdfg")) state = sdfg_level1.add_state(is_start_block=True) array_names = ["a1", "b1"] @@ -404,3 +406,126 @@ def test_strides_propagation_dependent_symbol(): assert sdfg_level1.symbols[sym] == dtype assert sym in nsdfg_level2.sdfg.symbols assert nsdfg_level2.sdfg.symbols[sym] == dtype + + +def _make_strides_propagation_shared_symbols_nsdfg() -> dace.SDFG: + sdfg = dace.SDFG(util.unique_name("strides_propagation_shared_symbols_nsdfg")) + state = sdfg.add_state(is_start_block=True) + + # NOTE: Both arrays have the same symbols used for strides. + array_names = ["a2", "b2"] + stride_sym0 = dace.symbol(f"__stride_0", dtype=dace.uint64) + stride_sym1 = dace.symbol(f"__stride_1", dtype=dace.uint64) + sdfg.add_symbol(stride_sym0.name, stride_sym0.dtype) + sdfg.add_symbol(stride_sym1.name, stride_sym1.dtype) + for name in array_names: + sdfg.add_array( + name, + shape=(10, 10), + dtype=dace.float64, + strides=(stride_sym0, stride_sym1), + transient=False, + ) + + state.add_mapped_tasklet( + "nested_comp", + map_ranges={ + "__i0": "0:10", + "__i1": "0:10", + }, + inputs={"__in1": dace.Memlet("a2[__i0, __i1]")}, + code="__out = __in1 + 10.", + outputs={"__out": dace.Memlet("b2[__i0, __i1]")}, + external_edges=True, + ) + sdfg.validate() + return sdfg + + +def _make_strides_propagation_shared_symbols_sdfg() -> tuple[dace.SDFG, dace_nodes.NestedSDFG]: + sdfg_level1 = dace.SDFG(util.unique_name("strides_propagation_shared_symbols_sdfg")) + state = sdfg_level1.add_state(is_start_block=True) + + # NOTE: Both arrays use the same symbols as strides. + # Furthermore, they are the same as in the nested SDFG, i.e. they are shared. + array_names = ["a1", "b1"] + stride_sym0 = dace.symbol(f"__stride_0", dtype=dace.uint64) + stride_sym1 = dace.symbol(f"__stride_1", dtype=dace.uint64) + sdfg_level1.add_symbol(stride_sym0.name, stride_sym0.dtype) + sdfg_level1.add_symbol(stride_sym1.name, stride_sym1.dtype) + for name in array_names: + sdfg_level1.add_array( + name, + shape=(10, 10), + dtype=dace.float64, + strides=( + stride_sym0, + stride_sym1, + ), + transient=False, + ) + + sdfg_level2 = _make_strides_propagation_shared_symbols_nsdfg() + nsdfg = state.add_nested_sdfg( + sdfg=sdfg_level2, + parent=sdfg_level1, + inputs={"a2"}, + outputs={"b2"}, + symbol_mapping={s: s for s in sdfg_level2.symbols}, + ) + + state.add_edge(state.add_access("a1"), None, nsdfg, "a2", dace.Memlet("a1[0:10, 0:10]")) + state.add_edge(nsdfg, "b2", state.add_access("b1"), None, dace.Memlet("b1[0:10, 0:10]")) + sdfg_level1.validate() + + return sdfg_level1, nsdfg + + +def test_strides_propagation_shared_symbols_sdfg(): + """ + Note: + If `ignore_symbol_mapping` is `False` then this test will fail. + This is because the `symbol_mapping` of the NestedSDFG will act on the + whole SDFG. Thus it will not only change the strides of `b` but as an + unintended side effect also the strides of `a`. + """ + + def ref(a1, b1): + for i in range(10): + for j in range(10): + b1[i, j] = a1[i, j] + 10.0 + + sdfg_level1, nsdfg_level2 = _make_strides_propagation_shared_symbols_sdfg() + + res_args = { + "a1": np.array(np.random.rand(10, 10), order="C", dtype=np.float64, copy=True), + "b1": np.array(np.random.rand(10, 10), order="F", dtype=np.float64, copy=True), + } + ref_args = copy.deepcopy(res_args) + + # Now we change the strides of `b1`, and then we propagate the new strides + # into the nested SDFG. We want to keep (for whatever reasons) strides of `a1`. + stride_b1_sym0 = dace.symbol(f"__b1_stride_0", dtype=dace.uint64) + stride_b1_sym1 = dace.symbol(f"__b1_stride_1", dtype=dace.uint64) + sdfg_level1.add_symbol(stride_b1_sym0.name, stride_b1_sym0.dtype) + sdfg_level1.add_symbol(stride_b1_sym1.name, stride_b1_sym1.dtype) + + desc_b1 = sdfg_level1.arrays["b1"] + desc_b1.set_shape((10, 10), (stride_b1_sym0, stride_b1_sym1)) + + # Now we propagate the data into it. + gtx_transformations.gt_propagate_strides_of(sdfg=sdfg_level1, data_name="b1") + + # Now we have to prepare the call arguments, i.e. adding the strides + itemsize = res_args["b1"].itemsize + res_args.update( + { + "__b1_stride_0": res_args["b1"].strides[0] // itemsize, + "__b1_stride_1": res_args["b1"].strides[1] // itemsize, + "__stride_0": res_args["a1"].strides[0] // itemsize, + "__stride_1": res_args["a1"].strides[1] // itemsize, + } + ) + ref(**ref_args) + sdfg_level1(**res_args) + assert np.allclose(ref_args["b1"], res_args["b1"]) From 419a386722a685316ee5917e9c8d8e44905e153b Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Fri, 20 Dec 2024 08:44:46 +0100 Subject: [PATCH 76/80] Updated the description. --- .../transformation_tests/test_strides.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_strides.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_strides.py index 6d6a36028a..5b16e41bc3 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_strides.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_strides.py @@ -482,7 +482,14 @@ def _make_strides_propagation_shared_symbols_sdfg() -> tuple[dace.SDFG, dace_nod def test_strides_propagation_shared_symbols_sdfg(): - """ + """Tests what happens if symbols are (unintentionally) shred between descriptor. + + This test looks rather artificial, but it is actually quite likely. Because + transients will most likely have the same shape and if the strides are not + set explicitly, which is the case, the strides will also be related to their + shape. This test explores the situation, where we can, for whatever reason, + only propagate the strides of one such data descriptor. + Note: If `ignore_symbol_mapping` is `False` then this test will fail. This is because the `symbol_mapping` of the NestedSDFG will act on the @@ -514,7 +521,10 @@ def ref(a1, b1): desc_b1.set_shape((10, 10), (stride_b1_sym0, stride_b1_sym1)) # Now we propagate the data into it. - gtx_transformations.gt_propagate_strides_of(sdfg=sdfg_level1, data_name="b1") + gtx_transformations.gt_propagate_strides_of( + sdfg=sdfg_level1, + data_name="b1", + ) # Now we have to prepare the call arguments, i.e. adding the strides itemsize = res_args["b1"].itemsize From cc9801b7364e91d782772b7b4bbb949857c03ec3 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Fri, 20 Dec 2024 08:46:38 +0100 Subject: [PATCH 77/80] Applied Edoardo's comment. --- .../runners/dace_fieldview/transformations/strides.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py index 06dfe6626c..aa9d55b5f6 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py @@ -460,7 +460,7 @@ def _gt_map_strides_into_nested_sdfg( inner_data: str, outer_subset: dace.subsets.Subset, outer_desc: dace_data.Data, - ignore_symbol_mapping: bool = True, + ignore_symbol_mapping: bool, ) -> None: """Modify the strides of `inner_data` inside `nsdfg_node` to match `outer_desc`. From 360baae7f3b21521b8d55dec3f4f4e122501b5e1 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Fri, 20 Dec 2024 09:16:33 +0100 Subject: [PATCH 78/80] Added a todo from Edoardo's suggestions. --- .../runners/dace_fieldview/transformations/strides.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py index aa9d55b5f6..980b2a8fdf 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py @@ -485,6 +485,8 @@ def _gt_map_strides_into_nested_sdfg( - Handle explicit dimensions of size 1. - What should we do if the stride symbol is used somewhere else, creating an alias is probably not the right thing? + - Handle the case if the outer stride symbol is already used in another + context inside the Neste SDFG. """ # We need to compute the new strides. In the following we assume that the # relative order of the dimensions does not change, but we support the case From a0c37cb5ddb177c5103c36d25d943fde5e1091c6 Mon Sep 17 00:00:00 2001 From: Edoardo Paone Date: Fri, 20 Dec 2024 09:57:10 +0100 Subject: [PATCH 79/80] minor edit --- .../gtir_builtin_translators.py | 61 ++++++++++--------- .../runners/dace_fieldview/gtir_dataflow.py | 4 +- .../runners/dace_fieldview/utility.py | 4 +- 3 files changed, 37 insertions(+), 32 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py index f59755649b..131321f77e 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py @@ -161,7 +161,7 @@ def get_local_view( def get_tuple_type(data: tuple[FieldopResult, ...]) -> ts.TupleType: """ - Compute the `ts.TupleType` corresponding to the structure of a tuple of data nodes. + Compute the `ts.TupleType` corresponding to the structure of a tuple of `FieldopResult`. """ return ts.TupleType( types=[get_tuple_type(d) if isinstance(d, tuple) else d.gt_type for d in data] @@ -169,6 +169,11 @@ def get_tuple_type(data: tuple[FieldopResult, ...]) -> ts.TupleType: def flatten_tuples(name: str, arg: FieldopResult) -> list[tuple[str, FieldopData]]: + """ + Visit a `FieldopResult`, potentially containing nested tuples, and construct a list + of pairs `(str, FieldopData)` containing the symbol name of each tuple field and + the corresponding `FieldopData`. + """ if isinstance(arg, tuple): tuple_type = get_tuple_type(arg) tuple_symbols = dace_gtir_utils.flatten_tuple_fields(name, tuple_type) @@ -337,9 +342,6 @@ def _create_field_operator( # here we setup the edges passing through the map entry node for edge in input_edges: - if isinstance(edge, gtir_dataflow.EmptyInputEdge) and me is None: - # cannot create empty edge from MapEntry node, if this is not present - continue edge.connect(me) def create_field(output_edge: gtir_dataflow.DataflowOutputEdge, sym: gtir.Sym) -> FieldopData: @@ -402,6 +404,7 @@ def create_field(output_edge: gtir_dataflow.DataflowOutputEdge, sym: gtir.Sym) - assert isinstance(node_type, ts.FieldType) return create_field(output_edges, im.sym("x", node_type)) else: + # handle tuples of fields assert isinstance(node_type, ts.TupleType) return gtx_utils.tree_map(create_field)( output_edges, dace_gtir_utils.make_symbol_tuple("x", node_type) @@ -888,8 +891,8 @@ def translate_scan( stencil_expr = scan_expr.args[0] assert isinstance(stencil_expr, gtir.Lambda) - # params[0]: the lambda parameter to propagate the scan state on the vertical dimension - scan_state = str(stencil_expr.params[0].id) + # params[0]: the lambda parameter to propagate the scan carry on the vertical dimension + scan_carry = str(stencil_expr.params[0].id) # params[1]: boolean flag for forward/backward scan assert isinstance(scan_expr.args[1], gtir.Literal) and ti.is_logical(scan_expr.args[1].type) @@ -908,13 +911,13 @@ def scan_output_name(input_name: str) -> str: # visit the initialization value of the scan expression init_data = sdfg_builder.visit(init_value, sdfg=sdfg, head_state=state) - # extract type definition of the scan state - scan_state_type = ( + # extract type definition of the scan carry + scan_carry_type = ( init_data.gt_type if isinstance(init_data, FieldopData) else get_tuple_type(init_data) ) # create list of params to the lambda function with associated node type - lambda_symbols = {scan_state: scan_state_type} | { + lambda_symbols = {scan_carry: scan_carry_type} | { str(p.id): arg.type for p, arg in zip(stencil_expr.params[1:], node.args, strict=True) if isinstance(arg.type, ts.DataType) @@ -925,7 +928,7 @@ def scan_output_name(input_name: str) -> str: # the data descriptor with the correct field domain offsets for field arguments lambda_args = [sdfg_builder.visit(arg, sdfg=sdfg, head_state=state) for arg in node.args] lambda_args_mapping = { - scan_input_name(scan_state): init_data, + scan_input_name(scan_carry): init_data, } | { str(param.id): arg for param, arg in zip(stencil_expr.params[1:], lambda_args, strict=True) } @@ -941,11 +944,11 @@ def scan_output_name(input_name: str) -> str: { str(sym.id): sym.type for sym in dace_gtir_utils.flatten_tuple_fields( - scan_output_name(scan_state), scan_state_type + scan_output_name(scan_carry), scan_carry_type ) } - if isinstance(scan_state_type, ts.TupleType) - else {scan_output_name(scan_state): scan_state_type} + if isinstance(scan_carry_type, ts.TupleType) + else {scan_output_name(scan_carry): scan_carry_type} ) # the scan operator is implemented as an nested SDFG implementing the lambda expression @@ -998,14 +1001,14 @@ def scan_output_name(input_name: str) -> str: nsdfg, compute_state, stencil_builder, stencil_expr, args=stencil_args ) - # now initialize the scan state - scan_state_input = ( - dace_gtir_utils.make_symbol_tuple(scan_state, scan_state_type) - if isinstance(scan_state_type, ts.TupleType) - else im.sym(scan_state, scan_state_type) + # now initialize the scan carry + scan_carry_input = ( + dace_gtir_utils.make_symbol_tuple(scan_carry, scan_carry_type) + if isinstance(scan_carry_type, ts.TupleType) + else im.sym(scan_carry, scan_carry_type) ) - def init_scan_state(sym: gtir.Sym) -> None: + def init_scan_carry(sym: gtir.Sym) -> None: scan_state = str(sym.id) scan_state_desc = nsdfg.data(scan_state) input_state = scan_input_name(scan_state) @@ -1018,17 +1021,17 @@ def init_scan_state(sym: gtir.Sym) -> None: nsdfg.make_array_memlet(input_state), ) - if isinstance(scan_state_input, tuple): - gtx_utils.tree_map(init_scan_state)(scan_state_input) + if isinstance(scan_carry_input, tuple): + gtx_utils.tree_map(init_scan_carry)(scan_carry_input) else: - init_scan_state(scan_state_input) + init_scan_carry(scan_carry_input) # connect the dataflow input directly to the source data nodes, without passing through a map node; # the reason is that the map for horizontal domain is outside the scan loop region for edge in input_edges: edge.connect(map_entry=None) - # connect the dataflow result nodes to the variables that carry the scan state along the column axis + # connect the dataflow result nodes to the carry variables def connect_scan_output( scan_output_edge: gtir_dataflow.DataflowOutputEdge, sym: gtir.Sym ) -> FieldopData: @@ -1057,12 +1060,12 @@ def connect_scan_output( return FieldopData(output_node, output_type, scan_output_offset) lambda_output = ( - gtx_utils.tree_map(connect_scan_output)(result, scan_state_input) - if (isinstance(result, tuple) and isinstance(scan_state_input, tuple)) - else connect_scan_output(result, scan_state_input) + gtx_utils.tree_map(connect_scan_output)(result, scan_carry_input) + if (isinstance(result, tuple) and isinstance(scan_carry_input, tuple)) + else connect_scan_output(result, scan_carry_input) if ( isinstance(result, gtir_dataflow.DataflowOutputEdge) - and isinstance(scan_state_input, gtir.Sym) + and isinstance(scan_carry_input, gtir.Sym) ) else None ) @@ -1075,8 +1078,8 @@ def connect_scan_output( if (compute_state.degree(data_node) == 0) and ( (not data_desc.transient) or data_node.data.startswith( - scan_state - ) # exceptional case where the state is not used, not a scan indeed + scan_carry + ) # exceptional case where the carry variable is not used, not a scan indeed ): # isolated node, connect it to a transient to avoid SDFG validation errors temp, temp_desc = nsdfg.add_temp_transient_like(data_desc) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow.py index 5d7159c987..ee22c4cd13 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow.py @@ -196,6 +196,7 @@ class EmptyInputEdge(DataflowInputEdge): node: dace.nodes.Tasklet def connect(self, map_entry: Optional[dace.nodes.MapEntry]) -> None: + # cannot create empty edge from MapEntry node, if this is not present if map_entry is not None: self.state.add_nedge(map_entry, self.node, dace.Memlet()) @@ -564,7 +565,7 @@ def _visit_deref(self, node: gtir.FunCall) -> DataExpr: def _visit_if(self, node: gtir.FunCall) -> ValueExpr | tuple[ValueExpr | tuple[Any, ...], ...]: assert len(node.args) == 3 - # TODO(edopao): enable once DaCe supports it in next release + # TODO(edopao): enable once supported in next DaCe release use_conditional_block: Final[bool] = False condition_value = self.visit(node.args[0]) @@ -690,6 +691,7 @@ def visit_arg(arg: IteratorExpr | DataExpr) -> IteratorExpr | ValueExpr: lambda_args.append(inner_arg) lambda_params.append(im.sym(p)) + # visit each branch of the if-statement as it was a Lambda node lambda_node = gtir.Lambda(params=lambda_params, expr=expr) return visit_lambda(nsdfg, state, self.subgraph_builder, lambda_node, lambda_args) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/utility.py b/src/gt4py/next/program_processors/runners/dace_fieldview/utility.py index 33c333a9f3..ad120e2502 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/utility.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/utility.py @@ -30,7 +30,7 @@ def get_map_variable(dim: gtx_common.Dimension) -> str: def make_symbol_tuple(tuple_name: str, tuple_type: ts.TupleType) -> tuple[gtir.Sym, ...]: """ Creates a tuple representation of the symbols corresponding to the tuple fields. - The constructed tuple presrves the nested nature of the type, is any. + The constructed tuple preserves the nested nature of the type, if any. Examples -------- @@ -53,7 +53,7 @@ def make_symbol_tuple(tuple_name: str, tuple_type: ts.TupleType) -> tuple[gtir.S def flatten_tuple_fields(tuple_name: str, tuple_type: ts.TupleType) -> list[gtir.Sym]: """ - Creates a list of names with the corresponding data type for all elements of the given tuple. + Creates a list of symbols, annotated with the data type, for all elements of the given tuple. Examples -------- From 0f9043bc041bb49dba6d6fcbee13c059e81c2042 Mon Sep 17 00:00:00 2001 From: Edoardo Paone Date: Fri, 20 Dec 2024 14:19:47 +0100 Subject: [PATCH 80/80] fix for missing symbols in nested sdfg --- .../runners/dace_fieldview/gtir_dataflow.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow.py index ee22c4cd13..4f53a9dcad 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow.py @@ -664,10 +664,6 @@ def visit_arg(arg: IteratorExpr | DataExpr) -> IteratorExpr | ValueExpr: nsdfg.add_datadesc(inner_data, inner_desc) input_memlets[inner_data] = (arg_node, arg_subset) - if arg_subset: - # symbols used in memlet subset are not automatically mapped to the parent SDFG - nsdfg_symbol_mapping.update({sym: sym for sym in arg_subset.free_symbols}) - inner_node = state.add_access(inner_data) if isinstance(arg, IteratorExpr): return IteratorExpr(inner_node, arg.gt_dtype, arg.field_domain, arg.indices) @@ -750,7 +746,7 @@ def construct_output( self.sdfg, inputs=set(input_memlets.keys()), outputs=outputs, - symbol_mapping=nsdfg_symbol_mapping, + symbol_mapping=nsdfg_symbol_mapping | {str(sym): sym for sym in nsdfg.free_symbols}, ) for inner, (src_node, src_subset) in input_memlets.items():