Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor[cartesian]: readability improvements in gtir -> oir conversion and other cleanups #1630

Merged
merged 12 commits into from
Sep 11, 2024
7 changes: 1 addition & 6 deletions src/gt4py/cartesian/backend/dace_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
import dace
import dace.data
from dace.sdfg.utils import inline_sdfgs
from dace.serialize import dumps

from gt4py import storage as gt_storage
from gt4py.cartesian import config as gt_config
Expand Down Expand Up @@ -56,10 +55,6 @@
from gt4py.cartesian.stencil_object import StencilObject


def _serialize_sdfg(sdfg: dace.SDFG):
return dumps(sdfg)


def _specialize_transient_strides(sdfg: dace.SDFG, layout_map):
repldict = replace_strides(
[array for array in sdfg.arrays.values() if array.transient], layout_map
Expand Down Expand Up @@ -125,7 +120,7 @@ def _set_expansion_orders(sdfg: dace.SDFG):


def _set_tile_sizes(sdfg: dace.SDFG):
import gt4py.cartesian.gtc.daceir as dcir # avoid circular import
import gt4py.cartesian.gtc.dace.daceir as dcir # avoid circular import

for node, _ in filter(
lambda n: isinstance(n[0], StencilComputation), sdfg.all_nodes_recursive()
Expand Down
4 changes: 3 additions & 1 deletion src/gt4py/cartesian/frontend/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,8 @@

"""

from __future__ import annotations

import enum
import operator
import sys
Expand Down Expand Up @@ -704,7 +706,7 @@ def is_single_index(self) -> bool:

return self.start.level == self.end.level and self.start.offset == self.end.offset - 1

def disjoint_from(self, other: "AxisInterval") -> bool:
def disjoint_from(self, other: AxisInterval) -> bool:
def get_offset(bound: AxisBound) -> int:
return (
0 + bound.offset if bound.level == LevelMarker.START else sys.maxsize + bound.offset
Expand Down
34 changes: 14 additions & 20 deletions src/gt4py/cartesian/gtc/dace/expansion/daceir_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@
import dace.subsets

from gt4py import eve
from gt4py.cartesian.gtc import common, daceir as dcir, oir
from gt4py.cartesian.gtc import common, oir
from gt4py.cartesian.gtc.dace import daceir as dcir
from gt4py.cartesian.gtc.dace.expansion_specification import Loop, Map, Sections, Stages
from gt4py.cartesian.gtc.dace.utils import (
compute_dcir_access_infos,
Expand Down Expand Up @@ -68,7 +69,13 @@ def _iterator():
).unique(key=lambda x: x[2])


def _get_tasklet_inout_memlets(node: oir.HorizontalExecution, *, get_outputs, global_ctx, **kwargs):
def _get_tasklet_inout_memlets(
node: oir.HorizontalExecution,
*,
get_outputs: bool,
global_ctx: DaCeIRBuilder.GlobalContext,
**kwargs,
):
access_infos = compute_dcir_access_infos(
node,
block_extents=global_ctx.library_node.get_extents,
Expand Down Expand Up @@ -190,12 +197,7 @@ def _get_dcir_decl(
@dataclass
class IterationContext:
grid_subset: dcir.GridSubset
parent: Optional[DaCeIRBuilder.IterationContext]

@classmethod
def init(cls, *args, **kwargs):
res = cls(*args, parent=None, **kwargs)
return res
parent: Optional[DaCeIRBuilder.IterationContext] = None

def push_axes_extents(self, axes_extents) -> DaCeIRBuilder.IterationContext:
res = self.grid_subset
Expand Down Expand Up @@ -611,7 +613,7 @@ def _process_map_item(
scope_nodes,
item: Map,
*,
global_ctx,
global_ctx: DaCeIRBuilder.GlobalContext,
iteration_ctx: DaCeIRBuilder.IterationContext,
symbol_collector: DaCeIRBuilder.SymbolCollector,
**kwargs,
Expand Down Expand Up @@ -787,19 +789,13 @@ def _process_iteration_item(self, scope, item, **kwargs):
def visit_VerticalLoop(
self, node: oir.VerticalLoop, *, global_ctx: DaCeIRBuilder.GlobalContext, **kwargs
):
start, end = (node.sections[0].interval.start, node.sections[0].interval.end)

overall_interval = dcir.DomainInterval(
start=dcir.AxisBound(axis=dcir.Axis.K, level=start.level, offset=start.offset),
end=dcir.AxisBound(axis=dcir.Axis.K, level=end.level, offset=end.offset),
)
overall_extent = Extent.zeros(2)
for he in node.walk_values().if_isinstance(oir.HorizontalExecution):
overall_extent = overall_extent.union(global_ctx.library_node.get_extents(he))

iteration_ctx = DaCeIRBuilder.IterationContext.init(
iteration_ctx = DaCeIRBuilder.IterationContext(
grid_subset=dcir.GridSubset.from_gt4py_extent(overall_extent).set_interval(
axis=dcir.Axis.K, interval=overall_interval
axis=dcir.Axis.K, interval=node.sections[0].interval
)
)

Expand Down Expand Up @@ -849,13 +845,11 @@ def visit_VerticalLoop(

read_fields = set(memlet.field for memlet in read_memlets)
write_fields = set(memlet.field for memlet in write_memlets)
res = dcir.NestedSDFG(
return dcir.NestedSDFG(
label=global_ctx.library_node.label,
states=self.to_state(computations, grid_subset=iteration_ctx.grid_subset),
field_decls=field_decls,
read_memlets=[memlet for memlet in field_memlets if memlet.field in read_fields],
write_memlets=[memlet for memlet in field_memlets if memlet.field in write_fields],
symbol_decls=list(symbol_collector.symbol_decls.values()),
)

return res
2 changes: 1 addition & 1 deletion src/gt4py/cartesian/gtc/dace/expansion/expansion.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import dace.subsets
import sympy

from gt4py.cartesian.gtc import daceir as dcir
from gt4py.cartesian.gtc.dace import daceir as dcir
from gt4py.cartesian.gtc.dace.expansion.daceir_builder import DaCeIRBuilder
from gt4py.cartesian.gtc.dace.expansion.sdfg_builder import StencilComputationSDFGBuilder

Expand Down
2 changes: 1 addition & 1 deletion src/gt4py/cartesian/gtc/dace/expansion/sdfg_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import dace.subsets

from gt4py import eve
from gt4py.cartesian.gtc import daceir as dcir
from gt4py.cartesian.gtc.dace import daceir as dcir
from gt4py.cartesian.gtc.dace.expansion.tasklet_codegen import TaskletCodegen
from gt4py.cartesian.gtc.dace.expansion.utils import get_dace_debuginfo
from gt4py.cartesian.gtc.dace.symbol_utils import data_type_to_dace_typeclass
Expand Down
2 changes: 1 addition & 1 deletion src/gt4py/cartesian/gtc/dace/expansion/tasklet_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

import gt4py.cartesian.gtc.common as common
from gt4py import eve
from gt4py.cartesian.gtc import daceir as dcir
from gt4py.cartesian.gtc.dace import daceir as dcir
from gt4py.cartesian.gtc.dace.symbol_utils import get_axis_bound_str
from gt4py.cartesian.gtc.dace.utils import make_dace_subset
from gt4py.eve.codegen import FormatTemplate as as_fmt
Expand Down
3 changes: 2 additions & 1 deletion src/gt4py/cartesian/gtc/dace/expansion/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@
import dace.subsets

from gt4py import eve
from gt4py.cartesian.gtc import common, daceir as dcir, oir
from gt4py.cartesian.gtc import common, oir
from gt4py.cartesian.gtc.dace import daceir as dcir
from gt4py.cartesian.gtc.definitions import Extent


Expand Down
3 changes: 2 additions & 1 deletion src/gt4py/cartesian/gtc/dace/expansion_specification.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@

import dace

from gt4py.cartesian.gtc import common, daceir as dcir, oir
from gt4py.cartesian.gtc import common, oir
from gt4py.cartesian.gtc.dace import daceir as dcir
from gt4py.cartesian.gtc.definitions import Extent


Expand Down
15 changes: 7 additions & 8 deletions src/gt4py/cartesian/gtc/dace/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@
import numpy as np
from dace import library

from gt4py.cartesian.gtc import common, daceir as dcir, oir
from gt4py.cartesian.gtc import common, oir
from gt4py.cartesian.gtc.dace import daceir as dcir
from gt4py.cartesian.gtc.dace.expansion.expansion import StencilComputationExpansion
from gt4py.cartesian.gtc.definitions import Extent
from gt4py.cartesian.gtc.oir import Decl, FieldDecl, VerticalLoop, VerticalLoopSection
Expand Down Expand Up @@ -215,10 +216,8 @@ def has_splittable_regions(self):
def tile_strides(self):
if self.tile_sizes_interpretation == "strides":
return self.tile_sizes
else:
overall_extent: Extent = next(iter(self.extents.values()))
for extent in self.extents.values():
overall_extent |= extent
return {
key: value + overall_extent[key.to_idx()] for key, value in self.tile_sizes.items()
}

overall_extent: Extent = next(iter(self.extents.values()))
for extent in self.extents.values():
overall_extent |= extent
return {key: value + overall_extent[key.to_idx()] for key, value in self.tile_sizes.items()}
2 changes: 1 addition & 1 deletion src/gt4py/cartesian/gtc/dace/oir_to_dace.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

import gt4py.cartesian.gtc.oir as oir
from gt4py import eve
from gt4py.cartesian.gtc import daceir as dcir
from gt4py.cartesian.gtc.dace import daceir as dcir
from gt4py.cartesian.gtc.dace.nodes import StencilComputation
from gt4py.cartesian.gtc.dace.symbol_utils import data_type_to_dace_typeclass
from gt4py.cartesian.gtc.dace.utils import compute_dcir_access_infos, make_dace_subset
Expand Down
2 changes: 1 addition & 1 deletion src/gt4py/cartesian/gtc/dace/symbol_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@


if TYPE_CHECKING:
import gt4py.cartesian.gtc.daceir as dcir
import gt4py.cartesian.gtc.dace.daceir as dcir


def data_type_to_dace_typeclass(data_type):
Expand Down
9 changes: 6 additions & 3 deletions src/gt4py/cartesian/gtc/dace/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,9 @@
import numpy as np

from gt4py import eve
from gt4py.cartesian.gtc import common, daceir as dcir, oir
from gt4py.cartesian.gtc.common import CartesianOffset
from gt4py.cartesian.gtc import common, oir
from gt4py.cartesian.gtc.common import CartesianOffset, VariableKOffset
from gt4py.cartesian.gtc.dace import daceir as dcir
from gt4py.cartesian.gtc.passes.oir_optimizations.utils import compute_horizontal_block_extents


Expand Down Expand Up @@ -56,7 +57,9 @@ def replace_strides(arrays, get_layout_map):
return symbol_mapping


def get_tasklet_symbol(name, offset, is_target):
def get_tasklet_symbol(
name: eve.SymbolRef, offset: Union[CartesianOffset, VariableKOffset], is_target: bool
):
if is_target:
return f"__{name}"

Expand Down
Loading
Loading