diff --git a/docs/source/kernels.rst b/docs/source/kernels.rst index 5fa89b916..0964085be 100644 --- a/docs/source/kernels.rst +++ b/docs/source/kernels.rst @@ -119,24 +119,12 @@ Composition/Decoration Kernels .. autoclass:: MultiDeviceKernel :members: -:hidden:`AdditiveStructureKernel` -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -.. autoclass:: AdditiveStructureKernel - :members: - :hidden:`ProductKernel` ~~~~~~~~~~~~~~~~~~~~~~~~~ .. autoclass:: ProductKernel :members: -:hidden:`ProductStructureKernel` -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -.. autoclass:: ProductStructureKernel - :members: - :hidden:`ScaleKernel` ~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/gpytorch/kernels/__init__.py b/gpytorch/kernels/__init__.py index 55119b784..2d1189e9c 100644 --- a/gpytorch/kernels/__init__.py +++ b/gpytorch/kernels/__init__.py @@ -1,6 +1,5 @@ #!/usr/bin/env python3 from . import keops -from .additive_structure_kernel import AdditiveStructureKernel from .arc_kernel import ArcKernel from .constant_kernel import ConstantKernel from .cosine_kernel import CosineKernel @@ -24,7 +23,6 @@ from .piecewise_polynomial_kernel import PiecewisePolynomialKernel from .polynomial_kernel import PolynomialKernel from .polynomial_kernel_grad import PolynomialKernelGrad -from .product_structure_kernel import ProductStructureKernel from .rbf_kernel import RBFKernel from .rbf_kernel_grad import RBFKernelGrad from .rbf_kernel_gradgrad import RBFKernelGradGrad @@ -39,7 +37,6 @@ "Kernel", "ArcKernel", "AdditiveKernel", - "AdditiveStructureKernel", "ConstantKernel", "CylindricalKernel", "MultiDeviceKernel", @@ -61,7 +58,6 @@ "PolynomialKernel", "PolynomialKernelGrad", "ProductKernel", - "ProductStructureKernel", "RBFKernel", "RFFKernel", "RBFKernelGrad", diff --git a/gpytorch/kernels/additive_structure_kernel.py b/gpytorch/kernels/additive_structure_kernel.py deleted file mode 100644 index 1c2ba9c6e..000000000 --- a/gpytorch/kernels/additive_structure_kernel.py +++ /dev/null @@ -1,66 +0,0 @@ -#!/usr/bin/env python3 - -from typing import Optional, Tuple - -from .kernel import Kernel - - -class AdditiveStructureKernel(Kernel): - r""" - A Kernel decorator for kernels with additive structure. If a kernel decomposes - additively, then this module will be much more computationally efficient. - - A kernel function `k` decomposes additively if it can be written as - - .. math:: - - \begin{equation*} - k(\mathbf{x_1}, \mathbf{x_2}) = k'(x_1^{(1)}, x_2^{(1)}) + \ldots + k'(x_1^{(d)}, x_2^{(d)}) - \end{equation*} - - for some kernel :math:`k'` that operates on a subset of dimensions. - - Given a `b x n x d` input, `AdditiveStructureKernel` computes `d` one-dimensional kernels - (using the supplied base_kernel), and then adds the component kernels together. - Unlike :class:`~gpytorch.kernels.AdditiveKernel`, `AdditiveStructureKernel` computes each - of the additive terms in batch, making it very fast. - - Args: - base_kernel (Kernel): - The kernel to approximate with KISS-GP - num_dims (int): - The dimension of the input data. - active_dims (tuple of ints, optional): - Passed down to the `base_kernel`. - """ - - @property - def is_stationary(self) -> bool: - """ - Kernel is stationary if the base kernel is stationary. - """ - return self.base_kernel.is_stationary - - def __init__( - self, - base_kernel: Kernel, - num_dims: int, - active_dims: Optional[Tuple[int, ...]] = None, - ): - super(AdditiveStructureKernel, self).__init__(active_dims=active_dims) - self.base_kernel = base_kernel - self.num_dims = num_dims - - def forward(self, x1, x2, diag=False, last_dim_is_batch=False, **params): - if last_dim_is_batch: - raise RuntimeError("AdditiveStructureKernel does not accept the last_dim_is_batch argument.") - - res = self.base_kernel(x1, x2, diag=diag, last_dim_is_batch=True, **params) - res = res.sum(-2 if diag else -3) - return res - - def prediction_strategy(self, train_inputs, train_prior_dist, train_labels, likelihood): - return self.base_kernel.prediction_strategy(train_inputs, train_prior_dist, train_labels, likelihood) - - def num_outputs_per_input(self, x1, x2): - return self.base_kernel.num_outputs_per_input(x1, x2) diff --git a/gpytorch/kernels/product_structure_kernel.py b/gpytorch/kernels/product_structure_kernel.py deleted file mode 100644 index f25f8d7a7..000000000 --- a/gpytorch/kernels/product_structure_kernel.py +++ /dev/null @@ -1,86 +0,0 @@ -#!/usr/bin/env python3 - -from typing import Optional, Tuple - -from linear_operator.operators import to_linear_operator - -from .kernel import Kernel - - -class ProductStructureKernel(Kernel): - r""" - A Kernel decorator for kernels with product structure. If a kernel decomposes - multiplicatively, then this module will be much more computationally efficient. - - A kernel function `k` has product structure if it can be written as - - .. math:: - - \begin{equation*} - k(\mathbf{x_1}, \mathbf{x_2}) = k'(x_1^{(1)}, x_2^{(1)}) * \ldots * k'(x_1^{(d)}, x_2^{(d)}) - \end{equation*} - - for some kernel :math:`k'` that operates on each dimension. - - Given a `b x n x d` input, `ProductStructureKernel` computes `d` one-dimensional kernels - (using the supplied base_kernel), and then multiplies the component kernels together. - Unlike :class:`~gpytorch.kernels.ProductKernel`, `ProductStructureKernel` computes each - of the product terms in batch, making it very fast. - - See `Product Kernel Interpolation for Scalable Gaussian Processes`_ for more detail. - - Args: - base_kernel (Kernel): - The kernel to approximate with KISS-GP - num_dims (int): - The dimension of the input data. - active_dims (tuple of ints, optional): - Passed down to the `base_kernel`. - - .. _Product Kernel Interpolation for Scalable Gaussian Processes: - https://arxiv.org/pdf/1802.08903 - """ - - @property - def is_stationary(self) -> bool: - """ - Kernel is stationary if the base kernel is stationary. - """ - return self.base_kernel.is_stationary - - def __init__( - self, - base_kernel: Kernel, - num_dims: int, - active_dims: Optional[Tuple[int, ...]] = None, - ): - super(ProductStructureKernel, self).__init__(active_dims=active_dims) - self.base_kernel = base_kernel - self.num_dims = num_dims - - def forward(self, x1, x2, diag=False, last_dim_is_batch=False, **params): - if last_dim_is_batch: - raise RuntimeError("ProductStructureKernel does not accept the last_dim_is_batch argument.") - - res = self.base_kernel(x1, x2, diag=diag, last_dim_is_batch=True, **params) - res = res.prod(-2 if diag else -3) - return res - - def num_outputs_per_input(self, x1, x2): - return self.base_kernel.num_outputs_per_input(x1, x2) - - def __call__(self, x1_, x2_=None, diag=False, last_dim_is_batch=False, **params): - """ - We cannot lazily evaluate actual kernel calls when using SKIP, because we - cannot root decompose rectangular matrices. - - Because we slice in to the kernel during prediction to get the test x train - covar before calling evaluate_kernel, the order of operations would mean we - would get a MulLinearOperator representing a rectangular matrix, which we - cannot matmul with because we cannot root decompose it. Thus, SKIP actually - *requires* that we work with the full (train + test) x (train + test) - kernel matrix. - """ - res = super().__call__(x1_, x2_, diag=diag, last_dim_is_batch=last_dim_is_batch, **params) - res = to_linear_operator(res).evaluate_kernel() - return res diff --git a/test/examples/test_kissgp_additive_regression.py b/test/examples/test_kissgp_additive_regression.py index 4018e11ad..087c77d16 100644 --- a/test/examples/test_kissgp_additive_regression.py +++ b/test/examples/test_kissgp_additive_regression.py @@ -9,7 +9,7 @@ import gpytorch from gpytorch.distributions import MultivariateNormal -from gpytorch.kernels import AdditiveStructureKernel, GridInterpolationKernel, RBFKernel, ScaleKernel +from gpytorch.kernels import GridInterpolationKernel, RBFKernel, ScaleKernel from gpytorch.likelihoods import GaussianLikelihood from gpytorch.means import ZeroMean @@ -36,14 +36,12 @@ class GPRegressionModel(gpytorch.models.ExactGP): def __init__(self, train_x, train_y, likelihood): super(GPRegressionModel, self).__init__(train_x, train_y, likelihood) self.mean_module = ZeroMean() - self.base_covar_module = ScaleKernel(RBFKernel(ard_num_dims=2)) - self.covar_module = AdditiveStructureKernel( - GridInterpolationKernel(self.base_covar_module, grid_size=100, num_dims=1), num_dims=2 - ) + self.base_covar_module = ScaleKernel(RBFKernel(batch_shape=torch.Size([2]))) + self.covar_module = GridInterpolationKernel(self.base_covar_module, grid_size=100, num_dims=1) def forward(self, x): mean_x = self.mean_module(x) - covar_x = self.covar_module(x) + covar_x = self.covar_module(x.mT[..., None]).sum(dim=-3) return MultivariateNormal(mean_x, covar_x) diff --git a/test/examples/test_kissgp_multiplicative_regression.py b/test/examples/test_kissgp_multiplicative_regression.py index ca8b40360..d16869f95 100644 --- a/test/examples/test_kissgp_multiplicative_regression.py +++ b/test/examples/test_kissgp_multiplicative_regression.py @@ -10,7 +10,7 @@ import gpytorch from gpytorch.distributions import MultivariateNormal -from gpytorch.kernels import GridInterpolationKernel, ProductStructureKernel, RBFKernel, ScaleKernel +from gpytorch.kernels import GridInterpolationKernel, RBFKernel, ScaleKernel from gpytorch.likelihoods import GaussianLikelihood from gpytorch.means import ConstantMean from gpytorch.priors import SmoothedBoxPrior @@ -42,14 +42,12 @@ class GPRegressionModel(gpytorch.models.ExactGP): def __init__(self, train_x, train_y, likelihood): super(GPRegressionModel, self).__init__(train_x, train_y, likelihood) self.mean_module = ConstantMean(constant_prior=SmoothedBoxPrior(-1, 1)) - self.base_covar_module = ScaleKernel(RBFKernel()) - self.covar_module = ProductStructureKernel( - GridInterpolationKernel(self.base_covar_module, grid_size=100, num_dims=1), num_dims=2 - ) + self.base_covar_module = ScaleKernel(RBFKernel(batch_shape=torch.Size([2]))) + self.covar_module = GridInterpolationKernel(self.base_covar_module, grid_size=100, num_dims=1) def forward(self, x): mean_x = self.mean_module(x) - covar_x = self.covar_module(x) + covar_x = self.covar_module(x.mT[..., None]).prod(dim=-3) return MultivariateNormal(mean_x, covar_x)