diff --git a/src/gt4py/cartesian/backend/dace_backend.py b/src/gt4py/cartesian/backend/dace_backend.py index 163a0dee3f..f49895a435 100644 --- a/src/gt4py/cartesian/backend/dace_backend.py +++ b/src/gt4py/cartesian/backend/dace_backend.py @@ -18,7 +18,6 @@ import dace import dace.data from dace.sdfg.utils import inline_sdfgs -from dace.serialize import dumps from gt4py import storage as gt_storage from gt4py.cartesian import config as gt_config @@ -56,10 +55,6 @@ from gt4py.cartesian.stencil_object import StencilObject -def _serialize_sdfg(sdfg: dace.SDFG): - return dumps(sdfg) - - def _specialize_transient_strides(sdfg: dace.SDFG, layout_map): repldict = replace_strides( [array for array in sdfg.arrays.values() if array.transient], layout_map @@ -125,7 +120,7 @@ def _set_expansion_orders(sdfg: dace.SDFG): def _set_tile_sizes(sdfg: dace.SDFG): - import gt4py.cartesian.gtc.daceir as dcir # avoid circular import + import gt4py.cartesian.gtc.dace.daceir as dcir # avoid circular import for node, _ in filter( lambda n: isinstance(n[0], StencilComputation), sdfg.all_nodes_recursive() diff --git a/src/gt4py/cartesian/frontend/nodes.py b/src/gt4py/cartesian/frontend/nodes.py index ed447bf37d..f84577e7b5 100644 --- a/src/gt4py/cartesian/frontend/nodes.py +++ b/src/gt4py/cartesian/frontend/nodes.py @@ -133,6 +133,8 @@ """ +from __future__ import annotations + import enum import operator import sys @@ -704,7 +706,7 @@ def is_single_index(self) -> bool: return self.start.level == self.end.level and self.start.offset == self.end.offset - 1 - def disjoint_from(self, other: "AxisInterval") -> bool: + def disjoint_from(self, other: AxisInterval) -> bool: def get_offset(bound: AxisBound) -> int: return ( 0 + bound.offset if bound.level == LevelMarker.START else sys.maxsize + bound.offset diff --git a/src/gt4py/cartesian/gtc/daceir.py b/src/gt4py/cartesian/gtc/dace/daceir.py similarity index 100% rename from src/gt4py/cartesian/gtc/daceir.py rename to src/gt4py/cartesian/gtc/dace/daceir.py diff --git a/src/gt4py/cartesian/gtc/dace/expansion/daceir_builder.py b/src/gt4py/cartesian/gtc/dace/expansion/daceir_builder.py index 399d4d7af5..d5b1c91466 100644 --- a/src/gt4py/cartesian/gtc/dace/expansion/daceir_builder.py +++ b/src/gt4py/cartesian/gtc/dace/expansion/daceir_builder.py @@ -18,7 +18,8 @@ import dace.subsets from gt4py import eve -from gt4py.cartesian.gtc import common, daceir as dcir, oir +from gt4py.cartesian.gtc import common, oir +from gt4py.cartesian.gtc.dace import daceir as dcir from gt4py.cartesian.gtc.dace.expansion_specification import Loop, Map, Sections, Stages from gt4py.cartesian.gtc.dace.utils import ( compute_dcir_access_infos, @@ -68,7 +69,13 @@ def _iterator(): ).unique(key=lambda x: x[2]) -def _get_tasklet_inout_memlets(node: oir.HorizontalExecution, *, get_outputs, global_ctx, **kwargs): +def _get_tasklet_inout_memlets( + node: oir.HorizontalExecution, + *, + get_outputs: bool, + global_ctx: DaCeIRBuilder.GlobalContext, + **kwargs, +): access_infos = compute_dcir_access_infos( node, block_extents=global_ctx.library_node.get_extents, @@ -190,12 +197,7 @@ def _get_dcir_decl( @dataclass class IterationContext: grid_subset: dcir.GridSubset - parent: Optional[DaCeIRBuilder.IterationContext] - - @classmethod - def init(cls, *args, **kwargs): - res = cls(*args, parent=None, **kwargs) - return res + parent: Optional[DaCeIRBuilder.IterationContext] = None def push_axes_extents(self, axes_extents) -> DaCeIRBuilder.IterationContext: res = self.grid_subset @@ -611,7 +613,7 @@ def _process_map_item( scope_nodes, item: Map, *, - global_ctx, + global_ctx: DaCeIRBuilder.GlobalContext, iteration_ctx: DaCeIRBuilder.IterationContext, symbol_collector: DaCeIRBuilder.SymbolCollector, **kwargs, @@ -787,19 +789,13 @@ def _process_iteration_item(self, scope, item, **kwargs): def visit_VerticalLoop( self, node: oir.VerticalLoop, *, global_ctx: DaCeIRBuilder.GlobalContext, **kwargs ): - start, end = (node.sections[0].interval.start, node.sections[0].interval.end) - - overall_interval = dcir.DomainInterval( - start=dcir.AxisBound(axis=dcir.Axis.K, level=start.level, offset=start.offset), - end=dcir.AxisBound(axis=dcir.Axis.K, level=end.level, offset=end.offset), - ) overall_extent = Extent.zeros(2) for he in node.walk_values().if_isinstance(oir.HorizontalExecution): overall_extent = overall_extent.union(global_ctx.library_node.get_extents(he)) - iteration_ctx = DaCeIRBuilder.IterationContext.init( + iteration_ctx = DaCeIRBuilder.IterationContext( grid_subset=dcir.GridSubset.from_gt4py_extent(overall_extent).set_interval( - axis=dcir.Axis.K, interval=overall_interval + axis=dcir.Axis.K, interval=node.sections[0].interval ) ) @@ -849,7 +845,7 @@ def visit_VerticalLoop( read_fields = set(memlet.field for memlet in read_memlets) write_fields = set(memlet.field for memlet in write_memlets) - res = dcir.NestedSDFG( + return dcir.NestedSDFG( label=global_ctx.library_node.label, states=self.to_state(computations, grid_subset=iteration_ctx.grid_subset), field_decls=field_decls, @@ -857,5 +853,3 @@ def visit_VerticalLoop( write_memlets=[memlet for memlet in field_memlets if memlet.field in write_fields], symbol_decls=list(symbol_collector.symbol_decls.values()), ) - - return res diff --git a/src/gt4py/cartesian/gtc/dace/expansion/expansion.py b/src/gt4py/cartesian/gtc/dace/expansion/expansion.py index 56bb6c1b3f..055bf64015 100644 --- a/src/gt4py/cartesian/gtc/dace/expansion/expansion.py +++ b/src/gt4py/cartesian/gtc/dace/expansion/expansion.py @@ -17,7 +17,7 @@ import dace.subsets import sympy -from gt4py.cartesian.gtc import daceir as dcir +from gt4py.cartesian.gtc.dace import daceir as dcir from gt4py.cartesian.gtc.dace.expansion.daceir_builder import DaCeIRBuilder from gt4py.cartesian.gtc.dace.expansion.sdfg_builder import StencilComputationSDFGBuilder diff --git a/src/gt4py/cartesian/gtc/dace/expansion/sdfg_builder.py b/src/gt4py/cartesian/gtc/dace/expansion/sdfg_builder.py index f6aa725b01..9d64464377 100644 --- a/src/gt4py/cartesian/gtc/dace/expansion/sdfg_builder.py +++ b/src/gt4py/cartesian/gtc/dace/expansion/sdfg_builder.py @@ -18,7 +18,7 @@ import dace.subsets from gt4py import eve -from gt4py.cartesian.gtc import daceir as dcir +from gt4py.cartesian.gtc.dace import daceir as dcir from gt4py.cartesian.gtc.dace.expansion.tasklet_codegen import TaskletCodegen from gt4py.cartesian.gtc.dace.expansion.utils import get_dace_debuginfo from gt4py.cartesian.gtc.dace.symbol_utils import data_type_to_dace_typeclass diff --git a/src/gt4py/cartesian/gtc/dace/expansion/tasklet_codegen.py b/src/gt4py/cartesian/gtc/dace/expansion/tasklet_codegen.py index 30e7c56bb4..c219667a4a 100644 --- a/src/gt4py/cartesian/gtc/dace/expansion/tasklet_codegen.py +++ b/src/gt4py/cartesian/gtc/dace/expansion/tasklet_codegen.py @@ -16,7 +16,7 @@ import gt4py.cartesian.gtc.common as common from gt4py import eve -from gt4py.cartesian.gtc import daceir as dcir +from gt4py.cartesian.gtc.dace import daceir as dcir from gt4py.cartesian.gtc.dace.symbol_utils import get_axis_bound_str from gt4py.cartesian.gtc.dace.utils import make_dace_subset from gt4py.eve.codegen import FormatTemplate as as_fmt diff --git a/src/gt4py/cartesian/gtc/dace/expansion/utils.py b/src/gt4py/cartesian/gtc/dace/expansion/utils.py index d8d2ce3176..919ec02996 100644 --- a/src/gt4py/cartesian/gtc/dace/expansion/utils.py +++ b/src/gt4py/cartesian/gtc/dace/expansion/utils.py @@ -16,7 +16,8 @@ import dace.subsets from gt4py import eve -from gt4py.cartesian.gtc import common, daceir as dcir, oir +from gt4py.cartesian.gtc import common, oir +from gt4py.cartesian.gtc.dace import daceir as dcir from gt4py.cartesian.gtc.definitions import Extent diff --git a/src/gt4py/cartesian/gtc/dace/expansion_specification.py b/src/gt4py/cartesian/gtc/dace/expansion_specification.py index 091849a1b7..c716f1a103 100644 --- a/src/gt4py/cartesian/gtc/dace/expansion_specification.py +++ b/src/gt4py/cartesian/gtc/dace/expansion_specification.py @@ -14,7 +14,8 @@ import dace -from gt4py.cartesian.gtc import common, daceir as dcir, oir +from gt4py.cartesian.gtc import common, oir +from gt4py.cartesian.gtc.dace import daceir as dcir from gt4py.cartesian.gtc.definitions import Extent diff --git a/src/gt4py/cartesian/gtc/dace/nodes.py b/src/gt4py/cartesian/gtc/dace/nodes.py index 5c2f11f30d..34401e18b9 100644 --- a/src/gt4py/cartesian/gtc/dace/nodes.py +++ b/src/gt4py/cartesian/gtc/dace/nodes.py @@ -20,7 +20,8 @@ import numpy as np from dace import library -from gt4py.cartesian.gtc import common, daceir as dcir, oir +from gt4py.cartesian.gtc import common, oir +from gt4py.cartesian.gtc.dace import daceir as dcir from gt4py.cartesian.gtc.dace.expansion.expansion import StencilComputationExpansion from gt4py.cartesian.gtc.definitions import Extent from gt4py.cartesian.gtc.oir import Decl, FieldDecl, VerticalLoop, VerticalLoopSection @@ -215,10 +216,8 @@ def has_splittable_regions(self): def tile_strides(self): if self.tile_sizes_interpretation == "strides": return self.tile_sizes - else: - overall_extent: Extent = next(iter(self.extents.values())) - for extent in self.extents.values(): - overall_extent |= extent - return { - key: value + overall_extent[key.to_idx()] for key, value in self.tile_sizes.items() - } + + overall_extent: Extent = next(iter(self.extents.values())) + for extent in self.extents.values(): + overall_extent |= extent + return {key: value + overall_extent[key.to_idx()] for key, value in self.tile_sizes.items()} diff --git a/src/gt4py/cartesian/gtc/dace/oir_to_dace.py b/src/gt4py/cartesian/gtc/dace/oir_to_dace.py index 283402e1ac..3555d555f9 100644 --- a/src/gt4py/cartesian/gtc/dace/oir_to_dace.py +++ b/src/gt4py/cartesian/gtc/dace/oir_to_dace.py @@ -17,7 +17,7 @@ import gt4py.cartesian.gtc.oir as oir from gt4py import eve -from gt4py.cartesian.gtc import daceir as dcir +from gt4py.cartesian.gtc.dace import daceir as dcir from gt4py.cartesian.gtc.dace.nodes import StencilComputation from gt4py.cartesian.gtc.dace.symbol_utils import data_type_to_dace_typeclass from gt4py.cartesian.gtc.dace.utils import compute_dcir_access_infos, make_dace_subset diff --git a/src/gt4py/cartesian/gtc/dace/symbol_utils.py b/src/gt4py/cartesian/gtc/dace/symbol_utils.py index 86823304db..b9b6a49ce0 100644 --- a/src/gt4py/cartesian/gtc/dace/symbol_utils.py +++ b/src/gt4py/cartesian/gtc/dace/symbol_utils.py @@ -19,7 +19,7 @@ if TYPE_CHECKING: - import gt4py.cartesian.gtc.daceir as dcir + import gt4py.cartesian.gtc.dace.daceir as dcir def data_type_to_dace_typeclass(data_type): diff --git a/src/gt4py/cartesian/gtc/dace/utils.py b/src/gt4py/cartesian/gtc/dace/utils.py index f4dade581d..9be2e9a07d 100644 --- a/src/gt4py/cartesian/gtc/dace/utils.py +++ b/src/gt4py/cartesian/gtc/dace/utils.py @@ -17,8 +17,9 @@ import numpy as np from gt4py import eve -from gt4py.cartesian.gtc import common, daceir as dcir, oir -from gt4py.cartesian.gtc.common import CartesianOffset +from gt4py.cartesian.gtc import common, oir +from gt4py.cartesian.gtc.common import CartesianOffset, VariableKOffset +from gt4py.cartesian.gtc.dace import daceir as dcir from gt4py.cartesian.gtc.passes.oir_optimizations.utils import compute_horizontal_block_extents @@ -56,7 +57,9 @@ def replace_strides(arrays, get_layout_map): return symbol_mapping -def get_tasklet_symbol(name, offset, is_target): +def get_tasklet_symbol( + name: eve.SymbolRef, offset: Union[CartesianOffset, VariableKOffset], is_target: bool +): if is_target: return f"__{name}" diff --git a/src/gt4py/cartesian/gtc/gtir_to_oir.py b/src/gt4py/cartesian/gtc/gtir_to_oir.py index 1e6b328a4e..560cbf96cf 100644 --- a/src/gt4py/cartesian/gtc/gtir_to_oir.py +++ b/src/gt4py/cartesian/gtc/gtir_to_oir.py @@ -117,47 +117,39 @@ def visit_NativeFuncCall(self, node: gtir.NativeFuncCall) -> oir.NativeFuncCall: loc=node.loc, ) - # --- Stmts --- + # --- Statements --- def visit_ParAssignStmt( self, node: gtir.ParAssignStmt, *, mask: Optional[oir.Expr] = None, **kwargs: Any ) -> Union[oir.AssignStmt, oir.MaskStmt]: - stmt: Union[oir.AssignStmt, oir.MaskStmt] = oir.AssignStmt( - left=self.visit(node.left), right=self.visit(node.right) - ) - if mask is not None: - # Wrap inside MaskStmt - stmt = oir.MaskStmt(body=[stmt], mask=mask, loc=node.loc) - return stmt + statement = oir.AssignStmt(left=self.visit(node.left), right=self.visit(node.right)) + if mask is None: + return statement + + # Wrap inside MaskStmt + return oir.MaskStmt(body=[statement], mask=mask, loc=node.loc) def visit_HorizontalRestriction( self, node: gtir.HorizontalRestriction, **kwargs: Any ) -> oir.HorizontalRestriction: - body_stmts = [] - for stmt in node.body: - stmt_or_stmts = self.visit(stmt, **kwargs) - stmts = utils.flatten_list( - [stmt_or_stmts] if isinstance(stmt_or_stmts, oir.Stmt) else stmt_or_stmts - ) - body_stmts.extend(stmts) + body = [] + for statement in node.body: + oir_statement = self.visit(statement, **kwargs) + body.extend(utils.flatten_list(utils.listify(oir_statement))) - return oir.HorizontalRestriction(mask=node.mask, body=body_stmts) + return oir.HorizontalRestriction(mask=node.mask, body=body) def visit_While( self, node: gtir.While, *, mask: Optional[oir.Expr] = None, **kwargs: Any - ) -> Union[oir.While, oir.MaskStmt]: - body_stmts: List[oir.Stmt] = [] - for st in node.body: - st_or_sts = self.visit(st, **kwargs) - sts = utils.flatten_list([st_or_sts] if isinstance(st_or_sts, oir.Stmt) else st_or_sts) - body_stmts.extend(sts) - - cond: oir.Expr = self.visit(node.cond) + ) -> oir.While: + body: List[oir.Stmt] = [] + for statement in node.body: + oir_statement = self.visit(statement, **kwargs) + body.extend(utils.flatten_list(utils.listify(oir_statement))) + + condition: oir.Expr = self.visit(node.cond) if mask: - cond = oir.BinaryOp(op=common.LogicalOperator.AND, left=mask, right=cond) - stmt: Union[oir.While, oir.MaskStmt] = oir.While(cond=cond, body=body_stmts, loc=node.loc) - if mask is not None: - stmt = oir.MaskStmt(body=[stmt], mask=mask, loc=node.loc) - return stmt + condition = oir.BinaryOp(op=common.LogicalOperator.AND, left=mask, right=condition) + return oir.While(cond=condition, body=body, loc=node.loc) def visit_FieldIfStmt( self, @@ -166,12 +158,12 @@ def visit_FieldIfStmt( mask: Optional[oir.Expr] = None, ctx: Context, **kwargs: Any, - ) -> List[oir.Stmt]: + ) -> List[Union[oir.AssignStmt, oir.MaskStmt]]: mask_field_decl = oir.Temporary( name=f"mask_{id(node)}", dtype=DataType.BOOL, dimensions=(True, True, True) ) ctx.temp_fields.append(mask_field_decl) - stmts: List[oir.Stmt] = [ + statements: List[Union[oir.AssignStmt, oir.MaskStmt]] = [ oir.AssignStmt( left=oir.FieldAccess( name=mask_field_decl.name, @@ -183,31 +175,35 @@ def visit_FieldIfStmt( ) ] - current_mask = oir.FieldAccess( + condition = oir.FieldAccess( name=mask_field_decl.name, offset=CartesianOffset.zero(), dtype=mask_field_decl.dtype, loc=node.loc, ) - combined_mask: oir.Expr = current_mask + combined_mask: oir.Expr = condition if mask: combined_mask = oir.BinaryOp( op=LogicalOperator.AND, left=mask, right=combined_mask, loc=node.loc ) - stmts.extend(self.visit(node.true_branch.body, mask=combined_mask, ctx=ctx, **kwargs)) + body = utils.flatten_list( + [self.visit(statement, ctx=ctx, **kwargs) for statement in node.true_branch.body] + ) + statements.append(oir.MaskStmt(body=body, mask=combined_mask, loc=node.loc)) if node.false_branch: - combined_mask_not: oir.Expr = oir.UnaryOp(op=UnaryOperator.NOT, expr=current_mask) + combined_mask = oir.UnaryOp(op=UnaryOperator.NOT, expr=condition) if mask: - combined_mask_not = oir.BinaryOp( - op=LogicalOperator.AND, left=mask, right=combined_mask_not, loc=node.loc + combined_mask = oir.BinaryOp( + op=LogicalOperator.AND, left=mask, right=combined_mask, loc=node.loc ) - stmts.extend( - self.visit(node.false_branch.body, mask=combined_mask_not, ctx=ctx, **kwargs) + body = utils.flatten_list( + [self.visit(statement, ctx=ctx, **kwargs) for statement in node.false_branch.body] ) + statements.append(oir.MaskStmt(body=body, mask=combined_mask, loc=node.loc)) - return stmts + return statements # For now we represent ScalarIf (and FieldIf) both as masks on the HorizontalExecution. # This is not meant to be set in stone... @@ -218,22 +214,30 @@ def visit_ScalarIfStmt( mask: Optional[oir.Expr] = None, ctx: Context, **kwargs: Any, - ) -> List[oir.Stmt]: - current_mask = self.visit(node.cond) - combined_mask = current_mask + ) -> List[oir.MaskStmt]: + condition = self.visit(node.cond) + combined_mask = condition if mask: combined_mask = oir.BinaryOp( - op=LogicalOperator.AND, left=mask, right=current_mask, loc=node.loc + op=LogicalOperator.AND, left=mask, right=condition, loc=node.loc ) - stmts = self.visit(node.true_branch.body, mask=combined_mask, ctx=ctx, **kwargs) + body = utils.flatten_list( + [self.visit(statement, ctx=ctx, **kwargs) for statement in node.true_branch.body] + ) + statements = [oir.MaskStmt(body=body, mask=condition, loc=node.loc)] + if node.false_branch: - combined_mask = oir.UnaryOp(op=UnaryOperator.NOT, expr=current_mask, loc=node.loc) + combined_mask = oir.UnaryOp(op=UnaryOperator.NOT, expr=condition, loc=node.loc) if mask: combined_mask = oir.BinaryOp(op=LogicalOperator.AND, left=mask, right=combined_mask) - stmts.extend(self.visit(node.false_branch.body, mask=combined_mask, ctx=ctx, **kwargs)) - return stmts + body = utils.flatten_list( + [self.visit(statement, ctx=ctx, **kwargs) for statement in node.false_branch.body] + ) + statements.append(oir.MaskStmt(body=body, mask=combined_mask, loc=node.loc)) + + return statements # --- Misc --- def visit_Interval(self, node: gtir.Interval) -> oir.Interval: @@ -241,12 +245,13 @@ def visit_Interval(self, node: gtir.Interval) -> oir.Interval: # --- Control flow --- def visit_VerticalLoop(self, node: gtir.VerticalLoop, *, ctx: Context) -> oir.VerticalLoop: - horiz_execs: List[oir.HorizontalExecution] = [] - for stmt in node.body: + horizontal_executions: List[oir.HorizontalExecution] = [] + for statement in node.body: ctx.reset_local_scalars() - ret = self.visit(stmt, ctx=ctx) - stmts = utils.flatten_list([ret] if isinstance(ret, oir.Stmt) else ret) - horiz_execs.append(oir.HorizontalExecution(body=stmts, declarations=ctx.local_scalars)) + body = utils.flatten_list(utils.listify(self.visit(statement, ctx=ctx))) + horizontal_executions.append( + oir.HorizontalExecution(body=body, declarations=ctx.local_scalars) + ) ctx.temp_fields += [ oir.Temporary( @@ -263,7 +268,7 @@ def visit_VerticalLoop(self, node: gtir.VerticalLoop, *, ctx: Context) -> oir.Ve sections=[ oir.VerticalLoopSection( interval=self.visit(node.interval), - horizontal_executions=horiz_execs, + horizontal_executions=horizontal_executions, loc=node.loc, ) ],