Skip to content

Commit

Permalink
feat[next][dace]: Add support for lift expressions in neighbor reduct…
Browse files Browse the repository at this point in the history
…ions (no unrolling) (#1431)

Baseline dace backend forced unroll of neighbor reductions, in the ITIR pass, in order to eliminate all lift expressions. This PR adds support for lowering of lift expressions in neighbor reductions, thus avoiding the need to unroll reduce expressions. The result is a more compact SDFG, which leaves to the optimization backend the option of unrolling neighbor reductions.
  • Loading branch information
edopao authored Feb 2, 2024
1 parent 0d158ad commit e462a2e
Show file tree
Hide file tree
Showing 3 changed files with 205 additions and 68 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -69,28 +69,16 @@ def preprocess_program(
program: itir.FencilDefinition,
offset_provider: Mapping[str, Any],
lift_mode: itir_transforms.LiftMode,
unroll_reduce: bool = False,
):
node = itir_transforms.apply_common_transforms(
return itir_transforms.apply_common_transforms(
program,
common_subexpression_elimination=False,
force_inline_lambda_args=True,
lift_mode=lift_mode,
offset_provider=offset_provider,
unroll_reduce=False,
unroll_reduce=unroll_reduce,
)
# If we don't unroll, there may be lifts left in the itir which can't be lowered to SDFG.
# In this case, just retry with unrolled reductions.
if all([ItirToSDFG._check_no_lifts(closure) for closure in node.closures]):
fencil_definition = node
else:
fencil_definition = itir_transforms.apply_common_transforms(
program,
common_subexpression_elimination=False,
force_inline_lambda_args=True,
lift_mode=lift_mode,
offset_provider=offset_provider,
unroll_reduce=True,
)
return fencil_definition


def get_args(sdfg: dace.SDFG, args: Sequence[Any]) -> dict[str, Any]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,24 @@ def _make_array_shape_and_strides(
return shape, strides


def _check_no_lifts(node: itir.StencilClosure):
"""
Parse stencil closure ITIR to check that lift expressions only appear as child nodes in neighbor reductions.
Returns
-------
True if lifts do not appear in the ITIR exception lift expressions in neighbor reductions. False otherwise.
"""
neighbors_call_count = 0
for fun in eve.walk_values(node).if_isinstance(itir.FunCall).getattr("fun"):
if getattr(fun, "id", "") == "neighbors":
neighbors_call_count = 3
elif getattr(fun, "id", "") == "lift" and neighbors_call_count != 1:
return False
neighbors_call_count = max(0, neighbors_call_count - 1)
return True


class ItirToSDFG(eve.NodeVisitor):
param_types: list[ts.TypeSpec]
storage_types: dict[str, ts.TypeSpec]
Expand Down Expand Up @@ -262,7 +280,7 @@ def visit_FencilDefinition(self, node: itir.FencilDefinition):
def visit_StencilClosure(
self, node: itir.StencilClosure, array_table: dict[str, dace.data.Array]
) -> tuple[dace.SDFG, list[str], list[str]]:
assert ItirToSDFG._check_no_lifts(node)
assert _check_no_lifts(node)

# Create the closure's nested SDFG and single state.
closure_sdfg = dace.SDFG(name="closure")
Expand Down Expand Up @@ -681,15 +699,6 @@ def _visit_domain(

return tuple(sorted(bounds, key=lambda item: item[0]))

@staticmethod
def _check_no_lifts(node: itir.StencilClosure):
if any(
getattr(fun, "id", "") == "lift"
for fun in eve.walk_values(node).if_isinstance(itir.FunCall).getattr("fun")
):
return False
return True

@staticmethod
def _check_shift_offsets_are_literals(node: itir.StencilClosure):
fun_calls = eve.walk_values(node).if_isinstance(itir.FunCall)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,126 @@ def __init__(
self.reduce_identity = reduce_identity


def _visit_lift_in_neighbors_reduction(
transformer: "PythonTaskletCodegen",
node: itir.FunCall,
node_args: Sequence[IteratorExpr | list[ValueExpr]],
offset_provider: NeighborTableOffsetProvider,
map_entry: dace.nodes.MapEntry,
map_exit: dace.nodes.MapExit,
neighbor_index_node: dace.nodes.AccessNode,
neighbor_value_node: dace.nodes.AccessNode,
) -> list[ValueExpr]:
neighbor_dim = offset_provider.neighbor_axis.value
origin_dim = offset_provider.origin_axis.value

lifted_args: list[IteratorExpr | ValueExpr] = []
for arg in node_args:
if isinstance(arg, IteratorExpr):
if origin_dim in arg.indices:
lifted_indices = arg.indices.copy()
lifted_indices.pop(origin_dim)
lifted_indices[neighbor_dim] = neighbor_index_node
lifted_args.append(
IteratorExpr(
arg.field,
lifted_indices,
arg.dtype,
arg.dimensions,
)
)
else:
lifted_args.append(arg)
else:
lifted_args.append(arg[0])

lift_context, inner_inputs, inner_outputs = transformer.visit(node.args[0], args=lifted_args)
assert len(inner_outputs) == 1
inner_out_connector = inner_outputs[0].value.data

input_nodes = {}
iterator_index_nodes = {}
lifted_index_connectors = set()

for x, y in inner_inputs:
if isinstance(y, IteratorExpr):
field_connector, inner_index_table = x
input_nodes[field_connector] = y.field
for dim, connector in inner_index_table.items():
if dim == neighbor_dim:
lifted_index_connectors.add(connector)
iterator_index_nodes[connector] = y.indices[dim]
else:
assert isinstance(y, ValueExpr)
input_nodes[x] = y.value

neighbor_tables = filter_neighbor_tables(transformer.offset_provider)
connectivity_names = [connectivity_identifier(offset) for offset in neighbor_tables.keys()]

parent_sdfg = transformer.context.body
parent_state = transformer.context.state

input_mapping = {
connector: create_memlet_full(node.data, node.desc(parent_sdfg))
for connector, node in input_nodes.items()
}
connectivity_mapping = {
name: create_memlet_full(name, parent_sdfg.arrays[name]) for name in connectivity_names
}
array_mapping = {**input_mapping, **connectivity_mapping}
symbol_mapping = map_nested_sdfg_symbols(parent_sdfg, lift_context.body, array_mapping)

nested_sdfg_node = parent_state.add_nested_sdfg(
lift_context.body,
parent_sdfg,
inputs={*array_mapping.keys(), *iterator_index_nodes.keys()},
outputs={inner_out_connector},
symbol_mapping=symbol_mapping,
debuginfo=lift_context.body.debuginfo,
)

for connectivity_connector, memlet in connectivity_mapping.items():
parent_state.add_memlet_path(
parent_state.add_access(memlet.data, debuginfo=lift_context.body.debuginfo),
map_entry,
nested_sdfg_node,
dst_conn=connectivity_connector,
memlet=memlet,
)

for inner_connector, access_node in input_nodes.items():
parent_state.add_memlet_path(
access_node,
map_entry,
nested_sdfg_node,
dst_conn=inner_connector,
memlet=input_mapping[inner_connector],
)

for inner_connector, access_node in iterator_index_nodes.items():
memlet = dace.Memlet(data=access_node.data, subset="0")
if inner_connector in lifted_index_connectors:
parent_state.add_edge(access_node, None, nested_sdfg_node, inner_connector, memlet)
else:
parent_state.add_memlet_path(
access_node,
map_entry,
nested_sdfg_node,
dst_conn=inner_connector,
memlet=memlet,
)

parent_state.add_memlet_path(
nested_sdfg_node,
map_exit,
neighbor_value_node,
src_conn=inner_out_connector,
memlet=dace.Memlet(data=neighbor_value_node.data, subset=",".join(map_entry.params)),
)

return [ValueExpr(neighbor_value_node, inner_outputs[0].dtype)]


def builtin_neighbors(
transformer: "PythonTaskletCodegen", node: itir.Expr, node_args: list[itir.Expr]
) -> list[ValueExpr]:
Expand All @@ -198,7 +318,16 @@ def builtin_neighbors(
"Neighbor reduction only implemented for connectivity based on neighbor tables."
)

iterator = transformer.visit(data)
lift_node = None
if isinstance(data, FunCall):
assert isinstance(data.fun, itir.FunCall)
fun_node = data.fun
if isinstance(fun_node.fun, itir.SymRef) and fun_node.fun.id == "lift":
lift_node = fun_node
lift_args = transformer.visit(data.args)
iterator = next(filter(lambda x: isinstance(x, IteratorExpr), lift_args), None)
if lift_node is None:
iterator = transformer.visit(data)
assert isinstance(iterator, IteratorExpr)
field_desc = iterator.field.desc(transformer.context.body)
origin_index_node = iterator.indices[offset_provider.origin_axis.value]
Expand Down Expand Up @@ -259,44 +388,56 @@ def builtin_neighbors(
dace.Memlet(data=neighbor_index_var, subset="0"),
)

data_access_tasklet = state.add_tasklet(
"data_access",
code="__data = __field[__idx]"
+ (
f" if __idx != {neighbor_skip_value} else {transformer.context.reduce_identity.value}"
if offset_provider.has_skip_values
else ""
),
inputs={"__field", "__idx"},
outputs={"__data"},
debuginfo=di,
)
# select full shape only in the neighbor-axis dimension
field_subset = tuple(
f"0:{shape}" if dim == offset_provider.neighbor_axis.value else f"i_{dim}"
for dim, shape in zip(sorted(iterator.dimensions), field_desc.shape)
)
state.add_memlet_path(
iterator.field,
me,
data_access_tasklet,
memlet=create_memlet_at(iterator.field.data, field_subset),
dst_conn="__field",
)
state.add_edge(
neighbor_index_node,
None,
data_access_tasklet,
"__idx",
dace.Memlet(data=neighbor_index_var, subset="0"),
)
state.add_memlet_path(
data_access_tasklet,
mx,
neighbor_value_node,
memlet=dace.Memlet(data=neighbor_value_var, subset=neighbor_map_index, debuginfo=di),
src_conn="__data",
)
if lift_node is not None:
_visit_lift_in_neighbors_reduction(
transformer,
lift_node,
lift_args,
offset_provider,
me,
mx,
neighbor_index_node,
neighbor_value_node,
)
else:
data_access_tasklet = state.add_tasklet(
"data_access",
code="__data = __field[__idx]"
+ (
f" if __idx != {neighbor_skip_value} else {transformer.context.reduce_identity.value}"
if offset_provider.has_skip_values
else ""
),
inputs={"__field", "__idx"},
outputs={"__data"},
debuginfo=di,
)
# select full shape only in the neighbor-axis dimension
field_subset = tuple(
f"0:{shape}" if dim == offset_provider.neighbor_axis.value else f"i_{dim}"
for dim, shape in zip(sorted(iterator.dimensions), field_desc.shape)
)
state.add_memlet_path(
iterator.field,
me,
data_access_tasklet,
memlet=create_memlet_at(iterator.field.data, field_subset),
dst_conn="__field",
)
state.add_edge(
neighbor_index_node,
None,
data_access_tasklet,
"__idx",
dace.Memlet(data=neighbor_index_var, subset="0"),
)
state.add_memlet_path(
data_access_tasklet,
mx,
neighbor_value_node,
memlet=dace.Memlet(data=neighbor_value_var, subset=neighbor_map_index, debuginfo=di),
src_conn="__data",
)

if not offset_provider.has_skip_values:
return [ValueExpr(neighbor_value_node, iterator.dtype)]
Expand Down Expand Up @@ -377,9 +518,8 @@ def builtin_can_deref(
# create tasklet to check that field indices are non-negative (-1 is invalid)
args = [ValueExpr(access_node, _INDEX_DTYPE) for access_node in iterator.indices.values()]
internals = [f"{arg.value.data}_v" for arg in args]
expr_code = " and ".join([f"{v} >= 0" for v in internals])
expr_code = " and ".join(f"{v} != {neighbor_skip_value}" for v in internals)

# TODO(edopao): select-memlet could maybe allow to efficiently translate can_deref to predicative execution
return transformer.add_expr_tasklet(
list(zip(args, internals)),
expr_code,
Expand Down Expand Up @@ -946,7 +1086,7 @@ def _visit_shift(self, node: itir.FunCall) -> IteratorExpr | list[ValueExpr]:
iterator = self.visit(node.args[0])
if not isinstance(iterator, IteratorExpr):
# shift cannot be applied because the argument is not iterable
# TODO: remove this special case when ITIR reduce-unroll pass is able to catch it
# TODO: remove this special case when ITIR pass is able to catch it
assert isinstance(iterator, list) and len(iterator) == 1
assert isinstance(iterator[0], ValueExpr)
return iterator
Expand Down

0 comments on commit e462a2e

Please sign in to comment.