Skip to content

Commit

Permalink
feat[next][dace]: Add more debug info to DaCe (#1384)
Browse files Browse the repository at this point in the history
* Add more debug info to DaCe (pass SourceLocation from past/foast to itir, and from itir to the SDFG): Preserve Location through Visitors
  • Loading branch information
kotsaloscv authored Jan 23, 2024
1 parent 8bd5a41 commit d5cfa7d
Show file tree
Hide file tree
Showing 29 changed files with 288 additions and 120 deletions.
8 changes: 7 additions & 1 deletion src/gt4py/eve/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,12 @@
field,
frozenmodel,
)
from .traits import SymbolTableTrait, ValidatedSymbolTableTrait, VisitorWithSymbolTableTrait
from .traits import (
PreserveLocationVisitor,
SymbolTableTrait,
ValidatedSymbolTableTrait,
VisitorWithSymbolTableTrait,
)
from .trees import (
bfs_walk_items,
bfs_walk_values,
Expand Down Expand Up @@ -113,6 +118,7 @@
"SymbolTableTrait",
"ValidatedSymbolTableTrait",
"VisitorWithSymbolTableTrait",
"PreserveLocationVisitor",
# trees
"bfs_walk_items",
"bfs_walk_values",
Expand Down
8 changes: 8 additions & 0 deletions src/gt4py/eve/traits.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,3 +172,11 @@ def visit(self, node: concepts.RootNode, **kwargs: Any) -> Any:
kwargs["symtable"] = kwargs["symtable"].parents

return result


class PreserveLocationVisitor(visitors.NodeVisitor):
def visit(self, node: concepts.RootNode, **kwargs: Any) -> Any:
result = super().visit(node, **kwargs)
if hasattr(node, "location") and hasattr(result, "location"):
result.location = node.location
return result
8 changes: 4 additions & 4 deletions src/gt4py/next/ffront/foast_to_itir.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import dataclasses
from typing import Any, Callable, Optional

from gt4py.eve import NodeTranslator
from gt4py.eve import NodeTranslator, PreserveLocationVisitor
from gt4py.eve.utils import UIDGenerator
from gt4py.next.ffront import (
dialect_ast_enums,
Expand All @@ -39,7 +39,7 @@ def promote_to_list(


@dataclasses.dataclass
class FieldOperatorLowering(NodeTranslator):
class FieldOperatorLowering(PreserveLocationVisitor, NodeTranslator):
"""
Lower FieldOperator AST (FOAST) to Iterator IR (ITIR).
Expand All @@ -61,7 +61,7 @@ class FieldOperatorLowering(NodeTranslator):
<class 'gt4py.next.iterator.ir.FunctionDefinition'>
>>> lowered.id
SymbolName('fieldop')
>>> lowered.params
>>> lowered.params # doctest: +ELLIPSIS
[Sym(id=SymbolName('inp'), kind='Iterator', dtype=('float64', False))]
"""

Expand Down Expand Up @@ -142,7 +142,7 @@ def visit_IfStmt(
self, node: foast.IfStmt, *, inner_expr: Optional[itir.Expr], **kwargs
) -> itir.Expr:
# the lowered if call doesn't need to be lifted as the condition can only originate
# from a scalar value (and not a field)
# from a scalar value (and not a field)
assert (
isinstance(node.condition.type, ts.ScalarType)
and node.condition.type.kind == ts.ScalarKind.BOOL
Expand Down
20 changes: 16 additions & 4 deletions src/gt4py/next/ffront/past_to_itir.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,9 @@ def _flatten_tuple_expr(
raise ValueError("Only 'past.Name', 'past.Subscript' or 'past.TupleExpr' thereof are allowed.")


class ProgramLowering(traits.VisitorWithSymbolTableTrait, NodeTranslator):
class ProgramLowering(
traits.PreserveLocationVisitor, traits.VisitorWithSymbolTableTrait, NodeTranslator
):
"""
Lower Program AST (PAST) to Iterator IR (ITIR).
Expand Down Expand Up @@ -151,6 +153,7 @@ def _visit_stencil_call(self, node: past.Call, **kwargs) -> itir.StencilClosure:
stencil=itir.SymRef(id=node.func.id),
inputs=[*lowered_args, *lowered_kwargs.values()],
output=output,
location=node.location,
)

def _visit_slice_bound(
Expand All @@ -175,17 +178,22 @@ def _visit_slice_bound(
lowered_bound = self.visit(slice_bound, **kwargs)
else:
raise AssertionError("Expected 'None' or 'past.Constant'.")
if slice_bound:
lowered_bound.location = slice_bound.location
return lowered_bound

def _construct_itir_out_arg(self, node: past.Expr) -> itir.Expr:
if isinstance(node, past.Name):
return itir.SymRef(id=node.id)
return itir.SymRef(id=node.id, location=node.location)
elif isinstance(node, past.Subscript):
return self._construct_itir_out_arg(node.value)
itir_node = self._construct_itir_out_arg(node.value)
itir_node.location = node.location
return itir_node
elif isinstance(node, past.TupleExpr):
return itir.FunCall(
fun=itir.SymRef(id="make_tuple"),
args=[self._construct_itir_out_arg(el) for el in node.elts],
location=node.location,
)
else:
raise ValueError(
Expand Down Expand Up @@ -247,7 +255,11 @@ def _construct_itir_domain_arg(
else:
raise AssertionError()

return itir.FunCall(fun=itir.SymRef(id=domain_builtin), args=domain_args)
return itir.FunCall(
fun=itir.SymRef(id=domain_builtin),
args=domain_args,
location=(node_domain or out_field).location,
)

def _construct_itir_initialized_domain_arg(
self,
Expand Down
3 changes: 3 additions & 0 deletions src/gt4py/next/iterator/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,15 @@

import gt4py.eve as eve
from gt4py.eve import Coerced, SymbolName, SymbolRef, datamodels
from gt4py.eve.concepts import SourceLocation
from gt4py.eve.traits import SymbolTableTrait, ValidatedSymbolTableTrait
from gt4py.eve.utils import noninstantiable


@noninstantiable
class Node(eve.Node):
location: Optional[SourceLocation] = eve.field(default=None, repr=False, compare=False)

def __str__(self) -> str:
from gt4py.next.iterator.pretty_printer import pformat

Expand Down
2 changes: 1 addition & 1 deletion src/gt4py/next/iterator/transforms/collapse_list_get.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from gt4py.next.iterator import ir


class CollapseListGet(eve.NodeTranslator):
class CollapseListGet(eve.PreserveLocationVisitor, eve.NodeTranslator):
"""Simplifies expressions containing `list_get`.
Examples
Expand Down
9 changes: 1 addition & 8 deletions src/gt4py/next/iterator/transforms/collapse_tuple.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def _get_tuple_size(elem: ir.Node, node_types: Optional[dict] = None) -> int | t


@dataclass(frozen=True)
class CollapseTuple(eve.NodeTranslator):
class CollapseTuple(eve.PreserveLocationVisitor, eve.NodeTranslator):
"""
Simplifies `make_tuple`, `tuple_get` calls.
Expand Down Expand Up @@ -88,13 +88,6 @@ def apply(
node_types,
).visit(node)

return cls(
ignore_tuple_size,
collapse_make_tuple_tuple_get,
collapse_tuple_get_make_tuple,
use_global_type_inference,
).visit(node)

def visit_FunCall(self, node: ir.FunCall, **kwargs) -> ir.Node:
if (
self.collapse_make_tuple_tuple_get
Expand Down
4 changes: 2 additions & 2 deletions src/gt4py/next/iterator/transforms/constant_folding.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,12 @@
#
# SPDX-License-Identifier: GPL-3.0-or-later

from gt4py.eve import NodeTranslator
from gt4py.eve import NodeTranslator, PreserveLocationVisitor
from gt4py.next.iterator import embedded, ir
from gt4py.next.iterator.ir_utils import ir_makers as im


class ConstantFolding(NodeTranslator):
class ConstantFolding(PreserveLocationVisitor, NodeTranslator):
@classmethod
def apply(cls, node: ir.Node) -> ir.Node:
return cls().visit(node)
Expand Down
14 changes: 10 additions & 4 deletions src/gt4py/next/iterator/transforms/cse.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,20 @@
import operator
import typing

from gt4py.eve import NodeTranslator, NodeVisitor, SymbolTableTrait, VisitorWithSymbolTableTrait
from gt4py.eve import (
NodeTranslator,
NodeVisitor,
PreserveLocationVisitor,
SymbolTableTrait,
VisitorWithSymbolTableTrait,
)
from gt4py.eve.utils import UIDGenerator
from gt4py.next.iterator import ir
from gt4py.next.iterator.transforms.inline_lambdas import inline_lambda


@dataclasses.dataclass
class _NodeReplacer(NodeTranslator):
class _NodeReplacer(PreserveLocationVisitor, NodeTranslator):
PRESERVED_ANNEX_ATTRS = ("type",)

expr_map: dict[int, ir.SymRef]
Expand Down Expand Up @@ -72,7 +78,7 @@ def _is_collectable_expr(node: ir.Node) -> bool:


@dataclasses.dataclass
class CollectSubexpressions(VisitorWithSymbolTableTrait, NodeVisitor):
class CollectSubexpressions(PreserveLocationVisitor, VisitorWithSymbolTableTrait, NodeVisitor):
@dataclasses.dataclass
class SubexpressionData:
#: A list of node ids with equal hash and a set of collected child subexpression ids
Expand Down Expand Up @@ -341,7 +347,7 @@ def extract_subexpression(


@dataclasses.dataclass(frozen=True)
class CommonSubexpressionElimination(NodeTranslator):
class CommonSubexpressionElimination(PreserveLocationVisitor, NodeTranslator):
"""
Perform common subexpression elimination.
Expand Down
4 changes: 2 additions & 2 deletions src/gt4py/next/iterator/transforms/eta_reduction.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,11 @@
#
# SPDX-License-Identifier: GPL-3.0-or-later

from gt4py.eve import NodeTranslator
from gt4py.eve import NodeTranslator, PreserveLocationVisitor
from gt4py.next.iterator import ir


class EtaReduction(NodeTranslator):
class EtaReduction(PreserveLocationVisitor, NodeTranslator):
"""Eta reduction: simplifies `λ(args...) → f(args...)` to `f`."""

def visit_Lambda(self, node: ir.Lambda) -> ir.Node:
Expand Down
4 changes: 3 additions & 1 deletion src/gt4py/next/iterator/transforms/fuse_maps.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def _is_reduce(node: ir.Node) -> TypeGuard[ir.FunCall]:


@dataclasses.dataclass(frozen=True)
class FuseMaps(traits.VisitorWithSymbolTableTrait, NodeTranslator):
class FuseMaps(traits.PreserveLocationVisitor, traits.VisitorWithSymbolTableTrait, NodeTranslator):
"""
Fuses nested `map_`s.
Expand Down Expand Up @@ -66,6 +66,7 @@ def _as_lambda(self, fun: ir.SymRef | ir.Lambda, param_count: int) -> ir.Lambda:
return ir.Lambda(
params=params,
expr=ir.FunCall(fun=fun, args=[ir.SymRef(id=p.id) for p in params]),
location=fun.location,
)

def visit_FunCall(self, node: ir.FunCall, **kwargs):
Expand Down Expand Up @@ -99,6 +100,7 @@ def visit_FunCall(self, node: ir.FunCall, **kwargs):
ir.FunCall(
fun=inner_op,
args=[ir.SymRef(id=param.id) for param in inner_op.params],
location=node.location,
)
)
)
Expand Down
10 changes: 8 additions & 2 deletions src/gt4py/next/iterator/transforms/global_tmps.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

import gt4py.eve as eve
import gt4py.next as gtx
from gt4py.eve import Coerced, NodeTranslator
from gt4py.eve import Coerced, NodeTranslator, PreserveLocationVisitor
from gt4py.eve.traits import SymbolTableTrait
from gt4py.eve.utils import UIDGenerator
from gt4py.next import common
Expand Down Expand Up @@ -267,6 +267,7 @@ def split_closures(node: ir.FencilDefinition, offset_provider) -> FencilWithTemp
stencil=stencil,
output=im.ref(tmp_sym.id),
inputs=[closure_param_arg_mapping[param.id] for param in lift_expr.args], # type: ignore[attr-defined]
location=current_closure.location,
)
)

Expand Down Expand Up @@ -294,6 +295,7 @@ def split_closures(node: ir.FencilDefinition, offset_provider) -> FencilWithTemp
output=current_closure.output,
inputs=current_closure.inputs
+ [ir.SymRef(id=sym.id) for sym in extracted_lifts.keys()],
location=current_closure.location,
)
)
else:
Expand All @@ -307,6 +309,7 @@ def split_closures(node: ir.FencilDefinition, offset_provider) -> FencilWithTemp
+ [ir.Sym(id=tmp.id) for tmp in tmps]
+ [ir.Sym(id=AUTO_DOMAIN.fun.id)], # type: ignore[attr-defined] # value is a global constant
closures=list(reversed(closures)),
location=node.location,
),
params=node.params,
tmps=[Temporary(id=tmp.id) for tmp in tmps],
Expand All @@ -333,6 +336,7 @@ def prune_unused_temporaries(node: FencilWithTemporaries) -> FencilWithTemporari
function_definitions=node.fencil.function_definitions,
params=[p for p in node.fencil.params if p.id not in unused_tmps],
closures=closures,
location=node.fencil.location,
),
params=node.params,
tmps=[tmp for tmp in node.tmps if tmp.id not in unused_tmps],
Expand Down Expand Up @@ -456,6 +460,7 @@ def update_domains(
stencil=closure.stencil,
output=closure.output,
inputs=closure.inputs,
location=closure.location,
)
else:
domain = closure.domain
Expand Down Expand Up @@ -521,6 +526,7 @@ def update_domains(
function_definitions=node.fencil.function_definitions,
params=node.fencil.params[:-1], # remove `_gtmp_auto_domain` param again
closures=list(reversed(closures)),
location=node.fencil.location,
),
params=node.params,
tmps=node.tmps,
Expand Down Expand Up @@ -580,7 +586,7 @@ def convert_type(dtype):
# TODO(tehrengruber): Add support for dynamic shifts (e.g. the distance is a symbol). This can be
# tricky: For every lift statement that is dynamically shifted we can not compute bounds anymore
# and hence also not extract as a temporary.
class CreateGlobalTmps(NodeTranslator):
class CreateGlobalTmps(PreserveLocationVisitor, NodeTranslator):
"""Main entry point for introducing global temporaries.
Transforms an existing iterator IR fencil into a fencil with global temporaries.
Expand Down
6 changes: 3 additions & 3 deletions src/gt4py/next/iterator/transforms/inline_fundefs.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,11 @@

from typing import Any, Dict, Set

from gt4py.eve import NOTHING, NodeTranslator
from gt4py.eve import NOTHING, NodeTranslator, PreserveLocationVisitor
from gt4py.next.iterator import ir


class InlineFundefs(NodeTranslator):
class InlineFundefs(PreserveLocationVisitor, NodeTranslator):
def visit_SymRef(self, node: ir.SymRef, *, symtable: Dict[str, Any]):
if node.id in symtable and isinstance((symbol := symtable[node.id]), ir.FunctionDefinition):
return ir.Lambda(
Expand All @@ -31,7 +31,7 @@ def visit_FencilDefinition(self, node: ir.FencilDefinition):
return self.generic_visit(node, symtable=node.annex.symtable)


class PruneUnreferencedFundefs(NodeTranslator):
class PruneUnreferencedFundefs(PreserveLocationVisitor, NodeTranslator):
def visit_FunctionDefinition(
self, node: ir.FunctionDefinition, *, referenced: Set[str], second_pass: bool
):
Expand Down
7 changes: 4 additions & 3 deletions src/gt4py/next/iterator/transforms/inline_into_scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,9 @@ def _lambda_and_lift_inliner(node: ir.FunCall) -> ir.FunCall:
return inlined


class InlineIntoScan(traits.VisitorWithSymbolTableTrait, NodeTranslator):
class InlineIntoScan(
traits.PreserveLocationVisitor, traits.VisitorWithSymbolTableTrait, NodeTranslator
):
"""
Inline non-SymRef arguments into the scan.
Expand Down Expand Up @@ -100,6 +102,5 @@ def visit_FunCall(self, node: ir.FunCall, **kwargs):
new_scan = ir.FunCall(
fun=ir.SymRef(id="scan"), args=[new_scanpass, *original_scan_call.args[1:]]
)
result = ir.FunCall(fun=new_scan, args=[ir.SymRef(id=ref) for ref in refs_in_args])
return result
return ir.FunCall(fun=new_scan, args=[ir.SymRef(id=ref) for ref in refs_in_args])
return self.generic_visit(node, **kwargs)
Loading

0 comments on commit d5cfa7d

Please sign in to comment.