From e3d5e255940a6771826b5408fe84829fad88835d Mon Sep 17 00:00:00 2001 From: Ben Zickel Date: Thu, 28 Mar 2024 14:44:47 +0200 Subject: [PATCH 1/2] Add effective sample size analytics to WeighedPredictive results. --- pyro/infer/importance.py | 96 +++++++++++++++++++--------------- pyro/infer/predictive.py | 13 +++-- tests/infer/test_predictive.py | 5 +- 3 files changed, 66 insertions(+), 48 deletions(-) diff --git a/pyro/infer/importance.py b/pyro/infer/importance.py index d25cf16680..bcb144b170 100644 --- a/pyro/infer/importance.py +++ b/pyro/infer/importance.py @@ -15,7 +15,59 @@ from .util import plate_log_prob_sum -class Importance(TracePosterior): +class WeightAnalytics: + def get_log_normalizer(self): + """ + Estimator of the normalizing constant of the target distribution. + (mean of the unnormalized weights) + """ + # ensure list is not empty + if len(self.log_weights) > 0: + log_w = ( + self.log_weights + if isinstance(self.log_weights, torch.Tensor) + else torch.tensor(self.log_weights) + ) + log_num_samples = torch.log(torch.tensor(log_w.numel() * 1.0)) + return torch.logsumexp(log_w - log_num_samples, 0) + else: + warnings.warn( + "The log_weights list is empty, can not compute normalizing constant estimate." + ) + + def get_normalized_weights(self, log_scale=False): + """ + Compute the normalized importance weights. + """ + if len(self.log_weights) > 0: + log_w = ( + self.log_weights + if isinstance(self.log_weights, torch.Tensor) + else torch.tensor(self.log_weights) + ) + log_w_norm = log_w - torch.logsumexp(log_w, 0) + return log_w_norm if log_scale else torch.exp(log_w_norm) + else: + warnings.warn( + "The log_weights list is empty. There is nothing to normalize." + ) + + def get_ESS(self): + """ + Compute (Importance Sampling) Effective Sample Size (ESS). + """ + if len(self.log_weights) > 0: + log_w_norm = self.get_normalized_weights(log_scale=True) + ess = torch.exp(-torch.logsumexp(2 * log_w_norm, 0)) + else: + warnings.warn( + "The log_weights list is empty, effective sample size is zero." + ) + ess = 0 + return ess + + +class Importance(TracePosterior, WeightAnalytics): """ :param model: probabilistic model defined as a function :param guide: guide used for sampling defined as a function @@ -55,48 +107,6 @@ def _traces(self, *args, **kwargs): log_weight = model_trace.log_prob_sum() - guide_trace.log_prob_sum() yield (model_trace, log_weight) - def get_log_normalizer(self): - """ - Estimator of the normalizing constant of the target distribution. - (mean of the unnormalized weights) - """ - # ensure list is not empty - if self.log_weights: - log_w = torch.tensor(self.log_weights) - log_num_samples = torch.log(torch.tensor(self.num_samples * 1.0)) - return torch.logsumexp(log_w - log_num_samples, 0) - else: - warnings.warn( - "The log_weights list is empty, can not compute normalizing constant estimate." - ) - - def get_normalized_weights(self, log_scale=False): - """ - Compute the normalized importance weights. - """ - if self.log_weights: - log_w = torch.tensor(self.log_weights) - log_w_norm = log_w - torch.logsumexp(log_w, 0) - return log_w_norm if log_scale else torch.exp(log_w_norm) - else: - warnings.warn( - "The log_weights list is empty. There is nothing to normalize." - ) - - def get_ESS(self): - """ - Compute (Importance Sampling) Effective Sample Size (ESS). - """ - if self.log_weights: - log_w_norm = self.get_normalized_weights(log_scale=True) - ess = torch.exp(-torch.logsumexp(2 * log_w_norm, 0)) - else: - warnings.warn( - "The log_weights list is empty, effective sample size is zero." - ) - ess = 0 - return ess - def vectorized_importance_weights(model, guide, *args, **kwargs): """ diff --git a/pyro/infer/predictive.py b/pyro/infer/predictive.py index 6be8b5cb5f..3d020f0fdb 100644 --- a/pyro/infer/predictive.py +++ b/pyro/infer/predictive.py @@ -9,6 +9,7 @@ import pyro import pyro.poutine as poutine +from pyro.infer.importance import WeightAnalytics from pyro.infer.util import plate_log_prob_sum from pyro.poutine.trace_struct import Trace from pyro.poutine.util import prune_subsample_sites @@ -317,16 +318,20 @@ def get_vectorized_trace(self, *args, **kwargs): class WeighedPredictiveResults(NamedTuple): - """ - Return value of call to instance of :class:`WeighedPredictive`. - """ - samples: Union[dict, tuple] log_weights: torch.Tensor guide_log_prob: torch.Tensor model_log_prob: torch.Tensor +class WeighedPredictiveResults(WeighedPredictiveResults, WeightAnalytics): + """ + Return value of call to instance of :class:`WeighedPredictive`. + """ + + pass + + class WeighedPredictive(Predictive): """ Class used to construct a weighed predictive distribution that is based diff --git a/tests/infer/test_predictive.py b/tests/infer/test_predictive.py index 1f28e1f05c..319a1196dd 100644 --- a/tests/infer/test_predictive.py +++ b/tests/infer/test_predictive.py @@ -46,6 +46,7 @@ def test_posterior_predictive_svi_manual_guide(parallel, predictive): num_trials = ( torch.ones(5) * 400 ) # Reduced to 400 from 1000 in order for guide optimization to converge + num_samples = 10000 num_success = dist.Binomial(num_trials, true_probs).sample() conditioned_model = poutine.condition(model, data={"obs": num_success}) elbo = Trace_ELBO(num_particles=100, vectorize_particles=True) @@ -57,7 +58,7 @@ def test_posterior_predictive_svi_manual_guide(parallel, predictive): posterior_predictive = predictive( model, guide=beta_guide, - num_samples=10000, + num_samples=num_samples, parallel=parallel, return_sites=["_RETURN"], ) @@ -71,6 +72,8 @@ def test_posterior_predictive_svi_manual_guide(parallel, predictive): assert marginal_return_vals.shape[:1] == weighed_samples.log_weights.shape # Weights should be uniform as the guide has the same distribution as the model assert weighed_samples.log_weights.std() < 0.6 + # Effective sample size should be close to actual number of samples taken from the guide + assert weighed_samples.get_ESS() > 0.8 * num_samples assert_close(marginal_return_vals.mean(dim=0), torch.ones(5) * 280, rtol=0.1) From 06ecfda513fb6751cd104e04377aa50918fcfc3b Mon Sep 17 00:00:00 2001 From: Ben Zickel Date: Thu, 28 Mar 2024 20:42:28 +0200 Subject: [PATCH 2/2] Clarify Mixin usage and convert namedtuple to dataclass. --- pyro/infer/importance.py | 11 +++++++++-- pyro/infer/predictive.py | 23 +++++++++++------------ 2 files changed, 20 insertions(+), 14 deletions(-) diff --git a/pyro/infer/importance.py b/pyro/infer/importance.py index bcb144b170..ca088645cb 100644 --- a/pyro/infer/importance.py +++ b/pyro/infer/importance.py @@ -3,6 +3,7 @@ import math import warnings +from typing import List, Union import torch @@ -15,7 +16,13 @@ from .util import plate_log_prob_sum -class WeightAnalytics: +class LogWeightsMixin: + """ + Mixin class to compute analytics from a ``.log_weights`` attribute. + """ + + log_weights: Union[List[Union[float, torch.Tensor]], torch.Tensor] + def get_log_normalizer(self): """ Estimator of the normalizing constant of the target distribution. @@ -67,7 +74,7 @@ def get_ESS(self): return ess -class Importance(TracePosterior, WeightAnalytics): +class Importance(TracePosterior, LogWeightsMixin): """ :param model: probabilistic model defined as a function :param guide: guide used for sampling defined as a function diff --git a/pyro/infer/predictive.py b/pyro/infer/predictive.py index 3d020f0fdb..ea89aff5e5 100644 --- a/pyro/infer/predictive.py +++ b/pyro/infer/predictive.py @@ -2,14 +2,15 @@ # SPDX-License-Identifier: Apache-2.0 import warnings +from dataclasses import dataclass from functools import reduce -from typing import List, NamedTuple, Union +from typing import List, Union import torch import pyro import pyro.poutine as poutine -from pyro.infer.importance import WeightAnalytics +from pyro.infer.importance import LogWeightsMixin from pyro.infer.util import plate_log_prob_sum from pyro.poutine.trace_struct import Trace from pyro.poutine.util import prune_subsample_sites @@ -35,7 +36,8 @@ def _guess_max_plate_nesting(model, args, kwargs): return max_plate_nesting -class _predictiveResults(NamedTuple): +@dataclass(frozen=True, eq=False) +class _predictiveResults: """ Return value of call to ``_predictive`` and ``_predictive_sequential``. """ @@ -317,19 +319,16 @@ def get_vectorized_trace(self, *args, **kwargs): ).trace -class WeighedPredictiveResults(NamedTuple): - samples: Union[dict, tuple] - log_weights: torch.Tensor - guide_log_prob: torch.Tensor - model_log_prob: torch.Tensor - - -class WeighedPredictiveResults(WeighedPredictiveResults, WeightAnalytics): +@dataclass(frozen=True, eq=False) +class WeighedPredictiveResults(LogWeightsMixin): """ Return value of call to instance of :class:`WeighedPredictive`. """ - pass + samples: Union[dict, tuple] + log_weights: torch.Tensor + guide_log_prob: torch.Tensor + model_log_prob: torch.Tensor class WeighedPredictive(Predictive):