Skip to content

Commit

Permalink
Minor edit
Browse files Browse the repository at this point in the history
  • Loading branch information
edopao committed Jan 12, 2024
1 parent 077c031 commit 5ccd381
Showing 1 changed file with 26 additions and 32 deletions.
58 changes: 26 additions & 32 deletions src/gt4py/cartesian/gtc/dace/expansion/daceir_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -459,42 +459,36 @@ def visit_HorizontalExecution(
write_memlets=write_memlets,
)

if next(dcir_node.walk_values().if_isinstance(dcir.IndexAccess).iterator, None) is not None:
"""
Special case of tasklet performing array access. The memlet should pass the full array shape
(no slicing) and the tasklet code should use all explicit indexes for array access.
"""
for memlet in [*read_memlets, *write_memlets]:
field_decl = global_ctx.library_node.field_decls[memlet.field]
# calculate array subset from original memlet
memlet_subset = make_dace_subset(
global_ctx.library_node.access_infos[memlet.field],
memlet.access_info,
field_decl.data_dims,
)
# ensure grid access on single point
memlet_data_index = [
dcir.Literal(value=str(r[0]), dtype=common.DataType.INT32)
for r, size in zip(memlet_subset, memlet_subset.size())
if size == 1
]
# loop through assignment statements in the tasklet body
tasklet_subset_size = 0
for memlet in [*read_memlets, *write_memlets]:
array_ndims = len(global_ctx.arrays[memlet.field].shape)
field_decl = global_ctx.library_node.field_decls[memlet.field]
# calculate array subset from original memlet
memlet_subset = make_dace_subset(
global_ctx.library_node.access_infos[memlet.field],
memlet.access_info,
field_decl.data_dims,
)
# index values for single-point grid access
memlet_data_index = [
dcir.Literal(value=str(dim_range[0]), dtype=common.DataType.INT32)
for dim_range, dim_size in zip(memlet_subset, memlet_subset.size())
if dim_size == 1
]
if len(memlet_data_index) < array_ndims:
"""
Special case of tasklet performing array access. The memlet should pass the full array shape
(no slicing) and the tasklet code should use all explicit indexes for array access.
"""
for access_node in dcir_node.walk_values().if_isinstance(dcir.IndexAccess):
if access_node.data_index and access_node.name == memlet.connector:
if tasklet_subset_size != 0:
assert len(access_node.data_index) == tasklet_subset_size
else:
tasklet_subset_size = len(access_node.data_index)
for idx in reversed(memlet_data_index):
access_node.data_index.insert(0, idx)
# reshape memlet if tasklet accessed the endpoint array with partial index
if tasklet_subset_size != 0:
# ensure that memlet symbols used for array subset are defined in context
for sym in memlet.access_info.grid_subset.free_symbols:
symbol_collector.add_symbol(sym)
# set full shape on memlet
memlet.access_info = global_ctx.library_node.access_infos[memlet.field]
assert len(access_node.data_index) == array_ndims
# ensure that memlet symbols used for array indexing are defined in context
for sym in memlet.access_info.grid_subset.free_symbols:
symbol_collector.add_symbol(sym)
# set full shape on memlet
memlet.access_info = global_ctx.library_node.access_infos[memlet.field]

for item in reversed(expansion_items):
iteration_ctx = iteration_ctx.pop()
Expand Down

0 comments on commit 5ccd381

Please sign in to comment.