Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement MVN.unsqueeze #2624

Merged
merged 6 commits into from
Jan 22, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 61 additions & 4 deletions gpytorch/distributions/multivariate_normal.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,10 +136,67 @@ def expand(self, batch_size: torch.Size) -> MultivariateNormal:
See :py:meth:`torch.distributions.Distribution.expand
<torch.distributions.distribution.Distribution.expand>`.
"""
new_loc = self.loc.expand(torch.Size(batch_size) + self.loc.shape[-1:])
new_covar = self._covar.expand(torch.Size(batch_size) + self._covar.shape[-2:])
res = self.__class__(new_loc, new_covar)
return res
# NOTE: Pyro may call this method with list[int] instead of torch.Size.
batch_size = torch.Size(batch_size)
new_loc = self.loc.expand(batch_size + self.loc.shape[-1:])
if self.islazy:
new_covar = self._covar.expand(batch_size + self._covar.shape[-2:])
new = self.__class__(mean=new_loc, covariance_matrix=new_covar)
if self.__unbroadcasted_scale_tril is not None:
# Reuse the scale tril if available.
new.__unbroadcasted_scale_tril = self.__unbroadcasted_scale_tril.expand(
batch_size + self.__unbroadcasted_scale_tril.shape[-2:]
)
else:
# Non-lazy MVN is represented using scale_tril in PyTorch.
# Constructing it from scale_tril will avoid unnecessary computation.
# Initialize using __new__, so that we can skip __init__ and use scale_tril.
new = self.__new__(type(self))
new._islazy = False
new_scale_tril = self.__unbroadcasted_scale_tril.expand(
batch_size + self.__unbroadcasted_scale_tril.shape[-2:]
)
super(MultivariateNormal, new).__init__(loc=new_loc, scale_tril=new_scale_tril)
# Set the covar matrix, since it is always available for GPyTorch MVN.
new.covariance_matrix = self.covariance_matrix.expand(batch_size + self.covariance_matrix.shape[-2:])
return new

def unsqueeze(self, dim: int) -> MultivariateNormal:
r"""
Constructs a new MultivariateNormal with the batch shape unsqueezed
by the given dimension.
For example, if `self.batch_shape = torch.Size([2, 3])` and `dim = 0`, then
the returned MultivariateNormal will have `batch_shape = torch.Size([1, 2, 3])`.
If `dim = -1`, then the returned MultivariateNormal will have
`batch_shape = torch.Size([2, 3, 1])`.
"""
if dim > len(self.batch_shape) or dim < -len(self.batch_shape) - 1:
raise IndexError(
"Dimension out of range (expected to be in range of "
f"[{-len(self.batch_shape) - 1}, {len(self.batch_shape)}], but got {dim})."
)
if dim < 0:
# If dim is negative, get the positive equivalent.
dim = len(self.batch_shape) + dim + 1

new_loc = self.loc.unsqueeze(dim)
if self.islazy:
new_covar = self._covar.unsqueeze(dim)
new = self.__class__(mean=new_loc, covariance_matrix=new_covar)
if self.__unbroadcasted_scale_tril is not None:
# Reuse the scale tril if available.
new.__unbroadcasted_scale_tril = self.__unbroadcasted_scale_tril.unsqueeze(dim)
else:
# Non-lazy MVN is represented using scale_tril in PyTorch.
# Constructing it from scale_tril will avoid unnecessary computation.
# Initialize using __new__, so that we can skip __init__ and use scale_tril.
new = self.__new__(type(self))
new._islazy = False
new_scale_tril = self.__unbroadcasted_scale_tril.unsqueeze(dim)
super(MultivariateNormal, new).__init__(loc=new_loc, scale_tril=new_scale_tril)
# Set the covar matrix, since it is always available for GPyTorch MVN.
new.covariance_matrix = self.covariance_matrix.unsqueeze(dim)
Comment on lines +183 to +198
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A lot of this is duplicate from the expand() code above. Can we make a helper to reuse most of this? Or potentially consider allowing to instantiate a new MVN from the scale_tril directly as suggested on the other PR. Could be a class method MultivariateNormal.from_scale_tril() or the like if we don't want to change the __init__() signature.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They follow the same pattern but I wouldn't call them duplicates. You could make a helper that takes in the operation to apply to the tensors, but there are two separate operations for expand (for loc & covar), so I don't know if the resulting helper would improve the code readability.

Could be a class method MultivariateNormal.from_scale_tril() or the like if we don't want to change the init() signature.

This would avoid the issue of checking for covar & scale_tril compatibility in __init__. If we're not using __init__, would this just extract the self.__new__ based construction used here into a separate method? I suppose this'd work for non-lazy scale_tril as is. We'd have to have a separate case for lazy, at which point we end up with a duplicate of __init__ and I again question the added value of it over just keeping this as is.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we have a method exposed that allows constructing the MVN from scale_tril (either as part of an updated __init__() method or as a separate method) then I would expect that people will see and use this for other purposes going forward (but they most likely won't look for the logic for this somewhere deep in the code as in this PR.

Not going to die on that hill, but the fact that we don't expose a natural way of doing this seems like a gap we ideally could address.

return new

def get_base_samples(self, sample_shape: torch.Size = torch.Size()) -> Tensor:
r"""
Expand Down
62 changes: 62 additions & 0 deletions test/distributions/test_multivariate_normal.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import math
import unittest
from itertools import product

