Skip to content

Commit

Permalink
feat[next][dace]: Fix lowering of nested let-statements (#1697)
Browse files Browse the repository at this point in the history
This PR fixes one corner case of nested let-statements, discovered in
`test_tuple_unpacking_star_multi` during GTIR integration. Test case
added.

Additionally, fixed handling of symbol already defined in SDFG for
#1695.
  • Loading branch information
edopao authored Oct 25, 2024
1 parent 77a8a6d commit eb05a0a
Show file tree
Hide file tree
Showing 2 changed files with 125 additions and 50 deletions.
122 changes: 74 additions & 48 deletions src/gt4py/next/program_processors/runners/dace_fieldview/gtir_sdfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,18 +266,20 @@ def _add_storage(
elif isinstance(gt_type, ts.ScalarType):
dc_dtype = dace_utils.as_dace_type(gt_type)
if name in symbolic_arguments:
sdfg.add_symbol(name, dc_dtype)
elif dace_utils.is_field_symbol(name):
# Sometimes, when the field domain is implicitly derived from the
# field domain, the gt4py lowering adds the field size as a scalar
# argument to the program IR. Suppose a field '__sym', then gt4py
# will add '__sym_size_0'.
# Therefore, here we check whether the shape symbol was already
# created by `_make_array_shape_and_strides`, when allocating
# storage for field arguments. We assume that the scalar argument
# for field size, if present, always follows the field argument.
if name in sdfg.symbols:
assert sdfg.symbols[name].dc_dtype == dc_dtype
# Sometimes, when the field domain is implicitly derived from the
# field domain, the gt4py lowering adds the field size as a scalar
# argument to the program IR. Suppose a field '__sym', then gt4py
# will add '__sym_size_0'.
# Therefore, here we check whether the shape symbol was already
# created by `_make_array_shape_and_strides()`, when allocating
# storage for field arguments. We assume that the scalar argument
# for field size, if present, always follows the field argument.
assert dace_utils.is_field_symbol(name)
if sdfg.symbols[name].dtype != dc_dtype:
raise ValueError(
f"Type mismatch on argument {name}: got {dc_dtype}, expected {sdfg.symbols[name].dtype}."
)
else:
sdfg.add_symbol(name, dc_dtype)
else:
Expand Down Expand Up @@ -599,6 +601,9 @@ def _flatten_tuples(

# Process lambda inputs
#
# All input arguments are passed as parameters to the nested SDFG, therefore
# we they are stored as non-transient array and scalar objects.
#
lambda_arg_nodes = dict(
itertools.chain(*[_flatten_tuples(pname, arg) for pname, arg in lambda_args_mapping])
)
Expand Down Expand Up @@ -639,68 +644,89 @@ def _flatten_tuples(

# Process lambda outputs
#
# The output arguments do not really exist, so they are not allocated before
# visiting the lambda expression. Therefore, the result appears inside the
# nested SDFG as transient array/scalar storage. The exception is given by
# input arguments that are just passed through and returned by the lambda,
# e.g. when the lambda is constructing a tuple: in this case, the result
# data is non-transient, because it corresponds to an input node.
# The transient storage of the lambda result in nested-SDFG is corrected
# below by the call to `make_temps()`: this function ensures that the result
# transient nodes are changed to non-transient and the corresponding output
# connecters on the nested SDFG are connected to new data nodes in parent SDFG.
#
lambda_output_data: Iterable[gtir_builtin_translators.FieldopData] = (
gtx_utils.flatten_nested_tuple(lambda_result)
)
# sanity check on isolated nodes
assert all(
nstate.degree(output_data.dc_node) == 0
for output_data in lambda_output_data
if output_data.dc_node.data in input_memlets
)
# keep only non-isolated output nodes
# The output connectors only need to be setup for the actual result of the
# internal dataflow that writes to transient nodes.
# We filter out the non-transient nodes because they are already available
# in the current context. Later these nodes will eventually be removed
# from the nested SDFG because they are isolated (see `make_temps()`).
lambda_outputs = {
output_data.dc_node.data
for output_data in lambda_output_data
if output_data.dc_node.data not in input_memlets
if output_data.dc_node.desc(nsdfg).transient
}

if lambda_outputs:
nsdfg_node = head_state.add_nested_sdfg(
nsdfg,
parent=sdfg,
inputs=set(input_memlets.keys()),
outputs=lambda_outputs,
symbol_mapping=nsdfg_symbols_mapping,
debuginfo=dace_utils.debug_info(node, default=sdfg.debuginfo),
)
nsdfg_node = head_state.add_nested_sdfg(
nsdfg,
parent=sdfg,
inputs=set(input_memlets.keys()),
outputs=lambda_outputs,
symbol_mapping=nsdfg_symbols_mapping,
debuginfo=dace_utils.debug_info(node, default=sdfg.debuginfo),
)

for connector, memlet in input_memlets.items():
if connector in lambda_arg_nodes:
src_node = lambda_arg_nodes[connector].dc_node
else:
src_node = head_state.add_access(memlet.data)
for connector, memlet in input_memlets.items():
if connector in lambda_arg_nodes:
src_node = lambda_arg_nodes[connector].dc_node
else:
src_node = head_state.add_access(memlet.data)

head_state.add_edge(src_node, None, nsdfg_node, connector, memlet)
head_state.add_edge(src_node, None, nsdfg_node, connector, memlet)

def make_temps(
output_data: gtir_builtin_translators.FieldopData,
) -> gtir_builtin_translators.FieldopData:
if output_data.dc_node.data in lambda_outputs:
connector = output_data.dc_node.data
desc = output_data.dc_node.desc(nsdfg)
# make lambda result non-transient and map it to external temporary
"""
This function will be called while traversing the result of the lambda
dataflow to setup the intermediate data nodes in the parent SDFG and
the data edges from the nested-SDFG output connectors.
"""
desc = output_data.dc_node.desc(nsdfg)
if desc.transient:
# Transient nodes actually contain some result produced by the dataflow
# itself, therefore these nodes are changed to non-transient and an output
# edge will write the result from the nested-SDFG to a new intermediate
# data node in the parent context.
desc.transient = False
# isolated access node will make validation fail
if nstate.degree(output_data.dc_node) == 0:
nstate.remove_node(output_data.dc_node)
temp, _ = sdfg.add_temp_transient_like(desc)
connector = output_data.dc_node.data
dst_node = head_state.add_access(temp)
head_state.add_edge(
nsdfg_node, connector, dst_node, None, sdfg.make_array_memlet(temp)
)
return gtir_builtin_translators.FieldopData(
temp_field = gtir_builtin_translators.FieldopData(
dst_node, output_data.gt_dtype, output_data.local_offset
)
elif output_data.dc_node.data in lambda_arg_nodes:
nstate.remove_node(output_data.dc_node)
return lambda_arg_nodes[output_data.dc_node.data]
# This if branch and the next one handle the non-transient result nodes.
# Non-transient nodes are just input nodes that are immediately returned
# by the lambda expression. Therefore, these nodes are already available
# in the parent context and can be directly accessed there.
temp_field = lambda_arg_nodes[output_data.dc_node.data]
else:
nstate.remove_node(output_data.dc_node)
data_node = head_state.add_access(output_data.dc_node.data)
return gtir_builtin_translators.FieldopData(
data_node, output_data.gt_dtype, output_data.local_offset
dc_node = head_state.add_access(output_data.dc_node.data)
temp_field = gtir_builtin_translators.FieldopData(
dc_node, output_data.gt_dtype, output_data.local_offset
)
# Isolated access node will make validation fail.
# Isolated access nodes can be found in the join-state of an if-expression
# or in lambda expressions that just construct tuples from input arguments.
if nstate.degree(output_data.dc_node) == 0:
nstate.remove_node(output_data.dc_node)
return temp_field

return gtx_utils.tree_map(make_temps)(lambda_result)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1772,10 +1772,10 @@ def test_gtir_let_lambda_with_cond():
assert np.allclose(b, a if s else a * 2)


def test_gtir_let_lambda_with_tuple():
def test_gtir_let_lambda_with_tuple1():
domain = im.domain(gtx_common.GridType.CARTESIAN, ranges={IDim: (0, "size")})
testee = gtir.Program(
id="let_lambda_with_tuple",
id="let_lambda_with_tuple1",
function_definitions=[],
params=[
gtir.Sym(id="x", type=IFTYPE),
Expand Down Expand Up @@ -1816,6 +1816,55 @@ def test_gtir_let_lambda_with_tuple():
assert np.allclose(z_fields[1], b)


def test_gtir_let_lambda_with_tuple2():
domain = im.domain(gtx_common.GridType.CARTESIAN, ranges={IDim: (0, "size")})
val = np.random.rand()
testee = gtir.Program(
id="let_lambda_with_tuple2",
function_definitions=[],
params=[
gtir.Sym(id="x", type=IFTYPE),
gtir.Sym(id="y", type=IFTYPE),
gtir.Sym(id="z", type=ts.TupleType(types=[IFTYPE, IFTYPE, IFTYPE])),
gtir.Sym(id="size", type=SIZE_TYPE),
],
declarations=[],
body=[
gtir.SetAt(
expr=im.let("s", im.as_fieldop("deref", domain)(val))(
im.let("t", im.make_tuple("x", "y"))(
im.let("p", im.op_as_fieldop("plus", domain)("x", "y"))(
im.make_tuple("p", "s", im.tuple_get(1, "t"))
)
)
),
domain=domain,
target=gtir.SymRef(id="z"),
)
],
)

a = np.random.rand(N)
b = np.random.rand(N)

sdfg = dace_backend.build_sdfg_from_gtir(testee, CARTESIAN_OFFSETS)

z_fields = (np.empty_like(a), np.empty_like(a), np.empty_like(a))
z_symbols = dict(
__z_0_size_0=FSYMBOLS["__x_size_0"],
__z_0_stride_0=FSYMBOLS["__x_stride_0"],
__z_1_size_0=FSYMBOLS["__x_size_0"],
__z_1_stride_0=FSYMBOLS["__x_stride_0"],
__z_2_size_0=FSYMBOLS["__x_size_0"],
__z_2_stride_0=FSYMBOLS["__x_stride_0"],
)

sdfg(a, b, *z_fields, **FSYMBOLS, **z_symbols)
assert np.allclose(z_fields[0], a + b)
assert np.allclose(z_fields[1], val)
assert np.allclose(z_fields[2], b)


def test_gtir_if_scalars():
domain = im.domain(gtx_common.GridType.CARTESIAN, ranges={IDim: (0, "size")})
testee = gtir.Program(
Expand Down

0 comments on commit eb05a0a

Please sign in to comment.