Skip to content

Commit

Permalink
automatic cpu fallback
Browse files Browse the repository at this point in the history
  • Loading branch information
ymahlau committed Jan 3, 2025
1 parent acb321f commit 790c2b2
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 2 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,8 @@ pip install -e .
```bash
# The following lines often lead to better memory usage in JAX
# when using multiple GPU.
export XLA_PYTHON_CLIENT_ALLOCATOR=platform
export XLA_PYTHON_CLIENT_PREALLOCATE=false
export XLA_PYTHON_CLIENT_ALLOCATOR="platform"
export XLA_PYTHON_CLIENT_PREALLOCATE="false"
export NCCL_LL128_BUFFSIZE="-2"
export NCCL_LL_BUFFSIZE="-2"
export NCCL_PROTO="SIMPLE,LL,LL128"
Expand Down
16 changes: 16 additions & 0 deletions src/fdtdx/core/config.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import math
from typing import Literal

import jax
import jax.numpy as jnp
from loguru import logger

from fdtdx.core.jax.pytrees import ExtendedTreeClass, extended_autoinit, frozen_field
from fdtdx.core.physics import constants
Expand Down Expand Up @@ -62,6 +64,20 @@ class SimulationConfig(ExtendedTreeClass):
courant_factor: float = 0.99
gradient_config: GradientConfig | None = None

def __post_init__(self):
if self.backend in ["gpu", "tpu"]:
# Try to initialize GPU
try:
jax.devices(self.backend)
logger.info(f"{str.upper(self.backend)} found and will be used for computations")
jax.config.update("jax_platform_name", self.backend)
except RuntimeError:
logger.warning(f"{str.upper(self.backend)} not found, falling back to CPU!")
self.backend = "cpu"
jax.config.update("jax_platform_name", "cpu")
else:
jax.config.update("jax_platform_name", "cpu")

@property
def courant_number(self) -> float:
"""Calculate the Courant number for the simulation.
Expand Down

0 comments on commit 790c2b2

Please sign in to comment.