From ba1a91f5d5c5bc9ea4f2c8fdaab15206bba72673 Mon Sep 17 00:00:00 2001 From: nfarabullini Date: Tue, 16 Jan 2024 16:32:27 +0100 Subject: [PATCH] edit to offset_invariants --- .../runners/dace_iterator/__init__.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py b/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py index d458be49f8..b9c15f1fb9 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py @@ -168,19 +168,20 @@ def get_cache_id( offset_provider: Mapping[str, Any], ) -> str: def offset_invariants(offset): - from gt4py.next.ffront import fbuiltins - if isinstance(offset, itir_embedded.NeighborTableOffsetProvider): + if isinstance( + offset, + ( + itir_embedded.NeighborTableOffsetProvider, + itir_embedded.StridedNeighborOffsetProvider, + ), + ): return offset.origin_axis, offset.neighbor_axis, offset.max_neighbors - if isinstance(offset, itir_embedded.StridedNeighborOffsetProvider): - return offset.origin_axis, offset.neighbor_axis, offset.max_neighbors - if isinstance(offset, fbuiltins.FieldOffset): - return offset.source, offset.target if isinstance(offset, common.Dimension): - return offset, + return (offset,) return tuple() + offset_cache_keys = [ - (name, offset_invariants(offset)) - for name, offset in offset_provider.items() + (name, offset_invariants(offset)) for name, offset in offset_provider.items() ] cache_id_args = [ str(arg)