You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Currently, the zuko.flows.mixture.GMM class only supports full covariance matrices. However, there are a number of use cases (especially high-dimensional) where a full covariance matrix is either not needed or infeasible to estimate. This issue proposes to add the option to choose between different covariance matrix types similar to sklearn.mixture.GaussianMixture
Here is an example of how different covariance types approximate a mixture of 3 Gaussians with varying covariance matrices.
Implementation
The current structure of the GMM zuko.flows.mixture.GMM class makes it very easy to add the above mentioned enhancements. I have implemented the changes in a fork of the repository and could open a pull request if this change is wanted. I have only tested the code for the unconditional case, but I do not see any way I could break it when adding context features.
Further improvements
When generating the above figure, I (again) realised how easily mode collapse happens for GMMs. The zuko.flows.mixture.GMM class could, therefore, also benefit from some sort of initialisation procedure, again, similar to sklearn.mixture.GaussianMixture. I fully understand if that goes beyond the scope of what Zuko wants to achieve. The benefit is that Zuko is very convenient to use and ties in so well with Pytorch code that having such a procedure here could be nice. However, it might add another dependency (e.g., sklearn) if you want to use existing implementations of initialisation algorithms. I have some basic implementation of this (using sklearn) lying around and would be happy to polish it up and make another commit if this is wanted.
The text was updated successfully, but these errors were encountered:
I think supporting several covariance types would be a valuable improvement for the GMM class. Actually I wanted to add a diagonal plus low rank covariance option at some point.
For the initialization, this is also relevant although handling the conditional case might be tricky (but possible by editing the weights of the last layer of the hyper network). However, Zuko cannot rely on sklearn so it should be implemented from scratch.
We would accept a PR that implements this feature!
Description
Currently, the
zuko.flows.mixture.GMM
class only supports full covariance matrices. However, there are a number of use cases (especially high-dimensional) where a full covariance matrix is either not needed or infeasible to estimate. This issue proposes to add the option to choose between different covariance matrix types similar tosklearn.mixture.GaussianMixture
Here is an example of how different covariance types approximate a mixture of 3 Gaussians with varying covariance matrices.
Implementation
The current structure of the GMM
zuko.flows.mixture.GMM
class makes it very easy to add the above mentioned enhancements. I have implemented the changes in a fork of the repository and could open a pull request if this change is wanted. I have only tested the code for the unconditional case, but I do not see any way I could break it when adding context features.Further improvements
When generating the above figure, I (again) realised how easily mode collapse happens for GMMs. The
zuko.flows.mixture.GMM
class could, therefore, also benefit from some sort of initialisation procedure, again, similar tosklearn.mixture.GaussianMixture
. I fully understand if that goes beyond the scope of what Zuko wants to achieve. The benefit is that Zuko is very convenient to use and ties in so well with Pytorch code that having such a procedure here could be nice. However, it might add another dependency (e.g., sklearn) if you want to use existing implementations of initialisation algorithms. I have some basic implementation of this (using sklearn) lying around and would be happy to polish it up and make another commit if this is wanted.The text was updated successfully, but these errors were encountered: