Skip to content

Commit

Permalink
fix[cartesian]: DaCe array access in tasklet (#1410)
Browse files Browse the repository at this point in the history
Found some incompatible tasklet representation while upgrading to dace v0.15.1. Array access inside tasklet with partial index subset worked in v0.14.1, although not valid.
The fix consists of modifying the memlets to pass the full array shape to such tasklet, and use all explicit indices inside the tasklet to access the array. This is the right representation in DaCe SDFG, as discussed with the DaCe developers.
  • Loading branch information
edopao authored Jan 17, 2024
1 parent 6e6271c commit 6283ac9
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 2 deletions.
35 changes: 35 additions & 0 deletions src/gt4py/cartesian/gtc/dace/expansion/daceir_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
compute_dcir_access_infos,
flatten_list,
get_tasklet_symbol,
make_dace_subset,
union_inout_memlets,
union_node_grid_subsets,
untile_memlets,
Expand Down Expand Up @@ -458,6 +459,40 @@ def visit_HorizontalExecution(
write_memlets=write_memlets,
)

for memlet in [*read_memlets, *write_memlets]:
"""
This loop handles the special case of a tasklet performing array access.
The memlet should pass the full array shape (no tiling) and
the tasklet expression for array access should use all explicit indexes.
"""
array_ndims = len(global_ctx.arrays[memlet.field].shape)
field_decl = global_ctx.library_node.field_decls[memlet.field]
# calculate array subset on original memlet
memlet_subset = make_dace_subset(
global_ctx.library_node.access_infos[memlet.field],
memlet.access_info,
field_decl.data_dims,
)
# select index values for single-point grid access
memlet_data_index = [
dcir.Literal(value=str(dim_range[0]), dtype=common.DataType.INT32)
for dim_range, dim_size in zip(memlet_subset, memlet_subset.size())
if dim_size == 1
]
if len(memlet_data_index) < array_ndims:
reshape_memlet = False
for access_node in dcir_node.walk_values().if_isinstance(dcir.IndexAccess):
if access_node.data_index and access_node.name == memlet.connector:
access_node.data_index = memlet_data_index + access_node.data_index
assert len(access_node.data_index) == array_ndims
reshape_memlet = True
if reshape_memlet:
# ensure that memlet symbols used for array indexing are defined in context
for sym in memlet.access_info.grid_subset.free_symbols:
symbol_collector.add_symbol(sym)
# set full shape on memlet
memlet.access_info = global_ctx.library_node.access_infos[memlet.field]

for item in reversed(expansion_items):
iteration_ctx = iteration_ctx.pop()
dcir_node = self._process_iteration_item(
Expand Down
4 changes: 2 additions & 2 deletions src/gt4py/cartesian/gtc/daceir.py
Original file line number Diff line number Diff line change
Expand Up @@ -536,7 +536,7 @@ def union(self, other):
else:
assert (
isinstance(interval2, (TileInterval, DomainInterval))
and isinstance(interval1, IndexWithExtent)
and isinstance(interval1, (IndexWithExtent, DomainInterval))
) or (
isinstance(interval1, (TileInterval, DomainInterval))
and isinstance(interval2, IndexWithExtent)
Expand Down Expand Up @@ -573,7 +573,7 @@ def overapproximated_shape(self):
def apply_iteration(self, grid_subset: GridSubset):
res_intervals = dict(self.grid_subset.intervals)
for axis, field_interval in self.grid_subset.intervals.items():
if axis in grid_subset.intervals:
if axis in grid_subset.intervals and not isinstance(field_interval, DomainInterval):
grid_interval = grid_subset.intervals[axis]
assert isinstance(field_interval, IndexWithExtent)
extent = field_interval.extent
Expand Down

0 comments on commit 6283ac9

Please sign in to comment.