Skip to content

Commit

Permalink
update transfer, disable gpu for itir tests
Browse files Browse the repository at this point in the history
  • Loading branch information
havogt committed Nov 15, 2023
1 parent 5c179a1 commit f831491
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 9 deletions.
26 changes: 18 additions & 8 deletions src/gt4py/next/program_processors/runners/gtfn.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,10 @@
#
# SPDX-License-Identifier: GPL-3.0-or-later

import functools
import warnings
from typing import Any

import cupy as cp # TODO
import numpy.typing as npt

import gt4py._core.definitions as core_defs
Expand Down Expand Up @@ -43,12 +44,12 @@ def convert_arg(arg: Any) -> Any:
return arg


def convert_args(inp: stages.CompiledProgram) -> stages.CompiledProgram:
def convert_args(inp: stages.CompiledProgram, for_cp: bool = False) -> stages.CompiledProgram:
def decorated_program(
*args, offset_provider: dict[str, common.Connectivity | common.Dimension]
):
converted_args = [convert_arg(arg) for arg in args]
conn_args = extract_connectivity_args(offset_provider)
conn_args = extract_connectivity_args(offset_provider, for_cp)
return inp(
*converted_args,
*conn_args,
Expand All @@ -58,8 +59,10 @@ def decorated_program(


def extract_connectivity_args(
offset_provider: dict[str, common.Connectivity | common.Dimension]
offset_provider: dict[str, common.Connectivity | common.Dimension], for_cp: bool
) -> list[tuple[npt.NDArray, tuple[int, ...]]]:
if for_cp:
import cupy as cp
# note: the order here needs to agree with the order of the generated bindings
args: list[tuple[npt.NDArray, tuple[int, ...]]] = []
for name, conn in offset_provider.items():
Expand All @@ -68,9 +71,16 @@ def extract_connectivity_args(
raise NotImplementedError(
"Only `NeighborTable` connectivities implemented at this point."
)
args.append(
(cp.asarray(conn.table), tuple([0] * 2))
) # TODO where do we do the host<->device of neighbortables
conn_arg = conn.table
if (
for_cp
): # copying to device here is a fallback for easy testing and might be removed later
if not isinstance(conn_arg, cp.ndarray):
conn_arg = cp.asarray(conn.table)
warnings.warn(
"Copying connectivity to device. For performance make sure connectivity is provided on device."
)
args.append((conn_arg, tuple([0] * 2)))
elif isinstance(conn, common.Dimension):
pass
else:
Expand Down Expand Up @@ -129,7 +139,7 @@ def compilation_hash(otf_closure: stages.ProgramCall) -> int:
translation=GTFN_GPU_TRANSLATION_STEP,
bindings=nanobind.bind_source,
compilation=GTFN_DEFAULT_COMPILE_STEP,
decoration=convert_args,
decoration=functools.partial(convert_args, for_cp=True),
)


Expand Down
2 changes: 1 addition & 1 deletion tests/next_tests/unit_tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def lift_mode(request):
(definitions.ProgramBackendId.GTFN_CPU, True),
(definitions.ProgramBackendId.GTFN_CPU_IMPERATIVE, True),
(definitions.ProgramBackendId.GTFN_CPU_WITH_TEMPORARIES, True),
pytest.param((definitions.ProgramBackendId.GTFN_GPU, True), marks=pytest.mark.requires_gpu),
# pytest.param((definitions.ProgramBackendId.GTFN_GPU, True), marks=pytest.mark.requires_gpu), # TODO(havogt): update tests to use proper allocation
(definitions.ProgramFormatterId.LISP_FORMATTER, False),
(definitions.ProgramFormatterId.ITIR_PRETTY_PRINTER, False),
(definitions.ProgramFormatterId.ITIR_TYPE_CHECKER, False),
Expand Down

0 comments on commit f831491

Please sign in to comment.