Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add covariance_type option to GMM class #49

Open
dominik-strutz opened this issue Apr 2, 2024 · 1 comment
Open

Add covariance_type option to GMM class #49

dominik-strutz opened this issue Apr 2, 2024 · 1 comment
Labels
enhancement New feature or request

Comments

@dominik-strutz
Copy link

Description

Currently, the zuko.flows.mixture.GMM class only supports full covariance matrices. However, there are a number of use cases (especially high-dimensional) where a full covariance matrix is either not needed or infeasible to estimate. This issue proposes to add the option to choose between different covariance matrix types similar to sklearn.mixture.GaussianMixture

Here is an example of how different covariance types approximate a mixture of 3 Gaussians with varying covariance matrices.

c6263e0a-5d78-45ef-8df4-414692a7871f

Implementation

The current structure of the GMM zuko.flows.mixture.GMM class makes it very easy to add the above mentioned enhancements. I have implemented the changes in a fork of the repository and could open a pull request if this change is wanted. I have only tested the code for the unconditional case, but I do not see any way I could break it when adding context features.

Further improvements

When generating the above figure, I (again) realised how easily mode collapse happens for GMMs. The zuko.flows.mixture.GMM class could, therefore, also benefit from some sort of initialisation procedure, again, similar to sklearn.mixture.GaussianMixture. I fully understand if that goes beyond the scope of what Zuko wants to achieve. The benefit is that Zuko is very convenient to use and ties in so well with Pytorch code that having such a procedure here could be nice. However, it might add another dependency (e.g., sklearn) if you want to use existing implementations of initialisation algorithms. I have some basic implementation of this (using sklearn) lying around and would be happy to polish it up and make another commit if this is wanted.

@dominik-strutz dominik-strutz added the enhancement New feature or request label Apr 2, 2024
@francois-rozet
Copy link
Member

Hello @dominik-strutz, thank you for the feature request!

I think supporting several covariance types would be a valuable improvement for the GMM class. Actually I wanted to add a diagonal plus low rank covariance option at some point.

For the initialization, this is also relevant although handling the conditional case might be tricky (but possible by editing the weights of the last layer of the hyper network). However, Zuko cannot rely on sklearn so it should be implemented from scratch.

We would accept a PR that implements this feature!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

2 participants