From 9feb51db27bde798245d3f80f4075e622bd42173 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Mon, 14 Oct 2024 17:01:26 +0200 Subject: [PATCH] [next]: Fix inline lambda pass opcount preserving option (#1687) In #1531 the `itir.Node` class got a `type` attribute, that until now contributed to the hash computation of all nodes. As such two `itir.SymRef` with the same `id`, but one with a type inferred and one without (i.e. `None`) got a different hash value. Consequently the `inline_lambda` pass did not recognize them as a reference to the same symbol and erroneously inlined the expression even with `opcount_preserving=True`. This PR fixes the hash computation, such that again `node1 == node2` implies `hash(node1) == hash(node2)`. --- src/gt4py/next/iterator/ir.py | 12 ++++++++---- .../transforms_tests/test_inline_lambdas.py | 16 ++++++++++++++++ 2 files changed, 24 insertions(+), 4 deletions(-) diff --git a/src/gt4py/next/iterator/ir.py b/src/gt4py/next/iterator/ir.py index b2a549501f..42da4c83a6 100644 --- a/src/gt4py/next/iterator/ir.py +++ b/src/gt4py/next/iterator/ir.py @@ -37,10 +37,14 @@ def __str__(self) -> str: return pformat(self) def __hash__(self) -> int: - return hash(type(self)) ^ hash( - tuple( - hash(tuple(v)) if isinstance(v, list) else hash(v) - for v in self.iter_children_values() + return hash( + ( + type(self), + *( + tuple(v) if isinstance(v, list) else v + for (k, v) in self.iter_children_items() + if k not in ["location", "type"] + ), ) ) diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_inline_lambdas.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_inline_lambdas.py index e45281734b..2e0a83d33b 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_inline_lambdas.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_inline_lambdas.py @@ -8,6 +8,7 @@ import pytest +from gt4py.next.type_system import type_specifications as ts from gt4py.next.iterator.ir_utils import ir_makers as im from gt4py.next.iterator.transforms.inline_lambdas import InlineLambdas @@ -39,6 +40,21 @@ ), im.multiplies_(im.plus(2, 1), im.plus("x", "x")), ), + ( + # ensure opcount preserving option works whether `itir.SymRef` has a type or not + "typed_ref", + im.let("a", im.call("opaque")())( + im.plus(im.ref("a", ts.ScalarType(kind=ts.ScalarKind.FLOAT32)), im.ref("a", None)) + ), + { + True: im.let("a", im.call("opaque")())( + im.plus( # stays as is + im.ref("a", ts.ScalarType(kind=ts.ScalarKind.FLOAT32)), im.ref("a", None) + ) + ), + False: im.plus(im.call("opaque")(), im.call("opaque")()), + }, + ), ]