-
Notifications
You must be signed in to change notification settings - Fork 151
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Allow Apple MPS as GPU device #912
Conversation
Re the nflows problem: I am fine with not supporting this, in particular if we will support other density estimators soon. |
Maybe I misunderstood: you would rather not support MPS devices because we would have to make sure the future density estimators all run with float32? |
No, I meant that we do not support MPS devices if nflows is used as backend. We should support MPS devides for other density estimators (which will hopefully use float32). |
Codecov ReportAttention:
Additional details and impacted files@@ Coverage Diff @@
## main #912 +/- ##
==========================================
+ Coverage 75.29% 76.02% +0.73%
==========================================
Files 80 80
Lines 6286 6319 +33
==========================================
+ Hits 4733 4804 +71
+ Misses 1553 1515 -38
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. |
@janfb thank you for this (and to the SBI team for a very useful software) ! I have been running on SBI on MPS, and saw this PR when looking at creating my own. Just to note in case it's useful, I have been using density_estimator_custom = lambda theta,x: density_estimator_custom_float64(theta,x).to(dtype=torch.float32) |
Hi @vivienr thanks for your comment! Good to know that you have been using MPS with SBI already! thanks also for the suggestion, that would indeed work. |
I'm seeing pretty small speed-ups (~10%) with my current test set-up: O(100k) simulations, and my default embedding net is a sequence of dense residual blocks with linear resizing layers. This test case is O(10 layers), input dimension 64, output 16. But I do need to scale up to my real use-case with a larger embedding network. I'm also limited by MacOS 12 not having several operator supported and falling back on the CPU. I will upgrade to 13 and see if things improve. |
Thank you for the details, that's good to know. 👍 |
32c6b3a
to
189fb75
Compare
Update:
To fix the problem with MPS, I added the option to set the type of that buffer when we are building our flows using Lines 480 to 487 in 189fb75
|
sbi/tests/inference_on_device_test.py Lines 491 to 497 in 189fb75
|
189fb75
to
0f7393e
Compare
Well, I do not have a MacBook (nor access to any). I guess I cant test it then. |
0f7393e
to
2591496
Compare
9d9b128
to
7a27299
Compare
using I made a comment in the corresponding vi test. This is ready for review now. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! Small comment regarding the nsf problems, feel free to merge once it is adressed.
7a27299
to
afc7df2
Compare
Problem
We only support CUDA as GPU devices, but PyTorch 2.1 now supports Apple MPS chips as well (to some extend, see below).
https://pytorch.org/docs/stable/notes/mps.html
Solution
This PR changes the processing of passed device arguments to also allow MPS, e.g., instead of using
cuda
in the tests, we usegpu
and parse the string tomps:0
orcuda:0
accordingly.Additional comments
VIPosterior
,q="nsf"
andnum_dims>1
. Samplingq
then produces NaNs, see VIPosterior with device="mps" fails for "nsf" #948torch.linalg.cholesky
. So it uses CPU as a fallback therenflows
requires float64 (see here: https://github.com/bayesiains/nflows/blob/3b122e5bbc14ed196301969c12d1c2d94fdfba47/nflows/distributions/normal.py#L19-L20)nflows
to remove the hard coding of float64 and add an option to set the default float type globally, using MPS will not work.