diff --git a/src/gt4py/next/program_processors/runners/gtfn.py b/src/gt4py/next/program_processors/runners/gtfn.py index c0a9be9168..f218d754e2 100644 --- a/src/gt4py/next/program_processors/runners/gtfn.py +++ b/src/gt4py/next/program_processors/runners/gtfn.py @@ -10,7 +10,7 @@ import pathlib import tempfile import warnings -from typing import Any, Optional +from typing import Any, Callable, Optional import diskcache import factory @@ -21,7 +21,8 @@ from gt4py.eve import utils from gt4py.eve.utils import content_hash from gt4py.next import backend, common, config -from gt4py.next.iterator import ir as itir +from gt4py.next.embedded import nd_array_field +from gt4py.next.iterator import embedded, ir as itir from gt4py.next.otf import arguments, recipes, stages, workflow from gt4py.next.otf.binding import nanobind from gt4py.next.otf.compilation import compiler @@ -29,16 +30,71 @@ from gt4py.next.program_processors.codegens.gtfn import gtfn_module -# TODO(ricoh): Add support for the whole range of arguments that can be passed to a fencil. -def convert_arg(arg: Any) -> Any: - if isinstance(arg, tuple): - return tuple(convert_arg(a) for a in arg) - if isinstance(arg, common.Field): - arr = arg.ndarray - origin = getattr(arg, "__gt_origin__", tuple([0] * len(arg.domain))) - return arr, origin - else: - return arg +def handle_tuple(arg: Any, convert_arg: Callable) -> Any: + return tuple(convert_arg(a) for a in arg) + + +def handle_field(arg: nd_array_field.NdArrayField) -> tuple: + arr = arg.ndarray + origin = getattr(arg, "__gt_origin__", tuple([0] * len(arg.domain))) + return arr, origin + + +type_handlers_convert_args = { + tuple: handle_tuple, + nd_array_field.NumPyArrayField: handle_field, +} + +ConnectivityArg = tuple[core_defs.NDArrayObject, tuple[int, ...]] + + +def handle_neighbortable( + conn: embedded.NeighborTableOffsetProvider, # type: ignore + zero_tuple: tuple[int, ...], + device: core_defs.DeviceType, +) -> ConnectivityArg: + return (_ensure_is_on_device(conn.table, device), zero_tuple) # type: ignore + + +def handle_connectivity_field( + conn: nd_array_field.NdArrayField, + zero_tuple: tuple[int, ...], + device: core_defs.DeviceType, + copy: bool, +) -> ConnectivityArg: + return (_ensure_is_on_device(conn.ndarray, device), zero_tuple) + + +def handle_dimension(*args: Any, **kwargs: Any) -> None: + return None + + +def handle_invalid_type(conn: Any, *args: Any, **kwargs: Any) -> None: + raise AssertionError( + f"Unsupported offset provider type '{type(conn).__name__}'. " + "Expected 'Connectivity' or 'Dimension'." + ) + + +type_handlers_connectivity_args = { + embedded.NeighborTableOffsetProvider: handle_neighbortable, + nd_array_field.NumPyArrayConnectivityField: handle_connectivity_field, + common.Dimension: handle_dimension, +} + +try: + import cupy as cp + + type_handlers_convert_args[nd_array_field.CuPyArrayField] = handle_field + type_handlers_connectivity_args[nd_array_field.CuPyArrayConnectivityField] = ( + handle_connectivity_field + ) +except ImportError: + cp = None + + +def handle_default(arg: Any) -> Any: + return arg def convert_args( @@ -63,6 +119,14 @@ def decorated_program( return decorated_program +def convert_arg(arg: Any) -> Any: + handler = type_handlers_convert_args.get(type(arg), handle_default) # type: ignore + if handler is handle_tuple: + return handler(arg, convert_arg) + else: + return handler(arg) + + def _ensure_is_on_device( connectivity_arg: core_defs.NDArrayObject, device: core_defs.DeviceType ) -> core_defs.NDArrayObject: @@ -79,26 +143,16 @@ def _ensure_is_on_device( def extract_connectivity_args( - offset_provider: dict[str, common.Connectivity | common.Dimension], device: core_defs.DeviceType -) -> list[tuple[core_defs.NDArrayObject, tuple[int, ...]]]: - # note: the order here needs to agree with the order of the generated bindings + offset_provider: dict[str, common.Connectivity | common.Dimension], + device: core_defs.DeviceType, +) -> list[ConnectivityArg]: + zero_tuple = (0, 0) args: list[tuple[core_defs.NDArrayObject, tuple[int, ...]]] = [] - for name, conn in offset_provider.items(): - if isinstance(conn, common.Connectivity): - if not common.is_neighbor_table(conn): - raise NotImplementedError( - "Only 'NeighborTable' connectivities implemented at this point." - ) - # copying to device here is a fallback for easy testing and might be removed later - conn_arg = _ensure_is_on_device(conn.ndarray, device) - args.append((conn_arg, tuple([0] * 2))) - elif isinstance(conn, common.Dimension): - pass - else: - raise AssertionError( - f"Expected offset provider '{name}' to be a 'Connectivity' or 'Dimension', " - f"but got '{type(conn).__name__}'." - ) + for conn in offset_provider.values(): + handler = type_handlers_connectivity_args.get(type(conn), handle_invalid_type) + result = handler(conn, zero_tuple, device) # type: ignore + if result: + args.append(result) return args