Skip to content

Commit

Permalink
verify by hand
Browse files Browse the repository at this point in the history
  • Loading branch information
samkellerhals committed Nov 15, 2023
1 parent 4e48283 commit 141c3d5
Showing 1 changed file with 21 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,9 @@
int64,
minimum,
neighbor_sum,
where,
where, NeighborTableOffsetProvider,
)
from gt4py.next.common import Domain, UnitRange, Dimension, DimensionKind, GridType
from gt4py.next.ffront.experimental import as_offset
from gt4py.next.program_processors import otf_compile_executor
from gt4py.next.program_processors.runners import gtfn
Expand All @@ -57,6 +58,7 @@
)

from gt4py.next.program_processors.runners.gtfn import run_gtfn_with_temporaries
from tests.next_tests.toy_connectivity import Edge


def test_copy(cartesian_case): # noqa: F811 # fixtures
Expand Down Expand Up @@ -1028,8 +1030,7 @@ def consume_constants(input: cases.IFloatField) -> cases.IFloatField:
)


def test_temporaries_with_sizes(unstructured_case):
# todo: select backend
def test_temporaries_with_sizes(reduction_setup):
# run_gtfn_with_temporaries_and_sizes = otf_compile_executor.OTFCompileExecutor(
# name="run_gtfn_with_temporaries_and_sizes",
# otf_workflow=run_gtfn_with_temporaries.otf_workflow.replace(
Expand All @@ -1044,8 +1045,21 @@ def testee(a: cases.VField) -> cases.EField:
amul = a * 2
return amul(E2V[0]) + amul(E2V[1])

cases.verify_with_default_data(
unstructured_case,
testee,
ref=lambda a: (a * 2)[unstructured_case.offset_provider["E2V"].table[:, 0]] + (a * 2)[unstructured_case.offset_provider["E2V"].table[:, 1]],
@gtx.program(grid_type=GridType.UNSTRUCTURED)
def testee_program(a: cases.VField, out: cases.EField) -> cases.EField:
testee(a=a, out=out, )

a = gtx.as_field(domain=Domain(dims=(Vertex,), ranges=(UnitRange(0, 9),)), data=np.arange(0,9), dtype=int32)
out = gtx.as_field(domain=Domain(dims=(Edge,), ranges=(UnitRange(0, 18),)), data=np.zeros(18), dtype=int32)

e2v_offset_provider = NeighborTableOffsetProvider(table=reduction_setup.e2v_table, origin_axis=Edge, neighbor_axis=Vertex, max_neighbors=2)

testee_program.with_backend(run_gtfn_with_temporaries)( # todo: select modified backend
a=a, out=out,
offset_provider={"E2V": e2v_offset_provider},
)

def reference(a):
return (a * 2)[e2v_offset_provider.table[:, 0]] + (a * 2)[e2v_offset_provider.table[:, 1]]

assert np.allclose(reference(a), out)

0 comments on commit 141c3d5

Please sign in to comment.