Skip to content

Commit

Permalink
Use symbol in domain for temporary
Browse files Browse the repository at this point in the history
  • Loading branch information
samkellerhals committed Nov 16, 2023
1 parent e85c273 commit 75ba912
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 49 deletions.
12 changes: 5 additions & 7 deletions src/gt4py/next/iterator/transforms/global_tmps.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,9 +191,7 @@ def _ensure_expr_does_not_capture(expr: ir.Expr, whitelist: list[ir.Sym]) -> Non
assert not (set(used_symbol_refs) - {param.id for param in whitelist})


def split_closures(
node: ir.FencilDefinition, offset_provider, symbolic_sizes
) -> FencilWithTemporaries:
def split_closures(node: ir.FencilDefinition, offset_provider) -> FencilWithTemporaries:
"""Split closures on lifted function calls and introduce new temporary buffers for return values.
Newly introduced temporaries will have the symbolic size of `AUTO_DOMAIN`. A symbol with the
Expand Down Expand Up @@ -438,7 +436,7 @@ def _group_offsets(
return zip(tags, offsets, strict=True) # type: ignore[return-value] # mypy doesn't infer literal correctly


def update_domains(node: FencilWithTemporaries, offset_provider: Mapping[str, Any]):
def update_domains(node: FencilWithTemporaries, offset_provider: Mapping[str, Any], symbolic_sizes):
horizontal_sizes = _max_domain_sizes_by_location_type(offset_provider)

closures: list[ir.StencilClosure] = []
Expand Down Expand Up @@ -489,7 +487,7 @@ def update_domains(node: FencilWithTemporaries, offset_provider: Mapping[str, An
assert new_axis not in consumed_domain.ranges
consumed_domain.ranges[new_axis] = SymbolicRange(
im.literal("0", ir.INTEGER_INDEX_BUILTIN),
im.literal(str(horizontal_sizes[new_axis]), ir.INTEGER_INDEX_BUILTIN),
symbolic_sizes[new_axis],
)
else:
raise NotImplementedError
Expand Down Expand Up @@ -578,14 +576,14 @@ def visit_FencilDefinition(
symbolic_sizes: dict[str, str],
) -> FencilWithTemporaries:
# Split closures on lifted function calls and introduce temporaries
res = split_closures(node, offset_provider=offset_provider, symbolic_sizes=symbolic_sizes)
res = split_closures(node, offset_provider=offset_provider)
# Prune unreferences closure inputs introduced in the previous step
res = PruneClosureInputs().visit(res)
# Prune unused temporaries possibly introduced in the previous step
res = prune_unused_temporaries(res)
# Perform an eta-reduction which should put all calls at the highest level of a closure
res = EtaReduction().visit(res)
# Perform a naive extent analysis to compute domain sizes of closures and temporaries
res = update_domains(res, offset_provider)
res = update_domains(res, offset_provider, symbolic_sizes)
# Use type inference to determine the data type of the temporaries
return collect_tmps_info(res, offset_provider=offset_provider)
17 changes: 17 additions & 0 deletions src/gt4py/next/program_processors/runners/gtfn.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,23 @@ def compilation_hash(otf_closure: stages.ProgramCall) -> int:
allocator=next_allocators.StandardCPUFieldBufferAllocator(),
)

# todo: remove, used temporarily for temporaries in blue icon4py
run_gtfn_with_temporaries_and_sizes = otf_compile_executor.OTFBackend(
executor=otf_compile_executor.OTFCompileExecutor(
name="run_gtfn_with_temporaries_and_sizes",
otf_workflow=run_gtfn_with_temporaries.executor.otf_workflow.replace(
translation=run_gtfn_with_temporaries.executor.otf_workflow.translation.replace(
symbolic_domain_sizes={
"Cell": "num_cells",
"Edge": "num_edges",
"Vertex": "num_vertices",
},
),
),
),
allocator=run_gtfn_with_temporaries.allocator,
)

gtfn_gpu_executor = otf_compile_executor.OTFCompileExecutor(
name="run_gtfn_gpu", otf_workflow=GTFN_GPU_WORKFLOW
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import pytest

import gt4py.next as gtx
from gt4py.eve import SymbolRef
from gt4py.next import (
astype,
broadcast,
Expand All @@ -31,12 +32,8 @@
neighbor_sum,
where,
NeighborTableOffsetProvider,
common,
)
from gt4py.next.common import Domain, UnitRange, Dimension, DimensionKind, GridType
from gt4py.next.ffront.experimental import as_offset
from gt4py.next.iterator.transforms import LiftMode
from gt4py.next.program_processors import otf_compile_executor
from gt4py.next.program_processors.runners import gtfn

from next_tests.integration_tests import cases
Expand All @@ -60,7 +57,7 @@
reduction_setup,
)

from gt4py.next.program_processors.runners.gtfn import run_gtfn_with_temporaries, gtfn_executor
from gt4py.next.program_processors.runners.gtfn import run_gtfn_with_temporaries_and_sizes
from tests.next_tests.integration_tests.cases import Case
from tests.next_tests.toy_connectivity import Edge, Cell

Expand Down Expand Up @@ -1035,30 +1032,6 @@ def consume_constants(input: cases.IFloatField) -> cases.IFloatField:


def test_temporaries_with_sizes(reduction_setup):
run_gtfn_with_temporaries_and_sizes = otf_compile_executor.OTFBackend(
executor=otf_compile_executor.OTFCompileExecutor(
name="run_gtfn_with_temporaries_and_sizes",
otf_workflow=run_gtfn_with_temporaries.executor.otf_workflow.replace(
translation=run_gtfn_with_temporaries.executor.otf_workflow.translation.replace(
symbolic_domain_sizes={"Cell": "num_cells", "Edge": "num_edges", "Vertex": "num_vertices"},
),
),
),
allocator=run_gtfn_with_temporaries.allocator,
)

unstructured_case = Case(
run_gtfn_with_temporaries_and_sizes,
offset_provider=reduction_setup.offset_provider,
default_sizes={
Vertex: reduction_setup.num_vertices,
Edge: reduction_setup.num_edges,
Cell: reduction_setup.num_cells,
KDim: reduction_setup.k_levels,
},
grid_type=common.GridType.UNSTRUCTURED,
)

@gtx.field_operator
def testee_op(a: cases.VField) -> cases.EField:
amul = a * 2
Expand All @@ -1068,18 +1041,18 @@ def testee_op(a: cases.VField) -> cases.EField:
def testee(a: cases.VField, out: cases.EField, num_vertices: int):
testee_op(a, out=out)

ir = testee.itir
e2v_offset_provider = {"E2V": NeighborTableOffsetProvider(table=reduction_setup.e2v_table, origin_axis=Edge,
neighbor_axis=Vertex, max_neighbors=2)}

ir_with_tmp = run_gtfn_with_temporaries_and_sizes.executor.otf_workflow.translation._preprocess_itir(testee.itir, e2v_offset_provider, False)

e2v_offset_provider = {
"E2V": NeighborTableOffsetProvider(
table=reduction_setup.e2v_table, origin_axis=Edge, neighbor_axis=Vertex, max_neighbors=2
)
}

# todo: check that symbols are in itir
ir_with_tmp = (
run_gtfn_with_temporaries_and_sizes.executor.otf_workflow.translation._preprocess_itir(
testee.itir, e2v_offset_provider, False
)
)
sym = ir_with_tmp.tmps[0].domain.args[0].args[2].args[0].id

# cases.verify_with_default_data(
# unstructured_case,
# testee,
# ref=lambda a: (a * 2)[unstructured_case.offset_provider["E2V"].table[:, 0]]
# + (a * 2)[unstructured_case.offset_provider["E2V"].table[:, 1]],
# )
assert sym == "num_vertices"
assert isinstance(sym, SymbolRef)

0 comments on commit 75ba912

Please sign in to comment.