Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat[next][dace]: Add more debug info to DaCe #1384

Merged
merged 38 commits into from
Jan 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
caebdff
Add more debug info to DaCe
kotsaloscv Nov 7, 2023
f14d091
Merge branch 'main' into more_debinfo
kotsaloscv Nov 27, 2023
bebb122
Add more debug info to DaCe
kotsaloscv Nov 27, 2023
7eeaddb
Add more debug info to DaCe : WIP
kotsaloscv Nov 27, 2023
7447807
Add more debug info to DaCe : WIP
kotsaloscv Nov 29, 2023
6d89149
Add more debug info to DaCe (pass SourceLocation from past/foast to i…
kotsaloscv Nov 29, 2023
16bc489
Add more debug info to DaCe (pass SourceLocation from past/foast to i…
kotsaloscv Nov 30, 2023
0ed8090
Add more debug info to DaCe (pass SourceLocation from past/foast to i…
kotsaloscv Dec 4, 2023
3e38756
merge main
kotsaloscv Dec 4, 2023
0b1fe1a
Add more debug info to DaCe (pass SourceLocation from past/foast to i…
kotsaloscv Dec 4, 2023
55abb29
Add more debug info to DaCe (pass SourceLocation from past/foast to i…
kotsaloscv Dec 4, 2023
54dd79d
merge main
kotsaloscv Dec 5, 2023
774b2f5
Add more debug info to DaCe (pass SourceLocation from past/foast to i…
kotsaloscv Dec 5, 2023
1622866
Add more debug info to DaCe (pass SourceLocation from past/foast to i…
kotsaloscv Dec 5, 2023
7baedb8
Add more debug info to DaCe (pass SourceLocation from past/foast to i…
kotsaloscv Dec 5, 2023
223de4e
Add more debug info to DaCe (pass SourceLocation from past/foast to i…
kotsaloscv Dec 13, 2023
6c636ae
Add more debug info to DaCe (pass SourceLocation from past/foast to i…
kotsaloscv Dec 13, 2023
93fcb14
Add more debug info to DaCe (pass SourceLocation from past/foast to i…
kotsaloscv Dec 13, 2023
d6da4ab
Add more debug info to DaCe (pass SourceLocation from past/foast to i…
kotsaloscv Dec 13, 2023
8460c67
Add more debug info to DaCe (pass SourceLocation from past/foast to i…
kotsaloscv Dec 13, 2023
5632def
Add more debug info to DaCe (pass SourceLocation from past/foast to i…
kotsaloscv Dec 13, 2023
b59fd83
Add more debug info to DaCe (pass SourceLocation from past/foast to i…
kotsaloscv Dec 13, 2023
68dde06
Add more debug info to DaCe (pass SourceLocation from past/foast to i…
kotsaloscv Dec 13, 2023
a1a91c4
Add more debug info to DaCe (pass SourceLocation from past/foast to i…
kotsaloscv Dec 14, 2023
1ed9764
Add more debug info to DaCe (pass SourceLocation from past/foast to i…
kotsaloscv Dec 14, 2023
371dc36
Add more debug info to DaCe (pass SourceLocation from past/foast to i…
kotsaloscv Dec 14, 2023
bb880dd
Add more debug info to DaCe (pass SourceLocation from past/foast to i…
kotsaloscv Dec 14, 2023
e0a254f
Add more debug info to DaCe (pass SourceLocation from past/foast to i…
kotsaloscv Dec 14, 2023
ea2f672
merge main
kotsaloscv Jan 4, 2024
50f96a8
Add more debug info to DaCe (pass SourceLocation from past/foast to i…
kotsaloscv Jan 4, 2024
9c9c8ae
Add more debug info to DaCe (pass SourceLocation from past/foast to i…
kotsaloscv Jan 12, 2024
6fb28a1
Add more debug info to DaCe (pass SourceLocation from past/foast to i…
kotsaloscv Jan 12, 2024
3f4e9d1
Add more debug info to DaCe (pass SourceLocation from past/foast to i…
kotsaloscv Jan 12, 2024
b29bc5f
Add more debug info to DaCe (pass SourceLocation from past/foast to i…
kotsaloscv Jan 12, 2024
1a2e978
Add more debug info to DaCe (pass SourceLocation from past/foast to i…
kotsaloscv Jan 12, 2024
bf33827
Add more debug info to DaCe (pass SourceLocation from past/foast to i…
kotsaloscv Jan 12, 2024
616a833
merge main
kotsaloscv Jan 23, 2024
371b1da
Add more debug info to DaCe (pass SourceLocation from past/foast to i…
kotsaloscv Jan 23, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion src/gt4py/eve/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,12 @@
field,
frozenmodel,
)
from .traits import SymbolTableTrait, ValidatedSymbolTableTrait, VisitorWithSymbolTableTrait
from .traits import (
PreserveLocationVisitor,
SymbolTableTrait,
ValidatedSymbolTableTrait,
VisitorWithSymbolTableTrait,
)
from .trees import (
bfs_walk_items,
bfs_walk_values,
Expand Down Expand Up @@ -113,6 +118,7 @@
"SymbolTableTrait",
"ValidatedSymbolTableTrait",
"VisitorWithSymbolTableTrait",
"PreserveLocationVisitor",
# trees
"bfs_walk_items",
"bfs_walk_values",
Expand Down
8 changes: 8 additions & 0 deletions src/gt4py/eve/traits.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,3 +172,11 @@ def visit(self, node: concepts.RootNode, **kwargs: Any) -> Any:
kwargs["symtable"] = kwargs["symtable"].parents

return result


class PreserveLocationVisitor(visitors.NodeVisitor):
def visit(self, node: concepts.RootNode, **kwargs: Any) -> Any:
result = super().visit(node, **kwargs)
if hasattr(node, "location") and hasattr(result, "location"):
result.location = node.location
return result
8 changes: 4 additions & 4 deletions src/gt4py/next/ffront/foast_to_itir.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import dataclasses
from typing import Any, Callable, Optional

from gt4py.eve import NodeTranslator
from gt4py.eve import NodeTranslator, PreserveLocationVisitor
from gt4py.eve.utils import UIDGenerator
from gt4py.next.ffront import (
dialect_ast_enums,
Expand All @@ -39,7 +39,7 @@ def promote_to_list(


@dataclasses.dataclass
class FieldOperatorLowering(NodeTranslator):
class FieldOperatorLowering(PreserveLocationVisitor, NodeTranslator):
"""
Lower FieldOperator AST (FOAST) to Iterator IR (ITIR).

Expand All @@ -61,7 +61,7 @@ class FieldOperatorLowering(NodeTranslator):
<class 'gt4py.next.iterator.ir.FunctionDefinition'>
>>> lowered.id
SymbolName('fieldop')
>>> lowered.params
>>> lowered.params # doctest: +ELLIPSIS
[Sym(id=SymbolName('inp'), kind='Iterator', dtype=('float64', False))]
"""

Expand Down Expand Up @@ -142,7 +142,7 @@ def visit_IfStmt(
self, node: foast.IfStmt, *, inner_expr: Optional[itir.Expr], **kwargs
) -> itir.Expr:
# the lowered if call doesn't need to be lifted as the condition can only originate
# from a scalar value (and not a field)
# from a scalar value (and not a field)
assert (
isinstance(node.condition.type, ts.ScalarType)
and node.condition.type.kind == ts.ScalarKind.BOOL
Expand Down
20 changes: 16 additions & 4 deletions src/gt4py/next/ffront/past_to_itir.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,9 @@ def _flatten_tuple_expr(
raise ValueError("Only 'past.Name', 'past.Subscript' or 'past.TupleExpr' thereof are allowed.")


class ProgramLowering(traits.VisitorWithSymbolTableTrait, NodeTranslator):
class ProgramLowering(
traits.PreserveLocationVisitor, traits.VisitorWithSymbolTableTrait, NodeTranslator
):
"""
Lower Program AST (PAST) to Iterator IR (ITIR).

Expand Down Expand Up @@ -151,6 +153,7 @@ def _visit_stencil_call(self, node: past.Call, **kwargs) -> itir.StencilClosure:
stencil=itir.SymRef(id=node.func.id),
inputs=[*lowered_args, *lowered_kwargs.values()],
output=output,
location=node.location,
)

def _visit_slice_bound(
Expand All @@ -175,17 +178,22 @@ def _visit_slice_bound(
lowered_bound = self.visit(slice_bound, **kwargs)
else:
raise AssertionError("Expected 'None' or 'past.Constant'.")
if slice_bound:
lowered_bound.location = slice_bound.location
return lowered_bound

def _construct_itir_out_arg(self, node: past.Expr) -> itir.Expr:
if isinstance(node, past.Name):
return itir.SymRef(id=node.id)
return itir.SymRef(id=node.id, location=node.location)
elif isinstance(node, past.Subscript):
return self._construct_itir_out_arg(node.value)
itir_node = self._construct_itir_out_arg(node.value)
itir_node.location = node.location
return itir_node
elif isinstance(node, past.TupleExpr):
return itir.FunCall(
fun=itir.SymRef(id="make_tuple"),
args=[self._construct_itir_out_arg(el) for el in node.elts],
location=node.location,
)
else:
raise ValueError(
Expand Down Expand Up @@ -247,7 +255,11 @@ def _construct_itir_domain_arg(
else:
raise AssertionError()

return itir.FunCall(fun=itir.SymRef(id=domain_builtin), args=domain_args)
return itir.FunCall(
fun=itir.SymRef(id=domain_builtin),
args=domain_args,
location=(node_domain or out_field).location,
)

def _construct_itir_initialized_domain_arg(
self,
Expand Down
3 changes: 3 additions & 0 deletions src/gt4py/next/iterator/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,15 @@

import gt4py.eve as eve
from gt4py.eve import Coerced, SymbolName, SymbolRef, datamodels
from gt4py.eve.concepts import SourceLocation
from gt4py.eve.traits import SymbolTableTrait, ValidatedSymbolTableTrait
from gt4py.eve.utils import noninstantiable


@noninstantiable
class Node(eve.Node):
location: Optional[SourceLocation] = eve.field(default=None, repr=False, compare=False)

def __str__(self) -> str:
from gt4py.next.iterator.pretty_printer import pformat

Expand Down
2 changes: 1 addition & 1 deletion src/gt4py/next/iterator/transforms/collapse_list_get.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from gt4py.next.iterator import ir


class CollapseListGet(eve.NodeTranslator):
class CollapseListGet(eve.PreserveLocationVisitor, eve.NodeTranslator):
"""Simplifies expressions containing `list_get`.

Examples
Expand Down
9 changes: 1 addition & 8 deletions src/gt4py/next/iterator/transforms/collapse_tuple.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def _get_tuple_size(elem: ir.Node, node_types: Optional[dict] = None) -> int | t


@dataclass(frozen=True)
class CollapseTuple(eve.NodeTranslator):
class CollapseTuple(eve.PreserveLocationVisitor, eve.NodeTranslator):
"""
Simplifies `make_tuple`, `tuple_get` calls.

Expand Down Expand Up @@ -88,13 +88,6 @@ def apply(
node_types,
).visit(node)

return cls(
ignore_tuple_size,
collapse_make_tuple_tuple_get,
collapse_tuple_get_make_tuple,
use_global_type_inference,
).visit(node)

def visit_FunCall(self, node: ir.FunCall, **kwargs) -> ir.Node:
if (
self.collapse_make_tuple_tuple_get
Expand Down
4 changes: 2 additions & 2 deletions src/gt4py/next/iterator/transforms/constant_folding.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,12 @@
#
# SPDX-License-Identifier: GPL-3.0-or-later

from gt4py.eve import NodeTranslator
from gt4py.eve import NodeTranslator, PreserveLocationVisitor
from gt4py.next.iterator import embedded, ir
from gt4py.next.iterator.ir_utils import ir_makers as im


class ConstantFolding(NodeTranslator):
class ConstantFolding(PreserveLocationVisitor, NodeTranslator):
@classmethod
def apply(cls, node: ir.Node) -> ir.Node:
return cls().visit(node)
Expand Down
14 changes: 10 additions & 4 deletions src/gt4py/next/iterator/transforms/cse.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,20 @@
import operator
import typing

from gt4py.eve import NodeTranslator, NodeVisitor, SymbolTableTrait, VisitorWithSymbolTableTrait
from gt4py.eve import (
NodeTranslator,
NodeVisitor,
PreserveLocationVisitor,
SymbolTableTrait,
VisitorWithSymbolTableTrait,
)
from gt4py.eve.utils import UIDGenerator
from gt4py.next.iterator import ir
from gt4py.next.iterator.transforms.inline_lambdas import inline_lambda


@dataclasses.dataclass
class _NodeReplacer(NodeTranslator):
class _NodeReplacer(PreserveLocationVisitor, NodeTranslator):
PRESERVED_ANNEX_ATTRS = ("type",)

expr_map: dict[int, ir.SymRef]
Expand Down Expand Up @@ -72,7 +78,7 @@ def _is_collectable_expr(node: ir.Node) -> bool:


@dataclasses.dataclass
class CollectSubexpressions(VisitorWithSymbolTableTrait, NodeVisitor):
class CollectSubexpressions(PreserveLocationVisitor, VisitorWithSymbolTableTrait, NodeVisitor):
@dataclasses.dataclass
class SubexpressionData:
#: A list of node ids with equal hash and a set of collected child subexpression ids
Expand Down Expand Up @@ -341,7 +347,7 @@ def extract_subexpression(


@dataclasses.dataclass(frozen=True)
class CommonSubexpressionElimination(NodeTranslator):
class CommonSubexpressionElimination(PreserveLocationVisitor, NodeTranslator):
"""
Perform common subexpression elimination.

Expand Down
4 changes: 2 additions & 2 deletions src/gt4py/next/iterator/transforms/eta_reduction.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,11 @@
#
# SPDX-License-Identifier: GPL-3.0-or-later

from gt4py.eve import NodeTranslator
from gt4py.eve import NodeTranslator, PreserveLocationVisitor
from gt4py.next.iterator import ir


class EtaReduction(NodeTranslator):
class EtaReduction(PreserveLocationVisitor, NodeTranslator):
"""Eta reduction: simplifies `λ(args...) → f(args...)` to `f`."""

def visit_Lambda(self, node: ir.Lambda) -> ir.Node:
Expand Down
4 changes: 3 additions & 1 deletion src/gt4py/next/iterator/transforms/fuse_maps.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def _is_reduce(node: ir.Node) -> TypeGuard[ir.FunCall]:


@dataclasses.dataclass(frozen=True)
class FuseMaps(traits.VisitorWithSymbolTableTrait, NodeTranslator):
class FuseMaps(traits.PreserveLocationVisitor, traits.VisitorWithSymbolTableTrait, NodeTranslator):
"""
Fuses nested `map_`s.

Expand Down Expand Up @@ -66,6 +66,7 @@ def _as_lambda(self, fun: ir.SymRef | ir.Lambda, param_count: int) -> ir.Lambda:
return ir.Lambda(
params=params,
expr=ir.FunCall(fun=fun, args=[ir.SymRef(id=p.id) for p in params]),
location=fun.location,
)

def visit_FunCall(self, node: ir.FunCall, **kwargs):
Expand Down Expand Up @@ -99,6 +100,7 @@ def visit_FunCall(self, node: ir.FunCall, **kwargs):
ir.FunCall(
fun=inner_op,
args=[ir.SymRef(id=param.id) for param in inner_op.params],
location=node.location,
)
)
)
Expand Down
10 changes: 8 additions & 2 deletions src/gt4py/next/iterator/transforms/global_tmps.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

import gt4py.eve as eve
import gt4py.next as gtx
from gt4py.eve import Coerced, NodeTranslator
from gt4py.eve import Coerced, NodeTranslator, PreserveLocationVisitor
from gt4py.eve.traits import SymbolTableTrait
from gt4py.eve.utils import UIDGenerator
from gt4py.next import common
Expand Down Expand Up @@ -267,6 +267,7 @@ def split_closures(node: ir.FencilDefinition, offset_provider) -> FencilWithTemp
stencil=stencil,
output=im.ref(tmp_sym.id),
inputs=[closure_param_arg_mapping[param.id] for param in lift_expr.args], # type: ignore[attr-defined]
location=current_closure.location,
)
)

Expand Down Expand Up @@ -294,6 +295,7 @@ def split_closures(node: ir.FencilDefinition, offset_provider) -> FencilWithTemp
output=current_closure.output,
inputs=current_closure.inputs
+ [ir.SymRef(id=sym.id) for sym in extracted_lifts.keys()],
location=current_closure.location,
)
)
else:
Expand All @@ -307,6 +309,7 @@ def split_closures(node: ir.FencilDefinition, offset_provider) -> FencilWithTemp
+ [ir.Sym(id=tmp.id) for tmp in tmps]
+ [ir.Sym(id=AUTO_DOMAIN.fun.id)], # type: ignore[attr-defined] # value is a global constant
closures=list(reversed(closures)),
location=node.location,
),
params=node.params,
tmps=[Temporary(id=tmp.id) for tmp in tmps],
Expand All @@ -333,6 +336,7 @@ def prune_unused_temporaries(node: FencilWithTemporaries) -> FencilWithTemporari
function_definitions=node.fencil.function_definitions,
params=[p for p in node.fencil.params if p.id not in unused_tmps],
closures=closures,
location=node.fencil.location,
),
params=node.params,
tmps=[tmp for tmp in node.tmps if tmp.id not in unused_tmps],
Expand Down Expand Up @@ -456,6 +460,7 @@ def update_domains(
stencil=closure.stencil,
output=closure.output,
inputs=closure.inputs,
location=closure.location,
)
else:
domain = closure.domain
Expand Down Expand Up @@ -521,6 +526,7 @@ def update_domains(
function_definitions=node.fencil.function_definitions,
params=node.fencil.params[:-1], # remove `_gtmp_auto_domain` param again
closures=list(reversed(closures)),
location=node.fencil.location,
),
params=node.params,
tmps=node.tmps,
Expand Down Expand Up @@ -580,7 +586,7 @@ def convert_type(dtype):
# TODO(tehrengruber): Add support for dynamic shifts (e.g. the distance is a symbol). This can be
# tricky: For every lift statement that is dynamically shifted we can not compute bounds anymore
# and hence also not extract as a temporary.
class CreateGlobalTmps(NodeTranslator):
class CreateGlobalTmps(PreserveLocationVisitor, NodeTranslator):
"""Main entry point for introducing global temporaries.

Transforms an existing iterator IR fencil into a fencil with global temporaries.
Expand Down
6 changes: 3 additions & 3 deletions src/gt4py/next/iterator/transforms/inline_fundefs.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,11 @@

from typing import Any, Dict, Set

from gt4py.eve import NOTHING, NodeTranslator
from gt4py.eve import NOTHING, NodeTranslator, PreserveLocationVisitor
from gt4py.next.iterator import ir


class InlineFundefs(NodeTranslator):
class InlineFundefs(PreserveLocationVisitor, NodeTranslator):
def visit_SymRef(self, node: ir.SymRef, *, symtable: Dict[str, Any]):
if node.id in symtable and isinstance((symbol := symtable[node.id]), ir.FunctionDefinition):
return ir.Lambda(
Expand All @@ -31,7 +31,7 @@ def visit_FencilDefinition(self, node: ir.FencilDefinition):
return self.generic_visit(node, symtable=node.annex.symtable)


class PruneUnreferencedFundefs(NodeTranslator):
class PruneUnreferencedFundefs(PreserveLocationVisitor, NodeTranslator):
def visit_FunctionDefinition(
self, node: ir.FunctionDefinition, *, referenced: Set[str], second_pass: bool
):
Expand Down
7 changes: 4 additions & 3 deletions src/gt4py/next/iterator/transforms/inline_into_scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,9 @@ def _lambda_and_lift_inliner(node: ir.FunCall) -> ir.FunCall:
return inlined


class InlineIntoScan(traits.VisitorWithSymbolTableTrait, NodeTranslator):
class InlineIntoScan(
traits.PreserveLocationVisitor, traits.VisitorWithSymbolTableTrait, NodeTranslator
):
"""
Inline non-SymRef arguments into the scan.

Expand Down Expand Up @@ -100,6 +102,5 @@ def visit_FunCall(self, node: ir.FunCall, **kwargs):
new_scan = ir.FunCall(
fun=ir.SymRef(id="scan"), args=[new_scanpass, *original_scan_call.args[1:]]
)
result = ir.FunCall(fun=new_scan, args=[ir.SymRef(id=ref) for ref in refs_in_args])
return result
return ir.FunCall(fun=new_scan, args=[ir.SymRef(id=ref) for ref in refs_in_args])
return self.generic_visit(node, **kwargs)
Loading
Loading