Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refact[next][dace]: split handling of let-statement lambdas from stencil body #1781

Merged
merged 5 commits into from
Dec 16, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from typing import TYPE_CHECKING, Final, Iterable, Optional, Protocol, Sequence, TypeAlias

import dace
import dace.subsets as sbs
from dace import subsets as dace_subsets

from gt4py.next import common as gtx_common, utils as gtx_utils
from gt4py.next.ffront import fbuiltins as gtx_fbuiltins
Expand All @@ -30,7 +30,7 @@
gtir_python_codegen,
utility as dace_gtir_utils,
)
from gt4py.next.type_system import type_specifications as ts
from gt4py.next.type_system import type_info as ti, type_specifications as ts


if TYPE_CHECKING:
Expand All @@ -39,7 +39,7 @@

def _get_domain_indices(
dims: Sequence[gtx_common.Dimension], offsets: Optional[Sequence[dace.symbolic.SymExpr]] = None
) -> sbs.Indices:
) -> dace_subsets.Indices:
"""
Helper function to construct the list of indices for a field domain, applying
an optional offset in each dimension as start index.
Expand All @@ -55,9 +55,9 @@ def _get_domain_indices(
"""
index_variables = [dace.symbolic.SymExpr(dace_gtir_utils.get_map_variable(dim)) for dim in dims]
if offsets is None:
return sbs.Indices(index_variables)
return dace_subsets.Indices(index_variables)
else:
return sbs.Indices(
return dace_subsets.Indices(
[
index - offset if offset != 0 else index
for index, offset in zip(index_variables, offsets, strict=True)
Expand Down Expand Up @@ -96,7 +96,7 @@ def get_local_view(
"""Helper method to access a field in local view, given the compute domain of a field operator."""
if isinstance(self.gt_type, ts.ScalarType):
return gtir_dataflow.MemletExpr(
dc_node=self.dc_node, gt_dtype=self.gt_type, subset=sbs.Indices([0])
dc_node=self.dc_node, gt_dtype=self.gt_type, subset=dace_subsets.Indices([0])
)

if isinstance(self.gt_type, ts.FieldType):
Expand Down Expand Up @@ -263,7 +263,7 @@ def _create_field_operator(

dataflow_output_desc = output_edge.result.dc_node.desc(sdfg)

field_subset = sbs.Range.from_indices(field_indices)
field_subset = dace_subsets.Range.from_indices(field_indices)
if isinstance(output_edge.result.gt_dtype, ts.ScalarType):
assert output_edge.result.gt_dtype == node_type.dtype
assert isinstance(dataflow_output_desc, dace.data.Scalar)
Expand All @@ -280,7 +280,7 @@ def _create_field_operator(
field_dims.append(output_edge.result.gt_dtype.offset_type)
field_shape.extend(dataflow_output_desc.shape)
field_offset.extend(dataflow_output_desc.offset)
field_subset = field_subset + sbs.Range.from_array(dataflow_output_desc)
field_subset = field_subset + dace_subsets.Range.from_array(dataflow_output_desc)

# allocate local temporary storage
field_name, _ = sdfg.add_temp_transient(field_shape, dataflow_output_desc.dtype)
Expand Down Expand Up @@ -366,36 +366,36 @@ def translate_as_fieldop(
"""
assert isinstance(node, gtir.FunCall)
assert cpm.is_call_to(node.fun, "as_fieldop")
assert isinstance(node.type, ts.FieldType)
philip-paul-mueller marked this conversation as resolved.
Show resolved Hide resolved

fun_node = node.fun
assert len(fun_node.args) == 2
stencil_expr, domain_expr = fun_node.args
fieldop_expr, domain_expr = fun_node.args
philip-paul-mueller marked this conversation as resolved.
Show resolved Hide resolved

if isinstance(stencil_expr, gtir.Lambda):
# Default case, handled below: the argument expression is a lambda function
# representing the stencil operation to be computed over the field domain.
pass
elif cpm.is_ref_to(stencil_expr, "deref"):
assert isinstance(node.type, ts.FieldType)
if cpm.is_ref_to(fieldop_expr, "deref"):
# Special usage of 'deref' as argument to fieldop expression, to pass a scalar
# value to 'as_fieldop' function. It results in broadcasting the scalar value
# over the field domain.
stencil_expr = im.lambda_("a")(im.deref("a"))
stencil_expr.expr.type = node.type.dtype # type: ignore[attr-defined]
stencil_expr.expr.type = node.type.dtype
elif isinstance(fieldop_expr, gtir.Lambda):
# Default case, handled below: the argument expression is a lambda function
# representing the stencil operation to be computed over the field domain.
stencil_expr = fieldop_expr
else:
raise NotImplementedError(
f"Expression type '{type(stencil_expr)}' not supported as argument to 'as_fieldop' node."
f"Expression type '{type(fieldop_expr)}' not supported as argument to 'as_fieldop' node."
)

# parse the domain of the field operator
domain = extract_domain(domain_expr)

# visit the list of arguments to be passed to the lambda expression
stencil_args = [_parse_fieldop_arg(arg, sdfg, state, sdfg_builder, domain) for arg in node.args]
fieldop_args = [_parse_fieldop_arg(arg, sdfg, state, sdfg_builder, domain) for arg in node.args]

# represent the field operator as a mapped tasklet graph, which will range over the field domain
taskgen = gtir_dataflow.LambdaToDataflow(sdfg, state, sdfg_builder)
input_edges, output_edge = taskgen.visit(stencil_expr, args=stencil_args)
input_edges, output_edge = taskgen.apply(stencil_expr, args=fieldop_args)

return _create_field_operator(
sdfg, state, domain, node.type, sdfg_builder, input_edges, output_edge
Expand Down Expand Up @@ -654,7 +654,7 @@ def translate_tuple_get(

if not isinstance(node.args[0], gtir.Literal):
raise ValueError("Tuple can only be subscripted with compile-time constants.")
assert node.args[0].type == dace_utils.as_itir_type(INDEX_DTYPE)
assert ti.is_integral(node.args[0].type)
index = int(node.args[0].value)

data_nodes = sdfg_builder.visit(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,22 @@

import abc
import dataclasses
from typing import Any, Dict, Final, List, Optional, Protocol, Set, Tuple, TypeAlias, Union
from typing import (
Any,
Dict,
Final,
List,
Optional,
Protocol,
Sequence,
Set,
Tuple,
TypeAlias,
Union,
)

import dace
import dace.subsets as sbs
from dace import subsets as dace_subsets

from gt4py import eve
from gt4py.next import common as gtx_common
Expand Down Expand Up @@ -68,7 +80,7 @@ class MemletExpr:

dc_node: dace.nodes.AccessNode
gt_dtype: itir_ts.ListType | ts.ScalarType
subset: sbs.Indices | sbs.Range
subset: dace_subsets.Range


@dataclasses.dataclass(frozen=True)
Expand Down Expand Up @@ -104,7 +116,7 @@ class IteratorExpr:
field_domain: list[tuple[gtx_common.Dimension, dace.symbolic.SymExpr]]
indices: dict[gtx_common.Dimension, DataExpr]

def get_memlet_subset(self, sdfg: dace.SDFG) -> sbs.Range:
def get_memlet_subset(self, sdfg: dace.SDFG) -> dace_subsets.Range:
if not all(isinstance(self.indices[dim], SymbolExpr) for dim, _ in self.field_domain):
raise ValueError(f"Cannot deref iterator {self}.")

Expand All @@ -117,7 +129,7 @@ def get_memlet_subset(self, sdfg: dace.SDFG) -> sbs.Range:
assert len(field_desc.shape) == len(self.field_domain)
field_domain = self.field_domain

return sbs.Range.from_string(
return dace_subsets.Range.from_string(
",".join(
str(self.indices[dim].value - offset) # type: ignore[union-attr]
if dim in self.indices
Expand Down Expand Up @@ -152,7 +164,7 @@ class MemletInputEdge(DataflowInputEdge):

state: dace.SDFGState
source: dace.nodes.AccessNode
subset: sbs.Range
subset: dace_subsets.Range
dest: dace.nodes.AccessNode | dace.nodes.Tasklet
dest_conn: Optional[str]

Expand Down Expand Up @@ -202,7 +214,7 @@ def connect(
self,
mx: dace.nodes.MapExit,
dest: dace.nodes.AccessNode,
subset: sbs.Range,
subset: dace_subsets.Range,
) -> None:
# retrieve the node which writes the result
last_node = self.state.in_edges(self.result.dc_node)[0].src
Expand Down Expand Up @@ -258,8 +270,9 @@ def get_reduce_params(node: gtir.FunCall) -> tuple[str, SymbolExpr, SymbolExpr]:

class LambdaToDataflow(eve.NodeVisitor):
"""
Translates an `ir.Lambda` expression to a dataflow graph.
Visitor class to translate a `Lambda` expression to a dataflow graph.

This visitor should be applied by calling `apply()` method on a `Lambda` IR.
The dataflow graph generated here typically represents the stencil function
of a field operator. It only computes single elements or pure local fields,
in case of neighbor values. In case of local fields, the dataflow contains
Expand Down Expand Up @@ -293,15 +306,15 @@ def __init__(
def _add_input_data_edge(
self,
src: dace.nodes.AccessNode,
src_subset: sbs.Range,
src_subset: dace_subsets.Range,
dst_node: dace.nodes.Node,
dst_conn: Optional[str] = None,
src_offset: Optional[list[dace.symbolic.SymExpr]] = None,
) -> None:
input_subset = (
src_subset
if src_offset is None
else sbs.Range(
else dace_subsets.Range(
(start - off, stop - off, step)
for (start, stop, step), off in zip(src_subset, src_offset, strict=True)
)
Expand Down Expand Up @@ -512,7 +525,7 @@ def _visit_deref(self, node: gtir.FunCall) -> DataExpr:
# add new termination point for the field parameter
self._add_input_data_edge(
arg_expr.field,
sbs.Range.from_array(field_desc),
dace_subsets.Range.from_array(field_desc),
deref_node,
"field",
src_offset=[offset for (_, offset) in arg_expr.field_domain],
Expand Down Expand Up @@ -580,7 +593,7 @@ def _visit_neighbors(self, node: gtir.FunCall) -> ValueExpr:
MemletExpr(
dc_node=it.field,
gt_dtype=node.type,
subset=sbs.Range.from_string(
subset=dace_subsets.Range.from_string(
",".join(
str(it.indices[dim].value - offset) # type: ignore[union-attr]
if dim != offset_provider.codomain
Expand All @@ -596,7 +609,7 @@ def _visit_neighbors(self, node: gtir.FunCall) -> ValueExpr:
MemletExpr(
dc_node=self.state.add_access(connectivity),
gt_dtype=node.type,
subset=sbs.Range.from_string(
subset=dace_subsets.Range.from_string(
f"{origin_index.value}, 0:{offset_provider.max_neighbors}"
),
)
Expand Down Expand Up @@ -758,7 +771,7 @@ def _visit_map(self, node: gtir.FunCall) -> ValueExpr:
gt_dtype=itir_ts.ListType(
element_type=node.type.element_type, offset_type=offset_type
),
subset=sbs.Range.from_string(
subset=dace_subsets.Range.from_string(
f"{origin_map_index}, 0:{offset_provider_type.max_neighbors}"
),
)
Expand Down Expand Up @@ -908,7 +921,9 @@ def _make_reduce_with_skip_values(
)
self._add_input_data_edge(
connectivity_node,
sbs.Range.from_string(f"{origin_map_index}, 0:{offset_provider_type.max_neighbors}"),
dace_subsets.Range.from_string(
f"{origin_map_index}, 0:{offset_provider_type.max_neighbors}"
),
nsdfg_node,
"neighbor_indices",
)
Expand Down Expand Up @@ -1081,7 +1096,7 @@ def _make_dynamic_neighbor_offset(
)
self._add_input_data_edge(
offset_table_node,
sbs.Range.from_array(offset_table_node.desc(self.sdfg)),
dace_subsets.Range.from_array(offset_table_node.desc(self.sdfg)),
tasklet_node,
"table",
)
Expand Down Expand Up @@ -1127,7 +1142,7 @@ def _make_unstructured_shift(
shifted_indices[neighbor_dim] = MemletExpr(
dc_node=offset_table_node,
gt_dtype=it.gt_dtype,
subset=sbs.Indices([origin_index.value, offset_expr.value]),
subset=dace_subsets.Indices([origin_index.value, offset_expr.value]),
)
else:
# dynamic offset: we cannot use a memlet to retrieve the offset value, use a tasklet node
Expand Down Expand Up @@ -1264,39 +1279,39 @@ def visit_FunCall(self, node: gtir.FunCall) -> IteratorExpr | DataExpr:
elif cpm.is_applied_shift(node):
return self._visit_shift(node)

elif isinstance(node.fun, gtir.Lambda):
# Lambda node should be visited with 'apply()' method.
raise ValueError(f"Unexpected lambda in 'FunCall' node: {node}.")

elif isinstance(node.fun, gtir.SymRef):
return self._visit_generic_builtin(node)

else:
raise NotImplementedError(f"Invalid 'FunCall' node: {node}.")

def visit_Lambda(
self, node: gtir.Lambda, args: list[IteratorExpr | MemletExpr | SymbolExpr]
) -> tuple[list[DataflowInputEdge], DataflowOutputEdge]:
for p, arg in zip(node.params, args, strict=True):
self.symbol_map[str(p.id)] = arg
output_expr: DataExpr = self.visit(node.expr)
if isinstance(output_expr, ValueExpr):
return self.input_edges, DataflowOutputEdge(self.state, output_expr)
def visit_Lambda(self, node: gtir.Lambda) -> DataflowOutputEdge:
result: DataExpr = self.visit(node.expr)

if isinstance(output_expr, MemletExpr):
if isinstance(result, ValueExpr):
return DataflowOutputEdge(self.state, result)

if isinstance(result, MemletExpr):
# special case where the field operator is simply copying data from source to destination node
output_dtype = output_expr.dc_node.desc(self.sdfg).dtype
output_dtype = result.dc_node.desc(self.sdfg).dtype
tasklet_node = self._add_tasklet("copy", {"__inp"}, {"__out"}, "__out = __inp")
self._add_input_data_edge(
output_expr.dc_node,
output_expr.subset,
result.dc_node,
result.subset,
tasklet_node,
"__inp",
)
else:
assert isinstance(output_expr, SymbolExpr)
# even simpler case, where a constant value is written to destination node
output_dtype = output_expr.dc_dtype
tasklet_node = self._add_tasklet("write", {}, {"__out"}, f"__out = {output_expr.value}")
output_dtype = result.dc_dtype
tasklet_node = self._add_tasklet("write", {}, {"__out"}, f"__out = {result.value}")

output_expr = self._construct_tasklet_result(output_dtype, tasklet_node, "__out")
return self.input_edges, DataflowOutputEdge(self.state, output_expr)
return DataflowOutputEdge(self.state, output_expr)

def visit_Literal(self, node: gtir.Literal) -> SymbolExpr:
dc_dtype = dace_utils.as_dace_type(node.type)
Expand All @@ -1309,3 +1324,45 @@ def visit_SymRef(self, node: gtir.SymRef) -> IteratorExpr | MemletExpr | SymbolE
# if not in the lambda symbol map, this must be a symref to a builtin function
assert param in gtir_python_codegen.MATH_BUILTINS_MAPPING
return SymbolExpr(param, dace.string)

def apply(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't remember that there is a call to apply().
As far as I can tell apply is only called in this function, so it is a recursive function.
However, it does some highly non trivial pre and post processing.
So I think that:

  • This function needs a better name, apply is just too generic.
  • This function needs a doc string.
  • This function needs comment that explain why the pre and post processing is needed and how it is done.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would like to keep the name apply for the entry point of this visitor class, since it is consistent with other visitor classes in GT4Py. However, I agree on the rest so I will write some documentation.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I still think that apply() is the wrong name. If it is for compatibility, then why was it there before?
_vistit_let() should be the much better name.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have now made a code change in this direction.

self,
node: gtir.Lambda,
args: Sequence[IteratorExpr | MemletExpr | SymbolExpr],
) -> tuple[list[DataflowInputEdge], DataflowOutputEdge]:
"""
Entry point for this visitor class.

This visitor will translate a `Lambda` node into a dataflow graph to be
instantiated inside a map scope implementing the field operator.
However, this `apply()` method is responsible to recognize the usage of
the `Lambda` node, which can be either a let-statement or the stencil expression
in local view. The usage of a `Lambda` as let-statement corresponds to computing
some results and making them available inside the lambda scope, represented
as a nested SDFG. All let-statements, if any, are supposed to be encountered
before the stencil expression. In other words, the `Lambda` node representing
the stencil expression is always the innermost node.
Therefore, the lowering of let-statements results in recursive calls to
`apply()` until the stencil expression is found. At that point, it falls
back to the `visit()` function.
"""

# lambda arguments are mapped to symbols defined in lambda scope.
prev_symbol_map = self.symbol_map
self.symbol_map = self.symbol_map.copy()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
prev_symbol_map = self.symbol_map
self.symbol_map = self.symbol_map.copy()
prev_symbol_map = self.symbol_map.copy()

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was thinking to save the self.symbol_map object as prev_symbol_map, just to be on the safe side.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In my view this is indicating that you have some big problems.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see what you mean, side effects. Luckily I do not have such problems: I can do the change you propose.
However, in my opinion, I still prefer the original version where I added/removed items from always the same dictionary object, without copying it.

self.symbol_map |= {str(p.id): arg for p, arg in zip(node.params, args, strict=True)}

if cpm.is_let(node.expr):
let_node = node.expr
let_args = [self.visit(arg) for arg in let_node.args]
assert isinstance(let_node.fun, gtir.Lambda)
input_edges, output_edge = self.apply(let_node.fun, args=let_args)
else:
# this lambda node is not a let-statement, but a stencil expression
output_edge = self.visit(node)
input_edges = self.input_edges

# remove locally defined lambda symbols and restore previous symbols
edopao marked this conversation as resolved.
Show resolved Hide resolved
self.symbol_map = prev_symbol_map

return input_edges, output_edge
Loading
Loading