diff --git a/docs/source/infer.autoguide.rst b/docs/source/infer.autoguide.rst index 80e4db225c..6ea6564233 100644 --- a/docs/source/infer.autoguide.rst +++ b/docs/source/infer.autoguide.rst @@ -8,7 +8,7 @@ AutoGuide .. autoclass:: pyro.infer.autoguide.AutoGuide :members: :undoc-members: - :special-members: __call__ + :member-order: bysource :show-inheritance: AutoGuideList @@ -16,7 +16,7 @@ AutoGuideList .. autoclass:: pyro.infer.autoguide.AutoGuideList :members: :undoc-members: - :special-members: __call__ + :member-order: bysource :show-inheritance: AutoCallable @@ -24,7 +24,7 @@ AutoCallable .. autoclass:: pyro.infer.autoguide.AutoCallable :members: :undoc-members: - :special-members: __call__ + :member-order: bysource :show-inheritance: AutoNormal @@ -32,7 +32,7 @@ AutoNormal .. autoclass:: pyro.infer.autoguide.AutoNormal :members: :undoc-members: - :special-members: __call__ + :member-order: bysource :show-inheritance: AutoDelta @@ -40,7 +40,7 @@ AutoDelta .. autoclass:: pyro.infer.autoguide.AutoDelta :members: :undoc-members: - :special-members: __call__ + :member-order: bysource :show-inheritance: AutoContinuous @@ -48,7 +48,7 @@ AutoContinuous .. autoclass:: pyro.infer.autoguide.AutoContinuous :members: :undoc-members: - :special-members: __call__ + :member-order: bysource :show-inheritance: AutoMultivariateNormal @@ -56,7 +56,7 @@ AutoMultivariateNormal .. autoclass:: pyro.infer.autoguide.AutoMultivariateNormal :members: :undoc-members: - :special-members: __call__ + :member-order: bysource :show-inheritance: AutoDiagonalNormal @@ -64,7 +64,7 @@ AutoDiagonalNormal .. autoclass:: pyro.infer.autoguide.AutoDiagonalNormal :members: :undoc-members: - :special-members: __call__ + :member-order: bysource :show-inheritance: AutoLowRankMultivariateNormal @@ -72,7 +72,7 @@ AutoLowRankMultivariateNormal .. autoclass:: pyro.infer.autoguide.AutoLowRankMultivariateNormal :members: :undoc-members: - :special-members: __call__ + :member-order: bysource :show-inheritance: @@ -81,7 +81,7 @@ AutoNormalizingFlow .. autoclass:: pyro.infer.autoguide.AutoNormalizingFlow :members: :undoc-members: - :special-members: __call__ + :member-order: bysource :show-inheritance: @@ -90,7 +90,7 @@ AutoIAFNormal .. autoclass:: pyro.infer.autoguide.AutoIAFNormal :members: :undoc-members: - :special-members: __call__ + :member-order: bysource :show-inheritance: AutoLaplaceApproximation @@ -98,7 +98,7 @@ AutoLaplaceApproximation .. autoclass:: pyro.infer.autoguide.AutoLaplaceApproximation :members: :undoc-members: - :special-members: __call__ + :member-order: bysource :show-inheritance: AutoDiscreteParallel @@ -106,7 +106,7 @@ AutoDiscreteParallel .. autoclass:: pyro.infer.autoguide.AutoDiscreteParallel :members: :undoc-members: - :special-members: __call__ + :member-order: bysource :show-inheritance: AutoStructured @@ -114,7 +114,15 @@ AutoStructured .. autoclass:: pyro.infer.autoguide.AutoStructured :members: :undoc-members: - :special-members: __call__ + :member-order: bysource + :show-inheritance: + +AutoGaussian +------------ +.. autoclass:: pyro.infer.autoguide.AutoGaussian + :members: + :undoc-members: + :member-order: bysource :show-inheritance: .. _autoguide-initialization: @@ -125,5 +133,5 @@ Initialization :members: :undoc-members: :special-members: __call__ - :show-inheritance: :member-order: bysource + :show-inheritance: diff --git a/pyro/infer/autoguide/__init__.py b/pyro/infer/autoguide/__init__.py index 5e3bd4774d..18db07b79c 100644 --- a/pyro/infer/autoguide/__init__.py +++ b/pyro/infer/autoguide/__init__.py @@ -1,6 +1,7 @@ # Copyright (c) 2017-2019 Uber Technologies, Inc. # SPDX-License-Identifier: Apache-2.0 +from pyro.infer.autoguide.gaussian import AutoGaussian from pyro.infer.autoguide.guides import ( AutoCallable, AutoContinuous, @@ -34,6 +35,7 @@ "AutoDelta", "AutoDiagonalNormal", "AutoDiscreteParallel", + "AutoGaussian", "AutoGuide", "AutoGuideList", "AutoIAFNormal", diff --git a/pyro/infer/autoguide/gaussian.py b/pyro/infer/autoguide/gaussian.py new file mode 100644 index 0000000000..b7fb190f3a --- /dev/null +++ b/pyro/infer/autoguide/gaussian.py @@ -0,0 +1,494 @@ +# Copyright Contributors to the Pyro project. +# SPDX-License-Identifier: Apache-2.0 + +import itertools +from collections import OrderedDict, defaultdict +from contextlib import ExitStack +from typing import Callable, Dict, Optional, Set, Tuple, Union + +import torch +from torch.distributions import biject_to + +import pyro +import pyro.distributions as dist +import pyro.poutine as poutine +from pyro.distributions import constraints +from pyro.infer.inspect import get_dependencies +from pyro.nn.module import PyroModule, PyroParam +from pyro.poutine.runtime import am_i_wrapped, get_plates +from pyro.poutine.util import site_is_subsample + +from .guides import AutoGuide +from .initialization import InitMessenger, init_to_feasible +from .utils import deep_getattr, deep_setattr, helpful_support_errors + + +# Helper to dispatch to concrete subclasses of AutoGaussian, e.g. +# AutoGaussian(model, backend="dense") +# is converted to +# AutoGaussianDense(model) +# The intent is to avoid proliferation of subclasses and docstrings, +# and provide a single interface AutoGaussian(...). +class AutoGaussianMeta(type(AutoGuide)): + backends = {} + default_backend = "dense" + + def __init__(cls, *args, **kwargs): + super().__init__(*args, **kwargs) + assert cls.__name__.startswith("AutoGaussian") + key = cls.__name__.replace("AutoGaussian", "").lower() + cls.backends[key] = cls + + def __call__(cls, *args, **kwargs): + backend = kwargs.pop("backend", cls.default_backend) + cls = cls.backends[backend] + return super(AutoGaussianMeta, cls).__call__(*args, **kwargs) + + +class AutoGaussian(AutoGuide, metaclass=AutoGaussianMeta): + """ + Gaussian guide with optimal conditional independence structure. + + This is equivalent to a full rank :class:`AutoMultivariateNormal` guide, + but with a sparse precision matrix determined by dependencies and plates in + the model [1]. Depending on model structure, this can have asymptotically + better statistical efficiency than :class:`AutoMultivariateNormal` . + + This guide implements multiple backends for computation. All backends use + the same statistically optimal parametrization. The default "dense" backend + has computational complexity similar to :class:`AutoMultivariateNormal` . + The experimental "funsor" backend can be asymptotically cheaper in terms of + time and space (using Gaussian tensor variable elimination [2,3]), but + incurs large constant overhead. The "funsor" backend requires `funsor + `_ which can be installed via ``pip install + pyro-ppl[funsor]``. + + The guide currently does not depend on the model's ``*args, **kwargs``. + + Example:: + + guide = AutoGaussian(model) + svi = SVI(model, guide, ...) + + Example using experimental funsor backend:: + + !pip install pyro-ppl[funsor] + guide = AutoGaussian(model, backend="funsor") + svi = SVI(model, guide, ...) + + **References** + + [1] S.Webb, A.GoliƄski, R.Zinkov, N.Siddharth, T.Rainforth, Y.W.Teh, F.Wood (2018) + "Faithful inversion of generative models for effective amortized inference" + https://dl.acm.org/doi/10.5555/3327144.3327229 + [2] F.Obermeyer, E.Bingham, M.Jankowiak, J.Chiu, N.Pradhan, A.M.Rush, N.Goodman + (2019) + "Tensor Variable Elimination for Plated Factor Graphs" + http://proceedings.mlr.press/v97/obermeyer19a/obermeyer19a.pdf + [3] F. Obermeyer, E. Bingham, M. Jankowiak, D. Phan, J. P. Chen + (2019) + "Functional Tensors for Probabilistic Programming" + https://arxiv.org/abs/1910.10775 + + :param callable model: A Pyro model. + :param callable init_loc_fn: A per-site initialization function. + See :ref:`autoguide-initialization` section for available functions. + :param float init_scale: Initial scale for the standard deviation of each + (unconstrained transformed) latent variable. + :param str backend: Back end for performing Gaussian tensor variable + elimination. Defaults to "dense"; other options include "funsor". + """ + + scale_constraint = constraints.softplus_positive + + def __init__( + self, + model: Callable, + *, + init_loc_fn: Callable = init_to_feasible, + init_scale: float = 0.1, + backend: Optional[str] = None, # used only by metaclass + ): + if not isinstance(init_scale, float) or not (init_scale > 0): + raise ValueError(f"Expected init_scale > 0. but got {init_scale}") + self._init_scale = init_scale + self._original_model = (model,) + model = InitMessenger(init_loc_fn)(model) + super().__init__(model) + + def _setup_prototype(self, *args, **kwargs) -> None: + super()._setup_prototype(*args, **kwargs) + + self.locs = PyroModule() + self.scales = PyroModule() + self.factors = PyroModule() + self._factors = OrderedDict() + self._plates = OrderedDict() + self._event_numel = OrderedDict() + self._unconstrained_event_shapes = OrderedDict() + + # Trace model dependencies. + model = self._original_model[0] + self._original_model = None + self.dependencies = poutine.block(get_dependencies)(model, args, kwargs)[ + "prior_dependencies" + ] + + # Collect factors and plates. + for d, site in self.prototype_trace.nodes.items(): + if site["type"] == "sample" and not site_is_subsample(site): + assert all(f.vectorized for f in site["cond_indep_stack"]) + self._factors[d] = site + plates = frozenset(site["cond_indep_stack"]) + if site["fn"].batch_shape != _plates_to_shape(plates): + raise ValueError( + f"Shape mismatch at site '{d}'. " + "Are you missing a pyro.plate() or .to_event()?" + ) + if site["is_observed"]: + # Eagerly eliminate irrelevant observation plates. + plates &= frozenset.union( + *(self._plates[u] for u in self.dependencies[d] if u != d) + ) + self._plates[d] = plates + + # Create location-scale parameters, one per latent variable. + for d, site in self._factors.items(): + if not site["is_observed"]: + with helpful_support_errors(site): + init_loc = biject_to(site["fn"].support).inv(site["value"]).detach() + batch_shape = site["fn"].batch_shape + event_shape = init_loc.shape[len(batch_shape) :] + self._unconstrained_event_shapes[d] = event_shape + self._event_numel[d] = event_shape.numel() + event_dim = len(event_shape) + deep_setattr(self.locs, d, PyroParam(init_loc, event_dim=event_dim)) + deep_setattr( + self.scales, + d, + PyroParam( + torch.full_like(init_loc, self._init_scale), + constraint=self.scale_constraint, + event_dim=event_dim, + ), + ) + + # Create parameters for dependencies, one per factor. + for d, site in self._factors.items(): + u_size = 0 + for u in self.dependencies[d]: + broken_shape = _plates_to_shape(self._plates[u] - self._plates[d]) + u_size += broken_shape.numel() * self._event_numel[u] + d_size = self._event_numel[d] + if site["is_observed"]: + d_size = min(d_size, u_size) # just an optimization + batch_shape = _plates_to_shape(self._plates[d]) + + # Create a square root parameter (full, not lower triangular). + sqrt = init_loc.new_zeros(batch_shape + (u_size, d_size)) + if d in self.dependencies[d]: + # Initialize the [d,d] block to the identity matrix. + sqrt.diagonal(dim1=-2, dim2=-1).fill_(1) + deep_setattr(self.factors, d, PyroParam(sqrt, event_dim=2)) + + def forward(self, *args, **kwargs) -> Dict[str, torch.Tensor]: + if self.prototype_trace is None: + self._setup_prototype(*args, **kwargs) + + aux_values = self._sample_aux_values() + values, log_densities = self._transform_values(aux_values) + + # Replay via Pyro primitives. + plates = self._create_plates(*args, **kwargs) + for name, site in self._factors.items(): + with ExitStack() as stack: + for frame in site["cond_indep_stack"]: + stack.enter_context(plates[frame.name]) + values[name] = pyro.sample( + name, + dist.Delta(values[name], log_densities[name], site["fn"].event_dim), + ) + return values + + def median(self) -> Dict[str, torch.Tensor]: + """ + Returns the posterior median value of each latent variable. + + :return: A dict mapping sample site name to median tensor. + :rtype: dict + """ + with torch.no_grad(), poutine.mask(mask=False): + aux_values = {name: 0.0 for name in self._factors} + values, _ = self._transform_values(aux_values) + return values + + def _transform_values( + self, + aux_values: Dict[str, torch.Tensor], + ) -> Tuple[Dict[str, torch.Tensor], Union[float, torch.Tensor]]: + # Learnably transform auxiliary values to user-facing values. + values = {} + log_densities = defaultdict(float) + compute_density = am_i_wrapped() and poutine.get_mask() is not False + for name, site in self._factors.items(): + loc = deep_getattr(self.locs, name) + scale = deep_getattr(self.scales, name) + unconstrained = aux_values[name] * scale + loc + + # Transform to constrained space. + transform = biject_to(site["fn"].support) + values[name] = transform(unconstrained) + if compute_density: + assert transform.codomain.event_dim == site["fn"].event_dim + log_densities[name] = transform.inv.log_abs_det_jacobian( + values[name], unconstrained + ) - scale.log().reshape(site["fn"].batch_shape + (-1,)).sum(-1) + + return values, log_densities + + def _sample_aux_values(self) -> Dict[str, torch.Tensor]: + raise NotImplementedError + + +class AutoGaussianDense(AutoGaussian): + """ + Dense implementation of :class:`AutoGaussian` . + + The following are equivalent:: + + guide = AutoGaussian(model, backend="dense") + guide = AutoGaussianDense(model) + """ + + def _setup_prototype(self, *args, **kwargs): + super()._setup_prototype(*args, **kwargs) + + # Collect global shapes and per-axis indices. + self._dense_shapes = {} + global_indices = {} + pos = 0 + for d, event_shape in self._unconstrained_event_shapes.items(): + batch_shape = self._factors[d]["fn"].batch_shape + self._dense_shapes[d] = batch_shape, event_shape + end = pos + (batch_shape + event_shape).numel() + global_indices[d] = torch.arange(pos, end).reshape(batch_shape + (-1,)) + pos = end + self._dense_size = pos + + # Create sparse -> dense precision scatter indices. + self._dense_scatter = {} + for d, site in self._factors.items(): + sqrt_shape = deep_getattr(self.factors, d).shape + precision_shape = sqrt_shape[:-1] + sqrt_shape[-2:-1] + index = torch.zeros(precision_shape, dtype=torch.long) + + # Collect local offsets. + local_offsets = {} + pos = 0 + for u in self.dependencies[d]: + local_offsets[u] = pos + broken_plates = self._plates[u] - self._plates[d] + pos += self._event_numel[u] * _plates_to_shape(broken_plates).numel() + + # Create indices blockwise. + for u, v in itertools.product(self.dependencies[d], self.dependencies[d]): + u_index = global_indices[u] + v_index = global_indices[v] + + # Permute broken plates to the right of preserved plates. + u_index = _break_plates(u_index, self._plates[u], self._plates[d]) + v_index = _break_plates(v_index, self._plates[v], self._plates[d]) + + # Scatter global indices into the [u,v] block. + u_start = local_offsets[u] + u_stop = u_start + u_index.size(-1) + v_start = local_offsets[v] + v_stop = v_start + v_index.size(-1) + index[ + ..., u_start:u_stop, v_start:v_stop + ] = self._dense_size * u_index.unsqueeze(-1) + v_index.unsqueeze(-2) + + self._dense_scatter[d] = index.reshape(-1) + + def _sample_aux_values(self) -> Dict[str, torch.Tensor]: + # Sample from a dense joint Gaussian over flattened variables. + precision = self._get_precision() + loc = precision.new_zeros(self._dense_size) + flat_samples = pyro.sample( + f"_{self._pyro_name}", + dist.MultivariateNormal(loc, precision_matrix=precision), + infer={"is_auxiliary": True}, + ) + sample_shape = flat_samples.shape[:-1] + + # Convert flat to shaped tensors. + samples = {} + pos = 0 + for d, (batch_shape, event_shape) in self._dense_shapes.items(): + end = pos + (batch_shape + event_shape).numel() + flat_sample = flat_samples[..., pos:end] + pos = end + # Assumes sample shapes are left of batch shapes. + samples[d] = flat_sample.reshape( + torch.broadcast_shapes(sample_shape, batch_shape) + event_shape + ) + return samples + + def _get_precision(self): + flat_precision = torch.zeros(self._dense_size ** 2) + for d, index in self._dense_scatter.items(): + sqrt = deep_getattr(self.factors, d) + precision = sqrt @ sqrt.transpose(-1, -2) + flat_precision.scatter_add_(0, index, precision.reshape(-1)) + precision = flat_precision.reshape(self._dense_size, self._dense_size) + return precision + + +class AutoGaussianFunsor(AutoGaussian): + """ + Funsor implementation of :class:`AutoGaussian` . + + The following are equivalent:: + guide = AutoGaussian(model, backend="funsor") + guide = AutoGaussianFunsor(model) + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + _import_funsor() + + def _setup_prototype(self, *args, **kwargs): + super()._setup_prototype(*args, **kwargs) + funsor = _import_funsor() + + # Check TVE condition 1: plate nesting is monotone. + for d in self._factors: + pd = {p.name for p in self._plates[d]} + for u in self.dependencies[d]: + pu = {p.name for p in self._plates[u]} + if pu <= pd: + continue # ok + raise NotImplementedError( + "Expected monotone plate nesting, but found dependency " + f"{repr(u)} -> {repr(d)} leaves plates {pu - pd}. " + "Consider splitting into multiple guides via AutoGuideList, " + "or replacing the plate in the model by .to_event()." + ) + + # Determine TVE problem shape. + factor_inputs: Dict[str, OrderedDict[str, funsor.Domain]] = {} + eliminate: Set[str] = set() + plate_to_dim: Dict[str, int] = {} + for d, site in self._factors.items(): + inputs = OrderedDict() + for f in sorted(site["cond_indep_stack"], key=lambda f: f.dim): + plate_to_dim[f.name] = f.dim + inputs[f.name] = funsor.Bint[f.size] + eliminate.add(f.name) + for u in self.dependencies[d]: + inputs[u] = funsor.Reals[self._unconstrained_event_shapes[u]] + eliminate.add(u) + factor_inputs[d] = inputs + + self._funsor_factor_inputs = factor_inputs + self._funsor_eliminate = frozenset(eliminate) + self._funsor_plate_to_dim = plate_to_dim + self._funsor_plates = frozenset(plate_to_dim) + + def _sample_aux_values(self) -> Dict[str, torch.Tensor]: + funsor = _import_funsor() + + # Convert torch to funsor. + particle_plates = frozenset(get_plates()) + plate_to_dim = self._funsor_plate_to_dim.copy() + plate_to_dim.update({f.name: f.dim for f in particle_plates}) + factors = {} + for d, inputs in self._funsor_factor_inputs.items(): + sqrt = deep_getattr(self.factors, d) + batch_shape = torch.Size( + p.size for p in sorted(self._plates[d], key=lambda p: p.dim) + ) + sqrt = sqrt.reshape(batch_shape + sqrt.shape[-2:]) + precision = sqrt @ sqrt.transpose(-1, -2) + info_vec = precision.new_zeros(()).expand(precision.shape[:-1]) + factors[d] = funsor.gaussian.Gaussian(info_vec, precision, inputs) + + # Perform Gaussian tensor variable elimination. + try: # Convert ValueError into NotImplementedError. + samples, log_prob = funsor.recipes.forward_filter_backward_rsample( + factors=factors, + eliminate=self._funsor_eliminate, + plates=frozenset(plate_to_dim), + sample_inputs={f.name: funsor.Bint[f.size] for f in particle_plates}, + ) + except ValueError as e: + if str(e) != "intractable!": + raise e from None + raise NotImplementedError( + "Funsor backend found intractable plate nesting. " + 'Consider using AutoGaussian(..., backend="dense"), ' + "splitting into multiple guides via AutoGuideList, or " + "replacing some plates in the model by .to_event()." + ) from e + + # Convert funsor to torch. + if am_i_wrapped() and poutine.get_mask() is not False: + log_prob = funsor.to_data(log_prob, name_to_dim=plate_to_dim) + pyro.factor(f"_{self._pyro_name}_latent", log_prob) + samples = { + k: funsor.to_data(v, name_to_dim=plate_to_dim) for k, v in samples.items() + } + return samples + + +def _plates_to_shape(plates): + shape = [1] * max([0] + [-f.dim for f in plates]) + for f in plates: + shape[f.dim] = f.size + return torch.Size(shape) + + +def _break_plates(x, all_plates, kept_plates): + """ + Reshapes and permutes a tensor ``x`` with event_dim=1 and batch shape given + by ``all_plates`` by breaking all plates not in ``kept_plates``. Each + broken plate is moved into the event shape, and finally the event shape is + flattend back to a single dimension. + """ + assert x.shape[:-1] == _plates_to_shape(all_plates) # event_dim == 1 + kept_plates = kept_plates & all_plates + broken_plates = all_plates - kept_plates + + if not broken_plates: + return x + + if not kept_plates: + # Empty batch shape. + return x.reshape(-1) + + batch_shape = _plates_to_shape(kept_plates) + if max(p.dim for p in kept_plates) < min(p.dim for p in broken_plates): + # No permutation is necessary. + return x.reshape(batch_shape + (-1,)) + + # We need to permute broken plates left past kept plates. + event_dims = {-1} | {p.dim - 1 for p in broken_plates} + perm = sorted(range(-x.dim(), 0), key=lambda d: (d in event_dims, d)) + return x.permute(perm).reshape(batch_shape + (-1,)) + + +def _import_funsor(): + try: + import funsor + except ImportError as e: + raise ImportError( + 'AutoGaussian(..., backend="funsor") requires funsor. ' + "Try installing via: pip install pyro-ppl[funsor]" + ) from e + funsor.set_backend("torch") + return funsor + + +__all__ = [ + "AutoGaussian", +] diff --git a/pyro/infer/autoguide/guides.py b/pyro/infer/autoguide/guides.py index 63b64b9f52..27a7fed53a 100644 --- a/pyro/infer/autoguide/guides.py +++ b/pyro/infer/autoguide/guides.py @@ -76,8 +76,25 @@ def __init__(self, model, *, create_plates=None): def model(self): return self._model[0] + def __getstate__(self): + # Do not pickle weakrefs. + self._model = None + self.master = None + return getattr(super(), "__getstate__", self.__dict__.copy)() + + def __setstate__(self, state): + getattr(super(), "__setstate__", self.__dict__.update)(state) + assert self.master is None + master_ref = weakref.ref(self) + for _, mod in self.named_modules(): + if mod is not self and isinstance(mod, AutoGuide): + mod._update_master(master_ref) + def _update_master(self, master_ref): self.master = master_ref + for _, mod in self.named_modules(): + if mod is not self and isinstance(mod, AutoGuide): + mod._update_master(master_ref) def call(self, *args, **kwargs): """ @@ -103,8 +120,8 @@ def sample_latent(*args, **kwargs): def __setattr__(self, name, value): if isinstance(value, AutoGuide): - master_ref = self if self.master is None else self.master - value._update_master(weakref.ref(master_ref)) + master_ref = weakref.ref(self) if self.master is None else self.master + value._update_master(master_ref) super().__setattr__(name, value) def _create_plates(self, *args, **kwargs): @@ -184,11 +201,6 @@ def _check_prototype(self, part_trace): assert part_site["fn"].event_shape == self_site["fn"].event_shape assert part_site["value"].shape == self_site["value"].shape - def _update_master(self, master_ref): - self.master = master_ref - for submodule in self: - submodule._update_master(master_ref) - def append(self, part): """ Add an automatic guide for part of the model. The guide should diff --git a/pyro/infer/autoguide/structured.py b/pyro/infer/autoguide/structured.py index 2e29f3fb29..4825030a2e 100644 --- a/pyro/infer/autoguide/structured.py +++ b/pyro/infer/autoguide/structured.py @@ -159,6 +159,7 @@ def _auto_config(self, sample_sites, args, kwargs): elif prior_order[d] > prior_order[u]: dependencies[d][u] = self.dependencies self.dependencies = dict(dependencies) + self._original_model = None def _setup_prototype(self, *args, **kwargs): super()._setup_prototype(*args, **kwargs) diff --git a/tests/infer/autoguide/__init__.py b/tests/infer/autoguide/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/infer/autoguide/conftest.py b/tests/infer/autoguide/conftest.py new file mode 100644 index 0000000000..17f112a1ad --- /dev/null +++ b/tests/infer/autoguide/conftest.py @@ -0,0 +1,13 @@ +# Copyright (c) 2017-2019 Uber Technologies, Inc. +# SPDX-License-Identifier: Apache-2.0 + +import pytest + + +def pytest_collection_modifyitems(items): + for item in items: + if item.nodeid.startswith("tests/infer/autoguide"): + if "stage" not in item.keywords: + item.add_marker(pytest.mark.stage("unit")) + if "init" not in item.keywords: + item.add_marker(pytest.mark.init(rng_seed=123)) diff --git a/tests/infer/autoguide/test_gaussian.py b/tests/infer/autoguide/test_gaussian.py new file mode 100644 index 0000000000..56e7af64bc --- /dev/null +++ b/tests/infer/autoguide/test_gaussian.py @@ -0,0 +1,611 @@ +# Copyright Contributors to the Pyro project. +# SPDX-License-Identifier: Apache-2.0 + +from collections import OrderedDict, namedtuple + +import pytest +import torch + +import pyro +import pyro.distributions as dist +import pyro.poutine as poutine +from pyro.infer import SVI, Predictive, Trace_ELBO +from pyro.infer.autoguide import AutoGaussian +from pyro.infer.autoguide.gaussian import ( + AutoGaussianDense, + AutoGaussianFunsor, + _break_plates, +) +from pyro.infer.reparam import LocScaleReparam +from pyro.optim import Adam +from tests.common import assert_equal, xfail_if_not_implemented + +BACKENDS = [ + "dense", + pytest.param("funsor", marks=[pytest.mark.stage("funsor")]), +] + + +def test_break_plates(): + shape = torch.Size([5, 4, 3, 2]) + x = torch.arange(shape.numel()).reshape(shape) + + MockPlate = namedtuple("MockPlate", "dim, size") + h = MockPlate(-4, 6) + i = MockPlate(-3, 5) + j = MockPlate(-2, 4) + k = MockPlate(-1, 3) + + actual = _break_plates(x, {i, j, k}, set()) + expected = x.reshape(-1) + assert_equal(actual, expected) + + actual = _break_plates(x, {i, j, k}, {i}) + expected = x.reshape(5, 1, 1, -1) + assert_equal(actual, expected) + + actual = _break_plates(x, {i, j, k}, {j}) + expected = x.permute((1, 0, 2, 3)).reshape(4, 1, -1) + assert_equal(actual, expected) + + actual = _break_plates(x, {i, j, k}, {k}) + expected = x.permute((2, 0, 1, 3)).reshape(3, -1) + assert_equal(actual, expected) + + actual = _break_plates(x, {i, j, k}, {i, j}) + expected = x.reshape(5, 4, 1, -1) + assert_equal(actual, expected) + + actual = _break_plates(x, {i, j, k}, {i, k}) + expected = x.permute((0, 2, 1, 3)).reshape(5, 1, 3, -1) + assert_equal(actual, expected) + + actual = _break_plates(x, {i, j, k}, {j, k}) + expected = x.permute((1, 2, 0, 3)).reshape(4, 3, -1) + assert_equal(actual, expected) + + actual = _break_plates(x, {i, j, k}, {i, j, k}) + expected = x + assert_equal(actual, expected) + + actual = _break_plates(x, {i, j, k}, {h, i, j, k}) + expected = x + assert_equal(actual, expected) + + +@pytest.mark.parametrize("backend", BACKENDS) +def test_backend_dispatch(backend): + def model(): + pyro.sample("x", dist.Normal(0, 1)) + + guide = AutoGaussian(model, backend=backend) + if backend == "dense": + assert isinstance(guide, AutoGaussianDense) + elif backend == "funsor": + assert isinstance(guide, AutoGaussianFunsor) + else: + raise ValueError(f"Unknown backend: {backend}") + + +def check_structure(model, expected_str): + guide = AutoGaussian(model, backend="dense") + guide() # initialize + + # Inject random noise into all unconstrained parameters. + for parameter in guide.parameters(): + parameter.data.normal_() + + with torch.no_grad(): + precision = guide._get_precision() + actual = precision.abs().gt(1e-5).long() + + str_to_number = {"?": 1, ".": 0} + expected = torch.tensor( + [[str_to_number[c] for c in row if c != " "] for row in expected_str] + ) + assert_equal(actual, expected) + + +def test_structure_1(): + def model(): + a = pyro.sample("a", dist.Normal(0, 1)) + b = pyro.sample("b", dist.Normal(a, 1)) + c = pyro.sample("c", dist.Normal(b, 1)) + pyro.sample("d", dist.Normal(c, 1), obs=torch.tensor(0.0)) + + expected = [ + "? ? .", + "? ? ?", + ". ? ?", + ] + check_structure(model, expected) + + +def test_structure_2(): + def model(): + a = pyro.sample("a", dist.Normal(0, 1)) + b = pyro.sample("b", dist.Normal(0, 1)) + with pyro.plate("i", 2): + c = pyro.sample("c", dist.Normal(a, b.exp())) + pyro.sample("d", dist.Normal(c, 1), obs=torch.tensor(0.0)) + + # size = 1 + 1 + 2 = 4 + expected = [ + "? . ? ?", + ". ? ? ?", + "? ? ? .", + "? ? . ?", + ] + check_structure(model, expected) + + +def test_structure_3(): + I, J = 2, 3 + + def model(): + i_plate = pyro.plate("i", I, dim=-1) + j_plate = pyro.plate("j", J, dim=-2) + with i_plate: + w = pyro.sample("w", dist.Normal(0, 1)) + with j_plate: + x = pyro.sample("x", dist.Normal(0, 1)) + with i_plate, j_plate: + y = pyro.sample("y", dist.Normal(w, x.exp())) + pyro.sample("z", dist.Normal(0, 1), obs=y) + + # size = 2 + 3 + 2 * 3 = 2 + 3 + 6 = 11 + expected = [ + "? . . . . ? . ? . ? .", + ". ? . . . . ? . ? . ?", + ". . ? . . ? ? . . . .", + ". . . ? . . . ? ? . .", + ". . . . ? . . . . ? ?", + "? . ? . . ? . . . . .", + ". ? ? . . . ? . . . .", + "? . . ? . . . ? . . .", + ". ? . ? . . . . ? . .", + "? . . . ? . . . . ? .", + ". ? . . ? . . . . . ?", + ] + check_structure(model, expected) + + +def test_structure_4(): + I, J = 2, 3 + + def model(): + i_plate = pyro.plate("i", I, dim=-1) + j_plate = pyro.plate("j", J, dim=-2) + a = pyro.sample("a", dist.Normal(0, 1)) + with i_plate: + b = pyro.sample("b", dist.Normal(a, 1)) + with j_plate: + c = pyro.sample("c", dist.Normal(b.mean(), 1)) + d = pyro.sample("d", dist.Normal(c.mean(), 1)) + pyro.sample("e", dist.Normal(0, 1), obs=d) + + # size = 1 + 2 + 3 + 1 = 7 + expected = [ + "? ? ? . . . .", + "? ? . ? ? ? .", + "? . ? ? ? ? .", + ". ? ? ? . . ?", + ". ? ? . ? . ?", + ". ? ? . . ? ?", + ". . . ? ? ? ?", + ] + check_structure(model, expected) + + +def test_structure_5(): + def model(): + i_plate = pyro.plate("i", 2, dim=-1) + with i_plate: + a = pyro.sample("a", dist.Normal(0, 1)) + b = pyro.sample("b", dist.Normal(a.mean(-1), 1)) + with i_plate: + pyro.sample("c", dist.Normal(b, 1), obs=torch.zeros(2)) + + # size = 2 + 1 = 3 + expected = [ + "? . ?", + ". ? ?", + "? ? ?", + ] + check_structure(model, expected) + + +@pytest.mark.parametrize("backend", BACKENDS) +def test_broken_plates_smoke(backend): + def model(): + with pyro.plate("i", 2): + a = pyro.sample("a", dist.Normal(0, 1)) + pyro.sample("b", dist.Normal(a.mean(-1), 1), obs=torch.tensor(0.0)) + + guide = AutoGaussian(model, backend=backend) + svi = SVI(model, guide, Adam({"lr": 1e-8}), Trace_ELBO()) + for step in range(2): + with xfail_if_not_implemented(): + svi.step() + guide() + predictive = Predictive(model, guide=guide, num_samples=2) + predictive() + + +@pytest.mark.parametrize("backend", BACKENDS) +def test_intractable_smoke(backend): + def model(): + i_plate = pyro.plate("i", 2, dim=-1) + j_plate = pyro.plate("j", 3, dim=-2) + with i_plate: + a = pyro.sample("a", dist.Normal(0, 1)) + with j_plate: + b = pyro.sample("b", dist.Normal(0, 1)) + with i_plate, j_plate: + c = pyro.sample("c", dist.Normal(a + b, 1)) + pyro.sample("d", dist.Normal(c, 1), obs=torch.zeros(3, 2)) + + guide = AutoGaussian(model, backend=backend) + svi = SVI(model, guide, Adam({"lr": 1e-8}), Trace_ELBO()) + for step in range(2): + with xfail_if_not_implemented(): + svi.step() + guide() + predictive = Predictive(model, guide=guide, num_samples=2) + predictive() + + +# Simplified from https://github.com/pyro-cov/tree/master/pyrocov/mutrans.py +def pyrocov_model(dataset): + # Tensor shapes are commented at the end of some lines. + features = dataset["features"] + local_time = dataset["local_time"][..., None] # [T, P, 1] + T, P, _ = local_time.shape + S, F = features.shape + weekly_strains = dataset["weekly_strains"] + assert weekly_strains.shape == (T, P, S) + + # Sample global random variables. + coef_scale = pyro.sample("coef_scale", dist.InverseGamma(5e3, 1e2))[..., None] + rate_scale = pyro.sample("rate_scale", dist.LogNormal(-4, 2))[..., None] + init_loc_scale = pyro.sample("init_loc_scale", dist.LogNormal(0, 2))[..., None] + init_scale = pyro.sample("init_scale", dist.LogNormal(0, 2))[..., None] + + # Assume relative growth rate depends strongly on mutations and weakly on place. + coef_loc = torch.zeros(F) + coef = pyro.sample("coef", dist.Logistic(coef_loc, coef_scale).to_event(1)) # [F] + rate_loc = pyro.deterministic( + "rate_loc", 0.01 * coef @ features.T, event_dim=1 + ) # [S] + + # Assume initial infections depend strongly on strain and place. + init_loc = pyro.sample( + "init_loc", dist.Normal(torch.zeros(S), init_loc_scale).to_event(1) + ) # [S] + with pyro.plate("place", P, dim=-1): + rate = pyro.sample( + "rate", dist.Normal(rate_loc, rate_scale).to_event(1) + ) # [P, S] + init = pyro.sample( + "init", dist.Normal(init_loc, init_scale).to_event(1) + ) # [P, S] + + # Finally observe counts. + with pyro.plate("time", T, dim=-2): + logits = init + rate * local_time # [T, P, S] + pyro.sample( + "obs", + dist.Multinomial(logits=logits, validate_args=False), + obs=weekly_strains, + ) + + +# This is modified by relaxing rate from deterministic to latent. +def pyrocov_model_relaxed(dataset): + # Tensor shapes are commented at the end of some lines. + features = dataset["features"] + local_time = dataset["local_time"][..., None] # [T, P, 1] + T, P, _ = local_time.shape + S, F = features.shape + weekly_strains = dataset["weekly_strains"] + assert weekly_strains.shape == (T, P, S) + + # Sample global random variables. + coef_scale = pyro.sample("coef_scale", dist.InverseGamma(5e3, 1e2))[..., None] + rate_loc_scale = pyro.sample("rate_loc_scale", dist.LogNormal(-4, 2))[..., None] + rate_scale = pyro.sample("rate_scale", dist.LogNormal(-4, 2))[..., None] + init_loc_scale = pyro.sample("init_loc_scale", dist.LogNormal(0, 2))[..., None] + init_scale = pyro.sample("init_scale", dist.LogNormal(0, 2))[..., None] + + # Assume relative growth rate depends strongly on mutations and weakly on place. + coef_loc = torch.zeros(F) + coef = pyro.sample("coef", dist.Logistic(coef_loc, coef_scale).to_event(1)) # [F] + rate_loc = pyro.sample( + "rate_loc", + dist.Normal(0.01 * coef @ features.T, rate_loc_scale).to_event(1), + ) # [S] + + # Assume initial infections depend strongly on strain and place. + init_loc = pyro.sample( + "init_loc", dist.Normal(torch.zeros(S), init_loc_scale).to_event(1) + ) # [S] + with pyro.plate("place", P, dim=-1): + rate = pyro.sample( + "rate", dist.Normal(rate_loc, rate_scale).to_event(1) + ) # [P, S] + init = pyro.sample( + "init", dist.Normal(init_loc, init_scale).to_event(1) + ) # [P, S] + + # Finally observe counts. + with pyro.plate("time", T, dim=-2): + logits = init + rate * local_time # [T, P, S] + pyro.sample( + "obs", + dist.Multinomial(logits=logits, validate_args=False), + obs=weekly_strains, + ) + + +# This is modified by more precisely tracking plates for features and strains. +def pyrocov_model_plated(dataset): + # Tensor shapes are commented at the end of some lines. + features = dataset["features"] + local_time = dataset["local_time"][..., None] # [T, P, 1] + T, P, _ = local_time.shape + S, F = features.shape + weekly_strains = dataset["weekly_strains"] # [T, P, S] + assert weekly_strains.shape == (T, P, S) + feature_plate = pyro.plate("feature", F, dim=-1) + strain_plate = pyro.plate("strain", S, dim=-1) + place_plate = pyro.plate("place", P, dim=-2) + time_plate = pyro.plate("time", T, dim=-3) + + # Sample global random variables. + coef_scale = pyro.sample("coef_scale", dist.InverseGamma(5e3, 1e2)) + rate_loc_scale = pyro.sample("rate_loc_scale", dist.LogNormal(-4, 2)) + rate_scale = pyro.sample("rate_scale", dist.LogNormal(-4, 2)) + init_loc_scale = pyro.sample("init_loc_scale", dist.LogNormal(0, 2)) + init_scale = pyro.sample("init_scale", dist.LogNormal(0, 2)) + + with feature_plate: + coef = pyro.sample("coef", dist.Logistic(0, coef_scale)) # [F] + rate_loc_loc = 0.01 * coef @ features.T + with strain_plate: + rate_loc = pyro.sample( + "rate_loc", dist.Normal(rate_loc_loc, rate_loc_scale) + ) # [S] + init_loc = pyro.sample("init_loc", dist.Normal(0, init_loc_scale)) # [S] + with place_plate, strain_plate: + rate = pyro.sample("rate", dist.Normal(rate_loc, rate_scale)) # [P, S] + init = pyro.sample("init", dist.Normal(init_loc, init_scale)) # [P, S] + + # Finally observe counts. + with time_plate, place_plate: + logits = (init + rate * local_time)[..., None, :] # [T, P, 1, S] + pyro.sample( + "obs", + dist.Multinomial(logits=logits, validate_args=False), + obs=weekly_strains[..., None, :], + ) + + +# This is modified by replacing the multinomial likelihood with poisson. +def pyrocov_model_poisson(dataset): + # Tensor shapes are commented at the end of some lines. + features = dataset["features"] + local_time = dataset["local_time"][..., None] # [T, P, 1] + T, P, _ = local_time.shape + S, F = features.shape + weekly_strains = dataset["weekly_strains"] # [T, P, S] + assert weekly_strains.shape == (T, P, S) + strain_plate = pyro.plate("strain", S, dim=-1) + place_plate = pyro.plate("place", P, dim=-2) + time_plate = pyro.plate("time", T, dim=-3) + + # Sample global random variables. + coef_scale = pyro.sample("coef_scale", dist.LogNormal(-4, 2)) + rate_loc_scale = pyro.sample("rate_loc_scale", dist.LogNormal(-4, 2)) + rate_scale = pyro.sample("rate_scale", dist.LogNormal(-4, 2)) + init_loc_scale = pyro.sample("init_loc_scale", dist.LogNormal(0, 2)) + init_scale = pyro.sample("init_scale", dist.LogNormal(0, 2)) + pois_loc = pyro.sample("pois_loc", dist.Normal(0, 2)) + pois_scale = pyro.sample("pois_scale", dist.LogNormal(0, 2)) + + coef = pyro.sample( + "coef", dist.Logistic(torch.zeros(F), coef_scale).to_event(1) + ) # [F] + rate_loc_loc = 0.01 * coef @ features.T + with strain_plate: + rate_loc = pyro.sample( + "rate_loc", dist.Normal(rate_loc_loc, rate_loc_scale) + ) # [S] + init_loc = pyro.sample("init_loc", dist.Normal(0, init_loc_scale)) # [S] + with place_plate, strain_plate: + rate = pyro.sample("rate", dist.Normal(rate_loc, rate_scale)) # [P, S] + init = pyro.sample("init", dist.Normal(init_loc, init_scale)) # [P, S] + + # Finally observe counts. + with time_plate, place_plate: + pois = pyro.sample("pois", dist.LogNormal(pois_loc, pois_scale)) + with time_plate, place_plate, strain_plate: + # Note .softmax() breaks conditional independence over strain, but only + # weakly. We could directly call .exp(), but .softmax is more + # numerically stable. + logits = pois * (init + rate * local_time).softmax(-1) # [T, P, S] + pyro.sample("obs", dist.Poisson(logits), obs=weekly_strains) + + +PYRO_COV_MODELS = [ + pyrocov_model, + pyrocov_model_relaxed, + pyrocov_model_plated, + pyrocov_model_poisson, +] + + +@pytest.mark.parametrize("model", PYRO_COV_MODELS) +@pytest.mark.parametrize("backend", BACKENDS) +def test_pyrocov_smoke(model, backend): + T, P, S, F = 3, 4, 5, 6 + dataset = { + "features": torch.randn(S, F), + "local_time": torch.randn(T, P), + "weekly_strains": torch.randn(T, P, S).exp().round(), + } + + guide = AutoGaussian(model, backend=backend) + svi = SVI(model, guide, Adam({"lr": 1e-8}), Trace_ELBO()) + for step in range(2): + with xfail_if_not_implemented(): + svi.step(dataset) + guide(dataset) + predictive = Predictive(model, guide=guide, num_samples=2) + predictive(dataset) + + +@pytest.mark.parametrize("model", PYRO_COV_MODELS) +@pytest.mark.parametrize("backend", BACKENDS) +def test_pyrocov_reparam(model, backend): + T, P, S, F = 2, 3, 4, 5 + dataset = { + "features": torch.randn(S, F), + "local_time": torch.randn(T, P), + "weekly_strains": torch.randn(T, P, S).exp().round(), + } + + # Reparametrize the model. + config = { + "coef": LocScaleReparam(), + "rate_loc": None if model is pyrocov_model else LocScaleReparam(), + "rate": LocScaleReparam(), + "init_loc": LocScaleReparam(), + "init": LocScaleReparam(), + } + model = poutine.reparam(model, config) + guide = AutoGaussian(model, backend=backend) + svi = SVI(model, guide, Adam({"lr": 1e-8}), Trace_ELBO()) + for step in range(2): + with xfail_if_not_implemented(): + svi.step(dataset) + guide(dataset) + predictive = Predictive(model, guide=guide, num_samples=2) + predictive(dataset) + + +@pytest.mark.stage("funsor") +def test_pyrocov_structure(): + from funsor import Bint, Real, Reals + + T, P, S, F = 2, 3, 4, 5 + dataset = { + "features": torch.randn(S, F), + "local_time": torch.randn(T, P), + "weekly_strains": torch.randn(T, P, S).exp().round(), + } + + guide = AutoGaussian(pyrocov_model_poisson, backend="funsor") + guide(dataset) # initialize + + expected_plates = frozenset(["time", "place", "strain"]) + assert guide._funsor_plates == expected_plates + + expected_eliminate = frozenset( + [ + "time", + "place", + "strain", + "coef_scale", + "rate_loc_scale", + "rate_scale", + "init_loc_scale", + "init_scale", + "coef", + "rate_loc", + "init_loc", + "rate", + "init", + "pois_loc", + "pois_scale", + "pois", + ] + ) + assert guide._funsor_eliminate == expected_eliminate + + expected_factor_inputs = { + "coef_scale": OrderedDict([("coef_scale", Real)]), + "rate_loc_scale": OrderedDict([("rate_loc_scale", Real)]), + "rate_scale": OrderedDict([("rate_scale", Real)]), + "init_loc_scale": OrderedDict([("init_loc_scale", Real)]), + "init_scale": OrderedDict([("init_scale", Real)]), + "pois_loc": OrderedDict([("pois_loc", Real)]), + "pois_scale": OrderedDict([("pois_scale", Real)]), + "coef": OrderedDict([("coef", Reals[5]), ("coef_scale", Real)]), + "rate_loc": OrderedDict( + [ + ("strain", Bint[4]), + ("rate_loc", Real), + ("rate_loc_scale", Real), + ("coef", Reals[5]), + ] + ), + "init_loc": OrderedDict( + [("strain", Bint[4]), ("init_loc", Real), ("init_loc_scale", Real)] + ), + "rate": OrderedDict( + [ + ("place", Bint[3]), + ("strain", Bint[4]), + ("rate", Real), + ("rate_scale", Real), + ("rate_loc", Real), + ] + ), + "init": OrderedDict( + [ + ("place", Bint[3]), + ("strain", Bint[4]), + ("init", Real), + ("init_scale", Real), + ("init_loc", Real), + ] + ), + "pois": OrderedDict( + [ + ("time", Bint[2]), + ("place", Bint[3]), + ("pois", Real), + ("pois_loc", Real), + ("pois_scale", Real), + ] + ), + } + assert guide._funsor_factor_inputs == expected_factor_inputs + + +@pytest.mark.parametrize("backend", BACKENDS) +def test_profile(backend, n=1, num_steps=1): + """ + Helper function for profiling. + """ + model = pyrocov_model_poisson + T, P, S, F = 2 * n, 3 * n, 4 * n, 5 * n + dataset = { + "features": torch.randn(S, F), + "local_time": torch.randn(T, P), + "weekly_strains": torch.randn(T, P, S).exp().round(), + } + + guide = AutoGaussian(model) + svi = SVI(model, guide, Adam({"lr": 1e-8}), Trace_ELBO()) + guide(dataset) # initialize + print("Parameter shapes:") + for name, param in guide.named_parameters(): + print(f" {name}: {tuple(param.shape)}") + + for step in range(num_steps): + svi.step(dataset) + + +if __name__ == "__main__": + test_profile(backend="funsor", n=10, num_steps=100) diff --git a/tests/infer/test_autoguide.py b/tests/infer/test_autoguide.py index 8e1f23b6f6..67eb3319bb 100644 --- a/tests/infer/test_autoguide.py +++ b/tests/infer/test_autoguide.py @@ -4,7 +4,6 @@ import functools import io import warnings -from operator import attrgetter import numpy as np import pytest @@ -21,6 +20,7 @@ AutoDelta, AutoDiagonalNormal, AutoDiscreteParallel, + AutoGaussian, AutoGuide, AutoGuideList, AutoIAFNormal, @@ -34,13 +34,26 @@ init_to_median, init_to_sample, ) +from pyro.infer.autoguide.gaussian import AutoGaussianFunsor from pyro.infer.reparam import ProjectedNormalReparam from pyro.nn.module import PyroModule, PyroParam, PyroSample -from pyro.optim import Adam +from pyro.ops.gaussian import Gaussian +from pyro.optim import Adam, ClippedAdam from pyro.poutine.util import prune_subsample_sites from pyro.util import check_model_guide_match from tests.common import assert_close, assert_equal +AutoGaussianFunsor_xfail = pytest.param( + AutoGaussianFunsor, + marks=[ + pytest.mark.stage("funsor"), + pytest.mark.xfail(reason="jit is not supported"), + ], +) +AutoGaussianFunsor = pytest.param( + AutoGaussianFunsor, marks=[pytest.mark.stage("funsor")] +) + @pytest.mark.parametrize( "auto_class", @@ -87,6 +100,8 @@ def model(): AutoLowRankMultivariateNormal, AutoIAFNormal, AutoLaplaceApproximation, + AutoGaussian, + AutoGaussianFunsor, ], ) def test_factor(auto_class, Elbo): @@ -183,6 +198,8 @@ def dependency_z6_z5(z5): AutoLaplaceApproximation, AutoStructured, AutoStructured_shapes, + AutoGaussian, + AutoGaussianFunsor, ], ) @pytest.mark.filterwarnings("ignore::FutureWarning") @@ -220,6 +237,8 @@ def model(): AutoLowRankMultivariateNormal, AutoIAFNormal, AutoLaplaceApproximation, + AutoGaussian, + AutoGaussianFunsor, ], ) @pytest.mark.parametrize("Elbo", [Trace_ELBO, TraceGraph_ELBO]) @@ -270,19 +289,20 @@ def median_x(): return guide -def auto_guide_module_callable(model): - class GuideX(AutoGuide): - def __init__(self, model): - super().__init__(model) - self.x_loc = nn.Parameter(torch.tensor(1.0)) - self.x_scale = PyroParam(torch.tensor(0.1), constraint=constraints.positive) +class GuideX(AutoGuide): + def __init__(self, model): + super().__init__(model) + self.x_loc = nn.Parameter(torch.tensor(1.0)) + self.x_scale = PyroParam(torch.tensor(0.1), constraint=constraints.positive) + + def forward(self, *args, **kwargs): + return {"x": pyro.sample("x", dist.Normal(self.x_loc, self.x_scale))} - def forward(self, *args, **kwargs): - return {"x": pyro.sample("x", dist.Normal(self.x_loc, self.x_scale))} + def median(self, *args, **kwargs): + return {"x": self.x_loc.detach()} - def median(self, *args, **kwargs): - return {"x": self.x_loc.detach()} +def auto_guide_module_callable(model): guide = AutoGuideList(model) guide.custom = GuideX(model) guide.diagnorm = AutoDiagonalNormal(poutine.block(model, hide=["x"])) @@ -332,6 +352,8 @@ def __init__(self, model): functools.partial(AutoDiagonalNormal, init_loc_fn=init_to_sample), AutoStructured, AutoStructured_median, + AutoGaussian, + AutoGaussianFunsor, ], ) @pytest.mark.parametrize("Elbo", [Trace_ELBO, TraceGraph_ELBO, TraceEnum_ELBO]) @@ -362,6 +384,7 @@ def model(): assert_equal(median["z"], torch.tensor(0.5), prec=0.1) +@pytest.mark.parametrize("jit", [False, True], ids=["nojit", "jit"]) @pytest.mark.parametrize( "auto_class", [ @@ -380,10 +403,12 @@ def model(): functools.partial(AutoDiagonalNormal, init_loc_fn=init_to_sample), AutoStructured, AutoStructured_median, + AutoGaussian, + AutoGaussianFunsor_xfail, ], ) @pytest.mark.parametrize("Elbo", [Trace_ELBO, TraceGraph_ELBO, TraceEnum_ELBO]) -def test_autoguide_serialization(auto_class, Elbo): +def test_serialization(auto_class, Elbo, jit): def model(): pyro.sample("x", dist.Normal(0.0, 1.0)) with pyro.plate("plate", 2): @@ -396,33 +421,45 @@ def model(): guide = guide.laplace_approximation() pyro.set_rng_seed(0) expected = guide.call() - names = sorted(guide()) - - # Ignore tracer warnings - with warnings.catch_warnings(): - warnings.filterwarnings("ignore", category=torch.jit.TracerWarning) - # XXX: check_trace=True fails for AutoLaplaceApproximation - traced_guide = torch.jit.trace_module(guide, {"call": ()}, check_trace=False) - f = io.BytesIO() - torch.jit.save(traced_guide, f) - f.seek(0) - guide_deser = torch.jit.load(f) + latent_names = sorted(guide()) + expected_params = {k: v.data for k, v in guide.named_parameters()} + + if jit: + # Ignore tracer warnings + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", category=torch.jit.TracerWarning) + # XXX: check_trace=True fails for AutoLaplaceApproximation + traced_guide = torch.jit.trace_module( + guide, {"call": ()}, check_trace=False + ) + f = io.BytesIO() + torch.jit.save(traced_guide, f) + del guide, traced_guide + pyro.clear_param_store() + f.seek(0) + guide_deser = torch.jit.load(f) + else: + # Work around https://github.com/pytorch/pytorch/issues/27972 + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", category=UserWarning) + f = io.BytesIO() + torch.save(guide, f) + f.seek(0) + guide_deser = torch.load(f) # Check .call() result. pyro.set_rng_seed(0) actual = guide_deser.call() assert len(actual) == len(expected) - for name, a, e in zip(names, actual, expected): + for name, a, e in zip(latent_names, actual, expected): assert_equal(a, e, msg="{}: {} vs {}".format(name, a, e)) # Check named_parameters. - expected_names = {name for name, _ in guide.named_parameters()} - actual_names = {name for name, _ in guide_deser.named_parameters()} - assert actual_names == expected_names - for name in actual_names: - # Get nested attributes. - attr_get = attrgetter(name) - assert_equal(attr_get(guide_deser), attr_get(guide).data) + actual_params = {k: v.data for k, v in guide_deser.named_parameters()} + assert set(actual_params) == set(expected_params) + for name, expected in expected_params.items(): + actual = actual_params[name] + assert_equal(actual, expected) def AutoGuideList_x(model): @@ -499,6 +536,8 @@ def model(): AutoLowRankMultivariateNormal, AutoIAFNormal, AutoLaplaceApproximation, + AutoGaussian, + AutoGaussianFunsor, ], ) def test_discrete_parallel(continuous_class): @@ -534,11 +573,13 @@ def model(data): AutoLowRankMultivariateNormal, AutoIAFNormal, AutoLaplaceApproximation, + AutoGaussian, + AutoGaussianFunsor, ], ) def test_guide_list(auto_class): def model(): - pyro.sample("x", dist.Normal(0.0, 1.0).expand([2])) + pyro.sample("x", dist.Normal(0.0, 1.0).expand([2]).to_event(1)) pyro.sample("y", dist.MultivariateNormal(torch.zeros(5), torch.eye(5, 5))) guide = AutoGuideList(model) @@ -556,6 +597,8 @@ def model(): AutoMultivariateNormal, AutoLowRankMultivariateNormal, AutoLaplaceApproximation, + AutoGaussian, + AutoGaussianFunsor, ], ) def test_callable(auto_class): @@ -578,11 +621,14 @@ def guide_x(): "auto_class", [ AutoDelta, + AutoNormal, AutoDiagonalNormal, AutoMultivariateNormal, AutoNormal, AutoLowRankMultivariateNormal, AutoLaplaceApproximation, + AutoGaussian, + AutoGaussianFunsor, ], ) def test_callable_return_dict(auto_class): @@ -625,9 +671,12 @@ def model(): "auto_class", [ AutoDelta, + AutoNormal, AutoDiagonalNormal, AutoMultivariateNormal, AutoLowRankMultivariateNormal, + AutoGaussian, + AutoGaussianFunsor, ], ) def test_init_loc_fn(auto_class): @@ -690,6 +739,8 @@ def model(): auto_guide_module_callable, functools.partial(AutoDiagonalNormal, init_loc_fn=init_to_mean), functools.partial(AutoDiagonalNormal, init_loc_fn=init_to_median), + functools.partial(AutoNormal, init_loc_fn=init_to_median), + functools.partial(AutoGaussian, init_loc_fn=init_to_median), ], ) @pytest.mark.parametrize("Elbo", [Trace_ELBO, TraceGraph_ELBO, TraceEnum_ELBO]) @@ -701,8 +752,9 @@ def __init__(self): self.x_scale = PyroParam(torch.tensor(0.1), constraints.positive) def forward(self): - pyro.sample("x", dist.Normal(self.x_loc, self.x_scale)) + x = pyro.sample("x", dist.Normal(self.x_loc, self.x_scale)) pyro.sample("y", dist.Normal(2.0, 0.1)) + pyro.sample("z", dist.Normal(1.0, 0.1), obs=x) model = Model() guide = auto_class(model) @@ -775,6 +827,8 @@ def forward(self): AutoLaplaceApproximation, functools.partial(AutoDiagonalNormal, init_loc_fn=init_to_mean), functools.partial(AutoDiagonalNormal, init_loc_fn=init_to_median), + AutoGaussian, + AutoGaussianFunsor, ], ) @pytest.mark.parametrize("Elbo", [Trace_ELBO, TraceGraph_ELBO, TraceEnum_ELBO]) @@ -839,6 +893,8 @@ def __init__(self, model): functools.partial(AutoDiagonalNormal, init_loc_fn=init_to_median), AutoStructured, AutoStructured_predictive, + AutoGaussian, + AutoGaussianFunsor_xfail, ], ) def test_predictive(auto_class): @@ -891,6 +947,45 @@ def forward(self, x, y=None): assert len(samples) == len(samples_deser) +@pytest.mark.parametrize("sample_shape", [(), (6,), (5, 4)], ids=str) +@pytest.mark.parametrize( + "auto_class", + [ + AutoDelta, + AutoDiagonalNormal, + AutoMultivariateNormal, + AutoNormal, + AutoLowRankMultivariateNormal, + AutoLaplaceApproximation, + functools.partial(AutoDiagonalNormal, init_loc_fn=init_to_mean), + functools.partial(AutoDiagonalNormal, init_loc_fn=init_to_median), + AutoStructured, + AutoGaussian, + AutoGaussianFunsor, + ], +) +def test_replay_plates(auto_class, sample_shape): + def model(): + a = pyro.sample("a", dist.Normal(0, 1)) + b = pyro.sample("b", dist.Normal(a[..., None], torch.ones(3)).to_event(1)) + c = pyro.sample( + "c", dist.MultivariateNormal(torch.zeros(3) + a[..., None], torch.eye(3)) + ) + with pyro.plate("i", 2): + d = pyro.sample("d", dist.Dirichlet((b + c).exp())) + pyro.sample("e", dist.Categorical(logits=d), obs=torch.tensor([0, 0])) + return a, b, c, d + + guide = auto_class(model) + with pyro.plate_stack("particles", sample_shape, rightmost_dim=-2): + guide_trace = poutine.trace(guide).get_trace() + a, b, c, d = poutine.replay(model, guide_trace)() + assert a.shape == (sample_shape + (1,) if sample_shape else ()) + assert b.shape == (sample_shape + (1, 3) if sample_shape else (3,)) + assert c.shape == (sample_shape + (1, 3) if sample_shape else (3,)) + assert d.shape == sample_shape + (2, 3) + + @pytest.mark.parametrize("auto_class", [AutoDelta, AutoNormal]) def test_subsample_model(auto_class): def model(x, y=None, batch_size=None): @@ -1015,6 +1110,8 @@ def create_plates(data): AutoNormal, AutoLowRankMultivariateNormal, AutoLaplaceApproximation, + AutoGaussian, + AutoGaussianFunsor, ], ) @pytest.mark.parametrize( @@ -1049,6 +1146,8 @@ def model(): AutoNormal, AutoLowRankMultivariateNormal, AutoLaplaceApproximation, + AutoGaussian, + AutoGaussianFunsor, ], ) @pytest.mark.parametrize( @@ -1080,6 +1179,8 @@ def model(): AutoNormal, AutoLowRankMultivariateNormal, AutoLaplaceApproximation, + AutoGaussian, + AutoGaussianFunsor, ], ) @pytest.mark.parametrize( @@ -1095,7 +1196,9 @@ def test_sphere_reparam_ok(auto_class, init_loc_fn): def model(): x = pyro.sample("x", dist.Normal(0.0, 1.0).expand([3]).to_event(1)) y = pyro.sample("y", dist.ProjectedNormal(x)) - pyro.sample("obs", dist.Normal(y, 1), obs=torch.tensor([1.0, 0.0])) + pyro.sample( + "obs", dist.Normal(y, 1).to_event(1), obs=torch.tensor([1.0, 0.0, 0.0]) + ) model = poutine.reparam(model, {"y": ProjectedNormalReparam()}) guide = auto_class(model) @@ -1116,7 +1219,9 @@ def test_sphere_raw_ok(auto_class, init_loc_fn): def model(): x = pyro.sample("x", dist.Normal(0.0, 1.0).expand([3]).to_event(1)) y = pyro.sample("y", dist.ProjectedNormal(x)) - pyro.sample("obs", dist.Normal(y, 1), obs=torch.tensor([1.0, 0.0])) + pyro.sample( + "obs", dist.Normal(y, 1).to_event(1), obs=torch.tensor([1.0, 0.0, 0.0]) + ) guide = auto_class(model, init_loc_fn=init_loc_fn) poutine.trace(guide).get_trace().compute_log_prob() @@ -1146,8 +1251,11 @@ def __init__(self, model): AutoNormal, AutoDiagonalNormal, AutoMultivariateNormal, + AutoLowRankMultivariateNormal, AutoStructured_exact_normal, AutoStructured_exact_mvn, + AutoGaussian, + AutoGaussianFunsor, ], ) def test_exact(Guide): @@ -1160,6 +1268,20 @@ def model(data): data = torch.randn(3) expected_mean = (0 + data.sum().item()) / (1 + len(data)) expected_std = (1 + len(data)) ** (-0.5) + g = Gaussian( + log_normalizer=torch.zeros(()), + info_vec=torch.zeros(4), + precision=torch.tensor( + [ + [4, -1, -1, -1], + [-1, 1, 0, 0], + [-1, 0, 1, 0], + [-1, 0, 0, 1], + ], + dtype=data.dtype, + ), + ) + expected_loss = float(g.event_logsumexp() - g.condition(data).event_logsumexp()) guide = Guide(model) elbo = Trace_ELBO(num_particles=100, vectorize_particles=True) @@ -1170,6 +1292,7 @@ def model(data): guide.requires_grad_(False) with torch.no_grad(): + # Check moments. vectorize = pyro.plate("particles", 10000, dim=-2) guide_trace = poutine.trace(vectorize(guide)).get_trace(data) samples = poutine.replay(vectorize(model), guide_trace)(data) @@ -1178,6 +1301,10 @@ def model(data): assert_close(actual_mean, expected_mean, atol=0.05) assert_close(actual_std, expected_std, rtol=0.05) + # Check ELBO loss. + actual_loss = elbo.loss(model, guide, data) + assert_close(actual_loss, expected_loss, atol=0.01) + @pytest.mark.parametrize( "Guide", @@ -1185,8 +1312,11 @@ def model(data): AutoNormal, AutoDiagonalNormal, AutoMultivariateNormal, + AutoLowRankMultivariateNormal, AutoStructured_exact_normal, AutoStructured_exact_mvn, + AutoGaussian, + AutoGaussianFunsor, ], ) def test_exact_batch(Guide): @@ -1199,6 +1329,17 @@ def model(data): data = torch.randn(3) expected_mean = (0 + data) / (1 + 1) expected_std = (1 + torch.ones_like(data)) ** (-0.5) + g = Gaussian( + log_normalizer=torch.zeros(3), + info_vec=torch.zeros(3, 2), + precision=torch.tensor( + [[[2, -1], [-1, 1]]] * 3, + dtype=data.dtype, + ), + ) + expected_loss = float( + g.event_logsumexp().sum() - g.condition(data[:, None]).event_logsumexp().sum() + ) guide = Guide(model) elbo = Trace_ELBO(num_particles=100, vectorize_particles=True) @@ -1209,6 +1350,7 @@ def model(data): guide.requires_grad_(False) with torch.no_grad(): + # Check moments. vectorize = pyro.plate("particles", 10000, dim=-2) guide_trace = poutine.trace(vectorize(guide)).get_trace(data) samples = poutine.replay(vectorize(model), guide_trace)(data) @@ -1216,3 +1358,79 @@ def model(data): actual_std = samples.std(0) assert_close(actual_mean, expected_mean, atol=0.05) assert_close(actual_std, expected_std, rtol=0.05) + + # Check ELBO loss. + actual_loss = elbo.loss(model, guide, data) + assert_close(actual_loss, expected_loss, atol=0.01) + + +@pytest.mark.parametrize( + "Guide", + [ + AutoNormal, + AutoDiagonalNormal, + AutoMultivariateNormal, + AutoLowRankMultivariateNormal, + AutoStructured, + AutoGaussian, + AutoGaussianFunsor, + ], +) +def test_exact_tree(Guide): + is_exact = Guide not in (AutoNormal, AutoDiagonalNormal) + + def model(data): + x = pyro.sample("x", dist.Normal(0, 1)) + with pyro.plate("data", len(data)): + y = pyro.sample("y", dist.Normal(x, 1)) + pyro.sample("obs", dist.Normal(y, 1), obs=data) + return {"x": x, "y": y} + + data = torch.randn(2) + g = Gaussian( + log_normalizer=torch.zeros(()), + info_vec=torch.zeros(5), + precision=torch.tensor( + [ + [3, -1, -1, 0, 0], # x + [-1, 2, 0, -1, 0], # y[0] + [-1, 0, 2, 0, -1], # y[1] + [0, -1, 0, 1, 0], # obs[0] + [0, 0, -1, 0, 1], # obs[1] + ], + dtype=data.dtype, + ), + ) + g_cond = g.condition(data) + mean = torch.linalg.solve(g_cond.precision, g_cond.info_vec) + std = torch.inverse(g_cond.precision).diag().sqrt() + expected_mean = {"x": mean[0], "y": mean[1:]} + expected_std = {"x": std[0], "y": std[1:]} + expected_loss = float(g.event_logsumexp() - g_cond.event_logsumexp()) + + guide = Guide(model) + elbo = Trace_ELBO(num_particles=100, vectorize_particles=True) + num_steps = 500 + optim = ClippedAdam({"lr": 0.05, "lrd": 0.1 ** (1 / num_steps)}) + svi = SVI(model, guide, optim, elbo) + for step in range(num_steps): + svi.step(data) + + guide.train(False) + guide.requires_grad_(False) + with torch.no_grad(): + # Check moments. + vectorize = pyro.plate("particles", 10000, dim=-2) + guide_trace = poutine.trace(vectorize(guide)).get_trace(data) + samples = poutine.replay(vectorize(model), guide_trace)(data) + for name in ["x", "y"]: + actual_mean = samples[name].mean(0).squeeze() + actual_std = samples[name].std(0).squeeze() + assert_close(actual_mean, expected_mean[name], atol=0.05) + if is_exact: + assert_close(actual_std, expected_std[name], rtol=0.05) + + if is_exact: + # Check ELBO loss. + actual_loss = elbo.loss(model, guide, data) + assert_close(actual_loss, expected_loss, atol=0.01) diff --git a/tests/ops/test_gaussian.py b/tests/ops/test_gaussian.py index b602a0e07b..904aafc951 100644 --- a/tests/ops/test_gaussian.py +++ b/tests/ops/test_gaussian.py @@ -2,9 +2,11 @@ # SPDX-License-Identifier: Apache-2.0 import math +from collections import OrderedDict import pytest import torch +from torch.distributions import constraints, transform_to from torch.nn.functional import pad import pyro.distributions as dist @@ -418,3 +420,70 @@ def test_gaussian_tensordot( # TODO(fehiepsi): find some condition to make this test stable, so we can compare large value # log densities. assert_close(actual.clamp(max=10.0), expect.clamp(max=10.0), atol=0.1, rtol=0.1) + + +@pytest.mark.stage("funsor") +@pytest.mark.parametrize("batch_shape", [(), (5,), (4, 2)], ids=str) +def test_gaussian_funsor(batch_shape): + # This tests sample distribution, rsample gradients, log_prob, and log_prob + # gradients for both Pyro's and Funsor's Gaussian. + import funsor + + funsor.set_backend("torch") + num_samples = 100000 + + # Declare unconstrained parameters. + loc = torch.randn(batch_shape + (3,)).requires_grad_() + t = transform_to(constraints.positive_definite) + m = torch.randn(batch_shape + (3, 3)) + precision_unconstrained = t.inv(m @ m.transpose(-1, -2)).requires_grad_() + + # Transform to constrained space. + log_normalizer = torch.zeros(batch_shape) + precision = t(precision_unconstrained) + info_vec = (precision @ loc[..., None])[..., 0] + + def check_equal(actual, expected, atol=0.01, rtol=0): + assert_close(actual.data, expected.data, atol=atol, rtol=rtol) + grads = torch.autograd.grad( + (actual - expected).abs().sum(), + [loc, precision_unconstrained], + retain_graph=True, + ) + for grad in grads: + assert grad.abs().max() < atol + + entropy = dist.MultivariateNormal(loc, precision_matrix=precision).entropy() + + # Monte carlo estimate entropy via pyro. + p_gaussian = Gaussian(log_normalizer, info_vec, precision) + p_log_Z = p_gaussian.event_logsumexp() + p_rsamples = p_gaussian.rsample((num_samples,)) + pp_entropy = (p_log_Z - p_gaussian.log_density(p_rsamples)).mean(0) + check_equal(pp_entropy, entropy) + + # Monte carlo estimate entropy via funsor. + inputs = OrderedDict([(k, funsor.Bint[v]) for k, v in zip("ij", batch_shape)]) + inputs["x"] = funsor.Reals[3] + f_gaussian = funsor.gaussian.Gaussian(info_vec, precision, inputs) + f_log_Z = f_gaussian.reduce(funsor.ops.logaddexp, "x") + sample_inputs = OrderedDict(particle=funsor.Bint[num_samples]) + deltas = f_gaussian.sample("x", sample_inputs) + f_rsamples = funsor.montecarlo.extract_samples(deltas)["x"] + ff_entropy = (f_log_Z - f_gaussian(x=f_rsamples)).reduce( + funsor.ops.mean, "particle" + ) + check_equal(ff_entropy.data, entropy) + + # Check Funsor's .rsample against Pyro's .log_prob. + pf_entropy = (p_log_Z - p_gaussian.log_density(f_rsamples.data)).mean(0) + check_equal(pf_entropy, entropy) + + # Check Pyro's .rsample against Funsor's .log_prob. + fp_rsamples = funsor.Tensor(p_rsamples)["particle"] + for i in "ij"[: len(batch_shape)]: + fp_rsamples = fp_rsamples[i] + fp_entropy = (f_log_Z - f_gaussian(x=fp_rsamples)).reduce( + funsor.ops.mean, "particle" + ) + check_equal(fp_entropy.data, entropy)