Skip to content

Commit

Permalink
fix[next][dace]: Bugfix for nested neighbor reduction (GridTools#1457)
Browse files Browse the repository at this point in the history
In case of nested neighbor reduction with lift expression on inner node, the DaCe backend should generate a conditional state transition to field access, based on the value of neighbor index provided by the outer connectivity table.
Additional change. The previous selection of valid neighbors implemented as conditional inter-state edge is replaced by a select tasklet, which makes the SDFG easier to read.
  • Loading branch information
edopao authored Feb 16, 2024
1 parent d31d2cf commit 0a29261
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 37 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,7 @@ def _visit_lift_in_neighbors_reduction(
neighbor_index_node: dace.nodes.AccessNode,
neighbor_value_node: dace.nodes.AccessNode,
) -> list[ValueExpr]:
assert transformer.context.reduce_identity is not None
neighbor_dim = offset_provider.neighbor_axis.value
origin_dim = offset_provider.origin_axis.value

Expand Down Expand Up @@ -220,15 +221,15 @@ def _visit_lift_in_neighbors_reduction(

input_nodes = {}
iterator_index_nodes = {}
lifted_index_connectors = set()
lifted_index_connectors = []

for x, y in inner_inputs:
if isinstance(y, IteratorExpr):
field_connector, inner_index_table = x
input_nodes[field_connector] = y.field
for dim, connector in inner_index_table.items():
if dim == neighbor_dim:
lifted_index_connectors.add(connector)
lifted_index_connectors.append(connector)
iterator_index_nodes[connector] = y.indices[dim]
else:
assert isinstance(y, ValueExpr)
Expand Down Expand Up @@ -298,6 +299,30 @@ def _visit_lift_in_neighbors_reduction(
memlet=dace.Memlet(data=neighbor_value_node.data, subset=",".join(map_entry.params)),
)

if offset_provider.has_skip_values:
# check neighbor validity on if/else inter-state edge
start_state = lift_context.body.add_state("start", is_start_block=True)
skip_neighbor_state = lift_context.body.add_state("skip_neighbor")
skip_neighbor_state.add_edge(
skip_neighbor_state.add_tasklet(
"identity", {}, {"val"}, f"val = {transformer.context.reduce_identity.value}"
),
"val",
skip_neighbor_state.add_access(inner_outputs[0].value.data),
None,
dace.Memlet(data=inner_outputs[0].value.data, subset="0"),
)
lift_context.body.add_edge(
start_state,
skip_neighbor_state,
dace.InterstateEdge(condition=f"{lifted_index_connectors[0]} == {neighbor_skip_value}"),
)
lift_context.body.add_edge(
start_state,
lift_context.state,
dace.InterstateEdge(condition=f"{lifted_index_connectors[0]} != {neighbor_skip_value}"),
)

return [ValueExpr(neighbor_value_node, inner_outputs[0].dtype)]


Expand Down Expand Up @@ -467,7 +492,7 @@ def builtin_neighbors(
neighbor_valid_node = state.add_access(neighbor_valid_var, debuginfo=di)

neighbor_valid_tasklet = state.add_tasklet(
"check_valid_neighbor",
f"check_valid_neighbor_{offset_dim}",
{"__idx"},
{"__valid"},
f"__valid = True if __idx != {neighbor_skip_value} else False",
Expand Down Expand Up @@ -1223,7 +1248,7 @@ def _visit_reduce(self, node: itir.FunCall):
nreduce_shape = args_shape[0]

input_args = [arg[0] for arg in args]
input_valid = [arg[1] for arg in args if len(arg) == 2]
input_valid_args = [arg[1] for arg in args if len(arg) == 2]

nreduce_index = tuple(f"_i{i}" for i in range(len(nreduce_shape)))
nreduce_domain = {idx: f"0:{size}" for idx, size in zip(nreduce_index, nreduce_shape)}
Expand Down Expand Up @@ -1255,41 +1280,56 @@ def _visit_reduce(self, node: itir.FunCall):
self.context.body, lambda_context.body, input_mapping
)

if input_valid:
if input_valid_args:
"""
The neighbors builtin returns an array of booleans in case the connectivity table
contains skip values. These boolean values indicate whether the neighbor value is present or not,
and are used below to construct an if/else branch to bypass the lambda call for neighbor skip values.
The neighbors builtin returns an array of booleans in case the connectivity table contains skip values.
These booleans indicate whether the neighbor is present or not, and are used in a tasklet to select
the result of field access or the identity value, respectively.
If the neighbor table has full connectivity (no skip values by type definition), the input_valid node
is not built, and the construction of the if/else branch below is also skipped.
is not built, and the construction of the select tasklet below is also skipped.
"""
input_args.append(input_valid[0])
input_valid_node = input_valid[0].value
input_args.append(input_valid_args[0])
input_valid_node = input_valid_args[0].value
lambda_output_node = inner_outputs[0].value
# add input connector to nested sdfg
input_mapping["is_valid"] = create_memlet_at(input_valid_node.data, nreduce_index)
# check neighbor validity on if/else inter-state edge
start_state = lambda_context.body.add_state("start", is_start_block=True)
skip_neighbor_state = lambda_context.body.add_state("skip_neighbor")
skip_neighbor_state.add_edge(
skip_neighbor_state.add_tasklet(
"identity", {}, {"val"}, f"val = {reduce_identity}"
),
"val",
skip_neighbor_state.add_access(inner_outputs[0].value.data),
lambda_context.body.add_scalar("_valid_neighbor", dace.dtypes.bool)
input_mapping["_valid_neighbor"] = create_memlet_at(
input_valid_node.data, nreduce_index
)
# add select tasklet before writing to output node
# TODO: consider replacing it with a select-memlet once it is supported by DaCe SDFG API
output_edge = lambda_context.state.in_edges(lambda_output_node)[0]
assert isinstance(
lambda_context.body.arrays[output_edge.src.data], dace.data.Scalar
)
select_tasklet = lambda_context.state.add_tasklet(
"neighbor_select",
{"_inp", "_valid"},
{"_out"},
f"_out = _inp if _valid else {reduce_identity}",
)
lambda_context.state.add_edge(
output_edge.src,
None,
dace.Memlet(data=inner_outputs[0].value.data, subset="0"),
select_tasklet,
"_inp",
dace.Memlet(data=output_edge.src.data, subset="0"),
)
lambda_context.body.add_scalar("is_valid", dace.dtypes.bool)
lambda_context.body.add_edge(
start_state,
skip_neighbor_state,
dace.InterstateEdge(condition="is_valid == False"),
lambda_context.state.add_edge(
lambda_context.state.add_access("_valid_neighbor"),
None,
select_tasklet,
"_valid",
dace.Memlet(data="_valid_neighbor", subset="0"),
)
lambda_context.body.add_edge(
start_state,
lambda_context.state,
dace.InterstateEdge(condition="is_valid == True"),
lambda_context.state.add_edge(
select_tasklet,
"_out",
lambda_output_node,
None,
dace.Memlet(data=lambda_output_node.data, subset="0"),
)
lambda_context.state.remove_edge(output_edge)

reduce_input_node = self.context.state.add_access(reduce_input_name, debuginfo=di)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -515,22 +515,22 @@ def combine(a: cases.IField, b: cases.IField) -> cases.IField:
@pytest.mark.uses_reduction_over_lift_expressions
def test_nested_reduction(unstructured_case):
@gtx.field_operator
def testee(a: cases.EField) -> cases.EField:
tmp = neighbor_sum(a(V2E), axis=V2EDim)
tmp_2 = neighbor_sum(tmp(E2V), axis=E2VDim)
def testee(a: cases.VField) -> cases.VField:
tmp = neighbor_sum(a(E2V), axis=E2VDim)
tmp_2 = neighbor_sum(tmp(V2E), axis=V2EDim)
return tmp_2

cases.verify_with_default_data(
unstructured_case,
testee,
ref=lambda a: np.sum(
np.sum(
a[unstructured_case.offset_provider["V2E"].table],
a[unstructured_case.offset_provider["E2V"].table],
axis=1,
initial=0,
where=unstructured_case.offset_provider["V2E"].table != common.SKIP_VALUE,
)[unstructured_case.offset_provider["E2V"].table],
)[unstructured_case.offset_provider["V2E"].table],
axis=1,
where=unstructured_case.offset_provider["V2E"].table != common.SKIP_VALUE,
),
comparison=lambda a, tmp_2: np.all(a == tmp_2),
)
Expand Down

0 comments on commit 0a29261

Please sign in to comment.