diff --git a/src/gt4py/cartesian/backend/base.py b/src/gt4py/cartesian/backend/base.py index 4e9cd2d1b5..922109a1f7 100644 --- a/src/gt4py/cartesian/backend/base.py +++ b/src/gt4py/cartesian/backend/base.py @@ -19,7 +19,21 @@ import pathlib import time import warnings -from typing import TYPE_CHECKING, Any, ClassVar, Dict, List, Optional, Protocol, Tuple, Type, Union +from typing import ( + TYPE_CHECKING, + Any, + Callable, + ClassVar, + Dict, + List, + Optional, + Protocol, + Tuple, + Type, + Union, +) + +from typing_extensions import deprecated from gt4py import storage as gt_storage from gt4py.cartesian import definitions as gt_definitions, utils as gt_utils @@ -39,7 +53,7 @@ def from_name(name: str) -> Optional[Type["Backend"]]: return REGISTRY.get(name, None) -def register(backend_cls: Type["Backend"]) -> None: +def register(backend_cls: Type["Backend"]) -> Type["Backend"]: assert issubclass(backend_cls, Backend) and backend_cls.name is not None if isinstance(backend_cls.name, str): @@ -413,3 +427,28 @@ def build_extension_module( ) return module_name, file_path + + +def disabled(message: str, *, enabled_env_var: str) -> Callable[[Type[Backend]], Type[Backend]]: + # We push for hard deprecation here by raising by default and warning if enabling has been forced. + enabled = bool(int(os.environ.get(enabled_env_var, "0"))) + if enabled: + return deprecated(message) + else: + + def _decorator(cls: Type[Backend]) -> Type[Backend]: + def _no_generate(obj) -> Type["StencilObject"]: + raise NotImplementedError( + f"Disabled '{cls.name}' backend: 'f{message}'\n", + f"You can still enable the backend by hand using the environment variable '{enabled_env_var}=1'", + ) + + # Replace generate method with raise + if not hasattr(cls, "generate"): + raise ValueError(f"Coding error. Expected a generate method on {cls}") + # Flag that it got disabled for register lookup + cls.disabled = True # type: ignore + cls.generate = _no_generate # type: ignore + return cls + + return _decorator diff --git a/src/gt4py/cartesian/backend/cuda_backend.py b/src/gt4py/cartesian/backend/cuda_backend.py index 84d1949818..0e3d77e1ae 100644 --- a/src/gt4py/cartesian/backend/cuda_backend.py +++ b/src/gt4py/cartesian/backend/cuda_backend.py @@ -15,7 +15,7 @@ from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Type from gt4py import storage as gt_storage -from gt4py.cartesian.backend.base import CLIBackendMixin, register +from gt4py.cartesian.backend.base import CLIBackendMixin, disabled, register from gt4py.cartesian.backend.gtc_common import ( BackendCodegen, bindings_main_template, @@ -125,12 +125,19 @@ def apply_codegen(cls, root, *, module_name="stencil", backend, **kwargs) -> str return generated_code +@disabled( + message="CUDA backend is deprecated. New features developed after February 2024 are not available.", + enabled_env_var="GT4PY_GTC_ENABLE_CUDA", +) @register class CudaBackend(BaseGTBackend, CLIBackendMixin): """CUDA backend using gtc.""" name = "cuda" - options = {**BaseGTBackend.GT_BACKEND_OPTS, "device_sync": {"versioning": True, "type": bool}} + options = { + **BaseGTBackend.GT_BACKEND_OPTS, + "device_sync": {"versioning": True, "type": bool}, + } languages = {"computation": "cuda", "bindings": ["python"]} storage_info = gt_storage.layout.CUDALayout PYEXT_GENERATOR_CLASS = CudaExtGenerator # type: ignore diff --git a/src/gt4py/cartesian/backend/gtc_common.py b/src/gt4py/cartesian/backend/gtc_common.py index beaf70c567..c21f04ee49 100644 --- a/src/gt4py/cartesian/backend/gtc_common.py +++ b/src/gt4py/cartesian/backend/gtc_common.py @@ -278,7 +278,10 @@ def make_extension( gt_pyext_sources: Dict[str, Any] if not self.builder.options._impl_opts.get("disable-code-generation", False): gt_pyext_files = self.make_extension_sources(stencil_ir=stencil_ir) - gt_pyext_sources = {**gt_pyext_files["computation"], **gt_pyext_files["bindings"]} + gt_pyext_sources = { + **gt_pyext_files["computation"], + **gt_pyext_files["bindings"], + } else: # Pass NOTHING to the self.builder means try to reuse the source code files gt_pyext_files = {} diff --git a/tests/cartesian_tests/definitions.py b/tests/cartesian_tests/definitions.py index 5ce4095e26..11eebfeb5d 100644 --- a/tests/cartesian_tests/definitions.py +++ b/tests/cartesian_tests/definitions.py @@ -43,7 +43,7 @@ def _get_backends_with_storage_info(storage_info_kind: str): res = [] for name in _ALL_BACKEND_NAMES: backend = gt4pyc.backend.from_name(name) - if backend is not None: + if not getattr(backend, "disabled", False): if backend.storage_info["device"] == storage_info_kind: res.append(_backend_name_as_param(name)) return res diff --git a/tests/cartesian_tests/unit_tests/backend_tests/test_backend.py b/tests/cartesian_tests/unit_tests/backend_tests/test_backend.py index 4d814a93e8..e8e1f1db64 100644 --- a/tests/cartesian_tests/unit_tests/backend_tests/test_backend.py +++ b/tests/cartesian_tests/unit_tests/backend_tests/test_backend.py @@ -48,7 +48,11 @@ def stencil_def( out = pa * fa + pb * fb - pc * fc # type: ignore # noqa -field_info_val = {0: ("out", "fa"), 1: ("out", "fa", "fb"), 2: ("out", "fa", "fb", "fc")} +field_info_val = { + 0: ("out", "fa"), + 1: ("out", "fa", "fb"), + 2: ("out", "fa", "fb", "fc"), +} parameter_info_val = {0: ("pa",), 1: ("pa", "pb"), 2: ("pa", "pb", "pc")} unreferenced_val = {0: ("pb", "fb", "pc", "fc"), 1: ("pc", "fc"), 2: ()} @@ -168,5 +172,23 @@ def test_toolchain_profiling(backend_name: str, mode: int, rebuild: bool): assert build_info["load_time"] > 0.0 +@pytest.mark.parametrize("backend_name", ["cuda"]) +def test_deprecation_gtc_cuda(backend_name: str): + # Default deprecation, raise an error + build_info: Dict[str, Any] = {} + builder = ( + StencilBuilder(cast(StencilFunc, stencil_def)) + .with_backend(backend_name) + .with_externals({"MODE": 2}) + .with_options( + name=stencil_def.__name__, + module=stencil_def.__module__, + build_info=build_info, + ) + ) + with pytest.raises(NotImplementedError): + builder.build() + + if __name__ == "__main__": pytest.main([__file__])