From f16d03bbfa324f933106294f5f303caa293b6216 Mon Sep 17 00:00:00 2001 From: samkellerhals Date: Mon, 22 Apr 2024 12:36:18 +0200 Subject: [PATCH 01/18] optimise extract_connectivity_args --- .../next/program_processors/runners/gtfn.py | 50 +++++++++++-------- 1 file changed, 29 insertions(+), 21 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/gtfn.py b/src/gt4py/next/program_processors/runners/gtfn.py index 39ec607323..e883aad75e 100644 --- a/src/gt4py/next/program_processors/runners/gtfn.py +++ b/src/gt4py/next/program_processors/runners/gtfn.py @@ -22,7 +22,7 @@ import gt4py._core.definitions as core_defs import gt4py.next.allocators as next_allocators from gt4py.eve.utils import content_hash -from gt4py.next import backend, common, config +from gt4py.next import NeighborTableOffsetProvider, backend, common, config from gt4py.next.iterator import transforms from gt4py.next.iterator.transforms import global_tmps from gt4py.next.otf import recipes, stages, workflow @@ -74,27 +74,35 @@ def _ensure_is_on_device( return connectivity_arg +ConnectivityArg = tuple[core_defs.NDArrayObject, tuple[int, ...]] + + +def handle_connectivity( + conn: NeighborTableOffsetProvider, zero_tuple: tuple[int, ...] +) -> ConnectivityArg: + return (conn.table, zero_tuple) + + +def handle_other_type(*args: Any, **kwargs: Any) -> None: + return None + + +type_handlers = { + NeighborTableOffsetProvider: handle_connectivity, + common.Dimension: handle_other_type, +} + + def extract_connectivity_args( - offset_provider: dict[str, common.Connectivity | common.Dimension], device: core_defs.DeviceType -) -> list[tuple[npt.NDArray, tuple[int, ...]]]: - # note: the order here needs to agree with the order of the generated bindings - args: list[tuple[npt.NDArray, tuple[int, ...]]] = [] - for name, conn in offset_provider.items(): - if isinstance(conn, common.Connectivity): - if not isinstance(conn, common.NeighborTable): - raise NotImplementedError( - "Only 'NeighborTable' connectivities implemented at this point." - ) - # copying to device here is a fallback for easy testing and might be removed later - conn_arg = _ensure_is_on_device(conn.table, device) - args.append((conn_arg, tuple([0] * 2))) - elif isinstance(conn, common.Dimension): - pass - else: - raise AssertionError( - f"Expected offset provider '{name}' to be a 'Connectivity' or 'Dimension', " - f"but got '{type(conn).__name__}'." - ) + offset_provider: dict[str, Any], device: core_defs.DeviceType +) -> list[ConnectivityArg]: + zero_tuple = (0, 0) + args = [] + for conn in offset_provider.values(): + handler = type_handlers.get(type(conn), handle_other_type) + result = handler(conn, zero_tuple) # type: ignore + if result: + args.append(result) return args From 4e2a31f667b27b982966806737a2fe24dca3eee9 Mon Sep 17 00:00:00 2001 From: samkellerhals Date: Mon, 22 Apr 2024 17:18:29 +0200 Subject: [PATCH 02/18] Remove isinstance checks from convert_args --- .../next/program_processors/runners/gtfn.py | 40 ++++++++++++++----- 1 file changed, 29 insertions(+), 11 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/gtfn.py b/src/gt4py/next/program_processors/runners/gtfn.py index e883aad75e..d4d03e732f 100644 --- a/src/gt4py/next/program_processors/runners/gtfn.py +++ b/src/gt4py/next/program_processors/runners/gtfn.py @@ -14,7 +14,7 @@ import functools import warnings -from typing import Any +from typing import Any, Callable import factory import numpy.typing as npt @@ -23,6 +23,7 @@ import gt4py.next.allocators as next_allocators from gt4py.eve.utils import content_hash from gt4py.next import NeighborTableOffsetProvider, backend, common, config +from gt4py.next.embedded.nd_array_field import CuPyArrayField, NumPyArrayField from gt4py.next.iterator import transforms from gt4py.next.iterator.transforms import global_tmps from gt4py.next.otf import recipes, stages, workflow @@ -34,16 +35,33 @@ from gt4py.next.type_system.type_translation import from_value -# TODO(ricoh): Add support for the whole range of arguments that can be passed to a fencil. +def handle_tuple(arg: Any, convert_arg: Callable) -> Any: + return tuple(convert_arg(a) for a in arg) + + +def handle_field(arg: Any) -> tuple: + arr = arg.ndarray + origin = getattr(arg, "__gt_origin__", tuple([0] * len(arg.domain))) + return arr, origin + + +def handle_default(arg: Any) -> Any: + return arg + + +type_handlers_convert_args = { + tuple: handle_tuple, + NumPyArrayField: handle_field, + CuPyArrayField: handle_field, +} + + def convert_arg(arg: Any) -> Any: - if isinstance(arg, tuple): - return tuple(convert_arg(a) for a in arg) - if isinstance(arg, common.Field): - arr = arg.ndarray - origin = getattr(arg, "__gt_origin__", tuple([0] * len(arg.domain))) - return arr, origin + handler = type_handlers_convert_args.get(type(arg), handle_default) # type: ignore + if handler is handle_tuple: + return handler(arg, convert_arg) else: - return arg + return handler(arg) def convert_args( @@ -87,7 +105,7 @@ def handle_other_type(*args: Any, **kwargs: Any) -> None: return None -type_handlers = { +type_handlers_connectivity_args = { NeighborTableOffsetProvider: handle_connectivity, common.Dimension: handle_other_type, } @@ -99,7 +117,7 @@ def extract_connectivity_args( zero_tuple = (0, 0) args = [] for conn in offset_provider.values(): - handler = type_handlers.get(type(conn), handle_other_type) + handler = type_handlers_connectivity_args.get(type(conn), handle_other_type) result = handler(conn, zero_tuple) # type: ignore if result: args.append(result) From ddf16f106f13fd30d1a358c1a1cf395021aa7bd1 Mon Sep 17 00:00:00 2001 From: samkellerhals Date: Mon, 29 Apr 2024 13:56:24 +0200 Subject: [PATCH 03/18] More small optimisations, and connecitivities caching --- src/gt4py/next/embedded/nd_array_field.py | 12 ------------ src/gt4py/next/iterator/embedded.py | 1 - src/gt4py/next/program_processors/runners/gtfn.py | 8 ++++++-- 3 files changed, 6 insertions(+), 15 deletions(-) diff --git a/src/gt4py/next/embedded/nd_array_field.py b/src/gt4py/next/embedded/nd_array_field.py index e290da33a2..157002a203 100644 --- a/src/gt4py/next/embedded/nd_array_field.py +++ b/src/gt4py/next/embedded/nd_array_field.py @@ -160,21 +160,9 @@ def from_array( domain: common.DomainLike, dtype: Optional[core_defs.DTypeLike] = None, ) -> NdArrayField: - domain = common.domain(domain) xp = cls.array_ns - xp_dtype = None if dtype is None else xp.dtype(core_defs.dtype(dtype).scalar_type) array = xp.asarray(data, dtype=xp_dtype) - - if dtype is not None: - assert array.dtype.type == core_defs.dtype(dtype).scalar_type - - assert issubclass(array.dtype.type, core_defs.SCALAR_TYPES) - - assert all(isinstance(d, common.Dimension) for d in domain.dims), domain - assert len(domain) == array.ndim - assert all(s == 1 or len(r) == s for r, s in zip(domain.ranges, array.shape)) - return cls(domain, array) def remap( diff --git a/src/gt4py/next/iterator/embedded.py b/src/gt4py/next/iterator/embedded.py index f5d4c6e53b..c57f8f75d8 100644 --- a/src/gt4py/next/iterator/embedded.py +++ b/src/gt4py/next/iterator/embedded.py @@ -1000,7 +1000,6 @@ def _shift_field_indices( def np_as_located_field( *axes: common.Dimension, origin: Optional[dict[common.Dimension, int]] = None ) -> Callable[[np.ndarray], common.Field]: - warnings.warn("`np_as_located_field()` is deprecated, use `gtx.as_field()`", DeprecationWarning) # noqa: B028 [no-explicit-stacklevel] origin = origin or {} diff --git a/src/gt4py/next/program_processors/runners/gtfn.py b/src/gt4py/next/program_processors/runners/gtfn.py index d4d03e732f..21313d32c2 100644 --- a/src/gt4py/next/program_processors/runners/gtfn.py +++ b/src/gt4py/next/program_processors/runners/gtfn.py @@ -68,10 +68,14 @@ def convert_args( inp: stages.CompiledProgram, device: core_defs.DeviceType = core_defs.DeviceType.CPU ) -> stages.CompiledProgram: def decorated_program( - *args: Any, offset_provider: dict[str, common.Connectivity | common.Dimension] + *args: Any, conn_args = None, offset_provider: dict[str, common.Connectivity | common.Dimension] ) -> None: + + # If we don't pass them as in the case of a CachedProgram extract connectivities here. + if conn_args is None: + conn_args = extract_connectivity_args(offset_provider, device) + converted_args = [convert_arg(arg) for arg in args] - conn_args = extract_connectivity_args(offset_provider, device) return inp(*converted_args, *conn_args) return decorated_program From 0f364a9d132d09fab0fc0b9332ec405e895eee2b Mon Sep 17 00:00:00 2001 From: samkellerhals Date: Thu, 2 May 2024 11:45:01 +0200 Subject: [PATCH 04/18] Only do asserts in debug mode --- src/gt4py/next/embedded/nd_array_field.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/src/gt4py/next/embedded/nd_array_field.py b/src/gt4py/next/embedded/nd_array_field.py index 157002a203..07b1d09dd6 100644 --- a/src/gt4py/next/embedded/nd_array_field.py +++ b/src/gt4py/next/embedded/nd_array_field.py @@ -163,6 +163,18 @@ def from_array( xp = cls.array_ns xp_dtype = None if dtype is None else xp.dtype(core_defs.dtype(dtype).scalar_type) array = xp.asarray(data, dtype=xp_dtype) + + if __debug__: + domain = common.domain(domain) + if dtype is not None: + assert array.dtype.type == core_defs.dtype(dtype).scalar_type + + assert issubclass(array.dtype.type, core_defs.SCALAR_TYPES) + + assert all(isinstance(d, common.Dimension) for d in domain.dims), domain + assert len(domain) == array.ndim + assert all(s == 1 or len(r) == s for r, s in zip(domain.ranges, array.shape)) + return cls(domain, array) def remap( From 789b21122a73dc0e56f96a89aad3bfad30ffe9bc Mon Sep 17 00:00:00 2001 From: samkellerhals Date: Thu, 2 May 2024 14:11:12 +0200 Subject: [PATCH 05/18] Import CuPyArrayField only when cp is available --- .../next/program_processors/runners/gtfn.py | 34 +++++++------------ 1 file changed, 12 insertions(+), 22 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/gtfn.py b/src/gt4py/next/program_processors/runners/gtfn.py index 21313d32c2..51bd5f7059 100644 --- a/src/gt4py/next/program_processors/runners/gtfn.py +++ b/src/gt4py/next/program_processors/runners/gtfn.py @@ -23,7 +23,7 @@ import gt4py.next.allocators as next_allocators from gt4py.eve.utils import content_hash from gt4py.next import NeighborTableOffsetProvider, backend, common, config -from gt4py.next.embedded.nd_array_field import CuPyArrayField, NumPyArrayField +from gt4py.next.embedded.nd_array_field import NumPyArrayField from gt4py.next.iterator import transforms from gt4py.next.iterator.transforms import global_tmps from gt4py.next.otf import recipes, stages, workflow @@ -34,7 +34,6 @@ from gt4py.next.program_processors.codegens.gtfn import gtfn_module from gt4py.next.type_system.type_translation import from_value - def handle_tuple(arg: Any, convert_arg: Callable) -> Any: return tuple(convert_arg(a) for a in arg) @@ -45,16 +44,22 @@ def handle_field(arg: Any) -> tuple: return arr, origin -def handle_default(arg: Any) -> Any: - return arg - - type_handlers_convert_args = { tuple: handle_tuple, NumPyArrayField: handle_field, - CuPyArrayField: handle_field, } +try: + import cupy as cp + from gt4py.next.embedded.nd_array_field import CuPyArrayField + type_handlers_convert_args[CuPyArrayField] = handle_field +except ImportError: + cp = None + + +def handle_default(arg: Any) -> Any: + return arg + def convert_arg(arg: Any) -> Any: handler = type_handlers_convert_args.get(type(arg), handle_default) # type: ignore @@ -81,21 +86,6 @@ def decorated_program( return decorated_program -def _ensure_is_on_device( - connectivity_arg: npt.NDArray, device: core_defs.DeviceType -) -> npt.NDArray: - if device == core_defs.DeviceType.CUDA: - import cupy as cp - - if not isinstance(connectivity_arg, cp.ndarray): - warnings.warn( - "Copying connectivity to device. For performance make sure connectivity is provided on device.", - stacklevel=2, - ) - return cp.asarray(connectivity_arg) - return connectivity_arg - - ConnectivityArg = tuple[core_defs.NDArrayObject, tuple[int, ...]] From a674725abdc085cd3e61d063bc116057336ba8cd Mon Sep 17 00:00:00 2001 From: samkellerhals Date: Thu, 2 May 2024 15:01:48 +0200 Subject: [PATCH 06/18] Run precommit --- src/gt4py/next/embedded/nd_array_field.py | 15 +++++++-------- src/gt4py/next/iterator/embedded.py | 2 -- .../next/program_processors/runners/gtfn.py | 16 +++++++++------- 3 files changed, 16 insertions(+), 17 deletions(-) diff --git a/src/gt4py/next/embedded/nd_array_field.py b/src/gt4py/next/embedded/nd_array_field.py index 07b1d09dd6..cea4325218 100644 --- a/src/gt4py/next/embedded/nd_array_field.py +++ b/src/gt4py/next/embedded/nd_array_field.py @@ -163,17 +163,16 @@ def from_array( xp = cls.array_ns xp_dtype = None if dtype is None else xp.dtype(core_defs.dtype(dtype).scalar_type) array = xp.asarray(data, dtype=xp_dtype) + domain = common.domain(domain) - if __debug__: - domain = common.domain(domain) - if dtype is not None: - assert array.dtype.type == core_defs.dtype(dtype).scalar_type + if dtype is not None: + assert array.dtype.type == core_defs.dtype(dtype).scalar_type - assert issubclass(array.dtype.type, core_defs.SCALAR_TYPES) + assert issubclass(array.dtype.type, core_defs.SCALAR_TYPES) - assert all(isinstance(d, common.Dimension) for d in domain.dims), domain - assert len(domain) == array.ndim - assert all(s == 1 or len(r) == s for r, s in zip(domain.ranges, array.shape)) + assert all(isinstance(d, common.Dimension) for d in domain.dims), domain + assert len(domain) == array.ndim + assert all(s == 1 or len(r) == s for r, s in zip(domain.ranges, array.shape)) return cls(domain, array) diff --git a/src/gt4py/next/iterator/embedded.py b/src/gt4py/next/iterator/embedded.py index c57f8f75d8..02315206cf 100644 --- a/src/gt4py/next/iterator/embedded.py +++ b/src/gt4py/next/iterator/embedded.py @@ -22,7 +22,6 @@ import itertools import math import sys -import warnings import numpy as np import numpy.typing as npt @@ -1000,7 +999,6 @@ def _shift_field_indices( def np_as_located_field( *axes: common.Dimension, origin: Optional[dict[common.Dimension, int]] = None ) -> Callable[[np.ndarray], common.Field]: - origin = origin or {} def _maker(a) -> common.Field: diff --git a/src/gt4py/next/program_processors/runners/gtfn.py b/src/gt4py/next/program_processors/runners/gtfn.py index 51bd5f7059..d2272f1a3c 100644 --- a/src/gt4py/next/program_processors/runners/gtfn.py +++ b/src/gt4py/next/program_processors/runners/gtfn.py @@ -13,11 +13,9 @@ # SPDX-License-Identifier: GPL-3.0-or-later import functools -import warnings -from typing import Any, Callable +from typing import Any, Callable, Optional import factory -import numpy.typing as npt import gt4py._core.definitions as core_defs import gt4py.next.allocators as next_allocators @@ -34,6 +32,7 @@ from gt4py.next.program_processors.codegens.gtfn import gtfn_module from gt4py.next.type_system.type_translation import from_value + def handle_tuple(arg: Any, convert_arg: Callable) -> Any: return tuple(convert_arg(a) for a in arg) @@ -51,7 +50,9 @@ def handle_field(arg: Any) -> tuple: try: import cupy as cp + from gt4py.next.embedded.nd_array_field import CuPyArrayField + type_handlers_convert_args[CuPyArrayField] = handle_field except ImportError: cp = None @@ -73,12 +74,13 @@ def convert_args( inp: stages.CompiledProgram, device: core_defs.DeviceType = core_defs.DeviceType.CPU ) -> stages.CompiledProgram: def decorated_program( - *args: Any, conn_args = None, offset_provider: dict[str, common.Connectivity | common.Dimension] + *args: Any, + conn_args: Optional[list[ConnectivityArg]] = None, + offset_provider: dict[str, common.Connectivity | common.Dimension], ) -> None: - # If we don't pass them as in the case of a CachedProgram extract connectivities here. if conn_args is None: - conn_args = extract_connectivity_args(offset_provider, device) + conn_args = extract_connectivity_args(offset_provider) converted_args = [convert_arg(arg) for arg in args] return inp(*converted_args, *conn_args) @@ -106,7 +108,7 @@ def handle_other_type(*args: Any, **kwargs: Any) -> None: def extract_connectivity_args( - offset_provider: dict[str, Any], device: core_defs.DeviceType + offset_provider: dict[str, Any], ) -> list[ConnectivityArg]: zero_tuple = (0, 0) args = [] From 0829ab4e91aac8fda55e57aa789de0638098df60 Mon Sep 17 00:00:00 2001 From: samkellerhals Date: Mon, 6 May 2024 09:37:57 +0200 Subject: [PATCH 07/18] Add _ensure_is_on_device checks to run tests --- .../next/program_processors/runners/gtfn.py | 26 ++++++++++++++++--- 1 file changed, 22 insertions(+), 4 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/gtfn.py b/src/gt4py/next/program_processors/runners/gtfn.py index b661d59fdb..a7f38ef745 100644 --- a/src/gt4py/next/program_processors/runners/gtfn.py +++ b/src/gt4py/next/program_processors/runners/gtfn.py @@ -13,9 +13,11 @@ # SPDX-License-Identifier: GPL-3.0-or-later import functools +import warnings from typing import Any, Callable, Optional import factory +import numpy.typing as npt import gt4py._core.definitions as core_defs import gt4py.next.allocators as next_allocators @@ -80,7 +82,7 @@ def decorated_program( ) -> None: # If we don't pass them as in the case of a CachedProgram extract connectivities here. if conn_args is None: - conn_args = extract_connectivity_args(offset_provider) + conn_args = extract_connectivity_args(offset_provider, device) converted_args = [convert_arg(arg) for arg in args] return inp(*converted_args, *conn_args) @@ -107,14 +109,30 @@ def handle_other_type(*args: Any, **kwargs: Any) -> None: } +def _ensure_is_on_device( + connectivity_arg: npt.NDArray, device: core_defs.DeviceType +) -> npt.NDArray: + if device == core_defs.DeviceType.CUDA: + import cupy as cp + + if not isinstance(connectivity_arg, cp.ndarray): + warnings.warn( + "Copying connectivity to device. For performance make sure connectivity is provided on device.", + stacklevel=2, + ) + return cp.asarray(connectivity_arg) + return connectivity_arg + + def extract_connectivity_args( - offset_provider: dict[str, Any], + offset_provider: dict[str, Any], device: core_defs.DeviceType ) -> list[ConnectivityArg]: zero_tuple = (0, 0) args = [] for conn in offset_provider.values(): - handler = type_handlers_connectivity_args.get(type(conn), handle_other_type) - result = handler(conn, zero_tuple) # type: ignore + conn_arg = _ensure_is_on_device(conn.table, device) + handler = type_handlers_connectivity_args.get(type(conn_arg), handle_other_type) + result = handler(conn_arg, zero_tuple) # type: ignore if result: args.append(result) return args From e717b3382de48d91436b264b978aca2ef445f225 Mon Sep 17 00:00:00 2001 From: samkellerhals Date: Mon, 6 May 2024 09:49:45 +0200 Subject: [PATCH 08/18] Place _ensure_is_on_device in right place --- src/gt4py/next/program_processors/runners/gtfn.py | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/gtfn.py b/src/gt4py/next/program_processors/runners/gtfn.py index a7f38ef745..4d52085270 100644 --- a/src/gt4py/next/program_processors/runners/gtfn.py +++ b/src/gt4py/next/program_processors/runners/gtfn.py @@ -17,7 +17,6 @@ from typing import Any, Callable, Optional import factory -import numpy.typing as npt import gt4py._core.definitions as core_defs import gt4py.next.allocators as next_allocators @@ -94,9 +93,9 @@ def decorated_program( def handle_connectivity( - conn: NeighborTableOffsetProvider, zero_tuple: tuple[int, ...] + conn: NeighborTableOffsetProvider, zero_tuple: tuple[int, ...], device: core_defs.DeviceType ) -> ConnectivityArg: - return (conn.table, zero_tuple) + return (_ensure_is_on_device(conn.table, device), zero_tuple) def handle_other_type(*args: Any, **kwargs: Any) -> None: @@ -110,8 +109,8 @@ def handle_other_type(*args: Any, **kwargs: Any) -> None: def _ensure_is_on_device( - connectivity_arg: npt.NDArray, device: core_defs.DeviceType -) -> npt.NDArray: + connectivity_arg: core_defs.NDArrayObject, device: core_defs.DeviceType +) -> core_defs.NDArrayObject: if device == core_defs.DeviceType.CUDA: import cupy as cp @@ -130,9 +129,8 @@ def extract_connectivity_args( zero_tuple = (0, 0) args = [] for conn in offset_provider.values(): - conn_arg = _ensure_is_on_device(conn.table, device) - handler = type_handlers_connectivity_args.get(type(conn_arg), handle_other_type) - result = handler(conn_arg, zero_tuple) # type: ignore + handler = type_handlers_connectivity_args.get(type(conn), handle_other_type) + result = handler(conn, zero_tuple, device) # type: ignore if result: args.append(result) return args From d921bca4d3f4c7b1ac8404c34ce8a4a72615a94b Mon Sep 17 00:00:00 2001 From: samkellerhals Date: Tue, 7 May 2024 11:04:23 +0200 Subject: [PATCH 09/18] Add deprecation warning --- src/gt4py/next/iterator/embedded.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/src/gt4py/next/iterator/embedded.py b/src/gt4py/next/iterator/embedded.py index 02315206cf..91eed53280 100644 --- a/src/gt4py/next/iterator/embedded.py +++ b/src/gt4py/next/iterator/embedded.py @@ -22,6 +22,7 @@ import itertools import math import sys +import warnings import numpy as np import numpy.typing as npt @@ -999,6 +1000,12 @@ def _shift_field_indices( def np_as_located_field( *axes: common.Dimension, origin: Optional[dict[common.Dimension, int]] = None ) -> Callable[[np.ndarray], common.Field]: + if __debug__: + warnings.warn( + "`np_as_located_field()` is deprecated, use `gtx.as_field()`", + DeprecationWarning, + stacklevel=2, + ) origin = origin or {} def _maker(a) -> common.Field: From 9eb29be9bb80afd10434c0835d4fc3be78cab1d5 Mon Sep 17 00:00:00 2001 From: samkellerhals Date: Wed, 11 Dec 2024 10:43:38 +0100 Subject: [PATCH 10/18] Add ConnectivityFields to handler --- src/gt4py/next/program_processors/runners/gtfn.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/gt4py/next/program_processors/runners/gtfn.py b/src/gt4py/next/program_processors/runners/gtfn.py index 1e7f530fd6..a2e1b62569 100644 --- a/src/gt4py/next/program_processors/runners/gtfn.py +++ b/src/gt4py/next/program_processors/runners/gtfn.py @@ -20,7 +20,7 @@ from gt4py.eve import utils from gt4py.eve.utils import content_hash from gt4py.next import NeighborTableOffsetProvider, backend, common, config -from gt4py.next.embedded.nd_array_field import NumPyArrayField +from gt4py.next.embedded import nd_array_field from gt4py.next.iterator import ir as itir from gt4py.next.otf import arguments from gt4py.next.otf import recipes, stages, workflow @@ -115,6 +115,8 @@ def handle_invalid_type( type_handlers_connectivity_args = { NeighborTableOffsetProvider: handle_connectivity, + nd_array_field.NumPyArrayConnectivityField: handle_connectivity, + nd_array_field.CuPyArrayConnectivityField: handle_connectivity, common.Dimension: handle_dimension, } From 8a0793b94e0cc31c311a5b84a753376d7a701c3f Mon Sep 17 00:00:00 2001 From: samkellerhals Date: Wed, 11 Dec 2024 11:04:17 +0100 Subject: [PATCH 11/18] Fix typing issues --- .../next/program_processors/runners/gtfn.py | 48 +++++++++---------- 1 file changed, 24 insertions(+), 24 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/gtfn.py b/src/gt4py/next/program_processors/runners/gtfn.py index a2e1b62569..c82af50a66 100644 --- a/src/gt4py/next/program_processors/runners/gtfn.py +++ b/src/gt4py/next/program_processors/runners/gtfn.py @@ -15,15 +15,15 @@ import diskcache import factory import filelock + import gt4py._core.definitions as core_defs import gt4py.next.allocators as next_allocators from gt4py.eve import utils from gt4py.eve.utils import content_hash -from gt4py.next import NeighborTableOffsetProvider, backend, common, config +from gt4py.next import backend, common, config from gt4py.next.embedded import nd_array_field -from gt4py.next.iterator import ir as itir -from gt4py.next.otf import arguments -from gt4py.next.otf import recipes, stages, workflow +from gt4py.next.iterator import embedded, ir as itir +from gt4py.next.otf import arguments, recipes, stages, workflow from gt4py.next.otf.binding import nanobind from gt4py.next.otf.compilation import compiler from gt4py.next.otf.compilation.build_systems import compiledb @@ -42,15 +42,13 @@ def handle_field(arg: Any) -> tuple: type_handlers_convert_args = { tuple: handle_tuple, - NumPyArrayField: handle_field, + nd_array_field.NumPyArrayField: handle_field, } try: import cupy as cp - from gt4py.next.embedded.nd_array_field import CuPyArrayField - - type_handlers_convert_args[CuPyArrayField] = handle_field + type_handlers_convert_args[nd_array_field.CuPyArrayField] = handle_field except ImportError: cp = None @@ -60,12 +58,12 @@ def handle_default(arg: Any) -> Any: def convert_args( - inp: stages.ExtendedCompiledProgram, device: core_defs.DeviceType = core_defs.DeviceType.CPU + inp: stages.ExtendedCompiledProgram, device: core_defs.DeviceType = core_defs.DeviceType.CPU ) -> stages.CompiledProgram: def decorated_program( - *args: Any, - offset_provider: dict[str, common.Connectivity | common.Dimension], - out: Any = None, + *args: Any, + offset_provider: dict[str, common.Connectivity | common.Dimension], + out: Any = None, ) -> None: if out is not None: args = (*args, out) @@ -93,20 +91,21 @@ def convert_arg(arg: Any) -> Any: def handle_connectivity( - conn: NeighborTableOffsetProvider, zero_tuple: tuple[int, ...], device: core_defs.DeviceType, copy: bool + conn: embedded.NeighborTableOffsetProvider, # type: ignore + zero_tuple: tuple[int, ...], + device: core_defs.DeviceType, + copy: bool, ) -> ConnectivityArg: if not copy: - return (conn.table, zero_tuple) - return (_ensure_is_on_device(conn.table, device), zero_tuple) + return (conn.table, zero_tuple) # type: ignore + return (_ensure_is_on_device(conn.table, device), zero_tuple) # type: ignore def handle_dimension(*args: Any, **kwargs: Any) -> None: return None -def handle_invalid_type( - conn: Any, *args: Any, **kwargs: Any -) -> None: +def handle_invalid_type(conn: Any, *args: Any, **kwargs: Any) -> None: raise AssertionError( f"Unsupported offset provider type '{type(conn).__name__}'. " "Expected 'Connectivity' or 'Dimension'." @@ -114,15 +113,15 @@ def handle_invalid_type( type_handlers_connectivity_args = { - NeighborTableOffsetProvider: handle_connectivity, + embedded.NeighborTableOffsetProvider: handle_connectivity, nd_array_field.NumPyArrayConnectivityField: handle_connectivity, - nd_array_field.CuPyArrayConnectivityField: handle_connectivity, + nd_array_field.CuPyArrayConnectivityField: handle_connectivity, common.Dimension: handle_dimension, } def _ensure_is_on_device( - connectivity_arg: core_defs.NDArrayObject, device: core_defs.DeviceType + connectivity_arg: core_defs.NDArrayObject, device: core_defs.DeviceType ) -> core_defs.NDArrayObject: if device in [core_defs.DeviceType.CUDA, core_defs.DeviceType.ROCM]: import cupy as cp @@ -137,12 +136,13 @@ def _ensure_is_on_device( def extract_connectivity_args( - offset_provider: dict[str, common.Connectivity | common.Dimension], device: core_defs.DeviceType, - copy: bool = False + offset_provider: dict[str, common.Connectivity | common.Dimension], + device: core_defs.DeviceType, + copy: bool = False, ) -> list[ConnectivityArg]: zero_tuple = (0, 0) args: list[tuple[core_defs.NDArrayObject, tuple[int, ...]]] = [] - for name, conn in offset_provider.items(): + for conn in offset_provider.values(): handler = type_handlers_connectivity_args.get(type(conn), handle_invalid_type) result = handler(conn, zero_tuple, device, copy) # type: ignore if result: From cef757bf2518da1aae09d494dd410e4b1466556b Mon Sep 17 00:00:00 2001 From: samkellerhals Date: Wed, 11 Dec 2024 11:11:58 +0100 Subject: [PATCH 12/18] Enable deprecation warning again --- src/gt4py/next/iterator/embedded.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/src/gt4py/next/iterator/embedded.py b/src/gt4py/next/iterator/embedded.py index dc9e0eae6a..56a068c8dc 100644 --- a/src/gt4py/next/iterator/embedded.py +++ b/src/gt4py/next/iterator/embedded.py @@ -1160,12 +1160,11 @@ def _shift_field_indices( def np_as_located_field( *axes: common.Dimension, origin: Optional[dict[common.Dimension, int]] = None ) -> Callable[[np.ndarray], common.Field]: - if __debug__: - warnings.warn( - "`np_as_located_field()` is deprecated, use `gtx.as_field()`", - DeprecationWarning, - stacklevel=2, - ) + warnings.warn( + "`np_as_located_field()` is deprecated, use `gtx.as_field()`", + DeprecationWarning, + stacklevel=2, + ) origin = origin or {} def _maker(a) -> common.Field: From 4c90a89731e6941488cacd5de96160b6821f1e23 Mon Sep 17 00:00:00 2001 From: samkellerhals Date: Wed, 11 Dec 2024 11:14:11 +0100 Subject: [PATCH 13/18] Enable deprecation warning again --- src/gt4py/next/iterator/embedded.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/src/gt4py/next/iterator/embedded.py b/src/gt4py/next/iterator/embedded.py index 56a068c8dc..13c64e264e 100644 --- a/src/gt4py/next/iterator/embedded.py +++ b/src/gt4py/next/iterator/embedded.py @@ -1160,11 +1160,8 @@ def _shift_field_indices( def np_as_located_field( *axes: common.Dimension, origin: Optional[dict[common.Dimension, int]] = None ) -> Callable[[np.ndarray], common.Field]: - warnings.warn( - "`np_as_located_field()` is deprecated, use `gtx.as_field()`", - DeprecationWarning, - stacklevel=2, - ) + warnings.warn("`np_as_located_field()` is deprecated, use `gtx.as_field()`", DeprecationWarning) # noqa: B028 [no-explicit-stacklevel] + origin = origin or {} def _maker(a) -> common.Field: From da6516dd4d584ff45b2f8742b589f94b3f4dd0df Mon Sep 17 00:00:00 2001 From: samkellerhals Date: Wed, 11 Dec 2024 11:16:41 +0100 Subject: [PATCH 14/18] Add back domain to right place --- src/gt4py/next/embedded/nd_array_field.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/gt4py/next/embedded/nd_array_field.py b/src/gt4py/next/embedded/nd_array_field.py index 73c89791ec..104d7a2a14 100644 --- a/src/gt4py/next/embedded/nd_array_field.py +++ b/src/gt4py/next/embedded/nd_array_field.py @@ -169,10 +169,10 @@ def from_array( domain: common.DomainLike, dtype: Optional[core_defs.DTypeLike] = None, ) -> NdArrayField: + domain = common.domain(domain) xp = cls.array_ns xp_dtype = None if dtype is None else xp.dtype(core_defs.dtype(dtype).scalar_type) array = xp.asarray(data, dtype=xp_dtype) - domain = common.domain(domain) if dtype is not None: assert array.dtype.type == core_defs.dtype(dtype).scalar_type From 54712fbd3cf92553f3268abe0166551a91fcf564 Mon Sep 17 00:00:00 2001 From: samkellerhals Date: Wed, 11 Dec 2024 11:17:42 +0100 Subject: [PATCH 15/18] Add space --- src/gt4py/next/embedded/nd_array_field.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/gt4py/next/embedded/nd_array_field.py b/src/gt4py/next/embedded/nd_array_field.py index 104d7a2a14..e15fb4266a 100644 --- a/src/gt4py/next/embedded/nd_array_field.py +++ b/src/gt4py/next/embedded/nd_array_field.py @@ -171,6 +171,7 @@ def from_array( ) -> NdArrayField: domain = common.domain(domain) xp = cls.array_ns + xp_dtype = None if dtype is None else xp.dtype(core_defs.dtype(dtype).scalar_type) array = xp.asarray(data, dtype=xp_dtype) From a49e341d5500e009d69eb5ee2d933caf04e72dbe Mon Sep 17 00:00:00 2001 From: samkellerhals Date: Wed, 11 Dec 2024 11:22:58 +0100 Subject: [PATCH 16/18] Remove CupyArrayConnectivityField --- src/gt4py/next/program_processors/runners/gtfn.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/gt4py/next/program_processors/runners/gtfn.py b/src/gt4py/next/program_processors/runners/gtfn.py index c82af50a66..e6fca64720 100644 --- a/src/gt4py/next/program_processors/runners/gtfn.py +++ b/src/gt4py/next/program_processors/runners/gtfn.py @@ -115,7 +115,6 @@ def handle_invalid_type(conn: Any, *args: Any, **kwargs: Any) -> None: type_handlers_connectivity_args = { embedded.NeighborTableOffsetProvider: handle_connectivity, nd_array_field.NumPyArrayConnectivityField: handle_connectivity, - nd_array_field.CuPyArrayConnectivityField: handle_connectivity, common.Dimension: handle_dimension, } From 0b2ca6c58a72bcee7c3ce131ba300f76163a0576 Mon Sep 17 00:00:00 2001 From: samkellerhals Date: Wed, 11 Dec 2024 11:51:17 +0100 Subject: [PATCH 17/18] Handle connecitivity fields --- .../next/program_processors/runners/gtfn.py | 79 +++++++++++-------- 1 file changed, 46 insertions(+), 33 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/gtfn.py b/src/gt4py/next/program_processors/runners/gtfn.py index e6fca64720..5b8d34506f 100644 --- a/src/gt4py/next/program_processors/runners/gtfn.py +++ b/src/gt4py/next/program_processors/runners/gtfn.py @@ -34,7 +34,7 @@ def handle_tuple(arg: Any, convert_arg: Callable) -> Any: return tuple(convert_arg(a) for a in arg) -def handle_field(arg: Any) -> tuple: +def handle_field(arg: nd_array_field.NdArrayField) -> tuple: arr = arg.ndarray origin = getattr(arg, "__gt_origin__", tuple([0] * len(arg.domain))) return arr, origin @@ -45,10 +45,55 @@ def handle_field(arg: Any) -> tuple: nd_array_field.NumPyArrayField: handle_field, } +ConnectivityArg = tuple[core_defs.NDArrayObject, tuple[int, ...]] + + +def handle_neighbortable( + conn: embedded.NeighborTableOffsetProvider, # type: ignore + zero_tuple: tuple[int, ...], + device: core_defs.DeviceType, + copy: bool, +) -> ConnectivityArg: + if not copy: + return (conn.table, zero_tuple) # type: ignore + return (_ensure_is_on_device(conn.table, device), zero_tuple) # type: ignore + + +def handle_connectivity_field( + conn: nd_array_field.NdArrayField, + zero_tuple: tuple[int, ...], + device: core_defs.DeviceType, + copy: bool, +) -> ConnectivityArg: + if not copy: + return (conn.ndarray, zero_tuple) + return (_ensure_is_on_device(conn.ndarray, device), zero_tuple) + + +def handle_dimension(*args: Any, **kwargs: Any) -> None: + return None + + +def handle_invalid_type(conn: Any, *args: Any, **kwargs: Any) -> None: + raise AssertionError( + f"Unsupported offset provider type '{type(conn).__name__}'. " + "Expected 'Connectivity' or 'Dimension'." + ) + + +type_handlers_connectivity_args = { + embedded.NeighborTableOffsetProvider: handle_neighbortable, + nd_array_field.NumPyArrayConnectivityField: handle_connectivity_field, + common.Dimension: handle_dimension, +} + try: import cupy as cp type_handlers_convert_args[nd_array_field.CuPyArrayField] = handle_field + type_handlers_connectivity_args[nd_array_field.CuPyArrayConnectivityField] = ( + handle_connectivity_field + ) except ImportError: cp = None @@ -87,38 +132,6 @@ def convert_arg(arg: Any) -> Any: return handler(arg) -ConnectivityArg = tuple[core_defs.NDArrayObject, tuple[int, ...]] - - -def handle_connectivity( - conn: embedded.NeighborTableOffsetProvider, # type: ignore - zero_tuple: tuple[int, ...], - device: core_defs.DeviceType, - copy: bool, -) -> ConnectivityArg: - if not copy: - return (conn.table, zero_tuple) # type: ignore - return (_ensure_is_on_device(conn.table, device), zero_tuple) # type: ignore - - -def handle_dimension(*args: Any, **kwargs: Any) -> None: - return None - - -def handle_invalid_type(conn: Any, *args: Any, **kwargs: Any) -> None: - raise AssertionError( - f"Unsupported offset provider type '{type(conn).__name__}'. " - "Expected 'Connectivity' or 'Dimension'." - ) - - -type_handlers_connectivity_args = { - embedded.NeighborTableOffsetProvider: handle_connectivity, - nd_array_field.NumPyArrayConnectivityField: handle_connectivity, - common.Dimension: handle_dimension, -} - - def _ensure_is_on_device( connectivity_arg: core_defs.NDArrayObject, device: core_defs.DeviceType ) -> core_defs.NDArrayObject: From 2607a81908e6261fd03dbad985f728fc07925edd Mon Sep 17 00:00:00 2001 From: samkellerhals Date: Wed, 11 Dec 2024 12:59:37 +0100 Subject: [PATCH 18/18] Always check if it is on device --- src/gt4py/next/program_processors/runners/gtfn.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/gtfn.py b/src/gt4py/next/program_processors/runners/gtfn.py index 5b8d34506f..f218d754e2 100644 --- a/src/gt4py/next/program_processors/runners/gtfn.py +++ b/src/gt4py/next/program_processors/runners/gtfn.py @@ -52,10 +52,7 @@ def handle_neighbortable( conn: embedded.NeighborTableOffsetProvider, # type: ignore zero_tuple: tuple[int, ...], device: core_defs.DeviceType, - copy: bool, ) -> ConnectivityArg: - if not copy: - return (conn.table, zero_tuple) # type: ignore return (_ensure_is_on_device(conn.table, device), zero_tuple) # type: ignore @@ -65,8 +62,6 @@ def handle_connectivity_field( device: core_defs.DeviceType, copy: bool, ) -> ConnectivityArg: - if not copy: - return (conn.ndarray, zero_tuple) return (_ensure_is_on_device(conn.ndarray, device), zero_tuple) @@ -150,13 +145,12 @@ def _ensure_is_on_device( def extract_connectivity_args( offset_provider: dict[str, common.Connectivity | common.Dimension], device: core_defs.DeviceType, - copy: bool = False, ) -> list[ConnectivityArg]: zero_tuple = (0, 0) args: list[tuple[core_defs.NDArrayObject, tuple[int, ...]]] = [] for conn in offset_provider.values(): handler = type_handlers_connectivity_args.get(type(conn), handle_invalid_type) - result = handler(conn, zero_tuple, device, copy) # type: ignore + result = handler(conn, zero_tuple, device) # type: ignore if result: args.append(result) return args