From 004db1e26cd393222efc258a6590b5749f6f4002 Mon Sep 17 00:00:00 2001 From: dominik-strutz Date: Tue, 2 Apr 2024 09:12:28 +0100 Subject: [PATCH 1/3] Add covariance_type option to GMM class --- zuko/flows/mixture.py | 62 ++++++++++++++++++++++++++++++++++--------- 1 file changed, 50 insertions(+), 12 deletions(-) diff --git a/zuko/flows/mixture.py b/zuko/flows/mixture.py index ab0c5c5..ba50451 100644 --- a/zuko/flows/mixture.py +++ b/zuko/flows/mixture.py @@ -7,15 +7,15 @@ import torch import torch.nn as nn -from math import prod -from torch import Tensor -from torch.distributions import Distribution, MultivariateNormal # isort: local from .core import LazyDistribution from ..distributions import Mixture from ..nn import MLP from ..utils import unpack +from math import prod +from torch import Tensor +from torch.distributions import Distribution, Independent, MultivariateNormal, Normal class GMM(LazyDistribution): @@ -30,6 +30,16 @@ class GMM(LazyDistribution): features: The number of features. context: The number of context features. components: The number of components :math:`K` in the mixture. + covariance_type: String describing the type of covariance parameters to use. Must be one of: + + - ‘full’: each component has its own general covariance matrix. + + - ‘tied’: all components share the same general covariance matrix. + + - ‘diag’: each component has its own diagonal covariance matrix. + + - ‘spherical’: each component has its own single variance. + kwargs: Keyword arguments passed to :class:`zuko.nn.MLP`. """ @@ -38,6 +48,7 @@ def __init__( features: int, context: int = 0, components: int = 2, + covariance_type: str = 'full', **kwargs, ): super().__init__() @@ -45,10 +56,31 @@ def __init__( shapes = [ (components,), # probabilities (components, features), # mean - (components, features), # diagonal - (components, features * (features - 1) // 2), # off diagonal ] + if covariance_type == 'full': + shapes.extend([ + (components, features), # diagonal + (components, features * (features - 1) // 2), # off diagonal + ]) + elif covariance_type == 'tied': + shapes.extend([ + (1, features), # diagonal + (1, features * (features - 1) // 2), # off diagonal + ]) + elif covariance_type == 'diag': + shapes.extend([ + (components, features), # diagonal + ]) + elif covariance_type == 'spherical': + shapes.extend([ + (components, 1), # diagonal + ]) + else: + raise ValueError( + f'Invalid covariance type: {covariance_type} (choose from full, diag, spherical)' + ) + self.covariance_type = covariance_type self.shapes = shapes self.total = sum(prod(s) for s in shapes) @@ -64,10 +96,16 @@ def forward(self, c: Tensor = None) -> Distribution: phi = self.hyper(c) phi = unpack(phi, self.shapes) - logits, loc, diag, tril = phi - - scale = torch.diag_embed(diag.exp() + 1e-5) - mask = torch.tril(torch.ones_like(scale, dtype=bool), diagonal=-1) - scale = torch.masked_scatter(scale, mask, tril) - - return Mixture(MultivariateNormal(loc=loc, scale_tril=scale), logits) + if self.covariance_type in ['full', 'tied']: + logits, loc, diag, tril = phi + scale = torch.diag_embed(diag.exp() + 1e-5) + mask = torch.tril(torch.ones_like(scale, dtype=bool), diagonal=-1) + scale = torch.masked_scatter(scale, mask, tril) + # expanded automatically for tied covariance + return Mixture(MultivariateNormal(loc=loc, scale_tril=scale), logits) + + elif self.covariance_type in ['diag', 'spherical']: + logits, loc, diag = phi + scale = diag.exp() + 1e-5 + # expanded automatically for spherical covariance + return Mixture(Independent(Normal(loc, scale), 1), logits) From 5f05877c963b7767746b95e9750f5f676df12370 Mon Sep 17 00:00:00 2001 From: dominik-strutz Date: Wed, 3 Apr 2024 08:10:17 +0100 Subject: [PATCH 2/3] Add low rank covariance type top GMM --- zuko/flows/mixture.py | 34 ++++++++++++++++++++++++++++++---- 1 file changed, 30 insertions(+), 4 deletions(-) diff --git a/zuko/flows/mixture.py b/zuko/flows/mixture.py index ba50451..2ed32e2 100644 --- a/zuko/flows/mixture.py +++ b/zuko/flows/mixture.py @@ -15,7 +15,13 @@ from ..utils import unpack from math import prod from torch import Tensor -from torch.distributions import Distribution, Independent, MultivariateNormal, Normal +from torch.distributions import ( + Distribution, + Independent, + LowRankMultivariateNormal, + MultivariateNormal, + Normal, +) class GMM(LazyDistribution): @@ -40,6 +46,9 @@ class GMM(LazyDistribution): - ‘spherical’: each component has its own single variance. + - 'lowrank': each component has its own low-rank covariance matrix. + + cov_rank: The rank of the low-rank covariance matrix. Only used when `covariance_type` is 'lowrank'. kwargs: Keyword arguments passed to :class:`zuko.nn.MLP`. """ @@ -49,6 +58,7 @@ def __init__( context: int = 0, components: int = 2, covariance_type: str = 'full', + cov_rank: int = None, **kwargs, ): super().__init__() @@ -67,6 +77,13 @@ def __init__( (1, features), # diagonal (1, features * (features - 1) // 2), # off diagonal ]) + elif covariance_type == 'lowrank': + if cov_rank is None: + raise ValueError('cov_rank must be specified when covariance_type is lowrank') + shapes.extend([ + (components, features), # diagonal + (components, features * cov_rank), # low-rank + ]) elif covariance_type == 'diag': shapes.extend([ (components, features), # diagonal @@ -77,10 +94,11 @@ def __init__( ]) else: raise ValueError( - f'Invalid covariance type: {covariance_type} (choose from full, diag, spherical)' + f'Invalid covariance type: {covariance_type} (choose from full, tied, lowrank, diag, or spherical)' ) self.covariance_type = covariance_type + self.cov_rank = cov_rank self.shapes = shapes self.total = sum(prod(s) for s in shapes) @@ -104,8 +122,16 @@ def forward(self, c: Tensor = None) -> Distribution: # expanded automatically for tied covariance return Mixture(MultivariateNormal(loc=loc, scale_tril=scale), logits) + if self.covariance_type == 'lowrank': + logits, loc, diag, lowrank = phi + diag = diag.exp() + 1e-5 + lowrank = lowrank.reshape(lowrank.shape[0], lowrank.shape[1], self.cov_rank) + return Mixture( + LowRankMultivariateNormal(loc=loc, cov_factor=lowrank, cov_diag=diag), logits + ) + elif self.covariance_type in ['diag', 'spherical']: logits, loc, diag = phi - scale = diag.exp() + 1e-5 + diag = diag.exp() + 1e-5 # expanded automatically for spherical covariance - return Mixture(Independent(Normal(loc, scale), 1), logits) + return Mixture(Independent(Normal(loc, diag), 1), logits) From 0a176b5d563b2a70e78c3770dc34716264787929 Mon Sep 17 00:00:00 2001 From: dominik-strutz Date: Wed, 3 Apr 2024 14:33:53 +0100 Subject: [PATCH 3/3] Restructure GMM class options to seperate cov type and whether they are tied --- zuko/flows/mixture.py | 102 +++++++++++++++++++++++++----------------- 1 file changed, 62 insertions(+), 40 deletions(-) diff --git a/zuko/flows/mixture.py b/zuko/flows/mixture.py index 2ed32e2..722096b 100644 --- a/zuko/flows/mixture.py +++ b/zuko/flows/mixture.py @@ -24,6 +24,58 @@ ) +def _determine_shapes(components, features, covariance_type, tied, cov_rank): + shapes = [ + (components,), # probabilities + (components, features), # mean + ] + if covariance_type == 'full' and not tied: + shapes.extend([ + (components, features), # diagonal + (components, features * (features - 1) // 2), # off diagonal + ]) + elif covariance_type == 'full' and tied: + shapes.extend([ + (1, features), # diagonal + (1, features * (features - 1) // 2), # off diagonal + ]) + elif covariance_type == 'lowrank' and not tied: + if cov_rank is None: + raise ValueError('cov_rank must be specified when covariance_type is lowrank') + shapes.extend([ + (components, features), # diagonal + (components, features * cov_rank), # low-rank + ]) + elif covariance_type == 'lowrank' and tied: + if cov_rank is None: + raise ValueError('cov_rank must be specified when covariance_type is lowrank') + shapes.extend([ + (1, features), # diagonal + (1, features * cov_rank), # low-rank + ]) + elif covariance_type == 'diag' and not tied: + shapes.extend([ + (components, features), # diagonal + ]) + elif covariance_type == 'diag' and tied: + shapes.extend([ + (1, features), # diagonal + ]) + elif covariance_type == 'spherical' and not tied: + shapes.extend([ + (components, 1), # diagonal + ]) + elif covariance_type == 'spherical' and tied: + shapes.extend([ + (1, 1), # diagonal + ]) + else: + raise ValueError( + f'Invalid covariance type: {covariance_type} (choose from full, lowrank, diag, or spherical)' + ) + return shapes + + class GMM(LazyDistribution): r"""Creates a Gaussian mixture model (GMM). @@ -38,16 +90,15 @@ class GMM(LazyDistribution): components: The number of components :math:`K` in the mixture. covariance_type: String describing the type of covariance parameters to use. Must be one of: - - ‘full’: each component has its own general covariance matrix. + - ‘full’: each component has its own full rank covariance matrix. - - ‘tied’: all components share the same general covariance matrix. + - ’lowrank’: each component has its own low-rank covariance matrix. - ‘diag’: each component has its own diagonal covariance matrix. - ‘spherical’: each component has its own single variance. - - 'lowrank': each component has its own low-rank covariance matrix. - + tied: Whether to use tied covariance matrices. Tied covariances share the same parameters across components. cov_rank: The rank of the low-rank covariance matrix. Only used when `covariance_type` is 'lowrank'. kwargs: Keyword arguments passed to :class:`zuko.nn.MLP`. """ @@ -58,46 +109,16 @@ def __init__( context: int = 0, components: int = 2, covariance_type: str = 'full', + tied: bool = False, cov_rank: int = None, **kwargs, ): super().__init__() - shapes = [ - (components,), # probabilities - (components, features), # mean - ] - if covariance_type == 'full': - shapes.extend([ - (components, features), # diagonal - (components, features * (features - 1) // 2), # off diagonal - ]) - elif covariance_type == 'tied': - shapes.extend([ - (1, features), # diagonal - (1, features * (features - 1) // 2), # off diagonal - ]) - elif covariance_type == 'lowrank': - if cov_rank is None: - raise ValueError('cov_rank must be specified when covariance_type is lowrank') - shapes.extend([ - (components, features), # diagonal - (components, features * cov_rank), # low-rank - ]) - elif covariance_type == 'diag': - shapes.extend([ - (components, features), # diagonal - ]) - elif covariance_type == 'spherical': - shapes.extend([ - (components, 1), # diagonal - ]) - else: - raise ValueError( - f'Invalid covariance type: {covariance_type} (choose from full, tied, lowrank, diag, or spherical)' - ) + shapes = _determine_shapes(components, features, covariance_type, tied, cov_rank) self.covariance_type = covariance_type + self.tied = tied self.cov_rank = cov_rank self.shapes = shapes self.total = sum(prod(s) for s in shapes) @@ -114,18 +135,19 @@ def forward(self, c: Tensor = None) -> Distribution: phi = self.hyper(c) phi = unpack(phi, self.shapes) - if self.covariance_type in ['full', 'tied']: + if self.covariance_type == 'full': logits, loc, diag, tril = phi scale = torch.diag_embed(diag.exp() + 1e-5) mask = torch.tril(torch.ones_like(scale, dtype=bool), diagonal=-1) scale = torch.masked_scatter(scale, mask, tril) - # expanded automatically for tied covariance + # expanded automatically for tied covariances return Mixture(MultivariateNormal(loc=loc, scale_tril=scale), logits) if self.covariance_type == 'lowrank': logits, loc, diag, lowrank = phi diag = diag.exp() + 1e-5 lowrank = lowrank.reshape(lowrank.shape[0], lowrank.shape[1], self.cov_rank) + # expanded automatically for tied covariances return Mixture( LowRankMultivariateNormal(loc=loc, cov_factor=lowrank, cov_diag=diag), logits ) @@ -133,5 +155,5 @@ def forward(self, c: Tensor = None) -> Distribution: elif self.covariance_type in ['diag', 'spherical']: logits, loc, diag = phi diag = diag.exp() + 1e-5 - # expanded automatically for spherical covariance + # expanded automatically for spherical and tied covariance return Mixture(Independent(Normal(loc, diag), 1), logits)