Skip to content

Commit

Permalink
Revert "feat[next][dace]: Add support for lift expressions in neighbo…
Browse files Browse the repository at this point in the history
…r reductions (no unrolling) (GridTools#1431)"

This reverts commit e462a2e.
  • Loading branch information
edopao committed Feb 13, 2024
1 parent 1d305e1 commit 696b47c
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 205 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -70,23 +70,34 @@ def preprocess_program(
program: itir.FencilDefinition,
offset_provider: Mapping[str, Any],
lift_mode: itir_transforms.LiftMode,
unroll_reduce: bool = False,
):
node = itir_transforms.apply_common_transforms(
program,
common_subexpression_elimination=False,
force_inline_lambda_args=True,
lift_mode=lift_mode,
offset_provider=offset_provider,
unroll_reduce=unroll_reduce,
unroll_reduce=False,
)

if isinstance(node, itir_transforms.global_tmps.FencilWithTemporaries):
fencil_definition = node.fencil
tmps = node.tmps

elif isinstance(node, itir.FencilDefinition):
fencil_definition = node
# If we don't unroll, there may be lifts left in the itir which can't be lowered to SDFG.
# In this case, just retry with unrolled reductions.
if all([ItirToSDFG._check_no_lifts(closure) for closure in node.closures]):
fencil_definition = node
else:
fencil_definition = itir_transforms.apply_common_transforms(
program,
common_subexpression_elimination=False,
force_inline_lambda_args=True,
lift_mode=lift_mode,
offset_provider=offset_provider,
unroll_reduce=True,
)

tmps = []

else:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -130,24 +130,6 @@ def _make_array_shape_and_strides(
return shape, strides


def _check_no_lifts(node: itir.StencilClosure):
"""
Parse stencil closure ITIR to check that lift expressions only appear as child nodes in neighbor reductions.
Returns
-------
True if lifts do not appear in the ITIR exception lift expressions in neighbor reductions. False otherwise.
"""
neighbors_call_count = 0
for fun in eve.walk_values(node).if_isinstance(itir.FunCall).getattr("fun"):
if getattr(fun, "id", "") == "neighbors":
neighbors_call_count = 3
elif getattr(fun, "id", "") == "lift" and neighbors_call_count != 1:
return False
neighbors_call_count = max(0, neighbors_call_count - 1)
return True


class ItirToSDFG(eve.NodeVisitor):
param_types: list[ts.TypeSpec]
storage_types: dict[str, ts.TypeSpec]
Expand Down Expand Up @@ -366,7 +348,7 @@ def visit_FencilDefinition(self, node: itir.FencilDefinition):
def visit_StencilClosure(
self, node: itir.StencilClosure, array_table: dict[str, dace.data.Array]
) -> tuple[dace.SDFG, list[str], list[str]]:
assert _check_no_lifts(node)
assert ItirToSDFG._check_no_lifts(node)

# Create the closure's nested SDFG and single state.
closure_sdfg = dace.SDFG(name="closure")
Expand Down Expand Up @@ -785,6 +767,15 @@ def _visit_domain(

return tuple(sorted(bounds, key=lambda item: item[0]))

@staticmethod
def _check_no_lifts(node: itir.StencilClosure):
if any(
getattr(fun, "id", "") == "lift"
for fun in eve.walk_values(node).if_isinstance(itir.FunCall).getattr("fun")
):
return False
return True

@staticmethod
def _check_shift_offsets_are_literals(node: itir.StencilClosure):
fun_calls = eve.walk_values(node).if_isinstance(itir.FunCall)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -181,126 +181,6 @@ def __init__(
self.reduce_identity = reduce_identity


def _visit_lift_in_neighbors_reduction(
transformer: "PythonTaskletCodegen",
node: itir.FunCall,
node_args: Sequence[IteratorExpr | list[ValueExpr]],
offset_provider: NeighborTableOffsetProvider,
map_entry: dace.nodes.MapEntry,
map_exit: dace.nodes.MapExit,
neighbor_index_node: dace.nodes.AccessNode,
neighbor_value_node: dace.nodes.AccessNode,
) -> list[ValueExpr]:
neighbor_dim = offset_provider.neighbor_axis.value
origin_dim = offset_provider.origin_axis.value

lifted_args: list[IteratorExpr | ValueExpr] = []
for arg in node_args:
if isinstance(arg, IteratorExpr):
if origin_dim in arg.indices:
lifted_indices = arg.indices.copy()
lifted_indices.pop(origin_dim)
lifted_indices[neighbor_dim] = neighbor_index_node
lifted_args.append(
IteratorExpr(
arg.field,
lifted_indices,
arg.dtype,
arg.dimensions,
)
)
else:
lifted_args.append(arg)
else:
lifted_args.append(arg[0])

lift_context, inner_inputs, inner_outputs = transformer.visit(node.args[0], args=lifted_args)
assert len(inner_outputs) == 1
inner_out_connector = inner_outputs[0].value.data

input_nodes = {}
iterator_index_nodes = {}
lifted_index_connectors = set()

for x, y in inner_inputs:
if isinstance(y, IteratorExpr):
field_connector, inner_index_table = x
input_nodes[field_connector] = y.field
for dim, connector in inner_index_table.items():
if dim == neighbor_dim:
lifted_index_connectors.add(connector)
iterator_index_nodes[connector] = y.indices[dim]
else:
assert isinstance(y, ValueExpr)
input_nodes[x] = y.value

neighbor_tables = filter_neighbor_tables(transformer.offset_provider)
connectivity_names = [connectivity_identifier(offset) for offset in neighbor_tables.keys()]

parent_sdfg = transformer.context.body
parent_state = transformer.context.state

input_mapping = {
connector: create_memlet_full(node.data, node.desc(parent_sdfg))
for connector, node in input_nodes.items()
}
connectivity_mapping = {
name: create_memlet_full(name, parent_sdfg.arrays[name]) for name in connectivity_names
}
array_mapping = {**input_mapping, **connectivity_mapping}
symbol_mapping = map_nested_sdfg_symbols(parent_sdfg, lift_context.body, array_mapping)

nested_sdfg_node = parent_state.add_nested_sdfg(
lift_context.body,
parent_sdfg,
inputs={*array_mapping.keys(), *iterator_index_nodes.keys()},
outputs={inner_out_connector},
symbol_mapping=symbol_mapping,
debuginfo=lift_context.body.debuginfo,
)

for connectivity_connector, memlet in connectivity_mapping.items():
parent_state.add_memlet_path(
parent_state.add_access(memlet.data, debuginfo=lift_context.body.debuginfo),
map_entry,
nested_sdfg_node,
dst_conn=connectivity_connector,
memlet=memlet,
)

for inner_connector, access_node in input_nodes.items():
parent_state.add_memlet_path(
access_node,
map_entry,
nested_sdfg_node,
dst_conn=inner_connector,
memlet=input_mapping[inner_connector],
)

for inner_connector, access_node in iterator_index_nodes.items():
memlet = dace.Memlet(data=access_node.data, subset="0")
if inner_connector in lifted_index_connectors:
parent_state.add_edge(access_node, None, nested_sdfg_node, inner_connector, memlet)
else:
parent_state.add_memlet_path(
access_node,
map_entry,
nested_sdfg_node,
dst_conn=inner_connector,
memlet=memlet,
)

parent_state.add_memlet_path(
nested_sdfg_node,
map_exit,
neighbor_value_node,
src_conn=inner_out_connector,
memlet=dace.Memlet(data=neighbor_value_node.data, subset=",".join(map_entry.params)),
)

return [ValueExpr(neighbor_value_node, inner_outputs[0].dtype)]


def builtin_neighbors(
transformer: "PythonTaskletCodegen", node: itir.Expr, node_args: list[itir.Expr]
) -> list[ValueExpr]:
Expand All @@ -318,16 +198,7 @@ def builtin_neighbors(
"Neighbor reduction only implemented for connectivity based on neighbor tables."
)

lift_node = None
if isinstance(data, FunCall):
assert isinstance(data.fun, itir.FunCall)
fun_node = data.fun
if isinstance(fun_node.fun, itir.SymRef) and fun_node.fun.id == "lift":
lift_node = fun_node
lift_args = transformer.visit(data.args)
iterator = next(filter(lambda x: isinstance(x, IteratorExpr), lift_args), None)
if lift_node is None:
iterator = transformer.visit(data)
iterator = transformer.visit(data)
assert isinstance(iterator, IteratorExpr)
field_desc = iterator.field.desc(transformer.context.body)
origin_index_node = iterator.indices[offset_provider.origin_axis.value]
Expand Down Expand Up @@ -388,56 +259,44 @@ def builtin_neighbors(
dace.Memlet(data=neighbor_index_var, subset="0"),
)

if lift_node is not None:
_visit_lift_in_neighbors_reduction(
transformer,
lift_node,
lift_args,
offset_provider,
me,
mx,
neighbor_index_node,
neighbor_value_node,
)
else:
data_access_tasklet = state.add_tasklet(
"data_access",
code="__data = __field[__idx]"
+ (
f" if __idx != {neighbor_skip_value} else {transformer.context.reduce_identity.value}"
if offset_provider.has_skip_values
else ""
),
inputs={"__field", "__idx"},
outputs={"__data"},
debuginfo=di,
)
# select full shape only in the neighbor-axis dimension
field_subset = tuple(
f"0:{shape}" if dim == offset_provider.neighbor_axis.value else f"i_{dim}"
for dim, shape in zip(sorted(iterator.dimensions), field_desc.shape)
)
state.add_memlet_path(
iterator.field,
me,
data_access_tasklet,
memlet=create_memlet_at(iterator.field.data, field_subset),
dst_conn="__field",
)
state.add_edge(
neighbor_index_node,
None,
data_access_tasklet,
"__idx",
dace.Memlet(data=neighbor_index_var, subset="0"),
)
state.add_memlet_path(
data_access_tasklet,
mx,
neighbor_value_node,
memlet=dace.Memlet(data=neighbor_value_var, subset=neighbor_map_index, debuginfo=di),
src_conn="__data",
)
data_access_tasklet = state.add_tasklet(
"data_access",
code="__data = __field[__idx]"
+ (
f" if __idx != {neighbor_skip_value} else {transformer.context.reduce_identity.value}"
if offset_provider.has_skip_values
else ""
),
inputs={"__field", "__idx"},
outputs={"__data"},
debuginfo=di,
)
# select full shape only in the neighbor-axis dimension
field_subset = tuple(
f"0:{shape}" if dim == offset_provider.neighbor_axis.value else f"i_{dim}"
for dim, shape in zip(sorted(iterator.dimensions), field_desc.shape)
)
state.add_memlet_path(
iterator.field,
me,
data_access_tasklet,
memlet=create_memlet_at(iterator.field.data, field_subset),
dst_conn="__field",
)
state.add_edge(
neighbor_index_node,
None,
data_access_tasklet,
"__idx",
dace.Memlet(data=neighbor_index_var, subset="0"),
)
state.add_memlet_path(
data_access_tasklet,
mx,
neighbor_value_node,
memlet=dace.Memlet(data=neighbor_value_var, subset=neighbor_map_index, debuginfo=di),
src_conn="__data",
)

if not offset_provider.has_skip_values:
return [ValueExpr(neighbor_value_node, iterator.dtype)]
Expand Down Expand Up @@ -518,8 +377,9 @@ def builtin_can_deref(
# create tasklet to check that field indices are non-negative (-1 is invalid)
args = [ValueExpr(access_node, _INDEX_DTYPE) for access_node in iterator.indices.values()]
internals = [f"{arg.value.data}_v" for arg in args]
expr_code = " and ".join(f"{v} != {neighbor_skip_value}" for v in internals)
expr_code = " and ".join([f"{v} >= 0" for v in internals])

# TODO(edopao): select-memlet could maybe allow to efficiently translate can_deref to predicative execution
return transformer.add_expr_tasklet(
list(zip(args, internals)),
expr_code,
Expand Down Expand Up @@ -1086,7 +946,7 @@ def _visit_shift(self, node: itir.FunCall) -> IteratorExpr | list[ValueExpr]:
iterator = self.visit(node.args[0])
if not isinstance(iterator, IteratorExpr):
# shift cannot be applied because the argument is not iterable
# TODO: remove this special case when ITIR pass is able to catch it
# TODO: remove this special case when ITIR reduce-unroll pass is able to catch it
assert isinstance(iterator, list) and len(iterator) == 1
assert isinstance(iterator[0], ValueExpr)
return iterator
Expand Down

0 comments on commit 696b47c

Please sign in to comment.