Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat[next]: Optimisations for icon4py #1536

Open
wants to merge 25 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 13 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
f16d03b
optimise extract_connectivity_args
samkellerhals Apr 22, 2024
4e2a31f
Remove isinstance checks from convert_args
samkellerhals Apr 22, 2024
ddf16f1
More small optimisations, and connecitivities caching
samkellerhals Apr 29, 2024
0f364a9
Only do asserts in debug mode
samkellerhals May 2, 2024
789b211
Import CuPyArrayField only when cp is available
samkellerhals May 2, 2024
a674725
Run precommit
samkellerhals May 2, 2024
f5fe70d
Merge branch 'main' into optimisations-for-icon4py
samkellerhals May 2, 2024
c5cbe5d
Merge remote-tracking branch 'origin/main' into optimisations-for-ico…
samkellerhals May 2, 2024
b080621
Merge remote-tracking branch 'samkellerhals/optimisations-for-icon4py…
samkellerhals May 2, 2024
4400ede
Merge remote-tracking branch 'origin' into optimisations-for-icon4py
samkellerhals May 3, 2024
0829ab4
Add _ensure_is_on_device checks to run tests
samkellerhals May 6, 2024
e717b33
Place _ensure_is_on_device in right place
samkellerhals May 6, 2024
d921bca
Add deprecation warning
samkellerhals May 7, 2024
33bd45b
Merge branch 'main' of https://github.com/GridTools/gt4py into optimi…
samkellerhals May 24, 2024
c68b970
Merge branch 'main' into optimisations-for-icon4py
philip-paul-mueller Sep 5, 2024
c9c5c8e
Update from main
samkellerhals Dec 11, 2024
9eb29be
Add ConnectivityFields to handler
samkellerhals Dec 11, 2024
8a0793b
Fix typing issues
samkellerhals Dec 11, 2024
cef757b
Enable deprecation warning again
samkellerhals Dec 11, 2024
4c90a89
Enable deprecation warning again
samkellerhals Dec 11, 2024
da6516d
Add back domain to right place
samkellerhals Dec 11, 2024
54712fb
Add space
samkellerhals Dec 11, 2024
a49e341
Remove CupyArrayConnectivityField
samkellerhals Dec 11, 2024
0b2ca6c
Handle connecitivity fields
samkellerhals Dec 11, 2024
2607a81
Always check if it is on device
samkellerhals Dec 11, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions src/gt4py/next/embedded/nd_array_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,11 +160,10 @@ def from_array(
domain: common.DomainLike,
dtype: Optional[core_defs.DTypeLike] = None,
) -> NdArrayField:
domain = common.domain(domain)
xp = cls.array_ns

xp_dtype = None if dtype is None else xp.dtype(core_defs.dtype(dtype).scalar_type)
array = xp.asarray(data, dtype=xp_dtype)
domain = common.domain(domain)

if dtype is not None:
assert array.dtype.type == core_defs.dtype(dtype).scalar_type
Expand Down
8 changes: 6 additions & 2 deletions src/gt4py/next/iterator/embedded.py
Original file line number Diff line number Diff line change
Expand Up @@ -1000,8 +1000,12 @@ def _shift_field_indices(
def np_as_located_field(
*axes: common.Dimension, origin: Optional[dict[common.Dimension, int]] = None
) -> Callable[[np.ndarray], common.Field]:
warnings.warn("`np_as_located_field()` is deprecated, use `gtx.as_field()`", DeprecationWarning) # noqa: B028 [no-explicit-stacklevel]
samkellerhals marked this conversation as resolved.
Show resolved Hide resolved

if __debug__:
warnings.warn(
"`np_as_located_field()` is deprecated, use `gtx.as_field()`",
DeprecationWarning,
stacklevel=2,
)
origin = origin or {}

def _maker(a) -> common.Field:
Expand Down
108 changes: 73 additions & 35 deletions src/gt4py/next/program_processors/runners/gtfn.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,15 @@

import functools
import warnings
from typing import Any
from typing import Any, Callable, Optional

import factory
import numpy.typing as npt

import gt4py._core.definitions as core_defs
import gt4py.next.allocators as next_allocators
from gt4py.eve.utils import content_hash
from gt4py.next import backend, common, config
from gt4py.next import NeighborTableOffsetProvider, backend, common, config
from gt4py.next.embedded.nd_array_field import NumPyArrayField
from gt4py.next.iterator import transforms
from gt4py.next.iterator.transforms import global_tmps
from gt4py.next.otf import recipes, stages, workflow
Expand All @@ -34,34 +34,83 @@
from gt4py.next.type_system.type_translation import from_value


# TODO(ricoh): Add support for the whole range of arguments that can be passed to a fencil.
def handle_tuple(arg: Any, convert_arg: Callable) -> Any:
return tuple(convert_arg(a) for a in arg)


def handle_field(arg: Any) -> tuple:
arr = arg.ndarray
origin = getattr(arg, "__gt_origin__", tuple([0] * len(arg.domain)))
return arr, origin


type_handlers_convert_args = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

@samkellerhals samkellerhals May 7, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it could work, however whilst nicer I have the feeling that it would be slower than a simple dictionary lookup.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it would be preferable to sacrifice readability only on the basis of hard evidence.

tuple: handle_tuple,
NumPyArrayField: handle_field,
}

