From 88877f4737da82e730ffae5b7ff7316c05982bed Mon Sep 17 00:00:00 2001 From: Tianxiang Dai Date: Wed, 22 Jan 2025 06:03:21 -0800 Subject: [PATCH] Support for GPU acceleration on Mac Include automatic selection of metal backend for Mac --- src/fdtdx/core/config.py | 24 +++++++++++++++++++----- 1 file changed, 19 insertions(+), 5 deletions(-) diff --git a/src/fdtdx/core/config.py b/src/fdtdx/core/config.py index b2268f0..e2ff223 100644 --- a/src/fdtdx/core/config.py +++ b/src/fdtdx/core/config.py @@ -59,14 +59,28 @@ class SimulationConfig(ExtendedTreeClass): time: float resolution: float - backend: Literal["gpu", "tpu", "cpu"] = frozen_field(default="gpu") + backend: Literal["gpu", "tpu", "METAL", "cpu"] = frozen_field(default="gpu") dtype: jnp.dtype = frozen_field(default=jnp.float32) courant_factor: float = 0.99 gradient_config: GradientConfig | None = None def __post_init__(self): - if self.backend in ["gpu", "tpu"]: - # Try to initialize GPU + from jax import extend + + current_platform = extend.backend.get_backend().platform + + if current_platform == "METAL" and self.backend == "gpu": + self.backend = "METAL" + + if self.backend == "METAL": + try: + jax.devices() + logger.info("METAL device found and will be used for computations") + jax.config.update("jax_platform_name", "metal") + except RuntimeError: + logger.warning("METAL initialization failed, falling back to CPU!") + self.backend = "cpu" + elif self.backend in ["gpu", "tpu"]: try: jax.devices(self.backend) logger.info(f"{str.upper(self.backend)} found and will be used for computations") @@ -74,8 +88,8 @@ def __post_init__(self): except RuntimeError: logger.warning(f"{str.upper(self.backend)} not found, falling back to CPU!") self.backend = "cpu" - jax.config.update("jax_platform_name", "cpu") - else: + + if self.backend == "cpu": jax.config.update("jax_platform_name", "cpu") @property