Skip to content

Commit

Permalink
[RFC] Add support for device extension autoloading (pytorch#127074)
Browse files Browse the repository at this point in the history
Fixes pytorch#122468

- Load device extensions at the end of `torch/__init__.py`
- Enabled by default, or you can disable it with `TORCH_DEVICE_BACKEND_AUTOLOAD=0`

run test:

```python
python test/run_test.py -i test_autoload_enable
python test/run_test.py -i test_autoload_disable
```

doc:

https://docs-preview.pytorch.org/pytorch/pytorch/127074/miscellaneous_environment_variables.html

co-author:  @jgong5 @bsochack @bkowalskiINTEL @jczaja @FFFrog @hipudding

Co-authored-by: albanD <[email protected]>
Co-authored-by: Jiong Gong <[email protected]>
Pull Request resolved: pytorch#127074
Approved by: https://github.com/albanD, https://github.com/jgong5
  • Loading branch information
shink authored and pytorchmergebot committed Jul 9, 2024
1 parent 6c4efd4 commit 312652c
Show file tree
Hide file tree
Showing 7 changed files with 136 additions and 1 deletion.
4 changes: 3 additions & 1 deletion docs/source/miscellaneous_environment_variables.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,6 @@ Miscellaneous Environment Variables
* - ``TORCH_FORCE_WEIGHTS_ONLY_LOAD``
- If set to [``1``, ``y``, ``yes``, ``true``], the torch.load will use ``weight_only=True``. For more documentation on this, see :func:`torch.load`.
* - ``TORCH_AUTOGRAD_SHUTDOWN_WAIT_LIMIT``
- Under some conditions, autograd threads can hang on shutdown, therefore we do not wait for them to shutdown indefinitely but rely on timeout that is default set to ``10`` seconds. This environment variable can be used to set the timeout in seconds.
- Under some conditions, autograd threads can hang on shutdown, therefore we do not wait for them to shutdown indefinitely but rely on timeout that is default set to ``10`` seconds. This environment variable can be used to set the timeout in seconds.
* - ``TORCH_DEVICE_BACKEND_AUTOLOAD``
- If set to ``1``, out-of-tree backend extensions will be automatically imported when running ``import torch``.
5 changes: 5 additions & 0 deletions test/cpp_extensions/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,4 +109,9 @@
ext_modules=ext_modules,
include_dirs="self_compiler_include_dirs_test",
cmdclass={"build_ext": BuildExtension.with_options(use_ninja=USE_NINJA)},
entry_points={
"torch.backends": [
"device_backend = torch_test_cpp_extension:_autoload",
],
},
)
11 changes: 11 additions & 0 deletions test/cpp_extensions/torch_test_cpp_extension/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
"""
This is a device backend extension used for testing.
See this RFC: https://github.com/pytorch/pytorch/issues/122468
"""

import os


def _autoload():
# Set the environment variable to true in this entrypoint
os.environ["IS_CUSTOM_DEVICE_BACKEND_IMPORTED"] = "1"
47 changes: 47 additions & 0 deletions test/run_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -734,6 +734,51 @@ def test_cpp_extensions_aot_no_ninja(test_module, test_directory, options):
return _test_cpp_extensions_aot(test_directory, options, use_ninja=False)


def test_autoload_enable(test_module, test_directory, options):
return _test_autoload(test_directory, options, enable=True)


def test_autoload_disable(test_module, test_directory, options):
return _test_autoload(test_directory, options, enable=False)


def _test_autoload(test_directory, options, enable=True):
# Wipe the build folder, if it exists already
cpp_extensions_test_dir = os.path.join(test_directory, "cpp_extensions")
cpp_extensions_test_build_dir = os.path.join(cpp_extensions_test_dir, "build")
if os.path.exists(cpp_extensions_test_build_dir):
shutil.rmtree(cpp_extensions_test_build_dir)

# Build the test cpp extensions modules
cmd = [sys.executable, "setup.py", "install", "--root", "./install"]
return_code = shell(cmd, cwd=cpp_extensions_test_dir, env=os.environ)
if return_code != 0:
return return_code

