diff --git a/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py b/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py index 78a64d9277..3ec630c949 100644 --- a/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py +++ b/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py @@ -34,7 +34,15 @@ KDim = Dimension("KDim") -@pytest.fixture(params=nd_array_field._nd_array_implementations) +def nd_array_implementation_params(): + for xp in nd_array_field._nd_array_implementations: + if hasattr(nd_array_field, "cp") and xp == nd_array_field.cp: + yield pytest.param(xp, id=xp.__name__, marks=pytest.mark.requires_gpu) + else: + yield pytest.param(xp, id=xp.__name__) + + +@pytest.fixture(params=nd_array_implementation_params()) def nd_array_implementation(request): yield request.param @@ -272,12 +280,16 @@ def test_binary_operations_with_intersection(binary_arithmetic_op, dims, expecte assert np.allclose(op_result.ndarray, expected_result) -@pytest.fixture( - params=itertools.product( - nd_array_field._nd_array_implementations, nd_array_field._nd_array_implementations - ), - ids=lambda param: f"{param[0].__name__}-{param[1].__name__}", -) +def product_nd_array_implementation_params(): + for xp1 in nd_array_field._nd_array_implementations: + for xp2 in nd_array_field._nd_array_implementations: + marks = () + if any(hasattr(nd_array_field, "cp") and xp == nd_array_field.cp for xp in (xp1, xp2)): + marks = pytest.mark.requires_gpu + yield pytest.param((xp1, xp2), id=f"{xp1.__name__}-{xp2.__name__}", marks=marks) + + +@pytest.fixture(params=product_nd_array_implementation_params()) def product_nd_array_implementation(request): yield request.param