-
Notifications
You must be signed in to change notification settings - Fork 561
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement MVN.unsqueeze #2624
Implement MVN.unsqueeze #2624
Conversation
if self.islazy: | ||
new_covar = self._covar.unsqueeze(dim) | ||
new = self.__class__(mean=new_loc, covariance_matrix=new_covar) | ||
if self.__unbroadcasted_scale_tril is not None: | ||
# Reuse the scale tril if available. | ||
new.__unbroadcasted_scale_tril = self.__unbroadcasted_scale_tril.unsqueeze(dim) | ||
else: | ||
# Non-lazy MVN is represented using scale_tril in PyTorch. | ||
# Constructing it from scale_tril will avoid unnecessary computation. | ||
# Initialize using __new__, so that we can skip __init__ and use scale_tril. | ||
new = self.__new__(type(self)) | ||
new._islazy = False | ||
new_scale_tril = self.__unbroadcasted_scale_tril.unsqueeze(dim) | ||
super(MultivariateNormal, new).__init__(loc=new_loc, scale_tril=new_scale_tril) | ||
# Set the covar matrix, since it is always available for GPyTorch MVN. | ||
new.covariance_matrix = self.covariance_matrix.unsqueeze(dim) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A lot of this is duplicate from the expand()
code above. Can we make a helper to reuse most of this? Or potentially consider allowing to instantiate a new MVN from the scale_tril
directly as suggested on the other PR. Could be a class method MultivariateNormal.from_scale_tril()
or the like if we don't want to change the __init__()
signature.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
They follow the same pattern but I wouldn't call them duplicates. You could make a helper that takes in the operation to apply to the tensors, but there are two separate operations for expand
(for loc & covar), so I don't know if the resulting helper would improve the code readability.
Could be a class method MultivariateNormal.from_scale_tril() or the like if we don't want to change the init() signature.
This would avoid the issue of checking for covar
& scale_tril
compatibility in __init__
. If we're not using __init__
, would this just extract the self.__new__
based construction used here into a separate method? I suppose this'd work for non-lazy scale_tril
as is. We'd have to have a separate case for lazy, at which point we end up with a duplicate of __init__
and I again question the added value of it over just keeping this as is.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we have a method exposed that allows constructing the MVN from scale_tril
(either as part of an updated __init__()
method or as a separate method) then I would expect that people will see and use this for other purposes going forward (but they most likely won't look for the logic for this somewhere deep in the code as in this PR.
Not going to die on that hill, but the fact that we don't expose a natural way of doing this seems like a gap we ideally could address.
6287b0b
to
312b0f1
Compare
if self.islazy: | ||
new_covar = self._covar.unsqueeze(dim) | ||
new = self.__class__(mean=new_loc, covariance_matrix=new_covar) | ||
if self.__unbroadcasted_scale_tril is not None: | ||
# Reuse the scale tril if available. | ||
new.__unbroadcasted_scale_tril = self.__unbroadcasted_scale_tril.unsqueeze(dim) | ||
else: | ||
# Non-lazy MVN is represented using scale_tril in PyTorch. | ||
# Constructing it from scale_tril will avoid unnecessary computation. | ||
# Initialize using __new__, so that we can skip __init__ and use scale_tril. | ||
new = self.__new__(type(self)) | ||
new._islazy = False | ||
new_scale_tril = self.__unbroadcasted_scale_tril.unsqueeze(dim) | ||
super(MultivariateNormal, new).__init__(loc=new_loc, scale_tril=new_scale_tril) | ||
# Set the covar matrix, since it is always available for GPyTorch MVN. | ||
new.covariance_matrix = self.covariance_matrix.unsqueeze(dim) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we have a method exposed that allows constructing the MVN from scale_tril
(either as part of an updated __init__()
method or as a separate method) then I would expect that people will see and use this for other purposes going forward (but they most likely won't look for the logic for this somewhere deep in the code as in this PR.
Not going to die on that hill, but the fact that we don't expose a natural way of doing this seems like a gap we ideally could address.
Implements an
unsqueeze
method for MVN, that constructs a new MVN with the underlying tensors / linear operators unsqueezed along the given batch dimension. The choice of unsqueezing along the batch dimensions is consistent with the definition ofexpand
.MVN.expand
allows us to start withMVN.batch_shape = b2 x b1
and add additional batch dimensions to the left, so thatnew_MVN.batch_shape = b3 x b2 x b1
.unsqueeze
, when combined withexpand
will allow us to add batch dimensions in the middle, so thatnew_MVN.batch_shape = b2 x b3 x b1 or b2 x b1 x b3
.The use case for this is to match MVNs produced by two models of different batch shapes. Assume that
m1.batch_shape = mb2 x mb1
andm2.batch_shape = mb1
. If we evaluate these two models withX = xb1 x q x d
, we getm1(X).batch_shape. = xb1 x mb2 x mb1
andm2(X) = xb1 x mb1
. By callingm2(X).unsqueeze(-2).expand(-1, mb2, -1)
, we can match the batch shapes of the two MVNs, allowing them to be combined into a single MTMVN using thefrom_independent_mvns
method.