diff --git a/gpustat/core.py b/gpustat/core.py index de0aa3f..4f381f5 100644 --- a/gpustat/core.py +++ b/gpustat/core.py @@ -31,6 +31,7 @@ from blessed import Terminal import gpustat.util as util +import gpustat.nvml as nvml from gpustat.nvml import pynvml as N from gpustat.nvml import check_driver_nvml_version @@ -443,7 +444,7 @@ def clean_processes(): def new_query(debug=False, id=None) -> 'GPUStatCollection': """Query the information of all the GPUs on local machine""" - N.nvmlInit() + nvml.ensure_initialized() log = util.DebugHelper() def _decode(b: Union[str, bytes]) -> str: @@ -625,7 +626,6 @@ def _wrapped(*args, **kwargs): if debug: log.report_summary() - N.nvmlShutdown() return GPUStatCollection(gpu_list, driver_version=driver_version) def __len__(self): @@ -752,15 +752,10 @@ def new_query() -> GPUStatCollection: def gpu_count() -> int: '''Return the number of available GPUs in the system.''' try: - N.nvmlInit() + nvml.ensure_initialized() return N.nvmlDeviceGetCount() except N.NVMLError: return 0 # fallback - finally: - try: - N.nvmlShutdown() - except N.NVMLError: - pass def is_available() -> bool: diff --git a/gpustat/nvml.py b/gpustat/nvml.py index 9c3847c..c60f950 100644 --- a/gpustat/nvml.py +++ b/gpustat/nvml.py @@ -2,7 +2,7 @@ # pylint: disable=protected-access -from typing import Tuple +import atexit import functools import os import sys @@ -204,4 +204,30 @@ def nvmlDeviceGetMemoryInfo(handle): setattr(pynvml, 'nvmlDeviceGetMemoryInfo', pynvml_monkeypatch.nvmlDeviceGetMemoryInfo) -__all__ = ['pynvml'] +# Upon importing this module, let pynvml be initialized and remain active +# throughout the lifespan of the python process (until gpustat exists). +_initialized: bool +_init_error = None +try: + pynvml.nvmlInit() + _initialized = True + + def _shutdown(): + pynvml.nvmlShutdown() + atexit.register(_shutdown) + +except pynvml.NVMLError as exc: + _initialized = False + _init_error = exc + + +def ensure_initialized(): + if not _initialized: + raise _init_error # type: ignore + + +__all__ = [ + 'pynvml', + 'check_driver_nvml_version', + 'ensure_initialized', +] diff --git a/gpustat/test_gpustat.py b/gpustat/test_gpustat.py index 478aac4..4d4017d 100644 --- a/gpustat/test_gpustat.py +++ b/gpustat/test_gpustat.py @@ -48,6 +48,7 @@ def _configure_mock(N=pynvml, unstub(N) # reset all the stubs when(N).nvmlInit().thenReturn() + gpustat.nvml._initialized = True # nvmlInit() is called upon module import when(N).nvmlShutdown().thenReturn() when(N).nvmlSystemGetDriverVersion().thenReturn('415.27.mock')