From 8c2a33e5d4c33a341a87fd6b82e6feca4e4ccca9 Mon Sep 17 00:00:00 2001 From: Roman Cattaneo <> Date: Fri, 23 Aug 2024 09:00:30 +0200 Subject: [PATCH 01/12] No else after return in if statement --- src/gt4py/cartesian/gtc/dace/nodes.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/src/gt4py/cartesian/gtc/dace/nodes.py b/src/gt4py/cartesian/gtc/dace/nodes.py index 5c2f11f30d..5ee5b69a35 100644 --- a/src/gt4py/cartesian/gtc/dace/nodes.py +++ b/src/gt4py/cartesian/gtc/dace/nodes.py @@ -215,10 +215,8 @@ def has_splittable_regions(self): def tile_strides(self): if self.tile_sizes_interpretation == "strides": return self.tile_sizes - else: - overall_extent: Extent = next(iter(self.extents.values())) - for extent in self.extents.values(): - overall_extent |= extent - return { - key: value + overall_extent[key.to_idx()] for key, value in self.tile_sizes.items() - } + + overall_extent: Extent = next(iter(self.extents.values())) + for extent in self.extents.values(): + overall_extent |= extent + return {key: value + overall_extent[key.to_idx()] for key, value in self.tile_sizes.items()} From 79d146519e382e494192e76df08b57c81c2e1742 Mon Sep 17 00:00:00 2001 From: Roman Cattaneo <> Date: Fri, 23 Aug 2024 09:02:22 +0200 Subject: [PATCH 02/12] Remove unused debug function --- src/gt4py/cartesian/backend/dace_backend.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/src/gt4py/cartesian/backend/dace_backend.py b/src/gt4py/cartesian/backend/dace_backend.py index 163a0dee3f..6833e16a92 100644 --- a/src/gt4py/cartesian/backend/dace_backend.py +++ b/src/gt4py/cartesian/backend/dace_backend.py @@ -18,7 +18,6 @@ import dace import dace.data from dace.sdfg.utils import inline_sdfgs -from dace.serialize import dumps from gt4py import storage as gt_storage from gt4py.cartesian import config as gt_config @@ -56,10 +55,6 @@ from gt4py.cartesian.stencil_object import StencilObject -def _serialize_sdfg(sdfg: dace.SDFG): - return dumps(sdfg) - - def _specialize_transient_strides(sdfg: dace.SDFG, layout_map): repldict = replace_strides( [array for array in sdfg.arrays.values() if array.transient], layout_map From 00d9128a1fe0cf98e6b07a9d8424791b7f5e3ca5 Mon Sep 17 00:00:00 2001 From: Roman Cattaneo <> Date: Fri, 23 Aug 2024 09:08:57 +0200 Subject: [PATCH 03/12] Add more type hints --- .../cartesian/gtc/dace/expansion/daceir_builder.py | 10 ++++++++-- src/gt4py/cartesian/gtc/dace/utils.py | 6 ++++-- 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/src/gt4py/cartesian/gtc/dace/expansion/daceir_builder.py b/src/gt4py/cartesian/gtc/dace/expansion/daceir_builder.py index 399d4d7af5..bf9130a73b 100644 --- a/src/gt4py/cartesian/gtc/dace/expansion/daceir_builder.py +++ b/src/gt4py/cartesian/gtc/dace/expansion/daceir_builder.py @@ -68,7 +68,13 @@ def _iterator(): ).unique(key=lambda x: x[2]) -def _get_tasklet_inout_memlets(node: oir.HorizontalExecution, *, get_outputs, global_ctx, **kwargs): +def _get_tasklet_inout_memlets( + node: oir.HorizontalExecution, + *, + get_outputs: bool, + global_ctx: DaCeIRBuilder.GlobalContext, + **kwargs, +): access_infos = compute_dcir_access_infos( node, block_extents=global_ctx.library_node.get_extents, @@ -611,7 +617,7 @@ def _process_map_item( scope_nodes, item: Map, *, - global_ctx, + global_ctx: DaCeIRBuilder.GlobalContext, iteration_ctx: DaCeIRBuilder.IterationContext, symbol_collector: DaCeIRBuilder.SymbolCollector, **kwargs, diff --git a/src/gt4py/cartesian/gtc/dace/utils.py b/src/gt4py/cartesian/gtc/dace/utils.py index f4dade581d..2c00ab76a8 100644 --- a/src/gt4py/cartesian/gtc/dace/utils.py +++ b/src/gt4py/cartesian/gtc/dace/utils.py @@ -18,7 +18,7 @@ from gt4py import eve from gt4py.cartesian.gtc import common, daceir as dcir, oir -from gt4py.cartesian.gtc.common import CartesianOffset +from gt4py.cartesian.gtc.common import CartesianOffset, VariableKOffset from gt4py.cartesian.gtc.passes.oir_optimizations.utils import compute_horizontal_block_extents @@ -56,7 +56,9 @@ def replace_strides(arrays, get_layout_map): return symbol_mapping -def get_tasklet_symbol(name, offset, is_target): +def get_tasklet_symbol( + name: eve.SymbolRef, offset: Union[CartesianOffset, VariableKOffset], is_target: bool +): if is_target: return f"__{name}" From 4e186b4d75406236fb6874766bd21b962fa7aae5 Mon Sep 17 00:00:00 2001 From: Roman Cattaneo <> Date: Fri, 23 Aug 2024 09:15:26 +0200 Subject: [PATCH 04/12] Return directly Avoid the otherwise unused "result" variable and return directly. --- src/gt4py/cartesian/gtc/dace/expansion/daceir_builder.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/gt4py/cartesian/gtc/dace/expansion/daceir_builder.py b/src/gt4py/cartesian/gtc/dace/expansion/daceir_builder.py index bf9130a73b..96167c0870 100644 --- a/src/gt4py/cartesian/gtc/dace/expansion/daceir_builder.py +++ b/src/gt4py/cartesian/gtc/dace/expansion/daceir_builder.py @@ -855,7 +855,7 @@ def visit_VerticalLoop( read_fields = set(memlet.field for memlet in read_memlets) write_fields = set(memlet.field for memlet in write_memlets) - res = dcir.NestedSDFG( + return dcir.NestedSDFG( label=global_ctx.library_node.label, states=self.to_state(computations, grid_subset=iteration_ctx.grid_subset), field_decls=field_decls, @@ -863,5 +863,3 @@ def visit_VerticalLoop( write_memlets=[memlet for memlet in field_memlets if memlet.field in write_fields], symbol_decls=list(symbol_collector.symbol_decls.values()), ) - - return res From 748dfe49042172631edd56475717254c9d7d9014 Mon Sep 17 00:00:00 2001 From: Roman Cattaneo <> Date: Fri, 23 Aug 2024 09:18:20 +0200 Subject: [PATCH 05/12] Cleanup DaCeIRBuilder.IterationContext No need for this `init()` function if we have a default value for the `parent` field. --- src/gt4py/cartesian/gtc/dace/expansion/daceir_builder.py | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/src/gt4py/cartesian/gtc/dace/expansion/daceir_builder.py b/src/gt4py/cartesian/gtc/dace/expansion/daceir_builder.py index 96167c0870..21dbfd78b5 100644 --- a/src/gt4py/cartesian/gtc/dace/expansion/daceir_builder.py +++ b/src/gt4py/cartesian/gtc/dace/expansion/daceir_builder.py @@ -196,12 +196,7 @@ def _get_dcir_decl( @dataclass class IterationContext: grid_subset: dcir.GridSubset - parent: Optional[DaCeIRBuilder.IterationContext] - - @classmethod - def init(cls, *args, **kwargs): - res = cls(*args, parent=None, **kwargs) - return res + parent: Optional[DaCeIRBuilder.IterationContext] = None def push_axes_extents(self, axes_extents) -> DaCeIRBuilder.IterationContext: res = self.grid_subset @@ -803,7 +798,7 @@ def visit_VerticalLoop( for he in node.walk_values().if_isinstance(oir.HorizontalExecution): overall_extent = overall_extent.union(global_ctx.library_node.get_extents(he)) - iteration_ctx = DaCeIRBuilder.IterationContext.init( + iteration_ctx = DaCeIRBuilder.IterationContext( grid_subset=dcir.GridSubset.from_gt4py_extent(overall_extent).set_interval( axis=dcir.Axis.K, interval=overall_interval ) From a05f440e4bee2bf8b9ad4ea7861cc781ab87b8b6 Mon Sep 17 00:00:00 2001 From: Roman Cattaneo <> Date: Fri, 23 Aug 2024 09:19:48 +0200 Subject: [PATCH 06/12] Cleanup GridSubset.set_interval usage No need to do this transformation manually. The same transformation is done inside `GridSubset.setInterval()` in case an `oir.Interval` is passed. --- src/gt4py/cartesian/gtc/dace/expansion/daceir_builder.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/src/gt4py/cartesian/gtc/dace/expansion/daceir_builder.py b/src/gt4py/cartesian/gtc/dace/expansion/daceir_builder.py index 21dbfd78b5..dfce07fe3d 100644 --- a/src/gt4py/cartesian/gtc/dace/expansion/daceir_builder.py +++ b/src/gt4py/cartesian/gtc/dace/expansion/daceir_builder.py @@ -788,19 +788,13 @@ def _process_iteration_item(self, scope, item, **kwargs): def visit_VerticalLoop( self, node: oir.VerticalLoop, *, global_ctx: DaCeIRBuilder.GlobalContext, **kwargs ): - start, end = (node.sections[0].interval.start, node.sections[0].interval.end) - - overall_interval = dcir.DomainInterval( - start=dcir.AxisBound(axis=dcir.Axis.K, level=start.level, offset=start.offset), - end=dcir.AxisBound(axis=dcir.Axis.K, level=end.level, offset=end.offset), - ) overall_extent = Extent.zeros(2) for he in node.walk_values().if_isinstance(oir.HorizontalExecution): overall_extent = overall_extent.union(global_ctx.library_node.get_extents(he)) iteration_ctx = DaCeIRBuilder.IterationContext( grid_subset=dcir.GridSubset.from_gt4py_extent(overall_extent).set_interval( - axis=dcir.Axis.K, interval=overall_interval + axis=dcir.Axis.K, interval=node.sections[0].interval ) ) From bc78ab2460f09bdbc55dac08faf978ca8b5dd84b Mon Sep 17 00:00:00 2001 From: Roman Cattaneo <> Date: Thu, 29 Aug 2024 10:23:24 +0200 Subject: [PATCH 07/12] Cleanup visit_ParAssignStmt --- src/gt4py/cartesian/gtc/gtir_to_oir.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/src/gt4py/cartesian/gtc/gtir_to_oir.py b/src/gt4py/cartesian/gtc/gtir_to_oir.py index 1e6b328a4e..f3a8b332a4 100644 --- a/src/gt4py/cartesian/gtc/gtir_to_oir.py +++ b/src/gt4py/cartesian/gtc/gtir_to_oir.py @@ -121,13 +121,12 @@ def visit_NativeFuncCall(self, node: gtir.NativeFuncCall) -> oir.NativeFuncCall: def visit_ParAssignStmt( self, node: gtir.ParAssignStmt, *, mask: Optional[oir.Expr] = None, **kwargs: Any ) -> Union[oir.AssignStmt, oir.MaskStmt]: - stmt: Union[oir.AssignStmt, oir.MaskStmt] = oir.AssignStmt( - left=self.visit(node.left), right=self.visit(node.right) - ) - if mask is not None: - # Wrap inside MaskStmt - stmt = oir.MaskStmt(body=[stmt], mask=mask, loc=node.loc) - return stmt + statement = oir.AssignStmt(left=self.visit(node.left), right=self.visit(node.right)) + if mask is None: + return statement + + # Wrap inside MaskStmt + return oir.MaskStmt(body=[statement], mask=mask, loc=node.loc) def visit_HorizontalRestriction( self, node: gtir.HorizontalRestriction, **kwargs: Any From b0627efa9e2059c19581f01acf98b826de0da621 Mon Sep 17 00:00:00 2001 From: Roman Cattaneo <> Date: Thu, 29 Aug 2024 10:36:31 +0200 Subject: [PATCH 08/12] Cleanup visit_While 1. The extra mask statement around `oir.While` is unnecessary because the mask is already part of the loop's condition. 2. Cleanup naming to increase readability --- src/gt4py/cartesian/gtc/gtir_to_oir.py | 22 +++++++++------------- 1 file changed, 9 insertions(+), 13 deletions(-) diff --git a/src/gt4py/cartesian/gtc/gtir_to_oir.py b/src/gt4py/cartesian/gtc/gtir_to_oir.py index f3a8b332a4..8b3bb600e1 100644 --- a/src/gt4py/cartesian/gtc/gtir_to_oir.py +++ b/src/gt4py/cartesian/gtc/gtir_to_oir.py @@ -143,20 +143,16 @@ def visit_HorizontalRestriction( def visit_While( self, node: gtir.While, *, mask: Optional[oir.Expr] = None, **kwargs: Any - ) -> Union[oir.While, oir.MaskStmt]: - body_stmts: List[oir.Stmt] = [] - for st in node.body: - st_or_sts = self.visit(st, **kwargs) - sts = utils.flatten_list([st_or_sts] if isinstance(st_or_sts, oir.Stmt) else st_or_sts) - body_stmts.extend(sts) - - cond: oir.Expr = self.visit(node.cond) + ) -> oir.While: + body: List[oir.Stmt] = [] + for statement in node.body: + oir_statement = self.visit(statement, **kwargs) + body.extend(utils.flatten_list(utils.listify(oir_statement))) + + condition: oir.Expr = self.visit(node.cond) if mask: - cond = oir.BinaryOp(op=common.LogicalOperator.AND, left=mask, right=cond) - stmt: Union[oir.While, oir.MaskStmt] = oir.While(cond=cond, body=body_stmts, loc=node.loc) - if mask is not None: - stmt = oir.MaskStmt(body=[stmt], mask=mask, loc=node.loc) - return stmt + condition = oir.BinaryOp(op=common.LogicalOperator.AND, left=mask, right=condition) + return oir.While(cond=condition, body=body, loc=node.loc) def visit_FieldIfStmt( self, From 4e11c0a1c2fe0055bfb5fbb9d40816d9fdfcbc27 Mon Sep 17 00:00:00 2001 From: Roman Cattaneo <> Date: Thu, 29 Aug 2024 10:40:58 +0200 Subject: [PATCH 09/12] Cleanup visit_HorizontalRestriction --- src/gt4py/cartesian/gtc/gtir_to_oir.py | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/src/gt4py/cartesian/gtc/gtir_to_oir.py b/src/gt4py/cartesian/gtc/gtir_to_oir.py index 8b3bb600e1..fd7a0afd5a 100644 --- a/src/gt4py/cartesian/gtc/gtir_to_oir.py +++ b/src/gt4py/cartesian/gtc/gtir_to_oir.py @@ -131,15 +131,12 @@ def visit_ParAssignStmt( def visit_HorizontalRestriction( self, node: gtir.HorizontalRestriction, **kwargs: Any ) -> oir.HorizontalRestriction: - body_stmts = [] - for stmt in node.body: - stmt_or_stmts = self.visit(stmt, **kwargs) - stmts = utils.flatten_list( - [stmt_or_stmts] if isinstance(stmt_or_stmts, oir.Stmt) else stmt_or_stmts - ) - body_stmts.extend(stmts) + body = [] + for statement in node.body: + oir_statement = self.visit(statement, **kwargs) + body.extend(utils.flatten_list(utils.listify(oir_statement))) - return oir.HorizontalRestriction(mask=node.mask, body=body_stmts) + return oir.HorizontalRestriction(mask=node.mask, body=body) def visit_While( self, node: gtir.While, *, mask: Optional[oir.Expr] = None, **kwargs: Any From 1297c4d1a2232a5b6274fd796ad13b5469ab55fd Mon Sep 17 00:00:00 2001 From: Roman Cattaneo <> Date: Thu, 5 Sep 2024 10:44:59 +0200 Subject: [PATCH 10/12] Clenaup visit_{if statements} and _VeritcalLoop --- src/gt4py/cartesian/gtc/gtir_to_oir.py | 65 +++++++++++++++----------- 1 file changed, 39 insertions(+), 26 deletions(-) diff --git a/src/gt4py/cartesian/gtc/gtir_to_oir.py b/src/gt4py/cartesian/gtc/gtir_to_oir.py index fd7a0afd5a..560cbf96cf 100644 --- a/src/gt4py/cartesian/gtc/gtir_to_oir.py +++ b/src/gt4py/cartesian/gtc/gtir_to_oir.py @@ -117,7 +117,7 @@ def visit_NativeFuncCall(self, node: gtir.NativeFuncCall) -> oir.NativeFuncCall: loc=node.loc, ) - # --- Stmts --- + # --- Statements --- def visit_ParAssignStmt( self, node: gtir.ParAssignStmt, *, mask: Optional[oir.Expr] = None, **kwargs: Any ) -> Union[oir.AssignStmt, oir.MaskStmt]: @@ -158,12 +158,12 @@ def visit_FieldIfStmt( mask: Optional[oir.Expr] = None, ctx: Context, **kwargs: Any, - ) -> List[oir.Stmt]: + ) -> List[Union[oir.AssignStmt, oir.MaskStmt]]: mask_field_decl = oir.Temporary( name=f"mask_{id(node)}", dtype=DataType.BOOL, dimensions=(True, True, True) ) ctx.temp_fields.append(mask_field_decl) - stmts: List[oir.Stmt] = [ + statements: List[Union[oir.AssignStmt, oir.MaskStmt]] = [ oir.AssignStmt( left=oir.FieldAccess( name=mask_field_decl.name, @@ -175,31 +175,35 @@ def visit_FieldIfStmt( ) ] - current_mask = oir.FieldAccess( + condition = oir.FieldAccess( name=mask_field_decl.name, offset=CartesianOffset.zero(), dtype=mask_field_decl.dtype, loc=node.loc, ) - combined_mask: oir.Expr = current_mask + combined_mask: oir.Expr = condition if mask: combined_mask = oir.BinaryOp( op=LogicalOperator.AND, left=mask, right=combined_mask, loc=node.loc ) - stmts.extend(self.visit(node.true_branch.body, mask=combined_mask, ctx=ctx, **kwargs)) + body = utils.flatten_list( + [self.visit(statement, ctx=ctx, **kwargs) for statement in node.true_branch.body] + ) + statements.append(oir.MaskStmt(body=body, mask=combined_mask, loc=node.loc)) if node.false_branch: - combined_mask_not: oir.Expr = oir.UnaryOp(op=UnaryOperator.NOT, expr=current_mask) + combined_mask = oir.UnaryOp(op=UnaryOperator.NOT, expr=condition) if mask: - combined_mask_not = oir.BinaryOp( - op=LogicalOperator.AND, left=mask, right=combined_mask_not, loc=node.loc + combined_mask = oir.BinaryOp( + op=LogicalOperator.AND, left=mask, right=combined_mask, loc=node.loc ) - stmts.extend( - self.visit(node.false_branch.body, mask=combined_mask_not, ctx=ctx, **kwargs) + body = utils.flatten_list( + [self.visit(statement, ctx=ctx, **kwargs) for statement in node.false_branch.body] ) + statements.append(oir.MaskStmt(body=body, mask=combined_mask, loc=node.loc)) - return stmts + return statements # For now we represent ScalarIf (and FieldIf) both as masks on the HorizontalExecution. # This is not meant to be set in stone... @@ -210,22 +214,30 @@ def visit_ScalarIfStmt( mask: Optional[oir.Expr] = None, ctx: Context, **kwargs: Any, - ) -> List[oir.Stmt]: - current_mask = self.visit(node.cond) - combined_mask = current_mask + ) -> List[oir.MaskStmt]: + condition = self.visit(node.cond) + combined_mask = condition if mask: combined_mask = oir.BinaryOp( - op=LogicalOperator.AND, left=mask, right=current_mask, loc=node.loc + op=LogicalOperator.AND, left=mask, right=condition, loc=node.loc ) - stmts = self.visit(node.true_branch.body, mask=combined_mask, ctx=ctx, **kwargs) + body = utils.flatten_list( + [self.visit(statement, ctx=ctx, **kwargs) for statement in node.true_branch.body] + ) + statements = [oir.MaskStmt(body=body, mask=condition, loc=node.loc)] + if node.false_branch: - combined_mask = oir.UnaryOp(op=UnaryOperator.NOT, expr=current_mask, loc=node.loc) + combined_mask = oir.UnaryOp(op=UnaryOperator.NOT, expr=condition, loc=node.loc) if mask: combined_mask = oir.BinaryOp(op=LogicalOperator.AND, left=mask, right=combined_mask) - stmts.extend(self.visit(node.false_branch.body, mask=combined_mask, ctx=ctx, **kwargs)) - return stmts + body = utils.flatten_list( + [self.visit(statement, ctx=ctx, **kwargs) for statement in node.false_branch.body] + ) + statements.append(oir.MaskStmt(body=body, mask=combined_mask, loc=node.loc)) + + return statements # --- Misc --- def visit_Interval(self, node: gtir.Interval) -> oir.Interval: @@ -233,12 +245,13 @@ def visit_Interval(self, node: gtir.Interval) -> oir.Interval: # --- Control flow --- def visit_VerticalLoop(self, node: gtir.VerticalLoop, *, ctx: Context) -> oir.VerticalLoop: - horiz_execs: List[oir.HorizontalExecution] = [] - for stmt in node.body: + horizontal_executions: List[oir.HorizontalExecution] = [] + for statement in node.body: ctx.reset_local_scalars() - ret = self.visit(stmt, ctx=ctx) - stmts = utils.flatten_list([ret] if isinstance(ret, oir.Stmt) else ret) - horiz_execs.append(oir.HorizontalExecution(body=stmts, declarations=ctx.local_scalars)) + body = utils.flatten_list(utils.listify(self.visit(statement, ctx=ctx))) + horizontal_executions.append( + oir.HorizontalExecution(body=body, declarations=ctx.local_scalars) + ) ctx.temp_fields += [ oir.Temporary( @@ -255,7 +268,7 @@ def visit_VerticalLoop(self, node: gtir.VerticalLoop, *, ctx: Context) -> oir.Ve sections=[ oir.VerticalLoopSection( interval=self.visit(node.interval), - horizontal_executions=horiz_execs, + horizontal_executions=horizontal_executions, loc=node.loc, ) ], From e539be23071be8917c4e5b863bb461015e040ae2 Mon Sep 17 00:00:00 2001 From: Roman Cattaneo <> Date: Thu, 5 Sep 2024 14:29:07 +0200 Subject: [PATCH 11/12] Move daceir into the dace folder All other backends (except the dace backend) have their IR in the backend-specific subfolder (e.g. `npir.py` inside the `numpy` folder). Let's do the same for the dace IR. --- src/gt4py/cartesian/backend/dace_backend.py | 2 +- src/gt4py/cartesian/gtc/{ => dace}/daceir.py | 0 src/gt4py/cartesian/gtc/dace/expansion/daceir_builder.py | 3 ++- src/gt4py/cartesian/gtc/dace/expansion/expansion.py | 2 +- src/gt4py/cartesian/gtc/dace/expansion/sdfg_builder.py | 2 +- src/gt4py/cartesian/gtc/dace/expansion/tasklet_codegen.py | 2 +- src/gt4py/cartesian/gtc/dace/expansion/utils.py | 3 ++- src/gt4py/cartesian/gtc/dace/expansion_specification.py | 3 ++- src/gt4py/cartesian/gtc/dace/nodes.py | 3 ++- src/gt4py/cartesian/gtc/dace/oir_to_dace.py | 2 +- src/gt4py/cartesian/gtc/dace/symbol_utils.py | 2 +- src/gt4py/cartesian/gtc/dace/utils.py | 3 ++- 12 files changed, 16 insertions(+), 11 deletions(-) rename src/gt4py/cartesian/gtc/{ => dace}/daceir.py (100%) diff --git a/src/gt4py/cartesian/backend/dace_backend.py b/src/gt4py/cartesian/backend/dace_backend.py index 6833e16a92..f49895a435 100644 --- a/src/gt4py/cartesian/backend/dace_backend.py +++ b/src/gt4py/cartesian/backend/dace_backend.py @@ -120,7 +120,7 @@ def _set_expansion_orders(sdfg: dace.SDFG): def _set_tile_sizes(sdfg: dace.SDFG): - import gt4py.cartesian.gtc.daceir as dcir # avoid circular import + import gt4py.cartesian.gtc.dace.daceir as dcir # avoid circular import for node, _ in filter( lambda n: isinstance(n[0], StencilComputation), sdfg.all_nodes_recursive() diff --git a/src/gt4py/cartesian/gtc/daceir.py b/src/gt4py/cartesian/gtc/dace/daceir.py similarity index 100% rename from src/gt4py/cartesian/gtc/daceir.py rename to src/gt4py/cartesian/gtc/dace/daceir.py diff --git a/src/gt4py/cartesian/gtc/dace/expansion/daceir_builder.py b/src/gt4py/cartesian/gtc/dace/expansion/daceir_builder.py index dfce07fe3d..d5b1c91466 100644 --- a/src/gt4py/cartesian/gtc/dace/expansion/daceir_builder.py +++ b/src/gt4py/cartesian/gtc/dace/expansion/daceir_builder.py @@ -18,7 +18,8 @@ import dace.subsets from gt4py import eve -from gt4py.cartesian.gtc import common, daceir as dcir, oir +from gt4py.cartesian.gtc import common, oir +from gt4py.cartesian.gtc.dace import daceir as dcir from gt4py.cartesian.gtc.dace.expansion_specification import Loop, Map, Sections, Stages from gt4py.cartesian.gtc.dace.utils import ( compute_dcir_access_infos, diff --git a/src/gt4py/cartesian/gtc/dace/expansion/expansion.py b/src/gt4py/cartesian/gtc/dace/expansion/expansion.py index 56bb6c1b3f..055bf64015 100644 --- a/src/gt4py/cartesian/gtc/dace/expansion/expansion.py +++ b/src/gt4py/cartesian/gtc/dace/expansion/expansion.py @@ -17,7 +17,7 @@ import dace.subsets import sympy -from gt4py.cartesian.gtc import daceir as dcir +from gt4py.cartesian.gtc.dace import daceir as dcir from gt4py.cartesian.gtc.dace.expansion.daceir_builder import DaCeIRBuilder from gt4py.cartesian.gtc.dace.expansion.sdfg_builder import StencilComputationSDFGBuilder diff --git a/src/gt4py/cartesian/gtc/dace/expansion/sdfg_builder.py b/src/gt4py/cartesian/gtc/dace/expansion/sdfg_builder.py index f6aa725b01..9d64464377 100644 --- a/src/gt4py/cartesian/gtc/dace/expansion/sdfg_builder.py +++ b/src/gt4py/cartesian/gtc/dace/expansion/sdfg_builder.py @@ -18,7 +18,7 @@ import dace.subsets from gt4py import eve -from gt4py.cartesian.gtc import daceir as dcir +from gt4py.cartesian.gtc.dace import daceir as dcir from gt4py.cartesian.gtc.dace.expansion.tasklet_codegen import TaskletCodegen from gt4py.cartesian.gtc.dace.expansion.utils import get_dace_debuginfo from gt4py.cartesian.gtc.dace.symbol_utils import data_type_to_dace_typeclass diff --git a/src/gt4py/cartesian/gtc/dace/expansion/tasklet_codegen.py b/src/gt4py/cartesian/gtc/dace/expansion/tasklet_codegen.py index 30e7c56bb4..c219667a4a 100644 --- a/src/gt4py/cartesian/gtc/dace/expansion/tasklet_codegen.py +++ b/src/gt4py/cartesian/gtc/dace/expansion/tasklet_codegen.py @@ -16,7 +16,7 @@ import gt4py.cartesian.gtc.common as common from gt4py import eve -from gt4py.cartesian.gtc import daceir as dcir +from gt4py.cartesian.gtc.dace import daceir as dcir from gt4py.cartesian.gtc.dace.symbol_utils import get_axis_bound_str from gt4py.cartesian.gtc.dace.utils import make_dace_subset from gt4py.eve.codegen import FormatTemplate as as_fmt diff --git a/src/gt4py/cartesian/gtc/dace/expansion/utils.py b/src/gt4py/cartesian/gtc/dace/expansion/utils.py index d8d2ce3176..919ec02996 100644 --- a/src/gt4py/cartesian/gtc/dace/expansion/utils.py +++ b/src/gt4py/cartesian/gtc/dace/expansion/utils.py @@ -16,7 +16,8 @@ import dace.subsets from gt4py import eve -from gt4py.cartesian.gtc import common, daceir as dcir, oir +from gt4py.cartesian.gtc import common, oir +from gt4py.cartesian.gtc.dace import daceir as dcir from gt4py.cartesian.gtc.definitions import Extent diff --git a/src/gt4py/cartesian/gtc/dace/expansion_specification.py b/src/gt4py/cartesian/gtc/dace/expansion_specification.py index 091849a1b7..c716f1a103 100644 --- a/src/gt4py/cartesian/gtc/dace/expansion_specification.py +++ b/src/gt4py/cartesian/gtc/dace/expansion_specification.py @@ -14,7 +14,8 @@ import dace -from gt4py.cartesian.gtc import common, daceir as dcir, oir +from gt4py.cartesian.gtc import common, oir +from gt4py.cartesian.gtc.dace import daceir as dcir from gt4py.cartesian.gtc.definitions import Extent diff --git a/src/gt4py/cartesian/gtc/dace/nodes.py b/src/gt4py/cartesian/gtc/dace/nodes.py index 5ee5b69a35..34401e18b9 100644 --- a/src/gt4py/cartesian/gtc/dace/nodes.py +++ b/src/gt4py/cartesian/gtc/dace/nodes.py @@ -20,7 +20,8 @@ import numpy as np from dace import library -from gt4py.cartesian.gtc import common, daceir as dcir, oir +from gt4py.cartesian.gtc import common, oir +from gt4py.cartesian.gtc.dace import daceir as dcir from gt4py.cartesian.gtc.dace.expansion.expansion import StencilComputationExpansion from gt4py.cartesian.gtc.definitions import Extent from gt4py.cartesian.gtc.oir import Decl, FieldDecl, VerticalLoop, VerticalLoopSection diff --git a/src/gt4py/cartesian/gtc/dace/oir_to_dace.py b/src/gt4py/cartesian/gtc/dace/oir_to_dace.py index 283402e1ac..3555d555f9 100644 --- a/src/gt4py/cartesian/gtc/dace/oir_to_dace.py +++ b/src/gt4py/cartesian/gtc/dace/oir_to_dace.py @@ -17,7 +17,7 @@ import gt4py.cartesian.gtc.oir as oir from gt4py import eve -from gt4py.cartesian.gtc import daceir as dcir +from gt4py.cartesian.gtc.dace import daceir as dcir from gt4py.cartesian.gtc.dace.nodes import StencilComputation from gt4py.cartesian.gtc.dace.symbol_utils import data_type_to_dace_typeclass from gt4py.cartesian.gtc.dace.utils import compute_dcir_access_infos, make_dace_subset diff --git a/src/gt4py/cartesian/gtc/dace/symbol_utils.py b/src/gt4py/cartesian/gtc/dace/symbol_utils.py index 86823304db..b9b6a49ce0 100644 --- a/src/gt4py/cartesian/gtc/dace/symbol_utils.py +++ b/src/gt4py/cartesian/gtc/dace/symbol_utils.py @@ -19,7 +19,7 @@ if TYPE_CHECKING: - import gt4py.cartesian.gtc.daceir as dcir + import gt4py.cartesian.gtc.dace.daceir as dcir def data_type_to_dace_typeclass(data_type): diff --git a/src/gt4py/cartesian/gtc/dace/utils.py b/src/gt4py/cartesian/gtc/dace/utils.py index 2c00ab76a8..9be2e9a07d 100644 --- a/src/gt4py/cartesian/gtc/dace/utils.py +++ b/src/gt4py/cartesian/gtc/dace/utils.py @@ -17,8 +17,9 @@ import numpy as np from gt4py import eve -from gt4py.cartesian.gtc import common, daceir as dcir, oir +from gt4py.cartesian.gtc import common, oir from gt4py.cartesian.gtc.common import CartesianOffset, VariableKOffset +from gt4py.cartesian.gtc.dace import daceir as dcir from gt4py.cartesian.gtc.passes.oir_optimizations.utils import compute_horizontal_block_extents From f49a04ff97d076b9e7699974911b4e20b9b23055 Mon Sep 17 00:00:00 2001 From: Roman Cattaneo <> Date: Tue, 10 Sep 2024 09:17:13 +0200 Subject: [PATCH 12/12] Use new-style type hints for AxisInterval --- src/gt4py/cartesian/frontend/nodes.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/gt4py/cartesian/frontend/nodes.py b/src/gt4py/cartesian/frontend/nodes.py index ed447bf37d..f84577e7b5 100644 --- a/src/gt4py/cartesian/frontend/nodes.py +++ b/src/gt4py/cartesian/frontend/nodes.py @@ -133,6 +133,8 @@ """ +from __future__ import annotations + import enum import operator import sys @@ -704,7 +706,7 @@ def is_single_index(self) -> bool: return self.start.level == self.end.level and self.start.offset == self.end.offset - 1 - def disjoint_from(self, other: "AxisInterval") -> bool: + def disjoint_from(self, other: AxisInterval) -> bool: def get_offset(bound: AxisBound) -> int: return ( 0 + bound.offset if bound.level == LevelMarker.START else sys.maxsize + bound.offset