Skip to content

Commit

Permalink
feat[next] Enable GPU backend tests (#1357)
Browse files Browse the repository at this point in the history
- connectivities are implicitly copied to GPU if they are not already on GPU, this might be removed later
- changes to cases: ensure we don't pass arrays to ConstInitializer

---------

Co-authored-by: Rico Häuselmann <[email protected]>
  • Loading branch information
havogt and Rico Häuselmann authored Nov 20, 2023
1 parent ecd0b68 commit 42912cc
Show file tree
Hide file tree
Showing 16 changed files with 176 additions and 126 deletions.
5 changes: 4 additions & 1 deletion src/gt4py/next/embedded/nd_array_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,10 @@ def ndarray(self) -> core_defs.NDArrayObject:
return self._ndarray

def __array__(self, dtype: npt.DTypeLike = None) -> np.ndarray:
return np.asarray(self._ndarray, dtype)
if self.array_ns == cp:
return np.asarray(cp.asnumpy(self._ndarray), dtype)
else:
return np.asarray(self._ndarray, dtype)

@property
def dtype(self) -> core_defs.DType[core_defs.ScalarT]:
Expand Down
59 changes: 31 additions & 28 deletions src/gt4py/next/program_processors/codegens/gtfn/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,10 @@ def visit_FunCall(self, node: gtfn_ir.FunCall, **kwargs):
"""
)

def visit_FunctionDefinition(self, node: gtfn_ir.FunctionDefinition, **kwargs):
expr_ = "return " + self.visit(node.expr)
return self.generic_visit(node, expr_=expr_)

FunctionDefinition = as_mako(
"""
struct ${id} {
Expand Down Expand Up @@ -206,24 +210,6 @@ def visit_FunCall(self, node: gtfn_ir.FunCall, **kwargs):
"""
)

def visit_FunctionDefinition(self, node: gtfn_ir.FunctionDefinition, **kwargs):
expr_ = "return " + self.visit(node.expr)
return self.generic_visit(node, expr_=expr_)

def visit_FencilDefinition(
self, node: gtfn_ir.FencilDefinition, **kwargs: Any
) -> Union[str, Collection[str]]:
self.is_cartesian = node.grid_type == common.GridType.CARTESIAN
self.user_defined_function_ids = list(
str(fundef.id) for fundef in node.function_definitions
)
return self.generic_visit(
node,
grid_type_str=self._grid_type_str[node.grid_type],
block_sizes=self._block_sizes(node.offset_definitions),
**kwargs,
)

def visit_TemporaryAllocation(self, node, **kwargs):
# TODO(tehrengruber): Revisit. We are currently converting an itir.NamedRange with
# start and stop values into an gtfn_ir.(Cartesian|Unstructured)Domain with
Expand All @@ -244,6 +230,20 @@ def visit_TemporaryAllocation(self, node, **kwargs):
"auto {id} = gtfn::allocate_global_tmp<{dtype}>(tmp_alloc__, {tmp_sizes});"
)

def visit_FencilDefinition(
self, node: gtfn_ir.FencilDefinition, **kwargs: Any
) -> Union[str, Collection[str]]:
self.is_cartesian = node.grid_type == common.GridType.CARTESIAN
self.user_defined_function_ids = list(
str(fundef.id) for fundef in node.function_definitions
)
return self.generic_visit(
node,
grid_type_str=self._grid_type_str[node.grid_type],
block_sizes=self._block_sizes(node.offset_definitions),
**kwargs,
)

FencilDefinition = as_mako(
"""
#include <cmath>
Expand Down Expand Up @@ -277,16 +277,19 @@ def visit_TemporaryAllocation(self, node, **kwargs):
)

def _block_sizes(self, offset_definitions: list[gtfn_ir.TagDefinition]) -> str:
block_dims = []
block_sizes = [32, 8] + [1] * (len(offset_definitions) - 2)
for i, tag in enumerate(offset_definitions):
if tag.alias is None:
block_dims.append(
f"gridtools::meta::list<{tag.name.id}_t, "
f"gridtools::integral_constant<int, {block_sizes[i]}>>"
)
sizes_str = ",\n".join(block_dims)
return f"using block_sizes_t = gridtools::meta::list<{sizes_str}>;"
if self.is_cartesian:
block_dims = []
block_sizes = [32, 8] + [1] * (len(offset_definitions) - 2)
for i, tag in enumerate(offset_definitions):
if tag.alias is None:
block_dims.append(
f"gridtools::meta::list<{tag.name.id}_t, "
f"gridtools::integral_constant<int, {block_sizes[i]}>>"
)
sizes_str = ",\n".join(block_dims)
return f"using block_sizes_t = gridtools::meta::list<{sizes_str}>;"
else:
return "using block_sizes_t = gridtools::meta::list<gridtools::meta::list<gtfn::unstructured::dim::horizontal, gridtools::integral_constant<int, 32>>, gridtools::meta::list<gtfn::unstructured::dim::vertical, gridtools::integral_constant<int, 8>>>;"

@classmethod
def apply(cls, root: Any, **kwargs: Any) -> str:
Expand Down
30 changes: 25 additions & 5 deletions src/gt4py/next/program_processors/runners/gtfn.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
#
# SPDX-License-Identifier: GPL-3.0-or-later

import functools
import warnings
from typing import Any

import numpy.typing as npt
Expand Down Expand Up @@ -42,12 +44,14 @@ def convert_arg(arg: Any) -> Any:
return arg


def convert_args(inp: stages.CompiledProgram) -> stages.CompiledProgram:
def convert_args(
inp: stages.CompiledProgram, device: core_defs.DeviceType = core_defs.DeviceType.CPU
) -> stages.CompiledProgram:
def decorated_program(
*args, offset_provider: dict[str, common.Connectivity | common.Dimension]
):
converted_args = [convert_arg(arg) for arg in args]
conn_args = extract_connectivity_args(offset_provider)
conn_args = extract_connectivity_args(offset_provider, device)
return inp(
*converted_args,
*conn_args,
Expand All @@ -56,8 +60,22 @@ def decorated_program(
return decorated_program


def _ensure_is_on_device(
connectivity_arg: npt.NDArray, device: core_defs.DeviceType
) -> npt.NDArray:
if device == core_defs.DeviceType.CUDA:
import cupy as cp

if not isinstance(connectivity_arg, cp.ndarray):
warnings.warn(
"Copying connectivity to device. For performance make sure connectivity is provided on device."
)
return cp.asarray(connectivity_arg)
return connectivity_arg


def extract_connectivity_args(
offset_provider: dict[str, common.Connectivity | common.Dimension]
offset_provider: dict[str, common.Connectivity | common.Dimension], device: core_defs.DeviceType
) -> list[tuple[npt.NDArray, tuple[int, ...]]]:
# note: the order here needs to agree with the order of the generated bindings
args: list[tuple[npt.NDArray, tuple[int, ...]]] = []
Expand All @@ -67,7 +85,9 @@ def extract_connectivity_args(
raise NotImplementedError(
"Only `NeighborTable` connectivities implemented at this point."
)
args.append((conn.table, tuple([0] * 2)))
# copying to device here is a fallback for easy testing and might be removed later
conn_arg = _ensure_is_on_device(conn.table, device)
args.append((conn_arg, tuple([0] * 2)))
elif isinstance(conn, common.Dimension):
pass
else:
Expand Down Expand Up @@ -126,7 +146,7 @@ def compilation_hash(otf_closure: stages.ProgramCall) -> int:
translation=GTFN_GPU_TRANSLATION_STEP,
bindings=nanobind.bind_source,
compilation=GTFN_DEFAULT_COMPILE_STEP,
decoration=convert_args,
decoration=functools.partial(convert_args, device=core_defs.DeviceType.CUDA),
)


Expand Down
5 changes: 5 additions & 0 deletions tests/next_tests/exclusion_matrices.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ class ProgramBackendId(_PythonObjectIdMixin, str, enum.Enum):
GTFN_CPU_WITH_TEMPORARIES = (
"gt4py.next.program_processors.runners.gtfn.run_gtfn_with_temporaries"
)
GTFN_GPU = "gt4py.next.program_processors.runners.gtfn.run_gtfn_gpu"
ROUNDTRIP = "gt4py.next.program_processors.runners.roundtrip.backend"
DOUBLE_ROUNDTRIP = "gt4py.next.program_processors.runners.double_roundtrip.backend"

Expand Down Expand Up @@ -148,6 +149,10 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum):
+ [
(USES_STRIDED_NEIGHBOR_OFFSET, XFAIL, BINDINGS_UNSUPPORTED_MESSAGE),
],
ProgramBackendId.GTFN_GPU: GTFN_SKIP_TEST_LIST
+ [
(USES_STRIDED_NEIGHBOR_OFFSET, XFAIL, BINDINGS_UNSUPPORTED_MESSAGE),
],
ProgramBackendId.GTFN_CPU_IMPERATIVE: GTFN_SKIP_TEST_LIST
+ [
(USES_STRIDED_NEIGHBOR_OFFSET, XFAIL, BINDINGS_UNSUPPORTED_MESSAGE),
Expand Down
18 changes: 15 additions & 3 deletions tests/next_tests/integration_tests/cases.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import pytest

import gt4py.next as gtx
from gt4py._core import definitions as core_defs
from gt4py.eve import extended_typing as xtyping
from gt4py.eve.extended_typing import Self
from gt4py.next import common, constructors
Expand Down Expand Up @@ -73,7 +74,7 @@
E2V = gtx.FieldOffset("E2V", source=Vertex, target=(Edge, E2VDim))
C2E = gtx.FieldOffset("E2V", source=Edge, target=(Cell, C2EDim))

ScalarValue: TypeAlias = np.int32 | np.int64 | np.float32 | np.float64 | np.generic
ScalarValue: TypeAlias = core_defs.Scalar
FieldValue: TypeAlias = gtx.Field
FieldViewArg: TypeAlias = FieldValue | ScalarValue | tuple["FieldViewArg", ...]
FieldViewInout: TypeAlias = FieldValue | tuple["FieldViewInout", ...]
Expand Down Expand Up @@ -117,12 +118,19 @@ def from_case(
return self


@dataclasses.dataclass
@dataclasses.dataclass(init=False)
class ConstInitializer(DataInitializer):
"""Initialize with a given value across the coordinate space."""

value: ScalarValue

def __init__(self, value: ScalarValue):
if not core_defs.is_scalar_type(value):
raise ValueError(
"`ConstInitializer` can not be used with non-scalars. Use `Case.as_field` instead."
)
self.value = value

@property
def scalar_value(self) -> ScalarValue:
return self.value
Expand Down Expand Up @@ -460,7 +468,7 @@ def verify_with_default_data(
``comparison(ref, <out | inout>)`` and should return a boolean.
"""
inps, kwfields = get_default_data(case, fieldop)
ref_args = tuple(i.ndarray if hasattr(i, "ndarray") else i for i in inps)
ref_args = tuple(i.__array__() if common.is_field(i) else i for i in inps)
verify(
case,
fieldop,
Expand Down Expand Up @@ -598,3 +606,7 @@ class Case:
offset_provider: dict[str, common.Connectivity | gtx.Dimension]
default_sizes: dict[gtx.Dimension, int]
grid_type: common.GridType

@property
def as_field(self):
return constructors.as_field.partial(allocator=self.backend)
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def no_backend(program: itir.FencilDefinition, *args: Any, **kwargs: Any) -> Non
definitions.ProgramBackendId.GTFN_CPU,
definitions.ProgramBackendId.GTFN_CPU_IMPERATIVE,
definitions.ProgramBackendId.GTFN_CPU_WITH_TEMPORARIES,
pytest.param(definitions.ProgramBackendId.GTFN_GPU, marks=pytest.mark.requires_gpu),
None,
]
+ OPTIONAL_PROCESSORS,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -371,8 +371,8 @@ def cast_nested_tuple(

a = cases.allocate(cartesian_case, cast_tuple, "a")()
b = cases.allocate(cartesian_case, cast_tuple, "b")()
a_asint = gtx.as_field([IDim], np.asarray(a).astype(int32))
b_asint = gtx.as_field([IDim], np.asarray(b).astype(int32))
a_asint = cartesian_case.as_field([IDim], np.asarray(a).astype(int32))
b_asint = cartesian_case.as_field([IDim], np.asarray(b).astype(int32))
out_tuple = cases.allocate(cartesian_case, cast_tuple, cases.RETURN)()
out_nested_tuple = cases.allocate(cartesian_case, cast_nested_tuple, cases.RETURN)()

Expand Down Expand Up @@ -589,7 +589,7 @@ def testee(a: tuple[tuple[cases.IField, cases.IField], cases.IField]) -> cases.I
def test_fieldop_from_scan(cartesian_case, forward):
init = 1.0
expected = np.arange(init + 1.0, init + 1.0 + cartesian_case.default_sizes[IDim], 1)
out = gtx.as_field([KDim], np.zeros((cartesian_case.default_sizes[KDim],)))
out = cartesian_case.as_field([KDim], np.zeros((cartesian_case.default_sizes[KDim],)))

if not forward:
expected = np.flip(expected)
Expand All @@ -610,6 +610,7 @@ def simple_scan_operator(carry: float) -> float:
def test_solve_triag(cartesian_case):
if cartesian_case.backend in [
gtfn.run_gtfn,
gtfn.run_gtfn_gpu,
gtfn.run_gtfn_imperative,
gtfn.run_gtfn_with_temporaries,
]:
Expand Down Expand Up @@ -723,8 +724,8 @@ def simple_scan_operator(carry: float, a: float) -> float:
return carry if carry > a else carry + 1.0

k_size = cartesian_case.default_sizes[KDim]
a = gtx.as_field([KDim], 4.0 * np.ones((k_size,)))
out = gtx.as_field([KDim], np.zeros((k_size,)))
a = cartesian_case.as_field([KDim], 4.0 * np.ones((k_size,)))
out = cartesian_case.as_field([KDim], np.zeros((k_size,)))

cases.verify(
cartesian_case,
Expand Down Expand Up @@ -773,16 +774,19 @@ def testee(out: tuple[cases.KField, tuple[cases.KField, cases.KField]]):
def test_scan_nested_tuple_input(cartesian_case):
init = 1.0
k_size = cartesian_case.default_sizes[KDim]
inp1 = gtx.as_field([KDim], np.ones((k_size,)))
inp2 = gtx.as_field([KDim], np.arange(0.0, k_size, 1))
out = gtx.as_field([KDim], np.zeros((k_size,)))

inp1_np = np.ones((k_size,))
inp2_np = np.arange(0.0, k_size, 1)
inp1 = cartesian_case.as_field([KDim], inp1_np)
inp2 = cartesian_case.as_field([KDim], inp2_np)
out = cartesian_case.as_field([KDim], np.zeros((k_size,)))

def prev_levels_iterator(i):
return range(i + 1)

expected = np.asarray(
[
reduce(lambda prev, i: prev + inp1[i] + inp2[i], prev_levels_iterator(i), init)
reduce(lambda prev, i: prev + inp1_np[i] + inp2_np[i], prev_levels_iterator(i), init)
for i in range(k_size)
]
)
Expand Down Expand Up @@ -842,7 +846,7 @@ def program_domain(a: cases.IField, out: cases.IField):
a = cases.allocate(cartesian_case, program_domain, "a")()
out = cases.allocate(cartesian_case, program_domain, "out")()

ref = out.ndarray.copy() # ensure we are not overwriting out outside of the domain
ref = np.asarray(out).copy() # ensure we are not overwriting `out` outside of the domain
ref[1:9] = a[1:9] * 2

cases.verify(cartesian_case, program_domain, a, out, inout=out, ref=ref)
Expand All @@ -851,6 +855,7 @@ def program_domain(a: cases.IField, out: cases.IField):
def test_domain_input_bounds(cartesian_case):
if cartesian_case.backend in [
gtfn.run_gtfn,
gtfn.run_gtfn_gpu,
gtfn.run_gtfn_imperative,
gtfn.run_gtfn_with_temporaries,
]:
Expand All @@ -876,7 +881,7 @@ def program_domain(
inp = cases.allocate(cartesian_case, program_domain, "inp")()
out = cases.allocate(cartesian_case, fieldop_domain, cases.RETURN)()

ref = out.ndarray.copy()
ref = np.asarray(out).copy()
ref[lower_i : int(upper_i / 2)] = inp[lower_i : int(upper_i / 2)] * 2

cases.verify(
Expand Down Expand Up @@ -919,7 +924,7 @@ def program_domain(
a = cases.allocate(cartesian_case, program_domain, "a")()
out = cases.allocate(cartesian_case, program_domain, "out")()

ref = out.ndarray.copy()
ref = np.asarray(out).copy()
ref[1 * lower_i : upper_i + 0, lower_j - 0 : upper_j] = (
a[1 * lower_i : upper_i + 0, lower_j - 0 : upper_j] * 2
)
Expand Down Expand Up @@ -959,9 +964,9 @@ def program_domain_tuple(
out0 = cases.allocate(cartesian_case, program_domain_tuple, "out0")()
out1 = cases.allocate(cartesian_case, program_domain_tuple, "out1")()

ref0 = out0.ndarray.copy()
ref0 = np.asarray(out0).copy()
ref0[1:9, 4:6] = inp0[1:9, 4:6] + inp1[1:9, 4:6]
ref1 = out1.ndarray.copy()
ref1 = np.asarray(out1).copy()
ref1[1:9, 4:6] = inp1[1:9, 4:6]

cases.verify(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,9 @@ def testee(
inp * ones(V2E), axis=V2EDim
) # multiplication with shifted `ones` because reduction of only non-shifted field with local dimension is not supported

inp = gtx.as_field([Vertex, V2EDim], unstructured_case.offset_provider["V2E"].table)
inp = unstructured_case.as_field(
[Vertex, V2EDim], unstructured_case.offset_provider["V2E"].table
)
ones = cases.allocate(unstructured_case, testee, "ones").strategy(cases.ConstInitializer(1))()

cases.verify(
Expand All @@ -59,7 +61,9 @@ def test_external_local_field_only(unstructured_case):
def testee(inp: gtx.Field[[Vertex, V2EDim], int32]) -> gtx.Field[[Vertex], int32]:
return neighbor_sum(inp, axis=V2EDim)

inp = gtx.as_field([Vertex, V2EDim], unstructured_case.offset_provider["V2E"].table)
inp = unstructured_case.as_field(
[Vertex, V2EDim], unstructured_case.offset_provider["V2E"].table
)

cases.verify(
unstructured_case,
Expand Down
Loading

0 comments on commit 42912cc

Please sign in to comment.