Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat[next]: Optimisations for icon4py #1536

Open
wants to merge 25 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
f16d03b
optimise extract_connectivity_args
samkellerhals Apr 22, 2024
4e2a31f
Remove isinstance checks from convert_args
samkellerhals Apr 22, 2024
ddf16f1
More small optimisations, and connecitivities caching
samkellerhals Apr 29, 2024
0f364a9
Only do asserts in debug mode
samkellerhals May 2, 2024
789b211
Import CuPyArrayField only when cp is available
samkellerhals May 2, 2024
a674725
Run precommit
samkellerhals May 2, 2024
f5fe70d
Merge branch 'main' into optimisations-for-icon4py
samkellerhals May 2, 2024
c5cbe5d
Merge remote-tracking branch 'origin/main' into optimisations-for-ico…
samkellerhals May 2, 2024
b080621
Merge remote-tracking branch 'samkellerhals/optimisations-for-icon4py…
samkellerhals May 2, 2024
4400ede
Merge remote-tracking branch 'origin' into optimisations-for-icon4py
samkellerhals May 3, 2024
0829ab4
Add _ensure_is_on_device checks to run tests
samkellerhals May 6, 2024
e717b33
Place _ensure_is_on_device in right place
samkellerhals May 6, 2024
d921bca
Add deprecation warning
samkellerhals May 7, 2024
33bd45b
Merge branch 'main' of https://github.com/GridTools/gt4py into optimi…
samkellerhals May 24, 2024
c68b970
Merge branch 'main' into optimisations-for-icon4py
philip-paul-mueller Sep 5, 2024
c9c5c8e
Update from main
samkellerhals Dec 11, 2024
9eb29be
Add ConnectivityFields to handler
samkellerhals Dec 11, 2024
8a0793b
Fix typing issues
samkellerhals Dec 11, 2024
cef757b
Enable deprecation warning again
samkellerhals Dec 11, 2024
4c90a89
Enable deprecation warning again
samkellerhals Dec 11, 2024
da6516d
Add back domain to right place
samkellerhals Dec 11, 2024
54712fb
Add space
samkellerhals Dec 11, 2024
a49e341
Remove CupyArrayConnectivityField
samkellerhals Dec 11, 2024
0b2ca6c
Handle connecitivity fields
samkellerhals Dec 11, 2024
2607a81
Always check if it is on device
samkellerhals Dec 11, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
116 changes: 85 additions & 31 deletions src/gt4py/next/program_processors/runners/gtfn.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import pathlib
import tempfile
import warnings
from typing import Any, Optional
from typing import Any, Callable, Optional

import diskcache
import factory
Expand All @@ -21,24 +21,80 @@
from gt4py.eve import utils
from gt4py.eve.utils import content_hash
from gt4py.next import backend, common, config
from gt4py.next.iterator import ir as itir
from gt4py.next.embedded import nd_array_field
from gt4py.next.iterator import embedded, ir as itir
from gt4py.next.otf import arguments, recipes, stages, workflow
from gt4py.next.otf.binding import nanobind
from gt4py.next.otf.compilation import compiler
from gt4py.next.otf.compilation.build_systems import compiledb
from gt4py.next.program_processors.codegens.gtfn import gtfn_module


# TODO(ricoh): Add support for the whole range of arguments that can be passed to a fencil.
def convert_arg(arg: Any) -> Any:
if isinstance(arg, tuple):
return tuple(convert_arg(a) for a in arg)
if isinstance(arg, common.Field):
arr = arg.ndarray
origin = getattr(arg, "__gt_origin__", tuple([0] * len(arg.domain)))
return arr, origin
else:
return arg
def handle_tuple(arg: Any, convert_arg: Callable) -> Any:
return tuple(convert_arg(a) for a in arg)


def handle_field(arg: nd_array_field.NdArrayField) -> tuple:
arr = arg.ndarray
origin = getattr(arg, "__gt_origin__", tuple([0] * len(arg.domain)))
return arr, origin


