Skip to content

Commit

Permalink
feat[next]: Add IR transform to remove unnecessary cast expressions (#…
Browse files Browse the repository at this point in the history
…1688)

Add IR transformation that removes cast expressions where the argument is
already in the target type.
  • Loading branch information
edopao authored Oct 14, 2024
1 parent ed9d82d commit b339b82
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 0 deletions.
45 changes: 45 additions & 0 deletions src/gt4py/next/iterator/transforms/prune_casts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# GT4Py - GridTools Framework
#
# Copyright (c) 2014-2024, ETH Zurich
# All rights reserved.
#
# Please, refer to the LICENSE file in the root directory.
# SPDX-License-Identifier: BSD-3-Clause

from gt4py.eve import NodeTranslator, PreserveLocationVisitor
from gt4py.next.iterator import ir
from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm
from gt4py.next.type_system import type_specifications as ts


class PruneCasts(PreserveLocationVisitor, NodeTranslator):
"""
Removes cast expressions where the argument is already in the target type.
This transformation requires the IR to be fully type-annotated,
therefore it should be applied after type-inference.
"""

def visit_FunCall(self, node: ir.FunCall) -> ir.Node:
node = self.generic_visit(node)

if not cpm.is_call_to(node, "cast_"):
return node

value, type_constructor = node.args

assert (
value.type
and isinstance(type_constructor, ir.SymRef)
and (type_constructor.id in ir.TYPEBUILTINS)
)
dtype = ts.ScalarType(kind=getattr(ts.ScalarKind, type_constructor.id.upper()))

if value.type == dtype:
return value

return node

@classmethod
def apply(cls, node: ir.Node) -> ir.Node:
return cls().visit(node)
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from gt4py.next import common as gtx_common, utils as gtx_utils
from gt4py.next.iterator import ir as gtir
from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm
from gt4py.next.iterator.transforms import prune_casts as ir_prune_casts
from gt4py.next.iterator.type_system import inference as gtir_type_inference
from gt4py.next.program_processors.runners.dace_common import utility as dace_utils
from gt4py.next.program_processors.runners.dace_fieldview import (
Expand Down Expand Up @@ -656,7 +657,9 @@ def build_sdfg_from_gtir(
Returns:
An SDFG in the DaCe canonical form (simplified)
"""

ir = gtir_type_inference.infer(ir, offset_provider=offset_provider)
ir = ir_prune_casts.PruneCasts().visit(ir)
ir = dace_gtir_utils.patch_gtir(ir)
sdfg_genenerator = GTIRToSDFG(offset_provider)
sdfg = sdfg_genenerator.visit(ir)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# GT4Py - GridTools Framework
#
# Copyright (c) 2014-2024, ETH Zurich
# All rights reserved.
#
# Please, refer to the LICENSE file in the root directory.
# SPDX-License-Identifier: BSD-3-Clause

from gt4py.next.iterator.ir_utils import ir_makers as im
from gt4py.next.type_system import type_specifications as ts
from gt4py.next.iterator.transforms.prune_casts import PruneCasts
from gt4py.next.iterator.type_system import inference as type_inference


def test_prune_casts_simple():
x_ref = im.ref("x", ts.ScalarType(kind=ts.ScalarKind.FLOAT32))
y_ref = im.ref("y", ts.ScalarType(kind=ts.ScalarKind.FLOAT64))
testee = im.call("plus")(im.call("cast_")(x_ref, "float64"), im.call("cast_")(y_ref, "float64"))
testee = type_inference.infer(testee, offset_provider={}, allow_undeclared_symbols=True)

expected = im.call("plus")(im.call("cast_")(x_ref, "float64"), y_ref)
actual = PruneCasts.apply(testee)
assert actual == expected

0 comments on commit b339b82

Please sign in to comment.