Skip to content

Commit

Permalink
Flesh out funsor backend
Browse files Browse the repository at this point in the history
  • Loading branch information
fritzo committed Oct 3, 2021
1 parent 58c344c commit 670fdda
Showing 1 changed file with 31 additions and 20 deletions.
51 changes: 31 additions & 20 deletions pyro/infer/autoguide/gaussian.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ class AutoGaussian(AutoGuide, metaclass=AutoGaussianMeta):
guide = AutoGaussian(model)
svi = SVI(model, guide, ...)
Example using funsor backend::
Example using experimental funsor backend::
!pip install pyro-ppl[funsor]
guide = AutoGaussian(model, backend="funsor")
Expand Down Expand Up @@ -201,8 +201,7 @@ def forward(self, *args, **kwargs) -> Dict[str, torch.Tensor]:
for name, site in self._factors.items():
with ExitStack() as stack:
for frame in site["cond_indep_stack"]:
if frame.vectorized:
stack.enter_context(plates[frame.name])
stack.enter_context(plates[frame.name])
pyro.sample(
name,
dist.Delta(values[name], log_densities[name], site["fn"].event_dim),
Expand Down Expand Up @@ -352,7 +351,7 @@ def _get_precision(self):
flat_precision = torch.zeros(self._dense_size ** 2)
for d, index in self._dense_scatter.items():
sqrt = deep_getattr(self.factors, d)
precision = sqrt @ sqrt.transpose(dim0=-2, dim1=-1)
precision = sqrt @ sqrt.transpose(-1, -2)
flat_precision.scatter_add_(0, index, precision.reshape(-1))
precision = flat_precision.reshape(self._dense_size, self._dense_size)
return precision
Expand All @@ -368,6 +367,7 @@ class AutoGaussianFunsor(AutoGaussian):
"""

# Attributes are prefixed by ._funsor_
# This uses tensor variable elimination (TVE).

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
Expand All @@ -386,29 +386,39 @@ def _setup_prototype(self, *args, **kwargs):

funsor.set_backend("torch")

# Break plates globally to fit this into a TVE problem.
broken_plates = frozenset()
for d in self._factors:
for u in self.dependencies[d]:
broken_plates |= self._plates[u] - self._plates[d]
broken_vars: Dict[str, Tuple[funsor.Variable, ...]] = {}
broken_event_shapes: Dict[str, Tuple[int, ...]] = {}
for u, event_shape in self._unconstrained_event_shapes.items():
plates = sorted(self._plates[u] & broken_plates, key=lambda p: p.size)
broken_vars[u] = tuple(
funsor.Variable(p.name, funsor.Bint[p.size]) for p in plates
)
broken_event_shapes[u] = tuple(p.size for p in plates) + event_shape

# Determine TVE problem shape.
factor_inputs: Dict[str, OrderedDict[str, funsor.Domain]] = {}
eliminate: Set[str] = set()
plate_to_dim: Dict[str, int] = {}

for d, site in self._factors.items():
# Order inputs as in the model, so as to maximize sparsity of the
# lower Cholesky parametrization of the precision matrix.
inputs = OrderedDict()
for f in site["cond_indep_stack"]:
if f.vectorized:
plate_to_dim[f.name] = f.dim
if f.name not in self._broken_plates[d]:
inputs[f.name] = funsor.Bint[f.size]
eliminate.add(f.name)
plate_to_dim[f.name] = f.dim
if f not in broken_plates:
inputs[f.name] = funsor.Bint[f.size]
eliminate.add(f.name)
if not site["is_observed"]:
inputs[d] = funsor.Reals[broken_event_shapes[d]]
for u in self.dependencies[d]:
inputs[u] = funsor.Reals[self._broken_event_shapes[u]]
inputs[u] = funsor.Reals[broken_event_shapes[u]]
eliminate.add(u)
if not site["is_observed"]:
inputs[d] = funsor.Reals[self._broken_event_shapes[d]]
assert d in eliminate
factor_inputs[d] = inputs

self._funsor_broken_vars = broken_vars
self._funsor_factor_inputs = factor_inputs
self._funsor_eliminate = frozenset(eliminate)
self._funsor_plate_to_dim = plate_to_dim
Expand All @@ -427,11 +437,12 @@ def _sample_aux_values(
plate_to_dim.update({f.name: f.dim for f in particle_plates})
factors = {}
for d, inputs in self._funsor_factor_inputs.items():
precision_chol = deep_getattr(self.precision_chols, d)
precision = precision_chol @ precision_chol.transpose(-1, -2)
sqrt = deep_getattr(self.factors, d)
if self._funsor_broken_vars:
raise NotImplementedError("TODO break plates in sqrt")
precision = sqrt @ sqrt.transpose(-1, -2)
info_vec = precision.new_zeros(()).expand(precision.shape[:-1])
factors[d] = funsor.gaussian.Gaussian(info_vec, precision, inputs)
factors[d]._precision_chol = precision_chol # avoid recomputing

# Perform Gaussian tensor variable elimination.
samples, log_prob = funsor.recipes.forward_filter_backward_rsample(
Expand All @@ -443,7 +454,7 @@ def _sample_aux_values(

# Convert funsor to torch.
samples = {
k: funsor.to_data(v[self._broken_plates[k]], name_to_dim=plate_to_dim)
k: funsor.to_data(v[self._funsor_broken_vars[k]], name_to_dim=plate_to_dim)
for k, v in samples.items()
}
log_density = funsor.to_data(log_prob, name_to_dim=plate_to_dim)
Expand Down

0 comments on commit 670fdda

Please sign in to comment.