Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow Apple MPS as GPU device #912

Merged
merged 2 commits into from
Feb 19, 2024
Merged

Allow Apple MPS as GPU device #912

merged 2 commits into from
Feb 19, 2024

Conversation

janfb
Copy link
Contributor

@janfb janfb commented Jan 18, 2024

Problem

We only support CUDA as GPU devices, but PyTorch 2.1 now supports Apple MPS chips as well (to some extend, see below).

https://pytorch.org/docs/stable/notes/mps.html

Solution

This PR changes the processing of passed device arguments to also allow MPS, e.g., instead of using cuda in the tests, we use gpu and parse the string to mps:0 or cuda:0 accordingly.

Additional comments

@janfb janfb added the enhancement New feature or request label Jan 18, 2024
@janfb janfb self-assigned this Jan 18, 2024
@michaeldeistler
Copy link
Contributor

Re the nflows problem: I am fine with not supporting this, in particular if we will support other density estimators soon.

@janfb
Copy link
Contributor Author

janfb commented Jan 24, 2024

Maybe I misunderstood: you would rather not support MPS devices because we would have to make sure the future density estimators all run with float32?

@michaeldeistler
Copy link
Contributor

No, I meant that we do not support MPS devices if nflows is used as backend. We should support MPS devides for other density estimators (which will hopefully use float32).

Copy link

codecov bot commented Jan 25, 2024

Codecov Report

Attention: 9 lines in your changes are missing coverage. Please review.

Comparison is base (f4cebfc) 75.29% compared to head (afc7df2) 76.02%.
Report is 4 commits behind head on main.

Files Patch % Lines
sbi/utils/torchutils.py 57.14% 9 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main     #912      +/-   ##
==========================================
+ Coverage   75.29%   76.02%   +0.73%     
==========================================
  Files          80       80              
  Lines        6286     6319      +33     
==========================================
+ Hits         4733     4804      +71     
+ Misses       1553     1515      -38     
Flag Coverage Δ
unittests 76.02% <57.14%> (+0.73%) ⬆️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@vivienr
Copy link

vivienr commented Feb 8, 2024

@janfb thank you for this (and to the SBI team for a very useful software) ! I have been running on SBI on MPS, and saw this PR when looking at creating my own. Just to note in case it's useful, I have been using nflows, with a somewhat ugly wrapping of the density estimator nflows returns:

density_estimator_custom = lambda theta,x: density_estimator_custom_float64(theta,x).to(dtype=torch.float32)

@janfb
Copy link
Contributor Author

janfb commented Feb 8, 2024

Hi @vivienr thanks for your comment! Good to know that you have been using MPS with SBI already!
So I assume you experienced a speed up compared to CPU? How large / what type are your embedding nets?

thanks also for the suggestion, that would indeed work.
We hope to find a more sustainable option soon, e.g., by making a PR in nflows or by adding support for other density estimation packages.

@vivienr
Copy link

vivienr commented Feb 8, 2024

I'm seeing pretty small speed-ups (~10%) with my current test set-up: O(100k) simulations, and my default embedding net is a sequence of dense residual blocks with linear resizing layers. This test case is O(10 layers), input dimension 64, output 16.

But I do need to scale up to my real use-case with a larger embedding network. I'm also limited by MacOS 12 not having several operator supported and falling back on the CPU. I will upgrade to 13 and see if things improve.

@janfb
Copy link
Contributor Author

janfb commented Feb 8, 2024

Thank you for the details, that's good to know. 👍

@janfb janfb added this to the Pre Hackathon 2024 milestone Feb 9, 2024
@janfb janfb force-pushed the allow-mps-device branch 2 times, most recently from 32c6b3a to 189fb75 Compare February 9, 2024 16:24
@janfb
Copy link
Contributor Author

janfb commented Feb 13, 2024

Update:

  • the float64 in nflows appears only in the buffer of the StandardNormal:

https://github.com/bayesiains/nflows/blob/3b122e5bbc14ed196301969c12d1c2d94fdfba47/nflows/distributions/normal.py#L18-L21

To fix the problem with MPS, I added the option to set the type of that buffer when we are building our flows using nflows:

sbi/sbi/neural_nets/flow.py

Lines 480 to 487 in 189fb75

def get_base_dist(
num_dims: int, dtype: torch.dtype = torch.float32, **kwargs
) -> distributions_.Distribution:
"""Returns the base distribution for the flows with given float type."""
base = distributions_.StandardNormal((num_dims,))
base._log_z = base._log_z.to(dtype)
return base

@janfb
Copy link
Contributor Author

janfb commented Feb 13, 2024

@manuelgloeckler

  • vi_on_gpu tests are failing: for nsf and for num_dim=2 the TransformedDistribution q produces NaN samples. @manuelgloeckler can you please reproduce this on this branch and have a look?
    To reproduce, run this command on a MacBook with MPS:
    pytest tests/inference_on_device_test.py::test_vi_on_gpu --pdb

@pytest.mark.slow
@pytest.mark.gpu
@pytest.mark.parametrize("num_dim", (1, 2))
@pytest.mark.parametrize("q", ("maf", "nsf", "gaussian_diag", "gaussian", "mcf", "scf"))
@pytest.mark.parametrize("vi_method", ("rKL", "fKL", "IW", "alpha"))
@pytest.mark.parametrize("sampling_method", ("naive", "sir"))
def test_vi_on_gpu(num_dim: int, q: Distribution, vi_method: str, sampling_method: str):

@janfb janfb marked this pull request as ready for review February 16, 2024 10:21
@janfb janfb added architecture Internal changes without API consequences performance Everything related to performance labels Feb 16, 2024
@manuelgloeckler
Copy link
Contributor

Well, I do not have a MacBook (nor access to any). I guess I cant test it then.

@janfb janfb force-pushed the allow-mps-device branch 2 times, most recently from 9d9b128 to 7a27299 Compare February 16, 2024 13:25
@janfb
Copy link
Contributor Author

janfb commented Feb 16, 2024

using VIPosterior with MPS and nsf variational family and num_dim>1 results in NaN samples.
Might be related to pytorch/pytorch#89127, thanks @manuelgloeckler

I made a comment in the corresponding vi test.

This is ready for review now.

Copy link
Contributor

@michaeldeistler michaeldeistler left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Small comment regarding the nsf problems, feel free to merge once it is adressed.

tests/inference_on_device_test.py Show resolved Hide resolved
@janfb janfb merged commit 2830fda into main Feb 19, 2024
2 of 3 checks passed
@janfb janfb deleted the allow-mps-device branch February 19, 2024 09:08
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
architecture Internal changes without API consequences enhancement New feature or request performance Everything related to performance
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants