Skip to content

Commit

Permalink
bug[next]: Mark cupy tests as requires gpu (GridTools#1483)
Browse files Browse the repository at this point in the history
It noticed we execute some tests that use cupy in the cpu CSCS-CI. These tests are always executed when cupy is installed (which usually means we have a gpu, but not necessarily). This PR changes this such that they are marked with `requires_gpu` and as such can be disabled.
  • Loading branch information
tehrengruber authored Mar 11, 2024
1 parent 6d7f38c commit cb2056c
Showing 1 changed file with 19 additions and 7 deletions.
26 changes: 19 additions & 7 deletions tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,15 @@
KDim = Dimension("KDim")


@pytest.fixture(params=nd_array_field._nd_array_implementations)
def nd_array_implementation_params():
for xp in nd_array_field._nd_array_implementations:
if hasattr(nd_array_field, "cp") and xp == nd_array_field.cp:
yield pytest.param(xp, id=xp.__name__, marks=pytest.mark.requires_gpu)
else:
yield pytest.param(xp, id=xp.__name__)


@pytest.fixture(params=nd_array_implementation_params())
def nd_array_implementation(request):
yield request.param

Expand Down Expand Up @@ -272,12 +280,16 @@ def test_binary_operations_with_intersection(binary_arithmetic_op, dims, expecte
assert np.allclose(op_result.ndarray, expected_result)


@pytest.fixture(
params=itertools.product(
nd_array_field._nd_array_implementations, nd_array_field._nd_array_implementations
),
ids=lambda param: f"{param[0].__name__}-{param[1].__name__}",
)
def product_nd_array_implementation_params():
for xp1 in nd_array_field._nd_array_implementations:
for xp2 in nd_array_field._nd_array_implementations:
marks = ()
if any(hasattr(nd_array_field, "cp") and xp == nd_array_field.cp for xp in (xp1, xp2)):
marks = pytest.mark.requires_gpu
yield pytest.param((xp1, xp2), id=f"{xp1.__name__}-{xp2.__name__}", marks=marks)


@pytest.fixture(params=product_nd_array_implementation_params())
def product_nd_array_implementation(request):
yield request.param

Expand Down

0 comments on commit cb2056c

Please sign in to comment.