# "install" the test modules and run tests
python_path = os.environ.get("PYTHONPATH", "")

try:
cpp_extensions = os.path.join(test_directory, "cpp_extensions")
install_directory = ""
# install directory is the one that is named site-packages
for root, directories, _ in os.walk(os.path.join(cpp_extensions, "install")):
for directory in directories:
if "-packages" in directory:
install_directory = os.path.join(root, directory)

assert install_directory, "install_directory must not be empty"
os.environ["PYTHONPATH"] = os.pathsep.join([install_directory, python_path])
os.environ["TORCH_DEVICE_BACKEND_AUTOLOAD"] = str(int(enable))

cmd = [sys.executable, "test_autoload.py"]
return_code = shell(cmd, cwd=test_directory, env=os.environ)
return return_code
finally:
os.environ["PYTHONPATH"] = python_path
os.environ.pop("TORCH_DEVICE_BACKEND_AUTOLOAD")


def test_distributed(test_module, test_directory, options):
# MPI tests are broken with Python-3.9
mpi_available = subprocess.call(
Expand Down Expand Up @@ -1052,6 +1097,8 @@ def run_ci_sanity_check(test: ShardedTest, test_directory, options):
"distributed/rpc/cuda/test_tensorpipe_agent": run_test_with_subprocess,
"doctests": run_doctests,
"test_ci_sanity_check_fail": run_ci_sanity_check,
"test_autoload_enable": test_autoload_enable,
"test_autoload_disable": test_autoload_disable,
}


Expand Down
21 changes: 21 additions & 0 deletions test/test_autoload.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# Owner(s): ["module: PrivateUse1"]

import os

from torch.testing._internal.common_utils import run_tests, TestCase


class TestDeviceBackendAutoload(TestCase):
def test_autoload(self):
switch = os.getenv("TORCH_DEVICE_BACKEND_AUTOLOAD", "0")

# After importing the extension, the value of this environment variable should be true
# See: test/cpp_extensions/torch_test_cpp_extension/__init__.py
is_imported = os.getenv("IS_CUSTOM_DEVICE_BACKEND_IMPORTED", "0")

# Both values should be equal
self.assertEqual(is_imported, switch)


if __name__ == "__main__":
run_tests()
2 changes: 2 additions & 0 deletions tools/testing/discover_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,8 @@ def skip_test_p(name: str) -> bool:
"distributed/elastic/utils/distributed_test",
"distributed/elastic/multiprocessing/api_test",
"doctests",
"test_autoload_enable",
"test_autoload_disable",
],
)

Expand Down
47 changes: 47 additions & 0 deletions torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2564,3 +2564,50 @@ def _constrain_as_size(
from torch import _logging

_logging._init_logs()


def _import_device_backends():
"""
Leverage the Python plugin mechanism to load out-of-the-tree device extensions.
See this RFC: https://github.com/pytorch/pytorch/issues/122468
"""
from importlib.metadata import entry_points

group_name = "torch.backends"
if sys.version_info < (3, 10):
backend_extensions = entry_points().get(group_name, ())
else:
backend_extensions = entry_points(group=group_name)

for backend_extension in backend_extensions:
try:
# Load the extension
entrypoint = backend_extension.load()
# Call the entrypoint
entrypoint()
except Exception as err:
raise RuntimeError(
f"Failed to load the backend extension: {backend_extension.name}. "
f"You can disable extension auto-loading with TORCH_DEVICE_BACKEND_AUTOLOAD=0."
) from err


def _is_device_backend_autoload_enabled() -> builtins.bool:
"""
Whether autoloading out-of-the-tree device extensions is enabled.
The switch depends on the value of the environment variable
`TORCH_DEVICE_BACKEND_AUTOLOAD`.
Returns:
bool: Whether to enable autoloading the extensions. Enabled by default.
Examples:
>>> torch._is_device_backend_autoload_enabled()
True
"""
# enabled by default
return os.getenv("TORCH_DEVICE_BACKEND_AUTOLOAD", "1") == "1"


if _is_device_backend_autoload_enabled():
_import_device_backends()

0 comments on commit 312652c

Please sign in to comment.