Skip to content

Commit

Permalink
Wrap intractable error in NotImplemented
Browse files Browse the repository at this point in the history
  • Loading branch information
fritzo committed Oct 4, 2021
1 parent fcf5b1e commit ade50a9
Show file tree
Hide file tree
Showing 3 changed files with 111 additions and 52 deletions.
49 changes: 30 additions & 19 deletions pyro/infer/autoguide/gaussian.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,9 +176,8 @@ def _setup_prototype(self, *args, **kwargs) -> None:
for d, site in self._factors.items():
u_size = 0
for u in self.dependencies[d]:
if not self._factors[u]["is_observed"]:
broken_shape = _plates_to_shape(self._plates[u] - self._plates[d])
u_size += broken_shape.numel() * self._event_numel[u]
broken_shape = _plates_to_shape(self._plates[u] - self._plates[d])
u_size += broken_shape.numel() * self._event_numel[u]
d_size = self._event_numel[d]
if site["is_observed"]:
d_size = min(d_size, u_size) # just an optimization
Expand Down Expand Up @@ -361,25 +360,27 @@ def _setup_prototype(self, *args, **kwargs):
super()._setup_prototype(*args, **kwargs)
funsor = _import_funsor()

# Check plates are strictly nested.
# Check TVE condition 1: plate nesting is monotone.
for d in self._factors:
pd = {p.name for p in self._plates[d]}
for u in self.dependencies[d]:
broken_plates = self._plates[u] - self._plates[d]
if broken_plates:
raise NotImplementedError(
"Expected strictly increasing plates, but found dependency "
f"{u} -> {d} leaves plates {set(broken_plates)}. "
"Consider splitting into multiple guides via AutoGuideList, "
"or replacing the plate in the model by .to_event()."
)
pu = {p.name for p in self._plates[u]}
if pu <= pd:
continue # ok
raise NotImplementedError(
"Expected monotone plate nesting, but found dependency "
f"{repr(u)} -> {repr(d)} leaves plates {pu - pd}. "
"Consider splitting into multiple guides via AutoGuideList, "
"or replacing the plate in the model by .to_event()."
)

# Determine TVE problem shape.
factor_inputs: Dict[str, OrderedDict[str, funsor.Domain]] = {}
eliminate: Set[str] = set()
plate_to_dim: Dict[str, int] = {}
for d, site in self._factors.items():
inputs = OrderedDict()
for f in site["cond_indep_stack"]:
for f in sorted(site["cond_indep_stack"], key=lambda f: f.dim):
plate_to_dim[f.name] = f.dim
inputs[f.name] = funsor.Bint[f.size]
eliminate.add(f.name)
Expand Down Expand Up @@ -412,12 +413,22 @@ def _sample_aux_values(self) -> Dict[str, torch.Tensor]:
factors[d] = funsor.gaussian.Gaussian(info_vec, precision, inputs)

# Perform Gaussian tensor variable elimination.
samples, log_prob = funsor.recipes.forward_filter_backward_rsample(
factors=factors,
eliminate=self._funsor_eliminate,
plates=frozenset(plate_to_dim),
sample_inputs={f.name: funsor.Bint[f.size] for f in particle_plates},
)
try: # Convert ValueError into NotImplementedError.
samples, log_prob = funsor.recipes.forward_filter_backward_rsample(
factors=factors,
eliminate=self._funsor_eliminate,
plates=frozenset(plate_to_dim),
sample_inputs={f.name: funsor.Bint[f.size] for f in particle_plates},
)
except ValueError as e:
if str(e) != "intractable!":
raise e from None
raise NotImplementedError(
"Funsor backend found intractable plate nesting. "
'Consider using AutoGaussian(..., backend="dense"), '
"splitting into multiple guides via AutoGuideList, or "
"replacing some plates in the model by .to_event()."
) from e

# Convert funsor to torch.
if am_i_wrapped() and poutine.get_mask() is not False:
Expand Down
34 changes: 29 additions & 5 deletions tests/infer/autoguide/test_gaussian.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,12 +197,30 @@ def model():
check_structure(model, expected)


def test_structure_5():
def model():
i_plate = pyro.plate("i", 2, dim=-1)
with i_plate:
a = pyro.sample("a", dist.Normal(0, 1))
b = pyro.sample("b", dist.Normal(a.mean(-1), 1))
with i_plate:
pyro.sample("c", dist.Normal(b, 1), obs=torch.zeros(2))

# size = 2 + 1 = 3
expected = [
"? . ?",
". ? ?",
"? ? ?",
]
check_structure(model, expected)


@pytest.mark.parametrize("backend", BACKENDS)
def test_broken_plates_smoke(backend):
def model():
with pyro.plate("i", 2):
x = pyro.sample("x", dist.Normal(0, 1))
pyro.sample("y", dist.Normal(x.mean(-1), 1), obs=torch.tensor(0.0))
a = pyro.sample("a", dist.Normal(0, 1))
pyro.sample("b", dist.Normal(a.mean(-1), 1), obs=torch.tensor(0.0))

