Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement MVN.unsqueeze #2624

Merged
merged 6 commits into from
Jan 22, 2025
Merged

Conversation

saitcakmak
Copy link
Collaborator

Implements an unsqueeze method for MVN, that constructs a new MVN with the underlying tensors / linear operators unsqueezed along the given batch dimension. The choice of unsqueezing along the batch dimensions is consistent with the definition of expand.

MVN.expand allows us to start with MVN.batch_shape = b2 x b1 and add additional batch dimensions to the left, so that new_MVN.batch_shape = b3 x b2 x b1. unsqueeze, when combined with expand will allow us to add batch dimensions in the middle, so that new_MVN.batch_shape = b2 x b3 x b1 or b2 x b1 x b3.

The use case for this is to match MVNs produced by two models of different batch shapes. Assume that m1.batch_shape = mb2 x mb1 and m2.batch_shape = mb1. If we evaluate these two models with X = xb1 x q x d, we get m1(X).batch_shape. = xb1 x mb2 x mb1 and m2(X) = xb1 x mb1. By calling m2(X).unsqueeze(-2).expand(-1, mb2, -1), we can match the batch shapes of the two MVNs, allowing them to be combined into a single MTMVN using the from_independent_mvns method.

Comment on lines +178 to +198
if self.islazy:
new_covar = self._covar.unsqueeze(dim)
new = self.__class__(mean=new_loc, covariance_matrix=new_covar)
if self.__unbroadcasted_scale_tril is not None:
# Reuse the scale tril if available.
new.__unbroadcasted_scale_tril = self.__unbroadcasted_scale_tril.unsqueeze(dim)
else:
# Non-lazy MVN is represented using scale_tril in PyTorch.
# Constructing it from scale_tril will avoid unnecessary computation.
# Initialize using __new__, so that we can skip __init__ and use scale_tril.
new = self.__new__(type(self))
new._islazy = False
new_scale_tril = self.__unbroadcasted_scale_tril.unsqueeze(dim)
super(MultivariateNormal, new).__init__(loc=new_loc, scale_tril=new_scale_tril)
# Set the covar matrix, since it is always available for GPyTorch MVN.
new.covariance_matrix = self.covariance_matrix.unsqueeze(dim)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A lot of this is duplicate from the expand() code above. Can we make a helper to reuse most of this? Or potentially consider allowing to instantiate a new MVN from the scale_tril directly as suggested on the other PR. Could be a class method MultivariateNormal.from_scale_tril() or the like if we don't want to change the __init__() signature.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They follow the same pattern but I wouldn't call them duplicates. You could make a helper that takes in the operation to apply to the tensors, but there are two separate operations for expand (for loc & covar), so I don't know if the resulting helper would improve the code readability.

Could be a class method MultivariateNormal.from_scale_tril() or the like if we don't want to change the init() signature.

This would avoid the issue of checking for covar & scale_tril compatibility in __init__. If we're not using __init__, would this just extract the self.__new__ based construction used here into a separate method? I suppose this'd work for non-lazy scale_tril as is. We'd have to have a separate case for lazy, at which point we end up with a duplicate of __init__ and I again question the added value of it over just keeping this as is.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we have a method exposed that allows constructing the MVN from scale_tril (either as part of an updated __init__() method or as a separate method) then I would expect that people will see and use this for other purposes going forward (but they most likely won't look for the logic for this somewhere deep in the code as in this PR.

Not going to die on that hill, but the fact that we don't expose a natural way of doing this seems like a gap we ideally could address.

gpytorch/distributions/multivariate_normal.py Outdated Show resolved Hide resolved
Comment on lines +178 to +198
if self.islazy:
new_covar = self._covar.unsqueeze(dim)
new = self.__class__(mean=new_loc, covariance_matrix=new_covar)
if self.__unbroadcasted_scale_tril is not None:
# Reuse the scale tril if available.
new.__unbroadcasted_scale_tril = self.__unbroadcasted_scale_tril.unsqueeze(dim)
else:
# Non-lazy MVN is represented using scale_tril in PyTorch.
# Constructing it from scale_tril will avoid unnecessary computation.
# Initialize using __new__, so that we can skip __init__ and use scale_tril.
new = self.__new__(type(self))
new._islazy = False
new_scale_tril = self.__unbroadcasted_scale_tril.unsqueeze(dim)
super(MultivariateNormal, new).__init__(loc=new_loc, scale_tril=new_scale_tril)
# Set the covar matrix, since it is always available for GPyTorch MVN.
new.covariance_matrix = self.covariance_matrix.unsqueeze(dim)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we have a method exposed that allows constructing the MVN from scale_tril (either as part of an updated __init__() method or as a separate method) then I would expect that people will see and use this for other purposes going forward (but they most likely won't look for the logic for this somewhere deep in the code as in this PR.

Not going to die on that hill, but the fact that we don't expose a natural way of doing this seems like a gap we ideally could address.

@saitcakmak saitcakmak enabled auto-merge January 22, 2025 17:50
@saitcakmak saitcakmak merged commit 2633973 into cornellius-gp:main Jan 22, 2025
7 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants