From 42912cc9d14e409801c1c71fc99a98f46e7c4a1b Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Mon, 20 Nov 2023 11:13:36 +0100 Subject: [PATCH] feat[next] Enable GPU backend tests (#1357) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - connectivities are implicitly copied to GPU if they are not already on GPU, this might be removed later - changes to cases: ensure we don't pass arrays to ConstInitializer --------- Co-authored-by: Rico Häuselmann --- src/gt4py/next/embedded/nd_array_field.py | 5 +- .../codegens/gtfn/codegen.py | 59 +++++++------- .../next/program_processors/runners/gtfn.py | 30 +++++-- tests/next_tests/exclusion_matrices.py | 5 ++ tests/next_tests/integration_tests/cases.py | 18 ++++- .../ffront_tests/ffront_test_utils.py | 1 + .../ffront_tests/test_execution.py | 33 ++++---- .../ffront_tests/test_external_local_field.py | 8 +- .../ffront_tests/test_gt4py_builtins.py | 18 ++--- .../test_math_builtin_execution.py | 4 +- .../ffront_tests/test_math_unary_builtins.py | 35 +++----- .../ffront_tests/test_program.py | 2 +- .../ffront_tests/test_icon_like_scan.py | 79 ++++++++++++------- .../ffront_tests/test_laplacian.py | 2 +- tests/next_tests/unit_tests/conftest.py | 1 + tox.ini | 2 +- 16 files changed, 176 insertions(+), 126 deletions(-) diff --git a/src/gt4py/next/embedded/nd_array_field.py b/src/gt4py/next/embedded/nd_array_field.py index 51e613ef81..9357570b05 100644 --- a/src/gt4py/next/embedded/nd_array_field.py +++ b/src/gt4py/next/embedded/nd_array_field.py @@ -121,7 +121,10 @@ def ndarray(self) -> core_defs.NDArrayObject: return self._ndarray def __array__(self, dtype: npt.DTypeLike = None) -> np.ndarray: - return np.asarray(self._ndarray, dtype) + if self.array_ns == cp: + return np.asarray(cp.asnumpy(self._ndarray), dtype) + else: + return np.asarray(self._ndarray, dtype) @property def dtype(self) -> core_defs.DType[core_defs.ScalarT]: diff --git a/src/gt4py/next/program_processors/codegens/gtfn/codegen.py b/src/gt4py/next/program_processors/codegens/gtfn/codegen.py index 645d1f742f..23165854de 100644 --- a/src/gt4py/next/program_processors/codegens/gtfn/codegen.py +++ b/src/gt4py/next/program_processors/codegens/gtfn/codegen.py @@ -179,6 +179,10 @@ def visit_FunCall(self, node: gtfn_ir.FunCall, **kwargs): """ ) + def visit_FunctionDefinition(self, node: gtfn_ir.FunctionDefinition, **kwargs): + expr_ = "return " + self.visit(node.expr) + return self.generic_visit(node, expr_=expr_) + FunctionDefinition = as_mako( """ struct ${id} { @@ -206,24 +210,6 @@ def visit_FunCall(self, node: gtfn_ir.FunCall, **kwargs): """ ) - def visit_FunctionDefinition(self, node: gtfn_ir.FunctionDefinition, **kwargs): - expr_ = "return " + self.visit(node.expr) - return self.generic_visit(node, expr_=expr_) - - def visit_FencilDefinition( - self, node: gtfn_ir.FencilDefinition, **kwargs: Any - ) -> Union[str, Collection[str]]: - self.is_cartesian = node.grid_type == common.GridType.CARTESIAN - self.user_defined_function_ids = list( - str(fundef.id) for fundef in node.function_definitions - ) - return self.generic_visit( - node, - grid_type_str=self._grid_type_str[node.grid_type], - block_sizes=self._block_sizes(node.offset_definitions), - **kwargs, - ) - def visit_TemporaryAllocation(self, node, **kwargs): # TODO(tehrengruber): Revisit. We are currently converting an itir.NamedRange with # start and stop values into an gtfn_ir.(Cartesian|Unstructured)Domain with @@ -244,6 +230,20 @@ def visit_TemporaryAllocation(self, node, **kwargs): "auto {id} = gtfn::allocate_global_tmp<{dtype}>(tmp_alloc__, {tmp_sizes});" ) + def visit_FencilDefinition( + self, node: gtfn_ir.FencilDefinition, **kwargs: Any + ) -> Union[str, Collection[str]]: + self.is_cartesian = node.grid_type == common.GridType.CARTESIAN + self.user_defined_function_ids = list( + str(fundef.id) for fundef in node.function_definitions + ) + return self.generic_visit( + node, + grid_type_str=self._grid_type_str[node.grid_type], + block_sizes=self._block_sizes(node.offset_definitions), + **kwargs, + ) + FencilDefinition = as_mako( """ #include @@ -277,16 +277,19 @@ def visit_TemporaryAllocation(self, node, **kwargs): ) def _block_sizes(self, offset_definitions: list[gtfn_ir.TagDefinition]) -> str: - block_dims = [] - block_sizes = [32, 8] + [1] * (len(offset_definitions) - 2) - for i, tag in enumerate(offset_definitions): - if tag.alias is None: - block_dims.append( - f"gridtools::meta::list<{tag.name.id}_t, " - f"gridtools::integral_constant>" - ) - sizes_str = ",\n".join(block_dims) - return f"using block_sizes_t = gridtools::meta::list<{sizes_str}>;" + if self.is_cartesian: + block_dims = [] + block_sizes = [32, 8] + [1] * (len(offset_definitions) - 2) + for i, tag in enumerate(offset_definitions): + if tag.alias is None: + block_dims.append( + f"gridtools::meta::list<{tag.name.id}_t, " + f"gridtools::integral_constant>" + ) + sizes_str = ",\n".join(block_dims) + return f"using block_sizes_t = gridtools::meta::list<{sizes_str}>;" + else: + return "using block_sizes_t = gridtools::meta::list>, gridtools::meta::list>>;" @classmethod def apply(cls, root: Any, **kwargs: Any) -> str: diff --git a/src/gt4py/next/program_processors/runners/gtfn.py b/src/gt4py/next/program_processors/runners/gtfn.py index 7233e7a893..5d4b450d39 100644 --- a/src/gt4py/next/program_processors/runners/gtfn.py +++ b/src/gt4py/next/program_processors/runners/gtfn.py @@ -12,6 +12,8 @@ # # SPDX-License-Identifier: GPL-3.0-or-later +import functools +import warnings from typing import Any import numpy.typing as npt @@ -42,12 +44,14 @@ def convert_arg(arg: Any) -> Any: return arg -def convert_args(inp: stages.CompiledProgram) -> stages.CompiledProgram: +def convert_args( + inp: stages.CompiledProgram, device: core_defs.DeviceType = core_defs.DeviceType.CPU +) -> stages.CompiledProgram: def decorated_program( *args, offset_provider: dict[str, common.Connectivity | common.Dimension] ): converted_args = [convert_arg(arg) for arg in args] - conn_args = extract_connectivity_args(offset_provider) + conn_args = extract_connectivity_args(offset_provider, device) return inp( *converted_args, *conn_args, @@ -56,8 +60,22 @@ def decorated_program( return decorated_program +def _ensure_is_on_device( + connectivity_arg: npt.NDArray, device: core_defs.DeviceType +) -> npt.NDArray: + if device == core_defs.DeviceType.CUDA: + import cupy as cp + + if not isinstance(connectivity_arg, cp.ndarray): + warnings.warn( + "Copying connectivity to device. For performance make sure connectivity is provided on device." + ) + return cp.asarray(connectivity_arg) + return connectivity_arg + + def extract_connectivity_args( - offset_provider: dict[str, common.Connectivity | common.Dimension] + offset_provider: dict[str, common.Connectivity | common.Dimension], device: core_defs.DeviceType ) -> list[tuple[npt.NDArray, tuple[int, ...]]]: # note: the order here needs to agree with the order of the generated bindings args: list[tuple[npt.NDArray, tuple[int, ...]]] = [] @@ -67,7 +85,9 @@ def extract_connectivity_args( raise NotImplementedError( "Only `NeighborTable` connectivities implemented at this point." ) - args.append((conn.table, tuple([0] * 2))) + # copying to device here is a fallback for easy testing and might be removed later + conn_arg = _ensure_is_on_device(conn.table, device) + args.append((conn_arg, tuple([0] * 2))) elif isinstance(conn, common.Dimension): pass else: @@ -126,7 +146,7 @@ def compilation_hash(otf_closure: stages.ProgramCall) -> int: translation=GTFN_GPU_TRANSLATION_STEP, bindings=nanobind.bind_source, compilation=GTFN_DEFAULT_COMPILE_STEP, - decoration=convert_args, + decoration=functools.partial(convert_args, device=core_defs.DeviceType.CUDA), ) diff --git a/tests/next_tests/exclusion_matrices.py b/tests/next_tests/exclusion_matrices.py index 249e17d358..ef30a61687 100644 --- a/tests/next_tests/exclusion_matrices.py +++ b/tests/next_tests/exclusion_matrices.py @@ -50,6 +50,7 @@ class ProgramBackendId(_PythonObjectIdMixin, str, enum.Enum): GTFN_CPU_WITH_TEMPORARIES = ( "gt4py.next.program_processors.runners.gtfn.run_gtfn_with_temporaries" ) + GTFN_GPU = "gt4py.next.program_processors.runners.gtfn.run_gtfn_gpu" ROUNDTRIP = "gt4py.next.program_processors.runners.roundtrip.backend" DOUBLE_ROUNDTRIP = "gt4py.next.program_processors.runners.double_roundtrip.backend" @@ -148,6 +149,10 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum): + [ (USES_STRIDED_NEIGHBOR_OFFSET, XFAIL, BINDINGS_UNSUPPORTED_MESSAGE), ], + ProgramBackendId.GTFN_GPU: GTFN_SKIP_TEST_LIST + + [ + (USES_STRIDED_NEIGHBOR_OFFSET, XFAIL, BINDINGS_UNSUPPORTED_MESSAGE), + ], ProgramBackendId.GTFN_CPU_IMPERATIVE: GTFN_SKIP_TEST_LIST + [ (USES_STRIDED_NEIGHBOR_OFFSET, XFAIL, BINDINGS_UNSUPPORTED_MESSAGE), diff --git a/tests/next_tests/integration_tests/cases.py b/tests/next_tests/integration_tests/cases.py index 634d85e64c..730ce18fd5 100644 --- a/tests/next_tests/integration_tests/cases.py +++ b/tests/next_tests/integration_tests/cases.py @@ -25,6 +25,7 @@ import pytest import gt4py.next as gtx +from gt4py._core import definitions as core_defs from gt4py.eve import extended_typing as xtyping from gt4py.eve.extended_typing import Self from gt4py.next import common, constructors @@ -73,7 +74,7 @@ E2V = gtx.FieldOffset("E2V", source=Vertex, target=(Edge, E2VDim)) C2E = gtx.FieldOffset("E2V", source=Edge, target=(Cell, C2EDim)) -ScalarValue: TypeAlias = np.int32 | np.int64 | np.float32 | np.float64 | np.generic +ScalarValue: TypeAlias = core_defs.Scalar FieldValue: TypeAlias = gtx.Field FieldViewArg: TypeAlias = FieldValue | ScalarValue | tuple["FieldViewArg", ...] FieldViewInout: TypeAlias = FieldValue | tuple["FieldViewInout", ...] @@ -117,12 +118,19 @@ def from_case( return self -@dataclasses.dataclass +@dataclasses.dataclass(init=False) class ConstInitializer(DataInitializer): """Initialize with a given value across the coordinate space.""" value: ScalarValue + def __init__(self, value: ScalarValue): + if not core_defs.is_scalar_type(value): + raise ValueError( + "`ConstInitializer` can not be used with non-scalars. Use `Case.as_field` instead." + ) + self.value = value + @property def scalar_value(self) -> ScalarValue: return self.value @@ -460,7 +468,7 @@ def verify_with_default_data( ``comparison(ref, )`` and should return a boolean. """ inps, kwfields = get_default_data(case, fieldop) - ref_args = tuple(i.ndarray if hasattr(i, "ndarray") else i for i in inps) + ref_args = tuple(i.__array__() if common.is_field(i) else i for i in inps) verify( case, fieldop, @@ -598,3 +606,7 @@ class Case: offset_provider: dict[str, common.Connectivity | gtx.Dimension] default_sizes: dict[gtx.Dimension, int] grid_type: common.GridType + + @property + def as_field(self): + return constructors.as_field.partial(allocator=self.backend) diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py index fb753bf169..01c78cf950 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py @@ -53,6 +53,7 @@ def no_backend(program: itir.FencilDefinition, *args: Any, **kwargs: Any) -> Non definitions.ProgramBackendId.GTFN_CPU, definitions.ProgramBackendId.GTFN_CPU_IMPERATIVE, definitions.ProgramBackendId.GTFN_CPU_WITH_TEMPORARIES, + pytest.param(definitions.ProgramBackendId.GTFN_GPU, marks=pytest.mark.requires_gpu), None, ] + OPTIONAL_PROCESSORS, diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py index 8036c22670..fe18bda9e3 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py @@ -371,8 +371,8 @@ def cast_nested_tuple( a = cases.allocate(cartesian_case, cast_tuple, "a")() b = cases.allocate(cartesian_case, cast_tuple, "b")() - a_asint = gtx.as_field([IDim], np.asarray(a).astype(int32)) - b_asint = gtx.as_field([IDim], np.asarray(b).astype(int32)) + a_asint = cartesian_case.as_field([IDim], np.asarray(a).astype(int32)) + b_asint = cartesian_case.as_field([IDim], np.asarray(b).astype(int32)) out_tuple = cases.allocate(cartesian_case, cast_tuple, cases.RETURN)() out_nested_tuple = cases.allocate(cartesian_case, cast_nested_tuple, cases.RETURN)() @@ -589,7 +589,7 @@ def testee(a: tuple[tuple[cases.IField, cases.IField], cases.IField]) -> cases.I def test_fieldop_from_scan(cartesian_case, forward): init = 1.0 expected = np.arange(init + 1.0, init + 1.0 + cartesian_case.default_sizes[IDim], 1) - out = gtx.as_field([KDim], np.zeros((cartesian_case.default_sizes[KDim],))) + out = cartesian_case.as_field([KDim], np.zeros((cartesian_case.default_sizes[KDim],))) if not forward: expected = np.flip(expected) @@ -610,6 +610,7 @@ def simple_scan_operator(carry: float) -> float: def test_solve_triag(cartesian_case): if cartesian_case.backend in [ gtfn.run_gtfn, + gtfn.run_gtfn_gpu, gtfn.run_gtfn_imperative, gtfn.run_gtfn_with_temporaries, ]: @@ -723,8 +724,8 @@ def simple_scan_operator(carry: float, a: float) -> float: return carry if carry > a else carry + 1.0 k_size = cartesian_case.default_sizes[KDim] - a = gtx.as_field([KDim], 4.0 * np.ones((k_size,))) - out = gtx.as_field([KDim], np.zeros((k_size,))) + a = cartesian_case.as_field([KDim], 4.0 * np.ones((k_size,))) + out = cartesian_case.as_field([KDim], np.zeros((k_size,))) cases.verify( cartesian_case, @@ -773,16 +774,19 @@ def testee(out: tuple[cases.KField, tuple[cases.KField, cases.KField]]): def test_scan_nested_tuple_input(cartesian_case): init = 1.0 k_size = cartesian_case.default_sizes[KDim] - inp1 = gtx.as_field([KDim], np.ones((k_size,))) - inp2 = gtx.as_field([KDim], np.arange(0.0, k_size, 1)) - out = gtx.as_field([KDim], np.zeros((k_size,))) + + inp1_np = np.ones((k_size,)) + inp2_np = np.arange(0.0, k_size, 1) + inp1 = cartesian_case.as_field([KDim], inp1_np) + inp2 = cartesian_case.as_field([KDim], inp2_np) + out = cartesian_case.as_field([KDim], np.zeros((k_size,))) def prev_levels_iterator(i): return range(i + 1) expected = np.asarray( [ - reduce(lambda prev, i: prev + inp1[i] + inp2[i], prev_levels_iterator(i), init) + reduce(lambda prev, i: prev + inp1_np[i] + inp2_np[i], prev_levels_iterator(i), init) for i in range(k_size) ] ) @@ -842,7 +846,7 @@ def program_domain(a: cases.IField, out: cases.IField): a = cases.allocate(cartesian_case, program_domain, "a")() out = cases.allocate(cartesian_case, program_domain, "out")() - ref = out.ndarray.copy() # ensure we are not overwriting out outside of the domain + ref = np.asarray(out).copy() # ensure we are not overwriting `out` outside of the domain ref[1:9] = a[1:9] * 2 cases.verify(cartesian_case, program_domain, a, out, inout=out, ref=ref) @@ -851,6 +855,7 @@ def program_domain(a: cases.IField, out: cases.IField): def test_domain_input_bounds(cartesian_case): if cartesian_case.backend in [ gtfn.run_gtfn, + gtfn.run_gtfn_gpu, gtfn.run_gtfn_imperative, gtfn.run_gtfn_with_temporaries, ]: @@ -876,7 +881,7 @@ def program_domain( inp = cases.allocate(cartesian_case, program_domain, "inp")() out = cases.allocate(cartesian_case, fieldop_domain, cases.RETURN)() - ref = out.ndarray.copy() + ref = np.asarray(out).copy() ref[lower_i : int(upper_i / 2)] = inp[lower_i : int(upper_i / 2)] * 2 cases.verify( @@ -919,7 +924,7 @@ def program_domain( a = cases.allocate(cartesian_case, program_domain, "a")() out = cases.allocate(cartesian_case, program_domain, "out")() - ref = out.ndarray.copy() + ref = np.asarray(out).copy() ref[1 * lower_i : upper_i + 0, lower_j - 0 : upper_j] = ( a[1 * lower_i : upper_i + 0, lower_j - 0 : upper_j] * 2 ) @@ -959,9 +964,9 @@ def program_domain_tuple( out0 = cases.allocate(cartesian_case, program_domain_tuple, "out0")() out1 = cases.allocate(cartesian_case, program_domain_tuple, "out1")() - ref0 = out0.ndarray.copy() + ref0 = np.asarray(out0).copy() ref0[1:9, 4:6] = inp0[1:9, 4:6] + inp1[1:9, 4:6] - ref1 = out1.ndarray.copy() + ref1 = np.asarray(out1).copy() ref1[1:9, 4:6] = inp1[1:9, 4:6] cases.verify( diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_external_local_field.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_external_local_field.py index 5135b3d47a..05adc63a45 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_external_local_field.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_external_local_field.py @@ -38,7 +38,9 @@ def testee( inp * ones(V2E), axis=V2EDim ) # multiplication with shifted `ones` because reduction of only non-shifted field with local dimension is not supported - inp = gtx.as_field([Vertex, V2EDim], unstructured_case.offset_provider["V2E"].table) + inp = unstructured_case.as_field( + [Vertex, V2EDim], unstructured_case.offset_provider["V2E"].table + ) ones = cases.allocate(unstructured_case, testee, "ones").strategy(cases.ConstInitializer(1))() cases.verify( @@ -59,7 +61,9 @@ def test_external_local_field_only(unstructured_case): def testee(inp: gtx.Field[[Vertex, V2EDim], int32]) -> gtx.Field[[Vertex], int32]: return neighbor_sum(inp, axis=V2EDim) - inp = gtx.as_field([Vertex, V2EDim], unstructured_case.offset_provider["V2E"].table) + inp = unstructured_case.as_field( + [Vertex, V2EDim], unstructured_case.offset_provider["V2E"].table + ) cases.verify( unstructured_case, diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py index 1eba95e880..8bc325d276 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py @@ -48,6 +48,7 @@ def test_maxover_execution_(unstructured_case, strategy): if unstructured_case.backend in [ gtfn.run_gtfn, + gtfn.run_gtfn_gpu, gtfn.run_gtfn_imperative, gtfn.run_gtfn_with_temporaries, ]: @@ -142,10 +143,7 @@ def conditional_nested_tuple( return where(mask, ((a, b), (b, a)), ((5.0, 7.0), (7.0, 5.0))) size = cartesian_case.default_sizes[IDim] - bool_field = np.random.choice(a=[False, True], size=(size)) - mask = cases.allocate(cartesian_case, conditional_nested_tuple, "mask").strategy( - cases.ConstInitializer(bool_field) - )() + mask = cartesian_case.as_field([IDim], np.random.choice(a=[False, True], size=size)) a = cases.allocate(cartesian_case, conditional_nested_tuple, "a")() b = cases.allocate(cartesian_case, conditional_nested_tuple, "b")() @@ -216,10 +214,7 @@ def conditional( return where(mask, a, b) size = cartesian_case.default_sizes[IDim] - bool_field = np.random.choice(a=[False, True], size=(size)) - mask = cases.allocate(cartesian_case, conditional, "mask").strategy( - cases.ConstInitializer(bool_field) - )() + mask = cartesian_case.as_field([IDim], np.random.choice(a=[False, True], size=(size))) a = cases.allocate(cartesian_case, conditional, "a")() b = cases.allocate(cartesian_case, conditional, "b")() out = cases.allocate(cartesian_case, conditional, cases.RETURN)() @@ -233,10 +228,7 @@ def conditional_promotion(mask: cases.IBoolField, a: cases.IFloatField) -> cases return where(mask, a, 10.0) size = cartesian_case.default_sizes[IDim] - bool_field = np.random.choice(a=[False, True], size=(size)) - mask = cases.allocate(cartesian_case, conditional_promotion, "mask").strategy( - cases.ConstInitializer(bool_field) - )() + mask = cartesian_case.as_field([IDim], np.random.choice(a=[False, True], size=(size))) a = cases.allocate(cartesian_case, conditional_promotion, "a")() out = cases.allocate(cartesian_case, conditional_promotion, cases.RETURN)() @@ -274,7 +266,7 @@ def conditional_program( conditional_shifted(mask, a, b, out=out) size = cartesian_case.default_sizes[IDim] + 1 - mask = gtx.as_field([IDim], np.random.choice(a=[False, True], size=(size))) + mask = cartesian_case.as_field([IDim], np.random.choice(a=[False, True], size=(size))) a = cases.allocate(cartesian_case, conditional_program, "a").extend({IDim: (0, 1)})() b = cases.allocate(cartesian_case, conditional_program, "b").extend({IDim: (0, 1)})() out = cases.allocate(cartesian_case, conditional_shifted, cases.RETURN)() diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_math_builtin_execution.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_math_builtin_execution.py index a1839b8e17..937b05e087 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_math_builtin_execution.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_math_builtin_execution.py @@ -125,9 +125,9 @@ def test_math_function_builtins_execution(cartesian_case, builtin_name: str, inp else: ref_impl: Callable = getattr(np, builtin_name) - inps = [gtx.as_field([IDim], np.asarray(input)) for input in inputs] + inps = [cartesian_case.as_field([IDim], np.asarray(input)) for input in inputs] expected = ref_impl(*inputs) - out = gtx.as_field([IDim], np.zeros_like(expected)) + out = cartesian_case.as_field([IDim], np.zeros_like(expected)) builtin_field_op = make_builtin_field_operator(builtin_name).with_backend( cartesian_case.backend diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_math_unary_builtins.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_math_unary_builtins.py index 59e11a7de8..8660ecfdbd 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_math_unary_builtins.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_math_unary_builtins.py @@ -72,6 +72,7 @@ def test_floordiv(cartesian_case): gtfn.run_gtfn, gtfn.run_gtfn_imperative, gtfn.run_gtfn_with_temporaries, + gtfn.run_gtfn_gpu, ]: pytest.xfail( "FloorDiv not yet supported." @@ -90,7 +91,7 @@ def test_mod(cartesian_case): def mod_fieldop(inp1: cases.IField) -> cases.IField: return inp1 % 2 - inp1 = gtx.as_field([IDim], np.asarray(range(10), dtype=int32) - 5) + inp1 = cartesian_case.as_field([IDim], np.asarray(range(10), dtype=int32) - 5) out = cases.allocate(cartesian_case, mod_fieldop, cases.RETURN)() cases.verify(cartesian_case, mod_fieldop, inp1, out=out, ref=inp1 % 2) @@ -102,13 +103,8 @@ def binary_xor(inp1: cases.IBoolField, inp2: cases.IBoolField) -> cases.IBoolFie return inp1 ^ inp2 size = cartesian_case.default_sizes[IDim] - bool_field = np.random.choice(a=[False, True], size=(size)) - inp1 = cases.allocate(cartesian_case, binary_xor, "inp1").strategy( - cases.ConstInitializer(bool_field) - )() - inp2 = cases.allocate(cartesian_case, binary_xor, "inp2").strategy( - cases.ConstInitializer(bool_field) - )() + inp1 = cartesian_case.as_field([IDim], np.random.choice(a=[False, True], size=(size))) + inp2 = cartesian_case.as_field([IDim], np.random.choice(a=[False, True], size=(size))) out = cases.allocate(cartesian_case, binary_xor, cases.RETURN)() cases.verify(cartesian_case, binary_xor, inp1, inp2, out=out, ref=inp1 ^ inp2) @@ -119,13 +115,8 @@ def bit_and(inp1: cases.IBoolField, inp2: cases.IBoolField) -> cases.IBoolField: return inp1 & inp2 size = cartesian_case.default_sizes[IDim] - bool_field = np.random.choice(a=[False, True], size=(size)) - inp1 = cases.allocate(cartesian_case, bit_and, "inp1").strategy( - cases.ConstInitializer(bool_field) - )() - inp2 = cases.allocate(cartesian_case, bit_and, "inp2").strategy( - cases.ConstInitializer(bool_field) - )() + inp1 = cartesian_case.as_field([IDim], np.random.choice(a=[False, True], size=(size))) + inp2 = cartesian_case.as_field([IDim], np.random.choice(a=[False, True], size=(size))) out = cases.allocate(cartesian_case, bit_and, cases.RETURN)() cases.verify(cartesian_case, bit_and, inp1, inp2, out=out, ref=inp1 & inp2) @@ -136,13 +127,8 @@ def bit_or(inp1: cases.IBoolField, inp2: cases.IBoolField) -> cases.IBoolField: return inp1 | inp2 size = cartesian_case.default_sizes[IDim] - bool_field = np.random.choice(a=[False, True], size=(size)) - inp1 = cases.allocate(cartesian_case, bit_or, "inp1").strategy( - cases.ConstInitializer(bool_field) - )() - inp2 = cases.allocate(cartesian_case, bit_or, "inp2").strategy( - cases.ConstInitializer(bool_field) - )() + inp1 = cartesian_case.as_field([IDim], np.random.choice(a=[False, True], size=(size))) + inp2 = cartesian_case.as_field([IDim], np.random.choice(a=[False, True], size=(size))) out = cases.allocate(cartesian_case, bit_or, cases.RETURN)() cases.verify(cartesian_case, bit_or, inp1, inp2, out=out, ref=inp1 | inp2) @@ -164,10 +150,7 @@ def tilde_fieldop(inp1: cases.IBoolField) -> cases.IBoolField: return ~inp1 size = cartesian_case.default_sizes[IDim] - bool_field = np.random.choice(a=[False, True], size=(size)) - inp1 = cases.allocate(cartesian_case, tilde_fieldop, "inp1").strategy( - cases.ConstInitializer(bool_field) - )() + inp1 = cartesian_case.as_field([IDim], np.random.choice(a=[False, True], size=(size))) out = cases.allocate(cartesian_case, tilde_fieldop, cases.RETURN)() cases.verify(cartesian_case, tilde_fieldop, inp1, out=out, ref=~inp1) diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_program.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_program.py index 545abd2825..b82cae25a8 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_program.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_program.py @@ -215,7 +215,7 @@ def prog( def test_wrong_argument_type(cartesian_case, copy_program_def): copy_program = gtx.program(copy_program_def, backend=cartesian_case.backend) - inp = gtx.as_field([JDim], np.ones((cartesian_case.default_sizes[JDim],))) + inp = cartesian_case.as_field([JDim], np.ones((cartesian_case.default_sizes[JDim],))) out = cases.allocate(cartesian_case, copy_program, "out").strategy(cases.ConstInitializer(1))() with pytest.raises(TypeError) as exc_info: diff --git a/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_icon_like_scan.py b/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_icon_like_scan.py index eaae9a2a3e..cd948ffa02 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_icon_like_scan.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_icon_like_scan.py @@ -18,8 +18,11 @@ import pytest import gt4py.next as gtx +from gt4py.next import common from gt4py.next.program_processors.runners import gtfn, roundtrip +from next_tests.integration_tests import cases +from next_tests.integration_tests.cases import Cell, KDim, Koff from next_tests.integration_tests.feature_tests.ffront_tests.ffront_test_utils import ( fieldview_backend, ) @@ -190,80 +193,97 @@ def reference( @pytest.fixture -def test_setup(): +def test_setup(fieldview_backend): + test_case = cases.Case( + fieldview_backend, + offset_provider={"Koff": KDim}, + default_sizes={Cell: 14, KDim: 10}, + grid_type=common.GridType.UNSTRUCTURED, + ) + @dataclass(frozen=True) class setup: - cell_size = 14 - k_size = 10 - z_alpha = gtx.as_field( + case: cases.Case = test_case + cell_size = case.default_sizes[Cell] + k_size = case.default_sizes[KDim] + z_alpha = case.as_field( [Cell, KDim], np.random.default_rng().uniform(size=(cell_size, k_size + 1)) ) - z_beta = gtx.as_field( + z_beta = case.as_field( [Cell, KDim], np.random.default_rng().uniform(size=(cell_size, k_size)) ) - z_q = gtx.as_field([Cell, KDim], np.random.default_rng().uniform(size=(cell_size, k_size))) - w = gtx.as_field([Cell, KDim], np.random.default_rng().uniform(size=(cell_size, k_size))) + z_q = case.as_field([Cell, KDim], np.random.default_rng().uniform(size=(cell_size, k_size))) + w = case.as_field([Cell, KDim], np.random.default_rng().uniform(size=(cell_size, k_size))) z_q_ref, w_ref = reference(z_alpha.ndarray, z_beta.ndarray, z_q.ndarray, w.ndarray) - dummy = gtx.as_field([Cell, KDim], np.zeros((cell_size, k_size), dtype=bool)) - z_q_out = gtx.as_field([Cell, KDim], np.zeros((cell_size, k_size))) + dummy = case.as_field([Cell, KDim], np.zeros((cell_size, k_size), dtype=bool)) + z_q_out = case.as_field([Cell, KDim], np.zeros((cell_size, k_size))) return setup() @pytest.mark.uses_tuple_returns -def test_solve_nonhydro_stencil_52_like_z_q(test_setup, fieldview_backend): - if fieldview_backend in [ +def test_solve_nonhydro_stencil_52_like_z_q(test_setup): + if test_setup.case.backend in [ gtfn.run_gtfn, + gtfn.run_gtfn_gpu, gtfn.run_gtfn_imperative, gtfn.run_gtfn_with_temporaries, ]: pytest.xfail("Needs implementation of scan projector.") - solve_nonhydro_stencil_52_like_z_q.with_backend(fieldview_backend)( + cases.verify( + test_setup.case, + solve_nonhydro_stencil_52_like_z_q, test_setup.z_alpha, test_setup.z_beta, test_setup.z_q, test_setup.w, test_setup.z_q_out, - offset_provider={"Koff": KDim}, + ref=test_setup.z_q_ref, + inout=test_setup.z_q_out, + comparison=lambda ref, a: np.allclose(ref[:, 1:], a[:, 1:]), ) assert np.allclose(test_setup.z_q_ref[:, 1:], test_setup.z_q_out[:, 1:]) @pytest.mark.uses_tuple_returns -def test_solve_nonhydro_stencil_52_like_z_q_tup(test_setup, fieldview_backend): - if fieldview_backend in [gtfn.run_gtfn_with_temporaries]: +def test_solve_nonhydro_stencil_52_like_z_q_tup(test_setup): + if test_setup.case.backend in [gtfn.run_gtfn_with_temporaries]: pytest.xfail( "Needs implementation of scan projector. Breaks in type inference as executed" "again after CollapseTuple." ) - if fieldview_backend == roundtrip.backend: + if test_setup.case.backend == roundtrip.backend: pytest.xfail("Needs proper handling of tuple[Column] <-> Column[tuple].") - solve_nonhydro_stencil_52_like_z_q_tup.with_backend(fieldview_backend)( + cases.verify( + test_setup.case, + solve_nonhydro_stencil_52_like_z_q_tup, test_setup.z_alpha, test_setup.z_beta, test_setup.z_q, test_setup.w, test_setup.z_q_out, - offset_provider={"Koff": KDim}, + ref=test_setup.z_q_ref, + inout=test_setup.z_q_out, + comparison=lambda ref, a: np.allclose(ref[:, 1:], a[:, 1:]), ) - assert np.allclose(test_setup.z_q_ref[:, 1:], test_setup.z_q_out[:, 1:]) - @pytest.mark.uses_tuple_returns -def test_solve_nonhydro_stencil_52_like(test_setup, fieldview_backend): - if fieldview_backend in [gtfn.run_gtfn_with_temporaries]: +def test_solve_nonhydro_stencil_52_like(test_setup): + if test_setup.case.backend in [gtfn.run_gtfn_with_temporaries]: pytest.xfail("Temporary extraction does not work correctly in combination with scans.") - solve_nonhydro_stencil_52_like.with_backend(fieldview_backend)( + + cases.run( + test_setup.case, + solve_nonhydro_stencil_52_like, test_setup.z_alpha, test_setup.z_beta, test_setup.z_q, test_setup.w, test_setup.dummy, - offset_provider={"Koff": KDim}, ) assert np.allclose(test_setup.z_q_ref, test_setup.z_q) @@ -271,18 +291,19 @@ def test_solve_nonhydro_stencil_52_like(test_setup, fieldview_backend): @pytest.mark.uses_tuple_returns -def test_solve_nonhydro_stencil_52_like_with_gtfn_tuple_merge(test_setup, fieldview_backend): - if fieldview_backend in [gtfn.run_gtfn_with_temporaries]: +def test_solve_nonhydro_stencil_52_like_with_gtfn_tuple_merge(test_setup): + if test_setup.case.backend in [gtfn.run_gtfn_with_temporaries]: pytest.xfail("Temporary extraction does not work correctly in combination with scans.") - if fieldview_backend == roundtrip.backend: + if test_setup.case.backend == roundtrip.backend: pytest.xfail("Needs proper handling of tuple[Column] <-> Column[tuple].") - solve_nonhydro_stencil_52_like_with_gtfn_tuple_merge.with_backend(fieldview_backend)( + cases.run( + test_setup.case, + solve_nonhydro_stencil_52_like_with_gtfn_tuple_merge, test_setup.z_alpha, test_setup.z_beta, test_setup.z_q, test_setup.w, - offset_provider={"Koff": KDim}, ) assert np.allclose(test_setup.z_q_ref, test_setup.z_q) diff --git a/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_laplacian.py b/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_laplacian.py index 9a1e968de0..4f4d4969a9 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_laplacian.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_laplacian.py @@ -86,5 +86,5 @@ def test_ffront_lap(cartesian_case): in_field, out_field, inout=out_field[2:-2, 2:-2], - ref=lap_ref(lap_ref(np.asarray(in_field.ndarray))), + ref=lap_ref(lap_ref(in_field.array_ns.asarray(in_field.ndarray))), ) diff --git a/tests/next_tests/unit_tests/conftest.py b/tests/next_tests/unit_tests/conftest.py index b43eeb3f91..372062d08a 100644 --- a/tests/next_tests/unit_tests/conftest.py +++ b/tests/next_tests/unit_tests/conftest.py @@ -60,6 +60,7 @@ def lift_mode(request): (definitions.ProgramBackendId.GTFN_CPU, True), (definitions.ProgramBackendId.GTFN_CPU_IMPERATIVE, True), (definitions.ProgramBackendId.GTFN_CPU_WITH_TEMPORARIES, True), + # pytest.param((definitions.ProgramBackendId.GTFN_GPU, True), marks=pytest.mark.requires_gpu), # TODO(havogt): update tests to use proper allocation (definitions.ProgramFormatterId.LISP_FORMATTER, False), (definitions.ProgramFormatterId.ITIR_PRETTY_PRINTER, False), (definitions.ProgramFormatterId.ITIR_TYPE_CHECKER, False), diff --git a/tox.ini b/tox.ini index 5b644e7d97..44dc912c8a 100644 --- a/tox.ini +++ b/tox.ini @@ -84,7 +84,7 @@ commands = nomesh-cpu: python -m pytest --cache-clear -v -n {env:NUM_PROCESSES:1} -m "not requires_atlas and not requires_gpu" {posargs} tests{/}next_tests nomesh-{cuda,cuda11x,cuda12x}: python -m pytest --cache-clear -v -n {env:NUM_PROCESSES:1} -m "not requires_atlas and requires_gpu" {posargs} tests{/}next_tests atlas-cpu: python -m pytest --cache-clear -v -n {env:NUM_PROCESSES:1} -m "requires_atlas and not requires_gpu" {posargs} tests{/}next_tests - # atlas-{cuda,cuda11x,cuda12x}: python -m pytest --cache-clear -v -n {env:NUM_PROCESSES:1} -m "requires_atlas and requires_gpu" {posargs} tests{/}next_tests # TODO(ricoh): activate when such tests exist + # atlas-{cuda,cuda11x,cuda12x}: python -m pytest --cache-clear -v -n {env:NUM_PROCESSES:1} -m "requires_atlas and requires_gpu" {posargs} tests{/}next_tests # TODO(ricoh): activate when such tests exist pytest --doctest-modules src{/}gt4py{/}next [testenv:storage-py{38,39,310}-{internal,dace}-{cpu,cuda,cuda11x,cuda12x}]