Skip to content

Commit

Permalink
Improve JAX config parse_device implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
Toni-SM committed Oct 16, 2024
1 parent 39e6d28 commit c41ad0f
Showing 1 changed file with 7 additions and 7 deletions.
14 changes: 7 additions & 7 deletions skrl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,16 +155,15 @@ def parse_device(device: Union[str, "jax.Device", None]) -> "jax.Device":
"""
import jax

if isinstance(device, str):
if isinstance(device, jax.Device):
return device
elif isinstance(device, str):
device_type, device_index = f"{device}:0".split(':')[:2]
try:
return jax.devices(device_type)[int(device_index)]
except (RuntimeError, IndexError) as e:
logger.info(f"Invalid device specification ({device}): {e}")
device = None
if device is None:
return jax.devices()[0]
return device
logger.warning(f"Invalid device specification ({device}): {e}")
return jax.devices()[0]

@property
def device(self) -> "jax.Device":
Expand All @@ -173,7 +172,8 @@ def device(self) -> "jax.Device":
The default device, unless specified, is ``cuda:0`` (or ``cuda:JAX_LOCAL_RANK`` in a distributed environment)
if CUDA is available, ``cpu`` otherwise
"""
return self.parse_device(self._device)
self._device = self.parse_device(self._device)
return self._device

@device.setter
def device(self, device: Union[str, "jax.Device"]) -> None:
Expand Down

0 comments on commit c41ad0f

Please sign in to comment.