Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat[next][dace]: Add support for lift expressions in neighbor reductions (no unrolling) #1431

Merged
merged 19 commits into from
Feb 2, 2024
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -70,27 +70,13 @@ def preprocess_program(
offset_provider: Mapping[str, Any],
lift_mode: itir_transforms.LiftMode,
):
node = itir_transforms.apply_common_transforms(
return itir_transforms.apply_common_transforms(
program,
common_subexpression_elimination=False,
lift_mode=lift_mode,
offset_provider=offset_provider,
unroll_reduce=False,
)
# If we don't unroll, there may be lifts left in the itir which can't be lowered to SDFG.
# In this case, just retry with unrolled reductions.
if all([ItirToSDFG._check_no_lifts(closure) for closure in node.closures]):
fencil_definition = node
else:
fencil_definition = itir_transforms.apply_common_transforms(
program,
common_subexpression_elimination=False,
force_inline_lambda_args=True,
lift_mode=lift_mode,
offset_provider=offset_provider,
unroll_reduce=True,
)
return fencil_definition


def get_args(sdfg: dace.SDFG, args: Sequence[Any]) -> dict[str, Any]:
Expand Down Expand Up @@ -299,7 +285,7 @@ def build_sdfg_from_itir(
for nested_sdfg in sdfg.all_sdfgs_recursive():
if not nested_sdfg.debuginfo:
_, frameinfo = warnings.warn(
f"{nested_sdfg} does not have debuginfo. Consider adding them in the corresponding nested sdfg."
f"{nested_sdfg.label} does not have debuginfo. Consider adding them in the corresponding nested sdfg."
), getframeinfo(
currentframe() # type: ignore
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,24 @@ def _make_array_shape_and_strides(
return shape, strides


def _check_no_lifts(node: itir.StencilClosure):
"""
Parse stencil closure ITIR to check that lift expressions only appear as child nodes in neighbor reductions.

Returns
-------
True if lifts do not appear in the ITIR exception lift expressions in neighbor reductions. False otherwise.
"""
neighbors_call_count = 0
for fun in eve.walk_values(node).if_isinstance(itir.FunCall).getattr("fun"):
if getattr(fun, "id", "") == "neighbors":
neighbors_call_count = 3
elif getattr(fun, "id", "") == "lift" and neighbors_call_count != 1:
return False
neighbors_call_count = max(0, neighbors_call_count - 1)
return True


class ItirToSDFG(eve.NodeVisitor):
param_types: list[ts.TypeSpec]
storage_types: dict[str, ts.TypeSpec]
Expand Down Expand Up @@ -262,7 +280,7 @@ def visit_FencilDefinition(self, node: itir.FencilDefinition):
def visit_StencilClosure(
self, node: itir.StencilClosure, array_table: dict[str, dace.data.Array]
) -> tuple[dace.SDFG, list[str], list[str]]:
assert ItirToSDFG._check_no_lifts(node)
assert _check_no_lifts(node)

# Create the closure's nested SDFG and single state.
closure_sdfg = dace.SDFG(name="closure")
Expand Down Expand Up @@ -681,15 +699,6 @@ def _visit_domain(

return tuple(sorted(bounds, key=lambda item: item[0]))

@staticmethod
def _check_no_lifts(node: itir.StencilClosure):
if any(
getattr(fun, "id", "") == "lift"
for fun in eve.walk_values(node).if_isinstance(itir.FunCall).getattr("fun")
):
return False
return True

@staticmethod
philip-paul-mueller marked this conversation as resolved.
Show resolved Hide resolved
def _check_shift_offsets_are_literals(node: itir.StencilClosure):
fun_calls = eve.walk_values(node).if_isinstance(itir.FunCall)
Expand Down
Loading
Loading