Skip to content

Commit

Permalink
fix[next-dace]: scan_dim consistent with canonical field domain (#1346)
Browse files Browse the repository at this point in the history
The DaCe backend is reordering the dimensions of field domain based on alphabetical order - we call this the canonical representation of field domain. Therefore, array strides, sizes and offsets need to be shuffled, everywhere, to be consistent with the alphabetical order of dimensions.
This PR corrects indexing of field domain in get_scan_dim() which was not consistent with the canonical representation.

Additional minor edit:
* rename map_domain -> map_ranges
* replace dace.Memlet() with dace.Memlet.simple()
  • Loading branch information
edopao authored Oct 16, 2023
1 parent 6c69398 commit d07104d
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 36 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
create_memlet_at,
create_memlet_full,
filter_neighbor_tables,
get_sorted_dims,
map_nested_sdfg_symbols,
unique_var_name,
)
Expand Down Expand Up @@ -79,9 +80,10 @@ def get_scan_dim(
- scan_dim_dtype: data type along the scan dimension
"""
output_type = cast(ts.FieldType, storage_types[output.id])
sorted_dims = [dim for _, dim in get_sorted_dims(output_type.dims)]
return (
column_axis.value,
output_type.dims.index(column_axis),
sorted_dims.index(column_axis),
output_type.dtype,
)

Expand Down Expand Up @@ -246,7 +248,7 @@ def visit_StencilClosure(
)
access = closure_init_state.add_access(out_name)
value = ValueExpr(access, dtype)
memlet = create_memlet_at(out_name, ("0",))
memlet = dace.Memlet.simple(out_name, "0")
closure_init_state.add_edge(out_tasklet, "__result", access, None, memlet)
program_arg_syms[name] = value
else:
Expand Down Expand Up @@ -274,7 +276,7 @@ def visit_StencilClosure(
transient_to_arg_name_mapping[nsdfg_output_name] = output_name
# scan operator should always be the first function call in a closure
if is_scan(node.stencil):
nsdfg, map_domain, scan_dim_index = self._visit_scan_stencil_closure(
nsdfg, map_ranges, scan_dim_index = self._visit_scan_stencil_closure(
node, closure_sdfg.arrays, closure_domain, nsdfg_output_name
)
results = [nsdfg_output_name]
Expand All @@ -294,13 +296,13 @@ def visit_StencilClosure(
output_name,
tuple(
f"i_{dim}"
if f"i_{dim}" in map_domain
if f"i_{dim}" in map_ranges
else f"0:{output_descriptor.shape[scan_dim_index]}"
for dim, _ in closure_domain
),
)
else:
nsdfg, map_domain, results = self._visit_parallel_stencil_closure(
nsdfg, map_ranges, results = self._visit_parallel_stencil_closure(
node, closure_sdfg.arrays, closure_domain
)
assert len(results) == 1
Expand All @@ -313,7 +315,7 @@ def visit_StencilClosure(
transient=True,
)

output_memlet = create_memlet_at(output_name, tuple(idx for idx in map_domain.keys()))
output_memlet = create_memlet_at(output_name, tuple(idx for idx in map_ranges.keys()))

input_mapping = {param: arg for param, arg in zip(input_names, input_memlets)}
output_mapping = {param: arg_memlet for param, arg_memlet in zip(results, [output_memlet])}
Expand All @@ -325,7 +327,7 @@ def visit_StencilClosure(
nsdfg_node, map_entry, map_exit = add_mapped_nested_sdfg(
closure_state,
sdfg=nsdfg,
map_ranges=map_domain or {"__dummy": "0"},
map_ranges=map_ranges or {"__dummy": "0"},
inputs=array_mapping,
outputs=output_mapping,
symbol_mapping=symbol_mapping,
Expand All @@ -341,10 +343,10 @@ def visit_StencilClosure(
edge.src_conn,
transient_access,
None,
dace.Memlet(data=memlet.data, subset=output_subset),
dace.Memlet.simple(memlet.data, output_subset),
)
inner_memlet = dace.Memlet(
data=memlet.data, subset=output_subset, other_subset=memlet.subset
inner_memlet = dace.Memlet.simple(
memlet.data, output_subset, other_subset_str=memlet.subset
)
closure_state.add_edge(transient_access, None, map_exit, edge.dst_conn, inner_memlet)
closure_state.remove_edge(edge)
Expand All @@ -360,7 +362,7 @@ def visit_StencilClosure(
None,
map_entry,
b.value.data,
create_memlet_at(b.value.data, ("0",)),
dace.Memlet.simple(b.value.data, "0"),
)
return closure_sdfg

Expand Down Expand Up @@ -390,12 +392,12 @@ def _visit_scan_stencil_closure(
connectivity_names = [connectivity_identifier(offset) for offset, _ in neighbor_tables]

# find the scan dimension, same as output dimension, and exclude it from the map domain
map_domain = {}
map_ranges = {}
for dim, (lb, ub) in closure_domain:
lb_str = lb.value.data if isinstance(lb, ValueExpr) else lb.value
ub_str = ub.value.data if isinstance(ub, ValueExpr) else ub.value
if not dim == scan_dim:
map_domain[f"i_{dim}"] = f"{lb_str}:{ub_str}"
map_ranges[f"i_{dim}"] = f"{lb_str}:{ub_str}"
else:
scan_lb_str = lb_str
scan_ub_str = ub_str
Expand Down Expand Up @@ -481,29 +483,28 @@ def _visit_scan_stencil_closure(
"__result",
carry_node1,
None,
dace.Memlet(data=f"{scan_carry_name}", subset="0"),
dace.Memlet.simple(scan_carry_name, "0"),
)

carry_node2 = lambda_state.add_access(scan_carry_name)
lambda_state.add_memlet_path(
carry_node2,
scan_inner_node,
memlet=dace.Memlet(data=f"{scan_carry_name}", subset="0"),
memlet=dace.Memlet.simple(scan_carry_name, "0"),
src_conn=None,
dst_conn=lambda_carry_name,
)

# connect access nodes to lambda inputs
for (inner_name, _), data_name in zip(lambda_inputs[1:], input_names):
data_subset = (
", ".join([f"i_{dim}" for dim, _ in closure_domain])
if isinstance(self.storage_types[data_name], ts.FieldType)
else "0"
)
if isinstance(self.storage_types[data_name], ts.FieldType):
memlet = create_memlet_at(data_name, tuple(f"i_{dim}" for dim, _ in closure_domain))
else:
memlet = dace.Memlet.simple(data_name, "0")
lambda_state.add_memlet_path(
lambda_state.add_access(data_name),
scan_inner_node,
memlet=dace.Memlet(data=f"{data_name}", subset=data_subset),
memlet=memlet,
src_conn=None,
dst_conn=inner_name,
)
Expand Down Expand Up @@ -532,7 +533,7 @@ def _visit_scan_stencil_closure(
lambda_state.add_memlet_path(
scan_inner_node,
lambda_state.add_access(data_name),
memlet=dace.Memlet(data=data_name, subset=f"i_{scan_dim}"),
memlet=dace.Memlet.simple(data_name, f"i_{scan_dim}"),
src_conn=lambda_connector.value.label,
dst_conn=None,
)
Expand All @@ -544,10 +545,10 @@ def _visit_scan_stencil_closure(
lambda_update_state.add_memlet_path(
result_node,
carry_node3,
memlet=dace.Memlet(data=f"{output_names[0]}", subset=f"i_{scan_dim}", other_subset="0"),
memlet=dace.Memlet.simple(output_names[0], f"i_{scan_dim}", other_subset_str="0"),
)

return scan_sdfg, map_domain, scan_dim_index
return scan_sdfg, map_ranges, scan_dim_index

def _visit_parallel_stencil_closure(
self,
Expand All @@ -562,11 +563,11 @@ def _visit_parallel_stencil_closure(
conn_names = [connectivity_identifier(offset) for offset, _ in neighbor_tables]

# find the scan dimension, same as output dimension, and exclude it from the map domain
map_domain = {}
map_ranges = {}
for dim, (lb, ub) in closure_domain:
lb_str = lb.value.data if isinstance(lb, ValueExpr) else lb.value
ub_str = ub.value.data if isinstance(ub, ValueExpr) else ub.value
map_domain[f"i_{dim}"] = f"{lb_str}:{ub_str}"
map_ranges[f"i_{dim}"] = f"{lb_str}:{ub_str}"

# Create an SDFG for the tasklet that computes a single item of the output domain.
index_domain = {dim: f"i_{dim}" for dim, _ in closure_domain}
Expand All @@ -583,7 +584,7 @@ def _visit_parallel_stencil_closure(
self.node_types,
)

return context.body, map_domain, [r.value.data for r in results]
return context.body, map_ranges, [r.value.data for r in results]

def _visit_domain(
self, node: itir.FunCall, context: Context
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@
add_mapped_nested_sdfg,
as_dace_type,
connectivity_identifier,
create_memlet_at,
create_memlet_full,
filter_neighbor_tables,
map_nested_sdfg_symbols,
Expand Down Expand Up @@ -595,7 +594,7 @@ def _visit_deref(self, node: itir.FunCall) -> list[ValueExpr]:
)

# if dim is not found in iterator indices, we take the neighbor index over the reduction domain
array_index = [
flat_index = [
f"{iterator.indices[dim].data}_v" if dim in iterator.indices else index_name
for dim in sorted(iterator.dimensions)
]
Expand All @@ -608,7 +607,7 @@ def _visit_deref(self, node: itir.FunCall) -> list[ValueExpr]:
name="deref",
inputs=set(internals),
outputs={"__result"},
code=f"__result = {args[0].value.data}_v[{', '.join(array_index)}]",
code=f"__result = {args[0].value.data}_v[{', '.join(flat_index)}]",
)

for arg, internal in zip(args, internals):
Expand All @@ -634,8 +633,7 @@ def _visit_deref(self, node: itir.FunCall) -> list[ValueExpr]:
flat_index = [
ValueExpr(x[1], iterator.dtype) for x in sorted_index if x[0] in iterator.dimensions
]

args = [ValueExpr(iterator.field, int), *flat_index]
args = [ValueExpr(iterator.field, iterator.dtype), *flat_index]
internals = [f"{arg.value.data}_v" for arg in args]
expr = f"{internals[0]}[{', '.join(internals[1:])}]"
return self.add_expr_tasklet(list(zip(args, internals)), expr, iterator.dtype, "deref")
Expand Down Expand Up @@ -849,7 +847,7 @@ def _visit_reduce(self, node: itir.FunCall):
p.apply_pass(lambda_context.body, {})

input_memlets = [
create_memlet_at(expr.value.data, ("__idx",)) for arg, expr in zip(node.args, args)
dace.Memlet.simple(expr.value.data, "__idx") for arg, expr in zip(node.args, args)
]
output_memlet = dace.Memlet.simple(result_name, "0")

Expand Down Expand Up @@ -928,7 +926,7 @@ def add_expr_tasklet(
)
self.context.state.add_edge(arg.value, None, expr_tasklet, internal, memlet)

memlet = create_memlet_at(result_access.data, ("0",))
memlet = dace.Memlet.simple(result_access.data, "0")
self.context.state.add_edge(expr_tasklet, "__result", result_access, None, memlet)

return [ValueExpr(result_access, result_type)]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,11 @@
#
# SPDX-License-Identifier: GPL-3.0-or-later

from typing import Any
from typing import Any, Sequence

import dace

from gt4py.next import Dimension
from gt4py.next.iterator.embedded import NeighborTableOffsetProvider
from gt4py.next.type_system import type_specifications as ts

Expand Down Expand Up @@ -49,14 +50,18 @@ def connectivity_identifier(name: str):
def create_memlet_full(source_identifier: str, source_array: dace.data.Array):
bounds = [(0, size) for size in source_array.shape]
subset = ", ".join(f"{lb}:{ub}" for lb, ub in bounds)
return dace.Memlet(data=source_identifier, subset=subset)
return dace.Memlet.simple(source_identifier, subset)


def create_memlet_at(source_identifier: str, index: tuple[str, ...]):
subset = ", ".join(index)
return dace.Memlet(data=source_identifier, subset=subset)


def get_sorted_dims(dims: Sequence[Dimension]) -> Sequence[tuple[int, Dimension]]:
return sorted(enumerate(dims), key=lambda v: v[1].value)


def map_nested_sdfg_symbols(
parent_sdfg: dace.SDFG, nested_sdfg: dace.SDFG, array_mapping: dict[str, dace.Memlet]
) -> dict[str, str]:
Expand Down

0 comments on commit d07104d

Please sign in to comment.