Skip to content

Commit

Permalink
[Breaking Change] WIP: remove AdditiveStructureKernel and ProductStru…
Browse files Browse the repository at this point in the history
…ctureKernel

TODO: update docs / examples
  • Loading branch information
gpleiss committed Jul 2, 2024
1 parent a049476 commit 9e6d35b
Show file tree
Hide file tree
Showing 6 changed files with 8 additions and 180 deletions.
12 changes: 0 additions & 12 deletions docs/source/kernels.rst
Original file line number Diff line number Diff line change
Expand Up @@ -119,24 +119,12 @@ Composition/Decoration Kernels
.. autoclass:: MultiDeviceKernel
:members:

:hidden:`AdditiveStructureKernel`
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: AdditiveStructureKernel
:members:

:hidden:`ProductKernel`
~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: ProductKernel
:members:

:hidden:`ProductStructureKernel`
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: ProductStructureKernel
:members:

:hidden:`ScaleKernel`
~~~~~~~~~~~~~~~~~~~~~~~~~

Expand Down
4 changes: 0 additions & 4 deletions gpytorch/kernels/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
#!/usr/bin/env python3
from . import keops
from .additive_structure_kernel import AdditiveStructureKernel
from .arc_kernel import ArcKernel
from .constant_kernel import ConstantKernel
from .cosine_kernel import CosineKernel
Expand All @@ -24,7 +23,6 @@
from .piecewise_polynomial_kernel import PiecewisePolynomialKernel
from .polynomial_kernel import PolynomialKernel
from .polynomial_kernel_grad import PolynomialKernelGrad
from .product_structure_kernel import ProductStructureKernel
from .rbf_kernel import RBFKernel
from .rbf_kernel_grad import RBFKernelGrad
from .rbf_kernel_gradgrad import RBFKernelGradGrad
Expand All @@ -39,7 +37,6 @@
"Kernel",
"ArcKernel",
"AdditiveKernel",
"AdditiveStructureKernel",
"ConstantKernel",
"CylindricalKernel",
"MultiDeviceKernel",
Expand All @@ -61,7 +58,6 @@
"PolynomialKernel",
"PolynomialKernelGrad",
"ProductKernel",
"ProductStructureKernel",
"RBFKernel",
"RFFKernel",
"RBFKernelGrad",
Expand Down
66 changes: 0 additions & 66 deletions gpytorch/kernels/additive_structure_kernel.py

This file was deleted.

86 changes: 0 additions & 86 deletions gpytorch/kernels/product_structure_kernel.py

This file was deleted.

10 changes: 4 additions & 6 deletions test/examples/test_kissgp_additive_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

import gpytorch
from gpytorch.distributions import MultivariateNormal
from gpytorch.kernels import AdditiveStructureKernel, GridInterpolationKernel, RBFKernel, ScaleKernel
from gpytorch.kernels import GridInterpolationKernel, RBFKernel, ScaleKernel
from gpytorch.likelihoods import GaussianLikelihood
from gpytorch.means import ZeroMean

Expand All @@ -36,14 +36,12 @@ class GPRegressionModel(gpytorch.models.ExactGP):
def __init__(self, train_x, train_y, likelihood):
super(GPRegressionModel, self).__init__(train_x, train_y, likelihood)
self.mean_module = ZeroMean()
self.base_covar_module = ScaleKernel(RBFKernel(ard_num_dims=2))
self.covar_module = AdditiveStructureKernel(
GridInterpolationKernel(self.base_covar_module, grid_size=100, num_dims=1), num_dims=2
)
self.base_covar_module = ScaleKernel(RBFKernel(batch_shape=torch.Size([2])))
self.covar_module = GridInterpolationKernel(self.base_covar_module, grid_size=100, num_dims=1)

def forward(self, x):
mean_x = self.mean_module(x)
covar_x = self.covar_module(x)
covar_x = self.covar_module(x.mT[..., None]).sum(dim=-3)
return MultivariateNormal(mean_x, covar_x)


Expand Down
10 changes: 4 additions & 6 deletions test/examples/test_kissgp_multiplicative_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

import gpytorch
from gpytorch.distributions import MultivariateNormal
from gpytorch.kernels import GridInterpolationKernel, ProductStructureKernel, RBFKernel, ScaleKernel
from gpytorch.kernels import GridInterpolationKernel, RBFKernel, ScaleKernel
from gpytorch.likelihoods import GaussianLikelihood
from gpytorch.means import ConstantMean
from gpytorch.priors import SmoothedBoxPrior
Expand Down Expand Up @@ -42,14 +42,12 @@ class GPRegressionModel(gpytorch.models.ExactGP):
def __init__(self, train_x, train_y, likelihood):
super(GPRegressionModel, self).__init__(train_x, train_y, likelihood)
self.mean_module = ConstantMean(constant_prior=SmoothedBoxPrior(-1, 1))
self.base_covar_module = ScaleKernel(RBFKernel())
self.covar_module = ProductStructureKernel(
GridInterpolationKernel(self.base_covar_module, grid_size=100, num_dims=1), num_dims=2
)
self.base_covar_module = ScaleKernel(RBFKernel(batch_shape=torch.Size([2])))
self.covar_module = GridInterpolationKernel(self.base_covar_module, grid_size=100, num_dims=1)

def forward(self, x):
mean_x = self.mean_module(x)
covar_x = self.covar_module(x)
covar_x = self.covar_module(x.mT[..., None]).prod(dim=-3)
return MultivariateNormal(mean_x, covar_x)


Expand Down

0 comments on commit 9e6d35b

Please sign in to comment.