-
Notifications
You must be signed in to change notification settings - Fork 11
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support for Metal hardware #11
Comments
Hi, this would indeed be a good feature for a lot of users! We do not have an Apple machine to test this, so it is a bit difficult to develop this feature on our side. Also JAX on Apple Metal is an experimental feature of JAX, so there may be issues with that as well. Do you want to open a PR for this? I think the only things that may need to be changed is the backend in the config. I am not sure what the string identifier for the metal backend is. JAX does not seem to say in its documentation. You can install the metal version of jax with If you get this feature working, we might also change the installation option on apple to optionally support metal |
The name for metal is in fact "METAL" when you have Jax-metal installed. I included this but it seems that matrix inversion was not implemented in Jax-metal refered here. FDTDX depends on this to run so I guess more modifications would be needed (or to use MLX completely which is quite similar to Jax and supports a decent amount of features) |
Hm that's interesting. As far as I know, there is no matrix inversion used in our code of FDTDX. Do you get a specific error that points to the part of the code or the dependency that uses the inversion? We might be able to remove or work around that dependency |
For example in "fdtdx/objects/sources/plane_source.py", line 235, in _rotate_vector, "jnp.linalg.inv" is a matrix inversion |
Ah, I forgot about that part of the Code :D This is pretty easy to fix, the matrix is orthogonal and can also be inverted by a transpose operation. I changed this, could you verify if it works with metal now? |
This works now. Perhaps you can also provide the benchmarking example used in the white paper for speed comparison? I would like to see how M4Max performs in comparison |
I just uploaded the script, which is probably a good idea for reproducibility purposes anyways :) Please let me know when you have results on the M4Max, I would be very interested to hear about that! |
I tried running the benchmark with pixel size of 25nm. The Jax JIT seems to have some problems with GPU so I first tested the running without JIT. Took 65s to run on Metal GPU. Curiously, it took 2.4s to compile and 63s to run on CPU of M4 Max. I am able to confirm using the resource monitor that the code is indeed executed on the corresponding devices. Then I tried JIT with CPU. Took 91s to run. Seems like Jax-Metal is not quite ready for this task. Also buffer donation is not supported on Metal as well. |
Just want to chime in that I've got an M3 Max Mac and would be very interested to see metal support. I've had trouble getting meep to work on Apple silicon and there aren't a ton of options out there. I'm trying out fdtdx to see how it will work in my lab. Looking forward to see how this package develops! |
Thanks for you interest! We will do our best to enable Metal support, but I think it think the main restriction there is JAX itself and how well (or not) it is supported on Metal. If you have any feedback for improving FDTDX after trying it out let us know! |
In addition to Nvidia GPUs/ TPUs, Mac computers also have GPUs with reasonable computational power, and offered by Jax-metal for support. They normally have larger available memories (e.g. 32/64GB) compared with CUDA memories. I think some adaptation to support using this backend might be helpful
The text was updated successfully, but these errors were encountered: