Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implementation of quantiles for messenger guides [WIP] #2988

Open
wants to merge 5 commits into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 51 additions & 0 deletions pyro/infer/autoguide/effect.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,12 +157,16 @@ def __init__(
self.init_loc_fn = init_loc_fn
self._init_scale = init_scale
self._computing_median = False
self._computing_quantiles = False
self._quantile_values = None

def get_posterior(
self, name: str, prior: Distribution
) -> Union[Distribution, torch.Tensor]:
if self._computing_median:
return self._get_posterior_median(name, prior)
if self._computing_quantiles:
return self._get_posterior_quantiles(name, prior)

with helpful_support_errors({"name": name, "fn": prior}):
transform = biject_to(prior.support)
Expand Down Expand Up @@ -205,11 +209,30 @@ def median(self, *args, **kwargs):
finally:
self._computing_median = False

@torch.no_grad()
def _get_posterior_median(self, name, prior):
transform = biject_to(prior.support)
loc, scale = self._get_params(name, prior)
return transform(loc)

def quantiles(self, quantiles, *args, **kwargs):
self._computing_quantiles = True
self._quantile_values = quantiles
try:
return self(*args, **kwargs)
finally:
self._computing_quantiles = False

@torch.no_grad()
def _get_posterior_quantiles(self, name, prior):
transform = biject_to(prior.support)
loc, scale = self._get_params(name, prior)
site_quantiles = torch.tensor(
self._quantile_values, dtype=loc.dtype, device=loc.device
)
site_quantiles_values = dist.Normal(loc, scale).icdf(site_quantiles)
return transform(site_quantiles_values)


class AutoHierarchicalNormalMessenger(AutoNormalMessenger):
"""
Expand Down Expand Up @@ -263,12 +286,16 @@ def __init__(
self._init_weight = init_weight
self._hierarchical_sites = hierarchical_sites
self._computing_median = False
self._computing_quantiles = False
self._quantile_values = None

def get_posterior(
self, name: str, prior: Distribution
) -> Union[Distribution, torch.Tensor]:
if self._computing_median:
return self._get_posterior_median(name, prior)
if self._computing_quantiles:
return self._get_posterior_quantiles(name, prior)

with helpful_support_errors({"name": name, "fn": prior}):
transform = biject_to(prior.support)
Expand Down Expand Up @@ -351,6 +378,7 @@ def median(self, *args, **kwargs):
finally:
self._computing_median = False

@torch.no_grad()
def _get_posterior_median(self, name, prior):
transform = biject_to(prior.support)
if (self._hierarchical_sites is None) or (name in self._hierarchical_sites):
Expand All @@ -360,6 +388,29 @@ def _get_posterior_median(self, name, prior):
loc, scale = self._get_params(name, prior)
return transform(loc)

def quantiles(self, quantiles, *args, **kwargs):
self._computing_quantiles = True
self._quantile_values = quantiles
try:
return self(*args, **kwargs)
finally:
self._computing_quantiles = False

@torch.no_grad()
def _get_posterior_quantiles(self, name, prior):
transform = biject_to(prior.support)
if (self._hierarchical_sites is None) or (name in self._hierarchical_sites):
loc, scale, weight = self._get_params(name, prior)
loc = loc + transform.inv(prior.mean) * weight
else:
loc, scale = self._get_params(name, prior)

site_quantiles = torch.tensor(
self._quantile_values, dtype=loc.dtype, device=loc.device
)
site_quantiles_values = dist.Normal(loc, scale).icdf(site_quantiles)
return transform(site_quantiles_values)


class AutoRegressiveMessenger(AutoMessenger):
"""
Expand Down
3 changes: 3 additions & 0 deletions tests/infer/test_autoguide.py
Original file line number Diff line number Diff line change
Expand Up @@ -522,10 +522,13 @@ def AutoGuideList_x(model):
AutoLowRankMultivariateNormal,
AutoLaplaceApproximation,
AutoGuideList_x,
AutoNormalMessenger,
AutoHierarchicalNormalMessenger,
],
)
@pytest.mark.parametrize("Elbo", [Trace_ELBO, TraceGraph_ELBO, TraceEnum_ELBO])
def test_quantiles(auto_class, Elbo):
xfail_messenger(auto_class, Elbo)
def model():
pyro.sample("y", dist.LogNormal(0.0, 1.0))
pyro.sample("z", dist.Beta(2.0, 2.0).expand([2]).to_event(1))
Expand Down
Loading