Skip to content

Commit

Permalink
[Breaking Change] WIP: replace NewtonGirardAdditiveKernel with intera…
Browse files Browse the repository at this point in the history
…ction term summing utility
  • Loading branch information
gpleiss committed Jul 2, 2024
1 parent 495772f commit 816baf6
Show file tree
Hide file tree
Showing 7 changed files with 105 additions and 283 deletions.
12 changes: 6 additions & 6 deletions docs/source/utils.rst
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,12 @@ Interpolation Utilities
.. automodule:: gpytorch.utils.interpolation
:members:

Nearest Neighbors Utilities
---------------------------------

.. automodule:: gpytorch.utils.nearest_neighbors
:members:

Quadrature Utilities
----------------------------

Expand All @@ -31,9 +37,3 @@ Transform Utilities

.. automodule:: gpytorch.utils.transforms
:members:

Nearest Neighbors Utilities
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. automodule:: gpytorch.utils.nearest_neighbors
:members:
2 changes: 0 additions & 2 deletions gpytorch/kernels/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
from .matern_kernel import MaternKernel
from .multi_device_kernel import MultiDeviceKernel
from .multitask_kernel import MultitaskKernel
from .newton_girard_additive_kernel import NewtonGirardAdditiveKernel
from .periodic_kernel import PeriodicKernel
from .piecewise_polynomial_kernel import PiecewisePolynomialKernel
from .polynomial_kernel import PolynomialKernel
Expand Down Expand Up @@ -52,7 +51,6 @@
"LinearKernel",
"MaternKernel",
"MultitaskKernel",
"NewtonGirardAdditiveKernel",
"PeriodicKernel",
"PiecewisePolynomialKernel",
"PolynomialKernel",
Expand Down
117 changes: 0 additions & 117 deletions gpytorch/kernels/newton_girard_additive_kernel.py

This file was deleted.

2 changes: 2 additions & 0 deletions gpytorch/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from . import deprecation, errors, generic, grid, interpolation, quadrature, transforms, warnings
from .memoize import cached
from .nearest_neighbors import NNUtil
from .sum_interaction_terms import sum_interaction_terms

__all__ = [
"cached",
Expand All @@ -19,6 +20,7 @@
"grid",
"interpolation",
"quadrature",
"sum_interaction_terms",
"transforms",
"warnings",
"NNUtil",
Expand Down
60 changes: 60 additions & 0 deletions gpytorch/utils/sum_interaction_terms.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
#!/usr/bin/env python3

from typing import Optional, Union

import torch

from jaxtyping import Float
from linear_operator import LinearOperator, to_dense
from torch import Tensor


def sum_interaction_terms(
covars: Float[Union[LinearOperator, Tensor], "... D N N"],
max_degree: Optional[int] = None,
dim: int = -3,
) -> Float[Tensor, "... N N"]:
r"""
Given a batch of D x N x N covariance matrices :math:`\boldsymbol K_1, \ldots, \boldsymbol K_D`,
compute the sum of each covariance matrix as well as the interaction terms up to degree `max_degree`
(denoted as :math:`M` below):
.. math::
\sum_{1 \leq i_1 < i_2 < \ldots < i_M < D} \left[
\prod_{j=1}^M \boldsymbol K_{i_j}
\right].
This function is useful for computing the sum of additive kernels as defined in
`Additive Gaussian Processes (Duvenaud et al., 2011)`_.
Note that the summation is computed in :math:`\mathcal O(D)` time using the Newton-Girard formula.
.. _Additive Gaussian Processes (Duvenaud et al., 2011):
https://arxiv.org/pdf/1112.4394
:param covars: A batch of covariance matrices, representing the base covariances to sum over
:param max_degree: The maximum degree of the interaction terms to compute.
If not provided, this will default to `D`.
:param dim: The dimension to sum over (i.e. the batch dimension containing the base covariance matrices).
Note that dim must be a negative integer (i.e. -3, not 0).
"""
assert dim < 0, "dim must be a negative integer"

covars = to_dense(covars)
ks = torch.arange(max_degree, dtype=covars.dtype, device=covars.device)
neg_one = torch.tensor(-1.0, dtype=covars.dtype, device=covars.device)

# S_times_factor[k] = factor[k] * S[k]
# = (-1)^{k} * \sum_{i=1}^D covar_i^{k+1}
S_times_factor_ks = torch.vmap(lambda k: neg_one.pow(k) * torch.sum(covars.pow(k + 1), dim=dim))(ks)

# E[deg] = 1/(deg+1) \sum_{j=0}^{deg} factor[k] * S[k] * E[deg-k]
# = 1/(deg+1) [ (factor[deg] * S[deg]) + \sum_{j=1}^{deg - 1} factor * S_ks[k] * E_ks[deg-k] ]
E_ks = torch.empty_like(S_times_factor_ks)
E_ks[0] = S_times_factor_ks[0]
for deg in range(1, max_degree):
sum_term = torch.einsum("m...,m...->...", S_times_factor_ks[:deg], E_ks[:deg].flip(0))
E_ks[deg] = (S_times_factor_ks[deg] + sum_term) / (deg + 1)

return E_ks.sum(0)
158 changes: 0 additions & 158 deletions test/kernels/test_newton_girard_additive_kernel.py

This file was deleted.

Loading

0 comments on commit 816baf6

Please sign in to comment.