guide = AutoGaussian(model, backend=backend)
svi = SVI(model, guide, Adam({"lr": 1e-8}), Trace_ELBO())
Expand All @@ -217,9 +235,15 @@ def model():
@pytest.mark.parametrize("backend", BACKENDS)
def test_intractable_smoke(backend):
def model():
with pyro.plate("i", 2):
x = pyro.sample("x", dist.Normal(0, 1))
pyro.sample("y", dist.Normal(x.mean(-1), 1), obs=torch.tensor(0.0))
i_plate = pyro.plate("i", 2, dim=-1)
j_plate = pyro.plate("j", 3, dim=-2)
with i_plate:
a = pyro.sample("a", dist.Normal(0, 1))
with j_plate:
b = pyro.sample("b", dist.Normal(0, 1))
with i_plate, j_plate:
c = pyro.sample("c", dist.Normal(a + b, 1))
pyro.sample("d", dist.Normal(c, 1), obs=torch.zeros(3, 2))

guide = AutoGaussian(model, backend=backend)
svi = SVI(model, guide, Adam({"lr": 1e-8}), Trace_ELBO())
Expand Down
80 changes: 52 additions & 28 deletions tests/infer/test_autoguide.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import functools
import io
import warnings
from operator import attrgetter

import numpy as np
import pytest
Expand Down Expand Up @@ -44,6 +43,13 @@
from pyro.util import check_model_guide_match
from tests.common import assert_close, assert_equal

AutoGaussianFunsor_xfail = pytest.param(
AutoGaussianFunsor,
marks=[
pytest.mark.stage("funsor"),
pytest.mark.xfail(reason="jit is not supported"),
],
)
AutoGaussianFunsor = pytest.param(
AutoGaussianFunsor, marks=[pytest.mark.stage("funsor")]
)
Expand Down Expand Up @@ -377,6 +383,14 @@ def model():
assert_equal(median["z"], torch.tensor(0.5), prec=0.1)


def serialize_model():
pyro.sample("x", dist.Normal(0.0, 1.0))
with pyro.plate("plate", 2):
pyro.sample("y", dist.LogNormal(0.0, 1.0))
pyro.sample("z", dist.Beta(2.0, 2.0))


@pytest.mark.parametrize("jit", [False, True], ids=["nojit", "jit"])
@pytest.mark.parametrize(
"auto_class",
[
Expand All @@ -395,49 +409,57 @@ def model():
functools.partial(AutoDiagonalNormal, init_loc_fn=init_to_sample),
AutoStructured,
AutoStructured_median,
AutoGaussian,
AutoGaussianFunsor_xfail,
],
)
@pytest.mark.parametrize("Elbo", [Trace_ELBO, TraceGraph_ELBO, TraceEnum_ELBO])
def test_autoguide_serialization(auto_class, Elbo):
def model():
pyro.sample("x", dist.Normal(0.0, 1.0))
with pyro.plate("plate", 2):
pyro.sample("y", dist.LogNormal(0.0, 1.0))
pyro.sample("z", dist.Beta(2.0, 2.0))

guide = auto_class(model)
def test_serialization(auto_class, Elbo, jit):
guide = auto_class(serialize_model)
guide()
if auto_class is AutoLaplaceApproximation:
guide = guide.laplace_approximation()
pyro.set_rng_seed(0)
expected = guide.call()
names = sorted(guide())

# Ignore tracer warnings
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=torch.jit.TracerWarning)
# XXX: check_trace=True fails for AutoLaplaceApproximation
traced_guide = torch.jit.trace_module(guide, {"call": ()}, check_trace=False)
f = io.BytesIO()
torch.jit.save(traced_guide, f)
f.seek(0)
guide_deser = torch.jit.load(f)
latent_names = sorted(guide())
expected_params = {k: v.data for k, v in guide.named_parameters()}

if jit:
# Ignore tracer warnings
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=torch.jit.TracerWarning)
# XXX: check_trace=True fails for AutoLaplaceApproximation
traced_guide = torch.jit.trace_module(
guide, {"call": ()}, check_trace=False
)
f = io.BytesIO()
torch.jit.save(traced_guide, f)
del guide, traced_guide
pyro.clear_param_store()
f.seek(0)
guide_deser = torch.jit.load(f)
else:
# Work around https://github.com/pytorch/pytorch/issues/27972
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=UserWarning)
f = io.BytesIO()
torch.save(guide, f)
f.seek(0)
guide_deser = torch.load(f)

# Check .call() result.
pyro.set_rng_seed(0)
actual = guide_deser.call()
assert len(actual) == len(expected)
for name, a, e in zip(names, actual, expected):
for name, a, e in zip(latent_names, actual, expected):
assert_equal(a, e, msg="{}: {} vs {}".format(name, a, e))

# Check named_parameters.
expected_names = {name for name, _ in guide.named_parameters()}
actual_names = {name for name, _ in guide_deser.named_parameters()}
assert actual_names == expected_names
for name in actual_names:
# Get nested attributes.
attr_get = attrgetter(name)
assert_equal(attr_get(guide_deser), attr_get(guide).data)
actual_params = {k: v.data for k, v in guide_deser.named_parameters()}
assert set(actual_params) == set(expected_params)
for name, expected in expected_params.items():
actual = actual_params[name]
assert_equal(actual, expected)


def AutoGuideList_x(model):
Expand Down Expand Up @@ -871,6 +893,8 @@ def __init__(self, model):
functools.partial(AutoDiagonalNormal, init_loc_fn=init_to_median),
AutoStructured,
AutoStructured_predictive,
AutoGaussian,
AutoGaussianFunsor_xfail,
],
)
def test_predictive(auto_class):
Expand Down

0 comments on commit ade50a9

Please sign in to comment.