Skip to content

Commit

Permalink
Merge pull request #12 from txdai/main
Browse files Browse the repository at this point in the history
Include automatic selection of metal backend for Mac
  • Loading branch information
ymahlau authored Jan 23, 2025
2 parents c4cc009 + 88877f4 commit 981ccb5
Showing 1 changed file with 19 additions and 5 deletions.
24 changes: 19 additions & 5 deletions src/fdtdx/core/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,23 +59,37 @@ class SimulationConfig(ExtendedTreeClass):

time: float
resolution: float
backend: Literal["gpu", "tpu", "cpu"] = frozen_field(default="gpu")
backend: Literal["gpu", "tpu", "METAL", "cpu"] = frozen_field(default="gpu")
dtype: jnp.dtype = frozen_field(default=jnp.float32)
courant_factor: float = 0.99
gradient_config: GradientConfig | None = None

def __post_init__(self):
if self.backend in ["gpu", "tpu"]:
# Try to initialize GPU
from jax import extend

current_platform = extend.backend.get_backend().platform

if current_platform == "METAL" and self.backend == "gpu":
self.backend = "METAL"

if self.backend == "METAL":
try:
jax.devices()
logger.info("METAL device found and will be used for computations")
jax.config.update("jax_platform_name", "metal")
except RuntimeError:
logger.warning("METAL initialization failed, falling back to CPU!")
self.backend = "cpu"
elif self.backend in ["gpu", "tpu"]:
try:
jax.devices(self.backend)
logger.info(f"{str.upper(self.backend)} found and will be used for computations")
jax.config.update("jax_platform_name", self.backend)
except RuntimeError:
logger.warning(f"{str.upper(self.backend)} not found, falling back to CPU!")
self.backend = "cpu"
jax.config.update("jax_platform_name", "cpu")
else:

if self.backend == "cpu":
jax.config.update("jax_platform_name", "cpu")

@property
Expand Down

0 comments on commit 981ccb5

Please sign in to comment.