From 233481abb8991425c257982bf17875e922493e50 Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Mon, 13 Sep 2021 19:58:15 -0400 Subject: [PATCH 01/41] Implement front-end of AutoGaussian guide --- pyro/infer/autoguide/__init__.py | 2 + pyro/infer/autoguide/guides.py | 181 ++++++++++++++++++++++++++++++- tests/infer/test_autoguide.py | 168 +++++++++++++++++++++++++++- 3 files changed, 349 insertions(+), 2 deletions(-) diff --git a/pyro/infer/autoguide/__init__.py b/pyro/infer/autoguide/__init__.py index 42da0c7030..10147f9875 100644 --- a/pyro/infer/autoguide/__init__.py +++ b/pyro/infer/autoguide/__init__.py @@ -7,6 +7,7 @@ AutoDelta, AutoDiagonalNormal, AutoDiscreteParallel, + AutoGaussian, AutoGuide, AutoGuideList, AutoIAFNormal, @@ -34,6 +35,7 @@ "AutoDelta", "AutoDiagonalNormal", "AutoDiscreteParallel", + "AutoGaussian", "AutoGuide", "AutoGuideList", "AutoIAFNormal", diff --git a/pyro/infer/autoguide/guides.py b/pyro/infer/autoguide/guides.py index 036632ec16..5ed66e2512 100644 --- a/pyro/infer/autoguide/guides.py +++ b/pyro/infer/autoguide/guides.py @@ -22,7 +22,7 @@ def model(): from collections import OrderedDict, defaultdict from contextlib import ExitStack from types import SimpleNamespace -from typing import Callable, Dict, Optional, Union +from typing import Callable, Dict, Optional, Tuple, Union import torch from torch import nn @@ -44,6 +44,7 @@ def model(): from pyro.nn import PyroModule, PyroParam from pyro.ops.hessian import hessian from pyro.ops.tensor_utils import periodic_repeat +from pyro.poutine.runtime import am_i_wrapped from pyro.poutine.util import site_is_subsample from .utils import _product, helpful_support_errors @@ -1642,3 +1643,181 @@ def median(self, *args, **kwargs): loc = loc.reshape(shape) result[name] = biject_to(site["fn"].support)(loc) return result + + +class AutoGaussian(AutoGuide): + """ + EXPERIMENTAL Gaussian Markov random field guide. + + This is equivalent to a full rank :class:`AutoMultivariateNormal` guide, + but with a sparse precision matrix determined by dependencies and plates in + the model. This can be orders of magnitude cheaper than the naive + :class:`AutoMultivariateNormal` in terms of space, time, number of + parameters, and statistical complexity. This also parametrizes correlations + via precision matrices rather than Cholesky factors, and therefore is less + susceptible to variable ordering in the model. + + The guide currently does not depend on the model's ``*args, **kwargs``. + + Usage:: + + guide = AutoGaussian(model) + svi = SVI(model, guide, ...) + + :param callable model: A Pyro model. + :param callable init_loc_fn: A per-site initialization function. + See :ref:`autoguide-initialization` section for available functions. + :param float init_scale: Initial scale for the standard deviation of each + (unconstrained transformed) latent variable. + :param callable create_plates: An optional function inputing the same + ``*args,**kwargs`` as ``model()`` and returning a :class:`pyro.plate` + or iterable of plates. Plates not returned will be created + automatically as usual. This is useful for data subsampling. + :param str backend: Back end for performing Gaussian tensor variable + elimination. Currently only the experimental "funsor" backend is supported, + and this requires the optional ``funsor`` dependency, installed via + ``pip install pyro-ppl[funsor]``. + """ + + def __init__( + self, + model: Callable, + *, + init_loc_fn: Callable = init_to_feasible, + init_scale: float = 0.1, + create_plates: Optional[Callable] = None, + backend="funsor", + ): + if not isinstance(init_scale, float) or not (init_scale > 0): + raise ValueError(f"Expected init_scale > 0. but got {init_scale}") + self._init_scale = init_scale + self._original_model = (model,) + model = InitMessenger(init_loc_fn)(model) + super().__init__(model, create_plates=create_plates) + self.backend = backend + + def _setup_prototype(self, *args, **kwargs) -> None: + super()._setup_prototype(*args, **kwargs) + + self.locs = PyroModule() + self.scales = PyroModule() + self.precisions = PyroModule() + self._plates = {} + self._unconstrained_event_shapes = {} + + model = self._original_model[0] + meta = poutine.block(get_dependencies)(model, args, kwargs) + self.dependencies = meta["posterior_dependencies"] + order = {d: i for i, d in enumerate(self.dependencies)} + for d, upstreams in self.dependencies.items(): + site = self.prototype_trace.nodes[d] + init_loc = biject_to(site["fn"].support).inv(site["value"]).detach() + batch_shape = site["fn"].batch_shape + self._unconstrained_event_shapes[d] = init_loc.shape[len(batch_shape) :] + event_dim = len(self._unconstrained_event_shapes[d]) + _deep_setattr(self.locs, d, PyroParam(init_loc, event_dim=event_dim)) + _deep_setattr( + self.scales, + d, + PyroParam( + torch.full_like(init_loc, self._init_scale), + constraint=constraints.softplus_positive, + event_dim=event_dim, + ), + ) + self._plates[d] = frozenset(site["cond_indep_stack"]) + for f in site["cond_indep_stack"]: + self._plates[f.name] = f + for u, dep_plate_names in upstreams.items(): + if u not in order or order[u] > order[d]: + continue + dep_plates = frozenset(self._plates[p] for p in dep_plate_names) + indep_plates = self._plates[u] | self._plates[d] - dep_plates + shape = torch.Size( + tuple(f.size for f in indep_plates) + + tuple(f.size for f in dep_plates) + + self._unconstrained_event_shapes[u] + + self._unconstrained_event_shapes[d] + ) + _deep_setattr( + self.precisions, + f"{d}.{u}", + torch.nn.Parameter(init_loc.new_zeros(shape)), + ) + + def forward(self, *args, **kwargs) -> Dict[str, torch.Tensor]: + if self.prototype_trace is None: + self._setup_prototype(*args, **kwargs) + + aux_values, log_density_1 = self._sample_aux_values() + values, log_density_2 = self._transform_values(aux_values) + log_density = log_density_1 + log_density_2 + + # Replay via Pyro primitives. + compute_density = am_i_wrapped() and poutine.get_mask() is not False + plates = self._create_plates(*args, **kwargs) + for name in self.dependencies: + site = self.prototype_trace.nodes[name] + if compute_density: + scale = _deep_getattr(self.scales, name) + log_density = log_density + scale.log().sum() + with ExitStack() as stack: + for frame in site["cond_indep_stack"]: + if frame.vectorized: + stack.enter_context(plates[frame.name]) + pyro.sample( + name, dist.Delta(values[name], event_dim=site["fn"].event_dim) + ) + if compute_density: + pyro.factor("AutoGaussian_log_density", log_density) + return values + + @torch.no_grad() + def median(self) -> Dict[str, torch.Tensor]: + with poutine.mask(mask=False): + aux_values = {name: 0.0 for name in self.dependencies} + values, _ = self._transform_values(aux_values) + return values + + def _sample_aux_values( + self, + ) -> Tuple[Dict[str, torch.Tensor], Union[float, torch.tensor]]: + # Sample auxiliary values via Gaussian tensor variable elimination. + aux_values = {} + log_density = 0.0 + + if self.backend == "funsor": + raise NotImplementedError("TODO") + elif self.backend == "smoke_test": + # The following should be equivalent to AutoNormal. + for name in self.dependencies: + site = self.prototype_trace.nodes[name] + aux_values[name] = pyro.sample( + name + "_aux", + dist.Normal( + torch.zeros_like(_deep_getattr(self.locs, name)), 1 + ).to_event(site["fn"].event_dim), + infer={"is_auxiliary": True}, + ) + else: + raise ValueError(f"Unknown backend: {self.backend}") + + return aux_values, log_density + + def _transform_values( + self, + aux_values: Dict[str, torch.Tensor], + ) -> Tuple[Dict[str, torch.Tensor], Union[float, torch.Tensor]]: + # Learnably transform auxiliary values to user-facing values. + compute_density = am_i_wrapped() and poutine.get_mask() is not False + values = {} + log_density = 0.0 + for name in self.dependencies: + site = self.prototype_trace.nodes[name] + loc = _deep_getattr(self.locs, name) + scale = _deep_getattr(self.scales, name) + values[name] = biject_to(site["fn"].support)(aux_values[name] * scale + loc) + if compute_density: + log_density = log_density + scale.log().sum() + + return values, log_density diff --git a/tests/infer/test_autoguide.py b/tests/infer/test_autoguide.py index 8e1f23b6f6..dfa5357bdb 100644 --- a/tests/infer/test_autoguide.py +++ b/tests/infer/test_autoguide.py @@ -21,6 +21,7 @@ AutoDelta, AutoDiagonalNormal, AutoDiscreteParallel, + AutoGaussian, AutoGuide, AutoGuideList, AutoIAFNormal, @@ -34,7 +35,7 @@ init_to_median, init_to_sample, ) -from pyro.infer.reparam import ProjectedNormalReparam +from pyro.infer.reparam import LocScaleReparam, ProjectedNormalReparam from pyro.nn.module import PyroModule, PyroParam, PyroSample from pyro.optim import Adam from pyro.poutine.util import prune_subsample_sites @@ -183,6 +184,7 @@ def dependency_z6_z5(z5): AutoLaplaceApproximation, AutoStructured, AutoStructured_shapes, + AutoGaussian, ], ) @pytest.mark.filterwarnings("ignore::FutureWarning") @@ -332,6 +334,7 @@ def __init__(self, model): functools.partial(AutoDiagonalNormal, init_loc_fn=init_to_sample), AutoStructured, AutoStructured_median, + AutoGaussian, ], ) @pytest.mark.parametrize("Elbo", [Trace_ELBO, TraceGraph_ELBO, TraceEnum_ELBO]) @@ -380,6 +383,7 @@ def model(): functools.partial(AutoDiagonalNormal, init_loc_fn=init_to_sample), AutoStructured, AutoStructured_median, + AutoGaussian, ], ) @pytest.mark.parametrize("Elbo", [Trace_ELBO, TraceGraph_ELBO, TraceEnum_ELBO]) @@ -839,6 +843,7 @@ def __init__(self, model): functools.partial(AutoDiagonalNormal, init_loc_fn=init_to_median), AutoStructured, AutoStructured_predictive, + AutoGaussian, ], ) def test_predictive(auto_class): @@ -1216,3 +1221,164 @@ def model(data): actual_std = samples.std(0) assert_close(actual_mean, expected_mean, atol=0.05) assert_close(actual_std, expected_std, rtol=0.05) + + +# Simplified from https://github.com/pyro-cov/tree/master/pyrocov/mutrans.py +def pyrocov_model(dataset): + # Tensor shapes are commented at at the end of some lines. + features = dataset["features"] + local_time = dataset["local_time"][..., None] # [T, P, 1] + T, P, _ = local_time.shape + S, F = features.shape + weekly_strains = dataset["weekly_strains"] + assert weekly_strains.shape == (T, P, S) + + # Configure reparametrization (which does not affect model density). + local_time = local_time + pyro.param( + "local_time", lambda: torch.zeros(P, S) + ) # [T, P, S] + + # Sample global random variables. + coef_scale = pyro.sample("coef_scale", dist.InverseGamma(5e3, 1e2))[..., None] + rate_scale = pyro.sample("rate_scale", dist.LogNormal(-4, 2))[..., None] + init_loc_scale = pyro.sample("init_loc_scale", dist.LogNormal(0, 2))[..., None] + init_scale = pyro.sample("init_scale", dist.LogNormal(0, 2))[..., None] + + # Assume relative growth rate depends strongly on mutations and weakly on place. + coef_loc = torch.zeros(F) + coef = pyro.sample("coef", dist.Logistic(coef_loc, coef_scale).to_event(1)) # [F] + rate_loc = pyro.deterministic( + "rate_loc", 0.01 * coef @ features.T, event_dim=1 + ) # [S] + + # Assume initial infections depend strongly on strain and place. + init_loc = pyro.sample( + "init_loc", dist.Normal(torch.zeros(S), init_loc_scale).to_event(1) + ) # [S] + with pyro.plate("place", P, dim=-1): + rate = pyro.sample( + "rate", dist.Normal(rate_loc, rate_scale).to_event(1) + ) # [P, S] + init = pyro.sample( + "init", dist.Normal(init_loc, init_scale).to_event(1) + ) # [P, S] + + # Finally observe counts. + with pyro.plate("time", T, dim=-2): + logits = init + rate * local_time # [T, P, S] + pyro.sample( + "obs", + dist.Multinomial(logits=logits, validate_args=False), + obs=weekly_strains, + ) + + +@pytest.mark.parametrize("backend", ["smoke_test"]) +def test_autogaussian_pyrocov(backend): + T, P, S, F = 3, 4, 5, 6 + dataset = { + "features": torch.randn(S, F), + "local_time": torch.randn(T, P), + "weekly_strains": torch.randn(T, P, S).exp().round(), + } + + guide = AutoGaussian(pyrocov_model, backend=backend) + with poutine.mask(mask=False): + guide(dataset) # initialize guide + + # Check automatically determined dependencies. + # Without reparametrization the posterior dependencies are sparse. + _ = set() + expected_dependencies = { + "coef": { + "coef_scale": _, # direct + "rate_scale": _, # moralized + }, + "init_loc": { + "init_loc_scale": _, # direct + "init_scale": _, # moralized + }, + "rate": { + "coef": _, # direct + "rate_scale": _, # direct + }, + "init": { + "init_loc": _, # direct + "init_scale": _, # direct + "rate": _, # moralized + }, + } + assert guide.dependencies == expected_dependencies + + +@pytest.mark.parametrize("backend", ["smoke_test"]) +def test_structured_pyrocov_reparam(backend): + T, P, S, F = 3, 4, 5, 6 + dataset = { + "features": torch.randn(S, F), + "local_time": torch.randn(T, P), + "weekly_strains": torch.randn(T, P, S).exp().round(), + } + + # Reparametrize the model. + config = { + "coef": LocScaleReparam(), + "rate": LocScaleReparam(), + "init_loc": LocScaleReparam(), + "init": LocScaleReparam(), + } + model = poutine.reparam(pyrocov_model, config) + guide = AutoGaussian(model, backend=backend) + with poutine.mask(mask=False): + guide(dataset) # initialize guide + + # Check automatically determined dependencies. + # With reparametrization the posterior dependencies are dense. + _ = set() + expected_dependencies = { + "rate_scale": { + "coef_scale": _, # reparam moralized + }, + "init_loc_scale": { + "coef_scale": _, # reparam moralized + "rate_scale": _, # reparam moralized + }, + "init_scale": { + "coef_scale": _, # reparam moralized + "rate_scale": _, # reparam moralized + "init_loc_scale": _, # reparam moralized + }, + "coef_decentered": { + "coef_scale": _, # direct + "rate_scale": _, # moralized + "init_loc_scale": _, # reparam moralized + "init_scale": _, # reparam moralized + }, + "init_loc_decentered": { + "init_loc_scale": _, # direct + "init_scale": _, # moralized + "coef_scale": _, # reparam moralized + "rate_scale": _, # reparam moralized + "coef_decentered": _, # reparam moralized + }, + "rate_decentered": { + "coef_decentered": _, # direct + "coef_scale": _, # reparam direct + "rate_scale": _, # direct + "rate_scale": _, # reparam direct + "init_loc_scale": _, # reparam moralized + "init_scale": _, # reparam moralized + "init_loc_decentered": _, # reparam moralized + }, + "init_decentered": { + "init_loc_decentered": _, # direct + "init_scale": _, # direct + "rate_decentered": _, # moralized + "init_loc_scale": _, # reparam direct + "coef_scale": _, # reparam moralized + "rate_scale": _, # reparam moralized + "coef_decentered": _, # reparam moralized + "rate_decentered": _, # reparam moralized + }, + } + assert guide.dependencies == expected_dependencies From 75bf88a82d033f1fdb5dff5ddc8f17db749d9f0a Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Tue, 14 Sep 2021 12:36:15 -0400 Subject: [PATCH 02/41] Implement funsor backend --- pyro/infer/autoguide/guides.py | 173 ++++++++++++++++++++++++--------- 1 file changed, 129 insertions(+), 44 deletions(-) diff --git a/pyro/infer/autoguide/guides.py b/pyro/infer/autoguide/guides.py index 5ed66e2512..4315171dbd 100644 --- a/pyro/infer/autoguide/guides.py +++ b/pyro/infer/autoguide/guides.py @@ -1679,6 +1679,8 @@ class AutoGaussian(AutoGuide): ``pip install pyro-ppl[funsor]``. """ + scale_constraint = constraints.softplus_positive + def __init__( self, model: Callable, @@ -1702,48 +1704,69 @@ def _setup_prototype(self, *args, **kwargs) -> None: self.locs = PyroModule() self.scales = PyroModule() self.precisions = PyroModule() - self._plates = {} self._unconstrained_event_shapes = {} model = self._original_model[0] meta = poutine.block(get_dependencies)(model, args, kwargs) - self.dependencies = meta["posterior_dependencies"] - order = {d: i for i, d in enumerate(self.dependencies)} - for d, upstreams in self.dependencies.items(): - site = self.prototype_trace.nodes[d] - init_loc = biject_to(site["fn"].support).inv(site["value"]).detach() - batch_shape = site["fn"].batch_shape - self._unconstrained_event_shapes[d] = init_loc.shape[len(batch_shape) :] - event_dim = len(self._unconstrained_event_shapes[d]) - _deep_setattr(self.locs, d, PyroParam(init_loc, event_dim=event_dim)) + self.dependencies = meta["prior_dependencies"] + self.latents = { + name: site + for name, site in self.prototype_trace.nodes.items() + if site["type"] == "sample" + if not site_is_subsample(site) + } + for d, site in self.latents.items(): + precision_size = 0 + precision_plates = set() + if not site["is_observed"]: + # Initialize latent variable location-scale parameters. + init_loc = biject_to(site["fn"].support).inv(site["value"]).detach() + init_scale = torch.full_like(init_loc, self._init_scale) + batch_shape = site["fn"].batch_shape + event_shape = init_loc.shape[len(batch_shape) :] + self._unconstrained_event_shapes[d] = event_shape + event_dim = len(event_shape) + _deep_setattr(self.locs, d, PyroParam(init_loc, event_dim=event_dim)) + _deep_setattr( + self.scales, + d, + PyroParam( + init_scale, + constraint=self.scale_constraint, + event_dim=event_dim, + ), + ) + + # Gather shapes for precision matrices. + for f in site["cond_indep_stack"]: + if f.vectorized: + precision_plates.add(f) + + # Initialize precision matrices. + for u in self.dependencies[d]: + precision_size += self._unconstrained_event_shapes[u].numel() + for f in self.prototype_trace.nodes[u]["cond_indep_stack"]: + if f not in precision_plates: + # TODO Convert upstream plates to event dimensions, and + # support colliders in backend sum_product() algorithms. + raise NotImplementedError("intractable!") + precision_plates = sorted(precision_plates, key=lambda f: f.dim) + batch_shape = torch.Size(f.size for f in precision_plates) + init_precision = torch.zeros(*batch_shape, precision_size, precision_size) + init_precision.view(-1, precision_size ** 2)[ + ..., :: precision_size + 1 + ].fill_( + 1 + ) # init to eye _deep_setattr( - self.scales, + self.precisions, d, PyroParam( - torch.full_like(init_loc, self._init_scale), - constraint=constraints.softplus_positive, - event_dim=event_dim, + init_precision, + constraint=constraints.positive_definite, + event_dim=2, ), ) - self._plates[d] = frozenset(site["cond_indep_stack"]) - for f in site["cond_indep_stack"]: - self._plates[f.name] = f - for u, dep_plate_names in upstreams.items(): - if u not in order or order[u] > order[d]: - continue - dep_plates = frozenset(self._plates[p] for p in dep_plate_names) - indep_plates = self._plates[u] | self._plates[d] - dep_plates - shape = torch.Size( - tuple(f.size for f in indep_plates) - + tuple(f.size for f in dep_plates) - + self._unconstrained_event_shapes[u] - + self._unconstrained_event_shapes[d] - ) - _deep_setattr( - self.precisions, - f"{d}.{u}", - torch.nn.Parameter(init_loc.new_zeros(shape)), - ) def forward(self, *args, **kwargs) -> Dict[str, torch.Tensor]: if self.prototype_trace is None: @@ -1756,8 +1779,7 @@ def forward(self, *args, **kwargs) -> Dict[str, torch.Tensor]: # Replay via Pyro primitives. compute_density = am_i_wrapped() and poutine.get_mask() is not False plates = self._create_plates(*args, **kwargs) - for name in self.dependencies: - site = self.prototype_trace.nodes[name] + for name, site in self.latents.items(): if compute_density: scale = _deep_getattr(self.scales, name) log_density = log_density + scale.log().sum() @@ -1783,15 +1805,13 @@ def _sample_aux_values( self, ) -> Tuple[Dict[str, torch.Tensor], Union[float, torch.tensor]]: # Sample auxiliary values via Gaussian tensor variable elimination. - aux_values = {} - log_density = 0.0 - if self.backend == "funsor": - raise NotImplementedError("TODO") + return self._sample_aux_values_funsor() elif self.backend == "smoke_test": # The following should be equivalent to AutoNormal. - for name in self.dependencies: - site = self.prototype_trace.nodes[name] + aux_values = {} + log_density = 0.0 + for name, site in self.latents.items(): aux_values[name] = pyro.sample( name + "_aux", dist.Normal( @@ -1799,10 +1819,76 @@ def _sample_aux_values( ).to_event(site["fn"].event_dim), infer={"is_auxiliary": True}, ) + return aux_values, log_density else: raise ValueError(f"Unknown backend: {self.backend}") - return aux_values, log_density + def _sample_aux_values_funsor( + self, + ) -> Tuple[Dict[str, torch.Tensor], Union[float, torch.tensor]]: + import funsor + + import pyro.contrib.funsor + + # Construct TVE problem inputs, converting torch to funsor. + factors = {} + plate_to_dim = {} + eliminate = set() + for d, site in self.latents.items(): + inputs = OrderedDict() + for f in site["cond_indep_stack"]: + if f.vectorized: + inputs[f.name] = funsor.Bint[f.size] + plate_to_dim[f.name] = f.dim + eliminate.add(f.name) + if not site["is_observed"]: + inputs[d] = funsor.Reals[self._unconstrained_event_shapes[d]] + eliminate.add(d) + for u in self.dependencies[d]: + inputs[u] = funsor.Reals[self._unconstrained_event_shapes[u]] + assert u in eliminate + + precision = _deep_getattr(self.precisions, d) + info_vec = precision.new_zeros(precision.shape[:-1]) + factors[d] = funsor.gaussian.Gaussian(info_vec, precision, inputs) + plates = frozenset(plate_to_dim) + eliminate = frozenset(eliminate) + + # Draw samples via Gaussian tensor variable elimination. + with funsor.interpretations.reflect: + log_Z = funsor.sum_product.sum_product( + funsor.ops.logaddexp, + funsor.ops.add, + list(factors.values()), + eliminate=eliminate, + plates=plates, + ) + log_Z = funsor.optimizer.apply_optimizer(log_Z) + with funsor.montecarlo.MonteCarlo(): + samples = funsor.adjoint.adjoint( + funsor.ops.logaddexp, funsor.ops.add, log_Z + ) + # Extract funsor.Tensor values from funsor.Delta samples. + samples = { + name: pyro.contrib.funsor.handlers.enum_messenger._get_support_value( + samples[factors[name]], name + ) + for name in eliminate - plates + } + + # Compute density. + log_prob = 0.0 + for f in factors.values(): + # Substitute samples and eliminate plates. + log_prob += f(**samples).reduce(funsor.ops.add) + + # Convert funsor to torch. + samples = { + k: funsor.to_data(v, name_to_dim=plate_to_dim) for k, v in samples.items() + } + log_density = funsor.to_data(log_prob) + + return samples, log_density def _transform_values( self, @@ -1812,8 +1898,7 @@ def _transform_values( compute_density = am_i_wrapped() and poutine.get_mask() is not False values = {} log_density = 0.0 - for name in self.dependencies: - site = self.prototype_trace.nodes[name] + for name, site in self.latents.items(): loc = _deep_getattr(self.locs, name) scale = _deep_getattr(self.scales, name) values[name] = biject_to(site["fn"].support)(aux_values[name] * scale + loc) From c73979eeffc53b84dd1f13585f50890989f2e3d7 Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Tue, 14 Sep 2021 13:40:16 -0400 Subject: [PATCH 03/41] Break plates causing collision --- pyro/infer/autoguide/guides.py | 47 ++++--- tests/infer/test_autoguide.py | 215 +++++++++++++++++++-------------- 2 files changed, 155 insertions(+), 107 deletions(-) diff --git a/pyro/infer/autoguide/guides.py b/pyro/infer/autoguide/guides.py index 4315171dbd..551fbef565 100644 --- a/pyro/infer/autoguide/guides.py +++ b/pyro/infer/autoguide/guides.py @@ -22,7 +22,7 @@ def model(): from collections import OrderedDict, defaultdict from contextlib import ExitStack from types import SimpleNamespace -from typing import Callable, Dict, Optional, Tuple, Union +from typing import Callable, Dict, Optional, Set, Tuple, Union import torch from torch import nn @@ -44,6 +44,7 @@ def model(): from pyro.nn import PyroModule, PyroParam from pyro.ops.hessian import hessian from pyro.ops.tensor_utils import periodic_repeat +from pyro.poutine.indep_messenger import CondIndepStackFrame from pyro.poutine.runtime import am_i_wrapped from pyro.poutine.util import site_is_subsample @@ -1704,11 +1705,19 @@ def _setup_prototype(self, *args, **kwargs) -> None: self.locs = PyroModule() self.scales = PyroModule() self.precisions = PyroModule() - self._unconstrained_event_shapes = {} + self._unconstrained_event_shapes: Dict[str, torch.Size] = {} + self._broken_event_shapes: Dict[str, torch.Size] = {} + self._broken_plates: Dict[str, Tuple[str, ...]] = defaultdict(tuple) model = self._original_model[0] meta = poutine.block(get_dependencies)(model, args, kwargs) self.dependencies = meta["prior_dependencies"] + broken_plates: Set[str] = { + p + for upstreams in meta["posterior_dependencies"].values() + for plates in upstreams.values() + for p in plates + } self.latents = { name: site for name, site in self.prototype_trace.nodes.items() @@ -1717,7 +1726,7 @@ def _setup_prototype(self, *args, **kwargs) -> None: } for d, site in self.latents.items(): precision_size = 0 - precision_plates = set() + precision_plates: Set[CondIndepStackFrame] = set() if not site["is_observed"]: # Initialize latent variable location-scale parameters. init_loc = biject_to(site["fn"].support).inv(site["value"]).detach() @@ -1738,20 +1747,24 @@ def _setup_prototype(self, *args, **kwargs) -> None: ) # Gather shapes for precision matrices. + broken_shape = torch.Size() for f in site["cond_indep_stack"]: if f.vectorized: - precision_plates.add(f) + if f.name in broken_plates: + self._broken_plates[d] += (f.name,) + broken_shape += (f.size,) + else: + precision_plates.add(f) + self._broken_event_shapes[d] = broken_shape + event_shape # Initialize precision matrices. for u in self.dependencies[d]: - precision_size += self._unconstrained_event_shapes[u].numel() + precision_size += self._broken_event_shapes[u].numel() for f in self.prototype_trace.nodes[u]["cond_indep_stack"]: - if f not in precision_plates: - # TODO Convert upstream plates to event dimensions, and - # support colliders in backend sum_product() algorithms. - raise NotImplementedError("intractable!") - precision_plates = sorted(precision_plates, key=lambda f: f.dim) - batch_shape = torch.Size(f.size for f in precision_plates) + assert f in precision_plates or f.name in broken_plates + batch_shape = torch.Size( + f.size for f in sorted(precision_plates, key=lambda f: f.dim) + ) init_precision = torch.zeros(*batch_shape, precision_size, precision_size) init_precision.view(-1, precision_size ** 2)[ ..., :: precision_size + 1 @@ -1838,14 +1851,15 @@ def _sample_aux_values_funsor( inputs = OrderedDict() for f in site["cond_indep_stack"]: if f.vectorized: - inputs[f.name] = funsor.Bint[f.size] plate_to_dim[f.name] = f.dim - eliminate.add(f.name) + if f.name not in self._broken_plates[d]: + inputs[f.name] = funsor.Bint[f.size] + eliminate.add(f.name) if not site["is_observed"]: - inputs[d] = funsor.Reals[self._unconstrained_event_shapes[d]] + inputs[d] = funsor.Reals[self._broken_event_shapes[d]] eliminate.add(d) for u in self.dependencies[d]: - inputs[u] = funsor.Reals[self._unconstrained_event_shapes[u]] + inputs[u] = funsor.Reals[self._broken_event_shapes[u]] assert u in eliminate precision = _deep_getattr(self.precisions, d) @@ -1884,7 +1898,8 @@ def _sample_aux_values_funsor( # Convert funsor to torch. samples = { - k: funsor.to_data(v, name_to_dim=plate_to_dim) for k, v in samples.items() + k: funsor.to_data(v[self._broken_plates[k]], name_to_dim=plate_to_dim) + for k, v in samples.items() } log_density = funsor.to_data(log_prob) diff --git a/tests/infer/test_autoguide.py b/tests/infer/test_autoguide.py index dfa5357bdb..98d402334e 100644 --- a/tests/infer/test_autoguide.py +++ b/tests/infer/test_autoguide.py @@ -1225,7 +1225,7 @@ def model(data): # Simplified from https://github.com/pyro-cov/tree/master/pyrocov/mutrans.py def pyrocov_model(dataset): - # Tensor shapes are commented at at the end of some lines. + # Tensor shapes are commented at the end of some lines. features = dataset["features"] local_time = dataset["local_time"][..., None] # [T, P, 1] T, P, _ = local_time.shape @@ -1233,11 +1233,6 @@ def pyrocov_model(dataset): weekly_strains = dataset["weekly_strains"] assert weekly_strains.shape == (T, P, S) - # Configure reparametrization (which does not affect model density). - local_time = local_time + pyro.param( - "local_time", lambda: torch.zeros(P, S) - ) # [T, P, S] - # Sample global random variables. coef_scale = pyro.sample("coef_scale", dist.InverseGamma(5e3, 1e2))[..., None] rate_scale = pyro.sample("rate_scale", dist.LogNormal(-4, 2))[..., None] @@ -1273,8 +1268,109 @@ def pyrocov_model(dataset): ) -@pytest.mark.parametrize("backend", ["smoke_test"]) -def test_autogaussian_pyrocov(backend): +# This is modified by relaxing rate from deterministic to latent. +def pyrocov_model_relaxed(dataset): + # Tensor shapes are commented at the end of some lines. + features = dataset["features"] + local_time = dataset["local_time"][..., None] # [T, P, 1] + T, P, _ = local_time.shape + S, F = features.shape + weekly_strains = dataset["weekly_strains"] + assert weekly_strains.shape == (T, P, S) + + # Sample global random variables. + coef_scale = pyro.sample("coef_scale", dist.InverseGamma(5e3, 1e2))[..., None] + rate_loc_scale = pyro.sample("rate_loc_scale", dist.LogNormal(-4, 2))[..., None] + rate_scale = pyro.sample("rate_scale", dist.LogNormal(-4, 2))[..., None] + init_loc_scale = pyro.sample("init_loc_scale", dist.LogNormal(0, 2))[..., None] + init_scale = pyro.sample("init_scale", dist.LogNormal(0, 2))[..., None] + + # Assume relative growth rate depends strongly on mutations and weakly on place. + coef_loc = torch.zeros(F) + coef = pyro.sample("coef", dist.Logistic(coef_loc, coef_scale).to_event(1)) # [F] + rate_loc = pyro.sample( + "rate_loc", + dist.Normal(0.01 * coef @ features.T, rate_loc_scale).to_event(1), + ) # [S] + + # Assume initial infections depend strongly on strain and place. + init_loc = pyro.sample( + "init_loc", dist.Normal(torch.zeros(S), init_loc_scale).to_event(1) + ) # [S] + with pyro.plate("place", P, dim=-1): + rate = pyro.sample( + "rate", dist.Normal(rate_loc, rate_scale).to_event(1) + ) # [P, S] + init = pyro.sample( + "init", dist.Normal(init_loc, init_scale).to_event(1) + ) # [P, S] + + # Finally observe counts. + with pyro.plate("time", T, dim=-2): + logits = init + rate * local_time # [T, P, S] + pyro.sample( + "obs", + dist.Multinomial(logits=logits, validate_args=False), + obs=weekly_strains, + ) + + +# This is modified by relaxing rate from deterministic to latent. +def pyrocov_model_plated(dataset): + # Tensor shapes are commented at the end of some lines. + features = dataset["features"] + local_time = dataset["local_time"][..., None] # [T, P, 1] + T, P, _ = local_time.shape + S, F = features.shape + weekly_strains = dataset["weekly_strains"] # [T, P, S] + assert weekly_strains.shape == (T, P, S) + feature_plate = pyro.plate("feature", F, dim=-1) + strain_plate = pyro.plate("strain", S, dim=-1) + place_plate = pyro.plate("place", P, dim=-2) + time_plate = pyro.plate("time", T, dim=-3) + + # Sample global random variables. + coef_scale = pyro.sample("coef_scale", dist.InverseGamma(5e3, 1e2)) + rate_loc_scale = pyro.sample("rate_loc_scale", dist.LogNormal(-4, 2)) + rate_scale = pyro.sample("rate_scale", dist.LogNormal(-4, 2)) + init_loc_scale = pyro.sample("init_loc_scale", dist.LogNormal(0, 2)) + init_scale = pyro.sample("init_scale", dist.LogNormal(0, 2)) + + with feature_plate: + # FIXME + # coef = pyro.sample("coef", dist.Logistic(0, coef_scale)) # [F] + coef = pyro.sample("coef", dist.Normal(0, coef_scale)) # [F] + rate_loc_loc = 0.01 * coef @ features.T + with strain_plate: + rate_loc = pyro.sample( + "rate_loc", dist.Normal(rate_loc_loc, rate_loc_scale) + ) # [S] + init_loc = pyro.sample("init_loc", dist.Normal(0, init_loc_scale)) # [S] + with place_plate, strain_plate: + rate = pyro.sample("rate", dist.Normal(rate_loc, rate_scale)) # [P, S] + init = pyro.sample("init", dist.Normal(init_loc, init_scale)) # [P, S] + + # Finally observe counts. + with time_plate, place_plate: + logits = (init + rate * local_time)[..., None, :] # [T, P, 1, S] + pyro.sample( + "obs", + dist.Multinomial(logits=logits, validate_args=False), + obs=weekly_strains[..., None, :], + ) + + +@pytest.mark.parametrize( + "model", [pyrocov_model, pyrocov_model_relaxed, pyrocov_model_plated] +) +@pytest.mark.parametrize( + "backend", + [ + # "smoke_test", + "funsor", + ], +) +def test_autogaussian_pyrocov_smoke(model, backend): T, P, S, F = 3, 4, 5, 6 dataset = { "features": torch.randn(S, F), @@ -1282,37 +1378,23 @@ def test_autogaussian_pyrocov(backend): "weekly_strains": torch.randn(T, P, S).exp().round(), } - guide = AutoGaussian(pyrocov_model, backend=backend) - with poutine.mask(mask=False): - guide(dataset) # initialize guide - - # Check automatically determined dependencies. - # Without reparametrization the posterior dependencies are sparse. - _ = set() - expected_dependencies = { - "coef": { - "coef_scale": _, # direct - "rate_scale": _, # moralized - }, - "init_loc": { - "init_loc_scale": _, # direct - "init_scale": _, # moralized - }, - "rate": { - "coef": _, # direct - "rate_scale": _, # direct - }, - "init": { - "init_loc": _, # direct - "init_scale": _, # direct - "rate": _, # moralized - }, - } - assert guide.dependencies == expected_dependencies + guide = AutoGaussian(model, backend=backend) + svi = SVI(model, guide, Adam({"lr": 1e-8}), Trace_ELBO()) + svi.step(dataset) + svi.step(dataset) -@pytest.mark.parametrize("backend", ["smoke_test"]) -def test_structured_pyrocov_reparam(backend): +@pytest.mark.parametrize( + "model", [pyrocov_model, pyrocov_model_relaxed, pyrocov_model_plated] +) +@pytest.mark.parametrize( + "backend", + [ + # "smoke_test", + "funsor", + ], +) +def test_structured_pyrocov_reparam(model, backend): T, P, S, F = 3, 4, 5, 6 dataset = { "features": torch.randn(S, F), @@ -1323,62 +1405,13 @@ def test_structured_pyrocov_reparam(backend): # Reparametrize the model. config = { "coef": LocScaleReparam(), + "rate_loc": LocScaleReparam(), # only in relaxed model "rate": LocScaleReparam(), "init_loc": LocScaleReparam(), "init": LocScaleReparam(), } - model = poutine.reparam(pyrocov_model, config) + model = poutine.reparam(model, config) guide = AutoGaussian(model, backend=backend) - with poutine.mask(mask=False): - guide(dataset) # initialize guide - - # Check automatically determined dependencies. - # With reparametrization the posterior dependencies are dense. - _ = set() - expected_dependencies = { - "rate_scale": { - "coef_scale": _, # reparam moralized - }, - "init_loc_scale": { - "coef_scale": _, # reparam moralized - "rate_scale": _, # reparam moralized - }, - "init_scale": { - "coef_scale": _, # reparam moralized - "rate_scale": _, # reparam moralized - "init_loc_scale": _, # reparam moralized - }, - "coef_decentered": { - "coef_scale": _, # direct - "rate_scale": _, # moralized - "init_loc_scale": _, # reparam moralized - "init_scale": _, # reparam moralized - }, - "init_loc_decentered": { - "init_loc_scale": _, # direct - "init_scale": _, # moralized - "coef_scale": _, # reparam moralized - "rate_scale": _, # reparam moralized - "coef_decentered": _, # reparam moralized - }, - "rate_decentered": { - "coef_decentered": _, # direct - "coef_scale": _, # reparam direct - "rate_scale": _, # direct - "rate_scale": _, # reparam direct - "init_loc_scale": _, # reparam moralized - "init_scale": _, # reparam moralized - "init_loc_decentered": _, # reparam moralized - }, - "init_decentered": { - "init_loc_decentered": _, # direct - "init_scale": _, # direct - "rate_decentered": _, # moralized - "init_loc_scale": _, # reparam direct - "coef_scale": _, # reparam moralized - "rate_scale": _, # reparam moralized - "coef_decentered": _, # reparam moralized - "rate_decentered": _, # reparam moralized - }, - } - assert guide.dependencies == expected_dependencies + svi = SVI(model, guide, Adam({"lr": 1e-8}), Trace_ELBO()) + svi.step(dataset) + svi.step(dataset) From 068b3be0e8ec3e7692d508e73359898b6cee3dfc Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Tue, 14 Sep 2021 14:46:33 -0400 Subject: [PATCH 04/41] Fix some tests --- pyro/distributions/logistic.py | 2 +- pyro/infer/autoguide/guides.py | 6 ++++-- tests/infer/test_autoguide.py | 26 ++++++-------------------- 3 files changed, 11 insertions(+), 23 deletions(-) diff --git a/pyro/distributions/logistic.py b/pyro/distributions/logistic.py index 354aaef933..77be18f700 100644 --- a/pyro/distributions/logistic.py +++ b/pyro/distributions/logistic.py @@ -42,7 +42,7 @@ def __init__(self, loc, scale, *, validate_args=None): super().__init__(self.loc.shape, validate_args=validate_args) def expand(self, batch_shape, _instance=None): - new = self._get_checked_instance(SkewLogistic, _instance) + new = self._get_checked_instance(Logistic, _instance) batch_shape = torch.Size(batch_shape) new.loc = self.loc.expand(batch_shape) new.scale = self.scale.expand(batch_shape) diff --git a/pyro/infer/autoguide/guides.py b/pyro/infer/autoguide/guides.py index 551fbef565..1d3ceaef10 100644 --- a/pyro/infer/autoguide/guides.py +++ b/pyro/infer/autoguide/guides.py @@ -1845,7 +1845,7 @@ def _sample_aux_values_funsor( # Construct TVE problem inputs, converting torch to funsor. factors = {} - plate_to_dim = {} + plate_to_dim = {} # TODO including enclosing particle plates eliminate = set() for d, site in self.latents.items(): inputs = OrderedDict() @@ -1878,6 +1878,7 @@ def _sample_aux_values_funsor( plates=plates, ) log_Z = funsor.optimizer.apply_optimizer(log_Z) + # TODO Support enclosing particle plates. with funsor.montecarlo.MonteCarlo(): samples = funsor.adjoint.adjoint( funsor.ops.logaddexp, funsor.ops.add, log_Z @@ -1891,7 +1892,8 @@ def _sample_aux_values_funsor( } # Compute density. - log_prob = 0.0 + # TODO Avoid recomputing this by obtaining it from adjoint above. + log_prob = -funsor.reinterpret(log_Z) for f in factors.values(): # Substitute samples and eliminate plates. log_prob += f(**samples).reduce(funsor.ops.add) diff --git a/tests/infer/test_autoguide.py b/tests/infer/test_autoguide.py index 98d402334e..ec01f1a008 100644 --- a/tests/infer/test_autoguide.py +++ b/tests/infer/test_autoguide.py @@ -1315,7 +1315,7 @@ def pyrocov_model_relaxed(dataset): ) -# This is modified by relaxing rate from deterministic to latent. +# This is modified by more precisely tracking plates for features and strains. def pyrocov_model_plated(dataset): # Tensor shapes are commented at the end of some lines. features = dataset["features"] @@ -1337,9 +1337,7 @@ def pyrocov_model_plated(dataset): init_scale = pyro.sample("init_scale", dist.LogNormal(0, 2)) with feature_plate: - # FIXME - # coef = pyro.sample("coef", dist.Logistic(0, coef_scale)) # [F] - coef = pyro.sample("coef", dist.Normal(0, coef_scale)) # [F] + coef = pyro.sample("coef", dist.Logistic(0, coef_scale)) # [F] rate_loc_loc = 0.01 * coef @ features.T with strain_plate: rate_loc = pyro.sample( @@ -1363,13 +1361,7 @@ def pyrocov_model_plated(dataset): @pytest.mark.parametrize( "model", [pyrocov_model, pyrocov_model_relaxed, pyrocov_model_plated] ) -@pytest.mark.parametrize( - "backend", - [ - # "smoke_test", - "funsor", - ], -) +@pytest.mark.parametrize("backend", ["funsor"]) def test_autogaussian_pyrocov_smoke(model, backend): T, P, S, F = 3, 4, 5, 6 dataset = { @@ -1387,15 +1379,9 @@ def test_autogaussian_pyrocov_smoke(model, backend): @pytest.mark.parametrize( "model", [pyrocov_model, pyrocov_model_relaxed, pyrocov_model_plated] ) -@pytest.mark.parametrize( - "backend", - [ - # "smoke_test", - "funsor", - ], -) +@pytest.mark.parametrize("backend", ["funsor"]) def test_structured_pyrocov_reparam(model, backend): - T, P, S, F = 3, 4, 5, 6 + T, P, S, F = 2, 3, 4, 5 dataset = { "features": torch.randn(S, F), "local_time": torch.randn(T, P), @@ -1405,7 +1391,7 @@ def test_structured_pyrocov_reparam(model, backend): # Reparametrize the model. config = { "coef": LocScaleReparam(), - "rate_loc": LocScaleReparam(), # only in relaxed model + "rate_loc": None if model is pyrocov_model else LocScaleReparam(), "rate": LocScaleReparam(), "init_loc": LocScaleReparam(), "init": LocScaleReparam(), From 154a765dae799d7ed5207c092824e741e7e4c0fc Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Tue, 14 Sep 2021 17:15:00 -0400 Subject: [PATCH 05/41] Support monte carlo particles --- .../contrib/funsor/handlers/enum_messenger.py | 16 ++-- pyro/infer/autoguide/guides.py | 80 ++++++++----------- pyro/poutine/runtime.py | 14 +++- tests/common.py | 2 + tests/infer/test_autoguide.py | 6 +- tests/poutine/test_runtime.py | 39 +++++++++ 6 files changed, 101 insertions(+), 56 deletions(-) create mode 100644 tests/poutine/test_runtime.py diff --git a/pyro/contrib/funsor/handlers/enum_messenger.py b/pyro/contrib/funsor/handlers/enum_messenger.py index 98e9a7ee33..bf994a0710 100644 --- a/pyro/contrib/funsor/handlers/enum_messenger.py +++ b/pyro/contrib/funsor/handlers/enum_messenger.py @@ -11,6 +11,7 @@ import funsor import torch +from funsor.adjoint import _alpha_unmangle as alpha_unmangle # FIXME publish import pyro.poutine.runtime import pyro.poutine.util @@ -26,6 +27,10 @@ @functools.singledispatch def _get_support_value(funsor_dist, name, **kwargs): + """ + Extracts the sample value out of a funsor Delta, + possibly wrapped in reductions over sample plates. + """ raise ValueError( "Could not extract point from {} at name {}".format(funsor_dist, name) ) @@ -33,13 +38,10 @@ def _get_support_value(funsor_dist, name, **kwargs): @_get_support_value.register(funsor.cnf.Contraction) def _get_support_value_contraction(funsor_dist, name, **kwargs): - delta_terms = [ - v - for v in funsor_dist.terms - if isinstance(v, funsor.delta.Delta) and name in v.fresh - ] - assert len(delta_terms) == 1 - return _get_support_value(delta_terms[0], name, **kwargs) + unmangled_terms = alpha_unmangle(funsor_dist)[-1] + terms = [v for v in unmangled_terms if name in v.inputs] + assert len(terms) == 1 + return _get_support_value(terms[0], name, **kwargs) @_get_support_value.register(funsor.delta.Delta) diff --git a/pyro/infer/autoguide/guides.py b/pyro/infer/autoguide/guides.py index 1d3ceaef10..ea55273eb8 100644 --- a/pyro/infer/autoguide/guides.py +++ b/pyro/infer/autoguide/guides.py @@ -45,7 +45,7 @@ def model(): from pyro.ops.hessian import hessian from pyro.ops.tensor_utils import periodic_repeat from pyro.poutine.indep_messenger import CondIndepStackFrame -from pyro.poutine.runtime import am_i_wrapped +from pyro.poutine.runtime import am_i_wrapped, get_plates from pyro.poutine.util import site_is_subsample from .utils import _product, helpful_support_errors @@ -1718,7 +1718,7 @@ def _setup_prototype(self, *args, **kwargs) -> None: for plates in upstreams.values() for p in plates } - self.latents = { + self.latents: Dict[str, Dict[str, object]] = { name: site for name, site in self.prototype_trace.nodes.items() if site["type"] == "sample" @@ -1790,12 +1790,8 @@ def forward(self, *args, **kwargs) -> Dict[str, torch.Tensor]: log_density = log_density_1 + log_density_2 # Replay via Pyro primitives. - compute_density = am_i_wrapped() and poutine.get_mask() is not False plates = self._create_plates(*args, **kwargs) for name, site in self.latents.items(): - if compute_density: - scale = _deep_getattr(self.scales, name) - log_density = log_density + scale.log().sum() with ExitStack() as stack: for frame in site["cond_indep_stack"]: if frame.vectorized: @@ -1803,36 +1799,40 @@ def forward(self, *args, **kwargs) -> Dict[str, torch.Tensor]: pyro.sample( name, dist.Delta(values[name], event_dim=site["fn"].event_dim) ) - if compute_density: + if am_i_wrapped() and poutine.get_mask() is not False: pyro.factor("AutoGaussian_log_density", log_density) return values @torch.no_grad() def median(self) -> Dict[str, torch.Tensor]: with poutine.mask(mask=False): - aux_values = {name: 0.0 for name in self.dependencies} + aux_values = {name: 0.0 for name in self.latents} values, _ = self._transform_values(aux_values) return values + def _transform_values( + self, + aux_values: Dict[str, torch.Tensor], + ) -> Tuple[Dict[str, torch.Tensor], Union[float, torch.Tensor]]: + # Learnably transform auxiliary values to user-facing values. + values = {} + log_density = 0.0 + compute_density = am_i_wrapped() and poutine.get_mask() is not False + for name, site in self.latents.items(): + loc = _deep_getattr(self.locs, name) + scale = _deep_getattr(self.scales, name) + values[name] = biject_to(site["fn"].support)(aux_values[name] * scale + loc) + if compute_density: + log_density = log_density - scale.log().sum() + + return values, log_density + def _sample_aux_values( self, ) -> Tuple[Dict[str, torch.Tensor], Union[float, torch.tensor]]: # Sample auxiliary values via Gaussian tensor variable elimination. if self.backend == "funsor": return self._sample_aux_values_funsor() - elif self.backend == "smoke_test": - # The following should be equivalent to AutoNormal. - aux_values = {} - log_density = 0.0 - for name, site in self.latents.items(): - aux_values[name] = pyro.sample( - name + "_aux", - dist.Normal( - torch.zeros_like(_deep_getattr(self.locs, name)), 1 - ).to_event(site["fn"].event_dim), - infer={"is_auxiliary": True}, - ) - return aux_values, log_density else: raise ValueError(f"Unknown backend: {self.backend}") @@ -1844,8 +1844,8 @@ def _sample_aux_values_funsor( import pyro.contrib.funsor # Construct TVE problem inputs, converting torch to funsor. + plate_to_dim = {} factors = {} - plate_to_dim = {} # TODO including enclosing particle plates eliminate = set() for d, site in self.latents.items(): inputs = OrderedDict() @@ -1868,7 +1868,7 @@ def _sample_aux_values_funsor( plates = frozenset(plate_to_dim) eliminate = frozenset(eliminate) - # Draw samples via Gaussian tensor variable elimination. + # Compute log normalizer via Gaussian tensor variable elimination. with funsor.interpretations.reflect: log_Z = funsor.sum_product.sum_product( funsor.ops.logaddexp, @@ -1878,11 +1878,16 @@ def _sample_aux_values_funsor( plates=plates, ) log_Z = funsor.optimizer.apply_optimizer(log_Z) - # TODO Support enclosing particle plates. - with funsor.montecarlo.MonteCarlo(): + + # Draw a batch of samples. + particle_plates = frozenset(get_plates()) + plate_to_dim.update({f.name: f.dim for f in particle_plates}) + sample_inputs = {f.name: funsor.Bint[f.size] for f in particle_plates} + with funsor.montecarlo.MonteCarlo(**sample_inputs): samples = funsor.adjoint.adjoint( funsor.ops.logaddexp, funsor.ops.add, log_Z ) + # Extract funsor.Tensor values from funsor.Delta samples. samples = { name: pyro.contrib.funsor.handlers.enum_messenger._get_support_value( @@ -1895,31 +1900,16 @@ def _sample_aux_values_funsor( # TODO Avoid recomputing this by obtaining it from adjoint above. log_prob = -funsor.reinterpret(log_Z) for f in factors.values(): - # Substitute samples and eliminate plates. - log_prob += f(**samples).reduce(funsor.ops.add) + term = f(**samples) + term = term.reduce(funsor.ops.add, eliminate.intersection(term.inputs)) + log_prob += term + assert all(f.name in log_prob.inputs for f in particle_plates) # Convert funsor to torch. samples = { k: funsor.to_data(v[self._broken_plates[k]], name_to_dim=plate_to_dim) for k, v in samples.items() } - log_density = funsor.to_data(log_prob) + log_density = funsor.to_data(log_prob, name_to_dim=plate_to_dim) return samples, log_density - - def _transform_values( - self, - aux_values: Dict[str, torch.Tensor], - ) -> Tuple[Dict[str, torch.Tensor], Union[float, torch.Tensor]]: - # Learnably transform auxiliary values to user-facing values. - compute_density = am_i_wrapped() and poutine.get_mask() is not False - values = {} - log_density = 0.0 - for name, site in self.latents.items(): - loc = _deep_getattr(self.locs, name) - scale = _deep_getattr(self.scales, name) - values[name] = biject_to(site["fn"].support)(aux_values[name] * scale + loc) - if compute_density: - log_density = log_density + scale.log().sum() - - return values, log_density diff --git a/pyro/poutine/runtime.py b/pyro/poutine/runtime.py index e5a980c895..59b27c8911 100644 --- a/pyro/poutine/runtime.py +++ b/pyro/poutine/runtime.py @@ -2,6 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 import functools +from typing import Dict from pyro.params.param_store import ( # noqa: F401 _MODULE_NAMESPACE_DIVIDER, @@ -286,7 +287,7 @@ def _fn(*args, **kwargs): return _fn -def _inspect(): +def _inspect() -> Dict[str, object]: """ EXPERIMENTAL Inspect the Pyro stack. @@ -334,3 +335,14 @@ def model(): :rtype: None, bool, or torch.Tensor """ return _inspect()["mask"] + + +def get_plates() -> tuple: + """ + Records the effects of enclosing ``pyro.plate`` contexts. + + :returns: A tuple of + :class:`pyro.poutine.indep_messenger.CondIndepStackFrame` objects. + :rtype: tuple + """ + return _inspect()["cond_indep_stack"] diff --git a/tests/common.py b/tests/common.py index 100bc83cf6..28708ba8b4 100644 --- a/tests/common.py +++ b/tests/common.py @@ -145,6 +145,8 @@ def assert_tensors_equal(a, b, prec=0.0, msg=""): return b = b.type_as(a) b = b.cuda(device=a.get_device()) if a.is_cuda else b.cpu() + if not a.dtype.is_floating_point: + return (a == b).all() # check that NaNs are in the same locations nan_mask = a != a assert torch.equal(nan_mask, b != b), msg diff --git a/tests/infer/test_autoguide.py b/tests/infer/test_autoguide.py index ec01f1a008..31cfc4e559 100644 --- a/tests/infer/test_autoguide.py +++ b/tests/infer/test_autoguide.py @@ -40,7 +40,7 @@ from pyro.optim import Adam from pyro.poutine.util import prune_subsample_sites from pyro.util import check_model_guide_match -from tests.common import assert_close, assert_equal +from tests.common import assert_close, assert_equal, xfail_param @pytest.mark.parametrize( @@ -383,7 +383,7 @@ def model(): functools.partial(AutoDiagonalNormal, init_loc_fn=init_to_sample), AutoStructured, AutoStructured_median, - AutoGaussian, + xfail_param(AutoGaussian, reason="not jit compatible"), ], ) @pytest.mark.parametrize("Elbo", [Trace_ELBO, TraceGraph_ELBO, TraceEnum_ELBO]) @@ -843,7 +843,7 @@ def __init__(self, model): functools.partial(AutoDiagonalNormal, init_loc_fn=init_to_median), AutoStructured, AutoStructured_predictive, - AutoGaussian, + xfail_param(AutoGaussian, reason="not jit compatible"), ], ) def test_predictive(auto_class): diff --git a/tests/poutine/test_runtime.py b/tests/poutine/test_runtime.py new file mode 100644 index 0000000000..1cd4395287 --- /dev/null +++ b/tests/poutine/test_runtime.py @@ -0,0 +1,39 @@ +# Copyright Contributors to the Pyro project. +# SPDX-License-Identifier: Apache-2.0 + +import torch + +import pyro +import pyro.poutine as poutine +from pyro.poutine.runtime import get_mask, get_plates +from tests.common import assert_equal + + +def test_get_mask(): + assert get_mask() is None + + with poutine.mask(mask=True): + assert get_mask() is True + with poutine.mask(mask=False): + assert get_mask() is False + + with pyro.plate("i", 2, dim=-1): + mask1 = torch.tensor([False, True, True]) + mask2 = torch.tensor([True, True, False]) + with poutine.mask(mask=mask1): + assert_equal(get_mask(), mask1) + with poutine.mask(mask=mask2): + assert_equal(get_mask(), mask1 & mask2) + + +def test_get_plates(): + def get_plate_names(): + plates = get_plates() + assert isinstance(plates, tuple) + return {f.name for f in plates} + + assert get_plate_names() == set() + with pyro.plate("foo", 5): + assert get_plate_names() == {"foo"} + with pyro.plate("bar", 3): + assert get_plate_names() == {"foo", "bar"} From 586a4175ef13ec41dd97b5dbffc95f908a75009b Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Wed, 15 Sep 2021 09:20:02 -0400 Subject: [PATCH 06/41] Tweak variable order of cholesky parametrization --- pyro/infer/autoguide/guides.py | 39 ++++++++++++++++++++++------------ 1 file changed, 25 insertions(+), 14 deletions(-) diff --git a/pyro/infer/autoguide/guides.py b/pyro/infer/autoguide/guides.py index ea55273eb8..04de77ccfc 100644 --- a/pyro/infer/autoguide/guides.py +++ b/pyro/infer/autoguide/guides.py @@ -1654,9 +1654,7 @@ class AutoGaussian(AutoGuide): but with a sparse precision matrix determined by dependencies and plates in the model. This can be orders of magnitude cheaper than the naive :class:`AutoMultivariateNormal` in terms of space, time, number of - parameters, and statistical complexity. This also parametrizes correlations - via precision matrices rather than Cholesky factors, and therefore is less - susceptible to variable ordering in the model. + parameters, and statistical complexity. The guide currently does not depend on the model's ``*args, **kwargs``. @@ -1680,6 +1678,17 @@ class AutoGaussian(AutoGuide): ``pip install pyro-ppl[funsor]``. """ + backend: str + locs: PyroModule + scales: PyroModule + precisions: PyroModule + latents: Dict[str, Dict[str, object]] + _init_scale: float + _original_model: Tuple[Callable] + _unconstrained_event_shapes: Dict[str, torch.Size] + _broken_event_shapes: Dict[str, torch.Size] + _broken_plates: Dict[str, Tuple[str, ...]] + scale_constraint = constraints.softplus_positive def __init__( @@ -1705,9 +1714,9 @@ def _setup_prototype(self, *args, **kwargs) -> None: self.locs = PyroModule() self.scales = PyroModule() self.precisions = PyroModule() - self._unconstrained_event_shapes: Dict[str, torch.Size] = {} - self._broken_event_shapes: Dict[str, torch.Size] = {} - self._broken_plates: Dict[str, Tuple[str, ...]] = defaultdict(tuple) + self._unconstrained_event_shapes = {} + self._broken_event_shapes = {} + self._broken_plates = defaultdict(tuple) model = self._original_model[0] meta = poutine.block(get_dependencies)(model, args, kwargs) @@ -1718,7 +1727,7 @@ def _setup_prototype(self, *args, **kwargs) -> None: for plates in upstreams.values() for p in plates } - self.latents: Dict[str, Dict[str, object]] = { + self.latents = { name: site for name, site in self.prototype_trace.nodes.items() if site["type"] == "sample" @@ -1829,7 +1838,7 @@ def _transform_values( def _sample_aux_values( self, - ) -> Tuple[Dict[str, torch.Tensor], Union[float, torch.tensor]]: + ) -> Tuple[Dict[str, torch.Tensor], Union[float, torch.Tensor]]: # Sample auxiliary values via Gaussian tensor variable elimination. if self.backend == "funsor": return self._sample_aux_values_funsor() @@ -1838,15 +1847,15 @@ def _sample_aux_values( def _sample_aux_values_funsor( self, - ) -> Tuple[Dict[str, torch.Tensor], Union[float, torch.tensor]]: + ) -> Tuple[Dict[str, torch.Tensor], Union[float, torch.Tensor]]: import funsor import pyro.contrib.funsor # Construct TVE problem inputs, converting torch to funsor. - plate_to_dim = {} factors = {} eliminate = set() + plate_to_dim = {} for d, site in self.latents.items(): inputs = OrderedDict() for f in site["cond_indep_stack"]: @@ -1855,12 +1864,14 @@ def _sample_aux_values_funsor( if f.name not in self._broken_plates[d]: inputs[f.name] = funsor.Bint[f.size] eliminate.add(f.name) - if not site["is_observed"]: - inputs[d] = funsor.Reals[self._broken_event_shapes[d]] - eliminate.add(d) + # Order inputs as in the model, so as to maximize sparsity of the + # lower Cholesky parametrization of the precision matrix. for u in self.dependencies[d]: inputs[u] = funsor.Reals[self._broken_event_shapes[u]] - assert u in eliminate + eliminate.add(u) + if not site["is_observed"]: + inputs[d] = funsor.Reals[self._broken_event_shapes[d]] + assert d in eliminate precision = _deep_getattr(self.precisions, d) info_vec = precision.new_zeros(precision.shape[:-1]) From 58702bdf36782e72b70160e9479ef4724047120b Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Wed, 15 Sep 2021 09:44:41 -0400 Subject: [PATCH 07/41] Register docstring --- docs/source/infer.autoguide.rst | 7 +++++++ pyro/infer/autoguide/guides.py | 24 +++++++++++++++++++----- 2 files changed, 26 insertions(+), 5 deletions(-) diff --git a/docs/source/infer.autoguide.rst b/docs/source/infer.autoguide.rst index 80e4db225c..e647ce9766 100644 --- a/docs/source/infer.autoguide.rst +++ b/docs/source/infer.autoguide.rst @@ -117,6 +117,13 @@ AutoStructured :special-members: __call__ :show-inheritance: +AutoGaussian +------------ +.. autoclass:: pyro.infer.autoguide.AutoGaussian + :members: + :undoc-members: + :show-inheritance: + .. _autoguide-initialization: Initialization diff --git a/pyro/infer/autoguide/guides.py b/pyro/infer/autoguide/guides.py index 04de77ccfc..fd550c9c30 100644 --- a/pyro/infer/autoguide/guides.py +++ b/pyro/infer/autoguide/guides.py @@ -1648,7 +1648,7 @@ def median(self, *args, **kwargs): class AutoGaussian(AutoGuide): """ - EXPERIMENTAL Gaussian Markov random field guide. + EXPERIMENTAL Gaussian tensor variable elimination guide [1,2]. This is equivalent to a full rank :class:`AutoMultivariateNormal` guide, but with a sparse precision matrix determined by dependencies and plates in @@ -1658,11 +1658,26 @@ class AutoGaussian(AutoGuide): The guide currently does not depend on the model's ``*args, **kwargs``. - Usage:: + Example:: guide = AutoGaussian(model) svi = SVI(model, guide, ...) + .. warning: This currently supports ``backend=funsor`` which depends on + the funsor package. You can install via + ``pip install pyro-ppl[funsor]``. + + **References** + + [1] F. Obermeyer, E. Bingham, M. Jankowiak, J. Chiu, N. Pradhan, A. M. Rush, N. Goodman + (2019) + "Tensor Variable Elimination for Plated Factor Graphs" + http://proceedings.mlr.press/v97/obermeyer19a/obermeyer19a.pdf + [2] F. Obermeyer, E. Bingham, M. Jankowiak, D. Phan, J. P. Chen + (2019) + "Functional Tensors for Probabilistic Programming" + https://arxiv.org/abs/1910.10775 + :param callable model: A Pyro model. :param callable init_loc_fn: A per-site initialization function. See :ref:`autoguide-initialization` section for available functions. @@ -1673,9 +1688,8 @@ class AutoGaussian(AutoGuide): or iterable of plates. Plates not returned will be created automatically as usual. This is useful for data subsampling. :param str backend: Back end for performing Gaussian tensor variable - elimination. Currently only the experimental "funsor" backend is supported, - and this requires the optional ``funsor`` dependency, installed via - ``pip install pyro-ppl[funsor]``. + elimination. Currently only the experimental "funsor" backend is + supported. """ backend: str From ece747373f577add437cac5168e2008b928e58d1 Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Wed, 15 Sep 2021 18:06:22 -0400 Subject: [PATCH 08/41] Split up methods, split up tests, add new tests, mark stage=funsor --- pyro/infer/autoguide/guides.py | 71 ++++--- pyro/infer/inspect.py | 4 + tests/infer/autoguide/__init__.py | 0 tests/infer/autoguide/conftest.py | 13 ++ tests/infer/autoguide/test_autogaussian.py | 234 +++++++++++++++++++++ tests/infer/test_autoguide.py | 205 ++---------------- 6 files changed, 315 insertions(+), 212 deletions(-) create mode 100644 tests/infer/autoguide/__init__.py create mode 100644 tests/infer/autoguide/conftest.py create mode 100644 tests/infer/autoguide/test_autogaussian.py diff --git a/pyro/infer/autoguide/guides.py b/pyro/infer/autoguide/guides.py index fd550c9c30..945bf1f51c 100644 --- a/pyro/infer/autoguide/guides.py +++ b/pyro/infer/autoguide/guides.py @@ -1696,7 +1696,7 @@ class AutoGaussian(AutoGuide): locs: PyroModule scales: PyroModule precisions: PyroModule - latents: Dict[str, Dict[str, object]] + _sorted_sites: Dict[str, Dict[str, object]] _init_scale: float _original_model: Tuple[Callable] _unconstrained_event_shapes: Dict[str, torch.Size] @@ -1741,13 +1741,13 @@ def _setup_prototype(self, *args, **kwargs) -> None: for plates in upstreams.values() for p in plates } - self.latents = { + self._sorted_sites = { name: site for name, site in self.prototype_trace.nodes.items() if site["type"] == "sample" if not site_is_subsample(site) } - for d, site in self.latents.items(): + for d, site in self._sorted_sites.items(): precision_size = 0 precision_plates: Set[CondIndepStackFrame] = set() if not site["is_observed"]: @@ -1804,6 +1804,9 @@ def _setup_prototype(self, *args, **kwargs) -> None: ), ) + if self.backend == "funsor": + self._funsor_setup_prototype(*args, **kwargs) + def forward(self, *args, **kwargs) -> Dict[str, torch.Tensor]: if self.prototype_trace is None: self._setup_prototype(*args, **kwargs) @@ -1814,7 +1817,7 @@ def forward(self, *args, **kwargs) -> Dict[str, torch.Tensor]: # Replay via Pyro primitives. plates = self._create_plates(*args, **kwargs) - for name, site in self.latents.items(): + for name, site in self._sorted_sites.items(): with ExitStack() as stack: for frame in site["cond_indep_stack"]: if frame.vectorized: @@ -1823,13 +1826,12 @@ def forward(self, *args, **kwargs) -> Dict[str, torch.Tensor]: name, dist.Delta(values[name], event_dim=site["fn"].event_dim) ) if am_i_wrapped() and poutine.get_mask() is not False: - pyro.factor("AutoGaussian_log_density", log_density) + pyro.factor("AutoGaussian.log_density", log_density) return values - @torch.no_grad() def median(self) -> Dict[str, torch.Tensor]: - with poutine.mask(mask=False): - aux_values = {name: 0.0 for name in self.latents} + with torch.no_grad(), poutine.mask(mask=False): + aux_values = {name: 0.0 for name in self._sorted_sites} values, _ = self._transform_values(aux_values) return values @@ -1841,7 +1843,7 @@ def _transform_values( values = {} log_density = 0.0 compute_density = am_i_wrapped() and poutine.get_mask() is not False - for name, site in self.latents.items(): + for name, site in self._sorted_sites.items(): loc = _deep_getattr(self.locs, name) scale = _deep_getattr(self.scales, name) values[name] = biject_to(site["fn"].support)(aux_values[name] * scale + loc) @@ -1855,22 +1857,21 @@ def _sample_aux_values( ) -> Tuple[Dict[str, torch.Tensor], Union[float, torch.Tensor]]: # Sample auxiliary values via Gaussian tensor variable elimination. if self.backend == "funsor": - return self._sample_aux_values_funsor() + return self._funsor_sample_aux_values() else: raise ValueError(f"Unknown backend: {self.backend}") - def _sample_aux_values_funsor( - self, - ) -> Tuple[Dict[str, torch.Tensor], Union[float, torch.Tensor]]: + def _funsor_setup_prototype(self, *args, **kwargs): import funsor - import pyro.contrib.funsor + # Determine TVE problem shape. + factor_inputs: Dict[str, OrderedDict[str, funsor.Domain]] = {} + eliminate: Set[str] = set() + plate_to_dim: Dict[str, int] = {} - # Construct TVE problem inputs, converting torch to funsor. - factors = {} - eliminate = set() - plate_to_dim = {} - for d, site in self.latents.items(): + for d, site in self._sorted_sites.items(): + # Order inputs as in the model, so as to maximize sparsity of the + # lower Cholesky parametrization of the precision matrix. inputs = OrderedDict() for f in site["cond_indep_stack"]: if f.vectorized: @@ -1878,20 +1879,32 @@ def _sample_aux_values_funsor( if f.name not in self._broken_plates[d]: inputs[f.name] = funsor.Bint[f.size] eliminate.add(f.name) - # Order inputs as in the model, so as to maximize sparsity of the - # lower Cholesky parametrization of the precision matrix. for u in self.dependencies[d]: inputs[u] = funsor.Reals[self._broken_event_shapes[u]] eliminate.add(u) if not site["is_observed"]: inputs[d] = funsor.Reals[self._broken_event_shapes[d]] assert d in eliminate + factor_inputs[d] = inputs + + self._funsor_factor_inputs = factor_inputs + self._funsor_eliminate = frozenset(eliminate) + self._funsor_plate_to_dim = plate_to_dim + self._funsor_plates = frozenset(plate_to_dim) + + def _funsor_sample_aux_values( + self, + ) -> Tuple[Dict[str, torch.Tensor], Union[float, torch.Tensor]]: + import funsor + + import pyro.contrib.funsor + # Construct TVE problem inputs, converting torch to funsor. + factors = {} + for d, inputs in self._funsor_factor_inputs.items(): precision = _deep_getattr(self.precisions, d) - info_vec = precision.new_zeros(precision.shape[:-1]) + info_vec = precision.new_zeros(()).expand(precision.shape[:-1]) factors[d] = funsor.gaussian.Gaussian(info_vec, precision, inputs) - plates = frozenset(plate_to_dim) - eliminate = frozenset(eliminate) # Compute log normalizer via Gaussian tensor variable elimination. with funsor.interpretations.reflect: @@ -1899,13 +1912,14 @@ def _sample_aux_values_funsor( funsor.ops.logaddexp, funsor.ops.add, list(factors.values()), - eliminate=eliminate, - plates=plates, + eliminate=self._funsor_eliminate, + plates=self._funsor_plates, ) log_Z = funsor.optimizer.apply_optimizer(log_Z) # Draw a batch of samples. particle_plates = frozenset(get_plates()) + plate_to_dim = self._funsor_plate_to_dim.copy() plate_to_dim.update({f.name: f.dim for f in particle_plates}) sample_inputs = {f.name: funsor.Bint[f.size] for f in particle_plates} with funsor.montecarlo.MonteCarlo(**sample_inputs): @@ -1918,7 +1932,7 @@ def _sample_aux_values_funsor( name: pyro.contrib.funsor.handlers.enum_messenger._get_support_value( samples[factors[name]], name ) - for name in eliminate - plates + for name in self._funsor_eliminate - self._funsor_plates } # Compute density. @@ -1926,7 +1940,8 @@ def _sample_aux_values_funsor( log_prob = -funsor.reinterpret(log_Z) for f in factors.values(): term = f(**samples) - term = term.reduce(funsor.ops.add, eliminate.intersection(term.inputs)) + plates = self._funsor_eliminate.intersection(term.inputs) + term = term.reduce(funsor.ops.add, plates) log_prob += term assert all(f.name in log_prob.inputs for f in particle_plates) diff --git a/pyro/infer/inspect.py b/pyro/infer/inspect.py index ccb642ebe2..291a3f6aa2 100644 --- a/pyro/infer/inspect.py +++ b/pyro/infer/inspect.py @@ -28,6 +28,10 @@ def is_sample_site(msg): if type(fn).__name__ == "Delta": return False + # Exclude factor statements. + if type(fn).__name__ == "Unit": + return False + return True diff --git a/tests/infer/autoguide/__init__.py b/tests/infer/autoguide/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/infer/autoguide/conftest.py b/tests/infer/autoguide/conftest.py new file mode 100644 index 0000000000..17f112a1ad --- /dev/null +++ b/tests/infer/autoguide/conftest.py @@ -0,0 +1,13 @@ +# Copyright (c) 2017-2019 Uber Technologies, Inc. +# SPDX-License-Identifier: Apache-2.0 + +import pytest + + +def pytest_collection_modifyitems(items): + for item in items: + if item.nodeid.startswith("tests/infer/autoguide"): + if "stage" not in item.keywords: + item.add_marker(pytest.mark.stage("unit")) + if "init" not in item.keywords: + item.add_marker(pytest.mark.init(rng_seed=123)) diff --git a/tests/infer/autoguide/test_autogaussian.py b/tests/infer/autoguide/test_autogaussian.py new file mode 100644 index 0000000000..900799c0bc --- /dev/null +++ b/tests/infer/autoguide/test_autogaussian.py @@ -0,0 +1,234 @@ +# Copyright Contributors to the Pyro project. +# SPDX-License-Identifier: Apache-2.0 + +import pytest +import torch + +import pyro +import pyro.distributions as dist +import pyro.poutine as poutine +from pyro.infer import SVI, Predictive, Trace_ELBO +from pyro.infer.autoguide import AutoGaussian +from pyro.infer.reparam import LocScaleReparam +from pyro.optim import Adam + +# AutoGaussian currently depends on funsor. +pytestmark = pytest.mark.stage("funsor") + + +# Simplified from https://github.com/pyro-cov/tree/master/pyrocov/mutrans.py +def pyrocov_model(dataset): + # Tensor shapes are commented at the end of some lines. + features = dataset["features"] + local_time = dataset["local_time"][..., None] # [T, P, 1] + T, P, _ = local_time.shape + S, F = features.shape + weekly_strains = dataset["weekly_strains"] + assert weekly_strains.shape == (T, P, S) + + # Sample global random variables. + coef_scale = pyro.sample("coef_scale", dist.InverseGamma(5e3, 1e2))[..., None] + rate_scale = pyro.sample("rate_scale", dist.LogNormal(-4, 2))[..., None] + init_loc_scale = pyro.sample("init_loc_scale", dist.LogNormal(0, 2))[..., None] + init_scale = pyro.sample("init_scale", dist.LogNormal(0, 2))[..., None] + + # Assume relative growth rate depends strongly on mutations and weakly on place. + coef_loc = torch.zeros(F) + coef = pyro.sample("coef", dist.Logistic(coef_loc, coef_scale).to_event(1)) # [F] + rate_loc = pyro.deterministic( + "rate_loc", 0.01 * coef @ features.T, event_dim=1 + ) # [S] + + # Assume initial infections depend strongly on strain and place. + init_loc = pyro.sample( + "init_loc", dist.Normal(torch.zeros(S), init_loc_scale).to_event(1) + ) # [S] + with pyro.plate("place", P, dim=-1): + rate = pyro.sample( + "rate", dist.Normal(rate_loc, rate_scale).to_event(1) + ) # [P, S] + init = pyro.sample( + "init", dist.Normal(init_loc, init_scale).to_event(1) + ) # [P, S] + + # Finally observe counts. + with pyro.plate("time", T, dim=-2): + logits = init + rate * local_time # [T, P, S] + pyro.sample( + "obs", + dist.Multinomial(logits=logits, validate_args=False), + obs=weekly_strains, + ) + + +# This is modified by relaxing rate from deterministic to latent. +def pyrocov_model_relaxed(dataset): + # Tensor shapes are commented at the end of some lines. + features = dataset["features"] + local_time = dataset["local_time"][..., None] # [T, P, 1] + T, P, _ = local_time.shape + S, F = features.shape + weekly_strains = dataset["weekly_strains"] + assert weekly_strains.shape == (T, P, S) + + # Sample global random variables. + coef_scale = pyro.sample("coef_scale", dist.InverseGamma(5e3, 1e2))[..., None] + rate_loc_scale = pyro.sample("rate_loc_scale", dist.LogNormal(-4, 2))[..., None] + rate_scale = pyro.sample("rate_scale", dist.LogNormal(-4, 2))[..., None] + init_loc_scale = pyro.sample("init_loc_scale", dist.LogNormal(0, 2))[..., None] + init_scale = pyro.sample("init_scale", dist.LogNormal(0, 2))[..., None] + + # Assume relative growth rate depends strongly on mutations and weakly on place. + coef_loc = torch.zeros(F) + coef = pyro.sample("coef", dist.Logistic(coef_loc, coef_scale).to_event(1)) # [F] + rate_loc = pyro.sample( + "rate_loc", + dist.Normal(0.01 * coef @ features.T, rate_loc_scale).to_event(1), + ) # [S] + + # Assume initial infections depend strongly on strain and place. + init_loc = pyro.sample( + "init_loc", dist.Normal(torch.zeros(S), init_loc_scale).to_event(1) + ) # [S] + with pyro.plate("place", P, dim=-1): + rate = pyro.sample( + "rate", dist.Normal(rate_loc, rate_scale).to_event(1) + ) # [P, S] + init = pyro.sample( + "init", dist.Normal(init_loc, init_scale).to_event(1) + ) # [P, S] + + # Finally observe counts. + with pyro.plate("time", T, dim=-2): + logits = init + rate * local_time # [T, P, S] + pyro.sample( + "obs", + dist.Multinomial(logits=logits, validate_args=False), + obs=weekly_strains, + ) + + +# This is modified by more precisely tracking plates for features and strains. +def pyrocov_model_plated(dataset): + # Tensor shapes are commented at the end of some lines. + features = dataset["features"] + local_time = dataset["local_time"][..., None] # [T, P, 1] + T, P, _ = local_time.shape + S, F = features.shape + weekly_strains = dataset["weekly_strains"] # [T, P, S] + assert weekly_strains.shape == (T, P, S) + feature_plate = pyro.plate("feature", F, dim=-1) + strain_plate = pyro.plate("strain", S, dim=-1) + place_plate = pyro.plate("place", P, dim=-2) + time_plate = pyro.plate("time", T, dim=-3) + + # Sample global random variables. + coef_scale = pyro.sample("coef_scale", dist.InverseGamma(5e3, 1e2)) + rate_loc_scale = pyro.sample("rate_loc_scale", dist.LogNormal(-4, 2)) + rate_scale = pyro.sample("rate_scale", dist.LogNormal(-4, 2)) + init_loc_scale = pyro.sample("init_loc_scale", dist.LogNormal(0, 2)) + init_scale = pyro.sample("init_scale", dist.LogNormal(0, 2)) + + with feature_plate: + coef = pyro.sample("coef", dist.Logistic(0, coef_scale)) # [F] + rate_loc_loc = 0.01 * coef @ features.T + with strain_plate: + rate_loc = pyro.sample( + "rate_loc", dist.Normal(rate_loc_loc, rate_loc_scale) + ) # [S] + init_loc = pyro.sample("init_loc", dist.Normal(0, init_loc_scale)) # [S] + with place_plate, strain_plate: + rate = pyro.sample("rate", dist.Normal(rate_loc, rate_scale)) # [P, S] + init = pyro.sample("init", dist.Normal(init_loc, init_scale)) # [P, S] + + # Finally observe counts. + with time_plate, place_plate: + logits = (init + rate * local_time)[..., None, :] # [T, P, 1, S] + pyro.sample( + "obs", + dist.Multinomial(logits=logits, validate_args=False), + obs=weekly_strains[..., None, :], + ) + + +@pytest.mark.parametrize( + "model", [pyrocov_model, pyrocov_model_relaxed, pyrocov_model_plated] +) +@pytest.mark.parametrize("backend", ["funsor"]) +def test_autogaussian_pyrocov_smoke(model, backend): + T, P, S, F = 3, 4, 5, 6 + dataset = { + "features": torch.randn(S, F), + "local_time": torch.randn(T, P), + "weekly_strains": torch.randn(T, P, S).exp().round(), + } + + guide = AutoGaussian(model, backend=backend) + svi = SVI(model, guide, Adam({"lr": 1e-8}), Trace_ELBO()) + for step in range(2): + svi.step(dataset) + guide(dataset) + predictive = Predictive(model, guide=guide, num_samples=2) + predictive(dataset) + + +@pytest.mark.parametrize( + "model", [pyrocov_model, pyrocov_model_relaxed, pyrocov_model_plated] +) +@pytest.mark.parametrize("backend", ["funsor"]) +def test_structured_pyrocov_reparam(model, backend): + T, P, S, F = 2, 3, 4, 5 + dataset = { + "features": torch.randn(S, F), + "local_time": torch.randn(T, P), + "weekly_strains": torch.randn(T, P, S).exp().round(), + } + + # Reparametrize the model. + config = { + "coef": LocScaleReparam(), + "rate_loc": None if model is pyrocov_model else LocScaleReparam(), + "rate": LocScaleReparam(), + "init_loc": LocScaleReparam(), + "init": LocScaleReparam(), + } + model = poutine.reparam(model, config) + guide = AutoGaussian(model, backend=backend) + svi = SVI(model, guide, Adam({"lr": 1e-8}), Trace_ELBO()) + for step in range(2): + svi.step(dataset) + guide(dataset) + predictive = Predictive(model, guide=guide, num_samples=2) + predictive(dataset) + + +def test_profile(n=1, num_steps=1): + """ + Helper function for profiling. + """ + model = pyrocov_model_plated + T, P, S, F = 2 * n, 3 * n, 4 * n, 5 * n + dataset = { + "features": torch.randn(S, F), + "local_time": torch.randn(T, P), + "weekly_strains": torch.randn(T, P, S).exp().round(), + } + + guide = AutoGaussian(model) + svi = SVI(model, guide, Adam({"lr": 1e-8}), Trace_ELBO()) + guide(dataset) # initialize + print("Factor inputs:") + for name, inputs in guide._funsor_factor_inputs.items(): + print(f" {name}:") + for k, v in inputs.items(): + print(f" {k}: {v}") + print("Parameter shapes:") + for name, param in guide.named_parameters(): + print(f" {name}: {tuple(param.shape)}") + + for step in range(num_steps): + svi.step(dataset) + + +if __name__ == "__main__": + test_profile(n=10, num_steps=100) diff --git a/tests/infer/test_autoguide.py b/tests/infer/test_autoguide.py index 31cfc4e559..fc011edef6 100644 --- a/tests/infer/test_autoguide.py +++ b/tests/infer/test_autoguide.py @@ -35,12 +35,15 @@ init_to_median, init_to_sample, ) -from pyro.infer.reparam import LocScaleReparam, ProjectedNormalReparam +from pyro.infer.reparam import ProjectedNormalReparam from pyro.nn.module import PyroModule, PyroParam, PyroSample from pyro.optim import Adam from pyro.poutine.util import prune_subsample_sites from pyro.util import check_model_guide_match -from tests.common import assert_close, assert_equal, xfail_param +from tests.common import assert_close, assert_equal + +# AutoGaussian currently depends on funsor. +AutoGaussian = pytest.param(AutoGaussian, marks=[pytest.mark.stage("funsor")]) @pytest.mark.parametrize( @@ -88,6 +91,7 @@ def model(): AutoLowRankMultivariateNormal, AutoIAFNormal, AutoLaplaceApproximation, + AutoGaussian, ], ) def test_factor(auto_class, Elbo): @@ -222,6 +226,7 @@ def model(): AutoLowRankMultivariateNormal, AutoIAFNormal, AutoLaplaceApproximation, + AutoGaussian, ], ) @pytest.mark.parametrize("Elbo", [Trace_ELBO, TraceGraph_ELBO]) @@ -383,7 +388,6 @@ def model(): functools.partial(AutoDiagonalNormal, init_loc_fn=init_to_sample), AutoStructured, AutoStructured_median, - xfail_param(AutoGaussian, reason="not jit compatible"), ], ) @pytest.mark.parametrize("Elbo", [Trace_ELBO, TraceGraph_ELBO, TraceEnum_ELBO]) @@ -503,6 +507,7 @@ def model(): AutoLowRankMultivariateNormal, AutoIAFNormal, AutoLaplaceApproximation, + AutoGaussian, ], ) def test_discrete_parallel(continuous_class): @@ -538,6 +543,7 @@ def model(data): AutoLowRankMultivariateNormal, AutoIAFNormal, AutoLaplaceApproximation, + AutoGaussian, ], ) def test_guide_list(auto_class): @@ -560,6 +566,7 @@ def model(): AutoMultivariateNormal, AutoLowRankMultivariateNormal, AutoLaplaceApproximation, + AutoGaussian, ], ) def test_callable(auto_class): @@ -582,11 +589,13 @@ def guide_x(): "auto_class", [ AutoDelta, + AutoNormal, AutoDiagonalNormal, AutoMultivariateNormal, AutoNormal, AutoLowRankMultivariateNormal, AutoLaplaceApproximation, + AutoGaussian, ], ) def test_callable_return_dict(auto_class): @@ -629,9 +638,11 @@ def model(): "auto_class", [ AutoDelta, + AutoNormal, AutoDiagonalNormal, AutoMultivariateNormal, AutoLowRankMultivariateNormal, + AutoGaussian, ], ) def test_init_loc_fn(auto_class): @@ -694,6 +705,7 @@ def model(): auto_guide_module_callable, functools.partial(AutoDiagonalNormal, init_loc_fn=init_to_mean), functools.partial(AutoDiagonalNormal, init_loc_fn=init_to_median), + AutoGaussian, ], ) @pytest.mark.parametrize("Elbo", [Trace_ELBO, TraceGraph_ELBO, TraceEnum_ELBO]) @@ -779,6 +791,7 @@ def forward(self): AutoLaplaceApproximation, functools.partial(AutoDiagonalNormal, init_loc_fn=init_to_mean), functools.partial(AutoDiagonalNormal, init_loc_fn=init_to_median), + AutoGaussian, ], ) @pytest.mark.parametrize("Elbo", [Trace_ELBO, TraceGraph_ELBO, TraceEnum_ELBO]) @@ -843,7 +856,6 @@ def __init__(self, model): functools.partial(AutoDiagonalNormal, init_loc_fn=init_to_median), AutoStructured, AutoStructured_predictive, - xfail_param(AutoGaussian, reason="not jit compatible"), ], ) def test_predictive(auto_class): @@ -1020,6 +1032,7 @@ def create_plates(data): AutoNormal, AutoLowRankMultivariateNormal, AutoLaplaceApproximation, + AutoGaussian, ], ) @pytest.mark.parametrize( @@ -1054,6 +1067,7 @@ def model(): AutoNormal, AutoLowRankMultivariateNormal, AutoLaplaceApproximation, + AutoGaussian, ], ) @pytest.mark.parametrize( @@ -1085,6 +1099,7 @@ def model(): AutoNormal, AutoLowRankMultivariateNormal, AutoLaplaceApproximation, + AutoGaussian, ], ) @pytest.mark.parametrize( @@ -1153,6 +1168,7 @@ def __init__(self, model): AutoMultivariateNormal, AutoStructured_exact_normal, AutoStructured_exact_mvn, + AutoGaussian, ], ) def test_exact(Guide): @@ -1192,6 +1208,7 @@ def model(data): AutoMultivariateNormal, AutoStructured_exact_normal, AutoStructured_exact_mvn, + AutoGaussian, ], ) def test_exact_batch(Guide): @@ -1221,183 +1238,3 @@ def model(data): actual_std = samples.std(0) assert_close(actual_mean, expected_mean, atol=0.05) assert_close(actual_std, expected_std, rtol=0.05) - - -# Simplified from https://github.com/pyro-cov/tree/master/pyrocov/mutrans.py -def pyrocov_model(dataset): - # Tensor shapes are commented at the end of some lines. - features = dataset["features"] - local_time = dataset["local_time"][..., None] # [T, P, 1] - T, P, _ = local_time.shape - S, F = features.shape - weekly_strains = dataset["weekly_strains"] - assert weekly_strains.shape == (T, P, S) - - # Sample global random variables. - coef_scale = pyro.sample("coef_scale", dist.InverseGamma(5e3, 1e2))[..., None] - rate_scale = pyro.sample("rate_scale", dist.LogNormal(-4, 2))[..., None] - init_loc_scale = pyro.sample("init_loc_scale", dist.LogNormal(0, 2))[..., None] - init_scale = pyro.sample("init_scale", dist.LogNormal(0, 2))[..., None] - - # Assume relative growth rate depends strongly on mutations and weakly on place. - coef_loc = torch.zeros(F) - coef = pyro.sample("coef", dist.Logistic(coef_loc, coef_scale).to_event(1)) # [F] - rate_loc = pyro.deterministic( - "rate_loc", 0.01 * coef @ features.T, event_dim=1 - ) # [S] - - # Assume initial infections depend strongly on strain and place. - init_loc = pyro.sample( - "init_loc", dist.Normal(torch.zeros(S), init_loc_scale).to_event(1) - ) # [S] - with pyro.plate("place", P, dim=-1): - rate = pyro.sample( - "rate", dist.Normal(rate_loc, rate_scale).to_event(1) - ) # [P, S] - init = pyro.sample( - "init", dist.Normal(init_loc, init_scale).to_event(1) - ) # [P, S] - - # Finally observe counts. - with pyro.plate("time", T, dim=-2): - logits = init + rate * local_time # [T, P, S] - pyro.sample( - "obs", - dist.Multinomial(logits=logits, validate_args=False), - obs=weekly_strains, - ) - - -# This is modified by relaxing rate from deterministic to latent. -def pyrocov_model_relaxed(dataset): - # Tensor shapes are commented at the end of some lines. - features = dataset["features"] - local_time = dataset["local_time"][..., None] # [T, P, 1] - T, P, _ = local_time.shape - S, F = features.shape - weekly_strains = dataset["weekly_strains"] - assert weekly_strains.shape == (T, P, S) - - # Sample global random variables. - coef_scale = pyro.sample("coef_scale", dist.InverseGamma(5e3, 1e2))[..., None] - rate_loc_scale = pyro.sample("rate_loc_scale", dist.LogNormal(-4, 2))[..., None] - rate_scale = pyro.sample("rate_scale", dist.LogNormal(-4, 2))[..., None] - init_loc_scale = pyro.sample("init_loc_scale", dist.LogNormal(0, 2))[..., None] - init_scale = pyro.sample("init_scale", dist.LogNormal(0, 2))[..., None] - - # Assume relative growth rate depends strongly on mutations and weakly on place. - coef_loc = torch.zeros(F) - coef = pyro.sample("coef", dist.Logistic(coef_loc, coef_scale).to_event(1)) # [F] - rate_loc = pyro.sample( - "rate_loc", - dist.Normal(0.01 * coef @ features.T, rate_loc_scale).to_event(1), - ) # [S] - - # Assume initial infections depend strongly on strain and place. - init_loc = pyro.sample( - "init_loc", dist.Normal(torch.zeros(S), init_loc_scale).to_event(1) - ) # [S] - with pyro.plate("place", P, dim=-1): - rate = pyro.sample( - "rate", dist.Normal(rate_loc, rate_scale).to_event(1) - ) # [P, S] - init = pyro.sample( - "init", dist.Normal(init_loc, init_scale).to_event(1) - ) # [P, S] - - # Finally observe counts. - with pyro.plate("time", T, dim=-2): - logits = init + rate * local_time # [T, P, S] - pyro.sample( - "obs", - dist.Multinomial(logits=logits, validate_args=False), - obs=weekly_strains, - ) - - -# This is modified by more precisely tracking plates for features and strains. -def pyrocov_model_plated(dataset): - # Tensor shapes are commented at the end of some lines. - features = dataset["features"] - local_time = dataset["local_time"][..., None] # [T, P, 1] - T, P, _ = local_time.shape - S, F = features.shape - weekly_strains = dataset["weekly_strains"] # [T, P, S] - assert weekly_strains.shape == (T, P, S) - feature_plate = pyro.plate("feature", F, dim=-1) - strain_plate = pyro.plate("strain", S, dim=-1) - place_plate = pyro.plate("place", P, dim=-2) - time_plate = pyro.plate("time", T, dim=-3) - - # Sample global random variables. - coef_scale = pyro.sample("coef_scale", dist.InverseGamma(5e3, 1e2)) - rate_loc_scale = pyro.sample("rate_loc_scale", dist.LogNormal(-4, 2)) - rate_scale = pyro.sample("rate_scale", dist.LogNormal(-4, 2)) - init_loc_scale = pyro.sample("init_loc_scale", dist.LogNormal(0, 2)) - init_scale = pyro.sample("init_scale", dist.LogNormal(0, 2)) - - with feature_plate: - coef = pyro.sample("coef", dist.Logistic(0, coef_scale)) # [F] - rate_loc_loc = 0.01 * coef @ features.T - with strain_plate: - rate_loc = pyro.sample( - "rate_loc", dist.Normal(rate_loc_loc, rate_loc_scale) - ) # [S] - init_loc = pyro.sample("init_loc", dist.Normal(0, init_loc_scale)) # [S] - with place_plate, strain_plate: - rate = pyro.sample("rate", dist.Normal(rate_loc, rate_scale)) # [P, S] - init = pyro.sample("init", dist.Normal(init_loc, init_scale)) # [P, S] - - # Finally observe counts. - with time_plate, place_plate: - logits = (init + rate * local_time)[..., None, :] # [T, P, 1, S] - pyro.sample( - "obs", - dist.Multinomial(logits=logits, validate_args=False), - obs=weekly_strains[..., None, :], - ) - - -@pytest.mark.parametrize( - "model", [pyrocov_model, pyrocov_model_relaxed, pyrocov_model_plated] -) -@pytest.mark.parametrize("backend", ["funsor"]) -def test_autogaussian_pyrocov_smoke(model, backend): - T, P, S, F = 3, 4, 5, 6 - dataset = { - "features": torch.randn(S, F), - "local_time": torch.randn(T, P), - "weekly_strains": torch.randn(T, P, S).exp().round(), - } - - guide = AutoGaussian(model, backend=backend) - svi = SVI(model, guide, Adam({"lr": 1e-8}), Trace_ELBO()) - svi.step(dataset) - svi.step(dataset) - - -@pytest.mark.parametrize( - "model", [pyrocov_model, pyrocov_model_relaxed, pyrocov_model_plated] -) -@pytest.mark.parametrize("backend", ["funsor"]) -def test_structured_pyrocov_reparam(model, backend): - T, P, S, F = 2, 3, 4, 5 - dataset = { - "features": torch.randn(S, F), - "local_time": torch.randn(T, P), - "weekly_strains": torch.randn(T, P, S).exp().round(), - } - - # Reparametrize the model. - config = { - "coef": LocScaleReparam(), - "rate_loc": None if model is pyrocov_model else LocScaleReparam(), - "rate": LocScaleReparam(), - "init_loc": LocScaleReparam(), - "init": LocScaleReparam(), - } - model = poutine.reparam(model, config) - guide = AutoGaussian(model, backend=backend) - svi = SVI(model, guide, Adam({"lr": 1e-8}), Trace_ELBO()) - svi.step(dataset) - svi.step(dataset) From f3f170c0d93fa08103850541b0550d48a7085eae Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Wed, 15 Sep 2021 18:46:37 -0400 Subject: [PATCH 09/41] Fix get_dependencies() to handle pyro.factor --- pyro/infer/inspect.py | 35 ++++++++++++++++------------------- tests/infer/test_inspect.py | 20 ++++++++++++++++++++ 2 files changed, 36 insertions(+), 19 deletions(-) diff --git a/pyro/infer/inspect.py b/pyro/infer/inspect.py index 291a3f6aa2..b4f0c8865d 100644 --- a/pyro/infer/inspect.py +++ b/pyro/infer/inspect.py @@ -1,7 +1,7 @@ # Copyright Contributors to the Pyro project. # SPDX-License-Identifier: Apache-2.0 -from typing import Callable, Dict, Optional +from typing import Callable, Dict, List, Optional import torch @@ -28,10 +28,6 @@ def is_sample_site(msg): if type(fn).__name__ == "Delta": return False - # Exclude factor statements. - if type(fn).__name__ == "Unit": - return False - return True @@ -48,6 +44,7 @@ def _pyro_post_sample(self, msg): msg["value"] = msg["value"].detach() +@torch.enable_grad() def get_dependencies( model: Callable, model_args: Optional[tuple] = None, @@ -175,7 +172,7 @@ def model_3(): model_kwargs = {} def get_sample_sites(predicate=lambda msg: True): - with torch.enable_grad(), torch.random.fork_rng(): + with torch.random.fork_rng(): with pyro.validation_enabled(False), RequiresGradMessenger(predicate): trace = poutine.trace(model).get_trace(*model_args, **model_kwargs) return [msg for msg in trace.nodes.values() if is_sample_site(msg)] @@ -191,15 +188,13 @@ def get_sample_sites(predicate=lambda msg: True): # First find transitive dependencies among latent and observed sites prior_dependencies = {n: {n: set()} for n in plates} # no deps yet for i, downstream in enumerate(sample_sites): - upstreams = [u for u in sample_sites[:i] if not u["is_observed"]] + upstreams = [ + u for u in sample_sites[:i] if not u["is_observed"] if u["value"].numel() + ] if not upstreams: continue - grads = torch.autograd.grad( - downstream["fn"].log_prob(downstream["value"]).sum(), - [u["value"] for u in upstreams], - allow_unused=True, - retain_graph=True, - ) + log_prob = downstream["fn"].log_prob(downstream["value"]).sum() + grads = _safe_grad(log_prob, [u["value"] for u in upstreams]) for upstream, grad in zip(upstreams, grads): if grad is not None: d = downstream["name"] @@ -215,12 +210,8 @@ def get_sample_sites(predicate=lambda msg: True): sample_sites_ij = get_sample_sites(lambda msg: msg["name"] in names) d = sample_sites_ij[i] u = sample_sites_ij[j] - grad = torch.autograd.grad( - d["fn"].log_prob(d["value"]).sum(), - [u["value"]], - allow_unused=True, - retain_graph=True, - )[0] + log_prob = d["fn"].log_prob(d["value"]).sum() + grad = _safe_grad(log_prob, [u["value"]])[0] if grad is None: prior_dependencies[d["name"]].pop(u["name"]) @@ -252,6 +243,12 @@ def get_sample_sites(predicate=lambda msg: True): } +def _safe_grad(root: torch.Tensor, args: List[torch.Tensor]): + if not root.requires_grad: + return [None] * len(args) + return torch.autograd.grad(root, args, allow_unused=True, retain_graph=True) + + __all__ = [ "get_dependencies", ] diff --git a/tests/infer/test_inspect.py b/tests/infer/test_inspect.py index a0916ef3ac..1480f283b1 100644 --- a/tests/infer/test_inspect.py +++ b/tests/infer/test_inspect.py @@ -118,6 +118,26 @@ def model_3(): assert actual == expected +def test_factor(): + def model(): + a = pyro.sample("a", dist.Normal(0, 1)) + pyro.factor("b", torch.tensor(0.0)) + pyro.factor("c", a) + + actual = get_dependencies(model) + expected = { + "prior_dependencies": { + "a": {"a": set()}, + "b": {"b": set()}, + "c": {"c": set(), "a": set()}, + }, + "posterior_dependencies": { + "a": {"a": set(), "c": set()}, + }, + } + assert actual == expected + + def test_plate_coupling(): # x x # || From f5bc92fc2dfcd7a0e545538012034e68abfe6444 Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Wed, 15 Sep 2021 20:09:08 -0400 Subject: [PATCH 10/41] Add failing test for plated sampling --- pyro/infer/inspect.py | 2 +- tests/infer/autoguide/test_autogaussian.py | 76 ++++++++++++++++++++-- tests/infer/test_autoguide.py | 38 +++++++++++ 3 files changed, 109 insertions(+), 7 deletions(-) diff --git a/pyro/infer/inspect.py b/pyro/infer/inspect.py index b4f0c8865d..f8a91f65d7 100644 --- a/pyro/infer/inspect.py +++ b/pyro/infer/inspect.py @@ -37,7 +37,7 @@ def __init__(self, predicate=lambda msg: True): super().__init__() def _pyro_post_sample(self, msg): - if is_sample_site(msg): + if is_sample_site(msg) and msg["value"].dtype.is_floating_point: if self.predicate(msg): msg["value"].requires_grad_() elif not msg["is_observed"] and msg["value"].requires_grad: diff --git a/tests/infer/autoguide/test_autogaussian.py b/tests/infer/autoguide/test_autogaussian.py index 900799c0bc..223bc9edb6 100644 --- a/tests/infer/autoguide/test_autogaussian.py +++ b/tests/infer/autoguide/test_autogaussian.py @@ -1,6 +1,8 @@ # Copyright Contributors to the Pyro project. # SPDX-License-Identifier: Apache-2.0 +from collections import OrderedDict + import pytest import torch @@ -154,8 +156,7 @@ def pyrocov_model_plated(dataset): @pytest.mark.parametrize( "model", [pyrocov_model, pyrocov_model_relaxed, pyrocov_model_plated] ) -@pytest.mark.parametrize("backend", ["funsor"]) -def test_autogaussian_pyrocov_smoke(model, backend): +def test_pyrocov_smoke(model): T, P, S, F = 3, 4, 5, 6 dataset = { "features": torch.randn(S, F), @@ -163,7 +164,7 @@ def test_autogaussian_pyrocov_smoke(model, backend): "weekly_strains": torch.randn(T, P, S).exp().round(), } - guide = AutoGaussian(model, backend=backend) + guide = AutoGaussian(model) svi = SVI(model, guide, Adam({"lr": 1e-8}), Trace_ELBO()) for step in range(2): svi.step(dataset) @@ -175,8 +176,7 @@ def test_autogaussian_pyrocov_smoke(model, backend): @pytest.mark.parametrize( "model", [pyrocov_model, pyrocov_model_relaxed, pyrocov_model_plated] ) -@pytest.mark.parametrize("backend", ["funsor"]) -def test_structured_pyrocov_reparam(model, backend): +def test_pyrocov_reparam(model): T, P, S, F = 2, 3, 4, 5 dataset = { "features": torch.randn(S, F), @@ -193,7 +193,7 @@ def test_structured_pyrocov_reparam(model, backend): "init": LocScaleReparam(), } model = poutine.reparam(model, config) - guide = AutoGaussian(model, backend=backend) + guide = AutoGaussian(model) svi = SVI(model, guide, Adam({"lr": 1e-8}), Trace_ELBO()) for step in range(2): svi.step(dataset) @@ -202,6 +202,70 @@ def test_structured_pyrocov_reparam(model, backend): predictive(dataset) +def test_pyrocov_structure(): + from funsor import Bint, Real, Reals + + T, P, S, F = 2, 3, 4, 5 + dataset = { + "features": torch.randn(S, F), + "local_time": torch.randn(T, P), + "weekly_strains": torch.randn(T, P, S).exp().round(), + } + + guide = AutoGaussian(pyrocov_model_plated) + guide(dataset) # initialize + + expected_plates = frozenset(["place", "feature", "strain"]) + assert guide._funsor_plates == expected_plates + + expected_eliminate = frozenset( + [ + "place", + "coef_scale", + "rate_loc_scale", + "rate_scale", + "init_loc_scale", + "init_scale", + "coef", + "rate_loc", + "init_loc", + "rate", + "init", + ] + ) + assert guide._funsor_eliminate == expected_eliminate + + expected_factor_inputs = { + "coef_scale": OrderedDict([("coef_scale", Real)]), + "rate_loc_scale": OrderedDict([("rate_loc_scale", Real)]), + "rate_scale": OrderedDict([("rate_scale", Real)]), + "init_loc_scale": OrderedDict([("init_loc_scale", Real)]), + "init_scale": OrderedDict([("init_scale", Real)]), + "coef": OrderedDict([("coef", Reals[5]), ("coef_scale", Real)]), + "rate_loc": OrderedDict( + [("rate_loc", Reals[4]), ("rate_loc_scale", Real), ("coef", Reals[5])] + ), + "init_loc": OrderedDict([("init_loc", Reals[4]), ("init_loc_scale", Real)]), + "rate": OrderedDict( + [ + ("place", Bint[3]), + ("rate", Reals[4]), + ("rate_scale", Real), + ("rate_loc", Reals[4]), + ] + ), + "init": OrderedDict( + [ + ("place", Bint[3]), + ("init", Reals[4]), + ("init_scale", Real), + ("init_loc", Reals[4]), + ] + ), + } + assert guide._funsor_factor_inputs == expected_factor_inputs + + def test_profile(n=1, num_steps=1): """ Helper function for profiling. diff --git a/tests/infer/test_autoguide.py b/tests/infer/test_autoguide.py index fc011edef6..2aeae16ec4 100644 --- a/tests/infer/test_autoguide.py +++ b/tests/infer/test_autoguide.py @@ -908,6 +908,44 @@ def forward(self, x, y=None): assert len(samples) == len(samples_deser) +@pytest.mark.parametrize("sample_shape", [(), (6,), (5, 4)], ids=str) +@pytest.mark.parametrize( + "auto_class", + [ + AutoDelta, + AutoDiagonalNormal, + AutoMultivariateNormal, + AutoNormal, + AutoLowRankMultivariateNormal, + AutoLaplaceApproximation, + functools.partial(AutoDiagonalNormal, init_loc_fn=init_to_mean), + functools.partial(AutoDiagonalNormal, init_loc_fn=init_to_median), + AutoStructured, + AutoGaussian, + ], +) +def test_replay_plates(auto_class, sample_shape): + def model(): + a = pyro.sample("a", dist.Normal(0, 1)) + b = pyro.sample("b", dist.Normal(a[..., None], torch.ones(3)).to_event(1)) + c = pyro.sample( + "c", dist.MultivariateNormal(torch.zeros(3) + a[..., None], torch.eye(3)) + ) + with pyro.plate("i", 2): + d = pyro.sample("d", dist.Dirichlet((b + c).exp())) + pyro.sample("e", dist.Categorical(logits=d), obs=torch.tensor([0, 0])) + return a, b, c, d + + guide = auto_class(model) + with pyro.plate_stack("plate", sample_shape, rightmost_dim=-2): + guide_trace = poutine.trace(guide).get_trace() + a, b, c, d = poutine.replay(model, guide_trace)() + assert a.shape == (sample_shape + (1,) if sample_shape else ()) + assert b.shape == (sample_shape + (1, 3) if sample_shape else (3,)) + assert c.shape == (sample_shape + (1, 3) if sample_shape else (3,)) + assert d.shape == sample_shape + (2, 3) + + @pytest.mark.parametrize("auto_class", [AutoDelta, AutoNormal]) def test_subsample_model(auto_class): def model(x, y=None, batch_size=None): From 945c66ee28bc07ca804b54fefd914ce7b0dfa824 Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Mon, 20 Sep 2021 21:20:12 -0400 Subject: [PATCH 11/41] Add regression tests for get_dependencies() --- tests/infer/test_autoguide.py | 2 +- tests/infer/test_inspect.py | 39 +++++++++++++++++++++++++++++++++-- 2 files changed, 38 insertions(+), 3 deletions(-) diff --git a/tests/infer/test_autoguide.py b/tests/infer/test_autoguide.py index 2aeae16ec4..367965ea63 100644 --- a/tests/infer/test_autoguide.py +++ b/tests/infer/test_autoguide.py @@ -937,7 +937,7 @@ def model(): return a, b, c, d guide = auto_class(model) - with pyro.plate_stack("plate", sample_shape, rightmost_dim=-2): + with pyro.plate_stack("particles", sample_shape, rightmost_dim=-2): guide_trace = poutine.trace(guide).get_trace() a, b, c, d = poutine.replay(model, guide_trace)() assert a.shape == (sample_shape + (1,) if sample_shape else ()) diff --git a/tests/infer/test_inspect.py b/tests/infer/test_inspect.py index 1480f283b1..eb4b82ac77 100644 --- a/tests/infer/test_inspect.py +++ b/tests/infer/test_inspect.py @@ -8,8 +8,11 @@ from pyro.distributions.testing.fakes import NonreparameterizedNormal from pyro.infer.inspect import get_dependencies +import pytest -def test_get_dependencies(): + +@pytest.mark.parametrize("grad_enabled", [True, False]) +def test_get_dependencies(grad_enabled): def model(data): a = pyro.sample("a", dist.Normal(0, 1)) b = pyro.sample("b", NonreparameterizedNormal(a, 0)) @@ -30,7 +33,8 @@ def model(data): return [a, b, c, d, e, f, g, h, i, j, k] data = torch.randn(3) - actual = get_dependencies(model, (data,)) + with torch.set_grad_enabled(grad_enabled): + actual = get_dependencies(model, (data,)) _ = set() expected = { "prior_dependencies": { @@ -138,6 +142,37 @@ def model(): assert actual == expected +def test_discrete_obs(): + def model(): + a = pyro.sample("a", dist.Normal(0, 1)) + b = pyro.sample("b", dist.Normal(a[..., None], torch.ones(3)).to_event(1)) + c = pyro.sample( + "c", dist.MultivariateNormal(torch.zeros(3) + a[..., None], torch.eye(3)) + ) + with pyro.plate("i", 2): + d = pyro.sample("d", dist.Dirichlet((b + c).exp())) + pyro.sample("e", dist.Categorical(logits=d), obs=torch.tensor([0, 0])) + return a, b, c, d + + actual = get_dependencies(model) + expected = { + "prior_dependencies": { + "a": {"a": set()}, + "b": {"a": set(), "b": set()}, + "c": {"a": set(), "c": set()}, + "d": {"b": set(), "c": set(), "d": set()}, + "e": {"d": set(), "e": set()}, + }, + "posterior_dependencies": { + "a": {"a": set(), "b": set(), "c": set()}, + "b": {"b": set(), "c": set(), "d": set()}, + "c": {"c": set(), "d": set()}, + "d": {"d": set(), "e": set()}, + }, + } + assert actual == expected + + def test_plate_coupling(): # x x # || From ff0eaf54d9fb3bec065881fbe29d8e8ed43727fc Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Mon, 20 Sep 2021 21:40:29 -0400 Subject: [PATCH 12/41] lint --- tests/infer/test_inspect.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/infer/test_inspect.py b/tests/infer/test_inspect.py index eb4b82ac77..2f600c28a2 100644 --- a/tests/infer/test_inspect.py +++ b/tests/infer/test_inspect.py @@ -1,6 +1,7 @@ # Copyright Contributors to the Pyro project. # SPDX-License-Identifier: Apache-2.0 +import pytest import torch import pyro @@ -8,8 +9,6 @@ from pyro.distributions.testing.fakes import NonreparameterizedNormal from pyro.infer.inspect import get_dependencies -import pytest - @pytest.mark.parametrize("grad_enabled", [True, False]) def test_get_dependencies(grad_enabled): From 44cc8701e4586a8c1deb1e13ddcbadee0539254f Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Tue, 21 Sep 2021 13:59:17 -0400 Subject: [PATCH 13/41] Use funsor.recipes.forward_filter_backward_rsample --- pyro/infer/autoguide/guides.py | 61 +++++++++++----------------------- 1 file changed, 19 insertions(+), 42 deletions(-) diff --git a/pyro/infer/autoguide/guides.py b/pyro/infer/autoguide/guides.py index 945bf1f51c..c3cfc0e322 100644 --- a/pyro/infer/autoguide/guides.py +++ b/pyro/infer/autoguide/guides.py @@ -1862,7 +1862,14 @@ def _sample_aux_values( raise ValueError(f"Unknown backend: {self.backend}") def _funsor_setup_prototype(self, *args, **kwargs): - import funsor + try: + import funsor + except ImportError as e: + raise ImportError( + 'AutoGaussian(..., backend="funsor") requires funsor. ' + "Try installing via: pip install pyro-ppl[funsor]" + ) from e + funsor.set_backend("torch") # Determine TVE problem shape. factor_inputs: Dict[str, OrderedDict[str, funsor.Domain]] = {} @@ -1897,53 +1904,23 @@ def _funsor_sample_aux_values( ) -> Tuple[Dict[str, torch.Tensor], Union[float, torch.Tensor]]: import funsor - import pyro.contrib.funsor - - # Construct TVE problem inputs, converting torch to funsor. + # Convert torch to funsor. + particle_plates = frozenset(get_plates()) + plate_to_dim = self._funsor_plate_to_dim.copy() + plate_to_dim.update({f.name: f.dim for f in particle_plates}) factors = {} for d, inputs in self._funsor_factor_inputs.items(): precision = _deep_getattr(self.precisions, d) info_vec = precision.new_zeros(()).expand(precision.shape[:-1]) factors[d] = funsor.gaussian.Gaussian(info_vec, precision, inputs) - # Compute log normalizer via Gaussian tensor variable elimination. - with funsor.interpretations.reflect: - log_Z = funsor.sum_product.sum_product( - funsor.ops.logaddexp, - funsor.ops.add, - list(factors.values()), - eliminate=self._funsor_eliminate, - plates=self._funsor_plates, - ) - log_Z = funsor.optimizer.apply_optimizer(log_Z) - - # Draw a batch of samples. - particle_plates = frozenset(get_plates()) - plate_to_dim = self._funsor_plate_to_dim.copy() - plate_to_dim.update({f.name: f.dim for f in particle_plates}) - sample_inputs = {f.name: funsor.Bint[f.size] for f in particle_plates} - with funsor.montecarlo.MonteCarlo(**sample_inputs): - samples = funsor.adjoint.adjoint( - funsor.ops.logaddexp, funsor.ops.add, log_Z - ) - - # Extract funsor.Tensor values from funsor.Delta samples. - samples = { - name: pyro.contrib.funsor.handlers.enum_messenger._get_support_value( - samples[factors[name]], name - ) - for name in self._funsor_eliminate - self._funsor_plates - } - - # Compute density. - # TODO Avoid recomputing this by obtaining it from adjoint above. - log_prob = -funsor.reinterpret(log_Z) - for f in factors.values(): - term = f(**samples) - plates = self._funsor_eliminate.intersection(term.inputs) - term = term.reduce(funsor.ops.add, plates) - log_prob += term - assert all(f.name in log_prob.inputs for f in particle_plates) + # Perform Gaussian tensor variable elimination. + samples, log_prob = funsor.recipes.forward_filter_backward_rsample( + factors=factors, + eliminate=self._funsor_eliminate, + plates=frozenset(plate_to_dim), + sample_inputs={f.name: funsor.Bint[f.size] for f in particle_plates}, + ) # Convert funsor to torch. samples = { From f30019963e96bad1843a4e720c1aa7cd8b67a2cb Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Wed, 22 Sep 2021 18:42:10 -0400 Subject: [PATCH 14/41] Fix & strengthen tests --- pyro/infer/autoguide/guides.py | 34 +++++++++++++++------- tests/infer/test_autoguide.py | 52 +++++++++++++++++++++++++++++++--- 2 files changed, 72 insertions(+), 14 deletions(-) diff --git a/pyro/infer/autoguide/guides.py b/pyro/infer/autoguide/guides.py index c3cfc0e322..a8470b2807 100644 --- a/pyro/infer/autoguide/guides.py +++ b/pyro/infer/autoguide/guides.py @@ -1752,7 +1752,8 @@ def _setup_prototype(self, *args, **kwargs) -> None: precision_plates: Set[CondIndepStackFrame] = set() if not site["is_observed"]: # Initialize latent variable location-scale parameters. - init_loc = biject_to(site["fn"].support).inv(site["value"]).detach() + with helpful_support_errors(site): + init_loc = biject_to(site["fn"].support).inv(site["value"]).detach() init_scale = torch.full_like(init_loc, self._init_scale) batch_shape = site["fn"].batch_shape event_shape = init_loc.shape[len(batch_shape) :] @@ -1811,9 +1812,8 @@ def forward(self, *args, **kwargs) -> Dict[str, torch.Tensor]: if self.prototype_trace is None: self._setup_prototype(*args, **kwargs) - aux_values, log_density_1 = self._sample_aux_values() - values, log_density_2 = self._transform_values(aux_values) - log_density = log_density_1 + log_density_2 + aux_values, log_density = self._sample_aux_values() + values, log_densities = self._transform_values(aux_values) # Replay via Pyro primitives. plates = self._create_plates(*args, **kwargs) @@ -1823,10 +1823,12 @@ def forward(self, *args, **kwargs) -> Dict[str, torch.Tensor]: if frame.vectorized: stack.enter_context(plates[frame.name]) pyro.sample( - name, dist.Delta(values[name], event_dim=site["fn"].event_dim) + name, + dist.Delta(values[name], log_densities[name], site["fn"].event_dim), ) if am_i_wrapped() and poutine.get_mask() is not False: - pyro.factor("AutoGaussian.log_density", log_density) + log_density = log_density + log_densities["AutoGaussian"] + pyro.factor("AutoGaussian", log_density) return values def median(self) -> Dict[str, torch.Tensor]: @@ -1841,16 +1843,28 @@ def _transform_values( ) -> Tuple[Dict[str, torch.Tensor], Union[float, torch.Tensor]]: # Learnably transform auxiliary values to user-facing values. values = {} - log_density = 0.0 + log_densities = defaultdict(float) compute_density = am_i_wrapped() and poutine.get_mask() is not False for name, site in self._sorted_sites.items(): loc = _deep_getattr(self.locs, name) scale = _deep_getattr(self.scales, name) - values[name] = biject_to(site["fn"].support)(aux_values[name] * scale + loc) + unconstrained = aux_values[name] * scale + loc + + # Transform to constrained space. + transform = biject_to(site["fn"].support) + values[name] = transform(unconstrained) if compute_density: - log_density = log_density - scale.log().sum() + # Split the density into a aggregated unshaped part + # "AutoGaussian" and a per-site shaped part. + log_densities["AutoGaussian"] = ( + log_densities["AutoGaussian"] - scale.log().sum() + ) + assert transform.codomain.event_dim == site["fn"].event_dim + log_densities[name] = transform.inv.log_abs_det_jacobian( + values[name], unconstrained + ) - return values, log_density + return values, log_densities def _sample_aux_values( self, diff --git a/tests/infer/test_autoguide.py b/tests/infer/test_autoguide.py index 367965ea63..8f27157ffb 100644 --- a/tests/infer/test_autoguide.py +++ b/tests/infer/test_autoguide.py @@ -37,12 +37,17 @@ ) from pyro.infer.reparam import ProjectedNormalReparam from pyro.nn.module import PyroModule, PyroParam, PyroSample +from pyro.ops.gaussian import Gaussian from pyro.optim import Adam from pyro.poutine.util import prune_subsample_sites from pyro.util import check_model_guide_match from tests.common import assert_close, assert_equal # AutoGaussian currently depends on funsor. +AutoGaussian_median = pytest.param( + functools.partial(AutoGaussian, init_loc_fn=init_to_median), + marks=[pytest.mark.stage("funsor")], +) AutoGaussian = pytest.param(AutoGaussian, marks=[pytest.mark.stage("funsor")]) @@ -705,7 +710,8 @@ def model(): auto_guide_module_callable, functools.partial(AutoDiagonalNormal, init_loc_fn=init_to_mean), functools.partial(AutoDiagonalNormal, init_loc_fn=init_to_median), - AutoGaussian, + functools.partial(AutoNormal, init_loc_fn=init_to_median), + AutoGaussian_median, ], ) @pytest.mark.parametrize("Elbo", [Trace_ELBO, TraceGraph_ELBO, TraceEnum_ELBO]) @@ -717,8 +723,9 @@ def __init__(self): self.x_scale = PyroParam(torch.tensor(0.1), constraints.positive) def forward(self): - pyro.sample("x", dist.Normal(self.x_loc, self.x_scale)) + x = pyro.sample("x", dist.Normal(self.x_loc, self.x_scale)) pyro.sample("y", dist.Normal(2.0, 0.1)) + pyro.sample("z", dist.Normal(1.0, 0.1), obs=x) model = Model() guide = auto_class(model) @@ -1153,7 +1160,9 @@ def test_sphere_reparam_ok(auto_class, init_loc_fn): def model(): x = pyro.sample("x", dist.Normal(0.0, 1.0).expand([3]).to_event(1)) y = pyro.sample("y", dist.ProjectedNormal(x)) - pyro.sample("obs", dist.Normal(y, 1), obs=torch.tensor([1.0, 0.0])) + pyro.sample( + "obs", dist.Normal(y, 1).to_event(1), obs=torch.tensor([1.0, 0.0, 0.0]) + ) model = poutine.reparam(model, {"y": ProjectedNormalReparam()}) guide = auto_class(model) @@ -1174,7 +1183,9 @@ def test_sphere_raw_ok(auto_class, init_loc_fn): def model(): x = pyro.sample("x", dist.Normal(0.0, 1.0).expand([3]).to_event(1)) y = pyro.sample("y", dist.ProjectedNormal(x)) - pyro.sample("obs", dist.Normal(y, 1), obs=torch.tensor([1.0, 0.0])) + pyro.sample( + "obs", dist.Normal(y, 1).to_event(1), obs=torch.tensor([1.0, 0.0, 0.0]) + ) guide = auto_class(model, init_loc_fn=init_loc_fn) poutine.trace(guide).get_trace().compute_log_prob() @@ -1204,6 +1215,7 @@ def __init__(self, model): AutoNormal, AutoDiagonalNormal, AutoMultivariateNormal, + AutoLowRankMultivariateNormal, AutoStructured_exact_normal, AutoStructured_exact_mvn, AutoGaussian, @@ -1219,6 +1231,15 @@ def model(data): data = torch.randn(3) expected_mean = (0 + data.sum().item()) / (1 + len(data)) expected_std = (1 + len(data)) ** (-0.5) + g = Gaussian( + log_normalizer=torch.zeros(()), + info_vec=torch.zeros(4), + precision=torch.tensor( + [[4, -1, -1, -1], [-1, 1, 0, 0], [-1, 0, 1, 0], [-1, 0, 0, 1]], + dtype=data.dtype, + ), + ) + expected_loss = float(g.event_logsumexp() - g.condition(data).event_logsumexp()) guide = Guide(model) elbo = Trace_ELBO(num_particles=100, vectorize_particles=True) @@ -1229,6 +1250,7 @@ def model(data): guide.requires_grad_(False) with torch.no_grad(): + # Check moments. vectorize = pyro.plate("particles", 10000, dim=-2) guide_trace = poutine.trace(vectorize(guide)).get_trace(data) samples = poutine.replay(vectorize(model), guide_trace)(data) @@ -1237,6 +1259,10 @@ def model(data): assert_close(actual_mean, expected_mean, atol=0.05) assert_close(actual_std, expected_std, rtol=0.05) + # Check ELBO loss. + actual_loss = elbo.loss(model, guide, data) + assert_close(actual_loss, expected_loss, atol=0.01) + @pytest.mark.parametrize( "Guide", @@ -1244,6 +1270,7 @@ def model(data): AutoNormal, AutoDiagonalNormal, AutoMultivariateNormal, + AutoLowRankMultivariateNormal, AutoStructured_exact_normal, AutoStructured_exact_mvn, AutoGaussian, @@ -1259,6 +1286,19 @@ def model(data): data = torch.randn(3) expected_mean = (0 + data) / (1 + 1) expected_std = (1 + torch.ones_like(data)) ** (-0.5) + g = Gaussian( + log_normalizer=torch.zeros(3), + info_vec=torch.zeros(3, 2), + precision=torch.tensor( + [[[2, -1], [-1, 1]]] * 3, + dtype=data.dtype, + ), + ) + expected_loss = ( + (g.event_logsumexp() - g.condition(data[:, None]).event_logsumexp()) + .sum() + .item() + ) guide = Guide(model) elbo = Trace_ELBO(num_particles=100, vectorize_particles=True) @@ -1276,3 +1316,7 @@ def model(data): actual_std = samples.std(0) assert_close(actual_mean, expected_mean, atol=0.05) assert_close(actual_std, expected_std, rtol=0.05) + + # Check ELBO loss. + actual_loss = elbo.loss(model, guide, data) + assert_close(actual_loss, expected_loss, atol=0.01) From c00439a9418d4ced883a342cab5001fabb1e419c Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Wed, 22 Sep 2021 18:55:20 -0400 Subject: [PATCH 15/41] Reflect --- pyro/infer/autoguide/guides.py | 16 +++++++++------- tests/infer/test_autoguide.py | 7 +++---- 2 files changed, 12 insertions(+), 11 deletions(-) diff --git a/pyro/infer/autoguide/guides.py b/pyro/infer/autoguide/guides.py index a8470b2807..9860d11952 100644 --- a/pyro/infer/autoguide/guides.py +++ b/pyro/infer/autoguide/guides.py @@ -1805,8 +1805,10 @@ def _setup_prototype(self, *args, **kwargs) -> None: ), ) - if self.backend == "funsor": - self._funsor_setup_prototype(*args, **kwargs) + # Dispatch to backend logic. + backend_fn = getattr(self, f"_{self.backend}_setup_prototype", None) + if backend_fn is not None: + backend_fn(*args, **kwargs) def forward(self, *args, **kwargs) -> Dict[str, torch.Tensor]: if self.prototype_trace is None: @@ -1869,11 +1871,11 @@ def _transform_values( def _sample_aux_values( self, ) -> Tuple[Dict[str, torch.Tensor], Union[float, torch.Tensor]]: - # Sample auxiliary values via Gaussian tensor variable elimination. - if self.backend == "funsor": - return self._funsor_sample_aux_values() - else: - raise ValueError(f"Unknown backend: {self.backend}") + # Dispatch to backend logic. + backend_fn = getattr(self, f"_{self.backend}_sample_aux_values", None) + if backend_fn is None: + raise NotImplementedError(f"Unknown AutoGaussian backend: {self.backend}") + return backend_fn() def _funsor_setup_prototype(self, *args, **kwargs): try: diff --git a/tests/infer/test_autoguide.py b/tests/infer/test_autoguide.py index 8f27157ffb..7e4d8863c8 100644 --- a/tests/infer/test_autoguide.py +++ b/tests/infer/test_autoguide.py @@ -1294,10 +1294,8 @@ def model(data): dtype=data.dtype, ), ) - expected_loss = ( - (g.event_logsumexp() - g.condition(data[:, None]).event_logsumexp()) - .sum() - .item() + expected_loss = float( + g.event_logsumexp().sum() - g.condition(data[:, None]).event_logsumexp().sum() ) guide = Guide(model) @@ -1309,6 +1307,7 @@ def model(data): guide.requires_grad_(False) with torch.no_grad(): + # Check moments. vectorize = pyro.plate("particles", 10000, dim=-2) guide_trace = poutine.trace(vectorize(guide)).get_trace(data) samples = poutine.replay(vectorize(model), guide_trace)(data) From bfab03429a5e815424b445f56358dd2150c80a9a Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Thu, 23 Sep 2021 10:28:53 -0400 Subject: [PATCH 16/41] Add another exact test --- pyro/infer/autoguide/guides.py | 6 +-- tests/infer/test_autoguide.py | 77 +++++++++++++++++++++++++++++++++- 2 files changed, 77 insertions(+), 6 deletions(-) diff --git a/pyro/infer/autoguide/guides.py b/pyro/infer/autoguide/guides.py index 9860d11952..8dbe05a25a 100644 --- a/pyro/infer/autoguide/guides.py +++ b/pyro/infer/autoguide/guides.py @@ -34,11 +34,6 @@ def model(): from pyro.distributions import constraints from pyro.distributions.transforms import affine_autoregressive, iterated from pyro.distributions.util import eye_like, is_identically_zero, sum_rightmost -from pyro.infer.autoguide.initialization import ( - InitMessenger, - init_to_feasible, - init_to_median, -) from pyro.infer.enum import config_enumerate from pyro.infer.inspect import get_dependencies from pyro.nn import PyroModule, PyroParam @@ -48,6 +43,7 @@ def model(): from pyro.poutine.runtime import am_i_wrapped, get_plates from pyro.poutine.util import site_is_subsample +from .initialization import InitMessenger, init_to_feasible, init_to_median from .utils import _product, helpful_support_errors diff --git a/tests/infer/test_autoguide.py b/tests/infer/test_autoguide.py index 7e4d8863c8..9c6a8f4573 100644 --- a/tests/infer/test_autoguide.py +++ b/tests/infer/test_autoguide.py @@ -1235,7 +1235,12 @@ def model(data): log_normalizer=torch.zeros(()), info_vec=torch.zeros(4), precision=torch.tensor( - [[4, -1, -1, -1], [-1, 1, 0, 0], [-1, 0, 1, 0], [-1, 0, 0, 1]], + [ + [4, -1, -1, -1], + [-1, 1, 0, 0], + [-1, 0, 1, 0], + [-1, 0, 0, 1], + ], dtype=data.dtype, ), ) @@ -1319,3 +1324,73 @@ def model(data): # Check ELBO loss. actual_loss = elbo.loss(model, guide, data) assert_close(actual_loss, expected_loss, atol=0.01) + + +@pytest.mark.parametrize( + "Guide", + [ + AutoNormal, + AutoDiagonalNormal, + AutoMultivariateNormal, + AutoLowRankMultivariateNormal, + AutoStructured, + AutoGaussian, + ], +) +def test_exact_tree(Guide): + is_exact = Guide not in (AutoNormal, AutoDiagonalNormal) + + def model(data): + x = pyro.sample("x", dist.Normal(0, 1)) + with pyro.plate("data", len(data)): + y = pyro.sample("y", dist.Normal(x, 1)) + pyro.sample("obs", dist.Normal(y, 1), obs=data) + return {"x": x, "y": y} + + data = torch.randn(2) + g = Gaussian( + log_normalizer=torch.zeros(()), + info_vec=torch.zeros(5), + precision=torch.tensor( + [ + [3, -1, -1, 0, 0], # x + [-1, 2, 0, -1, 0], # y[0] + [-1, 0, 2, 0, -1], # y[1] + [0, -1, 0, 1, 0], # obs[0] + [0, 0, -1, 0, 1], # obs[1] + ], + dtype=data.dtype, + ), + ) + g_cond = g.condition(data) + mean = torch.linalg.solve(g_cond.precision, g_cond.info_vec) + std = torch.inverse(g_cond.precision).diag().sqrt() + expected_mean = {"x": mean[0], "y": mean[1:]} + expected_std = {"x": std[0], "y": std[1:]} + expected_loss = float(g.event_logsumexp() - g_cond.event_logsumexp()) + + guide = Guide(model) + elbo = Trace_ELBO(num_particles=100, vectorize_particles=True) + optim = Adam({"lr": 0.01}) + svi = SVI(model, guide, optim, elbo) + for step in range(500): + svi.step(data) + + guide.train(False) + guide.requires_grad_(False) + with torch.no_grad(): + # Check moments. + vectorize = pyro.plate("particles", 10000, dim=-2) + guide_trace = poutine.trace(vectorize(guide)).get_trace(data) + samples = poutine.replay(vectorize(model), guide_trace)(data) + for name in ["x", "y"]: + actual_mean = samples[name].mean(0).squeeze() + actual_std = samples[name].std(0).squeeze() + assert_close(actual_mean, expected_mean[name], atol=0.05) + if is_exact: + assert_close(actual_std, expected_std[name], rtol=0.05) + + if is_exact: + # Check ELBO loss. + actual_loss = elbo.loss(model, guide, data) + assert_close(actual_loss, expected_loss, atol=0.01) From 20ca07e9192f70b1b5bbfb15005f8baaaaa3c65f Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Thu, 23 Sep 2021 11:28:47 -0400 Subject: [PATCH 17/41] Simplify log_density computation --- pyro/infer/autoguide/guides.py | 12 +++--------- 1 file changed, 3 insertions(+), 9 deletions(-) diff --git a/pyro/infer/autoguide/guides.py b/pyro/infer/autoguide/guides.py index 8dbe05a25a..9d9592ecf1 100644 --- a/pyro/infer/autoguide/guides.py +++ b/pyro/infer/autoguide/guides.py @@ -1810,7 +1810,7 @@ def forward(self, *args, **kwargs) -> Dict[str, torch.Tensor]: if self.prototype_trace is None: self._setup_prototype(*args, **kwargs) - aux_values, log_density = self._sample_aux_values() + aux_values, global_log_density = self._sample_aux_values() values, log_densities = self._transform_values(aux_values) # Replay via Pyro primitives. @@ -1825,8 +1825,7 @@ def forward(self, *args, **kwargs) -> Dict[str, torch.Tensor]: dist.Delta(values[name], log_densities[name], site["fn"].event_dim), ) if am_i_wrapped() and poutine.get_mask() is not False: - log_density = log_density + log_densities["AutoGaussian"] - pyro.factor("AutoGaussian", log_density) + pyro.factor(self._pyro_name, global_log_density) return values def median(self) -> Dict[str, torch.Tensor]: @@ -1852,15 +1851,10 @@ def _transform_values( transform = biject_to(site["fn"].support) values[name] = transform(unconstrained) if compute_density: - # Split the density into a aggregated unshaped part - # "AutoGaussian" and a per-site shaped part. - log_densities["AutoGaussian"] = ( - log_densities["AutoGaussian"] - scale.log().sum() - ) assert transform.codomain.event_dim == site["fn"].event_dim log_densities[name] = transform.inv.log_abs_det_jacobian( values[name], unconstrained - ) + ) - scale.log().reshape(site["fn"].batch_shape + (-1,)).sum(-1) return values, log_densities From b28a5ff250d8e0b991b4aaef67f547770a7363de Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Thu, 23 Sep 2021 11:55:17 -0400 Subject: [PATCH 18/41] Switch from precision to precision_chol parameters --- pyro/infer/autoguide/guides.py | 23 ++++++++--------------- 1 file changed, 8 insertions(+), 15 deletions(-) diff --git a/pyro/infer/autoguide/guides.py b/pyro/infer/autoguide/guides.py index 9d9592ecf1..a570a6a471 100644 --- a/pyro/infer/autoguide/guides.py +++ b/pyro/infer/autoguide/guides.py @@ -1691,7 +1691,7 @@ class AutoGaussian(AutoGuide): backend: str locs: PyroModule scales: PyroModule - precisions: PyroModule + precision_chols: PyroModule _sorted_sites: Dict[str, Dict[str, object]] _init_scale: float _original_model: Tuple[Callable] @@ -1723,7 +1723,7 @@ def _setup_prototype(self, *args, **kwargs) -> None: self.locs = PyroModule() self.scales = PyroModule() - self.precisions = PyroModule() + self.precision_chols = PyroModule() self._unconstrained_event_shapes = {} self._broken_event_shapes = {} self._broken_plates = defaultdict(tuple) @@ -1785,20 +1785,11 @@ def _setup_prototype(self, *args, **kwargs) -> None: batch_shape = torch.Size( f.size for f in sorted(precision_plates, key=lambda f: f.dim) ) - init_precision = torch.zeros(*batch_shape, precision_size, precision_size) - init_precision.view(-1, precision_size ** 2)[ - ..., :: precision_size + 1 - ].fill_( - 1 - ) # init to eye + eye = torch.eye(precision_size) + torch.zeros(batch_shape + (1, 1)) _deep_setattr( - self.precisions, + self.precision_chols, d, - PyroParam( - init_precision, - constraint=constraints.positive_definite, - event_dim=2, - ), + PyroParam(eye, constraint=constraints.lower_cholesky, event_dim=2), ) # Dispatch to backend logic. @@ -1916,9 +1907,11 @@ def _funsor_sample_aux_values( plate_to_dim.update({f.name: f.dim for f in particle_plates}) factors = {} for d, inputs in self._funsor_factor_inputs.items(): - precision = _deep_getattr(self.precisions, d) + precision_chol = _deep_getattr(self.precision_chols, d) + precision = precision_chol @ precision_chol.transpose(-1, -2) info_vec = precision.new_zeros(()).expand(precision.shape[:-1]) factors[d] = funsor.gaussian.Gaussian(info_vec, precision, inputs) + factors[d]._precision_chol = precision_chol # avoid recomputing # Perform Gaussian tensor variable elimination. samples, log_prob = funsor.recipes.forward_filter_backward_rsample( From 061875d13c1881e6a325fb31f426d0d1f04eae7a Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Thu, 23 Sep 2021 12:56:39 -0400 Subject: [PATCH 19/41] Add test of Gaussian .rsample() and .log_prob() --- tests/ops/test_gaussian.py | 69 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 69 insertions(+) diff --git a/tests/ops/test_gaussian.py b/tests/ops/test_gaussian.py index b602a0e07b..904aafc951 100644 --- a/tests/ops/test_gaussian.py +++ b/tests/ops/test_gaussian.py @@ -2,9 +2,11 @@ # SPDX-License-Identifier: Apache-2.0 import math +from collections import OrderedDict import pytest import torch +from torch.distributions import constraints, transform_to from torch.nn.functional import pad import pyro.distributions as dist @@ -418,3 +420,70 @@ def test_gaussian_tensordot( # TODO(fehiepsi): find some condition to make this test stable, so we can compare large value # log densities. assert_close(actual.clamp(max=10.0), expect.clamp(max=10.0), atol=0.1, rtol=0.1) + + +@pytest.mark.stage("funsor") +@pytest.mark.parametrize("batch_shape", [(), (5,), (4, 2)], ids=str) +def test_gaussian_funsor(batch_shape): + # This tests sample distribution, rsample gradients, log_prob, and log_prob + # gradients for both Pyro's and Funsor's Gaussian. + import funsor + + funsor.set_backend("torch") + num_samples = 100000 + + # Declare unconstrained parameters. + loc = torch.randn(batch_shape + (3,)).requires_grad_() + t = transform_to(constraints.positive_definite) + m = torch.randn(batch_shape + (3, 3)) + precision_unconstrained = t.inv(m @ m.transpose(-1, -2)).requires_grad_() + + # Transform to constrained space. + log_normalizer = torch.zeros(batch_shape) + precision = t(precision_unconstrained) + info_vec = (precision @ loc[..., None])[..., 0] + + def check_equal(actual, expected, atol=0.01, rtol=0): + assert_close(actual.data, expected.data, atol=atol, rtol=rtol) + grads = torch.autograd.grad( + (actual - expected).abs().sum(), + [loc, precision_unconstrained], + retain_graph=True, + ) + for grad in grads: + assert grad.abs().max() < atol + + entropy = dist.MultivariateNormal(loc, precision_matrix=precision).entropy() + + # Monte carlo estimate entropy via pyro. + p_gaussian = Gaussian(log_normalizer, info_vec, precision) + p_log_Z = p_gaussian.event_logsumexp() + p_rsamples = p_gaussian.rsample((num_samples,)) + pp_entropy = (p_log_Z - p_gaussian.log_density(p_rsamples)).mean(0) + check_equal(pp_entropy, entropy) + + # Monte carlo estimate entropy via funsor. + inputs = OrderedDict([(k, funsor.Bint[v]) for k, v in zip("ij", batch_shape)]) + inputs["x"] = funsor.Reals[3] + f_gaussian = funsor.gaussian.Gaussian(info_vec, precision, inputs) + f_log_Z = f_gaussian.reduce(funsor.ops.logaddexp, "x") + sample_inputs = OrderedDict(particle=funsor.Bint[num_samples]) + deltas = f_gaussian.sample("x", sample_inputs) + f_rsamples = funsor.montecarlo.extract_samples(deltas)["x"] + ff_entropy = (f_log_Z - f_gaussian(x=f_rsamples)).reduce( + funsor.ops.mean, "particle" + ) + check_equal(ff_entropy.data, entropy) + + # Check Funsor's .rsample against Pyro's .log_prob. + pf_entropy = (p_log_Z - p_gaussian.log_density(f_rsamples.data)).mean(0) + check_equal(pf_entropy, entropy) + + # Check Pyro's .rsample against Funsor's .log_prob. + fp_rsamples = funsor.Tensor(p_rsamples)["particle"] + for i in "ij"[: len(batch_shape)]: + fp_rsamples = fp_rsamples[i] + fp_entropy = (f_log_Z - f_gaussian(x=fp_rsamples)).reduce( + funsor.ops.mean, "particle" + ) + check_equal(fp_entropy.data, entropy) From e512ece476480f6c86f591c14cc807398459c6aa Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Thu, 23 Sep 2021 19:10:42 -0400 Subject: [PATCH 20/41] Sketch dense backend --- pyro/infer/autoguide/guides.py | 123 ++++++++++++++++++--- tests/infer/autoguide/test_autogaussian.py | 26 +++-- tests/infer/test_autoguide.py | 37 +++++-- 3 files changed, 156 insertions(+), 30 deletions(-) diff --git a/pyro/infer/autoguide/guides.py b/pyro/infer/autoguide/guides.py index a570a6a471..c0c659c8a1 100644 --- a/pyro/infer/autoguide/guides.py +++ b/pyro/infer/autoguide/guides.py @@ -1644,13 +1644,19 @@ def median(self, *args, **kwargs): class AutoGaussian(AutoGuide): """ - EXPERIMENTAL Gaussian tensor variable elimination guide [1,2]. + Gaussian guide with optimal conditional independence structure. This is equivalent to a full rank :class:`AutoMultivariateNormal` guide, but with a sparse precision matrix determined by dependencies and plates in - the model. This can be orders of magnitude cheaper than the naive - :class:`AutoMultivariateNormal` in terms of space, time, number of - parameters, and statistical complexity. + the model [1]. Depending on model structure, this can have asymptotically + better statistical efficiency than :class:`AutoMultivariateNormal` . + + The default "dense" backend should have similar computational complexity to + :class:`AutoMultivariateNormal` . The experimental "funsor" backend can be + asymptotically cheaper in terms of time and space (using Gaussian tensor + variable elimination [2,3]), but incurs large constant overhead. The + "funsor" backend requires `funsor `_ which can be + installed via ``pip install pyro-ppl[funsor]``. The guide currently does not depend on the model's ``*args, **kwargs``. @@ -1659,17 +1665,22 @@ class AutoGaussian(AutoGuide): guide = AutoGaussian(model) svi = SVI(model, guide, ...) - .. warning: This currently supports ``backend=funsor`` which depends on - the funsor package. You can install via - ``pip install pyro-ppl[funsor]``. + Example using funsor backend:: + + !pip install pyro-ppl[funsor] + guide = AutoGaussian(model, backend="funsor") + svi = SVI(model, guide, ...) **References** - [1] F. Obermeyer, E. Bingham, M. Jankowiak, J. Chiu, N. Pradhan, A. M. Rush, N. Goodman + [1] S.Webb, A.Goliński, R.Zinkov, N.Siddharth, T.Rainforth, Y.W.Teh, F.Wood (2018) + "Faithful inversion of generative models for effective amortized inference" + https://dl.acm.org/doi/10.5555/3327144.3327229 + [2] F.Obermeyer, E.Bingham, M.Jankowiak, J.Chiu, N.Pradhan, A.M.Rush, N.Goodman (2019) "Tensor Variable Elimination for Plated Factor Graphs" http://proceedings.mlr.press/v97/obermeyer19a/obermeyer19a.pdf - [2] F. Obermeyer, E. Bingham, M. Jankowiak, D. Phan, J. P. Chen + [3] F. Obermeyer, E. Bingham, M. Jankowiak, D. Phan, J. P. Chen (2019) "Functional Tensors for Probabilistic Programming" https://arxiv.org/abs/1910.10775 @@ -1684,8 +1695,7 @@ class AutoGaussian(AutoGuide): or iterable of plates. Plates not returned will be created automatically as usual. This is useful for data subsampling. :param str backend: Back end for performing Gaussian tensor variable - elimination. Currently only the experimental "funsor" backend is - supported. + elimination. Defaults to "dense". """ backend: str @@ -1699,6 +1709,8 @@ class AutoGaussian(AutoGuide): _broken_event_shapes: Dict[str, torch.Size] _broken_plates: Dict[str, Tuple[str, ...]] + # Class configurable parameters. + default_backend = "dense" scale_constraint = constraints.softplus_positive def __init__( @@ -1708,7 +1720,7 @@ def __init__( init_loc_fn: Callable = init_to_feasible, init_scale: float = 0.1, create_plates: Optional[Callable] = None, - backend="funsor", + backend=None, ): if not isinstance(init_scale, float) or not (init_scale > 0): raise ValueError(f"Expected init_scale > 0. but got {init_scale}") @@ -1716,7 +1728,7 @@ def __init__( self._original_model = (model,) model = InitMessenger(init_loc_fn)(model) super().__init__(model, create_plates=create_plates) - self.backend = backend + self.backend = self.default_backend if backend is None else backend def _setup_prototype(self, *args, **kwargs) -> None: super()._setup_prototype(*args, **kwargs) @@ -1858,6 +1870,91 @@ def _sample_aux_values( raise NotImplementedError(f"Unknown AutoGaussian backend: {self.backend}") return backend_fn() + ############################################################################ + # Dense backend + + def _dense_setup_prototype(self, *args, **kwargs): + # Collect flat and individual shapes. + self._dense_shapes = {} + pos = 0 + offsets = {} + for d, event_shape in self._unconstrained_event_shapes.items(): + batch_shape = self.prototype_trace.nodes[d]["fn"].batch_shape + self._dense_shapes[d] = batch_shape, event_shape + offsets[d] = pos + pos += (batch_shape + event_shape).numel() + self._dense_size = pos + + # Create sparse -> dense precision matrix maps. + self._dense_factor_scatter = {} + for d, site in self._sorted_sites.items(): + # Order inputs as in the model, so as to maximize sparsity of the + # lower Cholesky parametrization of the precision matrix. + event_indices = [] + for f in site["cond_indep_stack"]: + if f.vectorized: + if f.name not in self._broken_plates[d]: + f.size + for u in self.dependencies[d]: + start = offsets[u] + stop = start + self._broken_event_shapes[u].numel() + event_indices.append(torch.arange(start, stop)) + if not site["is_observed"]: + start = offsets[d] + stop = start + self._broken_event_shapes[d].numel() + event_indices.append(torch.arange(start, stop)) + event_index = torch.cat(event_indices) + precision_shape = _deep_getattr(self.precision_chols, d).shape + index = torch.zeros(precision_shape, dtype=torch.long) + stride = 1 + index += event_index * stride + stride *= self._dense_size + index += event_index[:, None] * stride + stride *= self._dense_size + # TODO add batch shapes + self._dense_factor_scatter[d] = index.reshape(-1) + + def _dense_sample_aux_values( + self, + ) -> Tuple[Dict[str, torch.Tensor], Union[float, torch.Tensor]]: + from pyro.ops import Gaussian + + # Convert to a flat dense joint Gaussian. + flat_precision = torch.zeros(self._dense_size ** 2) + for d, index in self._dense_factor_scatter: + precision_chol = _deep_getattr(self.precision_chols, d) + precision = precision_chol @ precision_chol.transpose(-1, -2) + flat_precision.scatter_add_(0, index, precision.reshape(-1)) + precision = flat_precision.reshape(self._dense_size, self._dense_size) + info_vec = torch.zeros(self._dense_size) + log_normalizer = torch.zeros(()) + g = Gaussian(log_normalizer, info_vec, precision) + + # Draw a batch of samples. + particle_plates = frozenset(get_plates()) + sample_shape = [1] * max([0] + [-p.dim for p in particle_plates]) + for p in particle_plates: + sample_shape[p.dim] = p.size + sample_shape = torch.Size(sample_shape) + flat_samples = g.rsample(sample_shape) + log_density = g.log_density(flat_samples) - g.event_logsumexp() + + # Convert flat to shaped tensors. + samples = {} + pos = 0 + for d, (batch_shape, event_shape) in self._dense_shapes.items(): + numel = _deep_getattr(self.locs, d).numel() + flat_sample = flat_samples[pos : pos + numel] + pos += numel + # Assumes sample shapes are left of batch shapes. + samples[d] = flat_sample.reshape( + torch.broadcast_shapes(sample_shape, batch_shape) + event_shape + ) + return samples, log_density + + ############################################################################ + # Funsor backend + def _funsor_setup_prototype(self, *args, **kwargs): try: import funsor diff --git a/tests/infer/autoguide/test_autogaussian.py b/tests/infer/autoguide/test_autogaussian.py index 223bc9edb6..8871578523 100644 --- a/tests/infer/autoguide/test_autogaussian.py +++ b/tests/infer/autoguide/test_autogaussian.py @@ -14,8 +14,10 @@ from pyro.infer.reparam import LocScaleReparam from pyro.optim import Adam -# AutoGaussian currently depends on funsor. -pytestmark = pytest.mark.stage("funsor") +BACKENDS = [ + "dense", + pytest.param("funsor", marks=[pytest.mark.stage("funsor")]), +] # Simplified from https://github.com/pyro-cov/tree/master/pyrocov/mutrans.py @@ -156,7 +158,8 @@ def pyrocov_model_plated(dataset): @pytest.mark.parametrize( "model", [pyrocov_model, pyrocov_model_relaxed, pyrocov_model_plated] ) -def test_pyrocov_smoke(model): +@pytest.mark.parametrize("backend", BACKENDS) +def test_pyrocov_smoke(model, backend): T, P, S, F = 3, 4, 5, 6 dataset = { "features": torch.randn(S, F), @@ -164,7 +167,7 @@ def test_pyrocov_smoke(model): "weekly_strains": torch.randn(T, P, S).exp().round(), } - guide = AutoGaussian(model) + guide = AutoGaussian(model, backend=backend) svi = SVI(model, guide, Adam({"lr": 1e-8}), Trace_ELBO()) for step in range(2): svi.step(dataset) @@ -176,7 +179,8 @@ def test_pyrocov_smoke(model): @pytest.mark.parametrize( "model", [pyrocov_model, pyrocov_model_relaxed, pyrocov_model_plated] ) -def test_pyrocov_reparam(model): +@pytest.mark.parametrize("backend", BACKENDS) +def test_pyrocov_reparam(model, backend): T, P, S, F = 2, 3, 4, 5 dataset = { "features": torch.randn(S, F), @@ -193,7 +197,7 @@ def test_pyrocov_reparam(model): "init": LocScaleReparam(), } model = poutine.reparam(model, config) - guide = AutoGaussian(model) + guide = AutoGaussian(model, backend=backend) svi = SVI(model, guide, Adam({"lr": 1e-8}), Trace_ELBO()) for step in range(2): svi.step(dataset) @@ -202,7 +206,8 @@ def test_pyrocov_reparam(model): predictive(dataset) -def test_pyrocov_structure(): +@pytest.mark.parametrize("backend", BACKENDS) +def test_pyrocov_structure(backend): from funsor import Bint, Real, Reals T, P, S, F = 2, 3, 4, 5 @@ -212,7 +217,7 @@ def test_pyrocov_structure(): "weekly_strains": torch.randn(T, P, S).exp().round(), } - guide = AutoGaussian(pyrocov_model_plated) + guide = AutoGaussian(pyrocov_model_plated, backend=backend) guide(dataset) # initialize expected_plates = frozenset(["place", "feature", "strain"]) @@ -266,7 +271,8 @@ def test_pyrocov_structure(): assert guide._funsor_factor_inputs == expected_factor_inputs -def test_profile(n=1, num_steps=1): +@pytest.mark.parametrize("backend", BACKENDS) +def test_profile(n=1, num_steps=1, backend="funsor"): """ Helper function for profiling. """ @@ -295,4 +301,4 @@ def test_profile(n=1, num_steps=1): if __name__ == "__main__": - test_profile(n=10, num_steps=100) + test_profile(n=10, num_steps=100, backend="funsor") diff --git a/tests/infer/test_autoguide.py b/tests/infer/test_autoguide.py index 9c6a8f4573..21e4ec4214 100644 --- a/tests/infer/test_autoguide.py +++ b/tests/infer/test_autoguide.py @@ -43,12 +43,10 @@ from pyro.util import check_model_guide_match from tests.common import assert_close, assert_equal -# AutoGaussian currently depends on funsor. -AutoGaussian_median = pytest.param( - functools.partial(AutoGaussian, init_loc_fn=init_to_median), - marks=[pytest.mark.stage("funsor")], -) -AutoGaussian = pytest.param(AutoGaussian, marks=[pytest.mark.stage("funsor")]) + +@functools.partial(pytest.param, marks=[pytest.mark.stage("funsor")]) +class AutoGaussian_funsor(AutoGaussian): + default_backend = "funsor" @pytest.mark.parametrize( @@ -97,6 +95,7 @@ def model(): AutoIAFNormal, AutoLaplaceApproximation, AutoGaussian, + AutoGaussian_funsor, ], ) def test_factor(auto_class, Elbo): @@ -194,6 +193,7 @@ def dependency_z6_z5(z5): AutoStructured, AutoStructured_shapes, AutoGaussian, + AutoGaussian_funsor, ], ) @pytest.mark.filterwarnings("ignore::FutureWarning") @@ -232,6 +232,7 @@ def model(): AutoIAFNormal, AutoLaplaceApproximation, AutoGaussian, + AutoGaussian_funsor, ], ) @pytest.mark.parametrize("Elbo", [Trace_ELBO, TraceGraph_ELBO]) @@ -345,6 +346,7 @@ def __init__(self, model): AutoStructured, AutoStructured_median, AutoGaussian, + AutoGaussian_funsor, ], ) @pytest.mark.parametrize("Elbo", [Trace_ELBO, TraceGraph_ELBO, TraceEnum_ELBO]) @@ -513,6 +515,7 @@ def model(): AutoIAFNormal, AutoLaplaceApproximation, AutoGaussian, + AutoGaussian_funsor, ], ) def test_discrete_parallel(continuous_class): @@ -549,6 +552,7 @@ def model(data): AutoIAFNormal, AutoLaplaceApproximation, AutoGaussian, + AutoGaussian_funsor, ], ) def test_guide_list(auto_class): @@ -572,6 +576,7 @@ def model(): AutoLowRankMultivariateNormal, AutoLaplaceApproximation, AutoGaussian, + AutoGaussian_funsor, ], ) def test_callable(auto_class): @@ -601,6 +606,7 @@ def guide_x(): AutoLowRankMultivariateNormal, AutoLaplaceApproximation, AutoGaussian, + AutoGaussian_funsor, ], ) def test_callable_return_dict(auto_class): @@ -648,6 +654,7 @@ def model(): AutoMultivariateNormal, AutoLowRankMultivariateNormal, AutoGaussian, + AutoGaussian_funsor, ], ) def test_init_loc_fn(auto_class): @@ -711,7 +718,7 @@ def model(): functools.partial(AutoDiagonalNormal, init_loc_fn=init_to_mean), functools.partial(AutoDiagonalNormal, init_loc_fn=init_to_median), functools.partial(AutoNormal, init_loc_fn=init_to_median), - AutoGaussian_median, + functools.partial(AutoGaussian, init_loc_fn=init_to_median), ], ) @pytest.mark.parametrize("Elbo", [Trace_ELBO, TraceGraph_ELBO, TraceEnum_ELBO]) @@ -799,6 +806,7 @@ def forward(self): functools.partial(AutoDiagonalNormal, init_loc_fn=init_to_mean), functools.partial(AutoDiagonalNormal, init_loc_fn=init_to_median), AutoGaussian, + AutoGaussian_funsor, ], ) @pytest.mark.parametrize("Elbo", [Trace_ELBO, TraceGraph_ELBO, TraceEnum_ELBO]) @@ -929,6 +937,7 @@ def forward(self, x, y=None): functools.partial(AutoDiagonalNormal, init_loc_fn=init_to_median), AutoStructured, AutoGaussian, + AutoGaussian_funsor, ], ) def test_replay_plates(auto_class, sample_shape): @@ -1078,6 +1087,7 @@ def create_plates(data): AutoLowRankMultivariateNormal, AutoLaplaceApproximation, AutoGaussian, + AutoGaussian_funsor, ], ) @pytest.mark.parametrize( @@ -1113,6 +1123,7 @@ def model(): AutoLowRankMultivariateNormal, AutoLaplaceApproximation, AutoGaussian, + AutoGaussian_funsor, ], ) @pytest.mark.parametrize( @@ -1145,6 +1156,7 @@ def model(): AutoLowRankMultivariateNormal, AutoLaplaceApproximation, AutoGaussian, + AutoGaussian_funsor, ], ) @pytest.mark.parametrize( @@ -1219,6 +1231,7 @@ def __init__(self, model): AutoStructured_exact_normal, AutoStructured_exact_mvn, AutoGaussian, + AutoGaussian_funsor, ], ) def test_exact(Guide): @@ -1247,6 +1260,14 @@ def model(data): expected_loss = float(g.event_logsumexp() - g.condition(data).event_logsumexp()) guide = Guide(model) + + # DEBUG + guide(data) + # guide.scales.train(False) + # guide.scales.requires_grad_(False) + # guide.precision_chols.train(False) + # guide.precision_chols.requires_grad_(False) + elbo = Trace_ELBO(num_particles=100, vectorize_particles=True) optim = Adam({"lr": 0.01}) svi = SVI(model, guide, optim, elbo) @@ -1279,6 +1300,7 @@ def model(data): AutoStructured_exact_normal, AutoStructured_exact_mvn, AutoGaussian, + AutoGaussian_funsor, ], ) def test_exact_batch(Guide): @@ -1335,6 +1357,7 @@ def model(data): AutoLowRankMultivariateNormal, AutoStructured, AutoGaussian, + AutoGaussian_funsor, ], ) def test_exact_tree(Guide): From 98ead93dae46ecf903101614418e38b1813cd336 Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Thu, 23 Sep 2021 19:23:46 -0400 Subject: [PATCH 21/41] Update docs --- docs/source/infer.autoguide.rst | 31 ++++++++++++++++--------------- pyro/infer/autoguide/guides.py | 17 ++++++++++++----- 2 files changed, 28 insertions(+), 20 deletions(-) diff --git a/docs/source/infer.autoguide.rst b/docs/source/infer.autoguide.rst index e647ce9766..6ea6564233 100644 --- a/docs/source/infer.autoguide.rst +++ b/docs/source/infer.autoguide.rst @@ -8,7 +8,7 @@ AutoGuide .. autoclass:: pyro.infer.autoguide.AutoGuide :members: :undoc-members: - :special-members: __call__ + :member-order: bysource :show-inheritance: AutoGuideList @@ -16,7 +16,7 @@ AutoGuideList .. autoclass:: pyro.infer.autoguide.AutoGuideList :members: :undoc-members: - :special-members: __call__ + :member-order: bysource :show-inheritance: AutoCallable @@ -24,7 +24,7 @@ AutoCallable .. autoclass:: pyro.infer.autoguide.AutoCallable :members: :undoc-members: - :special-members: __call__ + :member-order: bysource :show-inheritance: AutoNormal @@ -32,7 +32,7 @@ AutoNormal .. autoclass:: pyro.infer.autoguide.AutoNormal :members: :undoc-members: - :special-members: __call__ + :member-order: bysource :show-inheritance: AutoDelta @@ -40,7 +40,7 @@ AutoDelta .. autoclass:: pyro.infer.autoguide.AutoDelta :members: :undoc-members: - :special-members: __call__ + :member-order: bysource :show-inheritance: AutoContinuous @@ -48,7 +48,7 @@ AutoContinuous .. autoclass:: pyro.infer.autoguide.AutoContinuous :members: :undoc-members: - :special-members: __call__ + :member-order: bysource :show-inheritance: AutoMultivariateNormal @@ -56,7 +56,7 @@ AutoMultivariateNormal .. autoclass:: pyro.infer.autoguide.AutoMultivariateNormal :members: :undoc-members: - :special-members: __call__ + :member-order: bysource :show-inheritance: AutoDiagonalNormal @@ -64,7 +64,7 @@ AutoDiagonalNormal .. autoclass:: pyro.infer.autoguide.AutoDiagonalNormal :members: :undoc-members: - :special-members: __call__ + :member-order: bysource :show-inheritance: AutoLowRankMultivariateNormal @@ -72,7 +72,7 @@ AutoLowRankMultivariateNormal .. autoclass:: pyro.infer.autoguide.AutoLowRankMultivariateNormal :members: :undoc-members: - :special-members: __call__ + :member-order: bysource :show-inheritance: @@ -81,7 +81,7 @@ AutoNormalizingFlow .. autoclass:: pyro.infer.autoguide.AutoNormalizingFlow :members: :undoc-members: - :special-members: __call__ + :member-order: bysource :show-inheritance: @@ -90,7 +90,7 @@ AutoIAFNormal .. autoclass:: pyro.infer.autoguide.AutoIAFNormal :members: :undoc-members: - :special-members: __call__ + :member-order: bysource :show-inheritance: AutoLaplaceApproximation @@ -98,7 +98,7 @@ AutoLaplaceApproximation .. autoclass:: pyro.infer.autoguide.AutoLaplaceApproximation :members: :undoc-members: - :special-members: __call__ + :member-order: bysource :show-inheritance: AutoDiscreteParallel @@ -106,7 +106,7 @@ AutoDiscreteParallel .. autoclass:: pyro.infer.autoguide.AutoDiscreteParallel :members: :undoc-members: - :special-members: __call__ + :member-order: bysource :show-inheritance: AutoStructured @@ -114,7 +114,7 @@ AutoStructured .. autoclass:: pyro.infer.autoguide.AutoStructured :members: :undoc-members: - :special-members: __call__ + :member-order: bysource :show-inheritance: AutoGaussian @@ -122,6 +122,7 @@ AutoGaussian .. autoclass:: pyro.infer.autoguide.AutoGaussian :members: :undoc-members: + :member-order: bysource :show-inheritance: .. _autoguide-initialization: @@ -132,5 +133,5 @@ Initialization :members: :undoc-members: :special-members: __call__ - :show-inheritance: :member-order: bysource + :show-inheritance: diff --git a/pyro/infer/autoguide/guides.py b/pyro/infer/autoguide/guides.py index c0c659c8a1..8d90582dd5 100644 --- a/pyro/infer/autoguide/guides.py +++ b/pyro/infer/autoguide/guides.py @@ -1695,9 +1695,14 @@ class AutoGaussian(AutoGuide): or iterable of plates. Plates not returned will be created automatically as usual. This is useful for data subsampling. :param str backend: Back end for performing Gaussian tensor variable - elimination. Defaults to "dense". + elimination. Defaults to "dense"; other options include "funsor". """ + # Class configurable parameters. + default_backend: str = "dense" + scale_constraint = constraints.softplus_positive + + # Type hints for instance variables. backend: str locs: PyroModule scales: PyroModule @@ -1709,10 +1714,6 @@ class AutoGaussian(AutoGuide): _broken_event_shapes: Dict[str, torch.Size] _broken_plates: Dict[str, Tuple[str, ...]] - # Class configurable parameters. - default_backend = "dense" - scale_constraint = constraints.softplus_positive - def __init__( self, model: Callable, @@ -1832,6 +1833,12 @@ def forward(self, *args, **kwargs) -> Dict[str, torch.Tensor]: return values def median(self) -> Dict[str, torch.Tensor]: + """ + Returns the posterior median value of each latent variable. + + :return: A dict mapping sample site name to median tensor. + :rtype: dict + """ with torch.no_grad(), poutine.mask(mask=False): aux_values = {name: 0.0 for name in self._sorted_sites} values, _ = self._transform_values(aux_values) From 76a8f51e9c014cf38df9ac7b6b4f3774c00c23fa Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Tue, 28 Sep 2021 00:07:57 -0400 Subject: [PATCH 22/41] Perfect precision parametrization (breaking both backends) --- pyro/infer/autoguide/guides.py | 47 +++++++++++++++++----------------- 1 file changed, 23 insertions(+), 24 deletions(-) diff --git a/pyro/infer/autoguide/guides.py b/pyro/infer/autoguide/guides.py index 8d90582dd5..56a92de8de 100644 --- a/pyro/infer/autoguide/guides.py +++ b/pyro/infer/autoguide/guides.py @@ -1737,30 +1737,23 @@ def _setup_prototype(self, *args, **kwargs) -> None: self.locs = PyroModule() self.scales = PyroModule() self.precision_chols = PyroModule() + self._sorted_sites = OrderedDict() self._unconstrained_event_shapes = {} - self._broken_event_shapes = {} - self._broken_plates = defaultdict(tuple) model = self._original_model[0] meta = poutine.block(get_dependencies)(model, args, kwargs) self.dependencies = meta["prior_dependencies"] - broken_plates: Set[str] = { - p - for upstreams in meta["posterior_dependencies"].values() - for plates in upstreams.values() - for p in plates - } - self._sorted_sites = { - name: site - for name, site in self.prototype_trace.nodes.items() - if site["type"] == "sample" - if not site_is_subsample(site) - } + for name, site in self.prototype_trace.nodes.items(): + if site["type"] == "sample": + if not site_is_subsample(site): + self._sorted_sites[name] = site for d, site in self._sorted_sites.items(): precision_size = 0 precision_plates: Set[CondIndepStackFrame] = set() if not site["is_observed"]: # Initialize latent variable location-scale parameters. + # The scale parameters are statistically redundant, but improve + # learning with coordinate-wise optimizers. with helpful_support_errors(site): init_loc = biject_to(site["fn"].support).inv(site["value"]).detach() init_scale = torch.full_like(init_loc, self._init_scale) @@ -1780,21 +1773,27 @@ def _setup_prototype(self, *args, **kwargs) -> None: ) # Gather shapes for precision matrices. - broken_shape = torch.Size() for f in site["cond_indep_stack"]: if f.vectorized: - if f.name in broken_plates: - self._broken_plates[d] += (f.name,) - broken_shape += (f.size,) - else: - precision_plates.add(f) - self._broken_event_shapes[d] = broken_shape + event_shape + precision_plates.add(f) + precision_size.add(event_shape.numel()) # Initialize precision matrices. + # This adds a batched dense matrix for each factor, achieving + # statistically optimal sparsity structure of the model's joint + # precision matrix. Multiple factors may redundantly parametrize + # entries of the precision matrix on which they overlap, incurring + # slight computational cost but no cost to statistical efficiency. for u in self.dependencies[d]: - precision_size += self._broken_event_shapes[u].numel() - for f in self.prototype_trace.nodes[u]["cond_indep_stack"]: - assert f in precision_plates or f.name in broken_plates + u_site = self.prototype_trace.nodes[u] + u_numel = self._unconstrained_event_shapes[u].numel() + for f in u_site["cond_indep_stack"]: + if f.vectorized: + if f in site["cond_indep_stack"]: + precision_plates.add(f) + else: + u_numel *= f.size + precision_size += u_numel batch_shape = torch.Size( f.size for f in sorted(precision_plates, key=lambda f: f.dim) ) From 2b8eb08b242e06ae574050b97ca207694eb9b4c8 Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Wed, 29 Sep 2021 08:39:50 -0400 Subject: [PATCH 23/41] Minor updates --- pyro/infer/autoguide/guides.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pyro/infer/autoguide/guides.py b/pyro/infer/autoguide/guides.py index 56a92de8de..b58f1ea8a4 100644 --- a/pyro/infer/autoguide/guides.py +++ b/pyro/infer/autoguide/guides.py @@ -1880,7 +1880,7 @@ def _sample_aux_values( # Dense backend def _dense_setup_prototype(self, *args, **kwargs): - # Collect flat and individual shapes. + # Collect flat and individual and aggregated flat shapes. self._dense_shapes = {} pos = 0 offsets = {} @@ -1907,7 +1907,7 @@ def _dense_setup_prototype(self, *args, **kwargs): event_indices.append(torch.arange(start, stop)) if not site["is_observed"]: start = offsets[d] - stop = start + self._broken_event_shapes[d].numel() + stop = start + self._unconstrained_event_shapes[d].numel() event_indices.append(torch.arange(start, stop)) event_index = torch.cat(event_indices) precision_shape = _deep_getattr(self.precision_chols, d).shape From 49dfd8c2c32ea848952216255b321ed56cb0e015 Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Thu, 30 Sep 2021 21:39:23 -0400 Subject: [PATCH 24/41] Change precision representation, start fixing dense backend --- pyro/infer/autoguide/gaussian.py | 194 ++++++++++++++++--------------- 1 file changed, 100 insertions(+), 94 deletions(-) diff --git a/pyro/infer/autoguide/gaussian.py b/pyro/infer/autoguide/gaussian.py index 50e187d7a6..cddba211f9 100644 --- a/pyro/infer/autoguide/gaussian.py +++ b/pyro/infer/autoguide/gaussian.py @@ -14,7 +14,6 @@ from pyro.distributions import constraints from pyro.infer.inspect import get_dependencies from pyro.nn.module import PyroModule, PyroParam -from pyro.poutine.indep_messenger import CondIndepStackFrame from pyro.poutine.runtime import am_i_wrapped, get_plates from pyro.poutine.util import site_is_subsample @@ -79,22 +78,9 @@ class AutoGaussian(AutoGuide): elimination. Defaults to "dense"; other options include "funsor". """ - # Class configurable parameters. default_backend: str = "dense" scale_constraint = constraints.softplus_positive - # Type hints for instance variables. - backend: str - locs: PyroModule - scales: PyroModule - precision_chols: PyroModule - _sorted_sites: Dict[str, Dict[str, object]] - _init_scale: float - _original_model: Tuple[Callable] - _unconstrained_event_shapes: Dict[str, torch.Size] - _broken_event_shapes: Dict[str, torch.Size] - _broken_plates: Dict[str, Tuple[str, ...]] - def __init__( self, model: Callable, @@ -117,73 +103,69 @@ def _setup_prototype(self, *args, **kwargs) -> None: self.locs = PyroModule() self.scales = PyroModule() - self.precision_chols = PyroModule() - self._sorted_sites = OrderedDict() - self._unconstrained_event_shapes = {} + self.factors = PyroModule() + self._factors = OrderedDict() + self._plates = OrderedDict() + self._event_numel = OrderedDict() + self._unconstrained_event_shapes = OrderedDict() + # Trace model dependencies. model = self._original_model[0] - meta = poutine.block(get_dependencies)(model, args, kwargs) - self.dependencies = meta["prior_dependencies"] + self.dependencies = poutine.block(get_dependencies)(model, args, kwargs)[ + "prior_dependencies" + ] + + # Collect factors and plates. for name, site in self.prototype_trace.nodes.items(): - if site["type"] == "sample": - if not site_is_subsample(site): - self._sorted_sites[name] = site - for d, site in self._sorted_sites.items(): - precision_size = 0 - precision_plates: Set[CondIndepStackFrame] = set() + if site["type"] == "sample" and not site_is_subsample(site): + assert all(f.vectorized for f in site["cond_indep_stack"]) + self._factors[name] = site + plates = frozenset(site["cond_indep_stack"]) + if site["is_observed"]: + # Eagerly eliminate irrelevant observation plates. + plates &= frozenset.union( + *(self._plates[u] for u in self.dependencies[d] if u != d) + ) + self._plates[name] = plates + + # Create location-scale parameters, one per latent variable. + for d, site in self._factors.items(): if not site["is_observed"]: - # Initialize latent variable location-scale parameters. - # The scale parameters are statistically redundant, but improve - # learning with coordinate-wise optimizers. with helpful_support_errors(site): init_loc = biject_to(site["fn"].support).inv(site["value"]).detach() - init_scale = torch.full_like(init_loc, self._init_scale) batch_shape = site["fn"].batch_shape event_shape = init_loc.shape[len(batch_shape) :] self._unconstrained_event_shapes[d] = event_shape + self._event_numel[d] = event_shape.numel() event_dim = len(event_shape) deep_setattr(self.locs, d, PyroParam(init_loc, event_dim=event_dim)) deep_setattr( self.scales, d, PyroParam( - init_scale, + torch.full_like(init_loc, self._init_scale), constraint=self.scale_constraint, event_dim=event_dim, ), ) - # Gather shapes for precision matrices. - for f in site["cond_indep_stack"]: - if f.vectorized: - precision_plates.add(f) - precision_size.add(event_shape.numel()) - - # Initialize precision matrices. - # This adds a batched dense matrix for each factor, achieving - # statistically optimal sparsity structure of the model's joint - # precision matrix. Multiple factors may redundantly parametrize - # entries of the precision matrix on which they overlap, incurring - # slight computational cost but no cost to statistical efficiency. + # Create parameters for dependencies, one per factor. + for d, site in self._factors.items(): + u_size = 0 for u in self.dependencies[d]: - u_site = self.prototype_trace.nodes[u] - u_numel = self._unconstrained_event_shapes[u].numel() - for f in u_site["cond_indep_stack"]: - if f.vectorized: - if f in site["cond_indep_stack"]: - precision_plates.add(f) - else: - u_numel *= f.size - precision_size += u_numel - batch_shape = torch.Size( - f.size for f in sorted(precision_plates, key=lambda f: f.dim) - ) - eye = torch.eye(precision_size) + torch.zeros(batch_shape + (1, 1)) - deep_setattr( - self.precision_chols, - d, - PyroParam(eye, constraint=constraints.lower_cholesky, event_dim=2), - ) + if not self._factors[u]["is_observed"]: + broken_shape = _plates_to_batch_shape( + self._plates[u] - self._plates[d] + ) + u_size += broken_shape.numel() * self._event_numel[u] + d_size = self._event_numel[d] + if site["is_observed"]: + d_size = min(d_size, u_size) # just an optimization + batch_shape = _plates_to_batch_shape(self._plates[d]) + + # Create a square root (not necessarily lower triangular). + raw = init_loc.new_zeros(batch_shape, u_size, d_size) + deep_setattr(self, self.factors, raw, event_dim=2) # Dispatch to backend logic. backend_fn = getattr(self, f"_{self.backend}_setup_prototype", None) @@ -199,7 +181,7 @@ def forward(self, *args, **kwargs) -> Dict[str, torch.Tensor]: # Replay via Pyro primitives. plates = self._create_plates(*args, **kwargs) - for name, site in self._sorted_sites.items(): + for name, site in self._factors.items(): with ExitStack() as stack: for frame in site["cond_indep_stack"]: if frame.vectorized: @@ -220,7 +202,7 @@ def median(self) -> Dict[str, torch.Tensor]: :rtype: dict """ with torch.no_grad(), poutine.mask(mask=False): - aux_values = {name: 0.0 for name in self._sorted_sites} + aux_values = {name: 0.0 for name in self._factors} values, _ = self._transform_values(aux_values) return values @@ -232,7 +214,7 @@ def _transform_values( values = {} log_densities = defaultdict(float) compute_density = am_i_wrapped() and poutine.get_mask() is not False - for name, site in self._sorted_sites.items(): + for name, site in self._factors.items(): loc = deep_getattr(self.locs, name) scale = deep_getattr(self.scales, name) unconstrained = aux_values[name] * scale + loc @@ -258,40 +240,48 @@ def _sample_aux_values( return backend_fn() ############################################################################ - # Dense backend + # Dense backend. Methods and attributes are prefixed by ._dense_ def _dense_setup_prototype(self, *args, **kwargs): # Collect flat and individual and aggregated flat shapes. self._dense_shapes = {} + dense_gather = {} pos = 0 - offsets = {} for d, event_shape in self._unconstrained_event_shapes.items(): - batch_shape = self.prototype_trace.nodes[d]["fn"].batch_shape + batch_shape = self._factors[d]["fn"].batch_shape self._dense_shapes[d] = batch_shape, event_shape - offsets[d] = pos - pos += (batch_shape + event_shape).numel() + shape = batch_shape + event_shape + end = pos + shape.numel() + dense_gather[d] = torch.arange(pos, end).reshape(shape) + pos = end self._dense_size = pos - # Create sparse -> dense precision matrix maps. - self._dense_factor_scatter = {} - for d, site in self._sorted_sites.items(): - # Order inputs as in the model, so as to maximize sparsity of the - # lower Cholesky parametrization of the precision matrix. + # Create sparse -> dense precision scatter indices. + self._dense_scatter = {} + for d, site in self._factors.items(): + raw_shape = deep_getattr(self.factors, d).shape + precision_shape = raw_shape[:-1] + raw_shape[-2:-1] + index = torch.zeros(precision_shape, dtype=torch.long) + for u in self.dependencies[d]: + if not self._factors[u]["is_observed"]: + batch_plates = self._plates[u] & self._plates[d] # linear + broken_plates = self._plates[u] - self._plates[d] # quadratic + event_numel = self._event_numel[u] # quadratic + "TODO" + self._dense_scatter[d] = index.reshape(-1) + + ######################################################### + # OLD + for d, site in self._factors.items(): event_indices = [] - for f in site["cond_indep_stack"]: - if f.vectorized: - if f.name not in self._broken_plates[d]: - f.size for u in self.dependencies[d]: - start = offsets[u] - stop = start + self._broken_event_shapes[u].numel() - event_indices.append(torch.arange(start, stop)) - if not site["is_observed"]: - start = offsets[d] - stop = start + self._unconstrained_event_shapes[d].numel() - event_indices.append(torch.arange(start, stop)) + if not self._factors[u]["is_observed"]: + start = offsets[u] + stop = start + "TODO" + event_indices.append(torch.arange(start, stop)) event_index = torch.cat(event_indices) - precision_shape = deep_getattr(self.precision_chols, d).shape + raw_shape = deep_getattr(self.factors, d).shape + precision_shape = raw_shape[:-1] + raw_shape[-2:-1] index = torch.zeros(precision_shape, dtype=torch.long) stride = 1 index += event_index * stride @@ -299,7 +289,7 @@ def _dense_setup_prototype(self, *args, **kwargs): index += event_index[:, None] * stride stride *= self._dense_size # TODO add batch shapes - self._dense_factor_scatter[d] = index.reshape(-1) + self._dense_scatter[d] = index.reshape(-1) def _dense_sample_aux_values( self, @@ -308,9 +298,9 @@ def _dense_sample_aux_values( # Convert to a flat dense joint Gaussian. flat_precision = torch.zeros(self._dense_size ** 2) - for d, index in self._dense_factor_scatter: - precision_chol = deep_getattr(self.precision_chols, d) - precision = precision_chol @ precision_chol.transpose(-1, -2) + for d, index in self._dense_scatter: + raw = deep_getattr(self.factors, d) + precision = _raw_to_precision(raw) flat_precision.scatter_add_(0, index, precision.reshape(-1)) precision = flat_precision.reshape(self._dense_size, self._dense_size) info_vec = torch.zeros(self._dense_size) @@ -330,9 +320,9 @@ def _dense_sample_aux_values( samples = {} pos = 0 for d, (batch_shape, event_shape) in self._dense_shapes.items(): - numel = deep_getattr(self.locs, d).numel() - flat_sample = flat_samples[pos : pos + numel] - pos += numel + end = pos + self._event_numel[d] + flat_sample = flat_samples[pos:end] + pos = end # Assumes sample shapes are left of batch shapes. samples[d] = flat_sample.reshape( torch.broadcast_shapes(sample_shape, batch_shape) + event_shape @@ -340,7 +330,7 @@ def _dense_sample_aux_values( return samples, log_density ############################################################################ - # Funsor backend + # Funsor backend. Methods and attributes are prefixed by ._funsor_ def _funsor_setup_prototype(self, *args, **kwargs): try: @@ -357,7 +347,7 @@ def _funsor_setup_prototype(self, *args, **kwargs): eliminate: Set[str] = set() plate_to_dim: Dict[str, int] = {} - for d, site in self._sorted_sites.items(): + for d, site in self._factors.items(): # Order inputs as in the model, so as to maximize sparsity of the # lower Cholesky parametrization of the precision matrix. inputs = OrderedDict() @@ -413,3 +403,19 @@ def _funsor_sample_aux_values( log_density = funsor.to_data(log_prob, name_to_dim=plate_to_dim) return samples, log_density + + +def _raw_to_precision(raw): + """ + Transform an unconstrained matrix of shape ``batch_shape + (m, n)`` to a + positive semidefinite precision matrix of shape ``batch_shape + (m, m)``. + Typically ``m >= n``. + """ + return raw @ raw.transpose(dim1=-2, dim2=-1) + + +def _plates_to_batch_shape(plates): + shape = [1] * max([0] + [-f.dim for f in plates]) + for f in plates: + shape[f.dim] = f.size + return torch.Size(shape) From 56a7e1f7dd415df9d2a3c40b138e5ebe8fade60a Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Fri, 1 Oct 2021 21:08:39 -0400 Subject: [PATCH 25/41] Add more tests --- pyro/infer/autoguide/gaussian.py | 26 ++-- ...{test_autogaussian.py => test_gaussian.py} | 117 +++++++++++++++++- 2 files changed, 130 insertions(+), 13 deletions(-) rename tests/infer/autoguide/{test_autogaussian.py => test_gaussian.py} (77%) diff --git a/pyro/infer/autoguide/gaussian.py b/pyro/infer/autoguide/gaussian.py index cddba211f9..94c49c8a8b 100644 --- a/pyro/infer/autoguide/gaussian.py +++ b/pyro/infer/autoguide/gaussian.py @@ -116,17 +116,17 @@ def _setup_prototype(self, *args, **kwargs) -> None: ] # Collect factors and plates. - for name, site in self.prototype_trace.nodes.items(): + for d, site in self.prototype_trace.nodes.items(): if site["type"] == "sample" and not site_is_subsample(site): assert all(f.vectorized for f in site["cond_indep_stack"]) - self._factors[name] = site + self._factors[d] = site plates = frozenset(site["cond_indep_stack"]) if site["is_observed"]: # Eagerly eliminate irrelevant observation plates. plates &= frozenset.union( *(self._plates[u] for u in self.dependencies[d] if u != d) ) - self._plates[name] = plates + self._plates[d] = plates # Create location-scale parameters, one per latent variable. for d, site in self._factors.items(): @@ -164,8 +164,8 @@ def _setup_prototype(self, *args, **kwargs) -> None: batch_shape = _plates_to_batch_shape(self._plates[d]) # Create a square root (not necessarily lower triangular). - raw = init_loc.new_zeros(batch_shape, u_size, d_size) - deep_setattr(self, self.factors, raw, event_dim=2) + raw = init_loc.new_zeros(batch_shape + (u_size, d_size)) + deep_setattr(self.factors, d, PyroParam(raw, event_dim=2)) # Dispatch to backend logic. backend_fn = getattr(self, f"_{self.backend}_setup_prototype", None) @@ -291,18 +291,22 @@ def _dense_setup_prototype(self, *args, **kwargs): # TODO add batch shapes self._dense_scatter[d] = index.reshape(-1) - def _dense_sample_aux_values( - self, - ) -> Tuple[Dict[str, torch.Tensor], Union[float, torch.Tensor]]: - from pyro.ops import Gaussian - - # Convert to a flat dense joint Gaussian. + def _dense_get_precision(self): flat_precision = torch.zeros(self._dense_size ** 2) for d, index in self._dense_scatter: raw = deep_getattr(self.factors, d) precision = _raw_to_precision(raw) flat_precision.scatter_add_(0, index, precision.reshape(-1)) precision = flat_precision.reshape(self._dense_size, self._dense_size) + return precision + + def _dense_sample_aux_values( + self, + ) -> Tuple[Dict[str, torch.Tensor], Union[float, torch.Tensor]]: + from pyro.ops import Gaussian + + # Convert to a flat dense joint Gaussian. + precision = self._dense_get_precision() info_vec = torch.zeros(self._dense_size) log_normalizer = torch.zeros(()) g = Gaussian(log_normalizer, info_vec, precision) diff --git a/tests/infer/autoguide/test_autogaussian.py b/tests/infer/autoguide/test_gaussian.py similarity index 77% rename from tests/infer/autoguide/test_autogaussian.py rename to tests/infer/autoguide/test_gaussian.py index 8871578523..2f62fe5a45 100644 --- a/tests/infer/autoguide/test_autogaussian.py +++ b/tests/infer/autoguide/test_gaussian.py @@ -14,12 +14,125 @@ from pyro.infer.reparam import LocScaleReparam from pyro.optim import Adam +from tests.common import assert_equal + BACKENDS = [ "dense", pytest.param("funsor", marks=[pytest.mark.stage("funsor")]), ] +def check_structure(model, expected_str): + guide = AutoGaussian(model, backend="dense") + guide() # initialize + + # Inject random noise into all unconstrained parameters. + for parameter in guide.parameters(): + parameter.normal_() + + with torch.no_grad(): + precision = guide._dense_get_gaussian().precision() + actual = precision.abs().gt(1e-5).long() + + str_to_number = {"?": 1, ".": 0} + expected = torch.tensor( + [[str_to_number[c] for c in row if c != " "] for row in expected_str] + ) + assert_equal(actual, expected) + + +def test_structure_1(): + def model(): + a = pyro.sample("a", dist.Normal(0, 1)) + b = pyro.sample("b", dist.Normal(a, 1)) + c = pyro.sample("c", dist.Normal(b, 1)) + pyro.sample("d", dist.Normal(c, 1), obs=torch.tensor(0.0)) + + expected = [ + "? ? .", + "? ? ?", + ". ? ?", + ] + check_structure(model, expected) + + +def test_structure_2(): + + def model(): + a = pyro.sample("a", dist.Normal(0, 1)) + b = pyro.sample("b", dist.Normal(0, 1)) + with pyro.plate("i", 2): + c = pyro.sample("c", dist.Normal(a, b.exp())) + pyro.sample("d", dist.Normal(c, 1), obs=torch.tensor(0.0)) + + # size = 1 + 1 + 2 = 4 + expected = [ + "? . ? ?", + ". ? ? ?", + "? ? ? .", + "? ? . ?", + ] + check_structure(model, expected) + + +def test_structure_3(): + I, J = 2, 3 + + def model(): + i_plate = pyro.plate("i", I, dim=-1) + j_plate = pyro.plate("j", J, dim=-2) + with i_plate: + w = pyro.sample("w", dist.Normal(0, 1)) + with j_plate: + x = pyro.sample("x", dist.Normal(0, 1)) + with i_plate, j_plate: + y = pyro.sample("y", dist.Normal(w, x.exp())) + pyro.sample("z", dist.Normal(0, 1), obs=y) + + # size = 2 + 3 + 2 * 3 = 2 + 3 + 6 = 11 + expected = [ + "? . . . . ? . ? . ? .", + ". ? . . . . ? . ? . ?", + ". . ? . . ? ? . . . .", + ". . . ? . . . ? ? . .", + ". . . . ? . . . . ? ?", + "? . ? . . ? . . . . .", + ". ? ? . . . ? . . . .", + "? . . ? . . . ? . . .", + ". ? . ? . . . . ? . .", + "? . . . ? . . . . ? .", + ". ? . . ? . . . . . ?", + ] + check_structure(model, expected) + + +def test_structure_4(): + I, J = 2, 3 + + def model(): + i_plate = pyro.plate("i", I, dim=-1) + j_plate = pyro.plate("j", J, dim=-2) + a = pyro.sample("a", dist.Normal(0, 1)) + with i_plate: + b = pyro.sample("b", dist.Normal(a, 1)) + with j_plate: + c = pyro.sample("c", dist.Normal(b.mean(), 1)) + d = pyro.sample("d", dist.Normal(c.mean(), 1)) + pyro.sample("e", dist.Normal(0, 1), obs=d) + + # size = 1 + 2 + 3 + 1 = 7 + expected = [ + "? ? ? . . . .", + "? ? . ? ? ? .", + "? . ? ? ? ? .", + ". ? ? ? . . ?", + ". ? ? . ? . ?", + ". ? ? . . ? ?", + ". . . ? ? ? ?", + ] + check_structure(model, expected) + + # Simplified from https://github.com/pyro-cov/tree/master/pyrocov/mutrans.py def pyrocov_model(dataset): # Tensor shapes are commented at the end of some lines. @@ -272,7 +385,7 @@ def test_pyrocov_structure(backend): @pytest.mark.parametrize("backend", BACKENDS) -def test_profile(n=1, num_steps=1, backend="funsor"): +def test_profile(backend, n=1, num_steps=1): """ Helper function for profiling. """ @@ -301,4 +414,4 @@ def test_profile(n=1, num_steps=1, backend="funsor"): if __name__ == "__main__": - test_profile(n=10, num_steps=100, backend="funsor") + test_profile(backend="funsor", n=10, num_steps=100) From 7cbaa3d9f4acc3b63943e5713d2f9e43d1309819 Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Sat, 2 Oct 2021 12:18:24 -0400 Subject: [PATCH 26/41] Refactor to use a class hierarchy --- pyro/infer/autoguide/gaussian.py | 115 ++++++++++++++----------- tests/infer/autoguide/test_gaussian.py | 2 - tests/infer/test_autoguide.py | 42 ++++----- 3 files changed, 88 insertions(+), 71 deletions(-) diff --git a/pyro/infer/autoguide/gaussian.py b/pyro/infer/autoguide/gaussian.py index 94c49c8a8b..340c074c45 100644 --- a/pyro/infer/autoguide/gaussian.py +++ b/pyro/infer/autoguide/gaussian.py @@ -3,7 +3,7 @@ from collections import OrderedDict, defaultdict from contextlib import ExitStack -from typing import Callable, Dict, Optional, Set, Tuple, Union +from typing import Callable, Dict, Set, Tuple, Union import torch from torch.distributions import biject_to @@ -22,7 +22,29 @@ from .utils import deep_getattr, deep_setattr, helpful_support_errors -class AutoGaussian(AutoGuide): +# Helper to dispatch to concrete subclasses of AutoGaussian, e.g. +# AutoGaussian(model, backend="dense") +# is converted to +# AutoGaussianDense(model) +# The intent is to avoid proliferation of subclasses and docstrings, +# and provide a single interface AutoGaussian(...). +class AutoGaussianMeta(type(AutoGuide)): + backends = {} + default_backend = "dense" + + def __init__(cls, *args, **kwargs): + super().__init__(*args, **kwargs) + assert cls.__name__.startswith("AutoGaussian") + key = cls.__name__.replace("AutoGaussian", "").lower() + cls.backends[key] = cls + + def __call__(cls, *args, **kwargs): + backend = kwargs.pop("backend", cls.default_backend) + cls = cls.backends[backend] + return super(AutoGaussianMeta, cls).__call__(*args, **kwargs) + + +class AutoGaussian(AutoGuide, metaclass=AutoGaussianMeta): """ Gaussian guide with optimal conditional independence structure. @@ -70,15 +92,10 @@ class AutoGaussian(AutoGuide): See :ref:`autoguide-initialization` section for available functions. :param float init_scale: Initial scale for the standard deviation of each (unconstrained transformed) latent variable. - :param callable create_plates: An optional function inputing the same - ``*args,**kwargs`` as ``model()`` and returning a :class:`pyro.plate` - or iterable of plates. Plates not returned will be created - automatically as usual. This is useful for data subsampling. :param str backend: Back end for performing Gaussian tensor variable elimination. Defaults to "dense"; other options include "funsor". """ - default_backend: str = "dense" scale_constraint = constraints.softplus_positive def __init__( @@ -87,16 +104,15 @@ def __init__( *, init_loc_fn: Callable = init_to_feasible, init_scale: float = 0.1, - create_plates: Optional[Callable] = None, backend=None, ): + assert backend is not None if not isinstance(init_scale, float) or not (init_scale > 0): raise ValueError(f"Expected init_scale > 0. but got {init_scale}") self._init_scale = init_scale self._original_model = (model,) model = InitMessenger(init_loc_fn)(model) - super().__init__(model, create_plates=create_plates) - self.backend = self.default_backend if backend is None else backend + super().__init__(model) def _setup_prototype(self, *args, **kwargs) -> None: super()._setup_prototype(*args, **kwargs) @@ -163,14 +179,9 @@ def _setup_prototype(self, *args, **kwargs) -> None: d_size = min(d_size, u_size) # just an optimization batch_shape = _plates_to_batch_shape(self._plates[d]) - # Create a square root (not necessarily lower triangular). - raw = init_loc.new_zeros(batch_shape + (u_size, d_size)) - deep_setattr(self.factors, d, PyroParam(raw, event_dim=2)) - - # Dispatch to backend logic. - backend_fn = getattr(self, f"_{self.backend}_setup_prototype", None) - if backend_fn is not None: - backend_fn(*args, **kwargs) + # Create a square root parameter (full, not lower triangular). + sqrt = init_loc.new_zeros(batch_shape + (u_size, d_size)) + deep_setattr(self.factors, d, PyroParam(sqrt, event_dim=2)) def forward(self, *args, **kwargs) -> Dict[str, torch.Tensor]: if self.prototype_trace is None: @@ -233,16 +244,19 @@ def _transform_values( def _sample_aux_values( self, ) -> Tuple[Dict[str, torch.Tensor], Union[float, torch.Tensor]]: - # Dispatch to backend logic. - backend_fn = getattr(self, f"_{self.backend}_sample_aux_values", None) - if backend_fn is None: - raise NotImplementedError(f"Unknown AutoGaussian backend: {self.backend}") - return backend_fn() + raise NotImplementedError + - ############################################################################ - # Dense backend. Methods and attributes are prefixed by ._dense_ +class AutoGaussianDense(AutoGaussian): + # Dense implementation of :class:`AutoGaussian` . + # Attributes are prefixed by ._dense_ + # The following are equivalent: + # guide = AutoGaussian(model, backend="dense") + # guide = AutoGaussianDense(model) + + def _setup_prototype(self, *args, **kwargs): + super()._setup_prototype(*args, **kwargs) - def _dense_setup_prototype(self, *args, **kwargs): # Collect flat and individual and aggregated flat shapes. self._dense_shapes = {} dense_gather = {} @@ -259,8 +273,8 @@ def _dense_setup_prototype(self, *args, **kwargs): # Create sparse -> dense precision scatter indices. self._dense_scatter = {} for d, site in self._factors.items(): - raw_shape = deep_getattr(self.factors, d).shape - precision_shape = raw_shape[:-1] + raw_shape[-2:-1] + sqrt_shape = deep_getattr(self.factors, d).shape + precision_shape = sqrt_shape[:-1] + sqrt_shape[-2:-1] index = torch.zeros(precision_shape, dtype=torch.long) for u in self.dependencies[d]: if not self._factors[u]["is_observed"]: @@ -280,8 +294,8 @@ def _dense_setup_prototype(self, *args, **kwargs): stop = start + "TODO" event_indices.append(torch.arange(start, stop)) event_index = torch.cat(event_indices) - raw_shape = deep_getattr(self.factors, d).shape - precision_shape = raw_shape[:-1] + raw_shape[-2:-1] + sqrt_shape = deep_getattr(self.factors, d).shape + precision_shape = sqrt_shape[:-1] + sqrt_shape[-2:-1] index = torch.zeros(precision_shape, dtype=torch.long) stride = 1 index += event_index * stride @@ -291,22 +305,22 @@ def _dense_setup_prototype(self, *args, **kwargs): # TODO add batch shapes self._dense_scatter[d] = index.reshape(-1) - def _dense_get_precision(self): + def _get_precision(self): flat_precision = torch.zeros(self._dense_size ** 2) for d, index in self._dense_scatter: - raw = deep_getattr(self.factors, d) - precision = _raw_to_precision(raw) + sqrt = deep_getattr(self.factors, d) + precision = sqrt @ sqrt.transpose(dim1=-2, dim2=-1) flat_precision.scatter_add_(0, index, precision.reshape(-1)) precision = flat_precision.reshape(self._dense_size, self._dense_size) return precision - def _dense_sample_aux_values( + def _sample_aux_values( self, ) -> Tuple[Dict[str, torch.Tensor], Union[float, torch.Tensor]]: from pyro.ops import Gaussian # Convert to a flat dense joint Gaussian. - precision = self._dense_get_precision() + precision = self._get_precision() info_vec = torch.zeros(self._dense_size) log_normalizer = torch.zeros(()) g = Gaussian(log_normalizer, info_vec, precision) @@ -333,10 +347,16 @@ def _dense_sample_aux_values( ) return samples, log_density - ############################################################################ - # Funsor backend. Methods and attributes are prefixed by ._funsor_ - def _funsor_setup_prototype(self, *args, **kwargs): +class AutoGaussianFunsor(AutoGaussian): + # Funsor implementation of :class:`AutoGaussian` . + # Attributes are prefixed by ._funsor_ + # The following are equivalent: + # guide = AutoGaussian(model, backend="funsor") + # guide = AutoGaussianFunsor(model) + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) try: import funsor except ImportError as e: @@ -346,6 +366,12 @@ def _funsor_setup_prototype(self, *args, **kwargs): ) from e funsor.set_backend("torch") + def _setup_prototype(self, *args, **kwargs): + super()._setup_prototype(*args, **kwargs) + import funsor + + funsor.set_backend("torch") + # Determine TVE problem shape. factor_inputs: Dict[str, OrderedDict[str, funsor.Domain]] = {} eliminate: Set[str] = set() @@ -374,11 +400,13 @@ def _funsor_setup_prototype(self, *args, **kwargs): self._funsor_plate_to_dim = plate_to_dim self._funsor_plates = frozenset(plate_to_dim) - def _funsor_sample_aux_values( + def _sample_aux_values( self, ) -> Tuple[Dict[str, torch.Tensor], Union[float, torch.Tensor]]: import funsor + funsor.set_backend("torch") + # Convert torch to funsor. particle_plates = frozenset(get_plates()) plate_to_dim = self._funsor_plate_to_dim.copy() @@ -409,15 +437,6 @@ def _funsor_sample_aux_values( return samples, log_density -def _raw_to_precision(raw): - """ - Transform an unconstrained matrix of shape ``batch_shape + (m, n)`` to a - positive semidefinite precision matrix of shape ``batch_shape + (m, m)``. - Typically ``m >= n``. - """ - return raw @ raw.transpose(dim1=-2, dim2=-1) - - def _plates_to_batch_shape(plates): shape = [1] * max([0] + [-f.dim for f in plates]) for f in plates: diff --git a/tests/infer/autoguide/test_gaussian.py b/tests/infer/autoguide/test_gaussian.py index 2f62fe5a45..51e119c7d2 100644 --- a/tests/infer/autoguide/test_gaussian.py +++ b/tests/infer/autoguide/test_gaussian.py @@ -13,7 +13,6 @@ from pyro.infer.autoguide import AutoGaussian from pyro.infer.reparam import LocScaleReparam from pyro.optim import Adam - from tests.common import assert_equal BACKENDS = [ @@ -57,7 +56,6 @@ def model(): def test_structure_2(): - def model(): a = pyro.sample("a", dist.Normal(0, 1)) b = pyro.sample("b", dist.Normal(0, 1)) diff --git a/tests/infer/test_autoguide.py b/tests/infer/test_autoguide.py index 21e4ec4214..30b7d9475d 100644 --- a/tests/infer/test_autoguide.py +++ b/tests/infer/test_autoguide.py @@ -22,6 +22,7 @@ AutoDiagonalNormal, AutoDiscreteParallel, AutoGaussian, + AutoGaussianFunsor, AutoGuide, AutoGuideList, AutoIAFNormal, @@ -43,10 +44,9 @@ from pyro.util import check_model_guide_match from tests.common import assert_close, assert_equal - -@functools.partial(pytest.param, marks=[pytest.mark.stage("funsor")]) -class AutoGaussian_funsor(AutoGaussian): - default_backend = "funsor" +AutoGaussianFunsor = pytest.param( + AutoGaussianFunsor, marks=[pytest.mark.stage("funsor")] +) @pytest.mark.parametrize( @@ -95,7 +95,7 @@ def model(): AutoIAFNormal, AutoLaplaceApproximation, AutoGaussian, - AutoGaussian_funsor, + AutoGaussianFunsor, ], ) def test_factor(auto_class, Elbo): @@ -193,7 +193,7 @@ def dependency_z6_z5(z5): AutoStructured, AutoStructured_shapes, AutoGaussian, - AutoGaussian_funsor, + AutoGaussianFunsor, ], ) @pytest.mark.filterwarnings("ignore::FutureWarning") @@ -232,7 +232,7 @@ def model(): AutoIAFNormal, AutoLaplaceApproximation, AutoGaussian, - AutoGaussian_funsor, + AutoGaussianFunsor, ], ) @pytest.mark.parametrize("Elbo", [Trace_ELBO, TraceGraph_ELBO]) @@ -346,7 +346,7 @@ def __init__(self, model): AutoStructured, AutoStructured_median, AutoGaussian, - AutoGaussian_funsor, + AutoGaussianFunsor, ], ) @pytest.mark.parametrize("Elbo", [Trace_ELBO, TraceGraph_ELBO, TraceEnum_ELBO]) @@ -515,7 +515,7 @@ def model(): AutoIAFNormal, AutoLaplaceApproximation, AutoGaussian, - AutoGaussian_funsor, + AutoGaussianFunsor, ], ) def test_discrete_parallel(continuous_class): @@ -552,7 +552,7 @@ def model(data): AutoIAFNormal, AutoLaplaceApproximation, AutoGaussian, - AutoGaussian_funsor, + AutoGaussianFunsor, ], ) def test_guide_list(auto_class): @@ -576,7 +576,7 @@ def model(): AutoLowRankMultivariateNormal, AutoLaplaceApproximation, AutoGaussian, - AutoGaussian_funsor, + AutoGaussianFunsor, ], ) def test_callable(auto_class): @@ -606,7 +606,7 @@ def guide_x(): AutoLowRankMultivariateNormal, AutoLaplaceApproximation, AutoGaussian, - AutoGaussian_funsor, + AutoGaussianFunsor, ], ) def test_callable_return_dict(auto_class): @@ -654,7 +654,7 @@ def model(): AutoMultivariateNormal, AutoLowRankMultivariateNormal, AutoGaussian, - AutoGaussian_funsor, + AutoGaussianFunsor, ], ) def test_init_loc_fn(auto_class): @@ -806,7 +806,7 @@ def forward(self): functools.partial(AutoDiagonalNormal, init_loc_fn=init_to_mean), functools.partial(AutoDiagonalNormal, init_loc_fn=init_to_median), AutoGaussian, - AutoGaussian_funsor, + AutoGaussianFunsor, ], ) @pytest.mark.parametrize("Elbo", [Trace_ELBO, TraceGraph_ELBO, TraceEnum_ELBO]) @@ -937,7 +937,7 @@ def forward(self, x, y=None): functools.partial(AutoDiagonalNormal, init_loc_fn=init_to_median), AutoStructured, AutoGaussian, - AutoGaussian_funsor, + AutoGaussianFunsor, ], ) def test_replay_plates(auto_class, sample_shape): @@ -1087,7 +1087,7 @@ def create_plates(data): AutoLowRankMultivariateNormal, AutoLaplaceApproximation, AutoGaussian, - AutoGaussian_funsor, + AutoGaussianFunsor, ], ) @pytest.mark.parametrize( @@ -1123,7 +1123,7 @@ def model(): AutoLowRankMultivariateNormal, AutoLaplaceApproximation, AutoGaussian, - AutoGaussian_funsor, + AutoGaussianFunsor, ], ) @pytest.mark.parametrize( @@ -1156,7 +1156,7 @@ def model(): AutoLowRankMultivariateNormal, AutoLaplaceApproximation, AutoGaussian, - AutoGaussian_funsor, + AutoGaussianFunsor, ], ) @pytest.mark.parametrize( @@ -1231,7 +1231,7 @@ def __init__(self, model): AutoStructured_exact_normal, AutoStructured_exact_mvn, AutoGaussian, - AutoGaussian_funsor, + AutoGaussianFunsor, ], ) def test_exact(Guide): @@ -1300,7 +1300,7 @@ def model(data): AutoStructured_exact_normal, AutoStructured_exact_mvn, AutoGaussian, - AutoGaussian_funsor, + AutoGaussianFunsor, ], ) def test_exact_batch(Guide): @@ -1357,7 +1357,7 @@ def model(data): AutoLowRankMultivariateNormal, AutoStructured, AutoGaussian, - AutoGaussian_funsor, + AutoGaussianFunsor, ], ) def test_exact_tree(Guide): From 9fbb797139d403a79950f236cea9420cbc745c87 Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Sat, 2 Oct 2021 15:59:30 -0400 Subject: [PATCH 27/41] Flesh out dense backend (some tests fail) --- pyro/infer/autoguide/gaussian.py | 156 ++++++++++++++++--------- tests/infer/autoguide/test_gaussian.py | 50 +++++++- 2 files changed, 146 insertions(+), 60 deletions(-) diff --git a/pyro/infer/autoguide/gaussian.py b/pyro/infer/autoguide/gaussian.py index 340c074c45..6b14112350 100644 --- a/pyro/infer/autoguide/gaussian.py +++ b/pyro/infer/autoguide/gaussian.py @@ -1,6 +1,7 @@ # Copyright Contributors to the Pyro project. # SPDX-License-Identifier: Apache-2.0 +import itertools from collections import OrderedDict, defaultdict from contextlib import ExitStack from typing import Callable, Dict, Set, Tuple, Union @@ -106,7 +107,6 @@ def __init__( init_scale: float = 0.1, backend=None, ): - assert backend is not None if not isinstance(init_scale, float) or not (init_scale > 0): raise ValueError(f"Expected init_scale > 0. but got {init_scale}") self._init_scale = init_scale @@ -170,17 +170,17 @@ def _setup_prototype(self, *args, **kwargs) -> None: u_size = 0 for u in self.dependencies[d]: if not self._factors[u]["is_observed"]: - broken_shape = _plates_to_batch_shape( - self._plates[u] - self._plates[d] - ) + broken_shape = _plates_to_shape(self._plates[u] - self._plates[d]) u_size += broken_shape.numel() * self._event_numel[u] d_size = self._event_numel[d] if site["is_observed"]: d_size = min(d_size, u_size) # just an optimization - batch_shape = _plates_to_batch_shape(self._plates[d]) + batch_shape = _plates_to_shape(self._plates[d]) # Create a square root parameter (full, not lower triangular). sqrt = init_loc.new_zeros(batch_shape + (u_size, d_size)) + if d in self.dependencies[d]: + sqrt += torch.eye(u_size, d_size) deep_setattr(self.factors, d, PyroParam(sqrt, event_dim=2)) def forward(self, *args, **kwargs) -> Dict[str, torch.Tensor]: @@ -248,25 +248,29 @@ def _sample_aux_values( class AutoGaussianDense(AutoGaussian): - # Dense implementation of :class:`AutoGaussian` . + """ + Dense implementation of :class:`AutoGaussian` . + + The following are equivalent:: + + guide = AutoGaussian(model, backend="dense") + guide = AutoGaussianDense(model) + """ + # Attributes are prefixed by ._dense_ - # The following are equivalent: - # guide = AutoGaussian(model, backend="dense") - # guide = AutoGaussianDense(model) def _setup_prototype(self, *args, **kwargs): super()._setup_prototype(*args, **kwargs) - # Collect flat and individual and aggregated flat shapes. + # Collect global shapes and per-axis indices. self._dense_shapes = {} - dense_gather = {} + global_indices = {} pos = 0 for d, event_shape in self._unconstrained_event_shapes.items(): batch_shape = self._factors[d]["fn"].batch_shape self._dense_shapes[d] = batch_shape, event_shape - shape = batch_shape + event_shape - end = pos + shape.numel() - dense_gather[d] = torch.arange(pos, end).reshape(shape) + end = pos + (batch_shape + event_shape).numel() + global_indices[d] = torch.arange(pos, end).reshape(batch_shape + (-1,)) pos = end self._dense_size = pos @@ -276,50 +280,42 @@ def _setup_prototype(self, *args, **kwargs): sqrt_shape = deep_getattr(self.factors, d).shape precision_shape = sqrt_shape[:-1] + sqrt_shape[-2:-1] index = torch.zeros(precision_shape, dtype=torch.long) - for u in self.dependencies[d]: - if not self._factors[u]["is_observed"]: - batch_plates = self._plates[u] & self._plates[d] # linear - broken_plates = self._plates[u] - self._plates[d] # quadratic - event_numel = self._event_numel[u] # quadratic - "TODO" - self._dense_scatter[d] = index.reshape(-1) - ######################################################### - # OLD - for d, site in self._factors.items(): - event_indices = [] + # Collect local offsets. + local_offsets = {} + pos = 0 for u in self.dependencies[d]: - if not self._factors[u]["is_observed"]: - start = offsets[u] - stop = start + "TODO" - event_indices.append(torch.arange(start, stop)) - event_index = torch.cat(event_indices) - sqrt_shape = deep_getattr(self.factors, d).shape - precision_shape = sqrt_shape[:-1] + sqrt_shape[-2:-1] - index = torch.zeros(precision_shape, dtype=torch.long) - stride = 1 - index += event_index * stride - stride *= self._dense_size - index += event_index[:, None] * stride - stride *= self._dense_size - # TODO add batch shapes - self._dense_scatter[d] = index.reshape(-1) + local_offsets[u] = pos + broken_plates = self._plates[u] - self._plates[d] + pos += self._event_numel[u] * _plates_to_shape(broken_plates).numel() + + # Create indices blockwise. + for u, v in itertools.product(self.dependencies[d], self.dependencies[d]): + u_index = global_indices[u] + v_index = global_indices[v] + + # Permute broken plates to the right of preserved plates. + # FIXME what happens if d has a plate not in u or v? + u_index = _break_plates(u_index, self._plates[u], self._plates[d]) + v_index = _break_plates(v_index, self._plates[v], self._plates[d]) + + # Scatter global indices into the [u,v] block. + u_start = local_offsets[u] + u_stop = u_start + u_index.size(-1) + v_start = local_offsets[v] + v_stop = v_start + v_index.size(-1) + index[ + ..., u_start:u_stop, v_start:v_stop + ] = self._dense_size * u_index.unsqueeze(-1) + v_index.unsqueeze(-2) - def _get_precision(self): - flat_precision = torch.zeros(self._dense_size ** 2) - for d, index in self._dense_scatter: - sqrt = deep_getattr(self.factors, d) - precision = sqrt @ sqrt.transpose(dim1=-2, dim2=-1) - flat_precision.scatter_add_(0, index, precision.reshape(-1)) - precision = flat_precision.reshape(self._dense_size, self._dense_size) - return precision + self._dense_scatter[d] = index.reshape(-1) def _sample_aux_values( self, ) -> Tuple[Dict[str, torch.Tensor], Union[float, torch.Tensor]]: - from pyro.ops import Gaussian + from pyro.ops.gaussian import Gaussian - # Convert to a flat dense joint Gaussian. + # Convert to a dense joint Gaussian over flattened variables. precision = self._get_precision() info_vec = torch.zeros(self._dense_size) log_normalizer = torch.zeros(()) @@ -338,8 +334,8 @@ def _sample_aux_values( samples = {} pos = 0 for d, (batch_shape, event_shape) in self._dense_shapes.items(): - end = pos + self._event_numel[d] - flat_sample = flat_samples[pos:end] + end = pos + (batch_shape + event_shape).numel() + flat_sample = flat_samples[..., pos:end] pos = end # Assumes sample shapes are left of batch shapes. samples[d] = flat_sample.reshape( @@ -347,13 +343,26 @@ def _sample_aux_values( ) return samples, log_density + def _get_precision(self): + flat_precision = torch.zeros(self._dense_size ** 2) + for d, index in self._dense_scatter.items(): + sqrt = deep_getattr(self.factors, d) + precision = sqrt @ sqrt.transpose(dim0=-2, dim1=-1) + flat_precision.scatter_add_(0, index, precision.reshape(-1)) + precision = flat_precision.reshape(self._dense_size, self._dense_size) + return precision + class AutoGaussianFunsor(AutoGaussian): - # Funsor implementation of :class:`AutoGaussian` . + """ + Funsor implementation of :class:`AutoGaussian` . + + The following are equivalent:: + guide = AutoGaussian(model, backend="funsor") + guide = AutoGaussianFunsor(model) + """ + # Attributes are prefixed by ._funsor_ - # The following are equivalent: - # guide = AutoGaussian(model, backend="funsor") - # guide = AutoGaussianFunsor(model) def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -437,8 +446,41 @@ def _sample_aux_values( return samples, log_density -def _plates_to_batch_shape(plates): +def _plates_to_shape(plates): shape = [1] * max([0] + [-f.dim for f in plates]) for f in plates: shape[f.dim] = f.size return torch.Size(shape) + + +def _break_plates(x, all_plates, kept_plates): + """ + Reshapes and permutes a tensor ``x`` with event_dim=1 and batch shape given + by ``all_plates`` by breaking all plates not in ``kept_plates``. Each + broken plate is moved into the event shape, and finally the event shape is + flattend back to a single dimension. + """ + assert x.shape[:-1] == _plates_to_shape(all_plates) # event_dim == 1 + broken_plates = all_plates - kept_plates + + if not broken_plates: + return x + + if not kept_plates: + # Empty batch shape. + return x.reshape(-1) + + batch_shape = _plates_to_shape(kept_plates) + if max(p.dim for p in kept_plates) < min(p.dim for p in broken_plates): + # No permutation is necessary. + return x.reshape(batch_shape + (-1,)) + + # We need to permute broken plates left past kept plates. + event_dims = {-1} | {p.dim - 1 for p in broken_plates} + perm = sorted(range(-x.dim(), 0), key=lambda d: (d in event_dims, d)) + return x.permute(perm).reshape(batch_shape + (-1,)) + + +__all__ = [ + "AutoGaussian", +] diff --git a/tests/infer/autoguide/test_gaussian.py b/tests/infer/autoguide/test_gaussian.py index 51e119c7d2..1a0f863ae8 100644 --- a/tests/infer/autoguide/test_gaussian.py +++ b/tests/infer/autoguide/test_gaussian.py @@ -1,7 +1,7 @@ # Copyright Contributors to the Pyro project. # SPDX-License-Identifier: Apache-2.0 -from collections import OrderedDict +from collections import OrderedDict, namedtuple import pytest import torch @@ -11,6 +11,7 @@ import pyro.poutine as poutine from pyro.infer import SVI, Predictive, Trace_ELBO from pyro.infer.autoguide import AutoGaussian +from pyro.infer.autoguide.gaussian import _break_plates from pyro.infer.reparam import LocScaleReparam from pyro.optim import Adam from tests.common import assert_equal @@ -21,16 +22,59 @@ ] +MockPlate = namedtuple("MockPlate", "dim, size") + + +def test_break_plates(): + shape = torch.Size([5, 4, 3, 2]) + i = MockPlate(-3, 5) + j = MockPlate(-2, 4) + k = MockPlate(-1, 3) + x = torch.arange(shape.numel()).reshape(shape) + + actual = _break_plates(x, {i, j, k}, set()) + expected = x.reshape(-1) + assert_equal(actual, expected) + + actual = _break_plates(x, {i, j, k}, {i}) + expected = x.reshape(5, 1, 1, -1) + assert_equal(actual, expected) + + actual = _break_plates(x, {i, j, k}, {j}) + expected = x.permute((1, 0, 2, 3)).reshape(4, 1, -1) + assert_equal(actual, expected) + + actual = _break_plates(x, {i, j, k}, {k}) + expected = x.permute((2, 0, 1, 3)).reshape(3, -1) + assert_equal(actual, expected) + + actual = _break_plates(x, {i, j, k}, {i, j}) + expected = x.reshape(5, 4, 1, -1) + assert_equal(actual, expected) + + actual = _break_plates(x, {i, j, k}, {i, k}) + expected = x.permute((0, 2, 1, 3)).reshape(5, 1, 3, -1) + assert_equal(actual, expected) + + actual = _break_plates(x, {i, j, k}, {j, k}) + expected = x.permute((1, 2, 0, 3)).reshape(4, 3, -1) + assert_equal(actual, expected) + + actual = _break_plates(x, {i, j, k}, {i, j, k}) + expected = x + assert_equal(actual, expected) + + def check_structure(model, expected_str): guide = AutoGaussian(model, backend="dense") guide() # initialize # Inject random noise into all unconstrained parameters. for parameter in guide.parameters(): - parameter.normal_() + parameter.data.normal_() with torch.no_grad(): - precision = guide._dense_get_gaussian().precision() + precision = guide._get_precision() actual = precision.abs().gt(1e-5).long() str_to_number = {"?": 1, ".": 0} From 8c3a6ab3cd84ee33fa04eea55605ddd5ab5a3f22 Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Sat, 2 Oct 2021 16:37:22 -0400 Subject: [PATCH 28/41] Fix tests, simplify init logic --- pyro/infer/autoguide/gaussian.py | 10 ++++++++-- tests/infer/autoguide/test_gaussian.py | 5 +++++ tests/infer/test_autoguide.py | 4 ++-- 3 files changed, 15 insertions(+), 4 deletions(-) diff --git a/pyro/infer/autoguide/gaussian.py b/pyro/infer/autoguide/gaussian.py index 6b14112350..0556002b6f 100644 --- a/pyro/infer/autoguide/gaussian.py +++ b/pyro/infer/autoguide/gaussian.py @@ -137,6 +137,11 @@ def _setup_prototype(self, *args, **kwargs) -> None: assert all(f.vectorized for f in site["cond_indep_stack"]) self._factors[d] = site plates = frozenset(site["cond_indep_stack"]) + if site["fn"].batch_shape != _plates_to_shape(plates): + raise ValueError( + f"Shape mismatch at site '{d}'. " + "Are you missing a pyro.plate() or .to_event()?" + ) if site["is_observed"]: # Eagerly eliminate irrelevant observation plates. plates &= frozenset.union( @@ -180,7 +185,8 @@ def _setup_prototype(self, *args, **kwargs) -> None: # Create a square root parameter (full, not lower triangular). sqrt = init_loc.new_zeros(batch_shape + (u_size, d_size)) if d in self.dependencies[d]: - sqrt += torch.eye(u_size, d_size) + # Initialize the [d,d] block to the identity matrix. + sqrt.diagonal(dim1=-2, dim2=-1).fill_(1) deep_setattr(self.factors, d, PyroParam(sqrt, event_dim=2)) def forward(self, *args, **kwargs) -> Dict[str, torch.Tensor]: @@ -295,7 +301,6 @@ def _setup_prototype(self, *args, **kwargs): v_index = global_indices[v] # Permute broken plates to the right of preserved plates. - # FIXME what happens if d has a plate not in u or v? u_index = _break_plates(u_index, self._plates[u], self._plates[d]) v_index = _break_plates(v_index, self._plates[v], self._plates[d]) @@ -461,6 +466,7 @@ def _break_plates(x, all_plates, kept_plates): flattend back to a single dimension. """ assert x.shape[:-1] == _plates_to_shape(all_plates) # event_dim == 1 + kept_plates = kept_plates & all_plates broken_plates = all_plates - kept_plates if not broken_plates: diff --git a/tests/infer/autoguide/test_gaussian.py b/tests/infer/autoguide/test_gaussian.py index 1a0f863ae8..b240bddbaf 100644 --- a/tests/infer/autoguide/test_gaussian.py +++ b/tests/infer/autoguide/test_gaussian.py @@ -27,6 +27,7 @@ def test_break_plates(): shape = torch.Size([5, 4, 3, 2]) + h = MockPlate(-4, 6) i = MockPlate(-3, 5) j = MockPlate(-2, 4) k = MockPlate(-1, 3) @@ -64,6 +65,10 @@ def test_break_plates(): expected = x assert_equal(actual, expected) + actual = _break_plates(x, {i, j, k}, {h, i, j, k}) + expected = x + assert_equal(actual, expected) + def check_structure(model, expected_str): guide = AutoGaussian(model, backend="dense") diff --git a/tests/infer/test_autoguide.py b/tests/infer/test_autoguide.py index 30b7d9475d..31f4b64749 100644 --- a/tests/infer/test_autoguide.py +++ b/tests/infer/test_autoguide.py @@ -22,7 +22,6 @@ AutoDiagonalNormal, AutoDiscreteParallel, AutoGaussian, - AutoGaussianFunsor, AutoGuide, AutoGuideList, AutoIAFNormal, @@ -36,6 +35,7 @@ init_to_median, init_to_sample, ) +from pyro.infer.autoguide.gaussian import AutoGaussianFunsor from pyro.infer.reparam import ProjectedNormalReparam from pyro.nn.module import PyroModule, PyroParam, PyroSample from pyro.ops.gaussian import Gaussian @@ -557,7 +557,7 @@ def model(data): ) def test_guide_list(auto_class): def model(): - pyro.sample("x", dist.Normal(0.0, 1.0).expand([2])) + pyro.sample("x", dist.Normal(0.0, 1.0).expand([2]).to_event(1)) pyro.sample("y", dist.MultivariateNormal(torch.zeros(5), torch.eye(5, 5))) guide = AutoGuideList(model) From 58c344c2528cbea8adbf2ea3eba8d7c37fc14763 Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Sat, 2 Oct 2021 17:06:36 -0400 Subject: [PATCH 29/41] Add a pyro-cov poisson example model --- tests/infer/autoguide/test_gaussian.py | 69 +++++++++++++++++++++----- 1 file changed, 57 insertions(+), 12 deletions(-) diff --git a/tests/infer/autoguide/test_gaussian.py b/tests/infer/autoguide/test_gaussian.py index b240bddbaf..4a7a61e521 100644 --- a/tests/infer/autoguide/test_gaussian.py +++ b/tests/infer/autoguide/test_gaussian.py @@ -315,9 +315,61 @@ def pyrocov_model_plated(dataset): ) -@pytest.mark.parametrize( - "model", [pyrocov_model, pyrocov_model_relaxed, pyrocov_model_plated] -) +# This is modified by replacing the multinomial likelihood with poisson. +def pyrocov_model_poisson(dataset): + # Tensor shapes are commented at the end of some lines. + features = dataset["features"] + local_time = dataset["local_time"][..., None] # [T, P, 1] + T, P, _ = local_time.shape + S, F = features.shape + weekly_strains = dataset["weekly_strains"] # [T, P, S] + assert weekly_strains.shape == (T, P, S) + feature_plate = pyro.plate("feature", F, dim=-1) + strain_plate = pyro.plate("strain", S, dim=-1) + place_plate = pyro.plate("place", P, dim=-2) + time_plate = pyro.plate("time", T, dim=-3) + + # Sample global random variables. + coef_scale = pyro.sample("coef_scale", dist.InverseGamma(5e3, 1e2)) + rate_loc_scale = pyro.sample("rate_loc_scale", dist.LogNormal(-4, 2)) + rate_scale = pyro.sample("rate_scale", dist.LogNormal(-4, 2)) + init_loc_scale = pyro.sample("init_loc_scale", dist.LogNormal(0, 2)) + init_scale = pyro.sample("init_scale", dist.LogNormal(0, 2)) + pois_loc = pyro.sample("pois_loc", dist.Normal(0, 2)) + pois_scale = pyro.sample("pois_scale", dist.LogNormal(0, 2)) + + with feature_plate: + coef = pyro.sample("coef", dist.Logistic(0, coef_scale)) # [F] + rate_loc_loc = 0.01 * coef @ features.T + with strain_plate: + rate_loc = pyro.sample( + "rate_loc", dist.Normal(rate_loc_loc, rate_loc_scale) + ) # [S] + init_loc = pyro.sample("init_loc", dist.Normal(0, init_loc_scale)) # [S] + with place_plate, strain_plate: + rate = pyro.sample("rate", dist.Normal(rate_loc, rate_scale)) # [P, S] + init = pyro.sample("init", dist.Normal(init_loc, init_scale)) # [P, S] + + # Finally observe counts. + with time_plate, place_plate: + pois = pyro.sample("pois", dist.LogNormal(pois_loc, pois_scale)) + with time_plate, place_plate, strain_plate: + # Note .softmax() breaks conditional independence over strain, but only + # weakly. We could directly call .exp(), but .softmax is more + # numerically stable. + logits = pois * (init + rate * local_time).softmax(-1) # [T, P, S] + pyro.sample("obs", dist.Poisson(logits), obs=weekly_strains) + + +PYRO_COV_MODELS = [ + pyrocov_model, + pyrocov_model_relaxed, + pyrocov_model_plated, + pyrocov_model_poisson, +] + + +@pytest.mark.parametrize("model", PYRO_COV_MODELS) @pytest.mark.parametrize("backend", BACKENDS) def test_pyrocov_smoke(model, backend): T, P, S, F = 3, 4, 5, 6 @@ -336,9 +388,7 @@ def test_pyrocov_smoke(model, backend): predictive(dataset) -@pytest.mark.parametrize( - "model", [pyrocov_model, pyrocov_model_relaxed, pyrocov_model_plated] -) +@pytest.mark.parametrize("model", PYRO_COV_MODELS) @pytest.mark.parametrize("backend", BACKENDS) def test_pyrocov_reparam(model, backend): T, P, S, F = 2, 3, 4, 5 @@ -366,7 +416,7 @@ def test_pyrocov_reparam(model, backend): predictive(dataset) -@pytest.mark.parametrize("backend", BACKENDS) +@pytest.mark.parametrize("backend", ["funsor"]) def test_pyrocov_structure(backend): from funsor import Bint, Real, Reals @@ -447,11 +497,6 @@ def test_profile(backend, n=1, num_steps=1): guide = AutoGaussian(model) svi = SVI(model, guide, Adam({"lr": 1e-8}), Trace_ELBO()) guide(dataset) # initialize - print("Factor inputs:") - for name, inputs in guide._funsor_factor_inputs.items(): - print(f" {name}:") - for k, v in inputs.items(): - print(f" {k}: {v}") print("Parameter shapes:") for name, param in guide.named_parameters(): print(f" {name}: {tuple(param.shape)}") From 670fdda46bc39912d219f979542d72dd8e4c007d Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Sun, 3 Oct 2021 07:28:10 -0400 Subject: [PATCH 30/41] Flesh out funsor backend --- pyro/infer/autoguide/gaussian.py | 51 +++++++++++++++++++------------- 1 file changed, 31 insertions(+), 20 deletions(-) diff --git a/pyro/infer/autoguide/gaussian.py b/pyro/infer/autoguide/gaussian.py index 0556002b6f..53c7c76ea2 100644 --- a/pyro/infer/autoguide/gaussian.py +++ b/pyro/infer/autoguide/gaussian.py @@ -68,7 +68,7 @@ class AutoGaussian(AutoGuide, metaclass=AutoGaussianMeta): guide = AutoGaussian(model) svi = SVI(model, guide, ...) - Example using funsor backend:: + Example using experimental funsor backend:: !pip install pyro-ppl[funsor] guide = AutoGaussian(model, backend="funsor") @@ -201,8 +201,7 @@ def forward(self, *args, **kwargs) -> Dict[str, torch.Tensor]: for name, site in self._factors.items(): with ExitStack() as stack: for frame in site["cond_indep_stack"]: - if frame.vectorized: - stack.enter_context(plates[frame.name]) + stack.enter_context(plates[frame.name]) pyro.sample( name, dist.Delta(values[name], log_densities[name], site["fn"].event_dim), @@ -352,7 +351,7 @@ def _get_precision(self): flat_precision = torch.zeros(self._dense_size ** 2) for d, index in self._dense_scatter.items(): sqrt = deep_getattr(self.factors, d) - precision = sqrt @ sqrt.transpose(dim0=-2, dim1=-1) + precision = sqrt @ sqrt.transpose(-1, -2) flat_precision.scatter_add_(0, index, precision.reshape(-1)) precision = flat_precision.reshape(self._dense_size, self._dense_size) return precision @@ -368,6 +367,7 @@ class AutoGaussianFunsor(AutoGaussian): """ # Attributes are prefixed by ._funsor_ + # This uses tensor variable elimination (TVE). def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -386,29 +386,39 @@ def _setup_prototype(self, *args, **kwargs): funsor.set_backend("torch") + # Break plates globally to fit this into a TVE problem. + broken_plates = frozenset() + for d in self._factors: + for u in self.dependencies[d]: + broken_plates |= self._plates[u] - self._plates[d] + broken_vars: Dict[str, Tuple[funsor.Variable, ...]] = {} + broken_event_shapes: Dict[str, Tuple[int, ...]] = {} + for u, event_shape in self._unconstrained_event_shapes.items(): + plates = sorted(self._plates[u] & broken_plates, key=lambda p: p.size) + broken_vars[u] = tuple( + funsor.Variable(p.name, funsor.Bint[p.size]) for p in plates + ) + broken_event_shapes[u] = tuple(p.size for p in plates) + event_shape + # Determine TVE problem shape. factor_inputs: Dict[str, OrderedDict[str, funsor.Domain]] = {} eliminate: Set[str] = set() plate_to_dim: Dict[str, int] = {} - for d, site in self._factors.items(): - # Order inputs as in the model, so as to maximize sparsity of the - # lower Cholesky parametrization of the precision matrix. inputs = OrderedDict() for f in site["cond_indep_stack"]: - if f.vectorized: - plate_to_dim[f.name] = f.dim - if f.name not in self._broken_plates[d]: - inputs[f.name] = funsor.Bint[f.size] - eliminate.add(f.name) + plate_to_dim[f.name] = f.dim + if f not in broken_plates: + inputs[f.name] = funsor.Bint[f.size] + eliminate.add(f.name) + if not site["is_observed"]: + inputs[d] = funsor.Reals[broken_event_shapes[d]] for u in self.dependencies[d]: - inputs[u] = funsor.Reals[self._broken_event_shapes[u]] + inputs[u] = funsor.Reals[broken_event_shapes[u]] eliminate.add(u) - if not site["is_observed"]: - inputs[d] = funsor.Reals[self._broken_event_shapes[d]] - assert d in eliminate factor_inputs[d] = inputs + self._funsor_broken_vars = broken_vars self._funsor_factor_inputs = factor_inputs self._funsor_eliminate = frozenset(eliminate) self._funsor_plate_to_dim = plate_to_dim @@ -427,11 +437,12 @@ def _sample_aux_values( plate_to_dim.update({f.name: f.dim for f in particle_plates}) factors = {} for d, inputs in self._funsor_factor_inputs.items(): - precision_chol = deep_getattr(self.precision_chols, d) - precision = precision_chol @ precision_chol.transpose(-1, -2) + sqrt = deep_getattr(self.factors, d) + if self._funsor_broken_vars: + raise NotImplementedError("TODO break plates in sqrt") + precision = sqrt @ sqrt.transpose(-1, -2) info_vec = precision.new_zeros(()).expand(precision.shape[:-1]) factors[d] = funsor.gaussian.Gaussian(info_vec, precision, inputs) - factors[d]._precision_chol = precision_chol # avoid recomputing # Perform Gaussian tensor variable elimination. samples, log_prob = funsor.recipes.forward_filter_backward_rsample( @@ -443,7 +454,7 @@ def _sample_aux_values( # Convert funsor to torch. samples = { - k: funsor.to_data(v[self._broken_plates[k]], name_to_dim=plate_to_dim) + k: funsor.to_data(v[self._funsor_broken_vars[k]], name_to_dim=plate_to_dim) for k, v in samples.items() } log_density = funsor.to_data(log_prob, name_to_dim=plate_to_dim) From 5db1109bc4fdf400afe0a2c7d77404b96e3693ae Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Sun, 3 Oct 2021 07:45:25 -0400 Subject: [PATCH 31/41] Add tests, update docs --- pyro/infer/autoguide/gaussian.py | 18 ++++++------ tests/infer/autoguide/test_gaussian.py | 38 ++++++++++++++++++++++++-- 2 files changed, 46 insertions(+), 10 deletions(-) diff --git a/pyro/infer/autoguide/gaussian.py b/pyro/infer/autoguide/gaussian.py index 53c7c76ea2..7df62c3d51 100644 --- a/pyro/infer/autoguide/gaussian.py +++ b/pyro/infer/autoguide/gaussian.py @@ -4,7 +4,7 @@ import itertools from collections import OrderedDict, defaultdict from contextlib import ExitStack -from typing import Callable, Dict, Set, Tuple, Union +from typing import Callable, Dict, Optional, Set, Tuple, Union import torch from torch.distributions import biject_to @@ -54,12 +54,14 @@ class AutoGaussian(AutoGuide, metaclass=AutoGaussianMeta): the model [1]. Depending on model structure, this can have asymptotically better statistical efficiency than :class:`AutoMultivariateNormal` . - The default "dense" backend should have similar computational complexity to - :class:`AutoMultivariateNormal` . The experimental "funsor" backend can be - asymptotically cheaper in terms of time and space (using Gaussian tensor - variable elimination [2,3]), but incurs large constant overhead. The - "funsor" backend requires `funsor `_ which can be - installed via ``pip install pyro-ppl[funsor]``. + This guide implements multiple backends for computation. All backends use + the same statistically optimal parametrization. The default "dense" backend + has computational complexity similar to :class:`AutoMultivariateNormal` . + The experimental "funsor" backend can be asymptotically cheaper in terms of + time and space (using Gaussian tensor variable elimination [2,3]), but + incurs large constant overhead. The "funsor" backend requires `funsor + `_ which can be installed via ``pip install + pyro-ppl[funsor]``. The guide currently does not depend on the model's ``*args, **kwargs``. @@ -105,7 +107,7 @@ def __init__( *, init_loc_fn: Callable = init_to_feasible, init_scale: float = 0.1, - backend=None, + backend: Optional[str] = None, # used only by metaclass ): if not isinstance(init_scale, float) or not (init_scale > 0): raise ValueError(f"Expected init_scale > 0. but got {init_scale}") diff --git a/tests/infer/autoguide/test_gaussian.py b/tests/infer/autoguide/test_gaussian.py index 4a7a61e521..dfa0b0b9ac 100644 --- a/tests/infer/autoguide/test_gaussian.py +++ b/tests/infer/autoguide/test_gaussian.py @@ -14,7 +14,7 @@ from pyro.infer.autoguide.gaussian import _break_plates from pyro.infer.reparam import LocScaleReparam from pyro.optim import Adam -from tests.common import assert_equal +from tests.common import assert_equal, xfail_if_not_implemented BACKENDS = [ "dense", @@ -180,6 +180,40 @@ def model(): check_structure(model, expected) +@pytest.mark.parametrize("backend", BACKENDS) +def test_broken_plates_smoke(backend): + def model(): + with pyro.plate("i", 2): + x = pyro.sample("x", dist.Normal(0, 1)) + pyro.sample("y", dist.Normal(x.mean(-1), 1), obs=torch.tensor(0.0)) + + guide = AutoGaussian(model, backend=backend) + svi = SVI(model, guide, Adam({"lr": 1e-8}), Trace_ELBO()) + for step in range(2): + with xfail_if_not_implemented(): + svi.step() + guide() + predictive = Predictive(model, guide=guide, num_samples=2) + predictive() + + +@pytest.mark.parametrize("backend", BACKENDS) +def test_intractable_smoke(backend): + def model(): + with pyro.plate("i", 2): + x = pyro.sample("x", dist.Normal(0, 1)) + pyro.sample("y", dist.Normal(x.mean(-1), 1), obs=torch.tensor(0.0)) + + guide = AutoGaussian(model, backend=backend) + svi = SVI(model, guide, Adam({"lr": 1e-8}), Trace_ELBO()) + for step in range(2): + with xfail_if_not_implemented(): + svi.step() + guide() + predictive = Predictive(model, guide=guide, num_samples=2) + predictive() + + # Simplified from https://github.com/pyro-cov/tree/master/pyrocov/mutrans.py def pyrocov_model(dataset): # Tensor shapes are commented at the end of some lines. @@ -486,7 +520,7 @@ def test_profile(backend, n=1, num_steps=1): """ Helper function for profiling. """ - model = pyrocov_model_plated + model = pyrocov_model_poisson T, P, S, F = 2 * n, 3 * n, 4 * n, 5 * n dataset = { "features": torch.randn(S, F), From fe194eef2cadbb673683da76b4bda0cefd67ae6c Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Sun, 3 Oct 2021 07:54:33 -0400 Subject: [PATCH 32/41] Be safer about importing funsor --- pyro/infer/autoguide/gaussian.py | 31 ++++++++++++++++--------------- 1 file changed, 16 insertions(+), 15 deletions(-) diff --git a/pyro/infer/autoguide/gaussian.py b/pyro/infer/autoguide/gaussian.py index 7df62c3d51..dd9ee15393 100644 --- a/pyro/infer/autoguide/gaussian.py +++ b/pyro/infer/autoguide/gaussian.py @@ -373,20 +373,11 @@ class AutoGaussianFunsor(AutoGaussian): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - try: - import funsor - except ImportError as e: - raise ImportError( - 'AutoGaussian(..., backend="funsor") requires funsor. ' - "Try installing via: pip install pyro-ppl[funsor]" - ) from e - funsor.set_backend("torch") + _import_funsor() def _setup_prototype(self, *args, **kwargs): super()._setup_prototype(*args, **kwargs) - import funsor - - funsor.set_backend("torch") + funsor = _import_funsor() # Break plates globally to fit this into a TVE problem. broken_plates = frozenset() @@ -429,9 +420,7 @@ def _setup_prototype(self, *args, **kwargs): def _sample_aux_values( self, ) -> Tuple[Dict[str, torch.Tensor], Union[float, torch.Tensor]]: - import funsor - - funsor.set_backend("torch") + funsor = _import_funsor() # Convert torch to funsor. particle_plates = frozenset(get_plates()) @@ -440,7 +429,7 @@ def _sample_aux_values( factors = {} for d, inputs in self._funsor_factor_inputs.items(): sqrt = deep_getattr(self.factors, d) - if self._funsor_broken_vars: + if any(self._funsor_broken_vars[u] for u in self.dependencies[d]): raise NotImplementedError("TODO break plates in sqrt") precision = sqrt @ sqrt.transpose(-1, -2) info_vec = precision.new_zeros(()).expand(precision.shape[:-1]) @@ -500,6 +489,18 @@ def _break_plates(x, all_plates, kept_plates): return x.permute(perm).reshape(batch_shape + (-1,)) +def _import_funsor(): + try: + import funsor + except ImportError as e: + raise ImportError( + 'AutoGaussian(..., backend="funsor") requires funsor. ' + "Try installing via: pip install pyro-ppl[funsor]" + ) from e + funsor.set_backend("torch") + return funsor + + __all__ = [ "AutoGaussian", ] From 6cc53f8aa128c87e8004d6c890892105c3674ba9 Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Sun, 3 Oct 2021 08:39:06 -0400 Subject: [PATCH 33/41] Fix more tests --- pyro/infer/autoguide/gaussian.py | 33 +++++++++++++--- tests/infer/autoguide/test_gaussian.py | 55 +++++++++++++++++++------- 2 files changed, 69 insertions(+), 19 deletions(-) diff --git a/pyro/infer/autoguide/gaussian.py b/pyro/infer/autoguide/gaussian.py index dd9ee15393..81c89e1665 100644 --- a/pyro/infer/autoguide/gaussian.py +++ b/pyro/infer/autoguide/gaussian.py @@ -380,7 +380,7 @@ def _setup_prototype(self, *args, **kwargs): funsor = _import_funsor() # Break plates globally to fit this into a TVE problem. - broken_plates = frozenset() + broken_plates: frozenset = frozenset() for d in self._factors: for u in self.dependencies[d]: broken_plates |= self._plates[u] - self._plates[d] @@ -404,13 +404,12 @@ def _setup_prototype(self, *args, **kwargs): if f not in broken_plates: inputs[f.name] = funsor.Bint[f.size] eliminate.add(f.name) - if not site["is_observed"]: - inputs[d] = funsor.Reals[broken_event_shapes[d]] for u in self.dependencies[d]: inputs[u] = funsor.Reals[broken_event_shapes[u]] eliminate.add(u) factor_inputs[d] = inputs + self._funsor_broken_plates = broken_plates self._funsor_broken_vars = broken_vars self._funsor_factor_inputs = factor_inputs self._funsor_eliminate = frozenset(eliminate) @@ -429,8 +428,17 @@ def _sample_aux_values( factors = {} for d, inputs in self._funsor_factor_inputs.items(): sqrt = deep_getattr(self.factors, d) - if any(self._funsor_broken_vars[u] for u in self.dependencies[d]): - raise NotImplementedError("TODO break plates in sqrt") + kept_plates = self._plates[d] - self._funsor_broken_plates + sqrt = _break_plates_sqrt( + sqrt, + self._plates[d], + kept_plates, + [(self._plates[u], self._event_numel[u]) for u in self.dependencies[d]], + ) + batch_shape = torch.Size( + p.size for p in sorted(kept_plates, key=lambda p: p.dim) + ) + sqrt = sqrt.reshape(batch_shape + sqrt.shape[-2:]) precision = sqrt @ sqrt.transpose(-1, -2) info_vec = precision.new_zeros(()).expand(precision.shape[:-1]) factors[d] = funsor.gaussian.Gaussian(info_vec, precision, inputs) @@ -489,6 +497,21 @@ def _break_plates(x, all_plates, kept_plates): return x.permute(perm).reshape(batch_shape + (-1,)) +def _break_plates_sqrt( + x, + d_plates, + kept_plates, + u_plates_and_event_numels, +): + """ + Reshapes a sqrt precision parameter ``x`` with event_dim=2 and batch shape + given by d_plates by breaking all plates not in ``kept_plates``. + """ + if any(u_plates - kept_plates for u_plates, _ in u_plates_and_event_numels): + raise NotImplementedError("TODO break plates in sqrt") + return x + + def _import_funsor(): try: import funsor diff --git a/tests/infer/autoguide/test_gaussian.py b/tests/infer/autoguide/test_gaussian.py index dfa0b0b9ac..096e967571 100644 --- a/tests/infer/autoguide/test_gaussian.py +++ b/tests/infer/autoguide/test_gaussian.py @@ -358,13 +358,12 @@ def pyrocov_model_poisson(dataset): S, F = features.shape weekly_strains = dataset["weekly_strains"] # [T, P, S] assert weekly_strains.shape == (T, P, S) - feature_plate = pyro.plate("feature", F, dim=-1) strain_plate = pyro.plate("strain", S, dim=-1) place_plate = pyro.plate("place", P, dim=-2) time_plate = pyro.plate("time", T, dim=-3) # Sample global random variables. - coef_scale = pyro.sample("coef_scale", dist.InverseGamma(5e3, 1e2)) + coef_scale = pyro.sample("coef_scale", dist.LogNormal(-4, 2)) rate_loc_scale = pyro.sample("rate_loc_scale", dist.LogNormal(-4, 2)) rate_scale = pyro.sample("rate_scale", dist.LogNormal(-4, 2)) init_loc_scale = pyro.sample("init_loc_scale", dist.LogNormal(0, 2)) @@ -372,8 +371,9 @@ def pyrocov_model_poisson(dataset): pois_loc = pyro.sample("pois_loc", dist.Normal(0, 2)) pois_scale = pyro.sample("pois_scale", dist.LogNormal(0, 2)) - with feature_plate: - coef = pyro.sample("coef", dist.Logistic(0, coef_scale)) # [F] + coef = pyro.sample( + "coef", dist.Logistic(torch.zeros(F), coef_scale).to_event(1) + ) # [F] rate_loc_loc = 0.01 * coef @ features.T with strain_plate: rate_loc = pyro.sample( @@ -416,7 +416,8 @@ def test_pyrocov_smoke(model, backend): guide = AutoGaussian(model, backend=backend) svi = SVI(model, guide, Adam({"lr": 1e-8}), Trace_ELBO()) for step in range(2): - svi.step(dataset) + with xfail_if_not_implemented(): + svi.step(dataset) guide(dataset) predictive = Predictive(model, guide=guide, num_samples=2) predictive(dataset) @@ -444,7 +445,8 @@ def test_pyrocov_reparam(model, backend): guide = AutoGaussian(model, backend=backend) svi = SVI(model, guide, Adam({"lr": 1e-8}), Trace_ELBO()) for step in range(2): - svi.step(dataset) + with xfail_if_not_implemented(): + svi.step(dataset) guide(dataset) predictive = Predictive(model, guide=guide, num_samples=2) predictive(dataset) @@ -461,15 +463,17 @@ def test_pyrocov_structure(backend): "weekly_strains": torch.randn(T, P, S).exp().round(), } - guide = AutoGaussian(pyrocov_model_plated, backend=backend) + guide = AutoGaussian(pyrocov_model_poisson, backend=backend) guide(dataset) # initialize - expected_plates = frozenset(["place", "feature", "strain"]) + expected_plates = frozenset(["time", "place", "strain"]) assert guide._funsor_plates == expected_plates expected_eliminate = frozenset( [ + "time", "place", + "strain", "coef_scale", "rate_loc_scale", "rate_scale", @@ -480,6 +484,9 @@ def test_pyrocov_structure(backend): "init_loc", "rate", "init", + "pois_loc", + "pois_scale", + "pois", ] ) assert guide._funsor_eliminate == expected_eliminate @@ -490,25 +497,45 @@ def test_pyrocov_structure(backend): "rate_scale": OrderedDict([("rate_scale", Real)]), "init_loc_scale": OrderedDict([("init_loc_scale", Real)]), "init_scale": OrderedDict([("init_scale", Real)]), + "pois_loc": OrderedDict([("pois_loc", Real)]), + "pois_scale": OrderedDict([("pois_scale", Real)]), "coef": OrderedDict([("coef", Reals[5]), ("coef_scale", Real)]), "rate_loc": OrderedDict( - [("rate_loc", Reals[4]), ("rate_loc_scale", Real), ("coef", Reals[5])] + [ + ("strain", Bint[4]), + ("rate_loc", Real), + ("rate_loc_scale", Real), + ("coef", Reals[5]), + ] + ), + "init_loc": OrderedDict( + [("strain", Bint[4]), ("init_loc", Real), ("init_loc_scale", Real)] ), - "init_loc": OrderedDict([("init_loc", Reals[4]), ("init_loc_scale", Real)]), "rate": OrderedDict( [ ("place", Bint[3]), - ("rate", Reals[4]), + ("strain", Bint[4]), + ("rate", Real), ("rate_scale", Real), - ("rate_loc", Reals[4]), + ("rate_loc", Real), ] ), "init": OrderedDict( [ ("place", Bint[3]), - ("init", Reals[4]), + ("strain", Bint[4]), + ("init", Real), ("init_scale", Real), - ("init_loc", Reals[4]), + ("init_loc", Real), + ] + ), + "pois": OrderedDict( + [ + ("time", Bint[2]), + ("place", Bint[3]), + ("pois", Real), + ("pois_loc", Real), + ("pois_scale", Real), ] ), } From 28ea4af31d9ae781f5a88eb016f27fc801852ead Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Sun, 3 Oct 2021 14:18:00 -0400 Subject: [PATCH 34/41] Add a test --- tests/infer/autoguide/test_gaussian.py | 27 +++++++++++++++++++++----- tests/infer/test_autoguide.py | 4 ++-- 2 files changed, 24 insertions(+), 7 deletions(-) diff --git a/tests/infer/autoguide/test_gaussian.py b/tests/infer/autoguide/test_gaussian.py index 096e967571..b20b83363a 100644 --- a/tests/infer/autoguide/test_gaussian.py +++ b/tests/infer/autoguide/test_gaussian.py @@ -11,7 +11,11 @@ import pyro.poutine as poutine from pyro.infer import SVI, Predictive, Trace_ELBO from pyro.infer.autoguide import AutoGaussian -from pyro.infer.autoguide.gaussian import _break_plates +from pyro.infer.autoguide.gaussian import ( + AutoGaussianDense, + AutoGaussianFunsor, + _break_plates, +) from pyro.infer.reparam import LocScaleReparam from pyro.optim import Adam from tests.common import assert_equal, xfail_if_not_implemented @@ -22,16 +26,15 @@ ] -MockPlate = namedtuple("MockPlate", "dim, size") - - def test_break_plates(): shape = torch.Size([5, 4, 3, 2]) + x = torch.arange(shape.numel()).reshape(shape) + + MockPlate = namedtuple("MockPlate", "dim, size") h = MockPlate(-4, 6) i = MockPlate(-3, 5) j = MockPlate(-2, 4) k = MockPlate(-1, 3) - x = torch.arange(shape.numel()).reshape(shape) actual = _break_plates(x, {i, j, k}, set()) expected = x.reshape(-1) @@ -70,6 +73,20 @@ def test_break_plates(): assert_equal(actual, expected) +@pytest.mark.parametrize("backend", BACKENDS) +def test_backend_dispatch(backend): + def model(): + pyro.sample("x", dist.Normal(0, 1)) + + guide = AutoGaussian(model, backend=backend) + if backend == "dense": + assert isinstance(guide, AutoGaussianDense) + elif backend == "funsor": + assert isinstance(guide, AutoGaussianFunsor) + else: + raise ValueError(f"Unknown backend: {backend}") + + def check_structure(model, expected_str): guide = AutoGaussian(model, backend="dense") guide() # initialize diff --git a/tests/infer/test_autoguide.py b/tests/infer/test_autoguide.py index 31f4b64749..9cd9c89788 100644 --- a/tests/infer/test_autoguide.py +++ b/tests/infer/test_autoguide.py @@ -1265,8 +1265,8 @@ def model(data): guide(data) # guide.scales.train(False) # guide.scales.requires_grad_(False) - # guide.precision_chols.train(False) - # guide.precision_chols.requires_grad_(False) + # guide.factor.train(False) + # guide.factor.requires_grad_(False) elbo = Trace_ELBO(num_particles=100, vectorize_particles=True) optim = Adam({"lr": 0.01}) From b774166fb2519e0f143880c884dadce47af7397a Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Sun, 3 Oct 2021 20:46:31 -0400 Subject: [PATCH 35/41] Simplify; fix bugs --- pyro/infer/autoguide/gaussian.py | 111 +++++++++---------------------- tests/infer/test_autoguide.py | 8 --- 2 files changed, 32 insertions(+), 87 deletions(-) diff --git a/pyro/infer/autoguide/gaussian.py b/pyro/infer/autoguide/gaussian.py index 81c89e1665..d9156f9422 100644 --- a/pyro/infer/autoguide/gaussian.py +++ b/pyro/infer/autoguide/gaussian.py @@ -195,7 +195,7 @@ def forward(self, *args, **kwargs) -> Dict[str, torch.Tensor]: if self.prototype_trace is None: self._setup_prototype(*args, **kwargs) - aux_values, global_log_density = self._sample_aux_values() + aux_values = self._sample_aux_values() values, log_densities = self._transform_values(aux_values) # Replay via Pyro primitives. @@ -204,12 +204,10 @@ def forward(self, *args, **kwargs) -> Dict[str, torch.Tensor]: with ExitStack() as stack: for frame in site["cond_indep_stack"]: stack.enter_context(plates[frame.name]) - pyro.sample( + values[name] = pyro.sample( name, dist.Delta(values[name], log_densities[name], site["fn"].event_dim), ) - if am_i_wrapped() and poutine.get_mask() is not False: - pyro.factor(self._pyro_name, global_log_density) return values def median(self) -> Dict[str, torch.Tensor]: @@ -248,9 +246,7 @@ def _transform_values( return values, log_densities - def _sample_aux_values( - self, - ) -> Tuple[Dict[str, torch.Tensor], Union[float, torch.Tensor]]: + def _sample_aux_values(self) -> Dict[str, torch.Tensor]: raise NotImplementedError @@ -264,8 +260,6 @@ class AutoGaussianDense(AutoGaussian): guide = AutoGaussianDense(model) """ - # Attributes are prefixed by ._dense_ - def _setup_prototype(self, *args, **kwargs): super()._setup_prototype(*args, **kwargs) @@ -316,25 +310,16 @@ def _setup_prototype(self, *args, **kwargs): self._dense_scatter[d] = index.reshape(-1) - def _sample_aux_values( - self, - ) -> Tuple[Dict[str, torch.Tensor], Union[float, torch.Tensor]]: - from pyro.ops.gaussian import Gaussian - - # Convert to a dense joint Gaussian over flattened variables. + def _sample_aux_values(self) -> Dict[str, torch.Tensor]: + # Sample from a dense joint Gaussian over flattened variables. precision = self._get_precision() - info_vec = torch.zeros(self._dense_size) - log_normalizer = torch.zeros(()) - g = Gaussian(log_normalizer, info_vec, precision) - - # Draw a batch of samples. - particle_plates = frozenset(get_plates()) - sample_shape = [1] * max([0] + [-p.dim for p in particle_plates]) - for p in particle_plates: - sample_shape[p.dim] = p.size - sample_shape = torch.Size(sample_shape) - flat_samples = g.rsample(sample_shape) - log_density = g.log_density(flat_samples) - g.event_logsumexp() + loc = precision.new_zeros(self._dense_size) + flat_samples = pyro.sample( + f"_{self._pyro_name}", + dist.MultivariateNormal(loc, precision_matrix=precision), + infer={"is_auxiliary": True}, + ) + sample_shape = flat_samples.shape[:-1] # Convert flat to shaped tensors. samples = {} @@ -347,7 +332,7 @@ def _sample_aux_values( samples[d] = flat_sample.reshape( torch.broadcast_shapes(sample_shape, batch_shape) + event_shape ) - return samples, log_density + return samples def _get_precision(self): flat_precision = torch.zeros(self._dense_size ** 2) @@ -368,9 +353,6 @@ class AutoGaussianFunsor(AutoGaussian): guide = AutoGaussianFunsor(model) """ - # Attributes are prefixed by ._funsor_ - # This uses tensor variable elimination (TVE). - def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) _import_funsor() @@ -379,19 +361,17 @@ def _setup_prototype(self, *args, **kwargs): super()._setup_prototype(*args, **kwargs) funsor = _import_funsor() - # Break plates globally to fit this into a TVE problem. - broken_plates: frozenset = frozenset() + # Check plates are strictly nested. for d in self._factors: for u in self.dependencies[d]: - broken_plates |= self._plates[u] - self._plates[d] - broken_vars: Dict[str, Tuple[funsor.Variable, ...]] = {} - broken_event_shapes: Dict[str, Tuple[int, ...]] = {} - for u, event_shape in self._unconstrained_event_shapes.items(): - plates = sorted(self._plates[u] & broken_plates, key=lambda p: p.size) - broken_vars[u] = tuple( - funsor.Variable(p.name, funsor.Bint[p.size]) for p in plates - ) - broken_event_shapes[u] = tuple(p.size for p in plates) + event_shape + broken_plates = self._plates[u] - self._plates[d] + if broken_plates: + raise NotImplementedError( + "Expected strictly increasing plates, but found dependency " + f"{u} -> {d} leaves plates {set(broken_plates)}. " + "Consider splitting into multiple guides via AutoGuideList, " + "or replacing the plate in the model by .to_event()." + ) # Determine TVE problem shape. factor_inputs: Dict[str, OrderedDict[str, funsor.Domain]] = {} @@ -401,24 +381,19 @@ def _setup_prototype(self, *args, **kwargs): inputs = OrderedDict() for f in site["cond_indep_stack"]: plate_to_dim[f.name] = f.dim - if f not in broken_plates: - inputs[f.name] = funsor.Bint[f.size] - eliminate.add(f.name) + inputs[f.name] = funsor.Bint[f.size] + eliminate.add(f.name) for u in self.dependencies[d]: - inputs[u] = funsor.Reals[broken_event_shapes[u]] + inputs[u] = funsor.Reals[self._unconstrained_event_shapes[u]] eliminate.add(u) factor_inputs[d] = inputs - self._funsor_broken_plates = broken_plates - self._funsor_broken_vars = broken_vars self._funsor_factor_inputs = factor_inputs self._funsor_eliminate = frozenset(eliminate) self._funsor_plate_to_dim = plate_to_dim self._funsor_plates = frozenset(plate_to_dim) - def _sample_aux_values( - self, - ) -> Tuple[Dict[str, torch.Tensor], Union[float, torch.Tensor]]: + def _sample_aux_values(self) -> Dict[str, torch.Tensor]: funsor = _import_funsor() # Convert torch to funsor. @@ -428,15 +403,8 @@ def _sample_aux_values( factors = {} for d, inputs in self._funsor_factor_inputs.items(): sqrt = deep_getattr(self.factors, d) - kept_plates = self._plates[d] - self._funsor_broken_plates - sqrt = _break_plates_sqrt( - sqrt, - self._plates[d], - kept_plates, - [(self._plates[u], self._event_numel[u]) for u in self.dependencies[d]], - ) batch_shape = torch.Size( - p.size for p in sorted(kept_plates, key=lambda p: p.dim) + p.size for p in sorted(self._plates[d], key=lambda p: p.dim) ) sqrt = sqrt.reshape(batch_shape + sqrt.shape[-2:]) precision = sqrt @ sqrt.transpose(-1, -2) @@ -452,13 +420,13 @@ def _sample_aux_values( ) # Convert funsor to torch. + if am_i_wrapped() and poutine.get_mask() is not False: + log_prob = funsor.to_data(log_prob, name_to_dim=plate_to_dim) + pyro.factor(f"_{self._pyro_name}_latent", log_prob) samples = { - k: funsor.to_data(v[self._funsor_broken_vars[k]], name_to_dim=plate_to_dim) - for k, v in samples.items() + k: funsor.to_data(v, name_to_dim=plate_to_dim) for k, v in samples.items() } - log_density = funsor.to_data(log_prob, name_to_dim=plate_to_dim) - - return samples, log_density + return samples def _plates_to_shape(plates): @@ -497,21 +465,6 @@ def _break_plates(x, all_plates, kept_plates): return x.permute(perm).reshape(batch_shape + (-1,)) -def _break_plates_sqrt( - x, - d_plates, - kept_plates, - u_plates_and_event_numels, -): - """ - Reshapes a sqrt precision parameter ``x`` with event_dim=2 and batch shape - given by d_plates by breaking all plates not in ``kept_plates``. - """ - if any(u_plates - kept_plates for u_plates, _ in u_plates_and_event_numels): - raise NotImplementedError("TODO break plates in sqrt") - return x - - def _import_funsor(): try: import funsor diff --git a/tests/infer/test_autoguide.py b/tests/infer/test_autoguide.py index 9cd9c89788..7a77867ac0 100644 --- a/tests/infer/test_autoguide.py +++ b/tests/infer/test_autoguide.py @@ -1260,14 +1260,6 @@ def model(data): expected_loss = float(g.event_logsumexp() - g.condition(data).event_logsumexp()) guide = Guide(model) - - # DEBUG - guide(data) - # guide.scales.train(False) - # guide.scales.requires_grad_(False) - # guide.factor.train(False) - # guide.factor.requires_grad_(False) - elbo = Trace_ELBO(num_particles=100, vectorize_particles=True) optim = Adam({"lr": 0.01}) svi = SVI(model, guide, optim, elbo) From 88912b29cc810a35a09a3c7b369c837db6f8d938 Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Sun, 3 Oct 2021 20:54:52 -0400 Subject: [PATCH 36/41] Tweak test parameters --- tests/infer/test_autoguide.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/tests/infer/test_autoguide.py b/tests/infer/test_autoguide.py index 7a77867ac0..d95348be45 100644 --- a/tests/infer/test_autoguide.py +++ b/tests/infer/test_autoguide.py @@ -39,7 +39,7 @@ from pyro.infer.reparam import ProjectedNormalReparam from pyro.nn.module import PyroModule, PyroParam, PyroSample from pyro.ops.gaussian import Gaussian -from pyro.optim import Adam +from pyro.optim import Adam, ClippedAdam from pyro.poutine.util import prune_subsample_sites from pyro.util import check_model_guide_match from tests.common import assert_close, assert_equal @@ -1386,9 +1386,10 @@ def model(data): guide = Guide(model) elbo = Trace_ELBO(num_particles=100, vectorize_particles=True) - optim = Adam({"lr": 0.01}) + num_steps = 500 + optim = ClippedAdam({"lr": 0.05, "lrd": 0.1 ** (1 / num_steps)}) svi = SVI(model, guide, optim, elbo) - for step in range(500): + for step in range(num_steps): svi.step(data) guide.train(False) From 2c2536af5fbf6e59370c7850fb607c8a15b76ea6 Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Sun, 3 Oct 2021 21:11:55 -0400 Subject: [PATCH 37/41] Revert unnecessary change --- pyro/contrib/funsor/handlers/enum_messenger.py | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/pyro/contrib/funsor/handlers/enum_messenger.py b/pyro/contrib/funsor/handlers/enum_messenger.py index 3b85963943..3815f1934e 100644 --- a/pyro/contrib/funsor/handlers/enum_messenger.py +++ b/pyro/contrib/funsor/handlers/enum_messenger.py @@ -11,7 +11,6 @@ import funsor import torch -from funsor.adjoint import _alpha_unmangle as alpha_unmangle # FIXME publish import pyro.poutine.runtime import pyro.poutine.util @@ -27,10 +26,6 @@ @functools.singledispatch def _get_support_value(funsor_dist, name, **kwargs): - """ - Extracts the sample value out of a funsor Delta, - possibly wrapped in reductions over sample plates. - """ raise ValueError( "Could not extract point from {} at name {}".format(funsor_dist, name) ) @@ -38,10 +33,13 @@ def _get_support_value(funsor_dist, name, **kwargs): @_get_support_value.register(funsor.cnf.Contraction) def _get_support_value_contraction(funsor_dist, name, **kwargs): - unmangled_terms = alpha_unmangle(funsor_dist)[-1] - terms = [v for v in unmangled_terms if name in v.inputs] - assert len(terms) == 1 - return _get_support_value(terms[0], name, **kwargs) + delta_terms = [ + v + for v in funsor_dist.terms + if isinstance(v, funsor.delta.Delta) and name in v.fresh + ] + assert len(delta_terms) == 1 + return _get_support_value(delta_terms[0], name, **kwargs) @_get_support_value.register(funsor.delta.Delta) From fcf5b1ea7925265b291f269a52112a4ad6c7a324 Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Sun, 3 Oct 2021 21:56:52 -0400 Subject: [PATCH 38/41] Mark test funsor stage --- tests/infer/autoguide/test_gaussian.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/infer/autoguide/test_gaussian.py b/tests/infer/autoguide/test_gaussian.py index b20b83363a..04d9ec7c5c 100644 --- a/tests/infer/autoguide/test_gaussian.py +++ b/tests/infer/autoguide/test_gaussian.py @@ -469,8 +469,8 @@ def test_pyrocov_reparam(model, backend): predictive(dataset) -@pytest.mark.parametrize("backend", ["funsor"]) -def test_pyrocov_structure(backend): +@pytest.mark.stage("funsor") +def test_pyrocov_structure(): from funsor import Bint, Real, Reals T, P, S, F = 2, 3, 4, 5 @@ -480,7 +480,7 @@ def test_pyrocov_structure(backend): "weekly_strains": torch.randn(T, P, S).exp().round(), } - guide = AutoGaussian(pyrocov_model_poisson, backend=backend) + guide = AutoGaussian(pyrocov_model_poisson, backend="funsor") guide(dataset) # initialize expected_plates = frozenset(["time", "place", "strain"]) From ade50a978fdeada01cc714e9bcd8ae437c220a89 Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Mon, 4 Oct 2021 13:29:33 -0400 Subject: [PATCH 39/41] Wrap intractable error in NotImplemented --- pyro/infer/autoguide/gaussian.py | 49 ++++++++++------ tests/infer/autoguide/test_gaussian.py | 34 +++++++++-- tests/infer/test_autoguide.py | 80 +++++++++++++++++--------- 3 files changed, 111 insertions(+), 52 deletions(-) diff --git a/pyro/infer/autoguide/gaussian.py b/pyro/infer/autoguide/gaussian.py index d9156f9422..4850d884ea 100644 --- a/pyro/infer/autoguide/gaussian.py +++ b/pyro/infer/autoguide/gaussian.py @@ -176,9 +176,8 @@ def _setup_prototype(self, *args, **kwargs) -> None: for d, site in self._factors.items(): u_size = 0 for u in self.dependencies[d]: - if not self._factors[u]["is_observed"]: - broken_shape = _plates_to_shape(self._plates[u] - self._plates[d]) - u_size += broken_shape.numel() * self._event_numel[u] + broken_shape = _plates_to_shape(self._plates[u] - self._plates[d]) + u_size += broken_shape.numel() * self._event_numel[u] d_size = self._event_numel[d] if site["is_observed"]: d_size = min(d_size, u_size) # just an optimization @@ -361,17 +360,19 @@ def _setup_prototype(self, *args, **kwargs): super()._setup_prototype(*args, **kwargs) funsor = _import_funsor() - # Check plates are strictly nested. + # Check TVE condition 1: plate nesting is monotone. for d in self._factors: + pd = {p.name for p in self._plates[d]} for u in self.dependencies[d]: - broken_plates = self._plates[u] - self._plates[d] - if broken_plates: - raise NotImplementedError( - "Expected strictly increasing plates, but found dependency " - f"{u} -> {d} leaves plates {set(broken_plates)}. " - "Consider splitting into multiple guides via AutoGuideList, " - "or replacing the plate in the model by .to_event()." - ) + pu = {p.name for p in self._plates[u]} + if pu <= pd: + continue # ok + raise NotImplementedError( + "Expected monotone plate nesting, but found dependency " + f"{repr(u)} -> {repr(d)} leaves plates {pu - pd}. " + "Consider splitting into multiple guides via AutoGuideList, " + "or replacing the plate in the model by .to_event()." + ) # Determine TVE problem shape. factor_inputs: Dict[str, OrderedDict[str, funsor.Domain]] = {} @@ -379,7 +380,7 @@ def _setup_prototype(self, *args, **kwargs): plate_to_dim: Dict[str, int] = {} for d, site in self._factors.items(): inputs = OrderedDict() - for f in site["cond_indep_stack"]: + for f in sorted(site["cond_indep_stack"], key=lambda f: f.dim): plate_to_dim[f.name] = f.dim inputs[f.name] = funsor.Bint[f.size] eliminate.add(f.name) @@ -412,12 +413,22 @@ def _sample_aux_values(self) -> Dict[str, torch.Tensor]: factors[d] = funsor.gaussian.Gaussian(info_vec, precision, inputs) # Perform Gaussian tensor variable elimination. - samples, log_prob = funsor.recipes.forward_filter_backward_rsample( - factors=factors, - eliminate=self._funsor_eliminate, - plates=frozenset(plate_to_dim), - sample_inputs={f.name: funsor.Bint[f.size] for f in particle_plates}, - ) + try: # Convert ValueError into NotImplementedError. + samples, log_prob = funsor.recipes.forward_filter_backward_rsample( + factors=factors, + eliminate=self._funsor_eliminate, + plates=frozenset(plate_to_dim), + sample_inputs={f.name: funsor.Bint[f.size] for f in particle_plates}, + ) + except ValueError as e: + if str(e) != "intractable!": + raise e from None + raise NotImplementedError( + "Funsor backend found intractable plate nesting. " + 'Consider using AutoGaussian(..., backend="dense"), ' + "splitting into multiple guides via AutoGuideList, or " + "replacing some plates in the model by .to_event()." + ) from e # Convert funsor to torch. if am_i_wrapped() and poutine.get_mask() is not False: diff --git a/tests/infer/autoguide/test_gaussian.py b/tests/infer/autoguide/test_gaussian.py index 04d9ec7c5c..56e7af64bc 100644 --- a/tests/infer/autoguide/test_gaussian.py +++ b/tests/infer/autoguide/test_gaussian.py @@ -197,12 +197,30 @@ def model(): check_structure(model, expected) +def test_structure_5(): + def model(): + i_plate = pyro.plate("i", 2, dim=-1) + with i_plate: + a = pyro.sample("a", dist.Normal(0, 1)) + b = pyro.sample("b", dist.Normal(a.mean(-1), 1)) + with i_plate: + pyro.sample("c", dist.Normal(b, 1), obs=torch.zeros(2)) + + # size = 2 + 1 = 3 + expected = [ + "? . ?", + ". ? ?", + "? ? ?", + ] + check_structure(model, expected) + + @pytest.mark.parametrize("backend", BACKENDS) def test_broken_plates_smoke(backend): def model(): with pyro.plate("i", 2): - x = pyro.sample("x", dist.Normal(0, 1)) - pyro.sample("y", dist.Normal(x.mean(-1), 1), obs=torch.tensor(0.0)) + a = pyro.sample("a", dist.Normal(0, 1)) + pyro.sample("b", dist.Normal(a.mean(-1), 1), obs=torch.tensor(0.0)) guide = AutoGaussian(model, backend=backend) svi = SVI(model, guide, Adam({"lr": 1e-8}), Trace_ELBO()) @@ -217,9 +235,15 @@ def model(): @pytest.mark.parametrize("backend", BACKENDS) def test_intractable_smoke(backend): def model(): - with pyro.plate("i", 2): - x = pyro.sample("x", dist.Normal(0, 1)) - pyro.sample("y", dist.Normal(x.mean(-1), 1), obs=torch.tensor(0.0)) + i_plate = pyro.plate("i", 2, dim=-1) + j_plate = pyro.plate("j", 3, dim=-2) + with i_plate: + a = pyro.sample("a", dist.Normal(0, 1)) + with j_plate: + b = pyro.sample("b", dist.Normal(0, 1)) + with i_plate, j_plate: + c = pyro.sample("c", dist.Normal(a + b, 1)) + pyro.sample("d", dist.Normal(c, 1), obs=torch.zeros(3, 2)) guide = AutoGaussian(model, backend=backend) svi = SVI(model, guide, Adam({"lr": 1e-8}), Trace_ELBO()) diff --git a/tests/infer/test_autoguide.py b/tests/infer/test_autoguide.py index d95348be45..7f75629197 100644 --- a/tests/infer/test_autoguide.py +++ b/tests/infer/test_autoguide.py @@ -4,7 +4,6 @@ import functools import io import warnings -from operator import attrgetter import numpy as np import pytest @@ -44,6 +43,13 @@ from pyro.util import check_model_guide_match from tests.common import assert_close, assert_equal +AutoGaussianFunsor_xfail = pytest.param( + AutoGaussianFunsor, + marks=[ + pytest.mark.stage("funsor"), + pytest.mark.xfail(reason="jit is not supported"), + ], +) AutoGaussianFunsor = pytest.param( AutoGaussianFunsor, marks=[pytest.mark.stage("funsor")] ) @@ -377,6 +383,14 @@ def model(): assert_equal(median["z"], torch.tensor(0.5), prec=0.1) +def serialize_model(): + pyro.sample("x", dist.Normal(0.0, 1.0)) + with pyro.plate("plate", 2): + pyro.sample("y", dist.LogNormal(0.0, 1.0)) + pyro.sample("z", dist.Beta(2.0, 2.0)) + + +@pytest.mark.parametrize("jit", [False, True], ids=["nojit", "jit"]) @pytest.mark.parametrize( "auto_class", [ @@ -395,49 +409,57 @@ def model(): functools.partial(AutoDiagonalNormal, init_loc_fn=init_to_sample), AutoStructured, AutoStructured_median, + AutoGaussian, + AutoGaussianFunsor_xfail, ], ) @pytest.mark.parametrize("Elbo", [Trace_ELBO, TraceGraph_ELBO, TraceEnum_ELBO]) -def test_autoguide_serialization(auto_class, Elbo): - def model(): - pyro.sample("x", dist.Normal(0.0, 1.0)) - with pyro.plate("plate", 2): - pyro.sample("y", dist.LogNormal(0.0, 1.0)) - pyro.sample("z", dist.Beta(2.0, 2.0)) - - guide = auto_class(model) +def test_serialization(auto_class, Elbo, jit): + guide = auto_class(serialize_model) guide() if auto_class is AutoLaplaceApproximation: guide = guide.laplace_approximation() pyro.set_rng_seed(0) expected = guide.call() - names = sorted(guide()) - - # Ignore tracer warnings - with warnings.catch_warnings(): - warnings.filterwarnings("ignore", category=torch.jit.TracerWarning) - # XXX: check_trace=True fails for AutoLaplaceApproximation - traced_guide = torch.jit.trace_module(guide, {"call": ()}, check_trace=False) - f = io.BytesIO() - torch.jit.save(traced_guide, f) - f.seek(0) - guide_deser = torch.jit.load(f) + latent_names = sorted(guide()) + expected_params = {k: v.data for k, v in guide.named_parameters()} + + if jit: + # Ignore tracer warnings + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", category=torch.jit.TracerWarning) + # XXX: check_trace=True fails for AutoLaplaceApproximation + traced_guide = torch.jit.trace_module( + guide, {"call": ()}, check_trace=False + ) + f = io.BytesIO() + torch.jit.save(traced_guide, f) + del guide, traced_guide + pyro.clear_param_store() + f.seek(0) + guide_deser = torch.jit.load(f) + else: + # Work around https://github.com/pytorch/pytorch/issues/27972 + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", category=UserWarning) + f = io.BytesIO() + torch.save(guide, f) + f.seek(0) + guide_deser = torch.load(f) # Check .call() result. pyro.set_rng_seed(0) actual = guide_deser.call() assert len(actual) == len(expected) - for name, a, e in zip(names, actual, expected): + for name, a, e in zip(latent_names, actual, expected): assert_equal(a, e, msg="{}: {} vs {}".format(name, a, e)) # Check named_parameters. - expected_names = {name for name, _ in guide.named_parameters()} - actual_names = {name for name, _ in guide_deser.named_parameters()} - assert actual_names == expected_names - for name in actual_names: - # Get nested attributes. - attr_get = attrgetter(name) - assert_equal(attr_get(guide_deser), attr_get(guide).data) + actual_params = {k: v.data for k, v in guide_deser.named_parameters()} + assert set(actual_params) == set(expected_params) + for name, expected in expected_params.items(): + actual = actual_params[name] + assert_equal(actual, expected) def AutoGuideList_x(model): @@ -871,6 +893,8 @@ def __init__(self, model): functools.partial(AutoDiagonalNormal, init_loc_fn=init_to_median), AutoStructured, AutoStructured_predictive, + AutoGaussian, + AutoGaussianFunsor_xfail, ], ) def test_predictive(auto_class): From 57c63d03f7fe8fe6f5d4fe945b2c84b7435a5a00 Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Mon, 4 Oct 2021 14:02:50 -0400 Subject: [PATCH 40/41] Fix serialization tests --- pyro/infer/autoguide/gaussian.py | 1 + pyro/infer/autoguide/guides.py | 26 +++++++++++++++------ pyro/infer/autoguide/structured.py | 1 + tests/infer/test_autoguide.py | 36 +++++++++++++++--------------- 4 files changed, 39 insertions(+), 25 deletions(-) diff --git a/pyro/infer/autoguide/gaussian.py b/pyro/infer/autoguide/gaussian.py index 4850d884ea..b7fb190f3a 100644 --- a/pyro/infer/autoguide/gaussian.py +++ b/pyro/infer/autoguide/gaussian.py @@ -129,6 +129,7 @@ def _setup_prototype(self, *args, **kwargs) -> None: # Trace model dependencies. model = self._original_model[0] + self._original_model = None self.dependencies = poutine.block(get_dependencies)(model, args, kwargs)[ "prior_dependencies" ] diff --git a/pyro/infer/autoguide/guides.py b/pyro/infer/autoguide/guides.py index 63b64b9f52..ce42278610 100644 --- a/pyro/infer/autoguide/guides.py +++ b/pyro/infer/autoguide/guides.py @@ -76,8 +76,25 @@ def __init__(self, model, *, create_plates=None): def model(self): return self._model[0] + def __getstate__(self): + # Do not pickle weakrefs. + self._model = None + self.master = None + return getattr(super(), "__getstate__", self.__dict__.copy)() + + def __setstate__(self, state): + getattr(super(), "__getstate__", self.__dict__.update)(state) + assert self.master is None + master_ref = weakref.ref(self) + for _, mod in self.named_modules(): + if mod is not self and isinstance(mod, AutoGuide): + mod._update_master(master_ref) + def _update_master(self, master_ref): self.master = master_ref + for _, mod in self.named_modules(): + if mod is not self and isinstance(mod, AutoGuide): + mod._update_master(master_ref) def call(self, *args, **kwargs): """ @@ -103,8 +120,8 @@ def sample_latent(*args, **kwargs): def __setattr__(self, name, value): if isinstance(value, AutoGuide): - master_ref = self if self.master is None else self.master - value._update_master(weakref.ref(master_ref)) + master_ref = weakref.ref(self) if self.master is None else self.master + value._update_master(master_ref) super().__setattr__(name, value) def _create_plates(self, *args, **kwargs): @@ -184,11 +201,6 @@ def _check_prototype(self, part_trace): assert part_site["fn"].event_shape == self_site["fn"].event_shape assert part_site["value"].shape == self_site["value"].shape - def _update_master(self, master_ref): - self.master = master_ref - for submodule in self: - submodule._update_master(master_ref) - def append(self, part): """ Add an automatic guide for part of the model. The guide should diff --git a/pyro/infer/autoguide/structured.py b/pyro/infer/autoguide/structured.py index 2e29f3fb29..4825030a2e 100644 --- a/pyro/infer/autoguide/structured.py +++ b/pyro/infer/autoguide/structured.py @@ -159,6 +159,7 @@ def _auto_config(self, sample_sites, args, kwargs): elif prior_order[d] > prior_order[u]: dependencies[d][u] = self.dependencies self.dependencies = dict(dependencies) + self._original_model = None def _setup_prototype(self, *args, **kwargs): super()._setup_prototype(*args, **kwargs) diff --git a/tests/infer/test_autoguide.py b/tests/infer/test_autoguide.py index 7f75629197..67eb3319bb 100644 --- a/tests/infer/test_autoguide.py +++ b/tests/infer/test_autoguide.py @@ -289,19 +289,20 @@ def median_x(): return guide -def auto_guide_module_callable(model): - class GuideX(AutoGuide): - def __init__(self, model): - super().__init__(model) - self.x_loc = nn.Parameter(torch.tensor(1.0)) - self.x_scale = PyroParam(torch.tensor(0.1), constraint=constraints.positive) +class GuideX(AutoGuide): + def __init__(self, model): + super().__init__(model) + self.x_loc = nn.Parameter(torch.tensor(1.0)) + self.x_scale = PyroParam(torch.tensor(0.1), constraint=constraints.positive) + + def forward(self, *args, **kwargs): + return {"x": pyro.sample("x", dist.Normal(self.x_loc, self.x_scale))} - def forward(self, *args, **kwargs): - return {"x": pyro.sample("x", dist.Normal(self.x_loc, self.x_scale))} + def median(self, *args, **kwargs): + return {"x": self.x_loc.detach()} - def median(self, *args, **kwargs): - return {"x": self.x_loc.detach()} +def auto_guide_module_callable(model): guide = AutoGuideList(model) guide.custom = GuideX(model) guide.diagnorm = AutoDiagonalNormal(poutine.block(model, hide=["x"])) @@ -383,13 +384,6 @@ def model(): assert_equal(median["z"], torch.tensor(0.5), prec=0.1) -def serialize_model(): - pyro.sample("x", dist.Normal(0.0, 1.0)) - with pyro.plate("plate", 2): - pyro.sample("y", dist.LogNormal(0.0, 1.0)) - pyro.sample("z", dist.Beta(2.0, 2.0)) - - @pytest.mark.parametrize("jit", [False, True], ids=["nojit", "jit"]) @pytest.mark.parametrize( "auto_class", @@ -415,7 +409,13 @@ def serialize_model(): ) @pytest.mark.parametrize("Elbo", [Trace_ELBO, TraceGraph_ELBO, TraceEnum_ELBO]) def test_serialization(auto_class, Elbo, jit): - guide = auto_class(serialize_model) + def model(): + pyro.sample("x", dist.Normal(0.0, 1.0)) + with pyro.plate("plate", 2): + pyro.sample("y", dist.LogNormal(0.0, 1.0)) + pyro.sample("z", dist.Beta(2.0, 2.0)) + + guide = auto_class(model) guide() if auto_class is AutoLaplaceApproximation: guide = guide.laplace_approximation() From f51a64e1df415613a5976434375db416f77880ef Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Mon, 4 Oct 2021 14:15:45 -0400 Subject: [PATCH 41/41] fix typo --- pyro/infer/autoguide/guides.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyro/infer/autoguide/guides.py b/pyro/infer/autoguide/guides.py index ce42278610..27a7fed53a 100644 --- a/pyro/infer/autoguide/guides.py +++ b/pyro/infer/autoguide/guides.py @@ -83,7 +83,7 @@ def __getstate__(self): return getattr(super(), "__getstate__", self.__dict__.copy)() def __setstate__(self, state): - getattr(super(), "__getstate__", self.__dict__.update)(state) + getattr(super(), "__setstate__", self.__dict__.update)(state) assert self.master is None master_ref = weakref.ref(self) for _, mod in self.named_modules():