Skip to content

Commit

Permalink
bug[next][dace]: Fix lowering of broadcast (#1698)
Browse files Browse the repository at this point in the history
Fix lowering of `as_fieldop` with broadcast expression after changes in
PR #1701.

Additional change:
- Add support for GT4Py zero-dimensional fields, equivalent of numpy
zero-dimensional arrays. Test case added.
  • Loading branch information
edopao authored Oct 25, 2024
1 parent eb05a0a commit db249bd
Show file tree
Hide file tree
Showing 6 changed files with 187 additions and 97 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,13 @@
def _convert_arg(arg: Any, sdfg_param: str, use_field_canonical_representation: bool) -> Any:
if not isinstance(arg, gtx_common.Field):
return arg
if len(arg.domain.dims) == 0:
# Pass zero-dimensional fields as scalars.
# We need to extract the scalar value from the 0d numpy array without changing its type.
# Note that 'ndarray.item()' always transforms the numpy scalar to a python scalar,
# which may change its precision. To avoid this, we use here the empty tuple as index
# for 'ndarray.__getitem__()'.
return arg.ndarray[()]
# field domain offsets are not supported
non_zero_offsets = [
(dim, dim_range)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,32 @@ def _parse_fieldop_arg(
raise NotImplementedError(f"Node type {type(arg.gt_dtype)} not supported.")


def _get_field_shape(
domain: FieldopDomain,
) -> tuple[list[gtx_common.Dimension], list[dace.symbolic.SymExpr]]:
"""
Parse the field operator domain and generates the shape of the result field.
It should be enough to allocate an array with shape (upper_bound - lower_bound)
but this would require to use array offset for compensate for the start index.
Suppose that a field operator executes on domain [2,N-2], the dace array to store
the result only needs size (N-4), but this would require to compensate all array
accesses with offset -2 (which corresponds to -lower_bound). Instead, we choose
to allocate (N-2), leaving positions [0:2] unused. The reason is that array offset
is known to cause issues to SDFG inlining. Besides, map fusion will in any case
eliminate most of transient arrays.
Args:
domain: The field operator domain.
Returns:
A tuple of two lists: the list of field dimensions and the list of dace
array sizes in each dimension.
"""
domain_dims, _, domain_ubs = zip(*domain)
return list(domain_dims), list(domain_ubs)


def _create_temporary_field(
sdfg: dace.SDFG,
state: dace.SDFGState,
Expand All @@ -146,17 +172,7 @@ def _create_temporary_field(
dataflow_output: gtir_dataflow.DataflowOutputEdge,
) -> FieldopData:
"""Helper method to allocate a temporary field where to write the output of a field operator."""
domain_dims, _, domain_ubs = zip(*domain)
field_dims = list(domain_dims)
# It should be enough to allocate an array with shape (upper_bound - lower_bound)
# but this would require to use array offset for compensate for the start index.
# Suppose that a field operator executes on domain [2,N-2], the dace array to store
# the result only needs size (N-4), but this would require to compensate all array
# accesses with offset -2 (which corresponds to -lower_bound). Instead, we choose
# to allocate (N-2), leaving positions [0:2] unused. The reason is that array offset
# is known to cause issues to SDFG inlining. Besides, map fusion will in any case
# eliminate most of transient arrays.
field_shape = list(domain_ubs)
field_dims, field_shape = _get_field_shape(domain)

output_desc = dataflow_output.result.dc_node.desc(sdfg)
if isinstance(output_desc, dace.data.Array):
Expand Down Expand Up @@ -311,17 +327,46 @@ def translate_broadcast_scalar(
assert cpm.is_ref_to(stencil_expr, "deref")

domain = extract_domain(domain_expr)
domain_indices = sbs.Indices([dace_gtir_utils.get_map_variable(dim) for dim, _, _ in domain])
field_dims, field_shape = _get_field_shape(domain)
field_subset = sbs.Range.from_string(
",".join(dace_gtir_utils.get_map_variable(dim) for dim in field_dims)
)

assert len(node.args) == 1
assert isinstance(node.args[0].type, ts.ScalarType)
scalar_expr = _parse_fieldop_arg(node.args[0], sdfg, state, sdfg_builder, domain)
assert isinstance(scalar_expr, gtir_dataflow.MemletExpr)
assert scalar_expr.subset == sbs.Indices.from_string("0")
result = gtir_dataflow.DataflowOutputEdge(
state, gtir_dataflow.ValueExpr(scalar_expr.dc_node, node.args[0].type)
)
result_field = _create_temporary_field(sdfg, state, domain, node.type, dataflow_output=result)

if isinstance(node.args[0].type, ts.ScalarType):
assert isinstance(scalar_expr, (gtir_dataflow.MemletExpr, gtir_dataflow.ValueExpr))
input_subset = (
str(scalar_expr.subset) if isinstance(scalar_expr, gtir_dataflow.MemletExpr) else "0"
)
input_node = scalar_expr.dc_node
gt_dtype = node.args[0].type
elif isinstance(node.args[0].type, ts.FieldType):
assert isinstance(scalar_expr, gtir_dataflow.IteratorExpr)
if len(node.args[0].type.dims) == 0: # zero-dimensional field
input_subset = "0"
elif all(
isinstance(scalar_expr.indices[dim], gtir_dataflow.SymbolExpr)
for dim in scalar_expr.dimensions
if dim not in field_dims
):
input_subset = ",".join(
dace_gtir_utils.get_map_variable(dim)
if dim in field_dims
else scalar_expr.indices[dim].value # type: ignore[union-attr] # catched by exception above
for dim in scalar_expr.dimensions
)
else:
raise ValueError(f"Cannot deref field {scalar_expr.field} in broadcast expression.")

input_node = scalar_expr.field
gt_dtype = node.args[0].type.dtype
else:
raise ValueError(f"Unexpected argument {node.args[0]} in broadcast expression.")

output, _ = sdfg.add_temp_transient(field_shape, input_node.desc(sdfg).dtype)
output_node = state.add_access(output)

sdfg_builder.add_mapped_tasklet(
"broadcast",
Expand All @@ -330,15 +375,15 @@ def translate_broadcast_scalar(
dace_gtir_utils.get_map_variable(dim): f"{lower_bound}:{upper_bound}"
for dim, lower_bound, upper_bound in domain
},
inputs={"__inp": dace.Memlet(data=scalar_expr.dc_node.data, subset="0")},
inputs={"__inp": dace.Memlet(data=input_node.data, subset=input_subset)},
code="__val = __inp",
outputs={"__val": dace.Memlet(data=result_field.dc_node.data, subset=domain_indices)},
input_nodes={scalar_expr.dc_node.data: scalar_expr.dc_node},
output_nodes={result_field.dc_node.data: result_field.dc_node},
outputs={"__val": dace.Memlet(data=output_node.data, subset=field_subset)},
input_nodes={input_node.data: input_node},
output_nodes={output_node.data: output_node},
external_edges=True,
)

return result_field
return FieldopData(output_node, ts.FieldType(field_dims, gt_dtype), local_offset=None)


def translate_if(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -426,82 +426,86 @@ def _visit_deref(self, node: gtir.FunCall) -> DataExpr:
assert len(node.args) == 1
arg_expr = self.visit(node.args[0])

if isinstance(arg_expr, IteratorExpr):
field_desc = arg_expr.field.desc(self.sdfg)
assert len(field_desc.shape) == len(arg_expr.dimensions)
if all(isinstance(index, SymbolExpr) for index in arg_expr.indices.values()):
# when all indices are symblic expressions, we can perform direct field access through a memlet
field_subset = sbs.Range(
(arg_expr.indices[dim].value, arg_expr.indices[dim].value, 1) # type: ignore[union-attr]
if dim in arg_expr.indices
else (0, size - 1, 1)
for dim, size in zip(arg_expr.dimensions, field_desc.shape)
)
return MemletExpr(arg_expr.field, field_subset, arg_expr.local_offset)

else:
# we use a tasklet to dereference an iterator when one or more indices are the result of some computation,
# either indirection through connectivity table or dynamic cartesian offset.
assert all(dim in arg_expr.indices for dim in arg_expr.dimensions)
field_indices = [(dim, arg_expr.indices[dim]) for dim in arg_expr.dimensions]
index_connectors = [
IndexConnectorFmt.format(dim=dim.value)
for dim, index in field_indices
if not isinstance(index, SymbolExpr)
]
# here `internals` refer to the names used as index in the tasklet code string:
# an index can be either a connector name (for dynamic/indirect indices)
# or a symbol value (for literal values and scalar arguments).
index_internals = ",".join(
str(index.value)
if isinstance(index, SymbolExpr)
else IndexConnectorFmt.format(dim=dim.value)
for dim, index in field_indices
)
deref_node = self._add_tasklet(
"runtime_deref",
{"field"} | set(index_connectors),
{"val"},
code=f"val = field[{index_internals}]",
)
# add new termination point for the field parameter
self._add_input_data_edge(
arg_expr.field,
sbs.Range.from_array(field_desc),
deref_node,
"field",
)
if not isinstance(arg_expr, IteratorExpr):
# dereferencing a scalar or a literal node results in the node itself
return arg_expr

for dim, index_expr in field_indices:
# add termination points for the dynamic iterator indices
deref_connector = IndexConnectorFmt.format(dim=dim.value)
if isinstance(index_expr, MemletExpr):
self._add_input_data_edge(
index_expr.dc_node,
index_expr.subset,
deref_node,
deref_connector,
)

elif isinstance(index_expr, ValueExpr):
self._add_edge(
index_expr.dc_node,
None,
deref_node,
deref_connector,
dace.Memlet(data=index_expr.dc_node.data, subset="0"),
)
else:
assert isinstance(index_expr, SymbolExpr)

dc_dtype = arg_expr.field.desc(self.sdfg).dtype
return self._construct_tasklet_result(
dc_dtype, deref_node, "val", arg_expr.local_offset
)
field_desc = arg_expr.field.desc(self.sdfg)
if isinstance(field_desc, dace.data.Scalar):
# deref a zero-dimensional field
assert len(arg_expr.dimensions) == 0
assert isinstance(node.type, ts.ScalarType)
return MemletExpr(arg_expr.field, subset="0")
# default case: deref a field with one or more dimensions
assert len(field_desc.shape) == len(arg_expr.dimensions)
if all(isinstance(index, SymbolExpr) for index in arg_expr.indices.values()):
# when all indices are symblic expressions, we can perform direct field access through a memlet
field_subset = sbs.Range(
(arg_expr.indices[dim].value, arg_expr.indices[dim].value, 1) # type: ignore[union-attr]
if dim in arg_expr.indices
else (0, size - 1, 1)
for dim, size in zip(arg_expr.dimensions, field_desc.shape)
)
return MemletExpr(arg_expr.field, field_subset, arg_expr.local_offset)

else:
# dereferencing a scalar or a literal node results in the node itself
return arg_expr
# we use a tasklet to dereference an iterator when one or more indices are the result of some computation,
# either indirection through connectivity table or dynamic cartesian offset.
assert all(dim in arg_expr.indices for dim in arg_expr.dimensions)
field_indices = [(dim, arg_expr.indices[dim]) for dim in arg_expr.dimensions]
index_connectors = [
IndexConnectorFmt.format(dim=dim.value)
for dim, index in field_indices
if not isinstance(index, SymbolExpr)
]
# here `internals` refer to the names used as index in the tasklet code string:
# an index can be either a connector name (for dynamic/indirect indices)
# or a symbol value (for literal values and scalar arguments).
index_internals = ",".join(
str(index.value)
if isinstance(index, SymbolExpr)
else IndexConnectorFmt.format(dim=dim.value)
for dim, index in field_indices
)
deref_node = self._add_tasklet(
"runtime_deref",
{"field"} | set(index_connectors),
{"val"},
code=f"val = field[{index_internals}]",
)
# add new termination point for the field parameter
self._add_input_data_edge(
arg_expr.field,
sbs.Range.from_array(field_desc),
deref_node,
"field",
)

for dim, index_expr in field_indices:
# add termination points for the dynamic iterator indices
deref_connector = IndexConnectorFmt.format(dim=dim.value)
if isinstance(index_expr, MemletExpr):
self._add_input_data_edge(
index_expr.dc_node,
index_expr.subset,
deref_node,
deref_connector,
)

elif isinstance(index_expr, ValueExpr):
self._add_edge(
index_expr.dc_node,
None,
deref_node,
deref_connector,
dace.Memlet(data=index_expr.dc_node.data, subset="0"),
)
else:
assert isinstance(index_expr, SymbolExpr)

return self._construct_tasklet_result(
field_desc.dtype, deref_node, "val", arg_expr.local_offset
)

def _visit_neighbors(self, node: gtir.FunCall) -> ValueExpr:
assert len(node.args) == 2
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,10 @@ def _add_storage(
return tuple_fields

elif isinstance(gt_type, ts.FieldType):
if len(gt_type.dims) == 0:
# represent zero-dimensional fields as scalar arguments
return self._add_storage(sdfg, symbolic_arguments, name, gt_type.dtype, transient)
# handle default case: field with one or more dimensions
dc_dtype = dace_utils.as_dace_type(gt_type.dtype)
# use symbolic shape, which allows to invoke the program with fields of different size;
# and symbolic strides, which enables decoupling the memory layout from generated code.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,7 @@ def testee(a: tuple[int32, tuple[int32, int32]]) -> cases.VField:


@pytest.mark.uses_tuple_args
@pytest.mark.uses_zero_dimensional_fields
def test_zero_dim_tuple_arg(unstructured_case):
@gtx.field_operator
def testee(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -372,6 +372,35 @@ def test_gtir_tuple_broadcast_scalar():
assert np.allclose(d, a + 2 * b + 3 * c)


def test_gtir_zero_dim_fields():
domain = im.domain(gtx_common.GridType.CARTESIAN, ranges={IDim: (0, "size")})
testee = gtir.Program(
id="gtir_zero_dim_fields",
function_definitions=[],
params=[
gtir.Sym(id="x", type=ts.FieldType(dims=[], dtype=IFTYPE.dtype)),
gtir.Sym(id="y", type=IFTYPE),
gtir.Sym(id="size", type=SIZE_TYPE),
],
declarations=[],
body=[
gtir.SetAt(
expr=im.as_fieldop("deref", domain)("x"),
domain=domain,
target=gtir.SymRef(id="y"),
)
],
)

a = np.asarray(np.random.rand())
b = np.empty(N)

sdfg = dace_backend.build_sdfg_from_gtir(testee, CARTESIAN_OFFSETS)

sdfg(a.item(), b, **FSYMBOLS)
assert np.allclose(a, b)


def test_gtir_tuple_return():
domain = im.domain(gtx_common.GridType.CARTESIAN, ranges={IDim: (0, "size")})
testee = gtir.Program(
Expand Down

0 comments on commit db249bd

Please sign in to comment.