Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat[next]: Optimisations for icon4py #1536

Open
wants to merge 25 commits into
base: main
Choose a base branch
from
Open
Changes from 1 commit
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
f16d03b
optimise extract_connectivity_args
samkellerhals Apr 22, 2024
4e2a31f
Remove isinstance checks from convert_args
samkellerhals Apr 22, 2024
ddf16f1
More small optimisations, and connecitivities caching
samkellerhals Apr 29, 2024
0f364a9
Only do asserts in debug mode
samkellerhals May 2, 2024
789b211
Import CuPyArrayField only when cp is available
samkellerhals May 2, 2024
a674725
Run precommit
samkellerhals May 2, 2024
f5fe70d
Merge branch 'main' into optimisations-for-icon4py
samkellerhals May 2, 2024
c5cbe5d
Merge remote-tracking branch 'origin/main' into optimisations-for-ico…
samkellerhals May 2, 2024
b080621
Merge remote-tracking branch 'samkellerhals/optimisations-for-icon4py…
samkellerhals May 2, 2024
4400ede
Merge remote-tracking branch 'origin' into optimisations-for-icon4py
samkellerhals May 3, 2024
0829ab4
Add _ensure_is_on_device checks to run tests
samkellerhals May 6, 2024
e717b33
Place _ensure_is_on_device in right place
samkellerhals May 6, 2024
d921bca
Add deprecation warning
samkellerhals May 7, 2024
33bd45b
Merge branch 'main' of https://github.com/GridTools/gt4py into optimi…
samkellerhals May 24, 2024
c68b970
Merge branch 'main' into optimisations-for-icon4py
philip-paul-mueller Sep 5, 2024
c9c5c8e
Update from main
samkellerhals Dec 11, 2024
9eb29be
Add ConnectivityFields to handler
samkellerhals Dec 11, 2024
8a0793b
Fix typing issues
samkellerhals Dec 11, 2024
cef757b
Enable deprecation warning again
samkellerhals Dec 11, 2024
4c90a89
Enable deprecation warning again
samkellerhals Dec 11, 2024
da6516d
Add back domain to right place
samkellerhals Dec 11, 2024
54712fb
Add space
samkellerhals Dec 11, 2024
a49e341
Remove CupyArrayConnectivityField
samkellerhals Dec 11, 2024
0b2ca6c
Handle connecitivity fields
samkellerhals Dec 11, 2024
2607a81
Always check if it is on device
samkellerhals Dec 11, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 29 additions & 21 deletions src/gt4py/next/program_processors/runners/gtfn.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import gt4py._core.definitions as core_defs
import gt4py.next.allocators as next_allocators
from gt4py.eve.utils import content_hash
from gt4py.next import backend, common, config
from gt4py.next import NeighborTableOffsetProvider, backend, common, config
from gt4py.next.iterator import transforms
from gt4py.next.iterator.transforms import global_tmps
from gt4py.next.otf import recipes, stages, workflow
Expand Down Expand Up @@ -74,27 +74,35 @@ def _ensure_is_on_device(
return connectivity_arg


ConnectivityArg = tuple[core_defs.NDArrayObject, tuple[int, ...]]


def handle_connectivity(
conn: NeighborTableOffsetProvider, zero_tuple: tuple[int, ...]
) -> ConnectivityArg:
return (conn.table, zero_tuple)


def handle_other_type(*args: Any, **kwargs: Any) -> None:
return None


type_handlers = {
NeighborTableOffsetProvider: handle_connectivity,
common.Dimension: handle_other_type,
}


def extract_connectivity_args(
offset_provider: dict[str, common.Connectivity | common.Dimension], device: core_defs.DeviceType
) -> list[tuple[npt.NDArray, tuple[int, ...]]]:
# note: the order here needs to agree with the order of the generated bindings
args: list[tuple[npt.NDArray, tuple[int, ...]]] = []
for name, conn in offset_provider.items():
if isinstance(conn, common.Connectivity):
if not isinstance(conn, common.NeighborTable):
raise NotImplementedError(
"Only 'NeighborTable' connectivities implemented at this point."
)
# copying to device here is a fallback for easy testing and might be removed later
conn_arg = _ensure_is_on_device(conn.table, device)
args.append((conn_arg, tuple([0] * 2)))
elif isinstance(conn, common.Dimension):
pass
else:
raise AssertionError(
f"Expected offset provider '{name}' to be a 'Connectivity' or 'Dimension', "
f"but got '{type(conn).__name__}'."
)
offset_provider: dict[str, Any], device: core_defs.DeviceType
) -> list[ConnectivityArg]:
zero_tuple = (0, 0)
args = []
for conn in offset_provider.values():
handler = type_handlers.get(type(conn), handle_other_type)
result = handler(conn, zero_tuple) # type: ignore
if result:
args.append(result)
samkellerhals marked this conversation as resolved.
Show resolved Hide resolved
return args


Expand Down
Loading