From 5db1109bc4fdf400afe0a2c7d77404b96e3693ae Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Sun, 3 Oct 2021 07:45:25 -0400 Subject: [PATCH] Add tests, update docs --- pyro/infer/autoguide/gaussian.py | 18 ++++++------ tests/infer/autoguide/test_gaussian.py | 38 ++++++++++++++++++++++++-- 2 files changed, 46 insertions(+), 10 deletions(-) diff --git a/pyro/infer/autoguide/gaussian.py b/pyro/infer/autoguide/gaussian.py index 53c7c76ea2..7df62c3d51 100644 --- a/pyro/infer/autoguide/gaussian.py +++ b/pyro/infer/autoguide/gaussian.py @@ -4,7 +4,7 @@ import itertools from collections import OrderedDict, defaultdict from contextlib import ExitStack -from typing import Callable, Dict, Set, Tuple, Union +from typing import Callable, Dict, Optional, Set, Tuple, Union import torch from torch.distributions import biject_to @@ -54,12 +54,14 @@ class AutoGaussian(AutoGuide, metaclass=AutoGaussianMeta): the model [1]. Depending on model structure, this can have asymptotically better statistical efficiency than :class:`AutoMultivariateNormal` . - The default "dense" backend should have similar computational complexity to - :class:`AutoMultivariateNormal` . The experimental "funsor" backend can be - asymptotically cheaper in terms of time and space (using Gaussian tensor - variable elimination [2,3]), but incurs large constant overhead. The - "funsor" backend requires `funsor `_ which can be - installed via ``pip install pyro-ppl[funsor]``. + This guide implements multiple backends for computation. All backends use + the same statistically optimal parametrization. The default "dense" backend + has computational complexity similar to :class:`AutoMultivariateNormal` . + The experimental "funsor" backend can be asymptotically cheaper in terms of + time and space (using Gaussian tensor variable elimination [2,3]), but + incurs large constant overhead. The "funsor" backend requires `funsor + `_ which can be installed via ``pip install + pyro-ppl[funsor]``. The guide currently does not depend on the model's ``*args, **kwargs``. @@ -105,7 +107,7 @@ def __init__( *, init_loc_fn: Callable = init_to_feasible, init_scale: float = 0.1, - backend=None, + backend: Optional[str] = None, # used only by metaclass ): if not isinstance(init_scale, float) or not (init_scale > 0): raise ValueError(f"Expected init_scale > 0. but got {init_scale}") diff --git a/tests/infer/autoguide/test_gaussian.py b/tests/infer/autoguide/test_gaussian.py index 4a7a61e521..dfa0b0b9ac 100644 --- a/tests/infer/autoguide/test_gaussian.py +++ b/tests/infer/autoguide/test_gaussian.py @@ -14,7 +14,7 @@ from pyro.infer.autoguide.gaussian import _break_plates from pyro.infer.reparam import LocScaleReparam from pyro.optim import Adam -from tests.common import assert_equal +from tests.common import assert_equal, xfail_if_not_implemented BACKENDS = [ "dense", @@ -180,6 +180,40 @@ def model(): check_structure(model, expected) +@pytest.mark.parametrize("backend", BACKENDS) +def test_broken_plates_smoke(backend): + def model(): + with pyro.plate("i", 2): + x = pyro.sample("x", dist.Normal(0, 1)) + pyro.sample("y", dist.Normal(x.mean(-1), 1), obs=torch.tensor(0.0)) + + guide = AutoGaussian(model, backend=backend) + svi = SVI(model, guide, Adam({"lr": 1e-8}), Trace_ELBO()) + for step in range(2): + with xfail_if_not_implemented(): + svi.step() + guide() + predictive = Predictive(model, guide=guide, num_samples=2) + predictive() + + +@pytest.mark.parametrize("backend", BACKENDS) +def test_intractable_smoke(backend): + def model(): + with pyro.plate("i", 2): + x = pyro.sample("x", dist.Normal(0, 1)) + pyro.sample("y", dist.Normal(x.mean(-1), 1), obs=torch.tensor(0.0)) + + guide = AutoGaussian(model, backend=backend) + svi = SVI(model, guide, Adam({"lr": 1e-8}), Trace_ELBO()) + for step in range(2): + with xfail_if_not_implemented(): + svi.step() + guide() + predictive = Predictive(model, guide=guide, num_samples=2) + predictive() + + # Simplified from https://github.com/pyro-cov/tree/master/pyrocov/mutrans.py def pyrocov_model(dataset): # Tensor shapes are commented at the end of some lines. @@ -486,7 +520,7 @@ def test_profile(backend, n=1, num_steps=1): """ Helper function for profiling. """ - model = pyrocov_model_plated + model = pyrocov_model_poisson T, P, S, F = 2 * n, 3 * n, 4 * n, 5 * n dataset = { "features": torch.randn(S, F),