From 998ca46f74bc376ef2d59c491020c7cf087dcd9d Mon Sep 17 00:00:00 2001 From: Jake Gardner Date: Thu, 21 Oct 2021 10:48:50 -0400 Subject: [PATCH 1/7] x1_eq_x2 can now be specified manually --- gpytorch/kernels/kernel.py | 23 +++++++++++++------ gpytorch/lazy/lazy_evaluated_kernel_tensor.py | 7 +++--- 2 files changed, 20 insertions(+), 10 deletions(-) diff --git a/gpytorch/kernels/kernel.py b/gpytorch/kernels/kernel.py index 7c70bbfbb..865b8523d 100644 --- a/gpytorch/kernels/kernel.py +++ b/gpytorch/kernels/kernel.py @@ -265,6 +265,7 @@ def covar_dist( diag=False, last_dim_is_batch=False, square_dist=False, + x1_eq_x2=None, dist_postprocess_func=default_postprocess_script, postprocess=True, **params, @@ -297,7 +298,8 @@ def covar_dist( x1 = x1.transpose(-1, -2).unsqueeze(-1) x2 = x2.transpose(-1, -2).unsqueeze(-1) - x1_eq_x2 = torch.equal(x1, x2) + if x1_eq_x2 is None: + x1_eq_x2 = torch.equal(x1, x2) # torch scripts expect tensors postprocess = torch.tensor(postprocess) @@ -353,7 +355,7 @@ def sub_kernels(self): for _, kernel in self.named_sub_kernels(): yield kernel - def __call__(self, x1, x2=None, diag=False, last_dim_is_batch=False, **params): + def __call__(self, x1, x2=None, diag=False, last_dim_is_batch=False, x1_eq_x2=None, **params): x1_, x2_ = x1, x2 # Select the active dimensions @@ -371,8 +373,15 @@ def __call__(self, x1, x2=None, diag=False, last_dim_is_batch=False, **params): if not x1_.size(-1) == x2_.size(-1): raise RuntimeError("x1_ and x2_ must have the same number of dimensions!") - if x2_ is None: - x2_ = x1_ + if x1_eq_x2 is None: + if x2_ is None: + x2_ = x1_ + x1_eq_x2 = True + else: + x1_eq_x2 = None + else: + if x2_ is None: + x2_ = x1_ # Check that ard_num_dims matches the supplied number of dimensions if settings.debug.on(): @@ -383,7 +392,7 @@ def __call__(self, x1, x2=None, diag=False, last_dim_is_batch=False, **params): ) if diag: - res = super(Kernel, self).__call__(x1_, x2_, diag=True, last_dim_is_batch=last_dim_is_batch, **params) + res = super(Kernel, self).__call__(x1_, x2_, diag=True, last_dim_is_batch=last_dim_is_batch, x1_eq_x2=x1_eq_x2, **params) # Did this Kernel eat the diag option? # If it does not return a LazyEvaluatedKernelTensor, we can call diag on the output if not isinstance(res, LazyEvaluatedKernelTensor): @@ -393,9 +402,9 @@ def __call__(self, x1, x2=None, diag=False, last_dim_is_batch=False, **params): else: if settings.lazily_evaluate_kernels.on(): - res = LazyEvaluatedKernelTensor(x1_, x2_, kernel=self, last_dim_is_batch=last_dim_is_batch, **params) + res = LazyEvaluatedKernelTensor(x1_, x2_, kernel=self, last_dim_is_batch=last_dim_is_batch, x1_eq_x2=x1_eq_x2, **params) else: - res = lazify(super(Kernel, self).__call__(x1_, x2_, last_dim_is_batch=last_dim_is_batch, **params)) + res = lazify(super(Kernel, self).__call__(x1_, x2_, last_dim_is_batch=last_dim_is_batch, x1_eq_x2=x1_eq_x2, **params)) return res def __getstate__(self): diff --git a/gpytorch/lazy/lazy_evaluated_kernel_tensor.py b/gpytorch/lazy/lazy_evaluated_kernel_tensor.py index e11dd1286..75f69653e 100644 --- a/gpytorch/lazy/lazy_evaluated_kernel_tensor.py +++ b/gpytorch/lazy/lazy_evaluated_kernel_tensor.py @@ -19,14 +19,15 @@ def _check_args(self, x1, x2, kernel, last_dim_is_batch=False, **params): if not torch.is_tensor(x2): return "x1 must be a tensor. Got {}".format(x1.__class__.__name__) - def __init__(self, x1, x2, kernel, last_dim_is_batch=False, **params): + def __init__(self, x1, x2, kernel, last_dim_is_batch=False, x1_eq_x2=None, **params): super(LazyEvaluatedKernelTensor, self).__init__( - x1, x2, kernel=kernel, last_dim_is_batch=last_dim_is_batch, **params + x1, x2, kernel=kernel, last_dim_is_batch=last_dim_is_batch, x1_eq_x2=x1_eq_x2, **params ) self.kernel = kernel self.x1 = x1 self.x2 = x2 self.last_dim_is_batch = last_dim_is_batch + self.x1_eq_x2 = x1_eq_x2 self.params = params @property @@ -281,7 +282,7 @@ def evaluate_kernel(self): with settings.lazily_evaluate_kernels(False): temp_active_dims = self.kernel.active_dims self.kernel.active_dims = None - res = self.kernel(x1, x2, diag=False, last_dim_is_batch=self.last_dim_is_batch, **self.params) + res = self.kernel(x1, x2, diag=False, last_dim_is_batch=self.last_dim_is_batch, x1_eq_x2=self.x1_eq_x2, **self.params) self.kernel.active_dims = temp_active_dims # Check the size of the output From 407e9bb88ee76706512e771e7e454699701f3d4b Mon Sep 17 00:00:00 2001 From: Jake Gardner Date: Sat, 23 Oct 2021 22:48:53 -0400 Subject: [PATCH 2/7] Implement derivative kernel --- gpytorch/kernels/__init__.py | 2 + gpytorch/kernels/experimental/__init__.py | 3 + .../kernels/experimental/derivative_kernel.py | 106 ++++++++++++++++++ gpytorch/kernels/kernel.py | 14 ++- 4 files changed, 121 insertions(+), 4 deletions(-) create mode 100644 gpytorch/kernels/experimental/__init__.py create mode 100644 gpytorch/kernels/experimental/derivative_kernel.py diff --git a/gpytorch/kernels/__init__.py b/gpytorch/kernels/__init__.py index b4209e477..5dfdc3a0c 100644 --- a/gpytorch/kernels/__init__.py +++ b/gpytorch/kernels/__init__.py @@ -1,4 +1,5 @@ #!/usr/bin/env python3 +from . import experimental from . import keops from .additive_structure_kernel import AdditiveStructureKernel from .arc_kernel import ArcKernel @@ -31,6 +32,7 @@ from .spectral_mixture_kernel import SpectralMixtureKernel __all__ = [ + "experimental", "keops", "Kernel", "ArcKernel", diff --git a/gpytorch/kernels/experimental/__init__.py b/gpytorch/kernels/experimental/__init__.py new file mode 100644 index 000000000..b2b2f9162 --- /dev/null +++ b/gpytorch/kernels/experimental/__init__.py @@ -0,0 +1,3 @@ +from .derivative_kernel import DerivativeKernel + +__all__ = ["DerivativeKernel"] \ No newline at end of file diff --git a/gpytorch/kernels/experimental/derivative_kernel.py b/gpytorch/kernels/experimental/derivative_kernel.py new file mode 100644 index 000000000..f23ca962b --- /dev/null +++ b/gpytorch/kernels/experimental/derivative_kernel.py @@ -0,0 +1,106 @@ +import torch + +from ..kernel import Kernel + +try: + from functorch import vmap, jacrev + + class DerivativeKernel(Kernel): + def __init__(self, base_kernel, shuffle=True, **kwargs): + """ + Wraps a kernel to add support for derivative information automatically using autograd. + + Args: + base_kernel (Kernel): Kernel object to add derivative information support to. + shuffle (bool, default True): Do we shuffle the output matrix to match GPyTorch multitask conventions? + + .. note:: + + Currently, this kernel takes advantage of experimental torch functionality found in the `functorch` package. + You must have this package installed in order to use this kernel. + + Example: + >>> x = torch.randn(5, 2) + >>> kern = gpytorch.kernels.PolynomialKernel(2) + >>> kern_grad = gpytorch.kernels.PolynomialKernelGrad(2) + >>> kern_autograd = gpytorch.kernels.experimental.DerivativeKernel(kern) + >>> assert torch.norm(kern_grad(x).evaluate() - kern_autograd(x).evaluate()) < 1e-5 + """ + super().__init__(**kwargs) + + self.base_kernel = base_kernel + self.shuffle = shuffle + + def forward(self, x1, x2, diag=False, x1_eq_x2=None, **params): + n1, d = x1.shape[-2:] + n2 = x2.shape[-2] + + if x1_eq_x2 is None: + # Functorch can't batch over equality checking, we have to assume the worst. + x1_eq_x2 = False + + if not diag: + kernf = lambda x1_, x2_: self.base_kernel.forward(x1_, x2_, diag=False, x1_eq_x2=x1_eq_x2) + + K_x1_x2 = kernf(x1, x2) + + # Compute K_{dx1, x2} block. + K_dx1_x2_func = vmap(lambda _x1: jacrev(kernf)(_x1, x2)) + K_dx1_x2 = K_dx1_x2_func(x1.unsqueeze(-2)) + K_dx1_x2 = K_dx1_x2.squeeze(-2).squeeze(-3).permute(-1, -3, -2).reshape(n1 * d, n2) + + if x1_eq_x2: + K_x1_dx2 = K_dx1_x2.transpose(-2, -1) + else: + # Compute K_{x1, dx2} block the same way (then we'll transpose). + K_dx2_x1_func = vmap(lambda _x2: jacrev(kernf)(_x2, x1)) + K_dx2_x1 = K_dx2_x1_func(x2.unsqueeze(-2)) + K_dx2_x1 = K_dx2_x1.squeeze(-2).squeeze(-3).permute(-1, -3, -2).reshape(n2 * d, n1) + K_x1_dx2 = K_dx2_x1.transpose(-2, -1) + + # Compute K_{dx1, dx2} block. + K_dx1_dx2_func = vmap(vmap(jacrev(jacrev(kernf, argnums=0), argnums=1))) + x1_expand = x1.unsqueeze(-2).unsqueeze(-2).expand(n1, n2, 1, d) + x2_expand = x2.unsqueeze(-3).unsqueeze(-2).expand(n1, n2, 1, d) + K_dx1_dx2 = K_dx1_dx2_func(x1_expand, x2_expand) + K_dx1_dx2 = K_dx1_dx2.squeeze(-2).squeeze(-3).squeeze(-3).squeeze(-3) + K_dx1_dx2 = K_dx1_dx2.permute(-2, -4, -1, -3).reshape(n1 * d, n2 * d) + + + R1 = torch.cat((K_x1_x2, K_x1_dx2), dim=-1) + R2 = torch.cat((K_dx1_x2, K_dx1_dx2), dim=-1) + K = torch.cat((R1, R2), dim=-2) + + if self.shuffle: + # Apply a perfect shuffle permutation to match the MutiTask ordering + pi1 = torch.arange(n1 * (d + 1)).view(d + 1, n1).t().reshape((n1 * (d + 1))) + pi2 = torch.arange(n2 * (d + 1)).view(d + 1, n2).t().reshape((n2 * (d + 1))) + K = K[..., pi1, :][..., :, pi2] + + return K + else: + if n1 != n2: + raise RuntimeError("DerivativeKernel does not support diag mode on rectangular kernel matrices.") + + # Must use x1_eq_x2=False here, because covar_dist just returns 0 otherwise and we lose gradients. + k_diag_f = lambda x1_, x2_: self.base_kernel.forward(x1_, x2_, diag=True, x1_eq_x2=False) + k_diag = k_diag_f(x1, x2) + k_grad_diag_f = vmap(jacrev(jacrev(k_diag_f, argnums=0), argnums=1)) + k_grad_diag = k_grad_diag_f(x1, x2).diagonal(dim1=-2, dim2=-1).transpose(-3, -1).reshape(-1) + + K_diag = torch.cat((k_diag, k_grad_diag), dim=-1) + + if self.shuffle: + pi1 = torch.arange(n1 * (d + 1)).view(d + 1, n1).t().reshape((n1 * (d + 1))) + K_diag = K_diag[..., pi1] + + return K_diag + + def num_outputs_per_input(self, x1, x2): + return x1.size(-1) + 1 + + +except (ImportError, ModuleNotFoundError): + class DerivativeKernel(Kernel): + def __init__(self, base_kernel, shuffle=False, **kwargs): + raise RuntimeError("You must have functorch installed to use DerivativeKernel!") diff --git a/gpytorch/kernels/kernel.py b/gpytorch/kernels/kernel.py index 865b8523d..cde167f68 100644 --- a/gpytorch/kernels/kernel.py +++ b/gpytorch/kernels/kernel.py @@ -139,7 +139,7 @@ def __init__( eps=1e-6, **kwargs, ): - super(Kernel, self).__init__() + super().__init__() self._batch_shape = batch_shape if active_dims is not None and not torch.is_tensor(active_dims): active_dims = torch.tensor(active_dims, dtype=torch.long) @@ -318,9 +318,15 @@ def covar_dist( res = dist_postprocess_func(res) return res else: - res = torch.norm(x1 - x2, p=2, dim=-1) - if square_dist: - res = res.pow(2) + if not square_dist: + res = torch.norm(x1 - x2, p=2, dim=-1) + else: + if x1.size(-2) == x2.size(-2): + # If n1 = n2, we can compute the squared distance diagonal + # in a way that preserves gradient flow. + res = (x1 * x1 - 2 * x1 * x2 + x2 * x2).sum(-1) + else: + res = torch.norm(x1 - x2, p=2, dim=-1).pow(2) if postprocess: res = dist_postprocess_func(res) return res From c664a10f35eaaf620dd2662e0678c0f4374c5329 Mon Sep 17 00:00:00 2001 From: Jake Gardner Date: Sat, 23 Oct 2021 22:53:28 -0400 Subject: [PATCH 3/7] add todo --- .../kernels/experimental/derivative_kernel.py | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/gpytorch/kernels/experimental/derivative_kernel.py b/gpytorch/kernels/experimental/derivative_kernel.py index f23ca962b..1468f491d 100644 --- a/gpytorch/kernels/experimental/derivative_kernel.py +++ b/gpytorch/kernels/experimental/derivative_kernel.py @@ -16,8 +16,8 @@ def __init__(self, base_kernel, shuffle=True, **kwargs): .. note:: - Currently, this kernel takes advantage of experimental torch functionality found in the `functorch` package. - You must have this package installed in order to use this kernel. + Currently, this kernel takes advantage of experimental torch functionality found in the `functorch` + package. You must have this package installed in order to use this kernel. Example: >>> x = torch.randn(5, 2) @@ -66,7 +66,6 @@ def forward(self, x1, x2, diag=False, x1_eq_x2=None, **params): K_dx1_dx2 = K_dx1_dx2.squeeze(-2).squeeze(-3).squeeze(-3).squeeze(-3) K_dx1_dx2 = K_dx1_dx2.permute(-2, -4, -1, -3).reshape(n1 * d, n2 * d) - R1 = torch.cat((K_x1_x2, K_x1_dx2), dim=-1) R2 = torch.cat((K_dx1_x2, K_dx1_dx2), dim=-1) K = torch.cat((R1, R2), dim=-2) @@ -85,6 +84,16 @@ def forward(self, x1, x2, diag=False, x1_eq_x2=None, **params): # Must use x1_eq_x2=False here, because covar_dist just returns 0 otherwise and we lose gradients. k_diag_f = lambda x1_, x2_: self.base_kernel.forward(x1_, x2_, diag=True, x1_eq_x2=False) k_diag = k_diag_f(x1, x2) + + # TODO: Currently, this computes the full Hessian of each k(x_i, x_i) diagonal element, + # and then takes the diagonal of each Hessian. As a result, this takes O(d) more memory + # than it should. + # + # This is still better than computing the full nd x nd Hessian block + # and taking the diagonal by a factor of n, but not as good as it ought to be. I'm not + # 100% sure how to solve this, since thinking about vmapping nested jacrevs hurts my brain. + # + # Maybe we could vmap a vjp against columns of I or something? k_grad_diag_f = vmap(jacrev(jacrev(k_diag_f, argnums=0), argnums=1)) k_grad_diag = k_grad_diag_f(x1, x2).diagonal(dim1=-2, dim2=-1).transpose(-3, -1).reshape(-1) @@ -101,6 +110,7 @@ def num_outputs_per_input(self, x1, x2): except (ImportError, ModuleNotFoundError): + class DerivativeKernel(Kernel): def __init__(self, base_kernel, shuffle=False, **kwargs): raise RuntimeError("You must have functorch installed to use DerivativeKernel!") From 0dfb80e2eb30b6c1a8c2a8747a0aa3ba63425aca Mon Sep 17 00:00:00 2001 From: Jake Gardner Date: Sat, 23 Oct 2021 23:03:44 -0400 Subject: [PATCH 4/7] more differentiable matern kernel --- gpytorch/kernels/matern_kernel.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/gpytorch/kernels/matern_kernel.py b/gpytorch/kernels/matern_kernel.py index 630989012..c5fc6bf59 100644 --- a/gpytorch/kernels/matern_kernel.py +++ b/gpytorch/kernels/matern_kernel.py @@ -98,7 +98,8 @@ def forward(self, x1, x2, diag=False, **params): x1_ = (x1 - mean).div(self.lengthscale) x2_ = (x2 - mean).div(self.lengthscale) - distance = self.covar_dist(x1_, x2_, diag=diag, **params) + distance_squared = self.covar_dist(x1_, x2_, diag=diag, square_dist=True, **params) + distance = distance_squared.clamp_min(1e-30).sqrt() exp_component = torch.exp(-math.sqrt(self.nu * 2) * distance) if self.nu == 0.5: @@ -106,7 +107,7 @@ def forward(self, x1, x2, diag=False, **params): elif self.nu == 1.5: constant_component = (math.sqrt(3) * distance).add(1) elif self.nu == 2.5: - constant_component = (math.sqrt(5) * distance).add(1).add(5.0 / 3.0 * distance ** 2) + constant_component = (math.sqrt(5) * distance).add(1).add(5.0 / 3.0 * distance_squared) return constant_component * exp_component return MaternCovariance.apply( x1, x2, self.lengthscale, self.nu, lambda x1, x2: self.covar_dist(x1, x2, **params) From 967644bc9951c56c1b8578cc4dee68dd5aa36e42 Mon Sep 17 00:00:00 2001 From: Jake Gardner Date: Sat, 23 Oct 2021 23:20:28 -0400 Subject: [PATCH 5/7] batch shape handling --- .../kernels/experimental/derivative_kernel.py | 30 +++++++++++++------ 1 file changed, 21 insertions(+), 9 deletions(-) diff --git a/gpytorch/kernels/experimental/derivative_kernel.py b/gpytorch/kernels/experimental/derivative_kernel.py index 1468f491d..017be3be5 100644 --- a/gpytorch/kernels/experimental/derivative_kernel.py +++ b/gpytorch/kernels/experimental/derivative_kernel.py @@ -32,6 +32,7 @@ def __init__(self, base_kernel, shuffle=True, **kwargs): self.shuffle = shuffle def forward(self, x1, x2, diag=False, x1_eq_x2=None, **params): + batch_shape = x1.shape[:-2] n1, d = x1.shape[-2:] n2 = x2.shape[-2] @@ -45,26 +46,36 @@ def forward(self, x1, x2, diag=False, x1_eq_x2=None, **params): K_x1_x2 = kernf(x1, x2) # Compute K_{dx1, x2} block. - K_dx1_x2_func = vmap(lambda _x1: jacrev(kernf)(_x1, x2)) + K_dx1_x2_func = vmap(lambda _x1: jacrev(kernf)(_x1, x2), in_dims=-3) K_dx1_x2 = K_dx1_x2_func(x1.unsqueeze(-2)) - K_dx1_x2 = K_dx1_x2.squeeze(-2).squeeze(-3).permute(-1, -3, -2).reshape(n1 * d, n2) + batch_dims = torch.flipud(-(torch.arange(len(batch_shape)) + 4)) + K_dx1_x2 = ( + K_dx1_x2.squeeze(-2).squeeze(-3).permute(*batch_dims, -1, -3, -2).reshape(*batch_shape, n1 * d, n2) + ) if x1_eq_x2: K_x1_dx2 = K_dx1_x2.transpose(-2, -1) else: # Compute K_{x1, dx2} block the same way (then we'll transpose). - K_dx2_x1_func = vmap(lambda _x2: jacrev(kernf)(_x2, x1)) + K_dx2_x1_func = vmap(lambda _x2: jacrev(kernf)(_x2, x1), in_dims=-3) K_dx2_x1 = K_dx2_x1_func(x2.unsqueeze(-2)) - K_dx2_x1 = K_dx2_x1.squeeze(-2).squeeze(-3).permute(-1, -3, -2).reshape(n2 * d, n1) + batch_dims = torch.flipud(-(torch.arange(len(batch_shape)) + 4)) + K_dx2_x1 = ( + K_dx2_x1.squeeze(-2) + .squeeze(-3) + .permute(*batch_dims, -1, -3, -2) + .reshape(*batch_shape, n2 * d, n1) + ) K_x1_dx2 = K_dx2_x1.transpose(-2, -1) # Compute K_{dx1, dx2} block. - K_dx1_dx2_func = vmap(vmap(jacrev(jacrev(kernf, argnums=0), argnums=1))) - x1_expand = x1.unsqueeze(-2).unsqueeze(-2).expand(n1, n2, 1, d) - x2_expand = x2.unsqueeze(-3).unsqueeze(-2).expand(n1, n2, 1, d) + K_dx1_dx2_func = vmap(vmap(jacrev(jacrev(kernf, argnums=0), argnums=1), in_dims=-3), in_dims=-4) + x1_expand = x1.unsqueeze(-2).unsqueeze(-2).expand(*batch_shape, n1, n2, 1, d) + x2_expand = x2.unsqueeze(-3).unsqueeze(-2).expand(*batch_shape, n1, n2, 1, d) K_dx1_dx2 = K_dx1_dx2_func(x1_expand, x2_expand) K_dx1_dx2 = K_dx1_dx2.squeeze(-2).squeeze(-3).squeeze(-3).squeeze(-3) - K_dx1_dx2 = K_dx1_dx2.permute(-2, -4, -1, -3).reshape(n1 * d, n2 * d) + batch_dims = torch.flipud(-(torch.arange(len(batch_shape)) + 5)) + K_dx1_dx2 = K_dx1_dx2.permute(*batch_dims, -2, -4, -1, -3).reshape(*batch_shape, n1 * d, n2 * d) R1 = torch.cat((K_x1_x2, K_x1_dx2), dim=-1) R2 = torch.cat((K_dx1_x2, K_dx1_dx2), dim=-1) @@ -95,7 +106,8 @@ def forward(self, x1, x2, diag=False, x1_eq_x2=None, **params): # # Maybe we could vmap a vjp against columns of I or something? k_grad_diag_f = vmap(jacrev(jacrev(k_diag_f, argnums=0), argnums=1)) - k_grad_diag = k_grad_diag_f(x1, x2).diagonal(dim1=-2, dim2=-1).transpose(-3, -1).reshape(-1) + k_grad_diag = k_grad_diag_f(x1, x2) + k_grad_diag = k_grad_diag.diagonal(dim1=-2, dim2=-1).transpose(-3, -1).reshape(*batch_shape, -1) K_diag = torch.cat((k_diag, k_grad_diag), dim=-1) From ab04472eedf22be828bd977304fd8c0922a5a992 Mon Sep 17 00:00:00 2001 From: Jake Gardner Date: Sat, 23 Oct 2021 23:21:43 -0400 Subject: [PATCH 6/7] fix lint --- gpytorch/kernels/experimental/__init__.py | 2 +- gpytorch/kernels/kernel.py | 14 +++++++++++--- gpytorch/lazy/lazy_evaluated_kernel_tensor.py | 4 +++- 3 files changed, 15 insertions(+), 5 deletions(-) diff --git a/gpytorch/kernels/experimental/__init__.py b/gpytorch/kernels/experimental/__init__.py index b2b2f9162..be2b79d76 100644 --- a/gpytorch/kernels/experimental/__init__.py +++ b/gpytorch/kernels/experimental/__init__.py @@ -1,3 +1,3 @@ from .derivative_kernel import DerivativeKernel -__all__ = ["DerivativeKernel"] \ No newline at end of file +__all__ = ["DerivativeKernel"] diff --git a/gpytorch/kernels/kernel.py b/gpytorch/kernels/kernel.py index cde167f68..274cde8a4 100644 --- a/gpytorch/kernels/kernel.py +++ b/gpytorch/kernels/kernel.py @@ -398,7 +398,9 @@ def __call__(self, x1, x2=None, diag=False, last_dim_is_batch=False, x1_eq_x2=No ) if diag: - res = super(Kernel, self).__call__(x1_, x2_, diag=True, last_dim_is_batch=last_dim_is_batch, x1_eq_x2=x1_eq_x2, **params) + res = super(Kernel, self).__call__( + x1_, x2_, diag=True, last_dim_is_batch=last_dim_is_batch, x1_eq_x2=x1_eq_x2, **params + ) # Did this Kernel eat the diag option? # If it does not return a LazyEvaluatedKernelTensor, we can call diag on the output if not isinstance(res, LazyEvaluatedKernelTensor): @@ -408,9 +410,15 @@ def __call__(self, x1, x2=None, diag=False, last_dim_is_batch=False, x1_eq_x2=No else: if settings.lazily_evaluate_kernels.on(): - res = LazyEvaluatedKernelTensor(x1_, x2_, kernel=self, last_dim_is_batch=last_dim_is_batch, x1_eq_x2=x1_eq_x2, **params) + res = LazyEvaluatedKernelTensor( + x1_, x2_, kernel=self, last_dim_is_batch=last_dim_is_batch, x1_eq_x2=x1_eq_x2, **params + ) else: - res = lazify(super(Kernel, self).__call__(x1_, x2_, last_dim_is_batch=last_dim_is_batch, x1_eq_x2=x1_eq_x2, **params)) + res = lazify( + super(Kernel, self).__call__( + x1_, x2_, last_dim_is_batch=last_dim_is_batch, x1_eq_x2=x1_eq_x2, **params + ) + ) return res def __getstate__(self): diff --git a/gpytorch/lazy/lazy_evaluated_kernel_tensor.py b/gpytorch/lazy/lazy_evaluated_kernel_tensor.py index 75f69653e..a65fee498 100644 --- a/gpytorch/lazy/lazy_evaluated_kernel_tensor.py +++ b/gpytorch/lazy/lazy_evaluated_kernel_tensor.py @@ -282,7 +282,9 @@ def evaluate_kernel(self): with settings.lazily_evaluate_kernels(False): temp_active_dims = self.kernel.active_dims self.kernel.active_dims = None - res = self.kernel(x1, x2, diag=False, last_dim_is_batch=self.last_dim_is_batch, x1_eq_x2=self.x1_eq_x2, **self.params) + res = self.kernel( + x1, x2, diag=False, last_dim_is_batch=self.last_dim_is_batch, x1_eq_x2=self.x1_eq_x2, **self.params + ) self.kernel.active_dims = temp_active_dims # Check the size of the output From ac001a03f4655c120f0ba1678394785896423589 Mon Sep 17 00:00:00 2001 From: wjmaddox Date: Thu, 28 Oct 2021 12:39:16 -0400 Subject: [PATCH 7/7] remove x1_eq_x2 when we compute covariance v grid --- gpytorch/kernels/grid_kernel.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/gpytorch/kernels/grid_kernel.py b/gpytorch/kernels/grid_kernel.py index a275ba514..ebeb5f5cd 100644 --- a/gpytorch/kernels/grid_kernel.py +++ b/gpytorch/kernels/grid_kernel.py @@ -134,6 +134,8 @@ def forward(self, x1, x2, diag=False, last_dim_is_batch=False, **params): # Use padded grid for batch mode first_grid_point = torch.stack([proj[0].unsqueeze(0) for proj in grid], dim=-1) full_grid = torch.stack(padded_grid, dim=-1) + # if params contains x1_eq_x2, we need to pop it from the dict + params.pop("x1_eq_x2") covars = delazify(self.base_kernel(first_grid_point, full_grid, last_dim_is_batch=True, **params)) if last_dim_is_batch: