Skip to content

Commit

Permalink
Merge pull request #2486 from chrisyeh96/patch-2
Browse files Browse the repository at this point in the history
DOC: fix typos in likelihood.py
  • Loading branch information
Balandat authored Jan 27, 2025
2 parents 0bace60 + 1560c8d commit c5fc3b2
Showing 1 changed file with 22 additions and 22 deletions.
44 changes: 22 additions & 22 deletions gpytorch/likelihoods/likelihood.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ def _draw_likelihood_samples(
with pyro.plate(plate_name, size=num_samples, dim=(-max_plate_nesting - 1)):
if sample_shape is None:
function_samples = pyro.sample(self.name_prefix, function_dist.mask(False))
# Deal with the fact that we're not assuming conditional indendence over data points here
# Deal with the fact that we're not assuming conditional independence over data points here
function_samples = function_samples.squeeze(-len(function_dist.event_shape) - 1)
else:
sample_shape = sample_shape[: -len(function_dist.batch_shape)]
Expand All @@ -182,8 +182,8 @@ def expected_log_prob(
:param observations: Values of :math:`y`.
:param function_dist: Distribution for :math:`f(x)`.
:param args: Additional args (passed to the foward function).
:param kwargs: Additional kwargs (passed to the foward function).
:param args: Additional args (passed to the forward function).
:param kwargs: Additional kwargs (passed to the forward function).
"""
return super().expected_log_prob(observations, function_dist, *args, **kwargs)

Expand Down Expand Up @@ -225,8 +225,8 @@ def log_marginal(
:param observations: Values of :math:`y`.
:param function_dist: Distribution for :math:`f(x)`.
:param args: Additional args (passed to the foward function).
:param kwargs: Additional kwargs (passed to the foward function).
:param args: Additional args (passed to the forward function).
:param kwargs: Additional kwargs (passed to the forward function).
"""
return super().log_marginal(observations, function_dist, *args, **kwargs)

Expand All @@ -243,8 +243,8 @@ def marginal(self, function_dist: MultivariateNormal, *args: Any, **kwargs: Any)
(co)variance of :math:`p(\mathbf f|...)`.
:param function_dist: Distribution for :math:`f(x)`.
:param args: Additional args (passed to the foward function).
:param kwargs: Additional kwargs (passed to the foward function).
:param args: Additional args (passed to the forward function).
:param kwargs: Additional kwargs (passed to the forward function).
:return: The marginal distribution, or samples from it.
"""
return super().marginal(function_dist, *args, **kwargs)
Expand All @@ -259,8 +259,8 @@ def pyro_guide(self, function_dist: MultivariateNormal, target: Tensor, *args: A
:param function_dist: Distribution of latent function
:math:`q(\mathbf f)`.
:param target: Observed :math:`\mathbf y`.
:param args: Additional args (passed to the foward function).
:param kwargs: Additional kwargs (passed to the foward function).
:param args: Additional args (passed to the forward function).
:param kwargs: Additional kwargs (passed to the forward function).
"""
with pyro.plate(self.name_prefix + ".data_plate", dim=-1):
pyro.sample(self.name_prefix + ".f", function_dist)
Expand All @@ -276,8 +276,8 @@ def pyro_model(self, function_dist: MultivariateNormal, target: Tensor, *args: A
:param function_dist: Distribution of latent function
:math:`p(\mathbf f)`.
:param target: Observed :math:`\mathbf y`.
:param args: Additional args (passed to the foward function).
:param kwargs: Additional kwargs (passed to the foward function).
:param args: Additional args (passed to the forward function).
:param kwargs: Additional kwargs (passed to the forward function).
"""
with pyro.plate(self.name_prefix + ".data_plate", dim=-1):
function_samples = pyro.sample(self.name_prefix + ".f", function_dist)
Expand Down Expand Up @@ -324,17 +324,17 @@ def __call__(self, input: Union[Tensor, MultivariateNormal], *args: Any, **kwarg
# Analytic marginal computation - Bernoulli and Gaussian likelihoods only
analytic_marginal_likelihood = gpytorch.likelihoods.GaussianLikelihood()
marginal = analytic_marginal_likeihood(f)
marginal = analytic_marginal_likelihood(f)
print(type(marginal), marginal.batch_shape, marginal.event_shape)
# >>> gpytorch.distributions.MultivariateNormal, torch.Size([]), torch.Size([20])
# >>> <class 'gpytorch.distributions.multivariate_normal.MultivariateNormal'> torch.Size([]) torch.Size([20]) # noqa: E501
# MC marginal computation - all other likelihoods
mc_marginal_likelihood = gpytorch.likelihoods.BetaLikelihood()
with gpytorch.settings.num_likelihood_samples(15):
marginal = analytic_marginal_likeihood(f)
marginal = mc_marginal_likelihood(f)
print(type(marginal), marginal.batch_shape, marginal.event_shape)
# >>> torch.distributions.Beta, torch.Size([15, 20]), torch.Size([])
# (The batch_shape of torch.Size([15, 20]) represents 15 MC samples for 20 data points.
# >>> <class 'torch.distributions.beta.Beta'> torch.Size([15, 20]) torch.Size([])
# The batch_shape torch.Size([15, 20]) represents 15 MC samples for 20 data points.
.. note::
Expand All @@ -344,8 +344,8 @@ def __call__(self, input: Union[Tensor, MultivariateNormal], *args: Any, **kwarg
:param input: Either a (... x N) sample from :math:`\mathbf f`
or a (... x N) MVN distribution of :math:`\mathbf f`.
:param args: Additional args (passed to the foward function).
:param kwargs: Additional kwargs (passed to the foward function).
:param args: Additional args (passed to the forward function).
:param kwargs: Additional kwargs (passed to the forward function).
:return: Either a conditional :math:`p(\mathbf y \mid \mathbf f)`
or marginal :math:`p(\mathbf y)`
based on whether :attr:`input` is a Tensor or a MultivariateNormal (see above).
Expand Down Expand Up @@ -377,21 +377,21 @@ def __call__(self, input: Union[Tensor, MultivariateNormal], *args: Any, **kwarg
class Likelihood(_Likelihood):
@property
def num_data(self) -> int:
warnings.warn("num_data is only used for likehoods that are integrated with Pyro.", RuntimeWarning)
warnings.warn("num_data is only used for likelihoods that are integrated with Pyro.", RuntimeWarning)
return 0

@num_data.setter
def num_data(self, val: int) -> None:
warnings.warn("num_data is only used for likehoods that are integrated with Pyro.", RuntimeWarning)
warnings.warn("num_data is only used for likelihoods that are integrated with Pyro.", RuntimeWarning)

@property
def name_prefix(self) -> str:
warnings.warn("name_prefix is only used for likehoods that are integrated with Pyro.", RuntimeWarning)
warnings.warn("name_prefix is only used for likelihoods that are integrated with Pyro.", RuntimeWarning)
return ""

@name_prefix.setter
def name_prefix(self, val: str) -> None:
warnings.warn("name_prefix is only used for likehoods that are integrated with Pyro.", RuntimeWarning)
warnings.warn("name_prefix is only used for likelihoods that are integrated with Pyro.", RuntimeWarning)


class _OneDimensionalLikelihood(Likelihood, ABC):
Expand Down

0 comments on commit c5fc3b2

Please sign in to comment.