type_handlers_convert_args = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

@samkellerhals samkellerhals May 7, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it could work, however whilst nicer I have the feeling that it would be slower than a simple dictionary lookup.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it would be preferable to sacrifice readability only on the basis of hard evidence.

tuple: handle_tuple,
nd_array_field.NumPyArrayField: handle_field,
}

ConnectivityArg = tuple[core_defs.NDArrayObject, tuple[int, ...]]


def handle_neighbortable(
conn: embedded.NeighborTableOffsetProvider, # type: ignore
zero_tuple: tuple[int, ...],
device: core_defs.DeviceType,
) -> ConnectivityArg:
return (_ensure_is_on_device(conn.table, device), zero_tuple) # type: ignore


def handle_connectivity_field(
conn: nd_array_field.NdArrayField,
zero_tuple: tuple[int, ...],
device: core_defs.DeviceType,
copy: bool,
) -> ConnectivityArg:
return (_ensure_is_on_device(conn.ndarray, device), zero_tuple)


def handle_dimension(*args: Any, **kwargs: Any) -> None:
return None


def handle_invalid_type(conn: Any, *args: Any, **kwargs: Any) -> None:
raise AssertionError(
f"Unsupported offset provider type '{type(conn).__name__}'. "
"Expected 'Connectivity' or 'Dimension'."
)


type_handlers_connectivity_args = {
embedded.NeighborTableOffsetProvider: handle_neighbortable,
nd_array_field.NumPyArrayConnectivityField: handle_connectivity_field,
common.Dimension: handle_dimension,
}

try:
import cupy as cp

type_handlers_convert_args[nd_array_field.CuPyArrayField] = handle_field
type_handlers_connectivity_args[nd_array_field.CuPyArrayConnectivityField] = (
handle_connectivity_field
)
except ImportError:
cp = None


def handle_default(arg: Any) -> Any:
return arg


def convert_args(
Expand All @@ -63,6 +119,14 @@ def decorated_program(
return decorated_program


def convert_arg(arg: Any) -> Any:
handler = type_handlers_convert_args.get(type(arg), handle_default) # type: ignore
if handler is handle_tuple:
return handler(arg, convert_arg)
else:
return handler(arg)


def _ensure_is_on_device(
connectivity_arg: core_defs.NDArrayObject, device: core_defs.DeviceType
) -> core_defs.NDArrayObject:
Expand All @@ -79,26 +143,16 @@ def _ensure_is_on_device(


def extract_connectivity_args(
offset_provider: dict[str, common.Connectivity | common.Dimension], device: core_defs.DeviceType
) -> list[tuple[core_defs.NDArrayObject, tuple[int, ...]]]:
# note: the order here needs to agree with the order of the generated bindings
offset_provider: dict[str, common.Connectivity | common.Dimension],
device: core_defs.DeviceType,
) -> list[ConnectivityArg]:
zero_tuple = (0, 0)
args: list[tuple[core_defs.NDArrayObject, tuple[int, ...]]] = []
for name, conn in offset_provider.items():
if isinstance(conn, common.Connectivity):
if not common.is_neighbor_table(conn):
raise NotImplementedError(
"Only 'NeighborTable' connectivities implemented at this point."
)
# copying to device here is a fallback for easy testing and might be removed later
conn_arg = _ensure_is_on_device(conn.ndarray, device)
args.append((conn_arg, tuple([0] * 2)))
elif isinstance(conn, common.Dimension):
pass
else:
raise AssertionError(
f"Expected offset provider '{name}' to be a 'Connectivity' or 'Dimension', "
f"but got '{type(conn).__name__}'."
)
for conn in offset_provider.values():
handler = type_handlers_connectivity_args.get(type(conn), handle_invalid_type)
result = handler(conn, zero_tuple, device) # type: ignore
if result:
args.append(result)
return args


Expand Down
Loading