diff --git a/src/gt4py/next/iterator/transforms/global_tmps.py b/src/gt4py/next/iterator/transforms/global_tmps.py index 8826f92a4c..9b67be2f69 100644 --- a/src/gt4py/next/iterator/transforms/global_tmps.py +++ b/src/gt4py/next/iterator/transforms/global_tmps.py @@ -191,9 +191,7 @@ def _ensure_expr_does_not_capture(expr: ir.Expr, whitelist: list[ir.Sym]) -> Non assert not (set(used_symbol_refs) - {param.id for param in whitelist}) -def split_closures( - node: ir.FencilDefinition, offset_provider, symbolic_sizes -) -> FencilWithTemporaries: +def split_closures(node: ir.FencilDefinition, offset_provider) -> FencilWithTemporaries: """Split closures on lifted function calls and introduce new temporary buffers for return values. Newly introduced temporaries will have the symbolic size of `AUTO_DOMAIN`. A symbol with the @@ -438,7 +436,7 @@ def _group_offsets( return zip(tags, offsets, strict=True) # type: ignore[return-value] # mypy doesn't infer literal correctly -def update_domains(node: FencilWithTemporaries, offset_provider: Mapping[str, Any]): +def update_domains(node: FencilWithTemporaries, offset_provider: Mapping[str, Any], symbolic_sizes): horizontal_sizes = _max_domain_sizes_by_location_type(offset_provider) closures: list[ir.StencilClosure] = [] @@ -489,7 +487,7 @@ def update_domains(node: FencilWithTemporaries, offset_provider: Mapping[str, An assert new_axis not in consumed_domain.ranges consumed_domain.ranges[new_axis] = SymbolicRange( im.literal("0", ir.INTEGER_INDEX_BUILTIN), - im.literal(str(horizontal_sizes[new_axis]), ir.INTEGER_INDEX_BUILTIN), + symbolic_sizes[new_axis], ) else: raise NotImplementedError @@ -578,7 +576,7 @@ def visit_FencilDefinition( symbolic_sizes: dict[str, str], ) -> FencilWithTemporaries: # Split closures on lifted function calls and introduce temporaries - res = split_closures(node, offset_provider=offset_provider, symbolic_sizes=symbolic_sizes) + res = split_closures(node, offset_provider=offset_provider) # Prune unreferences closure inputs introduced in the previous step res = PruneClosureInputs().visit(res) # Prune unused temporaries possibly introduced in the previous step @@ -586,6 +584,6 @@ def visit_FencilDefinition( # Perform an eta-reduction which should put all calls at the highest level of a closure res = EtaReduction().visit(res) # Perform a naive extent analysis to compute domain sizes of closures and temporaries - res = update_domains(res, offset_provider) + res = update_domains(res, offset_provider, symbolic_sizes) # Use type inference to determine the data type of the temporaries return collect_tmps_info(res, offset_provider=offset_provider) diff --git a/src/gt4py/next/program_processors/runners/gtfn.py b/src/gt4py/next/program_processors/runners/gtfn.py index 7233e7a893..446b3ac283 100644 --- a/src/gt4py/next/program_processors/runners/gtfn.py +++ b/src/gt4py/next/program_processors/runners/gtfn.py @@ -174,6 +174,23 @@ def compilation_hash(otf_closure: stages.ProgramCall) -> int: allocator=next_allocators.StandardCPUFieldBufferAllocator(), ) +# todo: remove, used temporarily for temporaries in blue icon4py +run_gtfn_with_temporaries_and_sizes = otf_compile_executor.OTFBackend( + executor=otf_compile_executor.OTFCompileExecutor( + name="run_gtfn_with_temporaries_and_sizes", + otf_workflow=run_gtfn_with_temporaries.executor.otf_workflow.replace( + translation=run_gtfn_with_temporaries.executor.otf_workflow.translation.replace( + symbolic_domain_sizes={ + "Cell": "num_cells", + "Edge": "num_edges", + "Vertex": "num_vertices", + }, + ), + ), + ), + allocator=run_gtfn_with_temporaries.allocator, +) + gtfn_gpu_executor = otf_compile_executor.OTFCompileExecutor( name="run_gtfn_gpu", otf_workflow=GTFN_GPU_WORKFLOW ) diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py index 929d2d5d6e..c2466f5a4c 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py @@ -19,6 +19,7 @@ import pytest import gt4py.next as gtx +from gt4py.eve import SymbolRef from gt4py.next import ( astype, broadcast, @@ -31,12 +32,8 @@ neighbor_sum, where, NeighborTableOffsetProvider, - common, ) -from gt4py.next.common import Domain, UnitRange, Dimension, DimensionKind, GridType from gt4py.next.ffront.experimental import as_offset -from gt4py.next.iterator.transforms import LiftMode -from gt4py.next.program_processors import otf_compile_executor from gt4py.next.program_processors.runners import gtfn from next_tests.integration_tests import cases @@ -60,7 +57,7 @@ reduction_setup, ) -from gt4py.next.program_processors.runners.gtfn import run_gtfn_with_temporaries, gtfn_executor +from gt4py.next.program_processors.runners.gtfn import run_gtfn_with_temporaries_and_sizes from tests.next_tests.integration_tests.cases import Case from tests.next_tests.toy_connectivity import Edge, Cell @@ -1035,30 +1032,6 @@ def consume_constants(input: cases.IFloatField) -> cases.IFloatField: def test_temporaries_with_sizes(reduction_setup): - run_gtfn_with_temporaries_and_sizes = otf_compile_executor.OTFBackend( - executor=otf_compile_executor.OTFCompileExecutor( - name="run_gtfn_with_temporaries_and_sizes", - otf_workflow=run_gtfn_with_temporaries.executor.otf_workflow.replace( - translation=run_gtfn_with_temporaries.executor.otf_workflow.translation.replace( - symbolic_domain_sizes={"Cell": "num_cells", "Edge": "num_edges", "Vertex": "num_vertices"}, - ), - ), - ), - allocator=run_gtfn_with_temporaries.allocator, - ) - - unstructured_case = Case( - run_gtfn_with_temporaries_and_sizes, - offset_provider=reduction_setup.offset_provider, - default_sizes={ - Vertex: reduction_setup.num_vertices, - Edge: reduction_setup.num_edges, - Cell: reduction_setup.num_cells, - KDim: reduction_setup.k_levels, - }, - grid_type=common.GridType.UNSTRUCTURED, - ) - @gtx.field_operator def testee_op(a: cases.VField) -> cases.EField: amul = a * 2 @@ -1068,18 +1041,18 @@ def testee_op(a: cases.VField) -> cases.EField: def testee(a: cases.VField, out: cases.EField, num_vertices: int): testee_op(a, out=out) - ir = testee.itir - e2v_offset_provider = {"E2V": NeighborTableOffsetProvider(table=reduction_setup.e2v_table, origin_axis=Edge, - neighbor_axis=Vertex, max_neighbors=2)} - - ir_with_tmp = run_gtfn_with_temporaries_and_sizes.executor.otf_workflow.translation._preprocess_itir(testee.itir, e2v_offset_provider, False) - + e2v_offset_provider = { + "E2V": NeighborTableOffsetProvider( + table=reduction_setup.e2v_table, origin_axis=Edge, neighbor_axis=Vertex, max_neighbors=2 + ) + } - # todo: check that symbols are in itir + ir_with_tmp = ( + run_gtfn_with_temporaries_and_sizes.executor.otf_workflow.translation._preprocess_itir( + testee.itir, e2v_offset_provider, False + ) + ) + sym = ir_with_tmp.tmps[0].domain.args[0].args[2].args[0].id - # cases.verify_with_default_data( - # unstructured_case, - # testee, - # ref=lambda a: (a * 2)[unstructured_case.offset_provider["E2V"].table[:, 0]] - # + (a * 2)[unstructured_case.offset_provider["E2V"].table[:, 1]], - # ) + assert sym == "num_vertices" + assert isinstance(sym, SymbolRef)