From 312652c3258a3a8fec8fbfe6a9e8887e23d39c13 Mon Sep 17 00:00:00 2001 From: Yuanhao Ji Date: Tue, 9 Jul 2024 06:14:10 +0000 Subject: [PATCH] [RFC] Add support for device extension autoloading (#127074) Fixes #122468 - Load device extensions at the end of `torch/__init__.py` - Enabled by default, or you can disable it with `TORCH_DEVICE_BACKEND_AUTOLOAD=0` run test: ```python python test/run_test.py -i test_autoload_enable python test/run_test.py -i test_autoload_disable ``` doc: https://docs-preview.pytorch.org/pytorch/pytorch/127074/miscellaneous_environment_variables.html co-author: @jgong5 @bsochack @bkowalskiINTEL @jczaja @FFFrog @hipudding Co-authored-by: albanD Co-authored-by: Jiong Gong Pull Request resolved: https://github.com/pytorch/pytorch/pull/127074 Approved by: https://github.com/albanD, https://github.com/jgong5 --- .../miscellaneous_environment_variables.rst | 4 +- test/cpp_extensions/setup.py | 5 ++ .../torch_test_cpp_extension/__init__.py | 11 +++++ test/run_test.py | 47 +++++++++++++++++++ test/test_autoload.py | 21 +++++++++ tools/testing/discover_tests.py | 2 + torch/__init__.py | 47 +++++++++++++++++++ 7 files changed, 136 insertions(+), 1 deletion(-) create mode 100644 test/test_autoload.py diff --git a/docs/source/miscellaneous_environment_variables.rst b/docs/source/miscellaneous_environment_variables.rst index 37b43d4946bd78..f783f4c9235429 100644 --- a/docs/source/miscellaneous_environment_variables.rst +++ b/docs/source/miscellaneous_environment_variables.rst @@ -10,4 +10,6 @@ Miscellaneous Environment Variables * - ``TORCH_FORCE_WEIGHTS_ONLY_LOAD`` - If set to [``1``, ``y``, ``yes``, ``true``], the torch.load will use ``weight_only=True``. For more documentation on this, see :func:`torch.load`. * - ``TORCH_AUTOGRAD_SHUTDOWN_WAIT_LIMIT`` - - Under some conditions, autograd threads can hang on shutdown, therefore we do not wait for them to shutdown indefinitely but rely on timeout that is default set to ``10`` seconds. This environment variable can be used to set the timeout in seconds. \ No newline at end of file + - Under some conditions, autograd threads can hang on shutdown, therefore we do not wait for them to shutdown indefinitely but rely on timeout that is default set to ``10`` seconds. This environment variable can be used to set the timeout in seconds. + * - ``TORCH_DEVICE_BACKEND_AUTOLOAD`` + - If set to ``1``, out-of-tree backend extensions will be automatically imported when running ``import torch``. diff --git a/test/cpp_extensions/setup.py b/test/cpp_extensions/setup.py index e0f1f858e884ee..b3eb760cbfefec 100644 --- a/test/cpp_extensions/setup.py +++ b/test/cpp_extensions/setup.py @@ -109,4 +109,9 @@ ext_modules=ext_modules, include_dirs="self_compiler_include_dirs_test", cmdclass={"build_ext": BuildExtension.with_options(use_ninja=USE_NINJA)}, + entry_points={ + "torch.backends": [ + "device_backend = torch_test_cpp_extension:_autoload", + ], + }, ) diff --git a/test/cpp_extensions/torch_test_cpp_extension/__init__.py b/test/cpp_extensions/torch_test_cpp_extension/__init__.py index e69de29bb2d1d6..9003fc9b9cd4e9 100644 --- a/test/cpp_extensions/torch_test_cpp_extension/__init__.py +++ b/test/cpp_extensions/torch_test_cpp_extension/__init__.py @@ -0,0 +1,11 @@ +""" +This is a device backend extension used for testing. +See this RFC: https://github.com/pytorch/pytorch/issues/122468 +""" + +import os + + +def _autoload(): + # Set the environment variable to true in this entrypoint + os.environ["IS_CUSTOM_DEVICE_BACKEND_IMPORTED"] = "1" diff --git a/test/run_test.py b/test/run_test.py index 72746ab04ae0f9..8372d600cb56f4 100755 --- a/test/run_test.py +++ b/test/run_test.py @@ -734,6 +734,51 @@ def test_cpp_extensions_aot_no_ninja(test_module, test_directory, options): return _test_cpp_extensions_aot(test_directory, options, use_ninja=False) +def test_autoload_enable(test_module, test_directory, options): + return _test_autoload(test_directory, options, enable=True) + + +def test_autoload_disable(test_module, test_directory, options): + return _test_autoload(test_directory, options, enable=False) + + +def _test_autoload(test_directory, options, enable=True): + # Wipe the build folder, if it exists already + cpp_extensions_test_dir = os.path.join(test_directory, "cpp_extensions") + cpp_extensions_test_build_dir = os.path.join(cpp_extensions_test_dir, "build") + if os.path.exists(cpp_extensions_test_build_dir): + shutil.rmtree(cpp_extensions_test_build_dir) + + # Build the test cpp extensions modules + cmd = [sys.executable, "setup.py", "install", "--root", "./install"] + return_code = shell(cmd, cwd=cpp_extensions_test_dir, env=os.environ) + if return_code != 0: + return return_code + + # "install" the test modules and run tests + python_path = os.environ.get("PYTHONPATH", "") + + try: + cpp_extensions = os.path.join(test_directory, "cpp_extensions") + install_directory = "" + # install directory is the one that is named site-packages + for root, directories, _ in os.walk(os.path.join(cpp_extensions, "install")): + for directory in directories: + if "-packages" in directory: + install_directory = os.path.join(root, directory) + + assert install_directory, "install_directory must not be empty" + os.environ["PYTHONPATH"] = os.pathsep.join([install_directory, python_path]) + os.environ["TORCH_DEVICE_BACKEND_AUTOLOAD"] = str(int(enable)) + + cmd = [sys.executable, "test_autoload.py"] + return_code = shell(cmd, cwd=test_directory, env=os.environ) + return return_code + finally: + os.environ["PYTHONPATH"] = python_path + os.environ.pop("TORCH_DEVICE_BACKEND_AUTOLOAD") + + def test_distributed(test_module, test_directory, options): # MPI tests are broken with Python-3.9 mpi_available = subprocess.call( @@ -1052,6 +1097,8 @@ def run_ci_sanity_check(test: ShardedTest, test_directory, options): "distributed/rpc/cuda/test_tensorpipe_agent": run_test_with_subprocess, "doctests": run_doctests, "test_ci_sanity_check_fail": run_ci_sanity_check, + "test_autoload_enable": test_autoload_enable, + "test_autoload_disable": test_autoload_disable, } diff --git a/test/test_autoload.py b/test/test_autoload.py new file mode 100644 index 00000000000000..b9f094d6bfb0c3 --- /dev/null +++ b/test/test_autoload.py @@ -0,0 +1,21 @@ +# Owner(s): ["module: PrivateUse1"] + +import os + +from torch.testing._internal.common_utils import run_tests, TestCase + + +class TestDeviceBackendAutoload(TestCase): + def test_autoload(self): + switch = os.getenv("TORCH_DEVICE_BACKEND_AUTOLOAD", "0") + + # After importing the extension, the value of this environment variable should be true + # See: test/cpp_extensions/torch_test_cpp_extension/__init__.py + is_imported = os.getenv("IS_CUSTOM_DEVICE_BACKEND_IMPORTED", "0") + + # Both values should be equal + self.assertEqual(is_imported, switch) + + +if __name__ == "__main__": + run_tests() diff --git a/tools/testing/discover_tests.py b/tools/testing/discover_tests.py index 6c505ad65e5494..beb4e42963972b 100644 --- a/tools/testing/discover_tests.py +++ b/tools/testing/discover_tests.py @@ -132,6 +132,8 @@ def skip_test_p(name: str) -> bool: "distributed/elastic/utils/distributed_test", "distributed/elastic/multiprocessing/api_test", "doctests", + "test_autoload_enable", + "test_autoload_disable", ], ) diff --git a/torch/__init__.py b/torch/__init__.py index 9ad49b253ff58c..932eea09b0a683 100644 --- a/torch/__init__.py +++ b/torch/__init__.py @@ -2564,3 +2564,50 @@ def _constrain_as_size( from torch import _logging _logging._init_logs() + + +def _import_device_backends(): + """ + Leverage the Python plugin mechanism to load out-of-the-tree device extensions. + See this RFC: https://github.com/pytorch/pytorch/issues/122468 + """ + from importlib.metadata import entry_points + + group_name = "torch.backends" + if sys.version_info < (3, 10): + backend_extensions = entry_points().get(group_name, ()) + else: + backend_extensions = entry_points(group=group_name) + + for backend_extension in backend_extensions: + try: + # Load the extension + entrypoint = backend_extension.load() + # Call the entrypoint + entrypoint() + except Exception as err: + raise RuntimeError( + f"Failed to load the backend extension: {backend_extension.name}. " + f"You can disable extension auto-loading with TORCH_DEVICE_BACKEND_AUTOLOAD=0." + ) from err + + +def _is_device_backend_autoload_enabled() -> builtins.bool: + """ + Whether autoloading out-of-the-tree device extensions is enabled. + The switch depends on the value of the environment variable + `TORCH_DEVICE_BACKEND_AUTOLOAD`. + + Returns: + bool: Whether to enable autoloading the extensions. Enabled by default. + + Examples: + >>> torch._is_device_backend_autoload_enabled() + True + """ + # enabled by default + return os.getenv("TORCH_DEVICE_BACKEND_AUTOLOAD", "1") == "1" + + +if _is_device_backend_autoload_enabled(): + _import_device_backends()