try:
import cupy as cp

from gt4py.next.embedded.nd_array_field import CuPyArrayField

type_handlers_convert_args[CuPyArrayField] = handle_field
except ImportError:
cp = None


def handle_default(arg: Any) -> Any:
return arg


def convert_arg(arg: Any) -> Any:
if isinstance(arg, tuple):
return tuple(convert_arg(a) for a in arg)
if isinstance(arg, common.Field):
arr = arg.ndarray
origin = getattr(arg, "__gt_origin__", tuple([0] * len(arg.domain)))
return arr, origin
handler = type_handlers_convert_args.get(type(arg), handle_default) # type: ignore
if handler is handle_tuple:
return handler(arg, convert_arg)
else:
return arg
return handler(arg)


def convert_args(
inp: stages.CompiledProgram, device: core_defs.DeviceType = core_defs.DeviceType.CPU
) -> stages.CompiledProgram:
def decorated_program(
*args: Any, offset_provider: dict[str, common.Connectivity | common.Dimension]
*args: Any,
conn_args: Optional[list[ConnectivityArg]] = None,
offset_provider: dict[str, common.Connectivity | common.Dimension],
) -> None:
# If we don't pass them as in the case of a CachedProgram extract connectivities here.
if conn_args is None:
conn_args = extract_connectivity_args(offset_provider, device)

samkellerhals marked this conversation as resolved.
Show resolved Hide resolved
converted_args = [convert_arg(arg) for arg in args]
conn_args = extract_connectivity_args(offset_provider, device)
return inp(*converted_args, *conn_args)

return decorated_program


ConnectivityArg = tuple[core_defs.NDArrayObject, tuple[int, ...]]


def handle_connectivity(
conn: NeighborTableOffsetProvider, zero_tuple: tuple[int, ...], device: core_defs.DeviceType
) -> ConnectivityArg:
return (_ensure_is_on_device(conn.table, device), zero_tuple)


def handle_other_type(*args: Any, **kwargs: Any) -> None:
return None


type_handlers_connectivity_args = {
NeighborTableOffsetProvider: handle_connectivity,
common.Dimension: handle_other_type,
}


Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this pattern proves to be a significant optimization over singledispatch, it still needs to be made more readable.

I propose to encode the pattern in a class. That class then needs a docstring explaining when to use it over standard approaches and why, for example, subclasses of the relevant types won't work with it unless they are added explicitly to the dict.

Sketch:

class FastDispatch:
    """
    Optimized version of functools.singledispatch, does not take into account inheritance or protocol membership.
    
    This leads to a speed-up of XXX, as documented in ADR YYY.
    
    Usage:
    >>> @Fastdispatch.fastdispatch(Any)
    ... def extract_connectivity_args(connectivity, *args, **kwargs):
    ...     return None
    ...
    ... @extract_connectivity_args(NeighborTableOffsetProvider):
    ... def extract_connectivity_args_from_nbtable(connectivity, device, *args, **kwargs):
    ...     return (_ensure_is_on_device(connectivity.table, device), zero_tuple)
    """
    _registry: dict[type: str]
    
    def __call__(self, dispatch_arg, *args, **kwargs):
        return getattr(self, self._registry[type(dispatch_arg)])(dispatch_arg, *args, **kwargs)
        
    def register(self, type):
        def decorator(function):
            self._registry[type] = function
            return function
        return decorator
        
    @classmethod
    def fastdispatch(cls, default_type):
        return decorator(function):
            dispatcher = cls()
            dispatcher.register(default_type)(function)
            return dispatcher
        return decorator

def _ensure_is_on_device(
connectivity_arg: npt.NDArray, device: core_defs.DeviceType
) -> npt.NDArray:
connectivity_arg: core_defs.NDArrayObject, device: core_defs.DeviceType
) -> core_defs.NDArrayObject:
if device == core_defs.DeviceType.CUDA:
import cupy as cp

Expand All @@ -75,26 +124,15 @@ def _ensure_is_on_device(


def extract_connectivity_args(
offset_provider: dict[str, common.Connectivity | common.Dimension], device: core_defs.DeviceType
) -> list[tuple[npt.NDArray, tuple[int, ...]]]:
# note: the order here needs to agree with the order of the generated bindings
args: list[tuple[npt.NDArray, tuple[int, ...]]] = []
for name, conn in offset_provider.items():
if isinstance(conn, common.Connectivity):
if not isinstance(conn, common.NeighborTable):
raise NotImplementedError(
"Only 'NeighborTable' connectivities implemented at this point."
)
# copying to device here is a fallback for easy testing and might be removed later
conn_arg = _ensure_is_on_device(conn.table, device)
args.append((conn_arg, tuple([0] * 2)))
elif isinstance(conn, common.Dimension):
pass
else:
raise AssertionError(
f"Expected offset provider '{name}' to be a 'Connectivity' or 'Dimension', "
f"but got '{type(conn).__name__}'."
)
offset_provider: dict[str, Any], device: core_defs.DeviceType
) -> list[ConnectivityArg]:
zero_tuple = (0, 0)
args = []
for conn in offset_provider.values():
handler = type_handlers_connectivity_args.get(type(conn), handle_other_type)
result = handler(conn, zero_tuple, device) # type: ignore
if result:
args.append(result)
return args


Expand Down
Loading