Skip to content

Commit

Permalink
Address review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
SF-N committed Nov 5, 2024
1 parent 3f76e88 commit 87d4e94
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 19 deletions.
3 changes: 0 additions & 3 deletions src/gt4py/next/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,9 +73,6 @@ def env_flag_to_bool(name: str, default: bool) -> bool:
)


GTFN_SOURCE_CACHE_DIR: str = os.environ.get(f"{_PREFIX}_GTFN_SOURCE_CACHE_DIR", "gtfn_cache")


#: Whether generated code projects should be kept around between runs.
#: - SESSION: generated code projects get destroyed when the interpreter shuts down
#: - PERSISTENT: generated code projects are written to BUILD_CACHE_DIR and persist between runs
Expand Down
6 changes: 3 additions & 3 deletions src/gt4py/next/program_processors/runners/gtfn.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def compilation_hash(otf_closure: stages.CompilableProgram) -> int:
)


def generate_stencil_source_hash_function(inp: stages.CompilableProgram) -> str:
def fingerprint_compilable_program(inp: stages.CompilableProgram) -> str:
"""
Generates a unique hash string for a stencil source program representing
the program, sorted offset_provider, and column_axis.
Expand Down Expand Up @@ -167,8 +167,8 @@ class Params:
translation=factory.LazyAttribute(
lambda o: workflow.CachedStep(
o.translation_,
hash_function=generate_stencil_source_hash_function,
cache=FileCache(str(config.BUILD_CACHE_DIR / config.GTFN_SOURCE_CACHE_DIR)),
hash_function=fingerprint_compilable_program,
cache=FileCache(str(config.BUILD_CACHE_DIR / "gtfn_cache")),
)
),
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from gt4py.next.program_processors.runners import gtfn
from gt4py.next.type_system import type_translation
from next_tests.integration_tests import cases
from next_tests.integration_tests.feature_tests.ffront_tests.ffront_test_utils import KDim

from next_tests.integration_tests.cases import cartesian_case

Expand Down Expand Up @@ -97,9 +98,8 @@ def test_hash_and_diskcache(fencil_example):
*parameters, **{"offset_provider": {}}
),
)

hash = gtfn.generate_stencil_source_hash_function(compilable_program)
path = str(gt4py.next.config.BUILD_CACHE_DIR / gt4py.next.config.GTFN_SOURCE_CACHE_DIR)
hash = gtfn.fingerprint_compilable_program(compilable_program)
path = tempfile.gettempdir()
with diskcache.Cache(path) as cache:
cache[hash] = compilable_program

Expand All @@ -111,15 +111,27 @@ def test_hash_and_diskcache(fencil_example):
del reopened_cache[hash] # delete data

# hash creation is deterministic
assert hash == gtfn.generate_stencil_source_hash_function(compilable_program)
assert hash == gtfn.generate_stencil_source_hash_function(compilable_program_from_cache)
assert hash == gtfn.fingerprint_compilable_program(compilable_program)
assert hash == gtfn.fingerprint_compilable_program(compilable_program_from_cache)

# hash is different if program changes
altered_program = copy.deepcopy(compilable_program)
altered_program.data.id = "example2"
assert gtfn.generate_stencil_source_hash_function(
altered_program_id = copy.deepcopy(compilable_program)
altered_program_id.data.id = "example2"
assert gtfn.fingerprint_compilable_program(
compilable_program
) != gtfn.fingerprint_compilable_program(altered_program_id)

altered_program_offset_provider = copy.deepcopy(compilable_program)
object.__setattr__(altered_program_offset_provider.args, "offset_provider", {"Koff": KDim})
assert gtfn.fingerprint_compilable_program(
compilable_program
) != gtfn.generate_stencil_source_hash_function(altered_program)
) != gtfn.fingerprint_compilable_program(altered_program_offset_provider)

altered_program_column_axis = copy.deepcopy(compilable_program)
object.__setattr__(altered_program_column_axis.args, "column_axis", KDim)
assert gtfn.fingerprint_compilable_program(
compilable_program
) != gtfn.fingerprint_compilable_program(altered_program_column_axis)


def test_gtfn_file_cache(fencil_example):
Expand All @@ -138,21 +150,24 @@ def test_gtfn_file_cache(fencil_example):
gpu=False, cached=True, otf_workflow__cached_translation=False
).executor.step.translation

cached_gtfn_translation_step(
compilable_program
) # run cached translation step once to populate cache
cache_key = gtfn.fingerprint_compilable_program(compilable_program)

# ensure the actual cached step in the backend generates the cache item for the test
if cache_key in (translation_cache := cached_gtfn_translation_step.cache):
del translation_cache[hash]
cached_gtfn_translation_step(compilable_program)
assert bare_gtfn_translation_step(compilable_program) == cached_gtfn_translation_step(
compilable_program
)

cache_key = gtfn.generate_stencil_source_hash_function(compilable_program)
assert cache_key in cached_gtfn_translation_step.cache
assert (
bare_gtfn_translation_step(compilable_program)
== cached_gtfn_translation_step.cache[cache_key]
)


# TODO(egparedes): we should switch to use the cached backend by default and then remove this test
def test_gtfn_file_cache_whole_workflow(cartesian_case):
if cartesian_case.backend != gtfn.run_gtfn:
pytest.skip("Skipping backend.")
Expand Down

0 comments on commit 87d4e94

Please sign in to comment.