Skip to content

Commit

Permalink
use custom unstructured case in test_icon_like_scan
Browse files Browse the repository at this point in the history
  • Loading branch information
Rico Häuselmann committed Nov 16, 2023
1 parent 393e196 commit 5e579fa
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 35 deletions.
2 changes: 1 addition & 1 deletion tests/next_tests/integration_tests/cases.py
Original file line number Diff line number Diff line change
Expand Up @@ -601,4 +601,4 @@ class Case:

@property
def as_field(self):
return constructors.as_field_with(allocator=self.backend)
return constructors.as_field.partial(allocator=self.backend)
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,16 @@
import pytest

import gt4py.next as gtx
from gt4py.next import common
from gt4py.next.program_processors.runners import gtfn, roundtrip

from next_tests.integration_tests import cases
from next_tests.integration_tests.cases import Cell, KDim, Koff
from next_tests.integration_tests.feature_tests.ffront_tests.ffront_test_utils import (
fieldview_backend,
)


Cell = gtx.Dimension("Cell")
KDim = gtx.Dimension("KDim", kind=gtx.DimensionKind.VERTICAL)
Koff = gtx.FieldOffset("Koff", KDim, (KDim,))


@gtx.scan_operator(axis=KDim, forward=True, init=(0.0, 0.0, True))
def _scan(
state: tuple[float, float, bool],
Expand Down Expand Up @@ -187,100 +185,117 @@ def reference(


@pytest.fixture
def test_setup():
def test_setup(fieldview_backend):
test_case = cases.Case(
fieldview_backend,
offset_provider={"Koff": KDim},
default_sizes={Cell: 14, KDim: 10},
grid_type=common.GridType.UNSTRUCTURED,
)

@dataclass(frozen=True)
class setup:
cell_size = 14
k_size = 10
z_alpha = gtx.as_field(
case: cases.Case = test_case
cell_size = case.default_sizes[Cell]
k_size = case.default_sizes[KDim]
z_alpha = case.as_field(
[Cell, KDim], np.random.default_rng().uniform(size=(cell_size, k_size + 1))
)
z_beta = gtx.as_field(
z_beta = case.as_field(
[Cell, KDim], np.random.default_rng().uniform(size=(cell_size, k_size))
)
z_q = gtx.as_field([Cell, KDim], np.random.default_rng().uniform(size=(cell_size, k_size)))
w = gtx.as_field([Cell, KDim], np.random.default_rng().uniform(size=(cell_size, k_size)))
z_q = case.as_field([Cell, KDim], np.random.default_rng().uniform(size=(cell_size, k_size)))
w = case.as_field([Cell, KDim], np.random.default_rng().uniform(size=(cell_size, k_size)))
z_q_ref, w_ref = reference(z_alpha.ndarray, z_beta.ndarray, z_q.ndarray, w.ndarray)
dummy = gtx.as_field([Cell, KDim], np.zeros((cell_size, k_size), dtype=bool))
z_q_out = gtx.as_field([Cell, KDim], np.zeros((cell_size, k_size)))
dummy = case.as_field([Cell, KDim], np.zeros((cell_size, k_size), dtype=bool))
z_q_out = case.as_field([Cell, KDim], np.zeros((cell_size, k_size)))

return setup()


@pytest.mark.uses_tuple_returns
def test_solve_nonhydro_stencil_52_like_z_q(test_setup, fieldview_backend):
if fieldview_backend in [
def test_solve_nonhydro_stencil_52_like_z_q(test_setup):
if test_setup.case.backend in [
gtfn.run_gtfn,
gtfn.run_gtfn_gpu,
gtfn.run_gtfn_imperative,
gtfn.run_gtfn_with_temporaries,
]:
pytest.xfail("Needs implementation of scan projector.")

solve_nonhydro_stencil_52_like_z_q.with_backend(fieldview_backend)(
cases.verify(
test_setup.case,
solve_nonhydro_stencil_52_like_z_q,
test_setup.z_alpha,
test_setup.z_beta,
test_setup.z_q,
test_setup.w,
test_setup.z_q_out,
offset_provider={"Koff": KDim},
ref=test_setup.z_q_ref,
inout=test_setup.z_q_out,
comparison=lambda ref, a: np.allclose(ref[:, 1:], a[:, 1:]),
)

assert np.allclose(test_setup.z_q_ref[:, 1:], test_setup.z_q_out[:, 1:])


@pytest.mark.uses_tuple_returns
def test_solve_nonhydro_stencil_52_like_z_q_tup(test_setup, fieldview_backend):
if fieldview_backend in [gtfn.run_gtfn_with_temporaries]:
def test_solve_nonhydro_stencil_52_like_z_q_tup(test_setup):
if test_setup.case.backend in [gtfn.run_gtfn_with_temporaries]:
pytest.xfail(
"Needs implementation of scan projector. Breaks in type inference as executed"
"again after CollapseTuple."
)
if fieldview_backend == roundtrip.backend:
if test_setup.case.backend == roundtrip.backend:
pytest.xfail("Needs proper handling of tuple[Column] <-> Column[tuple].")

solve_nonhydro_stencil_52_like_z_q_tup.with_backend(fieldview_backend)(
cases.verify(
test_setup.case,
solve_nonhydro_stencil_52_like_z_q_tup,
test_setup.z_alpha,
test_setup.z_beta,
test_setup.z_q,
test_setup.w,
test_setup.z_q_out,
offset_provider={"Koff": KDim},
ref=test_setup.z_q_ref,
inout=test_setup.z_q_out,
comparison=lambda ref, a: np.allclose(ref[:, 1:], a[:, 1:]),
)

assert np.allclose(test_setup.z_q_ref[:, 1:], test_setup.z_q_out[:, 1:])


@pytest.mark.uses_tuple_returns
def test_solve_nonhydro_stencil_52_like(test_setup, fieldview_backend):
if fieldview_backend in [gtfn.run_gtfn_with_temporaries]:
def test_solve_nonhydro_stencil_52_like(test_setup):
if test_setup.case.backend in [gtfn.run_gtfn_with_temporaries]:
pytest.xfail("Temporary extraction does not work correctly in combination with scans.")
solve_nonhydro_stencil_52_like.with_backend(fieldview_backend)(

cases.run(
test_setup.case,
solve_nonhydro_stencil_52_like,
test_setup.z_alpha,
test_setup.z_beta,
test_setup.z_q,
test_setup.w,
test_setup.dummy,
offset_provider={"Koff": KDim},
)

assert np.allclose(test_setup.z_q_ref, test_setup.z_q)
assert np.allclose(test_setup.w_ref, test_setup.w)


@pytest.mark.uses_tuple_returns
def test_solve_nonhydro_stencil_52_like_with_gtfn_tuple_merge(test_setup, fieldview_backend):
if fieldview_backend in [gtfn.run_gtfn_with_temporaries]:
def test_solve_nonhydro_stencil_52_like_with_gtfn_tuple_merge(test_setup):
if test_setup.case.backend in [gtfn.run_gtfn_with_temporaries]:
pytest.xfail("Temporary extraction does not work correctly in combination with scans.")
if fieldview_backend == roundtrip.backend:
if test_setup.case.backend == roundtrip.backend:
pytest.xfail("Needs proper handling of tuple[Column] <-> Column[tuple].")

solve_nonhydro_stencil_52_like_with_gtfn_tuple_merge.with_backend(fieldview_backend)(
cases.run(
test_setup.case,
solve_nonhydro_stencil_52_like_with_gtfn_tuple_merge,
test_setup.z_alpha,
test_setup.z_beta,
test_setup.z_q,
test_setup.w,
offset_provider={"Koff": KDim},
)

assert np.allclose(test_setup.z_q_ref, test_setup.z_q)
Expand Down

0 comments on commit 5e579fa

Please sign in to comment.