Skip to content

Commit

Permalink
edit to offset_invariants
Browse files Browse the repository at this point in the history
  • Loading branch information
nfarabullini committed Jan 16, 2024
1 parent 7261ae7 commit ba1a91f
Showing 1 changed file with 10 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -168,19 +168,20 @@ def get_cache_id(
offset_provider: Mapping[str, Any],
) -> str:
def offset_invariants(offset):
from gt4py.next.ffront import fbuiltins
if isinstance(offset, itir_embedded.NeighborTableOffsetProvider):
if isinstance(
offset,
(
itir_embedded.NeighborTableOffsetProvider,
itir_embedded.StridedNeighborOffsetProvider,
),
):
return offset.origin_axis, offset.neighbor_axis, offset.max_neighbors
if isinstance(offset, itir_embedded.StridedNeighborOffsetProvider):
return offset.origin_axis, offset.neighbor_axis, offset.max_neighbors
if isinstance(offset, fbuiltins.FieldOffset):
return offset.source, offset.target
if isinstance(offset, common.Dimension):
return offset,
return (offset,)
return tuple()

offset_cache_keys = [
(name, offset_invariants(offset))
for name, offset in offset_provider.items()
(name, offset_invariants(offset)) for name, offset in offset_provider.items()
]
cache_id_args = [
str(arg)
Expand Down

0 comments on commit ba1a91f

Please sign in to comment.