Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Force the use of the device local to process in distributed runs in JAX #263

Merged
merged 3 commits into from
Jan 28, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/ISSUE_TEMPLATE/bug_report.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ body:
description: The skrl version can be obtained with the command `pip show skrl`.
options:
- ---
- 1.4.1
- 1.4.0
- 1.3.0
- 1.2.0
Expand Down
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@

The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/).

## [1.4.1] - Unreleased
### Fixed
- Force the use of the device local to process in distributed runs in JAX

## [1.4.0] - 2025-01-16
### Added
- Utilities to operate on Gymnasium spaces (`Box`, `Discrete`, `MultiDiscrete`, `Tuple` and `Dict`)
Expand Down
5 changes: 3 additions & 2 deletions docs/source/api/config/frameworks.rst
Original file line number Diff line number Diff line change
Expand Up @@ -92,11 +92,12 @@ API

.. py:data:: skrl.config.jax.device
:type: jax.Device
:value: "cuda:${LOCAL_RANK}" | "cpu"
:value: "cuda:${JAX_LOCAL_RANK}" | "cpu"

Default device.

The default device, unless specified, is ``cuda:0`` (or ``cuda:JAX_LOCAL_RANK`` in a distributed environment) if CUDA is available, ``cpu`` otherwise.
The default device, unless specified, is ``cuda:0`` if CUDA is available, ``cpu`` otherwise.
However, in a distributed environment, it is the device local to process with index ``JAX_RANK``.

.. py:data:: skrl.config.jax.backend
:type: str
Expand Down
2 changes: 1 addition & 1 deletion docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
if skrl.__version__ != "unknown":
release = version = skrl.__version__
else:
release = version = "1.4.0"
release = version = "1.4.1"

master_doc = "index"

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "skrl"
version = "1.4.0"
version = "1.4.1"
description = "Modular and flexible library for reinforcement learning on PyTorch and JAX"
readme = "README.md"
requires-python = ">=3.6"
Expand Down
23 changes: 21 additions & 2 deletions skrl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,12 @@ def __init__(self) -> None:
process_id=self._rank,
local_device_ids=self._local_rank,
)
# get the device local to process
try:
self._device = jax.local_devices(process_index=self._rank)[0]
logger.info(f"Using device local to process with index/rank {self._rank} ({self._device})")
except Exception as e:
logger.warning(f"Failed to get the device local to process with index/rank {self._rank}: {e}")

@staticmethod
def parse_device(device: Union[str, "jax.Device", None]) -> "jax.Device":
Expand All @@ -197,13 +203,26 @@ def parse_device(device: Union[str, "jax.Device", None]) -> "jax.Device":

This function supports the PyTorch-like ``"type:ordinal"`` string specification (e.g.: ``"cuda:0"``).

.. warning::

This method returns (forces to use) the device local to process in a distributed environment.

:param device: Device specification. If the specified device is ``None`` or it cannot be resolved,
the default available device will be returned instead.

:return: JAX Device.
"""
import jax

# force the use of the device local to process in distributed runs
if config.jax.is_distributed:
try:
return jax.local_devices(process_index=config.jax.rank)[0]
except Exception as e:
logger.warning(
f"Failed to get the device local to process with index/rank {config.jax.rank}: {e}"
)

if isinstance(device, jax.Device):
return device
elif isinstance(device, str):
Expand All @@ -218,8 +237,8 @@ def parse_device(device: Union[str, "jax.Device", None]) -> "jax.Device":
def device(self) -> "jax.Device":
"""Default device.

The default device, unless specified, is ``cuda:0`` (or ``cuda:JAX_LOCAL_RANK`` in a distributed environment)
if CUDA is available, ``cpu`` otherwise.
The default device, unless specified, is ``cuda:0`` if CUDA is available, ``cpu`` otherwise.
However, in a distributed environment, it is the device local to process with index ``JAX_RANK``.
"""
self._device = self.parse_device(self._device)
return self._device
Expand Down
Loading