Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Experimental] Derivative Kernel #1794

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions gpytorch/kernels/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#!/usr/bin/env python3
from . import experimental
from . import keops
from .additive_structure_kernel import AdditiveStructureKernel
from .arc_kernel import ArcKernel
Expand Down Expand Up @@ -31,6 +32,7 @@
from .spectral_mixture_kernel import SpectralMixtureKernel

__all__ = [
"experimental",
"keops",
"Kernel",
"ArcKernel",
Expand Down
3 changes: 3 additions & 0 deletions gpytorch/kernels/experimental/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .derivative_kernel import DerivativeKernel

__all__ = ["DerivativeKernel"]
128 changes: 128 additions & 0 deletions gpytorch/kernels/experimental/derivative_kernel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
import torch

from ..kernel import Kernel

try:
from functorch import vmap, jacrev

class DerivativeKernel(Kernel):
def __init__(self, base_kernel, shuffle=True, **kwargs):
"""
Wraps a kernel to add support for derivative information automatically using autograd.

Args:
base_kernel (Kernel): Kernel object to add derivative information support to.
shuffle (bool, default True): Do we shuffle the output matrix to match GPyTorch multitask conventions?

.. note::

Currently, this kernel takes advantage of experimental torch functionality found in the `functorch`
package. You must have this package installed in order to use this kernel.

Example:
>>> x = torch.randn(5, 2)
>>> kern = gpytorch.kernels.PolynomialKernel(2)
>>> kern_grad = gpytorch.kernels.PolynomialKernelGrad(2)
>>> kern_autograd = gpytorch.kernels.experimental.DerivativeKernel(kern)
>>> assert torch.norm(kern_grad(x).evaluate() - kern_autograd(x).evaluate()) < 1e-5
"""
super().__init__(**kwargs)

self.base_kernel = base_kernel
self.shuffle = shuffle

def forward(self, x1, x2, diag=False, x1_eq_x2=None, **params):
batch_shape = x1.shape[:-2]
n1, d = x1.shape[-2:]
n2 = x2.shape[-2]

if x1_eq_x2 is None:
# Functorch can't batch over equality checking, we have to assume the worst.
x1_eq_x2 = False

if not diag:
kernf = lambda x1_, x2_: self.base_kernel.forward(x1_, x2_, diag=False, x1_eq_x2=x1_eq_x2)

K_x1_x2 = kernf(x1, x2)

# Compute K_{dx1, x2} block.
K_dx1_x2_func = vmap(lambda _x1: jacrev(kernf)(_x1, x2), in_dims=-3)
K_dx1_x2 = K_dx1_x2_func(x1.unsqueeze(-2))
batch_dims = torch.flipud(-(torch.arange(len(batch_shape)) + 4))
K_dx1_x2 = (
K_dx1_x2.squeeze(-2).squeeze(-3).permute(*batch_dims, -1, -3, -2).reshape(*batch_shape, n1 * d, n2)
)

if x1_eq_x2:
K_x1_dx2 = K_dx1_x2.transpose(-2, -1)
else:
# Compute K_{x1, dx2} block the same way (then we'll transpose).
K_dx2_x1_func = vmap(lambda _x2: jacrev(kernf)(_x2, x1), in_dims=-3)
K_dx2_x1 = K_dx2_x1_func(x2.unsqueeze(-2))
batch_dims = torch.flipud(-(torch.arange(len(batch_shape)) + 4))
K_dx2_x1 = (
K_dx2_x1.squeeze(-2)
.squeeze(-3)
.permute(*batch_dims, -1, -3, -2)
.reshape(*batch_shape, n2 * d, n1)
)
K_x1_dx2 = K_dx2_x1.transpose(-2, -1)

# Compute K_{dx1, dx2} block.
K_dx1_dx2_func = vmap(vmap(jacrev(jacrev(kernf, argnums=0), argnums=1), in_dims=-3), in_dims=-4)
x1_expand = x1.unsqueeze(-2).unsqueeze(-2).expand(*batch_shape, n1, n2, 1, d)
x2_expand = x2.unsqueeze(-3).unsqueeze(-2).expand(*batch_shape, n1, n2, 1, d)
K_dx1_dx2 = K_dx1_dx2_func(x1_expand, x2_expand)
K_dx1_dx2 = K_dx1_dx2.squeeze(-2).squeeze(-3).squeeze(-3).squeeze(-3)
batch_dims = torch.flipud(-(torch.arange(len(batch_shape)) + 5))
K_dx1_dx2 = K_dx1_dx2.permute(*batch_dims, -2, -4, -1, -3).reshape(*batch_shape, n1 * d, n2 * d)

R1 = torch.cat((K_x1_x2, K_x1_dx2), dim=-1)
R2 = torch.cat((K_dx1_x2, K_dx1_dx2), dim=-1)
K = torch.cat((R1, R2), dim=-2)

if self.shuffle:
# Apply a perfect shuffle permutation to match the MutiTask ordering
pi1 = torch.arange(n1 * (d + 1)).view(d + 1, n1).t().reshape((n1 * (d + 1)))
pi2 = torch.arange(n2 * (d + 1)).view(d + 1, n2).t().reshape((n2 * (d + 1)))
K = K[..., pi1, :][..., :, pi2]

return K
else:
if n1 != n2:
raise RuntimeError("DerivativeKernel does not support diag mode on rectangular kernel matrices.")

# Must use x1_eq_x2=False here, because covar_dist just returns 0 otherwise and we lose gradients.
k_diag_f = lambda x1_, x2_: self.base_kernel.forward(x1_, x2_, diag=True, x1_eq_x2=False)
k_diag = k_diag_f(x1, x2)

# TODO: Currently, this computes the full Hessian of each k(x_i, x_i) diagonal element,
# and then takes the diagonal of each Hessian. As a result, this takes O(d) more memory
# than it should.
#
# This is still better than computing the full nd x nd Hessian block
# and taking the diagonal by a factor of n, but not as good as it ought to be. I'm not
# 100% sure how to solve this, since thinking about vmapping nested jacrevs hurts my brain.
#
# Maybe we could vmap a vjp against columns of I or something?
k_grad_diag_f = vmap(jacrev(jacrev(k_diag_f, argnums=0), argnums=1))
k_grad_diag = k_grad_diag_f(x1, x2)
k_grad_diag = k_grad_diag.diagonal(dim1=-2, dim2=-1).transpose(-3, -1).reshape(*batch_shape, -1)

K_diag = torch.cat((k_diag, k_grad_diag), dim=-1)

if self.shuffle:
pi1 = torch.arange(n1 * (d + 1)).view(d + 1, n1).t().reshape((n1 * (d + 1)))
K_diag = K_diag[..., pi1]

return K_diag

def num_outputs_per_input(self, x1, x2):
return x1.size(-1) + 1


except (ImportError, ModuleNotFoundError):

class DerivativeKernel(Kernel):
def __init__(self, base_kernel, shuffle=False, **kwargs):
raise RuntimeError("You must have functorch installed to use DerivativeKernel!")
2 changes: 2 additions & 0 deletions gpytorch/kernels/grid_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,8 @@ def forward(self, x1, x2, diag=False, last_dim_is_batch=False, **params):
# Use padded grid for batch mode
first_grid_point = torch.stack([proj[0].unsqueeze(0) for proj in grid], dim=-1)
full_grid = torch.stack(padded_grid, dim=-1)
# if params contains x1_eq_x2, we need to pop it from the dict
params.pop("x1_eq_x2")
covars = delazify(self.base_kernel(first_grid_point, full_grid, last_dim_is_batch=True, **params))

if last_dim_is_batch:
Expand Down
45 changes: 34 additions & 11 deletions gpytorch/kernels/kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def __init__(
eps=1e-6,
**kwargs,
):
super(Kernel, self).__init__()
super().__init__()
self._batch_shape = batch_shape
if active_dims is not None and not torch.is_tensor(active_dims):
active_dims = torch.tensor(active_dims, dtype=torch.long)
Expand Down Expand Up @@ -265,6 +265,7 @@ def covar_dist(
diag=False,
last_dim_is_batch=False,
square_dist=False,
x1_eq_x2=None,
dist_postprocess_func=default_postprocess_script,
postprocess=True,
**params,
Expand Down Expand Up @@ -297,7 +298,8 @@ def covar_dist(
x1 = x1.transpose(-1, -2).unsqueeze(-1)
x2 = x2.transpose(-1, -2).unsqueeze(-1)

x1_eq_x2 = torch.equal(x1, x2)
if x1_eq_x2 is None:
x1_eq_x2 = torch.equal(x1, x2)

# torch scripts expect tensors
postprocess = torch.tensor(postprocess)
Expand All @@ -316,9 +318,15 @@ def covar_dist(
res = dist_postprocess_func(res)
return res
else:
res = torch.norm(x1 - x2, p=2, dim=-1)
if square_dist:
res = res.pow(2)
if not square_dist:
res = torch.norm(x1 - x2, p=2, dim=-1)
else:
if x1.size(-2) == x2.size(-2):
# If n1 = n2, we can compute the squared distance diagonal
# in a way that preserves gradient flow.
res = (x1 * x1 - 2 * x1 * x2 + x2 * x2).sum(-1)
else:
res = torch.norm(x1 - x2, p=2, dim=-1).pow(2)
if postprocess:
res = dist_postprocess_func(res)
return res
Expand Down Expand Up @@ -353,7 +361,7 @@ def sub_kernels(self):
for _, kernel in self.named_sub_kernels():
yield kernel

def __call__(self, x1, x2=None, diag=False, last_dim_is_batch=False, **params):
def __call__(self, x1, x2=None, diag=False, last_dim_is_batch=False, x1_eq_x2=None, **params):
x1_, x2_ = x1, x2

# Select the active dimensions
Expand All @@ -371,8 +379,15 @@ def __call__(self, x1, x2=None, diag=False, last_dim_is_batch=False, **params):
if not x1_.size(-1) == x2_.size(-1):
raise RuntimeError("x1_ and x2_ must have the same number of dimensions!")

if x2_ is None:
x2_ = x1_
if x1_eq_x2 is None:
if x2_ is None:
x2_ = x1_
x1_eq_x2 = True
else:
x1_eq_x2 = None
else:
if x2_ is None:
x2_ = x1_

# Check that ard_num_dims matches the supplied number of dimensions
if settings.debug.on():
Expand All @@ -383,7 +398,9 @@ def __call__(self, x1, x2=None, diag=False, last_dim_is_batch=False, **params):
)

if diag:
res = super(Kernel, self).__call__(x1_, x2_, diag=True, last_dim_is_batch=last_dim_is_batch, **params)
res = super(Kernel, self).__call__(
x1_, x2_, diag=True, last_dim_is_batch=last_dim_is_batch, x1_eq_x2=x1_eq_x2, **params
)
# Did this Kernel eat the diag option?
# If it does not return a LazyEvaluatedKernelTensor, we can call diag on the output
if not isinstance(res, LazyEvaluatedKernelTensor):
Expand All @@ -393,9 +410,15 @@ def __call__(self, x1, x2=None, diag=False, last_dim_is_batch=False, **params):

else:
if settings.lazily_evaluate_kernels.on():
res = LazyEvaluatedKernelTensor(x1_, x2_, kernel=self, last_dim_is_batch=last_dim_is_batch, **params)
res = LazyEvaluatedKernelTensor(
x1_, x2_, kernel=self, last_dim_is_batch=last_dim_is_batch, x1_eq_x2=x1_eq_x2, **params
)
else:
res = lazify(super(Kernel, self).__call__(x1_, x2_, last_dim_is_batch=last_dim_is_batch, **params))
res = lazify(
super(Kernel, self).__call__(
x1_, x2_, last_dim_is_batch=last_dim_is_batch, x1_eq_x2=x1_eq_x2, **params
)
)
return res

def __getstate__(self):
Expand Down
5 changes: 3 additions & 2 deletions gpytorch/kernels/matern_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,15 +98,16 @@ def forward(self, x1, x2, diag=False, **params):

x1_ = (x1 - mean).div(self.lengthscale)
x2_ = (x2 - mean).div(self.lengthscale)
distance = self.covar_dist(x1_, x2_, diag=diag, **params)
distance_squared = self.covar_dist(x1_, x2_, diag=diag, square_dist=True, **params)
distance = distance_squared.clamp_min(1e-30).sqrt()
exp_component = torch.exp(-math.sqrt(self.nu * 2) * distance)

if self.nu == 0.5:
constant_component = 1
elif self.nu == 1.5:
constant_component = (math.sqrt(3) * distance).add(1)
elif self.nu == 2.5:
constant_component = (math.sqrt(5) * distance).add(1).add(5.0 / 3.0 * distance ** 2)
constant_component = (math.sqrt(5) * distance).add(1).add(5.0 / 3.0 * distance_squared)
return constant_component * exp_component
return MaternCovariance.apply(
x1, x2, self.lengthscale, self.nu, lambda x1, x2: self.covar_dist(x1, x2, **params)
Expand Down
9 changes: 6 additions & 3 deletions gpytorch/lazy/lazy_evaluated_kernel_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,15 @@ def _check_args(self, x1, x2, kernel, last_dim_is_batch=False, **params):
if not torch.is_tensor(x2):
return "x1 must be a tensor. Got {}".format(x1.__class__.__name__)

def __init__(self, x1, x2, kernel, last_dim_is_batch=False, **params):
def __init__(self, x1, x2, kernel, last_dim_is_batch=False, x1_eq_x2=None, **params):
super(LazyEvaluatedKernelTensor, self).__init__(
x1, x2, kernel=kernel, last_dim_is_batch=last_dim_is_batch, **params
x1, x2, kernel=kernel, last_dim_is_batch=last_dim_is_batch, x1_eq_x2=x1_eq_x2, **params
)
self.kernel = kernel
self.x1 = x1
self.x2 = x2
self.last_dim_is_batch = last_dim_is_batch
self.x1_eq_x2 = x1_eq_x2
self.params = params

@property
Expand Down Expand Up @@ -281,7 +282,9 @@ def evaluate_kernel(self):
with settings.lazily_evaluate_kernels(False):
temp_active_dims = self.kernel.active_dims
self.kernel.active_dims = None
res = self.kernel(x1, x2, diag=False, last_dim_is_batch=self.last_dim_is_batch, **self.params)
res = self.kernel(
x1, x2, diag=False, last_dim_is_batch=self.last_dim_is_batch, x1_eq_x2=self.x1_eq_x2, **self.params
)
self.kernel.active_dims = temp_active_dims

# Check the size of the output
Expand Down