import torch
from linear_operator import to_linear_operator
Expand Down Expand Up @@ -323,6 +324,67 @@ def test_base_sample_shape(self):
samples = dist.rsample(torch.Size((16,)), base_samples=torch.randn(16, 5))
self.assertEqual(samples.shape, torch.Size((16, 5)))

def test_multivariate_normal_expand(self, cuda=False):
device = torch.device("cuda") if cuda else torch.device("cpu")
for dtype, lazy in product((torch.float, torch.double), (True, False)):
mean = torch.tensor([0, 1, 2], device=device, dtype=dtype)
covmat = torch.diag(torch.tensor([1, 0.75, 1.5], device=device, dtype=dtype))
if lazy:
mvn = MultivariateNormal(mean=mean, covariance_matrix=DenseLinearOperator(covmat), validate_args=True)
# Initialize scale tril so we can test that it was expanded.
mvn.scale_tril
else:
mvn = MultivariateNormal(mean=mean, covariance_matrix=covmat, validate_args=True)
self.assertEqual(mvn.batch_shape, torch.Size([]))
self.assertEqual(mvn.islazy, lazy)
expanded = mvn.expand(torch.Size([2]))
self.assertIsInstance(expanded, MultivariateNormal)
self.assertEqual(expanded.islazy, lazy)
self.assertEqual(expanded.batch_shape, torch.Size([2]))
self.assertEqual(expanded.event_shape, mvn.event_shape)
self.assertTrue(torch.equal(expanded.mean, mean.expand(2, -1)))
self.assertEqual(expanded.mean.shape, torch.Size([2, 3]))
self.assertTrue(torch.allclose(expanded.covariance_matrix, covmat.expand(2, -1, -1)))
self.assertEqual(expanded.covariance_matrix.shape, torch.Size([2, 3, 3]))
self.assertTrue(torch.allclose(expanded.scale_tril, mvn.scale_tril.expand(2, -1, -1)))
self.assertEqual(expanded.scale_tril.shape, torch.Size([2, 3, 3]))

def test_multivariate_normal_unsqueeze(self, cuda=False):
device = torch.device("cuda") if cuda else torch.device("cpu")
for dtype, lazy in product((torch.float, torch.double), (True, False)):
batch_shape = torch.Size([2, 3])
mean = torch.tensor([0, 1, 2], device=device, dtype=dtype).expand(*batch_shape, -1)
covmat = torch.diag(torch.tensor([1, 0.75, 1.5], device=device, dtype=dtype)).expand(*batch_shape, -1, -1)
if lazy:
mvn = MultivariateNormal(mean=mean, covariance_matrix=DenseLinearOperator(covmat), validate_args=True)
# Initialize scale tril so we can test that it was unsqueezed.
mvn.scale_tril
else:
mvn = MultivariateNormal(mean=mean, covariance_matrix=covmat, validate_args=True)
self.assertEqual(mvn.batch_shape, batch_shape)
self.assertEqual(mvn.islazy, lazy)
for dim, positive_dim, expected_batch in ((1, 1, torch.Size([2, 1, 3])), (-1, 2, torch.Size([2, 3, 1]))):
new = mvn.unsqueeze(dim)
self.assertIsInstance(new, MultivariateNormal)
self.assertEqual(new.islazy, lazy)
self.assertEqual(new.batch_shape, expected_batch)
self.assertEqual(new.event_shape, mvn.event_shape)
self.assertTrue(torch.equal(new.mean, mean.unsqueeze(positive_dim)))
self.assertEqual(new.mean.shape, expected_batch + torch.Size([3]))
self.assertTrue(torch.allclose(new.covariance_matrix, covmat.unsqueeze(positive_dim)))
self.assertEqual(new.covariance_matrix.shape, expected_batch + torch.Size([3, 3]))
self.assertTrue(torch.allclose(new.scale_tril, mvn.scale_tril.unsqueeze(positive_dim)))
self.assertEqual(new.scale_tril.shape, expected_batch + torch.Size([3, 3]))

# Check for dim validation.
with self.assertRaisesRegex(IndexError, "Dimension out of range"):
mvn.unsqueeze(3)
with self.assertRaisesRegex(IndexError, "Dimension out of range"):
mvn.unsqueeze(-4)
# Should not raise error up to 2 or -3.
mvn.unsqueeze(2)
mvn.unsqueeze(-3)


if __name__ == "__main__":
unittest.main()
Loading